use std::collections::HashMap; use rand::Rng; use crate::{ atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}, atuin_server::database::{DbError, DbResult, DbSettings, models::User}, }; use sqlx::postgres::PgPoolOptions; use tracing::instrument; use uuid::Uuid; use wrappers::DbRecord; mod wrappers; const MIN_PG_VERSION: u32 = 14; #[derive(Clone)] pub(crate) struct ServerPostgres { pool: sqlx::Pool, /// Optional read replica pool for read-only queries read_pool: Option>, } impl ServerPostgres { /// Returns the appropriate pool for read operations. /// Uses `read_pool` if available, otherwise falls back to the primary pool. fn read_pool(&self) -> &sqlx::Pool { self.read_pool.as_ref().unwrap_or(&self.pool) } } impl ServerPostgres { pub(crate) async fn new(settings: &DbSettings) -> DbResult { let pool = PgPoolOptions::new() .max_connections(100) .connect(settings.db_uri.as_str()) .await?; // Call server_version_num to get the DB server's major version number // The call returns None for servers older than 8.x. let pg_major_version: u32 = pool.acquire() .await? .server_version_num() .ok_or(DbError::Other(eyre::Report::msg( "could not get PostgreSQL version", )))? / 10000; if pg_major_version < MIN_PG_VERSION { return Err(DbError::Other(eyre::Report::msg(format!( "unsupported PostgreSQL version {pg_major_version}, minimum required is {MIN_PG_VERSION}" )))); } sqlx::migrate!("./db/server-pg-migrations") .run(&pool) .await .map_err(|error| DbError::Other(error.into()))?; // Create read replica pool if configured let read_pool = if let Some(read_db_uri) = &settings.read_db_uri { tracing::info!("Connecting to read replica database"); let read_pool = PgPoolOptions::new() .max_connections(100) .connect(read_db_uri.as_str()) .await?; // Verify the read replica is also a supported PostgreSQL version let read_pg_major_version: u32 = read_pool .acquire() .await? .server_version_num() .ok_or(DbError::Other(eyre::Report::msg( "could not get PostgreSQL version from read replica", )))? / 10000; if read_pg_major_version < MIN_PG_VERSION { return Err(DbError::Other(eyre::Report::msg(format!( "unsupported PostgreSQL version {read_pg_major_version} on read replica, minimum required is {MIN_PG_VERSION}" )))); } Some(read_pool) } else { None }; Ok(Self { pool, read_pool }) } #[instrument(skip_all)] pub(crate) async fn add_records( &self, user: &User, records: &[Record], ) -> DbResult<()> { let mut tx = self.pool.begin().await?; // We won't have uploaded this data if it wasn't the max. Therefore, we can deduce the max // idx without having to make further database queries. Doing the query on this small // amount of data should be much, much faster. // // Worst case, say we get this wrong. We end up caching data that isn't actually the max // idx, so clients upload again. The cache logic can be verified with a sql query anyway :) let mut heads = HashMap::<(HostId, &str), u64>::new(); for i in records { let id = crate::atuin_common::utils::uuid_v7(); let result = sqlx::query( " INSERT INTO store (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) ON conflict DO nothing ", ) .bind(id) .bind(i.id) .bind(i.host.id) .bind(i.idx as i64) .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time .bind(&i.version) .bind(&i.tag) .bind(&i.data.data) .bind(&i.data.content_encryption_key) .bind(user.id) .execute(&mut *tx) .await?; // Only update heads if we actually inserted the record if result.rows_affected() > 0 { heads .entry((i.host.id, &i.tag)) .and_modify(|e| { if i.idx > *e { *e = i.idx; } }) .or_insert(i.idx); } } // we've built the map of heads for this push, so commit it to the database for ((host, tag), idx) in heads { sqlx::query( " INSERT INTO store_idx_cache (user_id, host, tag, idx) VALUES ($1, $2, $3, $4) ON conflict(user_id, host, tag) DO update SET idx = greatest(store_idx_cache.idx, $4) ", ) .bind(user.id) .bind(host) .bind(tag) .bind(idx as i64) .execute(&mut *tx) .await?; } tx.commit().await?; Ok(()) } #[instrument(skip_all)] pub(crate) async fn next_records( &self, user: &User, host: HostId, tag: String, start: Option, count: u64, ) -> DbResult>> { tracing::debug!("{:?} - {:?} - {:?}", host, tag, start); let start = start.unwrap_or(0); let records: Result, DbError> = sqlx::query_as( " SELECT client_id, host, idx, timestamp, version, tag, data, cek FROM store WHERE user_id = $1 AND tag = $2 AND host = $3 AND idx >= $4 ORDER BY idx asc LIMIT $5 ", ) .bind(user.id) .bind(tag.clone()) .bind(host) .bind(start as i64) .bind(count as i64) .fetch_all(self.read_pool()) .await .map_err(Into::into); let ret = match records { Ok(records) => { let records: Vec> = records .into_iter() .map(|f| { let record: Record = f.into(); record }) .collect(); records } Err(DbError::NotFound) => { tracing::debug!("no records found in store: {:?}/{}", host, tag); return Ok(vec![]); } Err(e) => return Err(e), }; Ok(ret) } pub(crate) async fn status(&self, user: &User) -> DbResult { // If IDX_CACHE_ROLLOUT is set, then we // 1. Read the value of the var, use it as a % chance of using the cache // 2. If we use the cache, just read from the cache table // 3. If we don't use the cache, read from the store table // IDX_CACHE_ROLLOUT should be between 0 and 100. let idx_cache_rollout = std::env::var("IDX_CACHE_ROLLOUT").unwrap_or_else(|_| "0".to_string()); let idx_cache_rollout = idx_cache_rollout.parse::().unwrap_or(0.0); let use_idx_cache = rand::thread_rng().gen_bool(idx_cache_rollout / 100.0); let mut res: Vec<(Uuid, String, i64)> = if use_idx_cache { tracing::debug!("using idx cache for user {}", user.id); sqlx::query_as( " SELECT host, tag, idx FROM store_idx_cache WHERE user_id = $1 ", ) .bind(user.id) .fetch_all(self.read_pool()) .await? } else { tracing::debug!("using aggregate query for user {}", user.id); sqlx::query_as( " SELECT host, tag, max(idx) FROM store WHERE user_id = $1 GROUP BY host, tag ", ) .bind(user.id) .fetch_all(self.read_pool()) .await? }; res.sort(); let mut status = RecordStatus::new(); for i in &res { status.set_raw(HostId(i.0), i.1.clone(), i.2 as u64); } Ok(status) } }