aboutsummaryrefslogtreecommitdiffstats
path: root/crates/atuin-server-postgres/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/atuin-server-postgres/src/lib.rs')
-rw-r--r--crates/atuin-server-postgres/src/lib.rs119
1 files changed, 44 insertions, 75 deletions
diff --git a/crates/atuin-server-postgres/src/lib.rs b/crates/atuin-server-postgres/src/lib.rs
index ce101d8d..eeb9da14 100644
--- a/crates/atuin-server-postgres/src/lib.rs
+++ b/crates/atuin-server-postgres/src/lib.rs
@@ -35,33 +35,24 @@ impl Postgres {
}
}
-fn fix_error(error: sqlx::Error) -> DbError {
- match error {
- sqlx::Error::RowNotFound => DbError::NotFound,
- error => DbError::Other(error.into()),
- }
-}
-
#[async_trait]
impl Database for Postgres {
async fn new(settings: &DbSettings) -> DbResult<Self> {
let pool = PgPoolOptions::new()
.max_connections(100)
.connect(settings.db_uri.as_str())
- .await
- .map_err(fix_error)?;
+ .await?;
// Call server_version_num to get the DB server's major version number
// The call returns None for servers older than 8.x.
- let pg_major_version: u32 = pool
- .acquire()
- .await
- .map_err(fix_error)?
- .server_version_num()
- .ok_or(DbError::Other(eyre::Report::msg(
- "could not get PostgreSQL version",
- )))?
- / 10000;
+ let pg_major_version: u32 =
+ pool.acquire()
+ .await?
+ .server_version_num()
+ .ok_or(DbError::Other(eyre::Report::msg(
+ "could not get PostgreSQL version",
+ )))?
+ / 10000;
if pg_major_version < MIN_PG_VERSION {
return Err(DbError::Other(eyre::Report::msg(format!(
@@ -80,14 +71,12 @@ impl Database for Postgres {
let read_pool = PgPoolOptions::new()
.max_connections(100)
.connect(read_db_uri.as_str())
- .await
- .map_err(fix_error)?;
+ .await?;
// Verify the read replica is also a supported PostgreSQL version
let read_pg_major_version: u32 = read_pool
.acquire()
- .await
- .map_err(fix_error)?
+ .await?
.server_version_num()
.ok_or(DbError::Other(eyre::Report::msg(
"could not get PostgreSQL version from read replica",
@@ -114,7 +103,7 @@ impl Database for Postgres {
.bind(token)
.fetch_one(self.read_pool())
.await
- .map_err(fix_error)
+ .map_err(Into::into)
.map(|DbSession(session)| session)
}
@@ -124,7 +113,7 @@ impl Database for Postgres {
.bind(username)
.fetch_one(self.read_pool())
.await
- .map_err(fix_error)
+ .map_err(Into::into)
.map(|DbUser(user)| user)
}
@@ -139,7 +128,7 @@ impl Database for Postgres {
.bind(token)
.fetch_one(self.read_pool())
.await
- .map_err(fix_error)
+ .map_err(Into::into)
.map(|DbUser(user)| user)
}
@@ -155,8 +144,7 @@ impl Database for Postgres {
)
.bind(user.id)
.fetch_one(self.read_pool())
- .await
- .map_err(fix_error)?;
+ .await?;
Ok(res.0)
}
@@ -169,14 +157,13 @@ impl Database for Postgres {
)
.bind(user.id)
.fetch_one(self.read_pool())
- .await
- .map_err(fix_error)?;
+ .await?;
Ok(res.0 as i64)
}
async fn delete_store(&self, user: &User) -> DbResult<()> {
- let mut tx = self.pool.begin().await.map_err(fix_error)?;
+ let mut tx = self.pool.begin().await?;
sqlx::query(
"delete from store
@@ -184,8 +171,7 @@ impl Database for Postgres {
)
.bind(user.id)
.execute(&mut *tx)
- .await
- .map_err(fix_error)?;
+ .await?;
sqlx::query(
"delete from store_idx_cache
@@ -193,10 +179,9 @@ impl Database for Postgres {
)
.bind(user.id)
.execute(&mut *tx)
- .await
- .map_err(fix_error)?;
+ .await?;
- tx.commit().await.map_err(fix_error)?;
+ tx.commit().await?;
Ok(())
}
@@ -213,8 +198,7 @@ impl Database for Postgres {
.bind(id)
.bind(OffsetDateTime::now_utc())
.fetch_all(&self.pool)
- .await
- .map_err(fix_error)?;
+ .await?;
Ok(())
}
@@ -232,8 +216,7 @@ impl Database for Postgres {
)
.bind(user.id)
.fetch_all(self.read_pool())
- .await
- .map_err(fix_error)?;
+ .await?;
let res = res
.iter()
@@ -259,8 +242,7 @@ impl Database for Postgres {
.bind(into_utc(range.start))
.bind(into_utc(range.end))
.fetch_one(self.read_pool())
- .await
- .map_err(fix_error)?;
+ .await?;
Ok(res.0)
}
@@ -291,15 +273,14 @@ impl Database for Postgres {
.fetch(self.read_pool())
.map_ok(|DbHistory(h)| h)
.try_collect()
- .await
- .map_err(fix_error)?;
+ .await?;
Ok(res)
}
#[instrument(skip_all)]
async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> {
- let mut tx = self.pool.begin().await.map_err(fix_error)?;
+ let mut tx = self.pool.begin().await?;
for i in history {
let client_id: &str = &i.client_id;
@@ -319,11 +300,10 @@ impl Database for Postgres {
.bind(i.timestamp)
.bind(data)
.execute(&mut *tx)
- .await
- .map_err(fix_error)?;
+ .await?;
}
- tx.commit().await.map_err(fix_error)?;
+ tx.commit().await?;
Ok(())
}
@@ -333,32 +313,27 @@ impl Database for Postgres {
sqlx::query("delete from sessions where user_id = $1")
.bind(u.id)
.execute(&self.pool)
- .await
- .map_err(fix_error)?;
+ .await?;
sqlx::query("delete from history where user_id = $1")
.bind(u.id)
.execute(&self.pool)
- .await
- .map_err(fix_error)?;
+ .await?;
sqlx::query("delete from store where user_id = $1")
.bind(u.id)
.execute(&self.pool)
- .await
- .map_err(fix_error)?;
+ .await?;
sqlx::query("delete from total_history_count_user where user_id = $1")
.bind(u.id)
.execute(&self.pool)
- .await
- .map_err(fix_error)?;
+ .await?;
sqlx::query("delete from users where id = $1")
.bind(u.id)
.execute(&self.pool)
- .await
- .map_err(fix_error)?;
+ .await?;
Ok(())
}
@@ -373,8 +348,7 @@ impl Database for Postgres {
.bind(&user.password)
.bind(user.id)
.execute(&self.pool)
- .await
- .map_err(fix_error)?;
+ .await?;
Ok(())
}
@@ -395,8 +369,7 @@ impl Database for Postgres {
.bind(email)
.bind(password)
.fetch_one(&self.pool)
- .await
- .map_err(fix_error)?;
+ .await?;
Ok(res.0)
}
@@ -413,8 +386,7 @@ impl Database for Postgres {
.bind(session.user_id)
.bind(token)
.execute(&self.pool)
- .await
- .map_err(fix_error)?;
+ .await?;
Ok(())
}
@@ -425,7 +397,7 @@ impl Database for Postgres {
.bind(u.id)
.fetch_one(self.read_pool())
.await
- .map_err(fix_error)
+ .map_err(Into::into)
.map(|DbSession(session)| session)
}
@@ -440,13 +412,13 @@ impl Database for Postgres {
.bind(user.id)
.fetch_one(self.read_pool())
.await
- .map_err(fix_error)
+ .map_err(Into::into)
.map(|DbHistory(h)| h)
}
#[instrument(skip_all)]
async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> {
- let mut tx = self.pool.begin().await.map_err(fix_error)?;
+ let mut tx = self.pool.begin().await?;
// We won't have uploaded this data if it wasn't the max. Therefore, we can deduce the max
// idx without having to make further database queries. Doing the query on this small
@@ -478,8 +450,7 @@ impl Database for Postgres {
.bind(&i.data.content_encryption_key)
.bind(user.id)
.execute(&mut *tx)
- .await
- .map_err(fix_error)?;
+ .await?;
// Only update heads if we actually inserted the record
if result.rows_affected() > 0 {
@@ -509,10 +480,10 @@ impl Database for Postgres {
.bind(idx as i64)
.execute(&mut *tx)
.await
- .map_err(fix_error)?;
+ ?;
}
- tx.commit().await.map_err(fix_error)?;
+ tx.commit().await?;
Ok(())
}
@@ -545,7 +516,7 @@ impl Database for Postgres {
.bind(count as i64)
.fetch_all(self.read_pool())
.await
- .map_err(fix_error);
+ .map_err(Into::into);
let ret = match records {
Ok(records) => {
@@ -588,15 +559,13 @@ impl Database for Postgres {
sqlx::query_as("select host, tag, idx from store_idx_cache where user_id = $1")
.bind(user.id)
.fetch_all(self.read_pool())
- .await
- .map_err(fix_error)?
+ .await?
} else {
tracing::debug!("using aggregate query for user {}", user.id);
sqlx::query_as(STATUS_SQL)
.bind(user.id)
.fetch_all(self.read_pool())
- .await
- .map_err(fix_error)?
+ .await?
};
res.sort();