diff options
| author | Ellie Huxtable <ellie@elliehuxtable.com> | 2024-04-18 16:41:28 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-04-18 16:41:28 +0100 |
| commit | 95cc472037fcb3207b510e67f1a44af4e2a2cae9 (patch) | |
| tree | fc1d3e71d8e0bdb806370e4144fd6f373bcc9c5e /crates/atuin-client/src/record | |
| parent | feat: show preview auto (#1804) (diff) | |
| download | atuin-95cc472037fcb3207b510e67f1a44af4e2a2cae9.zip | |
chore: move crates into crates/ dir (#1958)
I'd like to tidy up the root a little, and it's nice to have all the
rust crates in one place
Diffstat (limited to 'crates/atuin-client/src/record')
| -rw-r--r-- | crates/atuin-client/src/record/encryption.rs | 373 | ||||
| -rw-r--r-- | crates/atuin-client/src/record/mod.rs | 6 | ||||
| -rw-r--r-- | crates/atuin-client/src/record/sqlite_store.rs | 641 | ||||
| -rw-r--r-- | crates/atuin-client/src/record/store.rs | 60 | ||||
| -rw-r--r-- | crates/atuin-client/src/record/sync.rs | 607 |
5 files changed, 1687 insertions, 0 deletions
diff --git a/crates/atuin-client/src/record/encryption.rs b/crates/atuin-client/src/record/encryption.rs new file mode 100644 index 00000000..3ad3be66 --- /dev/null +++ b/crates/atuin-client/src/record/encryption.rs @@ -0,0 +1,373 @@ +use atuin_common::record::{ + AdditionalData, DecryptedData, EncryptedData, Encryption, HostId, RecordId, RecordIdx, +}; +use base64::{engine::general_purpose, Engine}; +use eyre::{ensure, Context, Result}; +use rusty_paserk::{Key, KeyId, Local, PieWrappedKey}; +use rusty_paseto::core::{ + ImplicitAssertion, Key as DataKey, Local as LocalPurpose, Paseto, PasetoNonce, Payload, V4, +}; +use serde::{Deserialize, Serialize}; + +/// Use PASETO V4 Local encryption using the additional data as an implicit assertion. +#[allow(non_camel_case_types)] +pub struct PASETO_V4; + +/* +Why do we use a random content-encryption key? +Originally I was planning on using a derived key for encryption based on additional data. +This would be a lot more secure than using the master key directly. + +However, there's an established norm of using a random key. This scheme might be otherwise known as +- client-side encryption +- envelope encryption +- key wrapping + +A HSM (Hardware Security Module) provider, eg: AWS, Azure, GCP, or even a physical device like a YubiKey +will have some keys that they keep to themselves. These keys never leave their physical hardware. +If they never leave the hardware, then encrypting large amounts of data means giving them the data and waiting. +This is not a practical solution. Instead, generate a unique key for your data, encrypt that using your HSM +and then store that with your data. + +See + - <https://docs.aws.amazon.com/wellarchitected/latest/financial-services-industry-lens/use-envelope-encryption-with-customer-master-keys.html> + - <https://cloud.google.com/kms/docs/envelope-encryption> + - <https://learn.microsoft.com/en-us/azure/storage/blobs/client-side-encryption?tabs=dotnet#encryption-and-decryption-via-the-envelope-technique> + - <https://www.yubico.com/gb/product/yubihsm-2-fips/> + - <https://cheatsheetseries.owasp.org/cheatsheets/Cryptographic_Storage_Cheat_Sheet.html#encrypting-stored-keys> + +Why would we care? In the past we have received some requests for company solutions. If in future we can configure a +KMS service with little effort, then that would solve a lot of issues for their security team. + +Even for personal use, if a user is not comfortable with sharing keys between hosts, +GCP HSM costs $1/month and $0.03 per 10,000 key operations. Assuming an active user runs +1000 atuin records a day, that would only cost them $1 and 10 cent a month. + +Additionally, key rotations are much simpler using this scheme. Rotating a key is as simple as re-encrypting the CEK, and not the message contents. +This makes it very fast to rotate a key in bulk. + +For future reference, with asymmetric encryption, you can encrypt the CEK without the HSM's involvement, but decrypting +will need the HSM. This allows the encryption path to still be extremely fast (no network calls) but downloads/decryption +that happens in the background can make the network calls to the HSM +*/ + +impl Encryption for PASETO_V4 { + fn re_encrypt( + mut data: EncryptedData, + _ad: AdditionalData, + old_key: &[u8; 32], + new_key: &[u8; 32], + ) -> Result<EncryptedData> { + let cek = Self::decrypt_cek(data.content_encryption_key, old_key)?; + data.content_encryption_key = Self::encrypt_cek(cek, new_key); + Ok(data) + } + + fn encrypt(data: DecryptedData, ad: AdditionalData, key: &[u8; 32]) -> EncryptedData { + // generate a random key for this entry + // aka content-encryption-key (CEK) + let random_key = Key::<V4, Local>::new_os_random(); + + // encode the implicit assertions + let assertions = Assertions::from(ad).encode(); + + // build the payload and encrypt the token + let payload = serde_json::to_string(&AtuinPayload { + data: general_purpose::URL_SAFE_NO_PAD.encode(data.0), + }) + .expect("json encoding can't fail"); + let nonce = DataKey::<32>::try_new_random().expect("could not source from random"); + let nonce = PasetoNonce::<V4, LocalPurpose>::from(&nonce); + + let token = Paseto::<V4, LocalPurpose>::builder() + .set_payload(Payload::from(payload.as_str())) + .set_implicit_assertion(ImplicitAssertion::from(assertions.as_str())) + .try_encrypt(&random_key.into(), &nonce) + .expect("error encrypting atuin data"); + + EncryptedData { + data: token, + content_encryption_key: Self::encrypt_cek(random_key, key), + } + } + + fn decrypt(data: EncryptedData, ad: AdditionalData, key: &[u8; 32]) -> Result<DecryptedData> { + let token = data.data; + let cek = Self::decrypt_cek(data.content_encryption_key, key)?; + + // encode the implicit assertions + let assertions = Assertions::from(ad).encode(); + + // decrypt the payload with the footer and implicit assertions + let payload = Paseto::<V4, LocalPurpose>::try_decrypt( + &token, + &cek.into(), + None, + ImplicitAssertion::from(&*assertions), + ) + .context("could not decrypt entry")?; + + let payload: AtuinPayload = serde_json::from_str(&payload)?; + let data = general_purpose::URL_SAFE_NO_PAD.decode(payload.data)?; + Ok(DecryptedData(data)) + } +} + +impl PASETO_V4 { + fn decrypt_cek(wrapped_cek: String, key: &[u8; 32]) -> Result<Key<V4, Local>> { + let wrapping_key = Key::<V4, Local>::from_bytes(*key); + + // let wrapping_key = PasetoSymmetricKey::from(Key::from(key)); + + let AtuinFooter { kid, wpk } = serde_json::from_str(&wrapped_cek) + .context("wrapped cek did not contain the correct contents")?; + + // check that the wrapping key matches the required key to decrypt. + // In future, we could support multiple keys and use this key to + // look up the key rather than only allow one key. + // For now though we will only support the one key and key rotation will + // have to be a hard reset + let current_kid = wrapping_key.to_id(); + + ensure!( + current_kid == kid, + "attempting to decrypt with incorrect key. currently using {current_kid}, expecting {kid}" + ); + + // decrypt the random key + Ok(wpk.unwrap_key(&wrapping_key)?) + } + + fn encrypt_cek(cek: Key<V4, Local>, key: &[u8; 32]) -> String { + // aka key-encryption-key (KEK) + let wrapping_key = Key::<V4, Local>::from_bytes(*key); + + // wrap the random key so we can decrypt it later + let wrapped_cek = AtuinFooter { + wpk: cek.wrap_pie(&wrapping_key), + kid: wrapping_key.to_id(), + }; + serde_json::to_string(&wrapped_cek).expect("could not serialize wrapped cek") + } +} + +#[derive(Serialize, Deserialize)] +struct AtuinPayload { + data: String, +} + +#[derive(Serialize, Deserialize)] +/// Well-known footer claims for decrypting. This is not encrypted but is stored in the record. +/// <https://github.com/paseto-standard/paseto-spec/blob/master/docs/02-Implementation-Guide/04-Claims.md#optional-footer-claims> +struct AtuinFooter { + /// Wrapped key + wpk: PieWrappedKey<V4, Local>, + /// ID of the key which was used to wrap + kid: KeyId<V4, Local>, +} + +/// Used in the implicit assertions. This is not encrypted and not stored in the data blob. +// This cannot be changed, otherwise it breaks the authenticated encryption. +#[derive(Debug, Copy, Clone, Serialize)] +struct Assertions<'a> { + id: &'a RecordId, + idx: &'a RecordIdx, + version: &'a str, + tag: &'a str, + host: &'a HostId, +} + +impl<'a> From<AdditionalData<'a>> for Assertions<'a> { + fn from(ad: AdditionalData<'a>) -> Self { + Self { + id: ad.id, + version: ad.version, + tag: ad.tag, + host: ad.host, + idx: ad.idx, + } + } +} + +impl Assertions<'_> { + fn encode(&self) -> String { + serde_json::to_string(self).expect("could not serialize implicit assertions") + } +} + +#[cfg(test)] +mod tests { + use atuin_common::{ + record::{Host, Record}, + utils::uuid_v7, + }; + + use super::*; + + #[test] + fn round_trip() { + let key = Key::<V4, Local>::new_os_random(); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + version: "v0", + tag: "kv", + host: &HostId(uuid_v7()), + idx: &0, + }; + + let data = DecryptedData(vec![1, 2, 3, 4]); + + let encrypted = PASETO_V4::encrypt(data.clone(), ad, &key.to_bytes()); + let decrypted = PASETO_V4::decrypt(encrypted, ad, &key.to_bytes()).unwrap(); + assert_eq!(decrypted, data); + } + + #[test] + fn same_entry_different_output() { + let key = Key::<V4, Local>::new_os_random(); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + version: "v0", + tag: "kv", + host: &HostId(uuid_v7()), + idx: &0, + }; + + let data = DecryptedData(vec![1, 2, 3, 4]); + + let encrypted = PASETO_V4::encrypt(data.clone(), ad, &key.to_bytes()); + let encrypted2 = PASETO_V4::encrypt(data, ad, &key.to_bytes()); + + assert_ne!( + encrypted.data, encrypted2.data, + "re-encrypting the same contents should have different output due to key randomization" + ); + } + + #[test] + fn cannot_decrypt_different_key() { + let key = Key::<V4, Local>::new_os_random(); + let fake_key = Key::<V4, Local>::new_os_random(); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + version: "v0", + tag: "kv", + host: &HostId(uuid_v7()), + idx: &0, + }; + + let data = DecryptedData(vec![1, 2, 3, 4]); + + let encrypted = PASETO_V4::encrypt(data, ad, &key.to_bytes()); + let _ = PASETO_V4::decrypt(encrypted, ad, &fake_key.to_bytes()).unwrap_err(); + } + + #[test] + fn cannot_decrypt_different_id() { + let key = Key::<V4, Local>::new_os_random(); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + version: "v0", + tag: "kv", + host: &HostId(uuid_v7()), + idx: &0, + }; + + let data = DecryptedData(vec![1, 2, 3, 4]); + + let encrypted = PASETO_V4::encrypt(data, ad, &key.to_bytes()); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + ..ad + }; + let _ = PASETO_V4::decrypt(encrypted, ad, &key.to_bytes()).unwrap_err(); + } + + #[test] + fn re_encrypt_round_trip() { + let key1 = Key::<V4, Local>::new_os_random(); + let key2 = Key::<V4, Local>::new_os_random(); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + version: "v0", + tag: "kv", + host: &HostId(uuid_v7()), + idx: &0, + }; + + let data = DecryptedData(vec![1, 2, 3, 4]); + + let encrypted1 = PASETO_V4::encrypt(data.clone(), ad, &key1.to_bytes()); + let encrypted2 = + PASETO_V4::re_encrypt(encrypted1.clone(), ad, &key1.to_bytes(), &key2.to_bytes()) + .unwrap(); + + // we only re-encrypt the content keys + assert_eq!(encrypted1.data, encrypted2.data); + assert_ne!( + encrypted1.content_encryption_key, + encrypted2.content_encryption_key + ); + + let decrypted = PASETO_V4::decrypt(encrypted2, ad, &key2.to_bytes()).unwrap(); + + assert_eq!(decrypted, data); + } + + #[test] + fn full_record_round_trip() { + let key = [0x55; 32]; + let record = Record::builder() + .id(RecordId(uuid_v7())) + .version("v0".to_owned()) + .tag("kv".to_owned()) + .host(Host::new(HostId(uuid_v7()))) + .timestamp(1687244806000000) + .data(DecryptedData(vec![1, 2, 3, 4])) + .idx(0) + .build(); + + let encrypted = record.encrypt::<PASETO_V4>(&key); + + assert!(!encrypted.data.data.is_empty()); + assert!(!encrypted.data.content_encryption_key.is_empty()); + + let decrypted = encrypted.decrypt::<PASETO_V4>(&key).unwrap(); + + assert_eq!(decrypted.data.0, [1, 2, 3, 4]); + } + + #[test] + fn full_record_round_trip_fail() { + let key = [0x55; 32]; + let record = Record::builder() + .id(RecordId(uuid_v7())) + .version("v0".to_owned()) + .tag("kv".to_owned()) + .host(Host::new(HostId(uuid_v7()))) + .timestamp(1687244806000000) + .data(DecryptedData(vec![1, 2, 3, 4])) + .idx(0) + .build(); + + let encrypted = record.encrypt::<PASETO_V4>(&key); + + let mut enc1 = encrypted.clone(); + enc1.host = Host::new(HostId(uuid_v7())); + let _ = enc1 + .decrypt::<PASETO_V4>(&key) + .expect_err("tampering with the host should result in auth failure"); + + let mut enc2 = encrypted; + enc2.id = RecordId(uuid_v7()); + let _ = enc2 + .decrypt::<PASETO_V4>(&key) + .expect_err("tampering with the id should result in auth failure"); + } +} diff --git a/crates/atuin-client/src/record/mod.rs b/crates/atuin-client/src/record/mod.rs new file mode 100644 index 00000000..c40fd395 --- /dev/null +++ b/crates/atuin-client/src/record/mod.rs @@ -0,0 +1,6 @@ +pub mod encryption; +pub mod sqlite_store; +pub mod store; + +#[cfg(feature = "sync")] +pub mod sync; diff --git a/crates/atuin-client/src/record/sqlite_store.rs b/crates/atuin-client/src/record/sqlite_store.rs new file mode 100644 index 00000000..31de311b --- /dev/null +++ b/crates/atuin-client/src/record/sqlite_store.rs @@ -0,0 +1,641 @@ +// Here we are using sqlite as a pretty dumb store, and will not be running any complex queries. +// Multiple stores of multiple types are all stored in one chonky table (for now), and we just index +// by tag/host + +use std::str::FromStr; +use std::{path::Path, time::Duration}; + +use async_trait::async_trait; +use eyre::{eyre, Result}; +use fs_err as fs; + +use sqlx::{ + sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow}, + Row, +}; + +use atuin_common::record::{ + EncryptedData, Host, HostId, Record, RecordId, RecordIdx, RecordStatus, +}; +use uuid::Uuid; + +use super::encryption::PASETO_V4; +use super::store::Store; + +#[derive(Debug, Clone)] +pub struct SqliteStore { + pool: SqlitePool, +} + +impl SqliteStore { + pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> { + let path = path.as_ref(); + + debug!("opening sqlite database at {:?}", path); + + let create = !path.exists(); + if create { + if let Some(dir) = path.parent() { + fs::create_dir_all(dir)?; + } + } + + let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())? + .journal_mode(SqliteJournalMode::Wal) + .foreign_keys(true) + .create_if_missing(true); + + let pool = SqlitePoolOptions::new() + .acquire_timeout(Duration::from_secs_f64(timeout)) + .connect_with(opts) + .await?; + + Self::setup_db(&pool).await?; + + Ok(Self { pool }) + } + + async fn setup_db(pool: &SqlitePool) -> Result<()> { + debug!("running sqlite database setup"); + + sqlx::migrate!("./record-migrations").run(pool).await?; + + Ok(()) + } + + async fn save_raw( + tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, + r: &Record<EncryptedData>, + ) -> Result<()> { + // In sqlite, we are "limited" to i64. But that is still fine, until 2262. + sqlx::query( + "insert or ignore into store(id, idx, host, tag, timestamp, version, data, cek) + values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", + ) + .bind(r.id.0.as_hyphenated().to_string()) + .bind(r.idx as i64) + .bind(r.host.id.0.as_hyphenated().to_string()) + .bind(r.tag.as_str()) + .bind(r.timestamp as i64) + .bind(r.version.as_str()) + .bind(r.data.data.as_str()) + .bind(r.data.content_encryption_key.as_str()) + .execute(&mut **tx) + .await?; + + Ok(()) + } + + fn query_row(row: SqliteRow) -> Record<EncryptedData> { + let idx: i64 = row.get("idx"); + let timestamp: i64 = row.get("timestamp"); + + // tbh at this point things are pretty fucked so just panic + let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB"); + let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB"); + + Record { + id: RecordId(id), + idx: idx as u64, + host: Host::new(HostId(host)), + timestamp: timestamp as u64, + tag: row.get("tag"), + version: row.get("version"), + data: EncryptedData { + data: row.get("data"), + content_encryption_key: row.get("cek"), + }, + } + } + + async fn load_all(&self) -> Result<Vec<Record<EncryptedData>>> { + let res = sqlx::query("select * from store ") + .map(Self::query_row) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } +} + +#[async_trait] +impl Store for SqliteStore { + async fn push_batch( + &self, + records: impl Iterator<Item = &Record<EncryptedData>> + Send + Sync, + ) -> Result<()> { + let mut tx = self.pool.begin().await?; + + for record in records { + Self::save_raw(&mut tx, record).await?; + } + + tx.commit().await?; + + Ok(()) + } + + async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>> { + let res = sqlx::query("select * from store where store.id = ?1") + .bind(id.0.as_hyphenated().to_string()) + .map(Self::query_row) + .fetch_one(&self.pool) + .await?; + + Ok(res) + } + + async fn delete(&self, id: RecordId) -> Result<()> { + sqlx::query("delete from store where id = ?1") + .bind(id.0.as_hyphenated().to_string()) + .execute(&self.pool) + .await?; + + Ok(()) + } + + async fn delete_all(&self) -> Result<()> { + sqlx::query("delete from store").execute(&self.pool).await?; + + Ok(()) + } + + async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> { + let res = + sqlx::query("select * from store where host=?1 and tag=?2 order by idx desc limit 1") + .bind(host.0.as_hyphenated().to_string()) + .bind(tag) + .map(Self::query_row) + .fetch_one(&self.pool) + .await; + + match res { + Err(sqlx::Error::RowNotFound) => Ok(None), + Err(e) => Err(eyre!("an error occurred: {}", e)), + Ok(record) => Ok(Some(record)), + } + } + + async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> { + self.idx(host, tag, 0).await + } + + async fn len_all(&self) -> Result<u64> { + let res: Result<(i64,), sqlx::Error> = sqlx::query_as("select count(*) from store") + .fetch_one(&self.pool) + .await; + match res { + Err(e) => Err(eyre!("failed to fetch local store len: {}", e)), + Ok(v) => Ok(v.0 as u64), + } + } + + async fn len_tag(&self, tag: &str) -> Result<u64> { + let res: Result<(i64,), sqlx::Error> = + sqlx::query_as("select count(*) from store where tag=?1") + .bind(tag) + .fetch_one(&self.pool) + .await; + match res { + Err(e) => Err(eyre!("failed to fetch local store len: {}", e)), + Ok(v) => Ok(v.0 as u64), + } + } + + async fn len(&self, host: HostId, tag: &str) -> Result<u64> { + let last = self.last(host, tag).await?; + + if let Some(last) = last { + return Ok(last.idx + 1); + } + + return Ok(0); + } + + async fn next( + &self, + host: HostId, + tag: &str, + idx: RecordIdx, + limit: u64, + ) -> Result<Vec<Record<EncryptedData>>> { + let res = + sqlx::query("select * from store where idx >= ?1 and host = ?2 and tag = ?3 limit ?4") + .bind(idx as i64) + .bind(host.0.as_hyphenated().to_string()) + .bind(tag) + .bind(limit as i64) + .map(Self::query_row) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn idx( + &self, + host: HostId, + tag: &str, + idx: RecordIdx, + ) -> Result<Option<Record<EncryptedData>>> { + let res = sqlx::query("select * from store where idx = ?1 and host = ?2 and tag = ?3") + .bind(idx as i64) + .bind(host.0.as_hyphenated().to_string()) + .bind(tag) + .map(Self::query_row) + .fetch_one(&self.pool) + .await; + + match res { + Err(sqlx::Error::RowNotFound) => Ok(None), + Err(e) => Err(eyre!("an error occurred: {}", e)), + Ok(v) => Ok(Some(v)), + } + } + + async fn status(&self) -> Result<RecordStatus> { + let mut status = RecordStatus::new(); + + let res: Result<Vec<(String, String, i64)>, sqlx::Error> = + sqlx::query_as("select host, tag, max(idx) from store group by host, tag") + .fetch_all(&self.pool) + .await; + + let res = match res { + Err(e) => return Err(eyre!("failed to fetch local store status: {}", e)), + Ok(v) => v, + }; + + for i in res { + let host = HostId( + Uuid::from_str(i.0.as_str()).expect("failed to parse uuid for local store status"), + ); + + status.set_raw(host, i.1, i.2 as u64); + } + + Ok(status) + } + + async fn all_tagged(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> { + let res = sqlx::query("select * from store where tag = ?1 order by timestamp asc") + .bind(tag) + .map(Self::query_row) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + /// Reencrypt every single item in this store with a new key + /// Be careful - this may mess with sync. + async fn re_encrypt(&self, old_key: &[u8; 32], new_key: &[u8; 32]) -> Result<()> { + // Load all the records + // In memory like some of the other code here + // This will never be called in a hot loop, and only under the following circumstances + // 1. The user has logged into a new account, with a new key. They are unlikely to have a + // lot of data + // 2. The user has encountered some sort of issue, and runs a maintenance command that + // invokes this + let all = self.load_all().await?; + + let re_encrypted = all + .into_iter() + .map(|record| record.re_encrypt::<PASETO_V4>(old_key, new_key)) + .collect::<Result<Vec<_>>>()?; + + // next up, we delete all the old data and reinsert the new stuff + // do it in one transaction, so if anything fails we rollback OK + + let mut tx = self.pool.begin().await?; + + let res = sqlx::query("delete from store").execute(&mut *tx).await?; + + let rows = res.rows_affected(); + debug!("deleted {rows} rows"); + + // don't call push_batch, as it will start its own transaction + // call the underlying save_raw + + for record in re_encrypted { + Self::save_raw(&mut tx, &record).await?; + } + + tx.commit().await?; + + Ok(()) + } + + /// Verify that every record in this store can be decrypted with the current key + /// Someday maybe also check each tag/record can be deserialized, but not for now. + async fn verify(&self, key: &[u8; 32]) -> Result<()> { + let all = self.load_all().await?; + + all.into_iter() + .map(|record| record.decrypt::<PASETO_V4>(key)) + .collect::<Result<Vec<_>>>()?; + + Ok(()) + } + + /// Verify that every record in this store can be decrypted with the current key + /// Someday maybe also check each tag/record can be deserialized, but not for now. + async fn purge(&self, key: &[u8; 32]) -> Result<()> { + let all = self.load_all().await?; + + for record in all.iter() { + match record.clone().decrypt::<PASETO_V4>(key) { + Ok(_) => continue, + Err(_) => { + println!( + "Failed to decrypt {}, deleting", + record.id.0.as_hyphenated() + ); + + self.delete(record.id).await?; + } + } + } + + Ok(()) + } +} + +#[cfg(test)] +pub(crate) fn test_sqlite_store_timeout() -> f64 { + std::env::var("ATUIN_TEST_SQLITE_STORE_TIMEOUT") + .ok() + .and_then(|x| x.parse().ok()) + .unwrap_or(0.1) +} + +#[cfg(test)] +mod tests { + use atuin_common::{ + record::{DecryptedData, EncryptedData, Host, HostId, Record}, + utils::uuid_v7, + }; + + use crate::{ + encryption::generate_encoded_key, + record::{encryption::PASETO_V4, store::Store}, + }; + + use super::{test_sqlite_store_timeout, SqliteStore}; + + fn test_record() -> Record<EncryptedData> { + Record::builder() + .host(Host::new(HostId(atuin_common::utils::uuid_v7()))) + .version("v1".into()) + .tag(atuin_common::utils::uuid_v7().simple().to_string()) + .data(EncryptedData { + data: "1234".into(), + content_encryption_key: "1234".into(), + }) + .idx(0) + .build() + } + + #[tokio::test] + async fn create_db() { + let db = SqliteStore::new(":memory:", test_sqlite_store_timeout()).await; + + assert!( + db.is_ok(), + "db could not be created, {:?}", + db.err().unwrap() + ); + } + + #[tokio::test] + async fn push_record() { + let db = SqliteStore::new(":memory:", test_sqlite_store_timeout()) + .await + .unwrap(); + let record = test_record(); + + db.push(&record).await.expect("failed to insert record"); + } + + #[tokio::test] + async fn get_record() { + let db = SqliteStore::new(":memory:", test_sqlite_store_timeout()) + .await + .unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let new_record = db.get(record.id).await.expect("failed to fetch record"); + + assert_eq!(record, new_record, "records are not equal"); + } + + #[tokio::test] + async fn last() { + let db = SqliteStore::new(":memory:", test_sqlite_store_timeout()) + .await + .unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let last = db + .last(record.host.id, record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!( + last.unwrap().id, + record.id, + "expected to get back the same record that was inserted" + ); + } + + #[tokio::test] + async fn first() { + let db = SqliteStore::new(":memory:", test_sqlite_store_timeout()) + .await + .unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let first = db + .first(record.host.id, record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!( + first.unwrap().id, + record.id, + "expected to get back the same record that was inserted" + ); + } + + #[tokio::test] + async fn len() { + let db = SqliteStore::new(":memory:", test_sqlite_store_timeout()) + .await + .unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let len = db + .len(record.host.id, record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!(len, 1, "expected length of 1 after insert"); + } + + #[tokio::test] + async fn len_tag() { + let db = SqliteStore::new(":memory:", test_sqlite_store_timeout()) + .await + .unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let len = db + .len_tag(record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!(len, 1, "expected length of 1 after insert"); + } + + #[tokio::test] + async fn len_different_tags() { + let db = SqliteStore::new(":memory:", test_sqlite_store_timeout()) + .await + .unwrap(); + + // these have different tags, so the len should be the same + // we model multiple stores within one database + // new store = new tag = independent length + let first = test_record(); + let second = test_record(); + + db.push(&first).await.unwrap(); + db.push(&second).await.unwrap(); + + let first_len = db.len(first.host.id, first.tag.as_str()).await.unwrap(); + let second_len = db.len(second.host.id, second.tag.as_str()).await.unwrap(); + + assert_eq!(first_len, 1, "expected length of 1 after insert"); + assert_eq!(second_len, 1, "expected length of 1 after insert"); + } + + #[tokio::test] + async fn append_a_bunch() { + let db = SqliteStore::new(":memory:", test_sqlite_store_timeout()) + .await + .unwrap(); + + let mut tail = test_record(); + db.push(&tail).await.expect("failed to push record"); + + for _ in 1..100 { + tail = tail.append(vec![1, 2, 3, 4]).encrypt::<PASETO_V4>(&[0; 32]); + db.push(&tail).await.unwrap(); + } + + assert_eq!( + db.len(tail.host.id, tail.tag.as_str()).await.unwrap(), + 100, + "failed to insert 100 records" + ); + + assert_eq!( + db.len_tag(tail.tag.as_str()).await.unwrap(), + 100, + "failed to insert 100 records" + ); + } + + #[tokio::test] + async fn append_a_big_bunch() { + let db = SqliteStore::new(":memory:", test_sqlite_store_timeout()) + .await + .unwrap(); + + let mut records: Vec<Record<EncryptedData>> = Vec::with_capacity(10000); + + let mut tail = test_record(); + records.push(tail.clone()); + + for _ in 1..10000 { + tail = tail.append(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]); + records.push(tail.clone()); + } + + db.push_batch(records.iter()).await.unwrap(); + + assert_eq!( + db.len(tail.host.id, tail.tag.as_str()).await.unwrap(), + 10000, + "failed to insert 10k records" + ); + } + + #[tokio::test] + async fn re_encrypt() { + let store = SqliteStore::new(":memory:", test_sqlite_store_timeout()) + .await + .unwrap(); + let (key, _) = generate_encoded_key().unwrap(); + let data = vec![0u8, 1u8, 2u8, 3u8]; + let host_id = HostId(uuid_v7()); + + for i in 0..10 { + let record = Record::builder() + .host(Host::new(host_id)) + .version(String::from("test")) + .tag(String::from("test")) + .idx(i) + .data(DecryptedData(data.clone())) + .build(); + + let record = record.encrypt::<PASETO_V4>(&key.into()); + store + .push(&record) + .await + .expect("failed to push encrypted record"); + } + + // first, check that we can decrypt the data with the current key + let all = store.all_tagged("test").await.unwrap(); + + assert_eq!(all.len(), 10, "failed to fetch all records"); + + for record in all { + let decrypted = record.decrypt::<PASETO_V4>(&key.into()).unwrap(); + assert_eq!(decrypted.data.0, data); + } + + // reencrypt the store, then check if + // 1) it cannot be decrypted with the old key + // 2) it can be decrypted with the new key + + let (new_key, _) = generate_encoded_key().unwrap(); + store + .re_encrypt(&key.into(), &new_key.into()) + .await + .expect("failed to re-encrypt store"); + + let all = store.all_tagged("test").await.unwrap(); + + for record in all.iter() { + let decrypted = record.clone().decrypt::<PASETO_V4>(&key.into()); + assert!( + decrypted.is_err(), + "did not get error decrypting with old key after re-encrypt" + ) + } + + for record in all { + let decrypted = record.decrypt::<PASETO_V4>(&new_key.into()).unwrap(); + assert_eq!(decrypted.data.0, data); + } + + assert_eq!(store.len(host_id, "test").await.unwrap(), 10); + } +} diff --git a/crates/atuin-client/src/record/store.rs b/crates/atuin-client/src/record/store.rs new file mode 100644 index 00000000..49ca4968 --- /dev/null +++ b/crates/atuin-client/src/record/store.rs @@ -0,0 +1,60 @@ +use async_trait::async_trait; +use eyre::Result; + +use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIdx, RecordStatus}; + +/// A record store stores records +/// In more detail - we tend to need to process this into _another_ format to actually query it. +/// As is, the record store is intended as the source of truth for arbitrary data, which could +/// be shell history, kvs, etc. +#[async_trait] +pub trait Store { + // Push a record + async fn push(&self, record: &Record<EncryptedData>) -> Result<()> { + self.push_batch(std::iter::once(record)).await + } + + // Push a batch of records, all in one transaction + async fn push_batch( + &self, + records: impl Iterator<Item = &Record<EncryptedData>> + Send + Sync, + ) -> Result<()>; + + async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>>; + + async fn delete(&self, id: RecordId) -> Result<()>; + async fn delete_all(&self) -> Result<()>; + + async fn len_all(&self) -> Result<u64>; + async fn len(&self, host: HostId, tag: &str) -> Result<u64>; + async fn len_tag(&self, tag: &str) -> Result<u64>; + + async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>; + async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>; + + async fn re_encrypt(&self, old_key: &[u8; 32], new_key: &[u8; 32]) -> Result<()>; + async fn verify(&self, key: &[u8; 32]) -> Result<()>; + async fn purge(&self, key: &[u8; 32]) -> Result<()>; + + /// Get the next `limit` records, after and including the given index + async fn next( + &self, + host: HostId, + tag: &str, + idx: RecordIdx, + limit: u64, + ) -> Result<Vec<Record<EncryptedData>>>; + + /// Get the first record for a given host and tag + async fn idx( + &self, + host: HostId, + tag: &str, + idx: RecordIdx, + ) -> Result<Option<Record<EncryptedData>>>; + + async fn status(&self) -> Result<RecordStatus>; + + /// Get all records for a given tag + async fn all_tagged(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>>; +} diff --git a/crates/atuin-client/src/record/sync.rs b/crates/atuin-client/src/record/sync.rs new file mode 100644 index 00000000..234c6442 --- /dev/null +++ b/crates/atuin-client/src/record/sync.rs @@ -0,0 +1,607 @@ +// do a sync :O +use std::{cmp::Ordering, fmt::Write}; + +use eyre::Result; +use thiserror::Error; + +use super::store::Store; +use crate::{api_client::Client, settings::Settings}; + +use atuin_common::record::{Diff, HostId, RecordId, RecordIdx, RecordStatus}; +use indicatif::{ProgressBar, ProgressState, ProgressStyle}; + +#[derive(Error, Debug)] +pub enum SyncError { + #[error("the local store is ahead of the remote, but for another host. has remote lost data?")] + LocalAheadOtherHost, + + #[error("an issue with the local database occurred: {msg:?}")] + LocalStoreError { msg: String }, + + #[error("something has gone wrong with the sync logic: {msg:?}")] + SyncLogicError { msg: String }, + + #[error("operational error: {msg:?}")] + OperationalError { msg: String }, + + #[error("a request to the sync server failed: {msg:?}")] + RemoteRequestError { msg: String }, +} + +#[derive(Debug, Eq, PartialEq)] +pub enum Operation { + // Either upload or download until the states matches the below + Upload { + local: RecordIdx, + remote: Option<RecordIdx>, + host: HostId, + tag: String, + }, + Download { + local: Option<RecordIdx>, + remote: RecordIdx, + host: HostId, + tag: String, + }, + Noop { + host: HostId, + tag: String, + }, +} + +pub async fn diff( + settings: &Settings, + store: &impl Store, +) -> Result<(Vec<Diff>, RecordStatus), SyncError> { + let client = Client::new( + &settings.sync_address, + &settings.session_token, + settings.network_connect_timeout, + settings.network_timeout, + ) + .map_err(|e| SyncError::OperationalError { msg: e.to_string() })?; + + let local_index = store + .status() + .await + .map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?; + + let remote_index = client + .record_status() + .await + .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?; + + let diff = local_index.diff(&remote_index); + + Ok((diff, remote_index)) +} + +// Take a diff, along with a local store, and resolve it into a set of operations. +// With the store as context, we can determine if a tail exists locally or not and therefore if it needs uploading or download. +// In theory this could be done as a part of the diffing stage, but it's easier to reason +// about and test this way +pub async fn operations( + diffs: Vec<Diff>, + _store: &impl Store, +) -> Result<Vec<Operation>, SyncError> { + let mut operations = Vec::with_capacity(diffs.len()); + + for diff in diffs { + let op = match (diff.local, diff.remote) { + // We both have it! Could be either. Compare. + (Some(local), Some(remote)) => match local.cmp(&remote) { + Ordering::Equal => Operation::Noop { + host: diff.host, + tag: diff.tag, + }, + Ordering::Greater => Operation::Upload { + local, + remote: Some(remote), + host: diff.host, + tag: diff.tag, + }, + Ordering::Less => Operation::Download { + local: Some(local), + remote, + host: diff.host, + tag: diff.tag, + }, + }, + + // Remote has it, we don't. Gotta be download + (None, Some(remote)) => Operation::Download { + local: None, + remote, + host: diff.host, + tag: diff.tag, + }, + + // We have it, remote doesn't. Gotta be upload. + (Some(local), None) => Operation::Upload { + local, + remote: None, + host: diff.host, + tag: diff.tag, + }, + + // something is pretty fucked. + (None, None) => { + return Err(SyncError::SyncLogicError { + msg: String::from( + "diff has nothing for local or remote - (host, tag) does not exist", + ), + }) + } + }; + + operations.push(op); + } + + // sort them - purely so we have a stable testing order, and can rely on + // same input = same output + // We can sort by ID so long as we continue to use UUIDv7 or something + // with the same properties + + operations.sort_by_key(|op| match op { + Operation::Noop { host, tag } => (0, *host, tag.clone()), + + Operation::Upload { host, tag, .. } => (1, *host, tag.clone()), + + Operation::Download { host, tag, .. } => (2, *host, tag.clone()), + }); + + Ok(operations) +} + +async fn sync_upload( + store: &impl Store, + client: &Client<'_>, + host: HostId, + tag: String, + local: RecordIdx, + remote: Option<RecordIdx>, +) -> Result<i64, SyncError> { + let remote = remote.unwrap_or(0); + let expected = local - remote; + let upload_page_size = 100; + let mut progress = 0; + + let pb = ProgressBar::new(expected); + pb.set_style(ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {human_pos}/{human_len} ({eta})") + .unwrap() + .with_key("eta", |state: &ProgressState, w: &mut dyn Write| write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()) + .progress_chars("#>-")); + + println!( + "Uploading {} records to {}/{}", + expected, + host.0.as_simple(), + tag + ); + + // preload with the first entry if remote does not know of this store + loop { + let page = store + .next(host, tag.as_str(), remote + progress, upload_page_size) + .await + .map_err(|e| { + error!("failed to read upload page: {e:?}"); + + SyncError::LocalStoreError { msg: e.to_string() } + })?; + + client.post_records(&page).await.map_err(|e| { + error!("failed to post records: {e:?}"); + + SyncError::RemoteRequestError { msg: e.to_string() } + })?; + + pb.set_position(progress); + progress += page.len() as u64; + + if progress >= expected { + break; + } + } + + pb.finish_with_message("Uploaded records"); + + Ok(progress as i64) +} + +async fn sync_download( + store: &impl Store, + client: &Client<'_>, + host: HostId, + tag: String, + local: Option<RecordIdx>, + remote: RecordIdx, +) -> Result<Vec<RecordId>, SyncError> { + let local = local.unwrap_or(0); + let expected = remote - local; + let download_page_size = 100; + let mut progress = 0; + let mut ret = Vec::new(); + + println!( + "Downloading {} records from {}/{}", + expected, + host.0.as_simple(), + tag + ); + + let pb = ProgressBar::new(expected); + pb.set_style(ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {human_pos}/{human_len} ({eta})") + .unwrap() + .with_key("eta", |state: &ProgressState, w: &mut dyn Write| write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()) + .progress_chars("#>-")); + + // preload with the first entry if remote does not know of this store + loop { + let page = client + .next_records(host, tag.clone(), local + progress, download_page_size) + .await + .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?; + + store + .push_batch(page.iter()) + .await + .map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?; + + ret.extend(page.iter().map(|f| f.id)); + + pb.set_position(progress); + progress += page.len() as u64; + + if progress >= expected { + break; + } + } + + pb.finish_with_message("Downloaded records"); + + Ok(ret) +} + +pub async fn sync_remote( + operations: Vec<Operation>, + local_store: &impl Store, + settings: &Settings, +) -> Result<(i64, Vec<RecordId>), SyncError> { + let client = Client::new( + &settings.sync_address, + &settings.session_token, + settings.network_connect_timeout, + settings.network_timeout, + ) + .expect("failed to create client"); + + let mut uploaded = 0; + let mut downloaded = Vec::new(); + + // this can totally run in parallel, but lets get it working first + for i in operations { + match i { + Operation::Upload { + host, + tag, + local, + remote, + } => uploaded += sync_upload(local_store, &client, host, tag, local, remote).await?, + + Operation::Download { + host, + tag, + local, + remote, + } => { + let mut d = sync_download(local_store, &client, host, tag, local, remote).await?; + downloaded.append(&mut d) + } + + Operation::Noop { .. } => continue, + } + } + + Ok((uploaded, downloaded)) +} + +pub async fn sync( + settings: &Settings, + store: &impl Store, +) -> Result<(i64, Vec<RecordId>), SyncError> { + let (diff, _) = diff(settings, store).await?; + let operations = operations(diff, store).await?; + let (uploaded, downloaded) = sync_remote(operations, store, settings).await?; + + Ok((uploaded, downloaded)) +} + +#[cfg(test)] +mod tests { + use atuin_common::record::{Diff, EncryptedData, HostId, Record}; + use pretty_assertions::assert_eq; + + use crate::record::{ + encryption::PASETO_V4, + sqlite_store::{test_sqlite_store_timeout, SqliteStore}, + store::Store, + sync::{self, Operation}, + }; + + fn test_record() -> Record<EncryptedData> { + Record::builder() + .host(atuin_common::record::Host::new(HostId( + atuin_common::utils::uuid_v7(), + ))) + .version("v1".into()) + .tag(atuin_common::utils::uuid_v7().simple().to_string()) + .data(EncryptedData { + data: String::new(), + content_encryption_key: String::new(), + }) + .idx(0) + .build() + } + + // Take a list of local records, and a list of remote records. + // Return the local database, and a diff of local/remote, ready to build + // ops + async fn build_test_diff( + local_records: Vec<Record<EncryptedData>>, + remote_records: Vec<Record<EncryptedData>>, + ) -> (SqliteStore, Vec<Diff>) { + let local_store = SqliteStore::new(":memory:", test_sqlite_store_timeout()) + .await + .expect("failed to open in memory sqlite"); + let remote_store = SqliteStore::new(":memory:", test_sqlite_store_timeout()) + .await + .expect("failed to open in memory sqlite"); // "remote" + + for i in local_records { + local_store.push(&i).await.unwrap(); + } + + for i in remote_records { + remote_store.push(&i).await.unwrap(); + } + + let local_index = local_store.status().await.unwrap(); + let remote_index = remote_store.status().await.unwrap(); + + let diff = local_index.diff(&remote_index); + + (local_store, diff) + } + + #[tokio::test] + async fn test_basic_diff() { + // a diff where local is ahead of remote. nothing else. + + let record = test_record(); + let (store, diff) = build_test_diff(vec![record.clone()], vec![]).await; + + assert_eq!(diff.len(), 1); + + let operations = sync::operations(diff, &store).await.unwrap(); + + assert_eq!(operations.len(), 1); + + assert_eq!( + operations[0], + Operation::Upload { + host: record.host.id, + tag: record.tag, + local: record.idx, + remote: None, + } + ); + } + + #[tokio::test] + async fn build_two_way_diff() { + // a diff where local is ahead of remote for one, and remote for + // another. One upload, one download + + let shared_record = test_record(); + let remote_ahead = test_record(); + + let local_ahead = shared_record + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + + assert_eq!(local_ahead.idx, 1); + + let local = vec![shared_record.clone(), local_ahead.clone()]; // local knows about the already synced, and something newer in the same store + let remote = vec![shared_record.clone(), remote_ahead.clone()]; // remote knows about the already-synced, and one new record in a new store + + let (store, diff) = build_test_diff(local, remote).await; + let operations = sync::operations(diff, &store).await.unwrap(); + + assert_eq!(operations.len(), 2); + + assert_eq!( + operations, + vec![ + // Or in otherwords, local is ahead by one + Operation::Upload { + host: local_ahead.host.id, + tag: local_ahead.tag, + local: 1, + remote: Some(0), + }, + // Or in other words, remote knows of a record in an entirely new store (tag) + Operation::Download { + host: remote_ahead.host.id, + tag: remote_ahead.tag, + local: None, + remote: 0, + }, + ] + ); + } + + #[tokio::test] + async fn build_complex_diff() { + // One shared, ahead but known only by remote + // One known only by local + // One known only by remote + + let shared_record = test_record(); + let local_only = test_record(); + + let local_only_20 = test_record(); + let local_only_21 = local_only_20 + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + let local_only_22 = local_only_21 + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + let local_only_23 = local_only_22 + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + + let remote_only = test_record(); + + let remote_only_20 = test_record(); + let remote_only_21 = remote_only_20 + .append(vec![2, 3, 2]) + .encrypt::<PASETO_V4>(&[0; 32]); + let remote_only_22 = remote_only_21 + .append(vec![2, 3, 2]) + .encrypt::<PASETO_V4>(&[0; 32]); + let remote_only_23 = remote_only_22 + .append(vec![2, 3, 2]) + .encrypt::<PASETO_V4>(&[0; 32]); + let remote_only_24 = remote_only_23 + .append(vec![2, 3, 2]) + .encrypt::<PASETO_V4>(&[0; 32]); + + let second_shared = test_record(); + let second_shared_remote_ahead = second_shared + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + let second_shared_remote_ahead2 = second_shared_remote_ahead + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + + let third_shared = test_record(); + let third_shared_local_ahead = third_shared + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + let third_shared_local_ahead2 = third_shared_local_ahead + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + + let fourth_shared = test_record(); + let fourth_shared_remote_ahead = fourth_shared + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + let fourth_shared_remote_ahead2 = fourth_shared_remote_ahead + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + + let local = vec![ + shared_record.clone(), + second_shared.clone(), + third_shared.clone(), + fourth_shared.clone(), + fourth_shared_remote_ahead.clone(), + // single store, only local has it + local_only.clone(), + // bigger store, also only known by local + local_only_20.clone(), + local_only_21.clone(), + local_only_22.clone(), + local_only_23.clone(), + // another shared store, but local is ahead on this one + third_shared_local_ahead.clone(), + third_shared_local_ahead2.clone(), + ]; + + let remote = vec![ + remote_only.clone(), + remote_only_20.clone(), + remote_only_21.clone(), + remote_only_22.clone(), + remote_only_23.clone(), + remote_only_24.clone(), + shared_record.clone(), + second_shared.clone(), + third_shared.clone(), + second_shared_remote_ahead.clone(), + second_shared_remote_ahead2.clone(), + fourth_shared.clone(), + fourth_shared_remote_ahead.clone(), + fourth_shared_remote_ahead2.clone(), + ]; // remote knows about the already-synced, and one new record in a new store + + let (store, diff) = build_test_diff(local, remote).await; + let operations = sync::operations(diff, &store).await.unwrap(); + + assert_eq!(operations.len(), 7); + + let mut result_ops = vec![ + // We started with a shared record, but the remote knows of two newer records in the + // same store + Operation::Download { + local: Some(0), + remote: 2, + host: second_shared_remote_ahead.host.id, + tag: second_shared_remote_ahead.tag, + }, + // We have a shared record, local knows of the first two but not the last + Operation::Download { + local: Some(1), + remote: 2, + host: fourth_shared_remote_ahead2.host.id, + tag: fourth_shared_remote_ahead2.tag, + }, + // Remote knows of a store with a single record that local does not have + Operation::Download { + local: None, + remote: 0, + host: remote_only.host.id, + tag: remote_only.tag, + }, + // Remote knows of a store with a bunch of records that local does not have + Operation::Download { + local: None, + remote: 4, + host: remote_only_20.host.id, + tag: remote_only_20.tag, + }, + // Local knows of a record in a store that remote does not have + Operation::Upload { + local: 0, + remote: None, + host: local_only.host.id, + tag: local_only.tag, + }, + // Local knows of 4 records in a store that remote does not have + Operation::Upload { + local: 3, + remote: None, + host: local_only_20.host.id, + tag: local_only_20.tag, + }, + // Local knows of 2 more records in a shared store that remote only has one of + Operation::Upload { + local: 2, + remote: Some(0), + host: third_shared.host.id, + tag: third_shared.tag, + }, + ]; + + result_ops.sort_by_key(|op| match op { + Operation::Noop { host, tag } => (0, *host, tag.clone()), + + Operation::Upload { host, tag, .. } => (1, *host, tag.clone()), + + Operation::Download { host, tag, .. } => (2, *host, tag.clone()), + }); + + assert_eq!(result_ops, operations); + } +} |
