aboutsummaryrefslogtreecommitdiffstats
path: root/atuin-client/src/record
diff options
context:
space:
mode:
Diffstat (limited to 'atuin-client/src/record')
-rw-r--r--atuin-client/src/record/encryption.rs29
-rw-r--r--atuin-client/src/record/sqlite_store.rs250
-rw-r--r--atuin-client/src/record/store.rs36
-rw-r--r--atuin-client/src/record/sync.rs486
4 files changed, 489 insertions, 312 deletions
diff --git a/atuin-client/src/record/encryption.rs b/atuin-client/src/record/encryption.rs
index 3074a9c2..c2cdaa6a 100644
--- a/atuin-client/src/record/encryption.rs
+++ b/atuin-client/src/record/encryption.rs
@@ -1,5 +1,5 @@
use atuin_common::record::{
- AdditionalData, DecryptedData, EncryptedData, Encryption, HostId, RecordId,
+ AdditionalData, DecryptedData, EncryptedData, Encryption, HostId, RecordId, RecordIdx,
};
use base64::{engine::general_purpose, Engine};
use eyre::{ensure, Context, Result};
@@ -170,10 +170,10 @@ struct AtuinFooter {
#[derive(Debug, Copy, Clone, Serialize)]
struct Assertions<'a> {
id: &'a RecordId,
+ idx: &'a RecordIdx,
version: &'a str,
tag: &'a str,
host: &'a HostId,
- parent: Option<&'a RecordId>,
}
impl<'a> From<AdditionalData<'a>> for Assertions<'a> {
@@ -183,7 +183,7 @@ impl<'a> From<AdditionalData<'a>> for Assertions<'a> {
version: ad.version,
tag: ad.tag,
host: ad.host,
- parent: ad.parent,
+ idx: ad.idx,
}
}
}
@@ -196,7 +196,10 @@ impl Assertions<'_> {
#[cfg(test)]
mod tests {
- use atuin_common::{record::Record, utils::uuid_v7};
+ use atuin_common::{
+ record::{Host, Record},
+ utils::uuid_v7,
+ };
use super::*;
@@ -209,7 +212,7 @@ mod tests {
version: "v0",
tag: "kv",
host: &HostId(uuid_v7()),
- parent: None,
+ idx: &0,
};
let data = DecryptedData(vec![1, 2, 3, 4]);
@@ -228,7 +231,7 @@ mod tests {
version: "v0",
tag: "kv",
host: &HostId(uuid_v7()),
- parent: None,
+ idx: &0,
};
let data = DecryptedData(vec![1, 2, 3, 4]);
@@ -252,7 +255,7 @@ mod tests {
version: "v0",
tag: "kv",
host: &HostId(uuid_v7()),
- parent: None,
+ idx: &0,
};
let data = DecryptedData(vec![1, 2, 3, 4]);
@@ -270,7 +273,7 @@ mod tests {
version: "v0",
tag: "kv",
host: &HostId(uuid_v7()),
- parent: None,
+ idx: &0,
};
let data = DecryptedData(vec![1, 2, 3, 4]);
@@ -294,7 +297,7 @@ mod tests {
version: "v0",
tag: "kv",
host: &HostId(uuid_v7()),
- parent: None,
+ idx: &0,
};
let data = DecryptedData(vec![1, 2, 3, 4]);
@@ -323,9 +326,10 @@ mod tests {
.id(RecordId(uuid_v7()))
.version("v0".to_owned())
.tag("kv".to_owned())
- .host(HostId(uuid_v7()))
+ .host(Host::new(HostId(uuid_v7())))
.timestamp(1687244806000000)
.data(DecryptedData(vec![1, 2, 3, 4]))
+ .idx(0)
.build();
let encrypted = record.encrypt::<PASETO_V4>(&key);
@@ -345,15 +349,16 @@ mod tests {
.id(RecordId(uuid_v7()))
.version("v0".to_owned())
.tag("kv".to_owned())
- .host(HostId(uuid_v7()))
+ .host(Host::new(HostId(uuid_v7())))
.timestamp(1687244806000000)
.data(DecryptedData(vec![1, 2, 3, 4]))
+ .idx(0)
.build();
let encrypted = record.encrypt::<PASETO_V4>(&key);
let mut enc1 = encrypted.clone();
- enc1.host = HostId(uuid_v7());
+ enc1.host = Host::new(HostId(uuid_v7()));
let _ = enc1
.decrypt::<PASETO_V4>(&key)
.expect_err("tampering with the host should result in auth failure");
diff --git a/atuin-client/src/record/sqlite_store.rs b/atuin-client/src/record/sqlite_store.rs
index db709f20..8112aa96 100644
--- a/atuin-client/src/record/sqlite_store.rs
+++ b/atuin-client/src/record/sqlite_store.rs
@@ -8,17 +8,20 @@ use std::str::FromStr;
use async_trait::async_trait;
use eyre::{eyre, Result};
use fs_err as fs;
-use futures::TryStreamExt;
+
use sqlx::{
sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow},
Row,
};
-use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
+use atuin_common::record::{
+ EncryptedData, Host, HostId, Record, RecordId, RecordIdx, RecordStatus,
+};
use uuid::Uuid;
use super::store::Store;
+#[derive(Debug, Clone)]
pub struct SqliteStore {
pool: SqlitePool,
}
@@ -38,6 +41,7 @@ impl SqliteStore {
let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())?
.journal_mode(SqliteJournalMode::Wal)
+ .foreign_keys(true)
.create_if_missing(true);
let pool = SqlitePoolOptions::new().connect_with(opts).await?;
@@ -61,14 +65,14 @@ impl SqliteStore {
) -> Result<()> {
// In sqlite, we are "limited" to i64. But that is still fine, until 2262.
sqlx::query(
- "insert or ignore into records(id, host, tag, timestamp, parent, version, data, cek)
+ "insert or ignore into store(id, idx, host, tag, timestamp, version, data, cek)
values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
)
- .bind(r.id.0.as_simple().to_string())
- .bind(r.host.0.as_simple().to_string())
+ .bind(r.id.0.as_hyphenated().to_string())
+ .bind(r.idx as i64)
+ .bind(r.host.id.0.as_hyphenated().to_string())
.bind(r.tag.as_str())
.bind(r.timestamp as i64)
- .bind(r.parent.map(|p| p.0.as_simple().to_string()))
.bind(r.version.as_str())
.bind(r.data.data.as_str())
.bind(r.data.content_encryption_key.as_str())
@@ -79,20 +83,17 @@ impl SqliteStore {
}
fn query_row(row: SqliteRow) -> Record<EncryptedData> {
+ let idx: i64 = row.get("idx");
let timestamp: i64 = row.get("timestamp");
// tbh at this point things are pretty fucked so just panic
let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB");
let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB");
- let parent: Option<&str> = row.get("parent");
-
- let parent = parent
- .map(|parent| Uuid::from_str(parent).expect("invalid parent UUID format in sqlite DB"));
Record {
id: RecordId(id),
- host: HostId(host),
- parent: parent.map(RecordId),
+ idx: idx as u64,
+ host: Host::new(HostId(host)),
timestamp: timestamp as u64,
tag: row.get("tag"),
version: row.get("version"),
@@ -122,8 +123,8 @@ impl Store for SqliteStore {
}
async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>> {
- let res = sqlx::query("select * from records where id = ?1")
- .bind(id.0.as_simple().to_string())
+ let res = sqlx::query("select * from store where store.id = ?1")
+ .bind(id.0.as_hyphenated().to_string())
.map(Self::query_row)
.fetch_one(&self.pool)
.await?;
@@ -131,20 +132,66 @@ impl Store for SqliteStore {
Ok(res)
}
- async fn len(&self, host: HostId, tag: &str) -> Result<u64> {
- let res: (i64,) =
- sqlx::query_as("select count(1) from records where host = ?1 and tag = ?2")
- .bind(host.0.as_simple().to_string())
+ async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
+ let res =
+ sqlx::query("select * from store where host=?1 and tag=?2 order by idx desc limit 1")
+ .bind(host.0.as_hyphenated().to_string())
.bind(tag)
+ .map(Self::query_row)
.fetch_one(&self.pool)
+ .await;
+
+ match res {
+ Err(sqlx::Error::RowNotFound) => Ok(None),
+ Err(e) => Err(eyre!("an error occured: {}", e)),
+ Ok(record) => Ok(Some(record)),
+ }
+ }
+
+ async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
+ self.idx(host, tag, 0).await
+ }
+
+ async fn len(&self, host: HostId, tag: &str) -> Result<u64> {
+ let last = self.last(host, tag).await?;
+
+ if let Some(last) = last {
+ return Ok(last.idx + 1);
+ }
+
+ return Ok(0);
+ }
+
+ async fn next(
+ &self,
+ host: HostId,
+ tag: &str,
+ idx: RecordIdx,
+ limit: u64,
+ ) -> Result<Vec<Record<EncryptedData>>> {
+ let res =
+ sqlx::query("select * from store where idx >= ?1 and host = ?2 and tag = ?3 limit ?4")
+ .bind(idx as i64)
+ .bind(host.0.as_hyphenated().to_string())
+ .bind(tag)
+ .bind(limit as i64)
+ .map(Self::query_row)
+ .fetch_all(&self.pool)
.await?;
- Ok(res.0 as u64)
+ Ok(res)
}
- async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>> {
- let res = sqlx::query("select * from records where parent = ?1")
- .bind(record.id.0.as_simple().to_string())
+ async fn idx(
+ &self,
+ host: HostId,
+ tag: &str,
+ idx: RecordIdx,
+ ) -> Result<Option<Record<EncryptedData>>> {
+ let res = sqlx::query("select * from store where idx = ?1 and host = ?2 and tag = ?3")
+ .bind(idx as i64)
+ .bind(host.0.as_hyphenated().to_string())
+ .bind(tag)
.map(Self::query_row)
.fetch_one(&self.pool)
.await;
@@ -156,58 +203,36 @@ impl Store for SqliteStore {
}
}
- async fn head(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
- let res = sqlx::query(
- "select * from records where host = ?1 and tag = ?2 and parent is null limit 1",
- )
- .bind(host.0.as_simple().to_string())
- .bind(tag)
- .map(Self::query_row)
- .fetch_optional(&self.pool)
- .await?;
+ async fn status(&self) -> Result<RecordStatus> {
+ let mut status = RecordStatus::new();
- Ok(res)
- }
+ let res: Result<Vec<(String, String, i64)>, sqlx::Error> =
+ sqlx::query_as("select host, tag, max(idx) from store group by host, tag")
+ .fetch_all(&self.pool)
+ .await;
- async fn tail(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
- let res = sqlx::query(
- "select * from records rp where tag=?1 and host=?2 and (select count(1) from records where parent=rp.id) = 0;",
- )
- .bind(tag)
- .bind(host.0.as_simple().to_string())
- .map(Self::query_row)
- .fetch_optional(&self.pool)
- .await?;
+ let res = match res {
+ Err(e) => return Err(eyre!("failed to fetch local store status: {}", e)),
+ Ok(v) => v,
+ };
- Ok(res)
- }
+ for i in res {
+ let host = HostId(
+ Uuid::from_str(i.0.as_str()).expect("failed to parse uuid for local store status"),
+ );
- async fn tag_tails(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> {
- let res = sqlx::query(
- "select * from records rp where tag=?1 and (select count(1) from records where parent=rp.id) = 0;",
- )
- .bind(tag)
- .map(Self::query_row)
- .fetch_all(&self.pool)
- .await?;
+ status.set_raw(host, i.1, i.2 as u64);
+ }
- Ok(res)
+ Ok(status)
}
- async fn tail_records(&self) -> Result<RecordIndex> {
- let res = sqlx::query(
- "select host, tag, id from records rp where (select count(1) from records where parent=rp.id) = 0;",
- )
- .map(|row: SqliteRow| {
- let host: Uuid= Uuid::from_str(row.get("host")).expect("invalid uuid in db host");
- let tag: String= row.get("tag");
- let id: Uuid= Uuid::from_str(row.get("id")).expect("invalid uuid in db id");
-
- (HostId(host), tag, RecordId(id))
- })
- .fetch(&self.pool)
- .try_collect()
- .await?;
+ async fn all_tagged(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> {
+ let res = sqlx::query("select * from store where tag = ?1 order by timestamp asc")
+ .bind(tag)
+ .map(Self::query_row)
+ .fetch_all(&self.pool)
+ .await?;
Ok(res)
}
@@ -215,7 +240,7 @@ impl Store for SqliteStore {
#[cfg(test)]
mod tests {
- use atuin_common::record::{EncryptedData, HostId, Record};
+ use atuin_common::record::{EncryptedData, Host, HostId, Record};
use crate::record::{encryption::PASETO_V4, store::Store};
@@ -223,13 +248,14 @@ mod tests {
fn test_record() -> Record<EncryptedData> {
Record::builder()
- .host(HostId(atuin_common::utils::uuid_v7()))
+ .host(Host::new(HostId(atuin_common::utils::uuid_v7())))
.version("v1".into())
.tag(atuin_common::utils::uuid_v7().simple().to_string())
.data(EncryptedData {
data: "1234".into(),
content_encryption_key: "1234".into(),
})
+ .idx(0)
.build()
}
@@ -264,13 +290,49 @@ mod tests {
}
#[tokio::test]
+ async fn last() {
+ let db = SqliteStore::new(":memory:").await.unwrap();
+ let record = test_record();
+ db.push(&record).await.unwrap();
+
+ let last = db
+ .last(record.host.id, record.tag.as_str())
+ .await
+ .expect("failed to get store len");
+
+ assert_eq!(
+ last.unwrap().id,
+ record.id,
+ "expected to get back the same record that was inserted"
+ );
+ }
+
+ #[tokio::test]
+ async fn first() {
+ let db = SqliteStore::new(":memory:").await.unwrap();
+ let record = test_record();
+ db.push(&record).await.unwrap();
+
+ let first = db
+ .first(record.host.id, record.tag.as_str())
+ .await
+ .expect("failed to get store len");
+
+ assert_eq!(
+ first.unwrap().id,
+ record.id,
+ "expected to get back the same record that was inserted"
+ );
+ }
+
+ #[tokio::test]
async fn len() {
let db = SqliteStore::new(":memory:").await.unwrap();
let record = test_record();
db.push(&record).await.unwrap();
let len = db
- .len(record.host, record.tag.as_str())
+ .len(record.host.id, record.tag.as_str())
.await
.expect("failed to get store len");
@@ -290,8 +352,8 @@ mod tests {
db.push(&first).await.unwrap();
db.push(&second).await.unwrap();
- let first_len = db.len(first.host, first.tag.as_str()).await.unwrap();
- let second_len = db.len(second.host, second.tag.as_str()).await.unwrap();
+ let first_len = db.len(first.host.id, first.tag.as_str()).await.unwrap();
+ let second_len = db.len(second.host.id, second.tag.as_str()).await.unwrap();
assert_eq!(first_len, 1, "expected length of 1 after insert");
assert_eq!(second_len, 1, "expected length of 1 after insert");
@@ -305,14 +367,12 @@ mod tests {
db.push(&tail).await.expect("failed to push record");
for _ in 1..100 {
- tail = tail
- .new_child(vec![1, 2, 3, 4])
- .encrypt::<PASETO_V4>(&[0; 32]);
+ tail = tail.append(vec![1, 2, 3, 4]).encrypt::<PASETO_V4>(&[0; 32]);
db.push(&tail).await.unwrap();
}
assert_eq!(
- db.len(tail.host, tail.tag.as_str()).await.unwrap(),
+ db.len(tail.host.id, tail.tag.as_str()).await.unwrap(),
100,
"failed to insert 100 records"
);
@@ -328,50 +388,16 @@ mod tests {
records.push(tail.clone());
for _ in 1..10000 {
- tail = tail.new_child(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]);
+ tail = tail.append(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]);
records.push(tail.clone());
}
db.push_batch(records.iter()).await.unwrap();
assert_eq!(
- db.len(tail.host, tail.tag.as_str()).await.unwrap(),
+ db.len(tail.host.id, tail.tag.as_str()).await.unwrap(),
10000,
"failed to insert 10k records"
);
}
-
- #[tokio::test]
- async fn test_chain() {
- let db = SqliteStore::new(":memory:").await.unwrap();
-
- let mut records: Vec<Record<EncryptedData>> = Vec::with_capacity(1000);
-
- let mut tail = test_record();
- records.push(tail.clone());
-
- for _ in 1..1000 {
- tail = tail.new_child(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]);
- records.push(tail.clone());
- }
-
- db.push_batch(records.iter()).await.unwrap();
-
- let mut record = db
- .head(tail.host, tail.tag.as_str())
- .await
- .expect("in memory sqlite should not fail")
- .expect("entry exists");
-
- let mut count = 1;
-
- while let Some(next) = db.next(&record).await.unwrap() {
- assert_eq!(record.id, next.clone().parent.unwrap());
- record = next;
-
- count += 1;
- }
-
- assert_eq!(count, 1000);
- }
}
diff --git a/atuin-client/src/record/store.rs b/atuin-client/src/record/store.rs
index 45d554ef..a5c156d6 100644
--- a/atuin-client/src/record/store.rs
+++ b/atuin-client/src/record/store.rs
@@ -1,8 +1,7 @@
use async_trait::async_trait;
use eyre::Result;
-use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
-
+use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIdx, RecordStatus};
/// A record store stores records
/// In more detail - we tend to need to process this into _another_ format to actually query it.
/// As is, the record store is intended as the source of truth for arbitratry data, which could
@@ -23,19 +22,30 @@ pub trait Store {
async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>>;
async fn len(&self, host: HostId, tag: &str) -> Result<u64>;
- /// Get the record that follows this record
- async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>>;
+ async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
+ async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
- /// Get the first record for a given host and tag
- async fn head(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
+ /// Get the next `limit` records, after and including the given index
+ async fn next(
+ &self,
+ host: HostId,
+ tag: &str,
+ idx: RecordIdx,
+ limit: u64,
+ ) -> Result<Vec<Record<EncryptedData>>>;
- /// Get the last record for a given host and tag
- async fn tail(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
+ /// Get the first record for a given host and tag
+ async fn idx(
+ &self,
+ host: HostId,
+ tag: &str,
+ idx: RecordIdx,
+ ) -> Result<Option<Record<EncryptedData>>>;
- // Get the last record for all hosts for a given tag, useful for the read path of apps.
- async fn tag_tails(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>>;
+ async fn status(&self) -> Result<RecordStatus>;
- // Get the latest host/tag/record tuple for every set in the store. useful for building an
- // index
- async fn tail_records(&self) -> Result<RecordIndex>;
+ /// Get every start record for a given tag, regardless of host.
+ /// Useful when actually operating on synchronized data, and will often have conflict
+ /// resolution applied.
+ async fn all_tagged(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>>;
}
diff --git a/atuin-client/src/record/sync.rs b/atuin-client/src/record/sync.rs
index 56be0638..2694e0ff 100644
--- a/atuin-client/src/record/sync.rs
+++ b/atuin-client/src/record/sync.rs
@@ -1,27 +1,51 @@
// do a sync :O
+use std::cmp::Ordering;
+
use eyre::Result;
+use thiserror::Error;
use super::store::Store;
use crate::{api_client::Client, settings::Settings};
-use atuin_common::record::{Diff, HostId, RecordId, RecordIndex};
+use atuin_common::record::{Diff, HostId, RecordIdx, RecordStatus};
+
+#[derive(Error, Debug)]
+pub enum SyncError {
+ #[error("the local store is ahead of the remote, but for another host. has remote lost data?")]
+ LocalAheadOtherHost,
+
+ #[error("an issue with the local database occured")]
+ LocalStoreError,
+
+ #[error("something has gone wrong with the sync logic: {msg:?}")]
+ SyncLogicError { msg: String },
+
+ #[error("a request to the sync server failed")]
+ RemoteRequestError,
+}
#[derive(Debug, Eq, PartialEq)]
pub enum Operation {
- // Either upload or download until the tail matches the below
+ // Either upload or download until the states matches the below
Upload {
- tail: RecordId,
+ local: RecordIdx,
+ remote: Option<RecordIdx>,
host: HostId,
tag: String,
},
Download {
- tail: RecordId,
+ local: Option<RecordIdx>,
+ remote: RecordIdx,
+ host: HostId,
+ tag: String,
+ },
+ Noop {
host: HostId,
tag: String,
},
}
-pub async fn diff(settings: &Settings, store: &mut impl Store) -> Result<(Vec<Diff>, RecordIndex)> {
+pub async fn diff(settings: &Settings, store: &impl Store) -> Result<(Vec<Diff>, RecordStatus)> {
let client = Client::new(
&settings.sync_address,
&settings.session_token,
@@ -29,8 +53,8 @@ pub async fn diff(settings: &Settings, store: &mut impl Store) -> Result<(Vec<Di
settings.network_timeout,
)?;
- let local_index = store.tail_records().await?;
- let remote_index = client.record_index().await?;
+ let local_index = store.status().await?;
+ let remote_index = client.record_status().await?;
let diff = local_index.diff(&remote_index);
@@ -41,39 +65,57 @@ pub async fn diff(settings: &Settings, store: &mut impl Store) -> Result<(Vec<Di
// With the store as context, we can determine if a tail exists locally or not and therefore if it needs uploading or download.
// In theory this could be done as a part of the diffing stage, but it's easier to reason
// about and test this way
-pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Operation>> {
+pub async fn operations(
+ diffs: Vec<Diff>,
+ _store: &impl Store,
+) -> Result<Vec<Operation>, SyncError> {
let mut operations = Vec::with_capacity(diffs.len());
for diff in diffs {
- // First, try to fetch the tail
- // If it exists locally, then that means we need to update the remote
- // host until it has the same tail. Ie, upload.
- // If it does not exist locally, that means remote is ahead of us.
- // Therefore, we need to download until our local tail matches
- let record = store.get(diff.tail).await;
-
- let op = if record.is_ok() {
- // if local has the ID, then we should find the actual tail of this
- // store, so we know what we need to update the remote to.
- let tail = store
- .tail(diff.host, diff.tag.as_str())
- .await?
- .expect("failed to fetch last record, expected tag/host to exist");
-
- // TODO(ellie) update the diffing so that it stores the context of the current tail
- // that way, we can determine how much we need to upload.
- // For now just keep uploading until tails match
+ let op = match (diff.local, diff.remote) {
+ // We both have it! Could be either. Compare.
+ (Some(local), Some(remote)) => match local.cmp(&remote) {
+ Ordering::Equal => Operation::Noop {
+ host: diff.host,
+ tag: diff.tag,
+ },
+ Ordering::Greater => Operation::Upload {
+ local,
+ remote: Some(remote),
+ host: diff.host,
+ tag: diff.tag,
+ },
+ Ordering::Less => Operation::Download {
+ local: Some(local),
+ remote,
+ host: diff.host,
+ tag: diff.tag,
+ },
+ },
- Operation::Upload {
- tail: tail.id,
+ // Remote has it, we don't. Gotta be download
+ (None, Some(remote)) => Operation::Download {
+ local: None,
+ remote,
host: diff.host,
tag: diff.tag,
- }
- } else {
- Operation::Download {
- tail: diff.tail,
+ },
+
+ // We have it, remote doesn't. Gotta be upload.
+ (Some(local), None) => Operation::Upload {
+ local,
+ remote: None,
host: diff.host,
tag: diff.tag,
+ },
+
+ // something is pretty fucked.
+ (None, None) => {
+ return Err(SyncError::SyncLogicError {
+ msg: String::from(
+ "diff has nothing for local or remote - (host, tag) does not exist",
+ ),
+ })
}
};
@@ -86,149 +128,130 @@ pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Oper
// with the same properties
operations.sort_by_key(|op| match op {
- Operation::Upload { tail, host, .. } => ("upload", *host, *tail),
- Operation::Download { tail, host, .. } => ("download", *host, *tail),
+ Operation::Noop { host, tag } => (0, *host, tag.clone()),
+
+ Operation::Upload { host, tag, .. } => (1, *host, tag.clone()),
+
+ Operation::Download { host, tag, .. } => (2, *host, tag.clone()),
});
Ok(operations)
}
async fn sync_upload(
- store: &mut impl Store,
- remote_index: &RecordIndex,
+ store: &impl Store,
client: &Client<'_>,
- op: (HostId, String, RecordId),
-) -> Result<i64> {
+ host: HostId,
+ tag: String,
+ local: RecordIdx,
+ remote: Option<RecordIdx>,
+) -> Result<i64, SyncError> {
+ let remote = remote.unwrap_or(0);
+ let expected = local - remote;
let upload_page_size = 100;
- let mut total = 0;
-
- // so. we have an upload operation, with the tail representing the state
- // we want to get the remote to
- let current_tail = remote_index.get(op.0, op.1.clone());
+ let mut progress = 0;
println!(
- "Syncing local {:?}/{}/{:?}, remote has {:?}",
- op.0, op.1, op.2, current_tail
+ "Uploading {} records to {}/{}",
+ expected,
+ host.0.as_simple(),
+ tag
);
- let start = if let Some(current_tail) = current_tail {
- current_tail
- } else {
- store
- .head(op.0, op.1.as_str())
+ // preload with the first entry if remote does not know of this store
+ loop {
+ let page = store
+ .next(host, tag.as_str(), remote + progress, upload_page_size)
.await
- .expect("failed to fetch host/tag head")
- .expect("host/tag not in current index")
- .id
- };
+ .map_err(|e| {
+ error!("failed to read upload page: {e:?}");
- debug!("starting push to remote from: {:?}", start);
+ SyncError::LocalStoreError
+ })?;
- // we have the start point for sync. it is either the head of the store if
- // the remote has no data for it, or the tail that the remote has
- // we need to iterate from the remote tail, and keep going until
- // remote tail = current local tail
+ client.post_records(&page).await.map_err(|e| {
+ error!("failed to post records: {e:?}");
- let mut record = if current_tail.is_some() {
- let r = store.get(start).await.unwrap();
- store.next(&r).await?
- } else {
- Some(store.get(start).await.unwrap())
- };
+ SyncError::RemoteRequestError
+ })?;
- let mut buf = Vec::with_capacity(upload_page_size);
-
- while let Some(r) = record {
- if buf.len() < upload_page_size {
- buf.push(r.clone());
- } else {
- client.post_records(&buf).await?;
+ println!(
+ "uploaded {} to remote, progress {}/{}",
+ page.len(),
+ progress,
+ expected
+ );
+ progress += page.len() as u64;
- // can we reset what we have? len = 0 but keep capacity
- buf = Vec::with_capacity(upload_page_size);
+ if progress >= expected {
+ break;
}
- record = store.next(&r).await?;
-
- total += 1;
}
- if !buf.is_empty() {
- client.post_records(&buf).await?;
- }
-
- Ok(total)
+ Ok(progress as i64)
}
async fn sync_download(
- store: &mut impl Store,
- remote_index: &RecordIndex,
+ store: &impl Store,
client: &Client<'_>,
- op: (HostId, String, RecordId),
-) -> Result<i64> {
- // TODO(ellie): implement variable page sizing like on history sync
- let download_page_size = 1000;
+ host: HostId,
+ tag: String,
+ local: Option<RecordIdx>,
+ remote: RecordIdx,
+) -> Result<i64, SyncError> {
+ let local = local.unwrap_or(0);
+ let expected = remote - local;
+ let download_page_size = 100;
+ let mut progress = 0;
- let mut total = 0;
+ println!(
+ "Downloading {} records from {}/{}",
+ expected,
+ host.0.as_simple(),
+ tag
+ );
- // We know that the remote is ahead of us, so let's keep downloading until both
- // 1) The remote stops returning full pages
- // 2) The tail equals what we expect
- //
- // If (1) occurs without (2), then something is wrong with our index calculation
- // and we should bail.
- let remote_tail = remote_index
- .get(op.0, op.1.clone())
- .expect("remote index does not contain expected tail during download");
- let local_tail = store.tail(op.0, op.1.as_str()).await?;
- //
- // We expect that the operations diff will represent the desired state
- // In this case, that contains the remote tail.
- assert_eq!(remote_tail, op.2);
+ // preload with the first entry if remote does not know of this store
+ loop {
+ let page = client
+ .next_records(host, tag.clone(), local + progress, download_page_size)
+ .await
+ .map_err(|_| SyncError::RemoteRequestError)?;
- println!("Downloading {:?}/{}/{:?} to local", op.0, op.1, op.2);
+ store
+ .push_batch(page.iter())
+ .await
+ .map_err(|_| SyncError::LocalStoreError)?;
- let mut records = client
- .next_records(
- op.0,
- op.1.clone(),
- local_tail.map(|r| r.id),
- download_page_size,
- )
- .await?;
+ println!(
+ "downloaded {} records from remote, progress {}/{}",
+ page.len(),
+ progress,
+ expected
+ );
- while !records.is_empty() {
- total += std::cmp::min(download_page_size, records.len() as u64);
- store.push_batch(records.iter()).await?;
+ progress += page.len() as u64;
- if records.last().unwrap().id == remote_tail {
+ if progress >= expected {
break;
}
-
- records = client
- .next_records(
- op.0,
- op.1.clone(),
- records.last().map(|r| r.id),
- download_page_size,
- )
- .await?;
}
- Ok(total as i64)
+ Ok(progress as i64)
}
pub async fn sync_remote(
operations: Vec<Operation>,
- remote_index: &RecordIndex,
- local_store: &mut impl Store,
+ local_store: &impl Store,
settings: &Settings,
-) -> Result<(i64, i64)> {
+) -> Result<(i64, i64), SyncError> {
let client = Client::new(
&settings.sync_address,
&settings.session_token,
settings.network_connect_timeout,
settings.network_timeout,
- )?;
+ )
+ .expect("failed to create client");
let mut uploaded = 0;
let mut downloaded = 0;
@@ -236,14 +259,23 @@ pub async fn sync_remote(
// this can totally run in parallel, but lets get it working first
for i in operations {
match i {
- Operation::Upload { tail, host, tag } => {
- uploaded +=
- sync_upload(local_store, remote_index, &client, (host, tag, tail)).await?
- }
- Operation::Download { tail, host, tag } => {
- downloaded +=
- sync_download(local_store, remote_index, &client, (host, tag, tail)).await?
+ Operation::Upload {
+ host,
+ tag,
+ local,
+ remote,
+ } => uploaded += sync_upload(local_store, &client, host, tag, local, remote).await?,
+
+ Operation::Download {
+ host,
+ tag,
+ local,
+ remote,
+ } => {
+ downloaded += sync_download(local_store, &client, host, tag, local, remote).await?
}
+
+ Operation::Noop { .. } => continue,
}
}
@@ -264,13 +296,16 @@ mod tests {
fn test_record() -> Record<EncryptedData> {
Record::builder()
- .host(HostId(atuin_common::utils::uuid_v7()))
+ .host(atuin_common::record::Host::new(HostId(
+ atuin_common::utils::uuid_v7(),
+ )))
.version("v1".into())
.tag(atuin_common::utils::uuid_v7().simple().to_string())
.data(EncryptedData {
data: String::new(),
content_encryption_key: String::new(),
})
+ .idx(0)
.build()
}
@@ -296,8 +331,8 @@ mod tests {
remote_store.push(&i).await.unwrap();
}
- let local_index = local_store.tail_records().await.unwrap();
- let remote_index = remote_store.tail_records().await.unwrap();
+ let local_index = local_store.status().await.unwrap();
+ let remote_index = remote_store.status().await.unwrap();
let diff = local_index.diff(&remote_index);
@@ -320,9 +355,10 @@ mod tests {
assert_eq!(
operations[0],
Operation::Upload {
- host: record.host,
+ host: record.host.id,
tag: record.tag,
- tail: record.id
+ local: record.idx,
+ remote: None,
}
);
}
@@ -333,12 +369,14 @@ mod tests {
// another. One upload, one download
let shared_record = test_record();
-
let remote_ahead = test_record();
+
let local_ahead = shared_record
- .new_child(vec![1, 2, 3])
+ .append(vec![1, 2, 3])
.encrypt::<PASETO_V4>(&[0; 32]);
+ assert_eq!(local_ahead.idx, 1);
+
let local = vec![shared_record.clone(), local_ahead.clone()]; // local knows about the already synced, and something newer in the same store
let remote = vec![shared_record.clone(), remote_ahead.clone()]; // remote knows about the already-synced, and one new record in a new store
@@ -350,15 +388,19 @@ mod tests {
assert_eq!(
operations,
vec![
- Operation::Download {
- tail: remote_ahead.id,
- host: remote_ahead.host,
- tag: remote_ahead.tag,
- },
+ // Or in otherwords, local is ahead by one
Operation::Upload {
- tail: local_ahead.id,
- host: local_ahead.host,
+ host: local_ahead.host.id,
tag: local_ahead.tag,
+ local: 1,
+ remote: Some(0),
+ },
+ // Or in other words, remote knows of a record in an entirely new store (tag)
+ Operation::Download {
+ host: remote_ahead.host.id,
+ tag: remote_ahead.tag,
+ local: None,
+ remote: 0,
},
]
);
@@ -371,66 +413,160 @@ mod tests {
// One known only by remote
let shared_record = test_record();
+ let local_only = test_record();
+
+ let local_only_20 = test_record();
+ let local_only_21 = local_only_20
+ .append(vec![1, 2, 3])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+ let local_only_22 = local_only_21
+ .append(vec![1, 2, 3])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+ let local_only_23 = local_only_22
+ .append(vec![1, 2, 3])
+ .encrypt::<PASETO_V4>(&[0; 32]);
- let remote_known = test_record();
- let local_known = test_record();
+ let remote_only = test_record();
+
+ let remote_only_20 = test_record();
+ let remote_only_21 = remote_only_20
+ .append(vec![2, 3, 2])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+ let remote_only_22 = remote_only_21
+ .append(vec![2, 3, 2])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+ let remote_only_23 = remote_only_22
+ .append(vec![2, 3, 2])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+ let remote_only_24 = remote_only_23
+ .append(vec![2, 3, 2])
+ .encrypt::<PASETO_V4>(&[0; 32]);
let second_shared = test_record();
let second_shared_remote_ahead = second_shared
- .new_child(vec![1, 2, 3])
+ .append(vec![1, 2, 3])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+ let second_shared_remote_ahead2 = second_shared_remote_ahead
+ .append(vec![1, 2, 3])
.encrypt::<PASETO_V4>(&[0; 32]);
- let local_ahead = shared_record
- .new_child(vec![1, 2, 3])
+ let third_shared = test_record();
+ let third_shared_local_ahead = third_shared
+ .append(vec![1, 2, 3])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+ let third_shared_local_ahead2 = third_shared_local_ahead
+ .append(vec![1, 2, 3])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+
+ let fourth_shared = test_record();
+ let fourth_shared_remote_ahead = fourth_shared
+ .append(vec![1, 2, 3])
+ .encrypt::<PASETO_V4>(&[0; 32]);
+ let fourth_shared_remote_ahead2 = fourth_shared_remote_ahead
+ .append(vec![1, 2, 3])
.encrypt::<PASETO_V4>(&[0; 32]);
let local = vec![
shared_record.clone(),
second_shared.clone(),
- local_known.clone(),
- local_ahead.clone(),
+ third_shared.clone(),
+ fourth_shared.clone(),
+ fourth_shared_remote_ahead.clone(),
+ // single store, only local has it
+ local_only.clone(),
+ // bigger store, also only known by local
+ local_only_20.clone(),
+ local_only_21.clone(),
+ local_only_22.clone(),
+ local_only_23.clone(),
+ // another shared store, but local is ahead on this one
+ third_shared_local_ahead.clone(),
+ third_shared_local_ahead2.clone(),
];
let remote = vec![
+ remote_only.clone(),
+ remote_only_20.clone(),
+ remote_only_21.clone(),
+ remote_only_22.clone(),
+ remote_only_23.clone(),
+ remote_only_24.clone(),
shared_record.clone(),
second_shared.clone(),
+ third_shared.clone(),
second_shared_remote_ahead.clone(),
- remote_known.clone(),
+ second_shared_remote_ahead2.clone(),
+ fourth_shared.clone(),
+ fourth_shared_remote_ahead.clone(),
+ fourth_shared_remote_ahead2.clone(),
]; // remote knows about the already-synced, and one new record in a new store
let (store, diff) = build_test_diff(local, remote).await;
let operations = sync::operations(diff, &store).await.unwrap();
- assert_eq!(operations.len(), 4);
+ assert_eq!(operations.len(), 7);
let mut result_ops = vec![
+ // We started with a shared record, but the remote knows of two newer records in the
+ // same store
+ Operation::Download {
+ local: Some(0),
+ remote: 2,
+ host: second_shared_remote_ahead.host.id,
+ tag: second_shared_remote_ahead.tag,
+ },
+ // We have a shared record, local knows of the first two but not the last
+ Operation::Download {
+ local: Some(1),
+ remote: 2,
+ host: fourth_shared_remote_ahead2.host.id,
+ tag: fourth_shared_remote_ahead2.tag,
+ },
+ // Remote knows of a store with a single record that local does not have
Operation::Download {
- tail: remote_known.id,
- host: remote_known.host,
- tag: remote_known.tag,
+ local: None,
+ remote: 0,
+ host: remote_only.host.id,
+ tag: remote_only.tag,
},
+ // Remote knows of a store with a bunch of records that local does not have
Operation::Download {
- tail: second_shared_remote_ahead.id,
- host: second_shared.host,
- tag: second_shared.tag,
+ local: None,
+ remote: 4,
+ host: remote_only_20.host.id,
+ tag: remote_only_20.tag,
},
+ // Local knows of a record in a store that remote does not have
Operation::Upload {
- tail: local_ahead.id,
- host: local_ahead.host,
- tag: local_ahead.tag,
+ local: 0,
+ remote: None,
+ host: local_only.host.id,
+ tag: local_only.tag,
},
+ // Local knows of 4 records in a store that remote does not have
Operation::Upload {
- tail: local_known.id,
- host: local_known.host,
- tag: local_known.tag,
+ local: 3,
+ remote: None,
+ host: local_only_20.host.id,
+ tag: local_only_20.tag,
+ },
+ // Local knows of 2 more records in a shared store that remote only has one of
+ Operation::Upload {
+ local: 2,
+ remote: Some(0),
+ host: third_shared.host.id,
+ tag: third_shared.tag,
},
];
result_ops.sort_by_key(|op| match op {
- Operation::Upload { tail, host, .. } => ("upload", *host, *tail),
- Operation::Download { tail, host, .. } => ("download", *host, *tail),
+ Operation::Noop { host, tag } => (0, *host, tag.clone()),
+
+ Operation::Upload { host, tag, .. } => (1, *host, tag.clone()),
+
+ Operation::Download { host, tag, .. } => (2, *host, tag.clone()),
});
- assert_eq!(operations, result_ops);
+ assert_eq!(result_ops, operations);
}
}