diff options
| -rw-r--r-- | crates/atuin-server-database/src/lib.rs | 21 | ||||
| -rw-r--r-- | crates/atuin-server-postgres/src/lib.rs | 85 | ||||
| -rw-r--r-- | crates/atuin-server/server.toml | 4 | ||||
| -rw-r--r-- | crates/atuin/tests/common/mod.rs | 5 |
4 files changed, 86 insertions, 29 deletions
diff --git a/crates/atuin-server-database/src/lib.rs b/crates/atuin-server-database/src/lib.rs index e70c755c..db170b50 100644 --- a/crates/atuin-server-database/src/lib.rs +++ b/crates/atuin-server-database/src/lib.rs @@ -51,6 +51,8 @@ pub enum DbType { #[derive(Clone, Deserialize, Serialize)] pub struct DbSettings { pub db_uri: String, + /// Optional URI for read replicas. If set, read-only queries will use this connection. + pub read_db_uri: Option<String>, } impl DbSettings { @@ -65,22 +67,29 @@ impl DbSettings { } } +fn redact_db_uri(uri: &str) -> String { + url::Url::parse(uri) + .map(|mut url| { + let _ = url.set_password(Some("****")); + url.to_string() + }) + .unwrap_or_else(|_| uri.to_string()) +} + // Do our best to redact passwords so they're not logged in the event of an error. impl Debug for DbSettings { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { if self.db_type() == DbType::Postgres { - let redacted_uri = url::Url::parse(&self.db_uri) - .map(|mut url| { - let _ = url.set_password(Some("****")); - url.to_string() - }) - .unwrap_or(self.db_uri.clone()); + let redacted_uri = redact_db_uri(&self.db_uri); + let redacted_read_uri = self.read_db_uri.as_ref().map(|uri| redact_db_uri(uri)); f.debug_struct("DbSettings") .field("db_uri", &redacted_uri) + .field("read_db_uri", &redacted_read_uri) .finish() } else { f.debug_struct("DbSettings") .field("db_uri", &self.db_uri) + .field("read_db_uri", &self.read_db_uri) .finish() } } diff --git a/crates/atuin-server-postgres/src/lib.rs b/crates/atuin-server-postgres/src/lib.rs index 39c25256..8c40e6cc 100644 --- a/crates/atuin-server-postgres/src/lib.rs +++ b/crates/atuin-server-postgres/src/lib.rs @@ -24,6 +24,16 @@ const MIN_PG_VERSION: u32 = 14; #[derive(Clone)] pub struct Postgres { pool: sqlx::Pool<sqlx::postgres::Postgres>, + /// Optional read replica pool for read-only queries + read_pool: Option<sqlx::Pool<sqlx::postgres::Postgres>>, +} + +impl Postgres { + /// Returns the appropriate pool for read operations. + /// Uses read_pool if available, otherwise falls back to the primary pool. + fn read_pool(&self) -> &sqlx::Pool<sqlx::postgres::Postgres> { + self.read_pool.as_ref().unwrap_or(&self.pool) + } } fn fix_error(error: sqlx::Error) -> DbError { @@ -65,14 +75,45 @@ impl Database for Postgres { .await .map_err(|error| DbError::Other(error.into()))?; - Ok(Self { pool }) + // Create read replica pool if configured + let read_pool = if let Some(read_db_uri) = &settings.read_db_uri { + tracing::info!("Connecting to read replica database"); + let read_pool = PgPoolOptions::new() + .max_connections(100) + .connect(read_db_uri.as_str()) + .await + .map_err(fix_error)?; + + // Verify the read replica is also a supported PostgreSQL version + let read_pg_major_version: u32 = read_pool + .acquire() + .await + .map_err(fix_error)? + .server_version_num() + .ok_or(DbError::Other(eyre::Report::msg( + "could not get PostgreSQL version from read replica", + )))? + / 10000; + + if read_pg_major_version < MIN_PG_VERSION { + return Err(DbError::Other(eyre::Report::msg(format!( + "unsupported PostgreSQL version {read_pg_major_version} on read replica, minimum required is {MIN_PG_VERSION}" + )))); + } + + Some(read_pool) + } else { + None + }; + + Ok(Self { pool, read_pool }) } #[instrument(skip_all)] async fn get_session(&self, token: &str) -> DbResult<Session> { sqlx::query_as("select id, user_id, token from sessions where token = $1") .bind(token) - .fetch_one(&self.pool) + .fetch_one(self.read_pool()) .await .map_err(fix_error) .map(|DbSession(session)| session) @@ -84,7 +125,7 @@ impl Database for Postgres { "select id, username, email, password, verified_at from users where username = $1", ) .bind(username) - .fetch_one(&self.pool) + .fetch_one(self.read_pool()) .await .map_err(fix_error) .map(|DbUser(user)| user) @@ -95,7 +136,7 @@ impl Database for Postgres { let res: (bool,) = sqlx::query_as("select verified_at is not null from users where id = $1") .bind(id) - .fetch_one(&self.pool) + .fetch_one(self.read_pool()) .await .map_err(fix_error)?; @@ -173,13 +214,13 @@ impl Database for Postgres { #[instrument(skip_all)] async fn get_session_user(&self, token: &str) -> DbResult<User> { sqlx::query_as( - "select users.id, users.username, users.email, users.password, users.verified_at from users - inner join sessions - on users.id = sessions.user_id + "select users.id, users.username, users.email, users.password, users.verified_at from users + inner join sessions + on users.id = sessions.user_id and sessions.token = $1", ) .bind(token) - .fetch_one(&self.pool) + .fetch_one(self.read_pool()) .await .map_err(fix_error) .map(|DbUser(user)| user) @@ -196,7 +237,7 @@ impl Database for Postgres { where user_id = $1", ) .bind(user.id) - .fetch_one(&self.pool) + .fetch_one(self.read_pool()) .await .map_err(fix_error)?; @@ -210,7 +251,7 @@ impl Database for Postgres { // edge case. let res: (i64,) = sqlx::query_as("select sum(total) from total_history_count_user") - .fetch_optional(&self.pool) + .fetch_optional(self.read_pool()) .await .map_err(fix_error)? .unwrap_or((0,)); @@ -225,7 +266,7 @@ impl Database for Postgres { where user_id = $1", ) .bind(user.id) - .fetch_one(&self.pool) + .fetch_one(self.read_pool()) .await .map_err(fix_error)?; @@ -283,12 +324,12 @@ impl Database for Postgres { // edge case. let res = sqlx::query( - "select client_id from history + "select client_id from history where user_id = $1 and deleted_at is not null", ) .bind(user.id) - .fetch_all(&self.pool) + .fetch_all(self.read_pool()) .await .map_err(fix_error)?; @@ -315,7 +356,7 @@ impl Database for Postgres { .bind(user.id) .bind(into_utc(range.start)) .bind(into_utc(range.end)) - .fetch_one(&self.pool) + .fetch_one(self.read_pool()) .await .map_err(fix_error)?; @@ -332,7 +373,7 @@ impl Database for Postgres { page_size: i64, ) -> DbResult<Vec<History>> { let res = sqlx::query_as( - "select id, client_id, user_id, hostname, timestamp, data, created_at from history + "select id, client_id, user_id, hostname, timestamp, data, created_at from history where user_id = $1 and hostname != $2 and created_at >= $3 @@ -345,7 +386,7 @@ impl Database for Postgres { .bind(into_utc(created_after)) .bind(into_utc(since)) .bind(page_size) - .fetch(&self.pool) + .fetch(self.read_pool()) .map_ok(|DbHistory(h)| h) .try_collect() .await @@ -486,7 +527,7 @@ impl Database for Postgres { async fn get_user_session(&self, u: &User) -> DbResult<Session> { sqlx::query_as("select id, user_id, token from sessions where user_id = $1") .bind(u.id) - .fetch_one(&self.pool) + .fetch_one(self.read_pool()) .await .map_err(fix_error) .map(|DbSession(session)| session) @@ -495,13 +536,13 @@ impl Database for Postgres { #[instrument(skip_all)] async fn oldest_history(&self, user: &User) -> DbResult<History> { sqlx::query_as( - "select id, client_id, user_id, hostname, timestamp, data, created_at from history + "select id, client_id, user_id, hostname, timestamp, data, created_at from history where user_id = $1 order by timestamp asc limit 1", ) .bind(user.id) - .fetch_one(&self.pool) + .fetch_one(self.read_pool()) .await .map_err(fix_error) .map(|DbHistory(h)| h) @@ -606,7 +647,7 @@ impl Database for Postgres { .bind(host) .bind(start as i64) .bind(count as i64) - .fetch_all(&self.pool) + .fetch_all(self.read_pool()) .await .map_err(fix_error); @@ -650,14 +691,14 @@ impl Database for Postgres { tracing::debug!("using idx cache for user {}", user.id); sqlx::query_as("select host, tag, idx from store_idx_cache where user_id = $1") .bind(user.id) - .fetch_all(&self.pool) + .fetch_all(self.read_pool()) .await .map_err(fix_error)? } else { tracing::debug!("using aggregate query for user {}", user.id); sqlx::query_as(STATUS_SQL) .bind(user.id) - .fetch_all(&self.pool) + .fetch_all(self.read_pool()) .await .map_err(fix_error)? }; diff --git a/crates/atuin-server/server.toml b/crates/atuin-server/server.toml index 1eff5b72..6212de00 100644 --- a/crates/atuin-server/server.toml +++ b/crates/atuin-server/server.toml @@ -11,6 +11,10 @@ # db_uri="postgres://username:password@localhost/atuin" # db_uri="sqlite:///config/atuin-server.db" +## Optional: URI for read replica database +## If set, read-only queries will be routed to this database +# read_db_uri="postgres://username:password@localhost-replica/atuin" + ## Maximum size for one history entry # max_history_length = 8192 diff --git a/crates/atuin/tests/common/mod.rs b/crates/atuin/tests/common/mod.rs index d79c13d6..bf8f85a7 100644 --- a/crates/atuin/tests/common/mod.rs +++ b/crates/atuin/tests/common/mod.rs @@ -36,7 +36,10 @@ pub async fn start_server(path: &str) -> (String, oneshot::Sender<()>, JoinHandl page_size: 1100, register_webhook_url: None, register_webhook_username: String::new(), - db_settings: DbSettings { db_uri }, + db_settings: DbSettings { + db_uri: db_uri, + read_db_uri: None, + }, metrics: atuin_server::settings::Metrics::default(), tls: atuin_server::settings::Tls::default(), mail: atuin_server::settings::Mail::default(), |
