diff options
| author | Michelle Tilley <michelle@michelletilley.net> | 2026-04-10 13:24:57 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-04-10 20:24:57 +0000 |
| commit | 09279a428659cf41824737d3e0c97bcc19a8885a (patch) | |
| tree | 64731502c065df2483e8dd680d46c5559f3094f2 /crates/atuin-ai/src | |
| parent | feat: add strip_trailing_whitespace, on by default (#3390) (diff) | |
| download | atuin-09279a428659cf41824737d3e0c97bcc19a8885a.zip | |
feat: Client-tool execution + permission system (#3370)
Adds client-side tool execution to Atuin AI, starting with
`atuin_history`. The server can request tool calls, which are executed
locally with a permission system, and results are sent back to continue
the conversation.
Diffstat (limited to 'crates/atuin-ai/src')
26 files changed, 4993 insertions, 848 deletions
diff --git a/crates/atuin-ai/src/commands.rs b/crates/atuin-ai/src/commands.rs index 6e79da61..cdbc8f2d 100644 --- a/crates/atuin-ai/src/commands.rs +++ b/crates/atuin-ai/src/commands.rs @@ -9,7 +9,7 @@ use eyre::Result; use tracing_appender::rolling::{RollingFileAppender, Rotation}; use tracing_subscriber::{EnvFilter, Layer, fmt, layer::SubscriberExt, util::SubscriberInitExt}; pub mod init; -pub mod inline; +pub(crate) mod inline; #[derive(Args, Debug)] pub struct AiArgs { @@ -71,7 +71,7 @@ pub async fn run( } } -pub fn detect_shell() -> Option<String> { +pub(crate) fn detect_shell() -> Option<String> { Some(Shell::current().to_string()) } diff --git a/crates/atuin-ai/src/commands/init.rs b/crates/atuin-ai/src/commands/init.rs index 77abc4f4..f693d892 100644 --- a/crates/atuin-ai/src/commands/init.rs +++ b/crates/atuin-ai/src/commands/init.rs @@ -1,6 +1,6 @@ use crate::commands::detect_shell; -pub async fn run(shell: String) -> eyre::Result<()> { +pub(crate) async fn run(shell: String) -> eyre::Result<()> { let integration = match shell.as_str() { "zsh" => generate_zsh_integration(), "bash" => generate_bash_integration(), diff --git a/crates/atuin-ai/src/commands/inline.rs b/crates/atuin-ai/src/commands/inline.rs index aeb414fb..b37bb72f 100644 --- a/crates/atuin-ai/src/commands/inline.rs +++ b/crates/atuin-ai/src/commands/inline.rs @@ -1,38 +1,44 @@ use std::path::PathBuf; use std::sync::mpsc; -use crate::commands::detect_shell; +use crate::context::{AppContext, ClientContext}; +use crate::tui::dispatch; use crate::tui::events::AiTuiEvent; -use crate::tui::state::{AppState, ExitAction}; +use crate::tui::state::{ExitAction, Session}; use crate::tui::view::ai_view; use atuin_client::database::{Database, Sqlite}; -use atuin_client::distro::detect_linux_distribution; -use atuin_common::tls::ensure_crypto_provider; -use eventsource_stream::Eventsource; -use eye_declare::{Application, CtrlCBehavior, Handle}; +use eye_declare::{Application, CtrlCBehavior}; use eyre::{Context as _, Result, bail}; -use futures::StreamExt; -use reqwest::Url; -use tracing::{debug, error, info, trace}; +use tracing::{debug, info}; -pub async fn run( +pub(crate) async fn run( initial_command: Option<String>, api_endpoint: Option<String>, api_token: Option<String>, settings: &atuin_client::settings::Settings, output_for_hook: bool, ) -> Result<()> { - if !settings.ai.enabled.unwrap_or(false) { - emit_shell_result( - Action::Print( - "Atuin AI is not enabled. Please enable it in your settings or run `atuin setup`." - .to_string(), - ), - output_for_hook, - ); + if settings.ai.enabled == Some(false) { return Ok(()); } + if settings.ai.enabled.is_none() { + match prompt_ai_setup()? { + SetupChoice::EnableAi => { + set_ai_enabled(true).await?; + } + SetupChoice::DisableKeybind => { + set_ai_enabled(false).await?; + emit_shell_result(Action::Cancel, output_for_hook); + return Ok(()); + } + SetupChoice::Cancel => { + emit_shell_result(Action::Cancel, output_for_hook); + return Ok(()); + } + } + } + let endpoint = api_endpoint.as_deref().unwrap_or( settings .ai @@ -48,7 +54,36 @@ pub async fn run( ensure_hub_session(settings).await? }; - let action = run_inline_tui(endpoint.to_string(), token, initial_command, settings).await?; + let history_db_path = PathBuf::from(settings.db_path.as_str()); + let history_db = Sqlite::new(history_db_path, settings.local_timeout) + .await + .context("failed to open history database for AI")?; + + // Support both legacy [ai] send_cwd and new [ai.opening] send_cwd + let send_cwd = + settings.ai.opening.send_cwd.unwrap_or(false) || settings.ai.send_cwd.unwrap_or(false); + + let last_command = if settings.ai.opening.send_last_command.unwrap_or(false) { + history_db.last().await.ok().flatten().map(|h| h.command) + } else { + None + }; + + let git_root = std::env::current_dir() + .ok() + .and_then(|cwd| atuin_common::utils::in_git_repo(cwd.to_str()?)); + + let ctx = AppContext { + endpoint: endpoint.to_string(), + token, + send_cwd, + last_command, + history_db: std::sync::Arc::new(history_db), + git_root, + capabilities: settings.ai.capabilities.clone(), + }; + + let action = run_inline_tui(ctx, initial_command).await?; emit_shell_result(action, output_for_hook); Ok(()) @@ -69,7 +104,7 @@ async fn ensure_hub_session(settings: &atuin_client::settings::Settings) -> Resu if will_sync { println!( "Once logged in, your shell history will be synchronized via Atuin Hub if auto_sync is enabled or when manually syncing." - ) + ); } println!( "If you have an existing Atuin sync account, you can log in with your existing credentials." @@ -111,279 +146,16 @@ async fn ensure_hub_session(settings: &atuin_client::settings::Settings) -> Resu } // ─────────────────────────────────────────────────────────────────── -// SSE streaming -// ─────────────────────────────────────────────────────────────────── -#[derive(Debug, Clone)] -enum ChatStreamEvent { - TextChunk(String), - ToolCall { - id: String, - name: String, - input: serde_json::Value, - }, - ToolResult { - tool_use_id: String, - content: String, - is_error: bool, - }, - Status(String), - Done { - session_id: String, - }, - Error(String), -} - -fn create_chat_stream( - hub_address: String, - token: String, - session_id: Option<String>, - messages: Vec<serde_json::Value>, - send_cwd: bool, - last_command: Option<String>, -) -> std::pin::Pin<Box<dyn futures::Stream<Item = Result<ChatStreamEvent>> + Send>> { - Box::pin(async_stream::stream! { - ensure_crypto_provider(); - let endpoint = match hub_url(&hub_address, "/api/cli/chat") { - Ok(url) => url, - Err(e) => { - yield Err(e); - return; - } - }; - - debug!("Sending SSE request to {endpoint}"); - - let os = detect_os(); - let shell = detect_shell(); - - let mut context = serde_json::json!({ - "os": os, - "shell": shell, - "pwd": if send_cwd { std::env::current_dir() - .ok() - .map(|path| path.to_string_lossy().into_owned()) } else { None }, - "last_command": last_command, - }); - - if os == "linux" { - context["distro"] = serde_json::json!(detect_linux_distribution()); - } - - let mut request_body = serde_json::json!({ - "messages": messages, - "context": context, - }); - - if let Some(ref sid) = session_id { - trace!("Including session_id in request: {sid}"); - request_body["session_id"] = serde_json::json!(sid); - } - - let client = reqwest::Client::new(); - let response = match client - .post(endpoint.clone()) - .header("Accept", "text/event-stream") - .bearer_auth(&token) - .json(&request_body) - .send() - .await - { - Ok(resp) => resp, - Err(e) => { - yield Err(eyre::eyre!("Failed to send SSE request: {}", e)); - return; - } - }; +async fn run_inline_tui(ctx: AppContext, initial_prompt: Option<String>) -> Result<Action> { + let client_ctx = ClientContext::detect(); - let status = response.status(); - if status == reqwest::StatusCode::UNAUTHORIZED { - error!("SSE request failed with status: {status}, clearing session"); - let _ = atuin_client::hub::delete_session().await; - yield Err(eyre::eyre!("Hub session expired. Re-run to authenticate again.")); - return; - } - if !status.is_success() { - let body = response.text().await.unwrap_or_default(); - error!("SSE request failed ({}): {}", status, body); - yield Err(eyre::eyre!("SSE request failed ({}): {}", status, body)); - return; - } - - let byte_stream = response.bytes_stream(); - let mut stream = byte_stream.eventsource(); - - while let Some(event) = stream.next().await { - match event { - Ok(sse_event) => { - let event_type = sse_event.event.as_str(); - let data = sse_event.data.clone(); - - debug!(event_type = %event_type, "SSE event received"); - - match event_type { - "text" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) - && let Some(content) = json.get("content").and_then(|v| v.as_str()) - { - yield Ok(ChatStreamEvent::TextChunk(content.to_string())); - } - } - "tool_call" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { - let id = json.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let name = json.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let input = json.get("input").cloned().unwrap_or(serde_json::json!({})); - yield Ok(ChatStreamEvent::ToolCall { id, name, input }); - } - } - "tool_result" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { - let tool_use_id = json.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let content = json.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let is_error = json.get("is_error").and_then(|v| v.as_bool()).unwrap_or(false); - yield Ok(ChatStreamEvent::ToolResult { tool_use_id, content, is_error }); - } - } - "status" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) - && let Some(state) = json.get("state").and_then(|v| v.as_str()) - { - yield Ok(ChatStreamEvent::Status(state.to_string())); - } - } - "done" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { - let session_id = json.get("session_id") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - yield Ok(ChatStreamEvent::Done { session_id }); - } else { - yield Ok(ChatStreamEvent::Done { session_id: String::new() }); - } - break; - } - "error" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { - let message = json.get("message").and_then(|v| v.as_str()).unwrap_or("Unknown error").to_string(); - error!("SSE error: {}", message); - yield Ok(ChatStreamEvent::Error(message)); - } else { - error!("SSE error: {}", data); - yield Ok(ChatStreamEvent::Error(data)); - } - break; - } - _ => {} - } - } - Err(e) => { - yield Err(eyre::eyre!("SSE error: {}", e)); - break; - } - } - } - }) -} - -// ─────────────────────────────────────────────────────────────────── -// Async streaming task — pushes updates to app state via Handle -// ─────────────────────────────────────────────────────────────────── - -async fn run_chat_stream( - handle: Handle<AppState>, - endpoint: String, - token: String, - session_id: Option<String>, - messages: Vec<serde_json::Value>, - send_cwd: bool, - last_command: Option<String>, -) { - let stream = create_chat_stream( - endpoint, - token, - session_id, - messages, - send_cwd, - last_command, - ); - futures::pin_mut!(stream); - - while let Some(event) = stream.next().await { - match event { - Ok(ChatStreamEvent::TextChunk(text)) => { - trace!(text = %text, "Processing TextChunk"); - handle.update(move |state| { - state.append_streaming_text(&text); - }); - } - Ok(ChatStreamEvent::ToolCall { id, name, input }) => { - trace!(id = %id, name = %name, "Processing ToolCall"); - handle.update(move |state| { - state.add_tool_call(id, name, input); - }); - } - Ok(ChatStreamEvent::ToolResult { - tool_use_id, - content, - is_error, - }) => { - trace!(tool_use_id = %tool_use_id, "Processing ToolResult"); - handle.update(move |state| { - state.add_tool_result(tool_use_id, content, is_error); - }); - } - Ok(ChatStreamEvent::Status(status)) => { - trace!(status = %status, "Processing Status"); - handle.update(move |state| { - state.update_streaming_status(&status); - }); - } - Ok(ChatStreamEvent::Done { session_id }) => { - trace!(session_id = %session_id, "Processing Done"); - handle.update(move |state| { - if !session_id.is_empty() { - state.store_session_id(session_id); - } - state.finalize_streaming(); - }); - break; - } - Ok(ChatStreamEvent::Error(msg)) => { - trace!(error = %msg, "Processing Error"); - handle.update(move |state| { - state.streaming_error(msg); - }); - break; - } - Err(e) => { - let msg = e.to_string(); - handle.update(move |state| { - state.streaming_error(msg); - }); - break; - } - } - } -} - -// ─────────────────────────────────────────────────────────────────── -// Main TUI entry point -// ─────────────────────────────────────────────────────────────────── + let (tx, rx) = mpsc::channel::<AiTuiEvent>(); -async fn run_inline_tui( - endpoint: String, - token: String, - initial_prompt: Option<String>, - settings: &atuin_client::settings::Settings, -) -> Result<Action> { - let initial_state = AppState::new(); + let initial_state = Session::new(ctx.git_root.is_some()); println!(); - let (tx, rx) = mpsc::channel::<AiTuiEvent>(); - // If there's an initial prompt, send it as a SubmitInput event // so it flows through the same path as user-typed input. if let Some(prompt) = initial_prompt { @@ -396,164 +168,17 @@ async fn run_inline_tui( .ctrl_c(CtrlCBehavior::Deliver) .keyboard_protocol(eye_declare::KeyboardProtocol::Enhanced) .bracketed_paste(true) - .with_context(tx) + .with_context(tx.clone()) .extra_newlines_at_exit(1) .build()?; - // Support both legacy [ai] send_cwd and new [ai.opening] send_cwd - let send_cwd = - settings.ai.opening.send_cwd.unwrap_or(false) || settings.ai.send_cwd.unwrap_or(false); - - let last_command = if settings.ai.opening.send_last_command.unwrap_or(false) { - let db_path = PathBuf::from(settings.db_path.as_str()); - match Sqlite::new(db_path, settings.local_timeout).await { - Ok(db) => db.last().await.ok().flatten().map(|h| h.command), - Err(e) => { - debug!("Failed to open history database for read_history: {e}"); - None - } - } - } else { - None - }; - // Event loop: receives AiTuiEvent from components, mutates state via Handle. let h = handle.clone(); - let ep = endpoint.clone(); - let tk = token.clone(); tokio::task::spawn_blocking(move || { + let tx = tx.clone(); + let client_ctx = client_ctx; while let Ok(event) = rx.recv() { - match event { - AiTuiEvent::InputUpdated(input) => { - let input_blank = input.trim().is_empty(); - - h.update(move |state| { - state.is_input_blank = input_blank; - }); - } - AiTuiEvent::SubmitInput(input) => { - let input = input.trim().to_string(); - if input.is_empty() { - let h2 = h.clone(); - h.update(move |state| { - if state.has_any_command() { - state.exit_action = Some(ExitAction::Execute( - state.current_command().unwrap().to_string(), - )); - } else { - state.exit_action = Some(ExitAction::Cancel); - } - h2.exit(); - }); - continue; - } - - if input.starts_with('/') { - let input_clone = input.clone(); - h.update(move |state| { - state.handle_slash_command(&input_clone); - }); - continue; - } - - // Start generation and spawn streaming task - let ep = ep.clone(); - let tk = tk.clone(); - let h2 = h.clone(); - let lc = last_command.clone(); - h.update(move |state| { - state.start_generating(input); - state.start_streaming(); - state.is_input_blank = true; - let messages = state.events_to_messages(); - let sid = state.session_id.clone(); - let task = tokio::spawn(async move { - run_chat_stream(h2, ep, tk, sid, messages, send_cwd, lc).await; - }); - state.stream_abort = Some(task.abort_handle()); - }); - } - - AiTuiEvent::SlashCommand(command) => { - h.update(move |state| { - state.handle_slash_command(&command); - }); - } - - AiTuiEvent::CancelGeneration => { - h.update(|state| match state.mode { - crate::tui::state::AppMode::Generating => { - state.cancel_generation(); - } - crate::tui::state::AppMode::Streaming => { - state.cancel_streaming(); - } - _ => {} - }); - } - - AiTuiEvent::ExecuteCommand => { - let h2 = h.clone(); - h.update(move |state| { - let cmd = state.current_command().map(|c| c.to_string()); - if let Some(cmd) = cmd { - if state.is_current_command_dangerous() && !state.confirmation_pending { - state.confirmation_pending = true; - } else { - state.confirmation_pending = false; - state.exit_action = Some(ExitAction::Execute(cmd)); - h2.exit(); - } - } - }); - } - - AiTuiEvent::CancelConfirmation => { - h.update(move |state| { - state.confirmation_pending = false; - }); - } - - AiTuiEvent::InsertCommand => { - let h2 = h.clone(); - h.update(move |state| { - let cmd = state.current_command().map(|c| c.to_string()); - if let Some(cmd) = cmd { - state.confirmation_pending = false; - state.exit_action = Some(ExitAction::Insert(cmd)); - h2.exit(); - } - }); - } - - AiTuiEvent::Retry => { - let ep = ep.clone(); - let tk = tk.clone(); - let h2 = h.clone(); - let lc = last_command.clone(); - h.update(move |state| { - state.retry(); - state.start_streaming(); - let messages = state.events_to_messages(); - let sid = state.session_id.clone(); - let task = tokio::spawn(async move { - run_chat_stream(h2, ep, tk, sid, messages, send_cwd, lc).await; - }); - state.stream_abort = Some(task.abort_handle()); - }); - } - - AiTuiEvent::Exit => { - let h2 = h.clone(); - h.update(move |state| { - if let Some(abort) = state.stream_abort.take() { - abort.abort(); - } - state.exit_action = Some(ExitAction::Cancel); - h2.exit(); - }); - } - } + dispatch::dispatch(&h, event, &tx, &ctx, &client_ctx); } }); @@ -573,51 +198,125 @@ async fn run_inline_tui( // Helpers // ─────────────────────────────────────────────────────────────────── -fn hub_url(base: &str, path: &str) -> Result<Url> { - let base_with_slash = if base.ends_with('/') { - base.to_string() - } else { - format!("{base}/") - }; - let stripped = path.strip_prefix('/').unwrap_or(path); - Url::parse(&base_with_slash)? - .join(stripped) - .context("failed to build hub URL") +enum SetupChoice { + EnableAi, + DisableKeybind, + Cancel, } -fn detect_os() -> String { - match std::env::consts::OS { - "macos" => "macos".to_string(), - "linux" => "linux".to_string(), - "windows" => "windows".to_string(), - other => format!("Other: {other}"), +fn prompt_ai_setup() -> Result<SetupChoice> { + use crossterm::{ + cursor, + event::{self, Event, KeyCode}, + terminal, + }; + + let options = ["Enable Atuin AI", "Disable ? Keybind", "Cancel"]; + let mut selected: usize = 0; + let mut stdout = std::io::stdout(); + + // Print header before raw mode so newlines render correctly. + // Use stdout because the shell hook swaps stdout/stderr — stdout goes + // to the terminal in both hook and non-hook modes. + println!(); + println!(" Atuin AI is not yet configured."); + println!(); + + terminal::enable_raw_mode().context("failed to enable raw mode")?; + struct Guard; + impl Drop for Guard { + fn drop(&mut self) { + let _ = terminal::disable_raw_mode(); + } } -} + let _guard = Guard; -#[derive(Clone)] -enum Action { - Execute(String), - Insert(String), - Print(String), - Cancel, -} + crossterm::execute!(stdout, cursor::Hide)?; -fn emit_shell_result(action: Action, output_for_hook: bool) { - if output_for_hook { - match action { - Action::Execute(output) => eprintln!("__atuin_ai_execute__:{output}"), - Action::Insert(output) => eprintln!("__atuin_ai_insert__:{output}"), - Action::Print(output) => eprintln!("__atuin_ai_print__:{output}"), - Action::Cancel => eprintln!("__atuin_ai_cancel__"), + loop { + render_setup_options(&mut stdout, &options, selected)?; + + let ev = event::read().context("failed to read key event")?; + + crossterm::execute!(stdout, cursor::MoveUp(options.len() as u16))?; + + if let Event::Key(key) = ev { + match key.code { + KeyCode::Up | KeyCode::Char('k') => { + selected = selected.saturating_sub(1); + } + KeyCode::Down | KeyCode::Char('j') => { + if selected < options.len() - 1 { + selected += 1; + } + } + KeyCode::Enter => break, + KeyCode::Esc => { + selected = 2; + break; + } + _ => {} + } } - } else { - match action { - Action::Execute(output) => eprintln!("{output}"), - Action::Insert(output) => eprintln!("{output}"), - Action::Print(output) => eprintln!("{output}"), - Action::Cancel => eprintln!(), + } + + // Final render with selection visible + render_setup_options(&mut stdout, &options, selected)?; + crossterm::execute!(stdout, cursor::Show)?; + + Ok(match selected { + 0 => SetupChoice::EnableAi, + 1 => SetupChoice::DisableKeybind, + _ => SetupChoice::Cancel, + }) +} + +fn render_setup_options( + w: &mut impl std::io::Write, + options: &[&str], + selected: usize, +) -> Result<()> { + use crossterm::{ + style::Stylize, + terminal::{Clear, ClearType}, + }; + + for (i, option) in options.iter().enumerate() { + if i == selected { + write!(w, "\r {}", format!("> {option}").bold().cyan())?; + } else { + write!(w, "\r {option}")?; } + crossterm::execute!(w, Clear(ClearType::UntilNewLine))?; + write!(w, "\r\n")?; + } + w.flush()?; + Ok(()) +} + +async fn set_ai_enabled(enabled: bool) -> Result<()> { + let config_file = atuin_client::settings::Settings::get_config_path()?; + let config_str = tokio::fs::read_to_string(&config_file).await?; + let mut doc = config_str.parse::<toml_edit::DocumentMut>()?; + + if !doc.contains_key("ai") { + doc["ai"] = toml_edit::table(); + } + doc["ai"]["enabled"] = toml_edit::value(enabled); + + tokio::fs::write(&config_file, doc.to_string()).await?; + + if !enabled { + println!( + "Atuin AI keybind disabled. You can re-enable with `atuin config set ai.enabled true`.", + ); + println!("Restart your shell for changes to take effect."); + // Two printlns to ensure the message is visible above the shell prompt after program ends. + println!(); + println!(); } + + Ok(()) } fn wait_for_login_confirmation() -> Result<bool> { @@ -646,3 +345,27 @@ fn wait_for_login_confirmation() -> Result<bool> { } } } + +#[derive(Clone)] +enum Action { + Execute(String), + Insert(String), + Cancel, +} + +fn emit_shell_result(action: Action, output_for_hook: bool) { + if output_for_hook { + match action { + Action::Execute(output) => eprintln!("__atuin_ai_execute__:{output}"), + Action::Insert(output) => eprintln!("__atuin_ai_insert__:{output}"), + Action::Cancel => eprintln!("__atuin_ai_cancel__"), + } + } else { + match action { + Action::Execute(output) | Action::Insert(output) => { + println!("{output}"); + } + Action::Cancel => {} + } + } +} diff --git a/crates/atuin-ai/src/context.rs b/crates/atuin-ai/src/context.rs new file mode 100644 index 00000000..dabb5c5e --- /dev/null +++ b/crates/atuin-ai/src/context.rs @@ -0,0 +1,73 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use atuin_client::distro::detect_linux_distribution; +use atuin_client::settings::AiCapabilities; + +/// Session-scoped context for the AI chat session. +/// Holds the API configuration and client settings needed by the event loop and stream task. +#[derive(Clone, Debug)] +pub(crate) struct AppContext { + pub endpoint: String, + pub token: String, + pub send_cwd: bool, + pub last_command: Option<String>, + pub history_db: Arc<atuin_client::database::Sqlite>, + /// Git root of the current working directory, if inside a git repo. + /// Resolves through worktrees to the main repo root. + pub git_root: Option<PathBuf>, + pub capabilities: AiCapabilities, +} + +/// Machine identity — computed once per session. +#[derive(Clone, Debug)] +pub(crate) struct ClientContext { + pub os: String, + pub shell: Option<String>, + pub distro: Option<String>, +} + +impl ClientContext { + pub(crate) fn detect() -> Self { + let os = detect_os(); + let shell = crate::commands::detect_shell(); + let distro = if os == "linux" { + Some(detect_linux_distribution()) + } else { + None + }; + Self { os, shell, distro } + } + + /// Serialize to the JSON format the API expects for the "context" field. + /// The `pwd` field is always dynamic (current working directory), so it's + /// computed fresh on each call if `send_cwd` is true. + pub(crate) fn to_json(&self, send_cwd: bool, last_command: Option<&str>) -> serde_json::Value { + let mut ctx = serde_json::json!({ + "os": self.os, + "shell": self.shell, + "pwd": if send_cwd { + std::env::current_dir().ok().map(|p| p.to_string_lossy().into_owned()) + } else { + None + }, + "last_command": last_command, + }); + + if let Some(ref distro) = self.distro { + ctx["distro"] = serde_json::json!(distro); + } + + ctx + } +} + +/// Move the `detect_os` function here since it's about client identity. +fn detect_os() -> String { + match std::env::consts::OS { + "macos" => "macos".to_string(), + "linux" => "linux".to_string(), + "windows" => "windows".to_string(), + other => format!("Other: {other}"), + } +} diff --git a/crates/atuin-ai/src/lib.rs b/crates/atuin-ai/src/lib.rs index 2d86271d..6f431179 100644 --- a/crates/atuin-ai/src/lib.rs +++ b/crates/atuin-ai/src/lib.rs @@ -1,2 +1,6 @@ pub mod commands; -pub mod tui; +pub(crate) mod context; +pub(crate) mod permissions; +pub(crate) mod stream; +pub(crate) mod tools; +pub(crate) mod tui; diff --git a/crates/atuin-ai/src/permissions/check.rs b/crates/atuin-ai/src/permissions/check.rs new file mode 100644 index 00000000..6b908b93 --- /dev/null +++ b/crates/atuin-ai/src/permissions/check.rs @@ -0,0 +1,74 @@ +use eyre::Result; + +use crate::{permissions::file::RuleFile, tools::PermissableToolCall}; + +pub(crate) struct PermissionRequest<'t> { + call: &'t (dyn PermissableToolCall + Send + Sync), +} + +impl<'t> PermissionRequest<'t> { + pub fn new(call: &'t (dyn PermissableToolCall + Send + Sync)) -> Self { + Self { call } + } +} + +pub(crate) enum PermissionResponse { + Allowed, + Denied, + Ask, +} + +pub(crate) struct PermissionChecker { + files: Vec<RuleFile>, +} + +impl PermissionChecker { + pub fn new(files: Vec<RuleFile>) -> Self { + Self { files } + } + + pub async fn check<'t>( + &self, + request: &'t PermissionRequest<'t>, + ) -> Result<PermissionResponse> { + // Files are in order from deepest to shallowest, so we can stop at the first match. + // Within a file, the priority is ask -> deny -> allow + // The first rule type that matches is the one that applies, even if a later rule would contradict it. + for file in &self.files { + for rule in &file.content.permissions.ask { + if request.call.matches_rule(rule) { + tracing::debug!( + "Permission 'ASK' by rule: {} in file: {}", + rule, + file.path.display() + ); + return Ok(PermissionResponse::Ask); + } + } + + for rule in &file.content.permissions.deny { + if request.call.matches_rule(rule) { + tracing::debug!( + "Permission 'DENY' by rule: {} in file: {}", + rule, + file.path.display() + ); + return Ok(PermissionResponse::Denied); + } + } + + for rule in &file.content.permissions.allow { + if request.call.matches_rule(rule) { + tracing::debug!( + "Permission 'ALLOW' by rule: {} in file: {}", + rule, + file.path.display() + ); + return Ok(PermissionResponse::Allowed); + } + } + } + + Ok(PermissionResponse::Ask) + } +} diff --git a/crates/atuin-ai/src/permissions/file.rs b/crates/atuin-ai/src/permissions/file.rs new file mode 100644 index 00000000..c973f55b --- /dev/null +++ b/crates/atuin-ai/src/permissions/file.rs @@ -0,0 +1,26 @@ +use std::path::PathBuf; + +use serde::{Deserialize, Serialize}; + +use crate::permissions::rule::Rule; + +#[derive(Debug, Clone)] +pub(crate) struct RuleFile { + pub path: PathBuf, + pub content: RuleFileContent, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub(crate) struct RuleFileContent { + pub permissions: RuleFilePermissions, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub(crate) struct RuleFilePermissions { + #[serde(default)] + pub allow: Vec<Rule>, + #[serde(default)] + pub deny: Vec<Rule>, + #[serde(default)] + pub ask: Vec<Rule>, +} diff --git a/crates/atuin-ai/src/permissions/mod.rs b/crates/atuin-ai/src/permissions/mod.rs new file mode 100644 index 00000000..fce64a51 --- /dev/null +++ b/crates/atuin-ai/src/permissions/mod.rs @@ -0,0 +1,7 @@ +pub(crate) mod check; +pub(crate) mod file; +pub(crate) mod resolver; +pub(crate) mod rule; +pub(crate) mod shell; +pub(crate) mod walker; +pub(crate) mod writer; diff --git a/crates/atuin-ai/src/permissions/resolver.rs b/crates/atuin-ai/src/permissions/resolver.rs new file mode 100644 index 00000000..dc4f83bf --- /dev/null +++ b/crates/atuin-ai/src/permissions/resolver.rs @@ -0,0 +1,31 @@ +use std::path::PathBuf; + +use eyre::Result; + +use crate::permissions::check::{PermissionChecker, PermissionRequest, PermissionResponse}; +use crate::permissions::walker::PermissionWalker; +use crate::permissions::writer; +use crate::tools::ClientToolCall; + +/// Resolves permissions for client tool calls by walking the filesystem to find permission files, +pub(crate) struct PermissionResolver { + checker: PermissionChecker, +} + +impl PermissionResolver { + /// Create a new resolver that walks from `working_dir` to root for project + /// permissions, and also checks the global permissions file. + pub async fn new(working_dir: PathBuf) -> Result<Self> { + let global_file = writer::global_permissions_path(); + let mut walker = PermissionWalker::new(working_dir, Some(global_file)); + walker.walk().await?; + let checker = PermissionChecker::new(walker.rules().to_owned()); + Ok(Self { checker }) + } + + /// Check whether `tool` is allowed, denied, or needs user confirmation. + pub async fn check(&self, tool: &ClientToolCall) -> Result<PermissionResponse> { + let request = PermissionRequest::new(tool); + self.checker.check(&request).await + } +} diff --git a/crates/atuin-ai/src/permissions/rule.rs b/crates/atuin-ai/src/permissions/rule.rs new file mode 100644 index 00000000..8fa3fa4a --- /dev/null +++ b/crates/atuin-ai/src/permissions/rule.rs @@ -0,0 +1,106 @@ +use std::sync::OnceLock; + +use regex::Regex; +use serde::{Deserialize, Serialize}; + +static RULE_RE: OnceLock<Regex> = OnceLock::new(); + +#[derive(Debug, thiserror::Error)] +pub(crate) enum RuleError { + #[error("invalid rule format: {0}")] + InvalidRule(String), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct Rule { + pub tool: String, + pub scope: Option<String>, +} + +impl std::fmt::Display for Rule { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.scope.as_ref() { + Some(scope) => write!(f, "{}({})", self.tool, scope), + None => write!(f, "{}", self.tool), + } + } +} + +impl Serialize for Rule { + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> + where + S: serde::Serializer, + { + serializer.serialize_str(&self.to_string()) + } +} + +impl<'de> Deserialize<'de> for Rule { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + Self::try_from(s.as_str()).map_err(serde::de::Error::custom) + } +} +impl TryFrom<&str> for Rule { + type Error = RuleError; + + fn try_from(value: &str) -> Result<Self, Self::Error> { + let value = value.trim(); + let re = RULE_RE.get_or_init(|| Regex::new(r"^(\w+)(?:\((.*)\))?$").unwrap()); + let caps = re + .captures(value) + .ok_or(RuleError::InvalidRule(value.to_string()))?; + let tool = caps.get(1).unwrap().as_str().to_string(); + let scope = caps.get(2).map(|m| m.as_str().to_string()); + Ok(Rule { tool, scope }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_rule_try_from() { + assert_eq!( + Rule::try_from("Read").unwrap(), + Rule { + tool: "Read".to_string(), + scope: None + } + ); + assert_eq!( + Rule::try_from("Read(*)").unwrap(), + Rule { + tool: "Read".to_string(), + scope: Some("*".to_string()) + } + ); + assert_eq!( + Rule::try_from("Write(*.md)").unwrap(), + Rule { + tool: "Write".to_string(), + scope: Some("*.md".to_string()) + } + ); + assert_eq!( + Rule::try_from("Shell(git commit *)").unwrap(), + Rule { + tool: "Shell".to_string(), + scope: Some("git commit *".to_string()) + } + ); + assert_eq!( + Rule::try_from("Shell(echo ())").unwrap(), + Rule { + tool: "Shell".to_string(), + scope: Some("echo ()".to_string()) + } + ); + assert!(Rule::try_from("Shell(git commit *").is_err()); + assert!(Rule::try_from("Shell(git commit *)!").is_err()); + } +} diff --git a/crates/atuin-ai/src/permissions/shell.rs b/crates/atuin-ai/src/permissions/shell.rs new file mode 100644 index 00000000..7a2eee2e --- /dev/null +++ b/crates/atuin-ai/src/permissions/shell.rs @@ -0,0 +1,1297 @@ +/// Extracted command info from a shell command string. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct ShellCommand { + /// The command name (first word), e.g. "git" + pub name: String, + /// The full invocation including arguments, e.g. "git commit -m msg" + pub full: String, +} + +/// A parsed shell command with all subcommands extracted. +#[derive(Debug)] +pub(crate) struct ParsedShellCommand { + pub subcommands: Vec<ShellCommand>, +} + +/// Supported shell families for parsing. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum ShellKind { + /// POSIX sh, bash, zsh — all share similar syntax + Posix, + /// fish shell + Fish, + /// nushell or unknown — fallback to word-level extraction + Other, +} + +impl ShellKind { + pub(crate) fn from_shell_name(name: &str) -> Self { + match name { + "bash" | "sh" | "zsh" | "dash" | "ksh" => Self::Posix, + "fish" => Self::Fish, + _ => Self::Other, + } + } +} + +/// Parse a shell command string and extract all subcommands. +pub(crate) fn parse_shell_command(code: &str, shell: ShellKind) -> ParsedShellCommand { + #[cfg(feature = "tree-sitter")] + match shell { + ShellKind::Posix => ts::parse_posix(code), + ShellKind::Fish => ts::parse_fish(code), + ShellKind::Other => parse_fallback(code), + } + + #[cfg(not(feature = "tree-sitter"))] + { + let _ = shell; + parse_fallback(code) + } +} + +// ──────────────────────────────────────────────────────────────── +// Tree-sitter parsers (POSIX + Fish) +// Disabled on platforms where tree-sitter doesn't cross-compile +// (e.g. Windows); falls back to word-level extraction. +// ──────────────────────────────────────────────────────────────── + +#[cfg(feature = "tree-sitter")] +mod ts { + use super::{ParsedShellCommand, ShellCommand, parse_fallback}; + use tree_sitter_lib::{Parser, Tree}; + + fn bash_parser() -> Parser { + let mut parser = Parser::new(); + parser + .set_language(&tree_sitter_bash::LANGUAGE.into()) + .expect("failed to set bash language"); + parser + } + + pub(super) fn parse_posix(code: &str) -> ParsedShellCommand { + let mut parser = bash_parser(); + let Some(tree) = parser.parse(code, None) else { + return parse_fallback(code); + }; + + let mut commands = Vec::new(); + walk_bash_tree(&tree, code.as_bytes(), &mut commands); + ParsedShellCommand { + subcommands: commands, + } + } + + /// Leaf node kinds that never contain nested commands. + const BASH_LEAVES: &[&str] = &[ + "command_name", + "word", + "number", + "simple_expansion", + "expansion", + "arithmetic_expansion", + "ansi_c_string", + "special_variable_name", + "variable_name", + "file_descriptor", + "heredoc_body", + "heredoc_start", + "regex", + "heredoc_redirect", + ]; + + fn walk_bash_tree(tree: &Tree, source: &[u8], commands: &mut Vec<ShellCommand>) { + walk_bash_node(tree.root_node(), source, commands); + } + + fn walk_bash_node( + node: tree_sitter_lib::Node, + source: &[u8], + commands: &mut Vec<ShellCommand>, + ) { + match node.kind() { + "command" => { + if let Some(cmd) = extract_bash_command(node, source) { + commands.push(cmd); + } + // Descend into all non-leaf children to find nested commands + // (e.g. command_substitution inside a string inside a command) + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if !BASH_LEAVES.contains(&child.kind()) { + walk_bash_node(child, source, commands); + } + } + } + // Other nodes: descend into all children + _ => { + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + walk_bash_node(child, source, commands); + } + } + } + } + + /// Extract the full command string and name from a bash `command` node. + fn extract_bash_command(node: tree_sitter_lib::Node, source: &[u8]) -> Option<ShellCommand> { + // A `command` node has children like: + // variable_assignment* command_name argument* redirect* + // We want the command_name and all arguments (skipping assignments and redirects). + let mut name = None; + let mut name_start = None; + let mut arg_end = None; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + match child.kind() { + "command_name" => { + name = child.utf8_text(source).ok().map(|s| s.to_string()); + name_start = Some(child.start_byte()); + } + "word" + | "string" + | "raw_string" + | "concatenation" + | "number" + | "simple_expansion" + | "expansion" + | "arithmetic_expansion" + | "ansi_c_string" + | "process_substitution" => { + arg_end = Some(child.end_byte()); + } + _ => {} + } + } + + let name = name?; + let full = if let (Some(start), Some(end)) = (name_start, arg_end) { + std::str::from_utf8(&source[start..end]).ok()?.to_string() + } else { + name.clone() + }; + + Some(ShellCommand { name, full }) + } + + // ──────────────────────────────────────────────────────────────── + // Fish parser + // ──────────────────────────────────────────────────────────────── + + fn fish_parser() -> Parser { + let mut parser = Parser::new(); + parser + .set_language(&tree_sitter_fish::language()) + .expect("failed to set fish language"); + parser + } + + pub(super) fn parse_fish(code: &str) -> ParsedShellCommand { + let mut parser = fish_parser(); + let Some(tree) = parser.parse(code, None) else { + return parse_fallback(code); + }; + + let mut commands = Vec::new(); + walk_fish_tree(&tree, code.as_bytes(), &mut commands); + ParsedShellCommand { + subcommands: commands, + } + } + + const FISH_COMPOUND: &[&str] = &[ + "conditional_execution", + "pipe", + "job", + "command_substitution", + "block", + "for_statement", + "while_statement", + "if_statement", + "switch_statement", + "function_definition", + "begin_statement", + "redirected_statement", + ]; + + fn walk_fish_tree(tree: &Tree, source: &[u8], commands: &mut Vec<ShellCommand>) { + walk_fish_node(tree.root_node(), source, commands); + } + + fn walk_fish_node( + node: tree_sitter_lib::Node, + source: &[u8], + commands: &mut Vec<ShellCommand>, + ) { + match node.kind() { + "command" => { + if let Some(cmd) = extract_fish_command(node, source) { + commands.push(cmd); + } + // Still descend into compound children (e.g. command_substitution inside a command) + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if FISH_COMPOUND.contains(&child.kind()) { + walk_fish_node(child, source, commands); + } + } + } + // Other nodes: descend into all children + _ => { + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + walk_fish_node(child, source, commands); + } + } + } + } + + fn extract_fish_command(node: tree_sitter_lib::Node, source: &[u8]) -> Option<ShellCommand> { + // In fish, a `command` node has: + // name (command_name or word) followed by arguments (word, string, etc.) + let mut name = None; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + match child.kind() { + "command_name" | "word" => { + let text = child.utf8_text(source).ok()?.to_string(); + if name.is_none() { + name = Some(text); + } + } + "string" + | "concatenation" + | "command_substitution" + | "escape_sequence" + | "double_quote_string" + | "single_quote_string" => {} + _ => {} + } + } + + let name = name?; + // Get the full text of the command node + let full = node.utf8_text(source).ok()?.trim().to_string(); + + Some(ShellCommand { name, full }) + } +} // mod ts + +// ──────────────────────────────────────────────────────────────── +// Fallback (word-level extraction for nushell / unknown shells) +// ──────────────────────────────────────────────────────────────── + +fn parse_fallback(code: &str) -> ParsedShellCommand { + // Simple heuristic: split by &&, ||, ;, | and take the first word of each segment. + // This is intentionally simple — for unknown shells we can't do better. + let mut commands = Vec::new(); + let mut segment = String::new(); + let mut chars = code.chars().peekable(); + + while let Some(c) = chars.next() { + match c { + ';' => { + push_segment(&mut segment, &mut commands); + } + '|' => { + if chars.peek() == Some(&'|') { + chars.next(); + } + push_segment(&mut segment, &mut commands); + } + '&' if chars.peek() == Some(&'&') => { + chars.next(); + push_segment(&mut segment, &mut commands); + } + _ => segment.push(c), + } + } + push_segment(&mut segment, &mut commands); + + ParsedShellCommand { + subcommands: commands, + } +} + +fn push_segment(segment: &mut String, commands: &mut Vec<ShellCommand>) { + let trimmed = segment.trim(); + if !trimmed.is_empty() + && let Some(name) = trimmed.split_whitespace().next() + { + commands.push(ShellCommand { + name: name.to_string(), + full: trimmed.to_string(), + }); + } + segment.clear(); +} + +// ──────────────────────────────────────────────────────────────── +// Scope matching +// ──────────────────────────────────────────────────────────────── + +/// Check if any of the extracted subcommands match the given scope pattern. +/// +/// Matching semantics depend on where the `*` wildcard appears: +/// - `*` alone — matches everything +/// - `ls *` (space before `*`) — matches `ls` and `ls -a` but not `lsof` +/// - `git commit *` — matches `git commit -m "msg"` (word boundary) +/// - `ls*` (no space before `*`) — matches `lsof`, `ls`, `ls -a` (prefix/glob) +/// - `rm` (no wildcard) — matches exactly `rm` +/// - `git * amend` — matches `git commit amend` (middle wildcard matches zero+ words) +pub(crate) fn any_subcommand_matches(subcommands: &[ShellCommand], scope: &str) -> bool { + let scope = scope.trim(); + + if scope == "*" { + return true; + } + + if let Some(prefix) = scope.strip_suffix(" *") { + // Word-boundary matching: `ls *` matches `ls` and `ls -a` but not `lsof` + return subcommands.iter().any(|cmd| { + if prefix.is_empty() { + return true; + } + let cmd_words: Vec<&str> = cmd.full.split_whitespace().collect(); + let prefix_words: Vec<&str> = prefix.split_whitespace().collect(); + cmd_words.len() >= prefix_words.len() + && cmd_words[..prefix_words.len()] == prefix_words[..] + }); + } + + if let Some(prefix) = scope.strip_suffix('*') { + // Prefix/glob matching: `ls*` matches `lsof`, `ls`, etc. + return subcommands.iter().any(|cmd| cmd.full.starts_with(prefix)); + } + + if scope.contains('*') { + // Middle wildcard: `git * amend` — each `*` matches zero or more words + return subcommands + .iter() + .any(|cmd| scope_matches_words(scope, cmd.full.split_whitespace().collect())); + } + + // No wildcard: word-boundary prefix match + let scope_words: Vec<&str> = scope.split_whitespace().collect(); + subcommands.iter().any(|cmd| { + let cmd_words: Vec<&str> = cmd.full.split_whitespace().collect(); + cmd_words.len() >= scope_words.len() && cmd_words[..scope_words.len()] == scope_words[..] + }) +} + +/// Match a scope pattern containing `*` wildcards against a sequence of words. +/// Each `*` matches zero or more words. Consecutive `*` collapse into one. +fn scope_matches_words(scope: &str, words: Vec<&str>) -> bool { + let parts: Vec<&str> = scope.split('*').collect(); + if parts.len() == 1 { + // No wildcard (shouldn't reach here, but handle it) + let scope_words: Vec<&str> = scope.split_whitespace().collect(); + return words.len() >= scope_words.len() && words[..scope_words.len()] == scope_words[..]; + } + + // Each segment between * is a sequence of literal words that must appear in order. + // Walk through `words` consuming segments left to right. + let mut word_idx = 0; + + for (i, part) in parts.iter().enumerate() { + let segment_words: Vec<&str> = part.split_whitespace().collect(); + if segment_words.is_empty() { + continue; + } + + // Find the segment words starting from word_idx + if i == 0 { + // First segment must match at the start + if words.len() < segment_words.len() + || words[..segment_words.len()] != segment_words[..] + { + return false; + } + word_idx = segment_words.len(); + } else if i == parts.len() - 1 { + // Last segment must match at the end + if words.len() - word_idx < segment_words.len() { + return false; + } + let start = words.len() - segment_words.len(); + return words[start..] == segment_words[..]; + } else { + // Middle segment: find it anywhere after word_idx + let found = find_subslice(&words[word_idx..], &segment_words); + match found { + Some(idx) => word_idx += idx + segment_words.len(), + None => return false, + } + } + } + + true +} + +/// Find the first occurrence of `needle` as a contiguous subsequence in `haystack`. +fn find_subslice(haystack: &[&str], needle: &[&str]) -> Option<usize> { + if needle.is_empty() { + return Some(0); + } + if haystack.len() < needle.len() { + return None; + } + (0..=haystack.len() - needle.len()).find(|&i| haystack[i..i + needle.len()] == needle[..]) +} + +// ──────────────────────────────────────────────────────────────── +// Tests +// ──────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + fn names(cmds: &[ShellCommand]) -> Vec<&str> { + cmds.iter().map(|c| c.name.as_str()).collect() + } + + fn fulls(cmds: &[ShellCommand]) -> Vec<&str> { + cmds.iter().map(|c| c.full.as_str()).collect() + } + + #[test] + fn simple_command() { + let result = parse_shell_command("ls -la /tmp", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["ls"]); + assert_eq!(fulls(&result.subcommands), vec!["ls -la /tmp"]); + } + + #[test] + fn pipeline() { + let result = parse_shell_command("cat file.txt | grep foo | wc -l", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["cat", "grep", "wc"]); + } + + #[test] + fn command_chaining() { + let result = parse_shell_command("git add . && git commit -m 'hi'", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["git", "git"]); + assert_eq!( + fulls(&result.subcommands), + vec!["git add .", "git commit -m 'hi'"] + ); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn command_substitution() { + let result = parse_shell_command("echo $(git rev-parse HEAD)", ShellKind::Posix); + let n = names(&result.subcommands); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + assert!(n.contains(&"git"), "should contain git: {n:?}"); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn backtick_substitution() { + let result = parse_shell_command("echo `date`", ShellKind::Posix); + let n = names(&result.subcommands); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + assert!(n.contains(&"date"), "should contain date: {n:?}"); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn subshell() { + let result = parse_shell_command("(cd /tmp && ls)", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["cd", "ls"]); + } + + #[test] + fn semicolon_separated() { + let result = parse_shell_command("echo hello; echo world", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["echo", "echo"]); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn for_loop() { + let result = parse_shell_command("for f in *.txt; do cat $f; done", ShellKind::Posix); + assert!(names(&result.subcommands).contains(&"cat")); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn if_statement() { + let result = parse_shell_command( + "if [ -f foo ]; then cat foo; else echo nope; fi", + ShellKind::Posix, + ); + let n = names(&result.subcommands); + assert!(n.contains(&"cat"), "should contain cat: {n:?}"); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + } + + #[test] + fn scope_matching_wildcard() { + let commands = vec![ + ShellCommand { + name: "git".into(), + full: "git commit -m msg".into(), + }, + ShellCommand { + name: "npm".into(), + full: "npm test".into(), + }, + ]; + assert!(any_subcommand_matches(&commands, "*")); + } + + #[test] + fn scope_matching_prefix() { + let commands = vec![ + ShellCommand { + name: "git".into(), + full: "git commit -m msg".into(), + }, + ShellCommand { + name: "npm".into(), + full: "npm test".into(), + }, + ]; + assert!(any_subcommand_matches(&commands, "git commit *")); + assert!(any_subcommand_matches(&commands, "git commit")); + assert!(!any_subcommand_matches(&commands, "git push *")); + assert!(!any_subcommand_matches(&commands, "git push")); + assert!(any_subcommand_matches(&commands, "npm *")); + } + + #[test] + fn scope_word_boundary_vs_glob() { + let commands = vec![ + ShellCommand { + name: "ls".into(), + full: "ls -a".into(), + }, + ShellCommand { + name: "lsof".into(), + full: "lsof -i :3000".into(), + }, + ]; + // `ls *` — word boundary: matches `ls -a` but not `lsof` + assert!(any_subcommand_matches(&commands, "ls *")); + assert!(!any_subcommand_matches(&commands, "cat *")); + assert!(any_subcommand_matches(&commands, "lsof *")); + + // `ls*` — glob/prefix: matches both `ls -a` and `lsof` + assert!(any_subcommand_matches(&commands, "ls*")); + } + + #[test] + fn scope_exact_match() { + let commands = vec![ShellCommand { + name: "ls".into(), + full: "ls".into(), + }]; + assert!(any_subcommand_matches(&commands, "ls")); + assert!(!any_subcommand_matches(&commands, "cat")); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn nested_substitution() { + let result = parse_shell_command( + "echo \"Result: $(git log --oneline | head -1)\"", + ShellKind::Posix, + ); + let n = names(&result.subcommands); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + assert!(n.contains(&"git"), "should contain git: {n:?}"); + assert!(n.contains(&"head"), "should contain head: {n:?}"); + } + + #[test] + fn fallback_splits_correctly() { + let result = parse_shell_command("ls && cat foo || echo fail", ShellKind::Other); + let n = names(&result.subcommands); + assert!(n.contains(&"ls"), "should contain ls: {n:?}"); + assert!(n.contains(&"cat"), "should contain cat: {n:?}"); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + } + + #[test] + fn fish_simple_command() { + let result = parse_shell_command("ls -la /tmp", ShellKind::Fish); + assert_eq!(names(&result.subcommands), vec!["ls"]); + } + + #[test] + fn fish_conditional() { + let result = parse_shell_command("git add .; and git commit -m hi", ShellKind::Fish); + let n = names(&result.subcommands); + assert!(n.contains(&"git"), "should contain git: {n:?}"); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn fish_command_substitution() { + let result = parse_shell_command("echo (date)", ShellKind::Fish); + let n = names(&result.subcommands); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + assert!(n.contains(&"date"), "should contain date: {n:?}"); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn variable_assignment_excluded() { + let result = parse_shell_command("FOO=bar ls -la /tmp", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["ls"]); + assert_eq!(fulls(&result.subcommands), vec!["ls -la /tmp"]); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn variable_assignment_multiple() { + let result = parse_shell_command("A=1 B=2 git status", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["git"]); + assert_eq!(fulls(&result.subcommands), vec!["git status"]); + } + + #[test] + fn fallback_double_ampersand_and_pipe_pipe() { + let result = parse_shell_command("ls && cat foo || echo fail", ShellKind::Other); + assert_eq!(names(&result.subcommands), vec!["ls", "cat", "echo"]); + assert_eq!( + fulls(&result.subcommands), + vec!["ls", "cat foo", "echo fail"] + ); + } + + #[test] + fn fallback_pipe_without_double() { + let result = parse_shell_command("ls | grep foo", ShellKind::Other); + assert_eq!(names(&result.subcommands), vec!["ls", "grep"]); + assert_eq!(fulls(&result.subcommands), vec!["ls", "grep foo"]); + } + + #[test] + fn scope_middle_wildcard() { + let commands = vec![ShellCommand { + name: "git".into(), + full: "git commit -m amend".into(), + }]; + assert!(any_subcommand_matches(&commands, "git * amend")); + assert!(any_subcommand_matches(&commands, "git commit * amend")); + assert!(!any_subcommand_matches(&commands, "git push * amend")); + } + + #[test] + fn scope_middle_wildcard_zero_words() { + let commands = vec![ShellCommand { + name: "git".into(), + full: "git commit".into(), + }]; + // `*` matches zero words, so `git * commit` should match `git commit` + assert!(any_subcommand_matches(&commands, "git * commit")); + } + + #[test] + fn scope_leading_wildcard() { + let commands = vec![ShellCommand { + name: "docker".into(), + full: "docker run --rm alpine".into(), + }]; + assert!(any_subcommand_matches(&commands, "* alpine")); + assert!(!any_subcommand_matches(&commands, "* ubuntu")); + } + + #[test] + fn scope_multiple_wildcards() { + let commands = vec![ShellCommand { + name: "git".into(), + full: "git rebase -i HEAD~5".into(), + }]; + assert!(any_subcommand_matches(&commands, "git * -i * HEAD~5")); + assert!(!any_subcommand_matches(&commands, "git * -i * HEAD~10")); + } +} + +#[cfg(all(test, feature = "tree-sitter"))] +mod adversarial { + use super::*; + + fn cmd_names(cmds: &[ShellCommand]) -> Vec<&str> { + cmds.iter().map(|c| c.name.as_str()).collect() + } + + /// Helper: assert that parsing POSIX extracts all expected command names + fn assert_posix(code: &str, expected: &[&str]) { + let result = parse_shell_command(code, ShellKind::Posix); + let mut got: Vec<&str> = result.subcommands.iter().map(|c| c.name.as_str()).collect(); + got.sort(); + let mut want: Vec<&str> = expected.to_vec(); + want.sort(); + assert_eq!( + got, want, + "POSIX parse of {:?}:\n got: {:?}\n want: {:?}", + code, got, want + ); + } + + fn assert_fish(code: &str, expected: &[&str]) { + let result = parse_shell_command(code, ShellKind::Fish); + let mut got: Vec<&str> = result.subcommands.iter().map(|c| c.name.as_str()).collect(); + got.sort(); + let mut want: Vec<&str> = expected.to_vec(); + want.sort(); + assert_eq!( + got, want, + "Fish parse of {:?}:\n got: {:?}\n want: {:?}", + code, got, want + ); + } + + // ──────────────────────────────────────────────────────────── + // Level 1: Basic compounds + // ──────────────────────────────────────────────────────────── + + #[test] + fn a01_triple_chain() { + assert_posix("a && b && c", &["a", "b", "c"]); + } + + #[test] + fn a02_or_chain() { + assert_posix("a || b || c", &["a", "b", "c"]); + } + + #[test] + fn a03_mixed_chain() { + assert_posix("a && b || c && d", &["a", "b", "c", "d"]); + } + + #[test] + fn a04_long_pipeline() { + assert_posix( + "cat foo | grep bar | awk '{print $1}' | sort | uniq -c", + &["cat", "grep", "awk", "sort", "uniq"], + ); + } + + #[test] + fn a05_semicolons() { + assert_posix("a; b; c; d", &["a", "b", "c", "d"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 2: Nested substitution + // ──────────────────────────────────────────────────────────── + + #[test] + fn a06_nested_dollar() { + assert_posix( + "echo $(basename $(dirname /foo/bar))", + &["echo", "basename", "dirname"], + ); + } + + #[test] + fn a07_deeply_nested() { + // 4 nested echos, all should be extracted + assert_posix( + "echo $(echo $(echo $(echo deep)))", + &["echo", "echo", "echo", "echo"], + ); + } + + #[test] + fn a08_backtick_in_echo() { + assert_posix("echo `hostname`", &["echo", "hostname"]); + } + + #[test] + fn a09_mixed_substitutions() { + assert_posix("echo $(date) `uname`", &["echo", "date", "uname"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 3: Subshells and grouping + // ──────────────────────────────────────────────────────────── + + #[test] + fn a10_subshell_chain() { + assert_posix("(cd /tmp && ls -la)", &["cd", "ls"]); + } + + #[test] + fn a11_nested_subshells() { + assert_posix("( (inner_cmd) )", &["inner_cmd"]); + } + + #[test] + fn a12_brace_group() { + assert_posix("{ cd /tmp; ls; }", &["cd", "ls"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 4: Variable assignments + // ──────────────────────────────────────────────────────────── + + #[test] + fn a13_single_var_assignment() { + let result = parse_shell_command("FOO=bar ls", ShellKind::Posix); + assert_eq!(cmd_names(&result.subcommands), &["ls"]); + assert_eq!(result.subcommands[0].full, "ls"); + } + + #[test] + fn a14_multiple_var_assignments() { + let result = parse_shell_command("A=1 B=2 C=3 git status", ShellKind::Posix); + assert_eq!(cmd_names(&result.subcommands), &["git"]); + assert_eq!(result.subcommands[0].full, "git status"); + } + + #[test] + fn a15_var_assignment_no_command() { + // Variable assignment only — no command to extract + assert_posix("FOO=bar", &[]); + } + + #[test] + fn a16_var_assignment_in_pipeline() { + assert_posix("FOO=bar ls | BAZ=qux grep foo", &["ls", "grep"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 5: Control flow + // ──────────────────────────────────────────────────────────── + + #[test] + fn a17_if_then_else() { + assert_posix( + "if [ -f foo ]; then cat foo; else echo missing; fi", + &["cat", "echo"], + ); + } + + #[test] + fn a18_elif_chain() { + // Two cat commands (then + elif branch), one echo (else branch). + // [ is part of the test_condition, not extracted as a command. + assert_posix( + "if [ -f a ]; then cat a; elif [ -f b ]; then cat b; else echo none; fi", + &["cat", "cat", "echo"], + ); + } + + #[test] + fn a19_for_loop() { + assert_posix("for f in *.txt; do cat \"$f\"; done", &["cat"]); + } + + #[test] + fn a20_while_loop() { + // read in the condition is a real command + assert_posix( + "while read line; do echo \"$line\"; done < input.txt", + &["echo", "read"], + ); + } + + #[test] + fn f07_if_statement() { + // test in if-condition is a real command + assert_fish( + "if test -f foo; cat foo; else; echo missing; end", + &["cat", "echo", "test"], + ); + } + + #[test] + fn f09_while_loop() { + // `true` in the condition is a real command + assert_fish( + "while true; echo tick; sleep 1; end", + &["echo", "sleep", "true"], + ); + } + + // ──────────────────────────────────────────────────────────── + // Level 6: Redirections + // ──────────────────────────────────────────────────────────── + + #[test] + fn a23_redirect_out() { + assert_posix("ls > output.txt", &["ls"]); + } + + #[test] + fn a24_redirect_append() { + assert_posix("ls >> output.txt 2>&1", &["ls"]); + } + + #[test] + fn a25_here_string() { + assert_posix("grep foo <<< \"hello world\"", &["grep"]); + } + + #[test] + fn a26_redirect_in_pipeline() { + assert_posix("cat < input.txt | sort | uniq", &["cat", "sort", "uniq"]); + } + + #[test] + fn a27_process_substitution() { + assert_posix( + "diff <(sort a.txt) <(sort b.txt)", + &["diff", "sort", "sort"], + ); + } + + // ──────────────────────────────────────────────────────────── + // Level 7: Function definitions + // ──────────────────────────────────────────────────────────── + + #[test] + fn a28_function_def() { + assert_posix("foo() { echo hello; }", &["echo"]); + } + + #[test] + fn a29_function_with_subshell() { + assert_posix( + "build() { cargo build && cargo test; }", + &["cargo", "cargo"], + ); + } + + // ──────────────────────────────────────────────────────────── + // Level 8: Edge cases — empties, weird quoting + // ──────────────────────────────────────────────────────────── + + #[test] + fn a30_empty_string() { + let result = parse_shell_command("", ShellKind::Posix); + assert!(result.subcommands.is_empty()); + } + + #[test] + fn a31_whitespace_only() { + let result = parse_shell_command(" \t \n ", ShellKind::Posix); + assert!(result.subcommands.is_empty()); + } + + #[test] + fn a32_single_command_no_args() { + assert_posix("ls", &["ls"]); + } + + #[test] + fn a33_command_with_single_quotes() { + assert_posix("echo 'hello world'", &["echo"]); + } + + #[test] + fn a34_command_with_double_quotes() { + assert_posix("echo \"hello world\"", &["echo"]); + } + + #[test] + fn a35_escaped_spaces() { + // ls\ -la is a single word in bash, not "ls" with flag "-la" + assert_posix("ls\\ -la", &["ls\\ -la"]); + } + + #[test] + fn a36_command_with_dollar_var() { + assert_posix("echo $HOME/.bashrc", &["echo"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 9: Background jobs and coproc + // ──────────────────────────────────────────────────────────── + + #[test] + fn a37_background_job() { + assert_posix("sleep 10 &", &["sleep"]); + } + + #[test] + fn a38_background_chain() { + assert_posix("sleep 10 && echo done &", &["sleep", "echo"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 10: Real-world complex commands + // ──────────────────────────────────────────────────────────── + + #[test] + fn a39_docker_build_and_run() { + assert_posix( + "docker build -t app . && docker run --rm app npm test", + &["docker", "docker"], + ); + } + + #[test] + fn a40_git_rebase_interactive() { + assert_posix( + "GIT_SEQUENCE_EDITOR=\"sed -i 's/pick/reword/'\" git rebase -i HEAD~5", + &["git"], + ); + } + + #[test] + fn a41_find_with_exec() { + // tree-sitter-bash does not parse -exec body as commands — only `find` is extracted. + // This is a known limitation: args to -exec/-execdir are opaque to the parser. + assert_posix("find . -name '*.rs' -exec grep -l 'unsafe' {} +", &["find"]); + } + + #[test] + fn a42_curl_pipe_sh() { + assert_posix( + "curl -sSL https://example.com/install.sh | bash", + &["curl", "bash"], + ); + } + + #[test] + fn a43_xargs() { + assert_posix("find . -name '*.tmp' | xargs rm -f", &["find", "xargs"]); + } + + #[test] + fn a44_npm_script_chain() { + assert_posix( + "npm run build && npm run test && npm run lint", + &["npm", "npm", "npm"], + ); + } + + #[test] + fn a45_make_with_redirect() { + assert_posix( + "make -j$(nproc) 2>&1 | tee build.log", + &["make", "nproc", "tee"], + ); + } + + #[test] + fn a46_sudo_chain() { + assert_posix("sudo apt update && sudo apt upgrade -y", &["sudo", "sudo"]); + } + + #[test] + fn a47_here_doc_with_subcommand() { + assert_posix("cat <<EOF\nhello $(whoami)\nEOF", &["cat", "whoami"]); + } + + #[test] + fn a48_eval_with_command() { + assert_posix("eval \"echo hello\"", &["eval"]); + } + + #[test] + fn a49_exec_replace() { + assert_posix("exec ls", &["exec"]); + } + + #[test] + fn a50_source_script() { + assert_posix("source ~/.bashrc", &["source"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 11: Fish-specific tests + // ──────────────────────────────────────────────────────────── + + #[test] + fn f01_simple() { + assert_fish("ls -la /tmp", &["ls"]); + } + + #[test] + fn f02_pipe() { + assert_fish("cat foo | grep bar | sort", &["cat", "grep", "sort"]); + } + + #[test] + fn f03_and() { + assert_fish("git add .; and git commit -m hi", &["git", "git"]); + } + + #[test] + fn f04_or() { + assert_fish("test -f foo; or echo missing", &["test", "echo"]); + } + + #[test] + fn f04_not() { + // fish parses `not test -f foo` — `not` is a modifier, `test` is the command + assert_fish("not test -f foo", &["test"]); + } + + #[test] + fn f05_command_substitution() { + assert_fish("echo (date)", &["echo", "date"]); + } + + #[test] + fn f06_nested_substitution() { + assert_fish( + "echo (basename (dirname /foo/bar))", + &["echo", "basename", "dirname"], + ); + } + + #[test] + fn f06_begin_end() { + assert_fish("begin; ls; echo done; end", &["ls", "echo"]); + } + + #[test] + fn f10_switch() { + // Two echo commands, one per case branch + assert_fish( + "switch $x; case foo; echo foo; case bar; echo bar; end", + &["echo", "echo"], + ); + } + + #[test] + fn f08_for_loop() { + assert_fish("for f in *.txt; cat $f; end", &["cat"]); + } + + #[test] + fn a21_case_statement() { + // Two echo branches + assert_posix( + "case $x in foo) echo foo;; bar) echo bar;; esac", + &["echo", "echo"], + ); + } + + #[test] + fn f11_function_def() { + assert_fish("function greet; echo hello $argv; end", &["echo"]); + } + + #[test] + fn f12_redirect() { + assert_fish("ls > output.txt", &["ls"]); + } + + #[test] + fn f13_redirect_append() { + assert_fish("ls >> output.txt", &["ls"]); + } + + #[test] + fn f14_here_string() { + assert_fish("grep foo <<< \"hello\"", &["grep"]); + } + + #[test] + fn f15_curl_pipe() { + assert_fish( + "curl -sSL https://example.com/install.sh | bash", + &["curl", "bash"], + ); + } + + #[test] + fn f16_double_ampersand() { + assert_fish("git add . && git commit -m hi", &["git", "git"]); + } + + #[test] + fn f17_double_pipe() { + assert_fish("test -f foo || echo missing", &["test", "echo"]); + } + + #[test] + fn f18_empty() { + let result = parse_shell_command("", ShellKind::Fish); + assert!(result.subcommands.is_empty()); + } + + #[test] + fn f19_whitespace() { + let result = parse_shell_command(" ", ShellKind::Fish); + assert!(result.subcommands.is_empty()); + } + + // ──────────────────────────────────────────────────────────── + // Level 12: Scope matching adversarial + // ──────────────────────────────────────────────────────────── + + #[test] + fn s01_empty_scope() { + let commands = vec![ShellCommand { + name: "ls".into(), + full: "ls".into(), + }]; + // Empty scope matches everything (nothing to constrain) + assert!(any_subcommand_matches(&commands, "")); + } + + #[test] + fn s03_only_wildcard_space_star() { + let commands = vec![ShellCommand { + name: "ls".into(), + full: "ls".into(), + }]; + // " *" with empty prefix = match anything + assert!(any_subcommand_matches(&commands, " *")); + } + + #[test] + fn s04_glob_matches_empty() { + let commands = vec![ShellCommand { + name: "ls".into(), + full: "ls".into(), + }]; + // `ls*` matches `ls` (prefix match with nothing after) + assert!(any_subcommand_matches(&commands, "ls*")); + } + + #[test] + fn s05_middle_wildcard_empty_match() { + // `git * commit` matches `git commit` (* = zero words) + let commands = vec![ShellCommand { + name: "git".into(), + full: "git commit".into(), + }]; + assert!(any_subcommand_matches(&commands, "git * commit")); + } + + #[test] + fn s06_consecutive_wildcards() { + // `git ** commit` should behave like `git * commit` + let commands = vec![ShellCommand { + name: "git".into(), + full: "git commit".into(), + }]; + assert!(any_subcommand_matches(&commands, "git ** commit")); + } + + #[test] + fn s07_case_sensitivity() { + let commands = vec![ShellCommand { + name: "LS".into(), + full: "LS -la".into(), + }]; + assert!(!any_subcommand_matches(&commands, "ls")); + assert!(any_subcommand_matches(&commands, "LS")); + } + + #[test] + fn s08_multi_word_exact_no_subcommand() { + // `git commit` should not match `git commit-amend` + let commands = vec![ShellCommand { + name: "git".into(), + full: "git commit-amend".into(), + }]; + assert!(!any_subcommand_matches(&commands, "git commit")); + } +} diff --git a/crates/atuin-ai/src/permissions/walker.rs b/crates/atuin-ai/src/permissions/walker.rs new file mode 100644 index 00000000..3bda01c3 --- /dev/null +++ b/crates/atuin-ai/src/permissions/walker.rs @@ -0,0 +1,121 @@ +use std::path::{Path, PathBuf}; + +use eyre::Result; +use tokio::task::JoinSet; + +use crate::permissions::file::{RuleFile, RuleFileContent}; + +#[derive(Debug)] +struct FoundRuleFile { + depth: usize, + file: RuleFile, +} + +pub(crate) struct PermissionWalker { + start: PathBuf, + /// Direct path to the global permissions file (e.g. `~/.config/atuin/permissions.ai.toml`). + global_permissions_file: Option<PathBuf>, + rules: Vec<RuleFile>, +} + +impl PermissionWalker { + pub fn new(start: PathBuf, global_permissions_file: Option<PathBuf>) -> Self { + Self { + start, + global_permissions_file, + rules: Vec::new(), + } + } + + pub fn rules(&self) -> &[RuleFile] { + &self.rules + } + + /// Walks the filesystem starting from the start path and collecting permission files along the way. + /// Walks to the root, then checks the global permissions file, if any. + pub async fn walk(&mut self) -> Result<()> { + let dirs_to_check: Vec<PathBuf> = self.start.ancestors().map(PathBuf::from).collect(); + let dir_count = dirs_to_check.len(); + + let mut set: JoinSet<Result<Option<FoundRuleFile>>> = JoinSet::new(); + + for (index, path) in dirs_to_check.into_iter().enumerate() { + set.spawn(async move { + match check_dir_for_permissions(&path).await { + Ok(Some(rule_file)) => Ok(Some(FoundRuleFile { + depth: index, + file: rule_file, + })), + Ok(None) => Ok(None), + Err(e) => Err(e), + } + }); + } + + // Check the global file separately (it's a direct file path, not a dir/.atuin/ pattern) + if let Some(global_path) = self.global_permissions_file.clone() { + let depth = dir_count; // sorts after all directory-walk entries + set.spawn(async move { + match load_permissions_file(&global_path).await { + Ok(Some(rule_file)) => Ok(Some(FoundRuleFile { + depth, + file: rule_file, + })), + Ok(None) => Ok(None), + Err(e) => Err(e), + } + }); + } + + let capacity = dir_count + usize::from(self.global_permissions_file.is_some()); + let mut found = Vec::with_capacity(capacity); + while let Some(result) = set.join_next().await { + let result = result?; // JoinErrors result in failure to walk the filesystem + + match result { + Ok(Some(FoundRuleFile { depth, file })) => { + found.push((depth, file)); + } + Ok(None) => { + continue; + } + Err(e) => { + tracing::error!( + "Error while walking filesystem for permissions check; skipping: {}", + e + ); + continue; + } + } + } + // join_next() returns in order of completion, not order of spawn + found.sort_by_key(|(depth, _)| *depth); + self.rules = found.into_iter().map(|(_, file)| file).collect(); + + Ok(()) + } +} + +/// Checks a directory for `.atuin/permissions.ai.toml` and returns the RuleFile if found. +async fn check_dir_for_permissions(path: &Path) -> Result<Option<RuleFile>> { + let file_path = path.join(".atuin").join("permissions.ai.toml"); + load_permissions_file(&file_path).await +} + +/// Load a permissions file from an exact path. Returns None if the file doesn't exist. +async fn load_permissions_file(file_path: &Path) -> Result<Option<RuleFile>> { + if !tokio::fs::try_exists(file_path).await? { + return Ok(None); + } + + let raw = tokio::fs::read_to_string(file_path).await?; + let content: RuleFileContent = toml::from_str(&raw)?; + + // Use the file's parent as the rule file path (for logging/debugging) + let path = file_path + .parent() + .map(Path::to_path_buf) + .unwrap_or_else(|| file_path.to_path_buf()); + + Ok(Some(RuleFile { path, content })) +} diff --git a/crates/atuin-ai/src/permissions/writer.rs b/crates/atuin-ai/src/permissions/writer.rs new file mode 100644 index 00000000..b2bd9482 --- /dev/null +++ b/crates/atuin-ai/src/permissions/writer.rs @@ -0,0 +1,198 @@ +use std::path::Path; + +use eyre::Result; + +use crate::permissions::rule::Rule; + +/// Whether a rule should be added to the allow or deny list. +#[allow(dead_code)] +pub(crate) enum RuleDisposition { + Allow, + Deny, +} + +/// Write a permission rule to a `permissions.ai.toml` file. +/// +/// If the file doesn't exist it is created (along with parent directories). +/// If it does exist, `toml_edit` is used to append the rule while preserving +/// existing formatting and comments. +/// +/// **Not concurrent-safe.** The read-modify-write cycle is not atomic. In the +/// current UI this is fine — the Select widget serializes permission decisions — +/// but callers should not invoke this concurrently for the same file. +pub(crate) async fn write_rule( + file_path: &Path, + rule: &Rule, + disposition: RuleDisposition, +) -> Result<()> { + let content = if tokio::fs::try_exists(file_path).await.unwrap_or(false) { + tokio::fs::read_to_string(file_path).await? + } else { + String::new() + }; + + let mut doc: toml_edit::DocumentMut = content.parse()?; + + // Ensure [permissions] table exists + if !doc.contains_key("permissions") { + doc["permissions"] = toml_edit::Item::Table(toml_edit::Table::new()); + } + + let key = match disposition { + RuleDisposition::Allow => "allow", + RuleDisposition::Deny => "deny", + }; + + // Use as_table_like_mut so both standard and inline tables work. + let permissions = doc["permissions"] + .as_table_like_mut() + .ok_or_else(|| eyre::eyre!("[permissions] is not a table"))?; + + // Get or create the array + if !permissions.contains_key(key) { + permissions.insert(key, toml_edit::Item::Value(toml_edit::Array::new().into())); + } + + let array = permissions + .get_mut(key) + .and_then(|item| item.as_value_mut()) + .and_then(|v| v.as_array_mut()) + .ok_or_else(|| eyre::eyre!("permissions.{key} is not an array"))?; + + // Don't add duplicates + let rule_str = rule.to_string(); + let already_present = array.iter().any(|v| v.as_str() == Some(&rule_str)); + if !already_present { + array.push(rule_str); + } + + // Write back, creating parent directories as needed + if let Some(parent) = file_path.parent() { + tokio::fs::create_dir_all(parent).await?; + } + tokio::fs::write(file_path, doc.to_string()).await?; + + Ok(()) +} + +/// Build the path to the project-level permissions file. +/// `project_root` is typically a git root or the current working directory. +pub(crate) fn project_permissions_path(project_root: &Path) -> std::path::PathBuf { + project_root.join(".atuin").join("permissions.ai.toml") +} + +/// Build the path to the global permissions file (sibling of atuin config). +pub(crate) fn global_permissions_path() -> std::path::PathBuf { + atuin_common::utils::config_dir().join("permissions.ai.toml") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn creates_new_file_with_allow_rule() { + let dir = tempfile::tempdir().unwrap(); + let file = dir.path().join("permissions.ai.toml"); + let rule = Rule { + tool: "AtuinHistory".to_string(), + scope: None, + }; + + write_rule(&file, &rule, RuleDisposition::Allow) + .await + .unwrap(); + + let content = tokio::fs::read_to_string(&file).await.unwrap(); + assert!(content.contains("[permissions]")); + assert!(content.contains(r#""AtuinHistory""#)); + } + + #[tokio::test] + async fn appends_to_existing_file() { + let dir = tempfile::tempdir().unwrap(); + let file = dir.path().join("permissions.ai.toml"); + let existing = r#"# My permissions +[permissions] +allow = ["Read"] +"#; + tokio::fs::write(&file, existing).await.unwrap(); + + let rule = Rule { + tool: "AtuinHistory".to_string(), + scope: None, + }; + write_rule(&file, &rule, RuleDisposition::Allow) + .await + .unwrap(); + + let content = tokio::fs::read_to_string(&file).await.unwrap(); + // Comment preserved + assert!(content.contains("# My permissions")); + // Both rules present + assert!(content.contains(r#""Read""#)); + assert!(content.contains(r#""AtuinHistory""#)); + } + + #[tokio::test] + async fn does_not_duplicate_existing_rule() { + let dir = tempfile::tempdir().unwrap(); + let file = dir.path().join("permissions.ai.toml"); + let existing = r#"[permissions] +allow = ["AtuinHistory"] +"#; + tokio::fs::write(&file, existing).await.unwrap(); + + let rule = Rule { + tool: "AtuinHistory".to_string(), + scope: None, + }; + write_rule(&file, &rule, RuleDisposition::Allow) + .await + .unwrap(); + + let content = tokio::fs::read_to_string(&file).await.unwrap(); + // Should appear exactly once + assert_eq!(content.matches("AtuinHistory").count(), 1); + } + + #[tokio::test] + async fn handles_inline_table_permissions() { + let dir = tempfile::tempdir().unwrap(); + let file = dir.path().join("permissions.ai.toml"); + // Inline table style — as_table_mut() would return None for this + let existing = r#"permissions = { allow = ["Read"] } +"#; + tokio::fs::write(&file, existing).await.unwrap(); + + let rule = Rule { + tool: "AtuinHistory".to_string(), + scope: None, + }; + write_rule(&file, &rule, RuleDisposition::Allow) + .await + .unwrap(); + + let content = tokio::fs::read_to_string(&file).await.unwrap(); + assert!(content.contains(r#""Read""#)); + assert!(content.contains(r#""AtuinHistory""#)); + } + + #[tokio::test] + async fn writes_deny_rule() { + let dir = tempfile::tempdir().unwrap(); + let file = dir.path().join("permissions.ai.toml"); + let rule = Rule { + tool: "Shell".to_string(), + scope: None, + }; + + write_rule(&file, &rule, RuleDisposition::Deny) + .await + .unwrap(); + + let content = tokio::fs::read_to_string(&file).await.unwrap(); + assert!(content.contains("deny")); + assert!(content.contains(r#""Shell""#)); + } +} diff --git a/crates/atuin-ai/src/stream.rs b/crates/atuin-ai/src/stream.rs new file mode 100644 index 00000000..4673f2cd --- /dev/null +++ b/crates/atuin-ai/src/stream.rs @@ -0,0 +1,372 @@ +// ─────────────────────────────────────────────────────────────────── +// SSE streaming +// ─────────────────────────────────────────────────────────────────── + +use std::sync::mpsc; + +use atuin_client::settings::AiCapabilities; +use atuin_common::tls::ensure_crypto_provider; + +use eventsource_stream::Eventsource; +use eye_declare::Handle; +use eyre::{Context, Result}; +use futures::StreamExt; +use reqwest::Url; + +use crate::{ + context::{AppContext, ClientContext}, + tools::ClientToolCall, + tui::{Session, events::AiTuiEvent}, +}; + +/// Frames that alter the stream lifecycle — terminal or state-changing. +#[derive(Debug, Clone)] +pub(crate) enum StreamControl { + Done { session_id: String }, + Error(String), + StatusChanged(String), +} + +/// Frames that carry conversation content — they mutate the event log. +#[derive(Debug, Clone)] +pub(crate) enum StreamContent { + TextChunk(String), + ToolCall { + id: String, + name: String, + input: serde_json::Value, + }, + ToolResult { + tool_use_id: String, + content: String, + is_error: bool, + }, +} + +/// A frame from the SSE stream, classified as control or content. +#[derive(Debug, Clone)] +pub(crate) enum StreamFrame { + Content(StreamContent), + Control(StreamControl), +} + +/// Per-turn request payload for the chat API. +pub(crate) struct ChatRequest { + pub messages: Vec<serde_json::Value>, + pub session_id: Option<String>, + pub capabilities: Vec<String>, +} + +impl ChatRequest { + pub(crate) fn new( + messages: Vec<serde_json::Value>, + session_id: Option<String>, + capabilities: &AiCapabilities, + ) -> Self { + let mut caps = vec![]; + if capabilities.enable_history_search.unwrap_or(true) { + caps.push("client_v1_atuin_history".to_string()); + } + if let Ok(extra) = std::env::var("ATUIN_AI__ADDITIONAL_CAPS") { + caps.extend( + extra + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()), + ); + } + + Self { + messages, + session_id, + capabilities: caps, + } + } +} + +fn create_chat_stream( + hub_address: String, + token: String, + request: ChatRequest, + client_ctx: ClientContext, + send_cwd: bool, + last_command: Option<String>, +) -> std::pin::Pin<Box<dyn futures::Stream<Item = Result<StreamFrame>> + Send>> { + Box::pin(async_stream::stream! { + ensure_crypto_provider(); + let endpoint = match hub_url(&hub_address, "/api/cli/chat") { + Ok(url) => url, + Err(e) => { + yield Err(e); + return; + } + }; + + tracing::debug!("Sending SSE request to {endpoint}"); + + let context = client_ctx.to_json(send_cwd, last_command.as_deref()); + + let mut request_body = serde_json::json!({ + "messages": request.messages, + "context": context, + "capabilities": request.capabilities, + }); + + if let Some(ref sid) = request.session_id { + tracing::trace!("Including session_id in request: {sid}"); + request_body["session_id"] = serde_json::json!(sid); + } + + let client = reqwest::Client::new(); + let response = match client + .post(endpoint.clone()) + .header("Accept", "text/event-stream") + .bearer_auth(&token) + .json(&request_body) + .send() + .await + { + Ok(resp) => resp, + Err(e) => { + yield Err(eyre::eyre!("Failed to send SSE request: {}", e)); + return; + } + }; + + let status = response.status(); + if status == reqwest::StatusCode::UNAUTHORIZED { + tracing::error!("SSE request failed with status: {status}, clearing session"); + let _ = atuin_client::hub::delete_session().await; + yield Err(eyre::eyre!("Hub session expired. Re-run to authenticate again.")); + return; + } + if !status.is_success() { + let body = response.text().await.unwrap_or_default(); + tracing::error!("SSE request failed ({}): {}", status, body); + yield Err(eyre::eyre!("SSE request failed ({}): {}", status, body)); + return; + } + + let byte_stream = response.bytes_stream(); + let mut stream = byte_stream.eventsource(); + + while let Some(event) = stream.next().await { + match event { + Ok(sse_event) => { + let event_type = sse_event.event.as_str(); + let data = sse_event.data.clone(); + + tracing::debug!(event_type = %event_type, "SSE event received"); + + match event_type { + "text" => { + if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) + && let Some(content) = json.get("content").and_then(|v| v.as_str()) + { + yield Ok(StreamFrame::Content(StreamContent::TextChunk(content.to_string()))); + } + } + "tool_call" => { + if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { + let id = json.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let name = json.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let input = json.get("input").cloned().unwrap_or(serde_json::json!({})); + yield Ok(StreamFrame::Content(StreamContent::ToolCall { id, name, input })); + } + } + "tool_result" => { + if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { + let tool_use_id = json.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let content = json.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let is_error = json.get("is_error").and_then(|v| v.as_bool()).unwrap_or(false); + yield Ok(StreamFrame::Content(StreamContent::ToolResult { tool_use_id, content, is_error })); + } + } + "status" => { + if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) + && let Some(state) = json.get("state").and_then(|v| v.as_str()) + { + yield Ok(StreamFrame::Control(StreamControl::StatusChanged(state.to_string()))); + } + } + "done" => { + if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { + let session_id = json.get("session_id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + yield Ok(StreamFrame::Control(StreamControl::Done { session_id })); + } else { + yield Ok(StreamFrame::Control(StreamControl::Done { session_id: String::new() })); + } + break; + } + "error" => { + if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { + let message = json.get("message").and_then(|v| v.as_str()).unwrap_or("Unknown error").to_string(); + tracing::error!("SSE error: {}", message); + yield Ok(StreamFrame::Control(StreamControl::Error(message))); + } else { + tracing::error!("SSE error: {}", data); + yield Ok(StreamFrame::Control(StreamControl::Error(data))); + } + break; + } + _ => {} + } + } + Err(e) => { + yield Err(eyre::eyre!("SSE error: {}", e)); + break; + } + } + } + }) +} + +// ─────────────────────────────────────────────────────────────────── +// Async streaming task — pushes updates to app state via Handle +// ─────────────────────────────────────────────────────────────────── + +pub(crate) async fn run_chat_stream( + handle: Handle<Session>, + tx: mpsc::Sender<AiTuiEvent>, + app_ctx: AppContext, + client_ctx: ClientContext, + request: ChatRequest, +) { + let capabilities = request.capabilities.clone(); + let stream = create_chat_stream( + app_ctx.endpoint.clone(), + app_ctx.token.clone(), + request, + client_ctx, + app_ctx.send_cwd, + app_ctx.last_command.clone(), + ); + futures::pin_mut!(stream); + + while let Some(event) = stream.next().await { + match event { + Ok(StreamFrame::Content(content)) => { + apply_content_frame(&handle, &tx, &capabilities, content); + } + Ok(StreamFrame::Control(control)) => { + let terminal = apply_control_frame(&handle, control); + if terminal { + break; + } + } + Err(e) => { + let msg = e.to_string(); + handle.update(move |state| { + state.streaming_error(msg); + }); + break; + } + } + } +} + +/// Apply a content frame to session state. +/// Control flow: always continues the stream. +fn apply_content_frame( + handle: &Handle<Session>, + tx: &mpsc::Sender<AiTuiEvent>, + capabilities: &[String], + content: StreamContent, +) { + match content { + StreamContent::TextChunk(text) => { + handle.update(move |state| { + state.conversation.append_streaming_text(&text); + }); + } + StreamContent::ToolCall { id, name, input } => { + if let Ok(tool) = ClientToolCall::try_from((name.as_str(), &input)) { + // Enforce capability gating: reject tool calls the client didn't advertise. + if let Some(required_cap) = tool.descriptor().capability + && !capabilities.iter().any(|c| c == required_cap) + { + tracing::warn!( + tool = name, + capability = required_cap, + "Rejecting tool call: capability not advertised" + ); + handle.update(move |state| { + state.add_tool_call(id.clone(), name, input.clone()); + state.conversation.add_tool_result( + id, + format!("Tool not enabled: capability '{required_cap}' was not advertised by this client"), + true, + ); + }); + return; + } + + // Client-side tool — add to tracker and conversation, queue permission check + let id_for_event = id.clone(); + handle.update(move |state| { + state.handle_client_tool_call(id_for_event, tool, input); + }); + let _ = tx.send(AiTuiEvent::CheckToolCallPermission(id)); + } else { + // Server-side tool — just add to conversation events + handle.update(move |state| { + state.add_tool_call(id, name, input); + }); + } + } + StreamContent::ToolResult { + tool_use_id, + content, + is_error, + } => { + handle.update(move |state| { + state + .conversation + .add_tool_result(tool_use_id, content, is_error); + }); + } + } +} + +/// Apply a control frame to session state. +/// Returns true if the stream should terminate. +fn apply_control_frame(handle: &Handle<Session>, control: StreamControl) -> bool { + match control { + StreamControl::StatusChanged(status) => { + handle.update(move |state| { + state.update_streaming_status(&status); + }); + false + } + StreamControl::Done { session_id } => { + handle.update(move |state| { + if !session_id.is_empty() { + state.conversation.store_session_id(session_id); + } + state.finalize_streaming(); + }); + true + } + StreamControl::Error(msg) => { + handle.update(move |state| { + state.streaming_error(msg); + }); + true + } + } +} + +fn hub_url(base: &str, path: &str) -> Result<Url> { + let base_with_slash = if base.ends_with('/') { + base.to_string() + } else { + format!("{base}/") + }; + let stripped = path.strip_prefix('/').unwrap_or(path); + Url::parse(&base_with_slash)? + .join(stripped) + .context("failed to build hub URL") +} diff --git a/crates/atuin-ai/src/tools/descriptor.rs b/crates/atuin-ai/src/tools/descriptor.rs new file mode 100644 index 00000000..3b2b7ebf --- /dev/null +++ b/crates/atuin-ai/src/tools/descriptor.rs @@ -0,0 +1,98 @@ +/// Centralized metadata for a tool type. +/// +/// Covers both client-side tools (ones the CLI executes locally) and +/// server-side tools (ones the API executes remotely). This is the single +/// source of truth for display text and classification. +pub(crate) struct ToolDescriptor { + /// Canonical wire names for this tool (the names the server sends). + pub canonical_names: &'static [&'static str], + /// The capability string the client must advertise for this tool to be + /// accepted. `None` for server-side tools (always accepted). + pub capability: Option<&'static str>, + /// Imperative verb for permission prompts (e.g. "read", "run"). + pub display_verb: &'static str, + /// Present-tense progressive verb for spinners (e.g. "Reading file..."). + pub progressive_verb: &'static str, + /// Past-tense verb for summaries (e.g. "Read file"). + pub past_verb: &'static str, + /// Whether this tool is executed client-side (by the CLI). + pub is_client: bool, +} + +// ── Client-side tool descriptors ── + +pub(crate) const READ: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["read_file"], + capability: Some("client_v1_read"), + display_verb: "read", + progressive_verb: "Reading file...", + past_verb: "Read file", + is_client: true, +}; + +pub(crate) const WRITE: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["str_replace", "file_create", "file_insert"], + capability: Some("client_v1_write"), + display_verb: "write to", + progressive_verb: "Writing file...", + past_verb: "Wrote file", + is_client: true, +}; + +pub(crate) const SHELL: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["execute_shell_command"], + capability: Some("client_v1_shell"), + display_verb: "run", + progressive_verb: "Running command...", + past_verb: "Ran command", + is_client: true, +}; + +pub(crate) const ATUIN_HISTORY: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["atuin_history"], + capability: Some("client_v1_atuin_history"), + display_verb: "search your Atuin history for", + progressive_verb: "Searching...", + past_verb: "Searched", + is_client: true, +}; + +// ── Server-side tool descriptors ── +// These appear in tool summaries but aren't client-side tools. + +pub(crate) const SERVER_SEARCH: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["web_search"], + capability: None, + display_verb: "search", + progressive_verb: "Searching...", + past_verb: "Searched", + is_client: false, +}; + +pub(crate) const SERVER_SCRAPE: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["web_scrape"], + capability: None, + display_verb: "scrape", + progressive_verb: "Scraping...", + past_verb: "Scraped", + is_client: false, +}; + +/// All known tool descriptors, for lookup by name. +const ALL_DESCRIPTORS: &[&ToolDescriptor] = &[ + READ, + WRITE, + SHELL, + ATUIN_HISTORY, + SERVER_SEARCH, + SERVER_SCRAPE, +]; + +/// Look up a tool descriptor by its canonical wire name. +/// Returns None for unknown tool names. +pub(crate) fn by_name(name: &str) -> Option<&'static ToolDescriptor> { + ALL_DESCRIPTORS + .iter() + .find(|d| d.canonical_names.contains(&name)) + .copied() +} diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs new file mode 100644 index 00000000..8f2183b7 --- /dev/null +++ b/crates/atuin-ai/src/tools/mod.rs @@ -0,0 +1,1111 @@ +use std::{ + io::BufRead, + path::{Path, PathBuf}, + time::Duration, +}; + +use eyre::Result; + +const DEFAULT_FILE_READ_LINES: u64 = 100; +const MAX_FILE_READ_LINES: u64 = 1000; + +pub(crate) mod descriptor; + +use crate::permissions::rule::Rule; + +/// Check whether a file path matches a scope glob pattern. +/// +/// Resolves relative paths against the current directory before matching so +/// that `./foo.md` and `/cwd/foo.md` match the same glob. Supports `*`, `**`, +/// `?`, and `[...]` via `glob_match`. +fn path_matches_scope(path: &Path, scope: &str) -> bool { + let path = if path.is_relative() { + std::env::current_dir() + .map(|cwd| cwd.join(path)) + .unwrap_or_else(|_| path.to_path_buf()) + } else { + path.to_path_buf() + }; + // Normalize to forward slashes so globs work on Windows too. + let path_str = path.to_string_lossy().replace('\\', "/"); + + // If the scope is also relative, try matching against both the absolute + // path and just the filename/relative portion. + if !scope.starts_with('/') { + // Match against filename (e.g. "*.md" matches any .md file) + if let Some(name) = path.file_name().and_then(|n| n.to_str()) + && glob_match::glob_match(scope, name) + { + return true; + } + // Also try matching against the full absolute path in case the scope + // is a relative multi-segment pattern like "crates/**/*.rs" + if glob_match::glob_match(scope, &path_str) { + return true; + } + // And match relative to cwd (so "src/*.rs" works from project root) + if let Ok(cwd) = std::env::current_dir() + && let Ok(rel) = path.strip_prefix(&cwd) + { + let rel_str = rel.to_string_lossy().replace('\\', "/"); + return glob_match::glob_match(scope, &rel_str); + } + return false; + } + + // Absolute scope — match against absolute path + glob_match::glob_match(scope, &path_str) +} + +/// Result of executing a client-side tool. +pub(crate) enum ToolOutcome { + /// Simple success with a text result (used by Read, AtuinHistory). + Success(String), + /// Error with a message. + Error(String), + /// Structured shell result with separated stdout, stderr, exit code, and duration. + Structured { + stdout: String, + stderr: String, + exit_code: Option<i32>, + duration_ms: u64, + interrupted: bool, + }, +} + +impl ToolOutcome { + /// Format this outcome as a string for the tool result sent to the LLM. + pub fn format_for_llm(&self) -> String { + match self { + ToolOutcome::Success(s) => s.clone(), + ToolOutcome::Error(e) => e.clone(), + ToolOutcome::Structured { + stdout, + stderr, + exit_code, + duration_ms, + interrupted, + } => { + let mut parts = Vec::new(); + + if let Some(code) = exit_code { + parts.push(format!("Exit code: {code}")); + } + + parts.push(format!("Duration: {duration_ms}ms")); + + if !stdout.is_empty() { + parts.push(format!("stdout:\n{stdout}")); + } else { + parts.push("stdout: (empty)".to_string()); + } + + if !stderr.is_empty() { + parts.push(format!("stderr:\n{stderr}")); + } else { + parts.push("stderr: (empty)".to_string()); + } + + if *interrupted { + parts.push("[Interrupted by user]".to_string()); + } + + parts.join("\n\n") + } + } + } + + /// Whether this outcome represents an error. + pub fn is_error(&self) -> bool { + match self { + ToolOutcome::Error(_) => true, + ToolOutcome::Structured { + exit_code: Some(code), + .. + } if *code != 0 => true, + _ => false, + } + } +} + +/// Cached VT100 preview data for a shell tool call. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct ToolPreview { + pub lines: Vec<String>, + pub exit_code: Option<i32>, + pub interrupted: bool, +} + +/// Lifecycle phase of a tracked tool call. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum ToolPhase { + CheckingPermissions, + AskingForPermission, + #[expect(dead_code)] + Denied(String), + #[expect(dead_code)] + Executing, + /// Shell command is executing with live preview output. + ExecutingWithPreview { + command: String, + /// Current VT100 screen lines (plain text, viewport-sized). + output_lines: Vec<String>, + /// Exit code once the process completes. + exit_code: Option<i32>, + /// Whether the command was interrupted by the user. + interrupted: bool, + }, + /// Tool execution has completed. Preview is cached for rendering history. + Completed { + preview: Option<ToolPreview>, + }, +} + +/// A tracked tool call through its full lifecycle. +#[derive(Debug)] +pub(crate) struct TrackedTool { + pub id: String, + pub tool: ClientToolCall, + pub phase: ToolPhase, + /// Sender to interrupt a running shell command (only set during ExecutingWithPreview). + pub abort_tx: Option<tokio::sync::oneshot::Sender<()>>, +} + +impl TrackedTool { + pub(crate) fn target_dir(&self) -> Option<&Path> { + self.tool.target_dir() + } + + pub fn mark_asking(&mut self) { + self.phase = ToolPhase::AskingForPermission; + } + + pub fn mark_executing_preview(&mut self, command: String) { + self.phase = ToolPhase::ExecutingWithPreview { + command, + output_lines: Vec::new(), + exit_code: None, + interrupted: false, + }; + } + + pub fn complete(&mut self, preview: Option<ToolPreview>) { + self.phase = ToolPhase::Completed { preview }; + self.abort_tx = None; + } + + /// Extract the current preview, whether live or completed. + pub fn preview(&self) -> Option<ToolPreview> { + match &self.phase { + ToolPhase::ExecutingWithPreview { + output_lines, + exit_code, + interrupted, + .. + } => Some(ToolPreview { + lines: output_lines.clone(), + exit_code: *exit_code, + interrupted: *interrupted, + }), + ToolPhase::Completed { preview } => preview.clone(), + _ => None, + } + } +} + +/// Tracks all tool calls through their full lifecycle. +/// +/// Single source of truth for tool execution state. Entries persist after +/// completion so cached previews remain available for rendering history. +#[derive(Debug)] +pub(crate) struct ToolTracker { + tools: Vec<TrackedTool>, +} + +impl ToolTracker { + pub fn new() -> Self { + Self { tools: Vec::new() } + } + + /// Insert a new tool call in CheckingPermissions phase. + pub fn insert(&mut self, id: String, tool: ClientToolCall) { + self.tools.push(TrackedTool { + id, + tool, + phase: ToolPhase::CheckingPermissions, + abort_tx: None, + }); + } + + pub fn get(&self, id: &str) -> Option<&TrackedTool> { + self.tools.iter().find(|t| t.id == id) + } + + pub fn get_mut(&mut self, id: &str) -> Option<&mut TrackedTool> { + self.tools.iter_mut().find(|t| t.id == id) + } + + /// Remove a tool by ID and return it. + #[expect(dead_code)] + pub fn remove(&mut self, id: &str) -> Option<TrackedTool> { + let pos = self.tools.iter().position(|t| t.id == id)?; + Some(self.tools.remove(pos)) + } + + /// True if any tool is still awaiting a permission decision. + #[expect(dead_code)] + pub fn has_unresolved(&self) -> bool { + self.tools.iter().any(|t| { + matches!( + t.phase, + ToolPhase::CheckingPermissions | ToolPhase::AskingForPermission + ) + }) + } + + /// True if any tool has not yet reached the Completed phase. + /// Use this to gate `ContinueAfterTools` — we must wait for all tools + /// (including those still executing) before resuming the conversation. + pub fn has_pending(&self) -> bool { + self.tools + .iter() + .any(|t| !matches!(t.phase, ToolPhase::Completed { .. })) + } + + /// True if any tool is currently executing with a preview. + pub fn has_executing_preview(&self) -> bool { + self.tools + .iter() + .any(|t| matches!(t.phase, ToolPhase::ExecutingWithPreview { .. })) + } + + /// Find the first tool that is asking for permission. + pub fn asking_for_permission(&self) -> Option<&TrackedTool> { + self.tools + .iter() + .find(|t| t.phase == ToolPhase::AskingForPermission) + } + + /// Find the first tool that is asking for permission (mutable). + #[expect(dead_code)] + pub fn asking_for_permission_mut(&mut self) -> Option<&mut TrackedTool> { + self.tools + .iter_mut() + .find(|t| t.phase == ToolPhase::AskingForPermission) + } + + /// Get the preview for a tool by ID (live or cached). + pub fn preview_for(&self, id: &str) -> Option<ToolPreview> { + self.get(id)?.preview() + } + + /// Iterate mutably over all tracked tools. + pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut TrackedTool> { + self.tools.iter_mut() + } +} + +/// A tool call from the server, with parsed input parameters. +#[derive(Debug, Clone)] +pub(crate) enum ClientToolCall { + Read(ReadToolCall), + Write(WriteToolCall), + Shell(ShellToolCall), + AtuinHistory(AtuinHistoryToolCall), +} + +impl TryFrom<(&str, &serde_json::Value)> for ClientToolCall { + type Error = eyre::Error; + + fn try_from((name, input): (&str, &serde_json::Value)) -> Result<Self, Self::Error> { + match name { + "read_file" => Ok(ClientToolCall::Read(ReadToolCall::try_from(input)?)), + "create_file" => Ok(ClientToolCall::Write(WriteToolCall::try_from(input)?)), + // "append_to_file" => Ok(ClientToolCall::Append(AppendToolCall::try_from(input)?)), + // "str_replace" => Ok(ClientToolCall::StrReplace(StrReplaceToolCall::try_from(input)?)), + "execute_shell_command" => Ok(ClientToolCall::Shell(ShellToolCall::try_from(input)?)), + "atuin_history" => Ok(ClientToolCall::AtuinHistory( + AtuinHistoryToolCall::try_from(input)?, + )), + _ => Err(eyre::eyre!("Unknown tool call: {name}")), + } + } +} + +impl ClientToolCall { + pub(crate) fn descriptor(&self) -> &'static descriptor::ToolDescriptor { + match self { + ClientToolCall::Read(_) => descriptor::READ, + ClientToolCall::Write(_) => descriptor::WRITE, + ClientToolCall::Shell(_) => descriptor::SHELL, + ClientToolCall::AtuinHistory(_) => descriptor::ATUIN_HISTORY, + } + } + + /// The permission rule name for this tool category (e.g. "Write" covers + /// str_replace, file_create, file_insert). + pub(crate) fn rule_name(&self) -> &'static str { + match self { + ClientToolCall::Read(_) => "Read", + ClientToolCall::Write(_) => "Write", + ClientToolCall::Shell(_) => "Shell", + ClientToolCall::AtuinHistory(_) => "AtuinHistory", + } + } + + pub(crate) fn matches_rule(&self, rule: &Rule) -> bool { + match self { + ClientToolCall::Read(tool) => tool.matches_rule(rule), + ClientToolCall::Write(tool) => tool.matches_rule(rule), + ClientToolCall::Shell(tool) => tool.matches_rule(rule), + ClientToolCall::AtuinHistory(tool) => tool.matches_rule(rule), + } + } + + pub(crate) fn target_dir(&self) -> Option<&Path> { + match self { + ClientToolCall::Read(tool) => tool.target_dir(), + ClientToolCall::Write(tool) => tool.target_dir(), + ClientToolCall::Shell(tool) => tool.target_dir(), + ClientToolCall::AtuinHistory(tool) => tool.target_dir(), + } + } + + /// Execute this client-side tool and return the result. + pub async fn execute(&self, db: &atuin_client::database::Sqlite) -> ToolOutcome { + match self { + ClientToolCall::Read(tool) => tool.execute(), + ClientToolCall::AtuinHistory(tool) => tool.execute(db).await, + _ => ToolOutcome::Error("Client-side tool execution not yet implemented".to_string()), + } + } +} + +/// A trait for tool calls that can be checked against permission rules. +pub(crate) trait PermissableToolCall { + /// Checks if this tool call matches the given permission rule. + fn matches_rule(&self, rule: &Rule) -> bool; + /// Returns the target directory of this tool call, if applicable, for checking against directory-based rules. + fn target_dir(&self) -> Option<&Path> { + None + } +} + +impl PermissableToolCall for ClientToolCall { + fn matches_rule(&self, rule: &Rule) -> bool { + self.matches_rule(rule) + } + + fn target_dir(&self) -> Option<&Path> { + self.target_dir() + } +} + +#[derive(Debug, Clone)] +pub(crate) struct ReadToolCall { + pub path: PathBuf, + pub offset: u64, + pub limit: u64, +} + +impl TryFrom<&serde_json::Value> for ReadToolCall { + type Error = eyre::Error; + + fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> { + let path = value + .get("file_path") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing path"))?; + + let offset = value.get("offset").and_then(|v| v.as_u64()).unwrap_or(0); + let limit = value + .get("limit") + .and_then(|v| v.as_u64()) + .unwrap_or(DEFAULT_FILE_READ_LINES) + .min(MAX_FILE_READ_LINES); + + Ok(ReadToolCall { + path: PathBuf::from(path), + offset, + limit, + }) + } +} + +impl ReadToolCall { + fn execute(&self) -> ToolOutcome { + let mut path = self.path.clone(); + + if path.is_relative() + && let Ok(current_dir) = std::env::current_dir() + { + path = current_dir.join(path); + } + + if !path.exists() { + return ToolOutcome::Error(format!("Error: file does not exist: {}", path.display())); + } + + if path.is_dir() { + let Some(files) = std::fs::read_dir(&path).ok().and_then(|entries| { + entries + .filter_map(|entry| entry.ok()) + .map(|entry| entry.file_name().to_string_lossy().to_string()) + .collect::<Vec<_>>() + .into() + }) else { + return ToolOutcome::Error(format!( + "Error: could not read directory: {}", + path.display() + )); + }; + + return ToolOutcome::Success(format!("Directory contents:\n{}", files.join("\n"))); + } + + let file = match std::fs::File::open(&path) { + Ok(file) => file, + Err(e) => return ToolOutcome::Error(format!("Error opening file: {e}")), + }; + let reader = std::io::BufReader::new(file); + + let relevent_lines = reader + .lines() + .skip(self.offset as usize) + .take(self.limit as usize) + .collect::<Result<Vec<_>, _>>(); + + match relevent_lines { + Ok(lines) => { + let joined = lines.join("\n"); + if joined.len() > 100_000 { + ToolOutcome::Error(format!( + "Error: file is too large to read ({} bytes in {} lines); use view_range to read a subset of the file", + joined.len(), + lines.len() + )) + } else { + ToolOutcome::Success(joined) + } + } + Err(e) => ToolOutcome::Error(format!("Error reading file: {e}")), + } + } +} + +impl PermissableToolCall for ReadToolCall { + fn target_dir(&self) -> Option<&Path> { + Some(&self.path) + } + + fn matches_rule(&self, rule: &Rule) -> bool { + if rule.tool != "Read" { + return false; + } + + match rule.scope.as_deref() { + None | Some("*") => true, + Some(scope) => path_matches_scope(&self.path, scope), + } + } +} + +#[derive(Debug, Clone)] +#[expect(dead_code)] +pub(crate) struct WriteToolCall { + pub path: PathBuf, + pub content: String, +} + +impl TryFrom<&serde_json::Value> for WriteToolCall { + type Error = eyre::Error; + + fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> { + let path = value + .get("path") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing path"))?; + + let content = value + .get("content") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing content"))?; + + Ok(WriteToolCall { + path: PathBuf::from(path), + content: content.to_string(), + }) + } +} + +impl PermissableToolCall for WriteToolCall { + fn target_dir(&self) -> Option<&Path> { + Some(&self.path) + } + + fn matches_rule(&self, rule: &Rule) -> bool { + if rule.tool != "Write" { + return false; + } + + match rule.scope.as_deref() { + None | Some("*") => true, + Some(scope) => path_matches_scope(&self.path, scope), + } + } +} + +#[derive(Debug, Clone)] +pub(crate) struct ShellToolCall { + pub dir: Option<PathBuf>, + pub command: String, + pub shell: String, +} + +impl TryFrom<&serde_json::Value> for ShellToolCall { + type Error = eyre::Error; + + fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> { + let dir = value.get("dir").and_then(|v| v.as_str()); + + let command = value + .get("command") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing command"))?; + + let shell = value + .get("shell") + .and_then(|v| v.as_str()) + .unwrap_or("bash") + .to_string(); + + Ok(ShellToolCall { + dir: dir.map(PathBuf::from), + command: command.to_string(), + shell, + }) + } +} + +impl PermissableToolCall for ShellToolCall { + fn target_dir(&self) -> Option<&Path> { + self.dir.as_deref() + } + + fn matches_rule(&self, rule: &Rule) -> bool { + if rule.tool != "Shell" { + return false; + } + + let Some(scope) = rule.scope.as_ref() else { + // Shell without scope matches all shell commands + return true; + }; + + let shell_kind = crate::permissions::shell::ShellKind::from_shell_name(&self.shell); + let parsed = crate::permissions::shell::parse_shell_command(&self.command, shell_kind); + crate::permissions::shell::any_subcommand_matches(&parsed.subcommands, scope) + } +} + +/// Preview viewport height for VT100 emulation. +const PREVIEW_HEIGHT: u16 = 10; + +/// Default terminal width for VT100 emulation. +const PREVIEW_WIDTH: u16 = 120; + +/// Extract plain text lines from a VT100 screen buffer. +fn vt100_screen_lines(screen: &vt100::Screen) -> Vec<String> { + let (rows, cols) = screen.size(); + let mut lines = Vec::with_capacity(rows as usize); + for row in 0..rows { + let mut line = String::with_capacity(cols as usize); + for col in 0..cols { + if let Some(cell) = screen.cell(row, col) { + line.push_str(cell.contents()); + } + } + // Trim trailing whitespace for cleaner display + lines.push(line.trim_end().to_string()); + } + lines +} + +/// Strip ANSI escape sequences from raw bytes using a VT100 parser. +/// +/// Uses a large virtual screen so scrollback is preserved, then extracts +/// the plain text contents. This handles all escape sequences (colors, +/// cursor movement, progress bars, etc.) not just simple SGR codes. +fn strip_ansi_via_vt100(raw: &[u8]) -> String { + if raw.is_empty() { + return String::new(); + } + // Use the contents_formatted → screen approach: feed bytes into a parser + // with enough rows to hold everything, then read back the plain text. + // Estimate rows: one row per ~PREVIEW_WIDTH bytes, plus generous padding. + let estimated_rows = (raw.len() / PREVIEW_WIDTH as usize + 1).min(10_000) as u16; + let mut parser = vt100::Parser::new(estimated_rows, PREVIEW_WIDTH, 0); + parser.process(raw); + let screen = parser.screen(); + // screen.contents() returns the full plain-text content with trailing + // whitespace trimmed per line and trailing blank lines removed. + screen.contents() +} + +/// Execute a shell command with VT100 emulation and streaming output. +/// +/// Feeds stdout+stderr into a `vt100::Parser` so that ANSI escape sequences, +/// progress bars (`\r`), and cursor movement are handled correctly. Periodically +/// sends the current screen state as `Vec<String>` through `output_tx` for the +/// live preview. +/// +/// Captures the FULL stdout and stderr separately for the tool result sent to the LLM. +/// Returns a `ToolOutcome::Structured` with full output, exit code, and duration. +pub(crate) async fn execute_shell_command_streaming( + shell_call: &ShellToolCall, + output_tx: tokio::sync::mpsc::Sender<Vec<String>>, + mut interrupt_rx: tokio::sync::oneshot::Receiver<()>, +) -> ToolOutcome { + use tokio::io::AsyncReadExt; + + let start = std::time::Instant::now(); + + // TODO: check if this is proper for all shells we support + let mut cmd = tokio::process::Command::new(&shell_call.shell); + cmd.arg("-c").arg(&shell_call.command); + cmd.stdout(std::process::Stdio::piped()); + cmd.stderr(std::process::Stdio::piped()); + + if let Some(ref dir) = shell_call.dir { + cmd.current_dir(dir); + } + + let mut child = match cmd.spawn() { + Ok(child) => child, + Err(e) => return ToolOutcome::Error(format!("Failed to spawn command: {e}")), + }; + + let stdout = child.stdout.take().expect("stdout was piped"); + let stderr = child.stderr.take().expect("stderr was piped"); + + // VT100 emulator for the live preview (viewport-sized) + let mut parser = vt100::Parser::new(PREVIEW_HEIGHT, PREVIEW_WIDTH, 0); + + let mut stdout_reader = tokio::io::BufReader::new(stdout); + let mut stderr_reader = tokio::io::BufReader::new(stderr); + + let mut stdout_buf = [0u8; 4096]; + let mut stderr_buf = [0u8; 4096]; + let mut stdout_done = false; + let mut stderr_done = false; + + // Full output buffers (for the LLM, not the preview) + let mut full_stdout = Vec::<u8>::new(); + let mut full_stderr = Vec::<u8>::new(); + + let mut interval = tokio::time::interval(Duration::from_millis(50)); + + // Send initial empty screen + let initial_lines = vt100_screen_lines(parser.screen()); + let _ = output_tx.send(initial_lines).await; + + let mut interrupted = false; + + loop { + tokio::select! { + biased; + + // Check for interrupt signal + _ = &mut interrupt_rx, if !interrupted => { + interrupted = true; + let _ = child.start_kill(); + } + + // Read stdout + result = stdout_reader.read(&mut stdout_buf), if !stdout_done => { + match result { + Ok(0) => stdout_done = true, + Ok(n) => { + full_stdout.extend_from_slice(&stdout_buf[..n]); + parser.process(&stdout_buf[..n]); + } + Err(_) => stdout_done = true, + } + } + + // Read stderr + result = stderr_reader.read(&mut stderr_buf), if !stderr_done => { + match result { + Ok(0) => stderr_done = true, + Ok(n) => { + full_stderr.extend_from_slice(&stderr_buf[..n]); + // Feed stderr to the preview parser too, so it shows in the VT100 screen + parser.process(&stderr_buf[..n]); + } + Err(_) => stderr_done = true, + } + } + + // Periodic screen snapshot for preview + _ = interval.tick() => { + let lines = vt100_screen_lines(parser.screen()); + let _ = output_tx.send(lines).await; + } + } + + // Exit when both streams are done + if stdout_done && stderr_done { + break; + } + } + + // Wait for process to finish + let exit_code = match child.wait().await { + Ok(status) => status.code(), + Err(e) => { + if interrupted { + None + } else { + return ToolOutcome::Error(format!("Failed to wait for command: {e}")); + } + } + }; + + let duration = start.elapsed(); + + // Send final screen state + let final_lines = vt100_screen_lines(parser.screen()); + let _ = output_tx.send(final_lines).await; + + // Strip ANSI escape sequences for clean LLM output by running + // the raw bytes through a VT100 parser and extracting plain text. + let stdout_text = strip_ansi_via_vt100(&full_stdout); + let stderr_text = strip_ansi_via_vt100(&full_stderr); + + ToolOutcome::Structured { + stdout: stdout_text, + stderr: stderr_text, + exit_code, + duration_ms: duration.as_millis() as u64, + interrupted, + } +} + +#[derive(Debug, Clone)] +pub(crate) struct AtuinHistoryToolCall { + pub filter_modes: Vec<HistorySearchFilterMode>, + pub query: String, + pub limit: i64, +} + +#[derive(Debug, Clone)] +pub(crate) enum HistorySearchFilterMode { + Global, + Host, + Session, + Directory, + Workspace, +} + +impl From<&HistorySearchFilterMode> for atuin_client::settings::FilterMode { + fn from(mode: &HistorySearchFilterMode) -> Self { + match mode { + HistorySearchFilterMode::Global => Self::Global, + HistorySearchFilterMode::Host => Self::Host, + HistorySearchFilterMode::Session => Self::Session, + HistorySearchFilterMode::Directory => Self::Directory, + HistorySearchFilterMode::Workspace => Self::Workspace, + } + } +} + +impl TryFrom<&serde_json::Value> for AtuinHistoryToolCall { + type Error = eyre::Error; + + fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> { + let filter_modes = value + .get("filter_modes") + .and_then(|v| v.as_array()) + .ok_or(eyre::eyre!("Missing filter_modes"))?; + + let filter_modes = filter_modes + .iter() + .map(|v| { + let mode = v.as_str().ok_or(eyre::eyre!("Invalid filter mode"))?; + match mode { + "global" => Ok(HistorySearchFilterMode::Global), + "host" => Ok(HistorySearchFilterMode::Host), + "session" => Ok(HistorySearchFilterMode::Session), + "directory" => Ok(HistorySearchFilterMode::Directory), + "workspace" => Ok(HistorySearchFilterMode::Workspace), + _ => Err(eyre::eyre!("Invalid filter mode: {mode}")), + } + }) + .collect::<Result<Vec<HistorySearchFilterMode>>>()?; + + let query = value + .get("query") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing query"))?; + + let limit = value + .get("limit") + .and_then(|v| v.as_i64()) + .unwrap_or(10) + .clamp(1, 50); + + Ok(AtuinHistoryToolCall { + filter_modes, + query: query.to_string(), + limit, + }) + } +} + +impl PermissableToolCall for AtuinHistoryToolCall { + fn target_dir(&self) -> Option<&Path> { + None + } + + fn matches_rule(&self, rule: &Rule) -> bool { + rule.tool == "AtuinHistory" + } +} + +impl AtuinHistoryToolCall { + pub(crate) async fn execute(&self, db: &atuin_client::database::Sqlite) -> ToolOutcome { + use atuin_client::database::{self, Database as _, OptFilters}; + use atuin_client::settings::SearchMode; + use time::UtcOffset; + + let context = match database::current_context().await { + Ok(ctx) => ctx, + Err(e) => return ToolOutcome::Error(format!("Failed to get history context: {e}")), + }; + + let filter_mode = self + .filter_modes + .first() + .map(atuin_client::settings::FilterMode::from) + .unwrap_or(atuin_client::settings::FilterMode::Global); + + let filter_options = OptFilters { + limit: Some(self.limit), + ..Default::default() + }; + + let results = match db + .search( + SearchMode::Fuzzy, + filter_mode, + &context, + &self.query, + filter_options, + ) + .await + { + Ok(results) => results, + Err(e) => return ToolOutcome::Error(format!("History search failed: {e}")), + }; + + if results.is_empty() { + return ToolOutcome::Success("No matching history entries found.".to_string()); + } + + let local_offset = UtcOffset::current_local_offset().unwrap_or(UtcOffset::UTC); + + let formatted: Vec<String> = results + .iter() + .enumerate() + .map(|(i, h)| { + let ts = h.timestamp.to_offset(local_offset); + let time_str = format!( + "{:04}-{:02}-{:02} {:02}:{:02}:{:02}", + ts.year(), + ts.month() as u8, + ts.day(), + ts.hour(), + ts.minute(), + ts.second(), + ); + + let duration_str = format_duration(h.duration); + + format!( + "{}. `{}` [{}] ({}, exit: {}){}", + i + 1, + h.command, + time_str, + h.cwd, + h.exit, + duration_str, + ) + }) + .collect(); + + ToolOutcome::Success(formatted.join("\n")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn read_rule(scope: Option<&str>) -> Rule { + Rule { + tool: "Read".to_string(), + scope: scope.map(String::from), + } + } + + fn write_rule(scope: Option<&str>) -> Rule { + Rule { + tool: "Write".to_string(), + scope: scope.map(String::from), + } + } + + fn read_tool(path: &str) -> ReadToolCall { + ReadToolCall { + path: PathBuf::from(path), + offset: 0, + limit: 100, + } + } + + fn write_tool(path: &str) -> WriteToolCall { + WriteToolCall { + path: PathBuf::from(path), + content: String::new(), + } + } + + // ── Cross-platform tests ── + + #[test] + fn no_scope_matches_everything() { + assert!(read_tool("any/path.txt").matches_rule(&read_rule(None))); + assert!(write_tool("any/path.txt").matches_rule(&write_rule(None))); + } + + #[test] + fn wildcard_star_matches_everything() { + assert!(read_tool("foo/bar.rs").matches_rule(&read_rule(Some("*")))); + } + + #[test] + fn wrong_tool_never_matches() { + assert!(!read_tool("foo.txt").matches_rule(&write_rule(None))); + assert!(!write_tool("foo.txt").matches_rule(&read_rule(None))); + } + + #[test] + fn extension_glob() { + assert!(read_tool("notes.md").matches_rule(&read_rule(Some("*.md")))); + assert!(!read_tool("notes.txt").matches_rule(&read_rule(Some("*.md")))); + } + + #[test] + fn relative_multi_segment_glob() { + // This matches against the path relative to cwd + let cwd = std::env::current_dir().unwrap(); + let abs = cwd + .join("crates") + .join("atuin-ai") + .join("src") + .join("lib.rs"); + let tool = read_tool(abs.to_str().unwrap()); + assert!(tool.matches_rule(&read_rule(Some("crates/**/*.rs")))); + assert!(!tool.matches_rule(&read_rule(Some("crates/**/*.py")))); + } + + // ── Unix-specific tests (absolute paths with forward slashes) ── + + #[cfg(unix)] + mod unix { + use super::*; + + #[test] + fn absolute_glob() { + assert!( + read_tool("/home/user/src/main.rs") + .matches_rule(&read_rule(Some("/home/user/src/*.rs"))) + ); + assert!( + !read_tool("/home/user/docs/readme.md") + .matches_rule(&read_rule(Some("/home/user/src/*.rs"))) + ); + } + + #[test] + fn double_star_glob() { + assert!( + read_tool("/project/crates/foo/src/lib.rs") + .matches_rule(&read_rule(Some("/project/crates/**/*.rs"))) + ); + assert!( + !read_tool("/project/crates/foo/src/lib.py") + .matches_rule(&read_rule(Some("/project/crates/**/*.rs"))) + ); + } + } + + // ── Windows-specific tests (absolute paths with drive letters) ── + + #[cfg(windows)] + mod windows { + use super::*; + + #[test] + fn absolute_glob() { + assert!( + read_tool(r"C:\Users\dev\src\main.rs") + .matches_rule(&read_rule(Some("C:/Users/dev/src/*.rs"))) + ); + assert!( + !read_tool(r"C:\Users\dev\docs\readme.md") + .matches_rule(&read_rule(Some("C:/Users/dev/src/*.rs"))) + ); + } + + #[test] + fn double_star_glob() { + assert!( + read_tool(r"C:\project\crates\foo\src\lib.rs") + .matches_rule(&read_rule(Some("C:/project/crates/**/*.rs"))) + ); + assert!( + !read_tool(r"C:\project\crates\foo\src\lib.py") + .matches_rule(&read_rule(Some("C:/project/crates/**/*.rs"))) + ); + } + } +} + +fn format_duration(nanos: i64) -> String { + if nanos <= 0 { + return String::new(); + } + + let total_secs = nanos / 1_000_000_000; + let millis = (nanos % 1_000_000_000) / 1_000_000; + + if total_secs >= 3600 { + let hours = total_secs / 3600; + let mins = (total_secs % 3600) / 60; + let secs = total_secs % 60; + format!(", {hours}h{mins}m{secs}s") + } else if total_secs >= 60 { + let mins = total_secs / 60; + let secs = total_secs % 60; + format!(", {mins}m{secs}s") + } else if total_secs > 0 { + if millis > 0 { + format!(", {total_secs}.{millis:03}s") + } else { + format!(", {total_secs}s") + } + } else { + format!(", {millis}ms") + } +} diff --git a/crates/atuin-ai/src/tui/components/atuin_ai.rs b/crates/atuin-ai/src/tui/components/atuin_ai.rs index fab29502..c04ac722 100644 --- a/crates/atuin-ai/src/tui/components/atuin_ai.rs +++ b/crates/atuin-ai/src/tui/components/atuin_ai.rs @@ -22,10 +22,11 @@ pub(crate) struct AtuinAi { pub has_command: bool, pub is_input_blank: bool, pub pending_confirmation: bool, + pub has_executing_preview: bool, } #[derive(Default)] -pub struct AtuinAiState { +pub(crate) struct AtuinAiState { tx: Option<mpsc::Sender<AiTuiEvent>>, } @@ -55,15 +56,24 @@ fn atuin_ai( return EventResult::Ignored; }; - // Ctrl+C always exits + // Ctrl+C — interrupt executing command or exit if modifiers.contains(KeyModifiers::CONTROL) && *code == KeyCode::Char('c') { - let _ = tx.send(AiTuiEvent::Exit); + if props.has_executing_preview { + let _ = tx.send(AiTuiEvent::InterruptToolExecution); + } else { + let _ = tx.send(AiTuiEvent::Exit); + } return EventResult::Consumed; } match props.mode { AppMode::Input => match code { KeyCode::Esc => { + if props.has_executing_preview { + let _ = tx.send(AiTuiEvent::InterruptToolExecution); + return EventResult::Consumed; + } + if props.pending_confirmation { let _ = tx.send(AiTuiEvent::CancelConfirmation); return EventResult::Consumed; diff --git a/crates/atuin-ai/src/tui/components/markdown.rs b/crates/atuin-ai/src/tui/components/markdown.rs index 1cd7dbcf..f164fdc5 100644 --- a/crates/atuin-ai/src/tui/components/markdown.rs +++ b/crates/atuin-ai/src/tui/components/markdown.rs @@ -16,20 +16,12 @@ use ratatui_widgets::paragraph::{Paragraph, Wrap}; /// A markdown rendering component backed by pulldown-cmark. #[props] -pub struct Markdown { +pub(crate) struct Markdown { pub source: String, } -impl Markdown { - pub fn new(source: impl Into<String>) -> Self { - Self { - source: source.into(), - } - } -} - /// Style configuration for markdown rendering. -pub struct MarkdownStyles { +pub(crate) struct MarkdownStyles { pub base: Style, pub code_inline: Style, pub code_block: Style, @@ -98,26 +90,22 @@ fn parse_markdown<'a>(source: &'a str, styles: &'a MarkdownStyles) -> Text<'stat let mut style_stack: Vec<Style> = vec![styles.base]; let mut in_code_block = false; + let mut in_list_item = false; + // True until the first paragraph inside a list item has been opened. + // The first paragraph should flow inline with the "- " prefix. + let mut list_item_first_para = false; for event in parser { match event { Event::Start(Tag::Strong) => { - let bold = style_stack - .last() - .copied() - .unwrap_or(styles.base) - .add_modifier(Modifier::BOLD); + let bold = style_stack.last().copied().unwrap_or(styles.bold); style_stack.push(bold); } Event::End(TagEnd::Strong) => { style_stack.pop(); } Event::Start(Tag::Emphasis) => { - let italic = style_stack - .last() - .copied() - .unwrap_or(styles.base) - .add_modifier(Modifier::ITALIC); + let italic = style_stack.last().copied().unwrap_or(styles.italic); style_stack.push(italic); } Event::End(TagEnd::Emphasis) => { @@ -170,12 +158,17 @@ fn parse_markdown<'a>(source: &'a str, styles: &'a MarkdownStyles) -> Text<'stat lines.push(Vec::new()); } Event::Start(Tag::Paragraph) => { - if current_line > 0 || !lines[0].is_empty() { - // Two line advances: one to end the current line, one for a blank separator. - current_line += 1; - lines.push(Vec::new()); + if in_list_item && list_item_first_para { + // First paragraph flows inline with the "- " prefix + list_item_first_para = false; + } else if current_line > 0 || !lines[0].is_empty() { current_line += 1; lines.push(Vec::new()); + if !in_list_item { + // Blank separator between paragraphs (but not inside list items) + current_line += 1; + lines.push(Vec::new()); + } } } Event::End(TagEnd::Paragraph) => {} @@ -197,8 +190,12 @@ fn parse_markdown<'a>(source: &'a str, styles: &'a MarkdownStyles) -> Text<'stat lines.push(Vec::new()); } lines[current_line].push(Span::styled("- ", Style::default().fg(Color::DarkGray))); + in_list_item = true; + list_item_first_para = true; + } + Event::End(TagEnd::Item) => { + in_list_item = false; } - Event::End(TagEnd::Item) => {} Event::Start(Tag::List(_)) => { if current_line > 0 || !lines[0].is_empty() { current_line += 1; diff --git a/crates/atuin-ai/src/tui/components/mod.rs b/crates/atuin-ai/src/tui/components/mod.rs index 2f684f5f..3458327d 100644 --- a/crates/atuin-ai/src/tui/components/mod.rs +++ b/crates/atuin-ai/src/tui/components/mod.rs @@ -1,3 +1,4 @@ -pub mod atuin_ai; -pub mod input_box; -pub mod markdown; +pub(crate) mod atuin_ai; +pub(crate) mod input_box; +pub(crate) mod markdown; +pub(crate) mod select; diff --git a/crates/atuin-ai/src/tui/components/select.rs b/crates/atuin-ai/src/tui/components/select.rs new file mode 100644 index 00000000..5abbe655 --- /dev/null +++ b/crates/atuin-ai/src/tui/components/select.rs @@ -0,0 +1,96 @@ +use std::sync::mpsc; + +use crossterm::event::KeyCode; +use eye_declare::{Elements, EventResult, Hooks, Span, Text, View, component, element, props}; +use ratatui::style::Style; +use typed_builder::TypedBuilder; + +use crate::tui::events::AiTuiEvent; + +type OnSelectFn = Box<dyn Fn(&SelectOption) -> Option<AiTuiEvent> + Send + Sync + 'static>; + +#[derive(TypedBuilder)] +pub(crate) struct SelectOption { + #[builder(setter(into))] + pub label: String, + #[builder(setter(into))] + pub value: String, + #[builder(default = Style::default())] + pub label_style: Style, + #[builder(default = Style::default().reversed())] + pub selected_style: Style, +} + +#[derive(Default)] +pub(crate) struct PermissionSelectorState { + selected_option: usize, + tx: Option<mpsc::Sender<AiTuiEvent>>, +} + +#[props] +pub(crate) struct Select { + pub options: Vec<SelectOption>, + pub on_select: OnSelectFn, +} + +#[component(props = Select, state = PermissionSelectorState)] +pub(crate) fn permission_selector( + props: &Select, + state: &PermissionSelectorState, + hooks: &mut Hooks<Select, PermissionSelectorState>, +) -> Elements { + hooks.use_focusable(true); + hooks.use_autofocus(); + + hooks.use_context::<mpsc::Sender<AiTuiEvent>>(|tx, _, state| { + state.tx = tx.cloned(); + }); + + hooks.use_event(move |event, props, state| { + if !event.is_key_press() { + return EventResult::Ignored; + } + + if let crossterm::event::Event::Key(key) = event { + if key.kind != crossterm::event::KeyEventKind::Press { + return EventResult::Ignored; + } + + match key.code { + KeyCode::Up => { + state.selected_option = + (state.selected_option + props.options.len() - 1) % props.options.len(); + return EventResult::Consumed; + } + KeyCode::Down => { + state.selected_option = (state.selected_option + 1) % props.options.len(); + return EventResult::Consumed; + } + KeyCode::Enter => { + let option = &props.options[state.selected_option]; + if let Some(event) = (props.on_select)(option) + && let Some(ref tx) = state.tx + { + let _ = tx.send(event); + } + return EventResult::Consumed; + } + _ => {} + } + } + + EventResult::Ignored + }); + + element!( + View { + #(for (index, option) in props.options.iter().enumerate() { + Text { Span(text: &option.label, style: if index == state.selected_option { + option.selected_style + } else { + option.label_style + }) } + }) + } + ) +} diff --git a/crates/atuin-ai/src/tui/dispatch.rs b/crates/atuin-ai/src/tui/dispatch.rs new file mode 100644 index 00000000..b3e84757 --- /dev/null +++ b/crates/atuin-ai/src/tui/dispatch.rs @@ -0,0 +1,571 @@ +use std::path::PathBuf; +use std::sync::mpsc; + +use crate::context::{AppContext, ClientContext}; +use crate::permissions::check::PermissionResponse; +use crate::permissions::resolver::PermissionResolver; +use crate::permissions::rule::Rule; +use crate::permissions::writer::{self, RuleDisposition}; +use crate::stream::{ChatRequest, run_chat_stream}; +use crate::tools::{ClientToolCall, ToolPhase}; +use crate::tui::events::{AiTuiEvent, PermissionResult}; +use crate::tui::state::{ExitAction, Session}; +use eye_declare::Handle; +use tokio::task::JoinHandle; + +pub(crate) fn dispatch( + handle: &Handle<Session>, + event: AiTuiEvent, + tx: &mpsc::Sender<AiTuiEvent>, + app_ctx: &AppContext, + client_ctx: &ClientContext, +) { + match event { + AiTuiEvent::ContinueAfterTools => { + on_continue_after_tools(handle, tx, app_ctx, client_ctx); + } + AiTuiEvent::InputUpdated(input) => { + on_input_updated(handle, input); + } + AiTuiEvent::SubmitInput(input) => { + on_submit_input(handle, tx, app_ctx, client_ctx, input); + } + AiTuiEvent::SlashCommand(cmd) => { + on_slash_command(handle, cmd); + } + AiTuiEvent::CheckToolCallPermission(id) => { + on_check_tool_permission(handle, tx, app_ctx, id); + } + AiTuiEvent::SelectPermission(result) => { + on_select_permission(handle, tx, app_ctx, result); + } + AiTuiEvent::CancelGeneration => { + on_cancel_generation(handle); + } + AiTuiEvent::ExecuteCommand => { + on_execute_command(handle); + } + AiTuiEvent::CancelConfirmation => { + on_cancel_confirmation(handle); + } + AiTuiEvent::InterruptToolExecution => { + on_interrupt_tool_execution(handle); + } + AiTuiEvent::InsertCommand => { + on_insert_command(handle); + } + AiTuiEvent::Retry => { + on_retry(handle, tx, app_ctx, client_ctx); + } + AiTuiEvent::Exit => { + on_exit(handle); + } + } +} + +fn launch_stream( + handle: &Handle<Session>, + tx: &mpsc::Sender<AiTuiEvent>, + app_ctx: &AppContext, + client_ctx: &ClientContext, + setup: impl FnOnce(&mut Session) + Send + 'static, +) { + let h2 = handle.clone(); + let tx2 = tx.clone(); + let app = app_ctx.clone(); + let cc = client_ctx.clone(); + let caps = app_ctx.capabilities.clone(); + handle.update(move |state| { + (setup)(state); + state.start_streaming(); + let messages = state.conversation.events_to_messages(); + let sid = state.conversation.session_id.clone(); + let request = ChatRequest::new(messages, sid, &caps); + let task: JoinHandle<()> = tokio::spawn(async move { + run_chat_stream(h2, tx2, app, cc, request).await; + }); + state.stream_abort = Some(task.abort_handle()); + }); +} + +fn on_continue_after_tools( + handle: &Handle<Session>, + tx: &mpsc::Sender<AiTuiEvent>, + app_ctx: &AppContext, + client_ctx: &ClientContext, +) { + launch_stream(handle, tx, app_ctx, client_ctx, |_state| {}); +} + +fn on_input_updated(handle: &Handle<Session>, input: String) { + let input_blank = input.trim().is_empty(); + + handle.update(move |state| { + state.interaction.is_input_blank = input_blank; + }); +} + +fn on_submit_input( + handle: &Handle<Session>, + tx: &mpsc::Sender<AiTuiEvent>, + app_ctx: &AppContext, + client_ctx: &ClientContext, + input: String, +) { + let input = input.trim().to_string(); + if input.is_empty() { + let h2 = handle.clone(); + handle.update(move |state| { + if state.conversation.has_any_command() { + state.exit_action = Some(ExitAction::Execute( + state.conversation.current_command().unwrap().to_string(), + )); + } else { + state.exit_action = Some(ExitAction::Cancel); + } + h2.exit(); + }); + return; + } + + if input.starts_with('/') { + handle.update(move |state| { + state.conversation.handle_slash_command(&input); + }); + return; + } + + // Start generation and spawn streaming task + launch_stream(handle, tx, app_ctx, client_ctx, |state| { + state.start_generating(input); + state.interaction.is_input_blank = true; + }); +} + +fn on_slash_command(handle: &Handle<Session>, command: String) { + handle.update(move |state| { + state.conversation.handle_slash_command(&command); + }); +} + +// ─────────────────────────────────────────────────────────────────── +// Tool execution dispatch +// ─────────────────────────────────────────────────────────────────── + +/// Execute a tool call. Handles Shell tools (streaming with preview) and +/// non-shell tools (synchronous) uniformly. +fn execute_tool( + handle: &Handle<Session>, + tx: &mpsc::Sender<AiTuiEvent>, + tool_id: String, + tool: ClientToolCall, + db: &std::sync::Arc<atuin_client::database::Sqlite>, +) { + match &tool { + ClientToolCall::Shell(shell_call) => { + let shell_call = shell_call.clone(); + execute_shell_tool(handle, tx, &tool_id, &shell_call); + } + _ => { + execute_simple_tool(handle, tx, tool_id, tool, db); + } + } +} + +/// Execute a non-shell tool and finish the tool call. +/// The ToolCall event is already in the conversation (added by handle_client_tool_call). +fn execute_simple_tool( + handle: &Handle<Session>, + tx: &mpsc::Sender<AiTuiEvent>, + tool_id: String, + tool: ClientToolCall, + db: &std::sync::Arc<atuin_client::database::Sqlite>, +) { + let h = handle.clone(); + let tx = tx.clone(); + let db = db.clone(); + + tokio::spawn(async move { + let outcome = tool.execute(&db).await; + h.update(move |state| { + state.finish_tool_call(&tool_id, outcome); + if !state.tool_tracker.has_pending() { + let _ = tx.send(AiTuiEvent::ContinueAfterTools); + } + }); + }); +} + +/// Execute a shell tool with streaming VT100 preview. +fn execute_shell_tool( + handle: &Handle<Session>, + tx: &mpsc::Sender<AiTuiEvent>, + tool_id: &str, + shell_call: &crate::tools::ShellToolCall, +) { + let h = handle.clone(); + let tx = tx.clone(); + let shell_call = shell_call.clone(); + let command = shell_call.command.clone(); + let tc_id = tool_id.to_string(); + + // 1. Set up channels for streaming output and interruption + let (output_tx, mut output_rx) = tokio::sync::mpsc::channel::<Vec<String>>(32); + let (abort_tx, abort_rx) = tokio::sync::oneshot::channel::<()>(); + + // 2. Mark as executing with preview and store the abort sender on the tracker entry + let tc_id_setup = tc_id.clone(); + h.update(move |state| { + if let Some(tracked) = state.tool_tracker.get_mut(&tc_id_setup) { + tracked.mark_executing_preview(command); + tracked.abort_tx = Some(abort_tx); + } + }); + + // 3. Spawn a task to consume output updates and feed them to state + let h_output = h.clone(); + let preview_id = tc_id.clone(); + let output_task = tokio::spawn(async move { + while let Some(lines) = output_rx.recv().await { + let id = preview_id.clone(); + h_output.update(move |state| { + if let Some(tracked) = state.tool_tracker.get_mut(&id) + && let ToolPhase::ExecutingWithPreview { + ref mut output_lines, + .. + } = tracked.phase + { + *output_lines = lines; + } + }); + } + }); + + // 4. Spawn the streaming execution task + let tc_id_finish = tc_id; + tokio::spawn(async move { + let outcome = + crate::tools::execute_shell_command_streaming(&shell_call, output_tx, abort_rx).await; + + // Wait for the output task to finish so the final preview lines are captured + let _ = output_task.await; + + h.update(move |state| { + state.finish_tool_call(&tc_id_finish, outcome); + if !state.tool_tracker.has_pending() { + let _ = tx.send(AiTuiEvent::ContinueAfterTools); + } + }); + }); +} + +// ─────────────────────────────────────────────────────────────────── +// Permission handlers +// ─────────────────────────────────────────────────────────────────── + +fn on_check_tool_permission( + handle: &Handle<Session>, + tx: &mpsc::Sender<AiTuiEvent>, + app_ctx: &AppContext, + id: String, +) { + let h2 = handle.clone(); + let tx_for_task = tx.clone(); + let db = app_ctx.history_db.clone(); + + tokio::spawn(async move { + let id_for_error = id.clone(); + let result = check_tool_permission_inner(&h2, &tx_for_task, &db, id).await; + + // If the inner function didn't handle the tool (returned an error message), + // finish the tool call with that error so the conversation doesn't stall. + if let Err(error_msg) = result { + let tx = tx_for_task.clone(); + h2.update(move |state| { + state.finish_tool_call(&id_for_error, crate::tools::ToolOutcome::Error(error_msg)); + if !state.tool_tracker.has_pending() { + let _ = tx.send(AiTuiEvent::ContinueAfterTools); + } + }); + } + }); +} + +/// Inner permission check that returns Err(message) if the tool call should be +/// finished with an error. Returns Ok(()) if the tool was handled (executed, +/// denied, or sent to the permission UI). +async fn check_tool_permission_inner( + h2: &Handle<Session>, + tx: &mpsc::Sender<AiTuiEvent>, + db: &std::sync::Arc<atuin_client::database::Sqlite>, + id: String, +) -> Result<(), String> { + // 1. Fetch the tracked tool's data + let id_for_fetch = id.clone(); + let (tool, target_dir) = h2 + .fetch(move |state| { + state + .tool_tracker + .get(&id_for_fetch) + .map(|t| (t.tool.clone(), t.target_dir().map(PathBuf::from))) + }) + .await + .map_err(|e| format!("Internal error fetching tool state: {e}"))? + .ok_or_else(|| "Internal error: tool not found in tracker".to_string())?; + + // 2. Resolve working directory + let working_dir = target_dir + .or_else(|| std::env::current_dir().ok()) + .ok_or_else(|| "Could not determine working directory".to_string())?; + + // 3. Create permission resolver and check + let resolver = PermissionResolver::new(working_dir) + .await + .map_err(|e| format!("Permission check failed: {e}"))?; + + let response = resolver + .check(&tool) + .await + .map_err(|e| format!("Permission check failed: {e}"))?; + + // 4. Handle response — all paths here handle the tool, so return Ok + let id_clone = id.clone(); + match response { + PermissionResponse::Allowed => { + execute_tool(h2, tx, id, tool, db); + } + PermissionResponse::Denied => { + let tx = tx.clone(); + h2.update(move |state| { + state.finish_tool_call( + &id_clone, + crate::tools::ToolOutcome::Error( + "Permission denied on the user's system".to_string(), + ), + ); + if !state.tool_tracker.has_pending() { + let _ = tx.send(AiTuiEvent::ContinueAfterTools); + } + }); + } + PermissionResponse::Ask => { + h2.update(move |state| { + if let Some(tracked) = state.tool_tracker.get_mut(&id_clone) { + tracked.mark_asking(); + } + }); + } + } + + Ok(()) +} + +fn on_select_permission( + handle: &Handle<Session>, + tx: &mpsc::Sender<AiTuiEvent>, + app_ctx: &AppContext, + permission: PermissionResult, +) { + let tx = tx.clone(); + let h2 = handle.clone(); + + match permission { + PermissionResult::Allow => { + // Fetch the tool that's asking for permission, then execute it + let db = app_ctx.history_db.clone(); + tokio::spawn(async move { + let Ok(Some((tool_id, tool))) = h2 + .fetch(move |state| { + state + .tool_tracker + .asking_for_permission() + .map(|t| (t.id.clone(), t.tool.clone())) + }) + .await + else { + return; + }; + + execute_tool(&h2, &tx, tool_id, tool, &db); + }); + } + PermissionResult::AlwaysAllowInDir => { + let db = app_ctx.history_db.clone(); + let git_root = app_ctx.git_root.clone(); + tokio::spawn(async move { + let Ok(Some((tool_id, tool))) = h2 + .fetch(move |state| { + state + .tool_tracker + .asking_for_permission() + .map(|t| (t.id.clone(), t.tool.clone())) + }) + .await + else { + return; + }; + + // Write the rule to the project (git root) or cwd permissions file + let project_root = git_root + .or_else(|| std::env::current_dir().ok()) + .unwrap_or_else(|| PathBuf::from(".")); + let file_path = writer::project_permissions_path(&project_root); + let rule = Rule { + tool: tool.rule_name().to_string(), + scope: None, + }; + if let Err(e) = writer::write_rule(&file_path, &rule, RuleDisposition::Allow).await + { + tracing::error!("Failed to write project permission rule: {e}"); + } + + execute_tool(&h2, &tx, tool_id, tool, &db); + }); + } + PermissionResult::AlwaysAllow => { + let db = app_ctx.history_db.clone(); + tokio::spawn(async move { + let Ok(Some((tool_id, tool))) = h2 + .fetch(move |state| { + state + .tool_tracker + .asking_for_permission() + .map(|t| (t.id.clone(), t.tool.clone())) + }) + .await + else { + return; + }; + + // Write the rule to the global permissions file + let file_path = writer::global_permissions_path(); + let rule = Rule { + tool: tool.rule_name().to_string(), + scope: None, + }; + if let Err(e) = writer::write_rule(&file_path, &rule, RuleDisposition::Allow).await + { + tracing::error!("Failed to write global permission rule: {e}"); + } + + execute_tool(&h2, &tx, tool_id, tool, &db); + }); + } + PermissionResult::Deny => { + h2.update(move |state| { + let Some(tracked) = state.tool_tracker.asking_for_permission() else { + return; + }; + let tool_id = tracked.id.clone(); + + state.finish_tool_call( + &tool_id, + crate::tools::ToolOutcome::Error("Permission denied by the user".to_string()), + ); + if !state.tool_tracker.has_pending() { + let _ = tx.send(AiTuiEvent::ContinueAfterTools); + } + }); + } + } +} + +// ─────────────────────────────────────────────────────────────────── +// Other handlers +// ─────────────────────────────────────────────────────────────────── + +fn on_cancel_generation(handle: &Handle<Session>) { + handle.update(|state| match state.interaction.mode { + crate::tui::state::AppMode::Generating => { + state.cancel_generation(); + } + crate::tui::state::AppMode::Streaming => { + state.cancel_streaming(); + } + _ => {} + }); +} + +fn on_execute_command(handle: &Handle<Session>) { + let h2 = handle.clone(); + handle.update(move |state| { + let cmd = state.conversation.current_command().map(|c| c.to_string()); + if let Some(cmd) = cmd { + if state.conversation.is_current_command_dangerous() + && !state.interaction.confirmation_pending + { + state.interaction.confirmation_pending = true; + } else { + state.interaction.confirmation_pending = false; + state.exit_action = Some(ExitAction::Execute(cmd)); + h2.exit(); + } + } + }); +} + +fn on_cancel_confirmation(handle: &Handle<Session>) { + handle.update(move |state| { + state.interaction.confirmation_pending = false; + }); +} + +fn on_insert_command(handle: &Handle<Session>) { + let h2 = handle.clone(); + handle.update(move |state| { + let cmd = state.conversation.current_command().map(|c| c.to_string()); + if let Some(cmd) = cmd { + state.interaction.confirmation_pending = false; + state.exit_action = Some(ExitAction::Insert(cmd)); + h2.exit(); + } + }); +} + +fn on_retry( + handle: &Handle<Session>, + tx: &mpsc::Sender<AiTuiEvent>, + app_ctx: &AppContext, + client_ctx: &ClientContext, +) { + launch_stream(handle, tx, app_ctx, client_ctx, |state| { + state.retry(); + }); +} + +fn on_exit(handle: &Handle<Session>) { + let h2 = handle.clone(); + handle.update(move |state| { + if let Some(abort) = state.stream_abort.take() { + abort.abort(); + } + state.exit_action = Some(ExitAction::Cancel); + h2.exit(); + }); +} + +fn on_interrupt_tool_execution(handle: &Handle<Session>) { + handle.update(move |state| { + // Find executing previews, send interrupt, and mark as interrupted + for tracked in state.tool_tracker.iter_mut() { + if let ToolPhase::ExecutingWithPreview { + ref mut interrupted, + ref mut exit_code, + .. + } = tracked.phase + { + *interrupted = true; + if exit_code.is_none() { + *exit_code = Some(-1); + } + // Send interrupt signal via the tracker entry's abort channel + if let Some(abort_tx) = tracked.abort_tx.take() { + let _ = abort_tx.send(()); + } + } + } + + // The spawned execution task will handle finalizing and sending + // ContinueAfterTools when the process exits. Input mode is already active. + }); +} diff --git a/crates/atuin-ai/src/tui/events.rs b/crates/atuin-ai/src/tui/events.rs index a791bb80..1a422fef 100644 --- a/crates/atuin-ai/src/tui/events.rs +++ b/crates/atuin-ai/src/tui/events.rs @@ -5,13 +5,20 @@ /// eye-declare's context system. The main event loop in `inline.rs` /// receives them and mutates `AppState` accordingly. #[derive(Debug)] -pub enum AiTuiEvent { +pub(crate) enum AiTuiEvent { /// User updated the input text InputUpdated(String), /// User submitted text input (Enter in Input mode) SubmitInput(String), /// User entered a slash command (e.g. "/help") + #[allow(unused)] SlashCommand(String), + /// Check the permission for a tool call + CheckToolCallPermission(String), + /// User selected a permission + SelectPermission(PermissionResult), + /// Continue after client tools have completed + ContinueAfterTools, /// Cancel active generation or streaming (Esc during Generating/Streaming) CancelGeneration, /// Execute the suggested command @@ -20,8 +27,18 @@ pub enum AiTuiEvent { InsertCommand, /// Cancel confirmation of dangerous command CancelConfirmation, + /// Interrupt a running tool execution (Ctrl+C during ExecutingPreview) + InterruptToolExecution, /// Retry after error Retry, /// Exit the application Exit, } + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum PermissionResult { + Allow, + AlwaysAllowInDir, + AlwaysAllow, + Deny, +} diff --git a/crates/atuin-ai/src/tui/mod.rs b/crates/atuin-ai/src/tui/mod.rs index acb251a7..afd63312 100644 --- a/crates/atuin-ai/src/tui/mod.rs +++ b/crates/atuin-ai/src/tui/mod.rs @@ -1,6 +1,7 @@ -pub mod components; -pub mod events; -pub mod state; -pub mod view; +pub(crate) mod components; +pub(crate) mod dispatch; +pub(crate) mod events; +pub(crate) mod state; +pub(crate) mod view; -pub use state::{AppMode, AppState, ConversationEvent, ExitAction}; +pub(crate) use state::{ConversationEvent, Session}; diff --git a/crates/atuin-ai/src/tui/state.rs b/crates/atuin-ai/src/tui/state.rs index 4c5c2a1e..69b35909 100644 --- a/crates/atuin-ai/src/tui/state.rs +++ b/crates/atuin-ai/src/tui/state.rs @@ -5,9 +5,11 @@ use tokio::task::AbortHandle; +use crate::tools::{ClientToolCall, ToolOutcome, ToolTracker}; + /// Streaming status indicators from server #[derive(Debug, Clone, PartialEq, Eq)] -pub enum StreamingStatus { +pub(crate) enum StreamingStatus { Processing, Searching, Thinking, @@ -15,7 +17,7 @@ pub enum StreamingStatus { } impl StreamingStatus { - pub fn from_status_str(s: &str) -> Self { + pub(crate) fn from_status_str(s: &str) -> Self { match s { "processing" => Self::Processing, "searching" => Self::Searching, @@ -23,20 +25,11 @@ impl StreamingStatus { _ => Self::Thinking, } } - - pub fn display_text(&self) -> &'static str { - match self { - Self::Processing => "Processing...", - Self::Searching => "Searching...", - Self::Thinking => "Thinking...", - Self::WaitingForTools => "Waiting for tools...", - } - } } /// Conversation event types matching the API protocol #[derive(Debug, Clone)] -pub enum ConversationEvent { +pub(crate) enum ConversationEvent { /// User message (what the user typed) UserMessage { content: String }, /// Text content from assistant (streamed or complete) @@ -62,48 +55,8 @@ pub enum ConversationEvent { } impl ConversationEvent { - /// Convert to JSON for API calls - pub fn to_json(&self) -> serde_json::Value { - match self { - ConversationEvent::UserMessage { content } => serde_json::json!({ - "type": "user_message", - "content": content - }), - ConversationEvent::Text { content } => serde_json::json!({ - "type": "text", - "content": content - }), - ConversationEvent::ToolCall { id, name, input } => serde_json::json!({ - "type": "tool_call", - "id": id, - "name": name, - "input": input - }), - ConversationEvent::ToolResult { - tool_use_id, - content, - is_error, - } => serde_json::json!({ - "type": "tool_result", - "tool_use_id": tool_use_id, - "content": content, - "is_error": is_error - }), - ConversationEvent::OutOfBandOutput { - name, - command, - content, - } => serde_json::json!({ - "type": "out_of_band_output", - "name": name, - "command": command, - "content": content - }), - } - } - /// Extract command from a suggest_command tool call - pub fn as_command(&self) -> Option<&str> { + pub(crate) fn as_command(&self) -> Option<&str> { if let ConversationEvent::ToolCall { name, input, .. } = self && name == "suggest_command" { @@ -113,8 +66,9 @@ impl ConversationEvent { } } +/// Application mode for key handling and footer text. #[derive(Debug, Clone, PartialEq, Eq, Copy)] -pub enum AppMode { +pub(crate) enum AppMode { /// User is typing input Input, /// Waiting for generation (showing spinner) @@ -126,7 +80,7 @@ pub enum AppMode { } #[derive(Debug, Clone, PartialEq, Eq)] -pub enum ExitAction { +pub(crate) enum ExitAction { /// Run the command Execute(String), /// Insert command without running @@ -135,47 +89,20 @@ pub enum ExitAction { Cancel, } -/// Application state — the domain model -/// -/// Conversation is stored as a sequence of events matching the API protocol. -/// The view function derives the UI from this state. +/// Owned event log and session ID #[derive(Debug)] -pub struct AppState { - /// Current application mode - pub mode: AppMode, +pub(crate) struct Conversation { /// Conversation events (source of truth, matches API protocol) pub events: Vec<ConversationEvent>, - /// Current error message - pub error: Option<String>, - /// Exit action (set when exiting) - pub exit_action: Option<ExitAction>, /// Session ID from server pub session_id: Option<String>, - /// Current streaming status - pub streaming_status: Option<StreamingStatus>, - /// Whether the input is blank - pub is_input_blank: bool, - /// Whether current turn was interrupted by user - pub was_interrupted: bool, - /// True when user has pressed Enter once on a dangerous command - pub confirmation_pending: bool, - /// Abort handle for the active streaming task, if any - pub stream_abort: Option<AbortHandle>, } -impl AppState { +impl Conversation { pub fn new() -> Self { Self { - mode: AppMode::Input, events: Vec::new(), - error: None, - exit_action: None, session_id: None, - streaming_status: None, - is_input_blank: false, - was_interrupted: false, - confirmation_pending: false, - stream_abort: None, } } @@ -195,16 +122,57 @@ impl AppState { i += 1; } ConversationEvent::Text { content } => { - messages.push(serde_json::json!({ - "role": "assistant", - "content": content - })); - i += 1; + // Check if the next event(s) are ToolCalls — if so, combine + // into a single assistant message with mixed content blocks. + let next_is_tool_call = events + .get(i + 1) + .is_some_and(|e| matches!(e, ConversationEvent::ToolCall { .. })); + + if next_is_tool_call { + let mut content_blocks = Vec::new(); + + if !content.is_empty() { + content_blocks.push(serde_json::json!({ + "type": "text", + "text": content + })); + } + + while let Some(ConversationEvent::ToolCall { + id, name, input, .. + }) = events.get(i + 1) + { + content_blocks.push(serde_json::json!({ + "type": "tool_use", + "id": id, + "name": name, + "input": input + })); + i += 1; + } + + messages.push(serde_json::json!({ + "role": "assistant", + "content": content_blocks + })); + i += 1; + } else { + messages.push(serde_json::json!({ + "role": "assistant", + "content": content + })); + i += 1; + } } ConversationEvent::ToolCall { .. } => { + // ToolCalls without preceding Text (shouldn't normally happen, + // but handle defensively) let mut tool_uses = Vec::new(); while i < events.len() { - if let ConversationEvent::ToolCall { id, name, input } = &events[i] { + if let ConversationEvent::ToolCall { + id, name, input, .. + } = &events[i] + { tool_uses.push(serde_json::json!({ "type": "tool_use", "id": id, @@ -247,53 +215,42 @@ impl AppState { messages } - // ===== Generation lifecycle methods ===== - - /// Start generating from submitted input - pub fn start_generating(&mut self, input: String) { - self.events - .push(ConversationEvent::UserMessage { content: input }); - self.mode = AppMode::Generating; - } - - /// Generation error occurred - pub fn generation_error(&mut self, error: String) { - self.error = Some(error); - self.mode = AppMode::Error; - } - - /// Cancel during generation - pub fn cancel_generation(&mut self) { - if let Some(abort) = self.stream_abort.take() { - abort.abort(); - } - if let Some(ConversationEvent::UserMessage { .. }) = self.events.last() { - self.events.pop(); - } - self.mode = AppMode::Input; - } - - // ===== Streaming lifecycle methods ===== - - /// Start streaming response. - /// Pushes an empty Text event that will be mutated in-place as chunks arrive. - pub fn start_streaming(&mut self) { - self.events.push(ConversationEvent::Text { - content: String::new(), - }); - self.streaming_status = None; - self.was_interrupted = false; - self.mode = AppMode::Streaming; + /// Get the most recent command from events + pub fn current_command(&self) -> Option<&str> { + self.events.iter().rev().find_map(|e| e.as_command()) } - /// Store session ID from server response - pub fn store_session_id(&mut self, session_id: String) { - self.session_id = Some(session_id); + /// Check if any turn in the conversation has a command + pub fn has_any_command(&self) -> bool { + self.events.iter().any(|e| { + if let ConversationEvent::ToolCall { name, input, .. } = e { + name == "suggest_command" && input.get("command").and_then(|v| v.as_str()).is_some() + } else { + false + } + }) } - /// Update streaming status from SSE event - pub fn update_streaming_status(&mut self, status: &str) { - self.streaming_status = Some(StreamingStatus::from_status_str(status)); + /// Check if the most recent command is marked dangerous + pub fn is_current_command_dangerous(&self) -> bool { + self.events + .iter() + .rev() + .find_map(|e| { + if let ConversationEvent::ToolCall { name, input, .. } = e + && name == "suggest_command" + { + let danger_level = input + .get("danger") + .and_then(|v| v.as_str()) + .unwrap_or("low"); + return Some( + danger_level == "high" || danger_level == "medium" || danger_level == "med", + ); + } + None + }) + .unwrap_or(false) } /// Get a mutable reference to the last Text event's content (the streaming buffer). @@ -307,28 +264,15 @@ impl AppState { }) } - /// Cancel streaming with context preservation - pub fn cancel_streaming(&mut self) { - if let Some(abort) = self.stream_abort.take() { - abort.abort(); - } - self.was_interrupted = true; - - if let Some(content) = self.streaming_content_mut() { - let trimmed = content.trim_start().to_string(); - if trimmed.is_empty() { - // Remove the empty text event - *content = String::new(); + /// Remove trailing empty Text events from the events list + fn remove_empty_trailing_text(&mut self) { + while let Some(ConversationEvent::Text { content }) = self.events.last() { + if content.is_empty() { + self.events.pop(); } else { - *content = format!("{trimmed}\n\n[User cancelled this generation]"); + break; } } - // Remove trailing empty Text events - self.remove_empty_trailing_text(); - - self.streaming_status = None; - self.confirmation_pending = false; - self.mode = AppMode::Input; } /// Append text chunk during streaming (mutates the last Text event in-place) @@ -354,26 +298,6 @@ impl AppState { } } - /// Add a tool call event during streaming. - /// The current streaming text is already in events, so we just push the tool call. - pub fn add_tool_call(&mut self, id: String, name: String, input: serde_json::Value) { - // Trim the streaming text event - if let Some(content) = self.streaming_content_mut() { - let trimmed = content.trim_start().to_string(); - *content = trimmed; - } - self.remove_empty_trailing_text(); - - let is_suggest_command = name == "suggest_command"; - self.events - .push(ConversationEvent::ToolCall { id, name, input }); - - if is_suggest_command { - self.streaming_status = None; - self.mode = AppMode::Input; - } - } - /// Add a tool result event during streaming pub fn add_tool_result(&mut self, tool_use_id: String, content: String, is_error: bool) { self.events.push(ConversationEvent::ToolResult { @@ -383,47 +307,9 @@ impl AppState { }); } - /// Finalize streaming — trim the accumulated text and change mode - pub fn finalize_streaming(&mut self) { - if let Some(content) = self.streaming_content_mut() { - let trimmed = content.trim_start().to_string(); - *content = trimmed; - } - self.remove_empty_trailing_text(); - self.streaming_status = None; - self.mode = AppMode::Input; - } - - /// Streaming error — remove the partial text event - pub fn streaming_error(&mut self, error: String) { - self.remove_empty_trailing_text(); - self.error = Some(error); - self.mode = AppMode::Error; - } - - /// Remove trailing empty Text events from the events list - fn remove_empty_trailing_text(&mut self) { - while let Some(ConversationEvent::Text { content }) = self.events.last() { - if content.is_empty() { - self.events.pop(); - } else { - break; - } - } - } - - // ===== Edit mode and exit methods ===== - - /// Start edit mode for refinement - pub fn start_edit_mode(&mut self) { - self.confirmation_pending = false; - self.mode = AppMode::Input; - } - - /// Retry after error - pub fn retry(&mut self) { - self.error = None; - self.mode = AppMode::Generating; + /// Store session ID from server response + pub fn store_session_id(&mut self, session_id: String) { + self.session_id = Some(session_id); } /// Handle a slash command @@ -445,85 +331,247 @@ impl AppState { }), } } +} - // ===== Query methods ===== +/// Ephemeral UI/presentation state +#[derive(Debug)] +pub(crate) struct Interaction { + /// Current application mode + pub mode: AppMode, + /// Whether the input is blank + pub is_input_blank: bool, + /// True when user has pressed Enter once on a dangerous command + pub confirmation_pending: bool, + /// Current streaming status + pub streaming_status: Option<StreamingStatus>, + /// Whether current turn was interrupted by user + pub was_interrupted: bool, + /// Current error message + pub error: Option<String>, +} - /// Get the most recent command from events - pub fn current_command(&self) -> Option<&str> { - self.events.iter().rev().find_map(|e| e.as_command()) +impl Interaction { + pub fn new() -> Self { + Self { + mode: AppMode::Input, + is_input_blank: false, + confirmation_pending: false, + streaming_status: None, + was_interrupted: false, + error: None, + } } +} - /// Check if the most recent command is marked dangerous - pub fn is_current_command_dangerous(&self) -> bool { - self.events - .iter() - .rev() - .find_map(|e| { - if let ConversationEvent::ToolCall { name, input, .. } = e - && name == "suggest_command" - { - let danger_level = input - .get("danger") - .and_then(|v| v.as_str()) - .unwrap_or("low"); - return Some( - danger_level == "high" || danger_level == "medium" || danger_level == "med", - ); - } - None - }) - .unwrap_or(false) +/// Top-level session state +/// +/// Decomposed into `Conversation` (event log + session ID) and +/// `Interaction` (ephemeral UI state). Session methods that cross +/// both sub-structs live here. +#[derive(Debug)] +pub(crate) struct Session { + pub conversation: Conversation, + pub interaction: Interaction, + /// Tracks all tool calls through their full lifecycle. + pub tool_tracker: ToolTracker, + /// Whether the session is running inside a git project (for permission UI labels). + pub in_git_project: bool, + /// Exit action (set when exiting) + pub exit_action: Option<ExitAction>, + /// Abort handle for the active streaming task, if any + pub stream_abort: Option<AbortHandle>, +} + +impl Session { + pub fn new(in_git_project: bool) -> Self { + Self { + conversation: Conversation::new(), + interaction: Interaction::new(), + tool_tracker: ToolTracker::new(), + in_git_project, + exit_action: None, + stream_abort: None, + } } - /// Count non-suggest_command tool calls since the last user message - pub fn tool_count_since_last_user(&self) -> usize { - let last_user_idx = self + // ===== Generation lifecycle methods ===== + + /// Start generating from submitted input + pub fn start_generating(&mut self, input: String) { + self.conversation .events - .iter() - .rposition(|e| matches!(e, ConversationEvent::UserMessage { .. })) - .unwrap_or(0); + .push(ConversationEvent::UserMessage { content: input }); + self.interaction.mode = AppMode::Generating; + } - let mut completed = 0; - let mut in_flight = false; + /// Generation error occurred + #[expect(dead_code)] + pub fn generation_error(&mut self, error: String) { + self.interaction.error = Some(error); + self.interaction.mode = AppMode::Error; + } - for event in &self.events[last_user_idx..] { - match event { - ConversationEvent::ToolCall { name, .. } if name != "suggest_command" => { - if in_flight { - completed += 1; - } - in_flight = true; - } - ConversationEvent::ToolResult { .. } => { - if in_flight { - completed += 1; - in_flight = false; - } - } - _ => {} - } + /// Cancel during generation + pub fn cancel_generation(&mut self) { + if let Some(abort) = self.stream_abort.take() { + abort.abort(); + } + if let Some(ConversationEvent::UserMessage { .. }) = self.conversation.events.last() { + self.conversation.events.pop(); } + self.interaction.mode = AppMode::Input; + } - completed + // ===== Streaming lifecycle methods ===== + + /// Start streaming response. + /// Pushes an empty Text event that will be mutated in-place as chunks arrive. + pub fn start_streaming(&mut self) { + self.conversation.events.push(ConversationEvent::Text { + content: String::new(), + }); + self.interaction.streaming_status = None; + self.interaction.was_interrupted = false; + self.interaction.mode = AppMode::Streaming; } - /// Check if any turn in the conversation has a command - pub fn has_any_command(&self) -> bool { - self.events.iter().any(|e| { - if let ConversationEvent::ToolCall { name, input, .. } = e { - name == "suggest_command" && input.get("command").and_then(|v| v.as_str()).is_some() + /// Update streaming status from SSE event + pub fn update_streaming_status(&mut self, status: &str) { + self.interaction.streaming_status = Some(StreamingStatus::from_status_str(status)); + } + + /// Cancel streaming with context preservation + pub fn cancel_streaming(&mut self) { + if let Some(abort) = self.stream_abort.take() { + abort.abort(); + } + self.interaction.was_interrupted = true; + + if let Some(content) = self.conversation.streaming_content_mut() { + let trimmed = content.trim_start().to_string(); + if trimmed.is_empty() { + // Remove the empty text event + *content = String::new(); } else { - false + *content = format!("{trimmed}\n\n[User cancelled this generation]"); } - }) + } + // Remove trailing empty Text events + self.conversation.remove_empty_trailing_text(); + + self.interaction.streaming_status = None; + self.interaction.confirmation_pending = false; + self.interaction.mode = AppMode::Input; + } + + /// Add a tool call event during streaming. + /// The current streaming text is already in events, so we just push the tool call. + pub fn add_tool_call(&mut self, id: String, name: String, input: serde_json::Value) { + // Trim the streaming text event + if let Some(content) = self.conversation.streaming_content_mut() { + let trimmed = content.trim_start().to_string(); + *content = trimmed; + } + self.conversation.remove_empty_trailing_text(); + + let is_suggest_command = name == "suggest_command"; + self.conversation + .events + .push(ConversationEvent::ToolCall { id, name, input }); + + if is_suggest_command { + self.interaction.streaming_status = None; + self.interaction.mode = AppMode::Input; + } + } + + /// Finalize streaming — trim the accumulated text and change mode + pub fn finalize_streaming(&mut self) { + if let Some(content) = self.conversation.streaming_content_mut() { + let trimmed = content.trim_start().to_string(); + *content = trimmed; + } + self.conversation.remove_empty_trailing_text(); + self.interaction.streaming_status = None; + self.interaction.mode = AppMode::Input; + } + + /// Streaming error — remove the partial text event + pub fn streaming_error(&mut self, error: String) { + self.conversation.remove_empty_trailing_text(); + self.interaction.error = Some(error); + self.interaction.mode = AppMode::Error; + } + + pub(crate) fn handle_client_tool_call( + &mut self, + id: String, + tool: ClientToolCall, + input: serde_json::Value, + ) { + let desc = tool.descriptor(); + let name = desc.canonical_names[0].to_string(); + + self.tool_tracker.insert(id.clone(), tool); + + // Add the ToolCall event to the conversation immediately so it appears + // in the view. Preview data is sourced from tool_tracker. + self.conversation + .events + .push(ConversationEvent::ToolCall { id, name, input }); + + // Client tool calls can only happen at the last part of a turn + self.interaction.streaming_status = None; + self.interaction.mode = AppMode::Input; + } + + /// Retry after error + pub fn retry(&mut self) { + self.interaction.error = None; + self.interaction.mode = AppMode::Generating; + } + + // ===== Tool lifecycle methods ===== + + /// Finish a tool call: transition tracker to Completed, push ToolResult to conversation. + /// + /// For shell commands, captures the final preview from the ExecutingWithPreview phase + /// and patches exit_code/interrupted from the authoritative ToolOutcome. + pub fn finish_tool_call(&mut self, tool_id: &str, outcome: ToolOutcome) { + let mut preview = self.tool_tracker.get(tool_id).and_then(|t| t.preview()); + + // Patch preview with authoritative outcome data (handles race where + // final VT100 update hasn't been applied yet). + if let Some(ref mut p) = preview + && let ToolOutcome::Structured { + exit_code, + interrupted, + .. + } = &outcome + { + p.interrupted = *interrupted; + if p.exit_code.is_none() { + p.exit_code = *exit_code; + } + } + + // Transition tracker entry to Completed + if let Some(tracked) = self.tool_tracker.get_mut(tool_id) { + tracked.complete(preview); + } + + let content = outcome.format_for_llm(); + let is_error = outcome.is_error(); + self.conversation + .add_tool_result(tool_id.to_string(), content, is_error); } /// Get the footer text for current mode pub fn footer_text(&self) -> &'static str { - match self.mode { + match self.interaction.mode { AppMode::Input => { - if self.has_any_command() && self.is_input_blank { - if self.confirmation_pending { + if self.conversation.has_any_command() && self.interaction.is_input_blank { + if self.interaction.confirmation_pending { "[Enter] Confirm dangerous command [Esc] Cancel" } else { "[Enter] Execute suggested command [Tab] Insert Command" @@ -542,9 +590,3 @@ impl AppState { self.exit_action.is_some() } } - -impl Default for AppState { - fn default() -> Self { - Self::new() - } -} diff --git a/crates/atuin-ai/src/tui/view/mod.rs b/crates/atuin-ai/src/tui/view/mod.rs index 0cd51dfa..ee5483d8 100644 --- a/crates/atuin-ai/src/tui/view/mod.rs +++ b/crates/atuin-ai/src/tui/view/mod.rs @@ -1,14 +1,20 @@ //! View function that builds the eye-declare element tree from app state. use eye_declare::{ - Cells, Column, Elements, HStack, Span, Spinner, Text, View, WidthConstraint, element, + BorderType, Cells, Column, Elements, HStack, Span, Spinner, Text, View, Viewport, + WidthConstraint, element, }; use ratatui_core::style::{Color, Modifier, Style}; +use crate::tools::{ClientToolCall, TrackedTool}; +use crate::tui::components::select::SelectOption; +use crate::tui::events::{AiTuiEvent, PermissionResult}; + use super::components::atuin_ai::AtuinAi; use super::components::input_box::InputBox; use super::components::markdown::Markdown; -use super::state::{AppMode, AppState}; +use super::components::select::Select; +use super::state::{AppMode, Session}; mod turn; @@ -20,23 +26,25 @@ mod turn; /// - Error display (if in error state) /// - Spacer /// - Input box (bordered, with contextual keybindings) -pub fn ai_view(state: &AppState) -> Elements { - let mut turn_builder = turn::TurnBuilder::new(); +pub(crate) fn ai_view(state: &Session) -> Elements { + let mut turn_builder = turn::TurnBuilder::new(&state.tool_tracker); - for event in &state.events { + for event in &state.conversation.events { turn_builder.add_event(event); } let turns = turn_builder.build(); - let busy = state.mode == AppMode::Streaming || state.mode == AppMode::Generating; + let busy = state.interaction.mode == AppMode::Streaming + || state.interaction.mode == AppMode::Generating; let last_index = turns.len().saturating_sub(1); element! { AtuinAi( - mode: state.mode, - has_command: state.has_any_command(), - is_input_blank: state.is_input_blank, - pending_confirmation: state.confirmation_pending, + mode: state.interaction.mode, + has_command: state.conversation.has_any_command(), + is_input_blank: state.interaction.is_input_blank, + pending_confirmation: state.interaction.confirmation_pending, + has_executing_preview: state.tool_tracker.has_executing_preview(), ) { #(for (index, turn) in turns.iter().enumerate() { #(match turn { @@ -53,25 +61,94 @@ pub fn ai_view(state: &AppState) -> Elements { }) #(if !state.is_exiting() { - View(key: "input-box", padding_top: Cells::from(1)) { - InputBox( - key: "input", - title: "Generate a command or ask a question", - title_right: "Atuin AI", - footer: state.footer_text(), - active: state.mode == AppMode::Input && !state.confirmation_pending, - ) + #(input_view(state)) + }) + } + } +} - #(if state.is_input_blank && state.has_any_command() && state.mode == AppMode::Input { - #(if state.confirmation_pending { - Text { Span(text: "[Enter] Confirm dangerous command [Esc] Cancel", style: Style::default().fg(Color::Gray)) } - } else { - Text { Span(text: "[Enter] Execute suggested command [Tab] Insert Command", style: Style::default().fg(Color::Gray)) } - }) +fn input_view(state: &Session) -> Elements { + let asking_tool = state.tool_tracker.asking_for_permission(); + let in_git_project = state.in_git_project; + + element! { + #(if let Some(tc) = asking_tool { + #(tool_call_view(tc, in_git_project)) + }) + + #(if asking_tool.is_none() { + View(key: "input-box", padding_top: Cells::from(1)) { + InputBox( + key: "input", + title: "Generate a command or ask a question", + title_right: "Atuin AI", + footer: state.footer_text(), + active: state.interaction.mode == AppMode::Input && !state.interaction.confirmation_pending, + ) + + #(if state.interaction.is_input_blank && state.conversation.has_any_command() && state.interaction.mode == AppMode::Input { + #(if state.interaction.confirmation_pending { + Text { Span(text: "[Enter] Confirm dangerous command [Esc] Cancel", style: Style::default().fg(Color::Gray)) } + } else { + Text { Span(text: "[Enter] Execute suggested command [Tab] Insert Command", style: Style::default().fg(Color::Gray)) } }) + }) + } + }) + } +} - } - }) +fn tool_call_view(tool_call: &TrackedTool, in_git_project: bool) -> Elements { + let verb = tool_call.tool.descriptor().display_verb; + let tool_desc = match &tool_call.tool { + ClientToolCall::Read(tool) => tool.path.display().to_string(), + ClientToolCall::Write(tool) => tool.path.display().to_string(), + ClientToolCall::Shell(tool) => tool.command.clone(), + ClientToolCall::AtuinHistory(tool) => tool.query.clone(), + }; + + let dir_label = if in_git_project { + "Always allow in this workspace" + } else { + "Always allow in this directory" + }; + + element! { + View(key: format!("tool-call-{}", tool_call.id), padding_left: Cells::from(2), padding_top: Cells::from(1)) { + Text { + Span(text: format!("Atuin AI would like to {}: ", verb), style: Style::default()) + Span(text: &tool_desc, style: Style::default().fg(Color::Yellow)) + } + View(padding_left: Cells::from(2)) { + Select(options: [ + SelectOption::builder() + .label("Allow") + .value("allow") + .build(), + SelectOption::builder() + .label(dir_label) + .value("always-allow-in-dir") + .build(), + SelectOption::builder() + .label("Always allow") + .value("always-allow") + .build(), + SelectOption::builder() + .label("Deny") + .value("deny") + .build(), + ], on_select: Box::new(move |option: &SelectOption| { + let value = match option.value.as_str() { + "allow" => PermissionResult::Allow, + "always-allow-in-dir" => PermissionResult::AlwaysAllowInDir, + "always-allow" => PermissionResult::AlwaysAllow, + "deny" => PermissionResult::Deny, + _ => unreachable!(), + }; + + Some(AiTuiEvent::SelectPermission(value)) + }) as Box<dyn Fn(&SelectOption) -> Option<AiTuiEvent> + Send + Sync>) + } } } } @@ -86,7 +163,7 @@ fn user_turn_view(events: &[turn::UiEvent], first_turn: bool) -> Elements { element! { View(padding_top: Cells::from(padding)) { Text { - Span(text: "You", style: label_style) + Span(text: " You ", style: label_style.reversed()) } #(for event in events { #(match event { @@ -114,9 +191,9 @@ fn agent_turn_view(events: &[turn::UiEvent], busy: bool) -> Elements { element! { View { Spinner( - label: "Atuin AI", - label_style: label_style, - done_label_style: label_style, + label: " Atuin AI ", + label_style: label_style.reversed(), + done_label_style: label_style.reversed(), hide_checkmark: true, label_first: true, done: !busy, @@ -136,6 +213,52 @@ fn agent_turn_view(events: &[turn::UiEvent], busy: bool) -> Elements { turn::UiEvent::SuggestedCommand(details) => { suggested_command_view(details) }, + turn::UiEvent::ToolCall(details) => { + let preview_done = details.preview.as_ref().is_some_and(|p| p.exit_code.is_some() || p.interrupted); + let tool_key = details.tool_use_id.clone(); + + element! { + View(key: format!("tool-output-{tool_key}"), padding_left: Cells::from(2)) { + #(if let Some(ref preview) = details.preview { + View(key: format!("preview-{tool_key}")) { + #(preview_spinner_view(&details.name, preview_done)) + Viewport( + key: format!("viewport-{tool_key}"), + lines: preview.lines.clone(), + height: 10, + border: BorderType::Plain, + border_style: Style::default().fg(Color::DarkGray), + style: Style::default().fg(Color::White), + wrap: false, + ) + #(if let Some(code) = preview.exit_code { + #(if code == 0 { + Text { + Span(text: format!("Exit code: {code}"), style: Style::default().fg(Color::Green)) + } + } else { + Text { + Span(text: format!("Exit code: {code}"), style: Style::default().fg(Color::Red)) + } + }) + }) + #(if preview.interrupted { + Text { + Span(text: "Interrupted", style: Style::default().fg(Color::Red).add_modifier(Modifier::BOLD)) + } + }) + #(if !preview_done { + Text { + Span(text: "[Ctrl+C] Interrupt", style: Style::default().fg(Color::DarkGray)) + } + }) + } + } else { + #(tool_status_view(&details.name, &details.status)) + }) + } + } + } _ => element!{} }) }) @@ -180,6 +303,48 @@ fn tool_summary_view(summary: &turn::ToolSummary) -> Elements { } } +/// Render a status indicator for a non-preview tool call (e.g. atuin_history, read_file). +fn tool_status_view(name: &str, status: &turn::ToolResultStatus) -> Elements { + match status { + turn::ToolResultStatus::Pending => { + element! { + Spinner( + label: format!("Running: {name}"), + label_style: Style::default().fg(Color::Yellow), + done: false, + ) + } + } + turn::ToolResultStatus::Success => { + element! { + Spinner( + label: format!("Ran: {name}"), + done: true, + ) + } + } + turn::ToolResultStatus::Error => { + element! { + Text { + Span(text: "✗ ", style: Style::default().fg(Color::Red)) + Span(text: format!("{name}: denied"), style: Style::default().fg(Color::Red)) + } + } + } + } +} + +/// Render a spinner/status line for a command preview (shell tools). +fn preview_spinner_view(name: &str, done: bool) -> Elements { + element! { + Spinner( + label: if done { format!("Ran: {name}") } else { format!("Running: {name}") }, + label_style: Style::default().fg(Color::Yellow), + done: done, + ) + } +} + fn suggested_command_view(details: &turn::SuggestedCommandDetails) -> Elements { let is_dangerous = matches!( details.danger_level, diff --git a/crates/atuin-ai/src/tui/view/turn.rs b/crates/atuin-ai/src/tui/view/turn.rs index 861da64c..6949236c 100644 --- a/crates/atuin-ai/src/tui/view/turn.rs +++ b/crates/atuin-ai/src/tui/view/turn.rs @@ -1,5 +1,8 @@ +use crate::tools::descriptor; +use crate::tools::{ToolPreview, ToolTracker}; use crate::tui::ConversationEvent; +/// Server-sent danger level for a suggested command #[derive(Debug)] pub(crate) enum DangerLevel { Low(Option<String>), @@ -37,6 +40,7 @@ impl From<(&String, &String)> for DangerLevel { } } +/// Server-sent confidence level for a suggested command #[derive(Debug)] pub(crate) enum ConfidenceLevel { Low(Option<String>), @@ -85,9 +89,11 @@ pub(crate) enum UiEvent { #[derive(Debug)] pub(crate) struct ToolCallDetails { - tool_use_id: String, - name: String, - status: ToolResultStatus, + pub(crate) tool_use_id: String, + pub(crate) name: String, + pub(crate) status: ToolResultStatus, + pub(crate) is_client: bool, + pub(crate) preview: Option<ToolPreview>, } #[derive(Debug)] @@ -118,16 +124,19 @@ pub(crate) enum UiTurn { OutOfBand { events: Vec<UiEvent> }, } -pub(crate) struct TurnBuilder { +pub(crate) struct TurnBuilder<'a> { turns: Vec<UiTurn>, current_turn: Option<UiTurn>, + tracker: &'a ToolTracker, } -impl TurnBuilder { - pub(crate) fn new() -> Self { +/// A struct to iteratively build [UiTurn] events from [ConversationEvent]s. +impl<'a> TurnBuilder<'a> { + pub(crate) fn new(tracker: &'a ToolTracker) -> Self { Self { turns: Vec::new(), current_turn: None, + tracker, } } @@ -174,7 +183,7 @@ impl TurnBuilder { for event in events.drain(..) { match event { - UiEvent::ToolCall(details) => { + UiEvent::ToolCall(details) if !details.is_client => { pending_tools.push(details); } other => { @@ -306,12 +315,17 @@ impl TurnBuilder { } fn add_tool_call(&mut self, id: &str, name: &str, _input: &serde_json::Value) { + let is_client = descriptor::by_name(name).is_some_and(|d| d.is_client); + let preview = self.tracker.preview_for(id); + self.start_agent_turn(); if let UiTurn::Agent { events } = self.turn_mut_unsafe() { events.push(UiEvent::ToolCall(ToolCallDetails { tool_use_id: id.to_string(), name: name.to_string(), status: ToolResultStatus::Pending, + is_client, + preview, })); } } @@ -385,25 +399,15 @@ impl ToolSummary { /// Present-tense progressive verb for a tool name (e.g. "Searching...") fn progressive_verb(name: &str) -> String { - match name { - "search" => "Searching...".into(), - "read" | "read_file" => "Reading file...".into(), - "write" | "write_file" => "Writing file...".into(), - "execute" | "run" | "bash" => "Running command...".into(), - "list" | "list_files" => "Listing files...".into(), - _ => format!("Running {}...", name.replace('_', " ")), - } + descriptor::by_name(name) + .map(|d| d.progressive_verb.to_string()) + .unwrap_or_else(|| format!("Running {}...", name.replace('_', " "))) } /// Past-tense verb for a tool name (e.g. "Searched") fn past_verb(name: &str) -> String { - match name { - "search" => "Searched".into(), - "read" | "read_file" => "Read file".into(), - "write" | "write_file" => "Wrote file".into(), - "execute" | "run" | "bash" => "Ran command".into(), - "list" | "list_files" => "Listed files".into(), - _ => format!("Ran {}", name.replace('_', " ")), - } + descriptor::by_name(name) + .map(|d| d.past_verb.to_string()) + .unwrap_or_else(|| format!("Ran {}", name.replace('_', " "))) } } |
