diff options
| author | John Oxley <john.oxley@gmail.com> | 2026-05-14 22:53:32 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-05-14 14:53:32 -0700 |
| commit | d98ecd58af9f6d850461b8bef430dfef70111692 (patch) | |
| tree | 7b83a7718fa3b3a742134a9700b29330deafc400 | |
| parent | fix(ci): fossier install in scan workflow (#3485) (diff) | |
| download | atuin-d98ecd58af9f6d850461b8bef430dfef70111692.zip | |
refactor: Implement From<sqlx::Error> and clean up fix_error (#3484)
In the database crates for atuin-server, there is `fn fix_error`. This
PR implements `From<sqlx::Error>` on `DbError` which makes it possible
to mostly use `?` to bubble up the errors.
There are cases where `?` is not being used e.g.
```rust
async fn get_session(&self, token: &str) -> DbResult<Session> {
sqlx::query_as("select id, user_id, token from sessions where token = $1")
.bind(token)
.fetch_one(&self.pool)
.await
.map_err(fix_error)
.map(|DbSession(session)| session)
}
```
There are two options
## 1. Use `Into::into`
```rust
async fn get_session(&self, token: &str) -> DbResult<Session> {
sqlx::query_as("select id, user_id, token from sessions where token = $1")
.bind(token)
.fetch_one(&self.pool)
.await
.map_err(fix_error)
.map(|DbSession(session)| session)
}
```
## 2. Create a variable and use `?`
```rust
async fn get_session(&self, token: &str) -> DbResult<Session> {
let session = sqlx::query_as("select id, user_id, token from sessions where token = $1")
.bind(token)
.fetch_one(&self.pool)
.await
.map(|DbSession(session)| session)?;
Ok(session)
}
```
I chose to do option 1 as it was just a find/replace but say the word
and I'll convert them all to option 2
## Checks
- [X] I am happy for maintainers to push small adjustments to this PR,
to speed up the review cycle
- [X] I have checked that there are no existing pull requests for the
same thing
| -rw-r--r-- | Cargo.lock | 1 | ||||
| -rw-r--r-- | crates/atuin-server-database/Cargo.toml | 7 | ||||
| -rw-r--r-- | crates/atuin-server-database/src/lib.rs | 21 | ||||
| -rw-r--r-- | crates/atuin-server-postgres/src/lib.rs | 119 | ||||
| -rw-r--r-- | crates/atuin-server-sqlite/src/lib.rs | 80 |
5 files changed, 94 insertions, 134 deletions
@@ -573,6 +573,7 @@ dependencies = [ "atuin-common", "eyre", "serde", + "sqlx", "time", "tracing", "url", diff --git a/crates/atuin-server-database/Cargo.toml b/crates/atuin-server-database/Cargo.toml index e361c68b..52ccbf97 100644 --- a/crates/atuin-server-database/Cargo.toml +++ b/crates/atuin-server-database/Cargo.toml @@ -12,9 +12,10 @@ repository = { workspace = true } [dependencies] atuin-common = { path = "../atuin-common", version = "18.16.1" } -tracing = { workspace = true } -time = { workspace = true } +async-trait = { workspace = true } eyre = { workspace = true } serde = { workspace = true } -async-trait = { workspace = true } +sqlx = { workspace = true } +time = { workspace = true } +tracing = { workspace = true } url = "2.5.2" diff --git a/crates/atuin-server-database/src/lib.rs b/crates/atuin-server-database/src/lib.rs index 6000a530..9dd95eef 100644 --- a/crates/atuin-server-database/src/lib.rs +++ b/crates/atuin-server-database/src/lib.rs @@ -31,9 +31,24 @@ impl Display for DbError { } } -impl<T: std::error::Error + Into<time::error::Error>> From<T> for DbError { - fn from(value: T) -> Self { - DbError::Other(value.into().into()) +impl From<time::error::ComponentRange> for DbError { + fn from(error: time::error::ComponentRange) -> Self { + DbError::Other(error.into()) + } +} + +impl From<time::error::Error> for DbError { + fn from(error: time::error::Error) -> Self { + DbError::Other(error.into()) + } +} + +impl From<sqlx::Error> for DbError { + fn from(error: sqlx::Error) -> Self { + match error { + sqlx::Error::RowNotFound => DbError::NotFound, + error => DbError::Other(error.into()), + } } } diff --git a/crates/atuin-server-postgres/src/lib.rs b/crates/atuin-server-postgres/src/lib.rs index ce101d8d..eeb9da14 100644 --- a/crates/atuin-server-postgres/src/lib.rs +++ b/crates/atuin-server-postgres/src/lib.rs @@ -35,33 +35,24 @@ impl Postgres { } } -fn fix_error(error: sqlx::Error) -> DbError { - match error { - sqlx::Error::RowNotFound => DbError::NotFound, - error => DbError::Other(error.into()), - } -} - #[async_trait] impl Database for Postgres { async fn new(settings: &DbSettings) -> DbResult<Self> { let pool = PgPoolOptions::new() .max_connections(100) .connect(settings.db_uri.as_str()) - .await - .map_err(fix_error)?; + .await?; // Call server_version_num to get the DB server's major version number // The call returns None for servers older than 8.x. - let pg_major_version: u32 = pool - .acquire() - .await - .map_err(fix_error)? - .server_version_num() - .ok_or(DbError::Other(eyre::Report::msg( - "could not get PostgreSQL version", - )))? - / 10000; + let pg_major_version: u32 = + pool.acquire() + .await? + .server_version_num() + .ok_or(DbError::Other(eyre::Report::msg( + "could not get PostgreSQL version", + )))? + / 10000; if pg_major_version < MIN_PG_VERSION { return Err(DbError::Other(eyre::Report::msg(format!( @@ -80,14 +71,12 @@ impl Database for Postgres { let read_pool = PgPoolOptions::new() .max_connections(100) .connect(read_db_uri.as_str()) - .await - .map_err(fix_error)?; + .await?; // Verify the read replica is also a supported PostgreSQL version let read_pg_major_version: u32 = read_pool .acquire() - .await - .map_err(fix_error)? + .await? .server_version_num() .ok_or(DbError::Other(eyre::Report::msg( "could not get PostgreSQL version from read replica", @@ -114,7 +103,7 @@ impl Database for Postgres { .bind(token) .fetch_one(self.read_pool()) .await - .map_err(fix_error) + .map_err(Into::into) .map(|DbSession(session)| session) } @@ -124,7 +113,7 @@ impl Database for Postgres { .bind(username) .fetch_one(self.read_pool()) .await - .map_err(fix_error) + .map_err(Into::into) .map(|DbUser(user)| user) } @@ -139,7 +128,7 @@ impl Database for Postgres { .bind(token) .fetch_one(self.read_pool()) .await - .map_err(fix_error) + .map_err(Into::into) .map(|DbUser(user)| user) } @@ -155,8 +144,7 @@ impl Database for Postgres { ) .bind(user.id) .fetch_one(self.read_pool()) - .await - .map_err(fix_error)?; + .await?; Ok(res.0) } @@ -169,14 +157,13 @@ impl Database for Postgres { ) .bind(user.id) .fetch_one(self.read_pool()) - .await - .map_err(fix_error)?; + .await?; Ok(res.0 as i64) } async fn delete_store(&self, user: &User) -> DbResult<()> { - let mut tx = self.pool.begin().await.map_err(fix_error)?; + let mut tx = self.pool.begin().await?; sqlx::query( "delete from store @@ -184,8 +171,7 @@ impl Database for Postgres { ) .bind(user.id) .execute(&mut *tx) - .await - .map_err(fix_error)?; + .await?; sqlx::query( "delete from store_idx_cache @@ -193,10 +179,9 @@ impl Database for Postgres { ) .bind(user.id) .execute(&mut *tx) - .await - .map_err(fix_error)?; + .await?; - tx.commit().await.map_err(fix_error)?; + tx.commit().await?; Ok(()) } @@ -213,8 +198,7 @@ impl Database for Postgres { .bind(id) .bind(OffsetDateTime::now_utc()) .fetch_all(&self.pool) - .await - .map_err(fix_error)?; + .await?; Ok(()) } @@ -232,8 +216,7 @@ impl Database for Postgres { ) .bind(user.id) .fetch_all(self.read_pool()) - .await - .map_err(fix_error)?; + .await?; let res = res .iter() @@ -259,8 +242,7 @@ impl Database for Postgres { .bind(into_utc(range.start)) .bind(into_utc(range.end)) .fetch_one(self.read_pool()) - .await - .map_err(fix_error)?; + .await?; Ok(res.0) } @@ -291,15 +273,14 @@ impl Database for Postgres { .fetch(self.read_pool()) .map_ok(|DbHistory(h)| h) .try_collect() - .await - .map_err(fix_error)?; + .await?; Ok(res) } #[instrument(skip_all)] async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> { - let mut tx = self.pool.begin().await.map_err(fix_error)?; + let mut tx = self.pool.begin().await?; for i in history { let client_id: &str = &i.client_id; @@ -319,11 +300,10 @@ impl Database for Postgres { .bind(i.timestamp) .bind(data) .execute(&mut *tx) - .await - .map_err(fix_error)?; + .await?; } - tx.commit().await.map_err(fix_error)?; + tx.commit().await?; Ok(()) } @@ -333,32 +313,27 @@ impl Database for Postgres { sqlx::query("delete from sessions where user_id = $1") .bind(u.id) .execute(&self.pool) - .await - .map_err(fix_error)?; + .await?; sqlx::query("delete from history where user_id = $1") .bind(u.id) .execute(&self.pool) - .await - .map_err(fix_error)?; + .await?; sqlx::query("delete from store where user_id = $1") .bind(u.id) .execute(&self.pool) - .await - .map_err(fix_error)?; + .await?; sqlx::query("delete from total_history_count_user where user_id = $1") .bind(u.id) .execute(&self.pool) - .await - .map_err(fix_error)?; + .await?; sqlx::query("delete from users where id = $1") .bind(u.id) .execute(&self.pool) - .await - .map_err(fix_error)?; + .await?; Ok(()) } @@ -373,8 +348,7 @@ impl Database for Postgres { .bind(&user.password) .bind(user.id) .execute(&self.pool) - .await - .map_err(fix_error)?; + .await?; Ok(()) } @@ -395,8 +369,7 @@ impl Database for Postgres { .bind(email) .bind(password) .fetch_one(&self.pool) - .await - .map_err(fix_error)?; + .await?; Ok(res.0) } @@ -413,8 +386,7 @@ impl Database for Postgres { .bind(session.user_id) .bind(token) .execute(&self.pool) - .await - .map_err(fix_error)?; + .await?; Ok(()) } @@ -425,7 +397,7 @@ impl Database for Postgres { .bind(u.id) .fetch_one(self.read_pool()) .await - .map_err(fix_error) + .map_err(Into::into) .map(|DbSession(session)| session) } @@ -440,13 +412,13 @@ impl Database for Postgres { .bind(user.id) .fetch_one(self.read_pool()) .await - .map_err(fix_error) + .map_err(Into::into) .map(|DbHistory(h)| h) } #[instrument(skip_all)] async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> { - let mut tx = self.pool.begin().await.map_err(fix_error)?; + let mut tx = self.pool.begin().await?; // We won't have uploaded this data if it wasn't the max. Therefore, we can deduce the max // idx without having to make further database queries. Doing the query on this small @@ -478,8 +450,7 @@ impl Database for Postgres { .bind(&i.data.content_encryption_key) .bind(user.id) .execute(&mut *tx) - .await - .map_err(fix_error)?; + .await?; // Only update heads if we actually inserted the record if result.rows_affected() > 0 { @@ -509,10 +480,10 @@ impl Database for Postgres { .bind(idx as i64) .execute(&mut *tx) .await - .map_err(fix_error)?; + ?; } - tx.commit().await.map_err(fix_error)?; + tx.commit().await?; Ok(()) } @@ -545,7 +516,7 @@ impl Database for Postgres { .bind(count as i64) .fetch_all(self.read_pool()) .await - .map_err(fix_error); + .map_err(Into::into); let ret = match records { Ok(records) => { @@ -588,15 +559,13 @@ impl Database for Postgres { sqlx::query_as("select host, tag, idx from store_idx_cache where user_id = $1") .bind(user.id) .fetch_all(self.read_pool()) - .await - .map_err(fix_error)? + .await? } else { tracing::debug!("using aggregate query for user {}", user.id); sqlx::query_as(STATUS_SQL) .bind(user.id) .fetch_all(self.read_pool()) - .await - .map_err(fix_error)? + .await? }; res.sort(); diff --git a/crates/atuin-server-sqlite/src/lib.rs b/crates/atuin-server-sqlite/src/lib.rs index d69258c4..7d4dcb86 100644 --- a/crates/atuin-server-sqlite/src/lib.rs +++ b/crates/atuin-server-sqlite/src/lib.rs @@ -23,25 +23,14 @@ pub struct Sqlite { pool: sqlx::Pool<sqlx::sqlite::Sqlite>, } -fn fix_error(error: sqlx::Error) -> DbError { - match error { - sqlx::Error::RowNotFound => DbError::NotFound, - error => DbError::Other(error.into()), - } -} - #[async_trait] impl Database for Sqlite { async fn new(settings: &DbSettings) -> DbResult<Self> { - let opts = SqliteConnectOptions::from_str(&settings.db_uri) - .map_err(fix_error)? + let opts = SqliteConnectOptions::from_str(&settings.db_uri)? .journal_mode(SqliteJournalMode::Wal) .create_if_missing(true); - let pool = SqlitePoolOptions::new() - .connect_with(opts) - .await - .map_err(fix_error)?; + let pool = SqlitePoolOptions::new().connect_with(opts).await?; sqlx::migrate!("./migrations") .run(&pool) @@ -57,7 +46,7 @@ impl Database for Sqlite { .bind(token) .fetch_one(&self.pool) .await - .map_err(fix_error) + .map_err(Into::into) .map(|DbSession(session)| session) } @@ -72,7 +61,7 @@ impl Database for Sqlite { .bind(token) .fetch_one(&self.pool) .await - .map_err(fix_error) + .map_err(Into::into) .map(|DbUser(user)| user) } @@ -88,8 +77,7 @@ impl Database for Sqlite { .bind(session.user_id) .bind(token) .execute(&self.pool) - .await - .map_err(fix_error)?; + .await?; Ok(()) } @@ -100,7 +88,7 @@ impl Database for Sqlite { .bind(username) .fetch_one(&self.pool) .await - .map_err(fix_error) + .map_err(Into::into) .map(|DbUser(user)| user) } @@ -110,7 +98,7 @@ impl Database for Sqlite { .bind(u.id) .fetch_one(&self.pool) .await - .map_err(fix_error) + .map_err(Into::into) .map(|DbSession(session)| session) } @@ -130,8 +118,7 @@ impl Database for Sqlite { .bind(email) .bind(password) .fetch_one(&self.pool) - .await - .map_err(fix_error)?; + .await?; Ok(res.0) } @@ -146,8 +133,7 @@ impl Database for Sqlite { .bind(&user.password) .bind(user.id) .execute(&self.pool) - .await - .map_err(fix_error)?; + .await?; Ok(()) } @@ -164,8 +150,7 @@ impl Database for Sqlite { ) .bind(user.id) .fetch_one(&self.pool) - .await - .map_err(fix_error)?; + .await?; Ok(res.0) } @@ -180,20 +165,17 @@ impl Database for Sqlite { sqlx::query("delete from sessions where user_id = $1") .bind(u.id) .execute(&self.pool) - .await - .map_err(fix_error)?; + .await?; sqlx::query("delete from users where id = $1") .bind(u.id) .execute(&self.pool) - .await - .map_err(fix_error)?; + .await?; sqlx::query("delete from history where user_id = $1") .bind(u.id) .execute(&self.pool) - .await - .map_err(fix_error)?; + .await?; Ok(()) } @@ -210,8 +192,7 @@ impl Database for Sqlite { .bind(id) .bind(time::OffsetDateTime::now_utc()) .fetch_all(&self.pool) - .await - .map_err(fix_error)?; + .await?; Ok(()) } @@ -229,8 +210,7 @@ impl Database for Sqlite { ) .bind(user.id) .fetch_all(&self.pool) - .await - .map_err(fix_error)?; + .await?; let res = res.iter().map(|row| row.get("client_id")).collect(); @@ -244,15 +224,14 @@ impl Database for Sqlite { ) .bind(user.id) .execute(&self.pool) - .await - .map_err(fix_error)?; + .await?; Ok(()) } #[instrument(skip_all)] async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> { - let mut tx = self.pool.begin().await.map_err(fix_error)?; + let mut tx = self.pool.begin().await?; for i in records { let id = atuin_common::utils::uuid_v7(); @@ -275,11 +254,10 @@ impl Database for Sqlite { .bind(&i.data.content_encryption_key) .bind(user.id) .execute(&mut *tx) - .await - .map_err(fix_error)?; + .await?; } - tx.commit().await.map_err(fix_error)?; + tx.commit().await?; Ok(()) } @@ -312,7 +290,7 @@ impl Database for Sqlite { .bind(count as i64) .fetch_all(&self.pool) .await - .map_err(fix_error); + .map_err(Into::into); let ret = match records { Ok(records) => { @@ -343,8 +321,7 @@ impl Database for Sqlite { let res: Vec<(Uuid, String, i64)> = sqlx::query_as(STATUS_SQL) .bind(user.id) .fetch_all(&self.pool) - .await - .map_err(fix_error)?; + .await?; let mut status = RecordStatus::new(); @@ -371,8 +348,7 @@ impl Database for Sqlite { .bind(into_utc(range.start)) .bind(into_utc(range.end)) .fetch_one(&self.pool) - .await - .map_err(fix_error)?; + .await?; Ok(res.0) } @@ -403,15 +379,14 @@ impl Database for Sqlite { .fetch(&self.pool) .map_ok(|DbHistory(h)| h) .try_collect() - .await - .map_err(fix_error)?; + .await?; Ok(res) } #[instrument(skip_all)] async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> { - let mut tx = self.pool.begin().await.map_err(fix_error)?; + let mut tx = self.pool.begin().await?; for i in history { let client_id: &str = &i.client_id; @@ -431,11 +406,10 @@ impl Database for Sqlite { .bind(i.timestamp) .bind(data) .execute(&mut *tx) - .await - .map_err(fix_error)?; + .await?; } - tx.commit().await.map_err(fix_error)?; + tx.commit().await?; Ok(()) } @@ -451,7 +425,7 @@ impl Database for Sqlite { .bind(user.id) .fetch_one(&self.pool) .await - .map_err(fix_error) + .map_err(Into::into) .map(|DbHistory(h)| h) } } |
