aboutsummaryrefslogtreecommitdiffstats
path: root/crates/atuin-kv/src/store.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/atuin-kv/src/store.rs')
-rw-r--r--crates/atuin-kv/src/store.rs214
1 files changed, 214 insertions, 0 deletions
diff --git a/crates/atuin-kv/src/store.rs b/crates/atuin-kv/src/store.rs
new file mode 100644
index 00000000..3394b8c0
--- /dev/null
+++ b/crates/atuin-kv/src/store.rs
@@ -0,0 +1,214 @@
+use std::collections::HashSet;
+
+use eyre::{Result, bail};
+
+use atuin_client::record::sqlite_store::SqliteStore;
+use atuin_client::record::{encryption::PASETO_V4, store::Store};
+use atuin_common::record::{Host, HostId, Record, RecordId, RecordIdx};
+use entry::KvEntry;
+use record::{KV_TAG, KV_VERSION, KvRecord};
+
+use crate::database::Database;
+
+pub mod entry;
+pub mod record;
+
+#[derive(Debug, Clone)]
+pub struct KvStore {
+ pub record_store: SqliteStore,
+ pub kv_db: Database,
+ pub host_id: HostId,
+ pub encryption_key: [u8; 32],
+}
+
+impl KvStore {
+ pub fn new(
+ record_store: SqliteStore,
+ kv_db: Database,
+ host_id: HostId,
+ encryption_key: [u8; 32],
+ ) -> Self {
+ KvStore {
+ record_store,
+ kv_db,
+ host_id,
+ encryption_key,
+ }
+ }
+
+ pub async fn set(&self, namespace: &str, key: &str, value: &str) -> Result<()> {
+ let kv_record = KvRecord::builder()
+ .namespace(namespace.to_string())
+ .key(key.to_string())
+ .value(Some(value.to_string()))
+ .build();
+
+ self.push_record(kv_record).await?;
+
+ let kv = KvEntry::builder()
+ .namespace(namespace.to_string())
+ .key(key.to_string())
+ .value(value.to_string())
+ .build();
+
+ self.kv_db.save(&kv).await?;
+
+ Ok(())
+ }
+
+ pub async fn get(&self, namespace: &str, key: &str) -> Result<Option<String>> {
+ let kv = self.kv_db.load(namespace, key).await?;
+ Ok(kv.map(|kv| kv.value))
+ }
+
+ pub async fn delete(&self, namespace: &str, keys: &[String]) -> Result<()> {
+ for key in keys {
+ let record = KvRecord::builder()
+ .namespace(namespace.to_string())
+ .key(key.to_string())
+ .value(None)
+ .build();
+
+ self.push_record(record).await?;
+ self.kv_db.delete(namespace, key).await?;
+ }
+
+ Ok(())
+ }
+
+ pub async fn list(&self, namespace: Option<&str>) -> Result<Vec<KvEntry>> {
+ let entries = self.kv_db.list(namespace).await?;
+
+ Ok(entries)
+ }
+
+ async fn push_record(&self, record: KvRecord) -> Result<(RecordId, RecordIdx)> {
+ let bytes = record.serialize()?;
+ let idx = self
+ .record_store
+ .last(self.host_id, KV_TAG)
+ .await?
+ .map_or(0, |p| p.idx + 1);
+
+ let record = Record::builder()
+ .host(Host::new(self.host_id))
+ .version(KV_VERSION.to_string())
+ .tag(KV_TAG.to_string())
+ .idx(idx)
+ .data(bytes)
+ .build();
+
+ let id = record.id;
+
+ self.record_store
+ .push(&record.encrypt::<PASETO_V4>(&self.encryption_key))
+ .await?;
+
+ Ok((id, idx))
+ }
+
+ pub async fn build(&self) -> Result<()> {
+ let mut tagged = self.record_store.all_tagged(KV_TAG).await?;
+ tagged.reverse();
+
+ let cached = self.kv_db.list(None).await?;
+
+ let mut visited = HashSet::new();
+
+ // Iterate through all KV records from newest to oldest;
+ // only visit each KV once, inserting or deleting based on the first time we see it
+ for record in tagged {
+ let decrypted = match record.version.as_str() {
+ "v0" | KV_VERSION => record.decrypt::<PASETO_V4>(&self.encryption_key)?,
+ version => bail!("unknown version {version:?}"),
+ };
+
+ let kv = KvRecord::deserialize(&decrypted.data, &decrypted.version)?;
+ let uniq_id = format!("{}.{}", kv.namespace, kv.key);
+
+ if visited.insert(uniq_id) {
+ match kv.value {
+ Some(value) => {
+ self.kv_db
+ .save(
+ &KvEntry::builder()
+ .namespace(kv.namespace.clone())
+ .key(kv.key.clone())
+ .value(value)
+ .build(),
+ )
+ .await?;
+ }
+ None => {
+ self.kv_db
+ .delete(kv.namespace.as_str(), kv.key.as_str())
+ .await?;
+ }
+ }
+ }
+ }
+
+ // Any KVs that were in the cache but not in the tagged list should be deleted;
+ // this should never happen in practice since the cache is always built from the tagged list,
+ // but just in case because ** S O F T W A R E **
+ for kv in cached {
+ if !visited.contains(&format!("{}.{}", kv.namespace, kv.key)) {
+ self.kv_db
+ .delete(kv.namespace.as_str(), kv.key.as_str())
+ .await?;
+ }
+ }
+
+ Ok(())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ async fn setup() -> Result<KvStore> {
+ let record_store = SqliteStore::new("sqlite::memory:", 1.0).await.unwrap();
+ let kv_db = Database::new("sqlite::memory:", 1.0).await.unwrap();
+ let host_id = atuin_common::record::HostId(atuin_common::utils::uuid_v7());
+ let encryption_key = [0; 32];
+ Ok(KvStore::new(record_store, kv_db, host_id, encryption_key))
+ }
+
+ #[tokio::test]
+ async fn test_kv_store() -> Result<()> {
+ let store = setup().await?;
+
+ store.set("test", "key", "value").await.unwrap();
+ let value = store.get("test", "key").await.unwrap();
+ assert_eq!(value, Some("value".to_string()));
+
+ let records = store.record_store.all_tagged(KV_TAG).await?;
+ assert_eq!(records.len(), 1);
+
+ let list = store.list(Some("test")).await.unwrap();
+ let expected = vec![
+ KvEntry::builder()
+ .namespace("test".to_string())
+ .key("key".to_string())
+ .value("value".to_string())
+ .build(),
+ ];
+ assert_eq!(list, expected);
+
+ let ns_list = store.list(None).await.unwrap();
+ assert_eq!(ns_list, expected);
+
+ store
+ .delete("test", &vec!["key".to_string()])
+ .await
+ .unwrap();
+ let value = store.get("test", "key").await.unwrap();
+ assert_eq!(value, None);
+
+ let records = store.record_store.all_tagged(KV_TAG).await?;
+ assert_eq!(records.len(), 2);
+
+ Ok(())
+ }
+}