diff options
Diffstat (limited to 'crates/atuin-server-sqlite')
10 files changed, 720 insertions, 0 deletions
diff --git a/crates/atuin-server-sqlite/Cargo.toml b/crates/atuin-server-sqlite/Cargo.toml new file mode 100644 index 00000000..c04604e7 --- /dev/null +++ b/crates/atuin-server-sqlite/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "atuin-server-sqlite" +edition = "2024" +description = "server sqlite database library for atuin" + +version = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } + +[dependencies] +atuin-common = { path = "../atuin-common", version = "18.6.1" } +atuin-server-database = { path = "../atuin-server-database", version = "18.6.1" } + +eyre = { workspace = true } +tracing = { workspace = true } +time = { workspace = true } +serde = { workspace = true } +sqlx = { workspace = true } +async-trait = { workspace = true } +uuid = { workspace = true } +metrics = "0.21.1" +futures-util = "0.3" diff --git a/crates/atuin-server-sqlite/build.rs b/crates/atuin-server-sqlite/build.rs new file mode 100644 index 00000000..d5068697 --- /dev/null +++ b/crates/atuin-server-sqlite/build.rs @@ -0,0 +1,5 @@ +// generated by `sqlx migrate build-script` +fn main() { + // trigger recompilation when a new migration is added + println!("cargo:rerun-if-changed=migrations"); +} diff --git a/crates/atuin-server-sqlite/migrations/20231203124112_create-store.sql b/crates/atuin-server-sqlite/migrations/20231203124112_create-store.sql new file mode 100644 index 00000000..ca19ed62 --- /dev/null +++ b/crates/atuin-server-sqlite/migrations/20231203124112_create-store.sql @@ -0,0 +1,17 @@ +create table store ( + id text primary key, -- remember to use uuidv7 for happy indices <3 + client_id text not null, -- I am too uncomfortable with the idea of a client-generated primary key, even though it's fine mathematically + host text not null, -- a unique identifier for the host + idx bigint not null, -- the index of the record in this store, identified by (host, tag) + timestamp bigint not null, -- not a timestamp type, as those do not have nanosecond precision + version text not null, + tag text not null, -- what is this? history, kv, whatever. Remember clients get a log per tag per host + data text not null, -- store the actual history data, encrypted. I don't wanna know! + cek text not null, + + user_id bigint not null, -- allow multiple users + created_at timestamp not null default current_timestamp +); + +create unique index record_uniq ON store(user_id, host, tag, idx); + diff --git a/crates/atuin-server-sqlite/migrations/20240108124830_create-history.sql b/crates/atuin-server-sqlite/migrations/20240108124830_create-history.sql new file mode 100644 index 00000000..7bd653ba --- /dev/null +++ b/crates/atuin-server-sqlite/migrations/20240108124830_create-history.sql @@ -0,0 +1,15 @@ +create table history ( + id integer primary key autoincrement, + client_id text not null unique, -- the client-generated ID + user_id bigserial not null, -- allow multiple users + hostname text not null, -- a unique identifier from the client (can be hashed, random, whatever) + timestamp timestamp not null, -- one of the few non-encrypted metadatas + + data text not null, -- store the actual history data, encrypted. I don't wanna know! + + created_at timestamp not null default current_timestamp, + deleted_at timestamp +); + +create unique index history_deleted_index on history(client_id, user_id, deleted_at); + diff --git a/crates/atuin-server-sqlite/migrations/20240108124831_create-sessions.sql b/crates/atuin-server-sqlite/migrations/20240108124831_create-sessions.sql new file mode 100644 index 00000000..3120c35d --- /dev/null +++ b/crates/atuin-server-sqlite/migrations/20240108124831_create-sessions.sql @@ -0,0 +1,6 @@ +create table sessions ( + id integer primary key autoincrement, + user_id integer, + token text unique not null +); + diff --git a/crates/atuin-server-sqlite/migrations/20240621110730_create-users.sql b/crates/atuin-server-sqlite/migrations/20240621110730_create-users.sql new file mode 100644 index 00000000..852c159d --- /dev/null +++ b/crates/atuin-server-sqlite/migrations/20240621110730_create-users.sql @@ -0,0 +1,12 @@ +create table users ( + id integer primary key autoincrement, -- also store our own ID + username text not null unique, -- being able to contact users is useful + email text not null unique, -- being able to contact users is useful + password text not null unique, + created_at timestamp not null default (datetime('now','localtime')), + verified_at timestamp with time zone default null +); + +-- the prior index is case sensitive :( +CREATE UNIQUE INDEX email_unique_idx on users (LOWER(email)); +CREATE UNIQUE INDEX username_unique_idx on users (LOWER(username)); diff --git a/crates/atuin-server-sqlite/migrations/20240621110731_create-user-verification-token.sql b/crates/atuin-server-sqlite/migrations/20240621110731_create-user-verification-token.sql new file mode 100644 index 00000000..36eb14de --- /dev/null +++ b/crates/atuin-server-sqlite/migrations/20240621110731_create-user-verification-token.sql @@ -0,0 +1,6 @@ +create table user_verification_token( + id integer primary key autoincrement, + user_id bigint unique references users(id), + token text, + valid_until timestamp with time zone +); diff --git a/crates/atuin-server-sqlite/migrations/20240702094825_create-store-idx-cache.sql b/crates/atuin-server-sqlite/migrations/20240702094825_create-store-idx-cache.sql new file mode 100644 index 00000000..cd54cb18 --- /dev/null +++ b/crates/atuin-server-sqlite/migrations/20240702094825_create-store-idx-cache.sql @@ -0,0 +1,10 @@ +create table store_idx_cache( + id integer primary key autoincrement, + user_id bigint, + + host uuid, + tag text, + idx bigint +); + +create unique index store_idx_cache_uniq on store_idx_cache(user_id, host, tag); diff --git a/crates/atuin-server-sqlite/src/lib.rs b/crates/atuin-server-sqlite/src/lib.rs new file mode 100644 index 00000000..9cc1e8a7 --- /dev/null +++ b/crates/atuin-server-sqlite/src/lib.rs @@ -0,0 +1,552 @@ +use std::str::FromStr; + +use async_trait::async_trait; +use atuin_common::{ + record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}, + utils::crypto_random_string, +}; +use atuin_server_database::{ + Database, DbError, DbResult, DbSettings, + models::{History, NewHistory, NewSession, NewUser, Session, User}, +}; +use futures_util::TryStreamExt; +use sqlx::{ + Row, + sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions}, + types::Uuid, +}; +use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset}; +use tracing::instrument; +use wrappers::{DbHistory, DbRecord, DbSession, DbUser}; + +mod wrappers; + +#[derive(Clone)] +pub struct Sqlite { + pool: sqlx::Pool<sqlx::sqlite::Sqlite>, +} + +fn fix_error(error: sqlx::Error) -> DbError { + match error { + sqlx::Error::RowNotFound => DbError::NotFound, + error => DbError::Other(error.into()), + } +} + +#[async_trait] +impl Database for Sqlite { + async fn new(settings: &DbSettings) -> DbResult<Self> { + let opts = SqliteConnectOptions::from_str(&settings.db_uri) + .map_err(fix_error)? + .journal_mode(SqliteJournalMode::Wal) + .create_if_missing(true); + + let pool = SqlitePoolOptions::new() + .connect_with(opts) + .await + .map_err(fix_error)?; + + sqlx::migrate!("./migrations") + .run(&pool) + .await + .map_err(|error| DbError::Other(error.into()))?; + + Ok(Self { pool }) + } + + #[instrument(skip_all)] + async fn get_session(&self, token: &str) -> DbResult<Session> { + sqlx::query_as("select id, user_id, token from sessions where token = $1") + .bind(token) + .fetch_one(&self.pool) + .await + .map_err(fix_error) + .map(|DbSession(session)| session) + } + + #[instrument(skip_all)] + async fn get_session_user(&self, token: &str) -> DbResult<User> { + sqlx::query_as( + "select users.id, users.username, users.email, users.password, users.verified_at from users + inner join sessions + on users.id = sessions.user_id + and sessions.token = $1", + ) + .bind(token) + .fetch_one(&self.pool) + .await + .map_err(fix_error) + .map(|DbUser(user)| user) + } + + #[instrument(skip_all)] + async fn add_session(&self, session: &NewSession) -> DbResult<()> { + let token: &str = &session.token; + + sqlx::query( + "insert into sessions + (user_id, token) + values($1, $2)", + ) + .bind(session.user_id) + .bind(token) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn get_user(&self, username: &str) -> DbResult<User> { + sqlx::query_as( + "select id, username, email, password, verified_at from users where username = $1", + ) + .bind(username) + .fetch_one(&self.pool) + .await + .map_err(fix_error) + .map(|DbUser(user)| user) + } + + #[instrument(skip_all)] + async fn get_user_session(&self, u: &User) -> DbResult<Session> { + sqlx::query_as("select id, user_id, token from sessions where user_id = $1") + .bind(u.id) + .fetch_one(&self.pool) + .await + .map_err(fix_error) + .map(|DbSession(session)| session) + } + + #[instrument(skip_all)] + async fn add_user(&self, user: &NewUser) -> DbResult<i64> { + let email: &str = &user.email; + let username: &str = &user.username; + let password: &str = &user.password; + + let res: (i64,) = sqlx::query_as( + "insert into users + (username, email, password) + values($1, $2, $3) + returning id", + ) + .bind(username) + .bind(email) + .bind(password) + .fetch_one(&self.pool) + .await + .map_err(fix_error)?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn user_verified(&self, id: i64) -> DbResult<bool> { + let res: (bool,) = + sqlx::query_as("select verified_at is not null from users where id = $1") + .bind(id) + .fetch_one(&self.pool) + .await + .map_err(fix_error)?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn verify_user(&self, id: i64) -> DbResult<()> { + sqlx::query( + "update users set verified_at = (current_timestamp at time zone 'utc') where id=$1", + ) + .bind(id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn user_verification_token(&self, id: i64) -> DbResult<String> { + const TOKEN_VALID_MINUTES: i64 = 15; + + // First we check if there is a verification token + let token: Option<(String, sqlx::types::time::OffsetDateTime)> = sqlx::query_as( + "select token, valid_until from user_verification_token where user_id = $1", + ) + .bind(id) + .fetch_optional(&self.pool) + .await + .map_err(fix_error)?; + + let token = if let Some((token, valid_until)) = token { + // We have a token, AND it's still valid + if valid_until > time::OffsetDateTime::now_utc() { + token + } else { + // token has expired. generate a new one, return it + let token = crypto_random_string::<24>(); + + sqlx::query("update user_verification_token set token = $2, valid_until = $3 where user_id=$1") + .bind(id) + .bind(&token) + .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES)) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + token + } + } else { + // No token in the database! Generate one, insert it + let token = crypto_random_string::<24>(); + + sqlx::query("insert into user_verification_token (user_id, token, valid_until) values ($1, $2, $3)") + .bind(id) + .bind(&token) + .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES)) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + token + }; + + Ok(token) + } + + #[instrument(skip_all)] + async fn update_user_password(&self, user: &User) -> DbResult<()> { + sqlx::query( + "update users + set password = $1 + where id = $2", + ) + .bind(&user.password) + .bind(user.id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn total_history(&self) -> DbResult<i64> { + let res: (i64,) = sqlx::query_as("select count(1) from history") + .fetch_optional(&self.pool) + .await + .map_err(fix_error)? + .unwrap_or((0,)); + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn count_history(&self, user: &User) -> DbResult<i64> { + // The cache is new, and the user might not yet have a cache value. + // They will have one as soon as they post up some new history, but handle that + // edge case. + + let res: (i64,) = sqlx::query_as( + "select count(1) from history + where user_id = $1", + ) + .bind(user.id) + .fetch_one(&self.pool) + .await + .map_err(fix_error)?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn count_history_cached(&self, _user: &User) -> DbResult<i64> { + Err(DbError::NotFound) + } + + #[instrument(skip_all)] + async fn delete_user(&self, u: &User) -> DbResult<()> { + sqlx::query("delete from sessions where user_id = $1") + .bind(u.id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + sqlx::query("delete from users where id = $1") + .bind(u.id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + sqlx::query("delete from history where user_id = $1") + .bind(u.id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + Ok(()) + } + + async fn delete_history(&self, user: &User, id: String) -> DbResult<()> { + sqlx::query( + "update history + set deleted_at = $3 + where user_id = $1 + and client_id = $2 + and deleted_at is null", // don't just keep setting it + ) + .bind(user.id) + .bind(id) + .bind(time::OffsetDateTime::now_utc()) + .fetch_all(&self.pool) + .await + .map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> { + // The cache is new, and the user might not yet have a cache value. + // They will have one as soon as they post up some new history, but handle that + // edge case. + + let res = sqlx::query( + "select client_id from history + where user_id = $1 + and deleted_at is not null", + ) + .bind(user.id) + .fetch_all(&self.pool) + .await + .map_err(fix_error)?; + + let res = res.iter().map(|row| row.get("client_id")).collect(); + + Ok(res) + } + + async fn delete_store(&self, user: &User) -> DbResult<()> { + sqlx::query( + "delete from store + where user_id = $1", + ) + .bind(user.id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> { + let mut tx = self.pool.begin().await.map_err(fix_error)?; + + for i in records { + let id = atuin_common::utils::uuid_v7(); + + sqlx::query( + "insert into store + (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id) + values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + on conflict do nothing + ", + ) + .bind(id) + .bind(i.id) + .bind(i.host.id) + .bind(i.idx as i64) + .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time + .bind(&i.version) + .bind(&i.tag) + .bind(&i.data.data) + .bind(&i.data.content_encryption_key) + .bind(user.id) + .execute(&mut *tx) + .await + .map_err(fix_error)?; + } + + tx.commit().await.map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn next_records( + &self, + user: &User, + host: HostId, + tag: String, + start: Option<RecordIdx>, + count: u64, + ) -> DbResult<Vec<Record<EncryptedData>>> { + tracing::debug!("{:?} - {:?} - {:?}", host, tag, start); + let start = start.unwrap_or(0); + + let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as( + "select client_id, host, idx, timestamp, version, tag, data, cek from store + where user_id = $1 + and tag = $2 + and host = $3 + and idx >= $4 + order by idx asc + limit $5", + ) + .bind(user.id) + .bind(tag.clone()) + .bind(host) + .bind(start as i64) + .bind(count as i64) + .fetch_all(&self.pool) + .await + .map_err(fix_error); + + let ret = match records { + Ok(records) => { + let records: Vec<Record<EncryptedData>> = records + .into_iter() + .map(|f| { + let record: Record<EncryptedData> = f.into(); + record + }) + .collect(); + + records + } + Err(DbError::NotFound) => { + tracing::debug!("no records found in store: {:?}/{}", host, tag); + return Ok(vec![]); + } + Err(e) => return Err(e), + }; + + Ok(ret) + } + + async fn status(&self, user: &User) -> DbResult<RecordStatus> { + const STATUS_SQL: &str = + "select host, tag, max(idx) from store where user_id = $1 group by host, tag"; + + let res: Vec<(Uuid, String, i64)> = sqlx::query_as(STATUS_SQL) + .bind(user.id) + .fetch_all(&self.pool) + .await + .map_err(fix_error)?; + + let mut status = RecordStatus::new(); + + for i in res { + status.set_raw(HostId(i.0), i.1, i.2 as u64); + } + + Ok(status) + } + + #[instrument(skip_all)] + async fn count_history_range( + &self, + user: &User, + range: std::ops::Range<time::OffsetDateTime>, + ) -> DbResult<i64> { + let res: (i64,) = sqlx::query_as( + "select count(1) from history + where user_id = $1 + and timestamp >= $2::date + and timestamp < $3::date", + ) + .bind(user.id) + .bind(into_utc(range.start)) + .bind(into_utc(range.end)) + .fetch_one(&self.pool) + .await + .map_err(fix_error)?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn list_history( + &self, + user: &User, + created_after: time::OffsetDateTime, + since: time::OffsetDateTime, + host: &str, + page_size: i64, + ) -> DbResult<Vec<History>> { + let res = sqlx::query_as( + "select id, client_id, user_id, hostname, timestamp, data, created_at from history + where user_id = $1 + and hostname != $2 + and created_at >= $3 + and timestamp >= $4 + order by timestamp asc + limit $5", + ) + .bind(user.id) + .bind(host) + .bind(into_utc(created_after)) + .bind(into_utc(since)) + .bind(page_size) + .fetch(&self.pool) + .map_ok(|DbHistory(h)| h) + .try_collect() + .await + .map_err(fix_error)?; + + Ok(res) + } + + #[instrument(skip_all)] + async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> { + let mut tx = self.pool.begin().await.map_err(fix_error)?; + + for i in history { + let client_id: &str = &i.client_id; + let hostname: &str = &i.hostname; + let data: &str = &i.data; + + sqlx::query( + "insert into history + (client_id, user_id, hostname, timestamp, data) + values ($1, $2, $3, $4, $5) + on conflict do nothing + ", + ) + .bind(client_id) + .bind(i.user_id) + .bind(hostname) + .bind(i.timestamp) + .bind(data) + .execute(&mut *tx) + .await + .map_err(fix_error)?; + } + + tx.commit().await.map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn oldest_history(&self, user: &User) -> DbResult<History> { + sqlx::query_as( + "select id, client_id, user_id, hostname, timestamp, data, created_at from history + where user_id = $1 + order by timestamp asc + limit 1", + ) + .bind(user.id) + .fetch_one(&self.pool) + .await + .map_err(fix_error) + .map(|DbHistory(h)| h) + } +} + +fn into_utc(x: OffsetDateTime) -> PrimitiveDateTime { + let x = x.to_offset(UtcOffset::UTC); + PrimitiveDateTime::new(x.date(), x.time()) +} diff --git a/crates/atuin-server-sqlite/src/wrappers.rs b/crates/atuin-server-sqlite/src/wrappers.rs new file mode 100644 index 00000000..3f2262c3 --- /dev/null +++ b/crates/atuin-server-sqlite/src/wrappers.rs @@ -0,0 +1,73 @@ +use ::sqlx::{FromRow, Result}; +use atuin_common::record::{EncryptedData, Host, Record}; +use atuin_server_database::models::{History, Session, User}; +use sqlx::{Row, sqlite::SqliteRow}; + +pub struct DbUser(pub User); +pub struct DbSession(pub Session); +pub struct DbHistory(pub History); +pub struct DbRecord(pub Record<EncryptedData>); + +impl<'a> FromRow<'a, SqliteRow> for DbUser { + fn from_row(row: &'a SqliteRow) -> Result<Self> { + Ok(Self(User { + id: row.try_get("id")?, + username: row.try_get("username")?, + email: row.try_get("email")?, + password: row.try_get("password")?, + verified: row.try_get("verified_at")?, + })) + } +} + +impl<'a> ::sqlx::FromRow<'a, SqliteRow> for DbSession { + fn from_row(row: &'a SqliteRow) -> ::sqlx::Result<Self> { + Ok(Self(Session { + id: row.try_get("id")?, + user_id: row.try_get("user_id")?, + token: row.try_get("token")?, + })) + } +} + +impl<'a> ::sqlx::FromRow<'a, SqliteRow> for DbHistory { + fn from_row(row: &'a SqliteRow) -> ::sqlx::Result<Self> { + Ok(Self(History { + id: row.try_get("id")?, + client_id: row.try_get("client_id")?, + user_id: row.try_get("user_id")?, + hostname: row.try_get("hostname")?, + timestamp: row.try_get("timestamp")?, + data: row.try_get("data")?, + created_at: row.try_get("created_at")?, + })) + } +} + +impl<'a> ::sqlx::FromRow<'a, SqliteRow> for DbRecord { + fn from_row(row: &'a SqliteRow) -> ::sqlx::Result<Self> { + let idx: i64 = row.try_get("idx")?; + let timestamp: i64 = row.try_get("timestamp")?; + + let data = EncryptedData { + data: row.try_get("data")?, + content_encryption_key: row.try_get("cek")?, + }; + + Ok(Self(Record { + id: row.try_get("client_id")?, + host: Host::new(row.try_get("host")?), + idx: idx as u64, + timestamp: timestamp as u64, + version: row.try_get("version")?, + tag: row.try_get("tag")?, + data, + })) + } +} + +impl From<DbRecord> for Record<EncryptedData> { + fn from(other: DbRecord) -> Record<EncryptedData> { + Record { ..other.0 } + } +} |
