aboutsummaryrefslogtreecommitdiffstats
path: root/atuin-client/src/record
diff options
context:
space:
mode:
authorEllie Huxtable <ellie@elliehuxtable.com>2024-01-05 17:57:49 +0000
committerGitHub <noreply@github.com>2024-01-05 17:57:49 +0000
commit7bc6ccdd70422f8fc763e2fd814a481bc79ce7b5 (patch)
treea1c064a7c7394d261711c6e046d4c60791e6cf6f /atuin-client/src/record
parentfix: Prevent input to be interpreted as options for zsh autosuggestions (#1506) (diff)
downloadatuin-7bc6ccdd70422f8fc763e2fd814a481bc79ce7b5.zip
feat: rework record sync for improved reliability (#1478)
* feat: rework record sync for improved reliability So, to tell a story 1. We introduced the record sync, intended to be the new algorithm to sync history. 2. On top of this, I added the KV store. This was intended as a simple test of the record sync, and to see if people wanted that sort of functionality 3. History remained syncing via the old means, as while it had issues it worked more-or-less OK. And we are aware of its flaws 4. If KV syncing worked ok, history would be moved across KV syncing ran ok for 6mo or so, so I started to move across history. For several weeks, I ran a local fork of Atuin + the server that synced via records instead. The record store maintained ordering via a linked list, which was a mistake. It performed well in testing, but was really difficult to debug and reason about. So when a few small sync issues occured, they took an extremely long time to debug. This PR is huge, which I regret. It involves replacing the "parent" relationship that records once had (pointing to the previous record) with a simple index (generally referred to as idx). This also means we had to change the recordindex, which referenced "tails". Tails were the last item in the chain. Now that we use an "array" vs linked list, that logic was also replaced. And is much simpler :D Same for the queries that act on this data. ---- This isn't final - we still need to add 1. Proper server/client error handling, which has been lacking for a while 2. The actual history implementation on top This exists in a branch, just without deletions. Won't be much to add that, I just don't want to make this any larger than it already is The _only_ caveat here is that we basically lose data synced via the old record store. This is the KV data from before. It hasn't been deleted or anything, just no longer hooked up. So it's totally possible to write a migration script. I just need to do that. * update .gitignore * use correct endpoint * fix for stores with length of 1 * use create/delete enum for history store * lint, remove unneeded host_id * remove prints * add command to import old history * add enable/disable switch for record sync * add record sync to auto sync * satisfy the almighty clippy * remove file that I did not mean to commit * feedback
Diffstat (limited to 'atuin-client/src/record')
-rw-r--r--atuin-client/src/record/encryption.rs29
-rw-r--r--atuin-client/src/record/sqlite_store.rs250
-rw-r--r--atuin-client/src/record/store.rs36
-rw-r--r--atuin-client/src/record/sync.rs486
4 files changed, 489 insertions, 312 deletions
diff --git a/atuin-client/src/record/encryption.rs b/atuin-client/src/record/encryption.rs
index 3074a9c2..c2cdaa6a 100644
--- a/atuin-client/src/record/encryption.rs
+++ b/atuin-client/src/record/encryption.rs
@@ -1,5 +1,5 @@
use atuin_common::record::{
- AdditionalData, DecryptedData, EncryptedData, Encryption, HostId, RecordId,
+ AdditionalData, DecryptedData, EncryptedData, Encryption, HostId, RecordId, RecordIdx,
};
use base64::{engine::general_purpose, Engine};
use eyre::{ensure, Context, Result};
@@ -170,10 +170,10 @@ struct AtuinFooter {
#[derive(Debug, Copy, Clone, Serialize)]
struct Assertions<'a> {
id: &'a RecordId,
+ idx: &'a RecordIdx,
version: &'a str,
tag: &'a str,
host: &'a HostId,
- parent: Option<&'a RecordId>,
}
impl<'a> From<AdditionalData<'a>> for Assertions<'a> {
@@ -183,7 +183,7 @@ impl<'a> From<AdditionalData<'a>> for Assertions<'a> {
version: ad.version,
tag: ad.tag,
host: ad.host,
- parent: ad.parent,
+ idx: ad.idx,
}
}
}
@@ -196,7 +196,10 @@ impl Assertions<'_> {
#[cfg(test)]
mod tests {
- use atuin_common::{record::Record, utils::uuid_v7};
+ use atuin_common::{
+ record::{Host, Record},
+ utils::uuid_v7,
+ };
use super::*;
@@ -209,7 +212,7 @@ mod tests {
version: "v0",
tag: "kv",
host: &HostId(uuid_v7()),
- parent: None,
+ idx: &0,
};
let data = DecryptedData(vec![1, 2, 3, 4]);
@@ -228,7 +231,7 @@ mod tests {
version: "v0",
tag: "kv",
host: &HostId(uuid_v7()),
- parent: None,
+ idx: &0,
};
let data = DecryptedData(vec![1, 2, 3, 4]);
@@ -252,7 +255,7 @@ mod tests {
version: "v0",
tag: "kv",
host: &HostId(uuid_v7()),
- parent: None,
+ idx: &0,
};
let data = DecryptedData(vec![1, 2, 3, 4]);
@@ -270,7 +273,7 @@ mod tests {
version: "v0",
tag: "kv",
host: &HostId(uuid_v7()),
- parent: None,
+ idx: &0,
};
let data = DecryptedData(vec![1, 2, 3, 4]);
@@ -294,7 +297,7 @@ mod tests {
version: "v0",
tag: "kv",
host: &HostId(uuid_v7()),
- parent: None,
+ idx: &0,
};
let data = DecryptedData(vec![1, 2, 3, 4]);
@@ -323,9 +326,10 @@ mod tests {
.id(RecordId(uuid_v7()))
.version("v0".to_owned())
.tag("kv".to_owned())
- .host(HostId(uuid_v7()))
+ .host(Host::new(HostId(uuid_v7())))
.timestamp(1687244806000000)
.data(DecryptedData(vec![1, 2, 3, 4]))
+ .idx(0)
.build();
let encrypted = record.encrypt::<PASETO_V4>(&key);
@@ -345,15 +349,16 @@ mod tests {
.id(RecordId(uuid_v7()))
.version("v0".to_owned())
.tag("kv".to_owned())
- .host(HostId(uuid_v7()))
+ .host(Host::new(HostId(uuid_v7())))
.timestamp(1687244806000000)
.data(DecryptedData(vec![1, 2, 3, 4]))
+ .idx(0)
.build();
let encrypted = record.encrypt::<PASETO_V4>(&key);
let mut enc1 = encrypted.clone();
- enc1.host = HostId(uuid_v7());
+ enc1.host = Host::new(HostId(uuid_v7()));
let _ = enc1
.decrypt::<PASETO_V4>(&key)
.expect_err("tampering with the host should result in auth failure");
diff --git a/atuin-client/src/record/sqlite_store.rs b/atuin-client/src/record/sqlite_store.rs
index db709f20..8112aa96 100644
--- a/atuin-client/src/record/sqlite_store.rs
+++ b/atuin-client/src/record/sqlite_store.rs
@@ -8,17 +8,20 @@ use std::str::FromStr;
use async_trait::async_trait;
use eyre::{eyre, Result};
use fs_err as fs;
-use futures::TryStreamExt;
+
use sqlx::{
sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow},
Row,
};
-use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
+use atuin_common::record::{
+ EncryptedData, Host, HostId, Record, RecordId, RecordIdx, RecordStatus,
+};
use uuid::Uuid;
use super::store::Store;
+#[derive(Debug, Clone)]
pub struct SqliteStore {
pool: SqlitePool,
}
@@ -38,6 +41,7 @@ impl SqliteStore {
let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())?
.journal_mode(SqliteJournalMode::Wal)
+ .foreign_keys(true)
.create_if_missing(true);
let pool = SqlitePoolOptions::new().connect_with(opts).await?;
@@ -61,14 +65,14 @@ impl SqliteStore {
) -> Result<()> {
// In sqlite, we are "limited" to i64. But that is still fine, until 2262.
sqlx::query(
- "insert or ignore into records(id, host, tag, timestamp, parent, version, data, cek)
+ "insert or ignore into store(id, idx, host, tag, timestamp, version, data, cek)
values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
)
- .bind(r.id.0.as_simple().to_string())
- .bind(r.host.0.as_simple().to_string())
+ .bind(r.id.0.as_hyphenated().to_string())
+ .bind(r.idx as i64)
+ .bind(r.host.id.0.as_hyphenated().to_string())
.bind(r.tag.as_str())
.bind(r.timestamp as i64)
- .bind(r.parent.map(|p| p.0.as_simple().to_string()))
.bind(r.version.as_str())
.bind(r.data.data.as_str())
.bind(r.data.content_encryption_key.as_str())
@@ -79,20 +83,17 @@ impl SqliteStore {
}
fn query_row(row: SqliteRow) -> Record<EncryptedData> {
+ let idx: i64 = row.get("idx");
let timestamp: i64 = row.get("timestamp");
// tbh at this point things are pretty fucked so just panic
let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB");
let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB");
- let parent: Option<&str> = row.get("parent");
-
- let parent = parent
- .map(|parent| Uuid::from_str(parent).expect("invalid parent UUID format in sqlite DB"));
Record {
id: RecordId(id),
- host: HostId(host),
- parent: parent.map(RecordId),
+ idx: idx as u64,
+ host: Host::new(HostId(host)),
timestamp: timestamp as u64,
tag: row.get("tag"),
version: row.get("version"),
@@ -122,8 +123,8 @@ impl Store for SqliteStore {
}
async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>> {
- let res = sqlx::query("select * from records where id = ?1")
- .bind(id.0.as_simple().to_string())
+ let res = sqlx::query("select * from store where store.id = ?1")
+ .bind(id.0.as_hyphenated().to_string())
.map(Self::query_row)
.fetch_one(&self.pool)
.await?;
@@ -131,20 +132,66 @@ impl Store for SqliteStore {
Ok(res)
}
- async fn len(&self, host: HostId, tag: &str) -> Result<u64> {
- let res: (i64,) =
- sqlx::query_as("select count(1) from records where host = ?1 and tag = ?2")
- .bind(host.0.as_simple().to_string())
+ async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
+ let res =
+ sqlx::query("select * from store where host=?1 and tag=?2 order by idx desc limit 1")
+ .bind(host.0.as_hyphenated().to_string())
.bind(tag)
+ .map(Self::query_row)
.fetch_one(&self.pool)
+ .await;
+
+ match res {
+ Err(sqlx::Error::RowNotFound) => Ok(None),
+ Err(e) => Err(eyre!("an error occured: {}", e)),
+ Ok(record) => Ok(Some(record)),
+ }
+ }
+
+ async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
+ self.idx(host, tag, 0).await
+ }
+
+ async fn len(&self, host: HostId, tag: &str) -> Result<u64> {
+ let last = self.last(host, tag).await?;
+
+ if let Some(last) = last {
+ return Ok(last.idx + 1);
+ }
+
+ return Ok(0);
+ }
+
+ async fn next(
+ &self,
+ host: HostId,
+ tag: &str,
+ idx: RecordIdx,
+ limit: u64,
+ ) -> Result<Vec<Record<EncryptedData>>> {
+ let res =
+ sqlx::query("select * from store where idx >= ?1 and host = ?2 and tag = ?3 limit ?4")
+ .bind(idx as i64)
+ .bind(host.0.as_hyphenated().to_string())
+ .bind(tag)
+ .bind(limit as i64)
+ .map(Self::query_row)
+ .fetch_all(&self.pool)
.await?;
- Ok(res.0 as u64)
+ Ok(res)
}
- async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>> {
- let res = sqlx::query("select * from records where parent = ?1")
- .bind(record.id.0.as_simple().to_string())
+ async fn idx(
+ &self,
+ host: HostId,
+ tag: &str,
+ idx: RecordIdx,
+ ) -> Result<Option<Record<EncryptedData>>> {
+ let res = sqlx::query("select * from store where idx = ?1 and host = ?2 and tag = ?3")
+ .bind(idx as i64)
+ .bind(host.0.as_hyphenated().to_string())
+ .bind(tag)
.map(Self::query_row)
.fetch_one(&self.pool)
.await;
@@ -156,58 +203,36 @@ impl Store for SqliteStore {
}
}
- async fn head(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
- let res = sqlx::query(
- "select * from records where host = ?1 and tag = ?2 and parent is null limit 1",
- )
- .bind(host.0.as_simple().to_string())
- .bind(tag)
- .map(Self::query_row)
- .fetch_optional(&self.pool)
- .await?;
+ async fn status(&self) -> Result<RecordStatus> {
+ let mut status = RecordStatus::new();
- Ok(res)
- }
+ let res: Result<Vec<(String, String, i64)>, sqlx::Error> =
+ sqlx::query_as("select host, tag, max(idx) from store group by host, tag")
+ .fetch_all(&self.pool)
+ .await;
- async fn tail(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
- let res = sqlx::query(
- "select * from records rp where tag=?1 and host=?2 and (select count(1) from records where parent=rp.id) = 0;",
- )
- .bind(tag)
- .bind(host.0.as_simple().to_string())
- .map(Self::query_row)
- .fetch_optional(&self.pool)
- .await?;
+ let res = match res {
+ Err(e) => return Err(eyre!("failed to fetch local store status: {}", e)),
+ Ok(v) => v,
+ };
- Ok(res)
- }
+ for i in res {
+ let host = HostId(
+ Uuid::from_str(i.0.as_str()).expect("failed to parse uuid for local store status"),
+ );
- async fn tag_tails(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> {
- let res = sqlx::query(
- "select * from records rp where tag=?1 and (select count(1) from records where parent=rp.id) = 0;",
- )
- .bind(tag)
- .map(Self::query_row)
- .fetch_all(&self.pool)
- .await?;
+ status.set_raw(host, i.1, i.2 as u64);
+ }
- Ok(res)
+ Ok(status)
}
- async fn tail_records(&self) -> Result<RecordIndex> {
- let res = sqlx::query(
- "select host, tag, id from records rp where (select count(1) from records where parent=rp.id) = 0;",
- )
- .map(|row: SqliteRow| {
- let host: Uuid= Uuid::from_str(row.get("host")).expect("invalid uuid in db host");
- let tag: String= row.get("tag");
- let id: Uuid= Uuid::from_str(row.get("id")).expect("invalid uuid in db id");
-
- (HostId(host), tag, RecordId(id))
- })
- .fetch(&self.pool)
- .try_collect()
- .await?;
+ async fn all_tagged(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> {
+ let res = sqlx::query("select * from store where tag = ?1 order by timestamp asc")
+ .bind(tag)
+ .map(Self::query_row)
+ .fetch_all(&self.pool)
+ .await?;
Ok(res)
}
@@ -215,7 +240,7 @@ impl Store for SqliteStore {
#[cfg(test)]
mod tests {
- use atuin_common::record::{EncryptedData, HostId, Record};
+ use atuin_common::record::{EncryptedData, Host, HostId, Record};
use crate::record::{encryption::PASETO_V4, store::Store};
@@ -223,13 +248,14 @@ mod tests {
fn test_record() -> Record<EncryptedData> {
Record::builder()
- .host(HostId(atuin_common::utils::uuid_v7()))
+ .host(Host::new(HostId(atuin_common::utils::uuid_v7())))
.version("v1".into())
.tag(atuin_common::utils::uuid_v7().simple().to_string())
.data(EncryptedData {
data: "1234".into(),
content_encryption_key: "1234".into(),
})
+ .idx(0)
.build()
}
@@ -264,13 +290,49 @@ mod tests {
}
#[tokio::test]
+ async fn last() {
+ let db = SqliteStore::new(":memory:").await.unwrap();
+ let record = test_record();
+ db.push(&record).await.unwrap();
+
+ let last = db
+ .last(record.host.id, record.tag.as_str())
+ .await
+ .expect("failed to get store len");
+
+ assert_eq!(
+ last.unwrap().id,
+ record.id,
+ "expected to get back the same record that was inserted"
+ );
+ }
+
+ #[tokio::test]
+ async fn first() {
+ let db = SqliteStore::new(":memory:").await.unwrap();
+ let record = test_record();
+ db.push(&record).await.unwrap();
+
+ let first = db
+ .first(record.host.id, record.tag.as_str())
+ .await
+ .expect("failed to get store len");
+
+ assert_eq!(
+ first.unwrap().id,
+ record.id,
+ "expected to get back the same record that was inserted"
+ );
+ }
+
+ #[tokio::test]
async fn len() {
let db = SqliteStore::new(":memory:").await.unwrap();
let record = test_record();
db.push(&record).await.unwrap();
let len = db
- .len(record.host, record.tag.as_str())
+ .len(record.host.id, record.tag.as_str())
.await
.expect("failed to get store len");
@@ -290,8 +352,8 @@ mod tests {
db.push(&first).await.unwrap();
db.push(&second).await.unwrap();
- let first_len = db.len(first.host, first.tag.as_str()).await.unwrap();
- let second_len = db.len(second.host, second.tag.as_str()).await.unwrap();
+ let first_len = db.len(first.host.id, first.tag.as_str()).await.unwrap();
+ let second_len = db.len(second.host.id, second.tag.as_str()).await.unwrap();
assert_eq!(first_len, 1, "expected length of 1 after insert");
assert_eq!(second_len, 1, "expected length of 1 after insert");
@@ -305,14 +367,12 @@ mod tests {
db.push(&tail).await.expect("failed to push record");
for _ in 1..100 {
- tail = tail
- .new_child(vec![1, 2, 3, 4])
- .encrypt::<PASETO_V4>(&[0; 32]);
+ tail = tail.append(vec![1, 2, 3, 4]).encrypt::<PASETO_V4>(&[0; 32]);
db.push(&tail).await.unwrap();
}
assert_eq!(
- db.len(tail.host, tail.tag.as_str()).await.unwrap(),
+ db.len(tail.host.id, tail.tag.as_str()).await.unwrap(),
100,
"failed to insert 100 records"
);
@@ -328,50 +388,16 @@ mod tests {
records.push(tail.clone());
for _ in 1..10000 {
- tail = tail.new_child(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]);
+ tail = tail.append(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]);
records.push(tail.clone());
}
db.push_batch(records.iter()).await.unwrap();
assert_eq!(
- db.len(tail.host, tail.tag.as_str()).await.unwrap(),
+ db.len(tail.host.id, tail.tag.as_str()).await.unwrap(),
10000,
"failed to insert 10k records"
);
}
-
- #[tokio::test]
- async fn test_chain() {
- let db = SqliteStore::new(":memory:").await.unwrap();
-
- let mut records: Vec<Record<EncryptedData>> = Vec::with_capacity(1000);
-
- let mut tail = test_record();
- records.push(tail.clone());
-
- for _ in 1..1000 {
- tail = tail.new_child(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]);
- records.push(tail.clone());
- }
-
- db.push_batch(records.iter()).await.unwrap();
-
- let mut record = db
- .head(tail.host, tail.tag.as_str())
- .await
- .expect("in memory sqlite should not fail")
- .expect("entry exists");
-
- let mut count = 1;
-
- while let Some(next) = db.next(&record).await.unwrap() {
- assert_eq!(record.id, next.clone().parent.unwrap());
- record = next;
-
- count += 1;
- }
-
- assert_eq!(count, 1000);
- }
}
diff --git a/atuin-client/src/record/store.rs b/atuin-client/src/record/store.rs
index 45d554ef..a5c156d6 100644
--- a/atuin-client/src/record/store.rs
+++ b/atuin-client/src/record/store.rs
@@ -1,8 +1,7 @@
use async_trait::async_trait;
use eyre::Result;
-use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
-
+use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIdx, RecordStatus};
/// A record store stores records
/// In more detail - we tend to need to process this into _another_ format to actually query it.
/// As is, the record store is intended as the source of truth for arbitratry data, which could
@@ -23,19 +22,30 @@ pub trait Store {
async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>>;
async fn len(&self, host: HostId, tag: &str) -> Result<u64>;
- /// Get the record that follows this record
- async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>>;
+ async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
+ async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
- /// Get the first record for a given host and tag
- async fn head(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
+ /// Get the next `limit` records, after and including the given index
+ async fn next(
+ &self,
+ host: HostId,
+ tag: &str,
+ idx: RecordIdx,
+ limit: u64,
+ ) -> Result<Vec<Record<EncryptedData>>>;
- /// Get the last record for a given host and tag
- async fn tail(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
+ /// Get the first record for a given host and tag
+ async fn idx(
+ &self,
+ host: HostId,
+ tag: &str,
+ idx: RecordIdx,
+ ) -> Result<Option<Record<EncryptedData>>>;
- // Get the last record for all hosts for a given tag, useful for the read path of apps.
- async fn tag_tails(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>>;
+ async fn status(&self) -> Result<RecordStatus>;
- // Get the latest host/tag/record tuple for every set in the store. useful for building an
- // index
- async fn tail_records(&self) -> Result<RecordIndex>;
+ /// Get every start record for a given tag, regardless of host.
+ /// Useful when actually operating on synchronized data, and will often have conflict
+ /// resolution applied.
+ async fn all_tagged(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>>;
}
diff --git a/atuin-client/src/record/sync.rs b/atuin-client/src/record/sync.rs
index 56be0638..2694e0ff 100644
--- a/atuin-client/src/record/sync.rs
+++ b/atuin-client/src/record/sync.rs
@@ -1,27 +1,51 @@
// do a sync :O
+use std::cmp::Ordering;
+
use eyre::Result;
+use thiserror::Error;
use super::store::Store;
use crate::{api_client::Client, settings::Settings};
-use atuin_common::record::{Diff, HostId, RecordId, RecordIndex};
+use atuin_common::record::{Diff, HostId, RecordIdx, RecordStatus};
+
+#[derive(Error, Debug)]
+pub enum SyncError {
+ #[error("the local store is ahead of the remote, but for another host. has remote lost data?")]
+ LocalAheadOtherHost,
+
+ #[error("an issue with the local database occured")]
+ LocalStoreError,
+
+ #[error("something has gone wrong with the sync logic: {msg:?}")]
+ SyncLogicError { msg: String },
+
+ #[error("a request to the sync server failed")]
+ RemoteRequestError,
+}
#[derive(Debug, Eq, PartialEq)]
pub enum Operation {
- // Either upload or download until the tail matches the below
+ // Either upload or download until the states matches the below
Upload {
- tail: RecordId,
+ local: RecordIdx,
+ remote: Option<RecordIdx>,
host: HostId,
tag: String,
},
Download {
- tail: RecordId,
+ local: Option<RecordIdx>,
+ remote: RecordIdx,
+ host: HostId,
+ tag: String,
+ },
+ Noop {
host: HostId,
tag: String,
},
}
-pub async fn diff(settings: &Settings, store: &mut impl Store) -> Result<(Vec<Diff>, RecordIndex)> {
+pub async fn diff(settings: &Settings, store: &impl Store) -> Result<(Vec<Diff>, RecordStatus)> {
let client = Client::new(
&settings.sync_address,
&settings.session_token,
@@ -29,8 +53,8 @@ pub async fn diff(settings: &Settings, store: &mut impl Store) -> Result<(Vec<Di
settings.network_timeout,
)?;
- let local_index = store.tail_records().await?;
- let remote_index = client.record_index().await?;
+ let local_index = store.status().await?;
+ let remote_index = client.record_status().await?;
let diff = local_index.diff(&remote_index);
@@ -41,39 +65,57 @@ pub async fn diff(settings: &Settings, store: &mut impl Store) -> Result<(Vec<Di
// With the store as context, we can determine if a tail exists locally or not and therefore if it needs uploading or download.
// In theory this could be done as a part of the diffing stage, but it's easier to reason
// about and test this way
-pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Operation>> {
+pub async fn operations(
+ diffs: Vec<Diff>,
+ _store: &impl Store,
+) -> Result<Vec<Operation>, SyncError> {
let mut operations = Vec::with_capacity(diffs.len());
for diff in diffs {
- // First, try to fetch the tail
- // If it exists locally, then that means we need to update the remote
- // host until it has the same tail. Ie, upload.
- // If it does not exist locally, that means remote is ahead of us.
- // Therefore, we need to download until our local tail matches
- let record = store.get(diff.tail).await;
-
- let op = if record.is_ok() {
- // if local has the ID, then we should find the actual tail of this
- // store, so we know what we need to update the remote to.
- let tail = store
- .tail(diff.host, diff.tag.as_str())
- .await?
- .expect("failed to fetch last record, expected tag/host to exist");
-
- // TODO(ellie) update the diffing so that it stores the context of the current tail
- // that way, we can determine how much we need to upload.
- // For now just keep uploading until tails match
+ let op = match (diff.local, diff.remote) {
+ // We both have it! Could be either. Compare.
+ (Some(local), Some(remote)) => match local.cmp(&remote) {
+ Ordering::Equal => Operation::Noop {
+ host: diff.host,
+ tag: diff.tag,
+ },
+ Ordering::Greater => Operation::Upload {
+ local,
+ remote: Some(remote),
+ host: diff.host,
+ tag: diff.tag,
+ },
+ Ordering::Less => Operation::Download {
+ local: Some(local),
+ remote,
+ host: diff.host,
+ tag: diff.tag,
+ },
+ },
- Operation::Upload {
- tail: tail.id,
+ // Remote has it, we don't. Gotta be download
+ (None, Some(remote)) => Operation::Download {
+ local: None,
+ remote,
host: diff.host,
tag: diff.tag,
- }
- } else {
- Operation::Download {
- tail: diff.tail,
+ },
+
+ // We have it, remote doesn't. Gotta be upload.
+ (Some(local), None) => Operation::Upload {
+ local,
+ remote: None,
host: diff.host,
tag: diff.tag,
+ },
+
+ // something is pretty fucked.
+ (None, None) => {
+ return Err(SyncError::SyncLogicError {
+ msg: String::from(
+ "diff has nothing for local or remote - (host, tag) does not exist",
+ ),
+ })
}
};
@@ -86,149 +128,130 @@ pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Oper
// with the same properties
operations.sort_by_key(|op| match op {
- Operation::Upload { tail, host, .. } => ("upload", *host, *tail),
- Operation::Download { tail, host, .. } => ("download", *host, *tail),
+ Operation::Noop { host, tag } => (0, *host, tag.clone()),
+
+ Operation::Upload { host, tag, .. } => (1, *host, tag.clone()),
+
+ Operation::Download { host, tag, .. } => (2, *host, tag.clone()),
});
Ok(operations)
}
async fn sync_upload(
- store: &mut impl Store,
- remote_index: &RecordIndex,
+ store: &impl Store,
client: &Client<'_>,
- op: (HostId, String, RecordId),
-) -> Result<i64> {
+ host: HostId,
+ tag: String,
+ local: RecordIdx,
+ remote: Option<RecordIdx>,
+) -> Result<i64, SyncError> {
+ let remote = remote.unwrap_or(0);
+ let expected = local - remote;
let upload_page_size = 100;
- let mut total = 0;
-
- // so. we have an upload operation, with the tail representing the state
- // we want to get the remote to
- let current_tail = remote_index.get(op.0, op.1.clone());
+ let mut progress = 0;
println!(
- "Syncing local {:?}/{}/{:?}, remote has {:?}",
- op.0, op.1, op.2, current_tail
+ "Uploading {} records to {}/{}",
+ expected,
+ host.0.as_simple(),
+ tag
);
- let start = if let Some(current_tail) = current_tail {
- current_tail
- } else {
- store
- .head(op.0, op.1.as_str())
+ // preload with the first entry if remote does not know of this store
+ loop {
+ let page = store
+ .next(host, tag.as_str(), remote + progress, upload_page_size)
.await
- .expect("failed to fetch host/tag head")
- .expect("host/tag not in current index")
- .id
- };
+ .map_err(|e| {
+ error!("failed to read upload page: {e:?}");
- debug!("starting push to remote from: {:?}", start);
+ SyncError::LocalStoreError
+ })?;
- // we have the start point for sync. it is either the head of the store if
- // the remote has no data for it, or the tail that the remote has
- // we need to iterate from the remote tail, and keep going until
- // remote tail = current local tail
+ client.post_records(&page).await.map_err(|e| {
+ error!("failed to post records: {e:?}");
- let mut record = if current_tail.is_some() {
- let r = store.get(start).await.unwrap();
- store.next(&r).await?
- } else {
- Some(store.get(start).await.unwrap())
- };
+ SyncError::RemoteRequestError
+ })?;
- let mut buf = Vec::with_capacity(upload_page_size);
-
- while let Some(r) = record {
- if buf.len() < upload_page_size {
- buf.push(r.clone());
- } else {
- client.post_records(&buf).await?;
+ println!(
+ "uploaded {} to remote, progress {}/{}",
+ page.len(),
+ progress,
+ expected
+ );
+ progress += page.len() as u64;
- // can we reset what we have? len = 0 but keep capacity
- buf = Vec::with_capacity(upload_page_size);
+ if progress >= expected {
+ break;
}
- record = store.next(&r).await?;
-
- total += 1;
}
- if !buf.is_empty() {
- client.post_records(&buf).await?;
- }
-
- Ok(total)
+ Ok(progress as i64)
}
async fn sync_download(
- store: &mut impl Store,
- remote_index: &RecordIndex,
+ store: &impl Store,
client: &Client<'_>,
- op: (HostId, String, RecordId),
-) -> Result<i64> {
- // TODO(ellie): implement variable page sizing like on history sync
- let download_page_size = 1000;
+ host: HostId,
+ tag: String,
+ local: Option<RecordIdx>,
+ remote: RecordIdx,
+) -> Result<i64, SyncError> {
+ let local = local.unwrap_or(0);
+ let expected = remote - local;
+ let download_page_size = 100;
+ let mut progress = 0;
- let mut total = 0;
+ println!(
+ "Downloading {} records from {}/{}",
+ expected,
+ host.0.as_simple(),
+ tag
+ );
- // We know that the remote is ahead of us, so let's keep downloading until both
- // 1) The remote stops returning full pages
- // 2) The tail equals what we expect
- //
- // If (1) occurs without (2), then something is wrong with our index calculation
- // and we should bail.
- let remote_tail = remote_index
- .get(op.0, op.1.clone())
- .expect("remote index does not contain expected tail during download");
- let local_tail = store.tail(op.0, op.1.as_str()).await?;
- //
- // We expect that the operations diff will represent the desired state
- // In this case, that contains the remote tail.
- assert_eq!(remote_tail, op.2);
+ // preload with the first entry if remote does not know of this store
+ loop {
+ let page = client
+ .next_records(host, tag.clone(), local + progress, download_page_size)
+ .await
+ .map_err(|_| SyncError::RemoteRequestError)?;
- println!("Downloading {:?}/{}/{:?} to local", op.0, op.1, op.2);
+ store
+ .push_batch(page.iter())
+ .await
+ .map_err(|_| SyncError::LocalStoreError)?;
- let mut records = client
- .next_records(
- op.0,
- op.1.clone(),
- local_tail.map(|r| r.id),
- download_page_size,
- )
- .await?;
+ println!(
+ "downloaded {} records from remote, progress {}/{}",
+ page.len(),
+ progress,
+ expected
+ );
- while !records.is_empty() {
- total += std::cmp::min(download_page_size, records.len() as u64);
- store.push_batch(records.iter()).await?;
+ progress += page.len() as u64;
- if records.last().unwrap().id == remote_tail {
+ if progress >= expected {
break;
}
-
- records = client
- .next_records(
- op.0,
- op.1.clone(),
- records.last().map(|r| r.id),
- download_page_size,
- )
- .await?;
}
- Ok(total as i64)
+ Ok(progress as i64)
}
pub async fn sync_remote(
operations: Vec<Operation>,
- remote_index: &RecordIndex,
- local_store: &mut impl Store,
+ local_store: &impl Store,
settings: &Settings,
-) -> Result<(i64, i64)> {
+) -> Result<(i64, i64), SyncError> {
let client = Client::new(
&settings.sync_address,
&settings.session_token,
settings.network_connect_timeout,
settings.network_timeout,
- )?;
+ )
+ .expect("failed to create client");
let mut uploaded = 0;
let mut downloaded = 0;
@@ -236,14 +259,23 @@ pub async fn sync_remote(
// this can totally run in parallel, but lets get it working first
for i in operations {
match i {
- Operation::Upload { tail, host, tag } => {
- uploaded +=
- sync_upload(local_store, remote_index, &client, (host, tag, tail)).await?
- }
- Operation::Download { tail, host, tag } => {
- downloaded +=
- sync_download(local_store, remote_index, &client, (host, tag, tail)).await?
+ Operation::Upload {
+ host,
+ tag,
+ local,
+ remote,
+ } => uploaded += sync_upload(local_store, &client, host, tag, local, remote).await?,
+
+ Operation::Download {
+ host,
+ tag,
+ local,
+ remote,
+ } => {
+ downloaded += sync_download(local_store, &client, host, tag, local, remote).await?
}
+
+ Operation::Noop { .. } => continue,
}
}
@@ -264,13 +296,16 @@ mod tests {
fn test_record() -> Record<EncryptedData> {
Record::builder()
- .host(HostId(atuin_common::utils::uuid_v7()))
+ .host(atuin_common::record::Host::new(HostId(
+ atuin_common::utils::uuid_v7(),
+ )))
.version("v1".into())
.tag(atuin_common::utils::uuid_v7().simple().to_string())
.data(EncryptedData {
data: String::new(),
content_encryption_key: String::new(),
})
+ .idx(0)
.build()
}
@@ -296,8 +331,8 @@ mod tests {
remote_store.push(&i).await.unwrap();
}
- let local_index = local_store.tail_records().await.unwrap();
- let remote_index = remote_store.tail_records().await.unwrap();
+ let local_index = local_store.status().await.unwrap();
+ let remote_index = remote_store.status().await.unwrap();
let diff = local_index.diff(&remote_index);
@@ -320,9 +355,10 @@ mod tests {
assert_eq!(
operations[0],
Operation::Upload {
- host: record.host,
+ host: record.host.id,
tag: record.tag,
- tail: record.id
+ local: record.idx,
+ remote: None,
}
);
}
@@ -333,12 +369,14 @@ mod tests {
// another. One upload, one download
let shared_record = test_record();
-
let remote_ahead = test_record();
+
let local_ahead = shared_record
- .new_child(vec![1, 2, 3])
+ .append(vec![1, 2, 3])
.encrypt::<PASETO_V4>(&[0; 32]);
+ assert_eq!(local_ahead.idx, 1);
+
let local = vec![shared_record.clone(), local_ahead.clone()]; // local knows about the already synced, and something newer in the same store
let remote = vec![shared_record.clone(), remote_ahead.clone()]; // remote knows about the already-synced, and one new record in a new store
@@ -350,15 +388,19 @@ mod tests {
assert_eq!(
operations,
vec![
- Operation::Download {
- tail: remote_ahead.id,
- host: remote_ahead.host,
- tag: remote_ahead.tag,
- },
+ // Or in otherwords, local is ahead by one
Operation::Upload {
- tail: local_ahead.id,
- host: local_ahead.host,
+ host: local_ahead.host.id,
tag: local_ahead.tag,
+ local: 1,
+ remote: Some(0),
+ },
+ // Or in other words, remote knows of a record in an entirely new store (tag)
+ Operation::Download {
+ host: remote_ahead.host.id,
+ tag: remote_ahead.tag,
+ local: None,
+ remote: 0,
},
]
);
@@ -371,66 +413,160 @@ mod tests {
// One known only by remote
let shared_record = test_record();
+ let local_only = test_record();
+
+ let local_only_20 = test_record();
+ let local_only_21 = local_only_20
+ .append(vec![1, 2, 3])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+ let local_only_22 = local_only_21
+ .append(vec![1, 2, 3])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+ let local_only_23 = local_only_22
+ .append(vec![1, 2, 3])
+ .encrypt::<PASETO_V4>(&[0; 32]);
- let remote_known = test_record();
- let local_known = test_record();
+ let remote_only = test_record();
+
+ let remote_only_20 = test_record();
+ let remote_only_21 = remote_only_20
+ .append(vec![2, 3, 2])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+ let remote_only_22 = remote_only_21
+ .append(vec![2, 3, 2])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+ let remote_only_23 = remote_only_22
+ .append(vec![2, 3, 2])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+ let remote_only_24 = remote_only_23
+ .append(vec![2, 3, 2])
+ .encrypt::<PASETO_V4>(&[0; 32]);
let second_shared = test_record();
let second_shared_remote_ahead = second_shared
- .new_child(vec![1, 2, 3])
+ .append(vec![1, 2, 3])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+ let second_shared_remote_ahead2 = second_shared_remote_ahead
+ .append(vec![1, 2, 3])
.encrypt::<PASETO_V4>(&[0; 32]);
- let local_ahead = shared_record
- .new_child(vec![1, 2, 3])
+ let third_shared = test_record();
+ let third_shared_local_ahead = third_shared
+ .append(vec![1, 2, 3])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+ let third_shared_local_ahead2 = third_shared_local_ahead
+ .append(vec![1, 2, 3])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+
+ let fourth_shared = test_record();
+ let fourth_shared_remote_ahead = fourth_shared
+ .append(vec![1, 2, 3])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+ let fourth_shared_remote_ahead2 = fourth_shared_remote_ahead
+ .append(vec![1, 2, 3])
.encrypt::<PASETO_V4>(&[0; 32]);
let local = vec![
shared_record.clone(),
second_shared.clone(),
- local_known.clone(),
- local_ahead.clone(),
+ third_shared.clone(),
+ fourth_shared.clone(),
+ fourth_shared_remote_ahead.clone(),
+ // single store, only local has it
+ local_only.clone(),
+ // bigger store, also only known by local
+ local_only_20.clone(),
+ local_only_21.clone(),
+ local_only_22.clone(),
+ local_only_23.clone(),
+ // another shared store, but local is ahead on this one
+ third_shared_local_ahead.clone(),
+ third_shared_local_ahead2.clone(),
];
let remote = vec![
+ remote_only.clone(),
+ remote_only_20.clone(),
+ remote_only_21.clone(),
+ remote_only_22.clone(),
+ remote_only_23.clone(),
+ remote_only_24.clone(),
shared_record.clone(),
second_shared.clone(),
+ third_shared.clone(),
second_shared_remote_ahead.clone(),
- remote_known.clone(),
+ second_shared_remote_ahead2.clone(),
+ fourth_shared.clone(),
+ fourth_shared_remote_ahead.clone(),
+ fourth_shared_remote_ahead2.clone(),
]; // remote knows about the already-synced, and one new record in a new store
let (store, diff) = build_test_diff(local, remote).await;
let operations = sync::operations(diff, &store).await.unwrap();
- assert_eq!(operations.len(), 4);
+ assert_eq!(operations.len(), 7);
let mut result_ops = vec![
+ // We started with a shared record, but the remote knows of two newer records in the
+ // same store
+ Operation::Download {
+ local: Some(0),
+ remote: 2,
+ host: second_shared_remote_ahead.host.id,
+ tag: second_shared_remote_ahead.tag,
+ },
+ // We have a shared record, local knows of the first two but not the last
+ Operation::Download {
+ local: Some(1),
+ remote: 2,
+ host: fourth_shared_remote_ahead2.host.id,
+ tag: fourth_shared_remote_ahead2.tag,
+ },
+ // Remote knows of a store with a single record that local does not have
Operation::Download {
- tail: remote_known.id,
- host: remote_known.host,
- tag: remote_known.tag,
+ local: None,
+ remote: 0,
+ host: remote_only.host.id,
+ tag: remote_only.tag,
},
+ // Remote knows of a store with a bunch of records that local does not have
Operation::Download {
- tail: second_shared_remote_ahead.id,
- host: second_shared.host,
- tag: second_shared.tag,
+ local: None,
+ remote: 4,
+ host: remote_only_20.host.id,
+ tag: remote_only_20.tag,
},
+ // Local knows of a record in a store that remote does not have
Operation::Upload {
- tail: local_ahead.id,
- host: local_ahead.host,
- tag: local_ahead.tag,
+ local: 0,
+ remote: None,
+ host: local_only.host.id,
+ tag: local_only.tag,
},
+ // Local knows of 4 records in a store that remote does not have
Operation::Upload {
- tail: local_known.id,
- host: local_known.host,
- tag: local_known.tag,
+ local: 3,
+ remote: None,
+ host: local_only_20.host.id,
+ tag: local_only_20.tag,
+ },
+ // Local knows of 2 more records in a shared store that remote only has one of
+ Operation::Upload {
+ local: 2,
+ remote: Some(0),
+ host: third_shared.host.id,
+ tag: third_shared.tag,
},
];
result_ops.sort_by_key(|op| match op {
- Operation::Upload { tail, host, .. } => ("upload", *host, *tail),
- Operation::Download { tail, host, .. } => ("download", *host, *tail),
+ Operation::Noop { host, tag } => (0, *host, tag.clone()),
+
+ Operation::Upload { host, tag, .. } => (1, *host, tag.clone()),
+
+ Operation::Download { host, tag, .. } => (2, *host, tag.clone()),
});
- assert_eq!(operations, result_ops);
+ assert_eq!(result_ops, operations);
}
}