diff options
Diffstat (limited to '')
| -rw-r--r-- | atuin-server-postgres/src/lib.rs | 102 | ||||
| -rw-r--r-- | atuin-server-postgres/src/wrappers.rs | 29 |
2 files changed, 129 insertions, 2 deletions
diff --git a/atuin-server-postgres/src/lib.rs b/atuin-server-postgres/src/lib.rs index 0dc51daf..404188b0 100644 --- a/atuin-server-postgres/src/lib.rs +++ b/atuin-server-postgres/src/lib.rs @@ -1,14 +1,14 @@ use async_trait::async_trait; +use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex}; use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User}; use atuin_server_database::{Database, DbError, DbResult}; use futures_util::TryStreamExt; use serde::{Deserialize, Serialize}; use sqlx::postgres::PgPoolOptions; - use sqlx::Row; use tracing::instrument; -use wrappers::{DbHistory, DbSession, DbUser}; +use wrappers::{DbHistory, DbRecord, DbSession, DbUser}; mod wrappers; @@ -329,4 +329,102 @@ impl Database for Postgres { .map_err(fix_error) .map(|DbHistory(h)| h) } + + #[instrument(skip_all)] + async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> { + let mut tx = self.pool.begin().await.map_err(fix_error)?; + + for i in records { + let id = atuin_common::utils::uuid_v7(); + + sqlx::query( + "insert into records + (id, client_id, host, parent, timestamp, version, tag, data, cek, user_id) + values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + on conflict do nothing + ", + ) + .bind(id) + .bind(i.id) + .bind(i.host) + .bind(i.parent) + .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time + .bind(&i.version) + .bind(&i.tag) + .bind(&i.data.data) + .bind(&i.data.content_encryption_key) + .bind(user.id) + .execute(&mut tx) + .await + .map_err(fix_error)?; + } + + tx.commit().await.map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn next_records( + &self, + user: &User, + host: HostId, + tag: String, + start: Option<RecordId>, + count: u64, + ) -> DbResult<Vec<Record<EncryptedData>>> { + tracing::debug!("{:?} - {:?} - {:?}", host, tag, start); + let mut ret = Vec::with_capacity(count as usize); + let mut parent = start; + + // yeah let's do something better + for _ in 0..count { + // a very much not ideal query. but it's simple at least? + // we are basically using postgres as a kv store here, so... maybe consider using an actual + // kv store? + let record: Result<DbRecord, DbError> = sqlx::query_as( + "select client_id, host, parent, timestamp, version, tag, data, cek from records + where user_id = $1 + and tag = $2 + and host = $3 + and parent is not distinct from $4", + ) + .bind(user.id) + .bind(tag.clone()) + .bind(host) + .bind(parent) + .fetch_one(&self.pool) + .await + .map_err(fix_error); + + match record { + Ok(record) => { + let record: Record<EncryptedData> = record.into(); + ret.push(record.clone()); + + parent = Some(record.id); + } + Err(DbError::NotFound) => { + tracing::debug!("hit tail of store: {:?}/{}", host, tag); + return Ok(ret); + } + Err(e) => return Err(e), + } + } + + Ok(ret) + } + + async fn tail_records(&self, user: &User) -> DbResult<RecordIndex> { + const TAIL_RECORDS_SQL: &str = "select host, tag, client_id from records rp where (select count(1) from records where parent=rp.client_id and user_id = $1) = 0;"; + + let res = sqlx::query_as(TAIL_RECORDS_SQL) + .bind(user.id) + .fetch(&self.pool) + .try_collect() + .await + .map_err(fix_error)?; + + Ok(res) + } } diff --git a/atuin-server-postgres/src/wrappers.rs b/atuin-server-postgres/src/wrappers.rs index cb3d5a96..8bd482b1 100644 --- a/atuin-server-postgres/src/wrappers.rs +++ b/atuin-server-postgres/src/wrappers.rs @@ -1,10 +1,12 @@ use ::sqlx::{FromRow, Result}; +use atuin_common::record::{EncryptedData, Record}; use atuin_server_database::models::{History, Session, User}; use sqlx::{postgres::PgRow, Row}; pub struct DbUser(pub User); pub struct DbSession(pub Session); pub struct DbHistory(pub History); +pub struct DbRecord(pub Record<EncryptedData>); impl<'a> FromRow<'a, PgRow> for DbUser { fn from_row(row: &'a PgRow) -> Result<Self> { @@ -40,3 +42,30 @@ impl<'a> ::sqlx::FromRow<'a, PgRow> for DbHistory { })) } } + +impl<'a> ::sqlx::FromRow<'a, PgRow> for DbRecord { + fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> { + let timestamp: i64 = row.try_get("timestamp")?; + + let data = EncryptedData { + data: row.try_get("data")?, + content_encryption_key: row.try_get("cek")?, + }; + + Ok(Self(Record { + id: row.try_get("client_id")?, + host: row.try_get("host")?, + parent: row.try_get("parent")?, + timestamp: timestamp as u64, + version: row.try_get("version")?, + tag: row.try_get("tag")?, + data, + })) + } +} + +impl From<DbRecord> for Record<EncryptedData> { + fn from(other: DbRecord) -> Record<EncryptedData> { + Record { ..other.0 } + } +} |
