aboutsummaryrefslogtreecommitdiffstats
path: root/crates/atuin-kv/src
diff options
context:
space:
mode:
authorMichelle Tilley <michelle@michelletilley.net>2025-05-06 08:36:32 -0700
committerGitHub <noreply@github.com>2025-05-06 08:36:32 -0700
commita1433e0cefe3ad001d5473faf4312c25bdeea968 (patch)
treeee8bc10e1438641338b8ef7f5de00a52e6c7f074 /crates/atuin-kv/src
parentchore(deps): update minspan to 0.1.5 (#2729) (diff)
downloadatuin-a1433e0cefe3ad001d5473faf4312c25bdeea968.zip
feat: Implement KV as a write-through cache (#2732)
Diffstat (limited to 'crates/atuin-kv/src')
-rw-r--r--crates/atuin-kv/src/database.rs229
-rw-r--r--crates/atuin-kv/src/lib.rs2
-rw-r--r--crates/atuin-kv/src/store.rs214
-rw-r--r--crates/atuin-kv/src/store/entry.rs8
-rw-r--r--crates/atuin-kv/src/store/record.rs159
5 files changed, 612 insertions, 0 deletions
diff --git a/crates/atuin-kv/src/database.rs b/crates/atuin-kv/src/database.rs
new file mode 100644
index 00000000..ad9226de
--- /dev/null
+++ b/crates/atuin-kv/src/database.rs
@@ -0,0 +1,229 @@
+use std::{path::Path, str::FromStr, time::Duration};
+
+use atuin_common::utils;
+use sqlx::{
+ Result, Row,
+ sqlite::{
+ SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow,
+ SqliteSynchronous,
+ },
+};
+use tokio::fs;
+use tracing::debug;
+
+use crate::store::entry::KvEntry;
+
+#[derive(Debug, Clone)]
+pub struct Database {
+ pub pool: SqlitePool,
+}
+
+impl Database {
+ pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> {
+ let path = path.as_ref();
+ debug!("opening KV sqlite database at {:?}", path);
+
+ if utils::broken_symlink(path) {
+ eprintln!(
+ "Atuin: KV sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement."
+ );
+ std::process::exit(1);
+ }
+
+ if !path.exists() {
+ if let Some(dir) = path.parent() {
+ fs::create_dir_all(dir).await?;
+ }
+ }
+
+ let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())?
+ .journal_mode(SqliteJournalMode::Wal)
+ .optimize_on_close(true, None)
+ .synchronous(SqliteSynchronous::Normal)
+ .with_regexp()
+ .foreign_keys(true)
+ .create_if_missing(true);
+
+ let pool = SqlitePoolOptions::new()
+ .acquire_timeout(Duration::from_secs_f64(timeout))
+ .connect_with(opts)
+ .await?;
+
+ Self::setup_db(&pool).await?;
+ Ok(Self { pool })
+ }
+
+ pub async fn sqlite_version(&self) -> Result<String> {
+ sqlx::query_scalar("SELECT sqlite_version()")
+ .fetch_one(&self.pool)
+ .await
+ }
+
+ async fn setup_db(pool: &SqlitePool) -> Result<()> {
+ debug!("running sqlite database setup");
+
+ sqlx::migrate!("./migrations").run(pool).await?;
+
+ Ok(())
+ }
+
+ async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, e: &KvEntry) -> Result<()> {
+ sqlx::query(
+ "insert into kv(namespace, key, value)
+ values(?1, ?2, ?3)
+ on conflict(namespace, key) do update set
+ namespace = excluded.namespace,
+ key = excluded.key,
+ value = excluded.value",
+ )
+ .bind(e.namespace.as_str())
+ .bind(e.key.as_str())
+ .bind(e.value.as_str())
+ .execute(&mut **tx)
+ .await?;
+
+ Ok(())
+ }
+
+ async fn delete_raw(
+ tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>,
+ namespace: &str,
+ key: &str,
+ ) -> Result<()> {
+ sqlx::query("delete from kv where namespace = ?1 and key = ?2")
+ .bind(namespace)
+ .bind(key)
+ .execute(&mut **tx)
+ .await?;
+ Ok(())
+ }
+
+ pub async fn save(&self, e: &KvEntry) -> Result<()> {
+ debug!("saving kv entry to sqlite");
+ let mut tx = self.pool.begin().await?;
+ Self::save_raw(&mut tx, e).await?;
+ tx.commit().await?;
+
+ Ok(())
+ }
+
+ pub async fn delete(&self, namespace: &str, key: &str) -> Result<()> {
+ debug!("deleting kv entry {namespace}/{key}");
+
+ let mut tx = self.pool.begin().await?;
+ Self::delete_raw(&mut tx, namespace, key).await?;
+ tx.commit().await?;
+
+ Ok(())
+ }
+
+ fn query_kv_entry(row: SqliteRow) -> KvEntry {
+ let namespace = row.get("namespace");
+ let key = row.get("key");
+ let value = row.get("value");
+
+ KvEntry::builder()
+ .namespace(namespace)
+ .key(key)
+ .value(value)
+ .build()
+ }
+
+ pub async fn load(&self, namespace: &str, key: &str) -> Result<Option<KvEntry>> {
+ debug!("loading kv entry {namespace}.{key}");
+
+ let res = sqlx::query("select * from kv where namespace = ?1 and key = ?2")
+ .bind(namespace)
+ .bind(key)
+ .map(Self::query_kv_entry)
+ .fetch_optional(&self.pool)
+ .await?;
+
+ Ok(res)
+ }
+
+ pub async fn list(&self, namespace: Option<&str>) -> Result<Vec<KvEntry>> {
+ debug!("listing kv entries");
+
+ let res = if let Some(namespace) = namespace {
+ sqlx::query("select * from kv where namespace = ?1 order by key asc")
+ .bind(namespace)
+ .map(Self::query_kv_entry)
+ .fetch_all(&self.pool)
+ .await?
+ } else {
+ sqlx::query("select * from kv order by namespace, key asc")
+ .map(Self::query_kv_entry)
+ .fetch_all(&self.pool)
+ .await?
+ };
+
+ Ok(res)
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ #[tokio::test]
+ async fn test_list() {
+ let db = Database::new("sqlite::memory:", 1.0).await.unwrap();
+ let scripts = db.list(None).await.unwrap();
+ assert_eq!(scripts.len(), 0);
+
+ let entry = KvEntry::builder()
+ .namespace("test".to_string())
+ .key("test".to_string())
+ .value("test".to_string())
+ .build();
+
+ db.save(&entry).await.unwrap();
+
+ let entries = db.list(None).await.unwrap();
+ assert_eq!(entries.len(), 1);
+ assert_eq!(entries[0].namespace, "test");
+ assert_eq!(entries[0].key, "test");
+ assert_eq!(entries[0].value, "test");
+ }
+
+ #[tokio::test]
+ async fn test_save_load() {
+ let db = Database::new("sqlite::memory:", 1.0).await.unwrap();
+
+ let entry = KvEntry::builder()
+ .namespace("test".to_string())
+ .key("test".to_string())
+ .value("test".to_string())
+ .build();
+
+ db.save(&entry).await.unwrap();
+
+ let loaded = db
+ .load(&entry.namespace, &entry.key)
+ .await
+ .unwrap()
+ .unwrap();
+
+ assert_eq!(loaded, entry);
+ }
+
+ #[tokio::test]
+ async fn test_delete() {
+ let db = Database::new("sqlite::memory:", 1.0).await.unwrap();
+
+ let entry = KvEntry::builder()
+ .namespace("test".to_string())
+ .key("test".to_string())
+ .value("test".to_string())
+ .build();
+
+ db.save(&entry).await.unwrap();
+
+ assert_eq!(db.list(None).await.unwrap().len(), 1);
+ db.delete(&entry.namespace, &entry.key).await.unwrap();
+
+ let loaded = db.list(None).await.unwrap();
+ assert_eq!(loaded.len(), 0);
+ }
+}
diff --git a/crates/atuin-kv/src/lib.rs b/crates/atuin-kv/src/lib.rs
new file mode 100644
index 00000000..ad57b6ac
--- /dev/null
+++ b/crates/atuin-kv/src/lib.rs
@@ -0,0 +1,2 @@
+pub mod database;
+pub mod store;
diff --git a/crates/atuin-kv/src/store.rs b/crates/atuin-kv/src/store.rs
new file mode 100644
index 00000000..3394b8c0
--- /dev/null
+++ b/crates/atuin-kv/src/store.rs
@@ -0,0 +1,214 @@
+use std::collections::HashSet;
+
+use eyre::{Result, bail};
+
+use atuin_client::record::sqlite_store::SqliteStore;
+use atuin_client::record::{encryption::PASETO_V4, store::Store};
+use atuin_common::record::{Host, HostId, Record, RecordId, RecordIdx};
+use entry::KvEntry;
+use record::{KV_TAG, KV_VERSION, KvRecord};
+
+use crate::database::Database;
+
+pub mod entry;
+pub mod record;
+
+#[derive(Debug, Clone)]
+pub struct KvStore {
+ pub record_store: SqliteStore,
+ pub kv_db: Database,
+ pub host_id: HostId,
+ pub encryption_key: [u8; 32],
+}
+
+impl KvStore {
+ pub fn new(
+ record_store: SqliteStore,
+ kv_db: Database,
+ host_id: HostId,
+ encryption_key: [u8; 32],
+ ) -> Self {
+ KvStore {
+ record_store,
+ kv_db,
+ host_id,
+ encryption_key,
+ }
+ }
+
+ pub async fn set(&self, namespace: &str, key: &str, value: &str) -> Result<()> {
+ let kv_record = KvRecord::builder()
+ .namespace(namespace.to_string())
+ .key(key.to_string())
+ .value(Some(value.to_string()))
+ .build();
+
+ self.push_record(kv_record).await?;
+
+ let kv = KvEntry::builder()
+ .namespace(namespace.to_string())
+ .key(key.to_string())
+ .value(value.to_string())
+ .build();
+
+ self.kv_db.save(&kv).await?;
+
+ Ok(())
+ }
+
+ pub async fn get(&self, namespace: &str, key: &str) -> Result<Option<String>> {
+ let kv = self.kv_db.load(namespace, key).await?;
+ Ok(kv.map(|kv| kv.value))
+ }
+
+ pub async fn delete(&self, namespace: &str, keys: &[String]) -> Result<()> {
+ for key in keys {
+ let record = KvRecord::builder()
+ .namespace(namespace.to_string())
+ .key(key.to_string())
+ .value(None)
+ .build();
+
+ self.push_record(record).await?;
+ self.kv_db.delete(namespace, key).await?;
+ }
+
+ Ok(())
+ }
+
+ pub async fn list(&self, namespace: Option<&str>) -> Result<Vec<KvEntry>> {
+ let entries = self.kv_db.list(namespace).await?;
+
+ Ok(entries)
+ }
+
+ async fn push_record(&self, record: KvRecord) -> Result<(RecordId, RecordIdx)> {
+ let bytes = record.serialize()?;
+ let idx = self
+ .record_store
+ .last(self.host_id, KV_TAG)
+ .await?
+ .map_or(0, |p| p.idx + 1);
+
+ let record = Record::builder()
+ .host(Host::new(self.host_id))
+ .version(KV_VERSION.to_string())
+ .tag(KV_TAG.to_string())
+ .idx(idx)
+ .data(bytes)
+ .build();
+
+ let id = record.id;
+
+ self.record_store
+ .push(&record.encrypt::<PASETO_V4>(&self.encryption_key))
+ .await?;
+
+ Ok((id, idx))
+ }
+
+ pub async fn build(&self) -> Result<()> {
+ let mut tagged = self.record_store.all_tagged(KV_TAG).await?;
+ tagged.reverse();
+
+ let cached = self.kv_db.list(None).await?;
+
+ let mut visited = HashSet::new();
+
+ // Iterate through all KV records from newest to oldest;
+ // only visit each KV once, inserting or deleting based on the first time we see it
+ for record in tagged {
+ let decrypted = match record.version.as_str() {
+ "v0" | KV_VERSION => record.decrypt::<PASETO_V4>(&self.encryption_key)?,
+ version => bail!("unknown version {version:?}"),
+ };
+
+ let kv = KvRecord::deserialize(&decrypted.data, &decrypted.version)?;
+ let uniq_id = format!("{}.{}", kv.namespace, kv.key);
+
+ if visited.insert(uniq_id) {
+ match kv.value {
+ Some(value) => {
+ self.kv_db
+ .save(
+ &KvEntry::builder()
+ .namespace(kv.namespace.clone())
+ .key(kv.key.clone())
+ .value(value)
+ .build(),
+ )
+ .await?;
+ }
+ None => {
+ self.kv_db
+ .delete(kv.namespace.as_str(), kv.key.as_str())
+ .await?;
+ }
+ }
+ }
+ }
+
+ // Any KVs that were in the cache but not in the tagged list should be deleted;
+ // this should never happen in practice since the cache is always built from the tagged list,
+ // but just in case because ** S O F T W A R E **
+ for kv in cached {
+ if !visited.contains(&format!("{}.{}", kv.namespace, kv.key)) {
+ self.kv_db
+ .delete(kv.namespace.as_str(), kv.key.as_str())
+ .await?;
+ }
+ }
+
+ Ok(())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ async fn setup() -> Result<KvStore> {
+ let record_store = SqliteStore::new("sqlite::memory:", 1.0).await.unwrap();
+ let kv_db = Database::new("sqlite::memory:", 1.0).await.unwrap();
+ let host_id = atuin_common::record::HostId(atuin_common::utils::uuid_v7());
+ let encryption_key = [0; 32];
+ Ok(KvStore::new(record_store, kv_db, host_id, encryption_key))
+ }
+
+ #[tokio::test]
+ async fn test_kv_store() -> Result<()> {
+ let store = setup().await?;
+
+ store.set("test", "key", "value").await.unwrap();
+ let value = store.get("test", "key").await.unwrap();
+ assert_eq!(value, Some("value".to_string()));
+
+ let records = store.record_store.all_tagged(KV_TAG).await?;
+ assert_eq!(records.len(), 1);
+
+ let list = store.list(Some("test")).await.unwrap();
+ let expected = vec![
+ KvEntry::builder()
+ .namespace("test".to_string())
+ .key("key".to_string())
+ .value("value".to_string())
+ .build(),
+ ];
+ assert_eq!(list, expected);
+
+ let ns_list = store.list(None).await.unwrap();
+ assert_eq!(ns_list, expected);
+
+ store
+ .delete("test", &vec!["key".to_string()])
+ .await
+ .unwrap();
+ let value = store.get("test", "key").await.unwrap();
+ assert_eq!(value, None);
+
+ let records = store.record_store.all_tagged(KV_TAG).await?;
+ assert_eq!(records.len(), 2);
+
+ Ok(())
+ }
+}
diff --git a/crates/atuin-kv/src/store/entry.rs b/crates/atuin-kv/src/store/entry.rs
new file mode 100644
index 00000000..1d6a1ef8
--- /dev/null
+++ b/crates/atuin-kv/src/store/entry.rs
@@ -0,0 +1,8 @@
+use typed_builder::TypedBuilder;
+
+#[derive(Debug, Clone, PartialEq, Eq, TypedBuilder)]
+pub struct KvEntry {
+ pub namespace: String,
+ pub key: String,
+ pub value: String,
+}
diff --git a/crates/atuin-kv/src/store/record.rs b/crates/atuin-kv/src/store/record.rs
new file mode 100644
index 00000000..37254176
--- /dev/null
+++ b/crates/atuin-kv/src/store/record.rs
@@ -0,0 +1,159 @@
+use atuin_common::record::DecryptedData;
+use eyre::{Result, bail, ensure, eyre};
+use typed_builder::TypedBuilder;
+
+pub const KV_VERSION: &str = "v1";
+pub const KV_TAG: &str = "kv";
+pub const KV_VAL_MAX_LEN: usize = 100 * 1024;
+
+#[derive(Debug, Clone, PartialEq, Eq, TypedBuilder)]
+pub struct KvRecord {
+ pub namespace: String,
+ pub key: String,
+ pub value: Option<String>,
+}
+
+impl KvRecord {
+ pub fn serialize(&self) -> Result<DecryptedData> {
+ use rmp::encode;
+
+ let mut output = vec![];
+
+ // INFO: ensure this is updated when adding new fields
+ encode::write_array_len(&mut output, 4)?;
+
+ encode::write_str(&mut output, &self.namespace)?;
+ encode::write_str(&mut output, &self.key)?;
+ encode::write_bool(&mut output, self.value.is_some())?;
+
+ if let Some(value) = &self.value {
+ encode::write_str(&mut output, value)?;
+ }
+
+ Ok(DecryptedData(output))
+ }
+
+ pub fn deserialize(data: &DecryptedData, version: &str) -> Result<Self> {
+ use rmp::decode;
+
+ fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report {
+ eyre!("{err:?}")
+ }
+
+ match version {
+ "v0" => {
+ let mut bytes = decode::Bytes::new(&data.0);
+
+ let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?;
+ ensure!(nfields == 3, "too many entries in v0 kv record");
+
+ let bytes = bytes.remaining_slice();
+
+ let (namespace, bytes) =
+ decode::read_str_from_slice(bytes).map_err(error_report)?;
+ let (key, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
+ let (value, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
+
+ if !bytes.is_empty() {
+ bail!("trailing bytes in encoded kvrecord. malformed")
+ }
+
+ Ok(KvRecord {
+ namespace: namespace.to_owned(),
+ key: key.to_owned(),
+ value: Some(value.to_owned()),
+ })
+ }
+ KV_VERSION => {
+ let mut bytes = decode::Bytes::new(&data.0);
+
+ let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?;
+ ensure!(nfields == 4, "too many entries in v1 kv record");
+
+ let bytes = bytes.remaining_slice();
+
+ let (namespace, bytes) =
+ decode::read_str_from_slice(bytes).map_err(error_report)?;
+ let (key, mut bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
+ let has_value = decode::read_bool(&mut bytes).map_err(error_report)?;
+
+ let (value, bytes) = if has_value {
+ let (value, bytes) =
+ decode::read_str_from_slice(bytes).map_err(error_report)?;
+ (Some(value.to_owned()), bytes)
+ } else {
+ (None, bytes)
+ };
+
+ if !bytes.is_empty() {
+ bail!("trailing bytes in encoded kvrecord. malformed")
+ }
+
+ Ok(KvRecord {
+ namespace: namespace.to_owned(),
+ key: key.to_owned(),
+ value,
+ })
+ }
+ _ => {
+ bail!("unknown version {version:?}")
+ }
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::{DecryptedData, KV_VERSION, KvRecord};
+
+ #[test]
+ fn encode_decode_some() {
+ let kv = KvRecord {
+ namespace: "foo".to_owned(),
+ key: "bar".to_owned(),
+ value: Some("baz".to_owned()),
+ };
+ let snapshot = [
+ 0x94, 0xa3, b'f', b'o', b'o', 0xa3, b'b', b'a', b'r', 0xc3, 0xa3, b'b', b'a', b'z',
+ ];
+
+ let encoded = kv.serialize().unwrap();
+ let decoded = KvRecord::deserialize(&encoded, KV_VERSION).unwrap();
+
+ assert_eq!(encoded.0, &snapshot);
+ assert_eq!(decoded, kv);
+ }
+
+ #[test]
+ fn encode_decode_none() {
+ let kv = KvRecord {
+ namespace: "foo".to_owned(),
+ key: "bar".to_owned(),
+ value: None,
+ };
+ let snapshot = [0x94, 0xa3, b'f', b'o', b'o', 0xa3, b'b', b'a', b'r', 0xc2];
+
+ let encoded = kv.serialize().unwrap();
+ let decoded = KvRecord::deserialize(&encoded, KV_VERSION).unwrap();
+
+ assert_eq!(encoded.0, &snapshot);
+ assert_eq!(decoded, kv);
+ }
+
+ #[test]
+ fn decode_v0() {
+ let kv = KvRecord {
+ namespace: "foo".to_owned(),
+ key: "bar".to_owned(),
+ value: Some("baz".to_owned()),
+ };
+
+ let snapshot = vec![
+ 0x93, 0xa3, b'f', b'o', b'o', 0xa3, b'b', b'a', b'r', 0xa3, b'b', b'a', b'z',
+ ];
+
+ let decoded = KvRecord::deserialize(&DecryptedData(snapshot), "v0").unwrap();
+
+ assert_eq!(decoded, kv);
+ }
+}