diff options
| author | Ellie Huxtable <ellie@elliehuxtable.com> | 2023-06-14 21:18:24 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-06-14 21:18:24 +0100 |
| commit | ae1709dafd22ac3c64441472e90df8799253292e (patch) | |
| tree | 88d1cb17af6af9948d44ffb7242d69be5743785d /atuin-client/src/record | |
| parent | Bump debian from bullseye-20230502-slim to bullseye-20230612-slim (#1047) (diff) | |
| download | atuin-ae1709dafd22ac3c64441472e90df8799253292e.zip | |
Key values (#1038)
* wip
* Start testing
* Store host IDs, not hostnames
Why? Hostnames can change a lot, and therefore host filtering can be
funky. Really, all we want is a unique ID per machine + do not care what
it might be.
* Mostly just write a fuckload of tests
* Add a v0 kv store I can push to
* Appending works
* Add next() and iterate, test the pointer chain
* Fix sig
* Make clippy happy and thaw the ICE
* Fix tests'
* Fix tests
* typed builder and cleaner db trait
---------
Co-authored-by: Conrad Ludgate <conrad.ludgate@truelayer.com>
Diffstat (limited to 'atuin-client/src/record')
| -rw-r--r-- | atuin-client/src/record/mod.rs | 2 | ||||
| -rw-r--r-- | atuin-client/src/record/sqlite_store.rs | 331 | ||||
| -rw-r--r-- | atuin-client/src/record/store.rs | 30 |
3 files changed, 363 insertions, 0 deletions
diff --git a/atuin-client/src/record/mod.rs b/atuin-client/src/record/mod.rs new file mode 100644 index 00000000..72c1f889 --- /dev/null +++ b/atuin-client/src/record/mod.rs @@ -0,0 +1,2 @@ +pub mod sqlite_store; +pub mod store; diff --git a/atuin-client/src/record/sqlite_store.rs b/atuin-client/src/record/sqlite_store.rs new file mode 100644 index 00000000..f116b6e5 --- /dev/null +++ b/atuin-client/src/record/sqlite_store.rs @@ -0,0 +1,331 @@ +// Here we are using sqlite as a pretty dumb store, and will not be running any complex queries. +// Multiple stores of multiple types are all stored in one chonky table (for now), and we just index +// by tag/host + +use std::path::Path; +use std::str::FromStr; + +use async_trait::async_trait; +use eyre::{eyre, Result}; +use fs_err as fs; +use sqlx::{ + sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow}, + Row, +}; + +use atuin_common::record::Record; + +use super::store::Store; + +pub struct SqliteStore { + pool: SqlitePool, +} + +impl SqliteStore { + pub async fn new(path: impl AsRef<Path>) -> Result<Self> { + let path = path.as_ref(); + + debug!("opening sqlite database at {:?}", path); + + let create = !path.exists(); + if create { + if let Some(dir) = path.parent() { + fs::create_dir_all(dir)?; + } + } + + let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())? + .journal_mode(SqliteJournalMode::Wal) + .create_if_missing(true); + + let pool = SqlitePoolOptions::new().connect_with(opts).await?; + + Self::setup_db(&pool).await?; + + Ok(Self { pool }) + } + + async fn setup_db(pool: &SqlitePool) -> Result<()> { + debug!("running sqlite database setup"); + + sqlx::migrate!("./record-migrations").run(pool).await?; + + Ok(()) + } + + async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, r: &Record) -> Result<()> { + // In sqlite, we are "limited" to i64. But that is still fine, until 2262. + sqlx::query( + "insert or ignore into records(id, host, tag, timestamp, parent, version, data) + values(?1, ?2, ?3, ?4, ?5, ?6, ?7)", + ) + .bind(r.id.as_str()) + .bind(r.host.as_str()) + .bind(r.tag.as_str()) + .bind(r.timestamp as i64) + .bind(r.parent.as_ref()) + .bind(r.version.as_str()) + .bind(r.data.as_slice()) + .execute(tx) + .await?; + + Ok(()) + } + + fn query_row(row: SqliteRow) -> Record { + let timestamp: i64 = row.get("timestamp"); + + Record { + id: row.get("id"), + host: row.get("host"), + parent: row.get("parent"), + timestamp: timestamp as u64, + tag: row.get("tag"), + version: row.get("version"), + data: row.get("data"), + } + } +} + +#[async_trait] +impl Store for SqliteStore { + async fn push_batch(&self, records: impl Iterator<Item = &Record> + Send + Sync) -> Result<()> { + let mut tx = self.pool.begin().await?; + + for record in records { + Self::save_raw(&mut tx, record).await?; + } + + tx.commit().await?; + + Ok(()) + } + + async fn get(&self, id: &str) -> Result<Record> { + let res = sqlx::query("select * from records where id = ?1") + .bind(id) + .map(Self::query_row) + .fetch_one(&self.pool) + .await?; + + Ok(res) + } + + async fn len(&self, host: &str, tag: &str) -> Result<u64> { + let res: (i64,) = + sqlx::query_as("select count(1) from records where host = ?1 and tag = ?2") + .bind(host) + .bind(tag) + .fetch_one(&self.pool) + .await?; + + Ok(res.0 as u64) + } + + async fn next(&self, record: &Record) -> Result<Option<Record>> { + let res = sqlx::query("select * from records where parent = ?1") + .bind(record.id.clone()) + .map(Self::query_row) + .fetch_one(&self.pool) + .await; + + match res { + Err(sqlx::Error::RowNotFound) => Ok(None), + Err(e) => Err(eyre!("an error occured: {}", e)), + Ok(v) => Ok(Some(v)), + } + } + + async fn first(&self, host: &str, tag: &str) -> Result<Option<Record>> { + let res = sqlx::query( + "select * from records where host = ?1 and tag = ?2 and parent is null limit 1", + ) + .bind(host) + .bind(tag) + .map(Self::query_row) + .fetch_optional(&self.pool) + .await?; + + Ok(res) + } + + async fn last(&self, host: &str, tag: &str) -> Result<Option<Record>> { + let res = sqlx::query( + "select * from records rp where tag=?1 and host=?2 and (select count(1) from records where parent=rp.id) = 0;", + ) + .bind(tag) + .bind(host) + .map(Self::query_row) + .fetch_optional(&self.pool) + .await?; + + Ok(res) + } +} + +#[cfg(test)] +mod tests { + use atuin_common::record::Record; + + use crate::record::store::Store; + + use super::SqliteStore; + + fn test_record() -> Record { + Record::builder() + .host(atuin_common::utils::uuid_v7().simple().to_string()) + .version("v1".into()) + .tag(atuin_common::utils::uuid_v7().simple().to_string()) + .data(vec![0, 1, 2, 3]) + .build() + } + + #[tokio::test] + async fn create_db() { + let db = SqliteStore::new(":memory:").await; + + assert!( + db.is_ok(), + "db could not be created, {:?}", + db.err().unwrap() + ); + } + + #[tokio::test] + async fn push_record() { + let db = SqliteStore::new(":memory:").await.unwrap(); + let record = test_record(); + + db.push(&record).await.expect("failed to insert record"); + } + + #[tokio::test] + async fn get_record() { + let db = SqliteStore::new(":memory:").await.unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let new_record = db + .get(record.id.as_str()) + .await + .expect("failed to fetch record"); + + assert_eq!(record, new_record, "records are not equal"); + } + + #[tokio::test] + async fn len() { + let db = SqliteStore::new(":memory:").await.unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let len = db + .len(record.host.as_str(), record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!(len, 1, "expected length of 1 after insert"); + } + + #[tokio::test] + async fn len_different_tags() { + let db = SqliteStore::new(":memory:").await.unwrap(); + + // these have different tags, so the len should be the same + // we model multiple stores within one database + // new store = new tag = independent length + let first = test_record(); + let second = test_record(); + + db.push(&first).await.unwrap(); + db.push(&second).await.unwrap(); + + let first_len = db + .len(first.host.as_str(), first.tag.as_str()) + .await + .unwrap(); + let second_len = db + .len(second.host.as_str(), second.tag.as_str()) + .await + .unwrap(); + + assert_eq!(first_len, 1, "expected length of 1 after insert"); + assert_eq!(second_len, 1, "expected length of 1 after insert"); + } + + #[tokio::test] + async fn append_a_bunch() { + let db = SqliteStore::new(":memory:").await.unwrap(); + + let mut tail = test_record(); + db.push(&tail).await.expect("failed to push record"); + + for _ in 1..100 { + tail = tail.new_child(vec![1, 2, 3, 4]); + db.push(&tail).await.unwrap(); + } + + assert_eq!( + db.len(tail.host.as_str(), tail.tag.as_str()).await.unwrap(), + 100, + "failed to insert 100 records" + ); + } + + #[tokio::test] + async fn append_a_big_bunch() { + let db = SqliteStore::new(":memory:").await.unwrap(); + + let mut records: Vec<Record> = Vec::with_capacity(10000); + + let mut tail = test_record(); + records.push(tail.clone()); + + for _ in 1..10000 { + tail = tail.new_child(vec![1, 2, 3]); + records.push(tail.clone()); + } + + db.push_batch(records.iter()).await.unwrap(); + + assert_eq!( + db.len(tail.host.as_str(), tail.tag.as_str()).await.unwrap(), + 10000, + "failed to insert 10k records" + ); + } + + #[tokio::test] + async fn test_chain() { + let db = SqliteStore::new(":memory:").await.unwrap(); + + let mut records: Vec<Record> = Vec::with_capacity(1000); + + let mut tail = test_record(); + records.push(tail.clone()); + + for _ in 1..1000 { + tail = tail.new_child(vec![1, 2, 3]); + records.push(tail.clone()); + } + + db.push_batch(records.iter()).await.unwrap(); + + let mut record = db + .first(tail.host.as_str(), tail.tag.as_str()) + .await + .expect("in memory sqlite should not fail") + .expect("entry exists"); + + let mut count = 1; + + while let Some(next) = db.next(&record).await.unwrap() { + assert_eq!(record.id, next.clone().parent.unwrap()); + record = next; + + count += 1; + } + + assert_eq!(count, 1000); + } +} diff --git a/atuin-client/src/record/store.rs b/atuin-client/src/record/store.rs new file mode 100644 index 00000000..75d79fb5 --- /dev/null +++ b/atuin-client/src/record/store.rs @@ -0,0 +1,30 @@ +use async_trait::async_trait; +use eyre::Result; + +use atuin_common::record::Record; + +/// A record store stores records +/// In more detail - we tend to need to process this into _another_ format to actually query it. +/// As is, the record store is intended as the source of truth for arbitratry data, which could +/// be shell history, kvs, etc. +#[async_trait] +pub trait Store { + // Push a record + async fn push(&self, record: &Record) -> Result<()> { + self.push_batch(std::iter::once(record)).await + } + + // Push a batch of records, all in one transaction + async fn push_batch(&self, records: impl Iterator<Item = &Record> + Send + Sync) -> Result<()>; + + async fn get(&self, id: &str) -> Result<Record>; + async fn len(&self, host: &str, tag: &str) -> Result<u64>; + + /// Get the record that follows this record + async fn next(&self, record: &Record) -> Result<Option<Record>>; + + /// Get the first record for a given host and tag + async fn first(&self, host: &str, tag: &str) -> Result<Option<Record>>; + /// Get the last record for a given host and tag + async fn last(&self, host: &str, tag: &str) -> Result<Option<Record>>; +} |
