diff options
Diffstat (limited to 'atuin-client')
| -rw-r--r-- | atuin-client/Cargo.toml | 1 | ||||
| -rw-r--r-- | atuin-client/record-migrations/20231127090831_create-store.sql | 15 | ||||
| -rw-r--r-- | atuin-client/src/api_client.rs | 46 | ||||
| -rw-r--r-- | atuin-client/src/history.rs | 210 | ||||
| -rw-r--r-- | atuin-client/src/history/store.rs | 219 | ||||
| -rw-r--r-- | atuin-client/src/kv.rs | 100 | ||||
| -rw-r--r-- | atuin-client/src/record/encryption.rs | 29 | ||||
| -rw-r--r-- | atuin-client/src/record/sqlite_store.rs | 250 | ||||
| -rw-r--r-- | atuin-client/src/record/store.rs | 36 | ||||
| -rw-r--r-- | atuin-client/src/record/sync.rs | 486 | ||||
| -rw-r--r-- | atuin-client/src/settings.rs | 9 |
11 files changed, 1011 insertions, 390 deletions
diff --git a/atuin-client/Cargo.toml b/atuin-client/Cargo.toml index c3d9cab7..cbb8d016 100644 --- a/atuin-client/Cargo.toml +++ b/atuin-client/Cargo.toml @@ -48,6 +48,7 @@ rmp = { version = "0.8.11" } typed-builder = { workspace = true } tokio = { workspace = true } semver = { workspace = true } +thiserror = { workspace = true } futures = "0.3" crypto_secretbox = "0.1.1" generic-array = { version = "0.14", features = ["serde"] } diff --git a/atuin-client/record-migrations/20231127090831_create-store.sql b/atuin-client/record-migrations/20231127090831_create-store.sql new file mode 100644 index 00000000..53d78860 --- /dev/null +++ b/atuin-client/record-migrations/20231127090831_create-store.sql @@ -0,0 +1,15 @@ +-- Add migration script here +create table if not exists store ( + id text primary key, -- globally unique ID + + idx integer, -- incrementing integer ID unique per (host, tag) + host text not null, -- references the host row + tag text not null, + + timestamp integer not null, + version text not null, + data blob not null, + cek blob not null +); + +create unique index record_uniq ON store(host, tag, idx); diff --git a/atuin-client/src/api_client.rs b/atuin-client/src/api_client.rs index ae8df5ad..9007b9ab 100644 --- a/atuin-client/src/api_client.rs +++ b/atuin-client/src/api_client.rs @@ -13,11 +13,11 @@ use atuin_common::{ AddHistoryRequest, CountResponse, DeleteHistoryRequest, ErrorResponse, IndexResponse, LoginRequest, LoginResponse, RegisterResponse, StatusResponse, SyncHistoryResponse, }, - record::RecordIndex, + record::RecordStatus, }; use atuin_common::{ api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ATUIN_VERSION}, - record::{EncryptedData, HostId, Record, RecordId}, + record::{EncryptedData, HostId, Record, RecordIdx}, }; use semver::Version; use time::format_description::well_known::Rfc3339; @@ -267,10 +267,18 @@ impl<'a> Client<'a> { } pub async fn post_records(&self, records: &[Record<EncryptedData>]) -> Result<()> { - let url = format!("{}/record", self.sync_addr); + let url = format!("{}/api/v0/record", self.sync_addr); let url = Url::parse(url.as_str())?; - self.client.post(url).json(records).send().await?; + let resp = self.client.post(url).json(records).send().await?; + info!("posted records, got {}", resp.status()); + + if !resp.status().is_success() { + error!( + "failed to post records to server; got: {:?}", + resp.text().await + ); + } Ok(()) } @@ -279,24 +287,22 @@ impl<'a> Client<'a> { &self, host: HostId, tag: String, - start: Option<RecordId>, + start: RecordIdx, count: u64, ) -> Result<Vec<Record<EncryptedData>>> { + debug!( + "fetching record/s from host {}/{}/{}", + host.0.to_string(), + tag, + start + ); + let url = format!( - "{}/record/next?host={}&tag={}&count={}", - self.sync_addr, host.0, tag, count + "{}/api/v0/record/next?host={}&tag={}&count={}&start={}", + self.sync_addr, host.0, tag, count, start ); - let mut url = Url::parse(url.as_str())?; - if let Some(start) = start { - url.set_query(Some( - format!( - "host={}&tag={}&count={}&start={}", - host.0, tag, count, start.0 - ) - .as_str(), - )); - } + let url = Url::parse(url.as_str())?; let resp = self.client.get(url).send().await?; @@ -305,8 +311,8 @@ impl<'a> Client<'a> { Ok(records) } - pub async fn record_index(&self) -> Result<RecordIndex> { - let url = format!("{}/record", self.sync_addr); + pub async fn record_status(&self) -> Result<RecordStatus> { + let url = format!("{}/api/v0/record", self.sync_addr); let url = Url::parse(url.as_str())?; let resp = self.client.get(url).send().await?; @@ -317,6 +323,8 @@ impl<'a> Client<'a> { let index = resp.json().await?; + debug!("got remote index {:?}", index); + Ok(index) } diff --git a/atuin-client/src/history.rs b/atuin-client/src/history.rs index fbcb169c..2b2c41ee 100644 --- a/atuin-client/src/history.rs +++ b/atuin-client/src/history.rs @@ -1,12 +1,21 @@ +use rmp::decode::ValueReadError; +use rmp::{decode::Bytes, Marker}; use std::env; +use atuin_common::record::DecryptedData; use atuin_common::utils::uuid_v7; + +use eyre::{bail, eyre, Result}; use regex::RegexSet; use crate::{secrets::SECRET_PATTERNS, settings::Settings}; use time::OffsetDateTime; mod builder; +pub mod store; + +const HISTORY_VERSION: &str = "v0"; +const HISTORY_TAG: &str = "history"; /// Client-side history entry. /// @@ -81,6 +90,108 @@ impl History { } } + pub fn serialize(&self) -> Result<DecryptedData> { + // This is pretty much the same as what we used for the old history, with one difference - + // it uses integers for timestamps rather than a string format. + + use rmp::encode; + + let mut output = vec![]; + + // write the version + encode::write_u16(&mut output, 0)?; + // INFO: ensure this is updated when adding new fields + encode::write_array_len(&mut output, 9)?; + + encode::write_str(&mut output, &self.id)?; + encode::write_u64(&mut output, self.timestamp.unix_timestamp_nanos() as u64)?; + encode::write_sint(&mut output, self.duration)?; + encode::write_sint(&mut output, self.exit)?; + encode::write_str(&mut output, &self.command)?; + encode::write_str(&mut output, &self.cwd)?; + encode::write_str(&mut output, &self.session)?; + encode::write_str(&mut output, &self.hostname)?; + + match self.deleted_at { + Some(d) => encode::write_u64(&mut output, d.unix_timestamp_nanos() as u64)?, + None => encode::write_nil(&mut output)?, + } + + Ok(DecryptedData(output)) + } + + fn deserialize_v0(bytes: &[u8]) -> Result<History> { + use rmp::decode; + + fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { + eyre!("{err:?}") + } + + let mut bytes = Bytes::new(bytes); + + let version = decode::read_u16(&mut bytes).map_err(error_report)?; + + if version != 0 { + bail!("expected decoding v0 record, found v{version}"); + } + + let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; + + if nfields != 9 { + bail!("cannot decrypt history from a different version of Atuin"); + } + + let bytes = bytes.remaining_slice(); + let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + let mut bytes = Bytes::new(bytes); + let timestamp = decode::read_u64(&mut bytes).map_err(error_report)?; + let duration = decode::read_int(&mut bytes).map_err(error_report)?; + let exit = decode::read_int(&mut bytes).map_err(error_report)?; + + let bytes = bytes.remaining_slice(); + let (command, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (cwd, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (session, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (hostname, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + // if we have more fields, try and get the deleted_at + let mut bytes = Bytes::new(bytes); + + let (deleted_at, bytes) = match decode::read_u64(&mut bytes) { + Ok(unix) => (Some(unix), bytes.remaining_slice()), + // we accept null here + Err(ValueReadError::TypeMismatch(Marker::Null)) => (None, bytes.remaining_slice()), + Err(err) => return Err(error_report(err)), + }; + + if !bytes.is_empty() { + bail!("trailing bytes in encoded history. malformed") + } + + Ok(History { + id: id.to_owned(), + timestamp: OffsetDateTime::from_unix_timestamp_nanos(timestamp as i128)?, + duration, + exit, + command: command.to_owned(), + cwd: cwd.to_owned(), + session: session.to_owned(), + hostname: hostname.to_owned(), + deleted_at: deleted_at + .map(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128)) + .transpose()?, + }) + } + + pub fn deserialize(bytes: &[u8], version: &str) -> Result<History> { + match version { + HISTORY_VERSION => Self::deserialize_v0(bytes), + + _ => bail!("unknown version {version:?}"), + } + } + /// Builder for a history entry that is imported from shell history. /// /// The only two required fields are `timestamp` and `command`. @@ -202,8 +313,9 @@ impl History { #[cfg(test)] mod tests { use regex::RegexSet; + use time::macros::datetime; - use crate::settings::Settings; + use crate::{history::HISTORY_VERSION, settings::Settings}; use super::History; @@ -274,4 +386,100 @@ mod tests { assert!(stripe_key.should_save(&settings)); } + + #[test] + fn test_serialize_deserialize() { + let bytes = [ + 205, 0, 0, 153, 217, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, + 53, 51, 56, 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 207, 23, 99, + 98, 117, 24, 210, 246, 128, 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, + 97, 116, 117, 115, 217, 42, 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, + 46, 108, 117, 100, 103, 97, 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, + 47, 99, 111, 100, 101, 47, 97, 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, + 51, 48, 54, 102, 50, 55, 52, 52, 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, + 102, 57, 52, 53, 55, 187, 102, 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, + 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, 116, 101, 192, + ]; + + let history = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + deleted_at: None, + }; + + let serialized = history.serialize().expect("failed to serialize history"); + assert_eq!(serialized.0, bytes); + + let deserialized = History::deserialize(&serialized.0, HISTORY_VERSION) + .expect("failed to deserialize history"); + assert_eq!(history, deserialized); + + // test the snapshot too + let deserialized = + History::deserialize(&bytes, HISTORY_VERSION).expect("failed to deserialize history"); + assert_eq!(history, deserialized); + } + + #[test] + fn test_serialize_deserialize_deleted() { + let history = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + deleted_at: Some(datetime!(2023-11-19 20:18 +00:00)), + }; + + let serialized = history.serialize().expect("failed to serialize history"); + + let deserialized = History::deserialize(&serialized.0, HISTORY_VERSION) + .expect("failed to deserialize history"); + + assert_eq!(history, deserialized); + } + + #[test] + fn test_serialize_deserialize_version() { + // v0 + let bytes_v0 = [ + 205, 0, 0, 153, 217, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, + 53, 51, 56, 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 207, 23, 99, + 98, 117, 24, 210, 246, 128, 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, + 97, 116, 117, 115, 217, 42, 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, + 46, 108, 117, 100, 103, 97, 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, + 47, 99, 111, 100, 101, 47, 97, 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, + 51, 48, 54, 102, 50, 55, 52, 52, 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, + 102, 57, 52, 53, 55, 187, 102, 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, + 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, 116, 101, 192, + ]; + + // some other version + let bytes_v1 = [ + 205, 1, 0, 153, 217, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, + 53, 51, 56, 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 207, 23, 99, + 98, 117, 24, 210, 246, 128, 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, + 97, 116, 117, 115, 217, 42, 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, + 46, 108, 117, 100, 103, 97, 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, + 47, 99, 111, 100, 101, 47, 97, 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, + 51, 48, 54, 102, 50, 55, 52, 52, 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, + 102, 57, 52, 53, 55, 187, 102, 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, + 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, 116, 101, 192, + ]; + + let deserialized = History::deserialize(&bytes_v0, HISTORY_VERSION); + assert!(deserialized.is_ok()); + + let deserialized = History::deserialize(&bytes_v1, HISTORY_VERSION); + assert!(deserialized.is_err()); + } } diff --git a/atuin-client/src/history/store.rs b/atuin-client/src/history/store.rs new file mode 100644 index 00000000..bf74a0a8 --- /dev/null +++ b/atuin-client/src/history/store.rs @@ -0,0 +1,219 @@ +use eyre::{bail, eyre, Result}; +use rmp::decode::Bytes; + +use crate::record::{encryption::PASETO_V4, sqlite_store::SqliteStore, store::Store}; +use atuin_common::record::{DecryptedData, Host, HostId, Record, RecordIdx}; + +use super::{History, HISTORY_TAG, HISTORY_VERSION}; + +#[derive(Debug)] +pub struct HistoryStore { + pub store: SqliteStore, + pub host_id: HostId, + pub encryption_key: [u8; 32], +} + +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum HistoryRecord { + Create(History), // Create a history record + Delete(String), // Delete a history record, identified by ID +} + +impl HistoryRecord { + /// Serialize a history record, returning DecryptedData + /// The record will be of a certain type + /// We map those like so: + /// + /// HistoryRecord::Create -> 0 + /// HistoryRecord::Delete-> 1 + /// + /// This numeric identifier is then written as the first byte to the buffer. For history, we + /// append the serialized history right afterwards, to avoid having to handle serialization + /// twice. + /// + /// Deletion simply refers to the history by ID + pub fn serialize(&self) -> Result<DecryptedData> { + // probably don't actually need to use rmp here, but if we ever need to extend it, it's a + // nice wrapper around raw byte stuff + use rmp::encode; + + let mut output = vec![]; + + match self { + HistoryRecord::Create(history) => { + // 0 -> a history create + encode::write_u8(&mut output, 0)?; + + let bytes = history.serialize()?; + + encode::write_bin(&mut output, &bytes.0)?; + } + HistoryRecord::Delete(id) => { + // 1 -> a history delete + encode::write_u8(&mut output, 1)?; + encode::write_str(&mut output, id)?; + } + }; + + Ok(DecryptedData(output)) + } + + pub fn deserialize(bytes: &[u8], version: &str) -> Result<Self> { + use rmp::decode; + + fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { + eyre!("{err:?}") + } + + let mut bytes = Bytes::new(bytes); + + let record_type = decode::read_u8(&mut bytes).map_err(error_report)?; + + match record_type { + // 0 -> HistoryRecord::Create + 0 => { + // not super useful to us atm, but perhaps in the future + // written by write_bin above + let _ = decode::read_bin_len(&mut bytes).map_err(error_report)?; + + let record = History::deserialize(bytes.remaining_slice(), version)?; + + Ok(HistoryRecord::Create(record)) + } + + // 1 -> HistoryRecord::Delete + 1 => { + let bytes = bytes.remaining_slice(); + let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + if !bytes.is_empty() { + bail!( + "trailing bytes decoding HistoryRecord::Delete - malformed? got {bytes:?}" + ); + } + + Ok(HistoryRecord::Delete(id.to_string())) + } + + n => { + bail!("unknown HistoryRecord type {n}") + } + } + } +} + +impl HistoryStore { + pub fn new(store: SqliteStore, host_id: HostId, encryption_key: [u8; 32]) -> Self { + HistoryStore { + store, + host_id, + encryption_key, + } + } + + async fn push_record(&self, record: HistoryRecord) -> Result<RecordIdx> { + let bytes = record.serialize()?; + let idx = self + .store + .last(self.host_id, HISTORY_TAG) + .await? + .map_or(0, |p| p.idx + 1); + + let record = Record::builder() + .host(Host::new(self.host_id)) + .version(HISTORY_VERSION.to_string()) + .tag(HISTORY_TAG.to_string()) + .idx(idx) + .data(bytes) + .build(); + + self.store + .push(&record.encrypt::<PASETO_V4>(&self.encryption_key)) + .await?; + + Ok(idx) + } + + pub async fn delete(&self, id: String) -> Result<RecordIdx> { + let record = HistoryRecord::Delete(id); + + self.push_record(record).await + } + + pub async fn push(&self, history: History) -> Result<RecordIdx> { + // TODO(ellie): move the history store to its own file + // it's tiny rn so fine as is + let record = HistoryRecord::Create(history); + + self.push_record(record).await + } +} + +#[cfg(test)] +mod tests { + use time::macros::datetime; + + use crate::history::{store::HistoryRecord, HISTORY_VERSION}; + + use super::History; + + #[test] + fn test_serialize_deserialize_create() { + let bytes = [ + 204, 0, 196, 141, 205, 0, 0, 153, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 56, 49, + 55, 53, 55, 99, 100, 50, 97, 101, 101, 54, 53, 99, 100, 55, 56, 54, 49, 102, 57, 99, + 56, 49, 207, 23, 166, 251, 212, 181, 82, 0, 0, 100, 0, 162, 108, 115, 217, 41, 47, 85, + 115, 101, 114, 115, 47, 101, 108, 108, 105, 101, 47, 115, 114, 99, 47, 103, 105, 116, + 104, 117, 98, 46, 99, 111, 109, 47, 97, 116, 117, 105, 110, 115, 104, 47, 97, 116, 117, + 105, 110, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 97, 100, 56, 57, 55, 53, 57, 55, + 56, 53, 50, 53, 50, 55, 97, 51, 49, 99, 57, 57, 56, 48, 53, 57, 170, 98, 111, 111, 112, + 58, 101, 108, 108, 105, 101, 192, + ]; + + let history = History { + id: "018cd4fe81757cd2aee65cd7861f9c81".to_owned(), + timestamp: datetime!(2024-01-04 00:00:00.000000 +00:00), + duration: 100, + exit: 0, + command: "ls".to_owned(), + cwd: "/Users/ellie/src/github.com/atuinsh/atuin".to_owned(), + session: "018cd4fead897597852527a31c998059".to_owned(), + hostname: "boop:ellie".to_owned(), + deleted_at: None, + }; + + let record = HistoryRecord::Create(history); + + let serialized = record.serialize().expect("failed to serialize history"); + assert_eq!(serialized.0, bytes); + + let deserialized = HistoryRecord::deserialize(&serialized.0, HISTORY_VERSION) + .expect("failed to deserialize HistoryRecord"); + assert_eq!(deserialized, record); + + // check the snapshot too + let deserialized = HistoryRecord::deserialize(&bytes, HISTORY_VERSION) + .expect("failed to deserialize HistoryRecord"); + assert_eq!(deserialized, record); + } + + #[test] + fn test_serialize_deserialize_delete() { + let bytes = [ + 204, 1, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 56, 49, 55, 53, 55, 99, 100, 50, + 97, 101, 101, 54, 53, 99, 100, 55, 56, 54, 49, 102, 57, 99, 56, 49, + ]; + let record = HistoryRecord::Delete("018cd4fe81757cd2aee65cd7861f9c81".to_string()); + + let serialized = record.serialize().expect("failed to serialize history"); + assert_eq!(serialized.0, bytes); + + let deserialized = HistoryRecord::deserialize(&serialized.0, HISTORY_VERSION) + .expect("failed to deserialize HistoryRecord"); + assert_eq!(deserialized, record); + + let deserialized = HistoryRecord::deserialize(&bytes, HISTORY_VERSION) + .expect("failed to deserialize HistoryRecord"); + assert_eq!(deserialized, record); + } +} diff --git a/atuin-client/src/kv.rs b/atuin-client/src/kv.rs index 1ca6b5e8..cee7063d 100644 --- a/atuin-client/src/kv.rs +++ b/atuin-client/src/kv.rs @@ -1,6 +1,6 @@ use std::collections::BTreeMap; -use atuin_common::record::{DecryptedData, HostId}; +use atuin_common::record::{DecryptedData, Host, HostId}; use eyre::{bail, ensure, eyre, Result}; use serde::Deserialize; @@ -89,7 +89,7 @@ impl KvStore { pub async fn set( &self, - store: &mut (impl Store + Send + Sync), + store: &(impl Store + Send + Sync), encryption_key: &[u8; 32], host_id: HostId, namespace: &str, @@ -111,13 +111,16 @@ impl KvStore { let bytes = record.serialize()?; - let parent = store.tail(host_id, KV_TAG).await?.map(|entry| entry.id); + let idx = store + .last(host_id, KV_TAG) + .await? + .map_or(0, |entry| entry.idx + 1); let record = atuin_common::record::Record::builder() - .host(host_id) + .host(Host::new(host_id)) .version(KV_VERSION.to_string()) .tag(KV_TAG.to_string()) - .parent(parent) + .idx(idx) .data(bytes) .build(); @@ -137,43 +140,18 @@ impl KvStore { namespace: &str, key: &str, ) -> Result<Option<KvRecord>> { - // Currently, this is O(n). When we have an actual KV store, it can be better - // Just a poc for now! + // TODO: don't rebuild every time... + let map = self.build_kv(store, encryption_key).await?; - // iterate records to find the value we want - // start at the end, so we get the most recent version - let tails = store.tag_tails(KV_TAG).await?; + let res = map.get(namespace); - if tails.is_empty() { - return Ok(None); - } - - // first, decide on a record. - // try getting the newest first - // we always need a way of deciding the "winner" of a write - // TODO(ellie): something better than last-write-wins, what if two write at the same time? - let mut record = tails.iter().max_by_key(|r| r.timestamp).unwrap().clone(); - - loop { - let decrypted = match record.version.as_str() { - KV_VERSION => record.decrypt::<PASETO_V4>(encryption_key)?, - version => bail!("unknown version {version:?}"), - }; - - let kv = KvRecord::deserialize(&decrypted.data, &decrypted.version)?; - if kv.key == key && kv.namespace == namespace { - return Ok(Some(kv)); - } + if let Some(ns) = res { + let value = ns.get(key); - if let Some(parent) = decrypted.parent { - record = store.get(parent).await?; - } else { - break; - } + Ok(value.cloned()) + } else { + Ok(None) } - - // if we get here, then... we didn't find the record with that key :( - Ok(None) } // Build a kv map out of the linked list kv store @@ -184,32 +162,30 @@ impl KvStore { &self, store: &impl Store, encryption_key: &[u8; 32], - ) -> Result<BTreeMap<String, BTreeMap<String, String>>> { + ) -> Result<BTreeMap<String, BTreeMap<String, KvRecord>>> { let mut map = BTreeMap::new(); - let tails = store.tag_tails(KV_TAG).await?; - - if tails.is_empty() { - return Ok(map); - } - let mut record = tails.iter().max_by_key(|r| r.timestamp).unwrap().clone(); + // TODO: maybe don't load the entire tag into memory to build the kv + // we can be smart about it and only load values since the last build + // or, iterate/paginate + let tagged = store.all_tagged(KV_TAG).await?; - loop { + // iterate through all tags and play each KV record at a time + // this is "last write wins" + // probably good enough for now, but revisit in future + for record in tagged { let decrypted = match record.version.as_str() { KV_VERSION => record.decrypt::<PASETO_V4>(encryption_key)?, version => bail!("unknown version {version:?}"), }; - let kv = KvRecord::deserialize(&decrypted.data, &decrypted.version)?; + let kv = KvRecord::deserialize(&decrypted.data, KV_VERSION)?; - let ns = map.entry(kv.namespace).or_insert_with(BTreeMap::new); - ns.entry(kv.key).or_insert_with(|| kv.value); + let ns = map + .entry(kv.namespace.clone()) + .or_insert_with(BTreeMap::new); - if let Some(parent) = decrypted.parent { - record = store.get(parent).await?; - } else { - break; - } + ns.insert(kv.key.clone(), kv); } Ok(map) @@ -261,19 +237,27 @@ mod tests { let map = kv.build_kv(&store, &key).await.unwrap(); assert_eq!( - map.get("test-kv") + *map.get("test-kv") .expect("map namespace not set") .get("foo") .expect("map key not set"), - "bar" + KvRecord { + namespace: String::from("test-kv"), + key: String::from("foo"), + value: String::from("bar") + } ); assert_eq!( - map.get("test-kv") + *map.get("test-kv") .expect("map namespace not set") .get("1") .expect("map key not set"), - "2" + KvRecord { + namespace: String::from("test-kv"), + key: String::from("1"), + value: String::from("2") + } ); } } diff --git a/atuin-client/src/record/encryption.rs b/atuin-client/src/record/encryption.rs index 3074a9c2..c2cdaa6a 100644 --- a/atuin-client/src/record/encryption.rs +++ b/atuin-client/src/record/encryption.rs @@ -1,5 +1,5 @@ use atuin_common::record::{ - AdditionalData, DecryptedData, EncryptedData, Encryption, HostId, RecordId, + AdditionalData, DecryptedData, EncryptedData, Encryption, HostId, RecordId, RecordIdx, }; use base64::{engine::general_purpose, Engine}; use eyre::{ensure, Context, Result}; @@ -170,10 +170,10 @@ struct AtuinFooter { #[derive(Debug, Copy, Clone, Serialize)] struct Assertions<'a> { id: &'a RecordId, + idx: &'a RecordIdx, version: &'a str, tag: &'a str, host: &'a HostId, - parent: Option<&'a RecordId>, } impl<'a> From<AdditionalData<'a>> for Assertions<'a> { @@ -183,7 +183,7 @@ impl<'a> From<AdditionalData<'a>> for Assertions<'a> { version: ad.version, tag: ad.tag, host: ad.host, - parent: ad.parent, + idx: ad.idx, } } } @@ -196,7 +196,10 @@ impl Assertions<'_> { #[cfg(test)] mod tests { - use atuin_common::{record::Record, utils::uuid_v7}; + use atuin_common::{ + record::{Host, Record}, + utils::uuid_v7, + }; use super::*; @@ -209,7 +212,7 @@ mod tests { version: "v0", tag: "kv", host: &HostId(uuid_v7()), - parent: None, + idx: &0, }; let data = DecryptedData(vec![1, 2, 3, 4]); @@ -228,7 +231,7 @@ mod tests { version: "v0", tag: "kv", host: &HostId(uuid_v7()), - parent: None, + idx: &0, }; let data = DecryptedData(vec![1, 2, 3, 4]); @@ -252,7 +255,7 @@ mod tests { version: "v0", tag: "kv", host: &HostId(uuid_v7()), - parent: None, + idx: &0, }; let data = DecryptedData(vec![1, 2, 3, 4]); @@ -270,7 +273,7 @@ mod tests { version: "v0", tag: "kv", host: &HostId(uuid_v7()), - parent: None, + idx: &0, }; let data = DecryptedData(vec![1, 2, 3, 4]); @@ -294,7 +297,7 @@ mod tests { version: "v0", tag: "kv", host: &HostId(uuid_v7()), - parent: None, + idx: &0, }; let data = DecryptedData(vec![1, 2, 3, 4]); @@ -323,9 +326,10 @@ mod tests { .id(RecordId(uuid_v7())) .version("v0".to_owned()) .tag("kv".to_owned()) - .host(HostId(uuid_v7())) + .host(Host::new(HostId(uuid_v7()))) .timestamp(1687244806000000) .data(DecryptedData(vec![1, 2, 3, 4])) + .idx(0) .build(); let encrypted = record.encrypt::<PASETO_V4>(&key); @@ -345,15 +349,16 @@ mod tests { .id(RecordId(uuid_v7())) .version("v0".to_owned()) .tag("kv".to_owned()) - .host(HostId(uuid_v7())) + .host(Host::new(HostId(uuid_v7()))) .timestamp(1687244806000000) .data(DecryptedData(vec![1, 2, 3, 4])) + .idx(0) .build(); let encrypted = record.encrypt::<PASETO_V4>(&key); let mut enc1 = encrypted.clone(); - enc1.host = HostId(uuid_v7()); + enc1.host = Host::new(HostId(uuid_v7())); let _ = enc1 .decrypt::<PASETO_V4>(&key) .expect_err("tampering with the host should result in auth failure"); diff --git a/atuin-client/src/record/sqlite_store.rs b/atuin-client/src/record/sqlite_store.rs index db709f20..8112aa96 100644 --- a/atuin-client/src/record/sqlite_store.rs +++ b/atuin-client/src/record/sqlite_store.rs @@ -8,17 +8,20 @@ use std::str::FromStr; use async_trait::async_trait; use eyre::{eyre, Result}; use fs_err as fs; -use futures::TryStreamExt; + use sqlx::{ sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow}, Row, }; -use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex}; +use atuin_common::record::{ + EncryptedData, Host, HostId, Record, RecordId, RecordIdx, RecordStatus, +}; use uuid::Uuid; use super::store::Store; +#[derive(Debug, Clone)] pub struct SqliteStore { pool: SqlitePool, } @@ -38,6 +41,7 @@ impl SqliteStore { let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())? .journal_mode(SqliteJournalMode::Wal) + .foreign_keys(true) .create_if_missing(true); let pool = SqlitePoolOptions::new().connect_with(opts).await?; @@ -61,14 +65,14 @@ impl SqliteStore { ) -> Result<()> { // In sqlite, we are "limited" to i64. But that is still fine, until 2262. sqlx::query( - "insert or ignore into records(id, host, tag, timestamp, parent, version, data, cek) + "insert or ignore into store(id, idx, host, tag, timestamp, version, data, cek) values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", ) - .bind(r.id.0.as_simple().to_string()) - .bind(r.host.0.as_simple().to_string()) + .bind(r.id.0.as_hyphenated().to_string()) + .bind(r.idx as i64) + .bind(r.host.id.0.as_hyphenated().to_string()) .bind(r.tag.as_str()) .bind(r.timestamp as i64) - .bind(r.parent.map(|p| p.0.as_simple().to_string())) .bind(r.version.as_str()) .bind(r.data.data.as_str()) .bind(r.data.content_encryption_key.as_str()) @@ -79,20 +83,17 @@ impl SqliteStore { } fn query_row(row: SqliteRow) -> Record<EncryptedData> { + let idx: i64 = row.get("idx"); let timestamp: i64 = row.get("timestamp"); // tbh at this point things are pretty fucked so just panic let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB"); let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB"); - let parent: Option<&str> = row.get("parent"); - - let parent = parent - .map(|parent| Uuid::from_str(parent).expect("invalid parent UUID format in sqlite DB")); Record { id: RecordId(id), - host: HostId(host), - parent: parent.map(RecordId), + idx: idx as u64, + host: Host::new(HostId(host)), timestamp: timestamp as u64, tag: row.get("tag"), version: row.get("version"), @@ -122,8 +123,8 @@ impl Store for SqliteStore { } async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>> { - let res = sqlx::query("select * from records where id = ?1") - .bind(id.0.as_simple().to_string()) + let res = sqlx::query("select * from store where store.id = ?1") + .bind(id.0.as_hyphenated().to_string()) .map(Self::query_row) .fetch_one(&self.pool) .await?; @@ -131,20 +132,66 @@ impl Store for SqliteStore { Ok(res) } - async fn len(&self, host: HostId, tag: &str) -> Result<u64> { - let res: (i64,) = - sqlx::query_as("select count(1) from records where host = ?1 and tag = ?2") - .bind(host.0.as_simple().to_string()) + async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> { + let res = + sqlx::query("select * from store where host=?1 and tag=?2 order by idx desc limit 1") + .bind(host.0.as_hyphenated().to_string()) .bind(tag) + .map(Self::query_row) .fetch_one(&self.pool) + .await; + + match res { + Err(sqlx::Error::RowNotFound) => Ok(None), + Err(e) => Err(eyre!("an error occured: {}", e)), + Ok(record) => Ok(Some(record)), + } + } + + async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> { + self.idx(host, tag, 0).await + } + + async fn len(&self, host: HostId, tag: &str) -> Result<u64> { + let last = self.last(host, tag).await?; + + if let Some(last) = last { + return Ok(last.idx + 1); + } + + return Ok(0); + } + + async fn next( + &self, + host: HostId, + tag: &str, + idx: RecordIdx, + limit: u64, + ) -> Result<Vec<Record<EncryptedData>>> { + let res = + sqlx::query("select * from store where idx >= ?1 and host = ?2 and tag = ?3 limit ?4") + .bind(idx as i64) + .bind(host.0.as_hyphenated().to_string()) + .bind(tag) + .bind(limit as i64) + .map(Self::query_row) + .fetch_all(&self.pool) .await?; - Ok(res.0 as u64) + Ok(res) } - async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>> { - let res = sqlx::query("select * from records where parent = ?1") - .bind(record.id.0.as_simple().to_string()) + async fn idx( + &self, + host: HostId, + tag: &str, + idx: RecordIdx, + ) -> Result<Option<Record<EncryptedData>>> { + let res = sqlx::query("select * from store where idx = ?1 and host = ?2 and tag = ?3") + .bind(idx as i64) + .bind(host.0.as_hyphenated().to_string()) + .bind(tag) .map(Self::query_row) .fetch_one(&self.pool) .await; @@ -156,58 +203,36 @@ impl Store for SqliteStore { } } - async fn head(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> { - let res = sqlx::query( - "select * from records where host = ?1 and tag = ?2 and parent is null limit 1", - ) - .bind(host.0.as_simple().to_string()) - .bind(tag) - .map(Self::query_row) - .fetch_optional(&self.pool) - .await?; + async fn status(&self) -> Result<RecordStatus> { + let mut status = RecordStatus::new(); - Ok(res) - } + let res: Result<Vec<(String, String, i64)>, sqlx::Error> = + sqlx::query_as("select host, tag, max(idx) from store group by host, tag") + .fetch_all(&self.pool) + .await; - async fn tail(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> { - let res = sqlx::query( - "select * from records rp where tag=?1 and host=?2 and (select count(1) from records where parent=rp.id) = 0;", - ) - .bind(tag) - .bind(host.0.as_simple().to_string()) - .map(Self::query_row) - .fetch_optional(&self.pool) - .await?; + let res = match res { + Err(e) => return Err(eyre!("failed to fetch local store status: {}", e)), + Ok(v) => v, + }; - Ok(res) - } + for i in res { + let host = HostId( + Uuid::from_str(i.0.as_str()).expect("failed to parse uuid for local store status"), + ); - async fn tag_tails(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> { - let res = sqlx::query( - "select * from records rp where tag=?1 and (select count(1) from records where parent=rp.id) = 0;", - ) - .bind(tag) - .map(Self::query_row) - .fetch_all(&self.pool) - .await?; + status.set_raw(host, i.1, i.2 as u64); + } - Ok(res) + Ok(status) } - async fn tail_records(&self) -> Result<RecordIndex> { - let res = sqlx::query( - "select host, tag, id from records rp where (select count(1) from records where parent=rp.id) = 0;", - ) - .map(|row: SqliteRow| { - let host: Uuid= Uuid::from_str(row.get("host")).expect("invalid uuid in db host"); - let tag: String= row.get("tag"); - let id: Uuid= Uuid::from_str(row.get("id")).expect("invalid uuid in db id"); - - (HostId(host), tag, RecordId(id)) - }) - .fetch(&self.pool) - .try_collect() - .await?; + async fn all_tagged(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> { + let res = sqlx::query("select * from store where tag = ?1 order by timestamp asc") + .bind(tag) + .map(Self::query_row) + .fetch_all(&self.pool) + .await?; Ok(res) } @@ -215,7 +240,7 @@ impl Store for SqliteStore { #[cfg(test)] mod tests { - use atuin_common::record::{EncryptedData, HostId, Record}; + use atuin_common::record::{EncryptedData, Host, HostId, Record}; use crate::record::{encryption::PASETO_V4, store::Store}; @@ -223,13 +248,14 @@ mod tests { fn test_record() -> Record<EncryptedData> { Record::builder() - .host(HostId(atuin_common::utils::uuid_v7())) + .host(Host::new(HostId(atuin_common::utils::uuid_v7()))) .version("v1".into()) .tag(atuin_common::utils::uuid_v7().simple().to_string()) .data(EncryptedData { data: "1234".into(), content_encryption_key: "1234".into(), }) + .idx(0) .build() } @@ -264,13 +290,49 @@ mod tests { } #[tokio::test] + async fn last() { + let db = SqliteStore::new(":memory:").await.unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let last = db + .last(record.host.id, record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!( + last.unwrap().id, + record.id, + "expected to get back the same record that was inserted" + ); + } + + #[tokio::test] + async fn first() { + let db = SqliteStore::new(":memory:").await.unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let first = db + .first(record.host.id, record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!( + first.unwrap().id, + record.id, + "expected to get back the same record that was inserted" + ); + } + + #[tokio::test] async fn len() { let db = SqliteStore::new(":memory:").await.unwrap(); let record = test_record(); db.push(&record).await.unwrap(); let len = db - .len(record.host, record.tag.as_str()) + .len(record.host.id, record.tag.as_str()) .await .expect("failed to get store len"); @@ -290,8 +352,8 @@ mod tests { db.push(&first).await.unwrap(); db.push(&second).await.unwrap(); - let first_len = db.len(first.host, first.tag.as_str()).await.unwrap(); - let second_len = db.len(second.host, second.tag.as_str()).await.unwrap(); + let first_len = db.len(first.host.id, first.tag.as_str()).await.unwrap(); + let second_len = db.len(second.host.id, second.tag.as_str()).await.unwrap(); assert_eq!(first_len, 1, "expected length of 1 after insert"); assert_eq!(second_len, 1, "expected length of 1 after insert"); @@ -305,14 +367,12 @@ mod tests { db.push(&tail).await.expect("failed to push record"); for _ in 1..100 { - tail = tail - .new_child(vec![1, 2, 3, 4]) - .encrypt::<PASETO_V4>(&[0; 32]); + tail = tail.append(vec![1, 2, 3, 4]).encrypt::<PASETO_V4>(&[0; 32]); db.push(&tail).await.unwrap(); } assert_eq!( - db.len(tail.host, tail.tag.as_str()).await.unwrap(), + db.len(tail.host.id, tail.tag.as_str()).await.unwrap(), 100, "failed to insert 100 records" ); @@ -328,50 +388,16 @@ mod tests { records.push(tail.clone()); for _ in 1..10000 { - tail = tail.new_child(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]); + tail = tail.append(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]); records.push(tail.clone()); } db.push_batch(records.iter()).await.unwrap(); assert_eq!( - db.len(tail.host, tail.tag.as_str()).await.unwrap(), + db.len(tail.host.id, tail.tag.as_str()).await.unwrap(), 10000, "failed to insert 10k records" ); } - - #[tokio::test] - async fn test_chain() { - let db = SqliteStore::new(":memory:").await.unwrap(); - - let mut records: Vec<Record<EncryptedData>> = Vec::with_capacity(1000); - - let mut tail = test_record(); - records.push(tail.clone()); - - for _ in 1..1000 { - tail = tail.new_child(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]); - records.push(tail.clone()); - } - - db.push_batch(records.iter()).await.unwrap(); - - let mut record = db - .head(tail.host, tail.tag.as_str()) - .await - .expect("in memory sqlite should not fail") - .expect("entry exists"); - - let mut count = 1; - - while let Some(next) = db.next(&record).await.unwrap() { - assert_eq!(record.id, next.clone().parent.unwrap()); - record = next; - - count += 1; - } - - assert_eq!(count, 1000); - } } diff --git a/atuin-client/src/record/store.rs b/atuin-client/src/record/store.rs index 45d554ef..a5c156d6 100644 --- a/atuin-client/src/record/store.rs +++ b/atuin-client/src/record/store.rs @@ -1,8 +1,7 @@ use async_trait::async_trait; use eyre::Result; -use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex}; - +use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIdx, RecordStatus}; /// A record store stores records /// In more detail - we tend to need to process this into _another_ format to actually query it. /// As is, the record store is intended as the source of truth for arbitratry data, which could @@ -23,19 +22,30 @@ pub trait Store { async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>>; async fn len(&self, host: HostId, tag: &str) -> Result<u64>; - /// Get the record that follows this record - async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>>; + async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>; + async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>; - /// Get the first record for a given host and tag - async fn head(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>; + /// Get the next `limit` records, after and including the given index + async fn next( + &self, + host: HostId, + tag: &str, + idx: RecordIdx, + limit: u64, + ) -> Result<Vec<Record<EncryptedData>>>; - /// Get the last record for a given host and tag - async fn tail(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>; + /// Get the first record for a given host and tag + async fn idx( + &self, + host: HostId, + tag: &str, + idx: RecordIdx, + ) -> Result<Option<Record<EncryptedData>>>; - // Get the last record for all hosts for a given tag, useful for the read path of apps. - async fn tag_tails(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>>; + async fn status(&self) -> Result<RecordStatus>; - // Get the latest host/tag/record tuple for every set in the store. useful for building an - // index - async fn tail_records(&self) -> Result<RecordIndex>; + /// Get every start record for a given tag, regardless of host. + /// Useful when actually operating on synchronized data, and will often have conflict + /// resolution applied. + async fn all_tagged(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>>; } diff --git a/atuin-client/src/record/sync.rs b/atuin-client/src/record/sync.rs index 56be0638..2694e0ff 100644 --- a/atuin-client/src/record/sync.rs +++ b/atuin-client/src/record/sync.rs @@ -1,27 +1,51 @@ // do a sync :O +use std::cmp::Ordering; + use eyre::Result; +use thiserror::Error; use super::store::Store; use crate::{api_client::Client, settings::Settings}; -use atuin_common::record::{Diff, HostId, RecordId, RecordIndex}; +use atuin_common::record::{Diff, HostId, RecordIdx, RecordStatus}; + +#[derive(Error, Debug)] +pub enum SyncError { + #[error("the local store is ahead of the remote, but for another host. has remote lost data?")] + LocalAheadOtherHost, + + #[error("an issue with the local database occured")] + LocalStoreError, + + #[error("something has gone wrong with the sync logic: {msg:?}")] + SyncLogicError { msg: String }, + + #[error("a request to the sync server failed")] + RemoteRequestError, +} #[derive(Debug, Eq, PartialEq)] pub enum Operation { - // Either upload or download until the tail matches the below + // Either upload or download until the states matches the below Upload { - tail: RecordId, + local: RecordIdx, + remote: Option<RecordIdx>, host: HostId, tag: String, }, Download { - tail: RecordId, + local: Option<RecordIdx>, + remote: RecordIdx, + host: HostId, + tag: String, + }, + Noop { host: HostId, tag: String, }, } -pub async fn diff(settings: &Settings, store: &mut impl Store) -> Result<(Vec<Diff>, RecordIndex)> { +pub async fn diff(settings: &Settings, store: &impl Store) -> Result<(Vec<Diff>, RecordStatus)> { let client = Client::new( &settings.sync_address, &settings.session_token, @@ -29,8 +53,8 @@ pub async fn diff(settings: &Settings, store: &mut impl Store) -> Result<(Vec<Di settings.network_timeout, )?; - let local_index = store.tail_records().await?; - let remote_index = client.record_index().await?; + let local_index = store.status().await?; + let remote_index = client.record_status().await?; let diff = local_index.diff(&remote_index); @@ -41,39 +65,57 @@ pub async fn diff(settings: &Settings, store: &mut impl Store) -> Result<(Vec<Di // With the store as context, we can determine if a tail exists locally or not and therefore if it needs uploading or download. // In theory this could be done as a part of the diffing stage, but it's easier to reason // about and test this way -pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Operation>> { +pub async fn operations( + diffs: Vec<Diff>, + _store: &impl Store, +) -> Result<Vec<Operation>, SyncError> { let mut operations = Vec::with_capacity(diffs.len()); for diff in diffs { - // First, try to fetch the tail - // If it exists locally, then that means we need to update the remote - // host until it has the same tail. Ie, upload. - // If it does not exist locally, that means remote is ahead of us. - // Therefore, we need to download until our local tail matches - let record = store.get(diff.tail).await; - - let op = if record.is_ok() { - // if local has the ID, then we should find the actual tail of this - // store, so we know what we need to update the remote to. - let tail = store - .tail(diff.host, diff.tag.as_str()) - .await? - .expect("failed to fetch last record, expected tag/host to exist"); - - // TODO(ellie) update the diffing so that it stores the context of the current tail - // that way, we can determine how much we need to upload. - // For now just keep uploading until tails match + let op = match (diff.local, diff.remote) { + // We both have it! Could be either. Compare. + (Some(local), Some(remote)) => match local.cmp(&remote) { + Ordering::Equal => Operation::Noop { + host: diff.host, + tag: diff.tag, + }, + Ordering::Greater => Operation::Upload { + local, + remote: Some(remote), + host: diff.host, + tag: diff.tag, + }, + Ordering::Less => Operation::Download { + local: Some(local), + remote, + host: diff.host, + tag: diff.tag, + }, + }, - Operation::Upload { - tail: tail.id, + // Remote has it, we don't. Gotta be download + (None, Some(remote)) => Operation::Download { + local: None, + remote, host: diff.host, tag: diff.tag, - } - } else { - Operation::Download { - tail: diff.tail, + }, + + // We have it, remote doesn't. Gotta be upload. + (Some(local), None) => Operation::Upload { + local, + remote: None, host: diff.host, tag: diff.tag, + }, + + // something is pretty fucked. + (None, None) => { + return Err(SyncError::SyncLogicError { + msg: String::from( + "diff has nothing for local or remote - (host, tag) does not exist", + ), + }) } }; @@ -86,149 +128,130 @@ pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Oper // with the same properties operations.sort_by_key(|op| match op { - Operation::Upload { tail, host, .. } => ("upload", *host, *tail), - Operation::Download { tail, host, .. } => ("download", *host, *tail), + Operation::Noop { host, tag } => (0, *host, tag.clone()), + + Operation::Upload { host, tag, .. } => (1, *host, tag.clone()), + + Operation::Download { host, tag, .. } => (2, *host, tag.clone()), }); Ok(operations) } async fn sync_upload( - store: &mut impl Store, - remote_index: &RecordIndex, + store: &impl Store, client: &Client<'_>, - op: (HostId, String, RecordId), -) -> Result<i64> { + host: HostId, + tag: String, + local: RecordIdx, + remote: Option<RecordIdx>, +) -> Result<i64, SyncError> { + let remote = remote.unwrap_or(0); + let expected = local - remote; let upload_page_size = 100; - let mut total = 0; - - // so. we have an upload operation, with the tail representing the state - // we want to get the remote to - let current_tail = remote_index.get(op.0, op.1.clone()); + let mut progress = 0; println!( - "Syncing local {:?}/{}/{:?}, remote has {:?}", - op.0, op.1, op.2, current_tail + "Uploading {} records to {}/{}", + expected, + host.0.as_simple(), + tag ); - let start = if let Some(current_tail) = current_tail { - current_tail - } else { - store - .head(op.0, op.1.as_str()) + // preload with the first entry if remote does not know of this store + loop { + let page = store + .next(host, tag.as_str(), remote + progress, upload_page_size) .await - .expect("failed to fetch host/tag head") - .expect("host/tag not in current index") - .id - }; + .map_err(|e| { + error!("failed to read upload page: {e:?}"); - debug!("starting push to remote from: {:?}", start); + SyncError::LocalStoreError + })?; - // we have the start point for sync. it is either the head of the store if - // the remote has no data for it, or the tail that the remote has - // we need to iterate from the remote tail, and keep going until - // remote tail = current local tail + client.post_records(&page).await.map_err(|e| { + error!("failed to post records: {e:?}"); - let mut record = if current_tail.is_some() { - let r = store.get(start).await.unwrap(); - store.next(&r).await? - } else { - Some(store.get(start).await.unwrap()) - }; + SyncError::RemoteRequestError + })?; - let mut buf = Vec::with_capacity(upload_page_size); - - while let Some(r) = record { - if buf.len() < upload_page_size { - buf.push(r.clone()); - } else { - client.post_records(&buf).await?; + println!( + "uploaded {} to remote, progress {}/{}", + page.len(), + progress, + expected + ); + progress += page.len() as u64; - // can we reset what we have? len = 0 but keep capacity - buf = Vec::with_capacity(upload_page_size); + if progress >= expected { + break; } - record = store.next(&r).await?; - - total += 1; } - if !buf.is_empty() { - client.post_records(&buf).await?; - } - - Ok(total) + Ok(progress as i64) } async fn sync_download( - store: &mut impl Store, - remote_index: &RecordIndex, + store: &impl Store, client: &Client<'_>, - op: (HostId, String, RecordId), -) -> Result<i64> { - // TODO(ellie): implement variable page sizing like on history sync - let download_page_size = 1000; + host: HostId, + tag: String, + local: Option<RecordIdx>, + remote: RecordIdx, +) -> Result<i64, SyncError> { + let local = local.unwrap_or(0); + let expected = remote - local; + let download_page_size = 100; + let mut progress = 0; - let mut total = 0; + println!( + "Downloading {} records from {}/{}", + expected, + host.0.as_simple(), + tag + ); - // We know that the remote is ahead of us, so let's keep downloading until both - // 1) The remote stops returning full pages - // 2) The tail equals what we expect - // - // If (1) occurs without (2), then something is wrong with our index calculation - // and we should bail. - let remote_tail = remote_index - .get(op.0, op.1.clone()) - .expect("remote index does not contain expected tail during download"); - let local_tail = store.tail(op.0, op.1.as_str()).await?; - // - // We expect that the operations diff will represent the desired state - // In this case, that contains the remote tail. - assert_eq!(remote_tail, op.2); + // preload with the first entry if remote does not know of this store + loop { + let page = client + .next_records(host, tag.clone(), local + progress, download_page_size) + .await + .map_err(|_| SyncError::RemoteRequestError)?; - println!("Downloading {:?}/{}/{:?} to local", op.0, op.1, op.2); + store + .push_batch(page.iter()) + .await + .map_err(|_| SyncError::LocalStoreError)?; - let mut records = client - .next_records( - op.0, - op.1.clone(), - local_tail.map(|r| r.id), - download_page_size, - ) - .await?; + println!( + "downloaded {} records from remote, progress {}/{}", + page.len(), + progress, + expected + ); - while !records.is_empty() { - total += std::cmp::min(download_page_size, records.len() as u64); - store.push_batch(records.iter()).await?; + progress += page.len() as u64; - if records.last().unwrap().id == remote_tail { + if progress >= expected { break; } - - records = client - .next_records( - op.0, - op.1.clone(), - records.last().map(|r| r.id), - download_page_size, - ) - .await?; } - Ok(total as i64) + Ok(progress as i64) } pub async fn sync_remote( operations: Vec<Operation>, - remote_index: &RecordIndex, - local_store: &mut impl Store, + local_store: &impl Store, settings: &Settings, -) -> Result<(i64, i64)> { +) -> Result<(i64, i64), SyncError> { let client = Client::new( &settings.sync_address, &settings.session_token, settings.network_connect_timeout, settings.network_timeout, - )?; + ) + .expect("failed to create client"); let mut uploaded = 0; let mut downloaded = 0; @@ -236,14 +259,23 @@ pub async fn sync_remote( // this can totally run in parallel, but lets get it working first for i in operations { match i { - Operation::Upload { tail, host, tag } => { - uploaded += - sync_upload(local_store, remote_index, &client, (host, tag, tail)).await? - } - Operation::Download { tail, host, tag } => { - downloaded += - sync_download(local_store, remote_index, &client, (host, tag, tail)).await? + Operation::Upload { + host, + tag, + local, + remote, + } => uploaded += sync_upload(local_store, &client, host, tag, local, remote).await?, + + Operation::Download { + host, + tag, + local, + remote, + } => { + downloaded += sync_download(local_store, &client, host, tag, local, remote).await? } + + Operation::Noop { .. } => continue, } } @@ -264,13 +296,16 @@ mod tests { fn test_record() -> Record<EncryptedData> { Record::builder() - .host(HostId(atuin_common::utils::uuid_v7())) + .host(atuin_common::record::Host::new(HostId( + atuin_common::utils::uuid_v7(), + ))) .version("v1".into()) .tag(atuin_common::utils::uuid_v7().simple().to_string()) .data(EncryptedData { data: String::new(), content_encryption_key: String::new(), }) + .idx(0) .build() } @@ -296,8 +331,8 @@ mod tests { remote_store.push(&i).await.unwrap(); } - let local_index = local_store.tail_records().await.unwrap(); - let remote_index = remote_store.tail_records().await.unwrap(); + let local_index = local_store.status().await.unwrap(); + let remote_index = remote_store.status().await.unwrap(); let diff = local_index.diff(&remote_index); @@ -320,9 +355,10 @@ mod tests { assert_eq!( operations[0], Operation::Upload { - host: record.host, + host: record.host.id, tag: record.tag, - tail: record.id + local: record.idx, + remote: None, } ); } @@ -333,12 +369,14 @@ mod tests { // another. One upload, one download let shared_record = test_record(); - let remote_ahead = test_record(); + let local_ahead = shared_record - .new_child(vec![1, 2, 3]) + .append(vec![1, 2, 3]) .encrypt::<PASETO_V4>(&[0; 32]); + assert_eq!(local_ahead.idx, 1); + let local = vec![shared_record.clone(), local_ahead.clone()]; // local knows about the already synced, and something newer in the same store let remote = vec![shared_record.clone(), remote_ahead.clone()]; // remote knows about the already-synced, and one new record in a new store @@ -350,15 +388,19 @@ mod tests { assert_eq!( operations, vec![ - Operation::Download { - tail: remote_ahead.id, - host: remote_ahead.host, - tag: remote_ahead.tag, - }, + // Or in otherwords, local is ahead by one Operation::Upload { - tail: local_ahead.id, - host: local_ahead.host, + host: local_ahead.host.id, tag: local_ahead.tag, + local: 1, + remote: Some(0), + }, + // Or in other words, remote knows of a record in an entirely new store (tag) + Operation::Download { + host: remote_ahead.host.id, + tag: remote_ahead.tag, + local: None, + remote: 0, }, ] ); @@ -371,66 +413,160 @@ mod tests { // One known only by remote let shared_record = test_record(); + let local_only = test_record(); + + let local_only_20 = test_record(); + let local_only_21 = local_only_20 + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + let local_only_22 = local_only_21 + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + let local_only_23 = local_only_22 + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); - let remote_known = test_record(); - let local_known = test_record(); + let remote_only = test_record(); + + let remote_only_20 = test_record(); + let remote_only_21 = remote_only_20 + .append(vec![2, 3, 2]) + .encrypt::<PASETO_V4>(&[0; 32]); + let remote_only_22 = remote_only_21 + .append(vec![2, 3, 2]) + .encrypt::<PASETO_V4>(&[0; 32]); + let remote_only_23 = remote_only_22 + .append(vec![2, 3, 2]) + .encrypt::<PASETO_V4>(&[0; 32]); + let remote_only_24 = remote_only_23 + .append(vec![2, 3, 2]) + .encrypt::<PASETO_V4>(&[0; 32]); let second_shared = test_record(); let second_shared_remote_ahead = second_shared - .new_child(vec![1, 2, 3]) + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + let second_shared_remote_ahead2 = second_shared_remote_ahead + .append(vec![1, 2, 3]) .encrypt::<PASETO_V4>(&[0; 32]); - let local_ahead = shared_record - .new_child(vec![1, 2, 3]) + let third_shared = test_record(); + let third_shared_local_ahead = third_shared + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + let third_shared_local_ahead2 = third_shared_local_ahead + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + + let fourth_shared = test_record(); + let fourth_shared_remote_ahead = fourth_shared + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + let fourth_shared_remote_ahead2 = fourth_shared_remote_ahead + .append(vec![1, 2, 3]) .encrypt::<PASETO_V4>(&[0; 32]); let local = vec![ shared_record.clone(), second_shared.clone(), - local_known.clone(), - local_ahead.clone(), + third_shared.clone(), + fourth_shared.clone(), + fourth_shared_remote_ahead.clone(), + // single store, only local has it + local_only.clone(), + // bigger store, also only known by local + local_only_20.clone(), + local_only_21.clone(), + local_only_22.clone(), + local_only_23.clone(), + // another shared store, but local is ahead on this one + third_shared_local_ahead.clone(), + third_shared_local_ahead2.clone(), ]; let remote = vec![ + remote_only.clone(), + remote_only_20.clone(), + remote_only_21.clone(), + remote_only_22.clone(), + remote_only_23.clone(), + remote_only_24.clone(), shared_record.clone(), second_shared.clone(), + third_shared.clone(), second_shared_remote_ahead.clone(), - remote_known.clone(), + second_shared_remote_ahead2.clone(), + fourth_shared.clone(), + fourth_shared_remote_ahead.clone(), + fourth_shared_remote_ahead2.clone(), ]; // remote knows about the already-synced, and one new record in a new store let (store, diff) = build_test_diff(local, remote).await; let operations = sync::operations(diff, &store).await.unwrap(); - assert_eq!(operations.len(), 4); + assert_eq!(operations.len(), 7); let mut result_ops = vec![ + // We started with a shared record, but the remote knows of two newer records in the + // same store + Operation::Download { + local: Some(0), + remote: 2, + host: second_shared_remote_ahead.host.id, + tag: second_shared_remote_ahead.tag, + }, + // We have a shared record, local knows of the first two but not the last + Operation::Download { + local: Some(1), + remote: 2, + host: fourth_shared_remote_ahead2.host.id, + tag: fourth_shared_remote_ahead2.tag, + }, + // Remote knows of a store with a single record that local does not have Operation::Download { - tail: remote_known.id, - host: remote_known.host, - tag: remote_known.tag, + local: None, + remote: 0, + host: remote_only.host.id, + tag: remote_only.tag, }, + // Remote knows of a store with a bunch of records that local does not have Operation::Download { - tail: second_shared_remote_ahead.id, - host: second_shared.host, - tag: second_shared.tag, + local: None, + remote: 4, + host: remote_only_20.host.id, + tag: remote_only_20.tag, }, + // Local knows of a record in a store that remote does not have Operation::Upload { - tail: local_ahead.id, - host: local_ahead.host, - tag: local_ahead.tag, + local: 0, + remote: None, + host: local_only.host.id, + tag: local_only.tag, }, + // Local knows of 4 records in a store that remote does not have Operation::Upload { - tail: local_known.id, - host: local_known.host, - tag: local_known.tag, + local: 3, + remote: None, + host: local_only_20.host.id, + tag: local_only_20.tag, + }, + // Local knows of 2 more records in a shared store that remote only has one of + Operation::Upload { + local: 2, + remote: Some(0), + host: third_shared.host.id, + tag: third_shared.tag, }, ]; result_ops.sort_by_key(|op| match op { - Operation::Upload { tail, host, .. } => ("upload", *host, *tail), - Operation::Download { tail, host, .. } => ("download", *host, *tail), + Operation::Noop { host, tag } => (0, *host, tag.clone()), + + Operation::Upload { host, tag, .. } => (1, *host, tag.clone()), + + Operation::Download { host, tag, .. } => (2, *host, tag.clone()), }); - assert_eq!(operations, result_ops); + assert_eq!(result_ops, operations); } } diff --git a/atuin-client/src/settings.rs b/atuin-client/src/settings.rs index 7e251550..0798a890 100644 --- a/atuin-client/src/settings.rs +++ b/atuin-client/src/settings.rs @@ -173,6 +173,11 @@ impl Default for Stats { } } +#[derive(Clone, Debug, Deserialize, Default)] +pub struct Sync { + pub records: bool, +} + #[derive(Clone, Debug, Deserialize)] pub struct Settings { pub dialect: Dialect, @@ -217,6 +222,9 @@ pub struct Settings { #[serde(default)] pub stats: Stats, + #[serde(default)] + pub sync: Sync, + // This is automatically loaded when settings is created. Do not set in // config! Keep secrets and settings apart. #[serde(skip)] @@ -427,6 +435,7 @@ impl Settings { // muscle memory. // New users will get the new default, that is more similar to what they are used to. .set_default("enter_accept", false)? + .set_default("sync.records", false)? .add_source( Environment::with_prefix("atuin") .prefix_separator("_") |
