diff options
Diffstat (limited to 'atuin-client/src/kv.rs')
| -rw-r--r-- | atuin-client/src/kv.rs | 265 |
1 files changed, 0 insertions, 265 deletions
diff --git a/atuin-client/src/kv.rs b/atuin-client/src/kv.rs deleted file mode 100644 index fb26cadc..00000000 --- a/atuin-client/src/kv.rs +++ /dev/null @@ -1,265 +0,0 @@ -use std::collections::BTreeMap; - -use atuin_common::record::{DecryptedData, Host, HostId}; -use eyre::{bail, ensure, eyre, Result}; -use serde::Deserialize; - -use crate::record::encryption::PASETO_V4; -use crate::record::store::Store; - -const KV_VERSION: &str = "v0"; -const KV_TAG: &str = "kv"; -const KV_VAL_MAX_LEN: usize = 100 * 1024; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct KvRecord { - pub namespace: String, - pub key: String, - pub value: String, -} - -impl KvRecord { - pub fn serialize(&self) -> Result<DecryptedData> { - use rmp::encode; - - let mut output = vec![]; - - // INFO: ensure this is updated when adding new fields - encode::write_array_len(&mut output, 3)?; - - encode::write_str(&mut output, &self.namespace)?; - encode::write_str(&mut output, &self.key)?; - encode::write_str(&mut output, &self.value)?; - - Ok(DecryptedData(output)) - } - - pub fn deserialize(data: &DecryptedData, version: &str) -> Result<Self> { - use rmp::decode; - - fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { - eyre!("{err:?}") - } - - match version { - KV_VERSION => { - let mut bytes = decode::Bytes::new(&data.0); - - let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; - ensure!(nfields == 3, "too many entries in v0 kv record"); - - let bytes = bytes.remaining_slice(); - - let (namespace, bytes) = - decode::read_str_from_slice(bytes).map_err(error_report)?; - let (key, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (value, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - - if !bytes.is_empty() { - bail!("trailing bytes in encoded kvrecord. malformed") - } - - Ok(KvRecord { - namespace: namespace.to_owned(), - key: key.to_owned(), - value: value.to_owned(), - }) - } - _ => { - bail!("unknown version {version:?}") - } - } - } -} - -#[derive(Debug, Clone, Deserialize)] -pub struct KvStore; - -impl Default for KvStore { - fn default() -> Self { - Self::new() - } -} - -impl KvStore { - // will want to init the actual kv store when that is done - pub fn new() -> KvStore { - KvStore {} - } - - pub async fn set( - &self, - store: &(impl Store + Send + Sync), - encryption_key: &[u8; 32], - host_id: HostId, - namespace: &str, - key: &str, - value: &str, - ) -> Result<()> { - if value.len() > KV_VAL_MAX_LEN { - return Err(eyre!( - "kv value too large: max len {} bytes", - KV_VAL_MAX_LEN - )); - } - - let record = KvRecord { - namespace: namespace.to_string(), - key: key.to_string(), - value: value.to_string(), - }; - - let bytes = record.serialize()?; - - let idx = store - .last(host_id, KV_TAG) - .await? - .map_or(0, |entry| entry.idx + 1); - - let record = atuin_common::record::Record::builder() - .host(Host::new(host_id)) - .version(KV_VERSION.to_string()) - .tag(KV_TAG.to_string()) - .idx(idx) - .data(bytes) - .build(); - - store - .push(&record.encrypt::<PASETO_V4>(encryption_key)) - .await?; - - Ok(()) - } - - // TODO: setup an actual kv store, rebuild func, and do not pass the main store in here as - // well. - pub async fn get( - &self, - store: &impl Store, - encryption_key: &[u8; 32], - namespace: &str, - key: &str, - ) -> Result<Option<KvRecord>> { - // TODO: don't rebuild every time... - let map = self.build_kv(store, encryption_key).await?; - - let res = map.get(namespace); - - if let Some(ns) = res { - let value = ns.get(key); - - Ok(value.cloned()) - } else { - Ok(None) - } - } - - // Build a kv map out of the linked list kv store - // Map is Namespace -> Key -> Value - // TODO(ellie): "cache" this into a real kv structure, which we can - // use as a write-through cache to avoid constant rebuilds. - pub async fn build_kv( - &self, - store: &impl Store, - encryption_key: &[u8; 32], - ) -> Result<BTreeMap<String, BTreeMap<String, KvRecord>>> { - let mut map = BTreeMap::new(); - - // TODO: maybe don't load the entire tag into memory to build the kv - // we can be smart about it and only load values since the last build - // or, iterate/paginate - let tagged = store.all_tagged(KV_TAG).await?; - - // iterate through all tags and play each KV record at a time - // this is "last write wins" - // probably good enough for now, but revisit in future - for record in tagged { - let decrypted = match record.version.as_str() { - KV_VERSION => record.decrypt::<PASETO_V4>(encryption_key)?, - version => bail!("unknown version {version:?}"), - }; - - let kv = KvRecord::deserialize(&decrypted.data, KV_VERSION)?; - - let ns = map - .entry(kv.namespace.clone()) - .or_insert_with(BTreeMap::new); - - ns.insert(kv.key.clone(), kv); - } - - Ok(map) - } -} - -#[cfg(test)] -mod tests { - use crypto_secretbox::{KeyInit, XSalsa20Poly1305}; - use rand::rngs::OsRng; - - use crate::record::sqlite_store::{test_sqlite_store_timeout, SqliteStore}; - - use super::{KvRecord, KvStore, KV_VERSION}; - - #[test] - fn encode_decode() { - let kv = KvRecord { - namespace: "foo".to_owned(), - key: "bar".to_owned(), - value: "baz".to_owned(), - }; - let snapshot = [ - 0x93, 0xa3, b'f', b'o', b'o', 0xa3, b'b', b'a', b'r', 0xa3, b'b', b'a', b'z', - ]; - - let encoded = kv.serialize().unwrap(); - let decoded = KvRecord::deserialize(&encoded, KV_VERSION).unwrap(); - - assert_eq!(encoded.0, &snapshot); - assert_eq!(decoded, kv); - } - - #[tokio::test] - async fn build_kv() { - let mut store = SqliteStore::new(":memory:", test_sqlite_store_timeout()) - .await - .unwrap(); - let kv = KvStore::new(); - let key: [u8; 32] = XSalsa20Poly1305::generate_key(&mut OsRng).into(); - let host_id = atuin_common::record::HostId(atuin_common::utils::uuid_v7()); - - kv.set(&mut store, &key, host_id, "test-kv", "foo", "bar") - .await - .unwrap(); - - kv.set(&mut store, &key, host_id, "test-kv", "1", "2") - .await - .unwrap(); - - let map = kv.build_kv(&store, &key).await.unwrap(); - - assert_eq!( - *map.get("test-kv") - .expect("map namespace not set") - .get("foo") - .expect("map key not set"), - KvRecord { - namespace: String::from("test-kv"), - key: String::from("foo"), - value: String::from("bar") - } - ); - - assert_eq!( - *map.get("test-kv") - .expect("map namespace not set") - .get("1") - .expect("map key not set"), - KvRecord { - namespace: String::from("test-kv"), - key: String::from("1"), - value: String::from("2") - } - ); - } -} |
