diff options
Diffstat (limited to 'crates/turtle/src/atuin_client/database.rs')
| -rw-r--r-- | crates/turtle/src/atuin_client/database.rs | 1526 |
1 files changed, 1526 insertions, 0 deletions
diff --git a/crates/turtle/src/atuin_client/database.rs b/crates/turtle/src/atuin_client/database.rs new file mode 100644 index 00000000..75b1200c --- /dev/null +++ b/crates/turtle/src/atuin_client/database.rs @@ -0,0 +1,1526 @@ +use std::{ + env, + path::{Path, PathBuf}, + str::FromStr, + time::Duration, +}; + +use crate::atuin_client::history::{AUTHOR_FILTER_ALL_AGENT, AUTHOR_FILTER_ALL_USER, KNOWN_AGENTS}; +use crate::atuin_common::utils; +use async_trait::async_trait; +use fs_err as fs; +use itertools::Itertools; +use rand::{Rng, distributions::Alphanumeric}; +use sql_builder::{SqlBuilder, SqlName, bind::Bind, esc, quote}; +use sqlx::{ + Result, Row, + sqlite::{ + SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow, + SqliteSynchronous, + }, +}; +use time::OffsetDateTime; +use tracing::debug; +use uuid::Uuid; + +use crate::atuin_client::{ + history::{HistoryId, HistoryStats}, + utils::get_host_user, +}; + +use super::{ + history::History, + ordering, + settings::{FilterMode, SearchMode, Settings}, +}; + +#[derive(Clone)] +pub struct Context { + pub session: String, + pub cwd: String, + pub hostname: String, + pub host_id: String, + pub git_root: Option<PathBuf>, +} + +#[derive(Default, Clone)] +pub struct OptFilters { + pub exit: Option<i64>, + pub exclude_exit: Option<i64>, + pub cwd: Option<String>, + pub exclude_cwd: Option<String>, + pub before: Option<String>, + pub after: Option<String>, + pub limit: Option<i64>, + pub offset: Option<i64>, + pub reverse: bool, + pub include_duplicates: bool, + /// Author filter. Supports special values `$all-user` and `$all-agent`. + pub authors: Vec<String>, +} + +pub async fn current_context() -> eyre::Result<Context> { + let session = env::var("ATUIN_SESSION").map_err(|_| { + eyre::eyre!("Failed to find $ATUIN_SESSION in the environment. Check that you have correctly set up your shell.") + })?; + let hostname = get_host_user(); + let cwd = utils::get_current_dir(); + let host_id = Settings::host_id().await?; + let git_root = utils::in_git_repo(cwd.as_str()); + + Ok(Context { + session, + hostname, + cwd, + git_root, + host_id: host_id.0.as_simple().to_string(), + }) +} + +impl Context { + pub fn from_history(entry: &History) -> Self { + Context { + session: entry.session.to_string(), + cwd: entry.cwd.to_string(), + hostname: entry.hostname.to_string(), + host_id: String::new(), + git_root: utils::in_git_repo(entry.cwd.as_str()), + } + } +} + +/// Each entry is OR'd: `$all-user` → NOT IN agents, `$all-agent` → IN agents, literal → exact match. +fn apply_author_filter(sql: &mut SqlBuilder, authors: &[String]) { + let mut conditions: Vec<String> = Vec::new(); + let agent_list: String = KNOWN_AGENTS.iter().map(quote).join(", "); + let author_expr = "CASE \ + WHEN author IS NULL OR trim(author) = '' THEN \ + CASE \ + WHEN instr(hostname, ':') > 0 THEN substr(hostname, instr(hostname, ':') + 1) \ + ELSE hostname \ + END \ + ELSE author \ + END"; + + for author in authors { + match author.as_str() { + AUTHOR_FILTER_ALL_USER => { + conditions.push(format!("{author_expr} NOT IN ({agent_list})")); + } + AUTHOR_FILTER_ALL_AGENT => { + conditions.push(format!("{author_expr} IN ({agent_list})")); + } + literal => { + conditions.push(format!("{author_expr} = {}", quote(literal))); + } + } + } + + if !conditions.is_empty() { + sql.and_where(format!("({})", conditions.join(" OR "))); + } +} + +fn get_session_start_time(session_id: &str) -> Option<i64> { + if let Ok(uuid) = Uuid::parse_str(session_id) + && let Some(timestamp) = uuid.get_timestamp() + { + let (seconds, nanos) = timestamp.to_unix(); + return Some(seconds as i64 * 1_000_000_000 + nanos as i64); + } + None +} + +#[async_trait] +pub trait Database: Send + Sync + 'static { + async fn save(&self, h: &History) -> Result<()>; + async fn save_bulk(&self, h: &[History]) -> Result<()>; + + async fn load(&self, id: &str) -> Result<Option<History>>; + async fn list( + &self, + filters: &[FilterMode], + context: &Context, + max: Option<usize>, + unique: bool, + include_deleted: bool, + ) -> Result<Vec<History>>; + async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>>; + + async fn update(&self, h: &History) -> Result<()>; + async fn history_count(&self, include_deleted: bool) -> Result<i64>; + + async fn last(&self) -> Result<Option<History>>; + async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>>; + + async fn delete(&self, h: History) -> Result<()>; + async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()>; + async fn deleted(&self) -> Result<Vec<History>>; + + // Yes I know, it's a lot. + // Could maybe break it down to a searchparams struct or smth but that feels a little... pointless. + // Been debating maybe a DSL for search? eg "before:time limit:1 the query" + #[expect(clippy::too_many_arguments)] + async fn search( + &self, + search_mode: SearchMode, + filter: FilterMode, + context: &Context, + query: &str, + filter_options: OptFilters, + ) -> Result<Vec<History>>; + + async fn query_history(&self, query: &str) -> Result<Vec<History>>; + + async fn all_with_count(&self) -> Result<Vec<(History, i32)>>; + + fn all_paged(&self, page_size: usize, include_deleted: bool, unique: bool) -> Paged; + + async fn stats(&self, h: &History) -> Result<HistoryStats>; + + async fn get_dups(&self, before: i64, dupkeep: u32) -> Result<Vec<History>>; + + fn clone_boxed(&self) -> Box<dyn Database + 'static>; +} + +// Intended for use on a developer machine and not a sync server. +// TODO: implement IntoIterator +#[derive(Debug, Clone)] +pub struct Sqlite { + pub pool: SqlitePool, +} + +impl Sqlite { + pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> { + let path = path.as_ref(); + debug!("opening sqlite database at {path:?}"); + + if utils::broken_symlink(path) { + eprintln!( + "Atuin: Sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement." + ); + std::process::exit(1); + } + + if !path.exists() + && let Some(dir) = path.parent() + { + fs::create_dir_all(dir)?; + } + + let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())? + .journal_mode(SqliteJournalMode::Wal) + .optimize_on_close(true, None) + .synchronous(SqliteSynchronous::Normal) + .with_regexp() + .create_if_missing(true); + + let pool = SqlitePoolOptions::new() + .acquire_timeout(Duration::from_secs_f64(timeout)) + .connect_with(opts) + .await?; + + Self::setup_db(&pool).await?; + Ok(Self { pool }) + } + + pub async fn sqlite_version(&self) -> Result<String> { + sqlx::query_scalar("SELECT sqlite_version()") + .fetch_one(&self.pool) + .await + } + + async fn setup_db(pool: &SqlitePool) -> Result<()> { + debug!("running sqlite database setup"); + + sqlx::migrate!("./migrations").run(pool).await?; + + Ok(()) + } + + async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, h: &History) -> Result<()> { + sqlx::query( + "insert or ignore into history(id, timestamp, duration, exit, command, cwd, session, hostname, author, intent, deleted_at) + values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)", + ) + .bind(h.id.0.as_str()) + .bind(h.timestamp.unix_timestamp_nanos() as i64) + .bind(h.duration) + .bind(h.exit) + .bind(h.command.as_str()) + .bind(h.cwd.as_str()) + .bind(h.session.as_str()) + .bind(h.hostname.as_str()) + .bind(h.author.as_str()) + .bind(h.intent.as_deref()) + .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64)) + .execute(&mut **tx) + .await?; + + Ok(()) + } + + async fn delete_row_raw( + tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, + id: HistoryId, + ) -> Result<()> { + sqlx::query("delete from history where id = ?1") + .bind(id.0.as_str()) + .execute(&mut **tx) + .await?; + + Ok(()) + } + + fn query_history(row: SqliteRow) -> History { + let deleted_at: Option<i64> = row.get("deleted_at"); + let hostname: String = row.get("hostname"); + let author: Option<String> = row.try_get("author").ok().flatten(); + let author = author + .filter(|author| !author.trim().is_empty()) + .unwrap_or_else(|| History::author_from_hostname(hostname.as_str())); + let intent: Option<String> = row.try_get("intent").ok().flatten(); + let intent = intent.filter(|intent| !intent.trim().is_empty()); + + History::from_db() + .id(row.get("id")) + .timestamp( + OffsetDateTime::from_unix_timestamp_nanos(row.get::<i64, _>("timestamp") as i128) + .unwrap(), + ) + .duration(row.get("duration")) + .exit(row.get("exit")) + .command(row.get("command")) + .cwd(row.get("cwd")) + .session(row.get("session")) + .hostname(hostname) + .author(author) + .intent(intent) + .deleted_at( + deleted_at.and_then(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128).ok()), + ) + .build() + .into() + } +} + +#[async_trait] +impl Database for Sqlite { + async fn save(&self, h: &History) -> Result<()> { + debug!("saving history to sqlite"); + let mut tx = self.pool.begin().await?; + Self::save_raw(&mut tx, h).await?; + tx.commit().await?; + + Ok(()) + } + + async fn save_bulk(&self, h: &[History]) -> Result<()> { + debug!("saving history to sqlite"); + + let mut tx = self.pool.begin().await?; + + for i in h { + Self::save_raw(&mut tx, i).await?; + } + + tx.commit().await?; + + Ok(()) + } + + async fn load(&self, id: &str) -> Result<Option<History>> { + debug!("loading history item {}", id); + + let res = sqlx::query("select * from history where id = ?1") + .bind(id) + .map(Self::query_history) + .fetch_optional(&self.pool) + .await?; + + Ok(res) + } + + async fn update(&self, h: &History) -> Result<()> { + debug!("updating sqlite history"); + + sqlx::query( + "update history + set timestamp = ?2, duration = ?3, exit = ?4, command = ?5, cwd = ?6, session = ?7, hostname = ?8, author = ?9, intent = ?10, deleted_at = ?11 + where id = ?1", + ) + .bind(h.id.0.as_str()) + .bind(h.timestamp.unix_timestamp_nanos() as i64) + .bind(h.duration) + .bind(h.exit) + .bind(h.command.as_str()) + .bind(h.cwd.as_str()) + .bind(h.session.as_str()) + .bind(h.hostname.as_str()) + .bind(h.author.as_str()) + .bind(h.intent.as_deref()) + .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64)) + .execute(&self.pool) + .await?; + + Ok(()) + } + + // make a unique list, that only shows the *newest* version of things + async fn list( + &self, + filters: &[FilterMode], + context: &Context, + max: Option<usize>, + unique: bool, + include_deleted: bool, + ) -> Result<Vec<History>> { + debug!("listing history"); + + let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted()); + query.field("*").order_desc("timestamp"); + if !include_deleted { + query.and_where_is_null("deleted_at"); + } + + let git_root = if let Some(git_root) = context.git_root.clone() { + git_root.to_str().unwrap_or("/").to_string() + } else { + context.cwd.clone() + }; + + let session_start = get_session_start_time(&context.session); + + for filter in filters { + match filter { + FilterMode::Global => &mut query, + FilterMode::Host => query.and_where_eq("hostname", quote(&context.hostname)), + FilterMode::Session => query.and_where_eq("session", quote(&context.session)), + FilterMode::SessionPreload => { + query.and_where_eq("session", quote(&context.session)); + if let Some(session_start) = session_start { + query.or_where_lt("timestamp", session_start); + } + &mut query + } + FilterMode::Directory => query.and_where_eq("cwd", quote(&context.cwd)), + FilterMode::Workspace => query.and_where_like_left("cwd", &git_root), + }; + } + + if unique { + query.group_by("command").having("max(timestamp)"); + } + + if let Some(max) = max { + query.limit(max); + } + + let query = query.sql().expect("bug in list query. please report"); + + let res = sqlx::query(&query) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>> { + debug!("listing history from {:?} to {:?}", from, to); + + let res = sqlx::query( + "select * from history where timestamp >= ?1 and timestamp <= ?2 order by timestamp asc", + ) + .bind(from.unix_timestamp_nanos() as i64) + .bind(to.unix_timestamp_nanos() as i64) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn last(&self) -> Result<Option<History>> { + let res = sqlx::query( + "select * from history where duration >= 0 order by timestamp desc limit 1", + ) + .map(Self::query_history) + .fetch_optional(&self.pool) + .await?; + + Ok(res) + } + + async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>> { + let res = sqlx::query( + "select * from history where timestamp < ?1 order by timestamp desc limit ?2", + ) + .bind(timestamp.unix_timestamp_nanos() as i64) + .bind(count) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn deleted(&self) -> Result<Vec<History>> { + let res = sqlx::query("select * from history where deleted_at is not null") + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn history_count(&self, include_deleted: bool) -> Result<i64> { + let query = if include_deleted { + "select count(1) from history" + } else { + "select count(1) from history where deleted_at is null" + }; + + let res: (i64,) = sqlx::query_as(query).fetch_one(&self.pool).await?; + Ok(res.0) + } + + async fn search( + &self, + search_mode: SearchMode, + filter: FilterMode, + context: &Context, + query: &str, + filter_options: OptFilters, + ) -> Result<Vec<History>> { + let mut sql = SqlBuilder::select_from("history"); + + if !filter_options.include_duplicates { + sql.group_by("command").having("max(timestamp)"); + } + + if let Some(limit) = filter_options.limit { + sql.limit(limit); + } + + if let Some(offset) = filter_options.offset { + sql.offset(offset); + } + + if filter_options.reverse { + sql.order_asc("timestamp"); + } else { + sql.order_desc("timestamp"); + } + + let git_root = if let Some(git_root) = context.git_root.clone() { + git_root.to_str().unwrap_or("/").to_string() + } else { + context.cwd.clone() + }; + + let session_start = get_session_start_time(&context.session); + + match filter { + FilterMode::Global => &mut sql, + FilterMode::Host => { + sql.and_where_eq("lower(hostname)", quote(context.hostname.to_lowercase())) + } + FilterMode::Session => sql.and_where_eq("session", quote(&context.session)), + FilterMode::SessionPreload => { + sql.and_where_eq("session", quote(&context.session)); + if let Some(session_start) = session_start { + sql.or_where_lt("timestamp", session_start); + } + &mut sql + } + FilterMode::Directory => sql.and_where_eq("cwd", quote(&context.cwd)), + FilterMode::Workspace => sql.and_where_like_left("cwd", git_root), + }; + + let orig_query = query; + + let mut regexes = Vec::new(); + match search_mode { + SearchMode::Prefix => sql.and_where_like_left("command", query.replace('*', "%")), + _ => { + let mut is_or = false; + for token in QueryTokenizer::new(query) { + // TODO smart case mode could be made configurable like in fzf + let (is_glob, glob) = if token.has_uppercase() { + (true, "*") + } else { + (false, "%") + }; + let param = match token { + QueryToken::Regex(r) => { + regexes.push(String::from(r)); + continue; + } + QueryToken::Or => { + if !is_or { + is_or = true; + continue; + } else { + format!("{glob}|{glob}") + } + } + QueryToken::MatchStart(term, _) => { + format!("{term}{glob}") + } + QueryToken::MatchEnd(term, _) => { + format!("{glob}{term}") + } + QueryToken::MatchFull(term, _) => { + format!("{glob}{term}{glob}") + } + QueryToken::Match(term, _) => { + if search_mode == SearchMode::FullText { + format!("{glob}{term}{glob}") + } else { + term.split("").join(glob) + } + } + }; + + sql.fuzzy_condition("command", param, token.is_inverse(), is_glob, is_or); + is_or = false; + } + + &mut sql + } + }; + + for regex in regexes { + sql.and_where("command regexp ?".bind(®ex)); + } + + filter_options + .exit + .map(|exit| sql.and_where_eq("exit", exit)); + + filter_options + .exclude_exit + .map(|exclude_exit| sql.and_where_ne("exit", exclude_exit)); + + filter_options + .cwd + .map(|cwd| sql.and_where_eq("cwd", quote(cwd))); + + filter_options + .exclude_cwd + .map(|exclude_cwd| sql.and_where_ne("cwd", quote(exclude_cwd))); + + filter_options.before.map(|before| { + interim::parse_date_string( + before.as_str(), + OffsetDateTime::now_utc(), + interim::Dialect::Uk, + ) + .map(|before| { + sql.and_where_lt("timestamp", quote(before.unix_timestamp_nanos() as i64)) + }) + }); + + filter_options.after.map(|after| { + interim::parse_date_string( + after.as_str(), + OffsetDateTime::now_utc(), + interim::Dialect::Uk, + ) + .map(|after| sql.and_where_gt("timestamp", quote(after.unix_timestamp_nanos() as i64))) + }); + + if !filter_options.authors.is_empty() { + apply_author_filter(&mut sql, &filter_options.authors); + } + + sql.and_where_is_null("deleted_at"); + + let query = sql.sql().expect("bug in search query. please report"); + + let res = sqlx::query(&query) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(ordering::reorder_fuzzy(search_mode, orig_query, res)) + } + + async fn query_history(&self, query: &str) -> Result<Vec<History>> { + let res = sqlx::query(query) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn all_with_count(&self) -> Result<Vec<(History, i32)>> { + debug!("listing history"); + + let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted()); + + query + .fields(&[ + "id", + "max(timestamp) as timestamp", + "max(duration) as duration", + "exit", + "command", + "deleted_at", + "null as author", + "null as intent", + "group_concat(cwd, ':') as cwd", + "group_concat(session) as session", + "group_concat(hostname, ',') as hostname", + "count(*) as count", + ]) + .group_by("command") + .group_by("exit") + .and_where("deleted_at is null") + .order_desc("timestamp"); + + let query = query.sql().expect("bug in list query. please report"); + + let res = sqlx::query(&query) + .map(|row: SqliteRow| { + let count: i32 = row.get("count"); + (Self::query_history(row), count) + }) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + fn all_paged(&self, page_size: usize, include_deleted: bool, unique: bool) -> Paged { + Paged::new(Box::new(self.clone()), page_size, include_deleted, unique) + } + + // deleted_at doesn't mean the actual time that the user deleted it, + // but the time that the system marks it as deleted + async fn delete(&self, mut h: History) -> Result<()> { + let now = OffsetDateTime::now_utc(); + h.command = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(32) + .map(char::from) + .collect(); // overwrite with random string + h.deleted_at = Some(now); // delete it + + self.update(&h).await?; // save it + + Ok(()) + } + + async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()> { + let mut tx = self.pool.begin().await?; + + for id in ids { + Self::delete_row_raw(&mut tx, id.clone()).await?; + } + + tx.commit().await?; + + Ok(()) + } + + async fn stats(&self, h: &History) -> Result<HistoryStats> { + // We select the previous in the session by time + let mut prev = SqlBuilder::select_from("history"); + prev.field("*") + .and_where("timestamp < ?1") + .and_where("session = ?2") + .order_by("timestamp", true) + .limit(1); + + let mut next = SqlBuilder::select_from("history"); + next.field("*") + .and_where("timestamp > ?1") + .and_where("session = ?2") + .order_by("timestamp", false) + .limit(1); + + let mut total = SqlBuilder::select_from("history"); + total.field("count(1)").and_where("command = ?1"); + + let mut average = SqlBuilder::select_from("history"); + average.field("avg(duration)").and_where("command = ?1"); + + let mut exits = SqlBuilder::select_from("history"); + exits + .fields(&["exit", "count(1) as count"]) + .and_where("command = ?1") + .group_by("exit"); + + // rewrite the following with sqlbuilder + let mut day_of_week = SqlBuilder::select_from("history"); + day_of_week + .fields(&[ + "strftime('%w', ROUND(timestamp / 1000000000), 'unixepoch') AS day_of_week", + "count(1) as count", + ]) + .and_where("command = ?1") + .group_by("day_of_week"); + + // Intentionally format the string with 01 hardcoded. We want the average runtime for the + // _entire month_, but will later parse it as a datetime for sorting + // Sqlite has no datetime so we cannot do it there, and otherwise sorting will just be a + // string sort, which won't be correct. + let mut duration_over_time = SqlBuilder::select_from("history"); + duration_over_time + .fields(&[ + "strftime('01-%m-%Y', ROUND(timestamp / 1000000000), 'unixepoch') AS month_year", + "avg(duration) as duration", + ]) + .and_where("command = ?1") + .group_by("month_year") + .having("duration > 0"); + + let prev = prev.sql().expect("issue in stats previous query"); + let next = next.sql().expect("issue in stats next query"); + let total = total.sql().expect("issue in stats average query"); + let average = average.sql().expect("issue in stats previous query"); + let exits = exits.sql().expect("issue in stats exits query"); + let day_of_week = day_of_week.sql().expect("issue in stats day of week query"); + let duration_over_time = duration_over_time + .sql() + .expect("issue in stats duration over time query"); + + let prev = sqlx::query(&prev) + .bind(h.timestamp.unix_timestamp_nanos() as i64) + .bind(&h.session) + .map(Self::query_history) + .fetch_optional(&self.pool) + .await?; + + let next = sqlx::query(&next) + .bind(h.timestamp.unix_timestamp_nanos() as i64) + .bind(&h.session) + .map(Self::query_history) + .fetch_optional(&self.pool) + .await?; + + let total: (i64,) = sqlx::query_as(&total) + .bind(&h.command) + .fetch_one(&self.pool) + .await?; + + let average: (f64,) = sqlx::query_as(&average) + .bind(&h.command) + .fetch_one(&self.pool) + .await?; + + let exits: Vec<(i64, i64)> = sqlx::query_as(&exits) + .bind(&h.command) + .fetch_all(&self.pool) + .await?; + + let day_of_week: Vec<(String, i64)> = sqlx::query_as(&day_of_week) + .bind(&h.command) + .fetch_all(&self.pool) + .await?; + + let duration_over_time: Vec<(String, f64)> = sqlx::query_as(&duration_over_time) + .bind(&h.command) + .fetch_all(&self.pool) + .await?; + + let duration_over_time = duration_over_time + .iter() + .map(|f| (f.0.clone(), f.1.round() as i64)) + .collect(); + + Ok(HistoryStats { + next, + previous: prev, + total: total.0 as u64, + average_duration: average.0 as u64, + exits, + day_of_week, + duration_over_time, + }) + } + + async fn get_dups(&self, before: i64, dupkeep: u32) -> Result<Vec<History>> { + let res = sqlx::query( + "SELECT * FROM ( + SELECT *, ROW_NUMBER() + OVER (PARTITION BY command, cwd, hostname ORDER BY timestamp DESC) + AS rn + FROM history + ) sub + WHERE rn > ?1 and timestamp < ?2; + ", + ) + .bind(dupkeep) + .bind(before) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + fn clone_boxed(&self) -> Box<dyn Database + 'static> { + Box::new(self.clone()) + } +} + +pub struct Paged { + database: Box<dyn Database + 'static>, + page_size: usize, + last_id: Option<String>, + include_deleted: bool, + unique: bool, +} + +impl Paged { + pub fn new( + database: Box<dyn Database + 'static>, + page_size: usize, + include_deleted: bool, + unique: bool, + ) -> Self { + Self { + database, + page_size, + last_id: None, + include_deleted, + unique, + } + } + + pub async fn next(&mut self) -> Result<Option<Vec<History>>> { + let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted()); + + query.field("*").order_desc("id"); + + if !self.include_deleted { + query.and_where_is_null("deleted_at"); + } + + if self.unique { + // We want to deduplicate on command, but the user can search via cwd, hostname, and session. + // Without those fields, filter modes won't work right. With those fields, we get duplicates. + // This must be handled upstream. + query + .group_by("command, cwd, hostname, session") + .having("max(timestamp)"); + } + + query.limit(self.page_size); + + if let Some(last_id) = &self.last_id { + query.and_where_lt("id", quote(last_id)); + } + + let query = query.sql().expect("bug in list query. please report"); + let res = self.database.query_history(&query).await?; + + if res.is_empty() { + Ok(None) + } else { + self.last_id = Some(res.last().unwrap().id.0.clone()); + Ok(Some(res)) + } + } +} + +trait SqlBuilderExt { + fn fuzzy_condition<S: ToString, T: ToString>( + &mut self, + field: S, + mask: T, + inverse: bool, + glob: bool, + is_or: bool, + ) -> &mut Self; +} + +impl SqlBuilderExt for SqlBuilder { + /// adapted from the sql-builder *like functions + fn fuzzy_condition<S: ToString, T: ToString>( + &mut self, + field: S, + mask: T, + inverse: bool, + glob: bool, + is_or: bool, + ) -> &mut Self { + let mut cond = field.to_string(); + if inverse { + cond.push_str(" NOT"); + } + if glob { + cond.push_str(" GLOB '"); + } else { + cond.push_str(" LIKE '"); + } + cond.push_str(&esc(mask.to_string())); + cond.push('\''); + if is_or { + self.or_where(cond) + } else { + self.and_where(cond) + } + } +} + +#[cfg(test)] +mod test { + use crate::atuin_client::settings::test_local_timeout; + + use super::*; + use std::time::{Duration, Instant}; + + async fn assert_search_eq( + db: &impl Database, + mode: SearchMode, + filter_mode: FilterMode, + query: &str, + expected: usize, + ) -> Result<Vec<History>> { + let context = Context { + hostname: "test:host".to_string(), + session: "beepboopiamasession".to_string(), + cwd: "/home/ellie".to_string(), + host_id: "test-host".to_string(), + git_root: None, + }; + + let results = db + .search( + mode, + filter_mode, + &context, + query, + OptFilters { + ..Default::default() + }, + ) + .await?; + + assert_eq!( + results.len(), + expected, + "query \"{}\", commands: {:?}", + query, + results.iter().map(|a| &a.command).collect::<Vec<&String>>() + ); + Ok(results) + } + + async fn assert_search_commands( + db: &impl Database, + mode: SearchMode, + filter_mode: FilterMode, + query: &str, + expected_commands: Vec<&str>, + ) { + let results = assert_search_eq(db, mode, filter_mode, query, expected_commands.len()) + .await + .unwrap(); + let commands: Vec<&str> = results.iter().map(|a| a.command.as_str()).collect(); + assert_eq!(commands, expected_commands); + } + + async fn new_history_item(db: &mut impl Database, cmd: &str) -> Result<()> { + let mut captured: History = History::capture() + .timestamp(OffsetDateTime::now_utc()) + .command(cmd) + .cwd("/home/ellie") + .build() + .into(); + + captured.exit = 0; + captured.duration = 1; + captured.session = "beep boop".to_string(); + captured.hostname = "booop".to_string(); + + db.save(&captured).await + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_search_prefix() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + new_history_item(&mut db, "ls /home/ellie").await.unwrap(); + + assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "/home", 0) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls ", 0) + .await + .unwrap(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_search_fulltext() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + new_history_item(&mut db, "ls /home/ellie").await.unwrap(); + + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls ho", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "hm", 0) + .await + .unwrap(); + + // regex + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r/^ls ", 1) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r/ls / ie$", + 1, + ) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r/ls / !ie", + 0, + ) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "meow r/ls/", + 0, + ) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r//hom/", 1) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r//home//", + 1, + ) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r//home///", + 0, + ) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home.*e", 0) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r/home.*e", + 1, + ) + .await + .unwrap(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_search_fuzzy() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + new_history_item(&mut db, "ls /home/ellie").await.unwrap(); + new_history_item(&mut db, "ls /home/frank").await.unwrap(); + new_history_item(&mut db, "cd /home/Ellie").await.unwrap(); + new_history_item(&mut db, "/home/ellie/.bin/rustup") + .await + .unwrap(); + + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls /", 3) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls/", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "l/h/", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e", 3) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/hmoe/", 0) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie/home", 0) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "lsellie", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, " ", 4) + .await + .unwrap(); + + // single term operators + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "'ls", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie$", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!^ls", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie$", 2) + .await + .unwrap(); + + // multiple terms + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls !ellie", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls !e$", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "home !^ls", 2) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::Fuzzy, + FilterMode::Global, + "'frank | 'rustup", + 2, + ) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::Fuzzy, + FilterMode::Global, + "'frank | 'rustup 'ls", + 1, + ) + .await + .unwrap(); + + // case matching + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "Ellie", 1) + .await + .unwrap(); + + // regex + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/^ls ", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/[Ee]llie", 3) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e r/^ls ", 1) + .await + .unwrap(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_search_reordered_fuzzy() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + // test ordering of results: we should choose the first, even though it happened longer ago. + + new_history_item(&mut db, "curl").await.unwrap(); + new_history_item(&mut db, "corburl").await.unwrap(); + + // if fuzzy reordering is on, it should come back in a more sensible order + assert_search_commands( + &db, + SearchMode::Fuzzy, + FilterMode::Global, + "curl", + vec!["curl", "corburl"], + ) + .await; + + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "xxxx", 0) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "", 2) + .await + .unwrap(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_paged_basic() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + + // Add 5 history items + for i in 0..5 { + new_history_item(&mut db, &format!("command{}", i)) + .await + .unwrap(); + } + + // Create a paged iterator with page_size of 2 + let mut paged = db.all_paged(2, false, false); + + // First page should have 2 items + let page1 = paged.next().await.unwrap(); + assert!(page1.is_some()); + assert_eq!(page1.unwrap().len(), 2); + + // Second page should have 2 items + let page2 = paged.next().await.unwrap(); + assert!(page2.is_some()); + assert_eq!(page2.unwrap().len(), 2); + + // Third page should have 1 item + let page3 = paged.next().await.unwrap(); + assert!(page3.is_some()); + assert_eq!(page3.unwrap().len(), 1); + + // Fourth page should be None (exhausted) + let page4 = paged.next().await.unwrap(); + assert!(page4.is_none()); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_paged_empty() { + let db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + + // Create a paged iterator on empty database + let mut paged = db.all_paged(10, false, false); + + // Should return None immediately + let page = paged.next().await.unwrap(); + assert!(page.is_none()); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_paged_unique() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + + // Add duplicate commands + new_history_item(&mut db, "duplicate").await.unwrap(); + new_history_item(&mut db, "duplicate").await.unwrap(); + new_history_item(&mut db, "unique1").await.unwrap(); + new_history_item(&mut db, "unique2").await.unwrap(); + + // Without unique flag - should get all 4 + let mut paged = db.all_paged(10, false, false); + let page = paged.next().await.unwrap().unwrap(); + assert_eq!(page.len(), 4); + + // With unique flag - should get 3 (duplicates collapsed) + let mut paged_unique = db.all_paged(10, false, true); + let page_unique = paged_unique.next().await.unwrap().unwrap(); + assert_eq!(page_unique.len(), 3); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_paged_include_deleted() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + + // Add items + new_history_item(&mut db, "keep1").await.unwrap(); + new_history_item(&mut db, "keep2").await.unwrap(); + new_history_item(&mut db, "delete_me").await.unwrap(); + + // Delete one item + let all = db + .list( + &[], + &Context { + hostname: "".to_string(), + session: "".to_string(), + cwd: "".to_string(), + host_id: "".to_string(), + git_root: None, + }, + None, + false, + false, + ) + .await + .unwrap(); + + let to_delete = all + .iter() + .find(|h| h.command == "delete_me") + .unwrap() + .clone(); + db.delete(to_delete).await.unwrap(); + + // Without include_deleted - should get 2 + let mut paged = db.all_paged(10, false, false); + let page = paged.next().await.unwrap().unwrap(); + assert_eq!(page.len(), 2); + + // With include_deleted - should get 3 + let mut paged_deleted = db.all_paged(10, true, false); + let page_deleted = paged_deleted.next().await.unwrap().unwrap(); + assert_eq!(page_deleted.len(), 3); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_search_bench_dupes() { + let context = Context { + hostname: "test:host".to_string(), + session: "beepboopiamasession".to_string(), + cwd: "/home/ellie".to_string(), + host_id: "test-host".to_string(), + git_root: None, + }; + + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + for _i in 1..10000 { + new_history_item(&mut db, "i am a duplicated command") + .await + .unwrap(); + } + let start = Instant::now(); + let _results = db + .search( + SearchMode::Fuzzy, + FilterMode::Global, + &context, + "", + OptFilters { + ..Default::default() + }, + ) + .await + .unwrap(); + let duration = start.elapsed(); + + assert!(duration < Duration::from_secs(15)); + } +} + +pub struct QueryTokenizer<'a> { + query: &'a str, + last_pos: usize, +} + +pub enum QueryToken<'a> { + Match(&'a str, bool), + MatchStart(&'a str, bool), + MatchEnd(&'a str, bool), + MatchFull(&'a str, bool), + Or, + Regex(&'a str), +} + +impl<'a> QueryToken<'a> { + pub fn has_uppercase(&self) -> bool { + match self { + Self::Match(term, _) + | Self::MatchStart(term, _) + | Self::MatchEnd(term, _) + | Self::MatchFull(term, _) => term.contains(char::is_uppercase), + _ => false, + } + } + + pub fn is_inverse(&self) -> bool { + match self { + Self::Match(_, inv) + | Self::MatchStart(_, inv) + | Self::MatchEnd(_, inv) + | Self::MatchFull(_, inv) => *inv, + _ => false, + } + } +} + +impl<'a> QueryTokenizer<'a> { + pub fn new(query: &'a str) -> Self { + Self { query, last_pos: 0 } + } +} + +impl<'a> Iterator for QueryTokenizer<'a> { + type Item = QueryToken<'a>; + fn next(&mut self) -> Option<Self::Item> { + let remaining = &self.query[self.last_pos..]; + if remaining.is_empty() { + return None; + } + + if let Some(remaining) = remaining.strip_prefix("r/") { + let (regex, next_pos) = if let Some(end) = remaining.find("/ ") { + (&remaining[..end], self.last_pos + 2 + end + 2) + } else if let Some(remaining) = remaining.strip_suffix('/') { + (remaining, self.query.len()) + } else { + (remaining, self.query.len()) + }; + self.last_pos = next_pos; + Some(QueryToken::Regex(regex)) + } else { + let (mut part, next_pos) = if let Some(sp) = remaining.find(' ') { + (&remaining[..sp], self.last_pos + sp + 1) + } else { + (remaining, self.query.len()) + }; + self.last_pos = next_pos; + + if part == "|" { + return Some(QueryToken::Or); + } + + let mut is_inverse = false; + if let Some(s) = part.strip_prefix('!') { + part = s; + is_inverse = true; + } + let token = if let Some(s) = part.strip_prefix('^') { + QueryToken::MatchStart(s, is_inverse) + } else if let Some(s) = part.strip_suffix('$') { + QueryToken::MatchEnd(s, is_inverse) + } else if let Some(s) = part.strip_prefix('\'') { + QueryToken::MatchFull(s, is_inverse) + } else { + QueryToken::Match(part, is_inverse) + }; + Some(token) + } + } +} |
