aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--crates/atuin-server-database/src/lib.rs21
-rw-r--r--crates/atuin-server-postgres/src/lib.rs85
-rw-r--r--crates/atuin-server/server.toml4
-rw-r--r--crates/atuin/tests/common/mod.rs5
4 files changed, 86 insertions, 29 deletions
diff --git a/crates/atuin-server-database/src/lib.rs b/crates/atuin-server-database/src/lib.rs
index e70c755c..db170b50 100644
--- a/crates/atuin-server-database/src/lib.rs
+++ b/crates/atuin-server-database/src/lib.rs
@@ -51,6 +51,8 @@ pub enum DbType {
#[derive(Clone, Deserialize, Serialize)]
pub struct DbSettings {
pub db_uri: String,
+ /// Optional URI for read replicas. If set, read-only queries will use this connection.
+ pub read_db_uri: Option<String>,
}
impl DbSettings {
@@ -65,22 +67,29 @@ impl DbSettings {
}
}
+fn redact_db_uri(uri: &str) -> String {
+ url::Url::parse(uri)
+ .map(|mut url| {
+ let _ = url.set_password(Some("****"));
+ url.to_string()
+ })
+ .unwrap_or_else(|_| uri.to_string())
+}
+
// Do our best to redact passwords so they're not logged in the event of an error.
impl Debug for DbSettings {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.db_type() == DbType::Postgres {
- let redacted_uri = url::Url::parse(&self.db_uri)
- .map(|mut url| {
- let _ = url.set_password(Some("****"));
- url.to_string()
- })
- .unwrap_or(self.db_uri.clone());
+ let redacted_uri = redact_db_uri(&self.db_uri);
+ let redacted_read_uri = self.read_db_uri.as_ref().map(|uri| redact_db_uri(uri));
f.debug_struct("DbSettings")
.field("db_uri", &redacted_uri)
+ .field("read_db_uri", &redacted_read_uri)
.finish()
} else {
f.debug_struct("DbSettings")
.field("db_uri", &self.db_uri)
+ .field("read_db_uri", &self.read_db_uri)
.finish()
}
}
diff --git a/crates/atuin-server-postgres/src/lib.rs b/crates/atuin-server-postgres/src/lib.rs
index 39c25256..8c40e6cc 100644
--- a/crates/atuin-server-postgres/src/lib.rs
+++ b/crates/atuin-server-postgres/src/lib.rs
@@ -24,6 +24,16 @@ const MIN_PG_VERSION: u32 = 14;
#[derive(Clone)]
pub struct Postgres {
pool: sqlx::Pool<sqlx::postgres::Postgres>,
+ /// Optional read replica pool for read-only queries
+ read_pool: Option<sqlx::Pool<sqlx::postgres::Postgres>>,
+}
+
+impl Postgres {
+ /// Returns the appropriate pool for read operations.
+ /// Uses read_pool if available, otherwise falls back to the primary pool.
+ fn read_pool(&self) -> &sqlx::Pool<sqlx::postgres::Postgres> {
+ self.read_pool.as_ref().unwrap_or(&self.pool)
+ }
}
fn fix_error(error: sqlx::Error) -> DbError {
@@ -65,14 +75,45 @@ impl Database for Postgres {
.await
.map_err(|error| DbError::Other(error.into()))?;
- Ok(Self { pool })
+ // Create read replica pool if configured
+ let read_pool = if let Some(read_db_uri) = &settings.read_db_uri {
+ tracing::info!("Connecting to read replica database");
+ let read_pool = PgPoolOptions::new()
+ .max_connections(100)
+ .connect(read_db_uri.as_str())
+ .await
+ .map_err(fix_error)?;
+
+ // Verify the read replica is also a supported PostgreSQL version
+ let read_pg_major_version: u32 = read_pool
+ .acquire()
+ .await
+ .map_err(fix_error)?
+ .server_version_num()
+ .ok_or(DbError::Other(eyre::Report::msg(
+ "could not get PostgreSQL version from read replica",
+ )))?
+ / 10000;
+
+ if read_pg_major_version < MIN_PG_VERSION {
+ return Err(DbError::Other(eyre::Report::msg(format!(
+ "unsupported PostgreSQL version {read_pg_major_version} on read replica, minimum required is {MIN_PG_VERSION}"
+ ))));
+ }
+
+ Some(read_pool)
+ } else {
+ None
+ };
+
+ Ok(Self { pool, read_pool })
}
#[instrument(skip_all)]
async fn get_session(&self, token: &str) -> DbResult<Session> {
sqlx::query_as("select id, user_id, token from sessions where token = $1")
.bind(token)
- .fetch_one(&self.pool)
+ .fetch_one(self.read_pool())
.await
.map_err(fix_error)
.map(|DbSession(session)| session)
@@ -84,7 +125,7 @@ impl Database for Postgres {
"select id, username, email, password, verified_at from users where username = $1",
)
.bind(username)
- .fetch_one(&self.pool)
+ .fetch_one(self.read_pool())
.await
.map_err(fix_error)
.map(|DbUser(user)| user)
@@ -95,7 +136,7 @@ impl Database for Postgres {
let res: (bool,) =
sqlx::query_as("select verified_at is not null from users where id = $1")
.bind(id)
- .fetch_one(&self.pool)
+ .fetch_one(self.read_pool())
.await
.map_err(fix_error)?;
@@ -173,13 +214,13 @@ impl Database for Postgres {
#[instrument(skip_all)]
async fn get_session_user(&self, token: &str) -> DbResult<User> {
sqlx::query_as(
- "select users.id, users.username, users.email, users.password, users.verified_at from users
- inner join sessions
- on users.id = sessions.user_id
+ "select users.id, users.username, users.email, users.password, users.verified_at from users
+ inner join sessions
+ on users.id = sessions.user_id
and sessions.token = $1",
)
.bind(token)
- .fetch_one(&self.pool)
+ .fetch_one(self.read_pool())
.await
.map_err(fix_error)
.map(|DbUser(user)| user)
@@ -196,7 +237,7 @@ impl Database for Postgres {
where user_id = $1",
)
.bind(user.id)
- .fetch_one(&self.pool)
+ .fetch_one(self.read_pool())
.await
.map_err(fix_error)?;
@@ -210,7 +251,7 @@ impl Database for Postgres {
// edge case.
let res: (i64,) = sqlx::query_as("select sum(total) from total_history_count_user")
- .fetch_optional(&self.pool)
+ .fetch_optional(self.read_pool())
.await
.map_err(fix_error)?
.unwrap_or((0,));
@@ -225,7 +266,7 @@ impl Database for Postgres {
where user_id = $1",
)
.bind(user.id)
- .fetch_one(&self.pool)
+ .fetch_one(self.read_pool())
.await
.map_err(fix_error)?;
@@ -283,12 +324,12 @@ impl Database for Postgres {
// edge case.
let res = sqlx::query(
- "select client_id from history
+ "select client_id from history
where user_id = $1
and deleted_at is not null",
)
.bind(user.id)
- .fetch_all(&self.pool)
+ .fetch_all(self.read_pool())
.await
.map_err(fix_error)?;
@@ -315,7 +356,7 @@ impl Database for Postgres {
.bind(user.id)
.bind(into_utc(range.start))
.bind(into_utc(range.end))
- .fetch_one(&self.pool)
+ .fetch_one(self.read_pool())
.await
.map_err(fix_error)?;
@@ -332,7 +373,7 @@ impl Database for Postgres {
page_size: i64,
) -> DbResult<Vec<History>> {
let res = sqlx::query_as(
- "select id, client_id, user_id, hostname, timestamp, data, created_at from history
+ "select id, client_id, user_id, hostname, timestamp, data, created_at from history
where user_id = $1
and hostname != $2
and created_at >= $3
@@ -345,7 +386,7 @@ impl Database for Postgres {
.bind(into_utc(created_after))
.bind(into_utc(since))
.bind(page_size)
- .fetch(&self.pool)
+ .fetch(self.read_pool())
.map_ok(|DbHistory(h)| h)
.try_collect()
.await
@@ -486,7 +527,7 @@ impl Database for Postgres {
async fn get_user_session(&self, u: &User) -> DbResult<Session> {
sqlx::query_as("select id, user_id, token from sessions where user_id = $1")
.bind(u.id)
- .fetch_one(&self.pool)
+ .fetch_one(self.read_pool())
.await
.map_err(fix_error)
.map(|DbSession(session)| session)
@@ -495,13 +536,13 @@ impl Database for Postgres {
#[instrument(skip_all)]
async fn oldest_history(&self, user: &User) -> DbResult<History> {
sqlx::query_as(
- "select id, client_id, user_id, hostname, timestamp, data, created_at from history
+ "select id, client_id, user_id, hostname, timestamp, data, created_at from history
where user_id = $1
order by timestamp asc
limit 1",
)
.bind(user.id)
- .fetch_one(&self.pool)
+ .fetch_one(self.read_pool())
.await
.map_err(fix_error)
.map(|DbHistory(h)| h)
@@ -606,7 +647,7 @@ impl Database for Postgres {
.bind(host)
.bind(start as i64)
.bind(count as i64)
- .fetch_all(&self.pool)
+ .fetch_all(self.read_pool())
.await
.map_err(fix_error);
@@ -650,14 +691,14 @@ impl Database for Postgres {
tracing::debug!("using idx cache for user {}", user.id);
sqlx::query_as("select host, tag, idx from store_idx_cache where user_id = $1")
.bind(user.id)
- .fetch_all(&self.pool)
+ .fetch_all(self.read_pool())
.await
.map_err(fix_error)?
} else {
tracing::debug!("using aggregate query for user {}", user.id);
sqlx::query_as(STATUS_SQL)
.bind(user.id)
- .fetch_all(&self.pool)
+ .fetch_all(self.read_pool())
.await
.map_err(fix_error)?
};
diff --git a/crates/atuin-server/server.toml b/crates/atuin-server/server.toml
index 1eff5b72..6212de00 100644
--- a/crates/atuin-server/server.toml
+++ b/crates/atuin-server/server.toml
@@ -11,6 +11,10 @@
# db_uri="postgres://username:password@localhost/atuin"
# db_uri="sqlite:///config/atuin-server.db"
+## Optional: URI for read replica database
+## If set, read-only queries will be routed to this database
+# read_db_uri="postgres://username:password@localhost-replica/atuin"
+
## Maximum size for one history entry
# max_history_length = 8192
diff --git a/crates/atuin/tests/common/mod.rs b/crates/atuin/tests/common/mod.rs
index d79c13d6..bf8f85a7 100644
--- a/crates/atuin/tests/common/mod.rs
+++ b/crates/atuin/tests/common/mod.rs
@@ -36,7 +36,10 @@ pub async fn start_server(path: &str) -> (String, oneshot::Sender<()>, JoinHandl
page_size: 1100,
register_webhook_url: None,
register_webhook_username: String::new(),
- db_settings: DbSettings { db_uri },
+ db_settings: DbSettings {
+ db_uri: db_uri,
+ read_db_uri: None,
+ },
metrics: atuin_server::settings::Metrics::default(),
tls: atuin_server::settings::Tls::default(),
mail: atuin_server::settings::Mail::default(),