diff options
| author | Michelle Tilley <michelle@michelletilley.net> | 2025-05-06 08:36:32 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-05-06 08:36:32 -0700 |
| commit | a1433e0cefe3ad001d5473faf4312c25bdeea968 (patch) | |
| tree | ee8bc10e1438641338b8ef7f5de00a52e6c7f074 /crates/atuin-kv/src | |
| parent | chore(deps): update minspan to 0.1.5 (#2729) (diff) | |
| download | atuin-a1433e0cefe3ad001d5473faf4312c25bdeea968.zip | |
feat: Implement KV as a write-through cache (#2732)
Diffstat (limited to 'crates/atuin-kv/src')
| -rw-r--r-- | crates/atuin-kv/src/database.rs | 229 | ||||
| -rw-r--r-- | crates/atuin-kv/src/lib.rs | 2 | ||||
| -rw-r--r-- | crates/atuin-kv/src/store.rs | 214 | ||||
| -rw-r--r-- | crates/atuin-kv/src/store/entry.rs | 8 | ||||
| -rw-r--r-- | crates/atuin-kv/src/store/record.rs | 159 |
5 files changed, 612 insertions, 0 deletions
diff --git a/crates/atuin-kv/src/database.rs b/crates/atuin-kv/src/database.rs new file mode 100644 index 00000000..ad9226de --- /dev/null +++ b/crates/atuin-kv/src/database.rs @@ -0,0 +1,229 @@ +use std::{path::Path, str::FromStr, time::Duration}; + +use atuin_common::utils; +use sqlx::{ + Result, Row, + sqlite::{ + SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow, + SqliteSynchronous, + }, +}; +use tokio::fs; +use tracing::debug; + +use crate::store::entry::KvEntry; + +#[derive(Debug, Clone)] +pub struct Database { + pub pool: SqlitePool, +} + +impl Database { + pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> { + let path = path.as_ref(); + debug!("opening KV sqlite database at {:?}", path); + + if utils::broken_symlink(path) { + eprintln!( + "Atuin: KV sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement." + ); + std::process::exit(1); + } + + if !path.exists() { + if let Some(dir) = path.parent() { + fs::create_dir_all(dir).await?; + } + } + + let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())? + .journal_mode(SqliteJournalMode::Wal) + .optimize_on_close(true, None) + .synchronous(SqliteSynchronous::Normal) + .with_regexp() + .foreign_keys(true) + .create_if_missing(true); + + let pool = SqlitePoolOptions::new() + .acquire_timeout(Duration::from_secs_f64(timeout)) + .connect_with(opts) + .await?; + + Self::setup_db(&pool).await?; + Ok(Self { pool }) + } + + pub async fn sqlite_version(&self) -> Result<String> { + sqlx::query_scalar("SELECT sqlite_version()") + .fetch_one(&self.pool) + .await + } + + async fn setup_db(pool: &SqlitePool) -> Result<()> { + debug!("running sqlite database setup"); + + sqlx::migrate!("./migrations").run(pool).await?; + + Ok(()) + } + + async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, e: &KvEntry) -> Result<()> { + sqlx::query( + "insert into kv(namespace, key, value) + values(?1, ?2, ?3) + on conflict(namespace, key) do update set + namespace = excluded.namespace, + key = excluded.key, + value = excluded.value", + ) + .bind(e.namespace.as_str()) + .bind(e.key.as_str()) + .bind(e.value.as_str()) + .execute(&mut **tx) + .await?; + + Ok(()) + } + + async fn delete_raw( + tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, + namespace: &str, + key: &str, + ) -> Result<()> { + sqlx::query("delete from kv where namespace = ?1 and key = ?2") + .bind(namespace) + .bind(key) + .execute(&mut **tx) + .await?; + Ok(()) + } + + pub async fn save(&self, e: &KvEntry) -> Result<()> { + debug!("saving kv entry to sqlite"); + let mut tx = self.pool.begin().await?; + Self::save_raw(&mut tx, e).await?; + tx.commit().await?; + + Ok(()) + } + + pub async fn delete(&self, namespace: &str, key: &str) -> Result<()> { + debug!("deleting kv entry {namespace}/{key}"); + + let mut tx = self.pool.begin().await?; + Self::delete_raw(&mut tx, namespace, key).await?; + tx.commit().await?; + + Ok(()) + } + + fn query_kv_entry(row: SqliteRow) -> KvEntry { + let namespace = row.get("namespace"); + let key = row.get("key"); + let value = row.get("value"); + + KvEntry::builder() + .namespace(namespace) + .key(key) + .value(value) + .build() + } + + pub async fn load(&self, namespace: &str, key: &str) -> Result<Option<KvEntry>> { + debug!("loading kv entry {namespace}.{key}"); + + let res = sqlx::query("select * from kv where namespace = ?1 and key = ?2") + .bind(namespace) + .bind(key) + .map(Self::query_kv_entry) + .fetch_optional(&self.pool) + .await?; + + Ok(res) + } + + pub async fn list(&self, namespace: Option<&str>) -> Result<Vec<KvEntry>> { + debug!("listing kv entries"); + + let res = if let Some(namespace) = namespace { + sqlx::query("select * from kv where namespace = ?1 order by key asc") + .bind(namespace) + .map(Self::query_kv_entry) + .fetch_all(&self.pool) + .await? + } else { + sqlx::query("select * from kv order by namespace, key asc") + .map(Self::query_kv_entry) + .fetch_all(&self.pool) + .await? + }; + + Ok(res) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[tokio::test] + async fn test_list() { + let db = Database::new("sqlite::memory:", 1.0).await.unwrap(); + let scripts = db.list(None).await.unwrap(); + assert_eq!(scripts.len(), 0); + + let entry = KvEntry::builder() + .namespace("test".to_string()) + .key("test".to_string()) + .value("test".to_string()) + .build(); + + db.save(&entry).await.unwrap(); + + let entries = db.list(None).await.unwrap(); + assert_eq!(entries.len(), 1); + assert_eq!(entries[0].namespace, "test"); + assert_eq!(entries[0].key, "test"); + assert_eq!(entries[0].value, "test"); + } + + #[tokio::test] + async fn test_save_load() { + let db = Database::new("sqlite::memory:", 1.0).await.unwrap(); + + let entry = KvEntry::builder() + .namespace("test".to_string()) + .key("test".to_string()) + .value("test".to_string()) + .build(); + + db.save(&entry).await.unwrap(); + + let loaded = db + .load(&entry.namespace, &entry.key) + .await + .unwrap() + .unwrap(); + + assert_eq!(loaded, entry); + } + + #[tokio::test] + async fn test_delete() { + let db = Database::new("sqlite::memory:", 1.0).await.unwrap(); + + let entry = KvEntry::builder() + .namespace("test".to_string()) + .key("test".to_string()) + .value("test".to_string()) + .build(); + + db.save(&entry).await.unwrap(); + + assert_eq!(db.list(None).await.unwrap().len(), 1); + db.delete(&entry.namespace, &entry.key).await.unwrap(); + + let loaded = db.list(None).await.unwrap(); + assert_eq!(loaded.len(), 0); + } +} diff --git a/crates/atuin-kv/src/lib.rs b/crates/atuin-kv/src/lib.rs new file mode 100644 index 00000000..ad57b6ac --- /dev/null +++ b/crates/atuin-kv/src/lib.rs @@ -0,0 +1,2 @@ +pub mod database; +pub mod store; diff --git a/crates/atuin-kv/src/store.rs b/crates/atuin-kv/src/store.rs new file mode 100644 index 00000000..3394b8c0 --- /dev/null +++ b/crates/atuin-kv/src/store.rs @@ -0,0 +1,214 @@ +use std::collections::HashSet; + +use eyre::{Result, bail}; + +use atuin_client::record::sqlite_store::SqliteStore; +use atuin_client::record::{encryption::PASETO_V4, store::Store}; +use atuin_common::record::{Host, HostId, Record, RecordId, RecordIdx}; +use entry::KvEntry; +use record::{KV_TAG, KV_VERSION, KvRecord}; + +use crate::database::Database; + +pub mod entry; +pub mod record; + +#[derive(Debug, Clone)] +pub struct KvStore { + pub record_store: SqliteStore, + pub kv_db: Database, + pub host_id: HostId, + pub encryption_key: [u8; 32], +} + +impl KvStore { + pub fn new( + record_store: SqliteStore, + kv_db: Database, + host_id: HostId, + encryption_key: [u8; 32], + ) -> Self { + KvStore { + record_store, + kv_db, + host_id, + encryption_key, + } + } + + pub async fn set(&self, namespace: &str, key: &str, value: &str) -> Result<()> { + let kv_record = KvRecord::builder() + .namespace(namespace.to_string()) + .key(key.to_string()) + .value(Some(value.to_string())) + .build(); + + self.push_record(kv_record).await?; + + let kv = KvEntry::builder() + .namespace(namespace.to_string()) + .key(key.to_string()) + .value(value.to_string()) + .build(); + + self.kv_db.save(&kv).await?; + + Ok(()) + } + + pub async fn get(&self, namespace: &str, key: &str) -> Result<Option<String>> { + let kv = self.kv_db.load(namespace, key).await?; + Ok(kv.map(|kv| kv.value)) + } + + pub async fn delete(&self, namespace: &str, keys: &[String]) -> Result<()> { + for key in keys { + let record = KvRecord::builder() + .namespace(namespace.to_string()) + .key(key.to_string()) + .value(None) + .build(); + + self.push_record(record).await?; + self.kv_db.delete(namespace, key).await?; + } + + Ok(()) + } + + pub async fn list(&self, namespace: Option<&str>) -> Result<Vec<KvEntry>> { + let entries = self.kv_db.list(namespace).await?; + + Ok(entries) + } + + async fn push_record(&self, record: KvRecord) -> Result<(RecordId, RecordIdx)> { + let bytes = record.serialize()?; + let idx = self + .record_store + .last(self.host_id, KV_TAG) + .await? + .map_or(0, |p| p.idx + 1); + + let record = Record::builder() + .host(Host::new(self.host_id)) + .version(KV_VERSION.to_string()) + .tag(KV_TAG.to_string()) + .idx(idx) + .data(bytes) + .build(); + + let id = record.id; + + self.record_store + .push(&record.encrypt::<PASETO_V4>(&self.encryption_key)) + .await?; + + Ok((id, idx)) + } + + pub async fn build(&self) -> Result<()> { + let mut tagged = self.record_store.all_tagged(KV_TAG).await?; + tagged.reverse(); + + let cached = self.kv_db.list(None).await?; + + let mut visited = HashSet::new(); + + // Iterate through all KV records from newest to oldest; + // only visit each KV once, inserting or deleting based on the first time we see it + for record in tagged { + let decrypted = match record.version.as_str() { + "v0" | KV_VERSION => record.decrypt::<PASETO_V4>(&self.encryption_key)?, + version => bail!("unknown version {version:?}"), + }; + + let kv = KvRecord::deserialize(&decrypted.data, &decrypted.version)?; + let uniq_id = format!("{}.{}", kv.namespace, kv.key); + + if visited.insert(uniq_id) { + match kv.value { + Some(value) => { + self.kv_db + .save( + &KvEntry::builder() + .namespace(kv.namespace.clone()) + .key(kv.key.clone()) + .value(value) + .build(), + ) + .await?; + } + None => { + self.kv_db + .delete(kv.namespace.as_str(), kv.key.as_str()) + .await?; + } + } + } + } + + // Any KVs that were in the cache but not in the tagged list should be deleted; + // this should never happen in practice since the cache is always built from the tagged list, + // but just in case because ** S O F T W A R E ** + for kv in cached { + if !visited.contains(&format!("{}.{}", kv.namespace, kv.key)) { + self.kv_db + .delete(kv.namespace.as_str(), kv.key.as_str()) + .await?; + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + async fn setup() -> Result<KvStore> { + let record_store = SqliteStore::new("sqlite::memory:", 1.0).await.unwrap(); + let kv_db = Database::new("sqlite::memory:", 1.0).await.unwrap(); + let host_id = atuin_common::record::HostId(atuin_common::utils::uuid_v7()); + let encryption_key = [0; 32]; + Ok(KvStore::new(record_store, kv_db, host_id, encryption_key)) + } + + #[tokio::test] + async fn test_kv_store() -> Result<()> { + let store = setup().await?; + + store.set("test", "key", "value").await.unwrap(); + let value = store.get("test", "key").await.unwrap(); + assert_eq!(value, Some("value".to_string())); + + let records = store.record_store.all_tagged(KV_TAG).await?; + assert_eq!(records.len(), 1); + + let list = store.list(Some("test")).await.unwrap(); + let expected = vec![ + KvEntry::builder() + .namespace("test".to_string()) + .key("key".to_string()) + .value("value".to_string()) + .build(), + ]; + assert_eq!(list, expected); + + let ns_list = store.list(None).await.unwrap(); + assert_eq!(ns_list, expected); + + store + .delete("test", &vec!["key".to_string()]) + .await + .unwrap(); + let value = store.get("test", "key").await.unwrap(); + assert_eq!(value, None); + + let records = store.record_store.all_tagged(KV_TAG).await?; + assert_eq!(records.len(), 2); + + Ok(()) + } +} diff --git a/crates/atuin-kv/src/store/entry.rs b/crates/atuin-kv/src/store/entry.rs new file mode 100644 index 00000000..1d6a1ef8 --- /dev/null +++ b/crates/atuin-kv/src/store/entry.rs @@ -0,0 +1,8 @@ +use typed_builder::TypedBuilder; + +#[derive(Debug, Clone, PartialEq, Eq, TypedBuilder)] +pub struct KvEntry { + pub namespace: String, + pub key: String, + pub value: String, +} diff --git a/crates/atuin-kv/src/store/record.rs b/crates/atuin-kv/src/store/record.rs new file mode 100644 index 00000000..37254176 --- /dev/null +++ b/crates/atuin-kv/src/store/record.rs @@ -0,0 +1,159 @@ +use atuin_common::record::DecryptedData; +use eyre::{Result, bail, ensure, eyre}; +use typed_builder::TypedBuilder; + +pub const KV_VERSION: &str = "v1"; +pub const KV_TAG: &str = "kv"; +pub const KV_VAL_MAX_LEN: usize = 100 * 1024; + +#[derive(Debug, Clone, PartialEq, Eq, TypedBuilder)] +pub struct KvRecord { + pub namespace: String, + pub key: String, + pub value: Option<String>, +} + +impl KvRecord { + pub fn serialize(&self) -> Result<DecryptedData> { + use rmp::encode; + + let mut output = vec![]; + + // INFO: ensure this is updated when adding new fields + encode::write_array_len(&mut output, 4)?; + + encode::write_str(&mut output, &self.namespace)?; + encode::write_str(&mut output, &self.key)?; + encode::write_bool(&mut output, self.value.is_some())?; + + if let Some(value) = &self.value { + encode::write_str(&mut output, value)?; + } + + Ok(DecryptedData(output)) + } + + pub fn deserialize(data: &DecryptedData, version: &str) -> Result<Self> { + use rmp::decode; + + fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { + eyre!("{err:?}") + } + + match version { + "v0" => { + let mut bytes = decode::Bytes::new(&data.0); + + let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; + ensure!(nfields == 3, "too many entries in v0 kv record"); + + let bytes = bytes.remaining_slice(); + + let (namespace, bytes) = + decode::read_str_from_slice(bytes).map_err(error_report)?; + let (key, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (value, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + if !bytes.is_empty() { + bail!("trailing bytes in encoded kvrecord. malformed") + } + + Ok(KvRecord { + namespace: namespace.to_owned(), + key: key.to_owned(), + value: Some(value.to_owned()), + }) + } + KV_VERSION => { + let mut bytes = decode::Bytes::new(&data.0); + + let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; + ensure!(nfields == 4, "too many entries in v1 kv record"); + + let bytes = bytes.remaining_slice(); + + let (namespace, bytes) = + decode::read_str_from_slice(bytes).map_err(error_report)?; + let (key, mut bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let has_value = decode::read_bool(&mut bytes).map_err(error_report)?; + + let (value, bytes) = if has_value { + let (value, bytes) = + decode::read_str_from_slice(bytes).map_err(error_report)?; + (Some(value.to_owned()), bytes) + } else { + (None, bytes) + }; + + if !bytes.is_empty() { + bail!("trailing bytes in encoded kvrecord. malformed") + } + + Ok(KvRecord { + namespace: namespace.to_owned(), + key: key.to_owned(), + value, + }) + } + _ => { + bail!("unknown version {version:?}") + } + } + } +} + +#[cfg(test)] +mod tests { + use super::{DecryptedData, KV_VERSION, KvRecord}; + + #[test] + fn encode_decode_some() { + let kv = KvRecord { + namespace: "foo".to_owned(), + key: "bar".to_owned(), + value: Some("baz".to_owned()), + }; + let snapshot = [ + 0x94, 0xa3, b'f', b'o', b'o', 0xa3, b'b', b'a', b'r', 0xc3, 0xa3, b'b', b'a', b'z', + ]; + + let encoded = kv.serialize().unwrap(); + let decoded = KvRecord::deserialize(&encoded, KV_VERSION).unwrap(); + + assert_eq!(encoded.0, &snapshot); + assert_eq!(decoded, kv); + } + + #[test] + fn encode_decode_none() { + let kv = KvRecord { + namespace: "foo".to_owned(), + key: "bar".to_owned(), + value: None, + }; + let snapshot = [0x94, 0xa3, b'f', b'o', b'o', 0xa3, b'b', b'a', b'r', 0xc2]; + + let encoded = kv.serialize().unwrap(); + let decoded = KvRecord::deserialize(&encoded, KV_VERSION).unwrap(); + + assert_eq!(encoded.0, &snapshot); + assert_eq!(decoded, kv); + } + + #[test] + fn decode_v0() { + let kv = KvRecord { + namespace: "foo".to_owned(), + key: "bar".to_owned(), + value: Some("baz".to_owned()), + }; + + let snapshot = vec![ + 0x93, 0xa3, b'f', b'o', b'o', 0xa3, b'b', b'a', b'r', 0xa3, b'b', b'a', b'z', + ]; + + let decoded = KvRecord::deserialize(&DecryptedData(snapshot), "v0").unwrap(); + + assert_eq!(decoded, kv); + } +} |
