aboutsummaryrefslogtreecommitdiffstats
path: root/crates/turtle/src/atuin_server
diff options
context:
space:
mode:
authorBenedikt Peetz <benedikt.peetz@b-peetz.de>2026-06-11 16:10:29 +0200
committerBenedikt Peetz <benedikt.peetz@b-peetz.de>2026-06-11 16:10:29 +0200
commit97f207b771b94c5285faae4810d6eeda1b78926b (patch)
tree4482544233c30e0e9a62be6afcfe92c8e01b0a50 /crates/turtle/src/atuin_server
parentchore: Remove all `pub`s (diff)
downloadatuin-97f207b771b94c5285faae4810d6eeda1b78926b.zip
chore(server): Simplify the database support
Diffstat (limited to 'crates/turtle/src/atuin_server')
-rw-r--r--crates/turtle/src/atuin_server/database/calendar.rs18
-rw-r--r--crates/turtle/src/atuin_server/database/db/mod.rs667
-rw-r--r--crates/turtle/src/atuin_server/database/db/wrappers.rs79
-rw-r--r--crates/turtle/src/atuin_server/database/mod.rs123
-rw-r--r--crates/turtle/src/atuin_server/database/models.rs52
-rw-r--r--crates/turtle/src/atuin_server/handlers/history.rs237
-rw-r--r--crates/turtle/src/atuin_server/handlers/mod.rs5
-rw-r--r--crates/turtle/src/atuin_server/handlers/status.rs45
-rw-r--r--crates/turtle/src/atuin_server/handlers/user.rs32
-rw-r--r--crates/turtle/src/atuin_server/handlers/v0/record.rs13
-rw-r--r--crates/turtle/src/atuin_server/handlers/v0/store.rs5
-rw-r--r--crates/turtle/src/atuin_server/mod.rs16
-rw-r--r--crates/turtle/src/atuin_server/router.rs45
-rw-r--r--crates/turtle/src/atuin_server/settings.rs8
-rw-r--r--crates/turtle/src/atuin_server/utils.rs15
15 files changed, 985 insertions, 375 deletions
diff --git a/crates/turtle/src/atuin_server/database/calendar.rs b/crates/turtle/src/atuin_server/database/calendar.rs
new file mode 100644
index 00000000..f1c78262
--- /dev/null
+++ b/crates/turtle/src/atuin_server/database/calendar.rs
@@ -0,0 +1,18 @@
+// Calendar data
+
+use serde::{Deserialize, Serialize};
+use time::Month;
+
+pub(crate) enum TimePeriod {
+ Year,
+ Month { year: i32 },
+ Day { year: i32, month: Month },
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub(crate) struct TimePeriodInfo {
+ pub(crate) count: u64,
+
+ // TODO: Use this for merkle tree magic
+ pub(crate) hash: String,
+}
diff --git a/crates/turtle/src/atuin_server/database/db/mod.rs b/crates/turtle/src/atuin_server/database/db/mod.rs
new file mode 100644
index 00000000..22d69d3c
--- /dev/null
+++ b/crates/turtle/src/atuin_server/database/db/mod.rs
@@ -0,0 +1,667 @@
+use std::collections::HashMap;
+use std::ops::Range;
+
+use rand::Rng;
+
+use crate::{
+ atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus},
+ atuin_server::database::{
+ DbError, DbResult, DbSettings,
+ calendar::{TimePeriod, TimePeriodInfo},
+ into_utc,
+ models::{History, NewHistory, NewSession, NewUser, Session, User},
+ },
+};
+use futures_util::TryStreamExt;
+use sqlx::Row;
+use sqlx::postgres::PgPoolOptions;
+use time::{Date, Duration, Month, OffsetDateTime, Time, UtcOffset};
+
+use tracing::instrument;
+use uuid::Uuid;
+use wrappers::{DbHistory, DbRecord, DbSession, DbUser};
+
+mod wrappers;
+
+const MIN_PG_VERSION: u32 = 14;
+
+#[derive(Clone)]
+pub struct Database {
+ pool: sqlx::Pool<sqlx::postgres::Postgres>,
+ /// Optional read replica pool for read-only queries
+ read_pool: Option<sqlx::Pool<sqlx::postgres::Postgres>>,
+}
+
+impl Database {
+ /// Returns the appropriate pool for read operations.
+ /// Uses read_pool if available, otherwise falls back to the primary pool.
+ fn read_pool(&self) -> &sqlx::Pool<sqlx::postgres::Postgres> {
+ self.read_pool.as_ref().unwrap_or(&self.pool)
+ }
+}
+
+impl Database {
+ pub(crate) async fn new(settings: &DbSettings) -> DbResult<Self> {
+ let pool = PgPoolOptions::new()
+ .max_connections(100)
+ .connect(settings.db_uri.as_str())
+ .await?;
+
+ // Call server_version_num to get the DB server's major version number
+ // The call returns None for servers older than 8.x.
+ let pg_major_version: u32 =
+ pool.acquire()
+ .await?
+ .server_version_num()
+ .ok_or(DbError::Other(eyre::Report::msg(
+ "could not get PostgreSQL version",
+ )))?
+ / 10000;
+
+ if pg_major_version < MIN_PG_VERSION {
+ return Err(DbError::Other(eyre::Report::msg(format!(
+ "unsupported PostgreSQL version {pg_major_version}, minimum required is {MIN_PG_VERSION}"
+ ))));
+ }
+
+ sqlx::migrate!("./db/server-pg-migrations")
+ .run(&pool)
+ .await
+ .map_err(|error| DbError::Other(error.into()))?;
+
+ // Create read replica pool if configured
+ let read_pool = if let Some(read_db_uri) = &settings.read_db_uri {
+ tracing::info!("Connecting to read replica database");
+ let read_pool = PgPoolOptions::new()
+ .max_connections(100)
+ .connect(read_db_uri.as_str())
+ .await?;
+
+ // Verify the read replica is also a supported PostgreSQL version
+ let read_pg_major_version: u32 = read_pool
+ .acquire()
+ .await?
+ .server_version_num()
+ .ok_or(DbError::Other(eyre::Report::msg(
+ "could not get PostgreSQL version from read replica",
+ )))?
+ / 10000;
+
+ if read_pg_major_version < MIN_PG_VERSION {
+ return Err(DbError::Other(eyre::Report::msg(format!(
+ "unsupported PostgreSQL version {read_pg_major_version} on read replica, minimum required is {MIN_PG_VERSION}"
+ ))));
+ }
+
+ Some(read_pool)
+ } else {
+ None
+ };
+
+ Ok(Self { pool, read_pool })
+ }
+
+ #[instrument(skip_all)]
+ pub(crate) async fn calendar(
+ &self,
+ user: &User,
+ period: TimePeriod,
+ tz: UtcOffset,
+ ) -> DbResult<HashMap<u64, TimePeriodInfo>> {
+ let mut ret = HashMap::new();
+ let iter: Box<dyn Iterator<Item = DbResult<(u64, Range<Date>)>> + Send> = match period {
+ TimePeriod::Year => {
+ // First we need to work out how far back to calculate. Get the
+ // oldest history item
+ let oldest = self
+ .oldest_history(user)
+ .await?
+ .timestamp
+ .to_offset(tz)
+ .year();
+ let current_year = OffsetDateTime::now_utc().to_offset(tz).year();
+
+ // All the years we need to get data for
+ // The upper bound is exclusive, so include current +1
+ let years = oldest..current_year + 1;
+
+ Box::new(years.map(|year| {
+ let start = Date::from_calendar_date(year, time::Month::January, 1)?;
+ let end = Date::from_calendar_date(year + 1, time::Month::January, 1)?;
+
+ Ok((year as u64, start..end))
+ }))
+ }
+
+ TimePeriod::Month { year } => {
+ let months =
+ std::iter::successors(Some(Month::January), |m| Some(m.next())).take(12);
+
+ Box::new(months.map(move |month| {
+ let start = Date::from_calendar_date(year, month, 1)?;
+ let days = start.month().length(year);
+ let end = start + Duration::days(days as i64);
+
+ Ok((month as u64, start..end))
+ }))
+ }
+
+ TimePeriod::Day { year, month } => {
+ let days = 1..month.length(year);
+ Box::new(days.map(move |day| {
+ let start = Date::from_calendar_date(year, month, day)?;
+ let end = start
+ .next_day()
+ .ok_or_else(|| DbError::Other(eyre::eyre!("no next day?")))?;
+
+ Ok((day as u64, start..end))
+ }))
+ }
+ };
+
+ for x in iter {
+ let (index, range) = x?;
+
+ let start = range.start.with_time(Time::MIDNIGHT).assume_offset(tz);
+ let end = range.end.with_time(Time::MIDNIGHT).assume_offset(tz);
+
+ let count = self.count_history_range(user, start..end).await?;
+
+ ret.insert(
+ index,
+ TimePeriodInfo {
+ count: count as u64,
+ hash: "".to_string(),
+ },
+ );
+ }
+
+ Ok(ret)
+ }
+
+ #[instrument(skip_all)]
+ pub(crate) async fn get_session(&self, token: &str) -> DbResult<Session> {
+ sqlx::query_as("select id, user_id, token from sessions where token = $1")
+ .bind(token)
+ .fetch_one(self.read_pool())
+ .await
+ .map_err(Into::into)
+ .map(|DbSession(session)| session)
+ }
+
+ #[instrument(skip_all)]
+ pub(crate) async fn get_user(&self, username: &str) -> DbResult<User> {
+ sqlx::query_as("select id, username, email, password from users where username = $1")
+ .bind(username)
+ .fetch_one(self.read_pool())
+ .await
+ .map_err(Into::into)
+ .map(|DbUser(user)| user)
+ }
+
+ #[instrument(skip_all)]
+ pub(crate) async fn get_session_user(&self, token: &str) -> DbResult<User> {
+ sqlx::query_as(
+ "select users.id, users.username, users.email, users.password from users
+ inner join sessions
+ on users.id = sessions.user_id
+ and sessions.token = $1",
+ )
+ .bind(token)
+ .fetch_one(self.read_pool())
+ .await
+ .map_err(Into::into)
+ .map(|DbUser(user)| user)
+ }
+
+ #[instrument(skip_all)]
+ pub(crate) async fn count_history(&self, user: &User) -> DbResult<i64> {
+ // The cache is new, and the user might not yet have a cache value.
+ // They will have one as soon as they post up some new history, but handle that
+ // edge case.
+
+ let res: (i64,) = sqlx::query_as(
+ "select count(1) from history
+ where user_id = $1",
+ )
+ .bind(user.id)
+ .fetch_one(self.read_pool())
+ .await?;
+
+ Ok(res.0)
+ }
+
+ #[instrument(skip_all)]
+ pub(crate) async fn count_history_cached(&self, user: &User) -> DbResult<i64> {
+ let res: (i32,) = sqlx::query_as(
+ "select total from total_history_count_user
+ where user_id = $1",
+ )
+ .bind(user.id)
+ .fetch_one(self.read_pool())
+ .await?;
+
+ Ok(res.0 as i64)
+ }
+
+ pub(crate) async fn delete_store(&self, user: &User) -> DbResult<()> {
+ let mut tx = self.pool.begin().await?;
+
+ sqlx::query(
+ "delete from store
+ where user_id = $1",
+ )
+ .bind(user.id)
+ .execute(&mut *tx)
+ .await?;
+
+ sqlx::query(
+ "delete from store_idx_cache
+ where user_id = $1",
+ )
+ .bind(user.id)
+ .execute(&mut *tx)
+ .await?;
+
+ tx.commit().await?;
+
+ Ok(())
+ }
+
+ pub(crate) async fn delete_history(&self, user: &User, id: String) -> DbResult<()> {
+ sqlx::query(
+ "update history
+ set deleted_at = $3
+ where user_id = $1
+ and client_id = $2
+ and deleted_at is null", // don't just keep setting it
+ )
+ .bind(user.id)
+ .bind(id)
+ .bind(OffsetDateTime::now_utc())
+ .fetch_all(&self.pool)
+ .await?;
+
+ Ok(())
+ }
+
+ #[instrument(skip_all)]
+ pub(crate) async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> {
+ // The cache is new, and the user might not yet have a cache value.
+ // They will have one as soon as they post up some new history, but handle that
+ // edge case.
+
+ let res = sqlx::query(
+ "select client_id from history
+ where user_id = $1
+ and deleted_at is not null",
+ )
+ .bind(user.id)
+ .fetch_all(self.read_pool())
+ .await?;
+
+ let res = res
+ .iter()
+ .map(|row| row.get::<String, _>("client_id"))
+ .collect();
+
+ Ok(res)
+ }
+
+ #[instrument(skip_all)]
+ pub(crate) async fn count_history_range(
+ &self,
+ user: &User,
+ range: Range<OffsetDateTime>,
+ ) -> DbResult<i64> {
+ let res: (i64,) = sqlx::query_as(
+ "select count(1) from history
+ where user_id = $1
+ and timestamp >= $2::date
+ and timestamp < $3::date",
+ )
+ .bind(user.id)
+ .bind(into_utc(range.start))
+ .bind(into_utc(range.end))
+ .fetch_one(self.read_pool())
+ .await?;
+
+ Ok(res.0)
+ }
+
+ #[instrument(skip_all)]
+ pub(crate) async fn list_history(
+ &self,
+ user: &User,
+ created_after: OffsetDateTime,
+ since: OffsetDateTime,
+ host: &str,
+ page_size: i64,
+ ) -> DbResult<Vec<History>> {
+ let res = sqlx::query_as(
+ "select id, client_id, user_id, hostname, timestamp, data, created_at from history
+ where user_id = $1
+ and hostname != $2
+ and created_at >= $3
+ and timestamp >= $4
+ order by timestamp asc
+ limit $5",
+ )
+ .bind(user.id)
+ .bind(host)
+ .bind(into_utc(created_after))
+ .bind(into_utc(since))
+ .bind(page_size)
+ .fetch(self.read_pool())
+ .map_ok(|DbHistory(h)| h)
+ .try_collect()
+ .await?;
+
+ Ok(res)
+ }
+
+ #[instrument(skip_all)]
+ pub(crate) async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> {
+ let mut tx = self.pool.begin().await?;
+
+ for i in history {
+ let client_id: &str = &i.client_id;
+ let hostname: &str = &i.hostname;
+ let data: &str = &i.data;
+
+ sqlx::query(
+ "insert into history
+ (client_id, user_id, hostname, timestamp, data)
+ values ($1, $2, $3, $4, $5)
+ on conflict do nothing
+ ",
+ )
+ .bind(client_id)
+ .bind(i.user_id)
+ .bind(hostname)
+ .bind(i.timestamp)
+ .bind(data)
+ .execute(&mut *tx)
+ .await?;
+ }
+
+ tx.commit().await?;
+
+ Ok(())
+ }
+
+ #[instrument(skip_all)]
+ pub(crate) async fn delete_user(&self, u: &User) -> DbResult<()> {
+ sqlx::query("delete from sessions where user_id = $1")
+ .bind(u.id)
+ .execute(&self.pool)
+ .await?;
+
+ sqlx::query("delete from history where user_id = $1")
+ .bind(u.id)
+ .execute(&self.pool)
+ .await?;
+
+ sqlx::query("delete from store where user_id = $1")
+ .bind(u.id)
+ .execute(&self.pool)
+ .await?;
+
+ sqlx::query("delete from total_history_count_user where user_id = $1")
+ .bind(u.id)
+ .execute(&self.pool)
+ .await?;
+
+ sqlx::query("delete from users where id = $1")
+ .bind(u.id)
+ .execute(&self.pool)
+ .await?;
+
+ Ok(())
+ }
+
+ #[instrument(skip_all)]
+ pub(crate) async fn update_user_password(&self, user: &User) -> DbResult<()> {
+ sqlx::query(
+ "update users
+ set password = $1
+ where id = $2",
+ )
+ .bind(&user.password)
+ .bind(user.id)
+ .execute(&self.pool)
+ .await?;
+
+ Ok(())
+ }
+
+ #[instrument(skip_all)]
+ pub(crate) async fn add_user(&self, user: &NewUser) -> DbResult<i64> {
+ let email: &str = &user.email;
+ let username: &str = &user.username;
+ let password: &str = &user.password;
+
+ let res: (i64,) = sqlx::query_as(
+ "insert into users
+ (username, email, password)
+ values($1, $2, $3)
+ returning id",
+ )
+ .bind(username)
+ .bind(email)
+ .bind(password)
+ .fetch_one(&self.pool)
+ .await?;
+
+ Ok(res.0)
+ }
+
+ #[instrument(skip_all)]
+ pub(crate) async fn add_session(&self, session: &NewSession) -> DbResult<()> {
+ let token: &str = &session.token;
+
+ sqlx::query(
+ "insert into sessions
+ (user_id, token)
+ values($1, $2)",
+ )
+ .bind(session.user_id)
+ .bind(token)
+ .execute(&self.pool)
+ .await?;
+
+ Ok(())
+ }
+
+ #[instrument(skip_all)]
+ pub(crate) async fn get_user_session(&self, u: &User) -> DbResult<Session> {
+ sqlx::query_as("select id, user_id, token from sessions where user_id = $1")
+ .bind(u.id)
+ .fetch_one(self.read_pool())
+ .await
+ .map_err(Into::into)
+ .map(|DbSession(session)| session)
+ }
+
+ #[instrument(skip_all)]
+ pub(crate) async fn oldest_history(&self, user: &User) -> DbResult<History> {
+ sqlx::query_as(
+ "select id, client_id, user_id, hostname, timestamp, data, created_at from history
+ where user_id = $1
+ order by timestamp asc
+ limit 1",
+ )
+ .bind(user.id)
+ .fetch_one(self.read_pool())
+ .await
+ .map_err(Into::into)
+ .map(|DbHistory(h)| h)
+ }
+
+ #[instrument(skip_all)]
+ pub(crate) async fn add_records(
+ &self,
+ user: &User,
+ records: &[Record<EncryptedData>],
+ ) -> DbResult<()> {
+ let mut tx = self.pool.begin().await?;
+
+ // We won't have uploaded this data if it wasn't the max. Therefore, we can deduce the max
+ // idx without having to make further database queries. Doing the query on this small
+ // amount of data should be much, much faster.
+ //
+ // Worst case, say we get this wrong. We end up caching data that isn't actually the max
+ // idx, so clients upload again. The cache logic can be verified with a sql query anyway :)
+
+ let mut heads = HashMap::<(HostId, &str), u64>::new();
+
+ for i in records {
+ let id = crate::atuin_common::utils::uuid_v7();
+
+ let result = sqlx::query(
+ "insert into store
+ (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id)
+ values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
+ on conflict do nothing
+ ",
+ )
+ .bind(id)
+ .bind(i.id)
+ .bind(i.host.id)
+ .bind(i.idx as i64)
+ .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time
+ .bind(&i.version)
+ .bind(&i.tag)
+ .bind(&i.data.data)
+ .bind(&i.data.content_encryption_key)
+ .bind(user.id)
+ .execute(&mut *tx)
+ .await?;
+
+ // Only update heads if we actually inserted the record
+ if result.rows_affected() > 0 {
+ heads
+ .entry((i.host.id, &i.tag))
+ .and_modify(|e| {
+ if i.idx > *e {
+ *e = i.idx
+ }
+ })
+ .or_insert(i.idx);
+ }
+ }
+
+ // we've built the map of heads for this push, so commit it to the database
+ for ((host, tag), idx) in heads {
+ sqlx::query(
+ "insert into store_idx_cache
+ (user_id, host, tag, idx)
+ values ($1, $2, $3, $4)
+ on conflict(user_id, host, tag) do update set idx = greatest(store_idx_cache.idx, $4)
+ ",
+ )
+ .bind(user.id)
+ .bind(host)
+ .bind(tag)
+ .bind(idx as i64)
+ .execute(&mut *tx)
+ .await
+ ?;
+ }
+
+ tx.commit().await?;
+
+ Ok(())
+ }
+
+ #[instrument(skip_all)]
+ pub(crate) async fn next_records(
+ &self,
+ user: &User,
+ host: HostId,
+ tag: String,
+ start: Option<RecordIdx>,
+ count: u64,
+ ) -> DbResult<Vec<Record<EncryptedData>>> {
+ tracing::debug!("{:?} - {:?} - {:?}", host, tag, start);
+ let start = start.unwrap_or(0);
+
+ let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as(
+ "select client_id, host, idx, timestamp, version, tag, data, cek from store
+ where user_id = $1
+ and tag = $2
+ and host = $3
+ and idx >= $4
+ order by idx asc
+ limit $5",
+ )
+ .bind(user.id)
+ .bind(tag.clone())
+ .bind(host)
+ .bind(start as i64)
+ .bind(count as i64)
+ .fetch_all(self.read_pool())
+ .await
+ .map_err(Into::into);
+
+ let ret = match records {
+ Ok(records) => {
+ let records: Vec<Record<EncryptedData>> = records
+ .into_iter()
+ .map(|f| {
+ let record: Record<EncryptedData> = f.into();
+ record
+ })
+ .collect();
+
+ records
+ }
+ Err(DbError::NotFound) => {
+ tracing::debug!("no records found in store: {:?}/{}", host, tag);
+ return Ok(vec![]);
+ }
+ Err(e) => return Err(e),
+ };
+
+ Ok(ret)
+ }
+
+ pub(crate) async fn status(&self, user: &User) -> DbResult<RecordStatus> {
+ const STATUS_SQL: &str =
+ "select host, tag, max(idx) from store where user_id = $1 group by host, tag";
+
+ // If IDX_CACHE_ROLLOUT is set, then we
+ // 1. Read the value of the var, use it as a % chance of using the cache
+ // 2. If we use the cache, just read from the cache table
+ // 3. If we don't use the cache, read from the store table
+ // IDX_CACHE_ROLLOUT should be between 0 and 100.
+
+ let idx_cache_rollout = std::env::var("IDX_CACHE_ROLLOUT").unwrap_or("0".to_string());
+ let idx_cache_rollout = idx_cache_rollout.parse::<f64>().unwrap_or(0.0);
+ let use_idx_cache = rand::thread_rng().gen_bool(idx_cache_rollout / 100.0);
+
+ let mut res: Vec<(Uuid, String, i64)> = if use_idx_cache {
+ tracing::debug!("using idx cache for user {}", user.id);
+ sqlx::query_as("select host, tag, idx from store_idx_cache where user_id = $1")
+ .bind(user.id)
+ .fetch_all(self.read_pool())
+ .await?
+ } else {
+ tracing::debug!("using aggregate query for user {}", user.id);
+ sqlx::query_as(STATUS_SQL)
+ .bind(user.id)
+ .fetch_all(self.read_pool())
+ .await?
+ };
+
+ res.sort();
+
+ let mut status = RecordStatus::new();
+
+ for i in res.iter() {
+ status.set_raw(HostId(i.0), i.1.clone(), i.2 as u64);
+ }
+
+ Ok(status)
+ }
+}
diff --git a/crates/turtle/src/atuin_server/database/db/wrappers.rs b/crates/turtle/src/atuin_server/database/db/wrappers.rs
new file mode 100644
index 00000000..de4c5814
--- /dev/null
+++ b/crates/turtle/src/atuin_server/database/db/wrappers.rs
@@ -0,0 +1,79 @@
+use crate::{
+ atuin_common::record::{EncryptedData, Host, Record},
+ atuin_server::database::models::{History, Session, User},
+};
+use ::sqlx::{FromRow, Result};
+use sqlx::{Row, postgres::PgRow};
+use time::PrimitiveDateTime;
+
+pub struct DbUser(pub User);
+pub struct DbSession(pub Session);
+pub struct DbHistory(pub History);
+pub struct DbRecord(pub Record<EncryptedData>);
+
+impl<'a> FromRow<'a, PgRow> for DbUser {
+ fn from_row(row: &'a PgRow) -> Result<Self> {
+ Ok(Self(User {
+ id: row.try_get("id")?,
+ username: row.try_get("username")?,
+ email: row.try_get("email")?,
+ password: row.try_get("password")?,
+ }))
+ }
+}
+
+impl<'a> ::sqlx::FromRow<'a, PgRow> for DbSession {
+ fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> {
+ Ok(Self(Session {
+ id: row.try_get("id")?,
+ user_id: row.try_get("user_id")?,
+ token: row.try_get("token")?,
+ }))
+ }
+}
+
+impl<'a> ::sqlx::FromRow<'a, PgRow> for DbHistory {
+ fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> {
+ Ok(Self(History {
+ id: row.try_get("id")?,
+ client_id: row.try_get("client_id")?,
+ user_id: row.try_get("user_id")?,
+ hostname: row.try_get("hostname")?,
+ timestamp: row
+ .try_get::<PrimitiveDateTime, _>("timestamp")?
+ .assume_utc(),
+ data: row.try_get("data")?,
+ created_at: row
+ .try_get::<PrimitiveDateTime, _>("created_at")?
+ .assume_utc(),
+ }))
+ }
+}
+
+impl<'a> ::sqlx::FromRow<'a, PgRow> for DbRecord {
+ fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> {
+ let timestamp: i64 = row.try_get("timestamp")?;
+ let idx: i64 = row.try_get("idx")?;
+
+ let data = EncryptedData {
+ data: row.try_get("data")?,
+ content_encryption_key: row.try_get("cek")?,
+ };
+
+ Ok(Self(Record {
+ id: row.try_get("client_id")?,
+ host: Host::new(row.try_get("host")?),
+ idx: idx as u64,
+ timestamp: timestamp as u64,
+ version: row.try_get("version")?,
+ tag: row.try_get("tag")?,
+ data,
+ }))
+ }
+}
+
+impl From<DbRecord> for Record<EncryptedData> {
+ fn from(other: DbRecord) -> Record<EncryptedData> {
+ Record { ..other.0 }
+ }
+}
diff --git a/crates/turtle/src/atuin_server/database/mod.rs b/crates/turtle/src/atuin_server/database/mod.rs
new file mode 100644
index 00000000..845d67d7
--- /dev/null
+++ b/crates/turtle/src/atuin_server/database/mod.rs
@@ -0,0 +1,123 @@
+pub(crate) mod calendar;
+pub(crate) mod db;
+pub(crate) mod models;
+
+use std::fmt::{Debug, Display};
+
+use serde::{Deserialize, Serialize};
+use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset};
+
+#[derive(Debug)]
+pub(crate) enum DbError {
+ NotFound,
+ Other(eyre::Report),
+}
+
+impl Display for DbError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{self:?}")
+ }
+}
+
+impl From<time::error::ComponentRange> for DbError {
+ fn from(error: time::error::ComponentRange) -> Self {
+ DbError::Other(error.into())
+ }
+}
+
+impl From<time::error::Error> for DbError {
+ fn from(error: time::error::Error) -> Self {
+ DbError::Other(error.into())
+ }
+}
+
+impl From<sqlx::Error> for DbError {
+ fn from(error: sqlx::Error) -> Self {
+ match error {
+ sqlx::Error::RowNotFound => DbError::NotFound,
+ error => DbError::Other(error.into()),
+ }
+ }
+}
+
+impl std::error::Error for DbError {}
+
+pub(crate) type DbResult<T> = Result<T, DbError>;
+
+#[derive(Debug, PartialEq)]
+pub(crate) enum DbType {
+ Postgres,
+ Unknown,
+}
+
+#[derive(Clone, Deserialize, Serialize)]
+pub(crate) struct DbSettings {
+ pub(crate) db_uri: String,
+ /// Optional URI for read replicas. If set, read-only queries will use this connection.
+ pub(crate) read_db_uri: Option<String>,
+}
+
+impl DbSettings {
+ pub(crate) fn db_type(&self) -> DbType {
+ if self.db_uri.starts_with("postgres://") || self.db_uri.starts_with("postgresql://") {
+ DbType::Postgres
+ } else {
+ DbType::Unknown
+ }
+ }
+}
+
+fn redact_db_uri(uri: &str) -> String {
+ url::Url::parse(uri)
+ .map(|mut url| {
+ let _ = url.set_password(Some("****"));
+ url.to_string()
+ })
+ .unwrap_or_else(|_| uri.to_string())
+}
+
+// Do our best to redact passwords so they're not logged in the event of an error.
+impl Debug for DbSettings {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ if self.db_type() == DbType::Postgres {
+ let redacted_uri = redact_db_uri(&self.db_uri);
+ let redacted_read_uri = self.read_db_uri.as_ref().map(|uri| redact_db_uri(uri));
+ f.debug_struct("DbSettings")
+ .field("db_uri", &redacted_uri)
+ .field("read_db_uri", &redacted_read_uri)
+ .finish()
+ } else {
+ f.debug_struct("DbSettings")
+ .field("db_uri", &self.db_uri)
+ .field("read_db_uri", &self.read_db_uri)
+ .finish()
+ }
+ }
+}
+
+pub(crate) fn into_utc(x: OffsetDateTime) -> PrimitiveDateTime {
+ let x = x.to_offset(UtcOffset::UTC);
+ PrimitiveDateTime::new(x.date(), x.time())
+}
+
+#[cfg(test)]
+mod tests {
+ use time::macros::datetime;
+
+ use super::into_utc;
+
+ #[test]
+ fn utc() {
+ let dt = datetime!(2023-09-26 15:11:02 +05:30);
+ assert_eq!(into_utc(dt), datetime!(2023-09-26 09:41:02));
+ assert_eq!(into_utc(dt).assume_utc(), dt);
+
+ let dt = datetime!(2023-09-26 15:11:02 -07:00);
+ assert_eq!(into_utc(dt), datetime!(2023-09-26 22:11:02));
+ assert_eq!(into_utc(dt).assume_utc(), dt);
+
+ let dt = datetime!(2023-09-26 15:11:02 +00:00);
+ assert_eq!(into_utc(dt), datetime!(2023-09-26 15:11:02));
+ assert_eq!(into_utc(dt).assume_utc(), dt);
+ }
+}
diff --git a/crates/turtle/src/atuin_server/database/models.rs b/crates/turtle/src/atuin_server/database/models.rs
new file mode 100644
index 00000000..e47d614d
--- /dev/null
+++ b/crates/turtle/src/atuin_server/database/models.rs
@@ -0,0 +1,52 @@
+use time::OffsetDateTime;
+
+pub(crate) struct History {
+ pub(crate) id: i64,
+ pub(crate) client_id: String, // a client generated ID
+ pub(crate) user_id: i64,
+ pub(crate) hostname: String,
+ pub(crate) timestamp: OffsetDateTime,
+
+ /// All the data we have about this command, encrypted.
+ ///
+ /// Currently this is an encrypted msgpack object, but this may change in the future.
+ pub(crate) data: String,
+
+ pub(crate) created_at: OffsetDateTime,
+}
+
+pub(crate) struct NewHistory {
+ pub(crate) client_id: String,
+ pub(crate) user_id: i64,
+ pub(crate) hostname: String,
+ pub(crate) timestamp: OffsetDateTime,
+
+ /// All the data we have about this command, encrypted.
+ ///
+ /// Currently this is an encrypted msgpack object, but this may change in the future.
+ pub(crate) data: String,
+}
+
+pub(crate) struct User {
+ pub(crate) id: i64,
+ pub(crate) username: String,
+ pub(crate) email: String,
+ pub(crate) password: String,
+}
+
+pub(crate) struct Session {
+ pub(crate) id: i64,
+ pub(crate) user_id: i64,
+ pub(crate) token: String,
+}
+
+pub(crate) struct NewUser {
+ pub(crate) username: String,
+ pub(crate) email: String,
+ pub(crate) password: String,
+}
+
+pub(crate) struct NewSession {
+ pub(crate) user_id: i64,
+ pub(crate) token: String,
+}
diff --git a/crates/turtle/src/atuin_server/handlers/history.rs b/crates/turtle/src/atuin_server/handlers/history.rs
deleted file mode 100644
index e5057bcb..00000000
--- a/crates/turtle/src/atuin_server/handlers/history.rs
+++ /dev/null
@@ -1,237 +0,0 @@
-use std::{collections::HashMap, convert::TryFrom};
-
-use axum::{
- Json,
- extract::{Path, Query, State},
- http::{HeaderMap, StatusCode},
-};
-use metrics::counter;
-use time::{Month, UtcOffset};
-use tracing::{debug, error, instrument};
-
-use super::{ErrorResponse, ErrorResponseStatus, RespExt};
-use crate::atuin_server::{
- router::{AppState, UserAuth},
- utils::client_version_min,
-};
-use crate::atuin_server_database::{
- Database,
- calendar::{TimePeriod, TimePeriodInfo},
- models::NewHistory,
-};
-
-use crate::atuin_common::api::*;
-
-#[instrument(skip_all, fields(user.id = user.id))]
-pub(crate) async fn count<DB: Database>(
- UserAuth(user): UserAuth,
- state: State<AppState<DB>>,
-) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> {
- let db = &state.0.database;
- match db.count_history_cached(&user).await {
- // By default read out the cached value
- Ok(count) => Ok(Json(CountResponse { count })),
-
- // If that fails, fallback on a full COUNT. Cache is built on a POST
- // only
- Err(_) => match db.count_history(&user).await {
- Ok(count) => Ok(Json(CountResponse { count })),
- Err(_) => Err(ErrorResponse::reply("failed to query history count")
- .with_status(StatusCode::INTERNAL_SERVER_ERROR)),
- },
- }
-}
-
-#[instrument(skip_all, fields(user.id = user.id))]
-pub(crate) async fn list<DB: Database>(
- req: Query<SyncHistoryRequest>,
- UserAuth(user): UserAuth,
- headers: HeaderMap,
- state: State<AppState<DB>>,
-) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> {
- let db = &state.0.database;
-
- let agent = headers
- .get("user-agent")
- .map_or("", |v| v.to_str().unwrap_or(""));
-
- let variable_page_size = client_version_min(agent, ">=15.0.0").unwrap_or(false);
-
- let page_size = if variable_page_size {
- state.settings.page_size
- } else {
- 100
- };
-
- if req.sync_ts.unix_timestamp_nanos() < 0 || req.history_ts.unix_timestamp_nanos() < 0 {
- error!("client asked for history from < epoch 0");
- counter!("atuin_history_epoch_before_zero").increment(1);
-
- return Err(
- ErrorResponse::reply("asked for history from before epoch 0")
- .with_status(StatusCode::BAD_REQUEST),
- );
- }
-
- let history = db
- .list_history(&user, req.sync_ts, req.history_ts, &req.host, page_size)
- .await;
-
- if let Err(e) = history {
- error!("failed to load history: {}", e);
- return Err(ErrorResponse::reply("failed to load history")
- .with_status(StatusCode::INTERNAL_SERVER_ERROR));
- }
-
- let history: Vec<String> = history
- .unwrap()
- .iter()
- .map(|i| i.data.to_string())
- .collect();
-
- debug!(
- "loaded {} items of history for user {}",
- history.len(),
- user.id
- );
-
- counter!("atuin_history_returned").increment(history.len() as u64);
-
- Ok(Json(SyncHistoryResponse { history }))
-}
-
-#[instrument(skip_all, fields(user.id = user.id))]
-pub(crate) async fn delete<DB: Database>(
- UserAuth(user): UserAuth,
- state: State<AppState<DB>>,
- Json(req): Json<DeleteHistoryRequest>,
-) -> Result<Json<MessageResponse>, ErrorResponseStatus<'static>> {
- let db = &state.0.database;
-
- // user_id is the ID of the history, as set by the user (the server has its own ID)
- let deleted = db.delete_history(&user, req.client_id).await;
-
- if let Err(e) = deleted {
- error!("failed to delete history: {}", e);
- return Err(ErrorResponse::reply("failed to delete history")
- .with_status(StatusCode::INTERNAL_SERVER_ERROR));
- }
-
- Ok(Json(MessageResponse {
- message: String::from("deleted OK"),
- }))
-}
-
-#[instrument(skip_all, fields(user.id = user.id))]
-pub(crate) async fn add<DB: Database>(
- UserAuth(user): UserAuth,
- state: State<AppState<DB>>,
- Json(req): Json<Vec<AddHistoryRequest>>,
-) -> Result<(), ErrorResponseStatus<'static>> {
- let State(AppState { database, settings }) = state;
-
- debug!("request to add {} history items", req.len());
- counter!("atuin_history_uploaded").increment(req.len() as u64);
-
- let mut history: Vec<NewHistory> = req
- .into_iter()
- .map(|h| NewHistory {
- client_id: h.id,
- user_id: user.id,
- hostname: h.hostname,
- timestamp: h.timestamp,
- data: h.data,
- })
- .collect();
-
- history.retain(|h| {
- // keep if within limit, or limit is 0 (unlimited)
- let keep = h.data.len() <= settings.max_history_length || settings.max_history_length == 0;
-
- // Don't return an error here. We want to insert as much of the
- // history list as we can, so log the error and continue going.
- if !keep {
- counter!("atuin_history_too_long").increment(1);
-
- tracing::warn!(
- "history too long, got length {}, max {}",
- h.data.len(),
- settings.max_history_length
- );
- }
-
- keep
- });
-
- if let Err(e) = database.add_history(&history).await {
- error!("failed to add history: {}", e);
-
- return Err(ErrorResponse::reply("failed to add history")
- .with_status(StatusCode::INTERNAL_SERVER_ERROR));
- };
-
- Ok(())
-}
-
-#[derive(serde::Deserialize, Debug)]
-pub(crate) struct CalendarQuery {
- #[serde(default = "serde_calendar::zero")]
- year: i32,
- #[serde(default = "serde_calendar::one")]
- month: u8,
-
- #[serde(default = "serde_calendar::utc")]
- tz: UtcOffset,
-}
-
-mod serde_calendar {
- use time::UtcOffset;
-
- pub(crate) fn zero() -> i32 {
- 0
- }
-
- pub(crate) fn one() -> u8 {
- 1
- }
-
- pub(crate) fn utc() -> UtcOffset {
- UtcOffset::UTC
- }
-}
-
-#[instrument(skip_all, fields(user.id = user.id))]
-pub(crate) async fn calendar<DB: Database>(
- Path(focus): Path<String>,
- Query(params): Query<CalendarQuery>,
- UserAuth(user): UserAuth,
- state: State<AppState<DB>>,
-) -> Result<Json<HashMap<u64, TimePeriodInfo>>, ErrorResponseStatus<'static>> {
- let focus = focus.as_str();
-
- let year = params.year;
- let month = Month::try_from(params.month).map_err(|e| ErrorResponseStatus {
- error: ErrorResponse {
- reason: e.to_string().into(),
- },
- status: StatusCode::BAD_REQUEST,
- })?;
-
- let period = match focus {
- "year" => TimePeriod::Year,
- "month" => TimePeriod::Month { year },
- "day" => TimePeriod::Day { year, month },
- _ => {
- return Err(ErrorResponse::reply("invalid focus: use year/month/day")
- .with_status(StatusCode::BAD_REQUEST));
- }
- };
-
- let db = &state.0.database;
- let focus = db.calendar(&user, period, params.tz).await.map_err(|_| {
- ErrorResponse::reply("failed to query calendar")
- .with_status(StatusCode::INTERNAL_SERVER_ERROR)
- })?;
-
- Ok(Json(focus))
-}
diff --git a/crates/turtle/src/atuin_server/handlers/mod.rs b/crates/turtle/src/atuin_server/handlers/mod.rs
index 322324c4..3b935834 100644
--- a/crates/turtle/src/atuin_server/handlers/mod.rs
+++ b/crates/turtle/src/atuin_server/handlers/mod.rs
@@ -1,19 +1,16 @@
use crate::atuin_common::api::{ErrorResponse, IndexResponse};
-use crate::atuin_server_database::Database;
use axum::{Json, extract::State, http, response::IntoResponse};
use crate::atuin_server::router::AppState;
pub(crate) mod health;
-pub(crate) mod history;
pub(crate) mod record;
-pub(crate) mod status;
pub(crate) mod user;
pub(crate) mod v0;
const VERSION: &str = env!("CARGO_PKG_VERSION");
-pub(crate) async fn index<DB: Database>(state: State<AppState<DB>>) -> Json<IndexResponse> {
+pub(crate) async fn index(state: State<AppState>) -> Json<IndexResponse> {
let homage = r#""Through the fathomless deeps of space swims the star turtle Great A'Tuin, bearing on its back the four giant elephants who carry on their shoulders the mass of the Discworld." -- Sir Terry Pratchett"#;
let version = state
diff --git a/crates/turtle/src/atuin_server/handlers/status.rs b/crates/turtle/src/atuin_server/handlers/status.rs
deleted file mode 100644
index 59be1e5c..00000000
--- a/crates/turtle/src/atuin_server/handlers/status.rs
+++ /dev/null
@@ -1,45 +0,0 @@
-use axum::{Json, extract::State, http::StatusCode};
-use tracing::instrument;
-
-use super::{ErrorResponse, ErrorResponseStatus, RespExt};
-use crate::atuin_server::router::{AppState, UserAuth};
-use crate::atuin_server_database::Database;
-
-use crate::atuin_common::api::*;
-
-const VERSION: &str = env!("CARGO_PKG_VERSION");
-
-#[instrument(skip_all, fields(user.id = user.id))]
-pub(crate) async fn status<DB: Database>(
- UserAuth(user): UserAuth,
- state: State<AppState<DB>>,
-) -> Result<Json<StatusResponse>, ErrorResponseStatus<'static>> {
- let db = &state.0.database;
-
- let deleted = db.deleted_history(&user).await.unwrap_or(vec![]);
-
- let count = match db.count_history_cached(&user).await {
- // By default read out the cached value
- Ok(count) => count,
-
- // If that fails, fallback on a full COUNT. Cache is built on a POST
- // only
- Err(_) => match db.count_history(&user).await {
- Ok(count) => count,
- Err(_) => {
- return Err(ErrorResponse::reply("failed to query history count")
- .with_status(StatusCode::INTERNAL_SERVER_ERROR));
- }
- },
- };
-
- tracing::debug!(user = user.username, "requested sync status");
-
- Ok(Json(StatusResponse {
- count,
- deleted,
- username: user.username,
- version: VERSION.to_string(),
- page_size: state.settings.page_size,
- }))
-}
diff --git a/crates/turtle/src/atuin_server/handlers/user.rs b/crates/turtle/src/atuin_server/handlers/user.rs
index 7708d43e..28cebfab 100644
--- a/crates/turtle/src/atuin_server/handlers/user.rs
+++ b/crates/turtle/src/atuin_server/handlers/user.rs
@@ -16,14 +16,16 @@ use metrics::counter;
use rand::rngs::OsRng;
use tracing::{debug, error, info, instrument};
-use crate::atuin_common::tls::ensure_crypto_provider;
+use crate::{
+ atuin_common::tls::ensure_crypto_provider,
+ atuin_server::database::{
+ DbError,
+ models::{NewSession, NewUser},
+ },
+};
use super::{ErrorResponse, ErrorResponseStatus, RespExt};
use crate::atuin_server::router::{AppState, UserAuth};
-use crate::atuin_server_database::{
- Database, DbError,
- models::{NewSession, NewUser},
-};
use reqwest::header::CONTENT_TYPE;
@@ -63,9 +65,9 @@ async fn send_register_hook(url: &str, username: String, registered: String) {
}
#[instrument(skip_all, fields(user.username = username.as_str()))]
-pub(crate) async fn get<DB: Database>(
+pub(crate) async fn get(
Path(username): Path<String>,
- state: State<AppState<DB>>,
+ state: State<AppState>,
) -> Result<Json<UserResponse>, ErrorResponseStatus<'static>> {
let db = &state.0.database;
let user = match db.get_user(username.as_ref()).await {
@@ -87,8 +89,8 @@ pub(crate) async fn get<DB: Database>(
}
#[instrument(skip_all)]
-pub(crate) async fn register<DB: Database>(
- state: State<AppState<DB>>,
+pub(crate) async fn register(
+ state: State<AppState>,
Json(register): Json<RegisterRequest>,
) -> Result<Json<RegisterResponse>, ErrorResponseStatus<'static>> {
if !state.settings.open_registration {
@@ -163,9 +165,9 @@ pub(crate) async fn register<DB: Database>(
}
#[instrument(skip_all, fields(user.id = user.id))]
-pub(crate) async fn delete<DB: Database>(
+pub(crate) async fn delete(
UserAuth(user): UserAuth,
- state: State<AppState<DB>>,
+ state: State<AppState>,
) -> Result<Json<DeleteUserResponse>, ErrorResponseStatus<'static>> {
debug!("request to delete user {}", user.id);
@@ -183,9 +185,9 @@ pub(crate) async fn delete<DB: Database>(
}
#[instrument(skip_all, fields(user.id = user.id, change_password))]
-pub(crate) async fn change_password<DB: Database>(
+pub(crate) async fn change_password(
UserAuth(mut user): UserAuth,
- state: State<AppState<DB>>,
+ state: State<AppState>,
Json(change_password): Json<ChangePasswordRequest>,
) -> Result<Json<ChangePasswordResponse>, ErrorResponseStatus<'static>> {
let db = &state.0.database;
@@ -213,8 +215,8 @@ pub(crate) async fn change_password<DB: Database>(
}
#[instrument(skip_all, fields(user.username = login.username.as_str()))]
-pub(crate) async fn login<DB: Database>(
- state: State<AppState<DB>>,
+pub(crate) async fn login(
+ state: State<AppState>,
login: Json<LoginRequest>,
) -> Result<Json<LoginResponse>, ErrorResponseStatus<'static>> {
let db = &state.0.database;
diff --git a/crates/turtle/src/atuin_server/handlers/v0/record.rs b/crates/turtle/src/atuin_server/handlers/v0/record.rs
index 2cc09118..88027547 100644
--- a/crates/turtle/src/atuin_server/handlers/v0/record.rs
+++ b/crates/turtle/src/atuin_server/handlers/v0/record.rs
@@ -7,14 +7,13 @@ use crate::atuin_server::{
handlers::{ErrorResponse, ErrorResponseStatus, RespExt},
router::{AppState, UserAuth},
};
-use crate::atuin_server_database::Database;
use crate::atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
#[instrument(skip_all, fields(user.id = user.id))]
-pub(crate) async fn post<DB: Database>(
+pub(crate) async fn post(
UserAuth(user): UserAuth,
- state: State<AppState<DB>>,
+ state: State<AppState>,
Json(records): Json<Vec<Record<EncryptedData>>>,
) -> Result<(), ErrorResponseStatus<'static>> {
let State(AppState { database, settings }) = state;
@@ -51,9 +50,9 @@ pub(crate) async fn post<DB: Database>(
}
#[instrument(skip_all, fields(user.id = user.id))]
-pub(crate) async fn index<DB: Database>(
+pub(crate) async fn index(
UserAuth(user): UserAuth,
- state: State<AppState<DB>>,
+ state: State<AppState>,
) -> Result<Json<RecordStatus>, ErrorResponseStatus<'static>> {
let State(AppState {
database,
@@ -84,10 +83,10 @@ pub(crate) struct NextParams {
}
#[instrument(skip_all, fields(user.id = user.id))]
-pub(crate) async fn next<DB: Database>(
+pub(crate) async fn next(
params: Query<NextParams>,
UserAuth(user): UserAuth,
- state: State<AppState<DB>>,
+ state: State<AppState>,
) -> Result<Json<Vec<Record<EncryptedData>>>, ErrorResponseStatus<'static>> {
let State(AppState {
database,
diff --git a/crates/turtle/src/atuin_server/handlers/v0/store.rs b/crates/turtle/src/atuin_server/handlers/v0/store.rs
index 8269d6b3..f0aa1b36 100644
--- a/crates/turtle/src/atuin_server/handlers/v0/store.rs
+++ b/crates/turtle/src/atuin_server/handlers/v0/store.rs
@@ -7,16 +7,15 @@ use crate::atuin_server::{
handlers::{ErrorResponse, ErrorResponseStatus, RespExt},
router::{AppState, UserAuth},
};
-use crate::atuin_server_database::Database;
#[derive(Deserialize)]
pub(crate) struct DeleteParams {}
#[instrument(skip_all, fields(user.id = user.id))]
-pub(crate) async fn delete<DB: Database>(
+pub(crate) async fn delete(
_params: Query<DeleteParams>,
UserAuth(user): UserAuth,
- state: State<AppState<DB>>,
+ state: State<AppState>,
) -> Result<(), ErrorResponseStatus<'static>> {
let State(AppState {
database,
diff --git a/crates/turtle/src/atuin_server/mod.rs b/crates/turtle/src/atuin_server/mod.rs
index ad480e1d..c96a13bc 100644
--- a/crates/turtle/src/atuin_server/mod.rs
+++ b/crates/turtle/src/atuin_server/mod.rs
@@ -1,14 +1,14 @@
use std::future::Future;
use std::net::SocketAddr;
-use crate::atuin_server_database::Database;
use axum::{Router, serve};
+use database::db::Database;
use eyre::{Context, Result};
+pub(crate) mod database;
mod handlers;
mod metrics;
mod router;
-mod utils;
pub(crate) use settings::Settings;
@@ -31,8 +31,8 @@ async fn shutdown_signal() {
eprintln!("Shutting down gracefully...");
}
-pub(crate) async fn launch<Db: Database>(settings: Settings, addr: SocketAddr) -> Result<()> {
- launch_with_tcp_listener::<Db>(
+pub(crate) async fn launch(settings: Settings, addr: SocketAddr) -> Result<()> {
+ launch_with_tcp_listener(
settings,
TcpListener::bind(addr)
.await
@@ -42,12 +42,12 @@ pub(crate) async fn launch<Db: Database>(settings: Settings, addr: SocketAddr) -
.await
}
-pub(crate) async fn launch_with_tcp_listener<Db: Database>(
+pub(crate) async fn launch_with_tcp_listener(
settings: Settings,
listener: TcpListener,
shutdown: impl Future<Output = ()> + Send + 'static,
) -> Result<()> {
- let r = make_router::<Db>(settings).await?;
+ let r = make_router(settings).await?;
serve(listener, r.into_make_service())
.with_graceful_shutdown(shutdown)
@@ -77,8 +77,8 @@ pub(crate) async fn launch_metrics_server(host: String, port: u16) -> Result<()>
Ok(())
}
-async fn make_router<Db: Database>(settings: Settings) -> Result<Router, eyre::Error> {
- let db = Db::new(&settings.db_settings)
+async fn make_router(settings: Settings) -> Result<Router, eyre::Error> {
+ let db = Database::new(&settings.db_settings)
.await
.wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?;
let r = router::router(db, settings);
diff --git a/crates/turtle/src/atuin_server/router.rs b/crates/turtle/src/atuin_server/router.rs
index ed3d1e55..778e699a 100644
--- a/crates/turtle/src/atuin_server/router.rs
+++ b/crates/turtle/src/atuin_server/router.rs
@@ -1,4 +1,7 @@
-use crate::atuin_common::api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ErrorResponse};
+use crate::{
+ atuin_common::api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ErrorResponse},
+ atuin_server::database::{DbError, db::Database, models::User},
+};
use axum::{
Router,
extract::{FromRequestParts, Request},
@@ -17,19 +20,15 @@ use crate::atuin_server::{
metrics,
settings::Settings,
};
-use crate::atuin_server_database::{Database, DbError, models::User};
pub(crate) struct UserAuth(pub(crate) User);
-impl<DB: Send + Sync> FromRequestParts<AppState<DB>> for UserAuth
-where
- DB: Database,
-{
+impl FromRequestParts<AppState> for UserAuth {
type Rejection = ErrorResponseStatus<'static>;
async fn from_request_parts(
req: &mut Parts,
- state: &AppState<DB>,
+ state: &AppState,
) -> Result<Self, Self::Rejection> {
let auth_header = req
.headers
@@ -78,18 +77,6 @@ async fn teapot() -> impl IntoResponse {
(http::StatusCode::NOT_FOUND, "404 not found")
}
-async fn clacks_overhead(request: Request, next: Next) -> Response {
- let mut response = next.run(request).await;
-
- let gnu_terry_value = "GNU Terry Pratchett, Kris Nova";
- let gnu_terry_header = "X-Clacks-Overhead";
-
- response
- .headers_mut()
- .insert(gnu_terry_header, gnu_terry_value.parse().unwrap());
- response
-}
-
/// Ensure that we only try and sync with clients on the same major version
async fn semver(request: Request, next: Next) -> Response {
let mut response = next.run(request).await;
@@ -101,27 +88,16 @@ async fn semver(request: Request, next: Next) -> Response {
}
#[derive(Clone)]
-pub(crate) struct AppState<DB: Database> {
- pub(crate) database: DB,
+pub(crate) struct AppState {
+ pub(crate) database: Database,
pub(crate) settings: Settings,
}
-pub(crate) fn router<DB: Database>(database: DB, settings: Settings) -> Router {
- let mut routes = Router::new()
+pub(crate) fn router(database: Database, settings: Settings) -> Router {
+ let routes = Router::new()
.route("/", get(handlers::index))
.route("/healthz", get(handlers::health::health_check));
- // Sync v1 routes - can be disabled in favor of record-based sync
- if settings.sync_v1_enabled {
- routes = routes
- .route("/sync/count", get(handlers::history::count))
- .route("/sync/history", get(handlers::history::list))
- .route("/sync/calendar/{focus}", get(handlers::history::calendar))
- .route("/sync/status", get(handlers::status::status))
- .route("/history", post(handlers::history::add))
- .route("/history", delete(handlers::history::delete));
- }
-
let routes = routes
.route("/user/{username}", get(handlers::user::get))
.route("/account", delete(handlers::user::delete))
@@ -147,7 +123,6 @@ pub(crate) fn router<DB: Database>(database: DB, settings: Settings) -> Router {
.with_state(AppState { database, settings })
.layer(
ServiceBuilder::new()
- .layer(axum::middleware::from_fn(clacks_overhead))
.layer(TraceLayer::new_for_http())
.layer(axum::middleware::from_fn(metrics::track_metrics))
.layer(axum::middleware::from_fn(semver)),
diff --git a/crates/turtle/src/atuin_server/settings.rs b/crates/turtle/src/atuin_server/settings.rs
index 1d0ac2d0..b62f24e1 100644
--- a/crates/turtle/src/atuin_server/settings.rs
+++ b/crates/turtle/src/atuin_server/settings.rs
@@ -1,11 +1,12 @@
use std::{io::prelude::*, path::PathBuf};
-use crate::atuin_server_database::DbSettings;
use config::{Config, Environment, File as ConfigFile, FileFormat};
use eyre::{Result, eyre};
use fs_err::{File, create_dir_all};
use serde::{Deserialize, Serialize};
+use crate::atuin_server::database::DbSettings;
+
#[derive(Clone, Debug, Deserialize, Serialize)]
pub(crate) struct Metrics {
#[serde(alias = "enabled")]
@@ -37,10 +38,6 @@ pub(crate) struct Settings {
pub(crate) register_webhook_username: String,
pub(crate) metrics: Metrics,
- /// Enable legacy sync v1 routes (history-based sync)
- /// Set to false to use only the newer record-based sync
- pub(crate) sync_v1_enabled: bool,
-
/// Advertise a version that is not what we are _actually_ running
/// Many clients compare their version with api.atuin.sh, and if they differ, notify the user
/// that an update is available.
@@ -78,7 +75,6 @@ impl Settings {
.set_default("metrics.enable", false)?
.set_default("metrics.host", "127.0.0.1")?
.set_default("metrics.port", 9001)?
- .set_default("sync_v1_enabled", true)?
.add_source(
Environment::with_prefix("atuin")
.prefix_separator("_")
diff --git a/crates/turtle/src/atuin_server/utils.rs b/crates/turtle/src/atuin_server/utils.rs
deleted file mode 100644
index cceef3ed..00000000
--- a/crates/turtle/src/atuin_server/utils.rs
+++ /dev/null
@@ -1,15 +0,0 @@
-use eyre::Result;
-use semver::{Version, VersionReq};
-
-pub(crate) fn client_version_min(user_agent: &str, req: &str) -> Result<bool> {
- if user_agent.is_empty() {
- return Ok(false);
- }
-
- let version = user_agent.replace("atuin/", "");
-
- let req = VersionReq::parse(req)?;
- let version = Version::parse(version.as_str())?;
-
- Ok(req.matches(&version))
-}