diff options
Diffstat (limited to '')
18 files changed, 292 insertions, 1180 deletions
diff --git a/crates/turtle/src/atuin_server_database/calendar.rs b/crates/turtle/src/atuin_server/database/calendar.rs index f1c78262..f1c78262 100644 --- a/crates/turtle/src/atuin_server_database/calendar.rs +++ b/crates/turtle/src/atuin_server/database/calendar.rs diff --git a/crates/turtle/src/atuin_server_postgres/mod.rs b/crates/turtle/src/atuin_server/database/db/mod.rs index e06f8721..22d69d3c 100644 --- a/crates/turtle/src/atuin_server_postgres/mod.rs +++ b/crates/turtle/src/atuin_server/database/db/mod.rs @@ -3,17 +3,20 @@ use std::ops::Range; use rand::Rng; -use crate::atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}; -use crate::atuin_server_database::models::{ - History, NewHistory, NewSession, NewUser, Session, User, +use crate::{ + atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}, + atuin_server::database::{ + DbError, DbResult, DbSettings, + calendar::{TimePeriod, TimePeriodInfo}, + into_utc, + models::{History, NewHistory, NewSession, NewUser, Session, User}, + }, }; -use crate::atuin_server_database::{Database, DbError, DbResult, DbSettings, into_utc}; -use async_trait::async_trait; use futures_util::TryStreamExt; use sqlx::Row; use sqlx::postgres::PgPoolOptions; +use time::{Date, Duration, Month, OffsetDateTime, Time, UtcOffset}; -use time::OffsetDateTime; use tracing::instrument; use uuid::Uuid; use wrappers::{DbHistory, DbRecord, DbSession, DbUser}; @@ -23,13 +26,13 @@ mod wrappers; const MIN_PG_VERSION: u32 = 14; #[derive(Clone)] -pub(crate) struct Postgres { +pub struct Database { pool: sqlx::Pool<sqlx::postgres::Postgres>, /// Optional read replica pool for read-only queries read_pool: Option<sqlx::Pool<sqlx::postgres::Postgres>>, } -impl Postgres { +impl Database { /// Returns the appropriate pool for read operations. /// Uses read_pool if available, otherwise falls back to the primary pool. fn read_pool(&self) -> &sqlx::Pool<sqlx::postgres::Postgres> { @@ -37,9 +40,8 @@ impl Postgres { } } -#[async_trait] -impl Database for Postgres { - async fn new(settings: &DbSettings) -> DbResult<Self> { +impl Database { + pub(crate) async fn new(settings: &DbSettings) -> DbResult<Self> { let pool = PgPoolOptions::new() .max_connections(100) .connect(settings.db_uri.as_str()) @@ -100,7 +102,85 @@ impl Database for Postgres { } #[instrument(skip_all)] - async fn get_session(&self, token: &str) -> DbResult<Session> { + pub(crate) async fn calendar( + &self, + user: &User, + period: TimePeriod, + tz: UtcOffset, + ) -> DbResult<HashMap<u64, TimePeriodInfo>> { + let mut ret = HashMap::new(); + let iter: Box<dyn Iterator<Item = DbResult<(u64, Range<Date>)>> + Send> = match period { + TimePeriod::Year => { + // First we need to work out how far back to calculate. Get the + // oldest history item + let oldest = self + .oldest_history(user) + .await? + .timestamp + .to_offset(tz) + .year(); + let current_year = OffsetDateTime::now_utc().to_offset(tz).year(); + + // All the years we need to get data for + // The upper bound is exclusive, so include current +1 + let years = oldest..current_year + 1; + + Box::new(years.map(|year| { + let start = Date::from_calendar_date(year, time::Month::January, 1)?; + let end = Date::from_calendar_date(year + 1, time::Month::January, 1)?; + + Ok((year as u64, start..end)) + })) + } + + TimePeriod::Month { year } => { + let months = + std::iter::successors(Some(Month::January), |m| Some(m.next())).take(12); + + Box::new(months.map(move |month| { + let start = Date::from_calendar_date(year, month, 1)?; + let days = start.month().length(year); + let end = start + Duration::days(days as i64); + + Ok((month as u64, start..end)) + })) + } + + TimePeriod::Day { year, month } => { + let days = 1..month.length(year); + Box::new(days.map(move |day| { + let start = Date::from_calendar_date(year, month, day)?; + let end = start + .next_day() + .ok_or_else(|| DbError::Other(eyre::eyre!("no next day?")))?; + + Ok((day as u64, start..end)) + })) + } + }; + + for x in iter { + let (index, range) = x?; + + let start = range.start.with_time(Time::MIDNIGHT).assume_offset(tz); + let end = range.end.with_time(Time::MIDNIGHT).assume_offset(tz); + + let count = self.count_history_range(user, start..end).await?; + + ret.insert( + index, + TimePeriodInfo { + count: count as u64, + hash: "".to_string(), + }, + ); + } + + Ok(ret) + } + + #[instrument(skip_all)] + pub(crate) async fn get_session(&self, token: &str) -> DbResult<Session> { sqlx::query_as("select id, user_id, token from sessions where token = $1") .bind(token) .fetch_one(self.read_pool()) @@ -110,7 +190,7 @@ impl Database for Postgres { } #[instrument(skip_all)] - async fn get_user(&self, username: &str) -> DbResult<User> { + pub(crate) async fn get_user(&self, username: &str) -> DbResult<User> { sqlx::query_as("select id, username, email, password from users where username = $1") .bind(username) .fetch_one(self.read_pool()) @@ -120,7 +200,7 @@ impl Database for Postgres { } #[instrument(skip_all)] - async fn get_session_user(&self, token: &str) -> DbResult<User> { + pub(crate) async fn get_session_user(&self, token: &str) -> DbResult<User> { sqlx::query_as( "select users.id, users.username, users.email, users.password from users inner join sessions @@ -135,7 +215,7 @@ impl Database for Postgres { } #[instrument(skip_all)] - async fn count_history(&self, user: &User) -> DbResult<i64> { + pub(crate) async fn count_history(&self, user: &User) -> DbResult<i64> { // The cache is new, and the user might not yet have a cache value. // They will have one as soon as they post up some new history, but handle that // edge case. @@ -152,7 +232,7 @@ impl Database for Postgres { } #[instrument(skip_all)] - async fn count_history_cached(&self, user: &User) -> DbResult<i64> { + pub(crate) async fn count_history_cached(&self, user: &User) -> DbResult<i64> { let res: (i32,) = sqlx::query_as( "select total from total_history_count_user where user_id = $1", @@ -164,7 +244,7 @@ impl Database for Postgres { Ok(res.0 as i64) } - async fn delete_store(&self, user: &User) -> DbResult<()> { + pub(crate) async fn delete_store(&self, user: &User) -> DbResult<()> { let mut tx = self.pool.begin().await?; sqlx::query( @@ -188,7 +268,7 @@ impl Database for Postgres { Ok(()) } - async fn delete_history(&self, user: &User, id: String) -> DbResult<()> { + pub(crate) async fn delete_history(&self, user: &User, id: String) -> DbResult<()> { sqlx::query( "update history set deleted_at = $3 @@ -206,7 +286,7 @@ impl Database for Postgres { } #[instrument(skip_all)] - async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> { + pub(crate) async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> { // The cache is new, and the user might not yet have a cache value. // They will have one as soon as they post up some new history, but handle that // edge case. @@ -229,7 +309,7 @@ impl Database for Postgres { } #[instrument(skip_all)] - async fn count_history_range( + pub(crate) async fn count_history_range( &self, user: &User, range: Range<OffsetDateTime>, @@ -250,7 +330,7 @@ impl Database for Postgres { } #[instrument(skip_all)] - async fn list_history( + pub(crate) async fn list_history( &self, user: &User, created_after: OffsetDateTime, @@ -281,7 +361,7 @@ impl Database for Postgres { } #[instrument(skip_all)] - async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> { + pub(crate) async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> { let mut tx = self.pool.begin().await?; for i in history { @@ -311,7 +391,7 @@ impl Database for Postgres { } #[instrument(skip_all)] - async fn delete_user(&self, u: &User) -> DbResult<()> { + pub(crate) async fn delete_user(&self, u: &User) -> DbResult<()> { sqlx::query("delete from sessions where user_id = $1") .bind(u.id) .execute(&self.pool) @@ -341,7 +421,7 @@ impl Database for Postgres { } #[instrument(skip_all)] - async fn update_user_password(&self, user: &User) -> DbResult<()> { + pub(crate) async fn update_user_password(&self, user: &User) -> DbResult<()> { sqlx::query( "update users set password = $1 @@ -356,7 +436,7 @@ impl Database for Postgres { } #[instrument(skip_all)] - async fn add_user(&self, user: &NewUser) -> DbResult<i64> { + pub(crate) async fn add_user(&self, user: &NewUser) -> DbResult<i64> { let email: &str = &user.email; let username: &str = &user.username; let password: &str = &user.password; @@ -377,7 +457,7 @@ impl Database for Postgres { } #[instrument(skip_all)] - async fn add_session(&self, session: &NewSession) -> DbResult<()> { + pub(crate) async fn add_session(&self, session: &NewSession) -> DbResult<()> { let token: &str = &session.token; sqlx::query( @@ -394,7 +474,7 @@ impl Database for Postgres { } #[instrument(skip_all)] - async fn get_user_session(&self, u: &User) -> DbResult<Session> { + pub(crate) async fn get_user_session(&self, u: &User) -> DbResult<Session> { sqlx::query_as("select id, user_id, token from sessions where user_id = $1") .bind(u.id) .fetch_one(self.read_pool()) @@ -404,7 +484,7 @@ impl Database for Postgres { } #[instrument(skip_all)] - async fn oldest_history(&self, user: &User) -> DbResult<History> { + pub(crate) async fn oldest_history(&self, user: &User) -> DbResult<History> { sqlx::query_as( "select id, client_id, user_id, hostname, timestamp, data, created_at from history where user_id = $1 @@ -419,7 +499,11 @@ impl Database for Postgres { } #[instrument(skip_all)] - async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> { + pub(crate) async fn add_records( + &self, + user: &User, + records: &[Record<EncryptedData>], + ) -> DbResult<()> { let mut tx = self.pool.begin().await?; // We won't have uploaded this data if it wasn't the max. Therefore, we can deduce the max @@ -491,7 +575,7 @@ impl Database for Postgres { } #[instrument(skip_all)] - async fn next_records( + pub(crate) async fn next_records( &self, user: &User, host: HostId, @@ -542,7 +626,7 @@ impl Database for Postgres { Ok(ret) } - async fn status(&self, user: &User) -> DbResult<RecordStatus> { + pub(crate) async fn status(&self, user: &User) -> DbResult<RecordStatus> { const STATUS_SQL: &str = "select host, tag, max(idx) from store where user_id = $1 group by host, tag"; diff --git a/crates/turtle/src/atuin_server_postgres/wrappers.rs b/crates/turtle/src/atuin_server/database/db/wrappers.rs index ba7a9435..de4c5814 100644 --- a/crates/turtle/src/atuin_server_postgres/wrappers.rs +++ b/crates/turtle/src/atuin_server/database/db/wrappers.rs @@ -1,13 +1,15 @@ +use crate::{ + atuin_common::record::{EncryptedData, Host, Record}, + atuin_server::database::models::{History, Session, User}, +}; use ::sqlx::{FromRow, Result}; -use crate::atuin_common::record::{EncryptedData, Host, Record}; -use crate::atuin_server_database::models::{History, Session, User}; use sqlx::{Row, postgres::PgRow}; use time::PrimitiveDateTime; -pub(crate) struct DbUser(pub(crate) User); -pub(crate) struct DbSession(pub(crate) Session); -pub(crate) struct DbHistory(pub(crate) History); -pub(crate) struct DbRecord(pub(crate) Record<EncryptedData>); +pub struct DbUser(pub User); +pub struct DbSession(pub Session); +pub struct DbHistory(pub History); +pub struct DbRecord(pub Record<EncryptedData>); impl<'a> FromRow<'a, PgRow> for DbUser { fn from_row(row: &'a PgRow) -> Result<Self> { diff --git a/crates/turtle/src/atuin_server/database/mod.rs b/crates/turtle/src/atuin_server/database/mod.rs new file mode 100644 index 00000000..845d67d7 --- /dev/null +++ b/crates/turtle/src/atuin_server/database/mod.rs @@ -0,0 +1,123 @@ +pub(crate) mod calendar; +pub(crate) mod db; +pub(crate) mod models; + +use std::fmt::{Debug, Display}; + +use serde::{Deserialize, Serialize}; +use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset}; + +#[derive(Debug)] +pub(crate) enum DbError { + NotFound, + Other(eyre::Report), +} + +impl Display for DbError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{self:?}") + } +} + +impl From<time::error::ComponentRange> for DbError { + fn from(error: time::error::ComponentRange) -> Self { + DbError::Other(error.into()) + } +} + +impl From<time::error::Error> for DbError { + fn from(error: time::error::Error) -> Self { + DbError::Other(error.into()) + } +} + +impl From<sqlx::Error> for DbError { + fn from(error: sqlx::Error) -> Self { + match error { + sqlx::Error::RowNotFound => DbError::NotFound, + error => DbError::Other(error.into()), + } + } +} + +impl std::error::Error for DbError {} + +pub(crate) type DbResult<T> = Result<T, DbError>; + +#[derive(Debug, PartialEq)] +pub(crate) enum DbType { + Postgres, + Unknown, +} + +#[derive(Clone, Deserialize, Serialize)] +pub(crate) struct DbSettings { + pub(crate) db_uri: String, + /// Optional URI for read replicas. If set, read-only queries will use this connection. + pub(crate) read_db_uri: Option<String>, +} + +impl DbSettings { + pub(crate) fn db_type(&self) -> DbType { + if self.db_uri.starts_with("postgres://") || self.db_uri.starts_with("postgresql://") { + DbType::Postgres + } else { + DbType::Unknown + } + } +} + +fn redact_db_uri(uri: &str) -> String { + url::Url::parse(uri) + .map(|mut url| { + let _ = url.set_password(Some("****")); + url.to_string() + }) + .unwrap_or_else(|_| uri.to_string()) +} + +// Do our best to redact passwords so they're not logged in the event of an error. +impl Debug for DbSettings { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.db_type() == DbType::Postgres { + let redacted_uri = redact_db_uri(&self.db_uri); + let redacted_read_uri = self.read_db_uri.as_ref().map(|uri| redact_db_uri(uri)); + f.debug_struct("DbSettings") + .field("db_uri", &redacted_uri) + .field("read_db_uri", &redacted_read_uri) + .finish() + } else { + f.debug_struct("DbSettings") + .field("db_uri", &self.db_uri) + .field("read_db_uri", &self.read_db_uri) + .finish() + } + } +} + +pub(crate) fn into_utc(x: OffsetDateTime) -> PrimitiveDateTime { + let x = x.to_offset(UtcOffset::UTC); + PrimitiveDateTime::new(x.date(), x.time()) +} + +#[cfg(test)] +mod tests { + use time::macros::datetime; + + use super::into_utc; + + #[test] + fn utc() { + let dt = datetime!(2023-09-26 15:11:02 +05:30); + assert_eq!(into_utc(dt), datetime!(2023-09-26 09:41:02)); + assert_eq!(into_utc(dt).assume_utc(), dt); + + let dt = datetime!(2023-09-26 15:11:02 -07:00); + assert_eq!(into_utc(dt), datetime!(2023-09-26 22:11:02)); + assert_eq!(into_utc(dt).assume_utc(), dt); + + let dt = datetime!(2023-09-26 15:11:02 +00:00); + assert_eq!(into_utc(dt), datetime!(2023-09-26 15:11:02)); + assert_eq!(into_utc(dt).assume_utc(), dt); + } +} diff --git a/crates/turtle/src/atuin_server_database/models.rs b/crates/turtle/src/atuin_server/database/models.rs index e47d614d..e47d614d 100644 --- a/crates/turtle/src/atuin_server_database/models.rs +++ b/crates/turtle/src/atuin_server/database/models.rs diff --git a/crates/turtle/src/atuin_server/handlers/history.rs b/crates/turtle/src/atuin_server/handlers/history.rs deleted file mode 100644 index e5057bcb..00000000 --- a/crates/turtle/src/atuin_server/handlers/history.rs +++ /dev/null @@ -1,237 +0,0 @@ -use std::{collections::HashMap, convert::TryFrom}; - -use axum::{ - Json, - extract::{Path, Query, State}, - http::{HeaderMap, StatusCode}, -}; -use metrics::counter; -use time::{Month, UtcOffset}; -use tracing::{debug, error, instrument}; - -use super::{ErrorResponse, ErrorResponseStatus, RespExt}; -use crate::atuin_server::{ - router::{AppState, UserAuth}, - utils::client_version_min, -}; -use crate::atuin_server_database::{ - Database, - calendar::{TimePeriod, TimePeriodInfo}, - models::NewHistory, -}; - -use crate::atuin_common::api::*; - -#[instrument(skip_all, fields(user.id = user.id))] -pub(crate) async fn count<DB: Database>( - UserAuth(user): UserAuth, - state: State<AppState<DB>>, -) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> { - let db = &state.0.database; - match db.count_history_cached(&user).await { - // By default read out the cached value - Ok(count) => Ok(Json(CountResponse { count })), - - // If that fails, fallback on a full COUNT. Cache is built on a POST - // only - Err(_) => match db.count_history(&user).await { - Ok(count) => Ok(Json(CountResponse { count })), - Err(_) => Err(ErrorResponse::reply("failed to query history count") - .with_status(StatusCode::INTERNAL_SERVER_ERROR)), - }, - } -} - -#[instrument(skip_all, fields(user.id = user.id))] -pub(crate) async fn list<DB: Database>( - req: Query<SyncHistoryRequest>, - UserAuth(user): UserAuth, - headers: HeaderMap, - state: State<AppState<DB>>, -) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> { - let db = &state.0.database; - - let agent = headers - .get("user-agent") - .map_or("", |v| v.to_str().unwrap_or("")); - - let variable_page_size = client_version_min(agent, ">=15.0.0").unwrap_or(false); - - let page_size = if variable_page_size { - state.settings.page_size - } else { - 100 - }; - - if req.sync_ts.unix_timestamp_nanos() < 0 || req.history_ts.unix_timestamp_nanos() < 0 { - error!("client asked for history from < epoch 0"); - counter!("atuin_history_epoch_before_zero").increment(1); - - return Err( - ErrorResponse::reply("asked for history from before epoch 0") - .with_status(StatusCode::BAD_REQUEST), - ); - } - - let history = db - .list_history(&user, req.sync_ts, req.history_ts, &req.host, page_size) - .await; - - if let Err(e) = history { - error!("failed to load history: {}", e); - return Err(ErrorResponse::reply("failed to load history") - .with_status(StatusCode::INTERNAL_SERVER_ERROR)); - } - - let history: Vec<String> = history - .unwrap() - .iter() - .map(|i| i.data.to_string()) - .collect(); - - debug!( - "loaded {} items of history for user {}", - history.len(), - user.id - ); - - counter!("atuin_history_returned").increment(history.len() as u64); - - Ok(Json(SyncHistoryResponse { history })) -} - -#[instrument(skip_all, fields(user.id = user.id))] -pub(crate) async fn delete<DB: Database>( - UserAuth(user): UserAuth, - state: State<AppState<DB>>, - Json(req): Json<DeleteHistoryRequest>, -) -> Result<Json<MessageResponse>, ErrorResponseStatus<'static>> { - let db = &state.0.database; - - // user_id is the ID of the history, as set by the user (the server has its own ID) - let deleted = db.delete_history(&user, req.client_id).await; - - if let Err(e) = deleted { - error!("failed to delete history: {}", e); - return Err(ErrorResponse::reply("failed to delete history") - .with_status(StatusCode::INTERNAL_SERVER_ERROR)); - } - - Ok(Json(MessageResponse { - message: String::from("deleted OK"), - })) -} - -#[instrument(skip_all, fields(user.id = user.id))] -pub(crate) async fn add<DB: Database>( - UserAuth(user): UserAuth, - state: State<AppState<DB>>, - Json(req): Json<Vec<AddHistoryRequest>>, -) -> Result<(), ErrorResponseStatus<'static>> { - let State(AppState { database, settings }) = state; - - debug!("request to add {} history items", req.len()); - counter!("atuin_history_uploaded").increment(req.len() as u64); - - let mut history: Vec<NewHistory> = req - .into_iter() - .map(|h| NewHistory { - client_id: h.id, - user_id: user.id, - hostname: h.hostname, - timestamp: h.timestamp, - data: h.data, - }) - .collect(); - - history.retain(|h| { - // keep if within limit, or limit is 0 (unlimited) - let keep = h.data.len() <= settings.max_history_length || settings.max_history_length == 0; - - // Don't return an error here. We want to insert as much of the - // history list as we can, so log the error and continue going. - if !keep { - counter!("atuin_history_too_long").increment(1); - - tracing::warn!( - "history too long, got length {}, max {}", - h.data.len(), - settings.max_history_length - ); - } - - keep - }); - - if let Err(e) = database.add_history(&history).await { - error!("failed to add history: {}", e); - - return Err(ErrorResponse::reply("failed to add history") - .with_status(StatusCode::INTERNAL_SERVER_ERROR)); - }; - - Ok(()) -} - -#[derive(serde::Deserialize, Debug)] -pub(crate) struct CalendarQuery { - #[serde(default = "serde_calendar::zero")] - year: i32, - #[serde(default = "serde_calendar::one")] - month: u8, - - #[serde(default = "serde_calendar::utc")] - tz: UtcOffset, -} - -mod serde_calendar { - use time::UtcOffset; - - pub(crate) fn zero() -> i32 { - 0 - } - - pub(crate) fn one() -> u8 { - 1 - } - - pub(crate) fn utc() -> UtcOffset { - UtcOffset::UTC - } -} - -#[instrument(skip_all, fields(user.id = user.id))] -pub(crate) async fn calendar<DB: Database>( - Path(focus): Path<String>, - Query(params): Query<CalendarQuery>, - UserAuth(user): UserAuth, - state: State<AppState<DB>>, -) -> Result<Json<HashMap<u64, TimePeriodInfo>>, ErrorResponseStatus<'static>> { - let focus = focus.as_str(); - - let year = params.year; - let month = Month::try_from(params.month).map_err(|e| ErrorResponseStatus { - error: ErrorResponse { - reason: e.to_string().into(), - }, - status: StatusCode::BAD_REQUEST, - })?; - - let period = match focus { - "year" => TimePeriod::Year, - "month" => TimePeriod::Month { year }, - "day" => TimePeriod::Day { year, month }, - _ => { - return Err(ErrorResponse::reply("invalid focus: use year/month/day") - .with_status(StatusCode::BAD_REQUEST)); - } - }; - - let db = &state.0.database; - let focus = db.calendar(&user, period, params.tz).await.map_err(|_| { - ErrorResponse::reply("failed to query calendar") - .with_status(StatusCode::INTERNAL_SERVER_ERROR) - })?; - - Ok(Json(focus)) -} diff --git a/crates/turtle/src/atuin_server/handlers/mod.rs b/crates/turtle/src/atuin_server/handlers/mod.rs index 322324c4..3b935834 100644 --- a/crates/turtle/src/atuin_server/handlers/mod.rs +++ b/crates/turtle/src/atuin_server/handlers/mod.rs @@ -1,19 +1,16 @@ use crate::atuin_common::api::{ErrorResponse, IndexResponse}; -use crate::atuin_server_database::Database; use axum::{Json, extract::State, http, response::IntoResponse}; use crate::atuin_server::router::AppState; pub(crate) mod health; -pub(crate) mod history; pub(crate) mod record; -pub(crate) mod status; pub(crate) mod user; pub(crate) mod v0; const VERSION: &str = env!("CARGO_PKG_VERSION"); -pub(crate) async fn index<DB: Database>(state: State<AppState<DB>>) -> Json<IndexResponse> { +pub(crate) async fn index(state: State<AppState>) -> Json<IndexResponse> { let homage = r#""Through the fathomless deeps of space swims the star turtle Great A'Tuin, bearing on its back the four giant elephants who carry on their shoulders the mass of the Discworld." -- Sir Terry Pratchett"#; let version = state diff --git a/crates/turtle/src/atuin_server/handlers/status.rs b/crates/turtle/src/atuin_server/handlers/status.rs deleted file mode 100644 index 59be1e5c..00000000 --- a/crates/turtle/src/atuin_server/handlers/status.rs +++ /dev/null @@ -1,45 +0,0 @@ -use axum::{Json, extract::State, http::StatusCode}; -use tracing::instrument; - -use super::{ErrorResponse, ErrorResponseStatus, RespExt}; -use crate::atuin_server::router::{AppState, UserAuth}; -use crate::atuin_server_database::Database; - -use crate::atuin_common::api::*; - -const VERSION: &str = env!("CARGO_PKG_VERSION"); - -#[instrument(skip_all, fields(user.id = user.id))] -pub(crate) async fn status<DB: Database>( - UserAuth(user): UserAuth, - state: State<AppState<DB>>, -) -> Result<Json<StatusResponse>, ErrorResponseStatus<'static>> { - let db = &state.0.database; - - let deleted = db.deleted_history(&user).await.unwrap_or(vec![]); - - let count = match db.count_history_cached(&user).await { - // By default read out the cached value - Ok(count) => count, - - // If that fails, fallback on a full COUNT. Cache is built on a POST - // only - Err(_) => match db.count_history(&user).await { - Ok(count) => count, - Err(_) => { - return Err(ErrorResponse::reply("failed to query history count") - .with_status(StatusCode::INTERNAL_SERVER_ERROR)); - } - }, - }; - - tracing::debug!(user = user.username, "requested sync status"); - - Ok(Json(StatusResponse { - count, - deleted, - username: user.username, - version: VERSION.to_string(), - page_size: state.settings.page_size, - })) -} diff --git a/crates/turtle/src/atuin_server/handlers/user.rs b/crates/turtle/src/atuin_server/handlers/user.rs index 7708d43e..28cebfab 100644 --- a/crates/turtle/src/atuin_server/handlers/user.rs +++ b/crates/turtle/src/atuin_server/handlers/user.rs @@ -16,14 +16,16 @@ use metrics::counter; use rand::rngs::OsRng; use tracing::{debug, error, info, instrument}; -use crate::atuin_common::tls::ensure_crypto_provider; +use crate::{ + atuin_common::tls::ensure_crypto_provider, + atuin_server::database::{ + DbError, + models::{NewSession, NewUser}, + }, +}; use super::{ErrorResponse, ErrorResponseStatus, RespExt}; use crate::atuin_server::router::{AppState, UserAuth}; -use crate::atuin_server_database::{ - Database, DbError, - models::{NewSession, NewUser}, -}; use reqwest::header::CONTENT_TYPE; @@ -63,9 +65,9 @@ async fn send_register_hook(url: &str, username: String, registered: String) { } #[instrument(skip_all, fields(user.username = username.as_str()))] -pub(crate) async fn get<DB: Database>( +pub(crate) async fn get( Path(username): Path<String>, - state: State<AppState<DB>>, + state: State<AppState>, ) -> Result<Json<UserResponse>, ErrorResponseStatus<'static>> { let db = &state.0.database; let user = match db.get_user(username.as_ref()).await { @@ -87,8 +89,8 @@ pub(crate) async fn get<DB: Database>( } #[instrument(skip_all)] -pub(crate) async fn register<DB: Database>( - state: State<AppState<DB>>, +pub(crate) async fn register( + state: State<AppState>, Json(register): Json<RegisterRequest>, ) -> Result<Json<RegisterResponse>, ErrorResponseStatus<'static>> { if !state.settings.open_registration { @@ -163,9 +165,9 @@ pub(crate) async fn register<DB: Database>( } #[instrument(skip_all, fields(user.id = user.id))] -pub(crate) async fn delete<DB: Database>( +pub(crate) async fn delete( UserAuth(user): UserAuth, - state: State<AppState<DB>>, + state: State<AppState>, ) -> Result<Json<DeleteUserResponse>, ErrorResponseStatus<'static>> { debug!("request to delete user {}", user.id); @@ -183,9 +185,9 @@ pub(crate) async fn delete<DB: Database>( } #[instrument(skip_all, fields(user.id = user.id, change_password))] -pub(crate) async fn change_password<DB: Database>( +pub(crate) async fn change_password( UserAuth(mut user): UserAuth, - state: State<AppState<DB>>, + state: State<AppState>, Json(change_password): Json<ChangePasswordRequest>, ) -> Result<Json<ChangePasswordResponse>, ErrorResponseStatus<'static>> { let db = &state.0.database; @@ -213,8 +215,8 @@ pub(crate) async fn change_password<DB: Database>( } #[instrument(skip_all, fields(user.username = login.username.as_str()))] -pub(crate) async fn login<DB: Database>( - state: State<AppState<DB>>, +pub(crate) async fn login( + state: State<AppState>, login: Json<LoginRequest>, ) -> Result<Json<LoginResponse>, ErrorResponseStatus<'static>> { let db = &state.0.database; diff --git a/crates/turtle/src/atuin_server/handlers/v0/record.rs b/crates/turtle/src/atuin_server/handlers/v0/record.rs index 2cc09118..88027547 100644 --- a/crates/turtle/src/atuin_server/handlers/v0/record.rs +++ b/crates/turtle/src/atuin_server/handlers/v0/record.rs @@ -7,14 +7,13 @@ use crate::atuin_server::{ handlers::{ErrorResponse, ErrorResponseStatus, RespExt}, router::{AppState, UserAuth}, }; -use crate::atuin_server_database::Database; use crate::atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}; #[instrument(skip_all, fields(user.id = user.id))] -pub(crate) async fn post<DB: Database>( +pub(crate) async fn post( UserAuth(user): UserAuth, - state: State<AppState<DB>>, + state: State<AppState>, Json(records): Json<Vec<Record<EncryptedData>>>, ) -> Result<(), ErrorResponseStatus<'static>> { let State(AppState { database, settings }) = state; @@ -51,9 +50,9 @@ pub(crate) async fn post<DB: Database>( } #[instrument(skip_all, fields(user.id = user.id))] -pub(crate) async fn index<DB: Database>( +pub(crate) async fn index( UserAuth(user): UserAuth, - state: State<AppState<DB>>, + state: State<AppState>, ) -> Result<Json<RecordStatus>, ErrorResponseStatus<'static>> { let State(AppState { database, @@ -84,10 +83,10 @@ pub(crate) struct NextParams { } #[instrument(skip_all, fields(user.id = user.id))] -pub(crate) async fn next<DB: Database>( +pub(crate) async fn next( params: Query<NextParams>, UserAuth(user): UserAuth, - state: State<AppState<DB>>, + state: State<AppState>, ) -> Result<Json<Vec<Record<EncryptedData>>>, ErrorResponseStatus<'static>> { let State(AppState { database, diff --git a/crates/turtle/src/atuin_server/handlers/v0/store.rs b/crates/turtle/src/atuin_server/handlers/v0/store.rs index 8269d6b3..f0aa1b36 100644 --- a/crates/turtle/src/atuin_server/handlers/v0/store.rs +++ b/crates/turtle/src/atuin_server/handlers/v0/store.rs @@ -7,16 +7,15 @@ use crate::atuin_server::{ handlers::{ErrorResponse, ErrorResponseStatus, RespExt}, router::{AppState, UserAuth}, }; -use crate::atuin_server_database::Database; #[derive(Deserialize)] pub(crate) struct DeleteParams {} #[instrument(skip_all, fields(user.id = user.id))] -pub(crate) async fn delete<DB: Database>( +pub(crate) async fn delete( _params: Query<DeleteParams>, UserAuth(user): UserAuth, - state: State<AppState<DB>>, + state: State<AppState>, ) -> Result<(), ErrorResponseStatus<'static>> { let State(AppState { database, diff --git a/crates/turtle/src/atuin_server/mod.rs b/crates/turtle/src/atuin_server/mod.rs index ad480e1d..c96a13bc 100644 --- a/crates/turtle/src/atuin_server/mod.rs +++ b/crates/turtle/src/atuin_server/mod.rs @@ -1,14 +1,14 @@ use std::future::Future; use std::net::SocketAddr; -use crate::atuin_server_database::Database; use axum::{Router, serve}; +use database::db::Database; use eyre::{Context, Result}; +pub(crate) mod database; mod handlers; mod metrics; mod router; -mod utils; pub(crate) use settings::Settings; @@ -31,8 +31,8 @@ async fn shutdown_signal() { eprintln!("Shutting down gracefully..."); } -pub(crate) async fn launch<Db: Database>(settings: Settings, addr: SocketAddr) -> Result<()> { - launch_with_tcp_listener::<Db>( +pub(crate) async fn launch(settings: Settings, addr: SocketAddr) -> Result<()> { + launch_with_tcp_listener( settings, TcpListener::bind(addr) .await @@ -42,12 +42,12 @@ pub(crate) async fn launch<Db: Database>(settings: Settings, addr: SocketAddr) - .await } -pub(crate) async fn launch_with_tcp_listener<Db: Database>( +pub(crate) async fn launch_with_tcp_listener( settings: Settings, listener: TcpListener, shutdown: impl Future<Output = ()> + Send + 'static, ) -> Result<()> { - let r = make_router::<Db>(settings).await?; + let r = make_router(settings).await?; serve(listener, r.into_make_service()) .with_graceful_shutdown(shutdown) @@ -77,8 +77,8 @@ pub(crate) async fn launch_metrics_server(host: String, port: u16) -> Result<()> Ok(()) } -async fn make_router<Db: Database>(settings: Settings) -> Result<Router, eyre::Error> { - let db = Db::new(&settings.db_settings) +async fn make_router(settings: Settings) -> Result<Router, eyre::Error> { + let db = Database::new(&settings.db_settings) .await .wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?; let r = router::router(db, settings); diff --git a/crates/turtle/src/atuin_server/router.rs b/crates/turtle/src/atuin_server/router.rs index ed3d1e55..778e699a 100644 --- a/crates/turtle/src/atuin_server/router.rs +++ b/crates/turtle/src/atuin_server/router.rs @@ -1,4 +1,7 @@ -use crate::atuin_common::api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ErrorResponse}; +use crate::{ + atuin_common::api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ErrorResponse}, + atuin_server::database::{DbError, db::Database, models::User}, +}; use axum::{ Router, extract::{FromRequestParts, Request}, @@ -17,19 +20,15 @@ use crate::atuin_server::{ metrics, settings::Settings, }; -use crate::atuin_server_database::{Database, DbError, models::User}; pub(crate) struct UserAuth(pub(crate) User); -impl<DB: Send + Sync> FromRequestParts<AppState<DB>> for UserAuth -where - DB: Database, -{ +impl FromRequestParts<AppState> for UserAuth { type Rejection = ErrorResponseStatus<'static>; async fn from_request_parts( req: &mut Parts, - state: &AppState<DB>, + state: &AppState, ) -> Result<Self, Self::Rejection> { let auth_header = req .headers @@ -78,18 +77,6 @@ async fn teapot() -> impl IntoResponse { (http::StatusCode::NOT_FOUND, "404 not found") } -async fn clacks_overhead(request: Request, next: Next) -> Response { - let mut response = next.run(request).await; - - let gnu_terry_value = "GNU Terry Pratchett, Kris Nova"; - let gnu_terry_header = "X-Clacks-Overhead"; - - response - .headers_mut() - .insert(gnu_terry_header, gnu_terry_value.parse().unwrap()); - response -} - /// Ensure that we only try and sync with clients on the same major version async fn semver(request: Request, next: Next) -> Response { let mut response = next.run(request).await; @@ -101,27 +88,16 @@ async fn semver(request: Request, next: Next) -> Response { } #[derive(Clone)] -pub(crate) struct AppState<DB: Database> { - pub(crate) database: DB, +pub(crate) struct AppState { + pub(crate) database: Database, pub(crate) settings: Settings, } -pub(crate) fn router<DB: Database>(database: DB, settings: Settings) -> Router { - let mut routes = Router::new() +pub(crate) fn router(database: Database, settings: Settings) -> Router { + let routes = Router::new() .route("/", get(handlers::index)) .route("/healthz", get(handlers::health::health_check)); - // Sync v1 routes - can be disabled in favor of record-based sync - if settings.sync_v1_enabled { - routes = routes - .route("/sync/count", get(handlers::history::count)) - .route("/sync/history", get(handlers::history::list)) - .route("/sync/calendar/{focus}", get(handlers::history::calendar)) - .route("/sync/status", get(handlers::status::status)) - .route("/history", post(handlers::history::add)) - .route("/history", delete(handlers::history::delete)); - } - let routes = routes .route("/user/{username}", get(handlers::user::get)) .route("/account", delete(handlers::user::delete)) @@ -147,7 +123,6 @@ pub(crate) fn router<DB: Database>(database: DB, settings: Settings) -> Router { .with_state(AppState { database, settings }) .layer( ServiceBuilder::new() - .layer(axum::middleware::from_fn(clacks_overhead)) .layer(TraceLayer::new_for_http()) .layer(axum::middleware::from_fn(metrics::track_metrics)) .layer(axum::middleware::from_fn(semver)), diff --git a/crates/turtle/src/atuin_server/settings.rs b/crates/turtle/src/atuin_server/settings.rs index 1d0ac2d0..b62f24e1 100644 --- a/crates/turtle/src/atuin_server/settings.rs +++ b/crates/turtle/src/atuin_server/settings.rs @@ -1,11 +1,12 @@ use std::{io::prelude::*, path::PathBuf}; -use crate::atuin_server_database::DbSettings; use config::{Config, Environment, File as ConfigFile, FileFormat}; use eyre::{Result, eyre}; use fs_err::{File, create_dir_all}; use serde::{Deserialize, Serialize}; +use crate::atuin_server::database::DbSettings; + #[derive(Clone, Debug, Deserialize, Serialize)] pub(crate) struct Metrics { #[serde(alias = "enabled")] @@ -37,10 +38,6 @@ pub(crate) struct Settings { pub(crate) register_webhook_username: String, pub(crate) metrics: Metrics, - /// Enable legacy sync v1 routes (history-based sync) - /// Set to false to use only the newer record-based sync - pub(crate) sync_v1_enabled: bool, - /// Advertise a version that is not what we are _actually_ running /// Many clients compare their version with api.atuin.sh, and if they differ, notify the user /// that an update is available. @@ -78,7 +75,6 @@ impl Settings { .set_default("metrics.enable", false)? .set_default("metrics.host", "127.0.0.1")? .set_default("metrics.port", 9001)? - .set_default("sync_v1_enabled", true)? .add_source( Environment::with_prefix("atuin") .prefix_separator("_") diff --git a/crates/turtle/src/atuin_server/utils.rs b/crates/turtle/src/atuin_server/utils.rs deleted file mode 100644 index cceef3ed..00000000 --- a/crates/turtle/src/atuin_server/utils.rs +++ /dev/null @@ -1,15 +0,0 @@ -use eyre::Result; -use semver::{Version, VersionReq}; - -pub(crate) fn client_version_min(user_agent: &str, req: &str) -> Result<bool> { - if user_agent.is_empty() { - return Ok(false); - } - - let version = user_agent.replace("atuin/", ""); - - let req = VersionReq::parse(req)?; - let version = Version::parse(version.as_str())?; - - Ok(req.matches(&version)) -} diff --git a/crates/turtle/src/atuin_server_database/mod.rs b/crates/turtle/src/atuin_server_database/mod.rs deleted file mode 100644 index e4672bb0..00000000 --- a/crates/turtle/src/atuin_server_database/mod.rs +++ /dev/null @@ -1,266 +0,0 @@ -pub(crate) mod calendar; -pub(crate) mod models; - -use std::{ - collections::HashMap, - fmt::{Debug, Display}, - ops::Range, -}; - -use self::{ - calendar::{TimePeriod, TimePeriodInfo}, - models::{History, NewHistory, NewSession, NewUser, Session, User}, -}; -use async_trait::async_trait; -use crate::atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}; -use serde::{Deserialize, Serialize}; -use time::{Date, Duration, Month, OffsetDateTime, PrimitiveDateTime, Time, UtcOffset}; -use tracing::instrument; - -#[derive(Debug)] -pub(crate) enum DbError { - NotFound, - Other(eyre::Report), -} - -impl Display for DbError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{self:?}") - } -} - -impl From<time::error::ComponentRange> for DbError { - fn from(error: time::error::ComponentRange) -> Self { - DbError::Other(error.into()) - } -} - -impl From<time::error::Error> for DbError { - fn from(error: time::error::Error) -> Self { - DbError::Other(error.into()) - } -} - -impl From<sqlx::Error> for DbError { - fn from(error: sqlx::Error) -> Self { - match error { - sqlx::Error::RowNotFound => DbError::NotFound, - error => DbError::Other(error.into()), - } - } -} - -impl std::error::Error for DbError {} - -pub(crate) type DbResult<T> = Result<T, DbError>; - -#[derive(Debug, PartialEq)] -pub(crate) enum DbType { - Postgres, - Sqlite, - Unknown, -} - -#[derive(Clone, Deserialize, Serialize)] -pub(crate) struct DbSettings { - pub(crate) db_uri: String, - /// Optional URI for read replicas. If set, read-only queries will use this connection. - pub(crate) read_db_uri: Option<String>, -} - -impl DbSettings { - pub(crate) fn db_type(&self) -> DbType { - if self.db_uri.starts_with("postgres://") || self.db_uri.starts_with("postgresql://") { - DbType::Postgres - } else if self.db_uri.starts_with("sqlite://") { - DbType::Sqlite - } else { - DbType::Unknown - } - } -} - -fn redact_db_uri(uri: &str) -> String { - url::Url::parse(uri) - .map(|mut url| { - let _ = url.set_password(Some("****")); - url.to_string() - }) - .unwrap_or_else(|_| uri.to_string()) -} - -// Do our best to redact passwords so they're not logged in the event of an error. -impl Debug for DbSettings { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.db_type() == DbType::Postgres { - let redacted_uri = redact_db_uri(&self.db_uri); - let redacted_read_uri = self.read_db_uri.as_ref().map(|uri| redact_db_uri(uri)); - f.debug_struct("DbSettings") - .field("db_uri", &redacted_uri) - .field("read_db_uri", &redacted_read_uri) - .finish() - } else { - f.debug_struct("DbSettings") - .field("db_uri", &self.db_uri) - .field("read_db_uri", &self.read_db_uri) - .finish() - } - } -} - -#[async_trait] -pub(crate) trait Database: Sized + Clone + Send + Sync + 'static { - async fn new(settings: &DbSettings) -> DbResult<Self>; - - async fn get_session(&self, token: &str) -> DbResult<Session>; - async fn get_session_user(&self, token: &str) -> DbResult<User>; - async fn add_session(&self, session: &NewSession) -> DbResult<()>; - - async fn get_user(&self, username: &str) -> DbResult<User>; - async fn get_user_session(&self, u: &User) -> DbResult<Session>; - async fn add_user(&self, user: &NewUser) -> DbResult<i64>; - - async fn update_user_password(&self, u: &User) -> DbResult<()>; - - async fn count_history(&self, user: &User) -> DbResult<i64>; - async fn count_history_cached(&self, user: &User) -> DbResult<i64>; - - async fn delete_user(&self, u: &User) -> DbResult<()>; - async fn delete_history(&self, user: &User, id: String) -> DbResult<()>; - async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>>; - async fn delete_store(&self, user: &User) -> DbResult<()>; - - async fn add_records(&self, user: &User, record: &[Record<EncryptedData>]) -> DbResult<()>; - async fn next_records( - &self, - user: &User, - host: HostId, - tag: String, - start: Option<RecordIdx>, - count: u64, - ) -> DbResult<Vec<Record<EncryptedData>>>; - - // Return the tail record ID for each store, so (HostID, Tag, TailRecordID) - async fn status(&self, user: &User) -> DbResult<RecordStatus>; - - async fn count_history_range(&self, user: &User, range: Range<OffsetDateTime>) - -> DbResult<i64>; - - async fn list_history( - &self, - user: &User, - created_after: OffsetDateTime, - since: OffsetDateTime, - host: &str, - page_size: i64, - ) -> DbResult<Vec<History>>; - - async fn add_history(&self, history: &[NewHistory]) -> DbResult<()>; - - async fn oldest_history(&self, user: &User) -> DbResult<History>; - - #[instrument(skip_all)] - async fn calendar( - &self, - user: &User, - period: TimePeriod, - tz: UtcOffset, - ) -> DbResult<HashMap<u64, TimePeriodInfo>> { - let mut ret = HashMap::new(); - let iter: Box<dyn Iterator<Item = DbResult<(u64, Range<Date>)>> + Send> = match period { - TimePeriod::Year => { - // First we need to work out how far back to calculate. Get the - // oldest history item - let oldest = self - .oldest_history(user) - .await? - .timestamp - .to_offset(tz) - .year(); - let current_year = OffsetDateTime::now_utc().to_offset(tz).year(); - - // All the years we need to get data for - // The upper bound is exclusive, so include current +1 - let years = oldest..current_year + 1; - - Box::new(years.map(|year| { - let start = Date::from_calendar_date(year, time::Month::January, 1)?; - let end = Date::from_calendar_date(year + 1, time::Month::January, 1)?; - - Ok((year as u64, start..end)) - })) - } - - TimePeriod::Month { year } => { - let months = - std::iter::successors(Some(Month::January), |m| Some(m.next())).take(12); - - Box::new(months.map(move |month| { - let start = Date::from_calendar_date(year, month, 1)?; - let days = start.month().length(year); - let end = start + Duration::days(days as i64); - - Ok((month as u64, start..end)) - })) - } - - TimePeriod::Day { year, month } => { - let days = 1..month.length(year); - Box::new(days.map(move |day| { - let start = Date::from_calendar_date(year, month, day)?; - let end = start - .next_day() - .ok_or_else(|| DbError::Other(eyre::eyre!("no next day?")))?; - - Ok((day as u64, start..end)) - })) - } - }; - - for x in iter { - let (index, range) = x?; - - let start = range.start.with_time(Time::MIDNIGHT).assume_offset(tz); - let end = range.end.with_time(Time::MIDNIGHT).assume_offset(tz); - - let count = self.count_history_range(user, start..end).await?; - - ret.insert( - index, - TimePeriodInfo { - count: count as u64, - hash: "".to_string(), - }, - ); - } - - Ok(ret) - } -} - -pub(crate) fn into_utc(x: OffsetDateTime) -> PrimitiveDateTime { - let x = x.to_offset(UtcOffset::UTC); - PrimitiveDateTime::new(x.date(), x.time()) -} - -#[cfg(test)] -mod tests { - use time::macros::datetime; - - use crate::into_utc; - - #[test] - fn utc() { - let dt = datetime!(2023-09-26 15:11:02 +05:30); - assert_eq!(into_utc(dt), datetime!(2023-09-26 09:41:02)); - assert_eq!(into_utc(dt).assume_utc(), dt); - - let dt = datetime!(2023-09-26 15:11:02 -07:00); - assert_eq!(into_utc(dt), datetime!(2023-09-26 22:11:02)); - assert_eq!(into_utc(dt).assume_utc(), dt); - - let dt = datetime!(2023-09-26 15:11:02 +00:00); - assert_eq!(into_utc(dt), datetime!(2023-09-26 15:11:02)); - assert_eq!(into_utc(dt).assume_utc(), dt); - } -} diff --git a/crates/turtle/src/atuin_server_sqlite/mod.rs b/crates/turtle/src/atuin_server_sqlite/mod.rs deleted file mode 100644 index b1de511d..00000000 --- a/crates/turtle/src/atuin_server_sqlite/mod.rs +++ /dev/null @@ -1,430 +0,0 @@ -use std::str::FromStr; - -use crate::atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}; -use crate::atuin_server_database::{ - Database, DbError, DbResult, DbSettings, into_utc, - models::{History, NewHistory, NewSession, NewUser, Session, User}, -}; -use async_trait::async_trait; -use futures_util::TryStreamExt; -use sqlx::{ - Row, - sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions}, - types::Uuid, -}; -use tracing::instrument; -use wrappers::{DbHistory, DbRecord, DbSession, DbUser}; - -mod wrappers; - -#[derive(Clone)] -pub(crate) struct Sqlite { - pool: sqlx::Pool<sqlx::sqlite::Sqlite>, -} - -#[async_trait] -impl Database for Sqlite { - async fn new(settings: &DbSettings) -> DbResult<Self> { - let opts = SqliteConnectOptions::from_str(&settings.db_uri)? - .journal_mode(SqliteJournalMode::Wal) - .create_if_missing(true); - - let pool = SqlitePoolOptions::new().connect_with(opts).await?; - - sqlx::migrate!("./db/server-sqlite-migrations") - .run(&pool) - .await - .map_err(|error| DbError::Other(error.into()))?; - - Ok(Self { pool }) - } - - #[instrument(skip_all)] - async fn get_session(&self, token: &str) -> DbResult<Session> { - sqlx::query_as("select id, user_id, token from sessions where token = $1") - .bind(token) - .fetch_one(&self.pool) - .await - .map_err(Into::into) - .map(|DbSession(session)| session) - } - - #[instrument(skip_all)] - async fn get_session_user(&self, token: &str) -> DbResult<User> { - sqlx::query_as( - "select users.id, users.username, users.email, users.password from users - inner join sessions - on users.id = sessions.user_id - and sessions.token = $1", - ) - .bind(token) - .fetch_one(&self.pool) - .await - .map_err(Into::into) - .map(|DbUser(user)| user) - } - - #[instrument(skip_all)] - async fn add_session(&self, session: &NewSession) -> DbResult<()> { - let token: &str = &session.token; - - sqlx::query( - "insert into sessions - (user_id, token) - values($1, $2)", - ) - .bind(session.user_id) - .bind(token) - .execute(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn get_user(&self, username: &str) -> DbResult<User> { - sqlx::query_as("select id, username, email, password from users where username = $1") - .bind(username) - .fetch_one(&self.pool) - .await - .map_err(Into::into) - .map(|DbUser(user)| user) - } - - #[instrument(skip_all)] - async fn get_user_session(&self, u: &User) -> DbResult<Session> { - sqlx::query_as("select id, user_id, token from sessions where user_id = $1") - .bind(u.id) - .fetch_one(&self.pool) - .await - .map_err(Into::into) - .map(|DbSession(session)| session) - } - - #[instrument(skip_all)] - async fn add_user(&self, user: &NewUser) -> DbResult<i64> { - let email: &str = &user.email; - let username: &str = &user.username; - let password: &str = &user.password; - - let res: (i64,) = sqlx::query_as( - "insert into users - (username, email, password) - values($1, $2, $3) - returning id", - ) - .bind(username) - .bind(email) - .bind(password) - .fetch_one(&self.pool) - .await?; - - Ok(res.0) - } - - #[instrument(skip_all)] - async fn update_user_password(&self, user: &User) -> DbResult<()> { - sqlx::query( - "update users - set password = $1 - where id = $2", - ) - .bind(&user.password) - .bind(user.id) - .execute(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn count_history(&self, user: &User) -> DbResult<i64> { - // The cache is new, and the user might not yet have a cache value. - // They will have one as soon as they post up some new history, but handle that - // edge case. - - let res: (i64,) = sqlx::query_as( - "select count(1) from history - where user_id = $1", - ) - .bind(user.id) - .fetch_one(&self.pool) - .await?; - - Ok(res.0) - } - - #[instrument(skip_all)] - async fn count_history_cached(&self, _user: &User) -> DbResult<i64> { - Err(DbError::NotFound) - } - - #[instrument(skip_all)] - async fn delete_user(&self, u: &User) -> DbResult<()> { - sqlx::query("delete from sessions where user_id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - sqlx::query("delete from users where id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - sqlx::query("delete from history where user_id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - Ok(()) - } - - async fn delete_history(&self, user: &User, id: String) -> DbResult<()> { - sqlx::query( - "update history - set deleted_at = $3 - where user_id = $1 - and client_id = $2 - and deleted_at is null", // don't just keep setting it - ) - .bind(user.id) - .bind(id) - .bind(time::OffsetDateTime::now_utc()) - .fetch_all(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> { - // The cache is new, and the user might not yet have a cache value. - // They will have one as soon as they post up some new history, but handle that - // edge case. - - let res = sqlx::query( - "select client_id from history - where user_id = $1 - and deleted_at is not null", - ) - .bind(user.id) - .fetch_all(&self.pool) - .await?; - - let res = res.iter().map(|row| row.get("client_id")).collect(); - - Ok(res) - } - - async fn delete_store(&self, user: &User) -> DbResult<()> { - sqlx::query( - "delete from store - where user_id = $1", - ) - .bind(user.id) - .execute(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> { - let mut tx = self.pool.begin().await?; - - for i in records { - let id = crate::atuin_common::utils::uuid_v7(); - - sqlx::query( - "insert into store - (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id) - values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) - on conflict do nothing - ", - ) - .bind(id) - .bind(i.id) - .bind(i.host.id) - .bind(i.idx as i64) - .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time - .bind(&i.version) - .bind(&i.tag) - .bind(&i.data.data) - .bind(&i.data.content_encryption_key) - .bind(user.id) - .execute(&mut *tx) - .await?; - } - - tx.commit().await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn next_records( - &self, - user: &User, - host: HostId, - tag: String, - start: Option<RecordIdx>, - count: u64, - ) -> DbResult<Vec<Record<EncryptedData>>> { - tracing::debug!("{:?} - {:?} - {:?}", host, tag, start); - let start = start.unwrap_or(0); - - let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as( - "select client_id, host, idx, timestamp, version, tag, data, cek from store - where user_id = $1 - and tag = $2 - and host = $3 - and idx >= $4 - order by idx asc - limit $5", - ) - .bind(user.id) - .bind(tag.clone()) - .bind(host) - .bind(start as i64) - .bind(count as i64) - .fetch_all(&self.pool) - .await - .map_err(Into::into); - - let ret = match records { - Ok(records) => { - let records: Vec<Record<EncryptedData>> = records - .into_iter() - .map(|f| { - let record: Record<EncryptedData> = f.into(); - record - }) - .collect(); - - records - } - Err(DbError::NotFound) => { - tracing::debug!("no records found in store: {:?}/{}", host, tag); - return Ok(vec![]); - } - Err(e) => return Err(e), - }; - - Ok(ret) - } - - async fn status(&self, user: &User) -> DbResult<RecordStatus> { - const STATUS_SQL: &str = - "select host, tag, max(idx) from store where user_id = $1 group by host, tag"; - - let res: Vec<(Uuid, String, i64)> = sqlx::query_as(STATUS_SQL) - .bind(user.id) - .fetch_all(&self.pool) - .await?; - - let mut status = RecordStatus::new(); - - for i in res { - status.set_raw(HostId(i.0), i.1, i.2 as u64); - } - - Ok(status) - } - - #[instrument(skip_all)] - async fn count_history_range( - &self, - user: &User, - range: std::ops::Range<time::OffsetDateTime>, - ) -> DbResult<i64> { - let res: (i64,) = sqlx::query_as( - "select count(1) from history - where user_id = $1 - and timestamp >= $2::date - and timestamp < $3::date", - ) - .bind(user.id) - .bind(into_utc(range.start)) - .bind(into_utc(range.end)) - .fetch_one(&self.pool) - .await?; - - Ok(res.0) - } - - #[instrument(skip_all)] - async fn list_history( - &self, - user: &User, - created_after: time::OffsetDateTime, - since: time::OffsetDateTime, - host: &str, - page_size: i64, - ) -> DbResult<Vec<History>> { - let res = sqlx::query_as( - "select id, client_id, user_id, hostname, timestamp, data, created_at from history - where user_id = $1 - and hostname != $2 - and created_at >= $3 - and timestamp >= $4 - order by timestamp asc - limit $5", - ) - .bind(user.id) - .bind(host) - .bind(into_utc(created_after)) - .bind(into_utc(since)) - .bind(page_size) - .fetch(&self.pool) - .map_ok(|DbHistory(h)| h) - .try_collect() - .await?; - - Ok(res) - } - - #[instrument(skip_all)] - async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> { - let mut tx = self.pool.begin().await?; - - for i in history { - let client_id: &str = &i.client_id; - let hostname: &str = &i.hostname; - let data: &str = &i.data; - - sqlx::query( - "insert into history - (client_id, user_id, hostname, timestamp, data) - values ($1, $2, $3, $4, $5) - on conflict do nothing - ", - ) - .bind(client_id) - .bind(i.user_id) - .bind(hostname) - .bind(i.timestamp) - .bind(data) - .execute(&mut *tx) - .await?; - } - - tx.commit().await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn oldest_history(&self, user: &User) -> DbResult<History> { - sqlx::query_as( - "select id, client_id, user_id, hostname, timestamp, data, created_at from history - where user_id = $1 - order by timestamp asc - limit 1", - ) - .bind(user.id) - .fetch_one(&self.pool) - .await - .map_err(Into::into) - .map(|DbHistory(h)| h) - } -} diff --git a/crates/turtle/src/atuin_server_sqlite/wrappers.rs b/crates/turtle/src/atuin_server_sqlite/wrappers.rs deleted file mode 100644 index e7380bce..00000000 --- a/crates/turtle/src/atuin_server_sqlite/wrappers.rs +++ /dev/null @@ -1,72 +0,0 @@ -use ::sqlx::{FromRow, Result}; -use crate::atuin_common::record::{EncryptedData, Host, Record}; -use crate::atuin_server_database::models::{History, Session, User}; -use sqlx::{Row, sqlite::SqliteRow}; - -pub(crate) struct DbUser(pub(crate) User); -pub(crate) struct DbSession(pub(crate) Session); -pub(crate) struct DbHistory(pub(crate) History); -pub(crate) struct DbRecord(pub(crate) Record<EncryptedData>); - -impl<'a> FromRow<'a, SqliteRow> for DbUser { - fn from_row(row: &'a SqliteRow) -> Result<Self> { - Ok(Self(User { - id: row.try_get("id")?, - username: row.try_get("username")?, - email: row.try_get("email")?, - password: row.try_get("password")?, - })) - } -} - -impl<'a> ::sqlx::FromRow<'a, SqliteRow> for DbSession { - fn from_row(row: &'a SqliteRow) -> ::sqlx::Result<Self> { - Ok(Self(Session { - id: row.try_get("id")?, - user_id: row.try_get("user_id")?, - token: row.try_get("token")?, - })) - } -} - -impl<'a> ::sqlx::FromRow<'a, SqliteRow> for DbHistory { - fn from_row(row: &'a SqliteRow) -> ::sqlx::Result<Self> { - Ok(Self(History { - id: row.try_get("id")?, - client_id: row.try_get("client_id")?, - user_id: row.try_get("user_id")?, - hostname: row.try_get("hostname")?, - timestamp: row.try_get("timestamp")?, - data: row.try_get("data")?, - created_at: row.try_get("created_at")?, - })) - } -} - -impl<'a> ::sqlx::FromRow<'a, SqliteRow> for DbRecord { - fn from_row(row: &'a SqliteRow) -> ::sqlx::Result<Self> { - let idx: i64 = row.try_get("idx")?; - let timestamp: i64 = row.try_get("timestamp")?; - - let data = EncryptedData { - data: row.try_get("data")?, - content_encryption_key: row.try_get("cek")?, - }; - - Ok(Self(Record { - id: row.try_get("client_id")?, - host: Host::new(row.try_get("host")?), - idx: idx as u64, - timestamp: timestamp as u64, - version: row.try_get("version")?, - tag: row.try_get("tag")?, - data, - })) - } -} - -impl From<DbRecord> for Record<EncryptedData> { - fn from(other: DbRecord) -> Record<EncryptedData> { - Record { ..other.0 } - } -} |
