aboutsummaryrefslogtreecommitdiffstats
path: root/crates/atuin-server-postgres
diff options
context:
space:
mode:
authorEllie Huxtable <ellie@atuin.sh>2025-12-18 16:12:39 -0500
committerGitHub <noreply@github.com>2025-12-18 16:12:39 -0500
commit96e6bb23472735d4d1dec299d41e19a38f63adbd (patch)
tree4a18961d4f7ef39f2d4c1a88d0aad9e7db582719 /crates/atuin-server-postgres
parentfix: Move thorough search through search.filters w/ workspaces (#2703) (diff)
downloadatuin-96e6bb23472735d4d1dec299d41e19a38f63adbd.zip
feat: add support for read replicas to postgres (#3029)
Support for routing read queries to read replicas for Postgres We have very high database usage these days, and now run shell history sync off of [Planetscale](https://planetscale.com/) This setup gives us 2x read replicas, meaning we can reduce load on the primary I doubt this is required for anyone else's setup - lmk if so.
Diffstat (limited to 'crates/atuin-server-postgres')
-rw-r--r--crates/atuin-server-postgres/src/lib.rs85
1 files changed, 63 insertions, 22 deletions
diff --git a/crates/atuin-server-postgres/src/lib.rs b/crates/atuin-server-postgres/src/lib.rs
index 39c25256..8c40e6cc 100644
--- a/crates/atuin-server-postgres/src/lib.rs
+++ b/crates/atuin-server-postgres/src/lib.rs
@@ -24,6 +24,16 @@ const MIN_PG_VERSION: u32 = 14;
#[derive(Clone)]
pub struct Postgres {
pool: sqlx::Pool<sqlx::postgres::Postgres>,
+ /// Optional read replica pool for read-only queries
+ read_pool: Option<sqlx::Pool<sqlx::postgres::Postgres>>,
+}
+
+impl Postgres {
+ /// Returns the appropriate pool for read operations.
+ /// Uses read_pool if available, otherwise falls back to the primary pool.
+ fn read_pool(&self) -> &sqlx::Pool<sqlx::postgres::Postgres> {
+ self.read_pool.as_ref().unwrap_or(&self.pool)
+ }
}
fn fix_error(error: sqlx::Error) -> DbError {
@@ -65,14 +75,45 @@ impl Database for Postgres {
.await
.map_err(|error| DbError::Other(error.into()))?;
- Ok(Self { pool })
+ // Create read replica pool if configured
+ let read_pool = if let Some(read_db_uri) = &settings.read_db_uri {
+ tracing::info!("Connecting to read replica database");
+ let read_pool = PgPoolOptions::new()
+ .max_connections(100)
+ .connect(read_db_uri.as_str())
+ .await
+ .map_err(fix_error)?;
+
+ // Verify the read replica is also a supported PostgreSQL version
+ let read_pg_major_version: u32 = read_pool
+ .acquire()
+ .await
+ .map_err(fix_error)?
+ .server_version_num()
+ .ok_or(DbError::Other(eyre::Report::msg(
+ "could not get PostgreSQL version from read replica",
+ )))?
+ / 10000;
+
+ if read_pg_major_version < MIN_PG_VERSION {
+ return Err(DbError::Other(eyre::Report::msg(format!(
+ "unsupported PostgreSQL version {read_pg_major_version} on read replica, minimum required is {MIN_PG_VERSION}"
+ ))));
+ }
+
+ Some(read_pool)
+ } else {
+ None
+ };
+
+ Ok(Self { pool, read_pool })
}
#[instrument(skip_all)]
async fn get_session(&self, token: &str) -> DbResult<Session> {
sqlx::query_as("select id, user_id, token from sessions where token = $1")
.bind(token)
- .fetch_one(&self.pool)
+ .fetch_one(self.read_pool())
.await
.map_err(fix_error)
.map(|DbSession(session)| session)
@@ -84,7 +125,7 @@ impl Database for Postgres {
"select id, username, email, password, verified_at from users where username = $1",
)
.bind(username)
- .fetch_one(&self.pool)
+ .fetch_one(self.read_pool())
.await
.map_err(fix_error)
.map(|DbUser(user)| user)
@@ -95,7 +136,7 @@ impl Database for Postgres {
let res: (bool,) =
sqlx::query_as("select verified_at is not null from users where id = $1")
.bind(id)
- .fetch_one(&self.pool)
+ .fetch_one(self.read_pool())
.await
.map_err(fix_error)?;
@@ -173,13 +214,13 @@ impl Database for Postgres {
#[instrument(skip_all)]
async fn get_session_user(&self, token: &str) -> DbResult<User> {
sqlx::query_as(
- "select users.id, users.username, users.email, users.password, users.verified_at from users
- inner join sessions
- on users.id = sessions.user_id
+ "select users.id, users.username, users.email, users.password, users.verified_at from users
+ inner join sessions
+ on users.id = sessions.user_id
and sessions.token = $1",
)
.bind(token)
- .fetch_one(&self.pool)
+ .fetch_one(self.read_pool())
.await
.map_err(fix_error)
.map(|DbUser(user)| user)
@@ -196,7 +237,7 @@ impl Database for Postgres {
where user_id = $1",
)
.bind(user.id)
- .fetch_one(&self.pool)
+ .fetch_one(self.read_pool())
.await
.map_err(fix_error)?;
@@ -210,7 +251,7 @@ impl Database for Postgres {
// edge case.
let res: (i64,) = sqlx::query_as("select sum(total) from total_history_count_user")
- .fetch_optional(&self.pool)
+ .fetch_optional(self.read_pool())
.await
.map_err(fix_error)?
.unwrap_or((0,));
@@ -225,7 +266,7 @@ impl Database for Postgres {
where user_id = $1",
)
.bind(user.id)
- .fetch_one(&self.pool)
+ .fetch_one(self.read_pool())
.await
.map_err(fix_error)?;
@@ -283,12 +324,12 @@ impl Database for Postgres {
// edge case.
let res = sqlx::query(
- "select client_id from history
+ "select client_id from history
where user_id = $1
and deleted_at is not null",
)
.bind(user.id)
- .fetch_all(&self.pool)
+ .fetch_all(self.read_pool())
.await
.map_err(fix_error)?;
@@ -315,7 +356,7 @@ impl Database for Postgres {
.bind(user.id)
.bind(into_utc(range.start))
.bind(into_utc(range.end))
- .fetch_one(&self.pool)
+ .fetch_one(self.read_pool())
.await
.map_err(fix_error)?;
@@ -332,7 +373,7 @@ impl Database for Postgres {
page_size: i64,
) -> DbResult<Vec<History>> {
let res = sqlx::query_as(
- "select id, client_id, user_id, hostname, timestamp, data, created_at from history
+ "select id, client_id, user_id, hostname, timestamp, data, created_at from history
where user_id = $1
and hostname != $2
and created_at >= $3
@@ -345,7 +386,7 @@ impl Database for Postgres {
.bind(into_utc(created_after))
.bind(into_utc(since))
.bind(page_size)
- .fetch(&self.pool)
+ .fetch(self.read_pool())
.map_ok(|DbHistory(h)| h)
.try_collect()
.await
@@ -486,7 +527,7 @@ impl Database for Postgres {
async fn get_user_session(&self, u: &User) -> DbResult<Session> {
sqlx::query_as("select id, user_id, token from sessions where user_id = $1")
.bind(u.id)
- .fetch_one(&self.pool)
+ .fetch_one(self.read_pool())
.await
.map_err(fix_error)
.map(|DbSession(session)| session)
@@ -495,13 +536,13 @@ impl Database for Postgres {
#[instrument(skip_all)]
async fn oldest_history(&self, user: &User) -> DbResult<History> {
sqlx::query_as(
- "select id, client_id, user_id, hostname, timestamp, data, created_at from history
+ "select id, client_id, user_id, hostname, timestamp, data, created_at from history
where user_id = $1
order by timestamp asc
limit 1",
)
.bind(user.id)
- .fetch_one(&self.pool)
+ .fetch_one(self.read_pool())
.await
.map_err(fix_error)
.map(|DbHistory(h)| h)
@@ -606,7 +647,7 @@ impl Database for Postgres {
.bind(host)
.bind(start as i64)
.bind(count as i64)
- .fetch_all(&self.pool)
+ .fetch_all(self.read_pool())
.await
.map_err(fix_error);
@@ -650,14 +691,14 @@ impl Database for Postgres {
tracing::debug!("using idx cache for user {}", user.id);
sqlx::query_as("select host, tag, idx from store_idx_cache where user_id = $1")
.bind(user.id)
- .fetch_all(&self.pool)
+ .fetch_all(self.read_pool())
.await
.map_err(fix_error)?
} else {
tracing::debug!("using aggregate query for user {}", user.id);
sqlx::query_as(STATUS_SQL)
.bind(user.id)
- .fetch_all(&self.pool)
+ .fetch_all(self.read_pool())
.await
.map_err(fix_error)?
};