diff options
Diffstat (limited to 'crates/turtle/src/atuin_client/database.rs')
| -rw-r--r-- | crates/turtle/src/atuin_client/database.rs | 316 |
1 files changed, 99 insertions, 217 deletions
diff --git a/crates/turtle/src/atuin_client/database.rs b/crates/turtle/src/atuin_client/database.rs index 1bfe93a7..f8b73809 100644 --- a/crates/turtle/src/atuin_client/database.rs +++ b/crates/turtle/src/atuin_client/database.rs @@ -5,9 +5,7 @@ use std::{ time::Duration, }; -use crate::atuin_client::history::{AUTHOR_FILTER_ALL_AGENT, AUTHOR_FILTER_ALL_USER, KNOWN_AGENTS}; use crate::atuin_common::utils; -use async_trait::async_trait; use fs_err as fs; use itertools::Itertools; use rand::{Rng, distributions::Alphanumeric}; @@ -55,8 +53,6 @@ pub(crate) struct OptFilters { pub(crate) offset: Option<i64>, pub(crate) reverse: bool, pub(crate) include_duplicates: bool, - /// Author filter. Supports special values `$all-user` and `$all-agent`. - pub(crate) authors: Vec<String>, } pub(crate) async fn current_context() -> eyre::Result<Context> { @@ -80,47 +76,15 @@ pub(crate) async fn current_context() -> eyre::Result<Context> { impl Context { pub(crate) fn from_history(entry: &History) -> Self { Context { - session: entry.session.to_string(), - cwd: entry.cwd.to_string(), - hostname: entry.hostname.to_string(), + session: entry.session.clone(), + cwd: entry.cwd.clone(), + hostname: entry.hostname.clone(), host_id: String::new(), git_root: utils::in_git_repo(entry.cwd.as_str()), } } } -/// Each entry is OR'd: `$all-user` → NOT IN agents, `$all-agent` → IN agents, literal → exact match. -fn apply_author_filter(sql: &mut SqlBuilder, authors: &[String]) { - let mut conditions: Vec<String> = Vec::new(); - let agent_list: String = KNOWN_AGENTS.iter().map(quote).join(", "); - let author_expr = "CASE \ - WHEN author IS NULL OR trim(author) = '' THEN \ - CASE \ - WHEN instr(hostname, ':') > 0 THEN substr(hostname, instr(hostname, ':') + 1) \ - ELSE hostname \ - END \ - ELSE author \ - END"; - - for author in authors { - match author.as_str() { - AUTHOR_FILTER_ALL_USER => { - conditions.push(format!("{author_expr} NOT IN ({agent_list})")); - } - AUTHOR_FILTER_ALL_AGENT => { - conditions.push(format!("{author_expr} IN ({agent_list})")); - } - literal => { - conditions.push(format!("{author_expr} = {}", quote(literal))); - } - } - } - - if !conditions.is_empty() { - sql.and_where(format!("({})", conditions.join(" OR "))); - } -} - fn get_session_start_time(session_id: &str) -> Option<i64> { if let Ok(uuid) = Uuid::parse_str(session_id) && let Some(timestamp) = uuid.get_timestamp() @@ -131,73 +95,22 @@ fn get_session_start_time(session_id: &str) -> Option<i64> { None } -#[async_trait] -pub(crate) trait Database: Send + Sync + 'static { - async fn save(&self, h: &History) -> Result<()>; - async fn save_bulk(&self, h: &[History]) -> Result<()>; - - async fn load(&self, id: &str) -> Result<Option<History>>; - async fn list( - &self, - filters: &[FilterMode], - context: &Context, - max: Option<usize>, - unique: bool, - include_deleted: bool, - ) -> Result<Vec<History>>; - async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>>; - - async fn update(&self, h: &History) -> Result<()>; - async fn history_count(&self, include_deleted: bool) -> Result<i64>; - - async fn last(&self) -> Result<Option<History>>; - async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>>; - - async fn delete(&self, h: History) -> Result<()>; - async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()>; - async fn deleted(&self) -> Result<Vec<History>>; - - // Yes I know, it's a lot. - // Could maybe break it down to a searchparams struct or smth but that feels a little... pointless. - // Been debating maybe a DSL for search? eg "before:time limit:1 the query" - #[expect(clippy::too_many_arguments)] - async fn search( - &self, - search_mode: SearchMode, - filter: FilterMode, - context: &Context, - query: &str, - filter_options: OptFilters, - ) -> Result<Vec<History>>; - - async fn query_history(&self, query: &str) -> Result<Vec<History>>; - - async fn all_with_count(&self) -> Result<Vec<(History, i32)>>; - - fn all_paged(&self, page_size: usize, include_deleted: bool, unique: bool) -> Paged; - - async fn stats(&self, h: &History) -> Result<HistoryStats>; - - async fn get_dups(&self, before: i64, dupkeep: u32) -> Result<Vec<History>>; - - fn clone_boxed(&self) -> Box<dyn Database + 'static>; -} - // Intended for use on a developer machine and not a sync server. // TODO: implement IntoIterator #[derive(Debug, Clone)] -pub(crate) struct Sqlite { +pub(crate) struct ClientSqlite { pub(crate) pool: SqlitePool, } -impl Sqlite { +impl ClientSqlite { pub(crate) async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> { let path = path.as_ref(); debug!("opening sqlite database at {path:?}"); if utils::broken_symlink(path) { eprintln!( - "Atuin: Sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement." + "Atuin: Sqlite db path ({}) is a broken symlink. Unable to read or create replacement.", + path.display() ); std::process::exit(1); } @@ -224,12 +137,6 @@ impl Sqlite { Ok(Self { pool }) } - pub(crate) async fn sqlite_version(&self) -> Result<String> { - sqlx::query_scalar("SELECT sqlite_version()") - .fetch_one(&self.pool) - .await - } - async fn setup_db(pool: &SqlitePool) -> Result<()> { debug!("running sqlite database setup"); @@ -272,7 +179,7 @@ impl Sqlite { Ok(()) } - fn query_history(row: SqliteRow) -> History { + fn query_history_inner(row: SqliteRow) -> History { let deleted_at: Option<i64> = row.get("deleted_at"); let hostname: String = row.get("hostname"); let author: Option<String> = row.try_get("author").ok().flatten(); @@ -304,9 +211,8 @@ impl Sqlite { } } -#[async_trait] -impl Database for Sqlite { - async fn save(&self, h: &History) -> Result<()> { +impl ClientSqlite { + pub(crate) async fn save(&self, h: &History) -> Result<()> { debug!("saving history to sqlite"); let mut tx = self.pool.begin().await?; Self::save_raw(&mut tx, h).await?; @@ -315,7 +221,7 @@ impl Database for Sqlite { Ok(()) } - async fn save_bulk(&self, h: &[History]) -> Result<()> { + pub(crate) async fn save_bulk(&self, h: &[History]) -> Result<()> { debug!("saving history to sqlite"); let mut tx = self.pool.begin().await?; @@ -329,19 +235,19 @@ impl Database for Sqlite { Ok(()) } - async fn load(&self, id: &str) -> Result<Option<History>> { + pub(crate) async fn load(&self, id: &str) -> Result<Option<History>> { debug!("loading history item {}", id); let res = sqlx::query("select * from history where id = ?1") .bind(id) - .map(Self::query_history) + .map(Self::query_history_inner) .fetch_optional(&self.pool) .await?; Ok(res) } - async fn update(&self, h: &History) -> Result<()> { + pub(crate) async fn update(&self, h: &History) -> Result<()> { debug!("updating sqlite history"); sqlx::query( @@ -367,7 +273,7 @@ impl Database for Sqlite { } // make a unique list, that only shows the *newest* version of things - async fn list( + pub(crate) async fn list( &self, filters: &[FilterMode], context: &Context, @@ -419,14 +325,18 @@ impl Database for Sqlite { let query = query.sql().expect("bug in list query. please report"); let res = sqlx::query(&query) - .map(Self::query_history) + .map(Self::query_history_inner) .fetch_all(&self.pool) .await?; Ok(res) } - async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>> { + pub(crate) async fn range( + &self, + from: OffsetDateTime, + to: OffsetDateTime, + ) -> Result<Vec<History>> { debug!("listing history from {:?} to {:?}", from, to); let res = sqlx::query( @@ -434,47 +344,25 @@ impl Database for Sqlite { ) .bind(from.unix_timestamp_nanos() as i64) .bind(to.unix_timestamp_nanos() as i64) - .map(Self::query_history) + .map(Self::query_history_inner) .fetch_all(&self.pool) .await?; Ok(res) } - async fn last(&self) -> Result<Option<History>> { + pub(crate) async fn last(&self) -> Result<Option<History>> { let res = sqlx::query( "select * from history where duration >= 0 order by timestamp desc limit 1", ) - .map(Self::query_history) + .map(Self::query_history_inner) .fetch_optional(&self.pool) .await?; Ok(res) } - async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>> { - let res = sqlx::query( - "select * from history where timestamp < ?1 order by timestamp desc limit ?2", - ) - .bind(timestamp.unix_timestamp_nanos() as i64) - .bind(count) - .map(Self::query_history) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } - - async fn deleted(&self) -> Result<Vec<History>> { - let res = sqlx::query("select * from history where deleted_at is not null") - .map(Self::query_history) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } - - async fn history_count(&self, include_deleted: bool) -> Result<i64> { + pub(crate) async fn history_count(&self, include_deleted: bool) -> Result<i64> { let query = if include_deleted { "select count(1) from history" } else { @@ -485,7 +373,10 @@ impl Database for Sqlite { Ok(res.0) } - async fn search( + // Yes I know, it's a lot. + // Could maybe break it down to a searchparams struct or smth but that feels a little... pointless. + // Been debating maybe a DSL for search? eg "before:time limit:1 the query" + pub(crate) async fn search( &self, search_mode: SearchMode, filter: FilterMode, @@ -541,54 +432,53 @@ impl Database for Sqlite { let orig_query = query; let mut regexes = Vec::new(); - match search_mode { - SearchMode::Prefix => sql.and_where_like_left("command", query.replace('*', "%")), - _ => { - let mut is_or = false; - for token in QueryTokenizer::new(query) { - // TODO smart case mode could be made configurable like in fzf - let (is_glob, glob) = if token.has_uppercase() { - (true, "*") - } else { - (false, "%") - }; - let param = match token { - QueryToken::Regex(r) => { - regexes.push(String::from(r)); + if search_mode == SearchMode::Prefix { + sql.and_where_like_left("command", query.replace('*', "%")) + } else { + let mut is_or = false; + for token in QueryTokenizer::new(query) { + // TODO smart case mode could be made configurable like in fzf + let (is_glob, glob) = if token.has_uppercase() { + (true, "*") + } else { + (false, "%") + }; + let param = match token { + QueryToken::Regex(r) => { + regexes.push(String::from(r)); + continue; + } + QueryToken::Or => { + if !is_or { + is_or = true; continue; + } else { + format!("{glob}|{glob}") } - QueryToken::Or => { - if !is_or { - is_or = true; - continue; - } else { - format!("{glob}|{glob}") - } - } - QueryToken::MatchStart(term, _) => { - format!("{term}{glob}") - } - QueryToken::MatchEnd(term, _) => { - format!("{glob}{term}") - } - QueryToken::MatchFull(term, _) => { + } + QueryToken::MatchStart(term, _) => { + format!("{term}{glob}") + } + QueryToken::MatchEnd(term, _) => { + format!("{glob}{term}") + } + QueryToken::MatchFull(term, _) => { + format!("{glob}{term}{glob}") + } + QueryToken::Match(term, _) => { + if search_mode == SearchMode::FullText { format!("{glob}{term}{glob}") + } else { + term.split("").join(glob) } - QueryToken::Match(term, _) => { - if search_mode == SearchMode::FullText { - format!("{glob}{term}{glob}") - } else { - term.split("").join(glob) - } - } - }; - - sql.fuzzy_condition("command", param, token.is_inverse(), is_glob, is_or); - is_or = false; - } + } + }; - &mut sql + sql.fuzzy_condition("command", param, token.is_inverse(), is_glob, is_or); + is_or = false; } + + &mut sql }; for regex in regexes { @@ -631,32 +521,28 @@ impl Database for Sqlite { .map(|after| sql.and_where_gt("timestamp", quote(after.unix_timestamp_nanos() as i64))) }); - if !filter_options.authors.is_empty() { - apply_author_filter(&mut sql, &filter_options.authors); - } - sql.and_where_is_null("deleted_at"); let query = sql.sql().expect("bug in search query. please report"); let res = sqlx::query(&query) - .map(Self::query_history) + .map(Self::query_history_inner) .fetch_all(&self.pool) .await?; Ok(ordering::reorder_fuzzy(search_mode, orig_query, res)) } - async fn query_history(&self, query: &str) -> Result<Vec<History>> { + pub(crate) async fn query_history(&self, query: &str) -> Result<Vec<History>> { let res = sqlx::query(query) - .map(Self::query_history) + .map(Self::query_history_inner) .fetch_all(&self.pool) .await?; Ok(res) } - async fn all_with_count(&self) -> Result<Vec<(History, i32)>> { + pub(crate) async fn all_with_count(&self) -> Result<Vec<(History, i32)>> { debug!("listing history"); let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted()); @@ -686,7 +572,7 @@ impl Database for Sqlite { let res = sqlx::query(&query) .map(|row: SqliteRow| { let count: i32 = row.get("count"); - (Self::query_history(row), count) + (Self::query_history_inner(row), count) }) .fetch_all(&self.pool) .await?; @@ -694,13 +580,13 @@ impl Database for Sqlite { Ok(res) } - fn all_paged(&self, page_size: usize, include_deleted: bool, unique: bool) -> Paged { - Paged::new(Box::new(self.clone()), page_size, include_deleted, unique) + pub(crate) fn all_paged(&self, page_size: usize, include_deleted: bool, unique: bool) -> Paged { + Paged::new(self.clone(), page_size, include_deleted, unique) } // deleted_at doesn't mean the actual time that the user deleted it, // but the time that the system marks it as deleted - async fn delete(&self, mut h: History) -> Result<()> { + pub(crate) async fn delete(&self, mut h: History) -> Result<()> { let now = OffsetDateTime::now_utc(); h.command = rand::thread_rng() .sample_iter(&Alphanumeric) @@ -714,7 +600,7 @@ impl Database for Sqlite { Ok(()) } - async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()> { + pub(crate) async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()> { let mut tx = self.pool.begin().await?; for id in ids { @@ -726,7 +612,7 @@ impl Database for Sqlite { Ok(()) } - async fn stats(&self, h: &History) -> Result<HistoryStats> { + pub(crate) async fn stats(&self, h: &History) -> Result<HistoryStats> { // We select the previous in the session by time let mut prev = SqlBuilder::select_from("history"); prev.field("*") @@ -791,14 +677,14 @@ impl Database for Sqlite { let prev = sqlx::query(&prev) .bind(h.timestamp.unix_timestamp_nanos() as i64) .bind(&h.session) - .map(Self::query_history) + .map(Self::query_history_inner) .fetch_optional(&self.pool) .await?; let next = sqlx::query(&next) .bind(h.timestamp.unix_timestamp_nanos() as i64) .bind(&h.session) - .map(Self::query_history) + .map(Self::query_history_inner) .fetch_optional(&self.pool) .await?; @@ -843,7 +729,7 @@ impl Database for Sqlite { }) } - async fn get_dups(&self, before: i64, dupkeep: u32) -> Result<Vec<History>> { + pub(crate) async fn get_dups(&self, before: i64, dupkeep: u32) -> Result<Vec<History>> { let res = sqlx::query( "SELECT * FROM ( SELECT *, ROW_NUMBER() @@ -856,20 +742,16 @@ impl Database for Sqlite { ) .bind(dupkeep) .bind(before) - .map(Self::query_history) + .map(Self::query_history_inner) .fetch_all(&self.pool) .await?; Ok(res) } - - fn clone_boxed(&self) -> Box<dyn Database + 'static> { - Box::new(self.clone()) - } } pub(crate) struct Paged { - database: Box<dyn Database + 'static>, + database: ClientSqlite, page_size: usize, last_id: Option<String>, include_deleted: bool, @@ -878,7 +760,7 @@ pub(crate) struct Paged { impl Paged { pub(crate) fn new( - database: Box<dyn Database + 'static>, + database: ClientSqlite, page_size: usize, include_deleted: bool, unique: bool, @@ -976,7 +858,7 @@ mod test { use std::time::{Duration, Instant}; async fn assert_search_eq( - db: &impl Database, + db: &ClientSqlite, mode: SearchMode, filter_mode: FilterMode, query: &str, @@ -1013,7 +895,7 @@ mod test { } async fn assert_search_commands( - db: &impl Database, + db: &ClientSqlite, mode: SearchMode, filter_mode: FilterMode, query: &str, @@ -1026,7 +908,7 @@ mod test { assert_eq!(commands, expected_commands); } - async fn new_history_item(db: &mut impl Database, cmd: &str) -> Result<()> { + async fn new_history_item(db: &mut ClientSqlite, cmd: &str) -> Result<()> { let mut captured: History = History::capture() .timestamp(OffsetDateTime::now_utc()) .command(cmd) @@ -1044,7 +926,7 @@ mod test { #[tokio::test(flavor = "multi_thread")] async fn test_search_prefix() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + let mut db = ClientSqlite::new("sqlite::memory:", test_local_timeout()) .await .unwrap(); new_history_item(&mut db, "ls /home/ellie").await.unwrap(); @@ -1062,7 +944,7 @@ mod test { #[tokio::test(flavor = "multi_thread")] async fn test_search_fulltext() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + let mut db = ClientSqlite::new("sqlite::memory:", test_local_timeout()) .await .unwrap(); new_history_item(&mut db, "ls /home/ellie").await.unwrap(); @@ -1148,7 +1030,7 @@ mod test { #[tokio::test(flavor = "multi_thread")] async fn test_search_fuzzy() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + let mut db = ClientSqlite::new("sqlite::memory:", test_local_timeout()) .await .unwrap(); new_history_item(&mut db, "ls /home/ellie").await.unwrap(); @@ -1251,7 +1133,7 @@ mod test { #[tokio::test(flavor = "multi_thread")] async fn test_search_reordered_fuzzy() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + let mut db = ClientSqlite::new("sqlite::memory:", test_local_timeout()) .await .unwrap(); // test ordering of results: we should choose the first, even though it happened longer ago. @@ -1279,7 +1161,7 @@ mod test { #[tokio::test(flavor = "multi_thread")] async fn test_paged_basic() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + let mut db = ClientSqlite::new("sqlite::memory:", test_local_timeout()) .await .unwrap(); @@ -1315,7 +1197,7 @@ mod test { #[tokio::test(flavor = "multi_thread")] async fn test_paged_empty() { - let db = Sqlite::new("sqlite::memory:", test_local_timeout()) + let db = ClientSqlite::new("sqlite::memory:", test_local_timeout()) .await .unwrap(); @@ -1329,7 +1211,7 @@ mod test { #[tokio::test(flavor = "multi_thread")] async fn test_paged_unique() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + let mut db = ClientSqlite::new("sqlite::memory:", test_local_timeout()) .await .unwrap(); @@ -1352,7 +1234,7 @@ mod test { #[tokio::test(flavor = "multi_thread")] async fn test_paged_include_deleted() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + let mut db = ClientSqlite::new("sqlite::memory:", test_local_timeout()) .await .unwrap(); @@ -1407,7 +1289,7 @@ mod test { git_root: None, }; - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + let mut db = ClientSqlite::new("sqlite::memory:", test_local_timeout()) .await .unwrap(); for _i in 1..10000 { @@ -1448,7 +1330,7 @@ pub(crate) enum QueryToken<'a> { Regex(&'a str), } -impl<'a> QueryToken<'a> { +impl QueryToken<'_> { pub(crate) fn has_uppercase(&self) -> bool { match self { Self::Match(term, _) |
