diff options
Diffstat (limited to 'crates/atuin-server-database')
| -rw-r--r-- | crates/atuin-server-database/Cargo.toml | 1 | ||||
| -rw-r--r-- | crates/atuin-server-database/src/lib.rs | 50 |
2 files changed, 48 insertions, 3 deletions
diff --git a/crates/atuin-server-database/Cargo.toml b/crates/atuin-server-database/Cargo.toml index e3e38e3f..823b5d39 100644 --- a/crates/atuin-server-database/Cargo.toml +++ b/crates/atuin-server-database/Cargo.toml @@ -17,3 +17,4 @@ time = { workspace = true } eyre = { workspace = true } serde = { workspace = true } async-trait = { workspace = true } +url = "2.5.2" diff --git a/crates/atuin-server-database/src/lib.rs b/crates/atuin-server-database/src/lib.rs index 1c577f59..9df36d14 100644 --- a/crates/atuin-server-database/src/lib.rs +++ b/crates/atuin-server-database/src/lib.rs @@ -15,7 +15,7 @@ use self::{ }; use async_trait::async_trait; use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}; -use serde::{Serialize, de::DeserializeOwned}; +use serde::{Deserialize, Serialize}; use time::{Date, Duration, Month, OffsetDateTime, Time, UtcOffset}; use tracing::instrument; @@ -41,10 +41,54 @@ impl std::error::Error for DbError {} pub type DbResult<T> = Result<T, DbError>; +#[derive(Debug, PartialEq)] +pub enum DbType { + Postgres, + Sqlite, + Unknown, +} + +#[derive(Clone, Deserialize, Serialize)] +pub struct DbSettings { + pub db_uri: String, +} + +impl DbSettings { + pub fn db_type(&self) -> DbType { + if self.db_uri.starts_with("postgres://") { + DbType::Postgres + } else if self.db_uri.starts_with("sqlite://") { + DbType::Sqlite + } else { + DbType::Unknown + } + } +} + +// Do our best to redact passwords so they're not logged in the event of an error. +impl Debug for DbSettings { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.db_type() == DbType::Postgres { + let redacted_uri = url::Url::parse(&self.db_uri) + .map(|mut url| { + let _ = url.set_password(Some("****")); + url.to_string() + }) + .unwrap_or(self.db_uri.clone()); + f.debug_struct("DbSettings") + .field("db_uri", &redacted_uri) + .finish() + } else { + f.debug_struct("DbSettings") + .field("db_uri", &self.db_uri) + .finish() + } + } +} + #[async_trait] pub trait Database: Sized + Clone + Send + Sync + 'static { - type Settings: Debug + Clone + DeserializeOwned + Serialize + Send + Sync + 'static; - async fn new(settings: &Self::Settings) -> DbResult<Self>; + async fn new(settings: &DbSettings) -> DbResult<Self>; async fn get_session(&self, token: &str) -> DbResult<Session>; async fn get_session_user(&self, token: &str) -> DbResult<User>; |
