From 366b8ea97bbe36ad5e3dd8d45f1e787ee2a7f223 Mon Sep 17 00:00:00 2001 From: Ellie Huxtable Date: Mon, 29 Jan 2024 16:38:24 +0000 Subject: feat: automatically init history store when record sync is enabled (#1634) * add support for getting the total length of a store * tidy up sync * auto call init if history is ahead * fix import order, key regen * fix import order, key regen * do not delete key when user deletes account * message output * remote init store command; this is now automatic * should probs make that function return u64 at some point --- atuin-client/src/record/sqlite_store.rs | 32 +++++++++++++++++++++ atuin-client/src/record/store.rs | 2 ++ atuin-client/src/record/sync.rs | 49 +++++++++++++++++++++++++-------- 3 files changed, 71 insertions(+), 12 deletions(-) (limited to 'atuin-client/src/record') diff --git a/atuin-client/src/record/sqlite_store.rs b/atuin-client/src/record/sqlite_store.rs index 50f30d76..e9d7ff59 100644 --- a/atuin-client/src/record/sqlite_store.rs +++ b/atuin-client/src/record/sqlite_store.rs @@ -155,6 +155,18 @@ impl Store for SqliteStore { self.idx(host, tag, 0).await } + async fn len_tag(&self, tag: &str) -> Result { + let res: Result<(i64,), sqlx::Error> = + sqlx::query_as("select count(*) from store where tag=?1") + .bind(tag) + .fetch_one(&self.pool) + .await; + match res { + Err(e) => Err(eyre!("failed to fetch local store len: {}", e)), + Ok(v) => Ok(v.0 as u64), + } + } + async fn len(&self, host: HostId, tag: &str) -> Result { let last = self.last(host, tag).await?; @@ -342,6 +354,20 @@ mod tests { assert_eq!(len, 1, "expected length of 1 after insert"); } + #[tokio::test] + async fn len_tag() { + let db = SqliteStore::new(":memory:", 0.1).await.unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let len = db + .len_tag(record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!(len, 1, "expected length of 1 after insert"); + } + #[tokio::test] async fn len_different_tags() { let db = SqliteStore::new(":memory:", 0.1).await.unwrap(); @@ -379,6 +405,12 @@ mod tests { 100, "failed to insert 100 records" ); + + assert_eq!( + db.len_tag(tail.tag.as_str()).await.unwrap(), + 100, + "failed to insert 100 records" + ); } #[tokio::test] diff --git a/atuin-client/src/record/store.rs b/atuin-client/src/record/store.rs index efe2eb4a..40c1224b 100644 --- a/atuin-client/src/record/store.rs +++ b/atuin-client/src/record/store.rs @@ -21,7 +21,9 @@ pub trait Store { ) -> Result<()>; async fn get(&self, id: RecordId) -> Result>; + async fn len(&self, host: HostId, tag: &str) -> Result; + async fn len_tag(&self, tag: &str) -> Result; async fn last(&self, host: HostId, tag: &str) -> Result>>; async fn first(&self, host: HostId, tag: &str) -> Result>>; diff --git a/atuin-client/src/record/sync.rs b/atuin-client/src/record/sync.rs index 97152f79..eca0c930 100644 --- a/atuin-client/src/record/sync.rs +++ b/atuin-client/src/record/sync.rs @@ -14,14 +14,17 @@ pub enum SyncError { #[error("the local store is ahead of the remote, but for another host. has remote lost data?")] LocalAheadOtherHost, - #[error("an issue with the local database occured")] - LocalStoreError, + #[error("an issue with the local database occured: {msg:?}")] + LocalStoreError { msg: String }, #[error("something has gone wrong with the sync logic: {msg:?}")] SyncLogicError { msg: String }, - #[error("a request to the sync server failed")] - RemoteRequestError, + #[error("operational error: {msg:?}")] + OperationalError { msg: String }, + + #[error("a request to the sync server failed: {msg:?}")] + RemoteRequestError { msg: String }, } #[derive(Debug, Eq, PartialEq)] @@ -45,16 +48,27 @@ pub enum Operation { }, } -pub async fn diff(settings: &Settings, store: &impl Store) -> Result<(Vec, RecordStatus)> { +pub async fn diff( + settings: &Settings, + store: &impl Store, +) -> Result<(Vec, RecordStatus), SyncError> { let client = Client::new( &settings.sync_address, &settings.session_token, settings.network_connect_timeout, settings.network_timeout, - )?; + ) + .map_err(|e| SyncError::OperationalError { msg: e.to_string() })?; + + let local_index = store + .status() + .await + .map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?; - let local_index = store.status().await?; - let remote_index = client.record_status().await?; + let remote_index = client + .record_status() + .await + .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?; let diff = local_index.diff(&remote_index); @@ -166,13 +180,13 @@ async fn sync_upload( .map_err(|e| { error!("failed to read upload page: {e:?}"); - SyncError::LocalStoreError + SyncError::LocalStoreError { msg: e.to_string() } })?; client.post_records(&page).await.map_err(|e| { error!("failed to post records: {e:?}"); - SyncError::RemoteRequestError + SyncError::RemoteRequestError { msg: e.to_string() } })?; println!( @@ -217,12 +231,12 @@ async fn sync_download( let page = client .next_records(host, tag.clone(), local + progress, download_page_size) .await - .map_err(|_| SyncError::RemoteRequestError)?; + .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?; store .push_batch(page.iter()) .await - .map_err(|_| SyncError::LocalStoreError)?; + .map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?; println!( "downloaded {} records from remote, progress {}/{}", @@ -283,6 +297,17 @@ pub async fn sync_remote( Ok((uploaded, downloaded)) } +pub async fn sync( + settings: &Settings, + store: &impl Store, +) -> Result<(i64, Vec), SyncError> { + let (diff, _) = diff(settings, store).await?; + let operations = operations(diff, store).await?; + let (uploaded, downloaded) = sync_remote(operations, store, settings).await?; + + Ok((uploaded, downloaded)) +} + #[cfg(test)] mod tests { use atuin_common::record::{Diff, EncryptedData, HostId, Record}; -- cgit v1.3.1