diff options
| author | Michelle Tilley <michelle@michelletilley.net> | 2026-04-10 13:24:57 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-04-10 20:24:57 +0000 |
| commit | 09279a428659cf41824737d3e0c97bcc19a8885a (patch) | |
| tree | 64731502c065df2483e8dd680d46c5559f3094f2 | |
| parent | feat: add strip_trailing_whitespace, on by default (#3390) (diff) | |
| download | atuin-09279a428659cf41824737d3e0c97bcc19a8885a.zip | |
feat: Client-tool execution + permission system (#3370)
Adds client-side tool execution to Atuin AI, starting with
`atuin_history`. The server can request tool calls, which are executed
locally with a permission system, and results are sent back to continue
the conversation.
37 files changed, 5181 insertions, 906 deletions
@@ -13,3 +13,5 @@ ui/backend/target ui/backend/gen sqlite-server.db* + +.atuin/permissions.*.toml @@ -280,21 +280,33 @@ dependencies = [ "eye_declare", "eyre", "futures", + "glob-match", "pretty_assertions", "pulldown-cmark", "ratatui", "ratatui-core", "ratatui-widgets", + "regex", "reqwest", "serde", "serde_json", + "tempfile", + "thiserror 2.0.18", + "time", "tokio", + "toml", + "toml_edit", "tracing", "tracing-appender", "tracing-subscriber", + "tree-sitter", + "tree-sitter-bash", + "tree-sitter-fish", "tui-textarea-2", + "typed-builder 0.18.2", "unicode-width 0.2.2", "uuid", + "vt100", ] [[package]] @@ -954,7 +966,7 @@ dependencies = [ "pathdiff", "serde_core", "toml", - "winnow", + "winnow 0.7.15", ] [[package]] @@ -1498,9 +1510,9 @@ dependencies = [ [[package]] name = "eye_declare" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cac5c1a3b194e6674e9e44dbfb035f31c4df7a1ff6c8765181c50e8482bb393a" +checksum = "f9abe8051754adccf30ac4a0d54ce083a645fee7d4fc6c78d9d9770821bad45d" dependencies = [ "crossterm", "eye_declare_macros", @@ -1514,9 +1526,9 @@ dependencies = [ [[package]] name = "eye_declare_macros" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98595776b5e10c6ea519c09940fb7995b64da1e9a70cc94aa6c08b3bd404925a" +checksum = "39251ef16365f347032ab2344ad806f64d59f29fb171b4bafd05595fbda2604d" dependencies = [ "proc-macro2", "quote", @@ -1849,6 +1861,12 @@ dependencies = [ ] [[package]] +name = "glob-match" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9985c9503b412198aa4197559e9a318524ebc4519c229bfa05a535828c950b9d" + +[[package]] name = "h2" version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -4302,6 +4320,7 @@ version = "1.0.149" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" dependencies = [ + "indexmap 2.13.0", "itoa", "memchr", "serde", @@ -4332,9 +4351,9 @@ dependencies = [ [[package]] name = "serde_spanned" -version = "1.0.4" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8bbf91e5a4d6315eee45e704372590b30e260ee83af6639d64557f51b067776" +checksum = "6662b5879511e06e8999a8a235d848113e942c9124f211511b16466ee2995f26" dependencies = [ "serde_core", ] @@ -4805,6 +4824,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" [[package]] +name = "streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b2231b7c3057d5e4ad0156fb3dc807d900806020c5ffa3ee6ff2c8c76fb8520" + +[[package]] name = "stringprep" version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -5197,22 +5222,24 @@ dependencies = [ [[package]] name = "toml" -version = "1.0.6+spec-1.1.0" +version = "1.1.1+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "399b1124a3c9e16766831c6bba21e50192572cdd98706ea114f9502509686ffc" +checksum = "994b95d9e7bae62b34bab0e2a4510b801fa466066a6a8b2b57361fa1eba068ee" dependencies = [ + "indexmap 2.13.0", "serde_core", "serde_spanned", "toml_datetime", "toml_parser", - "winnow", + "toml_writer", + "winnow 1.0.1", ] [[package]] name = "toml_datetime" -version = "1.0.0+spec-1.1.0" +version = "1.1.1+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32c2555c699578a4f59f0cc68e5116c8d7cabbd45e1409b989d4be085b53f13e" +checksum = "3165f65f62e28e0115a00b2ebdd37eb6f3b641855f9d636d3cd4103767159ad7" dependencies = [ "serde_core", ] @@ -5227,23 +5254,23 @@ dependencies = [ "toml_datetime", "toml_parser", "toml_writer", - "winnow", + "winnow 0.7.15", ] [[package]] name = "toml_parser" -version = "1.0.9+spec-1.1.0" +version = "1.1.1+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "702d4415e08923e7e1ef96cd5727c0dfed80b4d2fa25db9647fe5eb6f7c5a4c4" +checksum = "39ca317ebc49f06bd748bfba29533eac9485569dc9bf80b849024b025e814fb9" dependencies = [ - "winnow", + "winnow 1.0.1", ] [[package]] name = "toml_writer" -version = "1.0.6+spec-1.1.0" +version = "1.1.1+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab16f14aed21ee8bfd8ec22513f7287cd4a91aa92e44edfe2c17ddd004e92607" +checksum = "756daf9b1013ebe47a8776667b466417e2d4c5679d441c26230efd9ef78692db" [[package]] name = "tonic" @@ -5474,6 +5501,46 @@ dependencies = [ ] [[package]] +name = "tree-sitter" +version = "0.26.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "887bd495d0582c5e3e0d8ece2233666169fa56a9644d172fc22ad179ab2d0538" +dependencies = [ + "cc", + "regex", + "regex-syntax", + "serde_json", + "streaming-iterator", + "tree-sitter-language", +] + +[[package]] +name = "tree-sitter-bash" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5ec769279cc91b561d3df0d8a5deb26b0ad40d183127f409494d6d8fc53062" +dependencies = [ + "cc", + "tree-sitter-language", +] + +[[package]] +name = "tree-sitter-fish" +version = "3.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "014e3b299f251e9c2e372e3b5e1b0323ef21196e9aa2e90a5bc1f6130cbe8b18" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-language" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "009994f150cc0cd50ff54917d5bc8bffe8cad10ca10d81c34da2ec421ae61782" + +[[package]] name = "tree_magic_mini" version = "3.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -5709,35 +5776,23 @@ checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" [[package]] name = "vt100" -version = "0.15.2" +version = "0.16.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84cd863bf0db7e392ba3bd04994be3473491b31e66340672af5d11943c6274de" +checksum = "054ff75fb8fa83e609e685106df4faeffdf3a735d3c74ebce97ec557d5d36fd9" dependencies = [ "itoa", - "log", - "unicode-width 0.1.14", + "unicode-width 0.2.2", "vte", ] [[package]] name = "vte" -version = "0.11.1" +version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5022b5fbf9407086c180e9557be968742d839e68346af7792b8592489732197" +checksum = "a5924018406ce0063cd67f8e008104968b74b563ee1b85dde3ed1f7cb87d3dbd" dependencies = [ "arrayvec", - "utf8parse", - "vte_generate_state_changes", -] - -[[package]] -name = "vte_generate_state_changes" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e369bee1b05d510a7b4ed645f5faa90619e05437111783ea5848f28d97d3c2e" -dependencies = [ - "proc-macro2", - "quote", + "memchr", ] [[package]] @@ -6550,6 +6605,12 @@ dependencies = [ ] [[package]] +name = "winnow" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09dac053f1cd375980747450bfc7250c264eaae0583872e845c0c7cd578872b5" + +[[package]] name = "winreg" version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1,5 +1,9 @@ [workspace] -members = ["crates/*", "crates/atuin-nucleo/matcher", "crates/atuin-nucleo/bench"] +members = [ + "crates/*", + "crates/atuin-nucleo/matcher", + "crates/atuin-nucleo/bench", +] resolver = "2" exclude = ["ui/backend", "crates/atuin-nucleo/matcher/fuzz"] @@ -65,6 +69,10 @@ rustls = { version = "0.23", default-features = false, features = [ "std", "tls12", ] } +glob-match = "0.2.1" +vt100 = "0.16" +regex = "1.10.5" +toml_edit = "0.25.4" [workspace.dependencies.tracing-subscriber] version = "0.3" diff --git a/crates/atuin-ai/Cargo.toml b/crates/atuin-ai/Cargo.toml index 6e7315cd..c5f66695 100644 --- a/crates/atuin-ai/Cargo.toml +++ b/crates/atuin-ai/Cargo.toml @@ -12,6 +12,10 @@ repository = { workspace = true } # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[features] +default = [] +tree-sitter = ["dep:tree-sitter-lib", "dep:tree-sitter-bash", "dep:tree-sitter-fish"] + [dependencies] atuin-client = { workspace = true } atuin-common = { workspace = true } @@ -39,9 +43,21 @@ async-stream = "0.3" uuid = { workspace = true } tui-textarea-2 = "0.10.2" unicode-width = "0.2" -eye_declare = "0.3" +eye_declare = "0.4" ratatui-core = "0.1" ratatui-widgets = "0.3" +thiserror = { workspace = true } +glob-match = { workspace = true } +regex = { workspace = true } +time = { workspace = true } +toml = "1.1" +toml_edit = { workspace = true } +tree-sitter-lib = { package = "tree-sitter", version = "0.26.8", optional = true } +tree-sitter-bash = { version = "0.25.1", optional = true } +tree-sitter-fish = { version = "3.6.0", optional = true } +typed-builder = { workspace = true } +vt100 = { workspace = true } [dev-dependencies] pretty_assertions = { workspace = true } +tempfile = { workspace = true } diff --git a/crates/atuin-ai/src/commands.rs b/crates/atuin-ai/src/commands.rs index 6e79da61..cdbc8f2d 100644 --- a/crates/atuin-ai/src/commands.rs +++ b/crates/atuin-ai/src/commands.rs @@ -9,7 +9,7 @@ use eyre::Result; use tracing_appender::rolling::{RollingFileAppender, Rotation}; use tracing_subscriber::{EnvFilter, Layer, fmt, layer::SubscriberExt, util::SubscriberInitExt}; pub mod init; -pub mod inline; +pub(crate) mod inline; #[derive(Args, Debug)] pub struct AiArgs { @@ -71,7 +71,7 @@ pub async fn run( } } -pub fn detect_shell() -> Option<String> { +pub(crate) fn detect_shell() -> Option<String> { Some(Shell::current().to_string()) } diff --git a/crates/atuin-ai/src/commands/init.rs b/crates/atuin-ai/src/commands/init.rs index 77abc4f4..f693d892 100644 --- a/crates/atuin-ai/src/commands/init.rs +++ b/crates/atuin-ai/src/commands/init.rs @@ -1,6 +1,6 @@ use crate::commands::detect_shell; -pub async fn run(shell: String) -> eyre::Result<()> { +pub(crate) async fn run(shell: String) -> eyre::Result<()> { let integration = match shell.as_str() { "zsh" => generate_zsh_integration(), "bash" => generate_bash_integration(), diff --git a/crates/atuin-ai/src/commands/inline.rs b/crates/atuin-ai/src/commands/inline.rs index aeb414fb..b37bb72f 100644 --- a/crates/atuin-ai/src/commands/inline.rs +++ b/crates/atuin-ai/src/commands/inline.rs @@ -1,38 +1,44 @@ use std::path::PathBuf; use std::sync::mpsc; -use crate::commands::detect_shell; +use crate::context::{AppContext, ClientContext}; +use crate::tui::dispatch; use crate::tui::events::AiTuiEvent; -use crate::tui::state::{AppState, ExitAction}; +use crate::tui::state::{ExitAction, Session}; use crate::tui::view::ai_view; use atuin_client::database::{Database, Sqlite}; -use atuin_client::distro::detect_linux_distribution; -use atuin_common::tls::ensure_crypto_provider; -use eventsource_stream::Eventsource; -use eye_declare::{Application, CtrlCBehavior, Handle}; +use eye_declare::{Application, CtrlCBehavior}; use eyre::{Context as _, Result, bail}; -use futures::StreamExt; -use reqwest::Url; -use tracing::{debug, error, info, trace}; +use tracing::{debug, info}; -pub async fn run( +pub(crate) async fn run( initial_command: Option<String>, api_endpoint: Option<String>, api_token: Option<String>, settings: &atuin_client::settings::Settings, output_for_hook: bool, ) -> Result<()> { - if !settings.ai.enabled.unwrap_or(false) { - emit_shell_result( - Action::Print( - "Atuin AI is not enabled. Please enable it in your settings or run `atuin setup`." - .to_string(), - ), - output_for_hook, - ); + if settings.ai.enabled == Some(false) { return Ok(()); } + if settings.ai.enabled.is_none() { + match prompt_ai_setup()? { + SetupChoice::EnableAi => { + set_ai_enabled(true).await?; + } + SetupChoice::DisableKeybind => { + set_ai_enabled(false).await?; + emit_shell_result(Action::Cancel, output_for_hook); + return Ok(()); + } + SetupChoice::Cancel => { + emit_shell_result(Action::Cancel, output_for_hook); + return Ok(()); + } + } + } + let endpoint = api_endpoint.as_deref().unwrap_or( settings .ai @@ -48,7 +54,36 @@ pub async fn run( ensure_hub_session(settings).await? }; - let action = run_inline_tui(endpoint.to_string(), token, initial_command, settings).await?; + let history_db_path = PathBuf::from(settings.db_path.as_str()); + let history_db = Sqlite::new(history_db_path, settings.local_timeout) + .await + .context("failed to open history database for AI")?; + + // Support both legacy [ai] send_cwd and new [ai.opening] send_cwd + let send_cwd = + settings.ai.opening.send_cwd.unwrap_or(false) || settings.ai.send_cwd.unwrap_or(false); + + let last_command = if settings.ai.opening.send_last_command.unwrap_or(false) { + history_db.last().await.ok().flatten().map(|h| h.command) + } else { + None + }; + + let git_root = std::env::current_dir() + .ok() + .and_then(|cwd| atuin_common::utils::in_git_repo(cwd.to_str()?)); + + let ctx = AppContext { + endpoint: endpoint.to_string(), + token, + send_cwd, + last_command, + history_db: std::sync::Arc::new(history_db), + git_root, + capabilities: settings.ai.capabilities.clone(), + }; + + let action = run_inline_tui(ctx, initial_command).await?; emit_shell_result(action, output_for_hook); Ok(()) @@ -69,7 +104,7 @@ async fn ensure_hub_session(settings: &atuin_client::settings::Settings) -> Resu if will_sync { println!( "Once logged in, your shell history will be synchronized via Atuin Hub if auto_sync is enabled or when manually syncing." - ) + ); } println!( "If you have an existing Atuin sync account, you can log in with your existing credentials." @@ -111,279 +146,16 @@ async fn ensure_hub_session(settings: &atuin_client::settings::Settings) -> Resu } // ─────────────────────────────────────────────────────────────────── -// SSE streaming -// ─────────────────────────────────────────────────────────────────── -#[derive(Debug, Clone)] -enum ChatStreamEvent { - TextChunk(String), - ToolCall { - id: String, - name: String, - input: serde_json::Value, - }, - ToolResult { - tool_use_id: String, - content: String, - is_error: bool, - }, - Status(String), - Done { - session_id: String, - }, - Error(String), -} - -fn create_chat_stream( - hub_address: String, - token: String, - session_id: Option<String>, - messages: Vec<serde_json::Value>, - send_cwd: bool, - last_command: Option<String>, -) -> std::pin::Pin<Box<dyn futures::Stream<Item = Result<ChatStreamEvent>> + Send>> { - Box::pin(async_stream::stream! { - ensure_crypto_provider(); - let endpoint = match hub_url(&hub_address, "/api/cli/chat") { - Ok(url) => url, - Err(e) => { - yield Err(e); - return; - } - }; - - debug!("Sending SSE request to {endpoint}"); - - let os = detect_os(); - let shell = detect_shell(); - - let mut context = serde_json::json!({ - "os": os, - "shell": shell, - "pwd": if send_cwd { std::env::current_dir() - .ok() - .map(|path| path.to_string_lossy().into_owned()) } else { None }, - "last_command": last_command, - }); - - if os == "linux" { - context["distro"] = serde_json::json!(detect_linux_distribution()); - } - - let mut request_body = serde_json::json!({ - "messages": messages, - "context": context, - }); - - if let Some(ref sid) = session_id { - trace!("Including session_id in request: {sid}"); - request_body["session_id"] = serde_json::json!(sid); - } - - let client = reqwest::Client::new(); - let response = match client - .post(endpoint.clone()) - .header("Accept", "text/event-stream") - .bearer_auth(&token) - .json(&request_body) - .send() - .await - { - Ok(resp) => resp, - Err(e) => { - yield Err(eyre::eyre!("Failed to send SSE request: {}", e)); - return; - } - }; +async fn run_inline_tui(ctx: AppContext, initial_prompt: Option<String>) -> Result<Action> { + let client_ctx = ClientContext::detect(); - let status = response.status(); - if status == reqwest::StatusCode::UNAUTHORIZED { - error!("SSE request failed with status: {status}, clearing session"); - let _ = atuin_client::hub::delete_session().await; - yield Err(eyre::eyre!("Hub session expired. Re-run to authenticate again.")); - return; - } - if !status.is_success() { - let body = response.text().await.unwrap_or_default(); - error!("SSE request failed ({}): {}", status, body); - yield Err(eyre::eyre!("SSE request failed ({}): {}", status, body)); - return; - } - - let byte_stream = response.bytes_stream(); - let mut stream = byte_stream.eventsource(); - - while let Some(event) = stream.next().await { - match event { - Ok(sse_event) => { - let event_type = sse_event.event.as_str(); - let data = sse_event.data.clone(); - - debug!(event_type = %event_type, "SSE event received"); - - match event_type { - "text" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) - && let Some(content) = json.get("content").and_then(|v| v.as_str()) - { - yield Ok(ChatStreamEvent::TextChunk(content.to_string())); - } - } - "tool_call" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { - let id = json.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let name = json.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let input = json.get("input").cloned().unwrap_or(serde_json::json!({})); - yield Ok(ChatStreamEvent::ToolCall { id, name, input }); - } - } - "tool_result" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { - let tool_use_id = json.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let content = json.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let is_error = json.get("is_error").and_then(|v| v.as_bool()).unwrap_or(false); - yield Ok(ChatStreamEvent::ToolResult { tool_use_id, content, is_error }); - } - } - "status" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) - && let Some(state) = json.get("state").and_then(|v| v.as_str()) - { - yield Ok(ChatStreamEvent::Status(state.to_string())); - } - } - "done" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { - let session_id = json.get("session_id") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - yield Ok(ChatStreamEvent::Done { session_id }); - } else { - yield Ok(ChatStreamEvent::Done { session_id: String::new() }); - } - break; - } - "error" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { - let message = json.get("message").and_then(|v| v.as_str()).unwrap_or("Unknown error").to_string(); - error!("SSE error: {}", message); - yield Ok(ChatStreamEvent::Error(message)); - } else { - error!("SSE error: {}", data); - yield Ok(ChatStreamEvent::Error(data)); - } - break; - } - _ => {} - } - } - Err(e) => { - yield Err(eyre::eyre!("SSE error: {}", e)); - break; - } - } - } - }) -} - -// ─────────────────────────────────────────────────────────────────── -// Async streaming task — pushes updates to app state via Handle -// ─────────────────────────────────────────────────────────────────── - -async fn run_chat_stream( - handle: Handle<AppState>, - endpoint: String, - token: String, - session_id: Option<String>, - messages: Vec<serde_json::Value>, - send_cwd: bool, - last_command: Option<String>, -) { - let stream = create_chat_stream( - endpoint, - token, - session_id, - messages, - send_cwd, - last_command, - ); - futures::pin_mut!(stream); - - while let Some(event) = stream.next().await { - match event { - Ok(ChatStreamEvent::TextChunk(text)) => { - trace!(text = %text, "Processing TextChunk"); - handle.update(move |state| { - state.append_streaming_text(&text); - }); - } - Ok(ChatStreamEvent::ToolCall { id, name, input }) => { - trace!(id = %id, name = %name, "Processing ToolCall"); - handle.update(move |state| { - state.add_tool_call(id, name, input); - }); - } - Ok(ChatStreamEvent::ToolResult { - tool_use_id, - content, - is_error, - }) => { - trace!(tool_use_id = %tool_use_id, "Processing ToolResult"); - handle.update(move |state| { - state.add_tool_result(tool_use_id, content, is_error); - }); - } - Ok(ChatStreamEvent::Status(status)) => { - trace!(status = %status, "Processing Status"); - handle.update(move |state| { - state.update_streaming_status(&status); - }); - } - Ok(ChatStreamEvent::Done { session_id }) => { - trace!(session_id = %session_id, "Processing Done"); - handle.update(move |state| { - if !session_id.is_empty() { - state.store_session_id(session_id); - } - state.finalize_streaming(); - }); - break; - } - Ok(ChatStreamEvent::Error(msg)) => { - trace!(error = %msg, "Processing Error"); - handle.update(move |state| { - state.streaming_error(msg); - }); - break; - } - Err(e) => { - let msg = e.to_string(); - handle.update(move |state| { - state.streaming_error(msg); - }); - break; - } - } - } -} - -// ─────────────────────────────────────────────────────────────────── -// Main TUI entry point -// ─────────────────────────────────────────────────────────────────── + let (tx, rx) = mpsc::channel::<AiTuiEvent>(); -async fn run_inline_tui( - endpoint: String, - token: String, - initial_prompt: Option<String>, - settings: &atuin_client::settings::Settings, -) -> Result<Action> { - let initial_state = AppState::new(); + let initial_state = Session::new(ctx.git_root.is_some()); println!(); - let (tx, rx) = mpsc::channel::<AiTuiEvent>(); - // If there's an initial prompt, send it as a SubmitInput event // so it flows through the same path as user-typed input. if let Some(prompt) = initial_prompt { @@ -396,164 +168,17 @@ async fn run_inline_tui( .ctrl_c(CtrlCBehavior::Deliver) .keyboard_protocol(eye_declare::KeyboardProtocol::Enhanced) .bracketed_paste(true) - .with_context(tx) + .with_context(tx.clone()) .extra_newlines_at_exit(1) .build()?; - // Support both legacy [ai] send_cwd and new [ai.opening] send_cwd - let send_cwd = - settings.ai.opening.send_cwd.unwrap_or(false) || settings.ai.send_cwd.unwrap_or(false); - - let last_command = if settings.ai.opening.send_last_command.unwrap_or(false) { - let db_path = PathBuf::from(settings.db_path.as_str()); - match Sqlite::new(db_path, settings.local_timeout).await { - Ok(db) => db.last().await.ok().flatten().map(|h| h.command), - Err(e) => { - debug!("Failed to open history database for read_history: {e}"); - None - } - } - } else { - None - }; - // Event loop: receives AiTuiEvent from components, mutates state via Handle. let h = handle.clone(); - let ep = endpoint.clone(); - let tk = token.clone(); tokio::task::spawn_blocking(move || { + let tx = tx.clone(); + let client_ctx = client_ctx; while let Ok(event) = rx.recv() { - match event { - AiTuiEvent::InputUpdated(input) => { - let input_blank = input.trim().is_empty(); - - h.update(move |state| { - state.is_input_blank = input_blank; - }); - } - AiTuiEvent::SubmitInput(input) => { - let input = input.trim().to_string(); - if input.is_empty() { - let h2 = h.clone(); - h.update(move |state| { - if state.has_any_command() { - state.exit_action = Some(ExitAction::Execute( - state.current_command().unwrap().to_string(), - )); - } else { - state.exit_action = Some(ExitAction::Cancel); - } - h2.exit(); - }); - continue; - } - - if input.starts_with('/') { - let input_clone = input.clone(); - h.update(move |state| { - state.handle_slash_command(&input_clone); - }); - continue; - } - - // Start generation and spawn streaming task - let ep = ep.clone(); - let tk = tk.clone(); - let h2 = h.clone(); - let lc = last_command.clone(); - h.update(move |state| { - state.start_generating(input); - state.start_streaming(); - state.is_input_blank = true; - let messages = state.events_to_messages(); - let sid = state.session_id.clone(); - let task = tokio::spawn(async move { - run_chat_stream(h2, ep, tk, sid, messages, send_cwd, lc).await; - }); - state.stream_abort = Some(task.abort_handle()); - }); - } - - AiTuiEvent::SlashCommand(command) => { - h.update(move |state| { - state.handle_slash_command(&command); - }); - } - - AiTuiEvent::CancelGeneration => { - h.update(|state| match state.mode { - crate::tui::state::AppMode::Generating => { - state.cancel_generation(); - } - crate::tui::state::AppMode::Streaming => { - state.cancel_streaming(); - } - _ => {} - }); - } - - AiTuiEvent::ExecuteCommand => { - let h2 = h.clone(); - h.update(move |state| { - let cmd = state.current_command().map(|c| c.to_string()); - if let Some(cmd) = cmd { - if state.is_current_command_dangerous() && !state.confirmation_pending { - state.confirmation_pending = true; - } else { - state.confirmation_pending = false; - state.exit_action = Some(ExitAction::Execute(cmd)); - h2.exit(); - } - } - }); - } - - AiTuiEvent::CancelConfirmation => { - h.update(move |state| { - state.confirmation_pending = false; - }); - } - - AiTuiEvent::InsertCommand => { - let h2 = h.clone(); - h.update(move |state| { - let cmd = state.current_command().map(|c| c.to_string()); - if let Some(cmd) = cmd { - state.confirmation_pending = false; - state.exit_action = Some(ExitAction::Insert(cmd)); - h2.exit(); - } - }); - } - - AiTuiEvent::Retry => { - let ep = ep.clone(); - let tk = tk.clone(); - let h2 = h.clone(); - let lc = last_command.clone(); - h.update(move |state| { - state.retry(); - state.start_streaming(); - let messages = state.events_to_messages(); - let sid = state.session_id.clone(); - let task = tokio::spawn(async move { - run_chat_stream(h2, ep, tk, sid, messages, send_cwd, lc).await; - }); - state.stream_abort = Some(task.abort_handle()); - }); - } - - AiTuiEvent::Exit => { - let h2 = h.clone(); - h.update(move |state| { - if let Some(abort) = state.stream_abort.take() { - abort.abort(); - } - state.exit_action = Some(ExitAction::Cancel); - h2.exit(); - }); - } - } + dispatch::dispatch(&h, event, &tx, &ctx, &client_ctx); } }); @@ -573,51 +198,125 @@ async fn run_inline_tui( // Helpers // ─────────────────────────────────────────────────────────────────── -fn hub_url(base: &str, path: &str) -> Result<Url> { - let base_with_slash = if base.ends_with('/') { - base.to_string() - } else { - format!("{base}/") - }; - let stripped = path.strip_prefix('/').unwrap_or(path); - Url::parse(&base_with_slash)? - .join(stripped) - .context("failed to build hub URL") +enum SetupChoice { + EnableAi, + DisableKeybind, + Cancel, } -fn detect_os() -> String { - match std::env::consts::OS { - "macos" => "macos".to_string(), - "linux" => "linux".to_string(), - "windows" => "windows".to_string(), - other => format!("Other: {other}"), +fn prompt_ai_setup() -> Result<SetupChoice> { + use crossterm::{ + cursor, + event::{self, Event, KeyCode}, + terminal, + }; + + let options = ["Enable Atuin AI", "Disable ? Keybind", "Cancel"]; + let mut selected: usize = 0; + let mut stdout = std::io::stdout(); + + // Print header before raw mode so newlines render correctly. + // Use stdout because the shell hook swaps stdout/stderr — stdout goes + // to the terminal in both hook and non-hook modes. + println!(); + println!(" Atuin AI is not yet configured."); + println!(); + + terminal::enable_raw_mode().context("failed to enable raw mode")?; + struct Guard; + impl Drop for Guard { + fn drop(&mut self) { + let _ = terminal::disable_raw_mode(); + } } -} + let _guard = Guard; -#[derive(Clone)] -enum Action { - Execute(String), - Insert(String), - Print(String), - Cancel, -} + crossterm::execute!(stdout, cursor::Hide)?; -fn emit_shell_result(action: Action, output_for_hook: bool) { - if output_for_hook { - match action { - Action::Execute(output) => eprintln!("__atuin_ai_execute__:{output}"), - Action::Insert(output) => eprintln!("__atuin_ai_insert__:{output}"), - Action::Print(output) => eprintln!("__atuin_ai_print__:{output}"), - Action::Cancel => eprintln!("__atuin_ai_cancel__"), + loop { + render_setup_options(&mut stdout, &options, selected)?; + + let ev = event::read().context("failed to read key event")?; + + crossterm::execute!(stdout, cursor::MoveUp(options.len() as u16))?; + + if let Event::Key(key) = ev { + match key.code { + KeyCode::Up | KeyCode::Char('k') => { + selected = selected.saturating_sub(1); + } + KeyCode::Down | KeyCode::Char('j') => { + if selected < options.len() - 1 { + selected += 1; + } + } + KeyCode::Enter => break, + KeyCode::Esc => { + selected = 2; + break; + } + _ => {} + } } - } else { - match action { - Action::Execute(output) => eprintln!("{output}"), - Action::Insert(output) => eprintln!("{output}"), - Action::Print(output) => eprintln!("{output}"), - Action::Cancel => eprintln!(), + } + + // Final render with selection visible + render_setup_options(&mut stdout, &options, selected)?; + crossterm::execute!(stdout, cursor::Show)?; + + Ok(match selected { + 0 => SetupChoice::EnableAi, + 1 => SetupChoice::DisableKeybind, + _ => SetupChoice::Cancel, + }) +} + +fn render_setup_options( + w: &mut impl std::io::Write, + options: &[&str], + selected: usize, +) -> Result<()> { + use crossterm::{ + style::Stylize, + terminal::{Clear, ClearType}, + }; + + for (i, option) in options.iter().enumerate() { + if i == selected { + write!(w, "\r {}", format!("> {option}").bold().cyan())?; + } else { + write!(w, "\r {option}")?; } + crossterm::execute!(w, Clear(ClearType::UntilNewLine))?; + write!(w, "\r\n")?; + } + w.flush()?; + Ok(()) +} + +async fn set_ai_enabled(enabled: bool) -> Result<()> { + let config_file = atuin_client::settings::Settings::get_config_path()?; + let config_str = tokio::fs::read_to_string(&config_file).await?; + let mut doc = config_str.parse::<toml_edit::DocumentMut>()?; + + if !doc.contains_key("ai") { + doc["ai"] = toml_edit::table(); + } + doc["ai"]["enabled"] = toml_edit::value(enabled); + + tokio::fs::write(&config_file, doc.to_string()).await?; + + if !enabled { + println!( + "Atuin AI keybind disabled. You can re-enable with `atuin config set ai.enabled true`.", + ); + println!("Restart your shell for changes to take effect."); + // Two printlns to ensure the message is visible above the shell prompt after program ends. + println!(); + println!(); } + + Ok(()) } fn wait_for_login_confirmation() -> Result<bool> { @@ -646,3 +345,27 @@ fn wait_for_login_confirmation() -> Result<bool> { } } } + +#[derive(Clone)] +enum Action { + Execute(String), + Insert(String), + Cancel, +} + +fn emit_shell_result(action: Action, output_for_hook: bool) { + if output_for_hook { + match action { + Action::Execute(output) => eprintln!("__atuin_ai_execute__:{output}"), + Action::Insert(output) => eprintln!("__atuin_ai_insert__:{output}"), + Action::Cancel => eprintln!("__atuin_ai_cancel__"), + } + } else { + match action { + Action::Execute(output) | Action::Insert(output) => { + println!("{output}"); + } + Action::Cancel => {} + } + } +} diff --git a/crates/atuin-ai/src/context.rs b/crates/atuin-ai/src/context.rs new file mode 100644 index 00000000..dabb5c5e --- /dev/null +++ b/crates/atuin-ai/src/context.rs @@ -0,0 +1,73 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use atuin_client::distro::detect_linux_distribution; +use atuin_client::settings::AiCapabilities; + +/// Session-scoped context for the AI chat session. +/// Holds the API configuration and client settings needed by the event loop and stream task. +#[derive(Clone, Debug)] +pub(crate) struct AppContext { + pub endpoint: String, + pub token: String, + pub send_cwd: bool, + pub last_command: Option<String>, + pub history_db: Arc<atuin_client::database::Sqlite>, + /// Git root of the current working directory, if inside a git repo. + /// Resolves through worktrees to the main repo root. + pub git_root: Option<PathBuf>, + pub capabilities: AiCapabilities, +} + +/// Machine identity — computed once per session. +#[derive(Clone, Debug)] +pub(crate) struct ClientContext { + pub os: String, + pub shell: Option<String>, + pub distro: Option<String>, +} + +impl ClientContext { + pub(crate) fn detect() -> Self { + let os = detect_os(); + let shell = crate::commands::detect_shell(); + let distro = if os == "linux" { + Some(detect_linux_distribution()) + } else { + None + }; + Self { os, shell, distro } + } + + /// Serialize to the JSON format the API expects for the "context" field. + /// The `pwd` field is always dynamic (current working directory), so it's + /// computed fresh on each call if `send_cwd` is true. + pub(crate) fn to_json(&self, send_cwd: bool, last_command: Option<&str>) -> serde_json::Value { + let mut ctx = serde_json::json!({ + "os": self.os, + "shell": self.shell, + "pwd": if send_cwd { + std::env::current_dir().ok().map(|p| p.to_string_lossy().into_owned()) + } else { + None + }, + "last_command": last_command, + }); + + if let Some(ref distro) = self.distro { + ctx["distro"] = serde_json::json!(distro); + } + + ctx + } +} + +/// Move the `detect_os` function here since it's about client identity. +fn detect_os() -> String { + match std::env::consts::OS { + "macos" => "macos".to_string(), + "linux" => "linux".to_string(), + "windows" => "windows".to_string(), + other => format!("Other: {other}"), + } +} diff --git a/crates/atuin-ai/src/lib.rs b/crates/atuin-ai/src/lib.rs index 2d86271d..6f431179 100644 --- a/crates/atuin-ai/src/lib.rs +++ b/crates/atuin-ai/src/lib.rs @@ -1,2 +1,6 @@ pub mod commands; -pub mod tui; +pub(crate) mod context; +pub(crate) mod permissions; +pub(crate) mod stream; +pub(crate) mod tools; +pub(crate) mod tui; diff --git a/crates/atuin-ai/src/permissions/check.rs b/crates/atuin-ai/src/permissions/check.rs new file mode 100644 index 00000000..6b908b93 --- /dev/null +++ b/crates/atuin-ai/src/permissions/check.rs @@ -0,0 +1,74 @@ +use eyre::Result; + +use crate::{permissions::file::RuleFile, tools::PermissableToolCall}; + +pub(crate) struct PermissionRequest<'t> { + call: &'t (dyn PermissableToolCall + Send + Sync), +} + +impl<'t> PermissionRequest<'t> { + pub fn new(call: &'t (dyn PermissableToolCall + Send + Sync)) -> Self { + Self { call } + } +} + +pub(crate) enum PermissionResponse { + Allowed, + Denied, + Ask, +} + +pub(crate) struct PermissionChecker { + files: Vec<RuleFile>, +} + +impl PermissionChecker { + pub fn new(files: Vec<RuleFile>) -> Self { + Self { files } + } + + pub async fn check<'t>( + &self, + request: &'t PermissionRequest<'t>, + ) -> Result<PermissionResponse> { + // Files are in order from deepest to shallowest, so we can stop at the first match. + // Within a file, the priority is ask -> deny -> allow + // The first rule type that matches is the one that applies, even if a later rule would contradict it. + for file in &self.files { + for rule in &file.content.permissions.ask { + if request.call.matches_rule(rule) { + tracing::debug!( + "Permission 'ASK' by rule: {} in file: {}", + rule, + file.path.display() + ); + return Ok(PermissionResponse::Ask); + } + } + + for rule in &file.content.permissions.deny { + if request.call.matches_rule(rule) { + tracing::debug!( + "Permission 'DENY' by rule: {} in file: {}", + rule, + file.path.display() + ); + return Ok(PermissionResponse::Denied); + } + } + + for rule in &file.content.permissions.allow { + if request.call.matches_rule(rule) { + tracing::debug!( + "Permission 'ALLOW' by rule: {} in file: {}", + rule, + file.path.display() + ); + return Ok(PermissionResponse::Allowed); + } + } + } + + Ok(PermissionResponse::Ask) + } +} diff --git a/crates/atuin-ai/src/permissions/file.rs b/crates/atuin-ai/src/permissions/file.rs new file mode 100644 index 00000000..c973f55b --- /dev/null +++ b/crates/atuin-ai/src/permissions/file.rs @@ -0,0 +1,26 @@ +use std::path::PathBuf; + +use serde::{Deserialize, Serialize}; + +use crate::permissions::rule::Rule; + +#[derive(Debug, Clone)] +pub(crate) struct RuleFile { + pub path: PathBuf, + pub content: RuleFileContent, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub(crate) struct RuleFileContent { + pub permissions: RuleFilePermissions, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub(crate) struct RuleFilePermissions { + #[serde(default)] + pub allow: Vec<Rule>, + #[serde(default)] + pub deny: Vec<Rule>, + #[serde(default)] + pub ask: Vec<Rule>, +} diff --git a/crates/atuin-ai/src/permissions/mod.rs b/crates/atuin-ai/src/permissions/mod.rs new file mode 100644 index 00000000..fce64a51 --- /dev/null +++ b/crates/atuin-ai/src/permissions/mod.rs @@ -0,0 +1,7 @@ +pub(crate) mod check; +pub(crate) mod file; +pub(crate) mod resolver; +pub(crate) mod rule; +pub(crate) mod shell; +pub(crate) mod walker; +pub(crate) mod writer; diff --git a/crates/atuin-ai/src/permissions/resolver.rs b/crates/atuin-ai/src/permissions/resolver.rs new file mode 100644 index 00000000..dc4f83bf --- /dev/null +++ b/crates/atuin-ai/src/permissions/resolver.rs @@ -0,0 +1,31 @@ +use std::path::PathBuf; + +use eyre::Result; + +use crate::permissions::check::{PermissionChecker, PermissionRequest, PermissionResponse}; +use crate::permissions::walker::PermissionWalker; +use crate::permissions::writer; +use crate::tools::ClientToolCall; + +/// Resolves permissions for client tool calls by walking the filesystem to find permission files, +pub(crate) struct PermissionResolver { + checker: PermissionChecker, +} + +impl PermissionResolver { + /// Create a new resolver that walks from `working_dir` to root for project + /// permissions, and also checks the global permissions file. + pub async fn new(working_dir: PathBuf) -> Result<Self> { + let global_file = writer::global_permissions_path(); + let mut walker = PermissionWalker::new(working_dir, Some(global_file)); + walker.walk().await?; + let checker = PermissionChecker::new(walker.rules().to_owned()); + Ok(Self { checker }) + } + + /// Check whether `tool` is allowed, denied, or needs user confirmation. + pub async fn check(&self, tool: &ClientToolCall) -> Result<PermissionResponse> { + let request = PermissionRequest::new(tool); + self.checker.check(&request).await + } +} diff --git a/crates/atuin-ai/src/permissions/rule.rs b/crates/atuin-ai/src/permissions/rule.rs new file mode 100644 index 00000000..8fa3fa4a --- /dev/null +++ b/crates/atuin-ai/src/permissions/rule.rs @@ -0,0 +1,106 @@ +use std::sync::OnceLock; + +use regex::Regex; +use serde::{Deserialize, Serialize}; + +static RULE_RE: OnceLock<Regex> = OnceLock::new(); + +#[derive(Debug, thiserror::Error)] +pub(crate) enum RuleError { + #[error("invalid rule format: {0}")] + InvalidRule(String), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct Rule { + pub tool: String, + pub scope: Option<String>, +} + +impl std::fmt::Display for Rule { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.scope.as_ref() { + Some(scope) => write!(f, "{}({})", self.tool, scope), + None => write!(f, "{}", self.tool), + } + } +} + +impl Serialize for Rule { + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> + where + S: serde::Serializer, + { + serializer.serialize_str(&self.to_string()) + } +} + +impl<'de> Deserialize<'de> for Rule { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + Self::try_from(s.as_str()).map_err(serde::de::Error::custom) + } +} +impl TryFrom<&str> for Rule { + type Error = RuleError; + + fn try_from(value: &str) -> Result<Self, Self::Error> { + let value = value.trim(); + let re = RULE_RE.get_or_init(|| Regex::new(r"^(\w+)(?:\((.*)\))?$").unwrap()); + let caps = re + .captures(value) + .ok_or(RuleError::InvalidRule(value.to_string()))?; + let tool = caps.get(1).unwrap().as_str().to_string(); + let scope = caps.get(2).map(|m| m.as_str().to_string()); + Ok(Rule { tool, scope }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_rule_try_from() { + assert_eq!( + Rule::try_from("Read").unwrap(), + Rule { + tool: "Read".to_string(), + scope: None + } + ); + assert_eq!( + Rule::try_from("Read(*)").unwrap(), + Rule { + tool: "Read".to_string(), + scope: Some("*".to_string()) + } + ); + assert_eq!( + Rule::try_from("Write(*.md)").unwrap(), + Rule { + tool: "Write".to_string(), + scope: Some("*.md".to_string()) + } + ); + assert_eq!( + Rule::try_from("Shell(git commit *)").unwrap(), + Rule { + tool: "Shell".to_string(), + scope: Some("git commit *".to_string()) + } + ); + assert_eq!( + Rule::try_from("Shell(echo ())").unwrap(), + Rule { + tool: "Shell".to_string(), + scope: Some("echo ()".to_string()) + } + ); + assert!(Rule::try_from("Shell(git commit *").is_err()); + assert!(Rule::try_from("Shell(git commit *)!").is_err()); + } +} diff --git a/crates/atuin-ai/src/permissions/shell.rs b/crates/atuin-ai/src/permissions/shell.rs new file mode 100644 index 00000000..7a2eee2e --- /dev/null +++ b/crates/atuin-ai/src/permissions/shell.rs @@ -0,0 +1,1297 @@ +/// Extracted command info from a shell command string. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct ShellCommand { + /// The command name (first word), e.g. "git" + pub name: String, + /// The full invocation including arguments, e.g. "git commit -m msg" + pub full: String, +} + +/// A parsed shell command with all subcommands extracted. +#[derive(Debug)] +pub(crate) struct ParsedShellCommand { + pub subcommands: Vec<ShellCommand>, +} + +/// Supported shell families for parsing. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum ShellKind { + /// POSIX sh, bash, zsh — all share similar syntax + Posix, + /// fish shell + Fish, + /// nushell or unknown — fallback to word-level extraction + Other, +} + +impl ShellKind { + pub(crate) fn from_shell_name(name: &str) -> Self { + match name { + "bash" | "sh" | "zsh" | "dash" | "ksh" => Self::Posix, + "fish" => Self::Fish, + _ => Self::Other, + } + } +} + +/// Parse a shell command string and extract all subcommands. +pub(crate) fn parse_shell_command(code: &str, shell: ShellKind) -> ParsedShellCommand { + #[cfg(feature = "tree-sitter")] + match shell { + ShellKind::Posix => ts::parse_posix(code), + ShellKind::Fish => ts::parse_fish(code), + ShellKind::Other => parse_fallback(code), + } + + #[cfg(not(feature = "tree-sitter"))] + { + let _ = shell; + parse_fallback(code) + } +} + +// ──────────────────────────────────────────────────────────────── +// Tree-sitter parsers (POSIX + Fish) +// Disabled on platforms where tree-sitter doesn't cross-compile +// (e.g. Windows); falls back to word-level extraction. +// ──────────────────────────────────────────────────────────────── + +#[cfg(feature = "tree-sitter")] +mod ts { + use super::{ParsedShellCommand, ShellCommand, parse_fallback}; + use tree_sitter_lib::{Parser, Tree}; + + fn bash_parser() -> Parser { + let mut parser = Parser::new(); + parser + .set_language(&tree_sitter_bash::LANGUAGE.into()) + .expect("failed to set bash language"); + parser + } + + pub(super) fn parse_posix(code: &str) -> ParsedShellCommand { + let mut parser = bash_parser(); + let Some(tree) = parser.parse(code, None) else { + return parse_fallback(code); + }; + + let mut commands = Vec::new(); + walk_bash_tree(&tree, code.as_bytes(), &mut commands); + ParsedShellCommand { + subcommands: commands, + } + } + + /// Leaf node kinds that never contain nested commands. + const BASH_LEAVES: &[&str] = &[ + "command_name", + "word", + "number", + "simple_expansion", + "expansion", + "arithmetic_expansion", + "ansi_c_string", + "special_variable_name", + "variable_name", + "file_descriptor", + "heredoc_body", + "heredoc_start", + "regex", + "heredoc_redirect", + ]; + + fn walk_bash_tree(tree: &Tree, source: &[u8], commands: &mut Vec<ShellCommand>) { + walk_bash_node(tree.root_node(), source, commands); + } + + fn walk_bash_node( + node: tree_sitter_lib::Node, + source: &[u8], + commands: &mut Vec<ShellCommand>, + ) { + match node.kind() { + "command" => { + if let Some(cmd) = extract_bash_command(node, source) { + commands.push(cmd); + } + // Descend into all non-leaf children to find nested commands + // (e.g. command_substitution inside a string inside a command) + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if !BASH_LEAVES.contains(&child.kind()) { + walk_bash_node(child, source, commands); + } + } + } + // Other nodes: descend into all children + _ => { + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + walk_bash_node(child, source, commands); + } + } + } + } + + /// Extract the full command string and name from a bash `command` node. + fn extract_bash_command(node: tree_sitter_lib::Node, source: &[u8]) -> Option<ShellCommand> { + // A `command` node has children like: + // variable_assignment* command_name argument* redirect* + // We want the command_name and all arguments (skipping assignments and redirects). + let mut name = None; + let mut name_start = None; + let mut arg_end = None; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + match child.kind() { + "command_name" => { + name = child.utf8_text(source).ok().map(|s| s.to_string()); + name_start = Some(child.start_byte()); + } + "word" + | "string" + | "raw_string" + | "concatenation" + | "number" + | "simple_expansion" + | "expansion" + | "arithmetic_expansion" + | "ansi_c_string" + | "process_substitution" => { + arg_end = Some(child.end_byte()); + } + _ => {} + } + } + + let name = name?; + let full = if let (Some(start), Some(end)) = (name_start, arg_end) { + std::str::from_utf8(&source[start..end]).ok()?.to_string() + } else { + name.clone() + }; + + Some(ShellCommand { name, full }) + } + + // ──────────────────────────────────────────────────────────────── + // Fish parser + // ──────────────────────────────────────────────────────────────── + + fn fish_parser() -> Parser { + let mut parser = Parser::new(); + parser + .set_language(&tree_sitter_fish::language()) + .expect("failed to set fish language"); + parser + } + + pub(super) fn parse_fish(code: &str) -> ParsedShellCommand { + let mut parser = fish_parser(); + let Some(tree) = parser.parse(code, None) else { + return parse_fallback(code); + }; + + let mut commands = Vec::new(); + walk_fish_tree(&tree, code.as_bytes(), &mut commands); + ParsedShellCommand { + subcommands: commands, + } + } + + const FISH_COMPOUND: &[&str] = &[ + "conditional_execution", + "pipe", + "job", + "command_substitution", + "block", + "for_statement", + "while_statement", + "if_statement", + "switch_statement", + "function_definition", + "begin_statement", + "redirected_statement", + ]; + + fn walk_fish_tree(tree: &Tree, source: &[u8], commands: &mut Vec<ShellCommand>) { + walk_fish_node(tree.root_node(), source, commands); + } + + fn walk_fish_node( + node: tree_sitter_lib::Node, + source: &[u8], + commands: &mut Vec<ShellCommand>, + ) { + match node.kind() { + "command" => { + if let Some(cmd) = extract_fish_command(node, source) { + commands.push(cmd); + } + // Still descend into compound children (e.g. command_substitution inside a command) + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if FISH_COMPOUND.contains(&child.kind()) { + walk_fish_node(child, source, commands); + } + } + } + // Other nodes: descend into all children + _ => { + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + walk_fish_node(child, source, commands); + } + } + } + } + + fn extract_fish_command(node: tree_sitter_lib::Node, source: &[u8]) -> Option<ShellCommand> { + // In fish, a `command` node has: + // name (command_name or word) followed by arguments (word, string, etc.) + let mut name = None; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + match child.kind() { + "command_name" | "word" => { + let text = child.utf8_text(source).ok()?.to_string(); + if name.is_none() { + name = Some(text); + } + } + "string" + | "concatenation" + | "command_substitution" + | "escape_sequence" + | "double_quote_string" + | "single_quote_string" => {} + _ => {} + } + } + + let name = name?; + // Get the full text of the command node + let full = node.utf8_text(source).ok()?.trim().to_string(); + + Some(ShellCommand { name, full }) + } +} // mod ts + +// ──────────────────────────────────────────────────────────────── +// Fallback (word-level extraction for nushell / unknown shells) +// ──────────────────────────────────────────────────────────────── + +fn parse_fallback(code: &str) -> ParsedShellCommand { + // Simple heuristic: split by &&, ||, ;, | and take the first word of each segment. + // This is intentionally simple — for unknown shells we can't do better. + let mut commands = Vec::new(); + let mut segment = String::new(); + let mut chars = code.chars().peekable(); + + while let Some(c) = chars.next() { + match c { + ';' => { + push_segment(&mut segment, &mut commands); + } + '|' => { + if chars.peek() == Some(&'|') { + chars.next(); + } + push_segment(&mut segment, &mut commands); + } + '&' if chars.peek() == Some(&'&') => { + chars.next(); + push_segment(&mut segment, &mut commands); + } + _ => segment.push(c), + } + } + push_segment(&mut segment, &mut commands); + + ParsedShellCommand { + subcommands: commands, + } +} + +fn push_segment(segment: &mut String, commands: &mut Vec<ShellCommand>) { + let trimmed = segment.trim(); + if !trimmed.is_empty() + && let Some(name) = trimmed.split_whitespace().next() + { + commands.push(ShellCommand { + name: name.to_string(), + full: trimmed.to_string(), + }); + } + segment.clear(); +} + +// ──────────────────────────────────────────────────────────────── +// Scope matching +// ──────────────────────────────────────────────────────────────── + +/// Check if any of the extracted subcommands match the given scope pattern. +/// +/// Matching semantics depend on where the `*` wildcard appears: +/// - `*` alone — matches everything +/// - `ls *` (space before `*`) — matches `ls` and `ls -a` but not `lsof` +/// - `git commit *` — matches `git commit -m "msg"` (word boundary) +/// - `ls*` (no space before `*`) — matches `lsof`, `ls`, `ls -a` (prefix/glob) +/// - `rm` (no wildcard) — matches exactly `rm` +/// - `git * amend` — matches `git commit amend` (middle wildcard matches zero+ words) +pub(crate) fn any_subcommand_matches(subcommands: &[ShellCommand], scope: &str) -> bool { + let scope = scope.trim(); + + if scope == "*" { + return true; + } + + if let Some(prefix) = scope.strip_suffix(" *") { + // Word-boundary matching: `ls *` matches `ls` and `ls -a` but not `lsof` + return subcommands.iter().any(|cmd| { + if prefix.is_empty() { + return true; + } + let cmd_words: Vec<&str> = cmd.full.split_whitespace().collect(); + let prefix_words: Vec<&str> = prefix.split_whitespace().collect(); + cmd_words.len() >= prefix_words.len() + && cmd_words[..prefix_words.len()] == prefix_words[..] + }); + } + + if let Some(prefix) = scope.strip_suffix('*') { + // Prefix/glob matching: `ls*` matches `lsof`, `ls`, etc. + return subcommands.iter().any(|cmd| cmd.full.starts_with(prefix)); + } + + if scope.contains('*') { + // Middle wildcard: `git * amend` — each `*` matches zero or more words + return subcommands + .iter() + .any(|cmd| scope_matches_words(scope, cmd.full.split_whitespace().collect())); + } + + // No wildcard: word-boundary prefix match + let scope_words: Vec<&str> = scope.split_whitespace().collect(); + subcommands.iter().any(|cmd| { + let cmd_words: Vec<&str> = cmd.full.split_whitespace().collect(); + cmd_words.len() >= scope_words.len() && cmd_words[..scope_words.len()] == scope_words[..] + }) +} + +/// Match a scope pattern containing `*` wildcards against a sequence of words. +/// Each `*` matches zero or more words. Consecutive `*` collapse into one. +fn scope_matches_words(scope: &str, words: Vec<&str>) -> bool { + let parts: Vec<&str> = scope.split('*').collect(); + if parts.len() == 1 { + // No wildcard (shouldn't reach here, but handle it) + let scope_words: Vec<&str> = scope.split_whitespace().collect(); + return words.len() >= scope_words.len() && words[..scope_words.len()] == scope_words[..]; + } + + // Each segment between * is a sequence of literal words that must appear in order. + // Walk through `words` consuming segments left to right. + let mut word_idx = 0; + + for (i, part) in parts.iter().enumerate() { + let segment_words: Vec<&str> = part.split_whitespace().collect(); + if segment_words.is_empty() { + continue; + } + + // Find the segment words starting from word_idx + if i == 0 { + // First segment must match at the start + if words.len() < segment_words.len() + || words[..segment_words.len()] != segment_words[..] + { + return false; + } + word_idx = segment_words.len(); + } else if i == parts.len() - 1 { + // Last segment must match at the end + if words.len() - word_idx < segment_words.len() { + return false; + } + let start = words.len() - segment_words.len(); + return words[start..] == segment_words[..]; + } else { + // Middle segment: find it anywhere after word_idx + let found = find_subslice(&words[word_idx..], &segment_words); + match found { + Some(idx) => word_idx += idx + segment_words.len(), + None => return false, + } + } + } + + true +} + +/// Find the first occurrence of `needle` as a contiguous subsequence in `haystack`. +fn find_subslice(haystack: &[&str], needle: &[&str]) -> Option<usize> { + if needle.is_empty() { + return Some(0); + } + if haystack.len() < needle.len() { + return None; + } + (0..=haystack.len() - needle.len()).find(|&i| haystack[i..i + needle.len()] == needle[..]) +} + +// ──────────────────────────────────────────────────────────────── +// Tests +// ──────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + fn names(cmds: &[ShellCommand]) -> Vec<&str> { + cmds.iter().map(|c| c.name.as_str()).collect() + } + + fn fulls(cmds: &[ShellCommand]) -> Vec<&str> { + cmds.iter().map(|c| c.full.as_str()).collect() + } + + #[test] + fn simple_command() { + let result = parse_shell_command("ls -la /tmp", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["ls"]); + assert_eq!(fulls(&result.subcommands), vec!["ls -la /tmp"]); + } + + #[test] + fn pipeline() { + let result = parse_shell_command("cat file.txt | grep foo | wc -l", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["cat", "grep", "wc"]); + } + + #[test] + fn command_chaining() { + let result = parse_shell_command("git add . && git commit -m 'hi'", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["git", "git"]); + assert_eq!( + fulls(&result.subcommands), + vec!["git add .", "git commit -m 'hi'"] + ); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn command_substitution() { + let result = parse_shell_command("echo $(git rev-parse HEAD)", ShellKind::Posix); + let n = names(&result.subcommands); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + assert!(n.contains(&"git"), "should contain git: {n:?}"); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn backtick_substitution() { + let result = parse_shell_command("echo `date`", ShellKind::Posix); + let n = names(&result.subcommands); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + assert!(n.contains(&"date"), "should contain date: {n:?}"); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn subshell() { + let result = parse_shell_command("(cd /tmp && ls)", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["cd", "ls"]); + } + + #[test] + fn semicolon_separated() { + let result = parse_shell_command("echo hello; echo world", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["echo", "echo"]); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn for_loop() { + let result = parse_shell_command("for f in *.txt; do cat $f; done", ShellKind::Posix); + assert!(names(&result.subcommands).contains(&"cat")); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn if_statement() { + let result = parse_shell_command( + "if [ -f foo ]; then cat foo; else echo nope; fi", + ShellKind::Posix, + ); + let n = names(&result.subcommands); + assert!(n.contains(&"cat"), "should contain cat: {n:?}"); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + } + + #[test] + fn scope_matching_wildcard() { + let commands = vec![ + ShellCommand { + name: "git".into(), + full: "git commit -m msg".into(), + }, + ShellCommand { + name: "npm".into(), + full: "npm test".into(), + }, + ]; + assert!(any_subcommand_matches(&commands, "*")); + } + + #[test] + fn scope_matching_prefix() { + let commands = vec![ + ShellCommand { + name: "git".into(), + full: "git commit -m msg".into(), + }, + ShellCommand { + name: "npm".into(), + full: "npm test".into(), + }, + ]; + assert!(any_subcommand_matches(&commands, "git commit *")); + assert!(any_subcommand_matches(&commands, "git commit")); + assert!(!any_subcommand_matches(&commands, "git push *")); + assert!(!any_subcommand_matches(&commands, "git push")); + assert!(any_subcommand_matches(&commands, "npm *")); + } + + #[test] + fn scope_word_boundary_vs_glob() { + let commands = vec![ + ShellCommand { + name: "ls".into(), + full: "ls -a".into(), + }, + ShellCommand { + name: "lsof".into(), + full: "lsof -i :3000".into(), + }, + ]; + // `ls *` — word boundary: matches `ls -a` but not `lsof` + assert!(any_subcommand_matches(&commands, "ls *")); + assert!(!any_subcommand_matches(&commands, "cat *")); + assert!(any_subcommand_matches(&commands, "lsof *")); + + // `ls*` — glob/prefix: matches both `ls -a` and `lsof` + assert!(any_subcommand_matches(&commands, "ls*")); + } + + #[test] + fn scope_exact_match() { + let commands = vec![ShellCommand { + name: "ls".into(), + full: "ls".into(), + }]; + assert!(any_subcommand_matches(&commands, "ls")); + assert!(!any_subcommand_matches(&commands, "cat")); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn nested_substitution() { + let result = parse_shell_command( + "echo \"Result: $(git log --oneline | head -1)\"", + ShellKind::Posix, + ); + let n = names(&result.subcommands); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + assert!(n.contains(&"git"), "should contain git: {n:?}"); + assert!(n.contains(&"head"), "should contain head: {n:?}"); + } + + #[test] + fn fallback_splits_correctly() { + let result = parse_shell_command("ls && cat foo || echo fail", ShellKind::Other); + let n = names(&result.subcommands); + assert!(n.contains(&"ls"), "should contain ls: {n:?}"); + assert!(n.contains(&"cat"), "should contain cat: {n:?}"); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + } + + #[test] + fn fish_simple_command() { + let result = parse_shell_command("ls -la /tmp", ShellKind::Fish); + assert_eq!(names(&result.subcommands), vec!["ls"]); + } + + #[test] + fn fish_conditional() { + let result = parse_shell_command("git add .; and git commit -m hi", ShellKind::Fish); + let n = names(&result.subcommands); + assert!(n.contains(&"git"), "should contain git: {n:?}"); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn fish_command_substitution() { + let result = parse_shell_command("echo (date)", ShellKind::Fish); + let n = names(&result.subcommands); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + assert!(n.contains(&"date"), "should contain date: {n:?}"); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn variable_assignment_excluded() { + let result = parse_shell_command("FOO=bar ls -la /tmp", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["ls"]); + assert_eq!(fulls(&result.subcommands), vec!["ls -la /tmp"]); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn variable_assignment_multiple() { + let result = parse_shell_command("A=1 B=2 git status", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["git"]); + assert_eq!(fulls(&result.subcommands), vec!["git status"]); + } + + #[test] + fn fallback_double_ampersand_and_pipe_pipe() { + let result = parse_shell_command("ls && cat foo || echo fail", ShellKind::Other); + assert_eq!(names(&result.subcommands), vec!["ls", "cat", "echo"]); + assert_eq!( + fulls(&result.subcommands), + vec!["ls", "cat foo", "echo fail"] + ); + } + + #[test] + fn fallback_pipe_without_double() { + let result = parse_shell_command("ls | grep foo", ShellKind::Other); + assert_eq!(names(&result.subcommands), vec!["ls", "grep"]); + assert_eq!(fulls(&result.subcommands), vec!["ls", "grep foo"]); + } + + #[test] + fn scope_middle_wildcard() { + let commands = vec![ShellCommand { + name: "git".into(), + full: "git commit -m amend".into(), + }]; + assert!(any_subcommand_matches(&commands, "git * amend")); + assert!(any_subcommand_matches(&commands, "git commit * amend")); + assert!(!any_subcommand_matches(&commands, "git push * amend")); + } + + #[test] + fn scope_middle_wildcard_zero_words() { + let commands = vec![ShellCommand { + name: "git".into(), + full: "git commit".into(), + }]; + // `*` matches zero words, so `git * commit` should match `git commit` + assert!(any_subcommand_matches(&commands, "git * commit")); + } + + #[test] + fn scope_leading_wildcard() { + let commands = vec![ShellCommand { + name: "docker".into(), + full: "docker run --rm alpine".into(), + }]; + assert!(any_subcommand_matches(&commands, "* alpine")); + assert!(!any_subcommand_matches(&commands, "* ubuntu")); + } + + #[test] + fn scope_multiple_wildcards() { + let commands = vec![ShellCommand { + name: "git".into(), + full: "git rebase -i HEAD~5".into(), + }]; + assert!(any_subcommand_matches(&commands, "git * -i * HEAD~5")); + assert!(!any_subcommand_matches(&commands, "git * -i * HEAD~10")); + } +} + +#[cfg(all(test, feature = "tree-sitter"))] +mod adversarial { + use super::*; + + fn cmd_names(cmds: &[ShellCommand]) -> Vec<&str> { + cmds.iter().map(|c| c.name.as_str()).collect() + } + + /// Helper: assert that parsing POSIX extracts all expected command names + fn assert_posix(code: &str, expected: &[&str]) { + let result = parse_shell_command(code, ShellKind::Posix); + let mut got: Vec<&str> = result.subcommands.iter().map(|c| c.name.as_str()).collect(); + got.sort(); + let mut want: Vec<&str> = expected.to_vec(); + want.sort(); + assert_eq!( + got, want, + "POSIX parse of {:?}:\n got: {:?}\n want: {:?}", + code, got, want + ); + } + + fn assert_fish(code: &str, expected: &[&str]) { + let result = parse_shell_command(code, ShellKind::Fish); + let mut got: Vec<&str> = result.subcommands.iter().map(|c| c.name.as_str()).collect(); + got.sort(); + let mut want: Vec<&str> = expected.to_vec(); + want.sort(); + assert_eq!( + got, want, + "Fish parse of {:?}:\n got: {:?}\n want: {:?}", + code, got, want + ); + } + + // ──────────────────────────────────────────────────────────── + // Level 1: Basic compounds + // ──────────────────────────────────────────────────────────── + + #[test] + fn a01_triple_chain() { + assert_posix("a && b && c", &["a", "b", "c"]); + } + + #[test] + fn a02_or_chain() { + assert_posix("a || b || c", &["a", "b", "c"]); + } + + #[test] + fn a03_mixed_chain() { + assert_posix("a && b || c && d", &["a", "b", "c", "d"]); + } + + #[test] + fn a04_long_pipeline() { + assert_posix( + "cat foo | grep bar | awk '{print $1}' | sort | uniq -c", + &["cat", "grep", "awk", "sort", "uniq"], + ); + } + + #[test] + fn a05_semicolons() { + assert_posix("a; b; c; d", &["a", "b", "c", "d"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 2: Nested substitution + // ──────────────────────────────────────────────────────────── + + #[test] + fn a06_nested_dollar() { + assert_posix( + "echo $(basename $(dirname /foo/bar))", + &["echo", "basename", "dirname"], + ); + } + + #[test] + fn a07_deeply_nested() { + // 4 nested echos, all should be extracted + assert_posix( + "echo $(echo $(echo $(echo deep)))", + &["echo", "echo", "echo", "echo"], + ); + } + + #[test] + fn a08_backtick_in_echo() { + assert_posix("echo `hostname`", &["echo", "hostname"]); + } + + #[test] + fn a09_mixed_substitutions() { + assert_posix("echo $(date) `uname`", &["echo", "date", "uname"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 3: Subshells and grouping + // ──────────────────────────────────────────────────────────── + + #[test] + fn a10_subshell_chain() { + assert_posix("(cd /tmp && ls -la)", &["cd", "ls"]); + } + + #[test] + fn a11_nested_subshells() { + assert_posix("( (inner_cmd) )", &["inner_cmd"]); + } + + #[test] + fn a12_brace_group() { + assert_posix("{ cd /tmp; ls; }", &["cd", "ls"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 4: Variable assignments + // ──────────────────────────────────────────────────────────── + + #[test] + fn a13_single_var_assignment() { + let result = parse_shell_command("FOO=bar ls", ShellKind::Posix); + assert_eq!(cmd_names(&result.subcommands), &["ls"]); + assert_eq!(result.subcommands[0].full, "ls"); + } + + #[test] + fn a14_multiple_var_assignments() { + let result = parse_shell_command("A=1 B=2 C=3 git status", ShellKind::Posix); + assert_eq!(cmd_names(&result.subcommands), &["git"]); + assert_eq!(result.subcommands[0].full, "git status"); + } + + #[test] + fn a15_var_assignment_no_command() { + // Variable assignment only — no command to extract + assert_posix("FOO=bar", &[]); + } + + #[test] + fn a16_var_assignment_in_pipeline() { + assert_posix("FOO=bar ls | BAZ=qux grep foo", &["ls", "grep"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 5: Control flow + // ──────────────────────────────────────────────────────────── + + #[test] + fn a17_if_then_else() { + assert_posix( + "if [ -f foo ]; then cat foo; else echo missing; fi", + &["cat", "echo"], + ); + } + + #[test] + fn a18_elif_chain() { + // Two cat commands (then + elif branch), one echo (else branch). + // [ is part of the test_condition, not extracted as a command. + assert_posix( + "if [ -f a ]; then cat a; elif [ -f b ]; then cat b; else echo none; fi", + &["cat", "cat", "echo"], + ); + } + + #[test] + fn a19_for_loop() { + assert_posix("for f in *.txt; do cat \"$f\"; done", &["cat"]); + } + + #[test] + fn a20_while_loop() { + // read in the condition is a real command + assert_posix( + "while read line; do echo \"$line\"; done < input.txt", + &["echo", "read"], + ); + } + + #[test] + fn f07_if_statement() { + // test in if-condition is a real command + assert_fish( + "if test -f foo; cat foo; else; echo missing; end", + &["cat", "echo", "test"], + ); + } + + #[test] + fn f09_while_loop() { + // `true` in the condition is a real command + assert_fish( + "while true; echo tick; sleep 1; end", + &["echo", "sleep", "true"], + ); + } + + // ──────────────────────────────────────────────────────────── + // Level 6: Redirections + // ──────────────────────────────────────────────────────────── + + #[test] + fn a23_redirect_out() { + assert_posix("ls > output.txt", &["ls"]); + } + + #[test] + fn a24_redirect_append() { + assert_posix("ls >> output.txt 2>&1", &["ls"]); + } + + #[test] + fn a25_here_string() { + assert_posix("grep foo <<< \"hello world\"", &["grep"]); + } + + #[test] + fn a26_redirect_in_pipeline() { + assert_posix("cat < input.txt | sort | uniq", &["cat", "sort", "uniq"]); + } + + #[test] + fn a27_process_substitution() { + assert_posix( + "diff <(sort a.txt) <(sort b.txt)", + &["diff", "sort", "sort"], + ); + } + + // ──────────────────────────────────────────────────────────── + // Level 7: Function definitions + // ──────────────────────────────────────────────────────────── + + #[test] + fn a28_function_def() { + assert_posix("foo() { echo hello; }", &["echo"]); + } + + #[test] + fn a29_function_with_subshell() { + assert_posix( + "build() { cargo build && cargo test; }", + &["cargo", "cargo"], + ); + } + + // ──────────────────────────────────────────────────────────── + // Level 8: Edge cases — empties, weird quoting + // ──────────────────────────────────────────────────────────── + + #[test] + fn a30_empty_string() { + let result = parse_shell_command("", ShellKind::Posix); + assert!(result.subcommands.is_empty()); + } + + #[test] + fn a31_whitespace_only() { + let result = parse_shell_command(" \t \n ", ShellKind::Posix); + assert!(result.subcommands.is_empty()); + } + + #[test] + fn a32_single_command_no_args() { + assert_posix("ls", &["ls"]); + } + + #[test] + fn a33_command_with_single_quotes() { + assert_posix("echo 'hello world'", &["echo"]); + } + + #[test] + fn a34_command_with_double_quotes() { + assert_posix("echo \"hello world\"", &["echo"]); + } + + #[test] + fn a35_escaped_spaces() { + // ls\ -la is a single word in bash, not "ls" with flag "-la" + assert_posix("ls\\ -la", &["ls\\ -la"]); + } + + #[test] + fn a36_command_with_dollar_var() { + assert_posix("echo $HOME/.bashrc", &["echo"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 9: Background jobs and coproc + // ──────────────────────────────────────────────────────────── + + #[test] + fn a37_background_job() { + assert_posix("sleep 10 &", &["sleep"]); + } + + #[test] + fn a38_background_chain() { + assert_posix("sleep 10 && echo done &", &["sleep", "echo"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 10: Real-world complex commands + // ──────────────────────────────────────────────────────────── + + #[test] + fn a39_docker_build_and_run() { + assert_posix( + "docker build -t app . && docker run --rm app npm test", + &["docker", "docker"], + ); + } + + #[test] + fn a40_git_rebase_interactive() { + assert_posix( + "GIT_SEQUENCE_EDITOR=\"sed -i 's/pick/reword/'\" git rebase -i HEAD~5", + &["git"], + ); + } + + #[test] + fn a41_find_with_exec() { + // tree-sitter-bash does not parse -exec body as commands — only `find` is extracted. + // This is a known limitation: args to -exec/-execdir are opaque to the parser. + assert_posix("find . -name '*.rs' -exec grep -l 'unsafe' {} +", &["find"]); + } + + #[test] + fn a42_curl_pipe_sh() { + assert_posix( + "curl -sSL https://example.com/install.sh | bash", + &["curl", "bash"], + ); + } + + #[test] + fn a43_xargs() { + assert_posix("find . -name '*.tmp' | xargs rm -f", &["find", "xargs"]); + } + + #[test] + fn a44_npm_script_chain() { + assert_posix( + "npm run build && npm run test && npm run lint", + &["npm", "npm", "npm"], + ); + } + + #[test] + fn a45_make_with_redirect() { + assert_posix( + "make -j$(nproc) 2>&1 | tee build.log", + &["make", "nproc", "tee"], + ); + } + + #[test] + fn a46_sudo_chain() { + assert_posix("sudo apt update && sudo apt upgrade -y", &["sudo", "sudo"]); + } + + #[test] + fn a47_here_doc_with_subcommand() { + assert_posix("cat <<EOF\nhello $(whoami)\nEOF", &["cat", "whoami"]); + } + + #[test] + fn a48_eval_with_command() { + assert_posix("eval \"echo hello\"", &["eval"]); + } + + #[test] + fn a49_exec_replace() { + assert_posix("exec ls", &["exec"]); + } + + #[test] + fn a50_source_script() { + assert_posix("source ~/.bashrc", &["source"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 11: Fish-specific tests + // ──────────────────────────────────────────────────────────── + + #[test] + fn f01_simple() { + assert_fish("ls -la /tmp", &["ls"]); + } + + #[test] + fn f02_pipe() { + assert_fish("cat foo | grep bar | sort", &["cat", "grep", "sort"]); + } + + #[test] + fn f03_and() { + assert_fish("git add .; and git commit -m hi", &["git", "git"]); + } + + #[test] + fn f04_or() { + assert_fish("test -f foo; or echo missing", &["test", "echo"]); + } + + #[test] + fn f04_not() { + // fish parses `not test -f foo` — `not` is a modifier, `test` is the command + assert_fish("not test -f foo", &["test"]); + } + + #[test] + fn f05_command_substitution() { + assert_fish("echo (date)", &["echo", "date"]); + } + + #[test] + fn f06_nested_substitution() { + assert_fish( + "echo (basename (dirname /foo/bar))", + &["echo", "basename", "dirname"], + ); + } + + #[test] + fn f06_begin_end() { + assert_fish("begin; ls; echo done; end", &["ls", "echo"]); + } + + #[test] + fn f10_switch() { + // Two echo commands, one per case branch + assert_fish( + "switch $x; case foo; echo foo; case bar; echo bar; end", + &["echo", "echo"], + ); + } + + #[test] + fn f08_for_loop() { + assert_fish("for f in *.txt; cat $f; end", &["cat"]); + } + + #[test] + fn a21_case_statement() { + // Two echo branches + assert_posix( + "case $x in foo) echo foo;; bar) echo bar;; esac", + &["echo", "echo"], + ); + } + + #[test] + fn f11_function_def() { + assert_fish("function greet; echo hello $argv; end", &["echo"]); + } + + #[test] + fn f12_redirect() { + assert_fish("ls > output.txt", &["ls"]); + } + + #[test] + fn f13_redirect_append() { + assert_fish("ls >> output.txt", &["ls"]); + } + + #[test] + fn f14_here_string() { + assert_fish("grep foo <<< \"hello\"", &["grep"]); + } + + #[test] + fn f15_curl_pipe() { + assert_fish( + "curl -sSL https://example.com/install.sh | bash", + &["curl", "bash"], + ); + } + + #[test] + fn f16_double_ampersand() { + assert_fish("git add . && git commit -m hi", &["git", "git"]); + } + + #[test] + fn f17_double_pipe() { + assert_fish("test -f foo || echo missing", &["test", "echo"]); + } + + #[test] + fn f18_empty() { + let result = parse_shell_command("", ShellKind::Fish); + assert!(result.subcommands.is_empty()); + } + + #[test] + fn f19_whitespace() { + let result = parse_shell_command(" ", ShellKind::Fish); + assert!(result.subcommands.is_empty()); + } + + // ──────────────────────────────────────────────────────────── + // Level 12: Scope matching adversarial + // ──────────────────────────────────────────────────────────── + + #[test] + fn s01_empty_scope() { + let commands = vec![ShellCommand { + name: "ls".into(), + full: "ls".into(), + }]; + // Empty scope matches everything (nothing to constrain) + assert!(any_subcommand_matches(&commands, "")); + } + + #[test] + fn s03_only_wildcard_space_star() { + let commands = vec![ShellCommand { + name: "ls".into(), + full: "ls".into(), + }]; + // " *" with empty prefix = match anything + assert!(any_subcommand_matches(&commands, " *")); + } + + #[test] + fn s04_glob_matches_empty() { + let commands = vec![ShellCommand { + name: "ls".into(), + full: "ls".into(), + }]; + // `ls*` matches `ls` (prefix match with nothing after) + assert!(any_subcommand_matches(&commands, "ls*")); + } + + #[test] + fn s05_middle_wildcard_empty_match() { + // `git * commit` matches `git commit` (* = zero words) + let commands = vec![ShellCommand { + name: "git".into(), + full: "git commit".into(), + }]; + assert!(any_subcommand_matches(&commands, "git * commit")); + } + + #[test] + fn s06_consecutive_wildcards() { + // `git ** commit` should behave like `git * commit` + let commands = vec![ShellCommand { + name: "git".into(), + full: "git commit".into(), + }]; + assert!(any_subcommand_matches(&commands, "git ** commit")); + } + + #[test] + fn s07_case_sensitivity() { + let commands = vec![ShellCommand { + name: "LS".into(), + full: "LS -la".into(), + }]; + assert!(!any_subcommand_matches(&commands, "ls")); + assert!(any_subcommand_matches(&commands, "LS")); + } + + #[test] + fn s08_multi_word_exact_no_subcommand() { + // `git commit` should not match `git commit-amend` + let commands = vec![ShellCommand { + name: "git".into(), + full: "git commit-amend".into(), + }]; + assert!(!any_subcommand_matches(&commands, "git commit")); + } +} diff --git a/crates/atuin-ai/src/permissions/walker.rs b/crates/atuin-ai/src/permissions/walker.rs new file mode 100644 index 00000000..3bda01c3 --- /dev/null +++ b/crates/atuin-ai/src/permissions/walker.rs @@ -0,0 +1,121 @@ +use std::path::{Path, PathBuf}; + +use eyre::Result; +use tokio::task::JoinSet; + +use crate::permissions::file::{RuleFile, RuleFileContent}; + +#[derive(Debug)] +struct FoundRuleFile { + depth: usize, + file: RuleFile, +} + +pub(crate) struct PermissionWalker { + start: PathBuf, + /// Direct path to the global permissions file (e.g. `~/.config/atuin/permissions.ai.toml`). + global_permissions_file: Option<PathBuf>, + rules: Vec<RuleFile>, +} + +impl PermissionWalker { + pub fn new(start: PathBuf, global_permissions_file: Option<PathBuf>) -> Self { + Self { + start, + global_permissions_file, + rules: Vec::new(), + } + } + + pub fn rules(&self) -> &[RuleFile] { + &self.rules + } + + /// Walks the filesystem starting from the start path and collecting permission files along the way. + /// Walks to the root, then checks the global permissions file, if any. + pub async fn walk(&mut self) -> Result<()> { + let dirs_to_check: Vec<PathBuf> = self.start.ancestors().map(PathBuf::from).collect(); + let dir_count = dirs_to_check.len(); + + let mut set: JoinSet<Result<Option<FoundRuleFile>>> = JoinSet::new(); + + for (index, path) in dirs_to_check.into_iter().enumerate() { + set.spawn(async move { + match check_dir_for_permissions(&path).await { + Ok(Some(rule_file)) => Ok(Some(FoundRuleFile { + depth: index, + file: rule_file, + })), + Ok(None) => Ok(None), + Err(e) => Err(e), + } + }); + } + + // Check the global file separately (it's a direct file path, not a dir/.atuin/ pattern) + if let Some(global_path) = self.global_permissions_file.clone() { + let depth = dir_count; // sorts after all directory-walk entries + set.spawn(async move { + match load_permissions_file(&global_path).await { + Ok(Some(rule_file)) => Ok(Some(FoundRuleFile { + depth, + file: rule_file, + })), + Ok(None) => Ok(None), + Err(e) => Err(e), + } + }); + } + + let capacity = dir_count + usize::from(self.global_permissions_file.is_some()); + let mut found = Vec::with_capacity(capacity); + while let Some(result) = set.join_next().await { + let result = result?; // JoinErrors result in failure to walk the filesystem + + match result { + Ok(Some(FoundRuleFile { depth, file })) => { + found.push((depth, file)); + } + Ok(None) => { + continue; + } + Err(e) => { + tracing::error!( + "Error while walking filesystem for permissions check; skipping: {}", + e + ); + continue; + } + } + } + // join_next() returns in order of completion, not order of spawn + found.sort_by_key(|(depth, _)| *depth); + self.rules = found.into_iter().map(|(_, file)| file).collect(); + + Ok(()) + } +} + +/// Checks a directory for `.atuin/permissions.ai.toml` and returns the RuleFile if found. +async fn check_dir_for_permissions(path: &Path) -> Result<Option<RuleFile>> { + let file_path = path.join(".atuin").join("permissions.ai.toml"); + load_permissions_file(&file_path).await +} + +/// Load a permissions file from an exact path. Returns None if the file doesn't exist. +async fn load_permissions_file(file_path: &Path) -> Result<Option<RuleFile>> { + if !tokio::fs::try_exists(file_path).await? { + return Ok(None); + } + + let raw = tokio::fs::read_to_string(file_path).await?; + let content: RuleFileContent = toml::from_str(&raw)?; + + // Use the file's parent as the rule file path (for logging/debugging) + let path = file_path + .parent() + .map(Path::to_path_buf) + .unwrap_or_else(|| file_path.to_path_buf()); + + Ok(Some(RuleFile { path, content })) +} diff --git a/crates/atuin-ai/src/permissions/writer.rs b/crates/atuin-ai/src/permissions/writer.rs new file mode 100644 index 00000000..b2bd9482 --- /dev/null +++ b/crates/atuin-ai/src/permissions/writer.rs @@ -0,0 +1,198 @@ +use std::path::Path; + +use eyre::Result; + +use crate::permissions::rule::Rule; + +/// Whether a rule should be added to the allow or deny list. +#[allow(dead_code)] +pub(crate) enum RuleDisposition { + Allow, + Deny, +} + +/// Write a permission rule to a `permissions.ai.toml` file. +/// +/// If the file doesn't exist it is created (along with parent directories). +/// If it does exist, `toml_edit` is used to append the rule while preserving +/// existing formatting and comments. +/// +/// **Not concurrent-safe.** The read-modify-write cycle is not atomic. In the +/// current UI this is fine — the Select widget serializes permission decisions — +/// but callers should not invoke this concurrently for the same file. +pub(crate) async fn write_rule( + file_path: &Path, + rule: &Rule, + disposition: RuleDisposition, +) -> Result<()> { + let content = if tokio::fs::try_exists(file_path).await.unwrap_or(false) { + tokio::fs::read_to_string(file_path).await? + } else { + String::new() + }; + + let mut doc: toml_edit::DocumentMut = content.parse()?; + + // Ensure [permissions] table exists + if !doc.contains_key("permissions") { + doc["permissions"] = toml_edit::Item::Table(toml_edit::Table::new()); + } + + let key = match disposition { + RuleDisposition::Allow => "allow", + RuleDisposition::Deny => "deny", + }; + + // Use as_table_like_mut so both standard and inline tables work. + let permissions = doc["permissions"] + .as_table_like_mut() + .ok_or_else(|| eyre::eyre!("[permissions] is not a table"))?; + + // Get or create the array + if !permissions.contains_key(key) { + permissions.insert(key, toml_edit::Item::Value(toml_edit::Array::new().into())); + } + + let array = permissions + .get_mut(key) + .and_then(|item| item.as_value_mut()) + .and_then(|v| v.as_array_mut()) + .ok_or_else(|| eyre::eyre!("permissions.{key} is not an array"))?; + + // Don't add duplicates + let rule_str = rule.to_string(); + let already_present = array.iter().any(|v| v.as_str() == Some(&rule_str)); + if !already_present { + array.push(rule_str); + } + + // Write back, creating parent directories as needed + if let Some(parent) = file_path.parent() { + tokio::fs::create_dir_all(parent).await?; + } + tokio::fs::write(file_path, doc.to_string()).await?; + + Ok(()) +} + +/// Build the path to the project-level permissions file. +/// `project_root` is typically a git root or the current working directory. +pub(crate) fn project_permissions_path(project_root: &Path) -> std::path::PathBuf { + project_root.join(".atuin").join("permissions.ai.toml") +} + +/// Build the path to the global permissions file (sibling of atuin config). +pub(crate) fn global_permissions_path() -> std::path::PathBuf { + atuin_common::utils::config_dir().join("permissions.ai.toml") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn creates_new_file_with_allow_rule() { + let dir = tempfile::tempdir().unwrap(); + let file = dir.path().join("permissions.ai.toml"); + let rule = Rule { + tool: "AtuinHistory".to_string(), + scope: None, + }; + + write_rule(&file, &rule, RuleDisposition::Allow) + .await + .unwrap(); + + let content = tokio::fs::read_to_string(&file).await.unwrap(); + assert!(content.contains("[permissions]")); + assert!(content.contains(r#""AtuinHistory""#)); + } + + #[tokio::test] + async fn appends_to_existing_file() { + let dir = tempfile::tempdir().unwrap(); + let file = dir.path().join("permissions.ai.toml"); + let existing = r#"# My permissions +[permissions] +allow = ["Read"] +"#; + tokio::fs::write(&file, existing).await.unwrap(); + + let rule = Rule { + tool: "AtuinHistory".to_string(), + scope: None, + }; + write_rule(&file, &rule, RuleDisposition::Allow) + .await + .unwrap(); + + let content = tokio::fs::read_to_string(&file).await.unwrap(); + // Comment preserved + assert!(content.contains("# My permissions")); + // Both rules present + assert!(content.contains(r#""Read""#)); + assert!(content.contains(r#""AtuinHistory""#)); + } + + #[tokio::test] + async fn does_not_duplicate_existing_rule() { + let dir = tempfile::tempdir().unwrap(); + let file = dir.path().join("permissions.ai.toml"); + let existing = r#"[permissions] +allow = ["AtuinHistory"] +"#; + tokio::fs::write(&file, existing).await.unwrap(); + + let rule = Rule { + tool: "AtuinHistory".to_string(), + scope: None, + }; + write_rule(&file, &rule, RuleDisposition::Allow) + .await + .unwrap(); + + let content = tokio::fs::read_to_string(&file).await.unwrap(); + // Should appear exactly once + assert_eq!(content.matches("AtuinHistory").count(), 1); + } + + #[tokio::test] + async fn handles_inline_table_permissions() { + let dir = tempfile::tempdir().unwrap(); + let file = dir.path().join("permissions.ai.toml"); + // Inline table style — as_table_mut() would return None for this + let existing = r#"permissions = { allow = ["Read"] } +"#; + tokio::fs::write(&file, existing).await.unwrap(); + + let rule = Rule { + tool: "AtuinHistory".to_string(), + scope: None, + }; + write_rule(&file, &rule, RuleDisposition::Allow) + .await + .unwrap(); + + let content = tokio::fs::read_to_string(&file).await.unwrap(); + assert!(content.contains(r#""Read""#)); + assert!(content.contains(r#""AtuinHistory""#)); + } + + #[tokio::test] + async fn writes_deny_rule() { + let dir = tempfile::tempdir().unwrap(); + let file = dir.path().join("permissions.ai.toml"); + let rule = Rule { + tool: "Shell".to_string(), + scope: None, + }; + + write_rule(&file, &rule, RuleDisposition::Deny) + .await + .unwrap(); + + let content = tokio::fs::read_to_string(&file).await.unwrap(); + assert!(content.contains("deny")); + assert!(content.contains(r#""Shell""#)); + } +} diff --git a/crates/atuin-ai/src/stream.rs b/crates/atuin-ai/src/stream.rs new file mode 100644 index 00000000..4673f2cd --- /dev/null +++ b/crates/atuin-ai/src/stream.rs @@ -0,0 +1,372 @@ +// ─────────────────────────────────────────────────────────────────── +// SSE streaming +// ─────────────────────────────────────────────────────────────────── + +use std::sync::mpsc; + +use atuin_client::settings::AiCapabilities; +use atuin_common::tls::ensure_crypto_provider; + +use eventsource_stream::Eventsource; +use eye_declare::Handle; +use eyre::{Context, Result}; +use futures::StreamExt; +use reqwest::Url; + +use crate::{ + context::{AppContext, ClientContext}, + tools::ClientToolCall, + tui::{Session, events::AiTuiEvent}, +}; + +/// Frames that alter the stream lifecycle — terminal or state-changing. +#[derive(Debug, Clone)] +pub(crate) enum StreamControl { + Done { session_id: String }, + Error(String), + StatusChanged(String), +} + +/// Frames that carry conversation content — they mutate the event log. +#[derive(Debug, Clone)] +pub(crate) enum StreamContent { + TextChunk(String), + ToolCall { + id: String, + name: String, + input: serde_json::Value, + }, + ToolResult { + tool_use_id: String, + content: String, + is_error: bool, + }, +} + +/// A frame from the SSE stream, classified as control or content. +#[derive(Debug, Clone)] +pub(crate) enum StreamFrame { + Content(StreamContent), + Control(StreamControl), +} + +/// Per-turn request payload for the chat API. +pub(crate) struct ChatRequest { + pub messages: Vec<serde_json::Value>, + pub session_id: Option<String>, + pub capabilities: Vec<String>, +} + +impl ChatRequest { + pub(crate) fn new( + messages: Vec<serde_json::Value>, + session_id: Option<String>, + capabilities: &AiCapabilities, + ) -> Self { + let mut caps = vec![]; + if capabilities.enable_history_search.unwrap_or(true) { + caps.push("client_v1_atuin_history".to_string()); + } + if let Ok(extra) = std::env::var("ATUIN_AI__ADDITIONAL_CAPS") { + caps.extend( + extra + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()), + ); + } + + Self { + messages, + session_id, + capabilities: caps, + } + } +} + +fn create_chat_stream( + hub_address: String, + token: String, + request: ChatRequest, + client_ctx: ClientContext, + send_cwd: bool, + last_command: Option<String>, +) -> std::pin::Pin<Box<dyn futures::Stream<Item = Result<StreamFrame>> + Send>> { + Box::pin(async_stream::stream! { + ensure_crypto_provider(); + let endpoint = match hub_url(&hub_address, "/api/cli/chat") { + Ok(url) => url, + Err(e) => { + yield Err(e); + return; + } + }; + + tracing::debug!("Sending SSE request to {endpoint}"); + + let context = client_ctx.to_json(send_cwd, last_command.as_deref()); + + let mut request_body = serde_json::json!({ + "messages": request.messages, + "context": context, + "capabilities": request.capabilities, + }); + + if let Some(ref sid) = request.session_id { + tracing::trace!("Including session_id in request: {sid}"); + request_body["session_id"] = serde_json::json!(sid); + } + + let client = reqwest::Client::new(); + let response = match client + .post(endpoint.clone()) + .header("Accept", "text/event-stream") + .bearer_auth(&token) + .json(&request_body) + .send() + .await + { + Ok(resp) => resp, + Err(e) => { + yield Err(eyre::eyre!("Failed to send SSE request: {}", e)); + return; + } + }; + + let status = response.status(); + if status == reqwest::StatusCode::UNAUTHORIZED { + tracing::error!("SSE request failed with status: {status}, clearing session"); + let _ = atuin_client::hub::delete_session().await; + yield Err(eyre::eyre!("Hub session expired. Re-run to authenticate again.")); + return; + } + if !status.is_success() { + let body = response.text().await.unwrap_or_default(); + tracing::error!("SSE request failed ({}): {}", status, body); + yield Err(eyre::eyre!("SSE request failed ({}): {}", status, body)); + return; + } + + let byte_stream = response.bytes_stream(); + let mut stream = byte_stream.eventsource(); + + while let Some(event) = stream.next().await { + match event { + Ok(sse_event) => { + let event_type = sse_event.event.as_str(); + let data = sse_event.data.clone(); + + tracing::debug!(event_type = %event_type, "SSE event received"); + + match event_type { + "text" => { + if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) + && let Some(content) = json.get("content").and_then(|v| v.as_str()) + { + yield Ok(StreamFrame::Content(StreamContent::TextChunk(content.to_string()))); + } + } + "tool_call" => { + if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { + let id = json.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let name = json.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let input = json.get("input").cloned().unwrap_or(serde_json::json!({})); + yield Ok(StreamFrame::Content(StreamContent::ToolCall { id, name, input })); + } + } + "tool_result" => { + if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { + let tool_use_id = json.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let content = json.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let is_error = json.get("is_error").and_then(|v| v.as_bool()).unwrap_or(false); + yield Ok(StreamFrame::Content(StreamContent::ToolResult { tool_use_id, content, is_error })); + } + } + "status" => { + if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) + && let Some(state) = json.get("state").and_then(|v| v.as_str()) + { + yield Ok(StreamFrame::Control(StreamControl::StatusChanged(state.to_string()))); + } + } + "done" => { + if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { + let session_id = json.get("session_id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + yield Ok(StreamFrame::Control(StreamControl::Done { session_id })); + } else { + yield Ok(StreamFrame::Control(StreamControl::Done { session_id: String::new() })); + } + break; + } + "error" => { + if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { + let message = json.get("message").and_then(|v| v.as_str()).unwrap_or("Unknown error").to_string(); + tracing::error!("SSE error: {}", message); + yield Ok(StreamFrame::Control(StreamControl::Error(message))); + } else { + tracing::error!("SSE error: {}", data); + yield Ok(StreamFrame::Control(StreamControl::Error(data))); + } + break; + } + _ => {} + } + } + Err(e) => { + yield Err(eyre::eyre!("SSE error: {}", e)); + break; + } + } + } + }) +} + +// ─────────────────────────────────────────────────────────────────── +// Async streaming task — pushes updates to app state via Handle +// ─────────────────────────────────────────────────────────────────── + +pub(crate) async fn run_chat_stream( + handle: Handle<Session>, + tx: mpsc::Sender<AiTuiEvent>, + app_ctx: AppContext, + client_ctx: ClientContext, + request: ChatRequest, +) { + let capabilities = request.capabilities.clone(); + let stream = create_chat_stream( + app_ctx.endpoint.clone(), + app_ctx.token.clone(), + request, + client_ctx, + app_ctx.send_cwd, + app_ctx.last_command.clone(), + ); + futures::pin_mut!(stream); + + while let Some(event) = stream.next().await { + match event { + Ok(StreamFrame::Content(content)) => { + apply_content_frame(&handle, &tx, &capabilities, content); + } + Ok(StreamFrame::Control(control)) => { + let terminal = apply_control_frame(&handle, control); + if terminal { + break; + } + } + Err(e) => { + let msg = e.to_string(); + handle.update(move |state| { + state.streaming_error(msg); + }); + break; + } + } + } +} + +/// Apply a content frame to session state. +/// Control flow: always continues the stream. +fn apply_content_frame( + handle: &Handle<Session>, + tx: &mpsc::Sender<AiTuiEvent>, + capabilities: &[String], + content: StreamContent, +) { + match content { + StreamContent::TextChunk(text) => { + handle.update(move |state| { + state.conversation.append_streaming_text(&text); + }); + } + StreamContent::ToolCall { id, name, input } => { + if let Ok(tool) = ClientToolCall::try_from((name.as_str(), &input)) { + // Enforce capability gating: reject tool calls the client didn't advertise. + if let Some(required_cap) = tool.descriptor().capability + && !capabilities.iter().any(|c| c == required_cap) + { + tracing::warn!( + tool = name, + capability = required_cap, + "Rejecting tool call: capability not advertised" + ); + handle.update(move |state| { + state.add_tool_call(id.clone(), name, input.clone()); + state.conversation.add_tool_result( + id, + format!("Tool not enabled: capability '{required_cap}' was not advertised by this client"), + true, + ); + }); + return; + } + + // Client-side tool — add to tracker and conversation, queue permission check + let id_for_event = id.clone(); + handle.update(move |state| { + state.handle_client_tool_call(id_for_event, tool, input); + }); + let _ = tx.send(AiTuiEvent::CheckToolCallPermission(id)); + } else { + // Server-side tool — just add to conversation events + handle.update(move |state| { + state.add_tool_call(id, name, input); + }); + } + } + StreamContent::ToolResult { + tool_use_id, + content, + is_error, + } => { + handle.update(move |state| { + state + .conversation + .add_tool_result(tool_use_id, content, is_error); + }); + } + } +} + +/// Apply a control frame to session state. +/// Returns true if the stream should terminate. +fn apply_control_frame(handle: &Handle<Session>, control: StreamControl) -> bool { + match control { + StreamControl::StatusChanged(status) => { + handle.update(move |state| { + state.update_streaming_status(&status); + }); + false + } + StreamControl::Done { session_id } => { + handle.update(move |state| { + if !session_id.is_empty() { + state.conversation.store_session_id(session_id); + } + state.finalize_streaming(); + }); + true + } + StreamControl::Error(msg) => { + handle.update(move |state| { + state.streaming_error(msg); + }); + true + } + } +} + +fn hub_url(base: &str, path: &str) -> Result<Url> { + let base_with_slash = if base.ends_with('/') { + base.to_string() + } else { + format!("{base}/") + }; + let stripped = path.strip_prefix('/').unwrap_or(path); + Url::parse(&base_with_slash)? + .join(stripped) + .context("failed to build hub URL") +} diff --git a/crates/atuin-ai/src/tools/descriptor.rs b/crates/atuin-ai/src/tools/descriptor.rs new file mode 100644 index 00000000..3b2b7ebf --- /dev/null +++ b/crates/atuin-ai/src/tools/descriptor.rs @@ -0,0 +1,98 @@ +/// Centralized metadata for a tool type. +/// +/// Covers both client-side tools (ones the CLI executes locally) and +/// server-side tools (ones the API executes remotely). This is the single +/// source of truth for display text and classification. +pub(crate) struct ToolDescriptor { + /// Canonical wire names for this tool (the names the server sends). + pub canonical_names: &'static [&'static str], + /// The capability string the client must advertise for this tool to be + /// accepted. `None` for server-side tools (always accepted). + pub capability: Option<&'static str>, + /// Imperative verb for permission prompts (e.g. "read", "run"). + pub display_verb: &'static str, + /// Present-tense progressive verb for spinners (e.g. "Reading file..."). + pub progressive_verb: &'static str, + /// Past-tense verb for summaries (e.g. "Read file"). + pub past_verb: &'static str, + /// Whether this tool is executed client-side (by the CLI). + pub is_client: bool, +} + +// ── Client-side tool descriptors ── + +pub(crate) const READ: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["read_file"], + capability: Some("client_v1_read"), + display_verb: "read", + progressive_verb: "Reading file...", + past_verb: "Read file", + is_client: true, +}; + +pub(crate) const WRITE: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["str_replace", "file_create", "file_insert"], + capability: Some("client_v1_write"), + display_verb: "write to", + progressive_verb: "Writing file...", + past_verb: "Wrote file", + is_client: true, +}; + +pub(crate) const SHELL: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["execute_shell_command"], + capability: Some("client_v1_shell"), + display_verb: "run", + progressive_verb: "Running command...", + past_verb: "Ran command", + is_client: true, +}; + +pub(crate) const ATUIN_HISTORY: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["atuin_history"], + capability: Some("client_v1_atuin_history"), + display_verb: "search your Atuin history for", + progressive_verb: "Searching...", + past_verb: "Searched", + is_client: true, +}; + +// ── Server-side tool descriptors ── +// These appear in tool summaries but aren't client-side tools. + +pub(crate) const SERVER_SEARCH: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["web_search"], + capability: None, + display_verb: "search", + progressive_verb: "Searching...", + past_verb: "Searched", + is_client: false, +}; + +pub(crate) const SERVER_SCRAPE: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["web_scrape"], + capability: None, + display_verb: "scrape", + progressive_verb: "Scraping...", + past_verb: "Scraped", + is_client: false, +}; + +/// All known tool descriptors, for lookup by name. +const ALL_DESCRIPTORS: &[&ToolDescriptor] = &[ + READ, + WRITE, + SHELL, + ATUIN_HISTORY, + SERVER_SEARCH, + SERVER_SCRAPE, +]; + +/// Look up a tool descriptor by its canonical wire name. +/// Returns None for unknown tool names. +pub(crate) fn by_name(name: &str) -> Option<&'static ToolDescriptor> { + ALL_DESCRIPTORS + .iter() + .find(|d| d.canonical_names.contains(&name)) + .copied() +} diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs new file mode 100644 index 00000000..8f2183b7 --- /dev/null +++ b/crates/atuin-ai/src/tools/mod.rs @@ -0,0 +1,1111 @@ +use std::{ + io::BufRead, + path::{Path, PathBuf}, + time::Duration, +}; + +use eyre::Result; + +const DEFAULT_FILE_READ_LINES: u64 = 100; +const MAX_FILE_READ_LINES: u64 = 1000; + +pub(crate) mod descriptor; + +use crate::permissions::rule::Rule; + +/// Check whether a file path matches a scope glob pattern. +/// +/// Resolves relative paths against the current directory before matching so +/// that `./foo.md` and `/cwd/foo.md` match the same glob. Supports `*`, `**`, +/// `?`, and `[...]` via `glob_match`. +fn path_matches_scope(path: &Path, scope: &str) -> bool { + let path = if path.is_relative() { + std::env::current_dir() + .map(|cwd| cwd.join(path)) + .unwrap_or_else(|_| path.to_path_buf()) + } else { + path.to_path_buf() + }; + // Normalize to forward slashes so globs work on Windows too. + let path_str = path.to_string_lossy().replace('\\', "/"); + + // If the scope is also relative, try matching against both the absolute + // path and just the filename/relative portion. + if !scope.starts_with('/') { + // Match against filename (e.g. "*.md" matches any .md file) + if let Some(name) = path.file_name().and_then(|n| n.to_str()) + && glob_match::glob_match(scope, name) + { + return true; + } + // Also try matching against the full absolute path in case the scope + // is a relative multi-segment pattern like "crates/**/*.rs" + if glob_match::glob_match(scope, &path_str) { + return true; + } + // And match relative to cwd (so "src/*.rs" works from project root) + if let Ok(cwd) = std::env::current_dir() + && let Ok(rel) = path.strip_prefix(&cwd) + { + let rel_str = rel.to_string_lossy().replace('\\', "/"); + return glob_match::glob_match(scope, &rel_str); + } + return false; + } + + // Absolute scope — match against absolute path + glob_match::glob_match(scope, &path_str) +} + +/// Result of executing a client-side tool. +pub(crate) enum ToolOutcome { + /// Simple success with a text result (used by Read, AtuinHistory). + Success(String), + /// Error with a message. + Error(String), + /// Structured shell result with separated stdout, stderr, exit code, and duration. + Structured { + stdout: String, + stderr: String, + exit_code: Option<i32>, + duration_ms: u64, + interrupted: bool, + }, +} + +impl ToolOutcome { + /// Format this outcome as a string for the tool result sent to the LLM. + pub fn format_for_llm(&self) -> String { + match self { + ToolOutcome::Success(s) => s.clone(), + ToolOutcome::Error(e) => e.clone(), + ToolOutcome::Structured { + stdout, + stderr, + exit_code, + duration_ms, + interrupted, + } => { + let mut parts = Vec::new(); + + if let Some(code) = exit_code { + parts.push(format!("Exit code: {code}")); + } + + parts.push(format!("Duration: {duration_ms}ms")); + + if !stdout.is_empty() { + parts.push(format!("stdout:\n{stdout}")); + } else { + parts.push("stdout: (empty)".to_string()); + } + + if !stderr.is_empty() { + parts.push(format!("stderr:\n{stderr}")); + } else { + parts.push("stderr: (empty)".to_string()); + } + + if *interrupted { + parts.push("[Interrupted by user]".to_string()); + } + + parts.join("\n\n") + } + } + } + + /// Whether this outcome represents an error. + pub fn is_error(&self) -> bool { + match self { + ToolOutcome::Error(_) => true, + ToolOutcome::Structured { + exit_code: Some(code), + .. + } if *code != 0 => true, + _ => false, + } + } +} + +/// Cached VT100 preview data for a shell tool call. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct ToolPreview { + pub lines: Vec<String>, + pub exit_code: Option<i32>, + pub interrupted: bool, +} + +/// Lifecycle phase of a tracked tool call. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum ToolPhase { + CheckingPermissions, + AskingForPermission, + #[expect(dead_code)] + Denied(String), + #[expect(dead_code)] + Executing, + /// Shell command is executing with live preview output. + ExecutingWithPreview { + command: String, + /// Current VT100 screen lines (plain text, viewport-sized). + output_lines: Vec<String>, + /// Exit code once the process completes. + exit_code: Option<i32>, + /// Whether the command was interrupted by the user. + interrupted: bool, + }, + /// Tool execution has completed. Preview is cached for rendering history. + Completed { + preview: Option<ToolPreview>, + }, +} + +/// A tracked tool call through its full lifecycle. +#[derive(Debug)] +pub(crate) struct TrackedTool { + pub id: String, + pub tool: ClientToolCall, + pub phase: ToolPhase, + /// Sender to interrupt a running shell command (only set during ExecutingWithPreview). + pub abort_tx: Option<tokio::sync::oneshot::Sender<()>>, +} + +impl TrackedTool { + pub(crate) fn target_dir(&self) -> Option<&Path> { + self.tool.target_dir() + } + + pub fn mark_asking(&mut self) { + self.phase = ToolPhase::AskingForPermission; + } + + pub fn mark_executing_preview(&mut self, command: String) { + self.phase = ToolPhase::ExecutingWithPreview { + command, + output_lines: Vec::new(), + exit_code: None, + interrupted: false, + }; + } + + pub fn complete(&mut self, preview: Option<ToolPreview>) { + self.phase = ToolPhase::Completed { preview }; + self.abort_tx = None; + } + + /// Extract the current preview, whether live or completed. + pub fn preview(&self) -> Option<ToolPreview> { + match &self.phase { + ToolPhase::ExecutingWithPreview { + output_lines, + exit_code, + interrupted, + .. + } => Some(ToolPreview { + lines: output_lines.clone(), + exit_code: *exit_code, + interrupted: *interrupted, + }), + ToolPhase::Completed { preview } => preview.clone(), + _ => None, + } + } +} + +/// Tracks all tool calls through their full lifecycle. +/// +/// Single source of truth for tool execution state. Entries persist after +/// completion so cached previews remain available for rendering history. +#[derive(Debug)] +pub(crate) struct ToolTracker { + tools: Vec<TrackedTool>, +} + +impl ToolTracker { + pub fn new() -> Self { + Self { tools: Vec::new() } + } + + /// Insert a new tool call in CheckingPermissions phase. + pub fn insert(&mut self, id: String, tool: ClientToolCall) { + self.tools.push(TrackedTool { + id, + tool, + phase: ToolPhase::CheckingPermissions, + abort_tx: None, + }); + } + + pub fn get(&self, id: &str) -> Option<&TrackedTool> { + self.tools.iter().find(|t| t.id == id) + } + + pub fn get_mut(&mut self, id: &str) -> Option<&mut TrackedTool> { + self.tools.iter_mut().find(|t| t.id == id) + } + + /// Remove a tool by ID and return it. + #[expect(dead_code)] + pub fn remove(&mut self, id: &str) -> Option<TrackedTool> { + let pos = self.tools.iter().position(|t| t.id == id)?; + Some(self.tools.remove(pos)) + } + + /// True if any tool is still awaiting a permission decision. + #[expect(dead_code)] + pub fn has_unresolved(&self) -> bool { + self.tools.iter().any(|t| { + matches!( + t.phase, + ToolPhase::CheckingPermissions | ToolPhase::AskingForPermission + ) + }) + } + + /// True if any tool has not yet reached the Completed phase. + /// Use this to gate `ContinueAfterTools` — we must wait for all tools + /// (including those still executing) before resuming the conversation. + pub fn has_pending(&self) -> bool { + self.tools + .iter() + .any(|t| !matches!(t.phase, ToolPhase::Completed { .. })) + } + + /// True if any tool is currently executing with a preview. + pub fn has_executing_preview(&self) -> bool { + self.tools + .iter() + .any(|t| matches!(t.phase, ToolPhase::ExecutingWithPreview { .. })) + } + + /// Find the first tool that is asking for permission. + pub fn asking_for_permission(&self) -> Option<&TrackedTool> { + self.tools + .iter() + .find(|t| t.phase == ToolPhase::AskingForPermission) + } + + /// Find the first tool that is asking for permission (mutable). + #[expect(dead_code)] + pub fn asking_for_permission_mut(&mut self) -> Option<&mut TrackedTool> { + self.tools + .iter_mut() + .find(|t| t.phase == ToolPhase::AskingForPermission) + } + + /// Get the preview for a tool by ID (live or cached). + pub fn preview_for(&self, id: &str) -> Option<ToolPreview> { + self.get(id)?.preview() + } + + /// Iterate mutably over all tracked tools. + pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut TrackedTool> { + self.tools.iter_mut() + } +} + +/// A tool call from the server, with parsed input parameters. +#[derive(Debug, Clone)] +pub(crate) enum ClientToolCall { + Read(ReadToolCall), + Write(WriteToolCall), + Shell(ShellToolCall), + AtuinHistory(AtuinHistoryToolCall), +} + +impl TryFrom<(&str, &serde_json::Value)> for ClientToolCall { + type Error = eyre::Error; + + fn try_from((name, input): (&str, &serde_json::Value)) -> Result<Self, Self::Error> { + match name { + "read_file" => Ok(ClientToolCall::Read(ReadToolCall::try_from(input)?)), + "create_file" => Ok(ClientToolCall::Write(WriteToolCall::try_from(input)?)), + // "append_to_file" => Ok(ClientToolCall::Append(AppendToolCall::try_from(input)?)), + // "str_replace" => Ok(ClientToolCall::StrReplace(StrReplaceToolCall::try_from(input)?)), + "execute_shell_command" => Ok(ClientToolCall::Shell(ShellToolCall::try_from(input)?)), + "atuin_history" => Ok(ClientToolCall::AtuinHistory( + AtuinHistoryToolCall::try_from(input)?, + )), + _ => Err(eyre::eyre!("Unknown tool call: {name}")), + } + } +} + +impl ClientToolCall { + pub(crate) fn descriptor(&self) -> &'static descriptor::ToolDescriptor { + match self { + ClientToolCall::Read(_) => descriptor::READ, + ClientToolCall::Write(_) => descriptor::WRITE, + ClientToolCall::Shell(_) => descriptor::SHELL, + ClientToolCall::AtuinHistory(_) => descriptor::ATUIN_HISTORY, + } + } + + /// The permission rule name for this tool category (e.g. "Write" covers + /// str_replace, file_create, file_insert). + pub(crate) fn rule_name(&self) -> &'static str { + match self { + ClientToolCall::Read(_) => "Read", + ClientToolCall::Write(_) => "Write", + ClientToolCall::Shell(_) => "Shell", + ClientToolCall::AtuinHistory(_) => "AtuinHistory", + } + } + + pub(crate) fn matches_rule(&self, rule: &Rule) -> bool { + match self { + ClientToolCall::Read(tool) => tool.matches_rule(rule), + ClientToolCall::Write(tool) => tool.matches_rule(rule), + ClientToolCall::Shell(tool) => tool.matches_rule(rule), + ClientToolCall::AtuinHistory(tool) => tool.matches_rule(rule), + } + } + + pub(crate) fn target_dir(&self) -> Option<&Path> { + match self { + ClientToolCall::Read(tool) => tool.target_dir(), + ClientToolCall::Write(tool) => tool.target_dir(), + ClientToolCall::Shell(tool) => tool.target_dir(), + ClientToolCall::AtuinHistory(tool) => tool.target_dir(), + } + } + + /// Execute this client-side tool and return the result. + pub async fn execute(&self, db: &atuin_client::database::Sqlite) -> ToolOutcome { + match self { + ClientToolCall::Read(tool) => tool.execute(), + ClientToolCall::AtuinHistory(tool) => tool.execute(db).await, + _ => ToolOutcome::Error("Client-side tool execution not yet implemented".to_string()), + } + } +} + +/// A trait for tool calls that can be checked against permission rules. +pub(crate) trait PermissableToolCall { + /// Checks if this tool call matches the given permission rule. + fn matches_rule(&self, rule: &Rule) -> bool; + /// Returns the target directory of this tool call, if applicable, for checking against directory-based rules. + fn target_dir(&self) -> Option<&Path> { + None + } +} + +impl PermissableToolCall for ClientToolCall { + fn matches_rule(&self, rule: &Rule) -> bool { + self.matches_rule(rule) + } + + fn target_dir(&self) -> Option<&Path> { + self.target_dir() + } +} + +#[derive(Debug, Clone)] +pub(crate) struct ReadToolCall { + pub path: PathBuf, + pub offset: u64, + pub limit: u64, +} + +impl TryFrom<&serde_json::Value> for ReadToolCall { + type Error = eyre::Error; + + fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> { + let path = value + .get("file_path") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing path"))?; + + let offset = value.get("offset").and_then(|v| v.as_u64()).unwrap_or(0); + let limit = value + .get("limit") + .and_then(|v| v.as_u64()) + .unwrap_or(DEFAULT_FILE_READ_LINES) + .min(MAX_FILE_READ_LINES); + + Ok(ReadToolCall { + path: PathBuf::from(path), + offset, + limit, + }) + } +} + +impl ReadToolCall { + fn execute(&self) -> ToolOutcome { + let mut path = self.path.clone(); + + if path.is_relative() + && let Ok(current_dir) = std::env::current_dir() + { + path = current_dir.join(path); + } + + if !path.exists() { + return ToolOutcome::Error(format!("Error: file does not exist: {}", path.display())); + } + + if path.is_dir() { + let Some(files) = std::fs::read_dir(&path).ok().and_then(|entries| { + entries + .filter_map(|entry| entry.ok()) + .map(|entry| entry.file_name().to_string_lossy().to_string()) + .collect::<Vec<_>>() + .into() + }) else { + return ToolOutcome::Error(format!( + "Error: could not read directory: {}", + path.display() + )); + }; + + return ToolOutcome::Success(format!("Directory contents:\n{}", files.join("\n"))); + } + + let file = match std::fs::File::open(&path) { + Ok(file) => file, + Err(e) => return ToolOutcome::Error(format!("Error opening file: {e}")), + }; + let reader = std::io::BufReader::new(file); + + let relevent_lines = reader + .lines() + .skip(self.offset as usize) + .take(self.limit as usize) + .collect::<Result<Vec<_>, _>>(); + + match relevent_lines { + Ok(lines) => { + let joined = lines.join("\n"); + if joined.len() > 100_000 { + ToolOutcome::Error(format!( + "Error: file is too large to read ({} bytes in {} lines); use view_range to read a subset of the file", + joined.len(), + lines.len() + )) + } else { + ToolOutcome::Success(joined) + } + } + Err(e) => ToolOutcome::Error(format!("Error reading file: {e}")), + } + } +} + +impl PermissableToolCall for ReadToolCall { + fn target_dir(&self) -> Option<&Path> { + Some(&self.path) + } + + fn matches_rule(&self, rule: &Rule) -> bool { + if rule.tool != "Read" { + return false; + } + + match rule.scope.as_deref() { + None | Some("*") => true, + Some(scope) => path_matches_scope(&self.path, scope), + } + } +} + +#[derive(Debug, Clone)] +#[expect(dead_code)] +pub(crate) struct WriteToolCall { + pub path: PathBuf, + pub content: String, +} + +impl TryFrom<&serde_json::Value> for WriteToolCall { + type Error = eyre::Error; + + fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> { + let path = value + .get("path") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing path"))?; + + let content = value + .get("content") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing content"))?; + + Ok(WriteToolCall { + path: PathBuf::from(path), + content: content.to_string(), + }) + } +} + +impl PermissableToolCall for WriteToolCall { + fn target_dir(&self) -> Option<&Path> { + Some(&self.path) + } + + fn matches_rule(&self, rule: &Rule) -> bool { + if rule.tool != "Write" { + return false; + } + + match rule.scope.as_deref() { + None | Some("*") => true, + Some(scope) => path_matches_scope(&self.path, scope), + } + } +} + +#[derive(Debug, Clone)] +pub(crate) struct ShellToolCall { + pub dir: Option<PathBuf>, + pub command: String, + pub shell: String, +} + +impl TryFrom<&serde_json::Value> for ShellToolCall { + type Error = eyre::Error; + + fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> { + let dir = value.get("dir").and_then(|v| v.as_str()); + + let command = value + .get("command") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing command"))?; + + let shell = value + .get("shell") + .and_then(|v| v.as_str()) + .unwrap_or("bash") + .to_string(); + + Ok(ShellToolCall { + dir: dir.map(PathBuf::from), + command: command.to_string(), + shell, + }) + } +} + +impl PermissableToolCall for ShellToolCall { + fn target_dir(&self) -> Option<&Path> { + self.dir.as_deref() + } + + fn matches_rule(&self, rule: &Rule) -> bool { + if rule.tool != "Shell" { + return false; + } + + let Some(scope) = rule.scope.as_ref() else { + // Shell without scope matches all shell commands + return true; + }; + + let shell_kind = crate::permissions::shell::ShellKind::from_shell_name(&self.shell); + let parsed = crate::permissions::shell::parse_shell_command(&self.command, shell_kind); + crate::permissions::shell::any_subcommand_matches(&parsed.subcommands, scope) + } +} + +/// Preview viewport height for VT100 emulation. +const PREVIEW_HEIGHT: u16 = 10; + +/// Default terminal width for VT100 emulation. +const PREVIEW_WIDTH: u16 = 120; + +/// Extract plain text lines from a VT100 screen buffer. +fn vt100_screen_lines(screen: &vt100::Screen) -> Vec<String> { + let (rows, cols) = screen.size(); + let mut lines = Vec::with_capacity(rows as usize); + for row in 0..rows { + let mut line = String::with_capacity(cols as usize); + for col in 0..cols { + if let Some(cell) = screen.cell(row, col) { + line.push_str(cell.contents()); + } + } + // Trim trailing whitespace for cleaner display + lines.push(line.trim_end().to_string()); + } + lines +} + +/// Strip ANSI escape sequences from raw bytes using a VT100 parser. +/// +/// Uses a large virtual screen so scrollback is preserved, then extracts +/// the plain text contents. This handles all escape sequences (colors, +/// cursor movement, progress bars, etc.) not just simple SGR codes. +fn strip_ansi_via_vt100(raw: &[u8]) -> String { + if raw.is_empty() { + return String::new(); + } + // Use the contents_formatted → screen approach: feed bytes into a parser + // with enough rows to hold everything, then read back the plain text. + // Estimate rows: one row per ~PREVIEW_WIDTH bytes, plus generous padding. + let estimated_rows = (raw.len() / PREVIEW_WIDTH as usize + 1).min(10_000) as u16; + let mut parser = vt100::Parser::new(estimated_rows, PREVIEW_WIDTH, 0); + parser.process(raw); + let screen = parser.screen(); + // screen.contents() returns the full plain-text content with trailing + // whitespace trimmed per line and trailing blank lines removed. + screen.contents() +} + +/// Execute a shell command with VT100 emulation and streaming output. +/// +/// Feeds stdout+stderr into a `vt100::Parser` so that ANSI escape sequences, +/// progress bars (`\r`), and cursor movement are handled correctly. Periodically +/// sends the current screen state as `Vec<String>` through `output_tx` for the +/// live preview. +/// +/// Captures the FULL stdout and stderr separately for the tool result sent to the LLM. +/// Returns a `ToolOutcome::Structured` with full output, exit code, and duration. +pub(crate) async fn execute_shell_command_streaming( + shell_call: &ShellToolCall, + output_tx: tokio::sync::mpsc::Sender<Vec<String>>, + mut interrupt_rx: tokio::sync::oneshot::Receiver<()>, +) -> ToolOutcome { + use tokio::io::AsyncReadExt; + + let start = std::time::Instant::now(); + + // TODO: check if this is proper for all shells we support + let mut cmd = tokio::process::Command::new(&shell_call.shell); + cmd.arg("-c").arg(&shell_call.command); + cmd.stdout(std::process::Stdio::piped()); + cmd.stderr(std::process::Stdio::piped()); + + if let Some(ref dir) = shell_call.dir { + cmd.current_dir(dir); + } + + let mut child = match cmd.spawn() { + Ok(child) => child, + Err(e) => return ToolOutcome::Error(format!("Failed to spawn command: {e}")), + }; + + let stdout = child.stdout.take().expect("stdout was piped"); + let stderr = child.stderr.take().expect("stderr was piped"); + + // VT100 emulator for the live preview (viewport-sized) + let mut parser = vt100::Parser::new(PREVIEW_HEIGHT, PREVIEW_WIDTH, 0); + + let mut stdout_reader = tokio::io::BufReader::new(stdout); + let mut stderr_reader = tokio::io::BufReader::new(stderr); + + let mut stdout_buf = [0u8; 4096]; + let mut stderr_buf = [0u8; 4096]; + let mut stdout_done = false; + let mut stderr_done = false; + + // Full output buffers (for the LLM, not the preview) + let mut full_stdout = Vec::<u8>::new(); + let mut full_stderr = Vec::<u8>::new(); + + let mut interval = tokio::time::interval(Duration::from_millis(50)); + + // Send initial empty screen + let initial_lines = vt100_screen_lines(parser.screen()); + let _ = output_tx.send(initial_lines).await; + + let mut interrupted = false; + + loop { + tokio::select! { + biased; + + // Check for interrupt signal + _ = &mut interrupt_rx, if !interrupted => { + interrupted = true; + let _ = child.start_kill(); + } + + // Read stdout + result = stdout_reader.read(&mut stdout_buf), if !stdout_done => { + match result { + Ok(0) => stdout_done = true, + Ok(n) => { + full_stdout.extend_from_slice(&stdout_buf[..n]); + parser.process(&stdout_buf[..n]); + } + Err(_) => stdout_done = true, + } + } + + // Read stderr + result = stderr_reader.read(&mut stderr_buf), if !stderr_done => { + match result { + Ok(0) => stderr_done = true, + Ok(n) => { + full_stderr.extend_from_slice(&stderr_buf[..n]); + // Feed stderr to the preview parser too, so it shows in the VT100 screen + parser.process(&stderr_buf[..n]); + } + Err(_) => stderr_done = true, + } + } + + // Periodic screen snapshot for preview + _ = interval.tick() => { + let lines = vt100_screen_lines(parser.screen()); + let _ = output_tx.send(lines).await; + } + } + + // Exit when both streams are done + if stdout_done && stderr_done { + break; + } + } + + // Wait for process to finish + let exit_code = match child.wait().await { + Ok(status) => status.code(), + Err(e) => { + if interrupted { + None + } else { + return ToolOutcome::Error(format!("Failed to wait for command: {e}")); + } + } + }; + + let duration = start.elapsed(); + + // Send final screen state + let final_lines = vt100_screen_lines(parser.screen()); + let _ = output_tx.send(final_lines).await; + + // Strip ANSI escape sequences for clean LLM output by running + // the raw bytes through a VT100 parser and extracting plain text. + let stdout_text = strip_ansi_via_vt100(&full_stdout); + let stderr_text = strip_ansi_via_vt100(&full_stderr); + + ToolOutcome::Structured { + stdout: stdout_text, + stderr: stderr_text, + exit_code, + duration_ms: duration.as_millis() as u64, + interrupted, + } +} + +#[derive(Debug, Clone)] +pub(crate) struct AtuinHistoryToolCall { + pub filter_modes: Vec<HistorySearchFilterMode>, + pub query: String, + pub limit: i64, +} + +#[derive(Debug, Clone)] +pub(crate) enum HistorySearchFilterMode { + Global, + Host, + Session, + Directory, + Workspace, +} + +impl From<&HistorySearchFilterMode> for atuin_client::settings::FilterMode { + fn from(mode: &HistorySearchFilterMode) -> Self { + match mode { + HistorySearchFilterMode::Global => Self::Global, + HistorySearchFilterMode::Host => Self::Host, + HistorySearchFilterMode::Session => Self::Session, + HistorySearchFilterMode::Directory => Self::Directory, + HistorySearchFilterMode::Workspace => Self::Workspace, + } + } +} + +impl TryFrom<&serde_json::Value> for AtuinHistoryToolCall { + type Error = eyre::Error; + + fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> { + let filter_modes = value + .get("filter_modes") + .and_then(|v| v.as_array()) + .ok_or(eyre::eyre!("Missing filter_modes"))?; + + let filter_modes = filter_modes + .iter() + .map(|v| { + let mode = v.as_str().ok_or(eyre::eyre!("Invalid filter mode"))?; + match mode { + "global" => Ok(HistorySearchFilterMode::Global), + "host" => Ok(HistorySearchFilterMode::Host), + "session" => Ok(HistorySearchFilterMode::Session), + "directory" => Ok(HistorySearchFilterMode::Directory), + "workspace" => Ok(HistorySearchFilterMode::Workspace), + _ => Err(eyre::eyre!("Invalid filter mode: {mode}")), + } + }) + .collect::<Result<Vec<HistorySearchFilterMode>>>()?; + + let query = value + .get("query") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing query"))?; + + let limit = value + .get("limit") + .and_then(|v| v.as_i64()) + .unwrap_or(10) + .clamp(1, 50); + + Ok(AtuinHistoryToolCall { + filter_modes, + query: query.to_string(), + limit, + }) + } +} + +impl PermissableToolCall for AtuinHistoryToolCall { + fn target_dir(&self) -> Option<&Path> { + None + } + + fn matches_rule(&self, rule: &Rule) -> bool { + rule.tool == "AtuinHistory" + } +} + +impl AtuinHistoryToolCall { + pub(crate) async fn execute(&self, db: &atuin_client::database::Sqlite) -> ToolOutcome { + use atuin_client::database::{self, Database as _, OptFilters}; + use atuin_client::settings::SearchMode; + use time::UtcOffset; + + let context = match database::current_context().await { + Ok(ctx) => ctx, + Err(e) => return ToolOutcome::Error(format!("Failed to get history context: {e}")), + }; + + let filter_mode = self + .filter_modes + .first() + .map(atuin_client::settings::FilterMode::from) + .unwrap_or(atuin_client::settings::FilterMode::Global); + + let filter_options = OptFilters { + limit: Some(self.limit), + ..Default::default() + }; + + let results = match db + .search( + SearchMode::Fuzzy, + filter_mode, + &context, + &self.query, + filter_options, + ) + .await + { + Ok(results) => results, + Err(e) => return ToolOutcome::Error(format!("History search failed: {e}")), + }; + + if results.is_empty() { + return ToolOutcome::Success("No matching history entries found.".to_string()); + } + + let local_offset = UtcOffset::current_local_offset().unwrap_or(UtcOffset::UTC); + + let formatted: Vec<String> = results + .iter() + .enumerate() + .map(|(i, h)| { + let ts = h.timestamp.to_offset(local_offset); + let time_str = format!( + "{:04}-{:02}-{:02} {:02}:{:02}:{:02}", + ts.year(), + ts.month() as u8, + ts.day(), + ts.hour(), + ts.minute(), + ts.second(), + ); + + let duration_str = format_duration(h.duration); + + format!( + "{}. `{}` [{}] ({}, exit: {}){}", + i + 1, + h.command, + time_str, + h.cwd, + h.exit, + duration_str, + ) + }) + .collect(); + + ToolOutcome::Success(formatted.join("\n")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn read_rule(scope: Option<&str>) -> Rule { + Rule { + tool: "Read".to_string(), + scope: scope.map(String::from), + } + } + + fn write_rule(scope: Option<&str>) -> Rule { + Rule { + tool: "Write".to_string(), + scope: scope.map(String::from), + } + } + + fn read_tool(path: &str) -> ReadToolCall { + ReadToolCall { + path: PathBuf::from(path), + offset: 0, + limit: 100, + } + } + + fn write_tool(path: &str) -> WriteToolCall { + WriteToolCall { + path: PathBuf::from(path), + content: String::new(), + } + } + + // ── Cross-platform tests ── + + #[test] + fn no_scope_matches_everything() { + assert!(read_tool("any/path.txt").matches_rule(&read_rule(None))); + assert!(write_tool("any/path.txt").matches_rule(&write_rule(None))); + } + + #[test] + fn wildcard_star_matches_everything() { + assert!(read_tool("foo/bar.rs").matches_rule(&read_rule(Some("*")))); + } + + #[test] + fn wrong_tool_never_matches() { + assert!(!read_tool("foo.txt").matches_rule(&write_rule(None))); + assert!(!write_tool("foo.txt").matches_rule(&read_rule(None))); + } + + #[test] + fn extension_glob() { + assert!(read_tool("notes.md").matches_rule(&read_rule(Some("*.md")))); + assert!(!read_tool("notes.txt").matches_rule(&read_rule(Some("*.md")))); + } + + #[test] + fn relative_multi_segment_glob() { + // This matches against the path relative to cwd + let cwd = std::env::current_dir().unwrap(); + let abs = cwd + .join("crates") + .join("atuin-ai") + .join("src") + .join("lib.rs"); + let tool = read_tool(abs.to_str().unwrap()); + assert!(tool.matches_rule(&read_rule(Some("crates/**/*.rs")))); + assert!(!tool.matches_rule(&read_rule(Some("crates/**/*.py")))); + } + + // ── Unix-specific tests (absolute paths with forward slashes) ── + + #[cfg(unix)] + mod unix { + use super::*; + + #[test] + fn absolute_glob() { + assert!( + read_tool("/home/user/src/main.rs") + .matches_rule(&read_rule(Some("/home/user/src/*.rs"))) + ); + assert!( + !read_tool("/home/user/docs/readme.md") + .matches_rule(&read_rule(Some("/home/user/src/*.rs"))) + ); + } + + #[test] + fn double_star_glob() { + assert!( + read_tool("/project/crates/foo/src/lib.rs") + .matches_rule(&read_rule(Some("/project/crates/**/*.rs"))) + ); + assert!( + !read_tool("/project/crates/foo/src/lib.py") + .matches_rule(&read_rule(Some("/project/crates/**/*.rs"))) + ); + } + } + + // ── Windows-specific tests (absolute paths with drive letters) ── + + #[cfg(windows)] + mod windows { + use super::*; + + #[test] + fn absolute_glob() { + assert!( + read_tool(r"C:\Users\dev\src\main.rs") + .matches_rule(&read_rule(Some("C:/Users/dev/src/*.rs"))) + ); + assert!( + !read_tool(r"C:\Users\dev\docs\readme.md") + .matches_rule(&read_rule(Some("C:/Users/dev/src/*.rs"))) + ); + } + + #[test] + fn double_star_glob() { + assert!( + read_tool(r"C:\project\crates\foo\src\lib.rs") + .matches_rule(&read_rule(Some("C:/project/crates/**/*.rs"))) + ); + assert!( + !read_tool(r"C:\project\crates\foo\src\lib.py") + .matches_rule(&read_rule(Some("C:/project/crates/**/*.rs"))) + ); + } + } +} + +fn format_duration(nanos: i64) -> String { + if nanos <= 0 { + return String::new(); + } + + let total_secs = nanos / 1_000_000_000; + let millis = (nanos % 1_000_000_000) / 1_000_000; + + if total_secs >= 3600 { + let hours = total_secs / 3600; + let mins = (total_secs % 3600) / 60; + let secs = total_secs % 60; + format!(", {hours}h{mins}m{secs}s") + } else if total_secs >= 60 { + let mins = total_secs / 60; + let secs = total_secs % 60; + format!(", {mins}m{secs}s") + } else if total_secs > 0 { + if millis > 0 { + format!(", {total_secs}.{millis:03}s") + } else { + format!(", {total_secs}s") + } + } else { + format!(", {millis}ms") + } +} diff --git a/crates/atuin-ai/src/tui/components/atuin_ai.rs b/crates/atuin-ai/src/tui/components/atuin_ai.rs index fab29502..c04ac722 100644 --- a/crates/atuin-ai/src/tui/components/atuin_ai.rs +++ b/crates/atuin-ai/src/tui/components/atuin_ai.rs @@ -22,10 +22,11 @@ pub(crate) struct AtuinAi { pub has_command: bool, pub is_input_blank: bool, pub pending_confirmation: bool, + pub has_executing_preview: bool, } #[derive(Default)] -pub struct AtuinAiState { +pub(crate) struct AtuinAiState { tx: Option<mpsc::Sender<AiTuiEvent>>, } @@ -55,15 +56,24 @@ fn atuin_ai( return EventResult::Ignored; }; - // Ctrl+C always exits + // Ctrl+C — interrupt executing command or exit if modifiers.contains(KeyModifiers::CONTROL) && *code == KeyCode::Char('c') { - let _ = tx.send(AiTuiEvent::Exit); + if props.has_executing_preview { + let _ = tx.send(AiTuiEvent::InterruptToolExecution); + } else { + let _ = tx.send(AiTuiEvent::Exit); + } return EventResult::Consumed; } match props.mode { AppMode::Input => match code { KeyCode::Esc => { + if props.has_executing_preview { + let _ = tx.send(AiTuiEvent::InterruptToolExecution); + return EventResult::Consumed; + } + if props.pending_confirmation { let _ = tx.send(AiTuiEvent::CancelConfirmation); return EventResult::Consumed; diff --git a/crates/atuin-ai/src/tui/components/markdown.rs b/crates/atuin-ai/src/tui/components/markdown.rs index 1cd7dbcf..f164fdc5 100644 --- a/crates/atuin-ai/src/tui/components/markdown.rs +++ b/crates/atuin-ai/src/tui/components/markdown.rs @@ -16,20 +16,12 @@ use ratatui_widgets::paragraph::{Paragraph, Wrap}; /// A markdown rendering component backed by pulldown-cmark. #[props] -pub struct Markdown { +pub(crate) struct Markdown { pub source: String, } -impl Markdown { - pub fn new(source: impl Into<String>) -> Self { - Self { - source: source.into(), - } - } -} - /// Style configuration for markdown rendering. -pub struct MarkdownStyles { +pub(crate) struct MarkdownStyles { pub base: Style, pub code_inline: Style, pub code_block: Style, @@ -98,26 +90,22 @@ fn parse_markdown<'a>(source: &'a str, styles: &'a MarkdownStyles) -> Text<'stat let mut style_stack: Vec<Style> = vec![styles.base]; let mut in_code_block = false; + let mut in_list_item = false; + // True until the first paragraph inside a list item has been opened. + // The first paragraph should flow inline with the "- " prefix. + let mut list_item_first_para = false; for event in parser { match event { Event::Start(Tag::Strong) => { - let bold = style_stack - .last() - .copied() - .unwrap_or(styles.base) - .add_modifier(Modifier::BOLD); + let bold = style_stack.last().copied().unwrap_or(styles.bold); style_stack.push(bold); } Event::End(TagEnd::Strong) => { style_stack.pop(); } Event::Start(Tag::Emphasis) => { - let italic = style_stack - .last() - .copied() - .unwrap_or(styles.base) - .add_modifier(Modifier::ITALIC); + let italic = style_stack.last().copied().unwrap_or(styles.italic); style_stack.push(italic); } Event::End(TagEnd::Emphasis) => { @@ -170,12 +158,17 @@ fn parse_markdown<'a>(source: &'a str, styles: &'a MarkdownStyles) -> Text<'stat lines.push(Vec::new()); } Event::Start(Tag::Paragraph) => { - if current_line > 0 || !lines[0].is_empty() { - // Two line advances: one to end the current line, one for a blank separator. - current_line += 1; - lines.push(Vec::new()); + if in_list_item && list_item_first_para { + // First paragraph flows inline with the "- " prefix + list_item_first_para = false; + } else if current_line > 0 || !lines[0].is_empty() { current_line += 1; lines.push(Vec::new()); + if !in_list_item { + // Blank separator between paragraphs (but not inside list items) + current_line += 1; + lines.push(Vec::new()); + } } } Event::End(TagEnd::Paragraph) => {} @@ -197,8 +190,12 @@ fn parse_markdown<'a>(source: &'a str, styles: &'a MarkdownStyles) -> Text<'stat lines.push(Vec::new()); } lines[current_line].push(Span::styled("- ", Style::default().fg(Color::DarkGray))); + in_list_item = true; + list_item_first_para = true; + } + Event::End(TagEnd::Item) => { + in_list_item = false; } - Event::End(TagEnd::Item) => {} Event::Start(Tag::List(_)) => { if current_line > 0 || !lines[0].is_empty() { current_line += 1; diff --git a/crates/atuin-ai/src/tui/components/mod.rs b/crates/atuin-ai/src/tui/components/mod.rs index 2f684f5f..3458327d 100644 --- a/crates/atuin-ai/src/tui/components/mod.rs +++ b/crates/atuin-ai/src/tui/components/mod.rs @@ -1,3 +1,4 @@ -pub mod atuin_ai; -pub mod input_box; -pub mod markdown; +pub(crate) mod atuin_ai; +pub(crate) mod input_box; +pub(crate) mod markdown; +pub(crate) mod select; diff --git a/crates/atuin-ai/src/tui/components/select.rs b/crates/atuin-ai/src/tui/components/select.rs new file mode 100644 index 00000000..5abbe655 --- /dev/null +++ b/crates/atuin-ai/src/tui/components/select.rs @@ -0,0 +1,96 @@ +use std::sync::mpsc; + +use crossterm::event::KeyCode; +use eye_declare::{Elements, EventResult, Hooks, Span, Text, View, component, element, props}; +use ratatui::style::Style; +use typed_builder::TypedBuilder; + +use crate::tui::events::AiTuiEvent; + +type OnSelectFn = Box<dyn Fn(&SelectOption) -> Option<AiTuiEvent> + Send + Sync + 'static>; + +#[derive(TypedBuilder)] +pub(crate) struct SelectOption { + #[builder(setter(into))] + pub label: String, + #[builder(setter(into))] + pub value: String, + #[builder(default = Style::default())] + pub label_style: Style, + #[builder(default = Style::default().reversed())] + pub selected_style: Style, +} + +#[derive(Default)] +pub(crate) struct PermissionSelectorState { + selected_option: usize, + tx: Option<mpsc::Sender<AiTuiEvent>>, +} + +#[props] +pub(crate) struct Select { + pub options: Vec<SelectOption>, + pub on_select: OnSelectFn, +} + +#[component(props = Select, state = PermissionSelectorState)] +pub(crate) fn permission_selector( + props: &Select, + state: &PermissionSelectorState, + hooks: &mut Hooks<Select, PermissionSelectorState>, +) -> Elements { + hooks.use_focusable(true); + hooks.use_autofocus(); + + hooks.use_context::<mpsc::Sender<AiTuiEvent>>(|tx, _, state| { + state.tx = tx.cloned(); + }); + + hooks.use_event(move |event, props, state| { + if !event.is_key_press() { + return EventResult::Ignored; + } + + if let crossterm::event::Event::Key(key) = event { + if key.kind != crossterm::event::KeyEventKind::Press { + return EventResult::Ignored; + } + + match key.code { + KeyCode::Up => { + state.selected_option = + (state.selected_option + props.options.len() - 1) % props.options.len(); + return EventResult::Consumed; + } + KeyCode::Down => { + state.selected_option = (state.selected_option + 1) % props.options.len(); + return EventResult::Consumed; + } + KeyCode::Enter => { + let option = &props.options[state.selected_option]; + if let Some(event) = (props.on_select)(option) + && let Some(ref tx) = state.tx + { + let _ = tx.send(event); + } + return EventResult::Consumed; + } + _ => {} + } + } + + EventResult::Ignored + }); + + element!( + View { + #(for (index, option) in props.options.iter().enumerate() { + Text { Span(text: &option.label, style: if index == state.selected_option { + option.selected_style + } else { + option.label_style + }) } + }) + } + ) +} diff --git a/crates/atuin-ai/src/tui/dispatch.rs b/crates/atuin-ai/src/tui/dispatch.rs new file mode 100644 index 00000000..b3e84757 --- /dev/null +++ b/crates/atuin-ai/src/tui/dispatch.rs @@ -0,0 +1,571 @@ +use std::path::PathBuf; +use std::sync::mpsc; + +use crate::context::{AppContext, ClientContext}; +use crate::permissions::check::PermissionResponse; +use crate::permissions::resolver::PermissionResolver; +use crate::permissions::rule::Rule; +use crate::permissions::writer::{self, RuleDisposition}; +use crate::stream::{ChatRequest, run_chat_stream}; +use crate::tools::{ClientToolCall, ToolPhase}; +use crate::tui::events::{AiTuiEvent, PermissionResult}; +use crate::tui::state::{ExitAction, Session}; +use eye_declare::Handle; +use tokio::task::JoinHandle; + +pub(crate) fn dispatch( + handle: &Handle<Session>, + event: AiTuiEvent, + tx: &mpsc::Sender<AiTuiEvent>, + app_ctx: &AppContext, + client_ctx: &ClientContext, +) { + match event { + AiTuiEvent::ContinueAfterTools => { + on_continue_after_tools(handle, tx, app_ctx, client_ctx); + } + AiTuiEvent::InputUpdated(input) => { + on_input_updated(handle, input); + } + AiTuiEvent::SubmitInput(input) => { + on_submit_input(handle, tx, app_ctx, client_ctx, input); + } + AiTuiEvent::SlashCommand(cmd) => { + on_slash_command(handle, cmd); + } + AiTuiEvent::CheckToolCallPermission(id) => { + on_check_tool_permission(handle, tx, app_ctx, id); + } + AiTuiEvent::SelectPermission(result) => { + on_select_permission(handle, tx, app_ctx, result); + } + AiTuiEvent::CancelGeneration => { + on_cancel_generation(handle); + } + AiTuiEvent::ExecuteCommand => { + on_execute_command(handle); + } + AiTuiEvent::CancelConfirmation => { + on_cancel_confirmation(handle); + } + AiTuiEvent::InterruptToolExecution => { + on_interrupt_tool_execution(handle); + } + AiTuiEvent::InsertCommand => { + on_insert_command(handle); + } + AiTuiEvent::Retry => { + on_retry(handle, tx, app_ctx, client_ctx); + } + AiTuiEvent::Exit => { + on_exit(handle); + } + } +} + +fn launch_stream( + handle: &Handle<Session>, + tx: &mpsc::Sender<AiTuiEvent>, + app_ctx: &AppContext, + client_ctx: &ClientContext, + setup: impl FnOnce(&mut Session) + Send + 'static, +) { + let h2 = handle.clone(); + let tx2 = tx.clone(); + let app = app_ctx.clone(); + let cc = client_ctx.clone(); + let caps = app_ctx.capabilities.clone(); + handle.update(move |state| { + (setup)(state); + state.start_streaming(); + let messages = state.conversation.events_to_messages(); + let sid = state.conversation.session_id.clone(); + let request = ChatRequest::new(messages, sid, &caps); + let task: JoinHandle<()> = tokio::spawn(async move { + run_chat_stream(h2, tx2, app, cc, request).await; + }); + state.stream_abort = Some(task.abort_handle()); + }); +} + +fn on_continue_after_tools( + handle: &Handle<Session>, + tx: &mpsc::Sender<AiTuiEvent>, + app_ctx: &AppContext, + client_ctx: &ClientContext, +) { + launch_stream(handle, tx, app_ctx, client_ctx, |_state| {}); +} + +fn on_input_updated(handle: &Handle<Session>, input: String) { + let input_blank = input.trim().is_empty(); + + handle.update(move |state| { + state.interaction.is_input_blank = input_blank; + }); +} + +fn on_submit_input( + handle: &Handle<Session>, + tx: &mpsc::Sender<AiTuiEvent>, + app_ctx: &AppContext, + client_ctx: &ClientContext, + input: String, +) { + let input = input.trim().to_string(); + if input.is_empty() { + let h2 = handle.clone(); + handle.update(move |state| { + if state.conversation.has_any_command() { + state.exit_action = Some(ExitAction::Execute( + state.conversation.current_command().unwrap().to_string(), + )); + } else { + state.exit_action = Some(ExitAction::Cancel); + } + h2.exit(); + }); + return; + } + + if input.starts_with('/') { + handle.update(move |state| { + state.conversation.handle_slash_command(&input); + }); + return; + } + + // Start generation and spawn streaming task + launch_stream(handle, tx, app_ctx, client_ctx, |state| { + state.start_generating(input); + state.interaction.is_input_blank = true; + }); +} + +fn on_slash_command(handle: &Handle<Session>, command: String) { + handle.update(move |state| { + state.conversation.handle_slash_command(&command); + }); +} + +// ─────────────────────────────────────────────────────────────────── +// Tool execution dispatch +// ─────────────────────────────────────────────────────────────────── + +/// Execute a tool call. Handles Shell tools (streaming with preview) and +/// non-shell tools (synchronous) uniformly. +fn execute_tool( + handle: &Handle<Session>, + tx: &mpsc::Sender<AiTuiEvent>, + tool_id: String, + tool: ClientToolCall, + db: &std::sync::Arc<atuin_client::database::Sqlite>, +) { + match &tool { + ClientToolCall::Shell(shell_call) => { + let shell_call = shell_call.clone(); + execute_shell_tool(handle, tx, &tool_id, &shell_call); + } + _ => { + execute_simple_tool(handle, tx, tool_id, tool, db); + } + } +} + +/// Execute a non-shell tool and finish the tool call. +/// The ToolCall event is already in the conversation (added by handle_client_tool_call). +fn execute_simple_tool( + handle: &Handle<Session>, + tx: &mpsc::Sender<AiTuiEvent>, + tool_id: String, + tool: ClientToolCall, + db: &std::sync::Arc<atuin_client::database::Sqlite>, +) { + let h = handle.clone(); + let tx = tx.clone(); + let db = db.clone(); + + tokio::spawn(async move { + let outcome = tool.execute(&db).await; + h.update(move |state| { + state.finish_tool_call(&tool_id, outcome); + if !state.tool_tracker.has_pending() { + let _ = tx.send(AiTuiEvent::ContinueAfterTools); + } + }); + }); +} + +/// Execute a shell tool with streaming VT100 preview. +fn execute_shell_tool( + handle: &Handle<Session>, + tx: &mpsc::Sender<AiTuiEvent>, + tool_id: &str, + shell_call: &crate::tools::ShellToolCall, +) { + let h = handle.clone(); + let tx = tx.clone(); + let shell_call = shell_call.clone(); + let command = shell_call.command.clone(); + let tc_id = tool_id.to_string(); + + // 1. Set up channels for streaming output and interruption + let (output_tx, mut output_rx) = tokio::sync::mpsc::channel::<Vec<String>>(32); + let (abort_tx, abort_rx) = tokio::sync::oneshot::channel::<()>(); + + // 2. Mark as executing with preview and store the abort sender on the tracker entry + let tc_id_setup = tc_id.clone(); + h.update(move |state| { + if let Some(tracked) = state.tool_tracker.get_mut(&tc_id_setup) { + tracked.mark_executing_preview(command); + tracked.abort_tx = Some(abort_tx); + } + }); + + // 3. Spawn a task to consume output updates and feed them to state + let h_output = h.clone(); + let preview_id = tc_id.clone(); + let output_task = tokio::spawn(async move { + while let Some(lines) = output_rx.recv().await { + let id = preview_id.clone(); + h_output.update(move |state| { + if let Some(tracked) = state.tool_tracker.get_mut(&id) + && let ToolPhase::ExecutingWithPreview { + ref mut output_lines, + .. + } = tracked.phase + { + *output_lines = lines; + } + }); + } + }); + + // 4. Spawn the streaming execution task + let tc_id_finish = tc_id; + tokio::spawn(async move { + let outcome = + crate::tools::execute_shell_command_streaming(&shell_call, output_tx, abort_rx).await; + + // Wait for the output task to finish so the final preview lines are captured + let _ = output_task.await; + + h.update(move |state| { + state.finish_tool_call(&tc_id_finish, outcome); + if !state.tool_tracker.has_pending() { + let _ = tx.send(AiTuiEvent::ContinueAfterTools); + } + }); + }); +} + +// ─────────────────────────────────────────────────────────────────── +// Permission handlers +// ─────────────────────────────────────────────────────────────────── + +fn on_check_tool_permission( + handle: &Handle<Session>, + tx: &mpsc::Sender<AiTuiEvent>, + app_ctx: &AppContext, + id: String, +) { + let h2 = handle.clone(); + let tx_for_task = tx.clone(); + let db = app_ctx.history_db.clone(); + + tokio::spawn(async move { + let id_for_error = id.clone(); + let result = check_tool_permission_inner(&h2, &tx_for_task, &db, id).await; + + // If the inner function didn't handle the tool (returned an error message), + // finish the tool call with that error so the conversation doesn't stall. + if let Err(error_msg) = result { + let tx = tx_for_task.clone(); + h2.update(move |state| { + state.finish_tool_call(&id_for_error, crate::tools::ToolOutcome::Error(error_msg)); + if !state.tool_tracker.has_pending() { + let _ = tx.send(AiTuiEvent::ContinueAfterTools); + } + }); + } + }); +} + +/// Inner permission check that returns Err(message) if the tool call should be +/// finished with an error. Returns Ok(()) if the tool was handled (executed, +/// denied, or sent to the permission UI). +async fn check_tool_permission_inner( + h2: &Handle<Session>, + tx: &mpsc::Sender<AiTuiEvent>, + db: &std::sync::Arc<atuin_client::database::Sqlite>, + id: String, +) -> Result<(), String> { + // 1. Fetch the tracked tool's data + let id_for_fetch = id.clone(); + let (tool, target_dir) = h2 + .fetch(move |state| { + state + .tool_tracker + .get(&id_for_fetch) + .map(|t| (t.tool.clone(), t.target_dir().map(PathBuf::from))) + }) + .await + .map_err(|e| format!("Internal error fetching tool state: {e}"))? + .ok_or_else(|| "Internal error: tool not found in tracker".to_string())?; + + // 2. Resolve working directory + let working_dir = target_dir + .or_else(|| std::env::current_dir().ok()) + .ok_or_else(|| "Could not determine working directory".to_string())?; + + // 3. Create permission resolver and check + let resolver = PermissionResolver::new(working_dir) + .await + .map_err(|e| format!("Permission check failed: {e}"))?; + + let response = resolver + .check(&tool) + .await + .map_err(|e| format!("Permission check failed: {e}"))?; + + // 4. Handle response — all paths here handle the tool, so return Ok + let id_clone = id.clone(); + match response { + PermissionResponse::Allowed => { + execute_tool(h2, tx, id, tool, db); + } + PermissionResponse::Denied => { + let tx = tx.clone(); + h2.update(move |state| { + state.finish_tool_call( + &id_clone, + crate::tools::ToolOutcome::Error( + "Permission denied on the user's system".to_string(), + ), + ); + if !state.tool_tracker.has_pending() { + let _ = tx.send(AiTuiEvent::ContinueAfterTools); + } + }); + } + PermissionResponse::Ask => { + h2.update(move |state| { + if let Some(tracked) = state.tool_tracker.get_mut(&id_clone) { + tracked.mark_asking(); + } + }); + } + } + + Ok(()) +} + +fn on_select_permission( + handle: &Handle<Session>, + tx: &mpsc::Sender<AiTuiEvent>, + app_ctx: &AppContext, + permission: PermissionResult, +) { + let tx = tx.clone(); + let h2 = handle.clone(); + + match permission { + PermissionResult::Allow => { + // Fetch the tool that's asking for permission, then execute it + let db = app_ctx.history_db.clone(); + tokio::spawn(async move { + let Ok(Some((tool_id, tool))) = h2 + .fetch(move |state| { + state + .tool_tracker + .asking_for_permission() + .map(|t| (t.id.clone(), t.tool.clone())) + }) + .await + else { + return; + }; + + execute_tool(&h2, &tx, tool_id, tool, &db); + }); + } + PermissionResult::AlwaysAllowInDir => { + let db = app_ctx.history_db.clone(); + let git_root = app_ctx.git_root.clone(); + tokio::spawn(async move { + let Ok(Some((tool_id, tool))) = h2 + .fetch(move |state| { + state + .tool_tracker + .asking_for_permission() + .map(|t| (t.id.clone(), t.tool.clone())) + }) + .await + else { + return; + }; + + // Write the rule to the project (git root) or cwd permissions file + let project_root = git_root + .or_else(|| std::env::current_dir().ok()) + .unwrap_or_else(|| PathBuf::from(".")); + let file_path = writer::project_permissions_path(&project_root); + let rule = Rule { + tool: tool.rule_name().to_string(), + scope: None, + }; + if let Err(e) = writer::write_rule(&file_path, &rule, RuleDisposition::Allow).await + { + tracing::error!("Failed to write project permission rule: {e}"); + } + + execute_tool(&h2, &tx, tool_id, tool, &db); + }); + } + PermissionResult::AlwaysAllow => { + let db = app_ctx.history_db.clone(); + tokio::spawn(async move { + let Ok(Some((tool_id, tool))) = h2 + .fetch(move |state| { + state + .tool_tracker + .asking_for_permission() + .map(|t| (t.id.clone(), t.tool.clone())) + }) + .await + else { + return; + }; + + // Write the rule to the global permissions file + let file_path = writer::global_permissions_path(); + let rule = Rule { + tool: tool.rule_name().to_string(), + scope: None, + }; + if let Err(e) = writer::write_rule(&file_path, &rule, RuleDisposition::Allow).await + { + tracing::error!("Failed to write global permission rule: {e}"); + } + + execute_tool(&h2, &tx, tool_id, tool, &db); + }); + } + PermissionResult::Deny => { + h2.update(move |state| { + let Some(tracked) = state.tool_tracker.asking_for_permission() else { + return; + }; + let tool_id = tracked.id.clone(); + + state.finish_tool_call( + &tool_id, + crate::tools::ToolOutcome::Error("Permission denied by the user".to_string()), + ); + if !state.tool_tracker.has_pending() { + let _ = tx.send(AiTuiEvent::ContinueAfterTools); + } + }); + } + } +} + +// ─────────────────────────────────────────────────────────────────── +// Other handlers +// ─────────────────────────────────────────────────────────────────── + +fn on_cancel_generation(handle: &Handle<Session>) { + handle.update(|state| match state.interaction.mode { + crate::tui::state::AppMode::Generating => { + state.cancel_generation(); + } + crate::tui::state::AppMode::Streaming => { + state.cancel_streaming(); + } + _ => {} + }); +} + +fn on_execute_command(handle: &Handle<Session>) { + let h2 = handle.clone(); + handle.update(move |state| { + let cmd = state.conversation.current_command().map(|c| c.to_string()); + if let Some(cmd) = cmd { + if state.conversation.is_current_command_dangerous() + && !state.interaction.confirmation_pending + { + state.interaction.confirmation_pending = true; + } else { + state.interaction.confirmation_pending = false; + state.exit_action = Some(ExitAction::Execute(cmd)); + h2.exit(); + } + } + }); +} + +fn on_cancel_confirmation(handle: &Handle<Session>) { + handle.update(move |state| { + state.interaction.confirmation_pending = false; + }); +} + +fn on_insert_command(handle: &Handle<Session>) { + let h2 = handle.clone(); + handle.update(move |state| { + let cmd = state.conversation.current_command().map(|c| c.to_string()); + if let Some(cmd) = cmd { + state.interaction.confirmation_pending = false; + state.exit_action = Some(ExitAction::Insert(cmd)); + h2.exit(); + } + }); +} + +fn on_retry( + handle: &Handle<Session>, + tx: &mpsc::Sender<AiTuiEvent>, + app_ctx: &AppContext, + client_ctx: &ClientContext, +) { + launch_stream(handle, tx, app_ctx, client_ctx, |state| { + state.retry(); + }); +} + +fn on_exit(handle: &Handle<Session>) { + let h2 = handle.clone(); + handle.update(move |state| { + if let Some(abort) = state.stream_abort.take() { + abort.abort(); + } + state.exit_action = Some(ExitAction::Cancel); + h2.exit(); + }); +} + +fn on_interrupt_tool_execution(handle: &Handle<Session>) { + handle.update(move |state| { + // Find executing previews, send interrupt, and mark as interrupted + for tracked in state.tool_tracker.iter_mut() { + if let ToolPhase::ExecutingWithPreview { + ref mut interrupted, + ref mut exit_code, + .. + } = tracked.phase + { + *interrupted = true; + if exit_code.is_none() { + *exit_code = Some(-1); + } + // Send interrupt signal via the tracker entry's abort channel + if let Some(abort_tx) = tracked.abort_tx.take() { + let _ = abort_tx.send(()); + } + } + } + + // The spawned execution task will handle finalizing and sending + // ContinueAfterTools when the process exits. Input mode is already active. + }); +} diff --git a/crates/atuin-ai/src/tui/events.rs b/crates/atuin-ai/src/tui/events.rs index a791bb80..1a422fef 100644 --- a/crates/atuin-ai/src/tui/events.rs +++ b/crates/atuin-ai/src/tui/events.rs @@ -5,13 +5,20 @@ /// eye-declare's context system. The main event loop in `inline.rs` /// receives them and mutates `AppState` accordingly. #[derive(Debug)] -pub enum AiTuiEvent { +pub(crate) enum AiTuiEvent { /// User updated the input text InputUpdated(String), /// User submitted text input (Enter in Input mode) SubmitInput(String), /// User entered a slash command (e.g. "/help") + #[allow(unused)] SlashCommand(String), + /// Check the permission for a tool call + CheckToolCallPermission(String), + /// User selected a permission + SelectPermission(PermissionResult), + /// Continue after client tools have completed + ContinueAfterTools, /// Cancel active generation or streaming (Esc during Generating/Streaming) CancelGeneration, /// Execute the suggested command @@ -20,8 +27,18 @@ pub enum AiTuiEvent { InsertCommand, /// Cancel confirmation of dangerous command CancelConfirmation, + /// Interrupt a running tool execution (Ctrl+C during ExecutingPreview) + InterruptToolExecution, /// Retry after error Retry, /// Exit the application Exit, } + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum PermissionResult { + Allow, + AlwaysAllowInDir, + AlwaysAllow, + Deny, +} diff --git a/crates/atuin-ai/src/tui/mod.rs b/crates/atuin-ai/src/tui/mod.rs index acb251a7..afd63312 100644 --- a/crates/atuin-ai/src/tui/mod.rs +++ b/crates/atuin-ai/src/tui/mod.rs @@ -1,6 +1,7 @@ -pub mod components; -pub mod events; -pub mod state; -pub mod view; +pub(crate) mod components; +pub(crate) mod dispatch; +pub(crate) mod events; +pub(crate) mod state; +pub(crate) mod view; -pub use state::{AppMode, AppState, ConversationEvent, ExitAction}; +pub(crate) use state::{ConversationEvent, Session}; diff --git a/crates/atuin-ai/src/tui/state.rs b/crates/atuin-ai/src/tui/state.rs index 4c5c2a1e..69b35909 100644 --- a/crates/atuin-ai/src/tui/state.rs +++ b/crates/atuin-ai/src/tui/state.rs @@ -5,9 +5,11 @@ use tokio::task::AbortHandle; +use crate::tools::{ClientToolCall, ToolOutcome, ToolTracker}; + /// Streaming status indicators from server #[derive(Debug, Clone, PartialEq, Eq)] -pub enum StreamingStatus { +pub(crate) enum StreamingStatus { Processing, Searching, Thinking, @@ -15,7 +17,7 @@ pub enum StreamingStatus { } impl StreamingStatus { - pub fn from_status_str(s: &str) -> Self { + pub(crate) fn from_status_str(s: &str) -> Self { match s { "processing" => Self::Processing, "searching" => Self::Searching, @@ -23,20 +25,11 @@ impl StreamingStatus { _ => Self::Thinking, } } - - pub fn display_text(&self) -> &'static str { - match self { - Self::Processing => "Processing...", - Self::Searching => "Searching...", - Self::Thinking => "Thinking...", - Self::WaitingForTools => "Waiting for tools...", - } - } } /// Conversation event types matching the API protocol #[derive(Debug, Clone)] -pub enum ConversationEvent { +pub(crate) enum ConversationEvent { /// User message (what the user typed) UserMessage { content: String }, /// Text content from assistant (streamed or complete) @@ -62,48 +55,8 @@ pub enum ConversationEvent { } impl ConversationEvent { - /// Convert to JSON for API calls - pub fn to_json(&self) -> serde_json::Value { - match self { - ConversationEvent::UserMessage { content } => serde_json::json!({ - "type": "user_message", - "content": content - }), - ConversationEvent::Text { content } => serde_json::json!({ - "type": "text", - "content": content - }), - ConversationEvent::ToolCall { id, name, input } => serde_json::json!({ - "type": "tool_call", - "id": id, - "name": name, - "input": input - }), - ConversationEvent::ToolResult { - tool_use_id, - content, - is_error, - } => serde_json::json!({ - "type": "tool_result", - "tool_use_id": tool_use_id, - "content": content, - "is_error": is_error - }), - ConversationEvent::OutOfBandOutput { - name, - command, - content, - } => serde_json::json!({ - "type": "out_of_band_output", - "name": name, - "command": command, - "content": content - }), - } - } - /// Extract command from a suggest_command tool call - pub fn as_command(&self) -> Option<&str> { + pub(crate) fn as_command(&self) -> Option<&str> { if let ConversationEvent::ToolCall { name, input, .. } = self && name == "suggest_command" { @@ -113,8 +66,9 @@ impl ConversationEvent { } } +/// Application mode for key handling and footer text. #[derive(Debug, Clone, PartialEq, Eq, Copy)] -pub enum AppMode { +pub(crate) enum AppMode { /// User is typing input Input, /// Waiting for generation (showing spinner) @@ -126,7 +80,7 @@ pub enum AppMode { } #[derive(Debug, Clone, PartialEq, Eq)] -pub enum ExitAction { +pub(crate) enum ExitAction { /// Run the command Execute(String), /// Insert command without running @@ -135,47 +89,20 @@ pub enum ExitAction { Cancel, } -/// Application state — the domain model -/// -/// Conversation is stored as a sequence of events matching the API protocol. -/// The view function derives the UI from this state. +/// Owned event log and session ID #[derive(Debug)] -pub struct AppState { - /// Current application mode - pub mode: AppMode, +pub(crate) struct Conversation { /// Conversation events (source of truth, matches API protocol) pub events: Vec<ConversationEvent>, - /// Current error message - pub error: Option<String>, - /// Exit action (set when exiting) - pub exit_action: Option<ExitAction>, /// Session ID from server pub session_id: Option<String>, - /// Current streaming status - pub streaming_status: Option<StreamingStatus>, - /// Whether the input is blank - pub is_input_blank: bool, - /// Whether current turn was interrupted by user - pub was_interrupted: bool, - /// True when user has pressed Enter once on a dangerous command - pub confirmation_pending: bool, - /// Abort handle for the active streaming task, if any - pub stream_abort: Option<AbortHandle>, } -impl AppState { +impl Conversation { pub fn new() -> Self { Self { - mode: AppMode::Input, events: Vec::new(), - error: None, - exit_action: None, session_id: None, - streaming_status: None, - is_input_blank: false, - was_interrupted: false, - confirmation_pending: false, - stream_abort: None, } } @@ -195,16 +122,57 @@ impl AppState { i += 1; } ConversationEvent::Text { content } => { - messages.push(serde_json::json!({ - "role": "assistant", - "content": content - })); - i += 1; + // Check if the next event(s) are ToolCalls — if so, combine + // into a single assistant message with mixed content blocks. + let next_is_tool_call = events + .get(i + 1) + .is_some_and(|e| matches!(e, ConversationEvent::ToolCall { .. })); + + if next_is_tool_call { + let mut content_blocks = Vec::new(); + + if !content.is_empty() { + content_blocks.push(serde_json::json!({ + "type": "text", + "text": content + })); + } + + while let Some(ConversationEvent::ToolCall { + id, name, input, .. + }) = events.get(i + 1) + { + content_blocks.push(serde_json::json!({ + "type": "tool_use", + "id": id, + "name": name, + "input": input + })); + i += 1; + } + + messages.push(serde_json::json!({ + "role": "assistant", + "content": content_blocks + })); + i += 1; + } else { + messages.push(serde_json::json!({ + "role": "assistant", + "content": content + })); + i += 1; + } } ConversationEvent::ToolCall { .. } => { + // ToolCalls without preceding Text (shouldn't normally happen, + // but handle defensively) let mut tool_uses = Vec::new(); while i < events.len() { - if let ConversationEvent::ToolCall { id, name, input } = &events[i] { + if let ConversationEvent::ToolCall { + id, name, input, .. + } = &events[i] + { tool_uses.push(serde_json::json!({ "type": "tool_use", "id": id, @@ -247,53 +215,42 @@ impl AppState { messages } - // ===== Generation lifecycle methods ===== - - /// Start generating from submitted input - pub fn start_generating(&mut self, input: String) { - self.events - .push(ConversationEvent::UserMessage { content: input }); - self.mode = AppMode::Generating; - } - - /// Generation error occurred - pub fn generation_error(&mut self, error: String) { - self.error = Some(error); - self.mode = AppMode::Error; - } - - /// Cancel during generation - pub fn cancel_generation(&mut self) { - if let Some(abort) = self.stream_abort.take() { - abort.abort(); - } - if let Some(ConversationEvent::UserMessage { .. }) = self.events.last() { - self.events.pop(); - } - self.mode = AppMode::Input; - } - - // ===== Streaming lifecycle methods ===== - - /// Start streaming response. - /// Pushes an empty Text event that will be mutated in-place as chunks arrive. - pub fn start_streaming(&mut self) { - self.events.push(ConversationEvent::Text { - content: String::new(), - }); - self.streaming_status = None; - self.was_interrupted = false; - self.mode = AppMode::Streaming; + /// Get the most recent command from events + pub fn current_command(&self) -> Option<&str> { + self.events.iter().rev().find_map(|e| e.as_command()) } - /// Store session ID from server response - pub fn store_session_id(&mut self, session_id: String) { - self.session_id = Some(session_id); + /// Check if any turn in the conversation has a command + pub fn has_any_command(&self) -> bool { + self.events.iter().any(|e| { + if let ConversationEvent::ToolCall { name, input, .. } = e { + name == "suggest_command" && input.get("command").and_then(|v| v.as_str()).is_some() + } else { + false + } + }) } - /// Update streaming status from SSE event - pub fn update_streaming_status(&mut self, status: &str) { - self.streaming_status = Some(StreamingStatus::from_status_str(status)); + /// Check if the most recent command is marked dangerous + pub fn is_current_command_dangerous(&self) -> bool { + self.events + .iter() + .rev() + .find_map(|e| { + if let ConversationEvent::ToolCall { name, input, .. } = e + && name == "suggest_command" + { + let danger_level = input + .get("danger") + .and_then(|v| v.as_str()) + .unwrap_or("low"); + return Some( + danger_level == "high" || danger_level == "medium" || danger_level == "med", + ); + } + None + }) + .unwrap_or(false) } /// Get a mutable reference to the last Text event's content (the streaming buffer). @@ -307,28 +264,15 @@ impl AppState { }) } - /// Cancel streaming with context preservation - pub fn cancel_streaming(&mut self) { - if let Some(abort) = self.stream_abort.take() { - abort.abort(); - } - self.was_interrupted = true; - - if let Some(content) = self.streaming_content_mut() { - let trimmed = content.trim_start().to_string(); - if trimmed.is_empty() { - // Remove the empty text event - *content = String::new(); + /// Remove trailing empty Text events from the events list + fn remove_empty_trailing_text(&mut self) { + while let Some(ConversationEvent::Text { content }) = self.events.last() { + if content.is_empty() { + self.events.pop(); } else { - *content = format!("{trimmed}\n\n[User cancelled this generation]"); + break; } } - // Remove trailing empty Text events - self.remove_empty_trailing_text(); - - self.streaming_status = None; - self.confirmation_pending = false; - self.mode = AppMode::Input; } /// Append text chunk during streaming (mutates the last Text event in-place) @@ -354,26 +298,6 @@ impl AppState { } } - /// Add a tool call event during streaming. - /// The current streaming text is already in events, so we just push the tool call. - pub fn add_tool_call(&mut self, id: String, name: String, input: serde_json::Value) { - // Trim the streaming text event - if let Some(content) = self.streaming_content_mut() { - let trimmed = content.trim_start().to_string(); - *content = trimmed; - } - self.remove_empty_trailing_text(); - - let is_suggest_command = name == "suggest_command"; - self.events - .push(ConversationEvent::ToolCall { id, name, input }); - - if is_suggest_command { - self.streaming_status = None; - self.mode = AppMode::Input; - } - } - /// Add a tool result event during streaming pub fn add_tool_result(&mut self, tool_use_id: String, content: String, is_error: bool) { self.events.push(ConversationEvent::ToolResult { @@ -383,47 +307,9 @@ impl AppState { }); } - /// Finalize streaming — trim the accumulated text and change mode - pub fn finalize_streaming(&mut self) { - if let Some(content) = self.streaming_content_mut() { - let trimmed = content.trim_start().to_string(); - *content = trimmed; - } - self.remove_empty_trailing_text(); - self.streaming_status = None; - self.mode = AppMode::Input; - } - - /// Streaming error — remove the partial text event - pub fn streaming_error(&mut self, error: String) { - self.remove_empty_trailing_text(); - self.error = Some(error); - self.mode = AppMode::Error; - } - - /// Remove trailing empty Text events from the events list - fn remove_empty_trailing_text(&mut self) { - while let Some(ConversationEvent::Text { content }) = self.events.last() { - if content.is_empty() { - self.events.pop(); - } else { - break; - } - } - } - - // ===== Edit mode and exit methods ===== - - /// Start edit mode for refinement - pub fn start_edit_mode(&mut self) { - self.confirmation_pending = false; - self.mode = AppMode::Input; - } - - /// Retry after error - pub fn retry(&mut self) { - self.error = None; - self.mode = AppMode::Generating; + /// Store session ID from server response + pub fn store_session_id(&mut self, session_id: String) { + self.session_id = Some(session_id); } /// Handle a slash command @@ -445,85 +331,247 @@ impl AppState { }), } } +} - // ===== Query methods ===== +/// Ephemeral UI/presentation state +#[derive(Debug)] +pub(crate) struct Interaction { + /// Current application mode + pub mode: AppMode, + /// Whether the input is blank + pub is_input_blank: bool, + /// True when user has pressed Enter once on a dangerous command + pub confirmation_pending: bool, + /// Current streaming status + pub streaming_status: Option<StreamingStatus>, + /// Whether current turn was interrupted by user + pub was_interrupted: bool, + /// Current error message + pub error: Option<String>, +} - /// Get the most recent command from events - pub fn current_command(&self) -> Option<&str> { - self.events.iter().rev().find_map(|e| e.as_command()) +impl Interaction { + pub fn new() -> Self { + Self { + mode: AppMode::Input, + is_input_blank: false, + confirmation_pending: false, + streaming_status: None, + was_interrupted: false, + error: None, + } } +} - /// Check if the most recent command is marked dangerous - pub fn is_current_command_dangerous(&self) -> bool { - self.events - .iter() - .rev() - .find_map(|e| { - if let ConversationEvent::ToolCall { name, input, .. } = e - && name == "suggest_command" - { - let danger_level = input - .get("danger") - .and_then(|v| v.as_str()) - .unwrap_or("low"); - return Some( - danger_level == "high" || danger_level == "medium" || danger_level == "med", - ); - } - None - }) - .unwrap_or(false) +/// Top-level session state +/// +/// Decomposed into `Conversation` (event log + session ID) and +/// `Interaction` (ephemeral UI state). Session methods that cross +/// both sub-structs live here. +#[derive(Debug)] +pub(crate) struct Session { + pub conversation: Conversation, + pub interaction: Interaction, + /// Tracks all tool calls through their full lifecycle. + pub tool_tracker: ToolTracker, + /// Whether the session is running inside a git project (for permission UI labels). + pub in_git_project: bool, + /// Exit action (set when exiting) + pub exit_action: Option<ExitAction>, + /// Abort handle for the active streaming task, if any + pub stream_abort: Option<AbortHandle>, +} + +impl Session { + pub fn new(in_git_project: bool) -> Self { + Self { + conversation: Conversation::new(), + interaction: Interaction::new(), + tool_tracker: ToolTracker::new(), + in_git_project, + exit_action: None, + stream_abort: None, + } } - /// Count non-suggest_command tool calls since the last user message - pub fn tool_count_since_last_user(&self) -> usize { - let last_user_idx = self + // ===== Generation lifecycle methods ===== + + /// Start generating from submitted input + pub fn start_generating(&mut self, input: String) { + self.conversation .events - .iter() - .rposition(|e| matches!(e, ConversationEvent::UserMessage { .. })) - .unwrap_or(0); + .push(ConversationEvent::UserMessage { content: input }); + self.interaction.mode = AppMode::Generating; + } - let mut completed = 0; - let mut in_flight = false; + /// Generation error occurred + #[expect(dead_code)] + pub fn generation_error(&mut self, error: String) { + self.interaction.error = Some(error); + self.interaction.mode = AppMode::Error; + } - for event in &self.events[last_user_idx..] { - match event { - ConversationEvent::ToolCall { name, .. } if name != "suggest_command" => { - if in_flight { - completed += 1; - } - in_flight = true; - } - ConversationEvent::ToolResult { .. } => { - if in_flight { - completed += 1; - in_flight = false; - } - } - _ => {} - } + /// Cancel during generation + pub fn cancel_generation(&mut self) { + if let Some(abort) = self.stream_abort.take() { + abort.abort(); + } + if let Some(ConversationEvent::UserMessage { .. }) = self.conversation.events.last() { + self.conversation.events.pop(); } + self.interaction.mode = AppMode::Input; + } - completed + // ===== Streaming lifecycle methods ===== + + /// Start streaming response. + /// Pushes an empty Text event that will be mutated in-place as chunks arrive. + pub fn start_streaming(&mut self) { + self.conversation.events.push(ConversationEvent::Text { + content: String::new(), + }); + self.interaction.streaming_status = None; + self.interaction.was_interrupted = false; + self.interaction.mode = AppMode::Streaming; } - /// Check if any turn in the conversation has a command - pub fn has_any_command(&self) -> bool { - self.events.iter().any(|e| { - if let ConversationEvent::ToolCall { name, input, .. } = e { - name == "suggest_command" && input.get("command").and_then(|v| v.as_str()).is_some() + /// Update streaming status from SSE event + pub fn update_streaming_status(&mut self, status: &str) { + self.interaction.streaming_status = Some(StreamingStatus::from_status_str(status)); + } + + /// Cancel streaming with context preservation + pub fn cancel_streaming(&mut self) { + if let Some(abort) = self.stream_abort.take() { + abort.abort(); + } + self.interaction.was_interrupted = true; + + if let Some(content) = self.conversation.streaming_content_mut() { + let trimmed = content.trim_start().to_string(); + if trimmed.is_empty() { + // Remove the empty text event + *content = String::new(); } else { - false + *content = format!("{trimmed}\n\n[User cancelled this generation]"); } - }) + } + // Remove trailing empty Text events + self.conversation.remove_empty_trailing_text(); + + self.interaction.streaming_status = None; + self.interaction.confirmation_pending = false; + self.interaction.mode = AppMode::Input; + } + + /// Add a tool call event during streaming. + /// The current streaming text is already in events, so we just push the tool call. + pub fn add_tool_call(&mut self, id: String, name: String, input: serde_json::Value) { + // Trim the streaming text event + if let Some(content) = self.conversation.streaming_content_mut() { + let trimmed = content.trim_start().to_string(); + *content = trimmed; + } + self.conversation.remove_empty_trailing_text(); + + let is_suggest_command = name == "suggest_command"; + self.conversation + .events + .push(ConversationEvent::ToolCall { id, name, input }); + + if is_suggest_command { + self.interaction.streaming_status = None; + self.interaction.mode = AppMode::Input; + } + } + + /// Finalize streaming — trim the accumulated text and change mode + pub fn finalize_streaming(&mut self) { + if let Some(content) = self.conversation.streaming_content_mut() { + let trimmed = content.trim_start().to_string(); + *content = trimmed; + } + self.conversation.remove_empty_trailing_text(); + self.interaction.streaming_status = None; + self.interaction.mode = AppMode::Input; + } + + /// Streaming error — remove the partial text event + pub fn streaming_error(&mut self, error: String) { + self.conversation.remove_empty_trailing_text(); + self.interaction.error = Some(error); + self.interaction.mode = AppMode::Error; + } + + pub(crate) fn handle_client_tool_call( + &mut self, + id: String, + tool: ClientToolCall, + input: serde_json::Value, + ) { + let desc = tool.descriptor(); + let name = desc.canonical_names[0].to_string(); + + self.tool_tracker.insert(id.clone(), tool); + + // Add the ToolCall event to the conversation immediately so it appears + // in the view. Preview data is sourced from tool_tracker. + self.conversation + .events + .push(ConversationEvent::ToolCall { id, name, input }); + + // Client tool calls can only happen at the last part of a turn + self.interaction.streaming_status = None; + self.interaction.mode = AppMode::Input; + } + + /// Retry after error + pub fn retry(&mut self) { + self.interaction.error = None; + self.interaction.mode = AppMode::Generating; + } + + // ===== Tool lifecycle methods ===== + + /// Finish a tool call: transition tracker to Completed, push ToolResult to conversation. + /// + /// For shell commands, captures the final preview from the ExecutingWithPreview phase + /// and patches exit_code/interrupted from the authoritative ToolOutcome. + pub fn finish_tool_call(&mut self, tool_id: &str, outcome: ToolOutcome) { + let mut preview = self.tool_tracker.get(tool_id).and_then(|t| t.preview()); + + // Patch preview with authoritative outcome data (handles race where + // final VT100 update hasn't been applied yet). + if let Some(ref mut p) = preview + && let ToolOutcome::Structured { + exit_code, + interrupted, + .. + } = &outcome + { + p.interrupted = *interrupted; + if p.exit_code.is_none() { + p.exit_code = *exit_code; + } + } + + // Transition tracker entry to Completed + if let Some(tracked) = self.tool_tracker.get_mut(tool_id) { + tracked.complete(preview); + } + + let content = outcome.format_for_llm(); + let is_error = outcome.is_error(); + self.conversation + .add_tool_result(tool_id.to_string(), content, is_error); } /// Get the footer text for current mode pub fn footer_text(&self) -> &'static str { - match self.mode { + match self.interaction.mode { AppMode::Input => { - if self.has_any_command() && self.is_input_blank { - if self.confirmation_pending { + if self.conversation.has_any_command() && self.interaction.is_input_blank { + if self.interaction.confirmation_pending { "[Enter] Confirm dangerous command [Esc] Cancel" } else { "[Enter] Execute suggested command [Tab] Insert Command" @@ -542,9 +590,3 @@ impl AppState { self.exit_action.is_some() } } - -impl Default for AppState { - fn default() -> Self { - Self::new() - } -} diff --git a/crates/atuin-ai/src/tui/view/mod.rs b/crates/atuin-ai/src/tui/view/mod.rs index 0cd51dfa..ee5483d8 100644 --- a/crates/atuin-ai/src/tui/view/mod.rs +++ b/crates/atuin-ai/src/tui/view/mod.rs @@ -1,14 +1,20 @@ //! View function that builds the eye-declare element tree from app state. use eye_declare::{ - Cells, Column, Elements, HStack, Span, Spinner, Text, View, WidthConstraint, element, + BorderType, Cells, Column, Elements, HStack, Span, Spinner, Text, View, Viewport, + WidthConstraint, element, }; use ratatui_core::style::{Color, Modifier, Style}; +use crate::tools::{ClientToolCall, TrackedTool}; +use crate::tui::components::select::SelectOption; +use crate::tui::events::{AiTuiEvent, PermissionResult}; + use super::components::atuin_ai::AtuinAi; use super::components::input_box::InputBox; use super::components::markdown::Markdown; -use super::state::{AppMode, AppState}; +use super::components::select::Select; +use super::state::{AppMode, Session}; mod turn; @@ -20,23 +26,25 @@ mod turn; /// - Error display (if in error state) /// - Spacer /// - Input box (bordered, with contextual keybindings) -pub fn ai_view(state: &AppState) -> Elements { - let mut turn_builder = turn::TurnBuilder::new(); +pub(crate) fn ai_view(state: &Session) -> Elements { + let mut turn_builder = turn::TurnBuilder::new(&state.tool_tracker); - for event in &state.events { + for event in &state.conversation.events { turn_builder.add_event(event); } let turns = turn_builder.build(); - let busy = state.mode == AppMode::Streaming || state.mode == AppMode::Generating; + let busy = state.interaction.mode == AppMode::Streaming + || state.interaction.mode == AppMode::Generating; let last_index = turns.len().saturating_sub(1); element! { AtuinAi( - mode: state.mode, - has_command: state.has_any_command(), - is_input_blank: state.is_input_blank, - pending_confirmation: state.confirmation_pending, + mode: state.interaction.mode, + has_command: state.conversation.has_any_command(), + is_input_blank: state.interaction.is_input_blank, + pending_confirmation: state.interaction.confirmation_pending, + has_executing_preview: state.tool_tracker.has_executing_preview(), ) { #(for (index, turn) in turns.iter().enumerate() { #(match turn { @@ -53,25 +61,94 @@ pub fn ai_view(state: &AppState) -> Elements { }) #(if !state.is_exiting() { - View(key: "input-box", padding_top: Cells::from(1)) { - InputBox( - key: "input", - title: "Generate a command or ask a question", - title_right: "Atuin AI", - footer: state.footer_text(), - active: state.mode == AppMode::Input && !state.confirmation_pending, - ) + #(input_view(state)) + }) + } + } +} - #(if state.is_input_blank && state.has_any_command() && state.mode == AppMode::Input { - #(if state.confirmation_pending { - Text { Span(text: "[Enter] Confirm dangerous command [Esc] Cancel", style: Style::default().fg(Color::Gray)) } - } else { - Text { Span(text: "[Enter] Execute suggested command [Tab] Insert Command", style: Style::default().fg(Color::Gray)) } - }) +fn input_view(state: &Session) -> Elements { + let asking_tool = state.tool_tracker.asking_for_permission(); + let in_git_project = state.in_git_project; + + element! { + #(if let Some(tc) = asking_tool { + #(tool_call_view(tc, in_git_project)) + }) + + #(if asking_tool.is_none() { + View(key: "input-box", padding_top: Cells::from(1)) { + InputBox( + key: "input", + title: "Generate a command or ask a question", + title_right: "Atuin AI", + footer: state.footer_text(), + active: state.interaction.mode == AppMode::Input && !state.interaction.confirmation_pending, + ) + + #(if state.interaction.is_input_blank && state.conversation.has_any_command() && state.interaction.mode == AppMode::Input { + #(if state.interaction.confirmation_pending { + Text { Span(text: "[Enter] Confirm dangerous command [Esc] Cancel", style: Style::default().fg(Color::Gray)) } + } else { + Text { Span(text: "[Enter] Execute suggested command [Tab] Insert Command", style: Style::default().fg(Color::Gray)) } }) + }) + } + }) + } +} - } - }) +fn tool_call_view(tool_call: &TrackedTool, in_git_project: bool) -> Elements { + let verb = tool_call.tool.descriptor().display_verb; + let tool_desc = match &tool_call.tool { + ClientToolCall::Read(tool) => tool.path.display().to_string(), + ClientToolCall::Write(tool) => tool.path.display().to_string(), + ClientToolCall::Shell(tool) => tool.command.clone(), + ClientToolCall::AtuinHistory(tool) => tool.query.clone(), + }; + + let dir_label = if in_git_project { + "Always allow in this workspace" + } else { + "Always allow in this directory" + }; + + element! { + View(key: format!("tool-call-{}", tool_call.id), padding_left: Cells::from(2), padding_top: Cells::from(1)) { + Text { + Span(text: format!("Atuin AI would like to {}: ", verb), style: Style::default()) + Span(text: &tool_desc, style: Style::default().fg(Color::Yellow)) + } + View(padding_left: Cells::from(2)) { + Select(options: [ + SelectOption::builder() + .label("Allow") + .value("allow") + .build(), + SelectOption::builder() + .label(dir_label) + .value("always-allow-in-dir") + .build(), + SelectOption::builder() + .label("Always allow") + .value("always-allow") + .build(), + SelectOption::builder() + .label("Deny") + .value("deny") + .build(), + ], on_select: Box::new(move |option: &SelectOption| { + let value = match option.value.as_str() { + "allow" => PermissionResult::Allow, + "always-allow-in-dir" => PermissionResult::AlwaysAllowInDir, + "always-allow" => PermissionResult::AlwaysAllow, + "deny" => PermissionResult::Deny, + _ => unreachable!(), + }; + + Some(AiTuiEvent::SelectPermission(value)) + }) as Box<dyn Fn(&SelectOption) -> Option<AiTuiEvent> + Send + Sync>) + } } } } @@ -86,7 +163,7 @@ fn user_turn_view(events: &[turn::UiEvent], first_turn: bool) -> Elements { element! { View(padding_top: Cells::from(padding)) { Text { - Span(text: "You", style: label_style) + Span(text: " You ", style: label_style.reversed()) } #(for event in events { #(match event { @@ -114,9 +191,9 @@ fn agent_turn_view(events: &[turn::UiEvent], busy: bool) -> Elements { element! { View { Spinner( - label: "Atuin AI", - label_style: label_style, - done_label_style: label_style, + label: " Atuin AI ", + label_style: label_style.reversed(), + done_label_style: label_style.reversed(), hide_checkmark: true, label_first: true, done: !busy, @@ -136,6 +213,52 @@ fn agent_turn_view(events: &[turn::UiEvent], busy: bool) -> Elements { turn::UiEvent::SuggestedCommand(details) => { suggested_command_view(details) }, + turn::UiEvent::ToolCall(details) => { + let preview_done = details.preview.as_ref().is_some_and(|p| p.exit_code.is_some() || p.interrupted); + let tool_key = details.tool_use_id.clone(); + + element! { + View(key: format!("tool-output-{tool_key}"), padding_left: Cells::from(2)) { + #(if let Some(ref preview) = details.preview { + View(key: format!("preview-{tool_key}")) { + #(preview_spinner_view(&details.name, preview_done)) + Viewport( + key: format!("viewport-{tool_key}"), + lines: preview.lines.clone(), + height: 10, + border: BorderType::Plain, + border_style: Style::default().fg(Color::DarkGray), + style: Style::default().fg(Color::White), + wrap: false, + ) + #(if let Some(code) = preview.exit_code { + #(if code == 0 { + Text { + Span(text: format!("Exit code: {code}"), style: Style::default().fg(Color::Green)) + } + } else { + Text { + Span(text: format!("Exit code: {code}"), style: Style::default().fg(Color::Red)) + } + }) + }) + #(if preview.interrupted { + Text { + Span(text: "Interrupted", style: Style::default().fg(Color::Red).add_modifier(Modifier::BOLD)) + } + }) + #(if !preview_done { + Text { + Span(text: "[Ctrl+C] Interrupt", style: Style::default().fg(Color::DarkGray)) + } + }) + } + } else { + #(tool_status_view(&details.name, &details.status)) + }) + } + } + } _ => element!{} }) }) @@ -180,6 +303,48 @@ fn tool_summary_view(summary: &turn::ToolSummary) -> Elements { } } +/// Render a status indicator for a non-preview tool call (e.g. atuin_history, read_file). +fn tool_status_view(name: &str, status: &turn::ToolResultStatus) -> Elements { + match status { + turn::ToolResultStatus::Pending => { + element! { + Spinner( + label: format!("Running: {name}"), + label_style: Style::default().fg(Color::Yellow), + done: false, + ) + } + } + turn::ToolResultStatus::Success => { + element! { + Spinner( + label: format!("Ran: {name}"), + done: true, + ) + } + } + turn::ToolResultStatus::Error => { + element! { + Text { + Span(text: "✗ ", style: Style::default().fg(Color::Red)) + Span(text: format!("{name}: denied"), style: Style::default().fg(Color::Red)) + } + } + } + } +} + +/// Render a spinner/status line for a command preview (shell tools). +fn preview_spinner_view(name: &str, done: bool) -> Elements { + element! { + Spinner( + label: if done { format!("Ran: {name}") } else { format!("Running: {name}") }, + label_style: Style::default().fg(Color::Yellow), + done: done, + ) + } +} + fn suggested_command_view(details: &turn::SuggestedCommandDetails) -> Elements { let is_dangerous = matches!( details.danger_level, diff --git a/crates/atuin-ai/src/tui/view/turn.rs b/crates/atuin-ai/src/tui/view/turn.rs index 861da64c..6949236c 100644 --- a/crates/atuin-ai/src/tui/view/turn.rs +++ b/crates/atuin-ai/src/tui/view/turn.rs @@ -1,5 +1,8 @@ +use crate::tools::descriptor; +use crate::tools::{ToolPreview, ToolTracker}; use crate::tui::ConversationEvent; +/// Server-sent danger level for a suggested command #[derive(Debug)] pub(crate) enum DangerLevel { Low(Option<String>), @@ -37,6 +40,7 @@ impl From<(&String, &String)> for DangerLevel { } } +/// Server-sent confidence level for a suggested command #[derive(Debug)] pub(crate) enum ConfidenceLevel { Low(Option<String>), @@ -85,9 +89,11 @@ pub(crate) enum UiEvent { #[derive(Debug)] pub(crate) struct ToolCallDetails { - tool_use_id: String, - name: String, - status: ToolResultStatus, + pub(crate) tool_use_id: String, + pub(crate) name: String, + pub(crate) status: ToolResultStatus, + pub(crate) is_client: bool, + pub(crate) preview: Option<ToolPreview>, } #[derive(Debug)] @@ -118,16 +124,19 @@ pub(crate) enum UiTurn { OutOfBand { events: Vec<UiEvent> }, } -pub(crate) struct TurnBuilder { +pub(crate) struct TurnBuilder<'a> { turns: Vec<UiTurn>, current_turn: Option<UiTurn>, + tracker: &'a ToolTracker, } -impl TurnBuilder { - pub(crate) fn new() -> Self { +/// A struct to iteratively build [UiTurn] events from [ConversationEvent]s. +impl<'a> TurnBuilder<'a> { + pub(crate) fn new(tracker: &'a ToolTracker) -> Self { Self { turns: Vec::new(), current_turn: None, + tracker, } } @@ -174,7 +183,7 @@ impl TurnBuilder { for event in events.drain(..) { match event { - UiEvent::ToolCall(details) => { + UiEvent::ToolCall(details) if !details.is_client => { pending_tools.push(details); } other => { @@ -306,12 +315,17 @@ impl TurnBuilder { } fn add_tool_call(&mut self, id: &str, name: &str, _input: &serde_json::Value) { + let is_client = descriptor::by_name(name).is_some_and(|d| d.is_client); + let preview = self.tracker.preview_for(id); + self.start_agent_turn(); if let UiTurn::Agent { events } = self.turn_mut_unsafe() { events.push(UiEvent::ToolCall(ToolCallDetails { tool_use_id: id.to_string(), name: name.to_string(), status: ToolResultStatus::Pending, + is_client, + preview, })); } } @@ -385,25 +399,15 @@ impl ToolSummary { /// Present-tense progressive verb for a tool name (e.g. "Searching...") fn progressive_verb(name: &str) -> String { - match name { - "search" => "Searching...".into(), - "read" | "read_file" => "Reading file...".into(), - "write" | "write_file" => "Writing file...".into(), - "execute" | "run" | "bash" => "Running command...".into(), - "list" | "list_files" => "Listing files...".into(), - _ => format!("Running {}...", name.replace('_', " ")), - } + descriptor::by_name(name) + .map(|d| d.progressive_verb.to_string()) + .unwrap_or_else(|| format!("Running {}...", name.replace('_', " "))) } /// Past-tense verb for a tool name (e.g. "Searched") fn past_verb(name: &str) -> String { - match name { - "search" => "Searched".into(), - "read" | "read_file" => "Read file".into(), - "write" | "write_file" => "Wrote file".into(), - "execute" | "run" | "bash" => "Ran command".into(), - "list" | "list_files" => "Listed files".into(), - _ => format!("Ran {}", name.replace('_', " ")), - } + descriptor::by_name(name) + .map(|d| d.past_verb.to_string()) + .unwrap_or_else(|| format!("Ran {}", name.replace('_', " "))) } } diff --git a/crates/atuin-client/Cargo.toml b/crates/atuin-client/Cargo.toml index 763f9d4e..2860c82b 100644 --- a/crates/atuin-client/Cargo.toml +++ b/crates/atuin-client/Cargo.toml @@ -41,7 +41,7 @@ rand = { workspace = true } shellexpand = "3" sqlx = { workspace = true, features = ["sqlite", "regexp"] } minspan = "0.1.5" -regex = "1.10.5" +regex = { workspace = true } serde_regex = "1.1.0" fs-err = { workspace = true } sql-builder = { workspace = true } diff --git a/crates/atuin-client/src/settings.rs b/crates/atuin-client/src/settings.rs index e2624136..25c3bd65 100644 --- a/crates/atuin-client/src/settings.rs +++ b/crates/atuin-client/src/settings.rs @@ -671,6 +671,16 @@ pub struct Ai { /// Configuration for what context is sent in the opening AI request. #[serde(default)] pub opening: AiOpening, + + /// Tool capability flags. + #[serde(default)] + pub capabilities: AiCapabilities, +} + +#[derive(Default, Clone, Debug, Deserialize, Serialize)] +pub struct AiCapabilities { + /// Whether the AI can request to search Atuin history. `None` = unset (defaults to enabled, and the ai will ask for permission). + pub enable_history_search: Option<bool>, } #[derive(Default, Clone, Debug, Deserialize, Serialize)] diff --git a/crates/atuin-hex/Cargo.toml b/crates/atuin-hex/Cargo.toml index 8a574a55..9625f67d 100644 --- a/crates/atuin-hex/Cargo.toml +++ b/crates/atuin-hex/Cargo.toml @@ -18,4 +18,4 @@ crossterm = { workspace = true } eyre = { workspace = true } portable-pty = "0.8" signal-hook = "0.3" -vt100 = "0.15" +vt100 = { workspace = true } diff --git a/crates/atuin-hex/src/lib.rs b/crates/atuin-hex/src/lib.rs index 0bfe54d8..f449a386 100644 --- a/crates/atuin-hex/src/lib.rs +++ b/crates/atuin-hex/src/lib.rs @@ -222,7 +222,7 @@ mod app { fn handle_parser_msg(parser: &mut vt100::Parser, msg: ParserMsg) { match msg { ParserMsg::Data(data) => parser.process(&data), - ParserMsg::Resize { rows, cols } => parser.set_size(rows, cols), + ParserMsg::Resize { rows, cols } => parser.screen_mut().set_size(rows, cols), ParserMsg::ScreenRequest(reply_tx) => { let _ = reply_tx.send(encode_screen(parser)); } diff --git a/crates/atuin/Cargo.toml b/crates/atuin/Cargo.toml index fedb1831..c83b16af 100644 --- a/crates/atuin/Cargo.toml +++ b/crates/atuin/Cargo.toml @@ -92,7 +92,7 @@ tempfile = { workspace = true } shlex = "1.3.0" # settings editor with comment and relative ordering preservation -toml_edit = "0.25.4" +toml_edit = { workspace = true } [target.'cfg(any(target_os = "windows", target_os = "macos"))'.dependencies] arboard = { version = "3.4", optional = true, default-features = false } @@ -105,6 +105,12 @@ arboard = { version = "3.4", optional = true, default-features = false, features [target.'cfg(unix)'.dependencies] daemonize = "0.5.0" +# Enable tree-sitter shell parsing on platforms where tree-sitter's bundled C +# compiles cleanly. tree-sitter 0.26's portable/endian.h fails on illumos, +# Windows cross-compiles, and potentially other exotic targets. +[target.'cfg(any(target_os = "linux", target_os = "macos"))'.dependencies] +atuin-ai = { path = "../atuin-ai", version = "18.13.6", optional = true, default-features = false, features = ["tree-sitter"] } + [target.'cfg(windows)'.dependencies] windows-sys = { version = "0.61.2", features = ["Win32_System_Console"] } diff --git a/crates/atuin/src/command/client.rs b/crates/atuin/src/command/client.rs index 6ed4a17b..917bee1a 100644 --- a/crates/atuin/src/command/client.rs +++ b/crates/atuin/src/command/client.rs @@ -154,11 +154,21 @@ impl Cmd { daemon::daemonize_current_process()?; } - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); + #[cfg(feature = "ai")] + let mut runtime = if matches!(&self, Self::Ai(_)) { + tokio::runtime::Builder::new_multi_thread() + } else { + tokio::runtime::Builder::new_current_thread() + }; + + #[cfg(not(feature = "ai"))] + let mut runtime = tokio::runtime::Builder::new_current_thread(); + + let runtime = runtime.enable_all().build().unwrap(); + // For non-history commands, we want to initialize logging and the theme manager before + // doing anything else. History commands are performance-sensitive and run before and after + // every shell command, so we want to skip any unnecessary initialization for them. let settings = Settings::new().wrap_err("could not load client settings")?; let theme_manager = theme::ThemeManager::new(settings.theme.debug, None); let res = runtime.block_on(self.run_inner(settings, theme_manager)); diff --git a/docs/docs/ai/settings.md b/docs/docs/ai/settings.md index ad7c7af7..83b6bf63 100644 --- a/docs/docs/ai/settings.md +++ b/docs/docs/ai/settings.md @@ -8,6 +8,35 @@ Default: `false` Whether or not the AI feature are enabled. When set to `false`, the question mark keybinding will output a message with instructions to run `atuin setup` to enable the feature. +### endpoint + +Default: `null` + +The address of the Atuin AI endpoint. Used for AI features like command generation. Most users will not need this setting; it is only necessary for custom AI endpoints. + +### api_token + +Default: `null` + +The API token for the Atuin AI endpoint. Used for AI features like command generation. Most users will not need this setting; it is only necessary for custom AI endpoints. + +## Capabilities + +Settings that control what capabilities are sent to the LLM. These are specified under `[ai.capabilities]`. + +### enable_history_search + +Default: `true` + +Whether or not to include the "history search" capability in the context sent to the LLM. This allows the AI to request to search your command history for relevant commands when generating suggestions or answering questions. + +**Example config** + +```toml +[ai.capabilities] +enable_history_search = false +``` + ## Opening context Settings that control what context is sent in the opening AI request. These are specified under `[ai.opening]`. @@ -37,15 +66,3 @@ Whether or not to send your previous command as context in the initial request, [ai.opening] send_last_command = true ``` - -### endpoint - -Default: `null` - -The address of the Atuin AI endpoint. Used for AI features like command generation. Most users will not need this setting; it is only necessary for custom AI endpoints. - -### api_token - -Default: `null` - -The API token for the Atuin AI endpoint. Used for AI features like command generation. Most users will not need this setting; it is only necessary for custom AI endpoints. |
