From 2f702ad446fcd6a261a3bea0ab2807d70eca43e2 Mon Sep 17 00:00:00 2001 From: Michelle Tilley Date: Tue, 21 Apr 2026 13:07:27 -0700 Subject: refactor: Replace ad-hoc dispatch with FSM + driver architecture (#3434) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the tangled dispatch handler system (`tui/dispatch.rs`, `tui/state.rs`) with a pure finite state machine + driver architecture. The FSM handles all state transitions as explicit `(State, Event) → (NewState, Effects)` mappings. The driver executes IO effects and bridges the TUI to the FSM. --- crates/atuin-ai/src/commands/inline.rs | 181 +++-- crates/atuin-ai/src/context.rs | 26 + crates/atuin-ai/src/driver.rs | 838 ++++++++++++++++++++++ crates/atuin-ai/src/fsm/effects.rs | 81 +++ crates/atuin-ai/src/fsm/events.rs | 121 ++++ crates/atuin-ai/src/fsm/mod.rs | 917 ++++++++++++++++++++++++ crates/atuin-ai/src/fsm/tests.rs | 541 ++++++++++++++ crates/atuin-ai/src/fsm/tools.rs | 165 +++++ crates/atuin-ai/src/lib.rs | 2 + crates/atuin-ai/src/permissions/writer.rs | 1 + crates/atuin-ai/src/snapshots.rs | 2 +- crates/atuin-ai/src/stream.rs | 154 +--- crates/atuin-ai/src/tools/mod.rs | 200 +----- crates/atuin-ai/src/tui/components/atuin_ai.rs | 7 +- crates/atuin-ai/src/tui/components/input_box.rs | 12 +- crates/atuin-ai/src/tui/components/select.rs | 7 +- crates/atuin-ai/src/tui/dispatch.rs | 894 ----------------------- crates/atuin-ai/src/tui/events.rs | 4 - crates/atuin-ai/src/tui/mod.rs | 3 +- crates/atuin-ai/src/tui/state.rs | 511 +------------ crates/atuin-ai/src/tui/view/mod.rs | 62 +- crates/atuin-ai/src/tui/view/turn.rs | 13 +- 22 files changed, 2904 insertions(+), 1838 deletions(-) create mode 100644 crates/atuin-ai/src/driver.rs create mode 100644 crates/atuin-ai/src/fsm/effects.rs create mode 100644 crates/atuin-ai/src/fsm/events.rs create mode 100644 crates/atuin-ai/src/fsm/mod.rs create mode 100644 crates/atuin-ai/src/fsm/tests.rs create mode 100644 crates/atuin-ai/src/fsm/tools.rs delete mode 100644 crates/atuin-ai/src/tui/dispatch.rs (limited to 'crates') diff --git a/crates/atuin-ai/src/commands/inline.rs b/crates/atuin-ai/src/commands/inline.rs index e0a92ab4..adedc542 100644 --- a/crates/atuin-ai/src/commands/inline.rs +++ b/crates/atuin-ai/src/commands/inline.rs @@ -2,10 +2,12 @@ use std::path::PathBuf; use std::sync::mpsc; use crate::context::{AppContext, ClientContext}; +use crate::driver::{DriverEvent, IoContext, ViewState, run_driver}; +use crate::fsm::AgentFsm; +use crate::fsm::effects::ExitAction; use crate::session::{LocalSessionService, SessionManager, SessionService}; -use crate::tui::dispatch; use crate::tui::events::AiTuiEvent; -use crate::tui::state::{ExitAction, Session}; +use crate::tui::state::ConversationEvent; use crate::tui::view::ai_view; use atuin_client::database::{Database, Sqlite}; use eye_declare::{Application, CtrlCBehavior}; @@ -175,124 +177,127 @@ async fn run_inline_tui( .find_resumable(cwd.as_deref(), git_root_str.as_deref(), max_age_secs) .await?; - let (mut session_mgr, mut initial_state) = if let Some(stored) = resumable { + // ─── Build FSM ─────────────────────────────────────────────── + let (session_mgr, fsm, file_tracker, edit_permissions) = if let Some(stored) = resumable { debug!(session_id = %stored.id, "resuming AI session"); - let (mgr, events, server_sid, last_event_ts, invocation_id) = + let (mgr, mut events, server_sid, last_event_ts, invocation_id) = SessionManager::resume(Box::new(service), &stored).await?; - // Only treat this as a meaningful resume if there are API-visible events - // (not just OutOfBandOutput or SystemContext). let has_api_content = events.iter().any(|e| e.is_api_content()); if has_api_content { - let mut session = Session::new(ctx.git_root.is_some(), Some(invocation_id)); - session.conversation.events = events; - session.conversation.session_id = server_sid; - // Inject an invocation boundary so the LLM knows prior messages - // are from an earlier interaction. - session.conversation.events.push( - crate::tui::state::ConversationEvent::SystemContext { + events.push(ConversationEvent::SystemContext { content: "[Note: The user has started a new invocation of Atuin AI. Prior messages from this session are from an earlier invocation.]".to_string(), - }, - ); - session.view_start_index = session.conversation.events.len(); - session.is_resumed = true; - session.last_event_time = - last_event_ts.and_then(|ts| chrono::DateTime::from_timestamp(ts, 0)); + }); + let view_start = events.len(); + let last_time = last_event_ts.and_then(|ts| chrono::DateTime::from_timestamp(ts, 0)); - // Restore file read tracker from session metadata - if let Ok(Some(json)) = mgr.get_metadata(crate::file_tracker::METADATA_KEY).await + let ft = if let Ok(Some(json)) = + mgr.get_metadata(crate::file_tracker::METADATA_KEY).await && let Ok(tracker) = crate::file_tracker::FileReadTracker::from_json(&json) { - session.file_tracker = tracker; - } + tracker + } else { + Default::default() + }; - // Restore edit permission grants from session metadata - if let Ok(Some(json)) = mgr + let ep = if let Ok(Some(json)) = mgr .get_metadata(crate::edit_permissions::METADATA_KEY) .await && let Ok(cache) = crate::edit_permissions::EditPermissionCache::from_json(&json) { - session.edit_permissions = cache; - } - - (mgr, session) + cache + } else { + Default::default() + }; + + let caps = ctx.capabilities_as_strings(); + let fsm = AgentFsm::from_session( + events, + server_sid, + caps, + invocation_id, + view_start, + true, + last_time, + ); + (mgr, fsm, ft, ep) } else { - // No meaningful content — treat as a fresh session debug!("resumable session has no API-visible content, starting fresh"); - ( - mgr, - Session::new(ctx.git_root.is_some(), Some(invocation_id)), - ) + let caps = ctx.capabilities_as_strings(); + let fsm = AgentFsm::new(caps, invocation_id); + (mgr, fsm, Default::default(), Default::default()) } } else { debug!("creating new AI session"); let mgr = SessionManager::create_new(Box::new(service), cwd.as_deref(), git_root_str.as_deref()); - (mgr, Session::new(ctx.git_root.is_some(), None)) + let invocation_id = uuid::Uuid::now_v7().to_string(); + let caps = ctx.capabilities_as_strings(); + let fsm = AgentFsm::new(caps, invocation_id); + (mgr, fsm, Default::default(), Default::default()) }; - // Initialize the snapshot store now that we know the session ID. + // ─── Snapshot store ───────────────────────────────────────── let snapshot_dir = atuin_common::utils::data_dir() .join("ai") .join("snapshots") .join(session_mgr.session_id()); - match crate::snapshots::SnapshotStore::open(snapshot_dir) { - Ok(store) => initial_state.snapshot_store = Some(store), - Err(e) => tracing::warn!("failed to open snapshot store: {e}"), - } + let snapshot_store = crate::snapshots::SnapshotStore::open(snapshot_dir).ok(); + + let in_git_project = ctx.git_root.is_some(); + + // ─── Build initial ViewState from FSM ─────────────────────── + let initial_view = build_view_state(&fsm, in_git_project); + + // ─── Build IoContext ──────────────────────────────────────── + let io = IoContext { + app_ctx: ctx.clone(), + client_ctx: client_ctx.clone(), + session_mgr, + file_tracker, + edit_permissions, + snapshot_store, + }; + + // ─── Channel + Application ────────────────────────────────── + // Components emit DriverEvent::Tui(AiTuiEvent) via a wrapping sender. + // Spawned tasks emit DriverEvent::Fsm(Event) directly. + let (tx, rx) = mpsc::channel::(); - let (tx, rx) = mpsc::channel::(); + // Wrap sender for components: they send AiTuiEvent, we wrap it + let tui_tx = DriverEventSender(tx.clone()); println!(); - // If there's an initial prompt, send it as a SubmitInput event - // so it flows through the same path as user-typed input. if let Some(prompt) = initial_prompt { - let _ = tx.send(AiTuiEvent::SubmitInput(prompt)); + let _ = tui_tx + .0 + .send(DriverEvent::Tui(AiTuiEvent::SubmitInput(prompt))); } let (mut app, handle) = Application::builder() - .state(initial_state) + .state(initial_view) .view(ai_view) .ctrl_c(CtrlCBehavior::Deliver) .keyboard_protocol(eye_declare::KeyboardProtocol::Enhanced) .bracketed_paste(true) - .with_context(tx.clone()) + .with_context(tui_tx) .extra_newlines_at_exit(1) .build()?; - // Event loop: receives AiTuiEvent from components, mutates state via Handle. - // The dispatch thread processes events synchronously, including async persistence - // via block_on. It signals exit via an AtomicBool rather than querying the handle - // (which would hang if the TUI thread has already stopped processing). + // ─── Driver loop ──────────────────────────────────────────── let h = handle.clone(); + let exiting = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)); + let exiting_clone = exiting.clone(); let dispatch_handle = tokio::task::spawn_blocking(move || { - let mut dctx = dispatch::DispatchContext { - handle: &h, - tx: &tx, - app_ctx: &ctx, - client_ctx: &client_ctx, - session_mgr: &mut session_mgr, - exiting: std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)), - }; - while let Ok(event) = rx.recv() { - if !dispatch::dispatch(&mut dctx, event) { - break; - } - } + run_driver(fsm, io, h, rx, tx, exiting_clone, in_git_project); }); let run_result = app.run_loop().await; - - // Wait for the dispatch thread to finish its final persist before the - // tokio runtime tears down. This prevents panics from block_on calls - // racing with runtime shutdown — including on the error path. let _ = dispatch_handle.await; - run_result?; - // Map exit action to return value let result = match app.state().exit_action { Some(ExitAction::Execute(ref cmd)) => Action::Execute(cmd.clone()), Some(ExitAction::Insert(ref cmd)) => Action::Insert(cmd.clone()), @@ -302,6 +307,44 @@ async fn run_inline_tui( Ok(result) } +/// Wrapper around `mpsc::Sender` that components use as context. +/// +/// Components call `tx.send(AiTuiEvent::...)` via eye-declare's context system. +/// This wrapper implements the same interface but wraps events in `DriverEvent::Tui`. +#[derive(Debug, Clone)] +pub(crate) struct DriverEventSender(pub mpsc::Sender); + +impl DriverEventSender { + pub fn send(&self, event: AiTuiEvent) -> Result<(), mpsc::SendError> { + self.0 + .send(DriverEvent::Tui(event)) + .map_err(|_| mpsc::SendError(AiTuiEvent::Exit)) + } +} + +/// Build a ViewState snapshot from FSM state. Used for the initial view +/// and by the driver for ongoing sync. +fn build_view_state(fsm: &AgentFsm, in_git_project: bool) -> ViewState { + let safe_start = fsm.ctx.view_start_index.min(fsm.ctx.events.len()); + ViewState { + agent_state: fsm.state.clone(), + visible_events: fsm.ctx.events[safe_start..].to_vec(), + all_events: fsm.ctx.events.clone(), + session_id: fsm.ctx.session_id.clone(), + tools: fsm.ctx.tools.clone(), + current_response: fsm.ctx.current_response.clone(), + is_resumed: fsm.ctx.is_resumed, + last_event_time: fsm.ctx.last_event_time, + in_git_project, + archived_events: fsm.ctx.archived_events.clone(), + is_input_blank: true, + slash_command_input: None, + slash_command_search_results: Vec::new(), + exit_action: None, + slash_registry: Default::default(), + } +} + // ─────────────────────────────────────────────────────────────────── // Helpers // ─────────────────────────────────────────────────────────────────── diff --git a/crates/atuin-ai/src/context.rs b/crates/atuin-ai/src/context.rs index dabb5c5e..93fcf9b9 100644 --- a/crates/atuin-ai/src/context.rs +++ b/crates/atuin-ai/src/context.rs @@ -19,6 +19,32 @@ pub(crate) struct AppContext { pub capabilities: AiCapabilities, } +impl AppContext { + pub(crate) fn capabilities_as_strings(&self) -> Vec { + let mut caps = vec!["client_invocations".to_string()]; + if self.capabilities.enable_history_search.unwrap_or(true) { + caps.push("client_v1_atuin_history".to_string()); + } + if self.capabilities.enable_file_tools.unwrap_or(true) { + caps.push("client_v1_read_file".to_string()); + caps.push("client_v1_edit_file".to_string()); + caps.push("client_v1_write_file".to_string()); + } + if self.capabilities.enable_command_execution.unwrap_or(true) { + caps.push("client_v1_execute_shell_command".to_string()); + } + if let Ok(extra) = std::env::var("ATUIN_AI__ADDITIONAL_CAPS") { + caps.extend( + extra + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()), + ); + } + caps + } +} + /// Machine identity — computed once per session. #[derive(Clone, Debug)] pub(crate) struct ClientContext { diff --git a/crates/atuin-ai/src/driver.rs b/crates/atuin-ai/src/driver.rs new file mode 100644 index 00000000..2d610203 --- /dev/null +++ b/crates/atuin-ai/src/driver.rs @@ -0,0 +1,838 @@ +//! Driver loop for the agent FSM. +//! +//! Receives events from the channel, calls `fsm.handle()`, syncs ViewState +//! to the Handle, and executes effects (spawning async tasks for IO). +//! +//! The driver runs on a blocking thread (`spawn_blocking`) so it can call +//! `blocking_recv()` on the Handle and `block_on()` for async persistence. + +use std::path::PathBuf; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::mpsc; + +use eye_declare::Handle; + +use crate::context::{AppContext, ClientContext}; +use crate::edit_permissions::EditPermissionCache; +use crate::file_tracker::FileReadTracker; +use crate::fsm::effects::{Effect, ExitAction, PermissionTarget}; +use crate::fsm::events::{Event, PermissionChoice, PermissionResponse}; +use crate::fsm::tools::ToolPreviewData; +use crate::fsm::{AgentFsm, AgentState}; +use crate::permissions::resolver::PermissionResolver; +use crate::permissions::writer; +use crate::session::SessionManager; +use crate::stream::ChatRequest; +use crate::tools::ClientToolCall; +use crate::tui::events::{AiTuiEvent, PermissionResult}; +use crate::tui::state::ConversationEvent; + +// ============================================================================ +// Driver event — the unified channel type +// ============================================================================ + +/// Events processed by the driver loop. +/// +/// Components emit `Tui` variants via the channel. Spawned async tasks +/// (stream, tool execution) emit `Fsm` variants directly. +#[derive(Debug)] +pub(crate) enum DriverEvent { + /// Event from a TUI component (key press, input change, etc.) + Tui(AiTuiEvent), + /// Internal FSM event (from spawned stream/tool tasks) + Fsm(Event), +} + +// ============================================================================ +// IO context (driver-owned, not visible to FSM) +// ============================================================================ + +pub(crate) struct IoContext { + pub app_ctx: AppContext, + pub client_ctx: ClientContext, + pub session_mgr: SessionManager, + pub file_tracker: FileReadTracker, + pub edit_permissions: EditPermissionCache, + pub snapshot_store: Option, +} + +// ============================================================================ +// ViewState (Handle payload for the render thread) +// ============================================================================ + +/// State pushed to the Handle for the view/render thread. +/// Synced from the FSM after each transition. +#[derive(Debug)] +pub(crate) struct ViewState { + // ─── From FSM ─────────────────────────────────────────────── + pub agent_state: AgentState, + pub visible_events: Vec, + pub all_events: Vec, + pub session_id: Option, + pub tools: crate::fsm::tools::ToolManager, + pub current_response: String, + + // ─── Session metadata (set once) ──────────────────────────── + pub is_resumed: bool, + pub last_event_time: Option>, + pub in_git_project: bool, + + // ─── View-only ────────────────────────────────────────────── + pub archived_events: Vec, + + // ─── Ephemeral interaction state ──────────────────────────── + pub is_input_blank: bool, + pub slash_command_input: Option, + pub slash_command_search_results: Vec, + pub exit_action: Option, + pub slash_registry: crate::tui::slash::SlashCommandRegistry, +} + +impl ViewState { + pub fn is_exiting(&self) -> bool { + self.exit_action.is_some() + } + + pub fn is_busy(&self) -> bool { + matches!(self.agent_state, AgentState::Turn { .. }) + } + + pub fn has_confirmation(&self) -> bool { + matches!( + self.agent_state, + AgentState::Idle { + confirmation: Some(_) + } + ) + } + + pub fn is_input_active(&self) -> bool { + matches!(self.agent_state, AgentState::Idle { .. }) && !self.has_confirmation() + } + + /// Whether any command has been suggested in the current invocation. + pub fn has_command(&self) -> bool { + self.visible_events.iter().any(|e| { + if let ConversationEvent::ToolCall { name, input, .. } = e { + name == "suggest_command" && input.get("command").and_then(|v| v.as_str()).is_some() + } else { + false + } + }) + } + + pub fn footer_text(&self) -> &'static str { + match &self.agent_state { + AgentState::Idle { confirmation: None } => { + if self.has_command() && self.is_input_blank { + "[Enter] Execute suggested command [Tab] Insert Command" + } else { + "[Enter] Send [Shift+Enter] New line [Esc] Exit" + } + } + AgentState::Idle { + confirmation: Some(_), + } => "[Enter] Confirm dangerous command [Esc] Cancel", + AgentState::Turn { .. } => "[Esc] Cancel", + AgentState::Error(_) => "[Enter]/[r] Retry [Esc] Exit", + } + } +} + +// ============================================================================ +// Main driver loop +// ============================================================================ + +struct DriverContext<'a> { + fsm: &'a mut AgentFsm, + io: &'a mut IoContext, + handle: &'a Handle, + tx: &'a mpsc::Sender, + exiting: &'a Arc, + stream_cancel_tx: &'a mut Option>, + tool_abort_txs: &'a mut std::collections::HashMap>, +} + +/// Main driver loop. Processes events, transitions FSM, syncs view, executes effects. +/// +/// Runs on a blocking thread. Returns when the event channel closes or exit is requested. +/// The Handle already contains the initial ViewState (set by Application::builder). +pub(crate) fn run_driver( + mut fsm: AgentFsm, + mut io: IoContext, + handle: Handle, + rx: mpsc::Receiver, + tx: mpsc::Sender, + exiting: Arc, + in_git_project: bool, +) { + // Dropping the sender cancels the stream (receiver sees Err on changed()). + let mut stream_cancel_tx: Option> = None; + // Per-tool interrupt senders for shell commands. + let mut tool_abort_txs: std::collections::HashMap> = + std::collections::HashMap::new(); + + while let Ok(driver_event) = rx.recv() { + // Log and translate DriverEvent to FSM Event (or handle directly) + let fsm_event = match driver_event { + DriverEvent::Fsm(event) => { + tracing::trace!(?event, state = ?fsm.state, "FSM event"); + Some(event) + } + DriverEvent::Tui(tui_event) => { + tracing::trace!(?tui_event, state = ?fsm.state, "TUI event"); + translate_tui_event(tui_event, &handle) + } + }; + + if let Some(event) = fsm_event { + // Feed event to FSM + let effects = fsm.handle(event); + tracing::trace!(?effects, state = ?fsm.state, "FSM transition"); + + // Sync ViewState to Handle (FSM owns all state now) + sync_view_state(&handle, &fsm, in_git_project); + + // Execute effects (only persist when FSM says to) + for effect in &effects { + if matches!(effect, Effect::Persist) { + persist(&fsm, &mut io); + } + + let ctx = DriverContext { + fsm: &mut fsm, + io: &mut io, + handle: &handle, + tx: &tx, + exiting: &exiting, + stream_cancel_tx: &mut stream_cancel_tx, + tool_abort_txs: &mut tool_abort_txs, + }; + + execute_effect(effect, ctx); + } + + // Final sync after effects — ensures the render thread sees + // the absolute final state even if effects modified anything. + if !effects.is_empty() { + sync_view_state(&handle, &fsm, in_git_project); + } + } else { + // Event was handled directly (e.g. InputUpdated) — just sync + sync_view_state(&handle, &fsm, in_git_project); + } + + if exiting.load(Ordering::Acquire) { + break; + } + tracing::trace!(state = ?fsm.state, "driver loop iteration complete, waiting for next event"); + } +} + +// ============================================================================ +// TUI event translation +// ============================================================================ + +/// Translate a TUI event into an FSM event. +/// Returns None for events handled directly (e.g. InputUpdated). +fn translate_tui_event(event: AiTuiEvent, handle: &Handle) -> Option { + match event { + AiTuiEvent::SubmitInput(input) => { + // Clear slash state and reset is_input_blank (the InputBox clears + // its text on submit but doesn't fire InputUpdated for the clear). + handle.update(|vs| { + vs.slash_command_input = None; + vs.slash_command_search_results.clear(); + vs.is_input_blank = true; + }); + + let input = input.trim().to_string(); + if input.is_empty() { + Some(Event::ExecuteCommand) + } else if input == "/new" { + Some(Event::NewSession) + } else if input.starts_with('/') { + let content = resolve_slash_command(&input, handle); + Some(Event::SlashCommand { + command: input, + content, + }) + } else { + Some(Event::UserSubmit(input)) + } + } + AiTuiEvent::InputUpdated(text) => { + let is_blank = text.is_empty(); + handle.update(move |vs| { + vs.is_input_blank = is_blank; + if text.starts_with('/') { + let query = text.trim_start_matches('/').to_string(); + let mut results = vs.slash_registry.search_fuzzy(&query); + results.sort_by(|a, b| { + b.relevance + .partial_cmp(&a.relevance) + .unwrap_or(std::cmp::Ordering::Equal) + }); + vs.slash_command_input = Some(query); + vs.slash_command_search_results = results; + } else { + vs.slash_command_input = None; + vs.slash_command_search_results.clear(); + } + }); + None + } + AiTuiEvent::CancelGeneration => Some(Event::Cancel), + AiTuiEvent::ExecuteCommand => Some(Event::ExecuteCommand), + AiTuiEvent::InsertCommand => Some(Event::InsertCommand), + AiTuiEvent::CancelConfirmation => Some(Event::Cancel), + AiTuiEvent::InterruptToolExecution => Some(Event::InterruptTools), + AiTuiEvent::Retry => Some(Event::Retry), + AiTuiEvent::Exit => Some(Event::Cancel), + AiTuiEvent::SelectPermission(result) => { + let tool_id = handle + .fetch(|vs| vs.tools.awaiting_permission().map(|t| t.id.clone())) + .blocking_recv() + .ok() + .flatten()?; + + let choice = match result { + PermissionResult::Allow => PermissionChoice::Allow, + PermissionResult::AllowFileForSession => PermissionChoice::AllowForSession, + PermissionResult::AlwaysAllowInDir => PermissionChoice::AlwaysAllowInProject, + PermissionResult::AlwaysAllow => PermissionChoice::AlwaysAllow, + PermissionResult::Deny => PermissionChoice::Deny, + }; + Some(Event::PermissionUserChoice { tool_id, choice }) + } + AiTuiEvent::SlashCommand(cmd) => { + let content = resolve_slash_command(&cmd, handle); + Some(Event::SlashCommand { + command: cmd, + content, + }) + } + } +} + +/// Resolve a slash command to its output content. +fn resolve_slash_command(command: &str, handle: &Handle) -> String { + match command.trim() { + "/help" => { + let commands = handle + .fetch(|vs| { + vs.slash_registry + .get_commands() + .iter() + .map(|cmd| format!("- `/{}` — {}", cmd.name, cmd.description)) + .collect::>() + .join("\n") + }) + .blocking_recv() + .unwrap_or_default(); + include_str!("tui/content/help.md").replace("{commands}", &commands) + } + _ => format!("Unknown command: {command}"), + } +} + +// ============================================================================ +// ViewState sync +// ============================================================================ + +fn sync_view_state(handle: &Handle, fsm: &AgentFsm, in_git_project: bool) { + let state = fsm.state.clone(); + let safe_start = fsm.ctx.view_start_index.min(fsm.ctx.events.len()); + let mut visible_events = fsm.ctx.events[safe_start..].to_vec(); + let all_events = fsm.ctx.events.clone(); + let tools = fsm.ctx.tools.clone(); + let current_response = fsm.ctx.current_response.clone(); + let session_id = fsm.ctx.session_id.clone(); + let is_resumed = fsm.ctx.is_resumed; + let last_event_time = fsm.ctx.last_event_time; + let archived_events = fsm.ctx.archived_events.clone(); + + // Inject streaming text as a synthetic event for live rendering. + // The FSM commits text to events on stream end; this makes it visible during streaming. + let trimmed = current_response.trim_start(); + if !trimmed.is_empty() { + visible_events.push(ConversationEvent::Text { + content: trimmed.to_string(), + }); + } + + tracing::trace!(?state, "sync_view_state pushing to handle"); + handle.update(move |vs| { + vs.agent_state = state; + vs.visible_events = visible_events; + vs.all_events = all_events; + vs.tools = tools; + vs.current_response = current_response; + vs.session_id = session_id; + vs.is_resumed = is_resumed; + vs.last_event_time = last_event_time; + vs.in_git_project = in_git_project; + vs.archived_events = archived_events; + }); +} + +// ============================================================================ +// Effect execution +// ============================================================================ + +fn execute_effect(effect: &Effect, ctx: DriverContext) { + let DriverContext { + fsm, + io, + handle, + tx, + exiting, + stream_cancel_tx, + tool_abort_txs, + } = ctx; + + match effect { + Effect::StartStream { + messages, + session_id, + } => { + // Cancel any existing stream before starting a new one + stream_cancel_tx.take(); + + let (cancel_tx, cancel_rx) = tokio::sync::watch::channel(()); + *stream_cancel_tx = Some(cancel_tx); + + let tx = tx.clone(); + let app = io.app_ctx.clone(); + let cc = io.client_ctx.clone(); + let request = ChatRequest::new( + messages.clone(), + session_id.clone(), + &app.capabilities, + fsm.ctx.invocation_id.clone(), + ); + tokio::spawn(async move { + run_stream_bridge(request, app, cc, tx, cancel_rx).await; + }); + } + + Effect::AbortStream => { + // Drop the sender — the bridge's cancel_rx.changed() will error, + // breaking the stream loop and dropping the HTTP connection. + stream_cancel_tx.take(); + } + + Effect::CheckPermission { tool_id, tool } => { + let tool_id = tool_id.clone(); + let tool = tool.clone(); + let tx = tx.clone(); + let working_dir = tool + .target_dir() + .map(|p| p.to_path_buf()) + .or_else(|| std::env::current_dir().ok()) + .unwrap_or_else(|| PathBuf::from(".")); + + // Check session grants first (synchronous) + if let Some(resolved) = tool.resolved_file_path() + && io.edit_permissions.has_valid_grant(&resolved) + { + let _ = tx.send(DriverEvent::Fsm(Event::PermissionResolved { + tool_id, + response: PermissionResponse::SessionGranted, + })); + return; + } + + tokio::spawn(async move { + let response = match PermissionResolver::new(working_dir).await { + Ok(resolver) => match resolver.check(&tool).await { + Ok(crate::permissions::check::PermissionResponse::Allowed) => { + PermissionResponse::Allowed + } + Ok(crate::permissions::check::PermissionResponse::Denied) => { + PermissionResponse::Denied + } + Ok(crate::permissions::check::PermissionResponse::Ask) => { + PermissionResponse::Ask + } + Err(_) => PermissionResponse::Ask, + }, + Err(_) => PermissionResponse::Ask, + }; + let _ = tx.send(DriverEvent::Fsm(Event::PermissionResolved { + tool_id, + response, + })); + }); + } + + Effect::ExecuteTool { tool_id, tool } => { + let tool_id = tool_id.clone(); + let tool = tool.clone(); + let tx = tx.clone(); + let db = io.app_ctx.history_db.clone(); + + match &tool { + ClientToolCall::Shell(shell_call) => { + let shell_call = shell_call.clone(); + let tx_preview = tx.clone(); + let tool_id_for_preview = tool_id.clone(); + + // Create interrupt channel and store the sender for AbortTool + let (interrupt_tx, interrupt_rx) = tokio::sync::oneshot::channel(); + tool_abort_txs.insert(tool_id.clone(), interrupt_tx); + + tokio::spawn(async move { + let (output_tx, mut output_rx) = + tokio::sync::mpsc::channel::>(16); + + let preview_id = tool_id_for_preview; + let tx_fwd = tx_preview; + tokio::spawn(async move { + while let Some(lines) = output_rx.recv().await { + let _ = tx_fwd.send(DriverEvent::Fsm(Event::ToolPreviewUpdate { + tool_id: preview_id.clone(), + lines, + exit_code: None, + })); + } + }); + + let outcome = crate::tools::execute_shell_command_streaming( + &shell_call, + output_tx, + interrupt_rx, + ) + .await; + + let preview = if let crate::tools::ToolOutcome::Structured { + exit_code, + interrupted, + .. + } = &outcome + { + Some(ToolPreviewData::Shell { + lines: vec![], + exit_code: *exit_code, + interrupted: *interrupted, + }) + } else { + None + }; + + let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone { + tool_id, + outcome, + preview, + })); + }); + } + ClientToolCall::Edit(edit_call) => { + let resolved = edit_call.resolved_path(); + + // Capture old content for snapshot + diff preview + let old_content = std::fs::read(&resolved).ok(); + if let Some(ref content) = old_content + && let Some(ref mut store) = io.snapshot_store + && let Err(e) = store.ensure_snapshot(&resolved, content) + { + tracing::warn!("Failed to snapshot before edit: {e}"); + } + + // Edit is fast (file read + string replace + write) — run inline + let (outcome, new_content) = edit_call.execute(&resolved, &io.file_tracker); + + // Update file tracker with new content + if let Some(new_bytes) = &new_content + && let Ok(mtime) = std::fs::metadata(&resolved).and_then(|m| m.modified()) + { + io.file_tracker + .update_after_edit(&resolved, new_bytes, mtime); + } + + // Compute diff preview + let preview = match (&old_content, &new_content) { + (Some(old_bytes), Some(new_bytes)) => { + let old_str = String::from_utf8_lossy(old_bytes); + let new_str = String::from_utf8_lossy(new_bytes); + let diff = crate::diff::EditPreview::compute(&old_str, &new_str); + if diff.hunks.is_empty() { + None + } else { + Some(ToolPreviewData::Edit(diff)) + } + } + _ => None, + }; + + let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone { + tool_id, + outcome, + preview, + })); + } + ClientToolCall::Write(write_call) => { + let resolved = write_call.resolved_path(); + + // Snapshot existing file before overwriting + if let Ok(content) = std::fs::read(&resolved) + && let Some(ref mut store) = io.snapshot_store + && let Err(e) = store.ensure_snapshot(&resolved, &content) + { + tracing::warn!("Failed to snapshot before write: {e}"); + } + + // Write is fast (atomic file write) — run inline + let (outcome, written_bytes) = write_call.execute(&resolved); + + // Update file tracker with new content + if let Some(new_bytes) = &written_bytes + && let Ok(mtime) = std::fs::metadata(&resolved).and_then(|m| m.modified()) + { + io.file_tracker + .update_after_edit(&resolved, new_bytes, mtime); + } + + let preview = if !outcome.is_error() { + Some(ToolPreviewData::Write( + crate::diff::WritePreview::from_content(&write_call.content), + )) + } else { + None + }; + + let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone { + tool_id, + outcome, + preview, + })); + } + ClientToolCall::Read(read_call) => { + // Read is fast (file read) — run inline so we can update file_tracker + let outcome = read_call.execute(); + + // Track the read for freshness checking on subsequent edits + if !outcome.is_error() { + let resolved = read_call.resolved_path(); + if resolved.is_file() + && let Ok(content) = std::fs::read(&resolved) + && let Ok(mtime) = + std::fs::metadata(&resolved).and_then(|m| m.modified()) + { + io.file_tracker.record_read(resolved, &content, mtime); + } + } + + let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone { + tool_id, + outcome, + preview: None, + })); + } + ClientToolCall::AtuinHistory(_) => { + // History search needs async DB access + tokio::spawn(async move { + let outcome = tool.execute(&db).await; + let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone { + tool_id, + outcome, + preview: None, + })); + }); + } + } + } + + Effect::AbortTool { tool_id } => { + if let Some(abort_tx) = tool_abort_txs.remove(tool_id) { + let _ = abort_tx.send(()); + } + } + + Effect::Persist => { + // Handled inline in the driver loop (before this function is called). + } + + Effect::WritePermissionRule { + target, + rule, + disposition, + } => { + let file_path = match target { + PermissionTarget::Project => { + let project_root = io + .app_ctx + .git_root + .clone() + .or_else(|| std::env::current_dir().ok()) + .unwrap_or_else(|| PathBuf::from(".")); + writer::project_permissions_path(&project_root) + } + PermissionTarget::Global => writer::global_permissions_path(), + }; + let rule = rule.clone(); + let disposition = disposition.clone(); + tokio::spawn(async move { + if let Err(e) = writer::write_rule(&file_path, &rule, disposition).await { + tracing::error!("Failed to write permission rule: {e}"); + } + }); + } + + Effect::CacheSessionGrant { path } => { + io.edit_permissions.grant(path.clone()); + } + + Effect::ArchiveSession => { + let rt = tokio::runtime::Handle::current(); + if let Err(e) = rt.block_on(io.session_mgr.archive_and_reset()) { + tracing::warn!("Failed to archive session: {e}"); + } + } + + Effect::ScheduleTimeout { + timeout_id, + duration, + } => { + let timeout_id = *timeout_id; + let duration = *duration; + let tx = tx.clone(); + tokio::spawn(async move { + tokio::time::sleep(duration).await; + let _ = tx.send(DriverEvent::Fsm(Event::ConfirmationTimeout { timeout_id })); + }); + } + + Effect::ExitApp(action) => { + let action = action.clone(); + handle.update(move |vs| { + vs.exit_action = Some(action); + }); + exiting.store(true, Ordering::Release); + let h2 = handle.clone(); + h2.exit(); + } + } +} + +// ============================================================================ +// Persistence +// ============================================================================ + +fn persist(fsm: &AgentFsm, io: &mut IoContext) { + let start = std::time::Instant::now(); + let rt = tokio::runtime::Handle::current(); + + if let Err(e) = rt.block_on(io.session_mgr.persist_events(&fsm.ctx.events)) { + tracing::warn!("Failed to persist session events: {e}"); + } + if let Some(ref sid) = fsm.ctx.session_id + && let Err(e) = rt.block_on(io.session_mgr.persist_server_session_id(sid)) + { + tracing::warn!("Failed to persist server session ID: {e}"); + } + if let Ok(json) = io.file_tracker.to_json() + && let Err(e) = rt.block_on( + io.session_mgr + .set_metadata(crate::file_tracker::METADATA_KEY, &json), + ) + { + tracing::warn!("Failed to persist file tracker: {e}"); + } + if let Ok(json) = io.edit_permissions.to_json() + && let Err(e) = rt.block_on( + io.session_mgr + .set_metadata(crate::edit_permissions::METADATA_KEY, &json), + ) + { + tracing::warn!("Failed to persist edit permissions: {e}"); + } + tracing::trace!(elapsed_ms = start.elapsed().as_millis(), "persist complete"); +} + +// ============================================================================ +// Stream bridge +// ============================================================================ + +async fn run_stream_bridge( + request: ChatRequest, + app_ctx: AppContext, + client_ctx: ClientContext, + tx: mpsc::Sender, + mut cancel_rx: tokio::sync::watch::Receiver<()>, +) { + use crate::stream::{StreamContent, StreamControl, StreamFrame, create_chat_stream}; + use futures::StreamExt; + + let stream = create_chat_stream( + app_ctx.endpoint.clone(), + app_ctx.token.clone(), + request, + client_ctx, + app_ctx.send_cwd, + app_ctx.last_command.clone(), + ); + futures::pin_mut!(stream); + + let _ = tx.send(DriverEvent::Fsm(Event::StreamStarted)); + + loop { + // Select between the next stream frame and cancellation. + // When the driver drops the cancel sender, changed() returns Err + // and we break — dropping the HTTP stream and cancelling the request. + let frame = tokio::select! { + biased; + _ = cancel_rx.changed() => break, + frame = stream.next() => match frame { + Some(frame) => frame, + None => break, + }, + }; + + let event = match frame { + Ok(StreamFrame::Content(content)) => match content { + StreamContent::TextChunk(text) => Some(Event::StreamChunk(text)), + StreamContent::ToolCall { id, name, input } => { + if name == "suggest_command" { + Some(Event::SuggestCommand { id, input }) + } else { + Some(Event::StreamToolCall { id, name, input }) + } + } + StreamContent::ToolResult { + tool_use_id, + content, + is_error, + remote, + content_length, + } => Some(Event::StreamServerToolResult { + tool_use_id, + content, + is_error, + remote, + content_length, + }), + }, + Ok(StreamFrame::Control(control)) => match control { + StreamControl::StatusChanged(status) => Some(Event::StreamStatusChanged(status)), + StreamControl::Done { session_id } => Some(Event::StreamDone { session_id }), + StreamControl::Error(msg) => Some(Event::StreamError(msg)), + }, + Err(e) => Some(Event::StreamError(e.to_string())), + }; + + if let Some(event) = event { + // StreamDone and StreamError are terminal — the server won't send more. + // SuggestCommand is NOT terminal: the server sends StreamDone after it + // with the session_id we need to capture. + let is_terminal = matches!(event, Event::StreamDone { .. } | Event::StreamError(_)); + if tx.send(DriverEvent::Fsm(event)).is_err() { + break; + } + if is_terminal { + break; + } + } + } +} diff --git a/crates/atuin-ai/src/fsm/effects.rs b/crates/atuin-ai/src/fsm/effects.rs new file mode 100644 index 00000000..ede72a42 --- /dev/null +++ b/crates/atuin-ai/src/fsm/effects.rs @@ -0,0 +1,81 @@ +//! Effects (outputs) from the agent FSM. +//! +//! The FSM returns these as data; the driver is responsible for executing them. + +use std::path::PathBuf; +use std::time::Duration; + +use serde_json::Value; + +use crate::permissions::rule::Rule; +use crate::permissions::writer::RuleDisposition; +use crate::tools::ClientToolCall; + +/// Where to write a permission rule. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum PermissionTarget { + /// Project-level: `/.atuin/permissions.ai.toml` + Project, + /// Global: `~/.config/atuin/permissions.ai.toml` + Global, +} + +/// Side effects the driver should execute after a state transition. +#[derive(Debug, Clone)] +pub(crate) enum Effect { + // ─── Network ──────────────────────────────────────────────── + /// Start a new streaming request to the server. + StartStream { + messages: Vec, + session_id: Option, + }, + /// Abort the active stream connection. + AbortStream, + + // ─── Tool orchestration ───────────────────────────────────── + /// Run the permission resolver for a tool call. + CheckPermission { + tool_id: String, + tool: ClientToolCall, + }, + /// Execute a tool (file read, edit, write, shell, history search). + ExecuteTool { + tool_id: String, + tool: ClientToolCall, + }, + /// Kill a running tool (send interrupt to shell command). + AbortTool { tool_id: String }, + + // ─── Persistence ──────────────────────────────────────────── + /// Persist current conversation state to disk. + Persist, + /// Write a permanent permission rule to disk. + WritePermissionRule { + target: PermissionTarget, + rule: Rule, + disposition: RuleDisposition, + }, + /// Cache a session-scoped file permission grant. + CacheSessionGrant { path: PathBuf }, + /// Archive current session and start fresh (IO only — state already updated by FSM). + ArchiveSession, + + // ─── Timers ───────────────────────────────────────────────── + /// Schedule a timer that will fire ConfirmationTimeout after delay. + ScheduleTimeout { timeout_id: u64, duration: Duration }, + + // ─── Exit ─────────────────────────────────────────────────── + /// Exit the application with the given action. + ExitApp(ExitAction), +} + +/// What to do when exiting the TUI. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum ExitAction { + /// Run the suggested command. + Execute(String), + /// Insert the command into the shell without running. + Insert(String), + /// Exit without action. + Cancel, +} diff --git a/crates/atuin-ai/src/fsm/events.rs b/crates/atuin-ai/src/fsm/events.rs new file mode 100644 index 00000000..62a624bf --- /dev/null +++ b/crates/atuin-ai/src/fsm/events.rs @@ -0,0 +1,121 @@ +//! Events (inputs) to the agent FSM. + +use serde_json::Value; + +use crate::tools::ToolOutcome; + +/// Events that drive state transitions in the agent FSM. +#[derive(Debug, Clone)] +pub(crate) enum Event { + // ─── User actions ─────────────────────────────────────────── + /// User submitted a message from the input box. + UserSubmit(String), + /// User pressed Esc or equivalent cancel action. + Cancel, + /// User pressed Enter to execute the suggested command. + ExecuteCommand, + /// User pressed Tab to insert the suggested command. + InsertCommand, + /// User chose to retry after an error. + Retry, + /// User interrupted executing tools (Ctrl+C / Esc during shell execution). + InterruptTools, + + // ─── Stream lifecycle ─────────────────────────────────────── + /// Stream connection established, first frame received. + StreamStarted, + /// Received a chunk of streamed text content. + StreamChunk(String), + /// Stream delivered a client-side tool call. + StreamToolCall { + id: String, + name: String, + input: Value, + }, + /// Stream delivered a server-side tool result (executed remotely). + StreamServerToolResult { + tool_use_id: String, + content: String, + is_error: bool, + remote: bool, + content_length: Option, + }, + /// Stream status changed (e.g. "thinking", "searching"). + StreamStatusChanged(String), + /// Stream ended normally. + StreamDone { session_id: String }, + /// Stream encountered an error. + StreamError(String), + + // ─── Suggest command (terminal tool call) ─────────────────── + /// The suggest_command tool call acts as a stream terminal event. + /// This is the server signaling "turn complete, here's the command." + SuggestCommand { id: String, input: Value }, + + // ─── Tool lifecycle ───────────────────────────────────────── + /// Permission resolver completed for a tool. + PermissionResolved { + tool_id: String, + response: PermissionResponse, + }, + /// User made a permission choice via the dialog. + PermissionUserChoice { + tool_id: String, + choice: PermissionChoice, + }, + /// Tool execution completed. + ToolExecutionDone { + tool_id: String, + outcome: ToolOutcome, + /// Preview data computed by the driver (diff, content preview, final shell state). + preview: Option, + }, + /// Live preview update for an executing shell command. + ToolPreviewUpdate { + tool_id: String, + lines: Vec, + exit_code: Option, + }, + + // ─── Timers ───────────────────────────────────────────────── + /// Confirmation timeout expired. + ConfirmationTimeout { timeout_id: u64 }, + + // ─── Session management ───────────────────────────────────── + /// User ran /new to start a fresh session. + NewSession, + + // ─── Slash commands ───────────────────────────────────────── + /// User submitted a slash command (other than /new). + /// The driver resolves known commands (like /help) and passes the + /// rendered content; the FSM just pushes an OOB event. + SlashCommand { command: String, content: String }, +} + +/// Result of the permission resolver check. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum PermissionResponse { + /// Rule allows this tool call — execute immediately. + Allowed, + /// Rule denies this tool call — reject with error. + Denied, + /// No matching rule — ask the user. + Ask, + /// Session-scoped grant exists — execute immediately (bypass resolver). + SessionGranted, +} + +/// User's choice from the permission dialog. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum PermissionChoice { + /// Allow this one time. + Allow, + /// Allow this file for the remainder of the session. + AllowForSession, + /// Always allow in this project (writes to project permissions file). + AlwaysAllowInProject, + /// Always allow globally (writes to global permissions file, scoped to file). + AlwaysAllow, + /// Deny this tool call. + Deny, +} diff --git a/crates/atuin-ai/src/fsm/mod.rs b/crates/atuin-ai/src/fsm/mod.rs new file mode 100644 index 00000000..92be1cd8 --- /dev/null +++ b/crates/atuin-ai/src/fsm/mod.rs @@ -0,0 +1,917 @@ +//! Agent conversation FSM. +//! +//! Pure state machine that returns effects as data. +//! The driver is responsible for executing effects and feeding events back. +//! +//! The FSM owns the conversation event log and tool lifecycle state. +//! It never performs IO directly. + +pub(crate) mod effects; +pub(crate) mod events; +pub(crate) mod tools; + +#[cfg(test)] +mod tests; + +use serde_json::Value; + +use crate::context_window::ContextWindowBuilder; +use crate::tui::state::ConversationEvent; + +use effects::{Effect, ExitAction, PermissionTarget}; +use events::{Event, PermissionChoice, PermissionResponse}; +use tools::{ToolManager, ToolState}; + +// ============================================================================ +// State +// ============================================================================ + +/// The discrete states of the agent FSM. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum AgentState { + /// Waiting for user input. + Idle { + confirmation: Option, + }, + + /// A conversation turn is in progress. + Turn { stream: StreamPhase }, + + /// Unrecoverable error. User can retry or exit. + Error(String), +} + +/// Stream connection lifecycle within a Turn. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum StreamPhase { + /// Request sent, awaiting first stream frame. + Connecting, + /// Actively receiving streamed response. + Streaming { status: Option }, + /// Stream connection has ended (Done received). + Done, +} + +/// Streaming status indicators from server. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum StreamingStatus { + Processing, + Searching, + Thinking, + WaitingForTools, +} + +impl StreamingStatus { + pub(crate) fn from_str(s: &str) -> Self { + match s { + "processing" => Self::Processing, + "searching" => Self::Searching, + "waiting_for_tools" => Self::WaitingForTools, + _ => Self::Thinking, + } + } +} + +/// Pending dangerous command confirmation state. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct PendingConfirmation { + pub command: String, + pub timeout_id: u64, +} + +// ============================================================================ +// Context +// ============================================================================ + +/// Shared context owned by the FSM. +#[derive(Debug, Clone)] +pub(crate) struct AgentContext { + /// The full conversation event log (source of truth for API + persistence). + pub events: Vec, + /// Server-assigned session ID. + pub session_id: Option, + /// Accumulated text from current stream (committed to events on tool call or stream end). + pub current_response: String, + /// Per-tool lifecycle state and cached render data. + /// Tools persist across turns for rendering history. + pub tools: ToolManager, + /// Tool IDs that belong to the current turn. Cleared on continuation start. + /// Used to determine whether a turn needs continuation (has unprocessed results). + current_turn_tool_ids: Vec, + /// Counter for generating unique timeout IDs. + next_timeout_id: u64, + /// Capabilities advertised to the server. + pub capabilities: Vec, + /// Unique invocation ID for this CLI invocation. + pub invocation_id: String, + + // ─── View state (owned by FSM for atomic transitions) ─────── + /// Index into events where the current TUI invocation starts. + /// Events before this are context for the API but not rendered. + pub view_start_index: usize, + /// Whether this session was resumed from a prior invocation. + pub is_resumed: bool, + /// Time of the last event from a previous invocation. + pub last_event_time: Option>, + /// Events from archived sessions (/new) still rendered on screen. + pub archived_events: Vec, +} + +impl AgentContext { + fn next_timeout_id(&mut self) -> u64 { + let id = self.next_timeout_id; + self.next_timeout_id += 1; + id + } +} + +// ============================================================================ +// The Agent FSM +// ============================================================================ + +/// The agent finite state machine. +/// +/// Pure state machine — `handle()` takes an event, mutates internal state, +/// and returns effects as data for the driver to execute. +#[derive(Debug, Clone)] +pub(crate) struct AgentFsm { + pub state: AgentState, + pub ctx: AgentContext, +} + +impl AgentFsm { + /// Create a new FSM in Idle state. + pub fn new(capabilities: Vec, invocation_id: String) -> Self { + Self { + state: AgentState::Idle { confirmation: None }, + ctx: AgentContext { + events: Vec::new(), + session_id: None, + current_response: String::new(), + tools: ToolManager::new(), + current_turn_tool_ids: Vec::new(), + next_timeout_id: 0, + capabilities, + invocation_id, + view_start_index: 0, + is_resumed: false, + last_event_time: None, + archived_events: Vec::new(), + }, + } + } + + /// Create an FSM from saved session state (for resume). + pub fn from_session( + events: Vec, + session_id: Option, + capabilities: Vec, + invocation_id: String, + view_start_index: usize, + is_resumed: bool, + last_event_time: Option>, + ) -> Self { + Self { + state: AgentState::Idle { confirmation: None }, + ctx: AgentContext { + events, + session_id, + current_response: String::new(), + tools: ToolManager::new(), + current_turn_tool_ids: Vec::new(), + next_timeout_id: 0, + capabilities, + invocation_id, + view_start_index, + is_resumed, + last_event_time, + archived_events: Vec::new(), + }, + } + } + + /// Handle an event, returning effects to execute. + pub fn handle(&mut self, event: Event) -> Vec { + match (&self.state, event) { + // ================================================================ + // Idle state + // ================================================================ + (AgentState::Idle { confirmation: None }, Event::UserSubmit(msg)) => { + self.start_turn(msg) + } + + ( + AgentState::Idle { + confirmation: Some(_), + }, + Event::UserSubmit(msg), + ) => self.start_turn(msg), + + (AgentState::Idle { confirmation: None }, Event::ExecuteCommand) => { + let cmd = self.current_command(); + let Some(cmd) = cmd else { + // No command suggested — exit + return vec![Effect::ExitApp(ExitAction::Cancel)]; + }; + if self.is_current_command_dangerous() { + let timeout_id = self.ctx.next_timeout_id(); + self.state = AgentState::Idle { + confirmation: Some(PendingConfirmation { + command: cmd, + timeout_id, + }), + }; + vec![Effect::ScheduleTimeout { + timeout_id, + duration: std::time::Duration::from_secs(5), + }] + } else { + vec![Effect::ExitApp(ExitAction::Execute(cmd))] + } + } + + ( + AgentState::Idle { + confirmation: Some(_), + }, + Event::ExecuteCommand, + ) => { + let confirm = self.state_confirmation().unwrap().clone(); + self.state = AgentState::Idle { confirmation: None }; + vec![Effect::ExitApp(ExitAction::Execute(confirm.command))] + } + + (AgentState::Idle { .. }, Event::InsertCommand) => { + let cmd = self.current_command(); + match cmd { + Some(cmd) => vec![Effect::ExitApp(ExitAction::Insert(cmd))], + None => vec![], + } + } + + ( + AgentState::Idle { + confirmation: Some(_), + }, + Event::Cancel, + ) => { + self.state = AgentState::Idle { confirmation: None }; + vec![] + } + + (AgentState::Idle { confirmation: None }, Event::Cancel) => { + vec![Effect::ExitApp(ExitAction::Cancel)] + } + + (AgentState::Idle { .. }, Event::ConfirmationTimeout { timeout_id }) => { + if self + .state_confirmation() + .is_some_and(|c| c.timeout_id == timeout_id) + { + self.state = AgentState::Idle { confirmation: None }; + } + vec![] + } + + (AgentState::Idle { .. }, Event::NewSession) => { + // Archive visible events so they remain on screen but aren't + // sent to the API. Tools persist for rendering. + let visible = self.ctx.events[self.ctx.view_start_index..].to_vec(); + self.ctx.archived_events.extend(visible); + + self.ctx.events.clear(); + self.ctx.session_id = None; + self.ctx.current_turn_tool_ids.clear(); + self.ctx.view_start_index = 0; + self.ctx.is_resumed = false; + + // Add OOB indicator for the new session + self.ctx.events.push(ConversationEvent::OutOfBandOutput { + name: "System".to_string(), + command: Some("/new".to_string()), + content: "Started a new session.".to_string(), + }); + + self.state = AgentState::Idle { confirmation: None }; + vec![Effect::ArchiveSession, Effect::Persist] + } + + (AgentState::Idle { .. }, Event::SlashCommand { command, content }) => { + self.handle_slash_command(&command, &content); + vec![] + } + + // ================================================================ + // Turn — stream lifecycle + // ================================================================ + ( + AgentState::Turn { + stream: StreamPhase::Connecting, + }, + Event::StreamStarted, + ) => { + self.state = AgentState::Turn { + stream: StreamPhase::Streaming { status: None }, + }; + vec![] + } + + ( + AgentState::Turn { + stream: StreamPhase::Connecting, + }, + Event::StreamError(e), + ) => { + self.state = AgentState::Error(e); + vec![] + } + + ( + AgentState::Turn { + stream: StreamPhase::Streaming { .. }, + }, + Event::StreamChunk(text), + ) => { + self.ctx.current_response.push_str(&text); + vec![] + } + + ( + AgentState::Turn { + stream: StreamPhase::Streaming { .. }, + }, + Event::StreamStatusChanged(status), + ) => { + self.state = AgentState::Turn { + stream: StreamPhase::Streaming { + status: Some(StreamingStatus::from_str(&status)), + }, + }; + vec![] + } + + (AgentState::Turn { .. }, Event::StreamToolCall { id, name, input }) => { + self.commit_streaming_text(); + self.handle_stream_tool_call(id, name, input) + } + + (AgentState::Turn { .. }, Event::SuggestCommand { id, input }) => { + self.commit_streaming_text(); + // Push the suggest_command as a ToolCall event (protocol requirement) + self.ctx.events.push(ConversationEvent::ToolCall { + id, + name: "suggest_command".to_string(), + input, + }); + self.state = AgentState::Idle { confirmation: None }; + vec![Effect::Persist] + } + + ( + AgentState::Turn { + stream: StreamPhase::Streaming { .. }, + }, + Event::StreamServerToolResult { + tool_use_id, + content, + is_error, + remote, + content_length, + }, + ) => { + self.ctx.events.push(ConversationEvent::ToolResult { + tool_use_id, + content, + is_error, + remote, + content_length, + }); + vec![] + } + + (AgentState::Turn { .. }, Event::StreamDone { session_id }) => { + self.commit_streaming_text(); + if !session_id.is_empty() { + self.ctx.session_id = Some(session_id); + } + self.state = AgentState::Turn { + stream: StreamPhase::Done, + }; + self.check_turn_completion() + } + + ( + AgentState::Turn { + stream: StreamPhase::Streaming { .. }, + }, + Event::StreamError(e), + ) => { + // Abort any executing tools on stream error + let abort_effects: Vec<_> = self + .ctx + .tools + .executing_ids() + .into_iter() + .map(|tool_id| Effect::AbortTool { tool_id }) + .collect(); + self.state = AgentState::Error(e); + abort_effects + } + + // ================================================================ + // Turn — tool lifecycle (any stream phase) + // ================================================================ + (AgentState::Turn { .. }, Event::PermissionResolved { tool_id, response }) => { + self.handle_permission_resolved(tool_id, response) + } + + (AgentState::Turn { .. }, Event::PermissionUserChoice { tool_id, choice }) => { + self.handle_permission_choice(tool_id, choice) + } + + ( + AgentState::Turn { .. }, + Event::ToolExecutionDone { + tool_id, + outcome, + preview, + }, + ) => self.handle_tool_done(tool_id, outcome, preview), + + ( + AgentState::Turn { .. }, + Event::ToolPreviewUpdate { + tool_id, + lines, + exit_code, + }, + ) => { + if let Some(tracked) = self.ctx.tools.get_mut(&tool_id) { + tracked.preview = Some(tools::ToolPreviewData::Shell { + lines, + exit_code, + interrupted: false, + }); + } + vec![] + } + + (AgentState::Turn { .. }, Event::InterruptTools) => { + let ids = self.ctx.tools.executing_ids(); + ids.into_iter() + .map(|tool_id| Effect::AbortTool { tool_id }) + .collect() + } + + // ─── Cancel during Turn ───────────────────────────────────── + (AgentState::Turn { stream }, Event::Cancel) => { + let mut effects = Vec::new(); + + // Abort stream if still active + if !matches!(stream, StreamPhase::Done) { + effects.push(Effect::AbortStream); + } + + // Cancel all pending tools + let pending = self.ctx.tools.pending_ids(); + for id in &pending { + if let Some(tracked) = self.ctx.tools.get_mut(id) { + if tracked.state == ToolState::Executing { + effects.push(Effect::AbortTool { + tool_id: id.clone(), + }); + } + tracked.state = ToolState::Completed; + } + self.ctx.events.push(ConversationEvent::ToolResult { + tool_use_id: id.clone(), + content: "Error: user cancelled this operation".to_string(), + is_error: true, + remote: false, + content_length: None, + }); + } + + // Commit any partial streaming text + self.commit_streaming_text_as_cancelled(); + + // Add context so the LLM knows what happened + if !pending.is_empty() { + self.ctx.events.push(ConversationEvent::SystemContext { + content: "The user cancelled the previous generation. Tool calls that were in progress have been aborted.".to_string(), + }); + } + + self.state = AgentState::Idle { confirmation: None }; + effects.push(Effect::Persist); + effects + } + + // ================================================================ + // Error state + // ================================================================ + (AgentState::Error(_), Event::Retry) => { + let messages = self.build_messages(); + let session_id = self.ctx.session_id.clone(); + self.state = AgentState::Turn { + stream: StreamPhase::Connecting, + }; + vec![Effect::StartStream { + messages, + session_id, + }] + } + + (AgentState::Error(_), Event::Cancel) => { + vec![Effect::ExitApp(ExitAction::Cancel)] + } + + // ================================================================ + // Fallthrough — ignore events with no valid transition + // ================================================================ + + // StreamDone can arrive after SuggestCommand (which already moved to Idle). + // We still need to capture the session_id from it. + (_, Event::StreamDone { session_id }) => { + if !session_id.is_empty() { + self.ctx.session_id = Some(session_id); + } + vec![Effect::Persist] + } + + (_, Event::SlashCommand { command, content }) => { + self.handle_slash_command(&command, &content); + vec![] + } + + _ => vec![], + } + } + + // ──────────────────────────────────────────────────────────────────── + // Private helpers + // ──────────────────────────────────────────────────────────────────── + + /// Start a new turn: push user message, build messages, emit StartStream. + fn start_turn(&mut self, msg: String) -> Vec { + self.ctx + .events + .push(ConversationEvent::UserMessage { content: msg }); + // Don't clear tools — completed tools persist for rendering history. + // Tools are only cleared on /new (session reset). + self.ctx.current_response.clear(); + self.ctx.current_turn_tool_ids.clear(); + + let messages = self.build_messages(); + let session_id = self.ctx.session_id.clone(); + self.state = AgentState::Turn { + stream: StreamPhase::Connecting, + }; + vec![Effect::StartStream { + messages, + session_id, + }] + } + + /// Build API messages from the conversation event log. + fn build_messages(&self) -> Vec { + ContextWindowBuilder::with_default_budget().build(&self.ctx.events) + } + + /// Commit accumulated streaming text to the event log. + fn commit_streaming_text(&mut self) { + let text = std::mem::take(&mut self.ctx.current_response); + let trimmed = text.trim_start().to_string(); + if !trimmed.is_empty() { + self.ctx + .events + .push(ConversationEvent::Text { content: trimmed }); + } + } + + /// Commit streaming text with a cancellation suffix. + fn commit_streaming_text_as_cancelled(&mut self) { + let text = std::mem::take(&mut self.ctx.current_response); + let trimmed = text.trim_start().to_string(); + if !trimmed.is_empty() { + self.ctx.events.push(ConversationEvent::Text { + content: format!("{trimmed}\n\n[User cancelled this generation]"), + }); + } + } + + /// Handle a client-side tool call from the stream. + fn handle_stream_tool_call(&mut self, id: String, name: String, input: Value) -> Vec { + // Parse the tool call + let tool = match crate::tools::ClientToolCall::try_from((name.as_str(), &input)) { + Ok(tool) => tool, + Err(_) => { + // Unknown tool — push as event but don't track + self.ctx + .events + .push(ConversationEvent::ToolCall { id, name, input }); + return vec![]; + } + }; + + // Capability gating + if let Some(required_cap) = tool.descriptor().capability + && !self.ctx.capabilities.iter().any(|c| c == required_cap) + { + self.ctx.events.push(ConversationEvent::ToolCall { + id: id.clone(), + name, + input, + }); + self.ctx.events.push(ConversationEvent::ToolResult { + tool_use_id: id, + content: format!( + "Tool not enabled: capability '{required_cap}' was not advertised by this client" + ), + is_error: true, + remote: false, + content_length: None, + }); + return vec![]; + } + + // Track the tool and push ToolCall event + let tool_for_effect = tool.clone(); + self.ctx.tools.insert(id.clone(), tool); + self.ctx.current_turn_tool_ids.push(id.clone()); + self.ctx.events.push(ConversationEvent::ToolCall { + id: id.clone(), + name, + input, + }); + + // Transition to Turn if we were Streaming + if let AgentState::Turn { + stream: StreamPhase::Streaming { .. }, + } = &self.state + { + self.state = AgentState::Turn { + stream: StreamPhase::Streaming { status: None }, + }; + } + + vec![Effect::CheckPermission { + tool_id: id, + tool: tool_for_effect, + }] + } + + /// Handle permission resolver result. + fn handle_permission_resolved( + &mut self, + tool_id: String, + response: PermissionResponse, + ) -> Vec { + let Some(tracked) = self.ctx.tools.get_mut(&tool_id) else { + return vec![]; + }; + + // If already resolved (e.g. cancelled while permission check was in flight), + // ignore the stale result to avoid re-executing a cancelled tool. + if tracked.is_resolved() { + return vec![]; + } + + match response { + PermissionResponse::Allowed | PermissionResponse::SessionGranted => { + tracked.state = ToolState::Executing; + let tool = tracked.tool.clone(); + vec![Effect::ExecuteTool { tool_id, tool }] + } + PermissionResponse::Ask => { + tracked.state = ToolState::AwaitingPermission; + vec![] + } + PermissionResponse::Denied => { + tracked.state = ToolState::Denied; + self.ctx.events.push(ConversationEvent::ToolResult { + tool_use_id: tool_id, + content: "Permission denied on the user's system".to_string(), + is_error: true, + remote: false, + content_length: None, + }); + self.check_turn_completion() + } + } + } + + /// Handle user's permission choice from the dialog. + fn handle_permission_choice( + &mut self, + tool_id: String, + choice: PermissionChoice, + ) -> Vec { + let Some(tracked) = self.ctx.tools.get_mut(&tool_id) else { + return vec![]; + }; + + if tracked.is_resolved() { + return vec![]; + } + + match choice { + PermissionChoice::Allow => { + tracked.state = ToolState::Executing; + let tool = tracked.tool.clone(); + vec![Effect::ExecuteTool { tool_id, tool }] + } + PermissionChoice::AllowForSession => { + tracked.state = ToolState::Executing; + let tool = tracked.tool.clone(); + let mut effects = vec![Effect::ExecuteTool { + tool_id, + tool: tool.clone(), + }]; + if let Some(path) = tool.resolved_file_path() { + effects.push(Effect::CacheSessionGrant { path }); + } + effects + } + PermissionChoice::AlwaysAllowInProject => { + tracked.state = ToolState::Executing; + let tool = tracked.tool.clone(); + let rule = crate::permissions::rule::Rule { + tool: tool.rule_name().to_string(), + scope: None, // project file provides the scoping + }; + vec![ + Effect::ExecuteTool { tool_id, tool }, + Effect::WritePermissionRule { + target: PermissionTarget::Project, + rule, + disposition: crate::permissions::writer::RuleDisposition::Allow, + }, + ] + } + PermissionChoice::AlwaysAllow => { + tracked.state = ToolState::Executing; + let tool = tracked.tool.clone(); + let scope = tool + .resolved_file_path() + .map(|p| p.to_string_lossy().to_string()); + let rule = crate::permissions::rule::Rule { + tool: tool.rule_name().to_string(), + scope, + }; + vec![ + Effect::ExecuteTool { tool_id, tool }, + Effect::WritePermissionRule { + target: PermissionTarget::Global, + rule, + disposition: crate::permissions::writer::RuleDisposition::Allow, + }, + ] + } + PermissionChoice::Deny => { + tracked.state = ToolState::Denied; + self.ctx.events.push(ConversationEvent::ToolResult { + tool_use_id: tool_id, + content: "Permission denied by the user".to_string(), + is_error: true, + remote: false, + content_length: None, + }); + self.check_turn_completion() + } + } + } + + /// Handle tool execution completion. + fn handle_tool_done( + &mut self, + tool_id: String, + outcome: crate::tools::ToolOutcome, + preview: Option, + ) -> Vec { + let Some(tracked) = self.ctx.tools.get_mut(&tool_id) else { + return vec![]; + }; + + // If already completed (e.g. cancelled), ignore stale result + if tracked.is_resolved() { + return vec![]; + } + + tracked.state = ToolState::Completed; + if preview.is_some() { + tracked.preview = preview; + } + + let content = outcome.format_for_llm(); + let is_error = outcome.is_error(); + self.ctx.events.push(ConversationEvent::ToolResult { + tool_use_id: tool_id, + content, + is_error, + remote: false, + content_length: None, + }); + + self.check_turn_completion() + } + + /// Check if the turn is complete (stream done + all tools resolved). + /// If so, either continue the conversation or go Idle. + fn check_turn_completion(&mut self) -> Vec { + // Stream must be done + if !matches!( + self.state, + AgentState::Turn { + stream: StreamPhase::Done + } + ) { + return vec![]; + } + + // All current-turn tools must be resolved before the turn can complete + if !self.ctx.tools.all_resolved(&self.ctx.current_turn_tool_ids) { + return vec![]; + } + + // Turn is complete. Check if we need to continue (tool results to send back). + // We continue if this turn had any client tool calls (the LLM needs to see + // the results and respond). + if !self.ctx.current_turn_tool_ids.is_empty() { + // Continue conversation with tool results. + // Don't clear tools — they persist for rendering history. + // Clear turn IDs so the continuation turn doesn't loop. + self.ctx.current_turn_tool_ids.clear(); + let messages = self.build_messages(); + let session_id = self.ctx.session_id.clone(); + self.ctx.current_response.clear(); + self.state = AgentState::Turn { + stream: StreamPhase::Connecting, + }; + vec![Effect::StartStream { + messages, + session_id, + }] + } else { + // No tools — turn is done, go idle + self.state = AgentState::Idle { confirmation: None }; + vec![Effect::Persist] + } + } + + /// Extract the current confirmation state (if any). + fn state_confirmation(&self) -> Option<&PendingConfirmation> { + if let AgentState::Idle { + confirmation: Some(ref c), + } = self.state + { + Some(c) + } else { + None + } + } + + /// Get the most recent suggested command from the conversation. + /// Get the most recent command from the current invocation only. + fn current_command(&self) -> Option { + self.current_invocation_events() + .rev() + .find_map(|e| e.as_command()) + .map(|s| s.to_string()) + } + + /// Check if the most recent command is dangerous. + fn is_current_command_dangerous(&self) -> bool { + self.current_invocation_events() + .rev() + .find_map(|e| { + if let ConversationEvent::ToolCall { name, input, .. } = e + && name == "suggest_command" + { + let danger = input + .get("danger") + .and_then(|v| v.as_str()) + .unwrap_or("low"); + Some(danger == "high" || danger == "medium" || danger == "med") + } else { + None + } + }) + .unwrap_or(false) + } + + /// Events from the current invocation only (from view_start_index onward). + fn current_invocation_events(&self) -> impl DoubleEndedIterator { + let start = self.ctx.view_start_index.min(self.ctx.events.len()); + self.ctx.events[start..].iter() + } + + /// Handle a slash command by pushing an OOB event. + fn handle_slash_command(&mut self, command: &str, content: &str) { + self.ctx.events.push(ConversationEvent::OutOfBandOutput { + name: "System".to_string(), + command: Some(command.to_string()), + content: content.to_string(), + }); + } +} diff --git a/crates/atuin-ai/src/fsm/tests.rs b/crates/atuin-ai/src/fsm/tests.rs new file mode 100644 index 00000000..9fc404c0 --- /dev/null +++ b/crates/atuin-ai/src/fsm/tests.rs @@ -0,0 +1,541 @@ +//! Pure FSM transition tests. No IO, no async. + +use serde_json::json; + +use super::*; +use effects::{Effect, ExitAction}; +use events::{Event, PermissionChoice, PermissionResponse}; + +fn new_fsm() -> AgentFsm { + AgentFsm::new( + vec!["client_v1_read_file".to_string()], + "test-inv".to_string(), + ) +} + +// ============================================================================ +// Idle → Turn +// ============================================================================ + +#[test] +fn user_submit_starts_turn() { + let mut fsm = new_fsm(); + + let effects = fsm.handle(Event::UserSubmit("hello".into())); + + assert!(matches!( + fsm.state, + AgentState::Turn { + stream: StreamPhase::Connecting + } + )); + assert_eq!(effects.len(), 1); + assert!(matches!(effects[0], Effect::StartStream { .. })); + // User message was pushed to events + assert!(fsm.ctx.events.iter().any(|e| matches!( + e, + ConversationEvent::UserMessage { content } if content == "hello" + ))); +} + +#[test] +fn stream_started_transitions_to_streaming() { + let mut fsm = new_fsm(); + fsm.handle(Event::UserSubmit("hello".into())); + + let effects = fsm.handle(Event::StreamStarted); + + assert!(matches!( + fsm.state, + AgentState::Turn { + stream: StreamPhase::Streaming { status: None } + } + )); + assert!(effects.is_empty()); +} + +#[test] +fn stream_chunk_accumulates_text() { + let mut fsm = new_fsm(); + fsm.handle(Event::UserSubmit("hello".into())); + fsm.handle(Event::StreamStarted); + + fsm.handle(Event::StreamChunk("Hello ".into())); + fsm.handle(Event::StreamChunk("world!".into())); + + assert_eq!(fsm.ctx.current_response, "Hello world!"); +} + +#[test] +fn stream_done_without_tools_goes_idle() { + let mut fsm = new_fsm(); + fsm.handle(Event::UserSubmit("hello".into())); + fsm.handle(Event::StreamStarted); + fsm.handle(Event::StreamChunk("Hi there!".into())); + + let effects = fsm.handle(Event::StreamDone { + session_id: "s1".into(), + }); + + assert_eq!(fsm.state, AgentState::Idle { confirmation: None }); + assert_eq!(fsm.ctx.session_id, Some("s1".to_string())); + assert!(effects.iter().any(|e| matches!(e, Effect::Persist))); + // Text was committed to events + assert!(fsm.ctx.events.iter().any(|e| matches!( + e, + ConversationEvent::Text { content } if content == "Hi there!" + ))); +} + +// ============================================================================ +// Tool lifecycle +// ============================================================================ + +#[test] +fn stream_tool_call_tracks_tool_and_emits_check_permission() { + let mut fsm = new_fsm(); + fsm.handle(Event::UserSubmit("read a file".into())); + fsm.handle(Event::StreamStarted); + + let effects = fsm.handle(Event::StreamToolCall { + id: "t1".into(), + name: "read_file".into(), + input: json!({"file_path": "/tmp/test.txt"}), + }); + + assert!(fsm.ctx.tools.get("t1").is_some()); + assert_eq!(effects.len(), 1); + assert!(matches!(effects[0], Effect::CheckPermission { .. })); +} + +#[test] +fn permission_allowed_transitions_to_executing() { + let mut fsm = new_fsm(); + fsm.handle(Event::UserSubmit("read".into())); + fsm.handle(Event::StreamStarted); + fsm.handle(Event::StreamToolCall { + id: "t1".into(), + name: "read_file".into(), + input: json!({"file_path": "/tmp/test.txt"}), + }); + + let effects = fsm.handle(Event::PermissionResolved { + tool_id: "t1".into(), + response: PermissionResponse::Allowed, + }); + + assert_eq!(fsm.ctx.tools.get("t1").unwrap().state, ToolState::Executing); + assert!(matches!(effects[0], Effect::ExecuteTool { .. })); +} + +#[test] +fn permission_ask_transitions_to_awaiting() { + let mut fsm = new_fsm(); + fsm.handle(Event::UserSubmit("read".into())); + fsm.handle(Event::StreamStarted); + fsm.handle(Event::StreamToolCall { + id: "t1".into(), + name: "read_file".into(), + input: json!({"file_path": "/tmp/test.txt"}), + }); + + let effects = fsm.handle(Event::PermissionResolved { + tool_id: "t1".into(), + response: PermissionResponse::Ask, + }); + + assert_eq!( + fsm.ctx.tools.get("t1").unwrap().state, + ToolState::AwaitingPermission + ); + assert!(effects.is_empty()); +} + +#[test] +fn tool_done_after_stream_done_continues_conversation() { + let mut fsm = new_fsm(); + fsm.handle(Event::UserSubmit("read".into())); + fsm.handle(Event::StreamStarted); + fsm.handle(Event::StreamToolCall { + id: "t1".into(), + name: "read_file".into(), + input: json!({"file_path": "/tmp/test.txt"}), + }); + fsm.handle(Event::StreamDone { + session_id: "".into(), + }); + fsm.handle(Event::PermissionResolved { + tool_id: "t1".into(), + response: PermissionResponse::Allowed, + }); + + // Now in Turn { Done } with one tool Executing + let effects = fsm.handle(Event::ToolExecutionDone { + tool_id: "t1".into(), + outcome: crate::tools::ToolOutcome::Success("file contents".into()), + preview: None, + }); + + // Turn complete → continuation + assert!(matches!( + fsm.state, + AgentState::Turn { + stream: StreamPhase::Connecting + } + )); + assert!( + effects + .iter() + .any(|e| matches!(e, Effect::StartStream { .. })) + ); +} + +#[test] +fn continuation_turn_without_new_tools_goes_idle() { + let mut fsm = new_fsm(); + fsm.handle(Event::UserSubmit("read".into())); + fsm.handle(Event::StreamStarted); + fsm.handle(Event::StreamToolCall { + id: "t1".into(), + name: "read_file".into(), + input: json!({"file_path": "/tmp/test.txt"}), + }); + fsm.handle(Event::StreamDone { + session_id: "".into(), + }); + fsm.handle(Event::PermissionResolved { + tool_id: "t1".into(), + response: PermissionResponse::Allowed, + }); + // Tool completes → continuation starts + fsm.handle(Event::ToolExecutionDone { + tool_id: "t1".into(), + outcome: crate::tools::ToolOutcome::Success("contents".into()), + preview: None, + }); + assert!(matches!( + fsm.state, + AgentState::Turn { + stream: StreamPhase::Connecting + } + )); + + // Continuation stream: text only, no new tools + fsm.handle(Event::StreamStarted); + fsm.handle(Event::StreamChunk("Here's the file.".into())); + let effects = fsm.handle(Event::StreamDone { + session_id: "".into(), + }); + + // Should go Idle, NOT start another continuation + assert_eq!(fsm.state, AgentState::Idle { confirmation: None }); + assert!(effects.iter().any(|e| matches!(e, Effect::Persist))); + assert!( + !effects + .iter() + .any(|e| matches!(e, Effect::StartStream { .. })) + ); +} + +#[test] +fn tool_done_before_stream_done_stays_in_turn() { + let mut fsm = new_fsm(); + fsm.handle(Event::UserSubmit("read".into())); + fsm.handle(Event::StreamStarted); + fsm.handle(Event::StreamToolCall { + id: "t1".into(), + name: "read_file".into(), + input: json!({"file_path": "/tmp/test.txt"}), + }); + fsm.handle(Event::PermissionResolved { + tool_id: "t1".into(), + response: PermissionResponse::Allowed, + }); + + // Tool completes but stream hasn't sent Done yet + let effects = fsm.handle(Event::ToolExecutionDone { + tool_id: "t1".into(), + outcome: crate::tools::ToolOutcome::Success("contents".into()), + preview: None, + }); + + // Still in Turn — stream phase is Streaming, not Done + assert!(matches!( + fsm.state, + AgentState::Turn { + stream: StreamPhase::Streaming { .. } + } + )); + assert!(effects.is_empty()); +} + +// ============================================================================ +// Cancel +// ============================================================================ + +#[test] +fn cancel_during_streaming_goes_idle() { + let mut fsm = new_fsm(); + fsm.handle(Event::UserSubmit("hello".into())); + fsm.handle(Event::StreamStarted); + fsm.handle(Event::StreamChunk("partial text".into())); + + let effects = fsm.handle(Event::Cancel); + + assert_eq!(fsm.state, AgentState::Idle { confirmation: None }); + assert!(effects.iter().any(|e| matches!(e, Effect::AbortStream))); + assert!(effects.iter().any(|e| matches!(e, Effect::Persist))); + // Partial text committed with cancel suffix + assert!(fsm.ctx.events.iter().any(|e| matches!( + e, + ConversationEvent::Text { content } if content.contains("[User cancelled") + ))); +} + +#[test] +fn stale_permission_resolved_after_cancel_is_ignored() { + let mut fsm = new_fsm(); + fsm.handle(Event::UserSubmit("read".into())); + fsm.handle(Event::StreamStarted); + fsm.handle(Event::StreamToolCall { + id: "t1".into(), + name: "read_file".into(), + input: json!({"file_path": "/tmp/test.txt"}), + }); + fsm.handle(Event::StreamDone { + session_id: "".into(), + }); + // Tool is in CheckingPermission, cancel happens before permission resolves + fsm.handle(Event::Cancel); + assert_eq!(fsm.state, AgentState::Idle { confirmation: None }); + + // Stale permission result arrives — tool is already Completed (cancelled) + let effects = fsm.handle(Event::PermissionResolved { + tool_id: "t1".into(), + response: PermissionResponse::Allowed, + }); + + // Should NOT emit ExecuteTool — the tool was cancelled + assert!(effects.is_empty()); +} + +#[test] +fn cancel_during_turn_with_pending_tools() { + let mut fsm = new_fsm(); + fsm.handle(Event::UserSubmit("hello".into())); + fsm.handle(Event::StreamStarted); + fsm.handle(Event::StreamToolCall { + id: "t1".into(), + name: "read_file".into(), + input: json!({"file_path": "/tmp/test.txt"}), + }); + fsm.handle(Event::StreamDone { + session_id: "".into(), + }); + fsm.handle(Event::PermissionResolved { + tool_id: "t1".into(), + response: PermissionResponse::Allowed, + }); + // Tool is Executing, stream is Done + + let effects = fsm.handle(Event::Cancel); + + assert_eq!(fsm.state, AgentState::Idle { confirmation: None }); + assert!( + effects + .iter() + .any(|e| matches!(e, Effect::AbortTool { .. })) + ); + // Error ToolResult injected + assert!(fsm.ctx.events.iter().any(|e| matches!( + e, + ConversationEvent::ToolResult { tool_use_id, is_error: true, .. } if tool_use_id == "t1" + ))); + // SystemContext about cancellation + assert!(fsm.ctx.events.iter().any(|e| matches!( + e, + ConversationEvent::SystemContext { content } if content.contains("cancelled") + ))); +} + +#[test] +fn stale_tool_result_after_cancel_is_ignored() { + let mut fsm = new_fsm(); + fsm.handle(Event::UserSubmit("hello".into())); + fsm.handle(Event::StreamStarted); + fsm.handle(Event::StreamToolCall { + id: "t1".into(), + name: "read_file".into(), + input: json!({"file_path": "/tmp/test.txt"}), + }); + fsm.handle(Event::StreamDone { + session_id: "".into(), + }); + fsm.handle(Event::PermissionResolved { + tool_id: "t1".into(), + response: PermissionResponse::Allowed, + }); + fsm.handle(Event::Cancel); + + // Stale event arrives + let effects = fsm.handle(Event::ToolExecutionDone { + tool_id: "t1".into(), + outcome: crate::tools::ToolOutcome::Success("contents".into()), + preview: None, + }); + + assert_eq!(fsm.state, AgentState::Idle { confirmation: None }); + assert!(effects.is_empty()); +} + +// ============================================================================ +// Confirmation +// ============================================================================ + +#[test] +fn dangerous_command_enters_confirmation() { + let mut fsm = new_fsm(); + // Simulate a dangerous command in history + fsm.ctx.events.push(ConversationEvent::ToolCall { + id: "sc1".into(), + name: "suggest_command".into(), + input: json!({"command": "rm -rf /", "description": "bad", "confidence": "high", "danger": "high"}), + }); + + let effects = fsm.handle(Event::ExecuteCommand); + + assert!(matches!( + fsm.state, + AgentState::Idle { + confirmation: Some(_) + } + )); + assert!( + effects + .iter() + .any(|e| matches!(e, Effect::ScheduleTimeout { .. })) + ); +} + +#[test] +fn second_execute_confirms_and_exits() { + let mut fsm = new_fsm(); + fsm.ctx.events.push(ConversationEvent::ToolCall { + id: "sc1".into(), + name: "suggest_command".into(), + input: json!({"command": "rm -rf /", "description": "bad", "confidence": "high", "danger": "high"}), + }); + fsm.handle(Event::ExecuteCommand); + + let effects = fsm.handle(Event::ExecuteCommand); + + assert!(effects.iter().any(|e| matches!( + e, + Effect::ExitApp(ExitAction::Execute(cmd)) if cmd == "rm -rf /" + ))); +} + +#[test] +fn confirmation_timeout_clears_confirmation() { + let mut fsm = new_fsm(); + fsm.ctx.events.push(ConversationEvent::ToolCall { + id: "sc1".into(), + name: "suggest_command".into(), + input: json!({"command": "rm -rf /", "description": "bad", "confidence": "high", "danger": "high"}), + }); + fsm.handle(Event::ExecuteCommand); + let timeout_id = match &fsm.state { + AgentState::Idle { + confirmation: Some(c), + } => c.timeout_id, + _ => panic!("expected confirmation"), + }; + + fsm.handle(Event::ConfirmationTimeout { timeout_id }); + + assert_eq!(fsm.state, AgentState::Idle { confirmation: None }); +} + +// ============================================================================ +// Error / Retry +// ============================================================================ + +#[test] +fn stream_error_goes_to_error_state() { + let mut fsm = new_fsm(); + fsm.handle(Event::UserSubmit("hello".into())); + fsm.handle(Event::StreamStarted); + + fsm.handle(Event::StreamError("network error".into())); + + assert_eq!(fsm.state, AgentState::Error("network error".to_string())); +} + +#[test] +fn retry_from_error_starts_new_stream() { + let mut fsm = new_fsm(); + fsm.handle(Event::UserSubmit("hello".into())); + fsm.handle(Event::StreamStarted); + fsm.handle(Event::StreamError("fail".into())); + + let effects = fsm.handle(Event::Retry); + + assert!(matches!( + fsm.state, + AgentState::Turn { + stream: StreamPhase::Connecting + } + )); + assert!( + effects + .iter() + .any(|e| matches!(e, Effect::StartStream { .. })) + ); +} + +// ============================================================================ +// Permission choices +// ============================================================================ + +#[test] +fn permission_deny_completes_turn_and_continues() { + let mut fsm = new_fsm(); + fsm.handle(Event::UserSubmit("read".into())); + fsm.handle(Event::StreamStarted); + fsm.handle(Event::StreamToolCall { + id: "t1".into(), + name: "read_file".into(), + input: json!({"file_path": "/tmp/test.txt"}), + }); + fsm.handle(Event::StreamDone { + session_id: "".into(), + }); + fsm.handle(Event::PermissionResolved { + tool_id: "t1".into(), + response: PermissionResponse::Ask, + }); + + let effects = fsm.handle(Event::PermissionUserChoice { + tool_id: "t1".into(), + choice: PermissionChoice::Deny, + }); + + // Turn should complete since all tools resolved and stream is done + // → continuation needed (there was a tool result to send back) + assert!(matches!( + fsm.state, + AgentState::Turn { + stream: StreamPhase::Connecting + } + )); + assert!( + effects + .iter() + .any(|e| matches!(e, Effect::StartStream { .. })) + ); + // Error result was injected + assert!(fsm.ctx.events.iter().any(|e| matches!( + e, + ConversationEvent::ToolResult { tool_use_id, is_error: true, .. } if tool_use_id == "t1" + ))); +} diff --git a/crates/atuin-ai/src/fsm/tools.rs b/crates/atuin-ai/src/fsm/tools.rs new file mode 100644 index 00000000..a6b2e9ae --- /dev/null +++ b/crates/atuin-ai/src/fsm/tools.rs @@ -0,0 +1,165 @@ +//! Tool lifecycle management within the FSM. +//! +//! Each tool call goes through an independent lifecycle. The ToolManager +//! tracks all tools in the current turn and provides the "all resolved" +//! check that gates turn completion. + +use crate::diff::{EditPreview, WritePreview}; +use crate::tools::ClientToolCall; + +/// Per-tool lifecycle state. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum ToolState { + /// Permission resolver is running asynchronously. + CheckingPermission, + /// Waiting for user to grant/deny via the permission dialog. + AwaitingPermission, + /// Actively executing. + Executing, + /// Execution completed (result injected into conversation). + Completed, + /// User denied permission (error result injected into conversation). + Denied, +} + +/// Cached preview data for rendering tool output. +#[derive(Debug, Clone)] +pub(crate) enum ToolPreviewData { + /// Shell command VT100 output lines. + Shell { + lines: Vec, + exit_code: Option, + interrupted: bool, + }, + /// File edit diff preview. + Edit(EditPreview), + /// File write content preview. + Write(WritePreview), +} + +/// A tracked tool call with its current lifecycle state. +#[derive(Debug, Clone)] +pub(crate) struct TrackedTool { + pub id: String, + pub tool: ClientToolCall, + pub state: ToolState, + /// Cached preview data for rendering (populated during/after execution). + pub preview: Option, +} + +impl TrackedTool { + /// Whether this tool has reached a terminal state. + pub fn is_resolved(&self) -> bool { + matches!(self.state, ToolState::Completed | ToolState::Denied) + } + + /// Extract shell preview data (for TurnBuilder compatibility). + pub fn shell_preview(&self) -> Option { + match &self.preview { + Some(ToolPreviewData::Shell { + lines, + exit_code, + interrupted, + }) => Some(crate::tools::ToolPreview { + lines: lines.clone(), + exit_code: *exit_code, + interrupted: *interrupted, + }), + _ => None, + } + } + + /// Extract edit diff preview (for TurnBuilder compatibility). + pub fn edit_preview(&self) -> Option<&EditPreview> { + match &self.preview { + Some(ToolPreviewData::Edit(p)) => Some(p), + _ => None, + } + } + + /// Extract write content preview (for TurnBuilder compatibility). + pub fn write_preview(&self) -> Option<&WritePreview> { + match &self.preview { + Some(ToolPreviewData::Write(p)) => Some(p), + _ => None, + } + } +} + +/// Manages tool call lifecycles for a single turn. +/// +/// Tools are inserted when received from the stream and progress through +/// their lifecycle independently. The manager provides aggregate queries +/// (all resolved, any awaiting permission, etc.) that the FSM uses for +/// state transitions. +#[derive(Debug, Clone, Default)] +pub(crate) struct ToolManager { + tools: Vec, +} + +impl ToolManager { + pub fn new() -> Self { + Self { tools: Vec::new() } + } + + /// Insert a new tool in CheckingPermission state. + pub fn insert(&mut self, id: String, tool: ClientToolCall) { + self.tools.push(TrackedTool { + id, + tool, + state: ToolState::CheckingPermission, + preview: None, + }); + } + + /// Look up a tool by ID. + pub fn get(&self, id: &str) -> Option<&TrackedTool> { + self.tools.iter().find(|t| t.id == id) + } + + /// Look up a tool mutably by ID. + pub fn get_mut(&mut self, id: &str) -> Option<&mut TrackedTool> { + self.tools.iter_mut().find(|t| t.id == id) + } + + /// True if all tools from the given set of IDs have reached a terminal state. + /// Returns true for an empty set (vacuously — no tools to wait for). + pub fn all_resolved(&self, tool_ids: &[String]) -> bool { + tool_ids + .iter() + .all(|id| self.get(id).is_some_and(|t| t.is_resolved())) + } + + /// Find the first tool awaiting user permission. + pub fn awaiting_permission(&self) -> Option<&TrackedTool> { + self.tools + .iter() + .find(|t| t.state == ToolState::AwaitingPermission) + } + + /// Get IDs of all non-resolved tools (for cancel). + pub fn pending_ids(&self) -> Vec { + self.tools + .iter() + .filter(|t| !t.is_resolved()) + .map(|t| t.id.clone()) + .collect() + } + + /// Get IDs of all currently executing tools (for interrupt/abort). + pub fn executing_ids(&self) -> Vec { + self.tools + .iter() + .filter(|t| t.state == ToolState::Executing) + .map(|t| t.id.clone()) + .collect() + } + + /// True if any tool has a shell preview with live output. + pub fn has_executing_preview(&self) -> bool { + self.tools.iter().any(|t| { + t.state == ToolState::Executing + && matches!(t.preview, Some(ToolPreviewData::Shell { .. })) + }) + } +} diff --git a/crates/atuin-ai/src/lib.rs b/crates/atuin-ai/src/lib.rs index afe9c1e4..540aece3 100644 --- a/crates/atuin-ai/src/lib.rs +++ b/crates/atuin-ai/src/lib.rs @@ -2,9 +2,11 @@ pub mod commands; pub(crate) mod context; pub(crate) mod context_window; pub(crate) mod diff; +pub(crate) mod driver; pub(crate) mod edit_permissions; pub(crate) mod event_serde; pub(crate) mod file_tracker; +pub(crate) mod fsm; pub(crate) mod permissions; pub(crate) mod session; pub(crate) mod snapshots; diff --git a/crates/atuin-ai/src/permissions/writer.rs b/crates/atuin-ai/src/permissions/writer.rs index b2bd9482..ffef404e 100644 --- a/crates/atuin-ai/src/permissions/writer.rs +++ b/crates/atuin-ai/src/permissions/writer.rs @@ -5,6 +5,7 @@ use eyre::Result; use crate::permissions::rule::Rule; /// Whether a rule should be added to the allow or deny list. +#[derive(Debug, Clone)] #[allow(dead_code)] pub(crate) enum RuleDisposition { Allow, diff --git a/crates/atuin-ai/src/snapshots.rs b/crates/atuin-ai/src/snapshots.rs index 6c7b0c9c..d46223a8 100644 --- a/crates/atuin-ai/src/snapshots.rs +++ b/crates/atuin-ai/src/snapshots.rs @@ -94,7 +94,7 @@ impl SnapshotStore { } /// Whether a file has already been snapshotted in this session. - #[expect(dead_code)] + #[cfg(test)] pub fn has_snapshot(&self, canonical_path: &Path) -> bool { let filename = sanitize_path(canonical_path); self.manifest.files.contains_key(&filename) diff --git a/crates/atuin-ai/src/stream.rs b/crates/atuin-ai/src/stream.rs index 24770abe..19d287e7 100644 --- a/crates/atuin-ai/src/stream.rs +++ b/crates/atuin-ai/src/stream.rs @@ -2,23 +2,16 @@ // SSE streaming // ─────────────────────────────────────────────────────────────────── -use std::sync::mpsc; - use atuin_client::settings::AiCapabilities; use atuin_common::tls::ensure_crypto_provider; use eventsource_stream::Eventsource; -use eye_declare::Handle; use eyre::{Context, Result}; use futures::StreamExt; use reqwest::Url; use reqwest::header::USER_AGENT; -use crate::{ - context::{AppContext, ClientContext}, - tools::ClientToolCall, - tui::{Session, events::AiTuiEvent}, -}; +use crate::context::ClientContext; static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION")); @@ -100,7 +93,7 @@ impl ChatRequest { } } -fn create_chat_stream( +pub(crate) fn create_chat_stream( hub_address: String, token: String, request: ChatRequest, @@ -244,149 +237,6 @@ fn create_chat_stream( }) } -// ─────────────────────────────────────────────────────────────────── -// Async streaming task — pushes updates to app state via Handle -// ─────────────────────────────────────────────────────────────────── - -pub(crate) async fn run_chat_stream( - handle: Handle, - tx: mpsc::Sender, - app_ctx: AppContext, - client_ctx: ClientContext, - request: ChatRequest, -) { - let capabilities = request.capabilities.clone(); - let stream = create_chat_stream( - app_ctx.endpoint.clone(), - app_ctx.token.clone(), - request, - client_ctx, - app_ctx.send_cwd, - app_ctx.last_command.clone(), - ); - futures::pin_mut!(stream); - - while let Some(event) = stream.next().await { - match event { - Ok(StreamFrame::Content(content)) => { - apply_content_frame(&handle, &tx, &capabilities, content); - } - Ok(StreamFrame::Control(control)) => { - let terminal = apply_control_frame(&handle, control); - if terminal { - break; - } - } - Err(e) => { - let msg = e.to_string(); - handle.update(move |state| { - state.streaming_error(msg); - }); - break; - } - } - } -} - -/// Apply a content frame to session state. -/// Control flow: always continues the stream. -fn apply_content_frame( - handle: &Handle, - tx: &mpsc::Sender, - capabilities: &[String], - content: StreamContent, -) { - match content { - StreamContent::TextChunk(text) => { - handle.update(move |state| { - state.conversation.append_streaming_text(&text); - }); - } - StreamContent::ToolCall { id, name, input } => { - if let Ok(tool) = ClientToolCall::try_from((name.as_str(), &input)) { - // Enforce capability gating: reject tool calls the client didn't advertise. - if let Some(required_cap) = tool.descriptor().capability - && !capabilities.iter().any(|c| c == required_cap) - { - tracing::warn!( - tool = name, - capability = required_cap, - "Rejecting tool call: capability not advertised" - ); - handle.update(move |state| { - state.add_tool_call(id.clone(), name, input.clone()); - state.conversation.add_tool_result( - id, - format!("Tool not enabled: capability '{required_cap}' was not advertised by this client"), - true, - false, - None, - ); - }); - return; - } - - // Client-side tool — add to tracker and conversation, queue permission check - let id_for_event = id.clone(); - handle.update(move |state| { - state.handle_client_tool_call(id_for_event, tool, input); - }); - let _ = tx.send(AiTuiEvent::CheckToolCallPermission(id)); - } else { - // Server-side tool — just add to conversation events - handle.update(move |state| { - state.add_tool_call(id, name, input); - }); - } - } - StreamContent::ToolResult { - tool_use_id, - content, - is_error, - remote, - content_length, - } => { - handle.update(move |state| { - state.conversation.add_tool_result( - tool_use_id, - content, - is_error, - remote, - content_length, - ); - }); - } - } -} - -/// Apply a control frame to session state. -/// Returns true if the stream should terminate. -fn apply_control_frame(handle: &Handle, control: StreamControl) -> bool { - match control { - StreamControl::StatusChanged(status) => { - handle.update(move |state| { - state.update_streaming_status(&status); - }); - false - } - StreamControl::Done { session_id } => { - handle.update(move |state| { - if !session_id.is_empty() { - state.conversation.store_session_id(session_id); - } - state.finalize_streaming(); - }); - true - } - StreamControl::Error(msg) => { - handle.update(move |state| { - state.streaming_error(msg); - }); - true - } - } -} - fn hub_url(base: &str, path: &str) -> Result { let base_with_slash = if base.ends_with('/') { base.to_string() diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs index 890ea734..530f0e83 100644 --- a/crates/atuin-ai/src/tools/mod.rs +++ b/crates/atuin-ai/src/tools/mod.rs @@ -58,6 +58,7 @@ fn path_matches_scope(path: &Path, scope: &str) -> bool { } /// Result of executing a client-side tool. +#[derive(Debug, Clone)] pub(crate) enum ToolOutcome { /// Simple success with a text result (used by Read, AtuinHistory). Success(String), @@ -136,176 +137,6 @@ pub(crate) struct ToolPreview { pub interrupted: bool, } -/// Lifecycle phase of a tracked tool call. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum ToolPhase { - CheckingPermissions, - AskingForPermission, - #[expect(dead_code)] - Denied(String), - #[expect(dead_code)] - Executing, - /// Shell command is executing with live preview output. - ExecutingWithPreview { - command: String, - /// Current VT100 screen lines (plain text, viewport-sized). - output_lines: Vec, - /// Exit code once the process completes. - exit_code: Option, - /// Whether the command was interrupted by the user. - interrupted: bool, - }, - /// Tool execution has completed. Preview is cached for rendering history. - Completed { - preview: Option, - }, -} - -/// A tracked tool call through its full lifecycle. -#[derive(Debug)] -pub(crate) struct TrackedTool { - pub id: String, - pub tool: ClientToolCall, - pub phase: ToolPhase, - /// Sender to interrupt a running shell command (only set during ExecutingWithPreview). - pub abort_tx: Option>, - /// Diff preview for completed edit tool calls. - pub edit_preview: Option, - /// Content preview for completed write tool calls. - pub write_preview: Option, -} - -impl TrackedTool { - pub(crate) fn target_dir(&self) -> Option<&Path> { - self.tool.target_dir() - } - - pub fn mark_asking(&mut self) { - self.phase = ToolPhase::AskingForPermission; - } - - pub fn mark_executing_preview(&mut self, command: String) { - self.phase = ToolPhase::ExecutingWithPreview { - command, - output_lines: Vec::new(), - exit_code: None, - interrupted: false, - }; - } - - pub fn complete(&mut self, preview: Option) { - self.phase = ToolPhase::Completed { preview }; - self.abort_tx = None; - } - - /// Extract the current preview, whether live or completed. - pub fn preview(&self) -> Option { - match &self.phase { - ToolPhase::ExecutingWithPreview { - output_lines, - exit_code, - interrupted, - .. - } => Some(ToolPreview { - lines: output_lines.clone(), - exit_code: *exit_code, - interrupted: *interrupted, - }), - ToolPhase::Completed { preview } => preview.clone(), - _ => None, - } - } -} - -/// Tracks all tool calls through their full lifecycle. -/// -/// Single source of truth for tool execution state. Entries persist after -/// completion so cached previews remain available for rendering history. -#[derive(Debug)] -pub(crate) struct ToolTracker { - tools: Vec, -} - -impl ToolTracker { - pub fn new() -> Self { - Self { tools: Vec::new() } - } - - /// Insert a new tool call in CheckingPermissions phase. - pub fn insert(&mut self, id: String, tool: ClientToolCall) { - self.tools.push(TrackedTool { - id, - tool, - phase: ToolPhase::CheckingPermissions, - abort_tx: None, - edit_preview: None, - write_preview: None, - }); - } - - pub fn get(&self, id: &str) -> Option<&TrackedTool> { - self.tools.iter().find(|t| t.id == id) - } - - pub fn get_mut(&mut self, id: &str) -> Option<&mut TrackedTool> { - self.tools.iter_mut().find(|t| t.id == id) - } - - /// Remove a tool by ID and return it. - #[expect(dead_code)] - pub fn remove(&mut self, id: &str) -> Option { - let pos = self.tools.iter().position(|t| t.id == id)?; - Some(self.tools.remove(pos)) - } - - /// True if any tool is still awaiting a permission decision. - #[expect(dead_code)] - pub fn has_unresolved(&self) -> bool { - self.tools.iter().any(|t| { - matches!( - t.phase, - ToolPhase::CheckingPermissions | ToolPhase::AskingForPermission - ) - }) - } - - /// True if any tool has not yet reached the Completed phase. - /// Use this to gate `ContinueAfterTools` — we must wait for all tools - /// (including those still executing) before resuming the conversation. - pub fn has_pending(&self) -> bool { - self.tools - .iter() - .any(|t| !matches!(t.phase, ToolPhase::Completed { .. })) - } - - /// True if any tool is currently executing with a preview. - pub fn has_executing_preview(&self) -> bool { - self.tools - .iter() - .any(|t| matches!(t.phase, ToolPhase::ExecutingWithPreview { .. })) - } - - /// Find the first tool that is asking for permission. - pub fn asking_for_permission(&self) -> Option<&TrackedTool> { - self.tools - .iter() - .find(|t| t.phase == ToolPhase::AskingForPermission) - } - - /// Find the first tool that is asking for permission (mutable). - #[expect(dead_code)] - pub fn asking_for_permission_mut(&mut self) -> Option<&mut TrackedTool> { - self.tools - .iter_mut() - .find(|t| t.phase == ToolPhase::AskingForPermission) - } - - /// Iterate mutably over all tracked tools. - pub fn iter_mut(&mut self) -> impl Iterator { - self.tools.iter_mut() - } -} - /// A tool call from the server, with parsed input parameters. #[derive(Debug, Clone)] pub(crate) enum ClientToolCall { @@ -359,6 +190,17 @@ impl ClientToolCall { } } + /// The resolved file path for this tool call, if it's a file-based tool. + /// Used to build scoped permission rules like `Write(/abs/path/to/file)`. + pub(crate) fn resolved_file_path(&self) -> Option { + match self { + ClientToolCall::Read(tool) => Some(tool.resolved_path()), + ClientToolCall::Edit(tool) => Some(tool.resolved_path()), + ClientToolCall::Write(tool) => Some(tool.resolved_path()), + _ => None, + } + } + pub(crate) fn matches_rule(&self, rule: &Rule) -> bool { match self { ClientToolCall::Read(tool) => tool.matches_rule(rule), @@ -449,14 +291,18 @@ impl TryFrom<&serde_json::Value> for ReadToolCall { } impl ReadToolCall { - fn execute(&self) -> ToolOutcome { - let mut path = self.path.clone(); - - if path.is_relative() - && let Ok(current_dir) = std::env::current_dir() - { - path = current_dir.join(path); + pub fn resolved_path(&self) -> PathBuf { + if self.path.is_relative() { + std::env::current_dir() + .map(|cwd| cwd.join(&self.path)) + .unwrap_or_else(|_| self.path.clone()) + } else { + self.path.clone() } + } + + pub fn execute(&self) -> ToolOutcome { + let path = self.resolved_path(); if !path.exists() { return ToolOutcome::Error(format!("Error: file does not exist: {}", path.display())); diff --git a/crates/atuin-ai/src/tui/components/atuin_ai.rs b/crates/atuin-ai/src/tui/components/atuin_ai.rs index c7227fbd..31dff1c3 100644 --- a/crates/atuin-ai/src/tui/components/atuin_ai.rs +++ b/crates/atuin-ai/src/tui/components/atuin_ai.rs @@ -5,11 +5,10 @@ //! Tab) are handled in the bubble phase so child components like the //! permission Select can consume them first. -use std::sync::mpsc; - use crossterm::event::{Event, KeyCode, KeyEvent, KeyEventKind, KeyModifiers}; use eye_declare::{Elements, EventResult, Hooks, component, props}; +use crate::commands::inline::DriverEventSender; use crate::tui::events::AiTuiEvent; use crate::tui::state::AppMode; @@ -28,7 +27,7 @@ pub(crate) struct AtuinAi { #[derive(Default)] pub(crate) struct AtuinAiState { - tx: Option>, + tx: Option, } #[component(props = AtuinAi, state = AtuinAiState, children = Elements)] @@ -38,7 +37,7 @@ fn atuin_ai( hooks: &mut Hooks, children: Elements, ) -> Elements { - hooks.use_context::>(|tx, _, state| { + hooks.use_context::(|tx, _, state| { state.tx = tx.cloned(); }); diff --git a/crates/atuin-ai/src/tui/components/input_box.rs b/crates/atuin-ai/src/tui/components/input_box.rs index 6e041418..6b81322c 100644 --- a/crates/atuin-ai/src/tui/components/input_box.rs +++ b/crates/atuin-ai/src/tui/components/input_box.rs @@ -6,7 +6,7 @@ //! //! On Enter, sends `AiTuiEvent::SubmitInput` via the context-provided channel. -use std::sync::{Arc, Mutex, mpsc}; +use std::sync::{Arc, Mutex}; use crossterm::event::KeyModifiers; use eye_declare::{Canvas, Elements, EventResult, Hooks, component, element, props}; @@ -19,6 +19,7 @@ use ratatui_core::{ }; use tui_textarea::TextArea; +use crate::commands::inline::DriverEventSender; use crate::tui::{events::AiTuiEvent, slash::SlashCommandSearchResult}; /// A bordered text input box backed by tui-textarea. @@ -41,7 +42,7 @@ pub(crate) struct InputBox { pub(crate) struct InputBoxState { textarea: Arc>>, - tx: Option>, + tx: Option, } impl Default for InputBoxState { @@ -97,10 +98,13 @@ fn input_box( state: &InputBoxState, hooks: &mut Hooks, ) -> Elements { - hooks.use_focusable(props.active); + // Always focusable so focus isn't lost when the permission Select is + // removed from the tree. The `active` prop controls visual state and + // whether keystrokes are processed, not focusability. + hooks.use_focusable(true); hooks.use_autofocus(); - hooks.use_context::>(|tx, _, state| { + hooks.use_context::(|tx, _, state| { state.tx = tx.cloned(); }); diff --git a/crates/atuin-ai/src/tui/components/select.rs b/crates/atuin-ai/src/tui/components/select.rs index 5abbe655..771d7830 100644 --- a/crates/atuin-ai/src/tui/components/select.rs +++ b/crates/atuin-ai/src/tui/components/select.rs @@ -1,10 +1,9 @@ -use std::sync::mpsc; - use crossterm::event::KeyCode; use eye_declare::{Elements, EventResult, Hooks, Span, Text, View, component, element, props}; use ratatui::style::Style; use typed_builder::TypedBuilder; +use crate::commands::inline::DriverEventSender; use crate::tui::events::AiTuiEvent; type OnSelectFn = Box Option + Send + Sync + 'static>; @@ -24,7 +23,7 @@ pub(crate) struct SelectOption { #[derive(Default)] pub(crate) struct PermissionSelectorState { selected_option: usize, - tx: Option>, + tx: Option, } #[props] @@ -42,7 +41,7 @@ pub(crate) fn permission_selector( hooks.use_focusable(true); hooks.use_autofocus(); - hooks.use_context::>(|tx, _, state| { + hooks.use_context::(|tx, _, state| { state.tx = tx.cloned(); }); diff --git a/crates/atuin-ai/src/tui/dispatch.rs b/crates/atuin-ai/src/tui/dispatch.rs deleted file mode 100644 index 46eebd9b..00000000 --- a/crates/atuin-ai/src/tui/dispatch.rs +++ /dev/null @@ -1,894 +0,0 @@ -use std::path::PathBuf; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::mpsc; - -use crate::context::{AppContext, ClientContext}; -use crate::context_window::ContextWindowBuilder; -use crate::permissions::check::PermissionResponse; -use crate::permissions::resolver::PermissionResolver; -use crate::permissions::rule::Rule; -use crate::permissions::writer::{self, RuleDisposition}; -use crate::session::SessionManager; -use crate::stream::{ChatRequest, run_chat_stream}; -use crate::tools::{ClientToolCall, ToolPhase}; -use crate::tui::events::{AiTuiEvent, PermissionResult}; -use crate::tui::state::{ConversationEvent, ExitAction, Session}; -use eye_declare::Handle; -use tokio::task::JoinHandle; - -/// Shared context for the dispatch loop. Bundles the references every -/// handler might need so `dispatch` doesn't forward a different subset -/// to each one. -pub(crate) struct DispatchContext<'a> { - pub handle: &'a Handle, - pub tx: &'a mpsc::Sender, - pub app_ctx: &'a AppContext, - pub client_ctx: &'a ClientContext, - pub session_mgr: &'a mut SessionManager, - /// Set by any handler that calls `h.exit()`. Read by `dispatch()` - /// to break the loop — without round-tripping through the handle, - /// which would hang if the TUI has already stopped. - pub exiting: Arc, -} - -/// Dispatch a single event. Returns `true` to keep the loop running, -/// `false` to shut down (after the final persist has completed). -pub(crate) fn dispatch(ctx: &mut DispatchContext, event: AiTuiEvent) -> bool { - match event { - AiTuiEvent::ContinueAfterTools => on_continue_after_tools(ctx), - AiTuiEvent::InputUpdated(input) => on_input_updated(ctx, input), - AiTuiEvent::SubmitInput(input) => on_submit_input(ctx, input), - AiTuiEvent::SlashCommand(cmd) => on_slash_command(ctx, cmd), - AiTuiEvent::CheckToolCallPermission(id) => on_check_tool_permission(ctx, id), - AiTuiEvent::SelectPermission(result) => on_select_permission(ctx, result), - AiTuiEvent::CancelGeneration => on_cancel_generation(ctx), - AiTuiEvent::ExecuteCommand => on_execute_command(ctx), - AiTuiEvent::CancelConfirmation => on_cancel_confirmation(ctx), - AiTuiEvent::InterruptToolExecution => on_interrupt_tool_execution(ctx), - AiTuiEvent::InsertCommand => on_insert_command(ctx), - AiTuiEvent::Retry => on_retry(ctx), - AiTuiEvent::Exit => on_exit(ctx), - } - - // Persist any new conversation events after each dispatch cycle. - persist_session(ctx); - - // The exiting flag is set by any handler that calls h.exit(). We - // read it here rather than querying state through the handle, - // because the TUI thread may have already stopped processing - // handle requests by this point. - !ctx.exiting.load(Ordering::Acquire) -} - -/// Persist new events, server session ID, file tracker, and edit permissions. -/// Called from the dispatch thread (sync), bridges to async via the tokio handle. -fn persist_session(ctx: &mut DispatchContext) { - let Ok((events, server_sid, file_tracker_json, edit_perms_json)) = ctx - .handle - .fetch(|state| { - ( - state.conversation.events.clone(), - state.conversation.session_id.clone(), - state.file_tracker.to_json().ok(), - state.edit_permissions.to_json().ok(), - ) - }) - .blocking_recv() - else { - return; - }; - - let rt = tokio::runtime::Handle::current(); - if let Err(e) = rt.block_on(ctx.session_mgr.persist_events(&events)) { - tracing::warn!("failed to persist session events: {e}"); - } - if let Some(ref sid) = server_sid - && let Err(e) = rt.block_on(ctx.session_mgr.persist_server_session_id(sid)) - { - tracing::warn!("failed to persist server session ID: {e}"); - } - if let Some(ref json) = file_tracker_json - && let Err(e) = rt.block_on( - ctx.session_mgr - .set_metadata(crate::file_tracker::METADATA_KEY, json), - ) - { - tracing::warn!("failed to persist file tracker: {e}"); - } - if let Some(ref json) = edit_perms_json - && let Err(e) = rt.block_on( - ctx.session_mgr - .set_metadata(crate::edit_permissions::METADATA_KEY, json), - ) - { - tracing::warn!("failed to persist edit permissions: {e}"); - } -} - -fn launch_stream(ctx: &DispatchContext, setup: impl FnOnce(&mut Session) + Send + 'static) { - let h2 = ctx.handle.clone(); - let tx2 = ctx.tx.clone(); - let app = ctx.app_ctx.clone(); - let cc = ctx.client_ctx.clone(); - let caps = ctx.app_ctx.capabilities.clone(); - ctx.handle.update(move |state| { - (setup)(state); - state.start_streaming(); - let messages = - ContextWindowBuilder::with_default_budget().build(&state.conversation.events); - let sid = state.conversation.session_id.clone(); - let request = ChatRequest::new(messages, sid, &caps, state.invocation_id.clone()); - let task: JoinHandle<()> = tokio::spawn(async move { - run_chat_stream(h2, tx2, app, cc, request).await; - }); - state.stream_abort = Some(task.abort_handle()); - }); -} - -fn on_continue_after_tools(ctx: &mut DispatchContext) { - launch_stream(ctx, |_state| {}); -} - -fn on_input_updated(ctx: &mut DispatchContext, input: String) { - let input_blank = input.is_empty(); - let slash_command = if input.starts_with('/') { - Some(input.trim_start_matches('/').to_string()) - } else { - None - }; - - ctx.handle.update(move |state| { - state.interaction.is_input_blank = input_blank; - state.interaction.slash_command_input = slash_command; - - if let Some(query) = state.interaction.slash_command_input.as_ref() { - let mut results = state.slash_registry.search_fuzzy(query); - - results.sort_by(|a, b| { - b.relevance - .partial_cmp(&a.relevance) - .unwrap_or(std::cmp::Ordering::Equal) - }); - - state.interaction.slash_command_search_results = results; - } else { - state.interaction.slash_command_search_results.clear(); - } - }); -} - -fn on_submit_input(ctx: &mut DispatchContext, input: String) { - ctx.handle.update(move |state| { - state.interaction.slash_command_input = None; - state.interaction.slash_command_search_results.clear(); - }); - - let input = input.trim().to_string(); - if input.is_empty() { - let h2 = ctx.handle.clone(); - let exiting = ctx.exiting.clone(); - ctx.handle.update(move |state| { - if state.conversation.has_any_command() { - state.exit_action = Some(ExitAction::Execute( - state.conversation.current_command().unwrap().to_string(), - )); - } else { - state.exit_action = Some(ExitAction::Cancel); - } - exiting.store(true, Ordering::Release); - h2.exit(); - }); - return; - } - - if input.starts_with('/') { - if input.trim() == "/new" { - on_new_session(ctx); - } else { - ctx.handle.update(move |state| { - state - .conversation - .handle_slash_command(&input, &state.slash_registry); - }); - } - return; - } - - // Start generation and spawn streaming task - launch_stream(ctx, |state| { - state.start_generating(input); - state.interaction.is_input_blank = true; - }); -} - -fn on_slash_command(ctx: &mut DispatchContext, command: String) { - ctx.handle.update(move |state| { - state - .conversation - .handle_slash_command(&command, &state.slash_registry); - }); -} - -// ─────────────────────────────────────────────────────────────────── -// Tool execution dispatch -// ─────────────────────────────────────────────────────────────────── - -/// Execute a tool call. Handles Shell tools (streaming with preview) and -/// non-shell tools (synchronous) uniformly. -fn execute_tool( - handle: &Handle, - tx: &mpsc::Sender, - tool_id: String, - tool: ClientToolCall, - db: &std::sync::Arc, -) { - match &tool { - ClientToolCall::Shell(shell_call) => { - let shell_call = shell_call.clone(); - execute_shell_tool(handle, tx, &tool_id, &shell_call); - } - ClientToolCall::Edit(edit_call) => { - let edit_call = edit_call.clone(); - execute_edit_tool(handle, tx, tool_id, edit_call); - } - ClientToolCall::Write(write_call) => { - let write_call = write_call.clone(); - execute_write_tool(handle, tx, tool_id, write_call); - } - _ => { - execute_simple_tool(handle, tx, tool_id, tool, db); - } - } -} - -/// Execute a non-shell tool and finish the tool call. -/// The ToolCall event is already in the conversation (added by handle_client_tool_call). -fn execute_simple_tool( - handle: &Handle, - tx: &mpsc::Sender, - tool_id: String, - tool: ClientToolCall, - db: &std::sync::Arc, -) { - let h = handle.clone(); - let tx = tx.clone(); - let db = db.clone(); - - tokio::spawn(async move { - let outcome = tool.execute(&db).await; - - // After a successful file read, capture tracking data for freshness - // checking. This re-stats the file to get content hash and mtime. - let read_tracking = if let ClientToolCall::Read(ref read_tool) = tool - && !outcome.is_error() - { - capture_read_tracking(&read_tool.path) - } else { - None - }; - - h.update(move |state| { - if let Some((path, content, mtime)) = read_tracking { - state.file_tracker.record_read(path, &content, mtime); - } - state.finish_tool_call(&tool_id, outcome); - if !state.tool_tracker.has_pending() { - let _ = tx.send(AiTuiEvent::ContinueAfterTools); - } - }); - }); -} - -/// Capture file content and mtime for the read tracker. -/// Returns None for directories or if the file can't be read. -fn capture_read_tracking( - path: &std::path::Path, -) -> Option<(std::path::PathBuf, Vec, std::time::SystemTime)> { - let resolved = if path.is_relative() { - std::env::current_dir().ok()?.join(path) - } else { - path.to_path_buf() - }; - if !resolved.is_file() { - return None; - } - let content = std::fs::read(&resolved).ok()?; - let mtime = std::fs::metadata(&resolved).ok()?.modified().ok()?; - Some((resolved, content, mtime)) -} - -/// Execute an edit_file tool call. -/// -/// Orchestrates snapshot → execute → tracker update. The snapshot and -/// tracker mutations happen via `h.update()` (on the TUI thread) since -/// they need mutable Session state. The actual file I/O (freshness check, -/// read, match, atomic write) runs in the tokio task. -fn execute_edit_tool( - handle: &Handle, - tx: &mpsc::Sender, - tool_id: String, - edit_call: crate::tools::EditToolCall, -) { - let h = handle.clone(); - let tx = tx.clone(); - - tokio::spawn(async move { - let resolved = edit_call.resolved_path(); - - // 1. Read the original file content (used for snapshot + diff). - let old_content = std::fs::read(&resolved).ok(); - - // 2. Snapshot the original file before editing. - if let Some(ref content) = old_content { - let snap_path = resolved.clone(); - let snap_content = content.clone(); - h.update(move |state| { - if let Some(ref mut store) = state.snapshot_store - && let Err(e) = store.ensure_snapshot(&snap_path, &snap_content) - { - tracing::warn!("failed to create file snapshot: {e}"); - } - }); - } - - // 3. Fetch a clone of the file tracker for freshness checking. - let Ok(tracker) = h.fetch(|state| state.file_tracker.clone()).await else { - let tc_id = tool_id.clone(); - h.update(move |state| { - state.finish_tool_call( - &tc_id, - crate::tools::ToolOutcome::Error("Internal error: TUI unavailable".into()), - ); - if !state.tool_tracker.has_pending() { - let _ = tx.send(AiTuiEvent::ContinueAfterTools); - } - }); - return; - }; - - // 4. Execute: freshness check → read → match → atomic write - let (outcome, new_bytes) = edit_call.execute(&resolved, &tracker); - - // 5. Compute diff preview on success - let edit_preview = if let Some(ref new_bytes) = new_bytes { - if let Some(ref old_bytes) = old_content { - let old_str = String::from_utf8_lossy(old_bytes); - let new_str = String::from_utf8_lossy(new_bytes); - let preview = crate::diff::EditPreview::compute(&old_str, &new_str); - if preview.hunks.is_empty() { - None - } else { - Some(preview) - } - } else { - None - } - } else { - None - }; - - // 6. Update tracker, store diff preview, and finish the tool call - let tc_id = tool_id; - h.update(move |state| { - if let Some(ref new_bytes) = new_bytes - && let Ok(mtime) = std::fs::metadata(&resolved).and_then(|m| m.modified()) - { - state - .file_tracker - .update_after_edit(&resolved, new_bytes, mtime); - } - if let Some(preview) = edit_preview - && let Some(tracked) = state.tool_tracker.get_mut(&tc_id) - { - tracked.edit_preview = Some(preview); - } - state.finish_tool_call(&tc_id, outcome); - if !state.tool_tracker.has_pending() { - let _ = tx.send(AiTuiEvent::ContinueAfterTools); - } - }); - }); -} - -/// Execute a write_file tool call. -/// -/// Snapshots the existing file (if any) before overwriting, writes atomically, -/// stores a content preview on the tracker, and updates the file tracker. -fn execute_write_tool( - handle: &Handle, - tx: &mpsc::Sender, - tool_id: String, - write_call: crate::tools::WriteToolCall, -) { - let h = handle.clone(); - let tx = tx.clone(); - - tokio::spawn(async move { - let resolved = write_call.resolved_path(); - - // 1. Snapshot the existing file before overwriting (if it exists). - if resolved.exists() - && let Ok(original_content) = std::fs::read(&resolved) - { - let snap_path = resolved.clone(); - h.update(move |state| { - if let Some(ref mut store) = state.snapshot_store - && let Err(e) = store.ensure_snapshot(&snap_path, &original_content) - { - tracing::warn!("failed to create file snapshot: {e}"); - } - }); - } - - // 2. Execute: check exists/overwrite, atomic write - let (outcome, new_bytes) = write_call.execute(&resolved); - - // 3. Build content preview on success - let write_preview = if new_bytes.is_some() { - Some(crate::diff::WritePreview::from_content(&write_call.content)) - } else { - None - }; - - // 4. Update tracker, store preview, and finish - let tc_id = tool_id; - h.update(move |state| { - if let Some(ref new_bytes) = new_bytes - && let Ok(mtime) = std::fs::metadata(&resolved).and_then(|m| m.modified()) - { - state - .file_tracker - .update_after_edit(&resolved, new_bytes, mtime); - } - if let Some(preview) = write_preview - && let Some(tracked) = state.tool_tracker.get_mut(&tc_id) - { - tracked.write_preview = Some(preview); - } - state.finish_tool_call(&tc_id, outcome); - if !state.tool_tracker.has_pending() { - let _ = tx.send(AiTuiEvent::ContinueAfterTools); - } - }); - }); -} - -/// Execute a shell tool with streaming VT100 preview. -fn execute_shell_tool( - handle: &Handle, - tx: &mpsc::Sender, - tool_id: &str, - shell_call: &crate::tools::ShellToolCall, -) { - let h = handle.clone(); - let tx = tx.clone(); - let shell_call = shell_call.clone(); - let command = shell_call.command.clone(); - let tc_id = tool_id.to_string(); - - // 1. Set up channels for streaming output and interruption - let (output_tx, mut output_rx) = tokio::sync::mpsc::channel::>(32); - let (abort_tx, abort_rx) = tokio::sync::oneshot::channel::<()>(); - - // 2. Mark as executing with preview and store the abort sender on the tracker entry - let tc_id_setup = tc_id.clone(); - h.update(move |state| { - if let Some(tracked) = state.tool_tracker.get_mut(&tc_id_setup) { - tracked.mark_executing_preview(command); - tracked.abort_tx = Some(abort_tx); - } - }); - - // 3. Spawn a task to consume output updates and feed them to state - let h_output = h.clone(); - let preview_id = tc_id.clone(); - let output_task = tokio::spawn(async move { - while let Some(lines) = output_rx.recv().await { - let id = preview_id.clone(); - h_output.update(move |state| { - if let Some(tracked) = state.tool_tracker.get_mut(&id) - && let ToolPhase::ExecutingWithPreview { - ref mut output_lines, - .. - } = tracked.phase - { - *output_lines = lines; - } - }); - } - }); - - // 4. Spawn the streaming execution task - let tc_id_finish = tc_id; - tokio::spawn(async move { - let outcome = - crate::tools::execute_shell_command_streaming(&shell_call, output_tx, abort_rx).await; - - // Wait for the output task to finish so the final preview lines are captured - let _ = output_task.await; - - h.update(move |state| { - state.finish_tool_call(&tc_id_finish, outcome); - if !state.tool_tracker.has_pending() { - let _ = tx.send(AiTuiEvent::ContinueAfterTools); - } - }); - }); -} - -// ─────────────────────────────────────────────────────────────────── -// Permission handlers -// ─────────────────────────────────────────────────────────────────── - -fn on_check_tool_permission(ctx: &mut DispatchContext, id: String) { - let h2 = ctx.handle.clone(); - let tx_for_task = ctx.tx.clone(); - let db = ctx.app_ctx.history_db.clone(); - - tokio::spawn(async move { - let id_for_error = id.clone(); - let result = check_tool_permission_inner(&h2, &tx_for_task, &db, id).await; - - // If the inner function didn't handle the tool (returned an error message), - // finish the tool call with that error so the conversation doesn't stall. - if let Err(error_msg) = result { - let tx = tx_for_task.clone(); - h2.update(move |state| { - state.finish_tool_call(&id_for_error, crate::tools::ToolOutcome::Error(error_msg)); - if !state.tool_tracker.has_pending() { - let _ = tx.send(AiTuiEvent::ContinueAfterTools); - } - }); - } - }); -} - -/// Inner permission check that returns Err(message) if the tool call should be -/// finished with an error. Returns Ok(()) if the tool was handled (executed, -/// denied, or sent to the permission UI). -async fn check_tool_permission_inner( - h2: &Handle, - tx: &mpsc::Sender, - db: &std::sync::Arc, - id: String, -) -> Result<(), String> { - // 1. Fetch the tracked tool's data - let id_for_fetch = id.clone(); - let (tool, target_dir) = h2 - .fetch(move |state| { - state - .tool_tracker - .get(&id_for_fetch) - .map(|t| (t.tool.clone(), t.target_dir().map(PathBuf::from))) - }) - .await - .map_err(|e| format!("Internal error fetching tool state: {e}"))? - .ok_or_else(|| "Internal error: tool not found in tracker".to_string())?; - - // 2. For edit tools, check session-scoped permission grants before - // hitting the filesystem-based resolver. A valid grant means the user - // already approved this file recently. - if let ClientToolCall::Edit(ref edit) = tool { - let resolved = edit.resolved_path(); - let has_grant = h2 - .fetch(move |state| state.edit_permissions.has_valid_grant(&resolved)) - .await - .unwrap_or(false); - - if has_grant { - execute_tool(h2, tx, id, tool, db); - return Ok(()); - } - } - - // 3. Resolve working directory - let working_dir = target_dir - .or_else(|| std::env::current_dir().ok()) - .ok_or_else(|| "Could not determine working directory".to_string())?; - - // 4. Create permission resolver and check - let resolver = PermissionResolver::new(working_dir) - .await - .map_err(|e| format!("Permission check failed: {e}"))?; - - let response = resolver - .check(&tool) - .await - .map_err(|e| format!("Permission check failed: {e}"))?; - - // 5. Handle response — all paths here handle the tool, so return Ok - let id_clone = id.clone(); - match response { - PermissionResponse::Allowed => { - execute_tool(h2, tx, id, tool, db); - } - PermissionResponse::Denied => { - let tx = tx.clone(); - h2.update(move |state| { - state.finish_tool_call( - &id_clone, - crate::tools::ToolOutcome::Error( - "Permission denied on the user's system".to_string(), - ), - ); - if !state.tool_tracker.has_pending() { - let _ = tx.send(AiTuiEvent::ContinueAfterTools); - } - }); - } - PermissionResponse::Ask => { - h2.update(move |state| { - if let Some(tracked) = state.tool_tracker.get_mut(&id_clone) { - tracked.mark_asking(); - } - }); - } - } - - Ok(()) -} - -fn on_select_permission(ctx: &mut DispatchContext, permission: PermissionResult) { - let tx = ctx.tx.clone(); - let h2 = ctx.handle.clone(); - - match permission { - PermissionResult::Allow => { - // Fetch the tool that's asking for permission, then execute it - let db = ctx.app_ctx.history_db.clone(); - tokio::spawn(async move { - let Ok(Some((tool_id, tool))) = h2 - .fetch(move |state| { - state - .tool_tracker - .asking_for_permission() - .map(|t| (t.id.clone(), t.tool.clone())) - }) - .await - else { - return; - }; - - execute_tool(&h2, &tx, tool_id, tool, &db); - }); - } - PermissionResult::AllowFileForSession => { - // Cache a session-scoped, time-limited grant for this file - let db = ctx.app_ctx.history_db.clone(); - tokio::spawn(async move { - let Ok(Some((tool_id, tool))) = h2 - .fetch(move |state| { - state - .tool_tracker - .asking_for_permission() - .map(|t| (t.id.clone(), t.tool.clone())) - }) - .await - else { - return; - }; - - if let ClientToolCall::Edit(ref edit) = tool { - let resolved = edit.resolved_path(); - h2.update(move |state| { - state.edit_permissions.grant(resolved); - }); - } - - execute_tool(&h2, &tx, tool_id, tool, &db); - }); - } - PermissionResult::AlwaysAllowInDir => { - let db = ctx.app_ctx.history_db.clone(); - let git_root = ctx.app_ctx.git_root.clone(); - tokio::spawn(async move { - let Ok(Some((tool_id, tool))) = h2 - .fetch(move |state| { - state - .tool_tracker - .asking_for_permission() - .map(|t| (t.id.clone(), t.tool.clone())) - }) - .await - else { - return; - }; - - // Write the rule to the project (git root) or cwd permissions file - let project_root = git_root - .or_else(|| std::env::current_dir().ok()) - .unwrap_or_else(|| PathBuf::from(".")); - let file_path = writer::project_permissions_path(&project_root); - let rule = Rule { - tool: tool.rule_name().to_string(), - scope: None, - }; - if let Err(e) = writer::write_rule(&file_path, &rule, RuleDisposition::Allow).await - { - tracing::error!("Failed to write project permission rule: {e}"); - } - - execute_tool(&h2, &tx, tool_id, tool, &db); - }); - } - PermissionResult::AlwaysAllow => { - let db = ctx.app_ctx.history_db.clone(); - tokio::spawn(async move { - let Ok(Some((tool_id, tool))) = h2 - .fetch(move |state| { - state - .tool_tracker - .asking_for_permission() - .map(|t| (t.id.clone(), t.tool.clone())) - }) - .await - else { - return; - }; - - // Write the rule to the global permissions file - let file_path = writer::global_permissions_path(); - let rule = Rule { - tool: tool.rule_name().to_string(), - scope: None, - }; - if let Err(e) = writer::write_rule(&file_path, &rule, RuleDisposition::Allow).await - { - tracing::error!("Failed to write global permission rule: {e}"); - } - - execute_tool(&h2, &tx, tool_id, tool, &db); - }); - } - PermissionResult::Deny => { - h2.update(move |state| { - let Some(tracked) = state.tool_tracker.asking_for_permission() else { - return; - }; - let tool_id = tracked.id.clone(); - - state.finish_tool_call( - &tool_id, - crate::tools::ToolOutcome::Error("Permission denied by the user".to_string()), - ); - if !state.tool_tracker.has_pending() { - let _ = tx.send(AiTuiEvent::ContinueAfterTools); - } - }); - } - } -} - -// ─────────────────────────────────────────────────────────────────── -// Other handlers -// ─────────────────────────────────────────────────────────────────── - -fn on_cancel_generation(ctx: &mut DispatchContext) { - ctx.handle.update(|state| match state.interaction.mode { - crate::tui::state::AppMode::Generating => { - state.cancel_generation(); - } - crate::tui::state::AppMode::Streaming => { - state.cancel_streaming(); - } - _ => {} - }); -} - -fn on_execute_command(ctx: &mut DispatchContext) { - let h2 = ctx.handle.clone(); - let exiting = ctx.exiting.clone(); - ctx.handle.update(move |state| { - let cmd = state.conversation.current_command().map(|c| c.to_string()); - if let Some(cmd) = cmd { - if state.conversation.is_current_command_dangerous() - && !state.interaction.confirmation_pending - { - state.interaction.confirmation_pending = true; - } else { - state.interaction.confirmation_pending = false; - state.exit_action = Some(ExitAction::Execute(cmd)); - exiting.store(true, Ordering::Release); - h2.exit(); - } - } - }); -} - -fn on_cancel_confirmation(ctx: &mut DispatchContext) { - ctx.handle.update(move |state| { - state.interaction.confirmation_pending = false; - }); -} - -fn on_insert_command(ctx: &mut DispatchContext) { - let h2 = ctx.handle.clone(); - let exiting = ctx.exiting.clone(); - ctx.handle.update(move |state| { - let cmd = state.conversation.current_command().map(|c| c.to_string()); - if let Some(cmd) = cmd { - state.interaction.confirmation_pending = false; - state.exit_action = Some(ExitAction::Insert(cmd)); - exiting.store(true, Ordering::Release); - h2.exit(); - } - }); -} - -fn on_retry(ctx: &mut DispatchContext) { - launch_stream(ctx, |state| { - state.retry(); - }); -} - -fn on_new_session(ctx: &mut DispatchContext) { - let rt = tokio::runtime::Handle::current(); - - if let Err(e) = rt.block_on(ctx.session_mgr.archive_and_reset()) { - tracing::warn!("failed to start new session: {e}"); - return; - } - - ctx.handle.update(|state| { - // Move the current invocation's visible events to the archived view - // so they remain on screen but are no longer sent to the API. - let visible_events: Vec = - state.conversation.events[state.view_start_index..].to_vec(); - state.archived_view_events.extend(visible_events); - - state.conversation.events.clear(); - state.conversation.session_id = None; - state.tool_tracker = crate::tools::ToolTracker::new(); - state.view_start_index = 0; - state.is_resumed = false; - state.last_event_time = None; - state - .conversation - .events - .push(ConversationEvent::OutOfBandOutput { - name: "System".to_string(), - command: Some("/new".to_string()), - content: "Started a new session.".to_string(), - }); - }); -} - -fn on_exit(ctx: &mut DispatchContext) { - let h2 = ctx.handle.clone(); - let exiting = ctx.exiting.clone(); - ctx.handle.update(move |state| { - if let Some(abort) = state.stream_abort.take() { - abort.abort(); - } - state.exit_action = Some(ExitAction::Cancel); - exiting.store(true, Ordering::Release); - h2.exit(); - }); -} - -fn on_interrupt_tool_execution(ctx: &mut DispatchContext) { - ctx.handle.update(move |state| { - // Find executing previews, send interrupt, and mark as interrupted - for tracked in state.tool_tracker.iter_mut() { - if let ToolPhase::ExecutingWithPreview { - ref mut interrupted, - ref mut exit_code, - .. - } = tracked.phase - { - *interrupted = true; - if exit_code.is_none() { - *exit_code = Some(-1); - } - // Send interrupt signal via the tracker entry's abort channel - if let Some(abort_tx) = tracked.abort_tx.take() { - let _ = abort_tx.send(()); - } - } - } - - // The spawned execution task will handle finalizing and sending - // ContinueAfterTools when the process exits. Input mode is already active. - }); -} diff --git a/crates/atuin-ai/src/tui/events.rs b/crates/atuin-ai/src/tui/events.rs index 969f6ae5..abcb1bd9 100644 --- a/crates/atuin-ai/src/tui/events.rs +++ b/crates/atuin-ai/src/tui/events.rs @@ -13,12 +13,8 @@ pub(crate) enum AiTuiEvent { /// User entered a slash command (e.g. "/help") #[allow(unused)] SlashCommand(String), - /// Check the permission for a tool call - CheckToolCallPermission(String), /// User selected a permission SelectPermission(PermissionResult), - /// Continue after client tools have completed - ContinueAfterTools, /// Cancel active generation or streaming (Esc during Generating/Streaming) CancelGeneration, /// Execute the suggested command diff --git a/crates/atuin-ai/src/tui/mod.rs b/crates/atuin-ai/src/tui/mod.rs index 05a040a1..9727f362 100644 --- a/crates/atuin-ai/src/tui/mod.rs +++ b/crates/atuin-ai/src/tui/mod.rs @@ -1,8 +1,7 @@ pub(crate) mod components; -pub(crate) mod dispatch; pub(crate) mod events; pub(crate) mod slash; pub(crate) mod state; pub(crate) mod view; -pub(crate) use state::{ConversationEvent, Session, events_to_messages}; +pub(crate) use state::{ConversationEvent, events_to_messages}; diff --git a/crates/atuin-ai/src/tui/state.rs b/crates/atuin-ai/src/tui/state.rs index af1ebffe..e008bd3c 100644 --- a/crates/atuin-ai/src/tui/state.rs +++ b/crates/atuin-ai/src/tui/state.rs @@ -1,36 +1,10 @@ -//! Domain state types for the TUI application +//! Core state types for the conversation protocol. //! -//! This module contains the core state types that represent the application's -//! domain model. Conversation events match the API protocol format. +//! ConversationEvent and events_to_messages are the canonical representations +//! used by both the FSM and the context window builder. AppMode is used by +//! the view layer for component prop derivation. -use tokio::task::AbortHandle; - -use crate::{ - tools::{ClientToolCall, ToolOutcome, ToolTracker}, - tui::slash::{SlashCommandRegistry, SlashCommandSearchResult}, -}; - -/// Streaming status indicators from server -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum StreamingStatus { - Processing, - Searching, - Thinking, - WaitingForTools, -} - -impl StreamingStatus { - pub(crate) fn from_status_str(s: &str) -> Self { - match s { - "processing" => Self::Processing, - "searching" => Self::Searching, - "waiting_for_tools" => Self::WaitingForTools, - _ => Self::Thinking, - } - } -} - -/// Conversation event types matching the API protocol +/// Conversation event types matching the API protocol. #[derive(Debug, Clone)] pub(crate) enum ConversationEvent { /// User message (what the user typed) @@ -54,7 +28,7 @@ pub(crate) enum ConversationEvent { /// Approximate content length for token estimation of remote results. content_length: Option, }, - /// Out-of-band output from the system - not sent to the server + /// Out-of-band output from the system — not sent to the server OutOfBandOutput { name: String, command: Option, @@ -67,7 +41,6 @@ pub(crate) enum ConversationEvent { impl ConversationEvent { /// Whether this event represents actual conversation content sent to the API. - /// Used to determine if a resumed session has meaningful context. pub(crate) fn is_api_content(&self) -> bool { match self { ConversationEvent::UserMessage { .. } => true, @@ -79,7 +52,7 @@ impl ConversationEvent { } } - /// Extract command from a suggest_command tool call + /// Extract command from a suggest_command tool call. pub(crate) fn as_command(&self) -> Option<&str> { if let ConversationEvent::ToolCall { name, input, .. } = self && name == "suggest_command" @@ -90,7 +63,9 @@ impl ConversationEvent { } } -/// Application mode for key handling and footer text. +/// Application mode for key handling and component props. +/// +/// Derived from AgentState in the view layer via `From<&AgentState>`. #[derive(Debug, Clone, PartialEq, Eq, Copy)] pub(crate) enum AppMode { /// User is typing input @@ -103,167 +78,6 @@ pub(crate) enum AppMode { Error, } -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum ExitAction { - /// Run the command - Execute(String), - /// Insert command without running - Insert(String), - /// User canceled - Cancel, -} - -/// Owned event log and session ID -#[derive(Debug)] -pub(crate) struct Conversation { - /// Conversation events (source of truth, matches API protocol) - pub events: Vec, - /// Session ID from server - pub session_id: Option, -} - -impl Conversation { - pub fn new() -> Self { - Self { - events: Vec::new(), - session_id: None, - } - } - - /// Get the most recent command from events - pub fn current_command(&self) -> Option<&str> { - self.events.iter().rev().find_map(|e| e.as_command()) - } - - /// Check if any turn in the conversation has a command - pub fn has_any_command(&self) -> bool { - self.events.iter().any(|e| { - if let ConversationEvent::ToolCall { name, input, .. } = e { - name == "suggest_command" && input.get("command").and_then(|v| v.as_str()).is_some() - } else { - false - } - }) - } - - /// Check if the most recent command is marked dangerous - pub fn is_current_command_dangerous(&self) -> bool { - self.events - .iter() - .rev() - .find_map(|e| { - if let ConversationEvent::ToolCall { name, input, .. } = e - && name == "suggest_command" - { - let danger_level = input - .get("danger") - .and_then(|v| v.as_str()) - .unwrap_or("low"); - return Some( - danger_level == "high" || danger_level == "medium" || danger_level == "med", - ); - } - None - }) - .unwrap_or(false) - } - - /// Get a mutable reference to the last Text event's content (the streaming buffer). - fn streaming_content_mut(&mut self) -> Option<&mut String> { - self.events.iter_mut().rev().find_map(|e| { - if let ConversationEvent::Text { content } = e { - Some(content) - } else { - None - } - }) - } - - /// Remove trailing empty Text events from the events list - fn remove_empty_trailing_text(&mut self) { - while let Some(ConversationEvent::Text { content }) = self.events.last() { - if content.is_empty() { - self.events.pop(); - } else { - break; - } - } - } - - /// Append text chunk during streaming (mutates the last Text event in-place) - pub fn append_streaming_text(&mut self, chunk: &str) { - // If the last event isn't a Text, we need a fresh buffer - // (e.g. after a tool call removed the empty streaming buffer) - if !matches!(self.events.last(), Some(ConversationEvent::Text { .. })) { - self.events.push(ConversationEvent::Text { - content: String::new(), - }); - } - - if let Some(content) = self.streaming_content_mut() { - if content.is_empty() { - // First chunk(s): trim leading whitespace - let trimmed = chunk.trim_start(); - if !trimmed.is_empty() { - content.push_str(trimmed); - } - } else { - content.push_str(chunk); - } - } - } - - /// Add a tool result event during streaming - pub fn add_tool_result( - &mut self, - tool_use_id: String, - content: String, - is_error: bool, - remote: bool, - content_length: Option, - ) { - self.events.push(ConversationEvent::ToolResult { - tool_use_id, - content, - is_error, - remote, - content_length, - }); - } - - /// Store session ID from server response - pub fn store_session_id(&mut self, session_id: String) { - self.session_id = Some(session_id); - } - - /// Handle a slash command - pub fn handle_slash_command(&mut self, command: &str, registry: &SlashCommandRegistry) { - match command.trim() { - "/help" => { - let commands = registry - .get_commands() - .iter() - .map(|cmd| format!("- `/{}` - {}", cmd.name, cmd.description)) - .collect::>() - .join("\n"); - - let content = include_str!("./content/help.md").replace("{commands}", &commands); - - self.events.push(ConversationEvent::OutOfBandOutput { - name: "System".to_string(), - command: Some("/help".to_string()), - content, - }); - } - _ => self.events.push(ConversationEvent::OutOfBandOutput { - name: "System".to_string(), - command: None, - content: (format!("Unknown command: {command}")), - }), - } - } -} - /// Convert a slice of conversation events to Claude API message format. /// /// This is the canonical event-to-message conversion, used by the context window @@ -284,13 +98,9 @@ pub(crate) fn events_to_messages(events: &[ConversationEvent]) -> Vec { - // Skip empty text events (e.g. streaming buffer before - // any data arrived). i += 1; } ConversationEvent::Text { content } => { - // Check if the next event(s) are ToolCalls — if so, combine - // into a single assistant message with mixed content blocks. let next_is_tool_call = events .get(i + 1) .is_some_and(|e| matches!(e, ConversationEvent::ToolCall { .. })); @@ -332,8 +142,6 @@ pub(crate) fn events_to_messages(events: &[ConversationEvent]) -> Vec { - // ToolCalls without preceding Text (shouldn't normally happen, - // but handle defensively) let mut tool_uses = Vec::new(); while i < events.len() { if let ConversationEvent::ToolCall { @@ -389,7 +197,6 @@ pub(crate) fn events_to_messages(events: &[ConversationEvent]) -> Vec { - // Out-of-band output is not sent to the server i += 1; } ConversationEvent::SystemContext { content } => { @@ -404,301 +211,3 @@ pub(crate) fn events_to_messages(events: &[ConversationEvent]) -> Vec, - /// Search results for the current slash command input - pub slash_command_search_results: Vec, - /// True when user has pressed Enter once on a dangerous command - pub confirmation_pending: bool, - /// Current streaming status - pub streaming_status: Option, - /// Whether current turn was interrupted by user - pub was_interrupted: bool, - /// Current error message - pub error: Option, -} - -impl Interaction { - pub fn new() -> Self { - Self { - mode: AppMode::Input, - is_input_blank: false, - slash_command_input: None, - slash_command_search_results: Vec::new(), - confirmation_pending: false, - streaming_status: None, - was_interrupted: false, - error: None, - } - } -} - -/// Top-level session state -/// -/// Decomposed into `Conversation` (event log + session ID) and -/// `Interaction` (ephemeral UI state). Session methods that cross -/// both sub-structs live here. -#[derive(Debug)] -pub(crate) struct Session { - pub conversation: Conversation, - pub interaction: Interaction, - /// Tracks all tool calls through their full lifecycle. - pub tool_tracker: ToolTracker, - /// Whether the session is running inside a git project (for permission UI labels). - pub in_git_project: bool, - /// Exit action (set when exiting) - pub exit_action: Option, - /// Abort handle for the active streaming task, if any - pub stream_abort: Option, - /// Index into `conversation.events` where the current TUI invocation starts. - /// Events before this index are historical context sent to the API but not - /// rendered in the TUI. - pub view_start_index: usize, - /// Whether this session was resumed from a prior invocation. - pub is_resumed: bool, - /// Time of the last event from a previous invocation when resuming a session - pub last_event_time: Option>, - /// Events from archived sessions that are still rendered on screen but no - /// longer sent to the API. Accumulated by `/new` commands within a single - /// TUI lifetime. - pub archived_view_events: Vec, - /// A registry of available slash commands - pub slash_registry: SlashCommandRegistry, - /// The unique ID for this invocation - pub invocation_id: String, - /// Tracks which files have been read, for freshness checking before edits. - pub file_tracker: crate::file_tracker::FileReadTracker, - /// Session-scoped edit permission grants (per-file, time-limited). - pub edit_permissions: crate::edit_permissions::EditPermissionCache, - /// Backs up files before the first edit in a session. - pub snapshot_store: Option, -} - -impl Session { - pub fn new(in_git_project: bool, invocation_id: Option) -> Self { - Self { - conversation: Conversation::new(), - interaction: Interaction::new(), - tool_tracker: ToolTracker::new(), - in_git_project, - exit_action: None, - stream_abort: None, - view_start_index: 0, - is_resumed: false, - last_event_time: None, - archived_view_events: Vec::new(), - slash_registry: Default::default(), - invocation_id: invocation_id.unwrap_or_else(|| uuid::Uuid::now_v7().to_string()), - file_tracker: Default::default(), - edit_permissions: Default::default(), - snapshot_store: None, - } - } - - // ===== Generation lifecycle methods ===== - - /// Start generating from submitted input - pub fn start_generating(&mut self, input: String) { - self.conversation - .events - .push(ConversationEvent::UserMessage { content: input }); - self.interaction.mode = AppMode::Generating; - } - - /// Generation error occurred - #[expect(dead_code)] - pub fn generation_error(&mut self, error: String) { - self.interaction.error = Some(error); - self.interaction.mode = AppMode::Error; - } - - /// Cancel during generation - pub fn cancel_generation(&mut self) { - if let Some(abort) = self.stream_abort.take() { - abort.abort(); - } - if let Some(ConversationEvent::UserMessage { .. }) = self.conversation.events.last() { - self.conversation.events.pop(); - } - self.interaction.mode = AppMode::Input; - } - - // ===== Streaming lifecycle methods ===== - - /// Start streaming response. - /// Pushes an empty Text event so the UI immediately creates an agent - /// turn (which renders the spinner). The empty event is skipped by - /// `events_to_messages` so it never becomes an empty assistant turn - /// in the API payload. - pub fn start_streaming(&mut self) { - self.conversation.events.push(ConversationEvent::Text { - content: String::new(), - }); - self.interaction.streaming_status = None; - self.interaction.was_interrupted = false; - self.interaction.mode = AppMode::Streaming; - } - - /// Update streaming status from SSE event - pub fn update_streaming_status(&mut self, status: &str) { - self.interaction.streaming_status = Some(StreamingStatus::from_status_str(status)); - } - - /// Cancel streaming with context preservation - pub fn cancel_streaming(&mut self) { - if let Some(abort) = self.stream_abort.take() { - abort.abort(); - } - self.interaction.was_interrupted = true; - - if let Some(content) = self.conversation.streaming_content_mut() { - let trimmed = content.trim_start().to_string(); - if trimmed.is_empty() { - // Remove the empty text event - *content = String::new(); - } else { - *content = format!("{trimmed}\n\n[User cancelled this generation]"); - } - } - // Remove trailing empty Text events - self.conversation.remove_empty_trailing_text(); - - self.interaction.streaming_status = None; - self.interaction.confirmation_pending = false; - self.interaction.mode = AppMode::Input; - } - - /// Add a tool call event during streaming. - /// The current streaming text is already in events, so we just push the tool call. - pub fn add_tool_call(&mut self, id: String, name: String, input: serde_json::Value) { - // Trim the streaming text event - if let Some(content) = self.conversation.streaming_content_mut() { - let trimmed = content.trim_start().to_string(); - *content = trimmed; - } - self.conversation.remove_empty_trailing_text(); - - let is_suggest_command = name == "suggest_command"; - self.conversation - .events - .push(ConversationEvent::ToolCall { id, name, input }); - - if is_suggest_command { - self.interaction.streaming_status = None; - self.interaction.mode = AppMode::Input; - } - } - - /// Finalize streaming — trim the accumulated text and change mode - pub fn finalize_streaming(&mut self) { - if let Some(content) = self.conversation.streaming_content_mut() { - let trimmed = content.trim_start().to_string(); - *content = trimmed; - } - self.conversation.remove_empty_trailing_text(); - self.interaction.streaming_status = None; - self.interaction.mode = AppMode::Input; - } - - /// Streaming error — remove the partial text event - pub fn streaming_error(&mut self, error: String) { - self.conversation.remove_empty_trailing_text(); - self.interaction.error = Some(error); - self.interaction.mode = AppMode::Error; - } - - pub(crate) fn handle_client_tool_call( - &mut self, - id: String, - tool: ClientToolCall, - input: serde_json::Value, - ) { - let desc = tool.descriptor(); - let name = desc.canonical_names[0].to_string(); - - self.tool_tracker.insert(id.clone(), tool); - - // Add the ToolCall event to the conversation immediately so it appears - // in the view. Preview data is sourced from tool_tracker. - self.conversation - .events - .push(ConversationEvent::ToolCall { id, name, input }); - - // Client tool calls can only happen at the last part of a turn - self.interaction.streaming_status = None; - self.interaction.mode = AppMode::Input; - } - - /// Retry after error - pub fn retry(&mut self) { - self.interaction.error = None; - self.interaction.mode = AppMode::Generating; - } - - // ===== Tool lifecycle methods ===== - - /// Finish a tool call: transition tracker to Completed, push ToolResult to conversation. - /// - /// For shell commands, captures the final preview from the ExecutingWithPreview phase - /// and patches exit_code/interrupted from the authoritative ToolOutcome. - pub fn finish_tool_call(&mut self, tool_id: &str, outcome: ToolOutcome) { - let mut preview = self.tool_tracker.get(tool_id).and_then(|t| t.preview()); - - // Patch preview with authoritative outcome data (handles race where - // final VT100 update hasn't been applied yet). - if let Some(ref mut p) = preview - && let ToolOutcome::Structured { - exit_code, - interrupted, - .. - } = &outcome - { - p.interrupted = *interrupted; - if p.exit_code.is_none() { - p.exit_code = *exit_code; - } - } - - // Transition tracker entry to Completed - if let Some(tracked) = self.tool_tracker.get_mut(tool_id) { - tracked.complete(preview); - } - - let content = outcome.format_for_llm(); - let is_error = outcome.is_error(); - self.conversation - .add_tool_result(tool_id.to_string(), content, is_error, false, None); - } - - /// Get the footer text for current mode - pub fn footer_text(&self) -> &'static str { - match self.interaction.mode { - AppMode::Input => { - if self.conversation.has_any_command() && self.interaction.is_input_blank { - if self.interaction.confirmation_pending { - "[Enter] Confirm dangerous command [Esc] Cancel" - } else { - "[Enter] Execute suggested command [Tab] Insert Command" - } - } else { - "[Enter] Send [Shift+Enter] New line [Esc] Exit" - } - } - AppMode::Generating | AppMode::Streaming => "[Esc] Cancel", - AppMode::Error => "[Enter]/[r] Retry [Esc] Exit", - } - } - - /// Check if the application is exiting - pub fn is_exiting(&self) -> bool { - self.exit_action.is_some() - } -} diff --git a/crates/atuin-ai/src/tui/view/mod.rs b/crates/atuin-ai/src/tui/view/mod.rs index 6e13e406..d40a44d4 100644 --- a/crates/atuin-ai/src/tui/view/mod.rs +++ b/crates/atuin-ai/src/tui/view/mod.rs @@ -5,7 +5,9 @@ use eye_declare::{ }; use ratatui_core::style::{Color, Modifier, Style}; -use crate::tools::{ClientToolCall, HistorySearchFilterMode, ToolPreview, TrackedTool}; +use crate::driver::ViewState; +use crate::fsm::{AgentState, StreamPhase}; +use crate::tools::{ClientToolCall, HistorySearchFilterMode, ToolPreview}; use crate::tui::components::select::SelectOption; use crate::tui::components::session_continue::SessionContinue; use crate::tui::events::{AiTuiEvent, PermissionResult}; @@ -14,10 +16,23 @@ use super::components::atuin_ai::AtuinAi; use super::components::input_box::InputBox; use super::components::markdown::Markdown; use super::components::select::Select; -use super::state::{AppMode, Session}; +use super::state::AppMode; mod turn; +impl From<&AgentState> for AppMode { + fn from(state: &AgentState) -> Self { + match state { + AgentState::Idle { .. } => AppMode::Input, + AgentState::Turn { + stream: StreamPhase::Connecting, + } => AppMode::Generating, + AgentState::Turn { .. } => AppMode::Streaming, + AgentState::Error(_) => AppMode::Error, + } + } +} + /// Build the element tree from current state. /// /// Layout (top to bottom): @@ -26,28 +41,27 @@ mod turn; /// - Error display (if in error state) /// - Spacer /// - Input box (bordered, with contextual keybindings) -pub(crate) fn ai_view(state: &Session) -> Elements { - let mut turn_builder = turn::TurnBuilder::new(&state.tool_tracker); +pub(crate) fn ai_view(state: &ViewState) -> Elements { + let mut turn_builder = turn::TurnBuilder::new(&state.tools); - for event in &state.archived_view_events { + for event in &state.archived_events { turn_builder.add_event(event); } - for event in &state.conversation.events[state.view_start_index..] { + for event in &state.visible_events { turn_builder.add_event(event); } let turns = turn_builder.build(); - let busy = state.interaction.mode == AppMode::Streaming - || state.interaction.mode == AppMode::Generating; + let busy = state.is_busy(); let last_index = turns.len().saturating_sub(1); element! { AtuinAi( - mode: state.interaction.mode, - has_command: state.conversation.has_any_command(), - is_input_blank: state.interaction.is_input_blank, - pending_confirmation: state.interaction.confirmation_pending, - has_executing_preview: state.tool_tracker.has_executing_preview(), + mode: AppMode::from(&state.agent_state), + has_command: state.has_command(), + is_input_blank: state.is_input_blank, + pending_confirmation: state.has_confirmation(), + has_executing_preview: state.tools.has_executing_preview(), ) { #(if state.is_resumed && (!state.is_exiting() || !turns.is_empty()) { SessionContinue(key: "continuation-notice", continued_at: state.last_event_time) @@ -77,6 +91,15 @@ pub(crate) fn ai_view(state: &Session) -> Elements { } }) + #(if let AgentState::Error(ref msg) = state.agent_state { + View(key: "error-display", padding_left: Cells::from(2), padding_top: Cells::from(1)) { + Text { + Span(text: "Error: ", style: Style::default().fg(Color::Red).add_modifier(Modifier::BOLD)) + Span(text: msg, style: Style::default().fg(Color::Red)) + } + } + }) + #(if !state.is_exiting() { #(input_view(state)) }) @@ -84,11 +107,10 @@ pub(crate) fn ai_view(state: &Session) -> Elements { } } -fn input_view(state: &Session) -> Elements { - let asking_tool = state.tool_tracker.asking_for_permission(); +fn input_view(state: &ViewState) -> Elements { + let asking_tool = state.tools.awaiting_permission(); let in_git_project = state.in_git_project; let slash_results = state - .interaction .slash_command_search_results .iter() .take(4) @@ -107,12 +129,12 @@ fn input_view(state: &Session) -> Elements { title: "Generate a command or ask a question", title_right: "Atuin AI", footer: state.footer_text(), - active: state.interaction.mode == AppMode::Input && !state.interaction.confirmation_pending, + active: state.is_input_active(), slash_suggestion: first_slash_result.cloned() ) - #(if state.interaction.is_input_blank && state.conversation.has_any_command() && state.interaction.mode == AppMode::Input { - #(if state.interaction.confirmation_pending { + #(if state.is_input_blank && state.has_command() && state.is_input_active() { + #(if state.has_confirmation() { Text { Span(text: "[Enter] Confirm dangerous command [Esc] Cancel", style: Style::default().fg(Color::Gray)) } } else { Text { Span(text: "[Enter] Execute suggested command [Tab] Insert Command", style: Style::default().fg(Color::Gray)) } @@ -140,7 +162,7 @@ fn input_view(state: &Session) -> Elements { } } -fn tool_call_view(tool_call: &TrackedTool, in_git_project: bool) -> Elements { +fn tool_call_view(tool_call: &crate::fsm::tools::TrackedTool, in_git_project: bool) -> Elements { let verb = tool_call.tool.descriptor().display_verb; let tool_desc = match &tool_call.tool { ClientToolCall::Read(tool) => tool.path.display().to_string(), diff --git a/crates/atuin-ai/src/tui/view/turn.rs b/crates/atuin-ai/src/tui/view/turn.rs index 6c3d5c29..98ae5eff 100644 --- a/crates/atuin-ai/src/tui/view/turn.rs +++ b/crates/atuin-ai/src/tui/view/turn.rs @@ -1,7 +1,8 @@ use std::path::PathBuf; +use crate::fsm::tools::ToolManager; use crate::tools::descriptor; -use crate::tools::{ClientToolCall, HistorySearchFilterMode, ToolPreview, ToolTracker}; +use crate::tools::{ClientToolCall, HistorySearchFilterMode, ToolPreview}; use crate::tui::ConversationEvent; /// Server-sent danger level for a suggested command @@ -210,12 +211,12 @@ pub(crate) enum UiTurn { pub(crate) struct TurnBuilder<'a> { turns: Vec, current_turn: Option, - tracker: &'a ToolTracker, + tracker: &'a ToolManager, } /// A struct to iteratively build [UiTurn] events from [ConversationEvent]s. impl<'a> TurnBuilder<'a> { - pub(crate) fn new(tracker: &'a ToolTracker) -> Self { + pub(crate) fn new(tracker: &'a ToolManager) -> Self { Self { turns: Vec::new(), current_turn: None, @@ -441,18 +442,18 @@ impl<'a> TurnBuilder<'a> { match &tracked.tool { ClientToolCall::Shell(shell) => ToolRenderData::Shell { command: shell.command.clone(), - preview: tracked.preview(), + preview: tracked.shell_preview(), }, ClientToolCall::Read(read) => ToolRenderData::FileRead { path: read.path.clone(), }, ClientToolCall::Edit(edit) => ToolRenderData::FileEdit { path: edit.path.clone(), - preview: tracked.edit_preview.clone(), + preview: tracked.edit_preview().cloned(), }, ClientToolCall::Write(write) => ToolRenderData::FileWrite { path: write.path.clone(), - preview: tracked.write_preview.clone(), + preview: tracked.write_preview().cloned(), }, ClientToolCall::AtuinHistory(history) => ToolRenderData::HistorySearch { query: history.query.clone(), -- cgit v1.3.1