diff options
Diffstat (limited to 'crates')
| -rw-r--r-- | crates/atuin-client/src/database.rs | 39 | ||||
| -rw-r--r-- | crates/atuin-client/src/history.rs | 18 | ||||
| -rw-r--r-- | crates/atuin-client/src/settings.rs | 14 | ||||
| -rw-r--r-- | crates/atuin-daemon/proto/search.proto | 1 | ||||
| -rw-r--r-- | crates/atuin-daemon/src/client.rs | 2 | ||||
| -rw-r--r-- | crates/atuin-daemon/src/components/search.rs | 9 | ||||
| -rw-r--r-- | crates/atuin-daemon/src/search/index.rs | 288 | ||||
| -rw-r--r-- | crates/atuin/src/command/client.rs | 9 | ||||
| -rw-r--r-- | crates/atuin/src/command/client/history.rs | 657 | ||||
| -rw-r--r-- | crates/atuin/src/command/client/hook.rs | 401 | ||||
| -rw-r--r-- | crates/atuin/src/command/client/search.rs | 37 | ||||
| -rw-r--r-- | crates/atuin/src/command/client/search/engines.rs | 40 | ||||
| -rw-r--r-- | crates/atuin/src/command/client/search/engines/daemon.rs | 31 | ||||
| -rw-r--r-- | crates/atuin/src/command/client/search/engines/db.rs | 77 | ||||
| -rw-r--r-- | crates/atuin/src/command/client/search/engines/skim.rs | 88 | ||||
| -rw-r--r-- | crates/atuin/src/command/client/search/interactive.rs | 11 |
16 files changed, 1348 insertions, 374 deletions
diff --git a/crates/atuin-client/src/database.rs b/crates/atuin-client/src/database.rs index 7c63368d..75ef51c3 100644 --- a/crates/atuin-client/src/database.rs +++ b/crates/atuin-client/src/database.rs @@ -5,6 +5,7 @@ use std::{ time::Duration, }; +use crate::history::{AUTHOR_FILTER_ALL_AGENT, AUTHOR_FILTER_ALL_USER, KNOWN_AGENTS}; use async_trait::async_trait; use atuin_common::utils; use fs_err as fs; @@ -53,6 +54,8 @@ pub struct OptFilters { pub offset: Option<i64>, pub reverse: bool, pub include_duplicates: bool, + /// Author filter. Supports special values `$all-user` and `$all-agent`. + pub authors: Vec<String>, } pub async fn current_context() -> eyre::Result<Context> { @@ -85,6 +88,38 @@ impl Context { } } +/// Each entry is OR'd: `$all-user` → NOT IN agents, `$all-agent` → IN agents, literal → exact match. +fn apply_author_filter(sql: &mut SqlBuilder, authors: &[String]) { + let mut conditions: Vec<String> = Vec::new(); + let agent_list: String = KNOWN_AGENTS.iter().map(quote).join(", "); + let author_expr = "CASE \ + WHEN author IS NULL OR trim(author) = '' THEN \ + CASE \ + WHEN instr(hostname, ':') > 0 THEN substr(hostname, instr(hostname, ':') + 1) \ + ELSE hostname \ + END \ + ELSE author \ + END"; + + for author in authors { + match author.as_str() { + AUTHOR_FILTER_ALL_USER => { + conditions.push(format!("{author_expr} NOT IN ({agent_list})")); + } + AUTHOR_FILTER_ALL_AGENT => { + conditions.push(format!("{author_expr} IN ({agent_list})")); + } + literal => { + conditions.push(format!("{author_expr} = {}", quote(literal))); + } + } + } + + if !conditions.is_empty() { + sql.and_where(format!("({})", conditions.join(" OR "))); + } +} + fn get_session_start_time(session_id: &str) -> Option<i64> { if let Ok(uuid) = Uuid::parse_str(session_id) && let Some(timestamp) = uuid.get_timestamp() @@ -595,6 +630,10 @@ impl Database for Sqlite { .map(|after| sql.and_where_gt("timestamp", quote(after.unix_timestamp_nanos() as i64))) }); + if !filter_options.authors.is_empty() { + apply_author_filter(&mut sql, &filter_options.authors); + } + sql.and_where_is_null("deleted_at"); let query = sql.sql().expect("bug in search query. please report"); diff --git a/crates/atuin-client/src/history.rs b/crates/atuin-client/src/history.rs index a5adc233..996208d9 100644 --- a/crates/atuin-client/src/history.rs +++ b/crates/atuin-client/src/history.rs @@ -18,6 +18,24 @@ use time::OffsetDateTime; mod builder; pub mod store; +/// Known AI agent author values. Used to expand `$all-agent` and `$all-user` filters. +pub const KNOWN_AGENTS: &[&str] = &["claude-code", "codex", "copilot"]; +pub const AUTHOR_FILTER_ALL_USER: &str = "$all-user"; +pub const AUTHOR_FILTER_ALL_AGENT: &str = "$all-agent"; + +pub fn is_known_agent(author: &str) -> bool { + KNOWN_AGENTS.contains(&author) +} + +pub fn author_matches_filters(author: &str, filters: &[String]) -> bool { + filters.is_empty() + || filters.iter().any(|filter| match filter.as_str() { + AUTHOR_FILTER_ALL_USER => !is_known_agent(author), + AUTHOR_FILTER_ALL_AGENT => is_known_agent(author), + literal => author == literal, + }) +} + pub(crate) const HISTORY_VERSION_V0: &str = "v0"; pub(crate) const HISTORY_VERSION_V1: &str = "v1"; const HISTORY_RECORD_VERSION_V0: u16 = 0; diff --git a/crates/atuin-client/src/settings.rs b/crates/atuin-client/src/settings.rs index 25c3bd65..5944de59 100644 --- a/crates/atuin-client/src/settings.rs +++ b/crates/atuin-client/src/settings.rs @@ -565,6 +565,13 @@ pub struct Search { /// The overall frecency score multiplier for the search index (default: 1.0). /// Applied after combining recency and frequency scores. pub frecency_score_multiplier: f64, + + /// Filter history by author. Special values: + /// - `$all-user`: any author that is NOT a known AI agent (default) + /// - `$all-agent`: any known AI agent author + /// - literal strings like "ellie", "claude-code" + #[serde(default = "Search::default_authors")] + pub authors: Vec<String>, } #[derive(Clone, Debug, Deserialize, Serialize)] @@ -844,10 +851,17 @@ impl Default for Search { recency_score_multiplier: 1.0, frequency_score_multiplier: 1.0, frecency_score_multiplier: 1.0, + authors: Self::default_authors(), } } } +impl Search { + fn default_authors() -> Vec<String> { + vec![crate::history::AUTHOR_FILTER_ALL_USER.to_string()] + } +} + impl Default for Tmux { fn default() -> Self { Self { diff --git a/crates/atuin-daemon/proto/search.proto b/crates/atuin-daemon/proto/search.proto index 6b84acbd..5eea2b62 100644 --- a/crates/atuin-daemon/proto/search.proto +++ b/crates/atuin-daemon/proto/search.proto @@ -23,6 +23,7 @@ message SearchRequest { uint64 query_id = 2; // Incrementing ID to match responses to queries FilterMode filter_mode = 3; SearchContext context = 4; + repeated string authors = 5; // Author filter ($all-user, $all-agent, or literals) } message SearchResponse { diff --git a/crates/atuin-daemon/src/client.rs b/crates/atuin-daemon/src/client.rs index 5f4ce20f..51334ee1 100644 --- a/crates/atuin-daemon/src/client.rs +++ b/crates/atuin-daemon/src/client.rs @@ -213,12 +213,14 @@ impl SearchClient { query_id: u64, filter_mode: FilterMode, context: Option<Context>, + authors: Vec<String>, ) -> Result<tonic::Streaming<SearchResponse>> { let request = SearchRequest { query, query_id, filter_mode: RpcFilterMode::from(filter_mode).into(), context: context.map(RpcSearchContext::from), + authors, }; let request_stream = tokio_stream::once(request); let response = span!(Level::TRACE, "daemon_client_search.request") diff --git a/crates/atuin-daemon/src/components/search.rs b/crates/atuin-daemon/src/components/search.rs index 9fc87fae..a2e74aa5 100644 --- a/crates/atuin-daemon/src/components/search.rs +++ b/crates/atuin-daemon/src/components/search.rs @@ -304,6 +304,7 @@ impl SearchSvc for SearchGrpcService { .try_into() .unwrap_or(FilterMode::Global); let proto_context = search_req.context; + let authors = search_req.authors; debug!( "search request: query = {}, query_id = {}, filter_mode = {}, context = {:?}", @@ -332,7 +333,13 @@ impl SearchSvc for SearchGrpcService { .in_scope(|| async { let index = index.read().await; index - .search(&query, index_filter, &query_context, RESULTS_LIMIT) + .search( + &query, + index_filter, + &query_context, + &authors, + RESULTS_LIMIT, + ) .await }) .await; diff --git a/crates/atuin-daemon/src/search/index.rs b/crates/atuin-daemon/src/search/index.rs index 3328c5b5..90751155 100644 --- a/crates/atuin-daemon/src/search/index.rs +++ b/crates/atuin-daemon/src/search/index.rs @@ -12,7 +12,9 @@ use std::{ sync::Arc, }; -use atuin_client::history::History; +use atuin_client::history::{ + AUTHOR_FILTER_ALL_AGENT, AUTHOR_FILTER_ALL_USER, History, KNOWN_AGENTS, +}; use atuin_client::settings::Search; use atuin_nucleo::{Injector, Nucleo, pattern}; use dashmap::DashMap; @@ -114,6 +116,100 @@ pub struct CommandData { hosts: HashSet<Spur>, /// All sessions where this command has been run (as 16-byte UUIDs). sessions: HashSet<[u8; 16]>, + /// All authors who have run this command (interned keys). + authors: HashSet<Spur>, + /// Per-invocation metadata for returning the newest row that matches the active filters. + invocations: Vec<InvocationData>, +} + +struct InvocationData { + id: [u8; 16], + timestamp: i64, + directory: Spur, + host: Spur, + session: [u8; 16], + author: Spur, +} + +#[derive(Default)] +struct CompiledAuthorFilter { + include_all_users: bool, + include_all_agents: bool, + literal_authors: HashSet<Spur>, +} + +impl CompiledAuthorFilter { + fn new(filters: &[String], interner: &ThreadedRodeo) -> Self { + let mut compiled = Self::default(); + + for filter in filters { + match filter.as_str() { + AUTHOR_FILTER_ALL_USER => compiled.include_all_users = true, + AUTHOR_FILTER_ALL_AGENT => compiled.include_all_agents = true, + literal => { + if let Some(author) = interner.get(literal) { + compiled.literal_authors.insert(author); + } + } + } + } + + compiled + } + + fn is_empty(&self) -> bool { + !self.include_all_users && !self.include_all_agents && self.literal_authors.is_empty() + } + + fn matches_author(&self, author: Spur, agent_authors: &HashSet<Spur>) -> bool { + self.is_empty() + || self.literal_authors.contains(&author) + || (self.include_all_users && !agent_authors.contains(&author)) + || (self.include_all_agents && agent_authors.contains(&author)) + } +} + +impl InvocationData { + fn new(history: &History, interner: &ThreadedRodeo) -> Option<Self> { + Some(Self { + id: parse_uuid_bytes(&history.id.0)?, + timestamp: history.timestamp.unix_timestamp(), + directory: interner.get_or_intern(with_trailing_slash(&history.cwd)), + host: interner.get_or_intern(&history.hostname), + session: parse_uuid_bytes(&history.session)?, + author: interner.get_or_intern(&history.author), + }) + } + + fn matches_mode(&self, mode: &IndexFilterMode, interner: &ThreadedRodeo) -> bool { + match mode { + IndexFilterMode::Global => true, + IndexFilterMode::Directory(dir) => { + interner.get(dir).is_some_and(|spur| self.directory == spur) + } + IndexFilterMode::Workspace(prefix) => { + interner.resolve(&self.directory).starts_with(prefix) + } + IndexFilterMode::Host(hostname) => { + interner.get(hostname).is_some_and(|spur| self.host == spur) + } + IndexFilterMode::Session(session) => { + parse_uuid_bytes(session).is_some_and(|bytes| self.session == bytes) + } + } + } + + fn matches_authors( + &self, + filter: &CompiledAuthorFilter, + agent_authors: &HashSet<Spur>, + ) -> bool { + filter.matches_author(self.author, agent_authors) + } + + fn id_string(&self) -> String { + format_uuid_bytes(&self.id) + } } impl CommandData { @@ -126,6 +222,7 @@ impl CommandData { let dir_key = interner.get_or_intern(with_trailing_slash(&history.cwd)); let host_key = interner.get_or_intern(&history.hostname); + let invocation = InvocationData::new(history, interner)?; let mut directories = HashSet::new(); directories.insert(dir_key); @@ -136,6 +233,9 @@ impl CommandData { let mut sessions = HashSet::new(); sessions.insert(session); + let mut authors = HashSet::new(); + authors.insert(interner.get_or_intern(&history.author)); + let mut global_frecency = FrecencyData::default(); global_frecency.record_use(timestamp); @@ -146,6 +246,8 @@ impl CommandData { directories, hosts, sessions, + authors, + invocations: vec![invocation], }) } @@ -160,6 +262,9 @@ impl CommandData { }; let timestamp = history.timestamp.unix_timestamp(); + let Some(invocation) = InvocationData::new(history, interner) else { + return false; + }; // Update global frecency self.global_frecency.record_use(timestamp); @@ -169,6 +274,8 @@ impl CommandData { self.directories.insert(dir_key); self.hosts.insert(interner.get_or_intern(&history.hostname)); self.sessions.insert(session); + self.authors.insert(interner.get_or_intern(&history.author)); + self.invocations.push(invocation); // Update most recent if this invocation is newer if timestamp > self.most_recent_timestamp { @@ -184,6 +291,27 @@ impl CommandData { format_uuid_bytes(&self.most_recent_id) } + fn most_recent_matching_id( + &self, + mode: &IndexFilterMode, + authors: &CompiledAuthorFilter, + interner: &ThreadedRodeo, + agent_authors: &HashSet<Spur>, + ) -> Option<String> { + if matches!(mode, IndexFilterMode::Global) && authors.is_empty() { + return Some(self.most_recent_id()); + } + + self.invocations + .iter() + .filter(|invocation| { + invocation.matches_mode(mode, interner) + && invocation.matches_authors(authors, agent_authors) + }) + .max_by_key(|invocation| invocation.timestamp) + .map(InvocationData::id_string) + } + /// Check if any invocation matches a directory filter (exact match). /// O(1) lookup using pre-computed index. pub fn has_invocation_in_dir(&self, dir: &str, interner: &ThreadedRodeo) -> bool { @@ -213,6 +341,16 @@ impl CommandData { pub fn has_invocation_in_session(&self, session: &str) -> bool { parse_uuid_bytes(session).is_some_and(|bytes| self.sessions.contains(&bytes)) } + + fn matches_authors( + &self, + filter: &CompiledAuthorFilter, + agent_authors: &HashSet<Spur>, + ) -> bool { + self.authors + .iter() + .any(|author| filter.matches_author(*author, agent_authors)) + } } /// Filter mode for search queries. @@ -336,6 +474,7 @@ impl SearchIndex { query: &str, filter_mode: IndexFilterMode, _context: &QueryContext, + authors: &[String], limit: u32, ) -> Vec<String> { let mut nucleo = self.nucleo.write().await; @@ -343,8 +482,21 @@ impl SearchIndex { // Get precomputed frecency map (may be None if not yet computed) let frecency_map = self.frecency_map.read().await.clone(); - // Build filter based on mode - let filter = self.build_filter(&filter_mode); + // Build filter based on mode + authors + let mode_filter = self.build_filter(&filter_mode); + let compiled_authors = CompiledAuthorFilter::new(authors, &self.interner); + let author_filter = self.build_author_filter(&compiled_authors); + let agent_authors: HashSet<Spur> = KNOWN_AGENTS + .iter() + .filter_map(|author| self.interner.get(author)) + .collect(); + let filter = + match (mode_filter, author_filter) { + (Some(mf), Some(af)) => Some(Arc::new(move |cmd: &String| mf(cmd) && af(cmd)) + as atuin_nucleo::Filter<String>), + (Some(f), None) | (None, Some(f)) => Some(f), + (None, None) => None, + }; nucleo.set_filter(filter); // Build scorer from precomputed frecency (or None if not available) @@ -375,9 +527,14 @@ impl SearchIndex { .filter_map(|item| { let cmd = item.data; // DashMap<Arc<str>, _>::get accepts &str via Borrow trait - self.commands - .get(cmd.as_str()) - .map(|data| data.most_recent_id()) + self.commands.get(cmd.as_str()).and_then(|data| { + data.most_recent_matching_id( + &filter_mode, + &compiled_authors, + &self.interner, + &agent_authors, + ) + }) }) .collect() }) @@ -452,6 +609,36 @@ impl SearchIndex { Some(Arc::new(move |cmd: &String| passing_commands.contains(cmd))) } + /// Build author filter predicate. + /// Same pre-computation approach as `build_filter`: find all commands with + /// a matching author, collect into a HashSet, wrap as a Nucleo Filter closure. + fn build_author_filter( + &self, + authors: &CompiledAuthorFilter, + ) -> Option<atuin_nucleo::Filter<String>> { + if authors.is_empty() { + return None; + } + + let agent_authors: HashSet<Spur> = KNOWN_AGENTS + .iter() + .filter_map(|a| self.interner.get(a)) + .collect(); + + let passing_commands: Arc<HashSet<String>> = { + let mut set = HashSet::new(); + for entry in self.commands.iter() { + let passes = entry.matches_authors(authors, &agent_authors); + if passes { + set.insert(entry.key().to_string()); + } + } + Arc::new(set) + }; + + Some(Arc::new(move |cmd: &String| passing_commands.contains(cmd))) + } + /// Build scorer from precomputed frecency map. /// /// Returns None if frecency map is not available (search still works, just without frecency ranking). @@ -634,6 +821,86 @@ mod tests { } #[tokio::test] + async fn search_index_returns_latest_matching_author_invocation() { + let index = SearchIndex::new(); + + let user_history: History = History::import() + .timestamp(datetime!(2024-01-01 10:00 UTC)) + .command("git status") + .cwd("/home/user/project") + .author("ellie") + .build() + .into(); + let agent_history: History = History::import() + .timestamp(datetime!(2024-01-01 11:00 UTC)) + .command("git status") + .cwd("/home/user/project") + .author("codex") + .build() + .into(); + + index.add_history(&user_history); + index.add_history(&agent_history); + + let user_results = index + .search( + "git", + IndexFilterMode::Global, + &QueryContext::default(), + &[AUTHOR_FILTER_ALL_USER.to_string()], + 10, + ) + .await; + assert_eq!( + user_results, + vec![Uuid::parse_str(&user_history.id.0).unwrap().to_string()] + ); + + let agent_results = index + .search( + "git", + IndexFilterMode::Global, + &QueryContext::default(), + &[AUTHOR_FILTER_ALL_AGENT.to_string()], + 10, + ) + .await; + assert_eq!( + agent_results, + vec![Uuid::parse_str(&agent_history.id.0).unwrap().to_string()] + ); + } + + #[tokio::test] + async fn search_index_returns_latest_matching_directory_invocation() { + let index = SearchIndex::new(); + + let dir1 = "/home/user/project"; + let dir2 = "/home/user/other"; + + let project_history = make_history("git status", dir1, datetime!(2024-01-01 10:00 UTC)); + let other_history = make_history("git status", dir2, datetime!(2024-01-01 11:00 UTC)); + + index.add_history(&project_history); + index.add_history(&other_history); + + let results = index + .search( + "git", + IndexFilterMode::Directory(with_trailing_slash(dir1)), + &QueryContext::default(), + &[], + 10, + ) + .await; + + assert_eq!( + results, + vec![Uuid::parse_str(&project_history.id.0).unwrap().to_string()] + ); + } + + #[tokio::test] async fn search_index_add_and_search() { let index = SearchIndex::new(); @@ -661,7 +928,13 @@ mod tests { // Search for "git" - should match 2 commands let results = index - .search("git", IndexFilterMode::Global, &QueryContext::default(), 10) + .search( + "git", + IndexFilterMode::Global, + &QueryContext::default(), + &[], + 10, + ) .await; assert_eq!(results.len(), 2); @@ -671,6 +944,7 @@ mod tests { "", IndexFilterMode::Directory(with_trailing_slash("/home/user/project")), &QueryContext::default(), + &[], 10, ) .await; diff --git a/crates/atuin/src/command/client.rs b/crates/atuin/src/command/client.rs index 917bee1a..e3ec01d9 100644 --- a/crates/atuin/src/command/client.rs +++ b/crates/atuin/src/command/client.rs @@ -54,6 +54,7 @@ mod default_config; mod doctor; mod dotfiles; mod history; +mod hook; mod import; mod info; mod init; @@ -76,6 +77,9 @@ pub enum Cmd { #[command(subcommand)] History(history::Cmd), + /// Manage AI-agent shell hooks + Hook(hook::Cmd), + /// Import shell history from file #[command(subcommand)] Import(import::Cmd), @@ -337,6 +341,7 @@ impl Cmd { // runs match self { Self::History(history) => return history.run(&settings).await, + Self::Hook(hook) => return hook.run(&settings).await, Self::Init(init) => return init.run(&settings).await, Self::Doctor => return doctor::run(&settings).await, Self::Config(config) => return config.run(&settings).await, @@ -387,7 +392,9 @@ impl Cmd { #[cfg(feature = "daemon")] Self::Daemon(cmd) => cmd.run(settings, sqlite_store, db).await, - Self::History(_) | Self::Init(_) | Self::Doctor | Self::Config(_) => unreachable!(), + Self::History(_) | Self::Hook(_) | Self::Init(_) | Self::Doctor | Self::Config(_) => { + unreachable!() + } #[cfg(feature = "ai")] Self::Ai(cli) => atuin_ai::commands::run(cli, &settings).await, diff --git a/crates/atuin/src/command/client/history.rs b/crates/atuin/src/command/client/history.rs index 67e0a5db..836556b4 100644 --- a/crates/atuin/src/command/client/history.rs +++ b/crates/atuin/src/command/client/history.rs @@ -374,6 +374,235 @@ fn parse_fmt(format: &str) -> ParsedFmt<'_> { } } +fn apply_start_metadata(history: &mut History, author: Option<&str>, intent: Option<&str>) { + if let Some(author) = author.map(str::trim).filter(|author| !author.is_empty()) { + author.clone_into(&mut history.author); + } + + if let Some(intent) = intent.map(str::trim).filter(|intent| !intent.is_empty()) { + history.intent = Some(intent.to_owned()); + } else if intent.is_some() { + history.intent = None; + } +} + +fn normalize_command_for_storage<'a>(command: &'a str, settings: &Settings) -> &'a str { + if !settings.strip_trailing_whitespace { + return command; + } + + let trimmed = command.trim_end_matches([' ', '\t']); + if trimmed.len() == command.len() { + return command; + } + + let trailing_backslashes = trimmed + .as_bytes() + .iter() + .rev() + .take_while(|&&byte| byte == b'\\') + .count(); + + if trailing_backslashes % 2 == 1 { + command + } else { + trimmed + } +} + +async fn handle_start( + db: &impl Database, + settings: &Settings, + command: &str, + author: Option<&str>, + intent: Option<&str>, +) -> Result<Option<String>> { + // It's better for atuin to silently fail here and attempt to + // store whatever is ran, than to throw an error to the terminal + let cwd = utils::get_current_dir(); + let command = normalize_command_for_storage(command, settings); + + let mut h: History = History::capture() + .timestamp(OffsetDateTime::now_utc()) + .command(command) + .cwd(cwd) + .build() + .into(); + apply_start_metadata(&mut h, author, intent); + + if !h.should_save(settings) { + return Ok(None); + } + + let id = h.id.0.clone(); + + // Silently ignore database errors to avoid breaking the shell + // This is important when disk is full or database is locked + if let Err(e) = db.save(&h).await { + debug!("failed to save history: {e}"); + } + + Ok(Some(id)) +} + +#[cfg(feature = "daemon")] +async fn handle_daemon_start( + settings: &Settings, + command: &str, + author: Option<&str>, + intent: Option<&str>, +) -> Result<Option<String>> { + // It's better for atuin to silently fail here and attempt to + // store whatever is ran, than to throw an error to the terminal + let cwd = utils::get_current_dir(); + let command = normalize_command_for_storage(command, settings); + + let mut h: History = History::capture() + .timestamp(OffsetDateTime::now_utc()) + .command(command) + .cwd(cwd) + .build() + .into(); + apply_start_metadata(&mut h, author, intent); + + if !h.should_save(settings) { + return Ok(None); + } + + // Attempt to start history via daemon, but silently ignore errors + // to avoid breaking the shell when the daemon is unavailable or disk is full + let resp = match daemon::start_history(settings, h.clone()).await { + Ok(id) => id, + Err(e) => { + debug!("failed to start history via daemon: {e}"); + h.id.0.clone() + } + }; + + Ok(Some(resp)) +} + +#[allow(unused_variables)] +async fn handle_end( + db: &impl Database, + store: SqliteStore, + history_store: HistoryStore, + settings: &Settings, + id: &str, + exit: i64, + duration: Option<u64>, +) -> Result<()> { + if id.trim() == "" { + return Ok(()); + } + + let Some(mut h) = db.load(id).await? else { + warn!("history entry is missing"); + return Ok(()); + }; + + if h.duration > 0 { + debug!("cannot end history - already has duration"); + + // returning OK as this can occur if someone Ctrl-c a prompt + return Ok(()); + } + + if !settings.store_failed && exit > 0 { + debug!("history has non-zero exit code, and store_failed is false"); + + // the history has already been inserted half complete. remove it + db.delete(h).await?; + + return Ok(()); + } + + h.exit = exit; + h.duration = match duration { + Some(value) => i64::try_from(value).context("command took over 292 years")?, + None => i64::try_from((OffsetDateTime::now_utc() - h.timestamp).whole_nanoseconds()) + .context("command took over 292 years")?, + }; + + db.update(&h).await?; + history_store.push(h).await?; + + if settings.should_sync().await? { + #[cfg(feature = "sync")] + { + if settings.sync.records { + let (_, downloaded) = record::sync::sync(settings, &store).await?; + Settings::save_sync_time().await?; + + crate::sync::build(settings, &store, db, Some(&downloaded)).await?; + } else { + debug!("running periodic background sync"); + sync::sync(settings, false, db).await?; + } + } + #[cfg(not(feature = "sync"))] + debug!("not compiled with sync support"); + } else { + debug!("sync disabled! not syncing"); + } + + Ok(()) +} + +#[cfg(feature = "daemon")] +async fn handle_daemon_end( + settings: &Settings, + id: &str, + exit: i64, + duration: Option<u64>, +) -> Result<()> { + daemon::end_history(settings, id.to_string(), duration.unwrap_or(0), exit).await?; + + Ok(()) +} + +pub(super) async fn start_history_entry( + settings: &Settings, + command: &str, + author: Option<&str>, + intent: Option<&str>, +) -> Result<Option<String>> { + #[cfg(feature = "daemon")] + if settings.daemon.enabled { + return handle_daemon_start(settings, command, author, intent).await; + } + + let db_path = PathBuf::from(settings.db_path.as_str()); + let db = Sqlite::new(db_path, settings.local_timeout).await?; + handle_start(&db, settings, command, author, intent).await +} + +pub(super) async fn end_history_entry( + settings: &Settings, + id: &str, + exit: i64, + duration: Option<u64>, +) -> Result<()> { + #[cfg(feature = "daemon")] + if settings.daemon.enabled { + return handle_daemon_end(settings, id, exit, duration).await; + } + + let db_path = PathBuf::from(settings.db_path.as_str()); + let record_store_path = PathBuf::from(settings.record_store_path.as_str()); + + let db = Sqlite::new(db_path, settings.local_timeout).await?; + let store = SqliteStore::new(record_store_path, settings.local_timeout).await?; + + let encryption_key: [u8; 32] = encryption::load_key(settings) + .context("could not load encryption key")? + .into(); + let host_id = Settings::host_id().await?; + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + + handle_end(&db, store, history_store, settings, id, exit, duration).await +} + #[cfg(feature = "daemon")] #[derive(Clone, Copy, Debug, Eq, PartialEq)] enum TailKind { @@ -676,200 +905,6 @@ fn normalize_optional_field(value: &str) -> Option<String> { } impl Cmd { - fn normalize_command_for_storage<'a>(command: &'a str, settings: &Settings) -> &'a str { - if !settings.strip_trailing_whitespace { - return command; - } - - let trimmed = command.trim_end_matches([' ', '\t']); - if trimmed.len() == command.len() { - return command; - } - - let trailing_backslashes = trimmed - .as_bytes() - .iter() - .rev() - .take_while(|&&byte| byte == b'\\') - .count(); - - if trailing_backslashes % 2 == 1 { - command - } else { - trimmed - } - } - - fn apply_start_metadata(history: &mut History, author: Option<&str>, intent: Option<&str>) { - if let Some(author) = author.map(str::trim).filter(|author| !author.is_empty()) { - author.clone_into(&mut history.author); - } - - if let Some(intent) = intent.map(str::trim).filter(|intent| !intent.is_empty()) { - history.intent = Some(intent.to_owned()); - } else if intent.is_some() { - history.intent = None; - } - } - - #[allow(clippy::too_many_lines, clippy::cast_possible_truncation)] - async fn handle_start( - db: &impl Database, - settings: &Settings, - command: &str, - author: Option<&str>, - intent: Option<&str>, - ) -> Result<()> { - // It's better for atuin to silently fail here and attempt to - // store whatever is ran, than to throw an error to the terminal - let cwd = utils::get_current_dir(); - let command = Self::normalize_command_for_storage(command, settings); - - let mut h: History = History::capture() - .timestamp(OffsetDateTime::now_utc()) - .command(command) - .cwd(cwd) - .build() - .into(); - Self::apply_start_metadata(&mut h, author, intent); - - if !h.should_save(settings) { - return Ok(()); - } - - // print the ID - // we use this as the key for calling end - println!("{}", h.id); - - // Silently ignore database errors to avoid breaking the shell - // This is important when disk is full or database is locked - if let Err(e) = db.save(&h).await { - debug!("failed to save history: {e}"); - } - - Ok(()) - } - - #[cfg(feature = "daemon")] - async fn handle_daemon_start( - settings: &Settings, - command: &str, - author: Option<&str>, - intent: Option<&str>, - ) -> Result<()> { - // It's better for atuin to silently fail here and attempt to - // store whatever is ran, than to throw an error to the terminal - let cwd = utils::get_current_dir(); - let command = Self::normalize_command_for_storage(command, settings); - - let mut h: History = History::capture() - .timestamp(OffsetDateTime::now_utc()) - .command(command) - .cwd(cwd) - .build() - .into(); - Self::apply_start_metadata(&mut h, author, intent); - - if !h.should_save(settings) { - return Ok(()); - } - - // Attempt to start history via daemon, but silently ignore errors - // to avoid breaking the shell when the daemon is unavailable or disk is full - let resp = match daemon::start_history(settings, h.clone()).await { - Ok(id) => id, - Err(e) => { - debug!("failed to start history via daemon: {e}"); - h.id.0.clone() - } - }; - - // print the ID - // we use this as the key for calling end - println!("{resp}"); - - Ok(()) - } - - #[allow(unused_variables)] - async fn handle_end( - db: &impl Database, - store: SqliteStore, - history_store: HistoryStore, - settings: &Settings, - id: &str, - exit: i64, - duration: Option<u64>, - ) -> Result<()> { - if id.trim() == "" { - return Ok(()); - } - - let Some(mut h) = db.load(id).await? else { - warn!("history entry is missing"); - return Ok(()); - }; - - if h.duration > 0 { - debug!("cannot end history - already has duration"); - - // returning OK as this can occur if someone Ctrl-c a prompt - return Ok(()); - } - - if !settings.store_failed && exit > 0 { - debug!("history has non-zero exit code, and store_failed is false"); - - // the history has already been inserted half complete. remove it - db.delete(h).await?; - - return Ok(()); - } - - h.exit = exit; - h.duration = match duration { - Some(value) => i64::try_from(value).context("command took over 292 years")?, - None => i64::try_from((OffsetDateTime::now_utc() - h.timestamp).whole_nanoseconds()) - .context("command took over 292 years")?, - }; - - db.update(&h).await?; - history_store.push(h).await?; - - if settings.should_sync().await? { - #[cfg(feature = "sync")] - { - if settings.sync.records { - let (_, downloaded) = record::sync::sync(settings, &store).await?; - Settings::save_sync_time().await?; - - crate::sync::build(settings, &store, db, Some(&downloaded)).await?; - } else { - debug!("running periodic background sync"); - sync::sync(settings, false, db).await?; - } - } - #[cfg(not(feature = "sync"))] - debug!("not compiled with sync support"); - } else { - debug!("sync disabled! not syncing"); - } - - Ok(()) - } - - #[cfg(feature = "daemon")] - async fn handle_daemon_end( - settings: &Settings, - id: &str, - exit: i64, - duration: Option<u64>, - ) -> Result<()> { - daemon::end_history(settings, id.to_string(), duration.unwrap_or(0), exit).await?; - - Ok(()) - } - #[cfg(feature = "daemon")] async fn handle_tail(settings: &Settings) -> Result<()> { let tty = std::io::stdout().is_terminal(); @@ -892,6 +927,7 @@ impl Cmd { Ok(()) } + #[allow(clippy::too_many_lines, clippy::cast_possible_truncation)] #[allow(clippy::too_many_arguments)] #[allow(clippy::fn_params_excessive_bools)] async fn handle_list( @@ -1057,145 +1093,125 @@ impl Cmd { #[allow(clippy::too_many_lines)] pub async fn run(self, settings: &Settings) -> Result<()> { - let context = current_context().await?; - - #[cfg(feature = "daemon")] - // Skip initializing any databases for start/end, if the daemon is enabled - if settings.daemon.enabled { - match self { - Self::Start { .. } => { - let command = self.get_start_command().unwrap_or_default(); - let (author, intent) = self.get_start_metadata().unwrap_or_default(); - return Self::handle_daemon_start(settings, &command, author, intent).await; - } + match self { + Self::Start { + cmd_env, + author, + intent, + command, + } => { + let command = if cmd_env { + std::env::var("ATUIN_COMMAND_LINE").unwrap_or_default() + } else { + command.join(" ") + }; - Self::End { id, exit, duration } => { - return Self::handle_daemon_end(settings, &id, exit, duration).await; + if let Some(id) = + start_history_entry(settings, &command, author.as_deref(), intent.as_deref()) + .await? + { + println!("{id}"); } - Self::Tail => { + Ok(()) + } + Self::End { id, exit, duration } => { + end_history_entry(settings, &id, exit, duration).await + } + Self::Tail => { + #[cfg(feature = "daemon")] + { return Self::handle_tail(settings).await; } - _ => {} + #[cfg(not(feature = "daemon"))] + bail!("`atuin history tail` requires Atuin to be built with the `daemon` feature"); } - } - - if matches!(self, Self::Tail) { - #[cfg(feature = "daemon")] - bail!("`atuin history tail` requires `daemon.enabled = true`"); - - #[cfg(not(feature = "daemon"))] - bail!("`atuin history tail` requires Atuin to be built with the `daemon` feature"); - } + cmd => { + let context = current_context().await?; - let db_path = PathBuf::from(settings.db_path.as_str()); - let record_store_path = PathBuf::from(settings.record_store_path.as_str()); + let db_path = PathBuf::from(settings.db_path.as_str()); + let record_store_path = PathBuf::from(settings.record_store_path.as_str()); - let db = Sqlite::new(db_path, settings.local_timeout).await?; - let store = SqliteStore::new(record_store_path, settings.local_timeout).await?; + let db = Sqlite::new(db_path, settings.local_timeout).await?; + let store = SqliteStore::new(record_store_path, settings.local_timeout).await?; - let encryption_key: [u8; 32] = encryption::load_key(settings) - .context("could not load encryption key")? - .into(); + let encryption_key: [u8; 32] = encryption::load_key(settings) + .context("could not load encryption key")? + .into(); - let host_id = Settings::host_id().await?; - let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + let host_id = Settings::host_id().await?; + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); - match self { - Self::Start { .. } => { - let command = self.get_start_command().unwrap_or_default(); - let (author, intent) = self.get_start_metadata().unwrap_or_default(); - Self::handle_start(&db, settings, &command, author, intent).await - } - Self::End { id, exit, duration } => { - Self::handle_end(&db, store, history_store, settings, &id, exit, duration).await - } - Self::Tail => unreachable!("tail handled before database initialization"), - Self::List { - session, - cwd, - human, - cmd_only, - print0, - reverse, - timezone, - format, - } => { - let mode = ListMode::from_flags(human, cmd_only); - let tz = timezone.unwrap_or(settings.timezone); - Self::handle_list( - &db, settings, context, session, cwd, mode, format, false, print0, reverse, tz, - ) - .await - } + match cmd { + Self::List { + session, + cwd, + human, + cmd_only, + print0, + reverse, + timezone, + format, + } => { + let mode = ListMode::from_flags(human, cmd_only); + let tz = timezone.unwrap_or(settings.timezone); + Self::handle_list( + &db, settings, context, session, cwd, mode, format, false, print0, + reverse, tz, + ) + .await + } - Self::Last { - human, - cmd_only, - timezone, - format, - } => { - let last = db.last().await?; - let last = last.as_slice(); - let tz = timezone.unwrap_or(settings.timezone); - print_list( - last, - ListMode::from_flags(human, cmd_only), - match format { - None => Some(settings.history_format.as_str()), - _ => format.as_deref(), - }, - false, - true, - tz, - ); + Self::Last { + human, + cmd_only, + timezone, + format, + } => { + let last = db.last().await?; + let last = last.as_slice(); + let tz = timezone.unwrap_or(settings.timezone); + print_list( + last, + ListMode::from_flags(human, cmd_only), + match format { + None => Some(settings.history_format.as_str()), + _ => format.as_deref(), + }, + false, + true, + tz, + ); - Ok(()) - } + Ok(()) + } - Self::InitStore => history_store.init_store(&db).await, + Self::InitStore => history_store.init_store(&db).await, - Self::Prune { dry_run } => { - Self::handle_prune(&db, settings, store, context, dry_run).await - } + Self::Prune { dry_run } => { + Self::handle_prune(&db, settings, store, context, dry_run).await + } - Self::Dedup { - dry_run, - before, - dupkeep, - } => { - let before = i64::try_from( - interim::parse_date_string( - before.as_str(), - OffsetDateTime::now_utc(), - interim::Dialect::Uk, - )? - .unix_timestamp_nanos(), - )?; - Self::handle_dedup(&db, settings, store, before, dupkeep, dry_run).await - } - } - } + Self::Dedup { + dry_run, + before, + dupkeep, + } => { + let before = i64::try_from( + interim::parse_date_string( + before.as_str(), + OffsetDateTime::now_utc(), + interim::Dialect::Uk, + )? + .unix_timestamp_nanos(), + )?; + Self::handle_dedup(&db, settings, store, before, dupkeep, dry_run).await + } - /// Returns the command line to use for the `Start` variant. - /// Returns `None` for any other variant. - fn get_start_command(&self) -> Option<String> { - match self { - Self::Start { cmd_env: true, .. } => { - Some(std::env::var("ATUIN_COMMAND_LINE").unwrap_or_default()) + Self::Start { .. } | Self::End { .. } | Self::Tail => unreachable!(), + } } - Self::Start { command, .. } => Some(command.join(" ")), - _ => None, - } - } - - /// Returns `(author, intent)` for the `Start` variant. - /// Returns `None` for any other variant. - fn get_start_metadata(&self) -> Option<(Option<&str>, Option<&str>)> { - match self { - Self::Start { author, intent, .. } => Some((author.as_deref(), intent.as_deref())), - _ => None, } } } @@ -1212,10 +1228,7 @@ mod tests { let settings = Settings::utc(); assert!(settings.strip_trailing_whitespace); - assert_eq!( - Cmd::normalize_command_for_storage("ls \t", &settings), - "ls" - ); + assert_eq!(normalize_command_for_storage("ls \t", &settings), "ls"); } #[test] @@ -1223,11 +1236,11 @@ mod tests { let settings = Settings::utc(); assert_eq!( - Cmd::normalize_command_for_storage("printf foo\\ ", &settings), + normalize_command_for_storage("printf foo\\ ", &settings), "printf foo\\ " ); assert_eq!( - Cmd::normalize_command_for_storage("printf foo\\\\ ", &settings), + normalize_command_for_storage("printf foo\\\\ ", &settings), "printf foo\\\\" ); } @@ -1237,7 +1250,7 @@ mod tests { let db = Sqlite::new("sqlite::memory:", 2.0).await.unwrap(); let settings = Settings::utc(); - Cmd::handle_start(&db, &settings, "ls \t", None, None) + handle_start(&db, &settings, "ls \t", None, None) .await .unwrap(); @@ -1258,7 +1271,7 @@ mod tests { ..Settings::utc() }; - Cmd::handle_start(&db, &settings, "ls \t", None, None) + handle_start(&db, &settings, "ls \t", None, None) .await .unwrap(); diff --git a/crates/atuin/src/command/client/hook.rs b/crates/atuin/src/command/client/hook.rs new file mode 100644 index 00000000..bb333c5f --- /dev/null +++ b/crates/atuin/src/command/client/hook.rs @@ -0,0 +1,401 @@ +use std::io::Read; +use std::path::PathBuf; + +use atuin_client::settings::Settings; +use atuin_common::utils::home_dir; +use clap::{Parser, Subcommand}; +use eyre::{Result, bail}; +use serde_json::Value; + +use super::history; + +const HOOK_EVENT_TYPES: &[&str] = &["PreToolUse", "PostToolUse", "PostToolUseFailure"]; + +struct AgentSpec { + aliases: &'static [&'static str], + actor_name: &'static str, + config_path: &'static [&'static str], + hook_command: &'static str, + matcher: &'static str, +} + +const CLAUDE_CODE: AgentSpec = AgentSpec { + aliases: &["claude-code", "claude"], + actor_name: "claude-code", + config_path: &[".claude", "settings.json"], + hook_command: "atuin hook claude-code", + matcher: "Bash", +}; + +const CODEX: AgentSpec = AgentSpec { + aliases: &["codex"], + actor_name: "codex", + config_path: &[".codex", "hooks.json"], + hook_command: "atuin hook codex", + matcher: "^Bash$", +}; + +const AGENTS: &[&AgentSpec] = &[&CLAUDE_CODE, &CODEX]; + +struct Agent(&'static AgentSpec); + +impl Agent { + fn from_name(name: &str) -> Result<Self> { + AGENTS + .iter() + .copied() + .find(|spec| spec.aliases.contains(&name)) + .map(Self) + .ok_or_else(|| { + eyre::eyre!("unknown agent: {name}. Supported agents: claude-code, codex") + }) + } + + fn actor_name(&self) -> &'static str { + self.0.actor_name + } + + fn config_path(&self) -> PathBuf { + self.0 + .config_path + .iter() + .fold(home_dir(), |path, segment| path.join(segment)) + } + + fn hook_command(&self) -> &'static str { + self.0.hook_command + } + + fn matcher(&self) -> &'static str { + self.0.matcher + } +} + +#[derive(Subcommand, Debug)] +enum Action { + /// Install hooks for an AI agent to capture commands in atuin history + Install { + /// Agent to install hooks for (e.g., "claude-code") + #[arg(value_name = "AGENT")] + agent: String, + }, +} + +#[derive(Parser, Debug)] +#[command(infer_subcommands = true, args_conflicts_with_subcommands = true)] +pub struct Cmd { + #[command(subcommand)] + action: Option<Action>, + + /// Which agent's hook format to parse (e.g., "claude-code") + #[arg(value_name = "AGENT", hide = true)] + agent: Option<String>, +} + +impl Cmd { + pub async fn run(self, settings: &Settings) -> Result<()> { + match (self.action, self.agent) { + (Some(Action::Install { agent }), None) => install(&agent), + (None, Some(agent)) => handle(&agent, settings).await, + (None, None) => bail!("expected `atuin hook <agent>` or `atuin hook install <agent>`"), + (Some(_), Some(_)) => bail!("hook action cannot be combined with a positional agent"), + } + } +} + +#[derive(Debug)] +enum HookEvent { + Start { + command: String, + intent: Option<String>, + tool_use_id: String, + }, + End { + tool_use_id: String, + exit: i64, + }, + Skip, +} + +fn parse_hook_stdin(input: &str) -> Result<HookEvent> { + let v: Value = serde_json::from_str(input)?; + + if v.get("tool_name").and_then(|t| t.as_str()) != Some("Bash") { + return Ok(HookEvent::Skip); + } + + let tool_use_id = match v.get("tool_use_id").and_then(|t| t.as_str()) { + Some(id) if !id.is_empty() => id.to_string(), + _ => return Ok(HookEvent::Skip), + }; + + match v.get("hook_event_name").and_then(|e| e.as_str()) { + Some("PreToolUse") => { + let tool_input = v.get("tool_input"); + let command = tool_input + .and_then(|ti| ti.get("command")) + .and_then(|c| c.as_str()) + .unwrap_or(""); + + if command.is_empty() { + return Ok(HookEvent::Skip); + } + + let intent = tool_input + .and_then(|ti| ti.get("description")) + .and_then(|d| d.as_str()) + .map(String::from); + + Ok(HookEvent::Start { + command: command.to_string(), + intent, + tool_use_id, + }) + } + Some(event @ ("PostToolUse" | "PostToolUseFailure")) => { + let exit = if event == "PostToolUseFailure" { + 1 + } else { + v.get("tool_response") + .and_then(|tr| tr.get("exitCode")) + .and_then(Value::as_i64) + .unwrap_or(0) + }; + + Ok(HookEvent::End { tool_use_id, exit }) + } + _ => Ok(HookEvent::Skip), + } +} + +fn id_file_path(tool_use_id: &str) -> PathBuf { + std::env::temp_dir().join(format!("atuin-hook-{tool_use_id}")) +} + +async fn handle(agent_name: &str, settings: &Settings) -> Result<()> { + let agent = Agent::from_name(agent_name)?; + + let mut input = String::new(); + std::io::stdin().read_to_string(&mut input)?; + + if input.trim().is_empty() { + return Ok(()); + } + + match parse_hook_stdin(&input)? { + HookEvent::Start { + command, + intent, + tool_use_id, + } => { + if let Some(history_id) = history::start_history_entry( + settings, + &command, + Some(agent.actor_name()), + intent.as_deref(), + ) + .await? + { + std::fs::write(id_file_path(&tool_use_id), &history_id)?; + } + } + HookEvent::End { tool_use_id, exit } => { + let id_path = id_file_path(&tool_use_id); + + if let Ok(history_id) = std::fs::read_to_string(&id_path) { + let history_id = history_id.trim(); + if !history_id.is_empty() { + let _ = history::end_history_entry(settings, history_id, exit, None).await; + } + let _ = std::fs::remove_file(&id_path); + } + } + HookEvent::Skip => {} + } + + Ok(()) +} + +fn install(agent_name: &str) -> Result<()> { + let agent = Agent::from_name(agent_name)?; + let config_path = agent.config_path(); + + if let Some(parent) = config_path.parent() { + std::fs::create_dir_all(parent)?; + } + + let mut root: Value = if config_path.exists() { + let content = std::fs::read_to_string(&config_path)?; + serde_json::from_str(&content)? + } else { + Value::Object(serde_json::Map::new()) + }; + + let hooks = root + .as_object_mut() + .ok_or_else(|| eyre::eyre!("config is not a JSON object"))? + .entry("hooks") + .or_insert_with(|| Value::Object(serde_json::Map::new())); + + add_hook_entries(hooks, &agent)?; + + let content = serde_json::to_string_pretty(&root)?; + std::fs::write(&config_path, content)?; + + eprintln!( + "\nAtuin hooks installed for {}. Config: {}", + agent.actor_name(), + config_path.display() + ); + + Ok(()) +} + +fn add_hook_entries(hooks: &mut Value, agent: &Agent) -> Result<()> { + let hook_command = agent.hook_command(); + let matcher = agent.matcher(); + + for event_type in HOOK_EVENT_TYPES { + let event_hooks = hooks + .as_object_mut() + .ok_or_else(|| eyre::eyre!("hooks is not a JSON object"))? + .entry(*event_type) + .or_insert_with(|| Value::Array(Vec::new())); + + let arr = event_hooks + .as_array_mut() + .ok_or_else(|| eyre::eyre!("hooks.{event_type} is not an array"))?; + + let already_installed = arr.iter().any(|entry| { + entry["hooks"].as_array().is_some_and(|h| { + h.iter() + .any(|hook| hook["command"].as_str() == Some(hook_command)) + }) + }); + + if already_installed { + eprintln!("hooks.{event_type}: already installed, skipping"); + continue; + } + + arr.push(serde_json::json!({ + "matcher": matcher, + "hooks": [{"type": "command", "command": hook_command}] + })); + eprintln!("hooks.{event_type}: installed atuin hook"); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + Atuin, + command::{AtuinCmd, client}, + }; + use clap::Parser; + + #[test] + fn parse_hook_agent_command() { + let cmd = Cmd::try_parse_from(["hook", "codex"]).unwrap(); + + assert!(matches!( + (cmd.action, cmd.agent.as_deref()), + (None, Some("codex")) + )); + } + + #[test] + fn parse_hook_install_command() { + let cmd = Cmd::try_parse_from(["hook", "install", "codex"]).unwrap(); + + match (cmd.action, cmd.agent) { + (Some(Action::Install { agent }), None) => assert_eq!(agent, "codex"), + other => panic!("unexpected parsed command: {other:?}"), + } + } + + #[test] + fn parse_top_level_hook_command() { + let cmd = Atuin::try_parse_from(["atuin", "hook", "codex"]).unwrap(); + + assert!(matches!( + cmd.atuin, + AtuinCmd::Client(client::Cmd::Hook(Cmd { action: None, agent: Some(agent) })) + if agent == "codex" + )); + } + + #[test] + fn test_parse_pre_tool_use() { + let input = r#"{ + "hook_event_name": "PreToolUse", + "tool_name": "Bash", + "tool_input": {"command": "echo hello", "description": "Test greeting"}, + "tool_use_id": "toolu_abc123", + "session_id": "sess1", + "cwd": "/tmp" + }"#; + + match parse_hook_stdin(input).unwrap() { + HookEvent::Start { + command, + intent, + tool_use_id, + } => { + assert_eq!(command, "echo hello"); + assert_eq!(intent.as_deref(), Some("Test greeting")); + assert_eq!(tool_use_id, "toolu_abc123"); + } + _ => panic!("expected Start event"), + } + } + + #[test] + fn test_parse_post_tool_use() { + let input = r#"{ + "hook_event_name": "PostToolUse", + "tool_name": "Bash", + "tool_input": {"command": "echo hello"}, + "tool_response": {"exitCode": 0}, + "tool_use_id": "toolu_abc123" + }"#; + + match parse_hook_stdin(input).unwrap() { + HookEvent::End { tool_use_id, exit } => { + assert_eq!(tool_use_id, "toolu_abc123"); + assert_eq!(exit, 0); + } + _ => panic!("expected End event"), + } + } + + #[test] + fn test_parse_non_bash_tool_skipped() { + let input = r#"{ + "hook_event_name": "PreToolUse", + "tool_name": "Write", + "tool_input": {"file_path": "/tmp/test.txt", "content": "hello"}, + "tool_use_id": "toolu_abc123" + }"#; + + assert!(matches!(parse_hook_stdin(input).unwrap(), HookEvent::Skip)); + } + + #[test] + fn test_parse_failure_event() { + let input = r#"{ + "hook_event_name": "PostToolUseFailure", + "tool_name": "Bash", + "tool_input": {"command": "false"}, + "tool_use_id": "toolu_abc123" + }"#; + + match parse_hook_stdin(input).unwrap() { + HookEvent::End { exit, .. } => assert_eq!(exit, 1), + _ => panic!("expected End event"), + } + } +} diff --git a/crates/atuin/src/command/client/search.rs b/crates/atuin/src/command/client/search.rs index 3d348473..19045867 100644 --- a/crates/atuin/src/command/client/search.rs +++ b/crates/atuin/src/command/client/search.rs @@ -131,6 +131,11 @@ pub struct Cmd { #[arg(long = "inline-height")] inline_height: Option<u16>, + /// Filter by author. Supports $all-user (non-agents), $all-agent, or literal names. + /// Can be specified multiple times. + #[arg(long)] + author: Option<Vec<String>>, + /// Include duplicate commands in the output (non-interactive only) #[arg(long)] include_duplicates: bool, @@ -141,6 +146,12 @@ pub struct Cmd { } impl Cmd { + fn resolved_authors(&self, settings: &Settings) -> Vec<String> { + self.author + .clone() + .unwrap_or_else(|| settings.search.authors.clone()) + } + /// Returns true if this search command will run in interactive (TUI) mode pub fn is_interactive(&self) -> bool { self.interactive @@ -157,6 +168,7 @@ impl Cmd { store: SqliteStore, theme: &Theme, ) -> Result<()> { + let authors = self.resolved_authors(settings); let query = self.query.unwrap_or_else(|| { std::env::var("ATUIN_QUERY").map_or_else( |_| vec![], @@ -221,7 +233,8 @@ impl Cmd { let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); if self.interactive { - let item = interactive::history(&query, settings, db, &history_store, theme).await?; + let item = + interactive::history(&query, authors, settings, db, &history_store, theme).await?; if let Some(result_file) = self.result_file { let mut file = File::create(result_file)?; @@ -248,6 +261,7 @@ impl Cmd { offset: self.offset, reverse: self.reverse, include_duplicates: self.include_duplicates, + authors, }; let mut entries = @@ -337,6 +351,7 @@ async fn run_non_interactive( #[cfg(test)] mod tests { use super::Cmd; + use atuin_client::settings::Settings; use clap::Parser; #[test] @@ -356,4 +371,24 @@ mod tests { let cmd = cmd.unwrap(); assert_eq!(cmd.query, Some(vec!["--foo".to_string()])); } + + #[test] + fn search_authors_default_to_settings() { + let cmd = Cmd::try_parse_from(["search"]).unwrap(); + let settings = Settings::default(); + + assert_eq!(cmd.resolved_authors(&settings), settings.search.authors); + } + + #[test] + fn search_authors_cli_override_config() { + let cmd = + Cmd::try_parse_from(["search", "--author", "codex", "--author", "ellie"]).unwrap(); + let settings = Settings::default(); + + assert_eq!( + cmd.resolved_authors(&settings), + vec!["codex".to_string(), "ellie".to_string()] + ); + } } diff --git a/crates/atuin/src/command/client/search/engines.rs b/crates/atuin/src/command/client/search/engines.rs index 8cbee0c3..98692828 100644 --- a/crates/atuin/src/command/client/search/engines.rs +++ b/crates/atuin/src/command/client/search/engines.rs @@ -1,6 +1,6 @@ use async_trait::async_trait; use atuin_client::{ - database::{Context, Database}, + database::{Context, Database, OptFilters}, history::{History, HistoryId}, settings::{FilterMode, SearchMode, Settings}, }; @@ -33,6 +33,30 @@ pub struct SearchState { pub filter_mode: FilterMode, pub context: Context, pub custom_context: Option<HistoryId>, + pub authors: Vec<String>, +} + +async fn search_db( + state: &SearchState, + db: &dyn Database, + mode: SearchMode, + query: &str, +) -> Result<Vec<History>> { + Ok(db + .search( + mode, + state.filter_mode, + &state.context, + query, + OptFilters { + limit: Some(200), + authors: state.authors.clone(), + ..Default::default() + }, + ) + .await? + .into_iter() + .collect()) } impl SearchState { @@ -70,13 +94,17 @@ pub trait SearchEngine: Send + Sync + 'static { db: &mut dyn Database, ) -> Result<Vec<History>>; + async fn empty_query( + &mut self, + state: &SearchState, + db: &mut dyn Database, + ) -> Result<Vec<History>> { + search_db(state, db, SearchMode::FullText, "").await + } + async fn query(&mut self, state: &SearchState, db: &mut dyn Database) -> Result<Vec<History>> { if state.input.as_str().is_empty() { - Ok(db - .list(&[state.filter_mode], &state.context, Some(200), true, false) - .await? - .into_iter() - .collect::<Vec<_>>()) + self.empty_query(state, db).await } else { self.full_query(state, db).await } diff --git a/crates/atuin/src/command/client/search/engines/daemon.rs b/crates/atuin/src/command/client/search/engines/daemon.rs index 50471898..4fb3e2ea 100644 --- a/crates/atuin/src/command/client/search/engines/daemon.rs +++ b/crates/atuin/src/command/client/search/engines/daemon.rs @@ -1,6 +1,8 @@ +use std::collections::HashMap; + use async_trait::async_trait; use atuin_client::{ - database::{Database, OptFilters}, + database::Database, history::History, settings::{SearchMode, Settings}, }; @@ -13,7 +15,7 @@ use eyre::Result; use tracing::{Level, debug, instrument, span}; use uuid::Uuid; -use super::{SearchEngine, SearchState}; +use super::{SearchEngine, SearchState, search_db}; use crate::command::client::daemon; pub struct Search { @@ -84,20 +86,9 @@ impl Search { state: &SearchState, db: &dyn Database, ) -> Result<Vec<History>> { - let results = db - .search( - SearchMode::FullText, - state.filter_mode, - &state.context, - state.input.as_str(), - OptFilters { - limit: Some(200), - ..Default::default() - }, - ) + search_db(state, db, SearchMode::FullText, state.input.as_str()) .await - .map_or(Vec::new(), |r| r.into_iter().collect()); - Ok(results) + .map_or_else(|_| Ok(Vec::new()), Ok) } #[instrument(skip_all, level = Level::TRACE, name = "hydrate_from_db", fields(count = ids.len()))] @@ -142,6 +133,7 @@ impl SearchEngine for Search { query_id, state.filter_mode, Some(state.context.clone()), + state.authors.clone(), ) .await } @@ -162,6 +154,7 @@ impl SearchEngine for Search { query_id, state.filter_mode, Some(state.context.clone()), + state.authors.clone(), ) .await? } @@ -206,12 +199,14 @@ impl SearchEngine for Search { // // Hydrate from local database let results = self.hydrate_from_db(db, &ids).await?; - // // Reorder results to match the order from the daemon (which is ranked by relevance) + // Reorder results to match the order from the daemon (which is ranked by relevance) let ordered_results = span!(Level::TRACE, "reorder_results").in_scope(|| { + let results_by_id: HashMap<&str, &History> = + results.iter().map(|h| (h.id.0.as_str(), h)).collect(); let mut ordered_results = Vec::with_capacity(results.len()); for id in &ids { - if let Some(history) = results.iter().find(|h| h.id.0 == *id) { - ordered_results.push(history.clone()); + if let Some(history) = results_by_id.get(id.as_str()) { + ordered_results.push((*history).clone()); } } ordered_results diff --git a/crates/atuin/src/command/client/search/engines/db.rs b/crates/atuin/src/command/client/search/engines/db.rs index 476462f5..29daaaa1 100644 --- a/crates/atuin/src/command/client/search/engines/db.rs +++ b/crates/atuin/src/command/client/search/engines/db.rs @@ -1,8 +1,7 @@ -use super::{SearchEngine, SearchState}; +use super::{SearchEngine, SearchState, search_db}; use async_trait::async_trait; use atuin_client::{ database::Database, - database::OptFilters, database::{QueryToken, QueryTokenizer}, history::History, settings::SearchMode, @@ -23,21 +22,10 @@ impl SearchEngine for Search { state: &SearchState, db: &mut dyn Database, ) -> Result<Vec<History>> { - let results = db - .search( - self.0, - state.filter_mode, - &state.context, - state.input.as_str(), - OptFilters { - limit: Some(200), - ..Default::default() - }, - ) + search_db(state, db, self.0, state.input.as_str()) .await // ignore errors as it may be caused by incomplete regex - .map_or(Vec::new(), |r| r.into_iter().collect()); - Ok(results) + .map_or_else(|_| Ok(Vec::new()), Ok) } #[instrument(skip_all, level = Level::TRACE, name = "db_highlight")] @@ -107,3 +95,62 @@ pub fn get_highlight_indices_fulltext(command: &str, search_input: &str) -> Vec< ret.dedup(); ret } + +#[cfg(test)] +mod tests { + use super::*; + use crate::command::client::search::cursor::Cursor; + use atuin_client::{ + database::{Context, Database, Sqlite}, + history::{AUTHOR_FILTER_ALL_USER, History}, + settings::FilterMode, + }; + use time::macros::datetime; + + fn context() -> Context { + Context { + session: uuid::Uuid::now_v7().as_simple().to_string(), + cwd: "/tmp".to_string(), + hostname: "host:user".to_string(), + host_id: String::new(), + git_root: None, + } + } + + #[tokio::test] + async fn empty_query_uses_author_filters() { + let mut db = Sqlite::new(":memory:", 0.1).await.unwrap(); + + let user_history: History = History::import() + .timestamp(datetime!(2024-01-01 10:00 UTC)) + .command("git status") + .cwd("/tmp") + .author("ellie") + .build() + .into(); + let agent_history: History = History::import() + .timestamp(datetime!(2024-01-01 11:00 UTC)) + .command("git diff") + .cwd("/tmp") + .author("codex") + .build() + .into(); + + db.save_bulk(&[user_history.clone(), agent_history]) + .await + .unwrap(); + + let mut engine = Search(SearchMode::Fuzzy); + let state = SearchState { + input: Cursor::from(String::new()), + filter_mode: FilterMode::Global, + context: context(), + custom_context: None, + authors: vec![AUTHOR_FILTER_ALL_USER.to_string()], + }; + + let results = engine.query(&state, &mut db).await.unwrap(); + + assert_eq!(results, vec![user_history]); + } +} diff --git a/crates/atuin/src/command/client/search/engines/skim.rs b/crates/atuin/src/command/client/search/engines/skim.rs index 7d9feb40..4075f148 100644 --- a/crates/atuin/src/command/client/search/engines/skim.rs +++ b/crates/atuin/src/command/client/search/engines/skim.rs @@ -1,7 +1,11 @@ -use std::path::Path; +use std::{collections::HashMap, path::Path}; use async_trait::async_trait; -use atuin_client::{database::Database, history::History, settings::FilterMode}; +use atuin_client::{ + database::Database, + history::{History, author_matches_filters}, + settings::FilterMode, +}; use eyre::Result; use fuzzy_matcher::{FuzzyMatcher, skim::SkimMatcherV2}; use itertools::Itertools; @@ -53,7 +57,23 @@ impl SearchEngine for Search { #[instrument(skip_all, level = Level::TRACE, name = "load_all_history")] async fn load_all_history(db: &dyn Database) -> Vec<(History, i32)> { - db.all_with_count().await.unwrap() + let histories = db + .query_history("SELECT * FROM history WHERE deleted_at IS NULL ORDER BY timestamp DESC") + .await + .unwrap_or_default(); + + let mut counts = HashMap::new(); + for history in &histories { + *counts.entry(history.command.clone()).or_insert(0) += 1; + } + + histories + .into_iter() + .map(|history| { + let count = counts.get(&history.command).copied().unwrap_or(1); + (history, count) + }) + .collect() } #[allow(clippy::too_many_lines)] @@ -72,6 +92,9 @@ async fn fuzzy_search( if i % 256 == 0 { yield_now().await; } + if !author_matches_filters(&history.author, &state.authors) { + continue; + } let context = &state.context; let git_root = context .git_root @@ -220,3 +243,62 @@ fn path_dist(a: &Path, b: &Path) -> usize { b.len() - a.len() + dist } + +#[cfg(test)] +mod tests { + use super::*; + use crate::command::client::search::cursor::Cursor; + use atuin_client::{ + database::{Context, Database, Sqlite}, + history::{AUTHOR_FILTER_ALL_USER, History}, + settings::FilterMode, + }; + use time::macros::datetime; + + fn context() -> Context { + Context { + session: uuid::Uuid::now_v7().as_simple().to_string(), + cwd: "/tmp".to_string(), + hostname: "host:user".to_string(), + host_id: String::new(), + git_root: None, + } + } + + #[tokio::test] + async fn skim_search_uses_author_filters() { + let mut db = Sqlite::new(":memory:", 0.1).await.unwrap(); + + let user_history: History = History::import() + .timestamp(datetime!(2024-01-01 10:00 UTC)) + .command("git status") + .cwd("/tmp") + .author("ellie") + .build() + .into(); + let agent_history: History = History::import() + .timestamp(datetime!(2024-01-01 11:00 UTC)) + .command("git stash") + .cwd("/tmp") + .author("codex") + .build() + .into(); + + db.save_bulk(&[user_history.clone(), agent_history]) + .await + .unwrap(); + + let mut engine = Search::new(); + let state = SearchState { + input: Cursor::from("git st".to_owned()), + filter_mode: FilterMode::Global, + context: context(), + custom_context: None, + authors: vec![AUTHOR_FILTER_ALL_USER.to_string()], + }; + + let results = engine.query(&state, &mut db).await.unwrap(); + + assert_eq!(results, vec![user_history]); + } +} diff --git a/crates/atuin/src/command/client/search/interactive.rs b/crates/atuin/src/command/client/search/interactive.rs index ee38ddaa..f572ed7d 100644 --- a/crates/atuin/src/command/client/search/interactive.rs +++ b/crates/atuin/src/command/client/search/interactive.rs @@ -1600,6 +1600,7 @@ fn compute_popup_placement( )] pub async fn history( query: &[String], + authors: Vec<String>, settings: &Settings, mut db: impl Database, history_store: &HistoryStore, @@ -1768,6 +1769,7 @@ pub async fn history( filter_mode: default_filter_mode, context: initial_context.clone(), custom_context: None, + authors, }, engine: engines::engine(search_mode, settings), results_len: 0, @@ -2259,6 +2261,7 @@ mod tests { git_root: None, }, custom_context: None, + authors: vec![], }, engine: engines::engine(SearchMode::Fuzzy, &settings), now: Box::new(OffsetDateTime::now_utc), @@ -2314,6 +2317,7 @@ mod tests { git_root: None, }, custom_context: None, + authors: vec![], }, engine: engines::engine(SearchMode::Fuzzy, &settings), now: Box::new(OffsetDateTime::now_utc), @@ -2433,6 +2437,7 @@ mod tests { git_root: None, }, custom_context: None, + authors: vec![], }, engine: engines::engine(SearchMode::Fuzzy, &settings), now: Box::new(OffsetDateTime::now_utc), @@ -2492,6 +2497,7 @@ mod tests { git_root: None, }, custom_context: None, + authors: vec![], }, engine: engines::engine(SearchMode::Fuzzy, &settings), now: Box::new(OffsetDateTime::now_utc), @@ -2547,6 +2553,7 @@ mod tests { git_root: None, }, custom_context: None, + authors: vec![], }, engine: engines::engine(SearchMode::Fuzzy, &settings), now: Box::new(OffsetDateTime::now_utc), @@ -2598,6 +2605,7 @@ mod tests { git_root: None, }, custom_context: None, + authors: vec![], }, engine: engines::engine(SearchMode::Fuzzy, &settings), now: Box::new(OffsetDateTime::now_utc), @@ -2658,6 +2666,7 @@ mod tests { git_root: None, }, custom_context: None, + authors: vec![], }, engine: engines::engine(SearchMode::Fuzzy, &settings), now: Box::new(OffsetDateTime::now_utc), @@ -2719,6 +2728,7 @@ mod tests { git_root: None, }, custom_context: None, + authors: vec![], }, engine: engines::engine(SearchMode::Fuzzy, &settings), now: Box::new(OffsetDateTime::now_utc), @@ -3098,6 +3108,7 @@ mod tests { git_root: None, }, custom_context: None, + authors: vec![], }, engine: engines::engine(SearchMode::Fuzzy, &settings), now: Box::new(OffsetDateTime::now_utc), |
