diff options
Diffstat (limited to 'crates/turtle/src/atuin_server/database/db')
| -rw-r--r-- | crates/turtle/src/atuin_server/database/db/mod.rs | 217 | ||||
| -rw-r--r-- | crates/turtle/src/atuin_server/database/db/wrappers.rs | 15 |
2 files changed, 43 insertions, 189 deletions
diff --git a/crates/turtle/src/atuin_server/database/db/mod.rs b/crates/turtle/src/atuin_server/database/db/mod.rs index e0c6b736..4ec51bf1 100644 --- a/crates/turtle/src/atuin_server/database/db/mod.rs +++ b/crates/turtle/src/atuin_server/database/db/mod.rs @@ -4,16 +4,13 @@ use rand::Rng; use crate::{ atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}, - atuin_server::database::{ - DbError, DbResult, DbSettings, - models::{NewSession, NewUser, Session, User}, - }, + atuin_server::database::{DbError, DbResult, DbSettings, models::User}, }; use sqlx::postgres::PgPoolOptions; use tracing::instrument; use uuid::Uuid; -use wrappers::{DbRecord, DbSession, DbUser}; +use wrappers::DbRecord; mod wrappers; @@ -96,148 +93,6 @@ impl Database { } #[instrument(skip_all)] - pub(crate) async fn get_user(&self, username: &str) -> DbResult<User> { - sqlx::query_as("select id, username, email, password from users where username = $1") - .bind(username) - .fetch_one(self.read_pool()) - .await - .map_err(Into::into) - .map(|DbUser(user)| user) - } - - #[instrument(skip_all)] - pub(crate) async fn get_session_user(&self, token: &str) -> DbResult<User> { - sqlx::query_as( - "select users.id, users.username, users.email, users.password from users - inner join sessions - on users.id = sessions.user_id - and sessions.token = $1", - ) - .bind(token) - .fetch_one(self.read_pool()) - .await - .map_err(Into::into) - .map(|DbUser(user)| user) - } - - pub(crate) async fn delete_store(&self, user: &User) -> DbResult<()> { - let mut tx = self.pool.begin().await?; - - sqlx::query( - "delete from store - where user_id = $1", - ) - .bind(user.id) - .execute(&mut *tx) - .await?; - - sqlx::query( - "delete from store_idx_cache - where user_id = $1", - ) - .bind(user.id) - .execute(&mut *tx) - .await?; - - tx.commit().await?; - - Ok(()) - } - - #[instrument(skip_all)] - pub(crate) async fn delete_user(&self, u: &User) -> DbResult<()> { - sqlx::query("delete from sessions where user_id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - sqlx::query("delete from history where user_id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - sqlx::query("delete from store where user_id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - sqlx::query("delete from total_history_count_user where user_id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - sqlx::query("delete from users where id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - pub(crate) async fn update_user_password(&self, user: &User) -> DbResult<()> { - sqlx::query( - "update users - set password = $1 - where id = $2", - ) - .bind(&user.password) - .bind(user.id) - .execute(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - pub(crate) async fn add_user(&self, user: &NewUser) -> DbResult<i64> { - let email: &str = &user.email; - let username: &str = &user.username; - let password: &str = &user.password; - - let res: (i64,) = sqlx::query_as( - "insert into users - (username, email, password) - values($1, $2, $3) - returning id", - ) - .bind(username) - .bind(email) - .bind(password) - .fetch_one(&self.pool) - .await?; - - Ok(res.0) - } - - #[instrument(skip_all)] - pub(crate) async fn add_session(&self, session: &NewSession) -> DbResult<()> { - let token: &str = &session.token; - - sqlx::query( - "insert into sessions - (user_id, token) - values($1, $2)", - ) - .bind(session.user_id) - .bind(token) - .execute(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - pub(crate) async fn get_user_session(&self, u: &User) -> DbResult<Session> { - sqlx::query_as("select id, user_id, token from sessions where user_id = $1") - .bind(u.id) - .fetch_one(self.read_pool()) - .await - .map_err(Into::into) - .map(|DbSession(session)| session) - } - - #[instrument(skip_all)] pub(crate) async fn add_records( &self, user: &User, @@ -258,10 +113,10 @@ impl Database { let id = crate::atuin_common::utils::uuid_v7(); let result = sqlx::query( - "insert into store - (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id) - values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) - on conflict do nothing + " + INSERT INTO store (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + ON conflict DO nothing ", ) .bind(id) @@ -293,10 +148,11 @@ impl Database { // we've built the map of heads for this push, so commit it to the database for ((host, tag), idx) in heads { sqlx::query( - "insert into store_idx_cache - (user_id, host, tag, idx) - values ($1, $2, $3, $4) - on conflict(user_id, host, tag) do update set idx = greatest(store_idx_cache.idx, $4) + " + INSERT INTO store_idx_cache (user_id, host, tag, idx) + VALUES ($1, $2, $3, $4) + ON conflict(user_id, host, tag) DO update + SET idx = greatest(store_idx_cache.idx, $4) ", ) .bind(user.id) @@ -304,8 +160,7 @@ impl Database { .bind(tag) .bind(idx as i64) .execute(&mut *tx) - .await - ?; + .await?; } tx.commit().await?; @@ -326,13 +181,15 @@ impl Database { let start = start.unwrap_or(0); let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as( - "select client_id, host, idx, timestamp, version, tag, data, cek from store - where user_id = $1 - and tag = $2 - and host = $3 - and idx >= $4 - order by idx asc - limit $5", + " + SELECT client_id, host, idx, timestamp, version, tag, data, cek FROM store + WHERE user_id = $1 + AND tag = $2 + AND host = $3 + AND idx >= $4 + ORDER BY idx asc + LIMIT $5 + ", ) .bind(user.id) .bind(tag.clone()) @@ -366,9 +223,6 @@ impl Database { } pub(crate) async fn status(&self, user: &User) -> DbResult<RecordStatus> { - const STATUS_SQL: &str = - "select host, tag, max(idx) from store where user_id = $1 group by host, tag"; - // If IDX_CACHE_ROLLOUT is set, then we // 1. Read the value of the var, use it as a % chance of using the cache // 2. If we use the cache, just read from the cache table @@ -381,16 +235,29 @@ impl Database { let mut res: Vec<(Uuid, String, i64)> = if use_idx_cache { tracing::debug!("using idx cache for user {}", user.id); - sqlx::query_as("select host, tag, idx from store_idx_cache where user_id = $1") - .bind(user.id) - .fetch_all(self.read_pool()) - .await? + sqlx::query_as( + " + SELECT host, tag, idx + FROM store_idx_cache + WHERE user_id = $1 + ", + ) + .bind(user.id) + .fetch_all(self.read_pool()) + .await? } else { tracing::debug!("using aggregate query for user {}", user.id); - sqlx::query_as(STATUS_SQL) - .bind(user.id) - .fetch_all(self.read_pool()) - .await? + sqlx::query_as( + " + SELECT host, tag, max(idx) + FROM store + WHERE user_id = $1 + GROUP BY host, tag + ", + ) + .bind(user.id) + .fetch_all(self.read_pool()) + .await? }; res.sort(); diff --git a/crates/turtle/src/atuin_server/database/db/wrappers.rs b/crates/turtle/src/atuin_server/database/db/wrappers.rs index c0633202..40fd5b4a 100644 --- a/crates/turtle/src/atuin_server/database/db/wrappers.rs +++ b/crates/turtle/src/atuin_server/database/db/wrappers.rs @@ -1,25 +1,12 @@ use crate::{ atuin_common::record::{EncryptedData, Host, Record}, - atuin_server::database::models::{Session, User}, + atuin_server::database::models::Session, }; -use ::sqlx::{FromRow, Result}; use sqlx::{Row, postgres::PgRow}; -pub struct DbUser(pub User); pub struct DbSession(pub Session); pub struct DbRecord(pub Record<EncryptedData>); -impl<'a> FromRow<'a, PgRow> for DbUser { - fn from_row(row: &'a PgRow) -> Result<Self> { - Ok(Self(User { - id: row.try_get("id")?, - username: row.try_get("username")?, - email: row.try_get("email")?, - password: row.try_get("password")?, - })) - } -} - impl<'a> ::sqlx::FromRow<'a, PgRow> for DbSession { fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> { Ok(Self(Session { |
