diff options
Diffstat (limited to 'crates')
21 files changed, 802 insertions, 48 deletions
diff --git a/crates/atuin-server-database/Cargo.toml b/crates/atuin-server-database/Cargo.toml index e3e38e3f..823b5d39 100644 --- a/crates/atuin-server-database/Cargo.toml +++ b/crates/atuin-server-database/Cargo.toml @@ -17,3 +17,4 @@ time = { workspace = true } eyre = { workspace = true } serde = { workspace = true } async-trait = { workspace = true } +url = "2.5.2" diff --git a/crates/atuin-server-database/src/lib.rs b/crates/atuin-server-database/src/lib.rs index 1c577f59..9df36d14 100644 --- a/crates/atuin-server-database/src/lib.rs +++ b/crates/atuin-server-database/src/lib.rs @@ -15,7 +15,7 @@ use self::{ }; use async_trait::async_trait; use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}; -use serde::{Serialize, de::DeserializeOwned}; +use serde::{Deserialize, Serialize}; use time::{Date, Duration, Month, OffsetDateTime, Time, UtcOffset}; use tracing::instrument; @@ -41,10 +41,54 @@ impl std::error::Error for DbError {} pub type DbResult<T> = Result<T, DbError>; +#[derive(Debug, PartialEq)] +pub enum DbType { + Postgres, + Sqlite, + Unknown, +} + +#[derive(Clone, Deserialize, Serialize)] +pub struct DbSettings { + pub db_uri: String, +} + +impl DbSettings { + pub fn db_type(&self) -> DbType { + if self.db_uri.starts_with("postgres://") { + DbType::Postgres + } else if self.db_uri.starts_with("sqlite://") { + DbType::Sqlite + } else { + DbType::Unknown + } + } +} + +// Do our best to redact passwords so they're not logged in the event of an error. +impl Debug for DbSettings { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.db_type() == DbType::Postgres { + let redacted_uri = url::Url::parse(&self.db_uri) + .map(|mut url| { + let _ = url.set_password(Some("****")); + url.to_string() + }) + .unwrap_or(self.db_uri.clone()); + f.debug_struct("DbSettings") + .field("db_uri", &redacted_uri) + .finish() + } else { + f.debug_struct("DbSettings") + .field("db_uri", &self.db_uri) + .finish() + } + } +} + #[async_trait] pub trait Database: Sized + Clone + Send + Sync + 'static { - type Settings: Debug + Clone + DeserializeOwned + Serialize + Send + Sync + 'static; - async fn new(settings: &Self::Settings) -> DbResult<Self>; + async fn new(settings: &DbSettings) -> DbResult<Self>; async fn get_session(&self, token: &str) -> DbResult<Session>; async fn get_session_user(&self, token: &str) -> DbResult<User>; diff --git a/crates/atuin-server-postgres/Cargo.toml b/crates/atuin-server-postgres/Cargo.toml index 9eccca50..f5e472cb 100644 --- a/crates/atuin-server-postgres/Cargo.toml +++ b/crates/atuin-server-postgres/Cargo.toml @@ -22,4 +22,3 @@ async-trait = { workspace = true } uuid = { workspace = true } metrics = "0.21.1" futures-util = "0.3" -url = "2.5.2" diff --git a/crates/atuin-server-postgres/src/lib.rs b/crates/atuin-server-postgres/src/lib.rs index 7c6d8f9a..005e8765 100644 --- a/crates/atuin-server-postgres/src/lib.rs +++ b/crates/atuin-server-postgres/src/lib.rs @@ -1,15 +1,13 @@ use std::collections::HashMap; -use std::fmt::Debug; use std::ops::Range; use async_trait::async_trait; use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}; use atuin_common::utils::crypto_random_string; use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User}; -use atuin_server_database::{Database, DbError, DbResult}; +use atuin_server_database::{Database, DbError, DbResult, DbSettings}; use futures_util::TryStreamExt; use metrics::counter; -use serde::{Deserialize, Serialize}; use sqlx::Row; use sqlx::postgres::PgPoolOptions; @@ -27,26 +25,6 @@ pub struct Postgres { pool: sqlx::Pool<sqlx::postgres::Postgres>, } -#[derive(Clone, Deserialize, Serialize)] -pub struct PostgresSettings { - pub db_uri: String, -} - -// Do our best to redact passwords so they're not logged in the event of an error. -impl Debug for PostgresSettings { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let redacted_uri = url::Url::parse(&self.db_uri) - .map(|mut url| { - let _ = url.set_password(Some("****")); - url.to_string() - }) - .unwrap_or(self.db_uri.clone()); - f.debug_struct("PostgresSettings") - .field("db_uri", &redacted_uri) - .finish() - } -} - fn fix_error(error: sqlx::Error) -> DbError { match error { sqlx::Error::RowNotFound => DbError::NotFound, @@ -56,8 +34,7 @@ fn fix_error(error: sqlx::Error) -> DbError { #[async_trait] impl Database for Postgres { - type Settings = PostgresSettings; - async fn new(settings: &PostgresSettings) -> DbResult<Self> { + async fn new(settings: &DbSettings) -> DbResult<Self> { let pool = PgPoolOptions::new() .max_connections(100) .connect(settings.db_uri.as_str()) diff --git a/crates/atuin-server-sqlite/Cargo.toml b/crates/atuin-server-sqlite/Cargo.toml new file mode 100644 index 00000000..c04604e7 --- /dev/null +++ b/crates/atuin-server-sqlite/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "atuin-server-sqlite" +edition = "2024" +description = "server sqlite database library for atuin" + +version = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } + +[dependencies] +atuin-common = { path = "../atuin-common", version = "18.6.1" } +atuin-server-database = { path = "../atuin-server-database", version = "18.6.1" } + +eyre = { workspace = true } +tracing = { workspace = true } +time = { workspace = true } +serde = { workspace = true } +sqlx = { workspace = true } +async-trait = { workspace = true } +uuid = { workspace = true } +metrics = "0.21.1" +futures-util = "0.3" diff --git a/crates/atuin-server-sqlite/build.rs b/crates/atuin-server-sqlite/build.rs new file mode 100644 index 00000000..d5068697 --- /dev/null +++ b/crates/atuin-server-sqlite/build.rs @@ -0,0 +1,5 @@ +// generated by `sqlx migrate build-script` +fn main() { + // trigger recompilation when a new migration is added + println!("cargo:rerun-if-changed=migrations"); +} diff --git a/crates/atuin-server-sqlite/migrations/20231203124112_create-store.sql b/crates/atuin-server-sqlite/migrations/20231203124112_create-store.sql new file mode 100644 index 00000000..ca19ed62 --- /dev/null +++ b/crates/atuin-server-sqlite/migrations/20231203124112_create-store.sql @@ -0,0 +1,17 @@ +create table store ( + id text primary key, -- remember to use uuidv7 for happy indices <3 + client_id text not null, -- I am too uncomfortable with the idea of a client-generated primary key, even though it's fine mathematically + host text not null, -- a unique identifier for the host + idx bigint not null, -- the index of the record in this store, identified by (host, tag) + timestamp bigint not null, -- not a timestamp type, as those do not have nanosecond precision + version text not null, + tag text not null, -- what is this? history, kv, whatever. Remember clients get a log per tag per host + data text not null, -- store the actual history data, encrypted. I don't wanna know! + cek text not null, + + user_id bigint not null, -- allow multiple users + created_at timestamp not null default current_timestamp +); + +create unique index record_uniq ON store(user_id, host, tag, idx); + diff --git a/crates/atuin-server-sqlite/migrations/20240108124830_create-history.sql b/crates/atuin-server-sqlite/migrations/20240108124830_create-history.sql new file mode 100644 index 00000000..7bd653ba --- /dev/null +++ b/crates/atuin-server-sqlite/migrations/20240108124830_create-history.sql @@ -0,0 +1,15 @@ +create table history ( + id integer primary key autoincrement, + client_id text not null unique, -- the client-generated ID + user_id bigserial not null, -- allow multiple users + hostname text not null, -- a unique identifier from the client (can be hashed, random, whatever) + timestamp timestamp not null, -- one of the few non-encrypted metadatas + + data text not null, -- store the actual history data, encrypted. I don't wanna know! + + created_at timestamp not null default current_timestamp, + deleted_at timestamp +); + +create unique index history_deleted_index on history(client_id, user_id, deleted_at); + diff --git a/crates/atuin-server-sqlite/migrations/20240108124831_create-sessions.sql b/crates/atuin-server-sqlite/migrations/20240108124831_create-sessions.sql new file mode 100644 index 00000000..3120c35d --- /dev/null +++ b/crates/atuin-server-sqlite/migrations/20240108124831_create-sessions.sql @@ -0,0 +1,6 @@ +create table sessions ( + id integer primary key autoincrement, + user_id integer, + token text unique not null +); + diff --git a/crates/atuin-server-sqlite/migrations/20240621110730_create-users.sql b/crates/atuin-server-sqlite/migrations/20240621110730_create-users.sql new file mode 100644 index 00000000..852c159d --- /dev/null +++ b/crates/atuin-server-sqlite/migrations/20240621110730_create-users.sql @@ -0,0 +1,12 @@ +create table users ( + id integer primary key autoincrement, -- also store our own ID + username text not null unique, -- being able to contact users is useful + email text not null unique, -- being able to contact users is useful + password text not null unique, + created_at timestamp not null default (datetime('now','localtime')), + verified_at timestamp with time zone default null +); + +-- the prior index is case sensitive :( +CREATE UNIQUE INDEX email_unique_idx on users (LOWER(email)); +CREATE UNIQUE INDEX username_unique_idx on users (LOWER(username)); diff --git a/crates/atuin-server-sqlite/migrations/20240621110731_create-user-verification-token.sql b/crates/atuin-server-sqlite/migrations/20240621110731_create-user-verification-token.sql new file mode 100644 index 00000000..36eb14de --- /dev/null +++ b/crates/atuin-server-sqlite/migrations/20240621110731_create-user-verification-token.sql @@ -0,0 +1,6 @@ +create table user_verification_token( + id integer primary key autoincrement, + user_id bigint unique references users(id), + token text, + valid_until timestamp with time zone +); diff --git a/crates/atuin-server-sqlite/migrations/20240702094825_create-store-idx-cache.sql b/crates/atuin-server-sqlite/migrations/20240702094825_create-store-idx-cache.sql new file mode 100644 index 00000000..cd54cb18 --- /dev/null +++ b/crates/atuin-server-sqlite/migrations/20240702094825_create-store-idx-cache.sql @@ -0,0 +1,10 @@ +create table store_idx_cache( + id integer primary key autoincrement, + user_id bigint, + + host uuid, + tag text, + idx bigint +); + +create unique index store_idx_cache_uniq on store_idx_cache(user_id, host, tag); diff --git a/crates/atuin-server-sqlite/src/lib.rs b/crates/atuin-server-sqlite/src/lib.rs new file mode 100644 index 00000000..9cc1e8a7 --- /dev/null +++ b/crates/atuin-server-sqlite/src/lib.rs @@ -0,0 +1,552 @@ +use std::str::FromStr; + +use async_trait::async_trait; +use atuin_common::{ + record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}, + utils::crypto_random_string, +}; +use atuin_server_database::{ + Database, DbError, DbResult, DbSettings, + models::{History, NewHistory, NewSession, NewUser, Session, User}, +}; +use futures_util::TryStreamExt; +use sqlx::{ + Row, + sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions}, + types::Uuid, +}; +use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset}; +use tracing::instrument; +use wrappers::{DbHistory, DbRecord, DbSession, DbUser}; + +mod wrappers; + +#[derive(Clone)] +pub struct Sqlite { + pool: sqlx::Pool<sqlx::sqlite::Sqlite>, +} + +fn fix_error(error: sqlx::Error) -> DbError { + match error { + sqlx::Error::RowNotFound => DbError::NotFound, + error => DbError::Other(error.into()), + } +} + +#[async_trait] +impl Database for Sqlite { + async fn new(settings: &DbSettings) -> DbResult<Self> { + let opts = SqliteConnectOptions::from_str(&settings.db_uri) + .map_err(fix_error)? + .journal_mode(SqliteJournalMode::Wal) + .create_if_missing(true); + + let pool = SqlitePoolOptions::new() + .connect_with(opts) + .await + .map_err(fix_error)?; + + sqlx::migrate!("./migrations") + .run(&pool) + .await + .map_err(|error| DbError::Other(error.into()))?; + + Ok(Self { pool }) + } + + #[instrument(skip_all)] + async fn get_session(&self, token: &str) -> DbResult<Session> { + sqlx::query_as("select id, user_id, token from sessions where token = $1") + .bind(token) + .fetch_one(&self.pool) + .await + .map_err(fix_error) + .map(|DbSession(session)| session) + } + + #[instrument(skip_all)] + async fn get_session_user(&self, token: &str) -> DbResult<User> { + sqlx::query_as( + "select users.id, users.username, users.email, users.password, users.verified_at from users + inner join sessions + on users.id = sessions.user_id + and sessions.token = $1", + ) + .bind(token) + .fetch_one(&self.pool) + .await + .map_err(fix_error) + .map(|DbUser(user)| user) + } + + #[instrument(skip_all)] + async fn add_session(&self, session: &NewSession) -> DbResult<()> { + let token: &str = &session.token; + + sqlx::query( + "insert into sessions + (user_id, token) + values($1, $2)", + ) + .bind(session.user_id) + .bind(token) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn get_user(&self, username: &str) -> DbResult<User> { + sqlx::query_as( + "select id, username, email, password, verified_at from users where username = $1", + ) + .bind(username) + .fetch_one(&self.pool) + .await + .map_err(fix_error) + .map(|DbUser(user)| user) + } + + #[instrument(skip_all)] + async fn get_user_session(&self, u: &User) -> DbResult<Session> { + sqlx::query_as("select id, user_id, token from sessions where user_id = $1") + .bind(u.id) + .fetch_one(&self.pool) + .await + .map_err(fix_error) + .map(|DbSession(session)| session) + } + + #[instrument(skip_all)] + async fn add_user(&self, user: &NewUser) -> DbResult<i64> { + let email: &str = &user.email; + let username: &str = &user.username; + let password: &str = &user.password; + + let res: (i64,) = sqlx::query_as( + "insert into users + (username, email, password) + values($1, $2, $3) + returning id", + ) + .bind(username) + .bind(email) + .bind(password) + .fetch_one(&self.pool) + .await + .map_err(fix_error)?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn user_verified(&self, id: i64) -> DbResult<bool> { + let res: (bool,) = + sqlx::query_as("select verified_at is not null from users where id = $1") + .bind(id) + .fetch_one(&self.pool) + .await + .map_err(fix_error)?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn verify_user(&self, id: i64) -> DbResult<()> { + sqlx::query( + "update users set verified_at = (current_timestamp at time zone 'utc') where id=$1", + ) + .bind(id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn user_verification_token(&self, id: i64) -> DbResult<String> { + const TOKEN_VALID_MINUTES: i64 = 15; + + // First we check if there is a verification token + let token: Option<(String, sqlx::types::time::OffsetDateTime)> = sqlx::query_as( + "select token, valid_until from user_verification_token where user_id = $1", + ) + .bind(id) + .fetch_optional(&self.pool) + .await + .map_err(fix_error)?; + + let token = if let Some((token, valid_until)) = token { + // We have a token, AND it's still valid + if valid_until > time::OffsetDateTime::now_utc() { + token + } else { + // token has expired. generate a new one, return it + let token = crypto_random_string::<24>(); + + sqlx::query("update user_verification_token set token = $2, valid_until = $3 where user_id=$1") + .bind(id) + .bind(&token) + .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES)) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + token + } + } else { + // No token in the database! Generate one, insert it + let token = crypto_random_string::<24>(); + + sqlx::query("insert into user_verification_token (user_id, token, valid_until) values ($1, $2, $3)") + .bind(id) + .bind(&token) + .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES)) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + token + }; + + Ok(token) + } + + #[instrument(skip_all)] + async fn update_user_password(&self, user: &User) -> DbResult<()> { + sqlx::query( + "update users + set password = $1 + where id = $2", + ) + .bind(&user.password) + .bind(user.id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn total_history(&self) -> DbResult<i64> { + let res: (i64,) = sqlx::query_as("select count(1) from history") + .fetch_optional(&self.pool) + .await + .map_err(fix_error)? + .unwrap_or((0,)); + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn count_history(&self, user: &User) -> DbResult<i64> { + // The cache is new, and the user might not yet have a cache value. + // They will have one as soon as they post up some new history, but handle that + // edge case. + + let res: (i64,) = sqlx::query_as( + "select count(1) from history + where user_id = $1", + ) + .bind(user.id) + .fetch_one(&self.pool) + .await + .map_err(fix_error)?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn count_history_cached(&self, _user: &User) -> DbResult<i64> { + Err(DbError::NotFound) + } + + #[instrument(skip_all)] + async fn delete_user(&self, u: &User) -> DbResult<()> { + sqlx::query("delete from sessions where user_id = $1") + .bind(u.id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + sqlx::query("delete from users where id = $1") + .bind(u.id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + sqlx::query("delete from history where user_id = $1") + .bind(u.id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + Ok(()) + } + + async fn delete_history(&self, user: &User, id: String) -> DbResult<()> { + sqlx::query( + "update history + set deleted_at = $3 + where user_id = $1 + and client_id = $2 + and deleted_at is null", // don't just keep setting it + ) + .bind(user.id) + .bind(id) + .bind(time::OffsetDateTime::now_utc()) + .fetch_all(&self.pool) + .await + .map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> { + // The cache is new, and the user might not yet have a cache value. + // They will have one as soon as they post up some new history, but handle that + // edge case. + + let res = sqlx::query( + "select client_id from history + where user_id = $1 + and deleted_at is not null", + ) + .bind(user.id) + .fetch_all(&self.pool) + .await + .map_err(fix_error)?; + + let res = res.iter().map(|row| row.get("client_id")).collect(); + + Ok(res) + } + + async fn delete_store(&self, user: &User) -> DbResult<()> { + sqlx::query( + "delete from store + where user_id = $1", + ) + .bind(user.id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> { + let mut tx = self.pool.begin().await.map_err(fix_error)?; + + for i in records { + let id = atuin_common::utils::uuid_v7(); + + sqlx::query( + "insert into store + (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id) + values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + on conflict do nothing + ", + ) + .bind(id) + .bind(i.id) + .bind(i.host.id) + .bind(i.idx as i64) + .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time + .bind(&i.version) + .bind(&i.tag) + .bind(&i.data.data) + .bind(&i.data.content_encryption_key) + .bind(user.id) + .execute(&mut *tx) + .await + .map_err(fix_error)?; + } + + tx.commit().await.map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn next_records( + &self, + user: &User, + host: HostId, + tag: String, + start: Option<RecordIdx>, + count: u64, + ) -> DbResult<Vec<Record<EncryptedData>>> { + tracing::debug!("{:?} - {:?} - {:?}", host, tag, start); + let start = start.unwrap_or(0); + + let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as( + "select client_id, host, idx, timestamp, version, tag, data, cek from store + where user_id = $1 + and tag = $2 + and host = $3 + and idx >= $4 + order by idx asc + limit $5", + ) + .bind(user.id) + .bind(tag.clone()) + .bind(host) + .bind(start as i64) + .bind(count as i64) + .fetch_all(&self.pool) + .await + .map_err(fix_error); + + let ret = match records { + Ok(records) => { + let records: Vec<Record<EncryptedData>> = records + .into_iter() + .map(|f| { + let record: Record<EncryptedData> = f.into(); + record + }) + .collect(); + + records + } + Err(DbError::NotFound) => { + tracing::debug!("no records found in store: {:?}/{}", host, tag); + return Ok(vec![]); + } + Err(e) => return Err(e), + }; + + Ok(ret) + } + + async fn status(&self, user: &User) -> DbResult<RecordStatus> { + const STATUS_SQL: &str = + "select host, tag, max(idx) from store where user_id = $1 group by host, tag"; + + let res: Vec<(Uuid, String, i64)> = sqlx::query_as(STATUS_SQL) + .bind(user.id) + .fetch_all(&self.pool) + .await + .map_err(fix_error)?; + + let mut status = RecordStatus::new(); + + for i in res { + status.set_raw(HostId(i.0), i.1, i.2 as u64); + } + + Ok(status) + } + + #[instrument(skip_all)] + async fn count_history_range( + &self, + user: &User, + range: std::ops::Range<time::OffsetDateTime>, + ) -> DbResult<i64> { + let res: (i64,) = sqlx::query_as( + "select count(1) from history + where user_id = $1 + and timestamp >= $2::date + and timestamp < $3::date", + ) + .bind(user.id) + .bind(into_utc(range.start)) + .bind(into_utc(range.end)) + .fetch_one(&self.pool) + .await + .map_err(fix_error)?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn list_history( + &self, + user: &User, + created_after: time::OffsetDateTime, + since: time::OffsetDateTime, + host: &str, + page_size: i64, + ) -> DbResult<Vec<History>> { + let res = sqlx::query_as( + "select id, client_id, user_id, hostname, timestamp, data, created_at from history + where user_id = $1 + and hostname != $2 + and created_at >= $3 + and timestamp >= $4 + order by timestamp asc + limit $5", + ) + .bind(user.id) + .bind(host) + .bind(into_utc(created_after)) + .bind(into_utc(since)) + .bind(page_size) + .fetch(&self.pool) + .map_ok(|DbHistory(h)| h) + .try_collect() + .await + .map_err(fix_error)?; + + Ok(res) + } + + #[instrument(skip_all)] + async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> { + let mut tx = self.pool.begin().await.map_err(fix_error)?; + + for i in history { + let client_id: &str = &i.client_id; + let hostname: &str = &i.hostname; + let data: &str = &i.data; + + sqlx::query( + "insert into history + (client_id, user_id, hostname, timestamp, data) + values ($1, $2, $3, $4, $5) + on conflict do nothing + ", + ) + .bind(client_id) + .bind(i.user_id) + .bind(hostname) + .bind(i.timestamp) + .bind(data) + .execute(&mut *tx) + .await + .map_err(fix_error)?; + } + + tx.commit().await.map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn oldest_history(&self, user: &User) -> DbResult<History> { + sqlx::query_as( + "select id, client_id, user_id, hostname, timestamp, data, created_at from history + where user_id = $1 + order by timestamp asc + limit 1", + ) + .bind(user.id) + .fetch_one(&self.pool) + .await + .map_err(fix_error) + .map(|DbHistory(h)| h) + } +} + +fn into_utc(x: OffsetDateTime) -> PrimitiveDateTime { + let x = x.to_offset(UtcOffset::UTC); + PrimitiveDateTime::new(x.date(), x.time()) +} diff --git a/crates/atuin-server-sqlite/src/wrappers.rs b/crates/atuin-server-sqlite/src/wrappers.rs new file mode 100644 index 00000000..3f2262c3 --- /dev/null +++ b/crates/atuin-server-sqlite/src/wrappers.rs @@ -0,0 +1,73 @@ +use ::sqlx::{FromRow, Result}; +use atuin_common::record::{EncryptedData, Host, Record}; +use atuin_server_database::models::{History, Session, User}; +use sqlx::{Row, sqlite::SqliteRow}; + +pub struct DbUser(pub User); +pub struct DbSession(pub Session); +pub struct DbHistory(pub History); +pub struct DbRecord(pub Record<EncryptedData>); + +impl<'a> FromRow<'a, SqliteRow> for DbUser { + fn from_row(row: &'a SqliteRow) -> Result<Self> { + Ok(Self(User { + id: row.try_get("id")?, + username: row.try_get("username")?, + email: row.try_get("email")?, + password: row.try_get("password")?, + verified: row.try_get("verified_at")?, + })) + } +} + +impl<'a> ::sqlx::FromRow<'a, SqliteRow> for DbSession { + fn from_row(row: &'a SqliteRow) -> ::sqlx::Result<Self> { + Ok(Self(Session { + id: row.try_get("id")?, + user_id: row.try_get("user_id")?, + token: row.try_get("token")?, + })) + } +} + +impl<'a> ::sqlx::FromRow<'a, SqliteRow> for DbHistory { + fn from_row(row: &'a SqliteRow) -> ::sqlx::Result<Self> { + Ok(Self(History { + id: row.try_get("id")?, + client_id: row.try_get("client_id")?, + user_id: row.try_get("user_id")?, + hostname: row.try_get("hostname")?, + timestamp: row.try_get("timestamp")?, + data: row.try_get("data")?, + created_at: row.try_get("created_at")?, + })) + } +} + +impl<'a> ::sqlx::FromRow<'a, SqliteRow> for DbRecord { + fn from_row(row: &'a SqliteRow) -> ::sqlx::Result<Self> { + let idx: i64 = row.try_get("idx")?; + let timestamp: i64 = row.try_get("timestamp")?; + + let data = EncryptedData { + data: row.try_get("data")?, + content_encryption_key: row.try_get("cek")?, + }; + + Ok(Self(Record { + id: row.try_get("client_id")?, + host: Host::new(row.try_get("host")?), + idx: idx as u64, + timestamp: timestamp as u64, + version: row.try_get("version")?, + tag: row.try_get("tag")?, + data, + })) + } +} + +impl From<DbRecord> for Record<EncryptedData> { + fn from(other: DbRecord) -> Record<EncryptedData> { + Record { ..other.0 } + } +} diff --git a/crates/atuin-server/server.toml b/crates/atuin-server/server.toml index 946769c9..1eff5b72 100644 --- a/crates/atuin-server/server.toml +++ b/crates/atuin-server/server.toml @@ -9,6 +9,7 @@ ## URI for postgres (using development creds here) # db_uri="postgres://username:password@localhost/atuin" +# db_uri="sqlite:///config/atuin-server.db" ## Maximum size for one history entry # max_history_length = 8192 diff --git a/crates/atuin-server/src/lib.rs b/crates/atuin-server/src/lib.rs index 7a0e982b..f1d616f2 100644 --- a/crates/atuin-server/src/lib.rs +++ b/crates/atuin-server/src/lib.rs @@ -45,10 +45,7 @@ async fn shutdown_signal() { eprintln!("Shutting down gracefully..."); } -pub async fn launch<Db: Database>( - settings: Settings<Db::Settings>, - addr: SocketAddr, -) -> Result<()> { +pub async fn launch<Db: Database>(settings: Settings, addr: SocketAddr) -> Result<()> { if settings.tls.enable { launch_with_tls::<Db>(settings, addr, shutdown_signal()).await } else { @@ -64,7 +61,7 @@ pub async fn launch<Db: Database>( } pub async fn launch_with_tcp_listener<Db: Database>( - settings: Settings<Db::Settings>, + settings: Settings, listener: TcpListener, shutdown: impl Future<Output = ()> + Send + 'static, ) -> Result<()> { @@ -78,7 +75,7 @@ pub async fn launch_with_tcp_listener<Db: Database>( } async fn launch_with_tls<Db: Database>( - settings: Settings<Db::Settings>, + settings: Settings, addr: SocketAddr, shutdown: impl Future<Output = ()>, ) -> Result<()> { @@ -135,9 +132,7 @@ pub async fn launch_metrics_server(host: String, port: u16) -> Result<()> { Ok(()) } -async fn make_router<Db: Database>( - settings: Settings<<Db as Database>::Settings>, -) -> Result<Router, eyre::Error> { +async fn make_router<Db: Database>(settings: Settings) -> Result<Router, eyre::Error> { let db = Db::new(&settings.db_settings) .await .wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?; diff --git a/crates/atuin-server/src/router.rs b/crates/atuin-server/src/router.rs index ae63e1e8..6d168f63 100644 --- a/crates/atuin-server/src/router.rs +++ b/crates/atuin-server/src/router.rs @@ -105,10 +105,10 @@ async fn semver(request: Request, next: Next) -> Response { #[derive(Clone)] pub struct AppState<DB: Database> { pub database: DB, - pub settings: Settings<DB::Settings>, + pub settings: Settings, } -pub fn router<DB: Database>(database: DB, settings: Settings<DB::Settings>) -> Router { +pub fn router<DB: Database>(database: DB, settings: Settings) -> Router { let routes = Router::new() .route("/", get(handlers::index)) .route("/healthz", get(handlers::health::health_check)) diff --git a/crates/atuin-server/src/settings.rs b/crates/atuin-server/src/settings.rs index d5070dae..7221d4dd 100644 --- a/crates/atuin-server/src/settings.rs +++ b/crates/atuin-server/src/settings.rs @@ -1,9 +1,10 @@ use std::{io::prelude::*, path::PathBuf}; +use atuin_server_database::DbSettings; use config::{Config, Environment, File as ConfigFile, FileFormat}; use eyre::{Result, eyre}; use fs_err::{File, create_dir_all}; -use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use serde::{Deserialize, Serialize}; static EXAMPLE_CONFIG: &str = include_str!("../server.toml"); @@ -53,7 +54,7 @@ impl Default for Metrics { } #[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Settings<DbSettings> { +pub struct Settings { pub host: String, pub port: u16, pub path: String, @@ -78,7 +79,7 @@ pub struct Settings<DbSettings> { pub db_settings: DbSettings, } -impl<DbSettings: DeserializeOwned> Settings<DbSettings> { +impl Settings { pub fn new() -> Result<Self> { let mut config_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { PathBuf::from(p) diff --git a/crates/atuin/Cargo.toml b/crates/atuin/Cargo.toml index 487df85a..a6c8dbec 100644 --- a/crates/atuin/Cargo.toml +++ b/crates/atuin/Cargo.toml @@ -37,12 +37,19 @@ default = ["client", "sync", "server", "clipboard", "check-update", "daemon"] client = ["atuin-client"] sync = ["atuin-client/sync"] daemon = ["atuin-client/daemon", "atuin-daemon"] -server = ["atuin-server", "atuin-server-postgres"] +server = [ + "atuin-server", + "atuin-server-database", + "atuin-server-postgres", + "atuin-server-sqlite", +] clipboard = ["arboard"] check-update = ["atuin-client/check-update"] [dependencies] +atuin-server-database = { path = "../atuin-server-database", version = "18.6.1", optional = true } atuin-server-postgres = { path = "../atuin-server-postgres", version = "18.6.1", optional = true } +atuin-server-sqlite = { path = "../atuin-server-sqlite", version = "18.6.1", optional = true } atuin-server = { path = "../atuin-server", version = "18.6.1", optional = true } atuin-client = { path = "../atuin-client", version = "18.6.1", optional = true, default-features = false } atuin-common = { path = "../atuin-common", version = "18.6.1" } diff --git a/crates/atuin/src/command/server.rs b/crates/atuin/src/command/server.rs index 8611fb56..fc09bd27 100644 --- a/crates/atuin/src/command/server.rs +++ b/crates/atuin/src/command/server.rs @@ -1,10 +1,12 @@ use std::net::SocketAddr; +use atuin_server_database::DbType; use atuin_server_postgres::Postgres; +use atuin_server_sqlite::Sqlite; use tracing_subscriber::{EnvFilter, fmt, prelude::*}; use clap::Parser; -use eyre::{Context, Result}; +use eyre::{Context, Result, eyre}; use atuin_server::{Settings, example_config, launch, launch_metrics_server}; @@ -50,7 +52,13 @@ impl Cmd { )); } - launch::<Postgres>(settings, addr).await + match settings.db_settings.db_type() { + DbType::Postgres => launch::<Postgres>(settings, addr).await, + DbType::Sqlite => launch::<Sqlite>(settings, addr).await, + DbType::Unknown => { + Err(eyre!("db_uri must start with postgres:// or sqlite://")) + } + } } Self::DefaultConfig => { println!("{}", example_config()); diff --git a/crates/atuin/tests/common/mod.rs b/crates/atuin/tests/common/mod.rs index f947d164..d79c13d6 100644 --- a/crates/atuin/tests/common/mod.rs +++ b/crates/atuin/tests/common/mod.rs @@ -3,7 +3,8 @@ use std::{env, time::Duration}; use atuin_client::api_client; use atuin_common::utils::uuid_v7; use atuin_server::{Settings as ServerSettings, launch_with_tcp_listener}; -use atuin_server_postgres::{Postgres, PostgresSettings}; +use atuin_server_database::DbSettings; +use atuin_server_postgres::Postgres; use futures_util::TryFutureExt; use tokio::{net::TcpListener, sync::oneshot, task::JoinHandle}; use tracing::{Dispatch, dispatcher}; @@ -35,7 +36,7 @@ pub async fn start_server(path: &str) -> (String, oneshot::Sender<()>, JoinHandl page_size: 1100, register_webhook_url: None, register_webhook_username: String::new(), - db_settings: PostgresSettings { db_uri }, + db_settings: DbSettings { db_uri }, metrics: atuin_server::settings::Metrics::default(), tls: atuin_server::settings::Tls::default(), mail: atuin_server::settings::Mail::default(), |
