aboutsummaryrefslogtreecommitdiffstats
path: root/atuin-client
diff options
context:
space:
mode:
authorEllie Huxtable <ellie@elliehuxtable.com>2023-06-14 21:18:24 +0100
committerGitHub <noreply@github.com>2023-06-14 21:18:24 +0100
commitae1709dafd22ac3c64441472e90df8799253292e (patch)
tree88d1cb17af6af9948d44ffb7242d69be5743785d /atuin-client
parentBump debian from bullseye-20230502-slim to bullseye-20230612-slim (#1047) (diff)
downloadatuin-ae1709dafd22ac3c64441472e90df8799253292e.zip
Key values (#1038)
* wip * Start testing * Store host IDs, not hostnames Why? Hostnames can change a lot, and therefore host filtering can be funky. Really, all we want is a unique ID per machine + do not care what it might be. * Mostly just write a fuckload of tests * Add a v0 kv store I can push to * Appending works * Add next() and iterate, test the pointer chain * Fix sig * Make clippy happy and thaw the ICE * Fix tests' * Fix tests * typed builder and cleaner db trait --------- Co-authored-by: Conrad Ludgate <conrad.ludgate@truelayer.com>
Diffstat (limited to 'atuin-client')
-rw-r--r--atuin-client/Cargo.toml3
-rw-r--r--atuin-client/record-migrations/20230531212437_create-records.sql15
-rw-r--r--atuin-client/src/database.rs7
-rw-r--r--atuin-client/src/kv.rs103
-rw-r--r--atuin-client/src/lib.rs2
-rw-r--r--atuin-client/src/record/mod.rs2
-rw-r--r--atuin-client/src/record/sqlite_store.rs331
-rw-r--r--atuin-client/src/record/store.rs30
-rw-r--r--atuin-client/src/settings.rs20
9 files changed, 510 insertions, 3 deletions
diff --git a/atuin-client/Cargo.toml b/atuin-client/Cargo.toml
index 42e3cf6b..7b85bf76 100644
--- a/atuin-client/Cargo.toml
+++ b/atuin-client/Cargo.toml
@@ -18,7 +18,6 @@ sync = [
"reqwest",
"sha2",
"hex",
- "rmp-serde",
"base64",
"generic-array",
"xsalsa20poly1305",
@@ -51,13 +50,13 @@ fs-err = { workspace = true }
sql-builder = "3"
lazy_static = "1"
memchr = "2.5"
+rmp-serde = { version = "1.1.1" }
# sync
urlencoding = { version = "2.1.0", optional = true }
reqwest = { workspace = true, optional = true }
hex = { version = "0.4", optional = true }
sha2 = { version = "0.10", optional = true }
-rmp-serde = { version = "1.1.1", optional = true }
base64 = { workspace = true, optional = true }
tokio = { workspace = true }
semver = { workspace = true }
diff --git a/atuin-client/record-migrations/20230531212437_create-records.sql b/atuin-client/record-migrations/20230531212437_create-records.sql
new file mode 100644
index 00000000..46963358
--- /dev/null
+++ b/atuin-client/record-migrations/20230531212437_create-records.sql
@@ -0,0 +1,15 @@
+-- Add migration script here
+create table if not exists records (
+ id text primary key,
+ parent text unique, -- null if this is the first one
+ host text not null,
+
+ timestamp integer not null,
+ tag text not null,
+ version text not null,
+ data blob not null
+);
+
+create index host_idx on records (host);
+create index tag_idx on records (tag);
+create index host_tag_idx on records (host, tag);
diff --git a/atuin-client/src/database.rs b/atuin-client/src/database.rs
index 22bd5886..a2d8c533 100644
--- a/atuin-client/src/database.rs
+++ b/atuin-client/src/database.rs
@@ -17,13 +17,14 @@ use sqlx::{
use super::{
history::History,
ordering,
- settings::{FilterMode, SearchMode},
+ settings::{FilterMode, SearchMode, Settings},
};
pub struct Context {
pub session: String,
pub cwd: String,
pub hostname: String,
+ pub host_id: String,
}
#[derive(Default, Clone)]
@@ -50,11 +51,13 @@ pub fn current_context() -> Context {
env::var("ATUIN_HOST_USER").unwrap_or_else(|_| whoami::username())
);
let cwd = utils::get_current_dir();
+ let host_id = Settings::host_id().expect("failed to load host ID");
Context {
session,
hostname,
cwd,
+ host_id,
}
}
@@ -551,6 +554,7 @@ mod test {
hostname: "test:host".to_string(),
session: "beepboopiamasession".to_string(),
cwd: "/home/ellie".to_string(),
+ host_id: "test-host".to_string(),
};
let results = db
@@ -757,6 +761,7 @@ mod test {
hostname: "test:host".to_string(),
session: "beepboopiamasession".to_string(),
cwd: "/home/ellie".to_string(),
+ host_id: "test-host".to_string(),
};
let mut db = Sqlite::new("sqlite::memory:").await.unwrap();
diff --git a/atuin-client/src/kv.rs b/atuin-client/src/kv.rs
new file mode 100644
index 00000000..87149275
--- /dev/null
+++ b/atuin-client/src/kv.rs
@@ -0,0 +1,103 @@
+use eyre::Result;
+use serde::{Deserialize, Serialize};
+
+use crate::record::store::Store;
+use crate::settings::Settings;
+
+const KV_VERSION: &str = "v0";
+const KV_TAG: &str = "kv";
+
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
+pub struct KvRecord {
+ pub key: String,
+ pub value: String,
+}
+
+impl KvRecord {
+ pub fn serialize(&self) -> Result<Vec<u8>> {
+ let buf = rmp_serde::to_vec(self)?;
+
+ Ok(buf)
+ }
+}
+
+pub struct KvStore;
+
+impl Default for KvStore {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl KvStore {
+ // will want to init the actual kv store when that is done
+ pub fn new() -> KvStore {
+ KvStore {}
+ }
+
+ pub async fn set(
+ &self,
+ store: &mut (impl Store + Send + Sync),
+ key: &str,
+ value: &str,
+ ) -> Result<()> {
+ let host_id = Settings::host_id().expect("failed to get host_id");
+
+ let record = KvRecord {
+ key: key.to_string(),
+ value: value.to_string(),
+ };
+
+ let bytes = record.serialize()?;
+
+ let parent = store
+ .last(host_id.as_str(), KV_TAG)
+ .await?
+ .map(|entry| entry.id);
+
+ let record = atuin_common::record::Record::builder()
+ .host(host_id)
+ .version(KV_VERSION.to_string())
+ .tag(KV_TAG.to_string())
+ .parent(parent)
+ .data(bytes)
+ .build();
+
+ store.push(&record).await?;
+
+ Ok(())
+ }
+
+ // TODO: setup an actual kv store, rebuild func, and do not pass the main store in here as
+ // well.
+ pub async fn get(&self, store: &impl Store, key: &str) -> Result<Option<KvRecord>> {
+ // TODO: don't load this from disk so much
+ let host_id = Settings::host_id().expect("failed to get host_id");
+
+ // Currently, this is O(n). When we have an actual KV store, it can be better
+ // Just a poc for now!
+
+ // iterate records to find the value we want
+ // start at the end, so we get the most recent version
+ let Some(mut record) = store.last(host_id.as_str(), KV_TAG).await? else {
+ return Ok(None);
+ };
+ let kv: KvRecord = rmp_serde::from_slice(&record.data)?;
+
+ if kv.key == key {
+ return Ok(Some(kv));
+ }
+
+ while let Some(parent) = record.parent {
+ record = store.get(parent.as_str()).await?;
+ let kv: KvRecord = rmp_serde::from_slice(&record.data)?;
+
+ if kv.key == key {
+ return Ok(Some(kv));
+ }
+ }
+
+ // if we get here, then... we didn't find the record with that key :(
+ Ok(None)
+ }
+}
diff --git a/atuin-client/src/lib.rs b/atuin-client/src/lib.rs
index 497c5e74..3f12153a 100644
--- a/atuin-client/src/lib.rs
+++ b/atuin-client/src/lib.rs
@@ -13,5 +13,7 @@ pub mod sync;
pub mod database;
pub mod history;
pub mod import;
+pub mod kv;
pub mod ordering;
+pub mod record;
pub mod settings;
diff --git a/atuin-client/src/record/mod.rs b/atuin-client/src/record/mod.rs
new file mode 100644
index 00000000..72c1f889
--- /dev/null
+++ b/atuin-client/src/record/mod.rs
@@ -0,0 +1,2 @@
+pub mod sqlite_store;
+pub mod store;
diff --git a/atuin-client/src/record/sqlite_store.rs b/atuin-client/src/record/sqlite_store.rs
new file mode 100644
index 00000000..f116b6e5
--- /dev/null
+++ b/atuin-client/src/record/sqlite_store.rs
@@ -0,0 +1,331 @@
+// Here we are using sqlite as a pretty dumb store, and will not be running any complex queries.
+// Multiple stores of multiple types are all stored in one chonky table (for now), and we just index
+// by tag/host
+
+use std::path::Path;
+use std::str::FromStr;
+
+use async_trait::async_trait;
+use eyre::{eyre, Result};
+use fs_err as fs;
+use sqlx::{
+ sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow},
+ Row,
+};
+
+use atuin_common::record::Record;
+
+use super::store::Store;
+
+pub struct SqliteStore {
+ pool: SqlitePool,
+}
+
+impl SqliteStore {
+ pub async fn new(path: impl AsRef<Path>) -> Result<Self> {
+ let path = path.as_ref();
+
+ debug!("opening sqlite database at {:?}", path);
+
+ let create = !path.exists();
+ if create {
+ if let Some(dir) = path.parent() {
+ fs::create_dir_all(dir)?;
+ }
+ }
+
+ let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())?
+ .journal_mode(SqliteJournalMode::Wal)
+ .create_if_missing(true);
+
+ let pool = SqlitePoolOptions::new().connect_with(opts).await?;
+
+ Self::setup_db(&pool).await?;
+
+ Ok(Self { pool })
+ }
+
+ async fn setup_db(pool: &SqlitePool) -> Result<()> {
+ debug!("running sqlite database setup");
+
+ sqlx::migrate!("./record-migrations").run(pool).await?;
+
+ Ok(())
+ }
+
+ async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, r: &Record) -> Result<()> {
+ // In sqlite, we are "limited" to i64. But that is still fine, until 2262.
+ sqlx::query(
+ "insert or ignore into records(id, host, tag, timestamp, parent, version, data)
+ values(?1, ?2, ?3, ?4, ?5, ?6, ?7)",
+ )
+ .bind(r.id.as_str())
+ .bind(r.host.as_str())
+ .bind(r.tag.as_str())
+ .bind(r.timestamp as i64)
+ .bind(r.parent.as_ref())
+ .bind(r.version.as_str())
+ .bind(r.data.as_slice())
+ .execute(tx)
+ .await?;
+
+ Ok(())
+ }
+
+ fn query_row(row: SqliteRow) -> Record {
+ let timestamp: i64 = row.get("timestamp");
+
+ Record {
+ id: row.get("id"),
+ host: row.get("host"),
+ parent: row.get("parent"),
+ timestamp: timestamp as u64,
+ tag: row.get("tag"),
+ version: row.get("version"),
+ data: row.get("data"),
+ }
+ }
+}
+
+#[async_trait]
+impl Store for SqliteStore {
+ async fn push_batch(&self, records: impl Iterator<Item = &Record> + Send + Sync) -> Result<()> {
+ let mut tx = self.pool.begin().await?;
+
+ for record in records {
+ Self::save_raw(&mut tx, record).await?;
+ }
+
+ tx.commit().await?;
+
+ Ok(())
+ }
+
+ async fn get(&self, id: &str) -> Result<Record> {
+ let res = sqlx::query("select * from records where id = ?1")
+ .bind(id)
+ .map(Self::query_row)
+ .fetch_one(&self.pool)
+ .await?;
+
+ Ok(res)
+ }
+
+ async fn len(&self, host: &str, tag: &str) -> Result<u64> {
+ let res: (i64,) =
+ sqlx::query_as("select count(1) from records where host = ?1 and tag = ?2")
+ .bind(host)
+ .bind(tag)
+ .fetch_one(&self.pool)
+ .await?;
+
+ Ok(res.0 as u64)
+ }
+
+ async fn next(&self, record: &Record) -> Result<Option<Record>> {
+ let res = sqlx::query("select * from records where parent = ?1")
+ .bind(record.id.clone())
+ .map(Self::query_row)
+ .fetch_one(&self.pool)
+ .await;
+
+ match res {
+ Err(sqlx::Error::RowNotFound) => Ok(None),
+ Err(e) => Err(eyre!("an error occured: {}", e)),
+ Ok(v) => Ok(Some(v)),
+ }
+ }
+
+ async fn first(&self, host: &str, tag: &str) -> Result<Option<Record>> {
+ let res = sqlx::query(
+ "select * from records where host = ?1 and tag = ?2 and parent is null limit 1",
+ )
+ .bind(host)
+ .bind(tag)
+ .map(Self::query_row)
+ .fetch_optional(&self.pool)
+ .await?;
+
+ Ok(res)
+ }
+
+ async fn last(&self, host: &str, tag: &str) -> Result<Option<Record>> {
+ let res = sqlx::query(
+ "select * from records rp where tag=?1 and host=?2 and (select count(1) from records where parent=rp.id) = 0;",
+ )
+ .bind(tag)
+ .bind(host)
+ .map(Self::query_row)
+ .fetch_optional(&self.pool)
+ .await?;
+
+ Ok(res)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use atuin_common::record::Record;
+
+ use crate::record::store::Store;
+
+ use super::SqliteStore;
+
+ fn test_record() -> Record {
+ Record::builder()
+ .host(atuin_common::utils::uuid_v7().simple().to_string())
+ .version("v1".into())
+ .tag(atuin_common::utils::uuid_v7().simple().to_string())
+ .data(vec![0, 1, 2, 3])
+ .build()
+ }
+
+ #[tokio::test]
+ async fn create_db() {
+ let db = SqliteStore::new(":memory:").await;
+
+ assert!(
+ db.is_ok(),
+ "db could not be created, {:?}",
+ db.err().unwrap()
+ );
+ }
+
+ #[tokio::test]
+ async fn push_record() {
+ let db = SqliteStore::new(":memory:").await.unwrap();
+ let record = test_record();
+
+ db.push(&record).await.expect("failed to insert record");
+ }
+
+ #[tokio::test]
+ async fn get_record() {
+ let db = SqliteStore::new(":memory:").await.unwrap();
+ let record = test_record();
+ db.push(&record).await.unwrap();
+
+ let new_record = db
+ .get(record.id.as_str())
+ .await
+ .expect("failed to fetch record");
+
+ assert_eq!(record, new_record, "records are not equal");
+ }
+
+ #[tokio::test]
+ async fn len() {
+ let db = SqliteStore::new(":memory:").await.unwrap();
+ let record = test_record();
+ db.push(&record).await.unwrap();
+
+ let len = db
+ .len(record.host.as_str(), record.tag.as_str())
+ .await
+ .expect("failed to get store len");
+
+ assert_eq!(len, 1, "expected length of 1 after insert");
+ }
+
+ #[tokio::test]
+ async fn len_different_tags() {
+ let db = SqliteStore::new(":memory:").await.unwrap();
+
+ // these have different tags, so the len should be the same
+ // we model multiple stores within one database
+ // new store = new tag = independent length
+ let first = test_record();
+ let second = test_record();
+
+ db.push(&first).await.unwrap();
+ db.push(&second).await.unwrap();
+
+ let first_len = db
+ .len(first.host.as_str(), first.tag.as_str())
+ .await
+ .unwrap();
+ let second_len = db
+ .len(second.host.as_str(), second.tag.as_str())
+ .await
+ .unwrap();
+
+ assert_eq!(first_len, 1, "expected length of 1 after insert");
+ assert_eq!(second_len, 1, "expected length of 1 after insert");
+ }
+
+ #[tokio::test]
+ async fn append_a_bunch() {
+ let db = SqliteStore::new(":memory:").await.unwrap();
+
+ let mut tail = test_record();
+ db.push(&tail).await.expect("failed to push record");
+
+ for _ in 1..100 {
+ tail = tail.new_child(vec![1, 2, 3, 4]);
+ db.push(&tail).await.unwrap();
+ }
+
+ assert_eq!(
+ db.len(tail.host.as_str(), tail.tag.as_str()).await.unwrap(),
+ 100,
+ "failed to insert 100 records"
+ );
+ }
+
+ #[tokio::test]
+ async fn append_a_big_bunch() {
+ let db = SqliteStore::new(":memory:").await.unwrap();
+
+ let mut records: Vec<Record> = Vec::with_capacity(10000);
+
+ let mut tail = test_record();
+ records.push(tail.clone());
+
+ for _ in 1..10000 {
+ tail = tail.new_child(vec![1, 2, 3]);
+ records.push(tail.clone());
+ }
+
+ db.push_batch(records.iter()).await.unwrap();
+
+ assert_eq!(
+ db.len(tail.host.as_str(), tail.tag.as_str()).await.unwrap(),
+ 10000,
+ "failed to insert 10k records"
+ );
+ }
+
+ #[tokio::test]
+ async fn test_chain() {
+ let db = SqliteStore::new(":memory:").await.unwrap();
+
+ let mut records: Vec<Record> = Vec::with_capacity(1000);
+
+ let mut tail = test_record();
+ records.push(tail.clone());
+
+ for _ in 1..1000 {
+ tail = tail.new_child(vec![1, 2, 3]);
+ records.push(tail.clone());
+ }
+
+ db.push_batch(records.iter()).await.unwrap();
+
+ let mut record = db
+ .first(tail.host.as_str(), tail.tag.as_str())
+ .await
+ .expect("in memory sqlite should not fail")
+ .expect("entry exists");
+
+ let mut count = 1;
+
+ while let Some(next) = db.next(&record).await.unwrap() {
+ assert_eq!(record.id, next.clone().parent.unwrap());
+ record = next;
+
+ count += 1;
+ }
+
+ assert_eq!(count, 1000);
+ }
+}
diff --git a/atuin-client/src/record/store.rs b/atuin-client/src/record/store.rs
new file mode 100644
index 00000000..75d79fb5
--- /dev/null
+++ b/atuin-client/src/record/store.rs
@@ -0,0 +1,30 @@
+use async_trait::async_trait;
+use eyre::Result;
+
+use atuin_common::record::Record;
+
+/// A record store stores records
+/// In more detail - we tend to need to process this into _another_ format to actually query it.
+/// As is, the record store is intended as the source of truth for arbitratry data, which could
+/// be shell history, kvs, etc.
+#[async_trait]
+pub trait Store {
+ // Push a record
+ async fn push(&self, record: &Record) -> Result<()> {
+ self.push_batch(std::iter::once(record)).await
+ }
+
+ // Push a batch of records, all in one transaction
+ async fn push_batch(&self, records: impl Iterator<Item = &Record> + Send + Sync) -> Result<()>;
+
+ async fn get(&self, id: &str) -> Result<Record>;
+ async fn len(&self, host: &str, tag: &str) -> Result<u64>;
+
+ /// Get the record that follows this record
+ async fn next(&self, record: &Record) -> Result<Option<Record>>;
+
+ /// Get the first record for a given host and tag
+ async fn first(&self, host: &str, tag: &str) -> Result<Option<Record>>;
+ /// Get the last record for a given host and tag
+ async fn last(&self, host: &str, tag: &str) -> Result<Option<Record>>;
+}
diff --git a/atuin-client/src/settings.rs b/atuin-client/src/settings.rs
index 524b2fd7..dd072451 100644
--- a/atuin-client/src/settings.rs
+++ b/atuin-client/src/settings.rs
@@ -17,6 +17,7 @@ pub const HISTORY_PAGE_SIZE: i64 = 100;
pub const LAST_SYNC_FILENAME: &str = "last_sync_time";
pub const LAST_VERSION_CHECK_FILENAME: &str = "last_version_check_time";
pub const LATEST_VERSION_FILENAME: &str = "latest_version";
+pub const HOST_ID_FILENAME: &str = "host_id";
#[derive(Clone, Debug, Deserialize, Copy, ValueEnum, PartialEq)]
pub enum SearchMode {
@@ -140,6 +141,7 @@ pub struct Settings {
pub sync_address: String,
pub sync_frequency: String,
pub db_path: String,
+ pub record_store_path: String,
pub key_path: String,
pub session_path: String,
pub search_mode: SearchMode,
@@ -226,6 +228,21 @@ impl Settings {
Settings::load_time_from_file(LAST_VERSION_CHECK_FILENAME)
}
+ pub fn host_id() -> Option<String> {
+ let id = Settings::read_from_data_dir(HOST_ID_FILENAME);
+
+ if id.is_some() {
+ return id;
+ }
+
+ let uuid = atuin_common::utils::uuid_v7();
+
+ Settings::save_to_data_dir(HOST_ID_FILENAME, uuid.as_simple().to_string().as_ref())
+ .expect("Could not write host ID to data dir");
+
+ Some(uuid.as_simple().to_string())
+ }
+
pub fn should_sync(&self) -> Result<bool> {
if !self.auto_sync || !PathBuf::from(self.session_path.as_str()).exists() {
return Ok(false);
@@ -321,11 +338,14 @@ impl Settings {
config_file.push("config.toml");
let db_path = data_dir.join("history.db");
+ let record_store_path = data_dir.join("records.db");
+
let key_path = data_dir.join("key");
let session_path = data_dir.join("session");
let mut config_builder = Config::builder()
.set_default("db_path", db_path.to_str())?
+ .set_default("record_store_path", record_store_path.to_str())?
.set_default("key_path", key_path.to_str())?
.set_default("session_path", session_path.to_str())?
.set_default("dialect", "us")?