diff options
| author | Ellie Huxtable <ellie@elliehuxtable.com> | 2023-07-14 20:44:08 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-07-14 20:44:08 +0100 |
| commit | 97e24d0d41bb743833e457de5ba49c5c233eb3b3 (patch) | |
| tree | f0cfefd9048df83d3029cb0b0d21f1f88813fe2e /atuin-server | |
| parent | Bump semver from 5.7.1 to 5.7.2 in /docs (#1100) (diff) | |
| download | atuin-97e24d0d41bb743833e457de5ba49c5c233eb3b3.zip | |
Add new sync (#1093)
* Add record migration
* Add database functions for inserting history
No real tests yet :( I would like to avoid running postgres lol
* Add index handler, use UUIDs not strings
* Fix a bunch of tests, remove Option<Uuid>
* Add tests, all passing
* Working upload sync
* Record downloading works
* Sync download works
* Don't waste requests
* Use a page size for uploads, make it variable later
* Aaaaaand they're encrypted now too
* Add cek
* Allow reading tail across hosts
* Revert "Allow reading tail across hosts"
Not like that
This reverts commit 7b0c72e7e050c358172f9b53cbd21b9e44cf4931.
* Handle multiple shards properly
* format
* Format and make clippy happy
* use some fancy types (#1098)
* use some fancy types
* fmt
* Goodbye horrible tuple
* Update atuin-server-postgres/migrations/20230623070418_records.sql
Co-authored-by: Conrad Ludgate <conradludgate@gmail.com>
* fmt
* Sort tests too because time sucks
* fix features
---------
Co-authored-by: Conrad Ludgate <conradludgate@gmail.com>
Diffstat (limited to '')
| -rw-r--r-- | atuin-server-database/src/lib.rs | 18 | ||||
| -rw-r--r-- | atuin-server-postgres/Cargo.toml | 1 | ||||
| -rw-r--r-- | atuin-server-postgres/build.rs | 5 | ||||
| -rw-r--r-- | atuin-server-postgres/migrations/20230623070418_records.sql | 15 | ||||
| -rw-r--r-- | atuin-server-postgres/src/lib.rs | 102 | ||||
| -rw-r--r-- | atuin-server-postgres/src/wrappers.rs | 29 | ||||
| -rw-r--r-- | atuin-server/src/handlers/mod.rs | 1 | ||||
| -rw-r--r-- | atuin-server/src/handlers/record.rs | 104 | ||||
| -rw-r--r-- | atuin-server/src/router.rs | 3 | ||||
| -rw-r--r-- | atuin-server/src/settings.rs | 2 |
10 files changed, 277 insertions, 3 deletions
diff --git a/atuin-server-database/src/lib.rs b/atuin-server-database/src/lib.rs index de33ba44..cdff90a2 100644 --- a/atuin-server-database/src/lib.rs +++ b/atuin-server-database/src/lib.rs @@ -13,7 +13,10 @@ use self::{ models::{History, NewHistory, NewSession, NewUser, Session, User}, }; use async_trait::async_trait; -use atuin_common::utils::get_days_from_month; +use atuin_common::{ + record::{EncryptedData, HostId, Record, RecordId, RecordIndex}, + utils::get_days_from_month, +}; use chrono::{Datelike, TimeZone}; use chronoutil::RelativeDuration; use serde::{de::DeserializeOwned, Serialize}; @@ -55,6 +58,19 @@ pub trait Database: Sized + Clone + Send + Sync + 'static { async fn delete_history(&self, user: &User, id: String) -> DbResult<()>; async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>>; + async fn add_records(&self, user: &User, record: &[Record<EncryptedData>]) -> DbResult<()>; + async fn next_records( + &self, + user: &User, + host: HostId, + tag: String, + start: Option<RecordId>, + count: u64, + ) -> DbResult<Vec<Record<EncryptedData>>>; + + // Return the tail record ID for each store, so (HostID, Tag, TailRecordID) + async fn tail_records(&self, user: &User) -> DbResult<RecordIndex>; + async fn count_history_range( &self, user: &User, diff --git a/atuin-server-postgres/Cargo.toml b/atuin-server-postgres/Cargo.toml index 18864f6c..bfec70a2 100644 --- a/atuin-server-postgres/Cargo.toml +++ b/atuin-server-postgres/Cargo.toml @@ -18,4 +18,5 @@ chrono = { workspace = true } serde = { workspace = true } sqlx = { workspace = true } async-trait = { workspace = true } +uuid = { workspace = true } futures-util = "0.3" diff --git a/atuin-server-postgres/build.rs b/atuin-server-postgres/build.rs new file mode 100644 index 00000000..d5068697 --- /dev/null +++ b/atuin-server-postgres/build.rs @@ -0,0 +1,5 @@ +// generated by `sqlx migrate build-script` +fn main() { + // trigger recompilation when a new migration is added + println!("cargo:rerun-if-changed=migrations"); +} diff --git a/atuin-server-postgres/migrations/20230623070418_records.sql b/atuin-server-postgres/migrations/20230623070418_records.sql new file mode 100644 index 00000000..22437595 --- /dev/null +++ b/atuin-server-postgres/migrations/20230623070418_records.sql @@ -0,0 +1,15 @@ +-- Add migration script here +create table records ( + id uuid primary key, -- remember to use uuidv7 for happy indices <3 + client_id uuid not null, -- I am too uncomfortable with the idea of a client-generated primary key + host uuid not null, -- a unique identifier for the host + parent uuid default null, -- the ID of the parent record, bearing in mind this is a linked list + timestamp bigint not null, -- not a timestamp type, as those do not have nanosecond precision + version text not null, + tag text not null, -- what is this? history, kv, whatever. Remember clients get a log per tag per host + data text not null, -- store the actual history data, encrypted. I don't wanna know! + cek text not null, + + user_id bigint not null, -- allow multiple users + created_at timestamp not null default current_timestamp +); diff --git a/atuin-server-postgres/src/lib.rs b/atuin-server-postgres/src/lib.rs index 0dc51daf..404188b0 100644 --- a/atuin-server-postgres/src/lib.rs +++ b/atuin-server-postgres/src/lib.rs @@ -1,14 +1,14 @@ use async_trait::async_trait; +use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex}; use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User}; use atuin_server_database::{Database, DbError, DbResult}; use futures_util::TryStreamExt; use serde::{Deserialize, Serialize}; use sqlx::postgres::PgPoolOptions; - use sqlx::Row; use tracing::instrument; -use wrappers::{DbHistory, DbSession, DbUser}; +use wrappers::{DbHistory, DbRecord, DbSession, DbUser}; mod wrappers; @@ -329,4 +329,102 @@ impl Database for Postgres { .map_err(fix_error) .map(|DbHistory(h)| h) } + + #[instrument(skip_all)] + async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> { + let mut tx = self.pool.begin().await.map_err(fix_error)?; + + for i in records { + let id = atuin_common::utils::uuid_v7(); + + sqlx::query( + "insert into records + (id, client_id, host, parent, timestamp, version, tag, data, cek, user_id) + values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + on conflict do nothing + ", + ) + .bind(id) + .bind(i.id) + .bind(i.host) + .bind(i.parent) + .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time + .bind(&i.version) + .bind(&i.tag) + .bind(&i.data.data) + .bind(&i.data.content_encryption_key) + .bind(user.id) + .execute(&mut tx) + .await + .map_err(fix_error)?; + } + + tx.commit().await.map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn next_records( + &self, + user: &User, + host: HostId, + tag: String, + start: Option<RecordId>, + count: u64, + ) -> DbResult<Vec<Record<EncryptedData>>> { + tracing::debug!("{:?} - {:?} - {:?}", host, tag, start); + let mut ret = Vec::with_capacity(count as usize); + let mut parent = start; + + // yeah let's do something better + for _ in 0..count { + // a very much not ideal query. but it's simple at least? + // we are basically using postgres as a kv store here, so... maybe consider using an actual + // kv store? + let record: Result<DbRecord, DbError> = sqlx::query_as( + "select client_id, host, parent, timestamp, version, tag, data, cek from records + where user_id = $1 + and tag = $2 + and host = $3 + and parent is not distinct from $4", + ) + .bind(user.id) + .bind(tag.clone()) + .bind(host) + .bind(parent) + .fetch_one(&self.pool) + .await + .map_err(fix_error); + + match record { + Ok(record) => { + let record: Record<EncryptedData> = record.into(); + ret.push(record.clone()); + + parent = Some(record.id); + } + Err(DbError::NotFound) => { + tracing::debug!("hit tail of store: {:?}/{}", host, tag); + return Ok(ret); + } + Err(e) => return Err(e), + } + } + + Ok(ret) + } + + async fn tail_records(&self, user: &User) -> DbResult<RecordIndex> { + const TAIL_RECORDS_SQL: &str = "select host, tag, client_id from records rp where (select count(1) from records where parent=rp.client_id and user_id = $1) = 0;"; + + let res = sqlx::query_as(TAIL_RECORDS_SQL) + .bind(user.id) + .fetch(&self.pool) + .try_collect() + .await + .map_err(fix_error)?; + + Ok(res) + } } diff --git a/atuin-server-postgres/src/wrappers.rs b/atuin-server-postgres/src/wrappers.rs index cb3d5a96..8bd482b1 100644 --- a/atuin-server-postgres/src/wrappers.rs +++ b/atuin-server-postgres/src/wrappers.rs @@ -1,10 +1,12 @@ use ::sqlx::{FromRow, Result}; +use atuin_common::record::{EncryptedData, Record}; use atuin_server_database::models::{History, Session, User}; use sqlx::{postgres::PgRow, Row}; pub struct DbUser(pub User); pub struct DbSession(pub Session); pub struct DbHistory(pub History); +pub struct DbRecord(pub Record<EncryptedData>); impl<'a> FromRow<'a, PgRow> for DbUser { fn from_row(row: &'a PgRow) -> Result<Self> { @@ -40,3 +42,30 @@ impl<'a> ::sqlx::FromRow<'a, PgRow> for DbHistory { })) } } + +impl<'a> ::sqlx::FromRow<'a, PgRow> for DbRecord { + fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> { + let timestamp: i64 = row.try_get("timestamp")?; + + let data = EncryptedData { + data: row.try_get("data")?, + content_encryption_key: row.try_get("cek")?, + }; + + Ok(Self(Record { + id: row.try_get("client_id")?, + host: row.try_get("host")?, + parent: row.try_get("parent")?, + timestamp: timestamp as u64, + version: row.try_get("version")?, + tag: row.try_get("tag")?, + data, + })) + } +} + +impl From<DbRecord> for Record<EncryptedData> { + fn from(other: DbRecord) -> Record<EncryptedData> { + Record { ..other.0 } + } +} diff --git a/atuin-server/src/handlers/mod.rs b/atuin-server/src/handlers/mod.rs index 35d32f6f..2bd782db 100644 --- a/atuin-server/src/handlers/mod.rs +++ b/atuin-server/src/handlers/mod.rs @@ -2,6 +2,7 @@ use atuin_common::api::{ErrorResponse, IndexResponse}; use axum::{response::IntoResponse, Json}; pub mod history; +pub mod record; pub mod status; pub mod user; diff --git a/atuin-server/src/handlers/record.rs b/atuin-server/src/handlers/record.rs new file mode 100644 index 00000000..0100c693 --- /dev/null +++ b/atuin-server/src/handlers/record.rs @@ -0,0 +1,104 @@ +use axum::{extract::Query, extract::State, Json}; +use http::StatusCode; +use serde::Deserialize; +use tracing::{error, instrument}; + +use super::{ErrorResponse, ErrorResponseStatus, RespExt}; +use crate::router::{AppState, UserAuth}; +use atuin_server_database::Database; + +use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex}; + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn post<DB: Database>( + UserAuth(user): UserAuth, + state: State<AppState<DB>>, + Json(records): Json<Vec<Record<EncryptedData>>>, +) -> Result<(), ErrorResponseStatus<'static>> { + let State(AppState { database, settings }) = state; + + tracing::debug!( + count = records.len(), + user = user.username, + "request to add records" + ); + + let too_big = records + .iter() + .any(|r| r.data.data.len() >= settings.max_record_size || settings.max_record_size == 0); + + if too_big { + return Err( + ErrorResponse::reply("could not add records; record too large") + .with_status(StatusCode::BAD_REQUEST), + ); + } + + if let Err(e) = database.add_records(&user, &records).await { + error!("failed to add record: {}", e); + + return Err(ErrorResponse::reply("failed to add record") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + }; + + Ok(()) +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn index<DB: Database>( + UserAuth(user): UserAuth, + state: State<AppState<DB>>, +) -> Result<Json<RecordIndex>, ErrorResponseStatus<'static>> { + let State(AppState { + database, + settings: _, + }) = state; + + let record_index = match database.tail_records(&user).await { + Ok(index) => index, + Err(e) => { + error!("failed to get record index: {}", e); + + return Err(ErrorResponse::reply("failed to calculate record index") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + }; + + Ok(Json(record_index)) +} + +#[derive(Deserialize)] +pub struct NextParams { + host: HostId, + tag: String, + start: Option<RecordId>, + count: u64, +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn next<DB: Database>( + params: Query<NextParams>, + UserAuth(user): UserAuth, + state: State<AppState<DB>>, +) -> Result<Json<Vec<Record<EncryptedData>>>, ErrorResponseStatus<'static>> { + let State(AppState { + database, + settings: _, + }) = state; + let params = params.0; + + let records = match database + .next_records(&user, params.host, params.tag, params.start, params.count) + .await + { + Ok(records) => records, + Err(e) => { + error!("failed to get record index: {}", e); + + return Err(ErrorResponse::reply("failed to calculate record index") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + }; + + Ok(Json(records)) +} diff --git a/atuin-server/src/router.rs b/atuin-server/src/router.rs index ec558e78..7dc8a246 100644 --- a/atuin-server/src/router.rs +++ b/atuin-server/src/router.rs @@ -71,6 +71,9 @@ pub fn router<DB: Database>(database: DB, settings: Settings<DB::Settings>) -> R .route("/sync/status", get(handlers::status::status)) .route("/history", post(handlers::history::add)) .route("/history", delete(handlers::history::delete)) + .route("/record", post(handlers::record::post)) + .route("/record", get(handlers::record::index)) + .route("/record/next", get(handlers::record::next)) .route("/user/:username", get(handlers::user::get)) .route("/account", delete(handlers::user::delete)) .route("/register", post(handlers::user::register)) diff --git a/atuin-server/src/settings.rs b/atuin-server/src/settings.rs index fb5325d4..7e447e9e 100644 --- a/atuin-server/src/settings.rs +++ b/atuin-server/src/settings.rs @@ -12,6 +12,7 @@ pub struct Settings<DbSettings> { pub path: String, pub open_registration: bool, pub max_history_length: usize, + pub max_record_size: usize, pub page_size: i64, pub register_webhook_url: Option<String>, pub register_webhook_username: String, @@ -39,6 +40,7 @@ impl<DbSettings: DeserializeOwned> Settings<DbSettings> { .set_default("port", 8888)? .set_default("open_registration", false)? .set_default("max_history_length", 8192)? + .set_default("max_record_size", 1024 * 1024 * 1024)? // pretty chonky .set_default("path", "")? .set_default("register_webhook_username", "")? .set_default("page_size", 1100)? |
