aboutsummaryrefslogtreecommitdiffstats
path: root/crates/turtle/src/atuin_server_postgres
diff options
context:
space:
mode:
authorBenedikt Peetz <benedikt.peetz@b-peetz.de>2026-06-11 00:54:30 +0200
committerBenedikt Peetz <benedikt.peetz@b-peetz.de>2026-06-11 00:54:30 +0200
commit5c39e7cf284a1f6e9a1657f2deb44e359fc47eb8 (patch)
treec64baa8d5866c8e339eaf660dd3f94f30a3f7d8a /crates/turtle/src/atuin_server_postgres
parentchore: Somewhat simplify sync code (diff)
downloadatuin-5c39e7cf284a1f6e9a1657f2deb44e359fc47eb8.zip
chore: Move everything into one big crate
That helps remove duplicated code and rustc/cargo will now also show dead code correctly.
Diffstat (limited to 'crates/turtle/src/atuin_server_postgres')
-rw-r--r--crates/turtle/src/atuin_server_postgres/mod.rs583
-rw-r--r--crates/turtle/src/atuin_server_postgres/wrappers.rs77
2 files changed, 660 insertions, 0 deletions
diff --git a/crates/turtle/src/atuin_server_postgres/mod.rs b/crates/turtle/src/atuin_server_postgres/mod.rs
new file mode 100644
index 00000000..f506cf25
--- /dev/null
+++ b/crates/turtle/src/atuin_server_postgres/mod.rs
@@ -0,0 +1,583 @@
+use std::collections::HashMap;
+use std::ops::Range;
+
+use rand::Rng;
+
+use crate::atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
+use crate::atuin_server_database::models::{
+ History, NewHistory, NewSession, NewUser, Session, User,
+};
+use crate::atuin_server_database::{Database, DbError, DbResult, DbSettings, into_utc};
+use async_trait::async_trait;
+use futures_util::TryStreamExt;
+use sqlx::Row;
+use sqlx::postgres::PgPoolOptions;
+
+use time::OffsetDateTime;
+use tracing::instrument;
+use uuid::Uuid;
+use wrappers::{DbHistory, DbRecord, DbSession, DbUser};
+
+mod wrappers;
+
+const MIN_PG_VERSION: u32 = 14;
+
+#[derive(Clone)]
+pub struct Postgres {
+ pool: sqlx::Pool<sqlx::postgres::Postgres>,
+ /// Optional read replica pool for read-only queries
+ read_pool: Option<sqlx::Pool<sqlx::postgres::Postgres>>,
+}
+
+impl Postgres {
+ /// Returns the appropriate pool for read operations.
+ /// Uses read_pool if available, otherwise falls back to the primary pool.
+ fn read_pool(&self) -> &sqlx::Pool<sqlx::postgres::Postgres> {
+ self.read_pool.as_ref().unwrap_or(&self.pool)
+ }
+}
+
+#[async_trait]
+impl Database for Postgres {
+ async fn new(settings: &DbSettings) -> DbResult<Self> {
+ let pool = PgPoolOptions::new()
+ .max_connections(100)
+ .connect(settings.db_uri.as_str())
+ .await?;
+
+ // Call server_version_num to get the DB server's major version number
+ // The call returns None for servers older than 8.x.
+ let pg_major_version: u32 =
+ pool.acquire()
+ .await?
+ .server_version_num()
+ .ok_or(DbError::Other(eyre::Report::msg(
+ "could not get PostgreSQL version",
+ )))?
+ / 10000;
+
+ if pg_major_version < MIN_PG_VERSION {
+ return Err(DbError::Other(eyre::Report::msg(format!(
+ "unsupported PostgreSQL version {pg_major_version}, minimum required is {MIN_PG_VERSION}"
+ ))));
+ }
+
+ sqlx::migrate!("./migrations")
+ .run(&pool)
+ .await
+ .map_err(|error| DbError::Other(error.into()))?;
+
+ // Create read replica pool if configured
+ let read_pool = if let Some(read_db_uri) = &settings.read_db_uri {
+ tracing::info!("Connecting to read replica database");
+ let read_pool = PgPoolOptions::new()
+ .max_connections(100)
+ .connect(read_db_uri.as_str())
+ .await?;
+
+ // Verify the read replica is also a supported PostgreSQL version
+ let read_pg_major_version: u32 = read_pool
+ .acquire()
+ .await?
+ .server_version_num()
+ .ok_or(DbError::Other(eyre::Report::msg(
+ "could not get PostgreSQL version from read replica",
+ )))?
+ / 10000;
+
+ if read_pg_major_version < MIN_PG_VERSION {
+ return Err(DbError::Other(eyre::Report::msg(format!(
+ "unsupported PostgreSQL version {read_pg_major_version} on read replica, minimum required is {MIN_PG_VERSION}"
+ ))));
+ }
+
+ Some(read_pool)
+ } else {
+ None
+ };
+
+ Ok(Self { pool, read_pool })
+ }
+
+ #[instrument(skip_all)]
+ async fn get_session(&self, token: &str) -> DbResult<Session> {
+ sqlx::query_as("select id, user_id, token from sessions where token = $1")
+ .bind(token)
+ .fetch_one(self.read_pool())
+ .await
+ .map_err(Into::into)
+ .map(|DbSession(session)| session)
+ }
+
+ #[instrument(skip_all)]
+ async fn get_user(&self, username: &str) -> DbResult<User> {
+ sqlx::query_as("select id, username, email, password from users where username = $1")
+ .bind(username)
+ .fetch_one(self.read_pool())
+ .await
+ .map_err(Into::into)
+ .map(|DbUser(user)| user)
+ }
+
+ #[instrument(skip_all)]
+ async fn get_session_user(&self, token: &str) -> DbResult<User> {
+ sqlx::query_as(
+ "select users.id, users.username, users.email, users.password from users
+ inner join sessions
+ on users.id = sessions.user_id
+ and sessions.token = $1",
+ )
+ .bind(token)
+ .fetch_one(self.read_pool())
+ .await
+ .map_err(Into::into)
+ .map(|DbUser(user)| user)
+ }
+
+ #[instrument(skip_all)]
+ async fn count_history(&self, user: &User) -> DbResult<i64> {
+ // The cache is new, and the user might not yet have a cache value.
+ // They will have one as soon as they post up some new history, but handle that
+ // edge case.
+
+ let res: (i64,) = sqlx::query_as(
+ "select count(1) from history
+ where user_id = $1",
+ )
+ .bind(user.id)
+ .fetch_one(self.read_pool())
+ .await?;
+
+ Ok(res.0)
+ }
+
+ #[instrument(skip_all)]
+ async fn count_history_cached(&self, user: &User) -> DbResult<i64> {
+ let res: (i32,) = sqlx::query_as(
+ "select total from total_history_count_user
+ where user_id = $1",
+ )
+ .bind(user.id)
+ .fetch_one(self.read_pool())
+ .await?;
+
+ Ok(res.0 as i64)
+ }
+
+ async fn delete_store(&self, user: &User) -> DbResult<()> {
+ let mut tx = self.pool.begin().await?;
+
+ sqlx::query(
+ "delete from store
+ where user_id = $1",
+ )
+ .bind(user.id)
+ .execute(&mut *tx)
+ .await?;
+
+ sqlx::query(
+ "delete from store_idx_cache
+ where user_id = $1",
+ )
+ .bind(user.id)
+ .execute(&mut *tx)
+ .await?;
+
+ tx.commit().await?;
+
+ Ok(())
+ }
+
+ async fn delete_history(&self, user: &User, id: String) -> DbResult<()> {
+ sqlx::query(
+ "update history
+ set deleted_at = $3
+ where user_id = $1
+ and client_id = $2
+ and deleted_at is null", // don't just keep setting it
+ )
+ .bind(user.id)
+ .bind(id)
+ .bind(OffsetDateTime::now_utc())
+ .fetch_all(&self.pool)
+ .await?;
+
+ Ok(())
+ }
+
+ #[instrument(skip_all)]
+ async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> {
+ // The cache is new, and the user might not yet have a cache value.
+ // They will have one as soon as they post up some new history, but handle that
+ // edge case.
+
+ let res = sqlx::query(
+ "select client_id from history
+ where user_id = $1
+ and deleted_at is not null",
+ )
+ .bind(user.id)
+ .fetch_all(self.read_pool())
+ .await?;
+
+ let res = res
+ .iter()
+ .map(|row| row.get::<String, _>("client_id"))
+ .collect();
+
+ Ok(res)
+ }
+
+ #[instrument(skip_all)]
+ async fn count_history_range(
+ &self,
+ user: &User,
+ range: Range<OffsetDateTime>,
+ ) -> DbResult<i64> {
+ let res: (i64,) = sqlx::query_as(
+ "select count(1) from history
+ where user_id = $1
+ and timestamp >= $2::date
+ and timestamp < $3::date",
+ )
+ .bind(user.id)
+ .bind(into_utc(range.start))
+ .bind(into_utc(range.end))
+ .fetch_one(self.read_pool())
+ .await?;
+
+ Ok(res.0)
+ }
+
+ #[instrument(skip_all)]
+ async fn list_history(
+ &self,
+ user: &User,
+ created_after: OffsetDateTime,
+ since: OffsetDateTime,
+ host: &str,
+ page_size: i64,
+ ) -> DbResult<Vec<History>> {
+ let res = sqlx::query_as(
+ "select id, client_id, user_id, hostname, timestamp, data, created_at from history
+ where user_id = $1
+ and hostname != $2
+ and created_at >= $3
+ and timestamp >= $4
+ order by timestamp asc
+ limit $5",
+ )
+ .bind(user.id)
+ .bind(host)
+ .bind(into_utc(created_after))
+ .bind(into_utc(since))
+ .bind(page_size)
+ .fetch(self.read_pool())
+ .map_ok(|DbHistory(h)| h)
+ .try_collect()
+ .await?;
+
+ Ok(res)
+ }
+
+ #[instrument(skip_all)]
+ async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> {
+ let mut tx = self.pool.begin().await?;
+
+ for i in history {
+ let client_id: &str = &i.client_id;
+ let hostname: &str = &i.hostname;
+ let data: &str = &i.data;
+
+ sqlx::query(
+ "insert into history
+ (client_id, user_id, hostname, timestamp, data)
+ values ($1, $2, $3, $4, $5)
+ on conflict do nothing
+ ",
+ )
+ .bind(client_id)
+ .bind(i.user_id)
+ .bind(hostname)
+ .bind(i.timestamp)
+ .bind(data)
+ .execute(&mut *tx)
+ .await?;
+ }
+
+ tx.commit().await?;
+
+ Ok(())
+ }
+
+ #[instrument(skip_all)]
+ async fn delete_user(&self, u: &User) -> DbResult<()> {
+ sqlx::query("delete from sessions where user_id = $1")
+ .bind(u.id)
+ .execute(&self.pool)
+ .await?;
+
+ sqlx::query("delete from history where user_id = $1")
+ .bind(u.id)
+ .execute(&self.pool)
+ .await?;
+
+ sqlx::query("delete from store where user_id = $1")
+ .bind(u.id)
+ .execute(&self.pool)
+ .await?;
+
+ sqlx::query("delete from total_history_count_user where user_id = $1")
+ .bind(u.id)
+ .execute(&self.pool)
+ .await?;
+
+ sqlx::query("delete from users where id = $1")
+ .bind(u.id)
+ .execute(&self.pool)
+ .await?;
+
+ Ok(())
+ }
+
+ #[instrument(skip_all)]
+ async fn update_user_password(&self, user: &User) -> DbResult<()> {
+ sqlx::query(
+ "update users
+ set password = $1
+ where id = $2",
+ )
+ .bind(&user.password)
+ .bind(user.id)
+ .execute(&self.pool)
+ .await?;
+
+ Ok(())
+ }
+
+ #[instrument(skip_all)]
+ async fn add_user(&self, user: &NewUser) -> DbResult<i64> {
+ let email: &str = &user.email;
+ let username: &str = &user.username;
+ let password: &str = &user.password;
+
+ let res: (i64,) = sqlx::query_as(
+ "insert into users
+ (username, email, password)
+ values($1, $2, $3)
+ returning id",
+ )
+ .bind(username)
+ .bind(email)
+ .bind(password)
+ .fetch_one(&self.pool)
+ .await?;
+
+ Ok(res.0)
+ }
+
+ #[instrument(skip_all)]
+ async fn add_session(&self, session: &NewSession) -> DbResult<()> {
+ let token: &str = &session.token;
+
+ sqlx::query(
+ "insert into sessions
+ (user_id, token)
+ values($1, $2)",
+ )
+ .bind(session.user_id)
+ .bind(token)
+ .execute(&self.pool)
+ .await?;
+
+ Ok(())
+ }
+
+ #[instrument(skip_all)]
+ async fn get_user_session(&self, u: &User) -> DbResult<Session> {
+ sqlx::query_as("select id, user_id, token from sessions where user_id = $1")
+ .bind(u.id)
+ .fetch_one(self.read_pool())
+ .await
+ .map_err(Into::into)
+ .map(|DbSession(session)| session)
+ }
+
+ #[instrument(skip_all)]
+ async fn oldest_history(&self, user: &User) -> DbResult<History> {
+ sqlx::query_as(
+ "select id, client_id, user_id, hostname, timestamp, data, created_at from history
+ where user_id = $1
+ order by timestamp asc
+ limit 1",
+ )
+ .bind(user.id)
+ .fetch_one(self.read_pool())
+ .await
+ .map_err(Into::into)
+ .map(|DbHistory(h)| h)
+ }
+
+ #[instrument(skip_all)]
+ async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> {
+ let mut tx = self.pool.begin().await?;
+
+ // We won't have uploaded this data if it wasn't the max. Therefore, we can deduce the max
+ // idx without having to make further database queries. Doing the query on this small
+ // amount of data should be much, much faster.
+ //
+ // Worst case, say we get this wrong. We end up caching data that isn't actually the max
+ // idx, so clients upload again. The cache logic can be verified with a sql query anyway :)
+
+ let mut heads = HashMap::<(HostId, &str), u64>::new();
+
+ for i in records {
+ let id = crate::atuin_common::utils::uuid_v7();
+
+ let result = sqlx::query(
+ "insert into store
+ (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id)
+ values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
+ on conflict do nothing
+ ",
+ )
+ .bind(id)
+ .bind(i.id)
+ .bind(i.host.id)
+ .bind(i.idx as i64)
+ .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time
+ .bind(&i.version)
+ .bind(&i.tag)
+ .bind(&i.data.data)
+ .bind(&i.data.content_encryption_key)
+ .bind(user.id)
+ .execute(&mut *tx)
+ .await?;
+
+ // Only update heads if we actually inserted the record
+ if result.rows_affected() > 0 {
+ heads
+ .entry((i.host.id, &i.tag))
+ .and_modify(|e| {
+ if i.idx > *e {
+ *e = i.idx
+ }
+ })
+ .or_insert(i.idx);
+ }
+ }
+
+ // we've built the map of heads for this push, so commit it to the database
+ for ((host, tag), idx) in heads {
+ sqlx::query(
+ "insert into store_idx_cache
+ (user_id, host, tag, idx)
+ values ($1, $2, $3, $4)
+ on conflict(user_id, host, tag) do update set idx = greatest(store_idx_cache.idx, $4)
+ ",
+ )
+ .bind(user.id)
+ .bind(host)
+ .bind(tag)
+ .bind(idx as i64)
+ .execute(&mut *tx)
+ .await
+ ?;
+ }
+
+ tx.commit().await?;
+
+ Ok(())
+ }
+
+ #[instrument(skip_all)]
+ async fn next_records(
+ &self,
+ user: &User,
+ host: HostId,
+ tag: String,
+ start: Option<RecordIdx>,
+ count: u64,
+ ) -> DbResult<Vec<Record<EncryptedData>>> {
+ tracing::debug!("{:?} - {:?} - {:?}", host, tag, start);
+ let start = start.unwrap_or(0);
+
+ let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as(
+ "select client_id, host, idx, timestamp, version, tag, data, cek from store
+ where user_id = $1
+ and tag = $2
+ and host = $3
+ and idx >= $4
+ order by idx asc
+ limit $5",
+ )
+ .bind(user.id)
+ .bind(tag.clone())
+ .bind(host)
+ .bind(start as i64)
+ .bind(count as i64)
+ .fetch_all(self.read_pool())
+ .await
+ .map_err(Into::into);
+
+ let ret = match records {
+ Ok(records) => {
+ let records: Vec<Record<EncryptedData>> = records
+ .into_iter()
+ .map(|f| {
+ let record: Record<EncryptedData> = f.into();
+ record
+ })
+ .collect();
+
+ records
+ }
+ Err(DbError::NotFound) => {
+ tracing::debug!("no records found in store: {:?}/{}", host, tag);
+ return Ok(vec![]);
+ }
+ Err(e) => return Err(e),
+ };
+
+ Ok(ret)
+ }
+
+ async fn status(&self, user: &User) -> DbResult<RecordStatus> {
+ const STATUS_SQL: &str =
+ "select host, tag, max(idx) from store where user_id = $1 group by host, tag";
+
+ // If IDX_CACHE_ROLLOUT is set, then we
+ // 1. Read the value of the var, use it as a % chance of using the cache
+ // 2. If we use the cache, just read from the cache table
+ // 3. If we don't use the cache, read from the store table
+ // IDX_CACHE_ROLLOUT should be between 0 and 100.
+
+ let idx_cache_rollout = std::env::var("IDX_CACHE_ROLLOUT").unwrap_or("0".to_string());
+ let idx_cache_rollout = idx_cache_rollout.parse::<f64>().unwrap_or(0.0);
+ let use_idx_cache = rand::thread_rng().gen_bool(idx_cache_rollout / 100.0);
+
+ let mut res: Vec<(Uuid, String, i64)> = if use_idx_cache {
+ tracing::debug!("using idx cache for user {}", user.id);
+ sqlx::query_as("select host, tag, idx from store_idx_cache where user_id = $1")
+ .bind(user.id)
+ .fetch_all(self.read_pool())
+ .await?
+ } else {
+ tracing::debug!("using aggregate query for user {}", user.id);
+ sqlx::query_as(STATUS_SQL)
+ .bind(user.id)
+ .fetch_all(self.read_pool())
+ .await?
+ };
+
+ res.sort();
+
+ let mut status = RecordStatus::new();
+
+ for i in res.iter() {
+ status.set_raw(HostId(i.0), i.1.clone(), i.2 as u64);
+ }
+
+ Ok(status)
+ }
+}
diff --git a/crates/turtle/src/atuin_server_postgres/wrappers.rs b/crates/turtle/src/atuin_server_postgres/wrappers.rs
new file mode 100644
index 00000000..214b255d
--- /dev/null
+++ b/crates/turtle/src/atuin_server_postgres/wrappers.rs
@@ -0,0 +1,77 @@
+use ::sqlx::{FromRow, Result};
+use crate::atuin_common::record::{EncryptedData, Host, Record};
+use crate::atuin_server_database::models::{History, Session, User};
+use sqlx::{Row, postgres::PgRow};
+use time::PrimitiveDateTime;
+
+pub struct DbUser(pub User);
+pub struct DbSession(pub Session);
+pub struct DbHistory(pub History);
+pub struct DbRecord(pub Record<EncryptedData>);
+
+impl<'a> FromRow<'a, PgRow> for DbUser {
+ fn from_row(row: &'a PgRow) -> Result<Self> {
+ Ok(Self(User {
+ id: row.try_get("id")?,
+ username: row.try_get("username")?,
+ email: row.try_get("email")?,
+ password: row.try_get("password")?,
+ }))
+ }
+}
+
+impl<'a> ::sqlx::FromRow<'a, PgRow> for DbSession {
+ fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> {
+ Ok(Self(Session {
+ id: row.try_get("id")?,
+ user_id: row.try_get("user_id")?,
+ token: row.try_get("token")?,
+ }))
+ }
+}
+
+impl<'a> ::sqlx::FromRow<'a, PgRow> for DbHistory {
+ fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> {
+ Ok(Self(History {
+ id: row.try_get("id")?,
+ client_id: row.try_get("client_id")?,
+ user_id: row.try_get("user_id")?,
+ hostname: row.try_get("hostname")?,
+ timestamp: row
+ .try_get::<PrimitiveDateTime, _>("timestamp")?
+ .assume_utc(),
+ data: row.try_get("data")?,
+ created_at: row
+ .try_get::<PrimitiveDateTime, _>("created_at")?
+ .assume_utc(),
+ }))
+ }
+}
+
+impl<'a> ::sqlx::FromRow<'a, PgRow> for DbRecord {
+ fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> {
+ let timestamp: i64 = row.try_get("timestamp")?;
+ let idx: i64 = row.try_get("idx")?;
+
+ let data = EncryptedData {
+ data: row.try_get("data")?,
+ content_encryption_key: row.try_get("cek")?,
+ };
+
+ Ok(Self(Record {
+ id: row.try_get("client_id")?,
+ host: Host::new(row.try_get("host")?),
+ idx: idx as u64,
+ timestamp: timestamp as u64,
+ version: row.try_get("version")?,
+ tag: row.try_get("tag")?,
+ data,
+ }))
+ }
+}
+
+impl From<DbRecord> for Record<EncryptedData> {
+ fn from(other: DbRecord) -> Record<EncryptedData> {
+ Record { ..other.0 }
+ }
+}