diff options
Diffstat (limited to 'crates')
| -rw-r--r-- | crates/atuin-client/src/kv.rs | 361 | ||||
| -rw-r--r-- | crates/atuin-client/src/lib.rs | 1 | ||||
| -rw-r--r-- | crates/atuin-client/src/settings.rs | 6 | ||||
| -rw-r--r-- | crates/atuin-client/src/settings/kv.rs | 17 | ||||
| -rw-r--r-- | crates/atuin-kv/Cargo.toml | 27 | ||||
| -rw-r--r-- | crates/atuin-kv/migrations/20250501160746_create_kv_db.down.sql | 2 | ||||
| -rw-r--r-- | crates/atuin-kv/migrations/20250501160746_create_kv_db.up.sql | 12 | ||||
| -rw-r--r-- | crates/atuin-kv/src/database.rs | 229 | ||||
| -rw-r--r-- | crates/atuin-kv/src/lib.rs | 2 | ||||
| -rw-r--r-- | crates/atuin-kv/src/store.rs | 214 | ||||
| -rw-r--r-- | crates/atuin-kv/src/store/entry.rs | 8 | ||||
| -rw-r--r-- | crates/atuin-kv/src/store/record.rs | 159 | ||||
| -rw-r--r-- | crates/atuin/Cargo.toml | 5 | ||||
| -rw-r--r-- | crates/atuin/src/command/client/kv.rs | 93 | ||||
| -rw-r--r-- | crates/atuin/src/sync.rs | 6 |
15 files changed, 734 insertions, 408 deletions
diff --git a/crates/atuin-client/src/kv.rs b/crates/atuin-client/src/kv.rs deleted file mode 100644 index 4915100b..00000000 --- a/crates/atuin-client/src/kv.rs +++ /dev/null @@ -1,361 +0,0 @@ -use std::collections::BTreeMap; - -use atuin_common::record::{DecryptedData, Host, HostId}; -use eyre::{Result, bail, ensure, eyre}; -use serde::Deserialize; - -use crate::record::encryption::PASETO_V4; -use crate::record::store::Store; - -const KV_VERSION: &str = "v1"; -const KV_TAG: &str = "kv"; -const KV_VAL_MAX_LEN: usize = 100 * 1024; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct KvRecord { - pub namespace: String, - pub key: String, - pub value: Option<String>, -} - -impl KvRecord { - pub fn serialize(&self) -> Result<DecryptedData> { - use rmp::encode; - - let mut output = vec![]; - - // INFO: ensure this is updated when adding new fields - encode::write_array_len(&mut output, 4)?; - - encode::write_str(&mut output, &self.namespace)?; - encode::write_str(&mut output, &self.key)?; - encode::write_bool(&mut output, self.value.is_some())?; - - if let Some(value) = &self.value { - encode::write_str(&mut output, value)?; - } - - Ok(DecryptedData(output)) - } - - pub fn deserialize(data: &DecryptedData, version: &str) -> Result<Self> { - use rmp::decode; - - fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { - eyre!("{err:?}") - } - - match version { - "v0" => { - let mut bytes = decode::Bytes::new(&data.0); - - let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; - ensure!(nfields == 3, "too many entries in v0 kv record"); - - let bytes = bytes.remaining_slice(); - - let (namespace, bytes) = - decode::read_str_from_slice(bytes).map_err(error_report)?; - let (key, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (value, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - - if !bytes.is_empty() { - bail!("trailing bytes in encoded kvrecord. malformed") - } - - Ok(KvRecord { - namespace: namespace.to_owned(), - key: key.to_owned(), - value: Some(value.to_owned()), - }) - } - KV_VERSION => { - let mut bytes = decode::Bytes::new(&data.0); - - let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; - ensure!(nfields == 4, "too many entries in v1 kv record"); - - let bytes = bytes.remaining_slice(); - - let (namespace, bytes) = - decode::read_str_from_slice(bytes).map_err(error_report)?; - let (key, mut bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let has_value = decode::read_bool(&mut bytes).map_err(error_report)?; - - let (value, bytes) = if has_value { - let (value, bytes) = - decode::read_str_from_slice(bytes).map_err(error_report)?; - (Some(value.to_owned()), bytes) - } else { - (None, bytes) - }; - - if !bytes.is_empty() { - bail!("trailing bytes in encoded kvrecord. malformed") - } - - Ok(KvRecord { - namespace: namespace.to_owned(), - key: key.to_owned(), - value, - }) - } - _ => { - bail!("unknown version {version:?}") - } - } - } -} - -#[derive(Debug, Clone, Deserialize)] -pub struct KvStore; - -impl Default for KvStore { - fn default() -> Self { - Self::new() - } -} - -impl KvStore { - // will want to init the actual kv store when that is done - pub fn new() -> KvStore { - KvStore {} - } - - pub async fn set( - &self, - store: &(impl Store + Send + Sync), - encryption_key: &[u8; 32], - host_id: HostId, - namespace: &str, - key: &str, - value: Option<&str>, - ) -> Result<()> { - if value.is_some() && value.unwrap().len() > KV_VAL_MAX_LEN { - return Err(eyre!( - "kv value too large: max len {} bytes", - KV_VAL_MAX_LEN - )); - } - - let record = KvRecord { - namespace: namespace.to_string(), - key: key.to_string(), - value: value.map(|v| v.to_string()), - }; - - let bytes = record.serialize()?; - - let idx = store - .last(host_id, KV_TAG) - .await? - .map_or(0, |entry| entry.idx + 1); - - let record = atuin_common::record::Record::builder() - .host(Host::new(host_id)) - .version(KV_VERSION.to_string()) - .tag(KV_TAG.to_string()) - .idx(idx) - .data(bytes) - .build(); - - store - .push(&record.encrypt::<PASETO_V4>(encryption_key)) - .await?; - - Ok(()) - } - - // TODO: setup an actual kv store, rebuild func, and do not pass the main store in here as - // well. - pub async fn get( - &self, - store: &impl Store, - encryption_key: &[u8; 32], - namespace: &str, - key: &str, - ) -> Result<Option<KvRecord>> { - // TODO: don't rebuild every time... - let map = self.build_kv(store, encryption_key).await?; - - let res = map.get(namespace); - - if let Some(ns) = res { - let value = ns.get(key); - - Ok(value.cloned()) - } else { - Ok(None) - } - } - - // Build a kv map out of the linked list kv store - // Map is Namespace -> Key -> Value - // TODO(ellie): "cache" this into a real kv structure, which we can - // use as a write-through cache to avoid constant rebuilds. - pub async fn build_kv( - &self, - store: &impl Store, - encryption_key: &[u8; 32], - ) -> Result<BTreeMap<String, BTreeMap<String, KvRecord>>> { - let mut map = BTreeMap::new(); - - // TODO: maybe don't load the entire tag into memory to build the kv - // we can be smart about it and only load values since the last build - // or, iterate/paginate - let tagged = store.all_tagged(KV_TAG).await?; - - // iterate through all tags and play each KV record at a time - // this is "last write wins" - // probably good enough for now, but revisit in future - for record in tagged { - let decrypted = match record.version.as_str() { - "v0" | KV_VERSION => record.decrypt::<PASETO_V4>(encryption_key)?, - version => bail!("unknown version {version:?}"), - }; - - let kv = KvRecord::deserialize(&decrypted.data, &decrypted.version)?; - - let ns = map - .entry(kv.namespace.clone()) - .or_insert_with(BTreeMap::new); - - ns.insert(kv.key.clone(), kv); - } - - Ok(map) - } -} - -#[cfg(test)] -mod tests { - use crypto_secretbox::{KeyInit, XSalsa20Poly1305}; - use rand::rngs::OsRng; - - use crate::record::sqlite_store::SqliteStore; - use crate::settings::test_local_timeout; - - use super::{DecryptedData, KV_VERSION, KvRecord, KvStore}; - - #[test] - fn encode_decode_some() { - let kv = KvRecord { - namespace: "foo".to_owned(), - key: "bar".to_owned(), - value: Some("baz".to_owned()), - }; - let snapshot = [ - 0x94, 0xa3, b'f', b'o', b'o', 0xa3, b'b', b'a', b'r', 0xc3, 0xa3, b'b', b'a', b'z', - ]; - - let encoded = kv.serialize().unwrap(); - let decoded = KvRecord::deserialize(&encoded, KV_VERSION).unwrap(); - - assert_eq!(encoded.0, &snapshot); - assert_eq!(decoded, kv); - } - - #[test] - fn encode_decode_none() { - let kv = KvRecord { - namespace: "foo".to_owned(), - key: "bar".to_owned(), - value: None, - }; - let snapshot = [0x94, 0xa3, b'f', b'o', b'o', 0xa3, b'b', b'a', b'r', 0xc2]; - - let encoded = kv.serialize().unwrap(); - let decoded = KvRecord::deserialize(&encoded, KV_VERSION).unwrap(); - - assert_eq!(encoded.0, &snapshot); - assert_eq!(decoded, kv); - } - - #[test] - fn decode_v0() { - let kv = KvRecord { - namespace: "foo".to_owned(), - key: "bar".to_owned(), - value: Some("baz".to_owned()), - }; - - let snapshot = vec![ - 0x93, 0xa3, b'f', b'o', b'o', 0xa3, b'b', b'a', b'r', 0xa3, b'b', b'a', b'z', - ]; - - let decoded = KvRecord::deserialize(&DecryptedData(snapshot), "v0").unwrap(); - - assert_eq!(decoded, kv); - } - - #[tokio::test] - async fn build_kv() { - let mut store = SqliteStore::new(":memory:", test_local_timeout()) - .await - .unwrap(); - let kv = KvStore::new(); - let key: [u8; 32] = XSalsa20Poly1305::generate_key(&mut OsRng).into(); - let host_id = atuin_common::record::HostId(atuin_common::utils::uuid_v7()); - - kv.set(&mut store, &key, host_id, "test-kv", "foo", Some("bar")) - .await - .unwrap(); - - kv.set(&mut store, &key, host_id, "test-kv", "1", Some("2")) - .await - .unwrap(); - - kv.set( - &mut store, - &key, - host_id, - "test-kv", - "deleted", - Some("hello"), - ) - .await - .unwrap(); - - kv.set(&mut store, &key, host_id, "test-kv", "deleted", None) - .await - .unwrap(); - - let map = kv.build_kv(&store, &key).await.unwrap(); - - assert_eq!( - *map.get("test-kv") - .expect("map namespace not set") - .get("foo") - .expect("map key not set"), - KvRecord { - namespace: String::from("test-kv"), - key: String::from("foo"), - value: Some(String::from("bar")) - } - ); - - assert_eq!( - *map.get("test-kv") - .expect("map namespace not set") - .get("1") - .expect("map key not set"), - KvRecord { - namespace: String::from("test-kv"), - key: String::from("1"), - value: Some(String::from("2")) - } - ); - - assert_eq!( - *map.get("test-kv") - .expect("map namespace not set") - .get("deleted") - .expect("map key not set"), - KvRecord { - namespace: String::from("test-kv"), - key: String::from("deleted"), - value: None - } - ); - } -} diff --git a/crates/atuin-client/src/lib.rs b/crates/atuin-client/src/lib.rs index 99640682..443ff3f8 100644 --- a/crates/atuin-client/src/lib.rs +++ b/crates/atuin-client/src/lib.rs @@ -12,7 +12,6 @@ pub mod database; pub mod encryption; pub mod history; pub mod import; -pub mod kv; pub mod login; pub mod logout; pub mod ordering; diff --git a/crates/atuin-client/src/settings.rs b/crates/atuin-client/src/settings.rs index 6c50f90c..48803a49 100644 --- a/crates/atuin-client/src/settings.rs +++ b/crates/atuin-client/src/settings.rs @@ -30,6 +30,7 @@ pub const HOST_ID_FILENAME: &str = "host_id"; static EXAMPLE_CONFIG: &str = include_str!("../config.toml"); mod dotfiles; +mod kv; mod scripts; #[derive(Clone, Debug, Deserialize, Copy, ValueEnum, PartialEq, Serialize)] @@ -520,6 +521,9 @@ pub struct Settings { #[serde(default)] pub scripts: scripts::Settings, + + #[serde(default)] + pub kv: kv::Settings, } impl Settings { @@ -736,6 +740,7 @@ impl Settings { let data_dir = atuin_common::utils::data_dir(); let db_path = data_dir.join("history.db"); let record_store_path = data_dir.join("records.db"); + let kv_path = data_dir.join("kv.db"); let socket_path = atuin_common::utils::runtime_dir().join("atuin.sock"); let key_path = data_dir.join("key"); @@ -799,6 +804,7 @@ impl Settings { .set_default("daemon.socket_path", socket_path.to_str())? .set_default("daemon.systemd_socket", false)? .set_default("daemon.tcp_port", 8889)? + .set_default("kv.db_path", kv_path.to_str())? .set_default( "search.filters", vec!["global", "host", "session", "workspace", "directory"], diff --git a/crates/atuin-client/src/settings/kv.rs b/crates/atuin-client/src/settings/kv.rs new file mode 100644 index 00000000..afc24a35 --- /dev/null +++ b/crates/atuin-client/src/settings/kv.rs @@ -0,0 +1,17 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Settings { + pub db_path: String, +} + +impl Default for Settings { + fn default() -> Self { + let dir = atuin_common::utils::data_dir(); + let path = dir.join("kv.db"); + + Self { + db_path: path.to_string_lossy().to_string(), + } + } +} diff --git a/crates/atuin-kv/Cargo.toml b/crates/atuin-kv/Cargo.toml new file mode 100644 index 00000000..ec4f9228 --- /dev/null +++ b/crates/atuin-kv/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "atuin-kv" +edition = "2024" +version = { workspace = true } +description = "The kv crate for Atuin" + +authors.workspace = true +rust-version.workspace = true +license.workspace = true +homepage.workspace = true +repository.workspace = true +readme.workspace = true + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +atuin-client = { path = "../atuin-client", version = "18.6.0-beta.1" } +atuin-common = { path = "../atuin-common", version = "18.6.0-beta.1" } + +tracing = { workspace = true } +tracing-subscriber = { workspace = true } +rmp = { version = "0.8.14" } +eyre = { workspace = true } +tokio = { workspace = true } +typed-builder = { workspace = true } +pretty_assertions = { workspace = true } +sqlx = { workspace = true } diff --git a/crates/atuin-kv/migrations/20250501160746_create_kv_db.down.sql b/crates/atuin-kv/migrations/20250501160746_create_kv_db.down.sql new file mode 100644 index 00000000..bce8dfd3 --- /dev/null +++ b/crates/atuin-kv/migrations/20250501160746_create_kv_db.down.sql @@ -0,0 +1,2 @@ +-- Add down migration script here +DROP TABLE kv; diff --git a/crates/atuin-kv/migrations/20250501160746_create_kv_db.up.sql b/crates/atuin-kv/migrations/20250501160746_create_kv_db.up.sql new file mode 100644 index 00000000..77384044 --- /dev/null +++ b/crates/atuin-kv/migrations/20250501160746_create_kv_db.up.sql @@ -0,0 +1,12 @@ +-- Add up migration script here +CREATE TABLE + kv ( + namespace TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + inserted_at INTEGER NOT NULL DEFAULT (strftime ('%s', 'now')) + ); + +CREATE INDEX idx_kv_namespace ON kv (namespace); + +CREATE UNIQUE INDEX idx_kv ON kv (namespace, key); diff --git a/crates/atuin-kv/src/database.rs b/crates/atuin-kv/src/database.rs new file mode 100644 index 00000000..ad9226de --- /dev/null +++ b/crates/atuin-kv/src/database.rs @@ -0,0 +1,229 @@ +use std::{path::Path, str::FromStr, time::Duration}; + +use atuin_common::utils; +use sqlx::{ + Result, Row, + sqlite::{ + SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow, + SqliteSynchronous, + }, +}; +use tokio::fs; +use tracing::debug; + +use crate::store::entry::KvEntry; + +#[derive(Debug, Clone)] +pub struct Database { + pub pool: SqlitePool, +} + +impl Database { + pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> { + let path = path.as_ref(); + debug!("opening KV sqlite database at {:?}", path); + + if utils::broken_symlink(path) { + eprintln!( + "Atuin: KV sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement." + ); + std::process::exit(1); + } + + if !path.exists() { + if let Some(dir) = path.parent() { + fs::create_dir_all(dir).await?; + } + } + + let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())? + .journal_mode(SqliteJournalMode::Wal) + .optimize_on_close(true, None) + .synchronous(SqliteSynchronous::Normal) + .with_regexp() + .foreign_keys(true) + .create_if_missing(true); + + let pool = SqlitePoolOptions::new() + .acquire_timeout(Duration::from_secs_f64(timeout)) + .connect_with(opts) + .await?; + + Self::setup_db(&pool).await?; + Ok(Self { pool }) + } + + pub async fn sqlite_version(&self) -> Result<String> { + sqlx::query_scalar("SELECT sqlite_version()") + .fetch_one(&self.pool) + .await + } + + async fn setup_db(pool: &SqlitePool) -> Result<()> { + debug!("running sqlite database setup"); + + sqlx::migrate!("./migrations").run(pool).await?; + + Ok(()) + } + + async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, e: &KvEntry) -> Result<()> { + sqlx::query( + "insert into kv(namespace, key, value) + values(?1, ?2, ?3) + on conflict(namespace, key) do update set + namespace = excluded.namespace, + key = excluded.key, + value = excluded.value", + ) + .bind(e.namespace.as_str()) + .bind(e.key.as_str()) + .bind(e.value.as_str()) + .execute(&mut **tx) + .await?; + + Ok(()) + } + + async fn delete_raw( + tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, + namespace: &str, + key: &str, + ) -> Result<()> { + sqlx::query("delete from kv where namespace = ?1 and key = ?2") + .bind(namespace) + .bind(key) + .execute(&mut **tx) + .await?; + Ok(()) + } + + pub async fn save(&self, e: &KvEntry) -> Result<()> { + debug!("saving kv entry to sqlite"); + let mut tx = self.pool.begin().await?; + Self::save_raw(&mut tx, e).await?; + tx.commit().await?; + + Ok(()) + } + + pub async fn delete(&self, namespace: &str, key: &str) -> Result<()> { + debug!("deleting kv entry {namespace}/{key}"); + + let mut tx = self.pool.begin().await?; + Self::delete_raw(&mut tx, namespace, key).await?; + tx.commit().await?; + + Ok(()) + } + + fn query_kv_entry(row: SqliteRow) -> KvEntry { + let namespace = row.get("namespace"); + let key = row.get("key"); + let value = row.get("value"); + + KvEntry::builder() + .namespace(namespace) + .key(key) + .value(value) + .build() + } + + pub async fn load(&self, namespace: &str, key: &str) -> Result<Option<KvEntry>> { + debug!("loading kv entry {namespace}.{key}"); + + let res = sqlx::query("select * from kv where namespace = ?1 and key = ?2") + .bind(namespace) + .bind(key) + .map(Self::query_kv_entry) + .fetch_optional(&self.pool) + .await?; + + Ok(res) + } + + pub async fn list(&self, namespace: Option<&str>) -> Result<Vec<KvEntry>> { + debug!("listing kv entries"); + + let res = if let Some(namespace) = namespace { + sqlx::query("select * from kv where namespace = ?1 order by key asc") + .bind(namespace) + .map(Self::query_kv_entry) + .fetch_all(&self.pool) + .await? + } else { + sqlx::query("select * from kv order by namespace, key asc") + .map(Self::query_kv_entry) + .fetch_all(&self.pool) + .await? + }; + + Ok(res) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[tokio::test] + async fn test_list() { + let db = Database::new("sqlite::memory:", 1.0).await.unwrap(); + let scripts = db.list(None).await.unwrap(); + assert_eq!(scripts.len(), 0); + + let entry = KvEntry::builder() + .namespace("test".to_string()) + .key("test".to_string()) + .value("test".to_string()) + .build(); + + db.save(&entry).await.unwrap(); + + let entries = db.list(None).await.unwrap(); + assert_eq!(entries.len(), 1); + assert_eq!(entries[0].namespace, "test"); + assert_eq!(entries[0].key, "test"); + assert_eq!(entries[0].value, "test"); + } + + #[tokio::test] + async fn test_save_load() { + let db = Database::new("sqlite::memory:", 1.0).await.unwrap(); + + let entry = KvEntry::builder() + .namespace("test".to_string()) + .key("test".to_string()) + .value("test".to_string()) + .build(); + + db.save(&entry).await.unwrap(); + + let loaded = db + .load(&entry.namespace, &entry.key) + .await + .unwrap() + .unwrap(); + + assert_eq!(loaded, entry); + } + + #[tokio::test] + async fn test_delete() { + let db = Database::new("sqlite::memory:", 1.0).await.unwrap(); + + let entry = KvEntry::builder() + .namespace("test".to_string()) + .key("test".to_string()) + .value("test".to_string()) + .build(); + + db.save(&entry).await.unwrap(); + + assert_eq!(db.list(None).await.unwrap().len(), 1); + db.delete(&entry.namespace, &entry.key).await.unwrap(); + + let loaded = db.list(None).await.unwrap(); + assert_eq!(loaded.len(), 0); + } +} diff --git a/crates/atuin-kv/src/lib.rs b/crates/atuin-kv/src/lib.rs new file mode 100644 index 00000000..ad57b6ac --- /dev/null +++ b/crates/atuin-kv/src/lib.rs @@ -0,0 +1,2 @@ +pub mod database; +pub mod store; diff --git a/crates/atuin-kv/src/store.rs b/crates/atuin-kv/src/store.rs new file mode 100644 index 00000000..3394b8c0 --- /dev/null +++ b/crates/atuin-kv/src/store.rs @@ -0,0 +1,214 @@ +use std::collections::HashSet; + +use eyre::{Result, bail}; + +use atuin_client::record::sqlite_store::SqliteStore; +use atuin_client::record::{encryption::PASETO_V4, store::Store}; +use atuin_common::record::{Host, HostId, Record, RecordId, RecordIdx}; +use entry::KvEntry; +use record::{KV_TAG, KV_VERSION, KvRecord}; + +use crate::database::Database; + +pub mod entry; +pub mod record; + +#[derive(Debug, Clone)] +pub struct KvStore { + pub record_store: SqliteStore, + pub kv_db: Database, + pub host_id: HostId, + pub encryption_key: [u8; 32], +} + +impl KvStore { + pub fn new( + record_store: SqliteStore, + kv_db: Database, + host_id: HostId, + encryption_key: [u8; 32], + ) -> Self { + KvStore { + record_store, + kv_db, + host_id, + encryption_key, + } + } + + pub async fn set(&self, namespace: &str, key: &str, value: &str) -> Result<()> { + let kv_record = KvRecord::builder() + .namespace(namespace.to_string()) + .key(key.to_string()) + .value(Some(value.to_string())) + .build(); + + self.push_record(kv_record).await?; + + let kv = KvEntry::builder() + .namespace(namespace.to_string()) + .key(key.to_string()) + .value(value.to_string()) + .build(); + + self.kv_db.save(&kv).await?; + + Ok(()) + } + + pub async fn get(&self, namespace: &str, key: &str) -> Result<Option<String>> { + let kv = self.kv_db.load(namespace, key).await?; + Ok(kv.map(|kv| kv.value)) + } + + pub async fn delete(&self, namespace: &str, keys: &[String]) -> Result<()> { + for key in keys { + let record = KvRecord::builder() + .namespace(namespace.to_string()) + .key(key.to_string()) + .value(None) + .build(); + + self.push_record(record).await?; + self.kv_db.delete(namespace, key).await?; + } + + Ok(()) + } + + pub async fn list(&self, namespace: Option<&str>) -> Result<Vec<KvEntry>> { + let entries = self.kv_db.list(namespace).await?; + + Ok(entries) + } + + async fn push_record(&self, record: KvRecord) -> Result<(RecordId, RecordIdx)> { + let bytes = record.serialize()?; + let idx = self + .record_store + .last(self.host_id, KV_TAG) + .await? + .map_or(0, |p| p.idx + 1); + + let record = Record::builder() + .host(Host::new(self.host_id)) + .version(KV_VERSION.to_string()) + .tag(KV_TAG.to_string()) + .idx(idx) + .data(bytes) + .build(); + + let id = record.id; + + self.record_store + .push(&record.encrypt::<PASETO_V4>(&self.encryption_key)) + .await?; + + Ok((id, idx)) + } + + pub async fn build(&self) -> Result<()> { + let mut tagged = self.record_store.all_tagged(KV_TAG).await?; + tagged.reverse(); + + let cached = self.kv_db.list(None).await?; + + let mut visited = HashSet::new(); + + // Iterate through all KV records from newest to oldest; + // only visit each KV once, inserting or deleting based on the first time we see it + for record in tagged { + let decrypted = match record.version.as_str() { + "v0" | KV_VERSION => record.decrypt::<PASETO_V4>(&self.encryption_key)?, + version => bail!("unknown version {version:?}"), + }; + + let kv = KvRecord::deserialize(&decrypted.data, &decrypted.version)?; + let uniq_id = format!("{}.{}", kv.namespace, kv.key); + + if visited.insert(uniq_id) { + match kv.value { + Some(value) => { + self.kv_db + .save( + &KvEntry::builder() + .namespace(kv.namespace.clone()) + .key(kv.key.clone()) + .value(value) + .build(), + ) + .await?; + } + None => { + self.kv_db + .delete(kv.namespace.as_str(), kv.key.as_str()) + .await?; + } + } + } + } + + // Any KVs that were in the cache but not in the tagged list should be deleted; + // this should never happen in practice since the cache is always built from the tagged list, + // but just in case because ** S O F T W A R E ** + for kv in cached { + if !visited.contains(&format!("{}.{}", kv.namespace, kv.key)) { + self.kv_db + .delete(kv.namespace.as_str(), kv.key.as_str()) + .await?; + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + async fn setup() -> Result<KvStore> { + let record_store = SqliteStore::new("sqlite::memory:", 1.0).await.unwrap(); + let kv_db = Database::new("sqlite::memory:", 1.0).await.unwrap(); + let host_id = atuin_common::record::HostId(atuin_common::utils::uuid_v7()); + let encryption_key = [0; 32]; + Ok(KvStore::new(record_store, kv_db, host_id, encryption_key)) + } + + #[tokio::test] + async fn test_kv_store() -> Result<()> { + let store = setup().await?; + + store.set("test", "key", "value").await.unwrap(); + let value = store.get("test", "key").await.unwrap(); + assert_eq!(value, Some("value".to_string())); + + let records = store.record_store.all_tagged(KV_TAG).await?; + assert_eq!(records.len(), 1); + + let list = store.list(Some("test")).await.unwrap(); + let expected = vec![ + KvEntry::builder() + .namespace("test".to_string()) + .key("key".to_string()) + .value("value".to_string()) + .build(), + ]; + assert_eq!(list, expected); + + let ns_list = store.list(None).await.unwrap(); + assert_eq!(ns_list, expected); + + store + .delete("test", &vec!["key".to_string()]) + .await + .unwrap(); + let value = store.get("test", "key").await.unwrap(); + assert_eq!(value, None); + + let records = store.record_store.all_tagged(KV_TAG).await?; + assert_eq!(records.len(), 2); + + Ok(()) + } +} diff --git a/crates/atuin-kv/src/store/entry.rs b/crates/atuin-kv/src/store/entry.rs new file mode 100644 index 00000000..1d6a1ef8 --- /dev/null +++ b/crates/atuin-kv/src/store/entry.rs @@ -0,0 +1,8 @@ +use typed_builder::TypedBuilder; + +#[derive(Debug, Clone, PartialEq, Eq, TypedBuilder)] +pub struct KvEntry { + pub namespace: String, + pub key: String, + pub value: String, +} diff --git a/crates/atuin-kv/src/store/record.rs b/crates/atuin-kv/src/store/record.rs new file mode 100644 index 00000000..37254176 --- /dev/null +++ b/crates/atuin-kv/src/store/record.rs @@ -0,0 +1,159 @@ +use atuin_common::record::DecryptedData; +use eyre::{Result, bail, ensure, eyre}; +use typed_builder::TypedBuilder; + +pub const KV_VERSION: &str = "v1"; +pub const KV_TAG: &str = "kv"; +pub const KV_VAL_MAX_LEN: usize = 100 * 1024; + +#[derive(Debug, Clone, PartialEq, Eq, TypedBuilder)] +pub struct KvRecord { + pub namespace: String, + pub key: String, + pub value: Option<String>, +} + +impl KvRecord { + pub fn serialize(&self) -> Result<DecryptedData> { + use rmp::encode; + + let mut output = vec![]; + + // INFO: ensure this is updated when adding new fields + encode::write_array_len(&mut output, 4)?; + + encode::write_str(&mut output, &self.namespace)?; + encode::write_str(&mut output, &self.key)?; + encode::write_bool(&mut output, self.value.is_some())?; + + if let Some(value) = &self.value { + encode::write_str(&mut output, value)?; + } + + Ok(DecryptedData(output)) + } + + pub fn deserialize(data: &DecryptedData, version: &str) -> Result<Self> { + use rmp::decode; + + fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { + eyre!("{err:?}") + } + + match version { + "v0" => { + let mut bytes = decode::Bytes::new(&data.0); + + let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; + ensure!(nfields == 3, "too many entries in v0 kv record"); + + let bytes = bytes.remaining_slice(); + + let (namespace, bytes) = + decode::read_str_from_slice(bytes).map_err(error_report)?; + let (key, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (value, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + if !bytes.is_empty() { + bail!("trailing bytes in encoded kvrecord. malformed") + } + + Ok(KvRecord { + namespace: namespace.to_owned(), + key: key.to_owned(), + value: Some(value.to_owned()), + }) + } + KV_VERSION => { + let mut bytes = decode::Bytes::new(&data.0); + + let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; + ensure!(nfields == 4, "too many entries in v1 kv record"); + + let bytes = bytes.remaining_slice(); + + let (namespace, bytes) = + decode::read_str_from_slice(bytes).map_err(error_report)?; + let (key, mut bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let has_value = decode::read_bool(&mut bytes).map_err(error_report)?; + + let (value, bytes) = if has_value { + let (value, bytes) = + decode::read_str_from_slice(bytes).map_err(error_report)?; + (Some(value.to_owned()), bytes) + } else { + (None, bytes) + }; + + if !bytes.is_empty() { + bail!("trailing bytes in encoded kvrecord. malformed") + } + + Ok(KvRecord { + namespace: namespace.to_owned(), + key: key.to_owned(), + value, + }) + } + _ => { + bail!("unknown version {version:?}") + } + } + } +} + +#[cfg(test)] +mod tests { + use super::{DecryptedData, KV_VERSION, KvRecord}; + + #[test] + fn encode_decode_some() { + let kv = KvRecord { + namespace: "foo".to_owned(), + key: "bar".to_owned(), + value: Some("baz".to_owned()), + }; + let snapshot = [ + 0x94, 0xa3, b'f', b'o', b'o', 0xa3, b'b', b'a', b'r', 0xc3, 0xa3, b'b', b'a', b'z', + ]; + + let encoded = kv.serialize().unwrap(); + let decoded = KvRecord::deserialize(&encoded, KV_VERSION).unwrap(); + + assert_eq!(encoded.0, &snapshot); + assert_eq!(decoded, kv); + } + + #[test] + fn encode_decode_none() { + let kv = KvRecord { + namespace: "foo".to_owned(), + key: "bar".to_owned(), + value: None, + }; + let snapshot = [0x94, 0xa3, b'f', b'o', b'o', 0xa3, b'b', b'a', b'r', 0xc2]; + + let encoded = kv.serialize().unwrap(); + let decoded = KvRecord::deserialize(&encoded, KV_VERSION).unwrap(); + + assert_eq!(encoded.0, &snapshot); + assert_eq!(decoded, kv); + } + + #[test] + fn decode_v0() { + let kv = KvRecord { + namespace: "foo".to_owned(), + key: "bar".to_owned(), + value: Some("baz".to_owned()), + }; + + let snapshot = vec![ + 0x93, 0xa3, b'f', b'o', b'o', 0xa3, b'b', b'a', b'r', 0xa3, b'b', b'a', b'z', + ]; + + let decoded = KvRecord::deserialize(&DecryptedData(snapshot), "v0").unwrap(); + + assert_eq!(decoded, kv); + } +} diff --git a/crates/atuin/Cargo.toml b/crates/atuin/Cargo.toml index 78f921c4..5285ebde 100644 --- a/crates/atuin/Cargo.toml +++ b/crates/atuin/Cargo.toml @@ -50,6 +50,7 @@ atuin-dotfiles = { path = "../atuin-dotfiles", version = "18.6.0-beta.1" } atuin-history = { path = "../atuin-history", version = "18.6.0-beta.1" } atuin-daemon = { path = "../atuin-daemon", version = "18.6.0-beta.1", optional = true, default-features = false } atuin-scripts = { path = "../atuin-scripts", version = "18.6.0-beta.1" } +atuin-kv = { path = "../atuin-kv", version = "18.6.0-beta.1" } log = { workspace = true } time = { workspace = true } @@ -87,7 +88,9 @@ tempfile = { workspace = true } arboard = { version = "3.4", optional = true } [target.'cfg(target_os = "linux")'.dependencies] -arboard = { version = "3.4", optional = true, features = ["wayland-data-control"] } +arboard = { version = "3.4", optional = true, features = [ + "wayland-data-control", +] } [dev-dependencies] tracing-tree = "0.4" diff --git a/crates/atuin/src/command/client/kv.rs b/crates/atuin/src/command/client/kv.rs index 02704da6..b4db9c17 100644 --- a/crates/atuin/src/command/client/kv.rs +++ b/crates/atuin/src/command/client/kv.rs @@ -1,82 +1,95 @@ use clap::Subcommand; -use eyre::{Context, Result}; +use eyre::{Context, Result, eyre}; -use atuin_client::{encryption, kv::KvStore, record::store::Store, settings::Settings}; +use atuin_client::{encryption, record::sqlite_store::SqliteStore, settings::Settings}; +use atuin_kv::store::KvStore; #[derive(Subcommand, Debug)] #[command(infer_subcommands = true)] pub enum Cmd { - // atuin kv set foo bar bar + /// Set a key-value pair Set { + /// Key to set #[arg(long, short)] key: String, + /// Value to store + value: String, + + /// Namespace for the key-value pair #[arg(long, short, default_value = "default")] namespace: String, - - value: String, }, + /// Delete one or more key-value pairs #[command(alias = "rm")] Delete { - key: String, + /// Keys to delete + #[arg(required = true)] + keys: Vec<String>, + /// Namespace for the key-value pair #[arg(long, short, default_value = "default")] namespace: String, }, - // atuin kv get foo => bar baz + /// Retrieve a saved value Get { + /// Key to retrieve key: String, + /// Namespace for the key-value pair #[arg(long, short, default_value = "default")] namespace: String, }, + /// List all keys in a namespace, or in all namespaces + #[command(alias = "ls")] List { + /// Namespace to list keys from #[arg(long, short, default_value = "default")] namespace: String, - #[arg(long, short)] + /// List all keys in all namespaces + #[arg(long, short, alias = "all")] all_namespaces: bool, }, + + /// Rebuild the KV store + Rebuild, } impl Cmd { - pub async fn run(&self, settings: &Settings, store: &(impl Store + Send + Sync)) -> Result<()> { - let kv_store = KvStore::new(); - + pub async fn run(&self, settings: &Settings, store: &SqliteStore) -> Result<()> { let encryption_key: [u8; 32] = encryption::load_key(settings) .context("could not load encryption key")? .into(); let host_id = Settings::host_id().expect("failed to get host_id"); + let kv_db = atuin_kv::database::Database::new(settings.kv.db_path.clone(), 1.0).await?; + let kv_store = KvStore::new(store.clone(), kv_db, host_id, encryption_key); + match self { Self::Set { key, value, namespace, } => { - kv_store - .set(store, &encryption_key, host_id, namespace, key, Some(value)) - .await - } + if namespace.is_empty() { + return Err(eyre!("namespace cannot be empty")); + } - Self::Delete { key, namespace } => { - kv_store - .set(store, &encryption_key, host_id, namespace, key, None) - .await + kv_store.set(namespace, key, value).await } + Self::Delete { keys, namespace } => kv_store.delete(namespace, keys).await, + Self::Get { key, namespace } => { - let val = kv_store.get(store, &encryption_key, namespace, key).await?; + let kv = kv_store.get(namespace, key).await?; - if let Some(kv) = val { - // a `None` for kv.value means the key was deleted - if let Some(value) = kv.value { - println!("{value}"); - } + if let Some(val) = kv { + println!("{val}"); } Ok(()) @@ -86,32 +99,24 @@ impl Cmd { namespace, all_namespaces, } => { - // TODO: don't rebuild this every time lol - let map = kv_store.build_kv(store, &encryption_key).await?; - - // slower, but sorting is probably useful - if *all_namespaces { - for (ns, kv) in &map { - for (k, v) in kv { - if v.value.is_some() { - println!("{ns}.{k}"); - } - } - } + let entries = if *all_namespaces { + kv_store.list(None).await? } else { - let ns = map.get(namespace); + kv_store.list(Some(namespace)).await? + }; - if let Some(ns) = ns { - for (k, v) in ns { - if v.value.is_some() { - println!("{k}"); - } - } + for entry in entries { + if *all_namespaces { + println!("{}.{}", entry.namespace, entry.key); + } else { + println!("{}", entry.key); } } Ok(()) } + + Self::Rebuild {} => kv_store.build().await, } } } diff --git a/crates/atuin/src/sync.rs b/crates/atuin/src/sync.rs index 2d7502e9..ad1e9764 100644 --- a/crates/atuin/src/sync.rs +++ b/crates/atuin/src/sync.rs @@ -7,6 +7,7 @@ use atuin_client::{ settings::Settings, }; use atuin_common::record::RecordId; +use atuin_kv::store::KvStore; // This is the only crate that ties together all other crates. // Therefore, it's the only crate where functions tying together all stores can live @@ -28,19 +29,22 @@ pub async fn build( let downloaded = downloaded.unwrap_or(&[]); + let kv_db = atuin_kv::database::Database::new(settings.kv.db_path.clone(), 1.0).await?; + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); let alias_store = AliasStore::new(store.clone(), host_id, encryption_key); let var_store = VarStore::new(store.clone(), host_id, encryption_key); + let kv_store = KvStore::new(store.clone(), kv_db, host_id, encryption_key); let script_store = ScriptStore::new(store.clone(), host_id, encryption_key); history_store.incremental_build(db, downloaded).await?; alias_store.build().await?; var_store.build().await?; + kv_store.build().await?; let script_db = atuin_scripts::database::Database::new(settings.scripts.database_path.clone(), 1.0).await?; script_store.build(script_db).await?; - Ok(()) } |
