diff options
| author | Ellie Huxtable <ellie@atuin.sh> | 2025-12-18 16:12:39 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-12-18 16:12:39 -0500 |
| commit | 96e6bb23472735d4d1dec299d41e19a38f63adbd (patch) | |
| tree | 4a18961d4f7ef39f2d4c1a88d0aad9e7db582719 /crates/atuin-server-postgres | |
| parent | fix: Move thorough search through search.filters w/ workspaces (#2703) (diff) | |
| download | atuin-96e6bb23472735d4d1dec299d41e19a38f63adbd.zip | |
feat: add support for read replicas to postgres (#3029)
Support for routing read queries to read replicas for Postgres
We have very high database usage these days, and now run shell history
sync off of [Planetscale](https://planetscale.com/)
This setup gives us 2x read replicas, meaning we can reduce load on the
primary
I doubt this is required for anyone else's setup - lmk if so.
Diffstat (limited to 'crates/atuin-server-postgres')
| -rw-r--r-- | crates/atuin-server-postgres/src/lib.rs | 85 |
1 files changed, 63 insertions, 22 deletions
diff --git a/crates/atuin-server-postgres/src/lib.rs b/crates/atuin-server-postgres/src/lib.rs index 39c25256..8c40e6cc 100644 --- a/crates/atuin-server-postgres/src/lib.rs +++ b/crates/atuin-server-postgres/src/lib.rs @@ -24,6 +24,16 @@ const MIN_PG_VERSION: u32 = 14; #[derive(Clone)] pub struct Postgres { pool: sqlx::Pool<sqlx::postgres::Postgres>, + /// Optional read replica pool for read-only queries + read_pool: Option<sqlx::Pool<sqlx::postgres::Postgres>>, +} + +impl Postgres { + /// Returns the appropriate pool for read operations. + /// Uses read_pool if available, otherwise falls back to the primary pool. + fn read_pool(&self) -> &sqlx::Pool<sqlx::postgres::Postgres> { + self.read_pool.as_ref().unwrap_or(&self.pool) + } } fn fix_error(error: sqlx::Error) -> DbError { @@ -65,14 +75,45 @@ impl Database for Postgres { .await .map_err(|error| DbError::Other(error.into()))?; - Ok(Self { pool }) + // Create read replica pool if configured + let read_pool = if let Some(read_db_uri) = &settings.read_db_uri { + tracing::info!("Connecting to read replica database"); + let read_pool = PgPoolOptions::new() + .max_connections(100) + .connect(read_db_uri.as_str()) + .await + .map_err(fix_error)?; + + // Verify the read replica is also a supported PostgreSQL version + let read_pg_major_version: u32 = read_pool + .acquire() + .await + .map_err(fix_error)? + .server_version_num() + .ok_or(DbError::Other(eyre::Report::msg( + "could not get PostgreSQL version from read replica", + )))? + / 10000; + + if read_pg_major_version < MIN_PG_VERSION { + return Err(DbError::Other(eyre::Report::msg(format!( + "unsupported PostgreSQL version {read_pg_major_version} on read replica, minimum required is {MIN_PG_VERSION}" + )))); + } + + Some(read_pool) + } else { + None + }; + + Ok(Self { pool, read_pool }) } #[instrument(skip_all)] async fn get_session(&self, token: &str) -> DbResult<Session> { sqlx::query_as("select id, user_id, token from sessions where token = $1") .bind(token) - .fetch_one(&self.pool) + .fetch_one(self.read_pool()) .await .map_err(fix_error) .map(|DbSession(session)| session) @@ -84,7 +125,7 @@ impl Database for Postgres { "select id, username, email, password, verified_at from users where username = $1", ) .bind(username) - .fetch_one(&self.pool) + .fetch_one(self.read_pool()) .await .map_err(fix_error) .map(|DbUser(user)| user) @@ -95,7 +136,7 @@ impl Database for Postgres { let res: (bool,) = sqlx::query_as("select verified_at is not null from users where id = $1") .bind(id) - .fetch_one(&self.pool) + .fetch_one(self.read_pool()) .await .map_err(fix_error)?; @@ -173,13 +214,13 @@ impl Database for Postgres { #[instrument(skip_all)] async fn get_session_user(&self, token: &str) -> DbResult<User> { sqlx::query_as( - "select users.id, users.username, users.email, users.password, users.verified_at from users - inner join sessions - on users.id = sessions.user_id + "select users.id, users.username, users.email, users.password, users.verified_at from users + inner join sessions + on users.id = sessions.user_id and sessions.token = $1", ) .bind(token) - .fetch_one(&self.pool) + .fetch_one(self.read_pool()) .await .map_err(fix_error) .map(|DbUser(user)| user) @@ -196,7 +237,7 @@ impl Database for Postgres { where user_id = $1", ) .bind(user.id) - .fetch_one(&self.pool) + .fetch_one(self.read_pool()) .await .map_err(fix_error)?; @@ -210,7 +251,7 @@ impl Database for Postgres { // edge case. let res: (i64,) = sqlx::query_as("select sum(total) from total_history_count_user") - .fetch_optional(&self.pool) + .fetch_optional(self.read_pool()) .await .map_err(fix_error)? .unwrap_or((0,)); @@ -225,7 +266,7 @@ impl Database for Postgres { where user_id = $1", ) .bind(user.id) - .fetch_one(&self.pool) + .fetch_one(self.read_pool()) .await .map_err(fix_error)?; @@ -283,12 +324,12 @@ impl Database for Postgres { // edge case. let res = sqlx::query( - "select client_id from history + "select client_id from history where user_id = $1 and deleted_at is not null", ) .bind(user.id) - .fetch_all(&self.pool) + .fetch_all(self.read_pool()) .await .map_err(fix_error)?; @@ -315,7 +356,7 @@ impl Database for Postgres { .bind(user.id) .bind(into_utc(range.start)) .bind(into_utc(range.end)) - .fetch_one(&self.pool) + .fetch_one(self.read_pool()) .await .map_err(fix_error)?; @@ -332,7 +373,7 @@ impl Database for Postgres { page_size: i64, ) -> DbResult<Vec<History>> { let res = sqlx::query_as( - "select id, client_id, user_id, hostname, timestamp, data, created_at from history + "select id, client_id, user_id, hostname, timestamp, data, created_at from history where user_id = $1 and hostname != $2 and created_at >= $3 @@ -345,7 +386,7 @@ impl Database for Postgres { .bind(into_utc(created_after)) .bind(into_utc(since)) .bind(page_size) - .fetch(&self.pool) + .fetch(self.read_pool()) .map_ok(|DbHistory(h)| h) .try_collect() .await @@ -486,7 +527,7 @@ impl Database for Postgres { async fn get_user_session(&self, u: &User) -> DbResult<Session> { sqlx::query_as("select id, user_id, token from sessions where user_id = $1") .bind(u.id) - .fetch_one(&self.pool) + .fetch_one(self.read_pool()) .await .map_err(fix_error) .map(|DbSession(session)| session) @@ -495,13 +536,13 @@ impl Database for Postgres { #[instrument(skip_all)] async fn oldest_history(&self, user: &User) -> DbResult<History> { sqlx::query_as( - "select id, client_id, user_id, hostname, timestamp, data, created_at from history + "select id, client_id, user_id, hostname, timestamp, data, created_at from history where user_id = $1 order by timestamp asc limit 1", ) .bind(user.id) - .fetch_one(&self.pool) + .fetch_one(self.read_pool()) .await .map_err(fix_error) .map(|DbHistory(h)| h) @@ -606,7 +647,7 @@ impl Database for Postgres { .bind(host) .bind(start as i64) .bind(count as i64) - .fetch_all(&self.pool) + .fetch_all(self.read_pool()) .await .map_err(fix_error); @@ -650,14 +691,14 @@ impl Database for Postgres { tracing::debug!("using idx cache for user {}", user.id); sqlx::query_as("select host, tag, idx from store_idx_cache where user_id = $1") .bind(user.id) - .fetch_all(&self.pool) + .fetch_all(self.read_pool()) .await .map_err(fix_error)? } else { tracing::debug!("using aggregate query for user {}", user.id); sqlx::query_as(STATUS_SQL) .bind(user.id) - .fetch_all(&self.pool) + .fetch_all(self.read_pool()) .await .map_err(fix_error)? }; |
