diff options
19 files changed, 1502 insertions, 374 deletions
diff --git a/crates/atuin-client/src/database.rs b/crates/atuin-client/src/database.rs index 7c63368d..75ef51c3 100644 --- a/crates/atuin-client/src/database.rs +++ b/crates/atuin-client/src/database.rs @@ -5,6 +5,7 @@ use std::{ time::Duration, }; +use crate::history::{AUTHOR_FILTER_ALL_AGENT, AUTHOR_FILTER_ALL_USER, KNOWN_AGENTS}; use async_trait::async_trait; use atuin_common::utils; use fs_err as fs; @@ -53,6 +54,8 @@ pub struct OptFilters { pub offset: Option<i64>, pub reverse: bool, pub include_duplicates: bool, + /// Author filter. Supports special values `$all-user` and `$all-agent`. + pub authors: Vec<String>, } pub async fn current_context() -> eyre::Result<Context> { @@ -85,6 +88,38 @@ impl Context { } } +/// Each entry is OR'd: `$all-user` → NOT IN agents, `$all-agent` → IN agents, literal → exact match. +fn apply_author_filter(sql: &mut SqlBuilder, authors: &[String]) { + let mut conditions: Vec<String> = Vec::new(); + let agent_list: String = KNOWN_AGENTS.iter().map(quote).join(", "); + let author_expr = "CASE \ + WHEN author IS NULL OR trim(author) = '' THEN \ + CASE \ + WHEN instr(hostname, ':') > 0 THEN substr(hostname, instr(hostname, ':') + 1) \ + ELSE hostname \ + END \ + ELSE author \ + END"; + + for author in authors { + match author.as_str() { + AUTHOR_FILTER_ALL_USER => { + conditions.push(format!("{author_expr} NOT IN ({agent_list})")); + } + AUTHOR_FILTER_ALL_AGENT => { + conditions.push(format!("{author_expr} IN ({agent_list})")); + } + literal => { + conditions.push(format!("{author_expr} = {}", quote(literal))); + } + } + } + + if !conditions.is_empty() { + sql.and_where(format!("({})", conditions.join(" OR "))); + } +} + fn get_session_start_time(session_id: &str) -> Option<i64> { if let Ok(uuid) = Uuid::parse_str(session_id) && let Some(timestamp) = uuid.get_timestamp() @@ -595,6 +630,10 @@ impl Database for Sqlite { .map(|after| sql.and_where_gt("timestamp", quote(after.unix_timestamp_nanos() as i64))) }); + if !filter_options.authors.is_empty() { + apply_author_filter(&mut sql, &filter_options.authors); + } + sql.and_where_is_null("deleted_at"); let query = sql.sql().expect("bug in search query. please report"); diff --git a/crates/atuin-client/src/history.rs b/crates/atuin-client/src/history.rs index a5adc233..996208d9 100644 --- a/crates/atuin-client/src/history.rs +++ b/crates/atuin-client/src/history.rs @@ -18,6 +18,24 @@ use time::OffsetDateTime; mod builder; pub mod store; +/// Known AI agent author values. Used to expand `$all-agent` and `$all-user` filters. +pub const KNOWN_AGENTS: &[&str] = &["claude-code", "codex", "copilot"]; +pub const AUTHOR_FILTER_ALL_USER: &str = "$all-user"; +pub const AUTHOR_FILTER_ALL_AGENT: &str = "$all-agent"; + +pub fn is_known_agent(author: &str) -> bool { + KNOWN_AGENTS.contains(&author) +} + +pub fn author_matches_filters(author: &str, filters: &[String]) -> bool { + filters.is_empty() + || filters.iter().any(|filter| match filter.as_str() { + AUTHOR_FILTER_ALL_USER => !is_known_agent(author), + AUTHOR_FILTER_ALL_AGENT => is_known_agent(author), + literal => author == literal, + }) +} + pub(crate) const HISTORY_VERSION_V0: &str = "v0"; pub(crate) const HISTORY_VERSION_V1: &str = "v1"; const HISTORY_RECORD_VERSION_V0: u16 = 0; diff --git a/crates/atuin-client/src/settings.rs b/crates/atuin-client/src/settings.rs index 25c3bd65..5944de59 100644 --- a/crates/atuin-client/src/settings.rs +++ b/crates/atuin-client/src/settings.rs @@ -565,6 +565,13 @@ pub struct Search { /// The overall frecency score multiplier for the search index (default: 1.0). /// Applied after combining recency and frequency scores. pub frecency_score_multiplier: f64, + + /// Filter history by author. Special values: + /// - `$all-user`: any author that is NOT a known AI agent (default) + /// - `$all-agent`: any known AI agent author + /// - literal strings like "ellie", "claude-code" + #[serde(default = "Search::default_authors")] + pub authors: Vec<String>, } #[derive(Clone, Debug, Deserialize, Serialize)] @@ -844,10 +851,17 @@ impl Default for Search { recency_score_multiplier: 1.0, frequency_score_multiplier: 1.0, frecency_score_multiplier: 1.0, + authors: Self::default_authors(), } } } +impl Search { + fn default_authors() -> Vec<String> { + vec![crate::history::AUTHOR_FILTER_ALL_USER.to_string()] + } +} + impl Default for Tmux { fn default() -> Self { Self { diff --git a/crates/atuin-daemon/proto/search.proto b/crates/atuin-daemon/proto/search.proto index 6b84acbd..5eea2b62 100644 --- a/crates/atuin-daemon/proto/search.proto +++ b/crates/atuin-daemon/proto/search.proto @@ -23,6 +23,7 @@ message SearchRequest { uint64 query_id = 2; // Incrementing ID to match responses to queries FilterMode filter_mode = 3; SearchContext context = 4; + repeated string authors = 5; // Author filter ($all-user, $all-agent, or literals) } message SearchResponse { diff --git a/crates/atuin-daemon/src/client.rs b/crates/atuin-daemon/src/client.rs index 5f4ce20f..51334ee1 100644 --- a/crates/atuin-daemon/src/client.rs +++ b/crates/atuin-daemon/src/client.rs @@ -213,12 +213,14 @@ impl SearchClient { query_id: u64, filter_mode: FilterMode, context: Option<Context>, + authors: Vec<String>, ) -> Result<tonic::Streaming<SearchResponse>> { let request = SearchRequest { query, query_id, filter_mode: RpcFilterMode::from(filter_mode).into(), context: context.map(RpcSearchContext::from), + authors, }; let request_stream = tokio_stream::once(request); let response = span!(Level::TRACE, "daemon_client_search.request") diff --git a/crates/atuin-daemon/src/components/search.rs b/crates/atuin-daemon/src/components/search.rs index 9fc87fae..a2e74aa5 100644 --- a/crates/atuin-daemon/src/components/search.rs +++ b/crates/atuin-daemon/src/components/search.rs @@ -304,6 +304,7 @@ impl SearchSvc for SearchGrpcService { .try_into() .unwrap_or(FilterMode::Global); let proto_context = search_req.context; + let authors = search_req.authors; debug!( "search request: query = {}, query_id = {}, filter_mode = {}, context = {:?}", @@ -332,7 +333,13 @@ impl SearchSvc for SearchGrpcService { .in_scope(|| async { let index = index.read().await; index - .search(&query, index_filter, &query_context, RESULTS_LIMIT) + .search( + &query, + index_filter, + &query_context, + &authors, + RESULTS_LIMIT, + ) .await }) .await; diff --git a/crates/atuin-daemon/src/search/index.rs b/crates/atuin-daemon/src/search/index.rs index 3328c5b5..90751155 100644 --- a/crates/atuin-daemon/src/search/index.rs +++ b/crates/atuin-daemon/src/search/index.rs @@ -12,7 +12,9 @@ use std::{ sync::Arc, }; -use atuin_client::history::History; +use atuin_client::history::{ + AUTHOR_FILTER_ALL_AGENT, AUTHOR_FILTER_ALL_USER, History, KNOWN_AGENTS, +}; use atuin_client::settings::Search; use atuin_nucleo::{Injector, Nucleo, pattern}; use dashmap::DashMap; @@ -114,6 +116,100 @@ pub struct CommandData { hosts: HashSet<Spur>, /// All sessions where this command has been run (as 16-byte UUIDs). sessions: HashSet<[u8; 16]>, + /// All authors who have run this command (interned keys). + authors: HashSet<Spur>, + /// Per-invocation metadata for returning the newest row that matches the active filters. + invocations: Vec<InvocationData>, +} + +struct InvocationData { + id: [u8; 16], + timestamp: i64, + directory: Spur, + host: Spur, + session: [u8; 16], + author: Spur, +} + +#[derive(Default)] +struct CompiledAuthorFilter { + include_all_users: bool, + include_all_agents: bool, + literal_authors: HashSet<Spur>, +} + +impl CompiledAuthorFilter { + fn new(filters: &[String], interner: &ThreadedRodeo) -> Self { + let mut compiled = Self::default(); + + for filter in filters { + match filter.as_str() { + AUTHOR_FILTER_ALL_USER => compiled.include_all_users = true, + AUTHOR_FILTER_ALL_AGENT => compiled.include_all_agents = true, + literal => { + if let Some(author) = interner.get(literal) { + compiled.literal_authors.insert(author); + } + } + } + } + + compiled + } + + fn is_empty(&self) -> bool { + !self.include_all_users && !self.include_all_agents && self.literal_authors.is_empty() + } + + fn matches_author(&self, author: Spur, agent_authors: &HashSet<Spur>) -> bool { + self.is_empty() + || self.literal_authors.contains(&author) + || (self.include_all_users && !agent_authors.contains(&author)) + || (self.include_all_agents && agent_authors.contains(&author)) + } +} + +impl InvocationData { + fn new(history: &History, interner: &ThreadedRodeo) -> Option<Self> { + Some(Self { + id: parse_uuid_bytes(&history.id.0)?, + timestamp: history.timestamp.unix_timestamp(), + directory: interner.get_or_intern(with_trailing_slash(&history.cwd)), + host: interner.get_or_intern(&history.hostname), + session: parse_uuid_bytes(&history.session)?, + author: interner.get_or_intern(&history.author), + }) + } + + fn matches_mode(&self, mode: &IndexFilterMode, interner: &ThreadedRodeo) -> bool { + match mode { + IndexFilterMode::Global => true, + IndexFilterMode::Directory(dir) => { + interner.get(dir).is_some_and(|spur| self.directory == spur) + } + IndexFilterMode::Workspace(prefix) => { + interner.resolve(&self.directory).starts_with(prefix) + } + IndexFilterMode::Host(hostname) => { + interner.get(hostname).is_some_and(|spur| self.host == spur) + } + IndexFilterMode::Session(session) => { + parse_uuid_bytes(session).is_some_and(|bytes| self.session == bytes) + } + } + } + + fn matches_authors( + &self, + filter: &CompiledAuthorFilter, + agent_authors: &HashSet<Spur>, + ) -> bool { + filter.matches_author(self.author, agent_authors) + } + + fn id_string(&self) -> String { + format_uuid_bytes(&self.id) + } } impl CommandData { @@ -126,6 +222,7 @@ impl CommandData { let dir_key = interner.get_or_intern(with_trailing_slash(&history.cwd)); let host_key = interner.get_or_intern(&history.hostname); + let invocation = InvocationData::new(history, interner)?; let mut directories = HashSet::new(); directories.insert(dir_key); @@ -136,6 +233,9 @@ impl CommandData { let mut sessions = HashSet::new(); sessions.insert(session); + let mut authors = HashSet::new(); + authors.insert(interner.get_or_intern(&history.author)); + let mut global_frecency = FrecencyData::default(); global_frecency.record_use(timestamp); @@ -146,6 +246,8 @@ impl CommandData { directories, hosts, sessions, + authors, + invocations: vec![invocation], }) } @@ -160,6 +262,9 @@ impl CommandData { }; let timestamp = history.timestamp.unix_timestamp(); + let Some(invocation) = InvocationData::new(history, interner) else { + return false; + }; // Update global frecency self.global_frecency.record_use(timestamp); @@ -169,6 +274,8 @@ impl CommandData { self.directories.insert(dir_key); self.hosts.insert(interner.get_or_intern(&history.hostname)); self.sessions.insert(session); + self.authors.insert(interner.get_or_intern(&history.author)); + self.invocations.push(invocation); // Update most recent if this invocation is newer if timestamp > self.most_recent_timestamp { @@ -184,6 +291,27 @@ impl CommandData { format_uuid_bytes(&self.most_recent_id) } + fn most_recent_matching_id( + &self, + mode: &IndexFilterMode, + authors: &CompiledAuthorFilter, + interner: &ThreadedRodeo, + agent_authors: &HashSet<Spur>, + ) -> Option<String> { + if matches!(mode, IndexFilterMode::Global) && authors.is_empty() { + return Some(self.most_recent_id()); + } + + self.invocations + .iter() + .filter(|invocation| { + invocation.matches_mode(mode, interner) + && invocation.matches_authors(authors, agent_authors) + }) + .max_by_key(|invocation| invocation.timestamp) + .map(InvocationData::id_string) + } + /// Check if any invocation matches a directory filter (exact match). /// O(1) lookup using pre-computed index. pub fn has_invocation_in_dir(&self, dir: &str, interner: &ThreadedRodeo) -> bool { @@ -213,6 +341,16 @@ impl CommandData { pub fn has_invocation_in_session(&self, session: &str) -> bool { parse_uuid_bytes(session).is_some_and(|bytes| self.sessions.contains(&bytes)) } + + fn matches_authors( + &self, + filter: &CompiledAuthorFilter, + agent_authors: &HashSet<Spur>, + ) -> bool { + self.authors + .iter() + .any(|author| filter.matches_author(*author, agent_authors)) + } } /// Filter mode for search queries. @@ -336,6 +474,7 @@ impl SearchIndex { query: &str, filter_mode: IndexFilterMode, _context: &QueryContext, + authors: &[String], limit: u32, ) -> Vec<String> { let mut nucleo = self.nucleo.write().await; @@ -343,8 +482,21 @@ impl SearchIndex { // Get precomputed frecency map (may be None if not yet computed) let frecency_map = self.frecency_map.read().await.clone(); - // Build filter based on mode - let filter = self.build_filter(&filter_mode); + // Build filter based on mode + authors + let mode_filter = self.build_filter(&filter_mode); + let compiled_authors = CompiledAuthorFilter::new(authors, &self.interner); + let author_filter = self.build_author_filter(&compiled_authors); + let agent_authors: HashSet<Spur> = KNOWN_AGENTS + .iter() + .filter_map(|author| self.interner.get(author)) + .collect(); + let filter = + match (mode_filter, author_filter) { + (Some(mf), Some(af)) => Some(Arc::new(move |cmd: &String| mf(cmd) && af(cmd)) + as atuin_nucleo::Filter<String>), + (Some(f), None) | (None, Some(f)) => Some(f), + (None, None) => None, + }; nucleo.set_filter(filter); // Build scorer from precomputed frecency (or None if not available) @@ -375,9 +527,14 @@ impl SearchIndex { .filter_map(|item| { let cmd = item.data; // DashMap<Arc<str>, _>::get accepts &str via Borrow trait - self.commands - .get(cmd.as_str()) - .map(|data| data.most_recent_id()) + self.commands.get(cmd.as_str()).and_then(|data| { + data.most_recent_matching_id( + &filter_mode, + &compiled_authors, + &self.interner, + &agent_authors, + ) + }) }) .collect() }) @@ -452,6 +609,36 @@ impl SearchIndex { Some(Arc::new(move |cmd: &String| passing_commands.contains(cmd))) } + /// Build author filter predicate. + /// Same pre-computation approach as `build_filter`: find all commands with + /// a matching author, collect into a HashSet, wrap as a Nucleo Filter closure. + fn build_author_filter( + &self, + authors: &CompiledAuthorFilter, + ) -> Option<atuin_nucleo::Filter<String>> { + if authors.is_empty() { + return None; + } + + let agent_authors: HashSet<Spur> = KNOWN_AGENTS + .iter() + .filter_map(|a| self.interner.get(a)) + .collect(); + + let passing_commands: Arc<HashSet<String>> = { + let mut set = HashSet::new(); + for entry in self.commands.iter() { + let passes = entry.matches_authors(authors, &agent_authors); + if passes { + set.insert(entry.key().to_string()); + } + } + Arc::new(set) + }; + + Some(Arc::new(move |cmd: &String| passing_commands.contains(cmd))) + } + /// Build scorer from precomputed frecency map. /// /// Returns None if frecency map is not available (search still works, just without frecency ranking). @@ -634,6 +821,86 @@ mod tests { } #[tokio::test] + async fn search_index_returns_latest_matching_author_invocation() { + let index = SearchIndex::new(); + + let user_history: History = History::import() + .timestamp(datetime!(2024-01-01 10:00 UTC)) + .command("git status") + .cwd("/home/user/project") + .author("ellie") + .build() + .into(); + let agent_history: History = History::import() + .timestamp(datetime!(2024-01-01 11:00 UTC)) + .command("git status") + .cwd("/home/user/project") + .author("codex") + .build() + .into(); + + index.add_history(&user_history); + index.add_history(&agent_history); + + let user_results = index + .search( + "git", + IndexFilterMode::Global, + &QueryContext::default(), + &[AUTHOR_FILTER_ALL_USER.to_string()], + 10, + ) + .await; + assert_eq!( + user_results, + vec![Uuid::parse_str(&user_history.id.0).unwrap().to_string()] + ); + + let agent_results = index + .search( + "git", + IndexFilterMode::Global, + &QueryContext::default(), + &[AUTHOR_FILTER_ALL_AGENT.to_string()], + 10, + ) + .await; + assert_eq!( + agent_results, + vec![Uuid::parse_str(&agent_history.id.0).unwrap().to_string()] + ); + } + + #[tokio::test] + async fn search_index_returns_latest_matching_directory_invocation() { + let index = SearchIndex::new(); + + let dir1 = "/home/user/project"; + let dir2 = "/home/user/other"; + + let project_history = make_history("git status", dir1, datetime!(2024-01-01 10:00 UTC)); + let other_history = make_history("git status", dir2, datetime!(2024-01-01 11:00 UTC)); + + index.add_history(&project_history); + index.add_history(&other_history); + + let results = index + .search( + "git", + IndexFilterMode::Directory(with_trailing_slash(dir1)), + &QueryContext::default(), + &[], + 10, + ) + .await; + + assert_eq!( + results, + vec![Uuid::parse_str(&project_history.id.0).unwrap().to_string()] + ); + } + + #[tokio::test] async fn search_index_add_and_search() { let index = SearchIndex::new(); @@ -661,7 +928,13 @@ mod tests { // Search for "git" - should match 2 commands let results = index - .search("git", IndexFilterMode::Global, &QueryContext::default(), 10) + .search( + "git", + IndexFilterMode::Global, + &QueryContext::default(), + &[], + 10, + ) .await; assert_eq!(results.len(), 2); @@ -671,6 +944,7 @@ mod tests { "", IndexFilterMode::Directory(with_trailing_slash("/home/user/project")), &QueryContext::default(), + &[], 10, ) .await; diff --git a/crates/atuin/src/command/client.rs b/crates/atuin/src/command/client.rs index 917bee1a..e3ec01d9 100644 --- a/crates/atuin/src/command/client.rs +++ b/crates/atuin/src/command/client.rs @@ -54,6 +54,7 @@ mod default_config; mod doctor; mod dotfiles; mod history; +mod hook; mod import; mod info; mod init; @@ -76,6 +77,9 @@ pub enum Cmd { #[command(subcommand)] History(history::Cmd), + /// Manage AI-agent shell hooks + Hook(hook::Cmd), + /// Import shell history from file #[command(subcommand)] Import(import::Cmd), @@ -337,6 +341,7 @@ impl Cmd { // runs match self { Self::History(history) => return history.run(&settings).await, + Self::Hook(hook) => return hook.run(&settings).await, Self::Init(init) => return init.run(&settings).await, Self::Doctor => return doctor::run(&settings).await, Self::Config(config) => return config.run(&settings).await, @@ -387,7 +392,9 @@ impl Cmd { #[cfg(feature = "daemon")] Self::Daemon(cmd) => cmd.run(settings, sqlite_store, db).await, - Self::History(_) | Self::Init(_) | Self::Doctor | Self::Config(_) => unreachable!(), + Self::History(_) | Self::Hook(_) | Self::Init(_) | Self::Doctor | Self::Config(_) => { + unreachable!() + } #[cfg(feature = "ai")] Self::Ai(cli) => atuin_ai::commands::run(cli, &settings).await, diff --git a/crates/atuin/src/command/client/history.rs b/crates/atuin/src/command/client/history.rs index 67e0a5db..836556b4 100644 --- a/crates/atuin/src/command/client/history.rs +++ b/crates/atuin/src/command/client/history.rs @@ -374,6 +374,235 @@ fn parse_fmt(format: &str) -> ParsedFmt<'_> { } } +fn apply_start_metadata(history: &mut History, author: Option<&str>, intent: Option<&str>) { + if let Some(author) = author.map(str::trim).filter(|author| !author.is_empty()) { + author.clone_into(&mut history.author); + } + + if let Some(intent) = intent.map(str::trim).filter(|intent| !intent.is_empty()) { + history.intent = Some(intent.to_owned()); + } else if intent.is_some() { + history.intent = None; + } +} + +fn normalize_command_for_storage<'a>(command: &'a str, settings: &Settings) -> &'a str { + if !settings.strip_trailing_whitespace { + return command; + } + + let trimmed = command.trim_end_matches([' ', '\t']); + if trimmed.len() == command.len() { + return command; + } + + let trailing_backslashes = trimmed + .as_bytes() + .iter() + .rev() + .take_while(|&&byte| byte == b'\\') + .count(); + + if trailing_backslashes % 2 == 1 { + command + } else { + trimmed + } +} + +async fn handle_start( + db: &impl Database, + settings: &Settings, + command: &str, + author: Option<&str>, + intent: Option<&str>, +) -> Result<Option<String>> { + // It's better for atuin to silently fail here and attempt to + // store whatever is ran, than to throw an error to the terminal + let cwd = utils::get_current_dir(); + let command = normalize_command_for_storage(command, settings); + + let mut h: History = History::capture() + .timestamp(OffsetDateTime::now_utc()) + .command(command) + .cwd(cwd) + .build() + .into(); + apply_start_metadata(&mut h, author, intent); + + if !h.should_save(settings) { + return Ok(None); + } + + let id = h.id.0.clone(); + + // Silently ignore database errors to avoid breaking the shell + // This is important when disk is full or database is locked + if let Err(e) = db.save(&h).await { + debug!("failed to save history: {e}"); + } + + Ok(Some(id)) +} + +#[cfg(feature = "daemon")] +async fn handle_daemon_start( + settings: &Settings, + command: &str, + author: Option<&str>, + intent: Option<&str>, +) -> Result<Option<String>> { + // It's better for atuin to silently fail here and attempt to + // store whatever is ran, than to throw an error to the terminal + let cwd = utils::get_current_dir(); + let command = normalize_command_for_storage(command, settings); + + let mut h: History = History::capture() + .timestamp(OffsetDateTime::now_utc()) + .command(command) + .cwd(cwd) + .build() + .into(); + apply_start_metadata(&mut h, author, intent); + + if !h.should_save(settings) { + return Ok(None); + } + + // Attempt to start history via daemon, but silently ignore errors + // to avoid breaking the shell when the daemon is unavailable or disk is full + let resp = match daemon::start_history(settings, h.clone()).await { + Ok(id) => id, + Err(e) => { + debug!("failed to start history via daemon: {e}"); + h.id.0.clone() + } + }; + + Ok(Some(resp)) +} + +#[allow(unused_variables)] +async fn handle_end( + db: &impl Database, + store: SqliteStore, + history_store: HistoryStore, + settings: &Settings, + id: &str, + exit: i64, + duration: Option<u64>, +) -> Result<()> { + if id.trim() == "" { + return Ok(()); + } + + let Some(mut h) = db.load(id).await? else { + warn!("history entry is missing"); + return Ok(()); + }; + + if h.duration > 0 { + debug!("cannot end history - already has duration"); + + // returning OK as this can occur if someone Ctrl-c a prompt + return Ok(()); + } + + if !settings.store_failed && exit > 0 { + debug!("history has non-zero exit code, and store_failed is false"); + + // the history has already been inserted half complete. remove it + db.delete(h).await?; + + return Ok(()); + } + + h.exit = exit; + h.duration = match duration { + Some(value) => i64::try_from(value).context("command took over 292 years")?, + None => i64::try_from((OffsetDateTime::now_utc() - h.timestamp).whole_nanoseconds()) + .context("command took over 292 years")?, + }; + + db.update(&h).await?; + history_store.push(h).await?; + + if settings.should_sync().await? { + #[cfg(feature = "sync")] + { + if settings.sync.records { + let (_, downloaded) = record::sync::sync(settings, &store).await?; + Settings::save_sync_time().await?; + + crate::sync::build(settings, &store, db, Some(&downloaded)).await?; + } else { + debug!("running periodic background sync"); + sync::sync(settings, false, db).await?; + } + } + #[cfg(not(feature = "sync"))] + debug!("not compiled with sync support"); + } else { + debug!("sync disabled! not syncing"); + } + + Ok(()) +} + +#[cfg(feature = "daemon")] +async fn handle_daemon_end( + settings: &Settings, + id: &str, + exit: i64, + duration: Option<u64>, +) -> Result<()> { + daemon::end_history(settings, id.to_string(), duration.unwrap_or(0), exit).await?; + + Ok(()) +} + +pub(super) async fn start_history_entry( + settings: &Settings, + command: &str, + author: Option<&str>, + intent: Option<&str>, +) -> Result<Option<String>> { + #[cfg(feature = "daemon")] + if settings.daemon.enabled { + return handle_daemon_start(settings, command, author, intent).await; + } + + let db_path = PathBuf::from(settings.db_path.as_str()); + let db = Sqlite::new(db_path, settings.local_timeout).await?; + handle_start(&db, settings, command, author, intent).await +} + +pub(super) async fn end_history_entry( + settings: &Settings, + id: &str, + exit: i64, + duration: Option<u64>, +) -> Result<()> { + #[cfg(feature = "daemon")] + if settings.daemon.enabled { + return handle_daemon_end(settings, id, exit, duration).await; + } + + let db_path = PathBuf::from(settings.db_path.as_str()); + let record_store_path = PathBuf::from(settings.record_store_path.as_str()); + + let db = Sqlite::new(db_path, settings.local_timeout).await?; + let store = SqliteStore::new(record_store_path, settings.local_timeout).await?; + + let encryption_key: [u8; 32] = encryption::load_key(settings) + .context("could not load encryption key")? + .into(); + let host_id = Settings::host_id().await?; + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + + handle_end(&db, store, history_store, settings, id, exit, duration).await +} + #[cfg(feature = "daemon")] #[derive(Clone, Copy, Debug, Eq, PartialEq)] enum TailKind { @@ -676,200 +905,6 @@ fn normalize_optional_field(value: &str) -> Option<String> { } impl Cmd { - fn normalize_command_for_storage<'a>(command: &'a str, settings: &Settings) -> &'a str { - if !settings.strip_trailing_whitespace { - return command; - } - - let trimmed = command.trim_end_matches([' ', '\t']); - if trimmed.len() == command.len() { - return command; - } - - let trailing_backslashes = trimmed - .as_bytes() - .iter() - .rev() - .take_while(|&&byte| byte == b'\\') - .count(); - - if trailing_backslashes % 2 == 1 { - command - } else { - trimmed - } - } - - fn apply_start_metadata(history: &mut History, author: Option<&str>, intent: Option<&str>) { - if let Some(author) = author.map(str::trim).filter(|author| !author.is_empty()) { - author.clone_into(&mut history.author); - } - - if let Some(intent) = intent.map(str::trim).filter(|intent| !intent.is_empty()) { - history.intent = Some(intent.to_owned()); - } else if intent.is_some() { - history.intent = None; - } - } - - #[allow(clippy::too_many_lines, clippy::cast_possible_truncation)] - async fn handle_start( - db: &impl Database, - settings: &Settings, - command: &str, - author: Option<&str>, - intent: Option<&str>, - ) -> Result<()> { - // It's better for atuin to silently fail here and attempt to - // store whatever is ran, than to throw an error to the terminal - let cwd = utils::get_current_dir(); - let command = Self::normalize_command_for_storage(command, settings); - - let mut h: History = History::capture() - .timestamp(OffsetDateTime::now_utc()) - .command(command) - .cwd(cwd) - .build() - .into(); - Self::apply_start_metadata(&mut h, author, intent); - - if !h.should_save(settings) { - return Ok(()); - } - - // print the ID - // we use this as the key for calling end - println!("{}", h.id); - - // Silently ignore database errors to avoid breaking the shell - // This is important when disk is full or database is locked - if let Err(e) = db.save(&h).await { - debug!("failed to save history: {e}"); - } - - Ok(()) - } - - #[cfg(feature = "daemon")] - async fn handle_daemon_start( - settings: &Settings, - command: &str, - author: Option<&str>, - intent: Option<&str>, - ) -> Result<()> { - // It's better for atuin to silently fail here and attempt to - // store whatever is ran, than to throw an error to the terminal - let cwd = utils::get_current_dir(); - let command = Self::normalize_command_for_storage(command, settings); - - let mut h: History = History::capture() - .timestamp(OffsetDateTime::now_utc()) - .command(command) - .cwd(cwd) - .build() - .into(); - Self::apply_start_metadata(&mut h, author, intent); - - if !h.should_save(settings) { - return Ok(()); - } - - // Attempt to start history via daemon, but silently ignore errors - // to avoid breaking the shell when the daemon is unavailable or disk is full - let resp = match daemon::start_history(settings, h.clone()).await { - Ok(id) => id, - Err(e) => { - debug!("failed to start history via daemon: {e}"); - h.id.0.clone() - } - }; - - // print the ID - // we use this as the key for calling end - println!("{resp}"); - - Ok(()) - } - - #[allow(unused_variables)] - async fn handle_end( - db: &impl Database, - store: SqliteStore, - history_store: HistoryStore, - settings: &Settings, - id: &str, - exit: i64, - duration: Option<u64>, - ) -> Result<()> { - if id.trim() == "" { - return Ok(()); - } - - let Some(mut h) = db.load(id).await? else { - warn!("history entry is missing"); - return Ok(()); - }; - - if h.duration > 0 { - debug!("cannot end history - already has duration"); - - // returning OK as this can occur if someone Ctrl-c a prompt - return Ok(()); - } - - if !settings.store_failed && exit > 0 { - debug!("history has non-zero exit code, and store_failed is false"); - - // the history has already been inserted half complete. remove it - db.delete(h).await?; - - return Ok(()); - } - - h.exit = exit; - h.duration = match duration { - Some(value) => i64::try_from(value).context("command took over 292 years")?, - None => i64::try_from((OffsetDateTime::now_utc() - h.timestamp).whole_nanoseconds()) - .context("command took over 292 years")?, - }; - - db.update(&h).await?; - history_store.push(h).await?; - - if settings.should_sync().await? { - #[cfg(feature = "sync")] - { - if settings.sync.records { - let (_, downloaded) = record::sync::sync(settings, &store).await?; - Settings::save_sync_time().await?; - - crate::sync::build(settings, &store, db, Some(&downloaded)).await?; - } else { - debug!("running periodic background sync"); - sync::sync(settings, false, db).await?; - } - } - #[cfg(not(feature = "sync"))] - debug!("not compiled with sync support"); - } else { - debug!("sync disabled! not syncing"); - } - - Ok(()) - } - - #[cfg(feature = "daemon")] - async fn handle_daemon_end( - settings: &Settings, - id: &str, - exit: i64, - duration: Option<u64>, - ) -> Result<()> { - daemon::end_history(settings, id.to_string(), duration.unwrap_or(0), exit).await?; - - Ok(()) - } - #[cfg(feature = "daemon")] async fn handle_tail(settings: &Settings) -> Result<()> { let tty = std::io::stdout().is_terminal(); @@ -892,6 +927,7 @@ impl Cmd { Ok(()) } + #[allow(clippy::too_many_lines, clippy::cast_possible_truncation)] #[allow(clippy::too_many_arguments)] #[allow(clippy::fn_params_excessive_bools)] async fn handle_list( @@ -1057,145 +1093,125 @@ impl Cmd { #[allow(clippy::too_many_lines)] pub async fn run(self, settings: &Settings) -> Result<()> { - let context = current_context().await?; - - #[cfg(feature = "daemon")] - // Skip initializing any databases for start/end, if the daemon is enabled - if settings.daemon.enabled { - match self { - Self::Start { .. } => { - let command = self.get_start_command().unwrap_or_default(); - let (author, intent) = self.get_start_metadata().unwrap_or_default(); - return Self::handle_daemon_start(settings, &command, author, intent).await; - } + match self { + Self::Start { + cmd_env, + author, + intent, + command, + } => { + let command = if cmd_env { + std::env::var("ATUIN_COMMAND_LINE").unwrap_or_default() + } else { + command.join(" ") + }; - Self::End { id, exit, duration } => { - return Self::handle_daemon_end(settings, &id, exit, duration).await; + if let Some(id) = + start_history_entry(settings, &command, author.as_deref(), intent.as_deref()) + .await? + { + println!("{id}"); } - Self::Tail => { + Ok(()) + } + Self::End { id, exit, duration } => { + end_history_entry(settings, &id, exit, duration).await + } + Self::Tail => { + #[cfg(feature = "daemon")] + { return Self::handle_tail(settings).await; } - _ => {} + #[cfg(not(feature = "daemon"))] + bail!("`atuin history tail` requires Atuin to be built with the `daemon` feature"); } - } - - if matches!(self, Self::Tail) { - #[cfg(feature = "daemon")] - bail!("`atuin history tail` requires `daemon.enabled = true`"); - - #[cfg(not(feature = "daemon"))] - bail!("`atuin history tail` requires Atuin to be built with the `daemon` feature"); - } + cmd => { + let context = current_context().await?; - let db_path = PathBuf::from(settings.db_path.as_str()); - let record_store_path = PathBuf::from(settings.record_store_path.as_str()); + let db_path = PathBuf::from(settings.db_path.as_str()); + let record_store_path = PathBuf::from(settings.record_store_path.as_str()); - let db = Sqlite::new(db_path, settings.local_timeout).await?; - let store = SqliteStore::new(record_store_path, settings.local_timeout).await?; + let db = Sqlite::new(db_path, settings.local_timeout).await?; + let store = SqliteStore::new(record_store_path, settings.local_timeout).await?; - let encryption_key: [u8; 32] = encryption::load_key(settings) - .context("could not load encryption key")? - .into(); + let encryption_key: [u8; 32] = encryption::load_key(settings) + .context("could not load encryption key")? + .into(); - let host_id = Settings::host_id().await?; - let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + let host_id = Settings::host_id().await?; + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); - match self { - Self::Start { .. } => { - let command = self.get_start_command().unwrap_or_default(); - let (author, intent) = self.get_start_metadata().unwrap_or_default(); - Self::handle_start(&db, settings, &command, author, intent).await - } - Self::End { id, exit, duration } => { - Self::handle_end(&db, store, history_store, settings, &id, exit, duration).await - } - Self::Tail => unreachable!("tail handled before database initialization"), - Self::List { - session, - cwd, - human, - cmd_only, - print0, - reverse, - timezone, - format, - } => { - let mode = ListMode::from_flags(human, cmd_only); - let tz = timezone.unwrap_or(settings.timezone); - Self::handle_list( - &db, settings, context, session, cwd, mode, format, false, print0, reverse, tz, - ) - .await - } + match cmd { + Self::List { + session, + cwd, + human, + cmd_only, + print0, + reverse, + timezone, + format, + } => { + let mode = ListMode::from_flags(human, cmd_only); + let tz = timezone.unwrap_or(settings.timezone); + Self::handle_list( + &db, settings, context, session, cwd, mode, format, false, print0, + reverse, tz, + ) + .await + } - Self::Last { - human, - cmd_only, - timezone, - format, - } => { - let last = db.last().await?; - let last = last.as_slice(); - let tz = timezone.unwrap_or(settings.timezone); - print_list( - last, - ListMode::from_flags(human, cmd_only), - match format { - None => Some(settings.history_format.as_str()), - _ => format.as_deref(), - }, - false, - true, - tz, - ); + Self::Last { + human, + cmd_only, + timezone, + format, + } => { + let last = db.last().await?; + let last = last.as_slice(); + let tz = timezone.unwrap_or(settings.timezone); + print_list( + last, + ListMode::from_flags(human, cmd_only), + match format { + None => Some(settings.history_format.as_str()), + _ => format.as_deref(), + }, + false, + true, + tz, + ); - Ok(()) - } + Ok(()) + } - Self::InitStore => history_store.init_store(&db).await, + Self::InitStore => history_store.init_store(&db).await, - Self::Prune { dry_run } => { - Self::handle_prune(&db, settings, store, context, dry_run).await - } + Self::Prune { dry_run } => { + Self::handle_prune(&db, settings, store, context, dry_run).await + } - Self::Dedup { - dry_run, - before, - dupkeep, - } => { - let before = i64::try_from( - interim::parse_date_string( - before.as_str(), - OffsetDateTime::now_utc(), - interim::Dialect::Uk, - )? - .unix_timestamp_nanos(), - )?; - Self::handle_dedup(&db, settings, store, before, dupkeep, dry_run).await - } - } - } + Self::Dedup { + dry_run, + before, + dupkeep, + } => { + let before = i64::try_from( + interim::parse_date_string( + before.as_str(), + OffsetDateTime::now_utc(), + interim::Dialect::Uk, + )? + .unix_timestamp_nanos(), + )?; + Self::handle_dedup(&db, settings, store, before, dupkeep, dry_run).await + } - /// Returns the command line to use for the `Start` variant. - /// Returns `None` for any other variant. - fn get_start_command(&self) -> Option<String> { - match self { - Self::Start { cmd_env: true, .. } => { - Some(std::env::var("ATUIN_COMMAND_LINE").unwrap_or_default()) + Self::Start { .. } | Self::End { .. } | Self::Tail => unreachable!(), + } } - Self::Start { command, .. } => Some(command.join(" ")), - _ => None, - } - } - - /// Returns `(author, intent)` for the `Start` variant. - /// Returns `None` for any other variant. - fn get_start_metadata(&self) -> Option<(Option<&str>, Option<&str>)> { - match self { - Self::Start { author, intent, .. } => Some((author.as_deref(), intent.as_deref())), - _ => None, } } } @@ -1212,10 +1228,7 @@ mod tests { let settings = Settings::utc(); assert!(settings.strip_trailing_whitespace); - assert_eq!( - Cmd::normalize_command_for_storage("ls \t", &settings), - "ls" - ); + assert_eq!(normalize_command_for_storage("ls \t", &settings), "ls"); } #[test] @@ -1223,11 +1236,11 @@ mod tests { let settings = Settings::utc(); assert_eq!( - Cmd::normalize_command_for_storage("printf foo\\ ", &settings), + normalize_command_for_storage("printf foo\\ ", &settings), "printf foo\\ " ); assert_eq!( - Cmd::normalize_command_for_storage("printf foo\\\\ ", &settings), + normalize_command_for_storage("printf foo\\\\ ", &settings), "printf foo\\\\" ); } @@ -1237,7 +1250,7 @@ mod tests { let db = Sqlite::new("sqlite::memory:", 2.0).await.unwrap(); let settings = Settings::utc(); - Cmd::handle_start(&db, &settings, "ls \t", None, None) + handle_start(&db, &settings, "ls \t", None, None) .await .unwrap(); @@ -1258,7 +1271,7 @@ mod tests { ..Settings::utc() }; - Cmd::handle_start(&db, &settings, "ls \t", None, None) + handle_start(&db, &settings, "ls \t", None, None) .await .unwrap(); diff --git a/crates/atuin/src/command/client/hook.rs b/crates/atuin/src/command/client/hook.rs new file mode 100644 index 00000000..bb333c5f --- /dev/null +++ b/crates/atuin/src/command/client/hook.rs @@ -0,0 +1,401 @@ +use std::io::Read; +use std::path::PathBuf; + +use atuin_client::settings::Settings; +use atuin_common::utils::home_dir; +use clap::{Parser, Subcommand}; +use eyre::{Result, bail}; +use serde_json::Value; + +use super::history; + +const HOOK_EVENT_TYPES: &[&str] = &["PreToolUse", "PostToolUse", "PostToolUseFailure"]; + +struct AgentSpec { + aliases: &'static [&'static str], + actor_name: &'static str, + config_path: &'static [&'static str], + hook_command: &'static str, + matcher: &'static str, +} + +const CLAUDE_CODE: AgentSpec = AgentSpec { + aliases: &["claude-code", "claude"], + actor_name: "claude-code", + config_path: &[".claude", "settings.json"], + hook_command: "atuin hook claude-code", + matcher: "Bash", +}; + +const CODEX: AgentSpec = AgentSpec { + aliases: &["codex"], + actor_name: "codex", + config_path: &[".codex", "hooks.json"], + hook_command: "atuin hook codex", + matcher: "^Bash$", +}; + +const AGENTS: &[&AgentSpec] = &[&CLAUDE_CODE, &CODEX]; + +struct Agent(&'static AgentSpec); + +impl Agent { + fn from_name(name: &str) -> Result<Self> { + AGENTS + .iter() + .copied() + .find(|spec| spec.aliases.contains(&name)) + .map(Self) + .ok_or_else(|| { + eyre::eyre!("unknown agent: {name}. Supported agents: claude-code, codex") + }) + } + + fn actor_name(&self) -> &'static str { + self.0.actor_name + } + + fn config_path(&self) -> PathBuf { + self.0 + .config_path + .iter() + .fold(home_dir(), |path, segment| path.join(segment)) + } + + fn hook_command(&self) -> &'static str { + self.0.hook_command + } + + fn matcher(&self) -> &'static str { + self.0.matcher + } +} + +#[derive(Subcommand, Debug)] +enum Action { + /// Install hooks for an AI agent to capture commands in atuin history + Install { + /// Agent to install hooks for (e.g., "claude-code") + #[arg(value_name = "AGENT")] + agent: String, + }, +} + +#[derive(Parser, Debug)] +#[command(infer_subcommands = true, args_conflicts_with_subcommands = true)] +pub struct Cmd { + #[command(subcommand)] + action: Option<Action>, + + /// Which agent's hook format to parse (e.g., "claude-code") + #[arg(value_name = "AGENT", hide = true)] + agent: Option<String>, +} + +impl Cmd { + pub async fn run(self, settings: &Settings) -> Result<()> { + match (self.action, self.agent) { + (Some(Action::Install { agent }), None) => install(&agent), + (None, Some(agent)) => handle(&agent, settings).await, + (None, None) => bail!("expected `atuin hook <agent>` or `atuin hook install <agent>`"), + (Some(_), Some(_)) => bail!("hook action cannot be combined with a positional agent"), + } + } +} + +#[derive(Debug)] +enum HookEvent { + Start { + command: String, + intent: Option<String>, + tool_use_id: String, + }, + End { + tool_use_id: String, + exit: i64, + }, + Skip, +} + +fn parse_hook_stdin(input: &str) -> Result<HookEvent> { + let v: Value = serde_json::from_str(input)?; + + if v.get("tool_name").and_then(|t| t.as_str()) != Some("Bash") { + return Ok(HookEvent::Skip); + } + + let tool_use_id = match v.get("tool_use_id").and_then(|t| t.as_str()) { + Some(id) if !id.is_empty() => id.to_string(), + _ => return Ok(HookEvent::Skip), + }; + + match v.get("hook_event_name").and_then(|e| e.as_str()) { + Some("PreToolUse") => { + let tool_input = v.get("tool_input"); + let command = tool_input + .and_then(|ti| ti.get("command")) + .and_then(|c| c.as_str()) + .unwrap_or(""); + + if command.is_empty() { + return Ok(HookEvent::Skip); + } + + let intent = tool_input + .and_then(|ti| ti.get("description")) + .and_then(|d| d.as_str()) + .map(String::from); + + Ok(HookEvent::Start { + command: command.to_string(), + intent, + tool_use_id, + }) + } + Some(event @ ("PostToolUse" | "PostToolUseFailure")) => { + let exit = if event == "PostToolUseFailure" { + 1 + } else { + v.get("tool_response") + .and_then(|tr| tr.get("exitCode")) + .and_then(Value::as_i64) + .unwrap_or(0) + }; + + Ok(HookEvent::End { tool_use_id, exit }) + } + _ => Ok(HookEvent::Skip), + } +} + +fn id_file_path(tool_use_id: &str) -> PathBuf { + std::env::temp_dir().join(format!("atuin-hook-{tool_use_id}")) +} + +async fn handle(agent_name: &str, settings: &Settings) -> Result<()> { + let agent = Agent::from_name(agent_name)?; + + let mut input = String::new(); + std::io::stdin().read_to_string(&mut input)?; + + if input.trim().is_empty() { + return Ok(()); + } + + match parse_hook_stdin(&input)? { + HookEvent::Start { + command, + intent, + tool_use_id, + } => { + if let Some(history_id) = history::start_history_entry( + settings, + &command, + Some(agent.actor_name()), + intent.as_deref(), + ) + .await? + { + std::fs::write(id_file_path(&tool_use_id), &history_id)?; + } + } + HookEvent::End { tool_use_id, exit } => { + let id_path = id_file_path(&tool_use_id); + + if let Ok(history_id) = std::fs::read_to_string(&id_path) { + let history_id = history_id.trim(); + if !history_id.is_empty() { + let _ = history::end_history_entry(settings, history_id, exit, None).await; + } + let _ = std::fs::remove_file(&id_path); + } + } + HookEvent::Skip => {} + } + + Ok(()) +} + +fn install(agent_name: &str) -> Result<()> { + let agent = Agent::from_name(agent_name)?; + let config_path = agent.config_path(); + + if let Some(parent) = config_path.parent() { + std::fs::create_dir_all(parent)?; + } + + let mut root: Value = if config_path.exists() { + let content = std::fs::read_to_string(&config_path)?; + serde_json::from_str(&content)? + } else { + Value::Object(serde_json::Map::new()) + }; + + let hooks = root + .as_object_mut() + .ok_or_else(|| eyre::eyre!("config is not a JSON object"))? + .entry("hooks") + .or_insert_with(|| Value::Object(serde_json::Map::new())); + + add_hook_entries(hooks, &agent)?; + + let content = serde_json::to_string_pretty(&root)?; + std::fs::write(&config_path, content)?; + + eprintln!( + "\nAtuin hooks installed for {}. Config: {}", + agent.actor_name(), + config_path.display() + ); + + Ok(()) +} + +fn add_hook_entries(hooks: &mut Value, agent: &Agent) -> Result<()> { + let hook_command = agent.hook_command(); + let matcher = agent.matcher(); + + for event_type in HOOK_EVENT_TYPES { + let event_hooks = hooks + .as_object_mut() + .ok_or_else(|| eyre::eyre!("hooks is not a JSON object"))? + .entry(*event_type) + .or_insert_with(|| Value::Array(Vec::new())); + + let arr = event_hooks + .as_array_mut() + .ok_or_else(|| eyre::eyre!("hooks.{event_type} is not an array"))?; + + let already_installed = arr.iter().any(|entry| { + entry["hooks"].as_array().is_some_and(|h| { + h.iter() + .any(|hook| hook["command"].as_str() == Some(hook_command)) + }) + }); + + if already_installed { + eprintln!("hooks.{event_type}: already installed, skipping"); + continue; + } + + arr.push(serde_json::json!({ + "matcher": matcher, + "hooks": [{"type": "command", "command": hook_command}] + })); + eprintln!("hooks.{event_type}: installed atuin hook"); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + Atuin, + command::{AtuinCmd, client}, + }; + use clap::Parser; + + #[test] + fn parse_hook_agent_command() { + let cmd = Cmd::try_parse_from(["hook", "codex"]).unwrap(); + + assert!(matches!( + (cmd.action, cmd.agent.as_deref()), + (None, Some("codex")) + )); + } + + #[test] + fn parse_hook_install_command() { + let cmd = Cmd::try_parse_from(["hook", "install", "codex"]).unwrap(); + + match (cmd.action, cmd.agent) { + (Some(Action::Install { agent }), None) => assert_eq!(agent, "codex"), + other => panic!("unexpected parsed command: {other:?}"), + } + } + + #[test] + fn parse_top_level_hook_command() { + let cmd = Atuin::try_parse_from(["atuin", "hook", "codex"]).unwrap(); + + assert!(matches!( + cmd.atuin, + AtuinCmd::Client(client::Cmd::Hook(Cmd { action: None, agent: Some(agent) })) + if agent == "codex" + )); + } + + #[test] + fn test_parse_pre_tool_use() { + let input = r#"{ + "hook_event_name": "PreToolUse", + "tool_name": "Bash", + "tool_input": {"command": "echo hello", "description": "Test greeting"}, + "tool_use_id": "toolu_abc123", + "session_id": "sess1", + "cwd": "/tmp" + }"#; + + match parse_hook_stdin(input).unwrap() { + HookEvent::Start { + command, + intent, + tool_use_id, + } => { + assert_eq!(command, "echo hello"); + assert_eq!(intent.as_deref(), Some("Test greeting")); + assert_eq!(tool_use_id, "toolu_abc123"); + } + _ => panic!("expected Start event"), + } + } + + #[test] + fn test_parse_post_tool_use() { + let input = r#"{ + "hook_event_name": "PostToolUse", + "tool_name": "Bash", + "tool_input": {"command": "echo hello"}, + "tool_response": {"exitCode": 0}, + "tool_use_id": "toolu_abc123" + }"#; + + match parse_hook_stdin(input).unwrap() { + HookEvent::End { tool_use_id, exit } => { + assert_eq!(tool_use_id, "toolu_abc123"); + assert_eq!(exit, 0); + } + _ => panic!("expected End event"), + } + } + + #[test] + fn test_parse_non_bash_tool_skipped() { + let input = r#"{ + "hook_event_name": "PreToolUse", + "tool_name": "Write", + "tool_input": {"file_path": "/tmp/test.txt", "content": "hello"}, + "tool_use_id": "toolu_abc123" + }"#; + + assert!(matches!(parse_hook_stdin(input).unwrap(), HookEvent::Skip)); + } + + #[test] + fn test_parse_failure_event() { + let input = r#"{ + "hook_event_name": "PostToolUseFailure", + "tool_name": "Bash", + "tool_input": {"command": "false"}, + "tool_use_id": "toolu_abc123" + }"#; + + match parse_hook_stdin(input).unwrap() { + HookEvent::End { exit, .. } => assert_eq!(exit, 1), + _ => panic!("expected End event"), + } + } +} diff --git a/crates/atuin/src/command/client/search.rs b/crates/atuin/src/command/client/search.rs index 3d348473..19045867 100644 --- a/crates/atuin/src/command/client/search.rs +++ b/crates/atuin/src/command/client/search.rs @@ -131,6 +131,11 @@ pub struct Cmd { #[arg(long = "inline-height")] inline_height: Option<u16>, + /// Filter by author. Supports $all-user (non-agents), $all-agent, or literal names. + /// Can be specified multiple times. + #[arg(long)] + author: Option<Vec<String>>, + /// Include duplicate commands in the output (non-interactive only) #[arg(long)] include_duplicates: bool, @@ -141,6 +146,12 @@ pub struct Cmd { } impl Cmd { + fn resolved_authors(&self, settings: &Settings) -> Vec<String> { + self.author + .clone() + .unwrap_or_else(|| settings.search.authors.clone()) + } + /// Returns true if this search command will run in interactive (TUI) mode pub fn is_interactive(&self) -> bool { self.interactive @@ -157,6 +168,7 @@ impl Cmd { store: SqliteStore, theme: &Theme, ) -> Result<()> { + let authors = self.resolved_authors(settings); let query = self.query.unwrap_or_else(|| { std::env::var("ATUIN_QUERY").map_or_else( |_| vec![], @@ -221,7 +233,8 @@ impl Cmd { let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); if self.interactive { - let item = interactive::history(&query, settings, db, &history_store, theme).await?; + let item = + interactive::history(&query, authors, settings, db, &history_store, theme).await?; if let Some(result_file) = self.result_file { let mut file = File::create(result_file)?; @@ -248,6 +261,7 @@ impl Cmd { offset: self.offset, reverse: self.reverse, include_duplicates: self.include_duplicates, + authors, }; let mut entries = @@ -337,6 +351,7 @@ async fn run_non_interactive( #[cfg(test)] mod tests { use super::Cmd; + use atuin_client::settings::Settings; use clap::Parser; #[test] @@ -356,4 +371,24 @@ mod tests { let cmd = cmd.unwrap(); assert_eq!(cmd.query, Some(vec!["--foo".to_string()])); } + + #[test] + fn search_authors_default_to_settings() { + let cmd = Cmd::try_parse_from(["search"]).unwrap(); + let settings = Settings::default(); + + assert_eq!(cmd.resolved_authors(&settings), settings.search.authors); + } + + #[test] + fn search_authors_cli_override_config() { + let cmd = + Cmd::try_parse_from(["search", "--author", "codex", "--author", "ellie"]).unwrap(); + let settings = Settings::default(); + + assert_eq!( + cmd.resolved_authors(&settings), + vec!["codex".to_string(), "ellie".to_string()] + ); + } } diff --git a/crates/atuin/src/command/client/search/engines.rs b/crates/atuin/src/command/client/search/engines.rs index 8cbee0c3..98692828 100644 --- a/crates/atuin/src/command/client/search/engines.rs +++ b/crates/atuin/src/command/client/search/engines.rs @@ -1,6 +1,6 @@ use async_trait::async_trait; use atuin_client::{ - database::{Context, Database}, + database::{Context, Database, OptFilters}, history::{History, HistoryId}, settings::{FilterMode, SearchMode, Settings}, }; @@ -33,6 +33,30 @@ pub struct SearchState { pub filter_mode: FilterMode, pub context: Context, pub custom_context: Option<HistoryId>, + pub authors: Vec<String>, +} + +async fn search_db( + state: &SearchState, + db: &dyn Database, + mode: SearchMode, + query: &str, +) -> Result<Vec<History>> { + Ok(db + .search( + mode, + state.filter_mode, + &state.context, + query, + OptFilters { + limit: Some(200), + authors: state.authors.clone(), + ..Default::default() + }, + ) + .await? + .into_iter() + .collect()) } impl SearchState { @@ -70,13 +94,17 @@ pub trait SearchEngine: Send + Sync + 'static { db: &mut dyn Database, ) -> Result<Vec<History>>; + async fn empty_query( + &mut self, + state: &SearchState, + db: &mut dyn Database, + ) -> Result<Vec<History>> { + search_db(state, db, SearchMode::FullText, "").await + } + async fn query(&mut self, state: &SearchState, db: &mut dyn Database) -> Result<Vec<History>> { if state.input.as_str().is_empty() { - Ok(db - .list(&[state.filter_mode], &state.context, Some(200), true, false) - .await? - .into_iter() - .collect::<Vec<_>>()) + self.empty_query(state, db).await } else { self.full_query(state, db).await } diff --git a/crates/atuin/src/command/client/search/engines/daemon.rs b/crates/atuin/src/command/client/search/engines/daemon.rs index 50471898..4fb3e2ea 100644 --- a/crates/atuin/src/command/client/search/engines/daemon.rs +++ b/crates/atuin/src/command/client/search/engines/daemon.rs @@ -1,6 +1,8 @@ +use std::collections::HashMap; + use async_trait::async_trait; use atuin_client::{ - database::{Database, OptFilters}, + database::Database, history::History, settings::{SearchMode, Settings}, }; @@ -13,7 +15,7 @@ use eyre::Result; use tracing::{Level, debug, instrument, span}; use uuid::Uuid; -use super::{SearchEngine, SearchState}; +use super::{SearchEngine, SearchState, search_db}; use crate::command::client::daemon; pub struct Search { @@ -84,20 +86,9 @@ impl Search { state: &SearchState, db: &dyn Database, ) -> Result<Vec<History>> { - let results = db - .search( - SearchMode::FullText, - state.filter_mode, - &state.context, - state.input.as_str(), - OptFilters { - limit: Some(200), - ..Default::default() - }, - ) + search_db(state, db, SearchMode::FullText, state.input.as_str()) .await - .map_or(Vec::new(), |r| r.into_iter().collect()); - Ok(results) + .map_or_else(|_| Ok(Vec::new()), Ok) } #[instrument(skip_all, level = Level::TRACE, name = "hydrate_from_db", fields(count = ids.len()))] @@ -142,6 +133,7 @@ impl SearchEngine for Search { query_id, state.filter_mode, Some(state.context.clone()), + state.authors.clone(), ) .await } @@ -162,6 +154,7 @@ impl SearchEngine for Search { query_id, state.filter_mode, Some(state.context.clone()), + state.authors.clone(), ) .await? } @@ -206,12 +199,14 @@ impl SearchEngine for Search { // // Hydrate from local database let results = self.hydrate_from_db(db, &ids).await?; - // // Reorder results to match the order from the daemon (which is ranked by relevance) + // Reorder results to match the order from the daemon (which is ranked by relevance) let ordered_results = span!(Level::TRACE, "reorder_results").in_scope(|| { + let results_by_id: HashMap<&str, &History> = + results.iter().map(|h| (h.id.0.as_str(), h)).collect(); let mut ordered_results = Vec::with_capacity(results.len()); for id in &ids { - if let Some(history) = results.iter().find(|h| h.id.0 == *id) { - ordered_results.push(history.clone()); + if let Some(history) = results_by_id.get(id.as_str()) { + ordered_results.push((*history).clone()); } } ordered_results diff --git a/crates/atuin/src/command/client/search/engines/db.rs b/crates/atuin/src/command/client/search/engines/db.rs index 476462f5..29daaaa1 100644 --- a/crates/atuin/src/command/client/search/engines/db.rs +++ b/crates/atuin/src/command/client/search/engines/db.rs @@ -1,8 +1,7 @@ -use super::{SearchEngine, SearchState}; +use super::{SearchEngine, SearchState, search_db}; use async_trait::async_trait; use atuin_client::{ database::Database, - database::OptFilters, database::{QueryToken, QueryTokenizer}, history::History, settings::SearchMode, @@ -23,21 +22,10 @@ impl SearchEngine for Search { state: &SearchState, db: &mut dyn Database, ) -> Result<Vec<History>> { - let results = db - .search( - self.0, - state.filter_mode, - &state.context, - state.input.as_str(), - OptFilters { - limit: Some(200), - ..Default::default() - }, - ) + search_db(state, db, self.0, state.input.as_str()) .await // ignore errors as it may be caused by incomplete regex - .map_or(Vec::new(), |r| r.into_iter().collect()); - Ok(results) + .map_or_else(|_| Ok(Vec::new()), Ok) } #[instrument(skip_all, level = Level::TRACE, name = "db_highlight")] @@ -107,3 +95,62 @@ pub fn get_highlight_indices_fulltext(command: &str, search_input: &str) -> Vec< ret.dedup(); ret } + +#[cfg(test)] +mod tests { + use super::*; + use crate::command::client::search::cursor::Cursor; + use atuin_client::{ + database::{Context, Database, Sqlite}, + history::{AUTHOR_FILTER_ALL_USER, History}, + settings::FilterMode, + }; + use time::macros::datetime; + + fn context() -> Context { + Context { + session: uuid::Uuid::now_v7().as_simple().to_string(), + cwd: "/tmp".to_string(), + hostname: "host:user".to_string(), + host_id: String::new(), + git_root: None, + } + } + + #[tokio::test] + async fn empty_query_uses_author_filters() { + let mut db = Sqlite::new(":memory:", 0.1).await.unwrap(); + + let user_history: History = History::import() + .timestamp(datetime!(2024-01-01 10:00 UTC)) + .command("git status") + .cwd("/tmp") + .author("ellie") + .build() + .into(); + let agent_history: History = History::import() + .timestamp(datetime!(2024-01-01 11:00 UTC)) + .command("git diff") + .cwd("/tmp") + .author("codex") + .build() + .into(); + + db.save_bulk(&[user_history.clone(), agent_history]) + .await + .unwrap(); + + let mut engine = Search(SearchMode::Fuzzy); + let state = SearchState { + input: Cursor::from(String::new()), + filter_mode: FilterMode::Global, + context: context(), + custom_context: None, + authors: vec![AUTHOR_FILTER_ALL_USER.to_string()], + }; + + let results = engine.query(&state, &mut db).await.unwrap(); + + assert_eq!(results, vec![user_history]); + } +} diff --git a/crates/atuin/src/command/client/search/engines/skim.rs b/crates/atuin/src/command/client/search/engines/skim.rs index 7d9feb40..4075f148 100644 --- a/crates/atuin/src/command/client/search/engines/skim.rs +++ b/crates/atuin/src/command/client/search/engines/skim.rs @@ -1,7 +1,11 @@ -use std::path::Path; +use std::{collections::HashMap, path::Path}; use async_trait::async_trait; -use atuin_client::{database::Database, history::History, settings::FilterMode}; +use atuin_client::{ + database::Database, + history::{History, author_matches_filters}, + settings::FilterMode, +}; use eyre::Result; use fuzzy_matcher::{FuzzyMatcher, skim::SkimMatcherV2}; use itertools::Itertools; @@ -53,7 +57,23 @@ impl SearchEngine for Search { #[instrument(skip_all, level = Level::TRACE, name = "load_all_history")] async fn load_all_history(db: &dyn Database) -> Vec<(History, i32)> { - db.all_with_count().await.unwrap() + let histories = db + .query_history("SELECT * FROM history WHERE deleted_at IS NULL ORDER BY timestamp DESC") + .await + .unwrap_or_default(); + + let mut counts = HashMap::new(); + for history in &histories { + *counts.entry(history.command.clone()).or_insert(0) += 1; + } + + histories + .into_iter() + .map(|history| { + let count = counts.get(&history.command).copied().unwrap_or(1); + (history, count) + }) + .collect() } #[allow(clippy::too_many_lines)] @@ -72,6 +92,9 @@ async fn fuzzy_search( if i % 256 == 0 { yield_now().await; } + if !author_matches_filters(&history.author, &state.authors) { + continue; + } let context = &state.context; let git_root = context .git_root @@ -220,3 +243,62 @@ fn path_dist(a: &Path, b: &Path) -> usize { b.len() - a.len() + dist } + +#[cfg(test)] +mod tests { + use super::*; + use crate::command::client::search::cursor::Cursor; + use atuin_client::{ + database::{Context, Database, Sqlite}, + history::{AUTHOR_FILTER_ALL_USER, History}, + settings::FilterMode, + }; + use time::macros::datetime; + + fn context() -> Context { + Context { + session: uuid::Uuid::now_v7().as_simple().to_string(), + cwd: "/tmp".to_string(), + hostname: "host:user".to_string(), + host_id: String::new(), + git_root: None, + } + } + + #[tokio::test] + async fn skim_search_uses_author_filters() { + let mut db = Sqlite::new(":memory:", 0.1).await.unwrap(); + + let user_history: History = History::import() + .timestamp(datetime!(2024-01-01 10:00 UTC)) + .command("git status") + .cwd("/tmp") + .author("ellie") + .build() + .into(); + let agent_history: History = History::import() + .timestamp(datetime!(2024-01-01 11:00 UTC)) + .command("git stash") + .cwd("/tmp") + .author("codex") + .build() + .into(); + + db.save_bulk(&[user_history.clone(), agent_history]) + .await + .unwrap(); + + let mut engine = Search::new(); + let state = SearchState { + input: Cursor::from("git st".to_owned()), + filter_mode: FilterMode::Global, + context: context(), + custom_context: None, + authors: vec![AUTHOR_FILTER_ALL_USER.to_string()], + }; + + let results = engine.query(&state, &mut db).await.unwrap(); + + assert_eq!(results, vec![user_history]); + } +} diff --git a/crates/atuin/src/command/client/search/interactive.rs b/crates/atuin/src/command/client/search/interactive.rs index ee38ddaa..f572ed7d 100644 --- a/crates/atuin/src/command/client/search/interactive.rs +++ b/crates/atuin/src/command/client/search/interactive.rs @@ -1600,6 +1600,7 @@ fn compute_popup_placement( )] pub async fn history( query: &[String], + authors: Vec<String>, settings: &Settings, mut db: impl Database, history_store: &HistoryStore, @@ -1768,6 +1769,7 @@ pub async fn history( filter_mode: default_filter_mode, context: initial_context.clone(), custom_context: None, + authors, }, engine: engines::engine(search_mode, settings), results_len: 0, @@ -2259,6 +2261,7 @@ mod tests { git_root: None, }, custom_context: None, + authors: vec![], }, engine: engines::engine(SearchMode::Fuzzy, &settings), now: Box::new(OffsetDateTime::now_utc), @@ -2314,6 +2317,7 @@ mod tests { git_root: None, }, custom_context: None, + authors: vec![], }, engine: engines::engine(SearchMode::Fuzzy, &settings), now: Box::new(OffsetDateTime::now_utc), @@ -2433,6 +2437,7 @@ mod tests { git_root: None, }, custom_context: None, + authors: vec![], }, engine: engines::engine(SearchMode::Fuzzy, &settings), now: Box::new(OffsetDateTime::now_utc), @@ -2492,6 +2497,7 @@ mod tests { git_root: None, }, custom_context: None, + authors: vec![], }, engine: engines::engine(SearchMode::Fuzzy, &settings), now: Box::new(OffsetDateTime::now_utc), @@ -2547,6 +2553,7 @@ mod tests { git_root: None, }, custom_context: None, + authors: vec![], }, engine: engines::engine(SearchMode::Fuzzy, &settings), now: Box::new(OffsetDateTime::now_utc), @@ -2598,6 +2605,7 @@ mod tests { git_root: None, }, custom_context: None, + authors: vec![], }, engine: engines::engine(SearchMode::Fuzzy, &settings), now: Box::new(OffsetDateTime::now_utc), @@ -2658,6 +2666,7 @@ mod tests { git_root: None, }, custom_context: None, + authors: vec![], }, engine: engines::engine(SearchMode::Fuzzy, &settings), now: Box::new(OffsetDateTime::now_utc), @@ -2719,6 +2728,7 @@ mod tests { git_root: None, }, custom_context: None, + authors: vec![], }, engine: engines::engine(SearchMode::Fuzzy, &settings), now: Box::new(OffsetDateTime::now_utc), @@ -3098,6 +3108,7 @@ mod tests { git_root: None, }, custom_context: None, + authors: vec![], }, engine: engines::engine(SearchMode::Fuzzy, &settings), now: Box::new(OffsetDateTime::now_utc), diff --git a/docs/docs/configuration/config.md b/docs/docs/configuration/config.md index f453f903..3d4f6b25 100644 --- a/docs/docs/configuration/config.md +++ b/docs/docs/configuration/config.md @@ -576,6 +576,33 @@ frequency_score_multiplier = 0.8 frecency_score_multiplier = 2.0 ``` +#### `authors` + +Default: `["$all-user"]` + +Filter search results by command author. This controls which commands appear in interactive search based on who (or what) ran them. Useful when AI coding agents are recording commands via [agent hooks](../guide/agent-hooks.md). + +Special values: + +| Value | Meaning | +|-------|---------| +| `$all-user` | Commands from any author that is **not** a known AI agent | +| `$all-agent` | Commands from any known AI agent | + +You can also use literal author names like `"claude-code"` or `"codex"`. + +```toml +[search] +# Default: only show human-authored commands +authors = ["$all-user"] + +# Show everything (no author filtering) +# authors = [] + +# Show commands from you and Claude Code +# authors = ["$all-user", "claude-code"] +``` + ## Stats This section of client config is specifically for configuring Atuin stats calculations diff --git a/docs/docs/guide/agent-hooks.md b/docs/docs/guide/agent-hooks.md new file mode 100644 index 00000000..6e02b794 --- /dev/null +++ b/docs/docs/guide/agent-hooks.md @@ -0,0 +1,126 @@ +# AI Agent Hooks + +Atuin can capture commands run by AI coding agents (like Claude Code and Codex) alongside your regular shell history. Each command is tagged with the agent that ran it, so you can filter your history by author. + +## Quick Start + +Install hooks for your agent, then restart the agent: + +```shell +# Claude Code +atuin hook install claude-code + +# Codex +atuin hook install codex +``` + +That's it. Commands the agent runs will now appear in your Atuin history, tagged with the agent's name. + +## How It Works + +AI coding agents support hook systems that notify external tools when they're about to run a shell command and when the command finishes. Atuin uses these hooks to record each command as a history entry, just like commands you type yourself. + +When `atuin hook install` runs, it writes the agent's config file to register Atuin as a hook handler: + +| Agent | Config file | +|-------|-------------| +| Claude Code | `~/.claude/settings.json` | +| Codex | `~/.codex/hooks.json` | + +The hook lifecycle: + +1. **PreToolUse** -- the agent is about to run a Bash command. Atuin records the command, working directory, and timestamp (same as `history start`). +2. **PostToolUse / PostToolUseFailure** -- the command finished. Atuin records the exit code and duration (same as `history end`). + +Only `Bash` tool invocations are captured. Other tool types (file writes, web fetches, etc.) are ignored. + +## Filtering by Author + +By default, Atuin's interactive search shows only your own commands. Agent-run commands are hidden so they don't clutter your history. + +This is controlled by the `search.authors` setting in `~/.config/atuin/config.toml`: + +```toml +[search] +# Default: only show commands from human users +authors = ["$all-user"] +``` + +### Special filter values + +| Value | Meaning | +|-------|---------| +| `$all-user` | Any author that is **not** a known AI agent | +| `$all-agent` | Any known AI agent author | + +You can also use literal author names: + +```toml +[search] +# Show only your own commands and Claude Code commands +authors = ["$all-user", "claude-code"] +``` + +```toml +[search] +# Show everything (no filtering) +authors = [] +``` + +```toml +[search] +# Show only agent commands +authors = ["$all-agent"] +``` + +Currently recognized agent names are: `claude-code`, `codex`, and `copilot`. + +## Supported Agents + +### Claude Code + +```shell +atuin hook install claude-code +``` + +This adds hook entries to `~/.claude/settings.json`. Claude Code calls `atuin hook claude-code` on each `Bash` tool use, passing the event as JSON on stdin. + +### Codex + +```shell +atuin hook install codex +``` + +This adds hook entries to `~/.codex/hooks.json`. Codex calls `atuin hook codex` on each Bash tool use matching `^Bash$`. + +## Verifying Installation + +After installing hooks and restarting your agent, run a command through the agent and then check your history: + +```shell +# Show all history including agent commands +atuin search --authors '' -- '' + +# Show only agent commands +atuin search --authors '$all-agent' -- '' +``` + +You can also check the agent's config file directly to confirm the hooks are registered: + +```shell +# Claude Code +cat ~/.claude/settings.json | grep atuin + +# Codex +cat ~/.codex/hooks.json | grep atuin +``` + +## Re-installing + +Running `atuin hook install` again is safe. If hooks are already installed, the command will skip them and print a message: + +``` +hooks.PreToolUse: already installed, skipping +hooks.PostToolUse: already installed, skipping +hooks.PostToolUseFailure: already installed, skipping +``` diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 00adcfb2..a25dfa8f 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -70,6 +70,7 @@ nav: - Basic usage: guide/basic-usage.md - Advanced usage: guide/advanced-usage.md - Shell Integration: guide/shell-integration.md + - AI Agent Hooks: guide/agent-hooks.md - Deleting history: guide/delete-history.md - Syncing dotfiles: guide/dotfiles.md - Theming: guide/theming.md |
