From ae1709dafd22ac3c64441472e90df8799253292e Mon Sep 17 00:00:00 2001 From: Ellie Huxtable Date: Wed, 14 Jun 2023 21:18:24 +0100 Subject: Key values (#1038) * wip * Start testing * Store host IDs, not hostnames Why? Hostnames can change a lot, and therefore host filtering can be funky. Really, all we want is a unique ID per machine + do not care what it might be. * Mostly just write a fuckload of tests * Add a v0 kv store I can push to * Appending works * Add next() and iterate, test the pointer chain * Fix sig * Make clippy happy and thaw the ICE * Fix tests' * Fix tests * typed builder and cleaner db trait --------- Co-authored-by: Conrad Ludgate --- atuin-client/src/database.rs | 7 +- atuin-client/src/kv.rs | 103 ++++++++++ atuin-client/src/lib.rs | 2 + atuin-client/src/record/mod.rs | 2 + atuin-client/src/record/sqlite_store.rs | 331 ++++++++++++++++++++++++++++++++ atuin-client/src/record/store.rs | 30 +++ atuin-client/src/settings.rs | 20 ++ 7 files changed, 494 insertions(+), 1 deletion(-) create mode 100644 atuin-client/src/kv.rs create mode 100644 atuin-client/src/record/mod.rs create mode 100644 atuin-client/src/record/sqlite_store.rs create mode 100644 atuin-client/src/record/store.rs (limited to 'atuin-client/src') diff --git a/atuin-client/src/database.rs b/atuin-client/src/database.rs index 22bd5886..a2d8c533 100644 --- a/atuin-client/src/database.rs +++ b/atuin-client/src/database.rs @@ -17,13 +17,14 @@ use sqlx::{ use super::{ history::History, ordering, - settings::{FilterMode, SearchMode}, + settings::{FilterMode, SearchMode, Settings}, }; pub struct Context { pub session: String, pub cwd: String, pub hostname: String, + pub host_id: String, } #[derive(Default, Clone)] @@ -50,11 +51,13 @@ pub fn current_context() -> Context { env::var("ATUIN_HOST_USER").unwrap_or_else(|_| whoami::username()) ); let cwd = utils::get_current_dir(); + let host_id = Settings::host_id().expect("failed to load host ID"); Context { session, hostname, cwd, + host_id, } } @@ -551,6 +554,7 @@ mod test { hostname: "test:host".to_string(), session: "beepboopiamasession".to_string(), cwd: "/home/ellie".to_string(), + host_id: "test-host".to_string(), }; let results = db @@ -757,6 +761,7 @@ mod test { hostname: "test:host".to_string(), session: "beepboopiamasession".to_string(), cwd: "/home/ellie".to_string(), + host_id: "test-host".to_string(), }; let mut db = Sqlite::new("sqlite::memory:").await.unwrap(); diff --git a/atuin-client/src/kv.rs b/atuin-client/src/kv.rs new file mode 100644 index 00000000..87149275 --- /dev/null +++ b/atuin-client/src/kv.rs @@ -0,0 +1,103 @@ +use eyre::Result; +use serde::{Deserialize, Serialize}; + +use crate::record::store::Store; +use crate::settings::Settings; + +const KV_VERSION: &str = "v0"; +const KV_TAG: &str = "kv"; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct KvRecord { + pub key: String, + pub value: String, +} + +impl KvRecord { + pub fn serialize(&self) -> Result> { + let buf = rmp_serde::to_vec(self)?; + + Ok(buf) + } +} + +pub struct KvStore; + +impl Default for KvStore { + fn default() -> Self { + Self::new() + } +} + +impl KvStore { + // will want to init the actual kv store when that is done + pub fn new() -> KvStore { + KvStore {} + } + + pub async fn set( + &self, + store: &mut (impl Store + Send + Sync), + key: &str, + value: &str, + ) -> Result<()> { + let host_id = Settings::host_id().expect("failed to get host_id"); + + let record = KvRecord { + key: key.to_string(), + value: value.to_string(), + }; + + let bytes = record.serialize()?; + + let parent = store + .last(host_id.as_str(), KV_TAG) + .await? + .map(|entry| entry.id); + + let record = atuin_common::record::Record::builder() + .host(host_id) + .version(KV_VERSION.to_string()) + .tag(KV_TAG.to_string()) + .parent(parent) + .data(bytes) + .build(); + + store.push(&record).await?; + + Ok(()) + } + + // TODO: setup an actual kv store, rebuild func, and do not pass the main store in here as + // well. + pub async fn get(&self, store: &impl Store, key: &str) -> Result> { + // TODO: don't load this from disk so much + let host_id = Settings::host_id().expect("failed to get host_id"); + + // Currently, this is O(n). When we have an actual KV store, it can be better + // Just a poc for now! + + // iterate records to find the value we want + // start at the end, so we get the most recent version + let Some(mut record) = store.last(host_id.as_str(), KV_TAG).await? else { + return Ok(None); + }; + let kv: KvRecord = rmp_serde::from_slice(&record.data)?; + + if kv.key == key { + return Ok(Some(kv)); + } + + while let Some(parent) = record.parent { + record = store.get(parent.as_str()).await?; + let kv: KvRecord = rmp_serde::from_slice(&record.data)?; + + if kv.key == key { + return Ok(Some(kv)); + } + } + + // if we get here, then... we didn't find the record with that key :( + Ok(None) + } +} diff --git a/atuin-client/src/lib.rs b/atuin-client/src/lib.rs index 497c5e74..3f12153a 100644 --- a/atuin-client/src/lib.rs +++ b/atuin-client/src/lib.rs @@ -13,5 +13,7 @@ pub mod sync; pub mod database; pub mod history; pub mod import; +pub mod kv; pub mod ordering; +pub mod record; pub mod settings; diff --git a/atuin-client/src/record/mod.rs b/atuin-client/src/record/mod.rs new file mode 100644 index 00000000..72c1f889 --- /dev/null +++ b/atuin-client/src/record/mod.rs @@ -0,0 +1,2 @@ +pub mod sqlite_store; +pub mod store; diff --git a/atuin-client/src/record/sqlite_store.rs b/atuin-client/src/record/sqlite_store.rs new file mode 100644 index 00000000..f116b6e5 --- /dev/null +++ b/atuin-client/src/record/sqlite_store.rs @@ -0,0 +1,331 @@ +// Here we are using sqlite as a pretty dumb store, and will not be running any complex queries. +// Multiple stores of multiple types are all stored in one chonky table (for now), and we just index +// by tag/host + +use std::path::Path; +use std::str::FromStr; + +use async_trait::async_trait; +use eyre::{eyre, Result}; +use fs_err as fs; +use sqlx::{ + sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow}, + Row, +}; + +use atuin_common::record::Record; + +use super::store::Store; + +pub struct SqliteStore { + pool: SqlitePool, +} + +impl SqliteStore { + pub async fn new(path: impl AsRef) -> Result { + let path = path.as_ref(); + + debug!("opening sqlite database at {:?}", path); + + let create = !path.exists(); + if create { + if let Some(dir) = path.parent() { + fs::create_dir_all(dir)?; + } + } + + let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())? + .journal_mode(SqliteJournalMode::Wal) + .create_if_missing(true); + + let pool = SqlitePoolOptions::new().connect_with(opts).await?; + + Self::setup_db(&pool).await?; + + Ok(Self { pool }) + } + + async fn setup_db(pool: &SqlitePool) -> Result<()> { + debug!("running sqlite database setup"); + + sqlx::migrate!("./record-migrations").run(pool).await?; + + Ok(()) + } + + async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, r: &Record) -> Result<()> { + // In sqlite, we are "limited" to i64. But that is still fine, until 2262. + sqlx::query( + "insert or ignore into records(id, host, tag, timestamp, parent, version, data) + values(?1, ?2, ?3, ?4, ?5, ?6, ?7)", + ) + .bind(r.id.as_str()) + .bind(r.host.as_str()) + .bind(r.tag.as_str()) + .bind(r.timestamp as i64) + .bind(r.parent.as_ref()) + .bind(r.version.as_str()) + .bind(r.data.as_slice()) + .execute(tx) + .await?; + + Ok(()) + } + + fn query_row(row: SqliteRow) -> Record { + let timestamp: i64 = row.get("timestamp"); + + Record { + id: row.get("id"), + host: row.get("host"), + parent: row.get("parent"), + timestamp: timestamp as u64, + tag: row.get("tag"), + version: row.get("version"), + data: row.get("data"), + } + } +} + +#[async_trait] +impl Store for SqliteStore { + async fn push_batch(&self, records: impl Iterator + Send + Sync) -> Result<()> { + let mut tx = self.pool.begin().await?; + + for record in records { + Self::save_raw(&mut tx, record).await?; + } + + tx.commit().await?; + + Ok(()) + } + + async fn get(&self, id: &str) -> Result { + let res = sqlx::query("select * from records where id = ?1") + .bind(id) + .map(Self::query_row) + .fetch_one(&self.pool) + .await?; + + Ok(res) + } + + async fn len(&self, host: &str, tag: &str) -> Result { + let res: (i64,) = + sqlx::query_as("select count(1) from records where host = ?1 and tag = ?2") + .bind(host) + .bind(tag) + .fetch_one(&self.pool) + .await?; + + Ok(res.0 as u64) + } + + async fn next(&self, record: &Record) -> Result> { + let res = sqlx::query("select * from records where parent = ?1") + .bind(record.id.clone()) + .map(Self::query_row) + .fetch_one(&self.pool) + .await; + + match res { + Err(sqlx::Error::RowNotFound) => Ok(None), + Err(e) => Err(eyre!("an error occured: {}", e)), + Ok(v) => Ok(Some(v)), + } + } + + async fn first(&self, host: &str, tag: &str) -> Result> { + let res = sqlx::query( + "select * from records where host = ?1 and tag = ?2 and parent is null limit 1", + ) + .bind(host) + .bind(tag) + .map(Self::query_row) + .fetch_optional(&self.pool) + .await?; + + Ok(res) + } + + async fn last(&self, host: &str, tag: &str) -> Result> { + let res = sqlx::query( + "select * from records rp where tag=?1 and host=?2 and (select count(1) from records where parent=rp.id) = 0;", + ) + .bind(tag) + .bind(host) + .map(Self::query_row) + .fetch_optional(&self.pool) + .await?; + + Ok(res) + } +} + +#[cfg(test)] +mod tests { + use atuin_common::record::Record; + + use crate::record::store::Store; + + use super::SqliteStore; + + fn test_record() -> Record { + Record::builder() + .host(atuin_common::utils::uuid_v7().simple().to_string()) + .version("v1".into()) + .tag(atuin_common::utils::uuid_v7().simple().to_string()) + .data(vec![0, 1, 2, 3]) + .build() + } + + #[tokio::test] + async fn create_db() { + let db = SqliteStore::new(":memory:").await; + + assert!( + db.is_ok(), + "db could not be created, {:?}", + db.err().unwrap() + ); + } + + #[tokio::test] + async fn push_record() { + let db = SqliteStore::new(":memory:").await.unwrap(); + let record = test_record(); + + db.push(&record).await.expect("failed to insert record"); + } + + #[tokio::test] + async fn get_record() { + let db = SqliteStore::new(":memory:").await.unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let new_record = db + .get(record.id.as_str()) + .await + .expect("failed to fetch record"); + + assert_eq!(record, new_record, "records are not equal"); + } + + #[tokio::test] + async fn len() { + let db = SqliteStore::new(":memory:").await.unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let len = db + .len(record.host.as_str(), record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!(len, 1, "expected length of 1 after insert"); + } + + #[tokio::test] + async fn len_different_tags() { + let db = SqliteStore::new(":memory:").await.unwrap(); + + // these have different tags, so the len should be the same + // we model multiple stores within one database + // new store = new tag = independent length + let first = test_record(); + let second = test_record(); + + db.push(&first).await.unwrap(); + db.push(&second).await.unwrap(); + + let first_len = db + .len(first.host.as_str(), first.tag.as_str()) + .await + .unwrap(); + let second_len = db + .len(second.host.as_str(), second.tag.as_str()) + .await + .unwrap(); + + assert_eq!(first_len, 1, "expected length of 1 after insert"); + assert_eq!(second_len, 1, "expected length of 1 after insert"); + } + + #[tokio::test] + async fn append_a_bunch() { + let db = SqliteStore::new(":memory:").await.unwrap(); + + let mut tail = test_record(); + db.push(&tail).await.expect("failed to push record"); + + for _ in 1..100 { + tail = tail.new_child(vec![1, 2, 3, 4]); + db.push(&tail).await.unwrap(); + } + + assert_eq!( + db.len(tail.host.as_str(), tail.tag.as_str()).await.unwrap(), + 100, + "failed to insert 100 records" + ); + } + + #[tokio::test] + async fn append_a_big_bunch() { + let db = SqliteStore::new(":memory:").await.unwrap(); + + let mut records: Vec = Vec::with_capacity(10000); + + let mut tail = test_record(); + records.push(tail.clone()); + + for _ in 1..10000 { + tail = tail.new_child(vec![1, 2, 3]); + records.push(tail.clone()); + } + + db.push_batch(records.iter()).await.unwrap(); + + assert_eq!( + db.len(tail.host.as_str(), tail.tag.as_str()).await.unwrap(), + 10000, + "failed to insert 10k records" + ); + } + + #[tokio::test] + async fn test_chain() { + let db = SqliteStore::new(":memory:").await.unwrap(); + + let mut records: Vec = Vec::with_capacity(1000); + + let mut tail = test_record(); + records.push(tail.clone()); + + for _ in 1..1000 { + tail = tail.new_child(vec![1, 2, 3]); + records.push(tail.clone()); + } + + db.push_batch(records.iter()).await.unwrap(); + + let mut record = db + .first(tail.host.as_str(), tail.tag.as_str()) + .await + .expect("in memory sqlite should not fail") + .expect("entry exists"); + + let mut count = 1; + + while let Some(next) = db.next(&record).await.unwrap() { + assert_eq!(record.id, next.clone().parent.unwrap()); + record = next; + + count += 1; + } + + assert_eq!(count, 1000); + } +} diff --git a/atuin-client/src/record/store.rs b/atuin-client/src/record/store.rs new file mode 100644 index 00000000..75d79fb5 --- /dev/null +++ b/atuin-client/src/record/store.rs @@ -0,0 +1,30 @@ +use async_trait::async_trait; +use eyre::Result; + +use atuin_common::record::Record; + +/// A record store stores records +/// In more detail - we tend to need to process this into _another_ format to actually query it. +/// As is, the record store is intended as the source of truth for arbitratry data, which could +/// be shell history, kvs, etc. +#[async_trait] +pub trait Store { + // Push a record + async fn push(&self, record: &Record) -> Result<()> { + self.push_batch(std::iter::once(record)).await + } + + // Push a batch of records, all in one transaction + async fn push_batch(&self, records: impl Iterator + Send + Sync) -> Result<()>; + + async fn get(&self, id: &str) -> Result; + async fn len(&self, host: &str, tag: &str) -> Result; + + /// Get the record that follows this record + async fn next(&self, record: &Record) -> Result>; + + /// Get the first record for a given host and tag + async fn first(&self, host: &str, tag: &str) -> Result>; + /// Get the last record for a given host and tag + async fn last(&self, host: &str, tag: &str) -> Result>; +} diff --git a/atuin-client/src/settings.rs b/atuin-client/src/settings.rs index 524b2fd7..dd072451 100644 --- a/atuin-client/src/settings.rs +++ b/atuin-client/src/settings.rs @@ -17,6 +17,7 @@ pub const HISTORY_PAGE_SIZE: i64 = 100; pub const LAST_SYNC_FILENAME: &str = "last_sync_time"; pub const LAST_VERSION_CHECK_FILENAME: &str = "last_version_check_time"; pub const LATEST_VERSION_FILENAME: &str = "latest_version"; +pub const HOST_ID_FILENAME: &str = "host_id"; #[derive(Clone, Debug, Deserialize, Copy, ValueEnum, PartialEq)] pub enum SearchMode { @@ -140,6 +141,7 @@ pub struct Settings { pub sync_address: String, pub sync_frequency: String, pub db_path: String, + pub record_store_path: String, pub key_path: String, pub session_path: String, pub search_mode: SearchMode, @@ -226,6 +228,21 @@ impl Settings { Settings::load_time_from_file(LAST_VERSION_CHECK_FILENAME) } + pub fn host_id() -> Option { + let id = Settings::read_from_data_dir(HOST_ID_FILENAME); + + if id.is_some() { + return id; + } + + let uuid = atuin_common::utils::uuid_v7(); + + Settings::save_to_data_dir(HOST_ID_FILENAME, uuid.as_simple().to_string().as_ref()) + .expect("Could not write host ID to data dir"); + + Some(uuid.as_simple().to_string()) + } + pub fn should_sync(&self) -> Result { if !self.auto_sync || !PathBuf::from(self.session_path.as_str()).exists() { return Ok(false); @@ -321,11 +338,14 @@ impl Settings { config_file.push("config.toml"); let db_path = data_dir.join("history.db"); + let record_store_path = data_dir.join("records.db"); + let key_path = data_dir.join("key"); let session_path = data_dir.join("session"); let mut config_builder = Config::builder() .set_default("db_path", db_path.to_str())? + .set_default("record_store_path", record_store_path.to_str())? .set_default("key_path", key_path.to_str())? .set_default("session_path", session_path.to_str())? .set_default("dialect", "us")? -- cgit v1.3.1