diff options
Diffstat (limited to 'atuin-client')
| -rw-r--r-- | atuin-client/Cargo.toml | 1 | ||||
| -rw-r--r-- | atuin-client/src/kv.rs | 95 |
2 files changed, 81 insertions, 15 deletions
diff --git a/atuin-client/Cargo.toml b/atuin-client/Cargo.toml index e00dc910..8147ddca 100644 --- a/atuin-client/Cargo.toml +++ b/atuin-client/Cargo.toml @@ -51,7 +51,6 @@ sql-builder = "3" lazy_static = "1" memchr = "2.5" rmp = { version = "0.8.11" } -rmp-serde = { version = "1.1.1" } typed-builder = "0.14.0" # sync diff --git a/atuin-client/src/kv.rs b/atuin-client/src/kv.rs index 35e8852e..1fe90b6c 100644 --- a/atuin-client/src/kv.rs +++ b/atuin-client/src/kv.rs @@ -1,5 +1,4 @@ -use eyre::Result; -use serde::{Deserialize, Serialize}; +use eyre::{bail, ensure, eyre, Result}; use crate::record::store::Store; use crate::settings::Settings; @@ -7,7 +6,7 @@ use crate::settings::Settings; const KV_VERSION: &str = "v0"; const KV_TAG: &str = "kv"; -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct KvRecord { pub namespace: String, pub key: String, @@ -16,9 +15,55 @@ pub struct KvRecord { impl KvRecord { pub fn serialize(&self) -> Result<Vec<u8>> { - let buf = rmp_serde::to_vec(self)?; + use rmp::encode; - Ok(buf) + let mut output = vec![]; + + // INFO: ensure this is updated when adding new fields + encode::write_array_len(&mut output, 3)?; + + encode::write_str(&mut output, &self.namespace)?; + encode::write_str(&mut output, &self.key)?; + encode::write_str(&mut output, &self.value)?; + + Ok(output) + } + + pub fn deserialize(data: &[u8], version: &str) -> Result<Self> { + use rmp::decode; + + fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { + eyre!("{err:?}") + } + + match version { + KV_VERSION => { + let mut bytes = decode::Bytes::new(data); + + let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; + ensure!(nfields == 3, "too many entries in v0 kv record"); + + let bytes = bytes.remaining_slice(); + + let (namespace, bytes) = + decode::read_str_from_slice(bytes).map_err(error_report)?; + let (key, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (value, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + if !bytes.is_empty() { + bail!("trailing bytes in encoded kvrecord. malformed") + } + + Ok(KvRecord { + namespace: namespace.to_owned(), + key: key.to_owned(), + value: value.to_owned(), + }) + } + _ => { + bail!("unknown version {version:?}") + } + } } } @@ -90,22 +135,44 @@ impl KvStore { let Some(mut record) = store.last(host_id.as_str(), KV_TAG).await? else { return Ok(None); }; - let kv: KvRecord = rmp_serde::from_slice(&record.data)?; - - if kv.key == key && kv.namespace == namespace { - return Ok(Some(kv)); - } - - while let Some(parent) = record.parent { - record = store.get(parent.as_str()).await?; - let kv: KvRecord = rmp_serde::from_slice(&record.data)?; + loop { + let kv = KvRecord::deserialize(&record.data, &record.version)?; if kv.key == key && kv.namespace == namespace { return Ok(Some(kv)); } + + if let Some(parent) = record.parent { + record = store.get(parent.as_str()).await?; + } else { + break; + } } // if we get here, then... we didn't find the record with that key :( Ok(None) } } + +#[cfg(test)] +mod tests { + use super::{KvRecord, KV_VERSION}; + + #[test] + fn encode_decode() { + let kv = KvRecord { + namespace: "foo".to_owned(), + key: "bar".to_owned(), + value: "baz".to_owned(), + }; + let snapshot = [ + 0x93, 0xa3, b'f', b'o', b'o', 0xa3, b'b', b'a', b'r', 0xa3, b'b', b'a', b'z', + ]; + + let encoded = kv.serialize().unwrap(); + let decoded = KvRecord::deserialize(&encoded, KV_VERSION).unwrap(); + + assert_eq!(encoded, &snapshot); + assert_eq!(decoded, kv); + } +} |
