aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.gitignore2
-rw-r--r--Cargo.lock21
-rw-r--r--crates/atuin-server-database/Cargo.toml1
-rw-r--r--crates/atuin-server-database/src/lib.rs50
-rw-r--r--crates/atuin-server-postgres/Cargo.toml1
-rw-r--r--crates/atuin-server-postgres/src/lib.rs27
-rw-r--r--crates/atuin-server-sqlite/Cargo.toml24
-rw-r--r--crates/atuin-server-sqlite/build.rs5
-rw-r--r--crates/atuin-server-sqlite/migrations/20231203124112_create-store.sql17
-rw-r--r--crates/atuin-server-sqlite/migrations/20240108124830_create-history.sql15
-rw-r--r--crates/atuin-server-sqlite/migrations/20240108124831_create-sessions.sql6
-rw-r--r--crates/atuin-server-sqlite/migrations/20240621110730_create-users.sql12
-rw-r--r--crates/atuin-server-sqlite/migrations/20240621110731_create-user-verification-token.sql6
-rw-r--r--crates/atuin-server-sqlite/migrations/20240702094825_create-store-idx-cache.sql10
-rw-r--r--crates/atuin-server-sqlite/src/lib.rs552
-rw-r--r--crates/atuin-server-sqlite/src/wrappers.rs73
-rw-r--r--crates/atuin-server/server.toml1
-rw-r--r--crates/atuin-server/src/lib.rs13
-rw-r--r--crates/atuin-server/src/router.rs4
-rw-r--r--crates/atuin-server/src/settings.rs7
-rw-r--r--crates/atuin/Cargo.toml9
-rw-r--r--crates/atuin/src/command/server.rs12
-rw-r--r--crates/atuin/tests/common/mod.rs5
23 files changed, 824 insertions, 49 deletions
diff --git a/.gitignore b/.gitignore
index 3b61025f..e4e42b21 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,3 +10,5 @@ publish.sh
ui/backend/target
ui/backend/gen
+
+sqlite-server.db*
diff --git a/Cargo.lock b/Cargo.lock
index 5688d2dc..df1b0b21 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -233,7 +233,9 @@ dependencies = [
"atuin-kv",
"atuin-scripts",
"atuin-server",
+ "atuin-server-database",
"atuin-server-postgres",
+ "atuin-server-sqlite",
"clap",
"clap_complete",
"clap_complete_nushell",
@@ -473,6 +475,7 @@ dependencies = [
"serde",
"time",
"tracing",
+ "url",
]
[[package]]
@@ -489,7 +492,23 @@ dependencies = [
"sqlx",
"time",
"tracing",
- "url",
+ "uuid",
+]
+
+[[package]]
+name = "atuin-server-sqlite"
+version = "18.6.1"
+dependencies = [
+ "async-trait",
+ "atuin-common",
+ "atuin-server-database",
+ "eyre",
+ "futures-util",
+ "metrics",
+ "serde",
+ "sqlx",
+ "time",
+ "tracing",
"uuid",
]
diff --git a/crates/atuin-server-database/Cargo.toml b/crates/atuin-server-database/Cargo.toml
index e3e38e3f..823b5d39 100644
--- a/crates/atuin-server-database/Cargo.toml
+++ b/crates/atuin-server-database/Cargo.toml
@@ -17,3 +17,4 @@ time = { workspace = true }
eyre = { workspace = true }
serde = { workspace = true }
async-trait = { workspace = true }
+url = "2.5.2"
diff --git a/crates/atuin-server-database/src/lib.rs b/crates/atuin-server-database/src/lib.rs
index 1c577f59..9df36d14 100644
--- a/crates/atuin-server-database/src/lib.rs
+++ b/crates/atuin-server-database/src/lib.rs
@@ -15,7 +15,7 @@ use self::{
};
use async_trait::async_trait;
use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
-use serde::{Serialize, de::DeserializeOwned};
+use serde::{Deserialize, Serialize};
use time::{Date, Duration, Month, OffsetDateTime, Time, UtcOffset};
use tracing::instrument;
@@ -41,10 +41,54 @@ impl std::error::Error for DbError {}
pub type DbResult<T> = Result<T, DbError>;
+#[derive(Debug, PartialEq)]
+pub enum DbType {
+ Postgres,
+ Sqlite,
+ Unknown,
+}
+
+#[derive(Clone, Deserialize, Serialize)]
+pub struct DbSettings {
+ pub db_uri: String,
+}
+
+impl DbSettings {
+ pub fn db_type(&self) -> DbType {
+ if self.db_uri.starts_with("postgres://") {
+ DbType::Postgres
+ } else if self.db_uri.starts_with("sqlite://") {
+ DbType::Sqlite
+ } else {
+ DbType::Unknown
+ }
+ }
+}
+
+// Do our best to redact passwords so they're not logged in the event of an error.
+impl Debug for DbSettings {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ if self.db_type() == DbType::Postgres {
+ let redacted_uri = url::Url::parse(&self.db_uri)
+ .map(|mut url| {
+ let _ = url.set_password(Some("****"));
+ url.to_string()
+ })
+ .unwrap_or(self.db_uri.clone());
+ f.debug_struct("DbSettings")
+ .field("db_uri", &redacted_uri)
+ .finish()
+ } else {
+ f.debug_struct("DbSettings")
+ .field("db_uri", &self.db_uri)
+ .finish()
+ }
+ }
+}
+
#[async_trait]
pub trait Database: Sized + Clone + Send + Sync + 'static {
- type Settings: Debug + Clone + DeserializeOwned + Serialize + Send + Sync + 'static;
- async fn new(settings: &Self::Settings) -> DbResult<Self>;
+ async fn new(settings: &DbSettings) -> DbResult<Self>;
async fn get_session(&self, token: &str) -> DbResult<Session>;
async fn get_session_user(&self, token: &str) -> DbResult<User>;
diff --git a/crates/atuin-server-postgres/Cargo.toml b/crates/atuin-server-postgres/Cargo.toml
index 9eccca50..f5e472cb 100644
--- a/crates/atuin-server-postgres/Cargo.toml
+++ b/crates/atuin-server-postgres/Cargo.toml
@@ -22,4 +22,3 @@ async-trait = { workspace = true }
uuid = { workspace = true }
metrics = "0.21.1"
futures-util = "0.3"
-url = "2.5.2"
diff --git a/crates/atuin-server-postgres/src/lib.rs b/crates/atuin-server-postgres/src/lib.rs
index 7c6d8f9a..005e8765 100644
--- a/crates/atuin-server-postgres/src/lib.rs
+++ b/crates/atuin-server-postgres/src/lib.rs
@@ -1,15 +1,13 @@
use std::collections::HashMap;
-use std::fmt::Debug;
use std::ops::Range;
use async_trait::async_trait;
use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
use atuin_common::utils::crypto_random_string;
use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User};
-use atuin_server_database::{Database, DbError, DbResult};
+use atuin_server_database::{Database, DbError, DbResult, DbSettings};
use futures_util::TryStreamExt;
use metrics::counter;
-use serde::{Deserialize, Serialize};
use sqlx::Row;
use sqlx::postgres::PgPoolOptions;
@@ -27,26 +25,6 @@ pub struct Postgres {
pool: sqlx::Pool<sqlx::postgres::Postgres>,
}
-#[derive(Clone, Deserialize, Serialize)]
-pub struct PostgresSettings {
- pub db_uri: String,
-}
-
-// Do our best to redact passwords so they're not logged in the event of an error.
-impl Debug for PostgresSettings {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- let redacted_uri = url::Url::parse(&self.db_uri)
- .map(|mut url| {
- let _ = url.set_password(Some("****"));
- url.to_string()
- })
- .unwrap_or(self.db_uri.clone());
- f.debug_struct("PostgresSettings")
- .field("db_uri", &redacted_uri)
- .finish()
- }
-}
-
fn fix_error(error: sqlx::Error) -> DbError {
match error {
sqlx::Error::RowNotFound => DbError::NotFound,
@@ -56,8 +34,7 @@ fn fix_error(error: sqlx::Error) -> DbError {
#[async_trait]
impl Database for Postgres {
- type Settings = PostgresSettings;
- async fn new(settings: &PostgresSettings) -> DbResult<Self> {
+ async fn new(settings: &DbSettings) -> DbResult<Self> {
let pool = PgPoolOptions::new()
.max_connections(100)
.connect(settings.db_uri.as_str())
diff --git a/crates/atuin-server-sqlite/Cargo.toml b/crates/atuin-server-sqlite/Cargo.toml
new file mode 100644
index 00000000..c04604e7
--- /dev/null
+++ b/crates/atuin-server-sqlite/Cargo.toml
@@ -0,0 +1,24 @@
+[package]
+name = "atuin-server-sqlite"
+edition = "2024"
+description = "server sqlite database library for atuin"
+
+version = { workspace = true }
+authors = { workspace = true }
+license = { workspace = true }
+homepage = { workspace = true }
+repository = { workspace = true }
+
+[dependencies]
+atuin-common = { path = "../atuin-common", version = "18.6.1" }
+atuin-server-database = { path = "../atuin-server-database", version = "18.6.1" }
+
+eyre = { workspace = true }
+tracing = { workspace = true }
+time = { workspace = true }
+serde = { workspace = true }
+sqlx = { workspace = true }
+async-trait = { workspace = true }
+uuid = { workspace = true }
+metrics = "0.21.1"
+futures-util = "0.3"
diff --git a/crates/atuin-server-sqlite/build.rs b/crates/atuin-server-sqlite/build.rs
new file mode 100644
index 00000000..d5068697
--- /dev/null
+++ b/crates/atuin-server-sqlite/build.rs
@@ -0,0 +1,5 @@
+// generated by `sqlx migrate build-script`
+fn main() {
+ // trigger recompilation when a new migration is added
+ println!("cargo:rerun-if-changed=migrations");
+}
diff --git a/crates/atuin-server-sqlite/migrations/20231203124112_create-store.sql b/crates/atuin-server-sqlite/migrations/20231203124112_create-store.sql
new file mode 100644
index 00000000..ca19ed62
--- /dev/null
+++ b/crates/atuin-server-sqlite/migrations/20231203124112_create-store.sql
@@ -0,0 +1,17 @@
+create table store (
+ id text primary key, -- remember to use uuidv7 for happy indices <3
+ client_id text not null, -- I am too uncomfortable with the idea of a client-generated primary key, even though it's fine mathematically
+ host text not null, -- a unique identifier for the host
+ idx bigint not null, -- the index of the record in this store, identified by (host, tag)
+ timestamp bigint not null, -- not a timestamp type, as those do not have nanosecond precision
+ version text not null,
+ tag text not null, -- what is this? history, kv, whatever. Remember clients get a log per tag per host
+ data text not null, -- store the actual history data, encrypted. I don't wanna know!
+ cek text not null,
+
+ user_id bigint not null, -- allow multiple users
+ created_at timestamp not null default current_timestamp
+);
+
+create unique index record_uniq ON store(user_id, host, tag, idx);
+
diff --git a/crates/atuin-server-sqlite/migrations/20240108124830_create-history.sql b/crates/atuin-server-sqlite/migrations/20240108124830_create-history.sql
new file mode 100644
index 00000000..7bd653ba
--- /dev/null
+++ b/crates/atuin-server-sqlite/migrations/20240108124830_create-history.sql
@@ -0,0 +1,15 @@
+create table history (
+ id integer primary key autoincrement,
+ client_id text not null unique, -- the client-generated ID
+ user_id bigserial not null, -- allow multiple users
+ hostname text not null, -- a unique identifier from the client (can be hashed, random, whatever)
+ timestamp timestamp not null, -- one of the few non-encrypted metadatas
+
+ data text not null, -- store the actual history data, encrypted. I don't wanna know!
+
+ created_at timestamp not null default current_timestamp,
+ deleted_at timestamp
+);
+
+create unique index history_deleted_index on history(client_id, user_id, deleted_at);
+
diff --git a/crates/atuin-server-sqlite/migrations/20240108124831_create-sessions.sql b/crates/atuin-server-sqlite/migrations/20240108124831_create-sessions.sql
new file mode 100644
index 00000000..3120c35d
--- /dev/null
+++ b/crates/atuin-server-sqlite/migrations/20240108124831_create-sessions.sql
@@ -0,0 +1,6 @@
+create table sessions (
+ id integer primary key autoincrement,
+ user_id integer,
+ token text unique not null
+);
+
diff --git a/crates/atuin-server-sqlite/migrations/20240621110730_create-users.sql b/crates/atuin-server-sqlite/migrations/20240621110730_create-users.sql
new file mode 100644
index 00000000..852c159d
--- /dev/null
+++ b/crates/atuin-server-sqlite/migrations/20240621110730_create-users.sql
@@ -0,0 +1,12 @@
+create table users (
+ id integer primary key autoincrement, -- also store our own ID
+ username text not null unique, -- being able to contact users is useful
+ email text not null unique, -- being able to contact users is useful
+ password text not null unique,
+ created_at timestamp not null default (datetime('now','localtime')),
+ verified_at timestamp with time zone default null
+);
+
+-- the prior index is case sensitive :(
+CREATE UNIQUE INDEX email_unique_idx on users (LOWER(email));
+CREATE UNIQUE INDEX username_unique_idx on users (LOWER(username));
diff --git a/crates/atuin-server-sqlite/migrations/20240621110731_create-user-verification-token.sql b/crates/atuin-server-sqlite/migrations/20240621110731_create-user-verification-token.sql
new file mode 100644
index 00000000..36eb14de
--- /dev/null
+++ b/crates/atuin-server-sqlite/migrations/20240621110731_create-user-verification-token.sql
@@ -0,0 +1,6 @@
+create table user_verification_token(
+ id integer primary key autoincrement,
+ user_id bigint unique references users(id),
+ token text,
+ valid_until timestamp with time zone
+);
diff --git a/crates/atuin-server-sqlite/migrations/20240702094825_create-store-idx-cache.sql b/crates/atuin-server-sqlite/migrations/20240702094825_create-store-idx-cache.sql
new file mode 100644
index 00000000..cd54cb18
--- /dev/null
+++ b/crates/atuin-server-sqlite/migrations/20240702094825_create-store-idx-cache.sql
@@ -0,0 +1,10 @@
+create table store_idx_cache(
+ id integer primary key autoincrement,
+ user_id bigint,
+
+ host uuid,
+ tag text,
+ idx bigint
+);
+
+create unique index store_idx_cache_uniq on store_idx_cache(user_id, host, tag);
diff --git a/crates/atuin-server-sqlite/src/lib.rs b/crates/atuin-server-sqlite/src/lib.rs
new file mode 100644
index 00000000..9cc1e8a7
--- /dev/null
+++ b/crates/atuin-server-sqlite/src/lib.rs
@@ -0,0 +1,552 @@
+use std::str::FromStr;
+
+use async_trait::async_trait;
+use atuin_common::{
+ record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus},
+ utils::crypto_random_string,
+};
+use atuin_server_database::{
+ Database, DbError, DbResult, DbSettings,
+ models::{History, NewHistory, NewSession, NewUser, Session, User},
+};
+use futures_util::TryStreamExt;
+use sqlx::{
+ Row,
+ sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions},
+ types::Uuid,
+};
+use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset};
+use tracing::instrument;
+use wrappers::{DbHistory, DbRecord, DbSession, DbUser};
+
+mod wrappers;
+
+#[derive(Clone)]
+pub struct Sqlite {
+ pool: sqlx::Pool<sqlx::sqlite::Sqlite>,
+}
+
+fn fix_error(error: sqlx::Error) -> DbError {
+ match error {
+ sqlx::Error::RowNotFound => DbError::NotFound,
+ error => DbError::Other(error.into()),
+ }
+}
+
+#[async_trait]
+impl Database for Sqlite {
+ async fn new(settings: &DbSettings) -> DbResult<Self> {
+ let opts = SqliteConnectOptions::from_str(&settings.db_uri)
+ .map_err(fix_error)?
+ .journal_mode(SqliteJournalMode::Wal)
+ .create_if_missing(true);
+
+ let pool = SqlitePoolOptions::new()
+ .connect_with(opts)
+ .await
+ .map_err(fix_error)?;
+
+ sqlx::migrate!("./migrations")
+ .run(&pool)
+ .await
+ .map_err(|error| DbError::Other(error.into()))?;
+
+ Ok(Self { pool })
+ }
+
+ #[instrument(skip_all)]
+ async fn get_session(&self, token: &str) -> DbResult<Session> {
+ sqlx::query_as("select id, user_id, token from sessions where token = $1")
+ .bind(token)
+ .fetch_one(&self.pool)
+ .await
+ .map_err(fix_error)
+ .map(|DbSession(session)| session)
+ }
+
+ #[instrument(skip_all)]
+ async fn get_session_user(&self, token: &str) -> DbResult<User> {
+ sqlx::query_as(
+ "select users.id, users.username, users.email, users.password, users.verified_at from users
+ inner join sessions
+ on users.id = sessions.user_id
+ and sessions.token = $1",
+ )
+ .bind(token)
+ .fetch_one(&self.pool)
+ .await
+ .map_err(fix_error)
+ .map(|DbUser(user)| user)
+ }
+
+ #[instrument(skip_all)]
+ async fn add_session(&self, session: &NewSession) -> DbResult<()> {
+ let token: &str = &session.token;
+
+ sqlx::query(
+ "insert into sessions
+ (user_id, token)
+ values($1, $2)",
+ )
+ .bind(session.user_id)
+ .bind(token)
+ .execute(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ Ok(())
+ }
+
+ #[instrument(skip_all)]
+ async fn get_user(&self, username: &str) -> DbResult<User> {
+ sqlx::query_as(
+ "select id, username, email, password, verified_at from users where username = $1",
+ )
+ .bind(username)
+ .fetch_one(&self.pool)
+ .await
+ .map_err(fix_error)
+ .map(|DbUser(user)| user)
+ }
+
+ #[instrument(skip_all)]
+ async fn get_user_session(&self, u: &User) -> DbResult<Session> {
+ sqlx::query_as("select id, user_id, token from sessions where user_id = $1")
+ .bind(u.id)
+ .fetch_one(&self.pool)
+ .await
+ .map_err(fix_error)
+ .map(|DbSession(session)| session)
+ }
+
+ #[instrument(skip_all)]
+ async fn add_user(&self, user: &NewUser) -> DbResult<i64> {
+ let email: &str = &user.email;
+ let username: &str = &user.username;
+ let password: &str = &user.password;
+
+ let res: (i64,) = sqlx::query_as(
+ "insert into users
+ (username, email, password)
+ values($1, $2, $3)
+ returning id",
+ )
+ .bind(username)
+ .bind(email)
+ .bind(password)
+ .fetch_one(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ Ok(res.0)
+ }
+
+ #[instrument(skip_all)]
+ async fn user_verified(&self, id: i64) -> DbResult<bool> {
+ let res: (bool,) =
+ sqlx::query_as("select verified_at is not null from users where id = $1")
+ .bind(id)
+ .fetch_one(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ Ok(res.0)
+ }
+
+ #[instrument(skip_all)]
+ async fn verify_user(&self, id: i64) -> DbResult<()> {
+ sqlx::query(
+ "update users set verified_at = (current_timestamp at time zone 'utc') where id=$1",
+ )
+ .bind(id)
+ .execute(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ Ok(())
+ }
+
+ #[instrument(skip_all)]
+ async fn user_verification_token(&self, id: i64) -> DbResult<String> {
+ const TOKEN_VALID_MINUTES: i64 = 15;
+
+ // First we check if there is a verification token
+ let token: Option<(String, sqlx::types::time::OffsetDateTime)> = sqlx::query_as(
+ "select token, valid_until from user_verification_token where user_id = $1",
+ )
+ .bind(id)
+ .fetch_optional(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ let token = if let Some((token, valid_until)) = token {
+ // We have a token, AND it's still valid
+ if valid_until > time::OffsetDateTime::now_utc() {
+ token
+ } else {
+ // token has expired. generate a new one, return it
+ let token = crypto_random_string::<24>();
+
+ sqlx::query("update user_verification_token set token = $2, valid_until = $3 where user_id=$1")
+ .bind(id)
+ .bind(&token)
+ .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES))
+ .execute(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ token
+ }
+ } else {
+ // No token in the database! Generate one, insert it
+ let token = crypto_random_string::<24>();
+
+ sqlx::query("insert into user_verification_token (user_id, token, valid_until) values ($1, $2, $3)")
+ .bind(id)
+ .bind(&token)
+ .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES))
+ .execute(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ token
+ };
+
+ Ok(token)
+ }
+
+ #[instrument(skip_all)]
+ async fn update_user_password(&self, user: &User) -> DbResult<()> {
+ sqlx::query(
+ "update users
+ set password = $1
+ where id = $2",
+ )
+ .bind(&user.password)
+ .bind(user.id)
+ .execute(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ Ok(())
+ }
+
+ #[instrument(skip_all)]
+ async fn total_history(&self) -> DbResult<i64> {
+ let res: (i64,) = sqlx::query_as("select count(1) from history")
+ .fetch_optional(&self.pool)
+ .await
+ .map_err(fix_error)?
+ .unwrap_or((0,));
+
+ Ok(res.0)
+ }
+
+ #[instrument(skip_all)]
+ async fn count_history(&self, user: &User) -> DbResult<i64> {
+ // The cache is new, and the user might not yet have a cache value.
+ // They will have one as soon as they post up some new history, but handle that
+ // edge case.
+
+ let res: (i64,) = sqlx::query_as(
+ "select count(1) from history
+ where user_id = $1",
+ )
+ .bind(user.id)
+ .fetch_one(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ Ok(res.0)
+ }
+
+ #[instrument(skip_all)]
+ async fn count_history_cached(&self, _user: &User) -> DbResult<i64> {
+ Err(DbError::NotFound)
+ }
+
+ #[instrument(skip_all)]
+ async fn delete_user(&self, u: &User) -> DbResult<()> {
+ sqlx::query("delete from sessions where user_id = $1")
+ .bind(u.id)
+ .execute(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ sqlx::query("delete from users where id = $1")
+ .bind(u.id)
+ .execute(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ sqlx::query("delete from history where user_id = $1")
+ .bind(u.id)
+ .execute(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ Ok(())
+ }
+
+ async fn delete_history(&self, user: &User, id: String) -> DbResult<()> {
+ sqlx::query(
+ "update history
+ set deleted_at = $3
+ where user_id = $1
+ and client_id = $2
+ and deleted_at is null", // don't just keep setting it
+ )
+ .bind(user.id)
+ .bind(id)
+ .bind(time::OffsetDateTime::now_utc())
+ .fetch_all(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ Ok(())
+ }
+
+ #[instrument(skip_all)]
+ async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> {
+ // The cache is new, and the user might not yet have a cache value.
+ // They will have one as soon as they post up some new history, but handle that
+ // edge case.
+
+ let res = sqlx::query(
+ "select client_id from history
+ where user_id = $1
+ and deleted_at is not null",
+ )
+ .bind(user.id)
+ .fetch_all(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ let res = res.iter().map(|row| row.get("client_id")).collect();
+
+ Ok(res)
+ }
+
+ async fn delete_store(&self, user: &User) -> DbResult<()> {
+ sqlx::query(
+ "delete from store
+ where user_id = $1",
+ )
+ .bind(user.id)
+ .execute(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ Ok(())
+ }
+
+ #[instrument(skip_all)]
+ async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> {
+ let mut tx = self.pool.begin().await.map_err(fix_error)?;
+
+ for i in records {
+ let id = atuin_common::utils::uuid_v7();
+
+ sqlx::query(
+ "insert into store
+ (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id)
+ values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
+ on conflict do nothing
+ ",
+ )
+ .bind(id)
+ .bind(i.id)
+ .bind(i.host.id)
+ .bind(i.idx as i64)
+ .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time
+ .bind(&i.version)
+ .bind(&i.tag)
+ .bind(&i.data.data)
+ .bind(&i.data.content_encryption_key)
+ .bind(user.id)
+ .execute(&mut *tx)
+ .await
+ .map_err(fix_error)?;
+ }
+
+ tx.commit().await.map_err(fix_error)?;
+
+ Ok(())
+ }
+
+ #[instrument(skip_all)]
+ async fn next_records(
+ &self,
+ user: &User,
+ host: HostId,
+ tag: String,
+ start: Option<RecordIdx>,
+ count: u64,
+ ) -> DbResult<Vec<Record<EncryptedData>>> {
+ tracing::debug!("{:?} - {:?} - {:?}", host, tag, start);
+ let start = start.unwrap_or(0);
+
+ let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as(
+ "select client_id, host, idx, timestamp, version, tag, data, cek from store
+ where user_id = $1
+ and tag = $2
+ and host = $3
+ and idx >= $4
+ order by idx asc
+ limit $5",
+ )
+ .bind(user.id)
+ .bind(tag.clone())
+ .bind(host)
+ .bind(start as i64)
+ .bind(count as i64)
+ .fetch_all(&self.pool)
+ .await
+ .map_err(fix_error);
+
+ let ret = match records {
+ Ok(records) => {
+ let records: Vec<Record<EncryptedData>> = records
+ .into_iter()
+ .map(|f| {
+ let record: Record<EncryptedData> = f.into();
+ record
+ })
+ .collect();
+
+ records
+ }
+ Err(DbError::NotFound) => {
+ tracing::debug!("no records found in store: {:?}/{}", host, tag);
+ return Ok(vec![]);
+ }
+ Err(e) => return Err(e),
+ };
+
+ Ok(ret)
+ }
+
+ async fn status(&self, user: &User) -> DbResult<RecordStatus> {
+ const STATUS_SQL: &str =
+ "select host, tag, max(idx) from store where user_id = $1 group by host, tag";
+
+ let res: Vec<(Uuid, String, i64)> = sqlx::query_as(STATUS_SQL)
+ .bind(user.id)
+ .fetch_all(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ let mut status = RecordStatus::new();
+
+ for i in res {
+ status.set_raw(HostId(i.0), i.1, i.2 as u64);
+ }
+
+ Ok(status)
+ }
+
+ #[instrument(skip_all)]
+ async fn count_history_range(
+ &self,
+ user: &User,
+ range: std::ops::Range<time::OffsetDateTime>,
+ ) -> DbResult<i64> {
+ let res: (i64,) = sqlx::query_as(
+ "select count(1) from history
+ where user_id = $1
+ and timestamp >= $2::date
+ and timestamp < $3::date",
+ )
+ .bind(user.id)
+ .bind(into_utc(range.start))
+ .bind(into_utc(range.end))
+ .fetch_one(&self.pool)
+ .await
+ .map_err(fix_error)?;
+
+ Ok(res.0)
+ }
+
+ #[instrument(skip_all)]
+ async fn list_history(
+ &self,
+ user: &User,
+ created_after: time::OffsetDateTime,
+ since: time::OffsetDateTime,
+ host: &str,
+ page_size: i64,
+ ) -> DbResult<Vec<History>> {
+ let res = sqlx::query_as(
+ "select id, client_id, user_id, hostname, timestamp, data, created_at from history
+ where user_id = $1
+ and hostname != $2
+ and created_at >= $3
+ and timestamp >= $4
+ order by timestamp asc
+ limit $5",
+ )
+ .bind(user.id)
+ .bind(host)
+ .bind(into_utc(created_after))
+ .bind(into_utc(since))
+ .bind(page_size)
+ .fetch(&self.pool)
+ .map_ok(|DbHistory(h)| h)
+ .try_collect()
+ .await
+ .map_err(fix_error)?;
+
+ Ok(res)
+ }
+
+ #[instrument(skip_all)]
+ async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> {
+ let mut tx = self.pool.begin().await.map_err(fix_error)?;
+
+ for i in history {
+ let client_id: &str = &i.client_id;
+ let hostname: &str = &i.hostname;
+ let data: &str = &i.data;
+
+ sqlx::query(
+ "insert into history
+ (client_id, user_id, hostname, timestamp, data)
+ values ($1, $2, $3, $4, $5)
+ on conflict do nothing
+ ",
+ )
+ .bind(client_id)
+ .bind(i.user_id)
+ .bind(hostname)
+ .bind(i.timestamp)
+ .bind(data)
+ .execute(&mut *tx)
+ .await
+ .map_err(fix_error)?;
+ }
+
+ tx.commit().await.map_err(fix_error)?;
+
+ Ok(())
+ }
+
+ #[instrument(skip_all)]
+ async fn oldest_history(&self, user: &User) -> DbResult<History> {
+ sqlx::query_as(
+ "select id, client_id, user_id, hostname, timestamp, data, created_at from history
+ where user_id = $1
+ order by timestamp asc
+ limit 1",
+ )
+ .bind(user.id)
+ .fetch_one(&self.pool)
+ .await
+ .map_err(fix_error)
+ .map(|DbHistory(h)| h)
+ }
+}
+
+fn into_utc(x: OffsetDateTime) -> PrimitiveDateTime {
+ let x = x.to_offset(UtcOffset::UTC);
+ PrimitiveDateTime::new(x.date(), x.time())
+}
diff --git a/crates/atuin-server-sqlite/src/wrappers.rs b/crates/atuin-server-sqlite/src/wrappers.rs
new file mode 100644
index 00000000..3f2262c3
--- /dev/null
+++ b/crates/atuin-server-sqlite/src/wrappers.rs
@@ -0,0 +1,73 @@
+use ::sqlx::{FromRow, Result};
+use atuin_common::record::{EncryptedData, Host, Record};
+use atuin_server_database::models::{History, Session, User};
+use sqlx::{Row, sqlite::SqliteRow};
+
+pub struct DbUser(pub User);
+pub struct DbSession(pub Session);
+pub struct DbHistory(pub History);
+pub struct DbRecord(pub Record<EncryptedData>);
+
+impl<'a> FromRow<'a, SqliteRow> for DbUser {
+ fn from_row(row: &'a SqliteRow) -> Result<Self> {
+ Ok(Self(User {
+ id: row.try_get("id")?,
+ username: row.try_get("username")?,
+ email: row.try_get("email")?,
+ password: row.try_get("password")?,
+ verified: row.try_get("verified_at")?,
+ }))
+ }
+}
+
+impl<'a> ::sqlx::FromRow<'a, SqliteRow> for DbSession {
+ fn from_row(row: &'a SqliteRow) -> ::sqlx::Result<Self> {
+ Ok(Self(Session {
+ id: row.try_get("id")?,
+ user_id: row.try_get("user_id")?,
+ token: row.try_get("token")?,
+ }))
+ }
+}
+
+impl<'a> ::sqlx::FromRow<'a, SqliteRow> for DbHistory {
+ fn from_row(row: &'a SqliteRow) -> ::sqlx::Result<Self> {
+ Ok(Self(History {
+ id: row.try_get("id")?,
+ client_id: row.try_get("client_id")?,
+ user_id: row.try_get("user_id")?,
+ hostname: row.try_get("hostname")?,
+ timestamp: row.try_get("timestamp")?,
+ data: row.try_get("data")?,
+ created_at: row.try_get("created_at")?,
+ }))
+ }
+}
+
+impl<'a> ::sqlx::FromRow<'a, SqliteRow> for DbRecord {
+ fn from_row(row: &'a SqliteRow) -> ::sqlx::Result<Self> {
+ let idx: i64 = row.try_get("idx")?;
+ let timestamp: i64 = row.try_get("timestamp")?;
+
+ let data = EncryptedData {
+ data: row.try_get("data")?,
+ content_encryption_key: row.try_get("cek")?,
+ };
+
+ Ok(Self(Record {
+ id: row.try_get("client_id")?,
+ host: Host::new(row.try_get("host")?),
+ idx: idx as u64,
+ timestamp: timestamp as u64,
+ version: row.try_get("version")?,
+ tag: row.try_get("tag")?,
+ data,
+ }))
+ }
+}
+
+impl From<DbRecord> for Record<EncryptedData> {
+ fn from(other: DbRecord) -> Record<EncryptedData> {
+ Record { ..other.0 }
+ }
+}
diff --git a/crates/atuin-server/server.toml b/crates/atuin-server/server.toml
index 946769c9..1eff5b72 100644
--- a/crates/atuin-server/server.toml
+++ b/crates/atuin-server/server.toml
@@ -9,6 +9,7 @@
## URI for postgres (using development creds here)
# db_uri="postgres://username:password@localhost/atuin"
+# db_uri="sqlite:///config/atuin-server.db"
## Maximum size for one history entry
# max_history_length = 8192
diff --git a/crates/atuin-server/src/lib.rs b/crates/atuin-server/src/lib.rs
index 7a0e982b..f1d616f2 100644
--- a/crates/atuin-server/src/lib.rs
+++ b/crates/atuin-server/src/lib.rs
@@ -45,10 +45,7 @@ async fn shutdown_signal() {
eprintln!("Shutting down gracefully...");
}
-pub async fn launch<Db: Database>(
- settings: Settings<Db::Settings>,
- addr: SocketAddr,
-) -> Result<()> {
+pub async fn launch<Db: Database>(settings: Settings, addr: SocketAddr) -> Result<()> {
if settings.tls.enable {
launch_with_tls::<Db>(settings, addr, shutdown_signal()).await
} else {
@@ -64,7 +61,7 @@ pub async fn launch<Db: Database>(
}
pub async fn launch_with_tcp_listener<Db: Database>(
- settings: Settings<Db::Settings>,
+ settings: Settings,
listener: TcpListener,
shutdown: impl Future<Output = ()> + Send + 'static,
) -> Result<()> {
@@ -78,7 +75,7 @@ pub async fn launch_with_tcp_listener<Db: Database>(
}
async fn launch_with_tls<Db: Database>(
- settings: Settings<Db::Settings>,
+ settings: Settings,
addr: SocketAddr,
shutdown: impl Future<Output = ()>,
) -> Result<()> {
@@ -135,9 +132,7 @@ pub async fn launch_metrics_server(host: String, port: u16) -> Result<()> {
Ok(())
}
-async fn make_router<Db: Database>(
- settings: Settings<<Db as Database>::Settings>,
-) -> Result<Router, eyre::Error> {
+async fn make_router<Db: Database>(settings: Settings) -> Result<Router, eyre::Error> {
let db = Db::new(&settings.db_settings)
.await
.wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?;
diff --git a/crates/atuin-server/src/router.rs b/crates/atuin-server/src/router.rs
index ae63e1e8..6d168f63 100644
--- a/crates/atuin-server/src/router.rs
+++ b/crates/atuin-server/src/router.rs
@@ -105,10 +105,10 @@ async fn semver(request: Request, next: Next) -> Response {
#[derive(Clone)]
pub struct AppState<DB: Database> {
pub database: DB,
- pub settings: Settings<DB::Settings>,
+ pub settings: Settings,
}
-pub fn router<DB: Database>(database: DB, settings: Settings<DB::Settings>) -> Router {
+pub fn router<DB: Database>(database: DB, settings: Settings) -> Router {
let routes = Router::new()
.route("/", get(handlers::index))
.route("/healthz", get(handlers::health::health_check))
diff --git a/crates/atuin-server/src/settings.rs b/crates/atuin-server/src/settings.rs
index d5070dae..7221d4dd 100644
--- a/crates/atuin-server/src/settings.rs
+++ b/crates/atuin-server/src/settings.rs
@@ -1,9 +1,10 @@
use std::{io::prelude::*, path::PathBuf};
+use atuin_server_database::DbSettings;
use config::{Config, Environment, File as ConfigFile, FileFormat};
use eyre::{Result, eyre};
use fs_err::{File, create_dir_all};
-use serde::{Deserialize, Serialize, de::DeserializeOwned};
+use serde::{Deserialize, Serialize};
static EXAMPLE_CONFIG: &str = include_str!("../server.toml");
@@ -53,7 +54,7 @@ impl Default for Metrics {
}
#[derive(Clone, Debug, Deserialize, Serialize)]
-pub struct Settings<DbSettings> {
+pub struct Settings {
pub host: String,
pub port: u16,
pub path: String,
@@ -78,7 +79,7 @@ pub struct Settings<DbSettings> {
pub db_settings: DbSettings,
}
-impl<DbSettings: DeserializeOwned> Settings<DbSettings> {
+impl Settings {
pub fn new() -> Result<Self> {
let mut config_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") {
PathBuf::from(p)
diff --git a/crates/atuin/Cargo.toml b/crates/atuin/Cargo.toml
index 487df85a..a6c8dbec 100644
--- a/crates/atuin/Cargo.toml
+++ b/crates/atuin/Cargo.toml
@@ -37,12 +37,19 @@ default = ["client", "sync", "server", "clipboard", "check-update", "daemon"]
client = ["atuin-client"]
sync = ["atuin-client/sync"]
daemon = ["atuin-client/daemon", "atuin-daemon"]
-server = ["atuin-server", "atuin-server-postgres"]
+server = [
+ "atuin-server",
+ "atuin-server-database",
+ "atuin-server-postgres",
+ "atuin-server-sqlite",
+]
clipboard = ["arboard"]
check-update = ["atuin-client/check-update"]
[dependencies]
+atuin-server-database = { path = "../atuin-server-database", version = "18.6.1", optional = true }
atuin-server-postgres = { path = "../atuin-server-postgres", version = "18.6.1", optional = true }
+atuin-server-sqlite = { path = "../atuin-server-sqlite", version = "18.6.1", optional = true }
atuin-server = { path = "../atuin-server", version = "18.6.1", optional = true }
atuin-client = { path = "../atuin-client", version = "18.6.1", optional = true, default-features = false }
atuin-common = { path = "../atuin-common", version = "18.6.1" }
diff --git a/crates/atuin/src/command/server.rs b/crates/atuin/src/command/server.rs
index 8611fb56..fc09bd27 100644
--- a/crates/atuin/src/command/server.rs
+++ b/crates/atuin/src/command/server.rs
@@ -1,10 +1,12 @@
use std::net::SocketAddr;
+use atuin_server_database::DbType;
use atuin_server_postgres::Postgres;
+use atuin_server_sqlite::Sqlite;
use tracing_subscriber::{EnvFilter, fmt, prelude::*};
use clap::Parser;
-use eyre::{Context, Result};
+use eyre::{Context, Result, eyre};
use atuin_server::{Settings, example_config, launch, launch_metrics_server};
@@ -50,7 +52,13 @@ impl Cmd {
));
}
- launch::<Postgres>(settings, addr).await
+ match settings.db_settings.db_type() {
+ DbType::Postgres => launch::<Postgres>(settings, addr).await,
+ DbType::Sqlite => launch::<Sqlite>(settings, addr).await,
+ DbType::Unknown => {
+ Err(eyre!("db_uri must start with postgres:// or sqlite://"))
+ }
+ }
}
Self::DefaultConfig => {
println!("{}", example_config());
diff --git a/crates/atuin/tests/common/mod.rs b/crates/atuin/tests/common/mod.rs
index f947d164..d79c13d6 100644
--- a/crates/atuin/tests/common/mod.rs
+++ b/crates/atuin/tests/common/mod.rs
@@ -3,7 +3,8 @@ use std::{env, time::Duration};
use atuin_client::api_client;
use atuin_common::utils::uuid_v7;
use atuin_server::{Settings as ServerSettings, launch_with_tcp_listener};
-use atuin_server_postgres::{Postgres, PostgresSettings};
+use atuin_server_database::DbSettings;
+use atuin_server_postgres::Postgres;
use futures_util::TryFutureExt;
use tokio::{net::TcpListener, sync::oneshot, task::JoinHandle};
use tracing::{Dispatch, dispatcher};
@@ -35,7 +36,7 @@ pub async fn start_server(path: &str) -> (String, oneshot::Sender<()>, JoinHandl
page_size: 1100,
register_webhook_url: None,
register_webhook_username: String::new(),
- db_settings: PostgresSettings { db_uri },
+ db_settings: DbSettings { db_uri },
metrics: atuin_server::settings::Metrics::default(),
tls: atuin_server::settings::Tls::default(),
mail: atuin_server::settings::Mail::default(),