From 7f868711f0a7c77c868a2dd956fcc594d3d95ec8 Mon Sep 17 00:00:00 2001 From: Scotte Zinn Date: Mon, 23 Jun 2025 07:31:55 -0400 Subject: feat: Add sqlite server support for self-hosting (#2770) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Move db_uri setting to DbSettings * WIP: sqlite crate framework * WIP: Migrations * WIP: sqlite implementation * Add sqlite3 to Docker image * verified_at needed for user query * chore(deps): bump debian (#2772) Bumps debian from bookworm-20250428-slim to bookworm-20250520-slim. --- updated-dependencies: - dependency-name: debian dependency-version: bookworm-20250520-slim dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * fix(doctor): mention the required ble.sh version (#2774) References: https://forum.atuin.sh/t/1047 * fix: Don't print errors in `zsh_autosuggest` helper (#2780) Previously, this would result in long multi-line errors when typing, making it hard to see the shell prompt: ``` $ Error: could not load client settings Caused by: 0: could not create config file 1: failed to create file `/home/jyn/.config/atuin/config.toml` 2: Required key not available (os error 126) Location: atuin-client/src/settings.rs:675:54 fError: could not load client settings Caused by: 0: could not create config file 1: failed to create file `/home/jyn/.config/atuin/config.toml` 2: Required key not available (os error 126) Location: atuin-client/src/settings.rs:675:54 faError: could not load client settings ``` Silence these in autosuggestions, such that they only show up when explicitly invoking atuin. * fix: `atuin.nu` enchancements (#2778) * PR feedback * Remove sqlite3 package * fix(search): prevent panic on malformed format strings (#2776) (#2777) * fix(search): prevent panic on malformed format strings (#2776) - Wrap format operations in panic catcher for graceful error handling - Improve error messages with context-aware guidance for common issues - Let runtime-format parser handle validation to avoid blocking valid formats Fixes crash when using malformed format strings by catching formatting errors gracefully and providing actionable guidance without restricting legitimate format patterns like {command} or {time}. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude * Satisfy cargo fmt * test(search): add regression tests for format string panic (#2776) - Add test for malformed JSON format strings that previously caused panics - Add test to ensure valid format strings continue to work - Prevent future regressions of the format string panic issue 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --------- Co-authored-by: Claude --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Koichi Murase Co-authored-by: jyn Co-authored-by: Tyarel8 <98483313+Tyarel8@users.noreply.github.com> Co-authored-by: Brian Cosgrove Co-authored-by: Claude --- crates/atuin-server-sqlite/src/lib.rs | 552 ++++++++++++++++++++++++++++++++++ 1 file changed, 552 insertions(+) create mode 100644 crates/atuin-server-sqlite/src/lib.rs (limited to 'crates/atuin-server-sqlite/src/lib.rs') diff --git a/crates/atuin-server-sqlite/src/lib.rs b/crates/atuin-server-sqlite/src/lib.rs new file mode 100644 index 00000000..9cc1e8a7 --- /dev/null +++ b/crates/atuin-server-sqlite/src/lib.rs @@ -0,0 +1,552 @@ +use std::str::FromStr; + +use async_trait::async_trait; +use atuin_common::{ + record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}, + utils::crypto_random_string, +}; +use atuin_server_database::{ + Database, DbError, DbResult, DbSettings, + models::{History, NewHistory, NewSession, NewUser, Session, User}, +}; +use futures_util::TryStreamExt; +use sqlx::{ + Row, + sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions}, + types::Uuid, +}; +use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset}; +use tracing::instrument; +use wrappers::{DbHistory, DbRecord, DbSession, DbUser}; + +mod wrappers; + +#[derive(Clone)] +pub struct Sqlite { + pool: sqlx::Pool, +} + +fn fix_error(error: sqlx::Error) -> DbError { + match error { + sqlx::Error::RowNotFound => DbError::NotFound, + error => DbError::Other(error.into()), + } +} + +#[async_trait] +impl Database for Sqlite { + async fn new(settings: &DbSettings) -> DbResult { + let opts = SqliteConnectOptions::from_str(&settings.db_uri) + .map_err(fix_error)? + .journal_mode(SqliteJournalMode::Wal) + .create_if_missing(true); + + let pool = SqlitePoolOptions::new() + .connect_with(opts) + .await + .map_err(fix_error)?; + + sqlx::migrate!("./migrations") + .run(&pool) + .await + .map_err(|error| DbError::Other(error.into()))?; + + Ok(Self { pool }) + } + + #[instrument(skip_all)] + async fn get_session(&self, token: &str) -> DbResult { + sqlx::query_as("select id, user_id, token from sessions where token = $1") + .bind(token) + .fetch_one(&self.pool) + .await + .map_err(fix_error) + .map(|DbSession(session)| session) + } + + #[instrument(skip_all)] + async fn get_session_user(&self, token: &str) -> DbResult { + sqlx::query_as( + "select users.id, users.username, users.email, users.password, users.verified_at from users + inner join sessions + on users.id = sessions.user_id + and sessions.token = $1", + ) + .bind(token) + .fetch_one(&self.pool) + .await + .map_err(fix_error) + .map(|DbUser(user)| user) + } + + #[instrument(skip_all)] + async fn add_session(&self, session: &NewSession) -> DbResult<()> { + let token: &str = &session.token; + + sqlx::query( + "insert into sessions + (user_id, token) + values($1, $2)", + ) + .bind(session.user_id) + .bind(token) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn get_user(&self, username: &str) -> DbResult { + sqlx::query_as( + "select id, username, email, password, verified_at from users where username = $1", + ) + .bind(username) + .fetch_one(&self.pool) + .await + .map_err(fix_error) + .map(|DbUser(user)| user) + } + + #[instrument(skip_all)] + async fn get_user_session(&self, u: &User) -> DbResult { + sqlx::query_as("select id, user_id, token from sessions where user_id = $1") + .bind(u.id) + .fetch_one(&self.pool) + .await + .map_err(fix_error) + .map(|DbSession(session)| session) + } + + #[instrument(skip_all)] + async fn add_user(&self, user: &NewUser) -> DbResult { + let email: &str = &user.email; + let username: &str = &user.username; + let password: &str = &user.password; + + let res: (i64,) = sqlx::query_as( + "insert into users + (username, email, password) + values($1, $2, $3) + returning id", + ) + .bind(username) + .bind(email) + .bind(password) + .fetch_one(&self.pool) + .await + .map_err(fix_error)?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn user_verified(&self, id: i64) -> DbResult { + let res: (bool,) = + sqlx::query_as("select verified_at is not null from users where id = $1") + .bind(id) + .fetch_one(&self.pool) + .await + .map_err(fix_error)?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn verify_user(&self, id: i64) -> DbResult<()> { + sqlx::query( + "update users set verified_at = (current_timestamp at time zone 'utc') where id=$1", + ) + .bind(id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn user_verification_token(&self, id: i64) -> DbResult { + const TOKEN_VALID_MINUTES: i64 = 15; + + // First we check if there is a verification token + let token: Option<(String, sqlx::types::time::OffsetDateTime)> = sqlx::query_as( + "select token, valid_until from user_verification_token where user_id = $1", + ) + .bind(id) + .fetch_optional(&self.pool) + .await + .map_err(fix_error)?; + + let token = if let Some((token, valid_until)) = token { + // We have a token, AND it's still valid + if valid_until > time::OffsetDateTime::now_utc() { + token + } else { + // token has expired. generate a new one, return it + let token = crypto_random_string::<24>(); + + sqlx::query("update user_verification_token set token = $2, valid_until = $3 where user_id=$1") + .bind(id) + .bind(&token) + .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES)) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + token + } + } else { + // No token in the database! Generate one, insert it + let token = crypto_random_string::<24>(); + + sqlx::query("insert into user_verification_token (user_id, token, valid_until) values ($1, $2, $3)") + .bind(id) + .bind(&token) + .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES)) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + token + }; + + Ok(token) + } + + #[instrument(skip_all)] + async fn update_user_password(&self, user: &User) -> DbResult<()> { + sqlx::query( + "update users + set password = $1 + where id = $2", + ) + .bind(&user.password) + .bind(user.id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn total_history(&self) -> DbResult { + let res: (i64,) = sqlx::query_as("select count(1) from history") + .fetch_optional(&self.pool) + .await + .map_err(fix_error)? + .unwrap_or((0,)); + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn count_history(&self, user: &User) -> DbResult { + // The cache is new, and the user might not yet have a cache value. + // They will have one as soon as they post up some new history, but handle that + // edge case. + + let res: (i64,) = sqlx::query_as( + "select count(1) from history + where user_id = $1", + ) + .bind(user.id) + .fetch_one(&self.pool) + .await + .map_err(fix_error)?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn count_history_cached(&self, _user: &User) -> DbResult { + Err(DbError::NotFound) + } + + #[instrument(skip_all)] + async fn delete_user(&self, u: &User) -> DbResult<()> { + sqlx::query("delete from sessions where user_id = $1") + .bind(u.id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + sqlx::query("delete from users where id = $1") + .bind(u.id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + sqlx::query("delete from history where user_id = $1") + .bind(u.id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + Ok(()) + } + + async fn delete_history(&self, user: &User, id: String) -> DbResult<()> { + sqlx::query( + "update history + set deleted_at = $3 + where user_id = $1 + and client_id = $2 + and deleted_at is null", // don't just keep setting it + ) + .bind(user.id) + .bind(id) + .bind(time::OffsetDateTime::now_utc()) + .fetch_all(&self.pool) + .await + .map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn deleted_history(&self, user: &User) -> DbResult> { + // The cache is new, and the user might not yet have a cache value. + // They will have one as soon as they post up some new history, but handle that + // edge case. + + let res = sqlx::query( + "select client_id from history + where user_id = $1 + and deleted_at is not null", + ) + .bind(user.id) + .fetch_all(&self.pool) + .await + .map_err(fix_error)?; + + let res = res.iter().map(|row| row.get("client_id")).collect(); + + Ok(res) + } + + async fn delete_store(&self, user: &User) -> DbResult<()> { + sqlx::query( + "delete from store + where user_id = $1", + ) + .bind(user.id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn add_records(&self, user: &User, records: &[Record]) -> DbResult<()> { + let mut tx = self.pool.begin().await.map_err(fix_error)?; + + for i in records { + let id = atuin_common::utils::uuid_v7(); + + sqlx::query( + "insert into store + (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id) + values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + on conflict do nothing + ", + ) + .bind(id) + .bind(i.id) + .bind(i.host.id) + .bind(i.idx as i64) + .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time + .bind(&i.version) + .bind(&i.tag) + .bind(&i.data.data) + .bind(&i.data.content_encryption_key) + .bind(user.id) + .execute(&mut *tx) + .await + .map_err(fix_error)?; + } + + tx.commit().await.map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn next_records( + &self, + user: &User, + host: HostId, + tag: String, + start: Option, + count: u64, + ) -> DbResult>> { + tracing::debug!("{:?} - {:?} - {:?}", host, tag, start); + let start = start.unwrap_or(0); + + let records: Result, DbError> = sqlx::query_as( + "select client_id, host, idx, timestamp, version, tag, data, cek from store + where user_id = $1 + and tag = $2 + and host = $3 + and idx >= $4 + order by idx asc + limit $5", + ) + .bind(user.id) + .bind(tag.clone()) + .bind(host) + .bind(start as i64) + .bind(count as i64) + .fetch_all(&self.pool) + .await + .map_err(fix_error); + + let ret = match records { + Ok(records) => { + let records: Vec> = records + .into_iter() + .map(|f| { + let record: Record = f.into(); + record + }) + .collect(); + + records + } + Err(DbError::NotFound) => { + tracing::debug!("no records found in store: {:?}/{}", host, tag); + return Ok(vec![]); + } + Err(e) => return Err(e), + }; + + Ok(ret) + } + + async fn status(&self, user: &User) -> DbResult { + const STATUS_SQL: &str = + "select host, tag, max(idx) from store where user_id = $1 group by host, tag"; + + let res: Vec<(Uuid, String, i64)> = sqlx::query_as(STATUS_SQL) + .bind(user.id) + .fetch_all(&self.pool) + .await + .map_err(fix_error)?; + + let mut status = RecordStatus::new(); + + for i in res { + status.set_raw(HostId(i.0), i.1, i.2 as u64); + } + + Ok(status) + } + + #[instrument(skip_all)] + async fn count_history_range( + &self, + user: &User, + range: std::ops::Range, + ) -> DbResult { + let res: (i64,) = sqlx::query_as( + "select count(1) from history + where user_id = $1 + and timestamp >= $2::date + and timestamp < $3::date", + ) + .bind(user.id) + .bind(into_utc(range.start)) + .bind(into_utc(range.end)) + .fetch_one(&self.pool) + .await + .map_err(fix_error)?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn list_history( + &self, + user: &User, + created_after: time::OffsetDateTime, + since: time::OffsetDateTime, + host: &str, + page_size: i64, + ) -> DbResult> { + let res = sqlx::query_as( + "select id, client_id, user_id, hostname, timestamp, data, created_at from history + where user_id = $1 + and hostname != $2 + and created_at >= $3 + and timestamp >= $4 + order by timestamp asc + limit $5", + ) + .bind(user.id) + .bind(host) + .bind(into_utc(created_after)) + .bind(into_utc(since)) + .bind(page_size) + .fetch(&self.pool) + .map_ok(|DbHistory(h)| h) + .try_collect() + .await + .map_err(fix_error)?; + + Ok(res) + } + + #[instrument(skip_all)] + async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> { + let mut tx = self.pool.begin().await.map_err(fix_error)?; + + for i in history { + let client_id: &str = &i.client_id; + let hostname: &str = &i.hostname; + let data: &str = &i.data; + + sqlx::query( + "insert into history + (client_id, user_id, hostname, timestamp, data) + values ($1, $2, $3, $4, $5) + on conflict do nothing + ", + ) + .bind(client_id) + .bind(i.user_id) + .bind(hostname) + .bind(i.timestamp) + .bind(data) + .execute(&mut *tx) + .await + .map_err(fix_error)?; + } + + tx.commit().await.map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn oldest_history(&self, user: &User) -> DbResult { + sqlx::query_as( + "select id, client_id, user_id, hostname, timestamp, data, created_at from history + where user_id = $1 + order by timestamp asc + limit 1", + ) + .bind(user.id) + .fetch_one(&self.pool) + .await + .map_err(fix_error) + .map(|DbHistory(h)| h) + } +} + +fn into_utc(x: OffsetDateTime) -> PrimitiveDateTime { + let x = x.to_offset(UtcOffset::UTC); + PrimitiveDateTime::new(x.date(), x.time()) +} -- cgit v1.3.1