From 09279a428659cf41824737d3e0c97bcc19a8885a Mon Sep 17 00:00:00 2001 From: Michelle Tilley Date: Fri, 10 Apr 2026 13:24:57 -0700 Subject: feat: Client-tool execution + permission system (#3370) Adds client-side tool execution to Atuin AI, starting with `atuin_history`. The server can request tool calls, which are executed locally with a permission system, and results are sent back to continue the conversation. --- crates/atuin-ai/Cargo.toml | 18 +- crates/atuin-ai/src/commands.rs | 4 +- crates/atuin-ai/src/commands/init.rs | 2 +- crates/atuin-ai/src/commands/inline.rs | 673 ++++-------- crates/atuin-ai/src/context.rs | 73 ++ crates/atuin-ai/src/lib.rs | 6 +- crates/atuin-ai/src/permissions/check.rs | 74 ++ crates/atuin-ai/src/permissions/file.rs | 26 + crates/atuin-ai/src/permissions/mod.rs | 7 + crates/atuin-ai/src/permissions/resolver.rs | 31 + crates/atuin-ai/src/permissions/rule.rs | 106 ++ crates/atuin-ai/src/permissions/shell.rs | 1297 ++++++++++++++++++++++++ crates/atuin-ai/src/permissions/walker.rs | 121 +++ crates/atuin-ai/src/permissions/writer.rs | 198 ++++ crates/atuin-ai/src/stream.rs | 372 +++++++ crates/atuin-ai/src/tools/descriptor.rs | 98 ++ crates/atuin-ai/src/tools/mod.rs | 1111 ++++++++++++++++++++ crates/atuin-ai/src/tui/components/atuin_ai.rs | 16 +- crates/atuin-ai/src/tui/components/markdown.rs | 47 +- crates/atuin-ai/src/tui/components/mod.rs | 7 +- crates/atuin-ai/src/tui/components/select.rs | 96 ++ crates/atuin-ai/src/tui/dispatch.rs | 571 +++++++++++ crates/atuin-ai/src/tui/events.rs | 19 +- crates/atuin-ai/src/tui/mod.rs | 11 +- crates/atuin-ai/src/tui/state.rs | 600 ++++++----- crates/atuin-ai/src/tui/view/mod.rs | 227 ++++- crates/atuin-ai/src/tui/view/turn.rs | 50 +- crates/atuin-client/Cargo.toml | 2 +- crates/atuin-client/src/settings.rs | 10 + crates/atuin-hex/Cargo.toml | 2 +- crates/atuin-hex/src/lib.rs | 2 +- crates/atuin/Cargo.toml | 8 +- crates/atuin/src/command/client.rs | 18 +- 33 files changed, 5045 insertions(+), 858 deletions(-) create mode 100644 crates/atuin-ai/src/context.rs create mode 100644 crates/atuin-ai/src/permissions/check.rs create mode 100644 crates/atuin-ai/src/permissions/file.rs create mode 100644 crates/atuin-ai/src/permissions/mod.rs create mode 100644 crates/atuin-ai/src/permissions/resolver.rs create mode 100644 crates/atuin-ai/src/permissions/rule.rs create mode 100644 crates/atuin-ai/src/permissions/shell.rs create mode 100644 crates/atuin-ai/src/permissions/walker.rs create mode 100644 crates/atuin-ai/src/permissions/writer.rs create mode 100644 crates/atuin-ai/src/stream.rs create mode 100644 crates/atuin-ai/src/tools/descriptor.rs create mode 100644 crates/atuin-ai/src/tools/mod.rs create mode 100644 crates/atuin-ai/src/tui/components/select.rs create mode 100644 crates/atuin-ai/src/tui/dispatch.rs (limited to 'crates') diff --git a/crates/atuin-ai/Cargo.toml b/crates/atuin-ai/Cargo.toml index 6e7315cd..c5f66695 100644 --- a/crates/atuin-ai/Cargo.toml +++ b/crates/atuin-ai/Cargo.toml @@ -12,6 +12,10 @@ repository = { workspace = true } # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[features] +default = [] +tree-sitter = ["dep:tree-sitter-lib", "dep:tree-sitter-bash", "dep:tree-sitter-fish"] + [dependencies] atuin-client = { workspace = true } atuin-common = { workspace = true } @@ -39,9 +43,21 @@ async-stream = "0.3" uuid = { workspace = true } tui-textarea-2 = "0.10.2" unicode-width = "0.2" -eye_declare = "0.3" +eye_declare = "0.4" ratatui-core = "0.1" ratatui-widgets = "0.3" +thiserror = { workspace = true } +glob-match = { workspace = true } +regex = { workspace = true } +time = { workspace = true } +toml = "1.1" +toml_edit = { workspace = true } +tree-sitter-lib = { package = "tree-sitter", version = "0.26.8", optional = true } +tree-sitter-bash = { version = "0.25.1", optional = true } +tree-sitter-fish = { version = "3.6.0", optional = true } +typed-builder = { workspace = true } +vt100 = { workspace = true } [dev-dependencies] pretty_assertions = { workspace = true } +tempfile = { workspace = true } diff --git a/crates/atuin-ai/src/commands.rs b/crates/atuin-ai/src/commands.rs index 6e79da61..cdbc8f2d 100644 --- a/crates/atuin-ai/src/commands.rs +++ b/crates/atuin-ai/src/commands.rs @@ -9,7 +9,7 @@ use eyre::Result; use tracing_appender::rolling::{RollingFileAppender, Rotation}; use tracing_subscriber::{EnvFilter, Layer, fmt, layer::SubscriberExt, util::SubscriberInitExt}; pub mod init; -pub mod inline; +pub(crate) mod inline; #[derive(Args, Debug)] pub struct AiArgs { @@ -71,7 +71,7 @@ pub async fn run( } } -pub fn detect_shell() -> Option { +pub(crate) fn detect_shell() -> Option { Some(Shell::current().to_string()) } diff --git a/crates/atuin-ai/src/commands/init.rs b/crates/atuin-ai/src/commands/init.rs index 77abc4f4..f693d892 100644 --- a/crates/atuin-ai/src/commands/init.rs +++ b/crates/atuin-ai/src/commands/init.rs @@ -1,6 +1,6 @@ use crate::commands::detect_shell; -pub async fn run(shell: String) -> eyre::Result<()> { +pub(crate) async fn run(shell: String) -> eyre::Result<()> { let integration = match shell.as_str() { "zsh" => generate_zsh_integration(), "bash" => generate_bash_integration(), diff --git a/crates/atuin-ai/src/commands/inline.rs b/crates/atuin-ai/src/commands/inline.rs index aeb414fb..b37bb72f 100644 --- a/crates/atuin-ai/src/commands/inline.rs +++ b/crates/atuin-ai/src/commands/inline.rs @@ -1,38 +1,44 @@ use std::path::PathBuf; use std::sync::mpsc; -use crate::commands::detect_shell; +use crate::context::{AppContext, ClientContext}; +use crate::tui::dispatch; use crate::tui::events::AiTuiEvent; -use crate::tui::state::{AppState, ExitAction}; +use crate::tui::state::{ExitAction, Session}; use crate::tui::view::ai_view; use atuin_client::database::{Database, Sqlite}; -use atuin_client::distro::detect_linux_distribution; -use atuin_common::tls::ensure_crypto_provider; -use eventsource_stream::Eventsource; -use eye_declare::{Application, CtrlCBehavior, Handle}; +use eye_declare::{Application, CtrlCBehavior}; use eyre::{Context as _, Result, bail}; -use futures::StreamExt; -use reqwest::Url; -use tracing::{debug, error, info, trace}; +use tracing::{debug, info}; -pub async fn run( +pub(crate) async fn run( initial_command: Option, api_endpoint: Option, api_token: Option, settings: &atuin_client::settings::Settings, output_for_hook: bool, ) -> Result<()> { - if !settings.ai.enabled.unwrap_or(false) { - emit_shell_result( - Action::Print( - "Atuin AI is not enabled. Please enable it in your settings or run `atuin setup`." - .to_string(), - ), - output_for_hook, - ); + if settings.ai.enabled == Some(false) { return Ok(()); } + if settings.ai.enabled.is_none() { + match prompt_ai_setup()? { + SetupChoice::EnableAi => { + set_ai_enabled(true).await?; + } + SetupChoice::DisableKeybind => { + set_ai_enabled(false).await?; + emit_shell_result(Action::Cancel, output_for_hook); + return Ok(()); + } + SetupChoice::Cancel => { + emit_shell_result(Action::Cancel, output_for_hook); + return Ok(()); + } + } + } + let endpoint = api_endpoint.as_deref().unwrap_or( settings .ai @@ -48,7 +54,36 @@ pub async fn run( ensure_hub_session(settings).await? }; - let action = run_inline_tui(endpoint.to_string(), token, initial_command, settings).await?; + let history_db_path = PathBuf::from(settings.db_path.as_str()); + let history_db = Sqlite::new(history_db_path, settings.local_timeout) + .await + .context("failed to open history database for AI")?; + + // Support both legacy [ai] send_cwd and new [ai.opening] send_cwd + let send_cwd = + settings.ai.opening.send_cwd.unwrap_or(false) || settings.ai.send_cwd.unwrap_or(false); + + let last_command = if settings.ai.opening.send_last_command.unwrap_or(false) { + history_db.last().await.ok().flatten().map(|h| h.command) + } else { + None + }; + + let git_root = std::env::current_dir() + .ok() + .and_then(|cwd| atuin_common::utils::in_git_repo(cwd.to_str()?)); + + let ctx = AppContext { + endpoint: endpoint.to_string(), + token, + send_cwd, + last_command, + history_db: std::sync::Arc::new(history_db), + git_root, + capabilities: settings.ai.capabilities.clone(), + }; + + let action = run_inline_tui(ctx, initial_command).await?; emit_shell_result(action, output_for_hook); Ok(()) @@ -69,7 +104,7 @@ async fn ensure_hub_session(settings: &atuin_client::settings::Settings) -> Resu if will_sync { println!( "Once logged in, your shell history will be synchronized via Atuin Hub if auto_sync is enabled or when manually syncing." - ) + ); } println!( "If you have an existing Atuin sync account, you can log in with your existing credentials." @@ -110,280 +145,17 @@ async fn ensure_hub_session(settings: &atuin_client::settings::Settings) -> Resu Ok(token) } -// ─────────────────────────────────────────────────────────────────── -// SSE streaming -// ─────────────────────────────────────────────────────────────────── - -#[derive(Debug, Clone)] -enum ChatStreamEvent { - TextChunk(String), - ToolCall { - id: String, - name: String, - input: serde_json::Value, - }, - ToolResult { - tool_use_id: String, - content: String, - is_error: bool, - }, - Status(String), - Done { - session_id: String, - }, - Error(String), -} - -fn create_chat_stream( - hub_address: String, - token: String, - session_id: Option, - messages: Vec, - send_cwd: bool, - last_command: Option, -) -> std::pin::Pin> + Send>> { - Box::pin(async_stream::stream! { - ensure_crypto_provider(); - let endpoint = match hub_url(&hub_address, "/api/cli/chat") { - Ok(url) => url, - Err(e) => { - yield Err(e); - return; - } - }; - - debug!("Sending SSE request to {endpoint}"); - - let os = detect_os(); - let shell = detect_shell(); - - let mut context = serde_json::json!({ - "os": os, - "shell": shell, - "pwd": if send_cwd { std::env::current_dir() - .ok() - .map(|path| path.to_string_lossy().into_owned()) } else { None }, - "last_command": last_command, - }); - - if os == "linux" { - context["distro"] = serde_json::json!(detect_linux_distribution()); - } - - let mut request_body = serde_json::json!({ - "messages": messages, - "context": context, - }); - - if let Some(ref sid) = session_id { - trace!("Including session_id in request: {sid}"); - request_body["session_id"] = serde_json::json!(sid); - } - - let client = reqwest::Client::new(); - let response = match client - .post(endpoint.clone()) - .header("Accept", "text/event-stream") - .bearer_auth(&token) - .json(&request_body) - .send() - .await - { - Ok(resp) => resp, - Err(e) => { - yield Err(eyre::eyre!("Failed to send SSE request: {}", e)); - return; - } - }; - - let status = response.status(); - if status == reqwest::StatusCode::UNAUTHORIZED { - error!("SSE request failed with status: {status}, clearing session"); - let _ = atuin_client::hub::delete_session().await; - yield Err(eyre::eyre!("Hub session expired. Re-run to authenticate again.")); - return; - } - if !status.is_success() { - let body = response.text().await.unwrap_or_default(); - error!("SSE request failed ({}): {}", status, body); - yield Err(eyre::eyre!("SSE request failed ({}): {}", status, body)); - return; - } - - let byte_stream = response.bytes_stream(); - let mut stream = byte_stream.eventsource(); - - while let Some(event) = stream.next().await { - match event { - Ok(sse_event) => { - let event_type = sse_event.event.as_str(); - let data = sse_event.data.clone(); - - debug!(event_type = %event_type, "SSE event received"); - - match event_type { - "text" => { - if let Ok(json) = serde_json::from_str::(&data) - && let Some(content) = json.get("content").and_then(|v| v.as_str()) - { - yield Ok(ChatStreamEvent::TextChunk(content.to_string())); - } - } - "tool_call" => { - if let Ok(json) = serde_json::from_str::(&data) { - let id = json.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let name = json.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let input = json.get("input").cloned().unwrap_or(serde_json::json!({})); - yield Ok(ChatStreamEvent::ToolCall { id, name, input }); - } - } - "tool_result" => { - if let Ok(json) = serde_json::from_str::(&data) { - let tool_use_id = json.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let content = json.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let is_error = json.get("is_error").and_then(|v| v.as_bool()).unwrap_or(false); - yield Ok(ChatStreamEvent::ToolResult { tool_use_id, content, is_error }); - } - } - "status" => { - if let Ok(json) = serde_json::from_str::(&data) - && let Some(state) = json.get("state").and_then(|v| v.as_str()) - { - yield Ok(ChatStreamEvent::Status(state.to_string())); - } - } - "done" => { - if let Ok(json) = serde_json::from_str::(&data) { - let session_id = json.get("session_id") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - yield Ok(ChatStreamEvent::Done { session_id }); - } else { - yield Ok(ChatStreamEvent::Done { session_id: String::new() }); - } - break; - } - "error" => { - if let Ok(json) = serde_json::from_str::(&data) { - let message = json.get("message").and_then(|v| v.as_str()).unwrap_or("Unknown error").to_string(); - error!("SSE error: {}", message); - yield Ok(ChatStreamEvent::Error(message)); - } else { - error!("SSE error: {}", data); - yield Ok(ChatStreamEvent::Error(data)); - } - break; - } - _ => {} - } - } - Err(e) => { - yield Err(eyre::eyre!("SSE error: {}", e)); - break; - } - } - } - }) -} - -// ─────────────────────────────────────────────────────────────────── -// Async streaming task — pushes updates to app state via Handle // ─────────────────────────────────────────────────────────────────── -async fn run_chat_stream( - handle: Handle, - endpoint: String, - token: String, - session_id: Option, - messages: Vec, - send_cwd: bool, - last_command: Option, -) { - let stream = create_chat_stream( - endpoint, - token, - session_id, - messages, - send_cwd, - last_command, - ); - futures::pin_mut!(stream); - - while let Some(event) = stream.next().await { - match event { - Ok(ChatStreamEvent::TextChunk(text)) => { - trace!(text = %text, "Processing TextChunk"); - handle.update(move |state| { - state.append_streaming_text(&text); - }); - } - Ok(ChatStreamEvent::ToolCall { id, name, input }) => { - trace!(id = %id, name = %name, "Processing ToolCall"); - handle.update(move |state| { - state.add_tool_call(id, name, input); - }); - } - Ok(ChatStreamEvent::ToolResult { - tool_use_id, - content, - is_error, - }) => { - trace!(tool_use_id = %tool_use_id, "Processing ToolResult"); - handle.update(move |state| { - state.add_tool_result(tool_use_id, content, is_error); - }); - } - Ok(ChatStreamEvent::Status(status)) => { - trace!(status = %status, "Processing Status"); - handle.update(move |state| { - state.update_streaming_status(&status); - }); - } - Ok(ChatStreamEvent::Done { session_id }) => { - trace!(session_id = %session_id, "Processing Done"); - handle.update(move |state| { - if !session_id.is_empty() { - state.store_session_id(session_id); - } - state.finalize_streaming(); - }); - break; - } - Ok(ChatStreamEvent::Error(msg)) => { - trace!(error = %msg, "Processing Error"); - handle.update(move |state| { - state.streaming_error(msg); - }); - break; - } - Err(e) => { - let msg = e.to_string(); - handle.update(move |state| { - state.streaming_error(msg); - }); - break; - } - } - } -} +async fn run_inline_tui(ctx: AppContext, initial_prompt: Option) -> Result { + let client_ctx = ClientContext::detect(); -// ─────────────────────────────────────────────────────────────────── -// Main TUI entry point -// ─────────────────────────────────────────────────────────────────── + let (tx, rx) = mpsc::channel::(); -async fn run_inline_tui( - endpoint: String, - token: String, - initial_prompt: Option, - settings: &atuin_client::settings::Settings, -) -> Result { - let initial_state = AppState::new(); + let initial_state = Session::new(ctx.git_root.is_some()); println!(); - let (tx, rx) = mpsc::channel::(); - // If there's an initial prompt, send it as a SubmitInput event // so it flows through the same path as user-typed input. if let Some(prompt) = initial_prompt { @@ -396,164 +168,17 @@ async fn run_inline_tui( .ctrl_c(CtrlCBehavior::Deliver) .keyboard_protocol(eye_declare::KeyboardProtocol::Enhanced) .bracketed_paste(true) - .with_context(tx) + .with_context(tx.clone()) .extra_newlines_at_exit(1) .build()?; - // Support both legacy [ai] send_cwd and new [ai.opening] send_cwd - let send_cwd = - settings.ai.opening.send_cwd.unwrap_or(false) || settings.ai.send_cwd.unwrap_or(false); - - let last_command = if settings.ai.opening.send_last_command.unwrap_or(false) { - let db_path = PathBuf::from(settings.db_path.as_str()); - match Sqlite::new(db_path, settings.local_timeout).await { - Ok(db) => db.last().await.ok().flatten().map(|h| h.command), - Err(e) => { - debug!("Failed to open history database for read_history: {e}"); - None - } - } - } else { - None - }; - // Event loop: receives AiTuiEvent from components, mutates state via Handle. let h = handle.clone(); - let ep = endpoint.clone(); - let tk = token.clone(); tokio::task::spawn_blocking(move || { + let tx = tx.clone(); + let client_ctx = client_ctx; while let Ok(event) = rx.recv() { - match event { - AiTuiEvent::InputUpdated(input) => { - let input_blank = input.trim().is_empty(); - - h.update(move |state| { - state.is_input_blank = input_blank; - }); - } - AiTuiEvent::SubmitInput(input) => { - let input = input.trim().to_string(); - if input.is_empty() { - let h2 = h.clone(); - h.update(move |state| { - if state.has_any_command() { - state.exit_action = Some(ExitAction::Execute( - state.current_command().unwrap().to_string(), - )); - } else { - state.exit_action = Some(ExitAction::Cancel); - } - h2.exit(); - }); - continue; - } - - if input.starts_with('/') { - let input_clone = input.clone(); - h.update(move |state| { - state.handle_slash_command(&input_clone); - }); - continue; - } - - // Start generation and spawn streaming task - let ep = ep.clone(); - let tk = tk.clone(); - let h2 = h.clone(); - let lc = last_command.clone(); - h.update(move |state| { - state.start_generating(input); - state.start_streaming(); - state.is_input_blank = true; - let messages = state.events_to_messages(); - let sid = state.session_id.clone(); - let task = tokio::spawn(async move { - run_chat_stream(h2, ep, tk, sid, messages, send_cwd, lc).await; - }); - state.stream_abort = Some(task.abort_handle()); - }); - } - - AiTuiEvent::SlashCommand(command) => { - h.update(move |state| { - state.handle_slash_command(&command); - }); - } - - AiTuiEvent::CancelGeneration => { - h.update(|state| match state.mode { - crate::tui::state::AppMode::Generating => { - state.cancel_generation(); - } - crate::tui::state::AppMode::Streaming => { - state.cancel_streaming(); - } - _ => {} - }); - } - - AiTuiEvent::ExecuteCommand => { - let h2 = h.clone(); - h.update(move |state| { - let cmd = state.current_command().map(|c| c.to_string()); - if let Some(cmd) = cmd { - if state.is_current_command_dangerous() && !state.confirmation_pending { - state.confirmation_pending = true; - } else { - state.confirmation_pending = false; - state.exit_action = Some(ExitAction::Execute(cmd)); - h2.exit(); - } - } - }); - } - - AiTuiEvent::CancelConfirmation => { - h.update(move |state| { - state.confirmation_pending = false; - }); - } - - AiTuiEvent::InsertCommand => { - let h2 = h.clone(); - h.update(move |state| { - let cmd = state.current_command().map(|c| c.to_string()); - if let Some(cmd) = cmd { - state.confirmation_pending = false; - state.exit_action = Some(ExitAction::Insert(cmd)); - h2.exit(); - } - }); - } - - AiTuiEvent::Retry => { - let ep = ep.clone(); - let tk = tk.clone(); - let h2 = h.clone(); - let lc = last_command.clone(); - h.update(move |state| { - state.retry(); - state.start_streaming(); - let messages = state.events_to_messages(); - let sid = state.session_id.clone(); - let task = tokio::spawn(async move { - run_chat_stream(h2, ep, tk, sid, messages, send_cwd, lc).await; - }); - state.stream_abort = Some(task.abort_handle()); - }); - } - - AiTuiEvent::Exit => { - let h2 = h.clone(); - h.update(move |state| { - if let Some(abort) = state.stream_abort.take() { - abort.abort(); - } - state.exit_action = Some(ExitAction::Cancel); - h2.exit(); - }); - } - } + dispatch::dispatch(&h, event, &tx, &ctx, &client_ctx); } }); @@ -573,51 +198,125 @@ async fn run_inline_tui( // Helpers // ─────────────────────────────────────────────────────────────────── -fn hub_url(base: &str, path: &str) -> Result { - let base_with_slash = if base.ends_with('/') { - base.to_string() - } else { - format!("{base}/") - }; - let stripped = path.strip_prefix('/').unwrap_or(path); - Url::parse(&base_with_slash)? - .join(stripped) - .context("failed to build hub URL") +enum SetupChoice { + EnableAi, + DisableKeybind, + Cancel, } -fn detect_os() -> String { - match std::env::consts::OS { - "macos" => "macos".to_string(), - "linux" => "linux".to_string(), - "windows" => "windows".to_string(), - other => format!("Other: {other}"), +fn prompt_ai_setup() -> Result { + use crossterm::{ + cursor, + event::{self, Event, KeyCode}, + terminal, + }; + + let options = ["Enable Atuin AI", "Disable ? Keybind", "Cancel"]; + let mut selected: usize = 0; + let mut stdout = std::io::stdout(); + + // Print header before raw mode so newlines render correctly. + // Use stdout because the shell hook swaps stdout/stderr — stdout goes + // to the terminal in both hook and non-hook modes. + println!(); + println!(" Atuin AI is not yet configured."); + println!(); + + terminal::enable_raw_mode().context("failed to enable raw mode")?; + struct Guard; + impl Drop for Guard { + fn drop(&mut self) { + let _ = terminal::disable_raw_mode(); + } } -} + let _guard = Guard; -#[derive(Clone)] -enum Action { - Execute(String), - Insert(String), - Print(String), - Cancel, -} + crossterm::execute!(stdout, cursor::Hide)?; -fn emit_shell_result(action: Action, output_for_hook: bool) { - if output_for_hook { - match action { - Action::Execute(output) => eprintln!("__atuin_ai_execute__:{output}"), - Action::Insert(output) => eprintln!("__atuin_ai_insert__:{output}"), - Action::Print(output) => eprintln!("__atuin_ai_print__:{output}"), - Action::Cancel => eprintln!("__atuin_ai_cancel__"), + loop { + render_setup_options(&mut stdout, &options, selected)?; + + let ev = event::read().context("failed to read key event")?; + + crossterm::execute!(stdout, cursor::MoveUp(options.len() as u16))?; + + if let Event::Key(key) = ev { + match key.code { + KeyCode::Up | KeyCode::Char('k') => { + selected = selected.saturating_sub(1); + } + KeyCode::Down | KeyCode::Char('j') => { + if selected < options.len() - 1 { + selected += 1; + } + } + KeyCode::Enter => break, + KeyCode::Esc => { + selected = 2; + break; + } + _ => {} + } } - } else { - match action { - Action::Execute(output) => eprintln!("{output}"), - Action::Insert(output) => eprintln!("{output}"), - Action::Print(output) => eprintln!("{output}"), - Action::Cancel => eprintln!(), + } + + // Final render with selection visible + render_setup_options(&mut stdout, &options, selected)?; + crossterm::execute!(stdout, cursor::Show)?; + + Ok(match selected { + 0 => SetupChoice::EnableAi, + 1 => SetupChoice::DisableKeybind, + _ => SetupChoice::Cancel, + }) +} + +fn render_setup_options( + w: &mut impl std::io::Write, + options: &[&str], + selected: usize, +) -> Result<()> { + use crossterm::{ + style::Stylize, + terminal::{Clear, ClearType}, + }; + + for (i, option) in options.iter().enumerate() { + if i == selected { + write!(w, "\r {}", format!("> {option}").bold().cyan())?; + } else { + write!(w, "\r {option}")?; } + crossterm::execute!(w, Clear(ClearType::UntilNewLine))?; + write!(w, "\r\n")?; } + w.flush()?; + Ok(()) +} + +async fn set_ai_enabled(enabled: bool) -> Result<()> { + let config_file = atuin_client::settings::Settings::get_config_path()?; + let config_str = tokio::fs::read_to_string(&config_file).await?; + let mut doc = config_str.parse::()?; + + if !doc.contains_key("ai") { + doc["ai"] = toml_edit::table(); + } + doc["ai"]["enabled"] = toml_edit::value(enabled); + + tokio::fs::write(&config_file, doc.to_string()).await?; + + if !enabled { + println!( + "Atuin AI keybind disabled. You can re-enable with `atuin config set ai.enabled true`.", + ); + println!("Restart your shell for changes to take effect."); + // Two printlns to ensure the message is visible above the shell prompt after program ends. + println!(); + println!(); + } + + Ok(()) } fn wait_for_login_confirmation() -> Result { @@ -646,3 +345,27 @@ fn wait_for_login_confirmation() -> Result { } } } + +#[derive(Clone)] +enum Action { + Execute(String), + Insert(String), + Cancel, +} + +fn emit_shell_result(action: Action, output_for_hook: bool) { + if output_for_hook { + match action { + Action::Execute(output) => eprintln!("__atuin_ai_execute__:{output}"), + Action::Insert(output) => eprintln!("__atuin_ai_insert__:{output}"), + Action::Cancel => eprintln!("__atuin_ai_cancel__"), + } + } else { + match action { + Action::Execute(output) | Action::Insert(output) => { + println!("{output}"); + } + Action::Cancel => {} + } + } +} diff --git a/crates/atuin-ai/src/context.rs b/crates/atuin-ai/src/context.rs new file mode 100644 index 00000000..dabb5c5e --- /dev/null +++ b/crates/atuin-ai/src/context.rs @@ -0,0 +1,73 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use atuin_client::distro::detect_linux_distribution; +use atuin_client::settings::AiCapabilities; + +/// Session-scoped context for the AI chat session. +/// Holds the API configuration and client settings needed by the event loop and stream task. +#[derive(Clone, Debug)] +pub(crate) struct AppContext { + pub endpoint: String, + pub token: String, + pub send_cwd: bool, + pub last_command: Option, + pub history_db: Arc, + /// Git root of the current working directory, if inside a git repo. + /// Resolves through worktrees to the main repo root. + pub git_root: Option, + pub capabilities: AiCapabilities, +} + +/// Machine identity — computed once per session. +#[derive(Clone, Debug)] +pub(crate) struct ClientContext { + pub os: String, + pub shell: Option, + pub distro: Option, +} + +impl ClientContext { + pub(crate) fn detect() -> Self { + let os = detect_os(); + let shell = crate::commands::detect_shell(); + let distro = if os == "linux" { + Some(detect_linux_distribution()) + } else { + None + }; + Self { os, shell, distro } + } + + /// Serialize to the JSON format the API expects for the "context" field. + /// The `pwd` field is always dynamic (current working directory), so it's + /// computed fresh on each call if `send_cwd` is true. + pub(crate) fn to_json(&self, send_cwd: bool, last_command: Option<&str>) -> serde_json::Value { + let mut ctx = serde_json::json!({ + "os": self.os, + "shell": self.shell, + "pwd": if send_cwd { + std::env::current_dir().ok().map(|p| p.to_string_lossy().into_owned()) + } else { + None + }, + "last_command": last_command, + }); + + if let Some(ref distro) = self.distro { + ctx["distro"] = serde_json::json!(distro); + } + + ctx + } +} + +/// Move the `detect_os` function here since it's about client identity. +fn detect_os() -> String { + match std::env::consts::OS { + "macos" => "macos".to_string(), + "linux" => "linux".to_string(), + "windows" => "windows".to_string(), + other => format!("Other: {other}"), + } +} diff --git a/crates/atuin-ai/src/lib.rs b/crates/atuin-ai/src/lib.rs index 2d86271d..6f431179 100644 --- a/crates/atuin-ai/src/lib.rs +++ b/crates/atuin-ai/src/lib.rs @@ -1,2 +1,6 @@ pub mod commands; -pub mod tui; +pub(crate) mod context; +pub(crate) mod permissions; +pub(crate) mod stream; +pub(crate) mod tools; +pub(crate) mod tui; diff --git a/crates/atuin-ai/src/permissions/check.rs b/crates/atuin-ai/src/permissions/check.rs new file mode 100644 index 00000000..6b908b93 --- /dev/null +++ b/crates/atuin-ai/src/permissions/check.rs @@ -0,0 +1,74 @@ +use eyre::Result; + +use crate::{permissions::file::RuleFile, tools::PermissableToolCall}; + +pub(crate) struct PermissionRequest<'t> { + call: &'t (dyn PermissableToolCall + Send + Sync), +} + +impl<'t> PermissionRequest<'t> { + pub fn new(call: &'t (dyn PermissableToolCall + Send + Sync)) -> Self { + Self { call } + } +} + +pub(crate) enum PermissionResponse { + Allowed, + Denied, + Ask, +} + +pub(crate) struct PermissionChecker { + files: Vec, +} + +impl PermissionChecker { + pub fn new(files: Vec) -> Self { + Self { files } + } + + pub async fn check<'t>( + &self, + request: &'t PermissionRequest<'t>, + ) -> Result { + // Files are in order from deepest to shallowest, so we can stop at the first match. + // Within a file, the priority is ask -> deny -> allow + // The first rule type that matches is the one that applies, even if a later rule would contradict it. + for file in &self.files { + for rule in &file.content.permissions.ask { + if request.call.matches_rule(rule) { + tracing::debug!( + "Permission 'ASK' by rule: {} in file: {}", + rule, + file.path.display() + ); + return Ok(PermissionResponse::Ask); + } + } + + for rule in &file.content.permissions.deny { + if request.call.matches_rule(rule) { + tracing::debug!( + "Permission 'DENY' by rule: {} in file: {}", + rule, + file.path.display() + ); + return Ok(PermissionResponse::Denied); + } + } + + for rule in &file.content.permissions.allow { + if request.call.matches_rule(rule) { + tracing::debug!( + "Permission 'ALLOW' by rule: {} in file: {}", + rule, + file.path.display() + ); + return Ok(PermissionResponse::Allowed); + } + } + } + + Ok(PermissionResponse::Ask) + } +} diff --git a/crates/atuin-ai/src/permissions/file.rs b/crates/atuin-ai/src/permissions/file.rs new file mode 100644 index 00000000..c973f55b --- /dev/null +++ b/crates/atuin-ai/src/permissions/file.rs @@ -0,0 +1,26 @@ +use std::path::PathBuf; + +use serde::{Deserialize, Serialize}; + +use crate::permissions::rule::Rule; + +#[derive(Debug, Clone)] +pub(crate) struct RuleFile { + pub path: PathBuf, + pub content: RuleFileContent, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub(crate) struct RuleFileContent { + pub permissions: RuleFilePermissions, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub(crate) struct RuleFilePermissions { + #[serde(default)] + pub allow: Vec, + #[serde(default)] + pub deny: Vec, + #[serde(default)] + pub ask: Vec, +} diff --git a/crates/atuin-ai/src/permissions/mod.rs b/crates/atuin-ai/src/permissions/mod.rs new file mode 100644 index 00000000..fce64a51 --- /dev/null +++ b/crates/atuin-ai/src/permissions/mod.rs @@ -0,0 +1,7 @@ +pub(crate) mod check; +pub(crate) mod file; +pub(crate) mod resolver; +pub(crate) mod rule; +pub(crate) mod shell; +pub(crate) mod walker; +pub(crate) mod writer; diff --git a/crates/atuin-ai/src/permissions/resolver.rs b/crates/atuin-ai/src/permissions/resolver.rs new file mode 100644 index 00000000..dc4f83bf --- /dev/null +++ b/crates/atuin-ai/src/permissions/resolver.rs @@ -0,0 +1,31 @@ +use std::path::PathBuf; + +use eyre::Result; + +use crate::permissions::check::{PermissionChecker, PermissionRequest, PermissionResponse}; +use crate::permissions::walker::PermissionWalker; +use crate::permissions::writer; +use crate::tools::ClientToolCall; + +/// Resolves permissions for client tool calls by walking the filesystem to find permission files, +pub(crate) struct PermissionResolver { + checker: PermissionChecker, +} + +impl PermissionResolver { + /// Create a new resolver that walks from `working_dir` to root for project + /// permissions, and also checks the global permissions file. + pub async fn new(working_dir: PathBuf) -> Result { + let global_file = writer::global_permissions_path(); + let mut walker = PermissionWalker::new(working_dir, Some(global_file)); + walker.walk().await?; + let checker = PermissionChecker::new(walker.rules().to_owned()); + Ok(Self { checker }) + } + + /// Check whether `tool` is allowed, denied, or needs user confirmation. + pub async fn check(&self, tool: &ClientToolCall) -> Result { + let request = PermissionRequest::new(tool); + self.checker.check(&request).await + } +} diff --git a/crates/atuin-ai/src/permissions/rule.rs b/crates/atuin-ai/src/permissions/rule.rs new file mode 100644 index 00000000..8fa3fa4a --- /dev/null +++ b/crates/atuin-ai/src/permissions/rule.rs @@ -0,0 +1,106 @@ +use std::sync::OnceLock; + +use regex::Regex; +use serde::{Deserialize, Serialize}; + +static RULE_RE: OnceLock = OnceLock::new(); + +#[derive(Debug, thiserror::Error)] +pub(crate) enum RuleError { + #[error("invalid rule format: {0}")] + InvalidRule(String), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct Rule { + pub tool: String, + pub scope: Option, +} + +impl std::fmt::Display for Rule { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.scope.as_ref() { + Some(scope) => write!(f, "{}({})", self.tool, scope), + None => write!(f, "{}", self.tool), + } + } +} + +impl Serialize for Rule { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(&self.to_string()) + } +} + +impl<'de> Deserialize<'de> for Rule { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + Self::try_from(s.as_str()).map_err(serde::de::Error::custom) + } +} +impl TryFrom<&str> for Rule { + type Error = RuleError; + + fn try_from(value: &str) -> Result { + let value = value.trim(); + let re = RULE_RE.get_or_init(|| Regex::new(r"^(\w+)(?:\((.*)\))?$").unwrap()); + let caps = re + .captures(value) + .ok_or(RuleError::InvalidRule(value.to_string()))?; + let tool = caps.get(1).unwrap().as_str().to_string(); + let scope = caps.get(2).map(|m| m.as_str().to_string()); + Ok(Rule { tool, scope }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_rule_try_from() { + assert_eq!( + Rule::try_from("Read").unwrap(), + Rule { + tool: "Read".to_string(), + scope: None + } + ); + assert_eq!( + Rule::try_from("Read(*)").unwrap(), + Rule { + tool: "Read".to_string(), + scope: Some("*".to_string()) + } + ); + assert_eq!( + Rule::try_from("Write(*.md)").unwrap(), + Rule { + tool: "Write".to_string(), + scope: Some("*.md".to_string()) + } + ); + assert_eq!( + Rule::try_from("Shell(git commit *)").unwrap(), + Rule { + tool: "Shell".to_string(), + scope: Some("git commit *".to_string()) + } + ); + assert_eq!( + Rule::try_from("Shell(echo ())").unwrap(), + Rule { + tool: "Shell".to_string(), + scope: Some("echo ()".to_string()) + } + ); + assert!(Rule::try_from("Shell(git commit *").is_err()); + assert!(Rule::try_from("Shell(git commit *)!").is_err()); + } +} diff --git a/crates/atuin-ai/src/permissions/shell.rs b/crates/atuin-ai/src/permissions/shell.rs new file mode 100644 index 00000000..7a2eee2e --- /dev/null +++ b/crates/atuin-ai/src/permissions/shell.rs @@ -0,0 +1,1297 @@ +/// Extracted command info from a shell command string. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct ShellCommand { + /// The command name (first word), e.g. "git" + pub name: String, + /// The full invocation including arguments, e.g. "git commit -m msg" + pub full: String, +} + +/// A parsed shell command with all subcommands extracted. +#[derive(Debug)] +pub(crate) struct ParsedShellCommand { + pub subcommands: Vec, +} + +/// Supported shell families for parsing. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum ShellKind { + /// POSIX sh, bash, zsh — all share similar syntax + Posix, + /// fish shell + Fish, + /// nushell or unknown — fallback to word-level extraction + Other, +} + +impl ShellKind { + pub(crate) fn from_shell_name(name: &str) -> Self { + match name { + "bash" | "sh" | "zsh" | "dash" | "ksh" => Self::Posix, + "fish" => Self::Fish, + _ => Self::Other, + } + } +} + +/// Parse a shell command string and extract all subcommands. +pub(crate) fn parse_shell_command(code: &str, shell: ShellKind) -> ParsedShellCommand { + #[cfg(feature = "tree-sitter")] + match shell { + ShellKind::Posix => ts::parse_posix(code), + ShellKind::Fish => ts::parse_fish(code), + ShellKind::Other => parse_fallback(code), + } + + #[cfg(not(feature = "tree-sitter"))] + { + let _ = shell; + parse_fallback(code) + } +} + +// ──────────────────────────────────────────────────────────────── +// Tree-sitter parsers (POSIX + Fish) +// Disabled on platforms where tree-sitter doesn't cross-compile +// (e.g. Windows); falls back to word-level extraction. +// ──────────────────────────────────────────────────────────────── + +#[cfg(feature = "tree-sitter")] +mod ts { + use super::{ParsedShellCommand, ShellCommand, parse_fallback}; + use tree_sitter_lib::{Parser, Tree}; + + fn bash_parser() -> Parser { + let mut parser = Parser::new(); + parser + .set_language(&tree_sitter_bash::LANGUAGE.into()) + .expect("failed to set bash language"); + parser + } + + pub(super) fn parse_posix(code: &str) -> ParsedShellCommand { + let mut parser = bash_parser(); + let Some(tree) = parser.parse(code, None) else { + return parse_fallback(code); + }; + + let mut commands = Vec::new(); + walk_bash_tree(&tree, code.as_bytes(), &mut commands); + ParsedShellCommand { + subcommands: commands, + } + } + + /// Leaf node kinds that never contain nested commands. + const BASH_LEAVES: &[&str] = &[ + "command_name", + "word", + "number", + "simple_expansion", + "expansion", + "arithmetic_expansion", + "ansi_c_string", + "special_variable_name", + "variable_name", + "file_descriptor", + "heredoc_body", + "heredoc_start", + "regex", + "heredoc_redirect", + ]; + + fn walk_bash_tree(tree: &Tree, source: &[u8], commands: &mut Vec) { + walk_bash_node(tree.root_node(), source, commands); + } + + fn walk_bash_node( + node: tree_sitter_lib::Node, + source: &[u8], + commands: &mut Vec, + ) { + match node.kind() { + "command" => { + if let Some(cmd) = extract_bash_command(node, source) { + commands.push(cmd); + } + // Descend into all non-leaf children to find nested commands + // (e.g. command_substitution inside a string inside a command) + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if !BASH_LEAVES.contains(&child.kind()) { + walk_bash_node(child, source, commands); + } + } + } + // Other nodes: descend into all children + _ => { + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + walk_bash_node(child, source, commands); + } + } + } + } + + /// Extract the full command string and name from a bash `command` node. + fn extract_bash_command(node: tree_sitter_lib::Node, source: &[u8]) -> Option { + // A `command` node has children like: + // variable_assignment* command_name argument* redirect* + // We want the command_name and all arguments (skipping assignments and redirects). + let mut name = None; + let mut name_start = None; + let mut arg_end = None; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + match child.kind() { + "command_name" => { + name = child.utf8_text(source).ok().map(|s| s.to_string()); + name_start = Some(child.start_byte()); + } + "word" + | "string" + | "raw_string" + | "concatenation" + | "number" + | "simple_expansion" + | "expansion" + | "arithmetic_expansion" + | "ansi_c_string" + | "process_substitution" => { + arg_end = Some(child.end_byte()); + } + _ => {} + } + } + + let name = name?; + let full = if let (Some(start), Some(end)) = (name_start, arg_end) { + std::str::from_utf8(&source[start..end]).ok()?.to_string() + } else { + name.clone() + }; + + Some(ShellCommand { name, full }) + } + + // ──────────────────────────────────────────────────────────────── + // Fish parser + // ──────────────────────────────────────────────────────────────── + + fn fish_parser() -> Parser { + let mut parser = Parser::new(); + parser + .set_language(&tree_sitter_fish::language()) + .expect("failed to set fish language"); + parser + } + + pub(super) fn parse_fish(code: &str) -> ParsedShellCommand { + let mut parser = fish_parser(); + let Some(tree) = parser.parse(code, None) else { + return parse_fallback(code); + }; + + let mut commands = Vec::new(); + walk_fish_tree(&tree, code.as_bytes(), &mut commands); + ParsedShellCommand { + subcommands: commands, + } + } + + const FISH_COMPOUND: &[&str] = &[ + "conditional_execution", + "pipe", + "job", + "command_substitution", + "block", + "for_statement", + "while_statement", + "if_statement", + "switch_statement", + "function_definition", + "begin_statement", + "redirected_statement", + ]; + + fn walk_fish_tree(tree: &Tree, source: &[u8], commands: &mut Vec) { + walk_fish_node(tree.root_node(), source, commands); + } + + fn walk_fish_node( + node: tree_sitter_lib::Node, + source: &[u8], + commands: &mut Vec, + ) { + match node.kind() { + "command" => { + if let Some(cmd) = extract_fish_command(node, source) { + commands.push(cmd); + } + // Still descend into compound children (e.g. command_substitution inside a command) + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if FISH_COMPOUND.contains(&child.kind()) { + walk_fish_node(child, source, commands); + } + } + } + // Other nodes: descend into all children + _ => { + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + walk_fish_node(child, source, commands); + } + } + } + } + + fn extract_fish_command(node: tree_sitter_lib::Node, source: &[u8]) -> Option { + // In fish, a `command` node has: + // name (command_name or word) followed by arguments (word, string, etc.) + let mut name = None; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + match child.kind() { + "command_name" | "word" => { + let text = child.utf8_text(source).ok()?.to_string(); + if name.is_none() { + name = Some(text); + } + } + "string" + | "concatenation" + | "command_substitution" + | "escape_sequence" + | "double_quote_string" + | "single_quote_string" => {} + _ => {} + } + } + + let name = name?; + // Get the full text of the command node + let full = node.utf8_text(source).ok()?.trim().to_string(); + + Some(ShellCommand { name, full }) + } +} // mod ts + +// ──────────────────────────────────────────────────────────────── +// Fallback (word-level extraction for nushell / unknown shells) +// ──────────────────────────────────────────────────────────────── + +fn parse_fallback(code: &str) -> ParsedShellCommand { + // Simple heuristic: split by &&, ||, ;, | and take the first word of each segment. + // This is intentionally simple — for unknown shells we can't do better. + let mut commands = Vec::new(); + let mut segment = String::new(); + let mut chars = code.chars().peekable(); + + while let Some(c) = chars.next() { + match c { + ';' => { + push_segment(&mut segment, &mut commands); + } + '|' => { + if chars.peek() == Some(&'|') { + chars.next(); + } + push_segment(&mut segment, &mut commands); + } + '&' if chars.peek() == Some(&'&') => { + chars.next(); + push_segment(&mut segment, &mut commands); + } + _ => segment.push(c), + } + } + push_segment(&mut segment, &mut commands); + + ParsedShellCommand { + subcommands: commands, + } +} + +fn push_segment(segment: &mut String, commands: &mut Vec) { + let trimmed = segment.trim(); + if !trimmed.is_empty() + && let Some(name) = trimmed.split_whitespace().next() + { + commands.push(ShellCommand { + name: name.to_string(), + full: trimmed.to_string(), + }); + } + segment.clear(); +} + +// ──────────────────────────────────────────────────────────────── +// Scope matching +// ──────────────────────────────────────────────────────────────── + +/// Check if any of the extracted subcommands match the given scope pattern. +/// +/// Matching semantics depend on where the `*` wildcard appears: +/// - `*` alone — matches everything +/// - `ls *` (space before `*`) — matches `ls` and `ls -a` but not `lsof` +/// - `git commit *` — matches `git commit -m "msg"` (word boundary) +/// - `ls*` (no space before `*`) — matches `lsof`, `ls`, `ls -a` (prefix/glob) +/// - `rm` (no wildcard) — matches exactly `rm` +/// - `git * amend` — matches `git commit amend` (middle wildcard matches zero+ words) +pub(crate) fn any_subcommand_matches(subcommands: &[ShellCommand], scope: &str) -> bool { + let scope = scope.trim(); + + if scope == "*" { + return true; + } + + if let Some(prefix) = scope.strip_suffix(" *") { + // Word-boundary matching: `ls *` matches `ls` and `ls -a` but not `lsof` + return subcommands.iter().any(|cmd| { + if prefix.is_empty() { + return true; + } + let cmd_words: Vec<&str> = cmd.full.split_whitespace().collect(); + let prefix_words: Vec<&str> = prefix.split_whitespace().collect(); + cmd_words.len() >= prefix_words.len() + && cmd_words[..prefix_words.len()] == prefix_words[..] + }); + } + + if let Some(prefix) = scope.strip_suffix('*') { + // Prefix/glob matching: `ls*` matches `lsof`, `ls`, etc. + return subcommands.iter().any(|cmd| cmd.full.starts_with(prefix)); + } + + if scope.contains('*') { + // Middle wildcard: `git * amend` — each `*` matches zero or more words + return subcommands + .iter() + .any(|cmd| scope_matches_words(scope, cmd.full.split_whitespace().collect())); + } + + // No wildcard: word-boundary prefix match + let scope_words: Vec<&str> = scope.split_whitespace().collect(); + subcommands.iter().any(|cmd| { + let cmd_words: Vec<&str> = cmd.full.split_whitespace().collect(); + cmd_words.len() >= scope_words.len() && cmd_words[..scope_words.len()] == scope_words[..] + }) +} + +/// Match a scope pattern containing `*` wildcards against a sequence of words. +/// Each `*` matches zero or more words. Consecutive `*` collapse into one. +fn scope_matches_words(scope: &str, words: Vec<&str>) -> bool { + let parts: Vec<&str> = scope.split('*').collect(); + if parts.len() == 1 { + // No wildcard (shouldn't reach here, but handle it) + let scope_words: Vec<&str> = scope.split_whitespace().collect(); + return words.len() >= scope_words.len() && words[..scope_words.len()] == scope_words[..]; + } + + // Each segment between * is a sequence of literal words that must appear in order. + // Walk through `words` consuming segments left to right. + let mut word_idx = 0; + + for (i, part) in parts.iter().enumerate() { + let segment_words: Vec<&str> = part.split_whitespace().collect(); + if segment_words.is_empty() { + continue; + } + + // Find the segment words starting from word_idx + if i == 0 { + // First segment must match at the start + if words.len() < segment_words.len() + || words[..segment_words.len()] != segment_words[..] + { + return false; + } + word_idx = segment_words.len(); + } else if i == parts.len() - 1 { + // Last segment must match at the end + if words.len() - word_idx < segment_words.len() { + return false; + } + let start = words.len() - segment_words.len(); + return words[start..] == segment_words[..]; + } else { + // Middle segment: find it anywhere after word_idx + let found = find_subslice(&words[word_idx..], &segment_words); + match found { + Some(idx) => word_idx += idx + segment_words.len(), + None => return false, + } + } + } + + true +} + +/// Find the first occurrence of `needle` as a contiguous subsequence in `haystack`. +fn find_subslice(haystack: &[&str], needle: &[&str]) -> Option { + if needle.is_empty() { + return Some(0); + } + if haystack.len() < needle.len() { + return None; + } + (0..=haystack.len() - needle.len()).find(|&i| haystack[i..i + needle.len()] == needle[..]) +} + +// ──────────────────────────────────────────────────────────────── +// Tests +// ──────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + fn names(cmds: &[ShellCommand]) -> Vec<&str> { + cmds.iter().map(|c| c.name.as_str()).collect() + } + + fn fulls(cmds: &[ShellCommand]) -> Vec<&str> { + cmds.iter().map(|c| c.full.as_str()).collect() + } + + #[test] + fn simple_command() { + let result = parse_shell_command("ls -la /tmp", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["ls"]); + assert_eq!(fulls(&result.subcommands), vec!["ls -la /tmp"]); + } + + #[test] + fn pipeline() { + let result = parse_shell_command("cat file.txt | grep foo | wc -l", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["cat", "grep", "wc"]); + } + + #[test] + fn command_chaining() { + let result = parse_shell_command("git add . && git commit -m 'hi'", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["git", "git"]); + assert_eq!( + fulls(&result.subcommands), + vec!["git add .", "git commit -m 'hi'"] + ); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn command_substitution() { + let result = parse_shell_command("echo $(git rev-parse HEAD)", ShellKind::Posix); + let n = names(&result.subcommands); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + assert!(n.contains(&"git"), "should contain git: {n:?}"); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn backtick_substitution() { + let result = parse_shell_command("echo `date`", ShellKind::Posix); + let n = names(&result.subcommands); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + assert!(n.contains(&"date"), "should contain date: {n:?}"); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn subshell() { + let result = parse_shell_command("(cd /tmp && ls)", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["cd", "ls"]); + } + + #[test] + fn semicolon_separated() { + let result = parse_shell_command("echo hello; echo world", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["echo", "echo"]); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn for_loop() { + let result = parse_shell_command("for f in *.txt; do cat $f; done", ShellKind::Posix); + assert!(names(&result.subcommands).contains(&"cat")); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn if_statement() { + let result = parse_shell_command( + "if [ -f foo ]; then cat foo; else echo nope; fi", + ShellKind::Posix, + ); + let n = names(&result.subcommands); + assert!(n.contains(&"cat"), "should contain cat: {n:?}"); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + } + + #[test] + fn scope_matching_wildcard() { + let commands = vec![ + ShellCommand { + name: "git".into(), + full: "git commit -m msg".into(), + }, + ShellCommand { + name: "npm".into(), + full: "npm test".into(), + }, + ]; + assert!(any_subcommand_matches(&commands, "*")); + } + + #[test] + fn scope_matching_prefix() { + let commands = vec![ + ShellCommand { + name: "git".into(), + full: "git commit -m msg".into(), + }, + ShellCommand { + name: "npm".into(), + full: "npm test".into(), + }, + ]; + assert!(any_subcommand_matches(&commands, "git commit *")); + assert!(any_subcommand_matches(&commands, "git commit")); + assert!(!any_subcommand_matches(&commands, "git push *")); + assert!(!any_subcommand_matches(&commands, "git push")); + assert!(any_subcommand_matches(&commands, "npm *")); + } + + #[test] + fn scope_word_boundary_vs_glob() { + let commands = vec![ + ShellCommand { + name: "ls".into(), + full: "ls -a".into(), + }, + ShellCommand { + name: "lsof".into(), + full: "lsof -i :3000".into(), + }, + ]; + // `ls *` — word boundary: matches `ls -a` but not `lsof` + assert!(any_subcommand_matches(&commands, "ls *")); + assert!(!any_subcommand_matches(&commands, "cat *")); + assert!(any_subcommand_matches(&commands, "lsof *")); + + // `ls*` — glob/prefix: matches both `ls -a` and `lsof` + assert!(any_subcommand_matches(&commands, "ls*")); + } + + #[test] + fn scope_exact_match() { + let commands = vec![ShellCommand { + name: "ls".into(), + full: "ls".into(), + }]; + assert!(any_subcommand_matches(&commands, "ls")); + assert!(!any_subcommand_matches(&commands, "cat")); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn nested_substitution() { + let result = parse_shell_command( + "echo \"Result: $(git log --oneline | head -1)\"", + ShellKind::Posix, + ); + let n = names(&result.subcommands); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + assert!(n.contains(&"git"), "should contain git: {n:?}"); + assert!(n.contains(&"head"), "should contain head: {n:?}"); + } + + #[test] + fn fallback_splits_correctly() { + let result = parse_shell_command("ls && cat foo || echo fail", ShellKind::Other); + let n = names(&result.subcommands); + assert!(n.contains(&"ls"), "should contain ls: {n:?}"); + assert!(n.contains(&"cat"), "should contain cat: {n:?}"); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + } + + #[test] + fn fish_simple_command() { + let result = parse_shell_command("ls -la /tmp", ShellKind::Fish); + assert_eq!(names(&result.subcommands), vec!["ls"]); + } + + #[test] + fn fish_conditional() { + let result = parse_shell_command("git add .; and git commit -m hi", ShellKind::Fish); + let n = names(&result.subcommands); + assert!(n.contains(&"git"), "should contain git: {n:?}"); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn fish_command_substitution() { + let result = parse_shell_command("echo (date)", ShellKind::Fish); + let n = names(&result.subcommands); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + assert!(n.contains(&"date"), "should contain date: {n:?}"); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn variable_assignment_excluded() { + let result = parse_shell_command("FOO=bar ls -la /tmp", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["ls"]); + assert_eq!(fulls(&result.subcommands), vec!["ls -la /tmp"]); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn variable_assignment_multiple() { + let result = parse_shell_command("A=1 B=2 git status", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["git"]); + assert_eq!(fulls(&result.subcommands), vec!["git status"]); + } + + #[test] + fn fallback_double_ampersand_and_pipe_pipe() { + let result = parse_shell_command("ls && cat foo || echo fail", ShellKind::Other); + assert_eq!(names(&result.subcommands), vec!["ls", "cat", "echo"]); + assert_eq!( + fulls(&result.subcommands), + vec!["ls", "cat foo", "echo fail"] + ); + } + + #[test] + fn fallback_pipe_without_double() { + let result = parse_shell_command("ls | grep foo", ShellKind::Other); + assert_eq!(names(&result.subcommands), vec!["ls", "grep"]); + assert_eq!(fulls(&result.subcommands), vec!["ls", "grep foo"]); + } + + #[test] + fn scope_middle_wildcard() { + let commands = vec![ShellCommand { + name: "git".into(), + full: "git commit -m amend".into(), + }]; + assert!(any_subcommand_matches(&commands, "git * amend")); + assert!(any_subcommand_matches(&commands, "git commit * amend")); + assert!(!any_subcommand_matches(&commands, "git push * amend")); + } + + #[test] + fn scope_middle_wildcard_zero_words() { + let commands = vec![ShellCommand { + name: "git".into(), + full: "git commit".into(), + }]; + // `*` matches zero words, so `git * commit` should match `git commit` + assert!(any_subcommand_matches(&commands, "git * commit")); + } + + #[test] + fn scope_leading_wildcard() { + let commands = vec![ShellCommand { + name: "docker".into(), + full: "docker run --rm alpine".into(), + }]; + assert!(any_subcommand_matches(&commands, "* alpine")); + assert!(!any_subcommand_matches(&commands, "* ubuntu")); + } + + #[test] + fn scope_multiple_wildcards() { + let commands = vec![ShellCommand { + name: "git".into(), + full: "git rebase -i HEAD~5".into(), + }]; + assert!(any_subcommand_matches(&commands, "git * -i * HEAD~5")); + assert!(!any_subcommand_matches(&commands, "git * -i * HEAD~10")); + } +} + +#[cfg(all(test, feature = "tree-sitter"))] +mod adversarial { + use super::*; + + fn cmd_names(cmds: &[ShellCommand]) -> Vec<&str> { + cmds.iter().map(|c| c.name.as_str()).collect() + } + + /// Helper: assert that parsing POSIX extracts all expected command names + fn assert_posix(code: &str, expected: &[&str]) { + let result = parse_shell_command(code, ShellKind::Posix); + let mut got: Vec<&str> = result.subcommands.iter().map(|c| c.name.as_str()).collect(); + got.sort(); + let mut want: Vec<&str> = expected.to_vec(); + want.sort(); + assert_eq!( + got, want, + "POSIX parse of {:?}:\n got: {:?}\n want: {:?}", + code, got, want + ); + } + + fn assert_fish(code: &str, expected: &[&str]) { + let result = parse_shell_command(code, ShellKind::Fish); + let mut got: Vec<&str> = result.subcommands.iter().map(|c| c.name.as_str()).collect(); + got.sort(); + let mut want: Vec<&str> = expected.to_vec(); + want.sort(); + assert_eq!( + got, want, + "Fish parse of {:?}:\n got: {:?}\n want: {:?}", + code, got, want + ); + } + + // ──────────────────────────────────────────────────────────── + // Level 1: Basic compounds + // ──────────────────────────────────────────────────────────── + + #[test] + fn a01_triple_chain() { + assert_posix("a && b && c", &["a", "b", "c"]); + } + + #[test] + fn a02_or_chain() { + assert_posix("a || b || c", &["a", "b", "c"]); + } + + #[test] + fn a03_mixed_chain() { + assert_posix("a && b || c && d", &["a", "b", "c", "d"]); + } + + #[test] + fn a04_long_pipeline() { + assert_posix( + "cat foo | grep bar | awk '{print $1}' | sort | uniq -c", + &["cat", "grep", "awk", "sort", "uniq"], + ); + } + + #[test] + fn a05_semicolons() { + assert_posix("a; b; c; d", &["a", "b", "c", "d"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 2: Nested substitution + // ──────────────────────────────────────────────────────────── + + #[test] + fn a06_nested_dollar() { + assert_posix( + "echo $(basename $(dirname /foo/bar))", + &["echo", "basename", "dirname"], + ); + } + + #[test] + fn a07_deeply_nested() { + // 4 nested echos, all should be extracted + assert_posix( + "echo $(echo $(echo $(echo deep)))", + &["echo", "echo", "echo", "echo"], + ); + } + + #[test] + fn a08_backtick_in_echo() { + assert_posix("echo `hostname`", &["echo", "hostname"]); + } + + #[test] + fn a09_mixed_substitutions() { + assert_posix("echo $(date) `uname`", &["echo", "date", "uname"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 3: Subshells and grouping + // ──────────────────────────────────────────────────────────── + + #[test] + fn a10_subshell_chain() { + assert_posix("(cd /tmp && ls -la)", &["cd", "ls"]); + } + + #[test] + fn a11_nested_subshells() { + assert_posix("( (inner_cmd) )", &["inner_cmd"]); + } + + #[test] + fn a12_brace_group() { + assert_posix("{ cd /tmp; ls; }", &["cd", "ls"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 4: Variable assignments + // ──────────────────────────────────────────────────────────── + + #[test] + fn a13_single_var_assignment() { + let result = parse_shell_command("FOO=bar ls", ShellKind::Posix); + assert_eq!(cmd_names(&result.subcommands), &["ls"]); + assert_eq!(result.subcommands[0].full, "ls"); + } + + #[test] + fn a14_multiple_var_assignments() { + let result = parse_shell_command("A=1 B=2 C=3 git status", ShellKind::Posix); + assert_eq!(cmd_names(&result.subcommands), &["git"]); + assert_eq!(result.subcommands[0].full, "git status"); + } + + #[test] + fn a15_var_assignment_no_command() { + // Variable assignment only — no command to extract + assert_posix("FOO=bar", &[]); + } + + #[test] + fn a16_var_assignment_in_pipeline() { + assert_posix("FOO=bar ls | BAZ=qux grep foo", &["ls", "grep"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 5: Control flow + // ──────────────────────────────────────────────────────────── + + #[test] + fn a17_if_then_else() { + assert_posix( + "if [ -f foo ]; then cat foo; else echo missing; fi", + &["cat", "echo"], + ); + } + + #[test] + fn a18_elif_chain() { + // Two cat commands (then + elif branch), one echo (else branch). + // [ is part of the test_condition, not extracted as a command. + assert_posix( + "if [ -f a ]; then cat a; elif [ -f b ]; then cat b; else echo none; fi", + &["cat", "cat", "echo"], + ); + } + + #[test] + fn a19_for_loop() { + assert_posix("for f in *.txt; do cat \"$f\"; done", &["cat"]); + } + + #[test] + fn a20_while_loop() { + // read in the condition is a real command + assert_posix( + "while read line; do echo \"$line\"; done < input.txt", + &["echo", "read"], + ); + } + + #[test] + fn f07_if_statement() { + // test in if-condition is a real command + assert_fish( + "if test -f foo; cat foo; else; echo missing; end", + &["cat", "echo", "test"], + ); + } + + #[test] + fn f09_while_loop() { + // `true` in the condition is a real command + assert_fish( + "while true; echo tick; sleep 1; end", + &["echo", "sleep", "true"], + ); + } + + // ──────────────────────────────────────────────────────────── + // Level 6: Redirections + // ──────────────────────────────────────────────────────────── + + #[test] + fn a23_redirect_out() { + assert_posix("ls > output.txt", &["ls"]); + } + + #[test] + fn a24_redirect_append() { + assert_posix("ls >> output.txt 2>&1", &["ls"]); + } + + #[test] + fn a25_here_string() { + assert_posix("grep foo <<< \"hello world\"", &["grep"]); + } + + #[test] + fn a26_redirect_in_pipeline() { + assert_posix("cat < input.txt | sort | uniq", &["cat", "sort", "uniq"]); + } + + #[test] + fn a27_process_substitution() { + assert_posix( + "diff <(sort a.txt) <(sort b.txt)", + &["diff", "sort", "sort"], + ); + } + + // ──────────────────────────────────────────────────────────── + // Level 7: Function definitions + // ──────────────────────────────────────────────────────────── + + #[test] + fn a28_function_def() { + assert_posix("foo() { echo hello; }", &["echo"]); + } + + #[test] + fn a29_function_with_subshell() { + assert_posix( + "build() { cargo build && cargo test; }", + &["cargo", "cargo"], + ); + } + + // ──────────────────────────────────────────────────────────── + // Level 8: Edge cases — empties, weird quoting + // ──────────────────────────────────────────────────────────── + + #[test] + fn a30_empty_string() { + let result = parse_shell_command("", ShellKind::Posix); + assert!(result.subcommands.is_empty()); + } + + #[test] + fn a31_whitespace_only() { + let result = parse_shell_command(" \t \n ", ShellKind::Posix); + assert!(result.subcommands.is_empty()); + } + + #[test] + fn a32_single_command_no_args() { + assert_posix("ls", &["ls"]); + } + + #[test] + fn a33_command_with_single_quotes() { + assert_posix("echo 'hello world'", &["echo"]); + } + + #[test] + fn a34_command_with_double_quotes() { + assert_posix("echo \"hello world\"", &["echo"]); + } + + #[test] + fn a35_escaped_spaces() { + // ls\ -la is a single word in bash, not "ls" with flag "-la" + assert_posix("ls\\ -la", &["ls\\ -la"]); + } + + #[test] + fn a36_command_with_dollar_var() { + assert_posix("echo $HOME/.bashrc", &["echo"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 9: Background jobs and coproc + // ──────────────────────────────────────────────────────────── + + #[test] + fn a37_background_job() { + assert_posix("sleep 10 &", &["sleep"]); + } + + #[test] + fn a38_background_chain() { + assert_posix("sleep 10 && echo done &", &["sleep", "echo"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 10: Real-world complex commands + // ──────────────────────────────────────────────────────────── + + #[test] + fn a39_docker_build_and_run() { + assert_posix( + "docker build -t app . && docker run --rm app npm test", + &["docker", "docker"], + ); + } + + #[test] + fn a40_git_rebase_interactive() { + assert_posix( + "GIT_SEQUENCE_EDITOR=\"sed -i 's/pick/reword/'\" git rebase -i HEAD~5", + &["git"], + ); + } + + #[test] + fn a41_find_with_exec() { + // tree-sitter-bash does not parse -exec body as commands — only `find` is extracted. + // This is a known limitation: args to -exec/-execdir are opaque to the parser. + assert_posix("find . -name '*.rs' -exec grep -l 'unsafe' {} +", &["find"]); + } + + #[test] + fn a42_curl_pipe_sh() { + assert_posix( + "curl -sSL https://example.com/install.sh | bash", + &["curl", "bash"], + ); + } + + #[test] + fn a43_xargs() { + assert_posix("find . -name '*.tmp' | xargs rm -f", &["find", "xargs"]); + } + + #[test] + fn a44_npm_script_chain() { + assert_posix( + "npm run build && npm run test && npm run lint", + &["npm", "npm", "npm"], + ); + } + + #[test] + fn a45_make_with_redirect() { + assert_posix( + "make -j$(nproc) 2>&1 | tee build.log", + &["make", "nproc", "tee"], + ); + } + + #[test] + fn a46_sudo_chain() { + assert_posix("sudo apt update && sudo apt upgrade -y", &["sudo", "sudo"]); + } + + #[test] + fn a47_here_doc_with_subcommand() { + assert_posix("cat < output.txt", &["ls"]); + } + + #[test] + fn f13_redirect_append() { + assert_fish("ls >> output.txt", &["ls"]); + } + + #[test] + fn f14_here_string() { + assert_fish("grep foo <<< \"hello\"", &["grep"]); + } + + #[test] + fn f15_curl_pipe() { + assert_fish( + "curl -sSL https://example.com/install.sh | bash", + &["curl", "bash"], + ); + } + + #[test] + fn f16_double_ampersand() { + assert_fish("git add . && git commit -m hi", &["git", "git"]); + } + + #[test] + fn f17_double_pipe() { + assert_fish("test -f foo || echo missing", &["test", "echo"]); + } + + #[test] + fn f18_empty() { + let result = parse_shell_command("", ShellKind::Fish); + assert!(result.subcommands.is_empty()); + } + + #[test] + fn f19_whitespace() { + let result = parse_shell_command(" ", ShellKind::Fish); + assert!(result.subcommands.is_empty()); + } + + // ──────────────────────────────────────────────────────────── + // Level 12: Scope matching adversarial + // ──────────────────────────────────────────────────────────── + + #[test] + fn s01_empty_scope() { + let commands = vec![ShellCommand { + name: "ls".into(), + full: "ls".into(), + }]; + // Empty scope matches everything (nothing to constrain) + assert!(any_subcommand_matches(&commands, "")); + } + + #[test] + fn s03_only_wildcard_space_star() { + let commands = vec![ShellCommand { + name: "ls".into(), + full: "ls".into(), + }]; + // " *" with empty prefix = match anything + assert!(any_subcommand_matches(&commands, " *")); + } + + #[test] + fn s04_glob_matches_empty() { + let commands = vec![ShellCommand { + name: "ls".into(), + full: "ls".into(), + }]; + // `ls*` matches `ls` (prefix match with nothing after) + assert!(any_subcommand_matches(&commands, "ls*")); + } + + #[test] + fn s05_middle_wildcard_empty_match() { + // `git * commit` matches `git commit` (* = zero words) + let commands = vec![ShellCommand { + name: "git".into(), + full: "git commit".into(), + }]; + assert!(any_subcommand_matches(&commands, "git * commit")); + } + + #[test] + fn s06_consecutive_wildcards() { + // `git ** commit` should behave like `git * commit` + let commands = vec![ShellCommand { + name: "git".into(), + full: "git commit".into(), + }]; + assert!(any_subcommand_matches(&commands, "git ** commit")); + } + + #[test] + fn s07_case_sensitivity() { + let commands = vec![ShellCommand { + name: "LS".into(), + full: "LS -la".into(), + }]; + assert!(!any_subcommand_matches(&commands, "ls")); + assert!(any_subcommand_matches(&commands, "LS")); + } + + #[test] + fn s08_multi_word_exact_no_subcommand() { + // `git commit` should not match `git commit-amend` + let commands = vec![ShellCommand { + name: "git".into(), + full: "git commit-amend".into(), + }]; + assert!(!any_subcommand_matches(&commands, "git commit")); + } +} diff --git a/crates/atuin-ai/src/permissions/walker.rs b/crates/atuin-ai/src/permissions/walker.rs new file mode 100644 index 00000000..3bda01c3 --- /dev/null +++ b/crates/atuin-ai/src/permissions/walker.rs @@ -0,0 +1,121 @@ +use std::path::{Path, PathBuf}; + +use eyre::Result; +use tokio::task::JoinSet; + +use crate::permissions::file::{RuleFile, RuleFileContent}; + +#[derive(Debug)] +struct FoundRuleFile { + depth: usize, + file: RuleFile, +} + +pub(crate) struct PermissionWalker { + start: PathBuf, + /// Direct path to the global permissions file (e.g. `~/.config/atuin/permissions.ai.toml`). + global_permissions_file: Option, + rules: Vec, +} + +impl PermissionWalker { + pub fn new(start: PathBuf, global_permissions_file: Option) -> Self { + Self { + start, + global_permissions_file, + rules: Vec::new(), + } + } + + pub fn rules(&self) -> &[RuleFile] { + &self.rules + } + + /// Walks the filesystem starting from the start path and collecting permission files along the way. + /// Walks to the root, then checks the global permissions file, if any. + pub async fn walk(&mut self) -> Result<()> { + let dirs_to_check: Vec = self.start.ancestors().map(PathBuf::from).collect(); + let dir_count = dirs_to_check.len(); + + let mut set: JoinSet>> = JoinSet::new(); + + for (index, path) in dirs_to_check.into_iter().enumerate() { + set.spawn(async move { + match check_dir_for_permissions(&path).await { + Ok(Some(rule_file)) => Ok(Some(FoundRuleFile { + depth: index, + file: rule_file, + })), + Ok(None) => Ok(None), + Err(e) => Err(e), + } + }); + } + + // Check the global file separately (it's a direct file path, not a dir/.atuin/ pattern) + if let Some(global_path) = self.global_permissions_file.clone() { + let depth = dir_count; // sorts after all directory-walk entries + set.spawn(async move { + match load_permissions_file(&global_path).await { + Ok(Some(rule_file)) => Ok(Some(FoundRuleFile { + depth, + file: rule_file, + })), + Ok(None) => Ok(None), + Err(e) => Err(e), + } + }); + } + + let capacity = dir_count + usize::from(self.global_permissions_file.is_some()); + let mut found = Vec::with_capacity(capacity); + while let Some(result) = set.join_next().await { + let result = result?; // JoinErrors result in failure to walk the filesystem + + match result { + Ok(Some(FoundRuleFile { depth, file })) => { + found.push((depth, file)); + } + Ok(None) => { + continue; + } + Err(e) => { + tracing::error!( + "Error while walking filesystem for permissions check; skipping: {}", + e + ); + continue; + } + } + } + // join_next() returns in order of completion, not order of spawn + found.sort_by_key(|(depth, _)| *depth); + self.rules = found.into_iter().map(|(_, file)| file).collect(); + + Ok(()) + } +} + +/// Checks a directory for `.atuin/permissions.ai.toml` and returns the RuleFile if found. +async fn check_dir_for_permissions(path: &Path) -> Result> { + let file_path = path.join(".atuin").join("permissions.ai.toml"); + load_permissions_file(&file_path).await +} + +/// Load a permissions file from an exact path. Returns None if the file doesn't exist. +async fn load_permissions_file(file_path: &Path) -> Result> { + if !tokio::fs::try_exists(file_path).await? { + return Ok(None); + } + + let raw = tokio::fs::read_to_string(file_path).await?; + let content: RuleFileContent = toml::from_str(&raw)?; + + // Use the file's parent as the rule file path (for logging/debugging) + let path = file_path + .parent() + .map(Path::to_path_buf) + .unwrap_or_else(|| file_path.to_path_buf()); + + Ok(Some(RuleFile { path, content })) +} diff --git a/crates/atuin-ai/src/permissions/writer.rs b/crates/atuin-ai/src/permissions/writer.rs new file mode 100644 index 00000000..b2bd9482 --- /dev/null +++ b/crates/atuin-ai/src/permissions/writer.rs @@ -0,0 +1,198 @@ +use std::path::Path; + +use eyre::Result; + +use crate::permissions::rule::Rule; + +/// Whether a rule should be added to the allow or deny list. +#[allow(dead_code)] +pub(crate) enum RuleDisposition { + Allow, + Deny, +} + +/// Write a permission rule to a `permissions.ai.toml` file. +/// +/// If the file doesn't exist it is created (along with parent directories). +/// If it does exist, `toml_edit` is used to append the rule while preserving +/// existing formatting and comments. +/// +/// **Not concurrent-safe.** The read-modify-write cycle is not atomic. In the +/// current UI this is fine — the Select widget serializes permission decisions — +/// but callers should not invoke this concurrently for the same file. +pub(crate) async fn write_rule( + file_path: &Path, + rule: &Rule, + disposition: RuleDisposition, +) -> Result<()> { + let content = if tokio::fs::try_exists(file_path).await.unwrap_or(false) { + tokio::fs::read_to_string(file_path).await? + } else { + String::new() + }; + + let mut doc: toml_edit::DocumentMut = content.parse()?; + + // Ensure [permissions] table exists + if !doc.contains_key("permissions") { + doc["permissions"] = toml_edit::Item::Table(toml_edit::Table::new()); + } + + let key = match disposition { + RuleDisposition::Allow => "allow", + RuleDisposition::Deny => "deny", + }; + + // Use as_table_like_mut so both standard and inline tables work. + let permissions = doc["permissions"] + .as_table_like_mut() + .ok_or_else(|| eyre::eyre!("[permissions] is not a table"))?; + + // Get or create the array + if !permissions.contains_key(key) { + permissions.insert(key, toml_edit::Item::Value(toml_edit::Array::new().into())); + } + + let array = permissions + .get_mut(key) + .and_then(|item| item.as_value_mut()) + .and_then(|v| v.as_array_mut()) + .ok_or_else(|| eyre::eyre!("permissions.{key} is not an array"))?; + + // Don't add duplicates + let rule_str = rule.to_string(); + let already_present = array.iter().any(|v| v.as_str() == Some(&rule_str)); + if !already_present { + array.push(rule_str); + } + + // Write back, creating parent directories as needed + if let Some(parent) = file_path.parent() { + tokio::fs::create_dir_all(parent).await?; + } + tokio::fs::write(file_path, doc.to_string()).await?; + + Ok(()) +} + +/// Build the path to the project-level permissions file. +/// `project_root` is typically a git root or the current working directory. +pub(crate) fn project_permissions_path(project_root: &Path) -> std::path::PathBuf { + project_root.join(".atuin").join("permissions.ai.toml") +} + +/// Build the path to the global permissions file (sibling of atuin config). +pub(crate) fn global_permissions_path() -> std::path::PathBuf { + atuin_common::utils::config_dir().join("permissions.ai.toml") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn creates_new_file_with_allow_rule() { + let dir = tempfile::tempdir().unwrap(); + let file = dir.path().join("permissions.ai.toml"); + let rule = Rule { + tool: "AtuinHistory".to_string(), + scope: None, + }; + + write_rule(&file, &rule, RuleDisposition::Allow) + .await + .unwrap(); + + let content = tokio::fs::read_to_string(&file).await.unwrap(); + assert!(content.contains("[permissions]")); + assert!(content.contains(r#""AtuinHistory""#)); + } + + #[tokio::test] + async fn appends_to_existing_file() { + let dir = tempfile::tempdir().unwrap(); + let file = dir.path().join("permissions.ai.toml"); + let existing = r#"# My permissions +[permissions] +allow = ["Read"] +"#; + tokio::fs::write(&file, existing).await.unwrap(); + + let rule = Rule { + tool: "AtuinHistory".to_string(), + scope: None, + }; + write_rule(&file, &rule, RuleDisposition::Allow) + .await + .unwrap(); + + let content = tokio::fs::read_to_string(&file).await.unwrap(); + // Comment preserved + assert!(content.contains("# My permissions")); + // Both rules present + assert!(content.contains(r#""Read""#)); + assert!(content.contains(r#""AtuinHistory""#)); + } + + #[tokio::test] + async fn does_not_duplicate_existing_rule() { + let dir = tempfile::tempdir().unwrap(); + let file = dir.path().join("permissions.ai.toml"); + let existing = r#"[permissions] +allow = ["AtuinHistory"] +"#; + tokio::fs::write(&file, existing).await.unwrap(); + + let rule = Rule { + tool: "AtuinHistory".to_string(), + scope: None, + }; + write_rule(&file, &rule, RuleDisposition::Allow) + .await + .unwrap(); + + let content = tokio::fs::read_to_string(&file).await.unwrap(); + // Should appear exactly once + assert_eq!(content.matches("AtuinHistory").count(), 1); + } + + #[tokio::test] + async fn handles_inline_table_permissions() { + let dir = tempfile::tempdir().unwrap(); + let file = dir.path().join("permissions.ai.toml"); + // Inline table style — as_table_mut() would return None for this + let existing = r#"permissions = { allow = ["Read"] } +"#; + tokio::fs::write(&file, existing).await.unwrap(); + + let rule = Rule { + tool: "AtuinHistory".to_string(), + scope: None, + }; + write_rule(&file, &rule, RuleDisposition::Allow) + .await + .unwrap(); + + let content = tokio::fs::read_to_string(&file).await.unwrap(); + assert!(content.contains(r#""Read""#)); + assert!(content.contains(r#""AtuinHistory""#)); + } + + #[tokio::test] + async fn writes_deny_rule() { + let dir = tempfile::tempdir().unwrap(); + let file = dir.path().join("permissions.ai.toml"); + let rule = Rule { + tool: "Shell".to_string(), + scope: None, + }; + + write_rule(&file, &rule, RuleDisposition::Deny) + .await + .unwrap(); + + let content = tokio::fs::read_to_string(&file).await.unwrap(); + assert!(content.contains("deny")); + assert!(content.contains(r#""Shell""#)); + } +} diff --git a/crates/atuin-ai/src/stream.rs b/crates/atuin-ai/src/stream.rs new file mode 100644 index 00000000..4673f2cd --- /dev/null +++ b/crates/atuin-ai/src/stream.rs @@ -0,0 +1,372 @@ +// ─────────────────────────────────────────────────────────────────── +// SSE streaming +// ─────────────────────────────────────────────────────────────────── + +use std::sync::mpsc; + +use atuin_client::settings::AiCapabilities; +use atuin_common::tls::ensure_crypto_provider; + +use eventsource_stream::Eventsource; +use eye_declare::Handle; +use eyre::{Context, Result}; +use futures::StreamExt; +use reqwest::Url; + +use crate::{ + context::{AppContext, ClientContext}, + tools::ClientToolCall, + tui::{Session, events::AiTuiEvent}, +}; + +/// Frames that alter the stream lifecycle — terminal or state-changing. +#[derive(Debug, Clone)] +pub(crate) enum StreamControl { + Done { session_id: String }, + Error(String), + StatusChanged(String), +} + +/// Frames that carry conversation content — they mutate the event log. +#[derive(Debug, Clone)] +pub(crate) enum StreamContent { + TextChunk(String), + ToolCall { + id: String, + name: String, + input: serde_json::Value, + }, + ToolResult { + tool_use_id: String, + content: String, + is_error: bool, + }, +} + +/// A frame from the SSE stream, classified as control or content. +#[derive(Debug, Clone)] +pub(crate) enum StreamFrame { + Content(StreamContent), + Control(StreamControl), +} + +/// Per-turn request payload for the chat API. +pub(crate) struct ChatRequest { + pub messages: Vec, + pub session_id: Option, + pub capabilities: Vec, +} + +impl ChatRequest { + pub(crate) fn new( + messages: Vec, + session_id: Option, + capabilities: &AiCapabilities, + ) -> Self { + let mut caps = vec![]; + if capabilities.enable_history_search.unwrap_or(true) { + caps.push("client_v1_atuin_history".to_string()); + } + if let Ok(extra) = std::env::var("ATUIN_AI__ADDITIONAL_CAPS") { + caps.extend( + extra + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()), + ); + } + + Self { + messages, + session_id, + capabilities: caps, + } + } +} + +fn create_chat_stream( + hub_address: String, + token: String, + request: ChatRequest, + client_ctx: ClientContext, + send_cwd: bool, + last_command: Option, +) -> std::pin::Pin> + Send>> { + Box::pin(async_stream::stream! { + ensure_crypto_provider(); + let endpoint = match hub_url(&hub_address, "/api/cli/chat") { + Ok(url) => url, + Err(e) => { + yield Err(e); + return; + } + }; + + tracing::debug!("Sending SSE request to {endpoint}"); + + let context = client_ctx.to_json(send_cwd, last_command.as_deref()); + + let mut request_body = serde_json::json!({ + "messages": request.messages, + "context": context, + "capabilities": request.capabilities, + }); + + if let Some(ref sid) = request.session_id { + tracing::trace!("Including session_id in request: {sid}"); + request_body["session_id"] = serde_json::json!(sid); + } + + let client = reqwest::Client::new(); + let response = match client + .post(endpoint.clone()) + .header("Accept", "text/event-stream") + .bearer_auth(&token) + .json(&request_body) + .send() + .await + { + Ok(resp) => resp, + Err(e) => { + yield Err(eyre::eyre!("Failed to send SSE request: {}", e)); + return; + } + }; + + let status = response.status(); + if status == reqwest::StatusCode::UNAUTHORIZED { + tracing::error!("SSE request failed with status: {status}, clearing session"); + let _ = atuin_client::hub::delete_session().await; + yield Err(eyre::eyre!("Hub session expired. Re-run to authenticate again.")); + return; + } + if !status.is_success() { + let body = response.text().await.unwrap_or_default(); + tracing::error!("SSE request failed ({}): {}", status, body); + yield Err(eyre::eyre!("SSE request failed ({}): {}", status, body)); + return; + } + + let byte_stream = response.bytes_stream(); + let mut stream = byte_stream.eventsource(); + + while let Some(event) = stream.next().await { + match event { + Ok(sse_event) => { + let event_type = sse_event.event.as_str(); + let data = sse_event.data.clone(); + + tracing::debug!(event_type = %event_type, "SSE event received"); + + match event_type { + "text" => { + if let Ok(json) = serde_json::from_str::(&data) + && let Some(content) = json.get("content").and_then(|v| v.as_str()) + { + yield Ok(StreamFrame::Content(StreamContent::TextChunk(content.to_string()))); + } + } + "tool_call" => { + if let Ok(json) = serde_json::from_str::(&data) { + let id = json.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let name = json.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let input = json.get("input").cloned().unwrap_or(serde_json::json!({})); + yield Ok(StreamFrame::Content(StreamContent::ToolCall { id, name, input })); + } + } + "tool_result" => { + if let Ok(json) = serde_json::from_str::(&data) { + let tool_use_id = json.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let content = json.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let is_error = json.get("is_error").and_then(|v| v.as_bool()).unwrap_or(false); + yield Ok(StreamFrame::Content(StreamContent::ToolResult { tool_use_id, content, is_error })); + } + } + "status" => { + if let Ok(json) = serde_json::from_str::(&data) + && let Some(state) = json.get("state").and_then(|v| v.as_str()) + { + yield Ok(StreamFrame::Control(StreamControl::StatusChanged(state.to_string()))); + } + } + "done" => { + if let Ok(json) = serde_json::from_str::(&data) { + let session_id = json.get("session_id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + yield Ok(StreamFrame::Control(StreamControl::Done { session_id })); + } else { + yield Ok(StreamFrame::Control(StreamControl::Done { session_id: String::new() })); + } + break; + } + "error" => { + if let Ok(json) = serde_json::from_str::(&data) { + let message = json.get("message").and_then(|v| v.as_str()).unwrap_or("Unknown error").to_string(); + tracing::error!("SSE error: {}", message); + yield Ok(StreamFrame::Control(StreamControl::Error(message))); + } else { + tracing::error!("SSE error: {}", data); + yield Ok(StreamFrame::Control(StreamControl::Error(data))); + } + break; + } + _ => {} + } + } + Err(e) => { + yield Err(eyre::eyre!("SSE error: {}", e)); + break; + } + } + } + }) +} + +// ─────────────────────────────────────────────────────────────────── +// Async streaming task — pushes updates to app state via Handle +// ─────────────────────────────────────────────────────────────────── + +pub(crate) async fn run_chat_stream( + handle: Handle, + tx: mpsc::Sender, + app_ctx: AppContext, + client_ctx: ClientContext, + request: ChatRequest, +) { + let capabilities = request.capabilities.clone(); + let stream = create_chat_stream( + app_ctx.endpoint.clone(), + app_ctx.token.clone(), + request, + client_ctx, + app_ctx.send_cwd, + app_ctx.last_command.clone(), + ); + futures::pin_mut!(stream); + + while let Some(event) = stream.next().await { + match event { + Ok(StreamFrame::Content(content)) => { + apply_content_frame(&handle, &tx, &capabilities, content); + } + Ok(StreamFrame::Control(control)) => { + let terminal = apply_control_frame(&handle, control); + if terminal { + break; + } + } + Err(e) => { + let msg = e.to_string(); + handle.update(move |state| { + state.streaming_error(msg); + }); + break; + } + } + } +} + +/// Apply a content frame to session state. +/// Control flow: always continues the stream. +fn apply_content_frame( + handle: &Handle, + tx: &mpsc::Sender, + capabilities: &[String], + content: StreamContent, +) { + match content { + StreamContent::TextChunk(text) => { + handle.update(move |state| { + state.conversation.append_streaming_text(&text); + }); + } + StreamContent::ToolCall { id, name, input } => { + if let Ok(tool) = ClientToolCall::try_from((name.as_str(), &input)) { + // Enforce capability gating: reject tool calls the client didn't advertise. + if let Some(required_cap) = tool.descriptor().capability + && !capabilities.iter().any(|c| c == required_cap) + { + tracing::warn!( + tool = name, + capability = required_cap, + "Rejecting tool call: capability not advertised" + ); + handle.update(move |state| { + state.add_tool_call(id.clone(), name, input.clone()); + state.conversation.add_tool_result( + id, + format!("Tool not enabled: capability '{required_cap}' was not advertised by this client"), + true, + ); + }); + return; + } + + // Client-side tool — add to tracker and conversation, queue permission check + let id_for_event = id.clone(); + handle.update(move |state| { + state.handle_client_tool_call(id_for_event, tool, input); + }); + let _ = tx.send(AiTuiEvent::CheckToolCallPermission(id)); + } else { + // Server-side tool — just add to conversation events + handle.update(move |state| { + state.add_tool_call(id, name, input); + }); + } + } + StreamContent::ToolResult { + tool_use_id, + content, + is_error, + } => { + handle.update(move |state| { + state + .conversation + .add_tool_result(tool_use_id, content, is_error); + }); + } + } +} + +/// Apply a control frame to session state. +/// Returns true if the stream should terminate. +fn apply_control_frame(handle: &Handle, control: StreamControl) -> bool { + match control { + StreamControl::StatusChanged(status) => { + handle.update(move |state| { + state.update_streaming_status(&status); + }); + false + } + StreamControl::Done { session_id } => { + handle.update(move |state| { + if !session_id.is_empty() { + state.conversation.store_session_id(session_id); + } + state.finalize_streaming(); + }); + true + } + StreamControl::Error(msg) => { + handle.update(move |state| { + state.streaming_error(msg); + }); + true + } + } +} + +fn hub_url(base: &str, path: &str) -> Result { + let base_with_slash = if base.ends_with('/') { + base.to_string() + } else { + format!("{base}/") + }; + let stripped = path.strip_prefix('/').unwrap_or(path); + Url::parse(&base_with_slash)? + .join(stripped) + .context("failed to build hub URL") +} diff --git a/crates/atuin-ai/src/tools/descriptor.rs b/crates/atuin-ai/src/tools/descriptor.rs new file mode 100644 index 00000000..3b2b7ebf --- /dev/null +++ b/crates/atuin-ai/src/tools/descriptor.rs @@ -0,0 +1,98 @@ +/// Centralized metadata for a tool type. +/// +/// Covers both client-side tools (ones the CLI executes locally) and +/// server-side tools (ones the API executes remotely). This is the single +/// source of truth for display text and classification. +pub(crate) struct ToolDescriptor { + /// Canonical wire names for this tool (the names the server sends). + pub canonical_names: &'static [&'static str], + /// The capability string the client must advertise for this tool to be + /// accepted. `None` for server-side tools (always accepted). + pub capability: Option<&'static str>, + /// Imperative verb for permission prompts (e.g. "read", "run"). + pub display_verb: &'static str, + /// Present-tense progressive verb for spinners (e.g. "Reading file..."). + pub progressive_verb: &'static str, + /// Past-tense verb for summaries (e.g. "Read file"). + pub past_verb: &'static str, + /// Whether this tool is executed client-side (by the CLI). + pub is_client: bool, +} + +// ── Client-side tool descriptors ── + +pub(crate) const READ: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["read_file"], + capability: Some("client_v1_read"), + display_verb: "read", + progressive_verb: "Reading file...", + past_verb: "Read file", + is_client: true, +}; + +pub(crate) const WRITE: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["str_replace", "file_create", "file_insert"], + capability: Some("client_v1_write"), + display_verb: "write to", + progressive_verb: "Writing file...", + past_verb: "Wrote file", + is_client: true, +}; + +pub(crate) const SHELL: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["execute_shell_command"], + capability: Some("client_v1_shell"), + display_verb: "run", + progressive_verb: "Running command...", + past_verb: "Ran command", + is_client: true, +}; + +pub(crate) const ATUIN_HISTORY: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["atuin_history"], + capability: Some("client_v1_atuin_history"), + display_verb: "search your Atuin history for", + progressive_verb: "Searching...", + past_verb: "Searched", + is_client: true, +}; + +// ── Server-side tool descriptors ── +// These appear in tool summaries but aren't client-side tools. + +pub(crate) const SERVER_SEARCH: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["web_search"], + capability: None, + display_verb: "search", + progressive_verb: "Searching...", + past_verb: "Searched", + is_client: false, +}; + +pub(crate) const SERVER_SCRAPE: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["web_scrape"], + capability: None, + display_verb: "scrape", + progressive_verb: "Scraping...", + past_verb: "Scraped", + is_client: false, +}; + +/// All known tool descriptors, for lookup by name. +const ALL_DESCRIPTORS: &[&ToolDescriptor] = &[ + READ, + WRITE, + SHELL, + ATUIN_HISTORY, + SERVER_SEARCH, + SERVER_SCRAPE, +]; + +/// Look up a tool descriptor by its canonical wire name. +/// Returns None for unknown tool names. +pub(crate) fn by_name(name: &str) -> Option<&'static ToolDescriptor> { + ALL_DESCRIPTORS + .iter() + .find(|d| d.canonical_names.contains(&name)) + .copied() +} diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs new file mode 100644 index 00000000..8f2183b7 --- /dev/null +++ b/crates/atuin-ai/src/tools/mod.rs @@ -0,0 +1,1111 @@ +use std::{ + io::BufRead, + path::{Path, PathBuf}, + time::Duration, +}; + +use eyre::Result; + +const DEFAULT_FILE_READ_LINES: u64 = 100; +const MAX_FILE_READ_LINES: u64 = 1000; + +pub(crate) mod descriptor; + +use crate::permissions::rule::Rule; + +/// Check whether a file path matches a scope glob pattern. +/// +/// Resolves relative paths against the current directory before matching so +/// that `./foo.md` and `/cwd/foo.md` match the same glob. Supports `*`, `**`, +/// `?`, and `[...]` via `glob_match`. +fn path_matches_scope(path: &Path, scope: &str) -> bool { + let path = if path.is_relative() { + std::env::current_dir() + .map(|cwd| cwd.join(path)) + .unwrap_or_else(|_| path.to_path_buf()) + } else { + path.to_path_buf() + }; + // Normalize to forward slashes so globs work on Windows too. + let path_str = path.to_string_lossy().replace('\\', "/"); + + // If the scope is also relative, try matching against both the absolute + // path and just the filename/relative portion. + if !scope.starts_with('/') { + // Match against filename (e.g. "*.md" matches any .md file) + if let Some(name) = path.file_name().and_then(|n| n.to_str()) + && glob_match::glob_match(scope, name) + { + return true; + } + // Also try matching against the full absolute path in case the scope + // is a relative multi-segment pattern like "crates/**/*.rs" + if glob_match::glob_match(scope, &path_str) { + return true; + } + // And match relative to cwd (so "src/*.rs" works from project root) + if let Ok(cwd) = std::env::current_dir() + && let Ok(rel) = path.strip_prefix(&cwd) + { + let rel_str = rel.to_string_lossy().replace('\\', "/"); + return glob_match::glob_match(scope, &rel_str); + } + return false; + } + + // Absolute scope — match against absolute path + glob_match::glob_match(scope, &path_str) +} + +/// Result of executing a client-side tool. +pub(crate) enum ToolOutcome { + /// Simple success with a text result (used by Read, AtuinHistory). + Success(String), + /// Error with a message. + Error(String), + /// Structured shell result with separated stdout, stderr, exit code, and duration. + Structured { + stdout: String, + stderr: String, + exit_code: Option, + duration_ms: u64, + interrupted: bool, + }, +} + +impl ToolOutcome { + /// Format this outcome as a string for the tool result sent to the LLM. + pub fn format_for_llm(&self) -> String { + match self { + ToolOutcome::Success(s) => s.clone(), + ToolOutcome::Error(e) => e.clone(), + ToolOutcome::Structured { + stdout, + stderr, + exit_code, + duration_ms, + interrupted, + } => { + let mut parts = Vec::new(); + + if let Some(code) = exit_code { + parts.push(format!("Exit code: {code}")); + } + + parts.push(format!("Duration: {duration_ms}ms")); + + if !stdout.is_empty() { + parts.push(format!("stdout:\n{stdout}")); + } else { + parts.push("stdout: (empty)".to_string()); + } + + if !stderr.is_empty() { + parts.push(format!("stderr:\n{stderr}")); + } else { + parts.push("stderr: (empty)".to_string()); + } + + if *interrupted { + parts.push("[Interrupted by user]".to_string()); + } + + parts.join("\n\n") + } + } + } + + /// Whether this outcome represents an error. + pub fn is_error(&self) -> bool { + match self { + ToolOutcome::Error(_) => true, + ToolOutcome::Structured { + exit_code: Some(code), + .. + } if *code != 0 => true, + _ => false, + } + } +} + +/// Cached VT100 preview data for a shell tool call. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct ToolPreview { + pub lines: Vec, + pub exit_code: Option, + pub interrupted: bool, +} + +/// Lifecycle phase of a tracked tool call. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum ToolPhase { + CheckingPermissions, + AskingForPermission, + #[expect(dead_code)] + Denied(String), + #[expect(dead_code)] + Executing, + /// Shell command is executing with live preview output. + ExecutingWithPreview { + command: String, + /// Current VT100 screen lines (plain text, viewport-sized). + output_lines: Vec, + /// Exit code once the process completes. + exit_code: Option, + /// Whether the command was interrupted by the user. + interrupted: bool, + }, + /// Tool execution has completed. Preview is cached for rendering history. + Completed { + preview: Option, + }, +} + +/// A tracked tool call through its full lifecycle. +#[derive(Debug)] +pub(crate) struct TrackedTool { + pub id: String, + pub tool: ClientToolCall, + pub phase: ToolPhase, + /// Sender to interrupt a running shell command (only set during ExecutingWithPreview). + pub abort_tx: Option>, +} + +impl TrackedTool { + pub(crate) fn target_dir(&self) -> Option<&Path> { + self.tool.target_dir() + } + + pub fn mark_asking(&mut self) { + self.phase = ToolPhase::AskingForPermission; + } + + pub fn mark_executing_preview(&mut self, command: String) { + self.phase = ToolPhase::ExecutingWithPreview { + command, + output_lines: Vec::new(), + exit_code: None, + interrupted: false, + }; + } + + pub fn complete(&mut self, preview: Option) { + self.phase = ToolPhase::Completed { preview }; + self.abort_tx = None; + } + + /// Extract the current preview, whether live or completed. + pub fn preview(&self) -> Option { + match &self.phase { + ToolPhase::ExecutingWithPreview { + output_lines, + exit_code, + interrupted, + .. + } => Some(ToolPreview { + lines: output_lines.clone(), + exit_code: *exit_code, + interrupted: *interrupted, + }), + ToolPhase::Completed { preview } => preview.clone(), + _ => None, + } + } +} + +/// Tracks all tool calls through their full lifecycle. +/// +/// Single source of truth for tool execution state. Entries persist after +/// completion so cached previews remain available for rendering history. +#[derive(Debug)] +pub(crate) struct ToolTracker { + tools: Vec, +} + +impl ToolTracker { + pub fn new() -> Self { + Self { tools: Vec::new() } + } + + /// Insert a new tool call in CheckingPermissions phase. + pub fn insert(&mut self, id: String, tool: ClientToolCall) { + self.tools.push(TrackedTool { + id, + tool, + phase: ToolPhase::CheckingPermissions, + abort_tx: None, + }); + } + + pub fn get(&self, id: &str) -> Option<&TrackedTool> { + self.tools.iter().find(|t| t.id == id) + } + + pub fn get_mut(&mut self, id: &str) -> Option<&mut TrackedTool> { + self.tools.iter_mut().find(|t| t.id == id) + } + + /// Remove a tool by ID and return it. + #[expect(dead_code)] + pub fn remove(&mut self, id: &str) -> Option { + let pos = self.tools.iter().position(|t| t.id == id)?; + Some(self.tools.remove(pos)) + } + + /// True if any tool is still awaiting a permission decision. + #[expect(dead_code)] + pub fn has_unresolved(&self) -> bool { + self.tools.iter().any(|t| { + matches!( + t.phase, + ToolPhase::CheckingPermissions | ToolPhase::AskingForPermission + ) + }) + } + + /// True if any tool has not yet reached the Completed phase. + /// Use this to gate `ContinueAfterTools` — we must wait for all tools + /// (including those still executing) before resuming the conversation. + pub fn has_pending(&self) -> bool { + self.tools + .iter() + .any(|t| !matches!(t.phase, ToolPhase::Completed { .. })) + } + + /// True if any tool is currently executing with a preview. + pub fn has_executing_preview(&self) -> bool { + self.tools + .iter() + .any(|t| matches!(t.phase, ToolPhase::ExecutingWithPreview { .. })) + } + + /// Find the first tool that is asking for permission. + pub fn asking_for_permission(&self) -> Option<&TrackedTool> { + self.tools + .iter() + .find(|t| t.phase == ToolPhase::AskingForPermission) + } + + /// Find the first tool that is asking for permission (mutable). + #[expect(dead_code)] + pub fn asking_for_permission_mut(&mut self) -> Option<&mut TrackedTool> { + self.tools + .iter_mut() + .find(|t| t.phase == ToolPhase::AskingForPermission) + } + + /// Get the preview for a tool by ID (live or cached). + pub fn preview_for(&self, id: &str) -> Option { + self.get(id)?.preview() + } + + /// Iterate mutably over all tracked tools. + pub fn iter_mut(&mut self) -> impl Iterator { + self.tools.iter_mut() + } +} + +/// A tool call from the server, with parsed input parameters. +#[derive(Debug, Clone)] +pub(crate) enum ClientToolCall { + Read(ReadToolCall), + Write(WriteToolCall), + Shell(ShellToolCall), + AtuinHistory(AtuinHistoryToolCall), +} + +impl TryFrom<(&str, &serde_json::Value)> for ClientToolCall { + type Error = eyre::Error; + + fn try_from((name, input): (&str, &serde_json::Value)) -> Result { + match name { + "read_file" => Ok(ClientToolCall::Read(ReadToolCall::try_from(input)?)), + "create_file" => Ok(ClientToolCall::Write(WriteToolCall::try_from(input)?)), + // "append_to_file" => Ok(ClientToolCall::Append(AppendToolCall::try_from(input)?)), + // "str_replace" => Ok(ClientToolCall::StrReplace(StrReplaceToolCall::try_from(input)?)), + "execute_shell_command" => Ok(ClientToolCall::Shell(ShellToolCall::try_from(input)?)), + "atuin_history" => Ok(ClientToolCall::AtuinHistory( + AtuinHistoryToolCall::try_from(input)?, + )), + _ => Err(eyre::eyre!("Unknown tool call: {name}")), + } + } +} + +impl ClientToolCall { + pub(crate) fn descriptor(&self) -> &'static descriptor::ToolDescriptor { + match self { + ClientToolCall::Read(_) => descriptor::READ, + ClientToolCall::Write(_) => descriptor::WRITE, + ClientToolCall::Shell(_) => descriptor::SHELL, + ClientToolCall::AtuinHistory(_) => descriptor::ATUIN_HISTORY, + } + } + + /// The permission rule name for this tool category (e.g. "Write" covers + /// str_replace, file_create, file_insert). + pub(crate) fn rule_name(&self) -> &'static str { + match self { + ClientToolCall::Read(_) => "Read", + ClientToolCall::Write(_) => "Write", + ClientToolCall::Shell(_) => "Shell", + ClientToolCall::AtuinHistory(_) => "AtuinHistory", + } + } + + pub(crate) fn matches_rule(&self, rule: &Rule) -> bool { + match self { + ClientToolCall::Read(tool) => tool.matches_rule(rule), + ClientToolCall::Write(tool) => tool.matches_rule(rule), + ClientToolCall::Shell(tool) => tool.matches_rule(rule), + ClientToolCall::AtuinHistory(tool) => tool.matches_rule(rule), + } + } + + pub(crate) fn target_dir(&self) -> Option<&Path> { + match self { + ClientToolCall::Read(tool) => tool.target_dir(), + ClientToolCall::Write(tool) => tool.target_dir(), + ClientToolCall::Shell(tool) => tool.target_dir(), + ClientToolCall::AtuinHistory(tool) => tool.target_dir(), + } + } + + /// Execute this client-side tool and return the result. + pub async fn execute(&self, db: &atuin_client::database::Sqlite) -> ToolOutcome { + match self { + ClientToolCall::Read(tool) => tool.execute(), + ClientToolCall::AtuinHistory(tool) => tool.execute(db).await, + _ => ToolOutcome::Error("Client-side tool execution not yet implemented".to_string()), + } + } +} + +/// A trait for tool calls that can be checked against permission rules. +pub(crate) trait PermissableToolCall { + /// Checks if this tool call matches the given permission rule. + fn matches_rule(&self, rule: &Rule) -> bool; + /// Returns the target directory of this tool call, if applicable, for checking against directory-based rules. + fn target_dir(&self) -> Option<&Path> { + None + } +} + +impl PermissableToolCall for ClientToolCall { + fn matches_rule(&self, rule: &Rule) -> bool { + self.matches_rule(rule) + } + + fn target_dir(&self) -> Option<&Path> { + self.target_dir() + } +} + +#[derive(Debug, Clone)] +pub(crate) struct ReadToolCall { + pub path: PathBuf, + pub offset: u64, + pub limit: u64, +} + +impl TryFrom<&serde_json::Value> for ReadToolCall { + type Error = eyre::Error; + + fn try_from(value: &serde_json::Value) -> Result { + let path = value + .get("file_path") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing path"))?; + + let offset = value.get("offset").and_then(|v| v.as_u64()).unwrap_or(0); + let limit = value + .get("limit") + .and_then(|v| v.as_u64()) + .unwrap_or(DEFAULT_FILE_READ_LINES) + .min(MAX_FILE_READ_LINES); + + Ok(ReadToolCall { + path: PathBuf::from(path), + offset, + limit, + }) + } +} + +impl ReadToolCall { + fn execute(&self) -> ToolOutcome { + let mut path = self.path.clone(); + + if path.is_relative() + && let Ok(current_dir) = std::env::current_dir() + { + path = current_dir.join(path); + } + + if !path.exists() { + return ToolOutcome::Error(format!("Error: file does not exist: {}", path.display())); + } + + if path.is_dir() { + let Some(files) = std::fs::read_dir(&path).ok().and_then(|entries| { + entries + .filter_map(|entry| entry.ok()) + .map(|entry| entry.file_name().to_string_lossy().to_string()) + .collect::>() + .into() + }) else { + return ToolOutcome::Error(format!( + "Error: could not read directory: {}", + path.display() + )); + }; + + return ToolOutcome::Success(format!("Directory contents:\n{}", files.join("\n"))); + } + + let file = match std::fs::File::open(&path) { + Ok(file) => file, + Err(e) => return ToolOutcome::Error(format!("Error opening file: {e}")), + }; + let reader = std::io::BufReader::new(file); + + let relevent_lines = reader + .lines() + .skip(self.offset as usize) + .take(self.limit as usize) + .collect::, _>>(); + + match relevent_lines { + Ok(lines) => { + let joined = lines.join("\n"); + if joined.len() > 100_000 { + ToolOutcome::Error(format!( + "Error: file is too large to read ({} bytes in {} lines); use view_range to read a subset of the file", + joined.len(), + lines.len() + )) + } else { + ToolOutcome::Success(joined) + } + } + Err(e) => ToolOutcome::Error(format!("Error reading file: {e}")), + } + } +} + +impl PermissableToolCall for ReadToolCall { + fn target_dir(&self) -> Option<&Path> { + Some(&self.path) + } + + fn matches_rule(&self, rule: &Rule) -> bool { + if rule.tool != "Read" { + return false; + } + + match rule.scope.as_deref() { + None | Some("*") => true, + Some(scope) => path_matches_scope(&self.path, scope), + } + } +} + +#[derive(Debug, Clone)] +#[expect(dead_code)] +pub(crate) struct WriteToolCall { + pub path: PathBuf, + pub content: String, +} + +impl TryFrom<&serde_json::Value> for WriteToolCall { + type Error = eyre::Error; + + fn try_from(value: &serde_json::Value) -> Result { + let path = value + .get("path") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing path"))?; + + let content = value + .get("content") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing content"))?; + + Ok(WriteToolCall { + path: PathBuf::from(path), + content: content.to_string(), + }) + } +} + +impl PermissableToolCall for WriteToolCall { + fn target_dir(&self) -> Option<&Path> { + Some(&self.path) + } + + fn matches_rule(&self, rule: &Rule) -> bool { + if rule.tool != "Write" { + return false; + } + + match rule.scope.as_deref() { + None | Some("*") => true, + Some(scope) => path_matches_scope(&self.path, scope), + } + } +} + +#[derive(Debug, Clone)] +pub(crate) struct ShellToolCall { + pub dir: Option, + pub command: String, + pub shell: String, +} + +impl TryFrom<&serde_json::Value> for ShellToolCall { + type Error = eyre::Error; + + fn try_from(value: &serde_json::Value) -> Result { + let dir = value.get("dir").and_then(|v| v.as_str()); + + let command = value + .get("command") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing command"))?; + + let shell = value + .get("shell") + .and_then(|v| v.as_str()) + .unwrap_or("bash") + .to_string(); + + Ok(ShellToolCall { + dir: dir.map(PathBuf::from), + command: command.to_string(), + shell, + }) + } +} + +impl PermissableToolCall for ShellToolCall { + fn target_dir(&self) -> Option<&Path> { + self.dir.as_deref() + } + + fn matches_rule(&self, rule: &Rule) -> bool { + if rule.tool != "Shell" { + return false; + } + + let Some(scope) = rule.scope.as_ref() else { + // Shell without scope matches all shell commands + return true; + }; + + let shell_kind = crate::permissions::shell::ShellKind::from_shell_name(&self.shell); + let parsed = crate::permissions::shell::parse_shell_command(&self.command, shell_kind); + crate::permissions::shell::any_subcommand_matches(&parsed.subcommands, scope) + } +} + +/// Preview viewport height for VT100 emulation. +const PREVIEW_HEIGHT: u16 = 10; + +/// Default terminal width for VT100 emulation. +const PREVIEW_WIDTH: u16 = 120; + +/// Extract plain text lines from a VT100 screen buffer. +fn vt100_screen_lines(screen: &vt100::Screen) -> Vec { + let (rows, cols) = screen.size(); + let mut lines = Vec::with_capacity(rows as usize); + for row in 0..rows { + let mut line = String::with_capacity(cols as usize); + for col in 0..cols { + if let Some(cell) = screen.cell(row, col) { + line.push_str(cell.contents()); + } + } + // Trim trailing whitespace for cleaner display + lines.push(line.trim_end().to_string()); + } + lines +} + +/// Strip ANSI escape sequences from raw bytes using a VT100 parser. +/// +/// Uses a large virtual screen so scrollback is preserved, then extracts +/// the plain text contents. This handles all escape sequences (colors, +/// cursor movement, progress bars, etc.) not just simple SGR codes. +fn strip_ansi_via_vt100(raw: &[u8]) -> String { + if raw.is_empty() { + return String::new(); + } + // Use the contents_formatted → screen approach: feed bytes into a parser + // with enough rows to hold everything, then read back the plain text. + // Estimate rows: one row per ~PREVIEW_WIDTH bytes, plus generous padding. + let estimated_rows = (raw.len() / PREVIEW_WIDTH as usize + 1).min(10_000) as u16; + let mut parser = vt100::Parser::new(estimated_rows, PREVIEW_WIDTH, 0); + parser.process(raw); + let screen = parser.screen(); + // screen.contents() returns the full plain-text content with trailing + // whitespace trimmed per line and trailing blank lines removed. + screen.contents() +} + +/// Execute a shell command with VT100 emulation and streaming output. +/// +/// Feeds stdout+stderr into a `vt100::Parser` so that ANSI escape sequences, +/// progress bars (`\r`), and cursor movement are handled correctly. Periodically +/// sends the current screen state as `Vec` through `output_tx` for the +/// live preview. +/// +/// Captures the FULL stdout and stderr separately for the tool result sent to the LLM. +/// Returns a `ToolOutcome::Structured` with full output, exit code, and duration. +pub(crate) async fn execute_shell_command_streaming( + shell_call: &ShellToolCall, + output_tx: tokio::sync::mpsc::Sender>, + mut interrupt_rx: tokio::sync::oneshot::Receiver<()>, +) -> ToolOutcome { + use tokio::io::AsyncReadExt; + + let start = std::time::Instant::now(); + + // TODO: check if this is proper for all shells we support + let mut cmd = tokio::process::Command::new(&shell_call.shell); + cmd.arg("-c").arg(&shell_call.command); + cmd.stdout(std::process::Stdio::piped()); + cmd.stderr(std::process::Stdio::piped()); + + if let Some(ref dir) = shell_call.dir { + cmd.current_dir(dir); + } + + let mut child = match cmd.spawn() { + Ok(child) => child, + Err(e) => return ToolOutcome::Error(format!("Failed to spawn command: {e}")), + }; + + let stdout = child.stdout.take().expect("stdout was piped"); + let stderr = child.stderr.take().expect("stderr was piped"); + + // VT100 emulator for the live preview (viewport-sized) + let mut parser = vt100::Parser::new(PREVIEW_HEIGHT, PREVIEW_WIDTH, 0); + + let mut stdout_reader = tokio::io::BufReader::new(stdout); + let mut stderr_reader = tokio::io::BufReader::new(stderr); + + let mut stdout_buf = [0u8; 4096]; + let mut stderr_buf = [0u8; 4096]; + let mut stdout_done = false; + let mut stderr_done = false; + + // Full output buffers (for the LLM, not the preview) + let mut full_stdout = Vec::::new(); + let mut full_stderr = Vec::::new(); + + let mut interval = tokio::time::interval(Duration::from_millis(50)); + + // Send initial empty screen + let initial_lines = vt100_screen_lines(parser.screen()); + let _ = output_tx.send(initial_lines).await; + + let mut interrupted = false; + + loop { + tokio::select! { + biased; + + // Check for interrupt signal + _ = &mut interrupt_rx, if !interrupted => { + interrupted = true; + let _ = child.start_kill(); + } + + // Read stdout + result = stdout_reader.read(&mut stdout_buf), if !stdout_done => { + match result { + Ok(0) => stdout_done = true, + Ok(n) => { + full_stdout.extend_from_slice(&stdout_buf[..n]); + parser.process(&stdout_buf[..n]); + } + Err(_) => stdout_done = true, + } + } + + // Read stderr + result = stderr_reader.read(&mut stderr_buf), if !stderr_done => { + match result { + Ok(0) => stderr_done = true, + Ok(n) => { + full_stderr.extend_from_slice(&stderr_buf[..n]); + // Feed stderr to the preview parser too, so it shows in the VT100 screen + parser.process(&stderr_buf[..n]); + } + Err(_) => stderr_done = true, + } + } + + // Periodic screen snapshot for preview + _ = interval.tick() => { + let lines = vt100_screen_lines(parser.screen()); + let _ = output_tx.send(lines).await; + } + } + + // Exit when both streams are done + if stdout_done && stderr_done { + break; + } + } + + // Wait for process to finish + let exit_code = match child.wait().await { + Ok(status) => status.code(), + Err(e) => { + if interrupted { + None + } else { + return ToolOutcome::Error(format!("Failed to wait for command: {e}")); + } + } + }; + + let duration = start.elapsed(); + + // Send final screen state + let final_lines = vt100_screen_lines(parser.screen()); + let _ = output_tx.send(final_lines).await; + + // Strip ANSI escape sequences for clean LLM output by running + // the raw bytes through a VT100 parser and extracting plain text. + let stdout_text = strip_ansi_via_vt100(&full_stdout); + let stderr_text = strip_ansi_via_vt100(&full_stderr); + + ToolOutcome::Structured { + stdout: stdout_text, + stderr: stderr_text, + exit_code, + duration_ms: duration.as_millis() as u64, + interrupted, + } +} + +#[derive(Debug, Clone)] +pub(crate) struct AtuinHistoryToolCall { + pub filter_modes: Vec, + pub query: String, + pub limit: i64, +} + +#[derive(Debug, Clone)] +pub(crate) enum HistorySearchFilterMode { + Global, + Host, + Session, + Directory, + Workspace, +} + +impl From<&HistorySearchFilterMode> for atuin_client::settings::FilterMode { + fn from(mode: &HistorySearchFilterMode) -> Self { + match mode { + HistorySearchFilterMode::Global => Self::Global, + HistorySearchFilterMode::Host => Self::Host, + HistorySearchFilterMode::Session => Self::Session, + HistorySearchFilterMode::Directory => Self::Directory, + HistorySearchFilterMode::Workspace => Self::Workspace, + } + } +} + +impl TryFrom<&serde_json::Value> for AtuinHistoryToolCall { + type Error = eyre::Error; + + fn try_from(value: &serde_json::Value) -> Result { + let filter_modes = value + .get("filter_modes") + .and_then(|v| v.as_array()) + .ok_or(eyre::eyre!("Missing filter_modes"))?; + + let filter_modes = filter_modes + .iter() + .map(|v| { + let mode = v.as_str().ok_or(eyre::eyre!("Invalid filter mode"))?; + match mode { + "global" => Ok(HistorySearchFilterMode::Global), + "host" => Ok(HistorySearchFilterMode::Host), + "session" => Ok(HistorySearchFilterMode::Session), + "directory" => Ok(HistorySearchFilterMode::Directory), + "workspace" => Ok(HistorySearchFilterMode::Workspace), + _ => Err(eyre::eyre!("Invalid filter mode: {mode}")), + } + }) + .collect::>>()?; + + let query = value + .get("query") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing query"))?; + + let limit = value + .get("limit") + .and_then(|v| v.as_i64()) + .unwrap_or(10) + .clamp(1, 50); + + Ok(AtuinHistoryToolCall { + filter_modes, + query: query.to_string(), + limit, + }) + } +} + +impl PermissableToolCall for AtuinHistoryToolCall { + fn target_dir(&self) -> Option<&Path> { + None + } + + fn matches_rule(&self, rule: &Rule) -> bool { + rule.tool == "AtuinHistory" + } +} + +impl AtuinHistoryToolCall { + pub(crate) async fn execute(&self, db: &atuin_client::database::Sqlite) -> ToolOutcome { + use atuin_client::database::{self, Database as _, OptFilters}; + use atuin_client::settings::SearchMode; + use time::UtcOffset; + + let context = match database::current_context().await { + Ok(ctx) => ctx, + Err(e) => return ToolOutcome::Error(format!("Failed to get history context: {e}")), + }; + + let filter_mode = self + .filter_modes + .first() + .map(atuin_client::settings::FilterMode::from) + .unwrap_or(atuin_client::settings::FilterMode::Global); + + let filter_options = OptFilters { + limit: Some(self.limit), + ..Default::default() + }; + + let results = match db + .search( + SearchMode::Fuzzy, + filter_mode, + &context, + &self.query, + filter_options, + ) + .await + { + Ok(results) => results, + Err(e) => return ToolOutcome::Error(format!("History search failed: {e}")), + }; + + if results.is_empty() { + return ToolOutcome::Success("No matching history entries found.".to_string()); + } + + let local_offset = UtcOffset::current_local_offset().unwrap_or(UtcOffset::UTC); + + let formatted: Vec = results + .iter() + .enumerate() + .map(|(i, h)| { + let ts = h.timestamp.to_offset(local_offset); + let time_str = format!( + "{:04}-{:02}-{:02} {:02}:{:02}:{:02}", + ts.year(), + ts.month() as u8, + ts.day(), + ts.hour(), + ts.minute(), + ts.second(), + ); + + let duration_str = format_duration(h.duration); + + format!( + "{}. `{}` [{}] ({}, exit: {}){}", + i + 1, + h.command, + time_str, + h.cwd, + h.exit, + duration_str, + ) + }) + .collect(); + + ToolOutcome::Success(formatted.join("\n")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn read_rule(scope: Option<&str>) -> Rule { + Rule { + tool: "Read".to_string(), + scope: scope.map(String::from), + } + } + + fn write_rule(scope: Option<&str>) -> Rule { + Rule { + tool: "Write".to_string(), + scope: scope.map(String::from), + } + } + + fn read_tool(path: &str) -> ReadToolCall { + ReadToolCall { + path: PathBuf::from(path), + offset: 0, + limit: 100, + } + } + + fn write_tool(path: &str) -> WriteToolCall { + WriteToolCall { + path: PathBuf::from(path), + content: String::new(), + } + } + + // ── Cross-platform tests ── + + #[test] + fn no_scope_matches_everything() { + assert!(read_tool("any/path.txt").matches_rule(&read_rule(None))); + assert!(write_tool("any/path.txt").matches_rule(&write_rule(None))); + } + + #[test] + fn wildcard_star_matches_everything() { + assert!(read_tool("foo/bar.rs").matches_rule(&read_rule(Some("*")))); + } + + #[test] + fn wrong_tool_never_matches() { + assert!(!read_tool("foo.txt").matches_rule(&write_rule(None))); + assert!(!write_tool("foo.txt").matches_rule(&read_rule(None))); + } + + #[test] + fn extension_glob() { + assert!(read_tool("notes.md").matches_rule(&read_rule(Some("*.md")))); + assert!(!read_tool("notes.txt").matches_rule(&read_rule(Some("*.md")))); + } + + #[test] + fn relative_multi_segment_glob() { + // This matches against the path relative to cwd + let cwd = std::env::current_dir().unwrap(); + let abs = cwd + .join("crates") + .join("atuin-ai") + .join("src") + .join("lib.rs"); + let tool = read_tool(abs.to_str().unwrap()); + assert!(tool.matches_rule(&read_rule(Some("crates/**/*.rs")))); + assert!(!tool.matches_rule(&read_rule(Some("crates/**/*.py")))); + } + + // ── Unix-specific tests (absolute paths with forward slashes) ── + + #[cfg(unix)] + mod unix { + use super::*; + + #[test] + fn absolute_glob() { + assert!( + read_tool("/home/user/src/main.rs") + .matches_rule(&read_rule(Some("/home/user/src/*.rs"))) + ); + assert!( + !read_tool("/home/user/docs/readme.md") + .matches_rule(&read_rule(Some("/home/user/src/*.rs"))) + ); + } + + #[test] + fn double_star_glob() { + assert!( + read_tool("/project/crates/foo/src/lib.rs") + .matches_rule(&read_rule(Some("/project/crates/**/*.rs"))) + ); + assert!( + !read_tool("/project/crates/foo/src/lib.py") + .matches_rule(&read_rule(Some("/project/crates/**/*.rs"))) + ); + } + } + + // ── Windows-specific tests (absolute paths with drive letters) ── + + #[cfg(windows)] + mod windows { + use super::*; + + #[test] + fn absolute_glob() { + assert!( + read_tool(r"C:\Users\dev\src\main.rs") + .matches_rule(&read_rule(Some("C:/Users/dev/src/*.rs"))) + ); + assert!( + !read_tool(r"C:\Users\dev\docs\readme.md") + .matches_rule(&read_rule(Some("C:/Users/dev/src/*.rs"))) + ); + } + + #[test] + fn double_star_glob() { + assert!( + read_tool(r"C:\project\crates\foo\src\lib.rs") + .matches_rule(&read_rule(Some("C:/project/crates/**/*.rs"))) + ); + assert!( + !read_tool(r"C:\project\crates\foo\src\lib.py") + .matches_rule(&read_rule(Some("C:/project/crates/**/*.rs"))) + ); + } + } +} + +fn format_duration(nanos: i64) -> String { + if nanos <= 0 { + return String::new(); + } + + let total_secs = nanos / 1_000_000_000; + let millis = (nanos % 1_000_000_000) / 1_000_000; + + if total_secs >= 3600 { + let hours = total_secs / 3600; + let mins = (total_secs % 3600) / 60; + let secs = total_secs % 60; + format!(", {hours}h{mins}m{secs}s") + } else if total_secs >= 60 { + let mins = total_secs / 60; + let secs = total_secs % 60; + format!(", {mins}m{secs}s") + } else if total_secs > 0 { + if millis > 0 { + format!(", {total_secs}.{millis:03}s") + } else { + format!(", {total_secs}s") + } + } else { + format!(", {millis}ms") + } +} diff --git a/crates/atuin-ai/src/tui/components/atuin_ai.rs b/crates/atuin-ai/src/tui/components/atuin_ai.rs index fab29502..c04ac722 100644 --- a/crates/atuin-ai/src/tui/components/atuin_ai.rs +++ b/crates/atuin-ai/src/tui/components/atuin_ai.rs @@ -22,10 +22,11 @@ pub(crate) struct AtuinAi { pub has_command: bool, pub is_input_blank: bool, pub pending_confirmation: bool, + pub has_executing_preview: bool, } #[derive(Default)] -pub struct AtuinAiState { +pub(crate) struct AtuinAiState { tx: Option>, } @@ -55,15 +56,24 @@ fn atuin_ai( return EventResult::Ignored; }; - // Ctrl+C always exits + // Ctrl+C — interrupt executing command or exit if modifiers.contains(KeyModifiers::CONTROL) && *code == KeyCode::Char('c') { - let _ = tx.send(AiTuiEvent::Exit); + if props.has_executing_preview { + let _ = tx.send(AiTuiEvent::InterruptToolExecution); + } else { + let _ = tx.send(AiTuiEvent::Exit); + } return EventResult::Consumed; } match props.mode { AppMode::Input => match code { KeyCode::Esc => { + if props.has_executing_preview { + let _ = tx.send(AiTuiEvent::InterruptToolExecution); + return EventResult::Consumed; + } + if props.pending_confirmation { let _ = tx.send(AiTuiEvent::CancelConfirmation); return EventResult::Consumed; diff --git a/crates/atuin-ai/src/tui/components/markdown.rs b/crates/atuin-ai/src/tui/components/markdown.rs index 1cd7dbcf..f164fdc5 100644 --- a/crates/atuin-ai/src/tui/components/markdown.rs +++ b/crates/atuin-ai/src/tui/components/markdown.rs @@ -16,20 +16,12 @@ use ratatui_widgets::paragraph::{Paragraph, Wrap}; /// A markdown rendering component backed by pulldown-cmark. #[props] -pub struct Markdown { +pub(crate) struct Markdown { pub source: String, } -impl Markdown { - pub fn new(source: impl Into) -> Self { - Self { - source: source.into(), - } - } -} - /// Style configuration for markdown rendering. -pub struct MarkdownStyles { +pub(crate) struct MarkdownStyles { pub base: Style, pub code_inline: Style, pub code_block: Style, @@ -98,26 +90,22 @@ fn parse_markdown<'a>(source: &'a str, styles: &'a MarkdownStyles) -> Text<'stat let mut style_stack: Vec