diff options
Diffstat (limited to 'crates/atuin-ai')
| -rw-r--r-- | crates/atuin-ai/src/driver.rs | 40 | ||||
| -rw-r--r-- | crates/atuin-ai/src/fsm/effects.rs | 17 | ||||
| -rw-r--r-- | crates/atuin-ai/src/fsm/events.rs | 2 | ||||
| -rw-r--r-- | crates/atuin-ai/src/fsm/mod.rs | 157 | ||||
| -rw-r--r-- | crates/atuin-ai/src/fsm/tests.rs | 349 | ||||
| -rw-r--r-- | crates/atuin-ai/src/fsm/tools.rs | 17 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tools/mod.rs | 29 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tui/view/mod.rs | 12 |
8 files changed, 570 insertions, 53 deletions
diff --git a/crates/atuin-ai/src/driver.rs b/crates/atuin-ai/src/driver.rs index 2d610203..3acb9798 100644 --- a/crates/atuin-ai/src/driver.rs +++ b/crates/atuin-ai/src/driver.rs @@ -506,20 +506,20 @@ fn execute_effect(effect: &Effect, ctx: DriverContext) { ) .await; - let preview = if let crate::tools::ToolOutcome::Structured { - exit_code, - interrupted, - .. - } = &outcome - { - Some(ToolPreviewData::Shell { - lines: vec![], - exit_code: *exit_code, - interrupted: *interrupted, - }) - } else { - None - }; + let preview = + if let crate::tools::ToolOutcome::Structured { exit_code, .. } = + &outcome + { + Some(ToolPreviewData::Shell { + lines: vec![], + exit_code: *exit_code, + // Reason is set by the FSM in handle_tool_done + // based on whether it was a user interrupt or timeout. + interrupted: None, + }) + } else { + None + }; let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone { tool_id, @@ -694,13 +694,23 @@ fn execute_effect(effect: &Effect, ctx: DriverContext) { Effect::ScheduleTimeout { timeout_id, duration, + kind, } => { let timeout_id = *timeout_id; let duration = *duration; + let kind = kind.clone(); let tx = tx.clone(); tokio::spawn(async move { tokio::time::sleep(duration).await; - let _ = tx.send(DriverEvent::Fsm(Event::ConfirmationTimeout { timeout_id })); + use crate::fsm::effects::TimeoutKind; + let event = match kind { + TimeoutKind::Confirmation => Event::ConfirmationTimeout { timeout_id }, + TimeoutKind::ToolExecution { tool_id } => Event::ToolExecutionTimeout { + timeout_id, + tool_id, + }, + }; + let _ = tx.send(DriverEvent::Fsm(event)); }); } diff --git a/crates/atuin-ai/src/fsm/effects.rs b/crates/atuin-ai/src/fsm/effects.rs index ede72a42..306f1401 100644 --- a/crates/atuin-ai/src/fsm/effects.rs +++ b/crates/atuin-ai/src/fsm/effects.rs @@ -61,14 +61,27 @@ pub(crate) enum Effect { ArchiveSession, // ─── Timers ───────────────────────────────────────────────── - /// Schedule a timer that will fire ConfirmationTimeout after delay. - ScheduleTimeout { timeout_id: u64, duration: Duration }, + /// Schedule a timer that fires an event after the given delay. + ScheduleTimeout { + timeout_id: u64, + duration: Duration, + kind: TimeoutKind, + }, // ─── Exit ─────────────────────────────────────────────────── /// Exit the application with the given action. ExitApp(ExitAction), } +/// What kind of timeout was scheduled. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum TimeoutKind { + /// Dangerous command confirmation dialog auto-dismiss. + Confirmation, + /// Shell tool execution timeout — abort the tool if it's still running. + ToolExecution { tool_id: String }, +} + /// What to do when exiting the TUI. #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) enum ExitAction { diff --git a/crates/atuin-ai/src/fsm/events.rs b/crates/atuin-ai/src/fsm/events.rs index 62a624bf..6fecda08 100644 --- a/crates/atuin-ai/src/fsm/events.rs +++ b/crates/atuin-ai/src/fsm/events.rs @@ -80,6 +80,8 @@ pub(crate) enum Event { // ─── Timers ───────────────────────────────────────────────── /// Confirmation timeout expired. ConfirmationTimeout { timeout_id: u64 }, + /// Shell tool execution timeout expired. + ToolExecutionTimeout { timeout_id: u64, tool_id: String }, // ─── Session management ───────────────────────────────────── /// User ran /new to start a fresh session. diff --git a/crates/atuin-ai/src/fsm/mod.rs b/crates/atuin-ai/src/fsm/mod.rs index d32d6d7b..25de41f3 100644 --- a/crates/atuin-ai/src/fsm/mod.rs +++ b/crates/atuin-ai/src/fsm/mod.rs @@ -13,12 +13,14 @@ pub(crate) mod tools; #[cfg(test)] mod tests; +use std::collections::HashMap; + use serde_json::Value; use crate::context_window::ContextWindowBuilder; use crate::tui::state::ConversationEvent; -use effects::{Effect, ExitAction, PermissionTarget}; +use effects::{Effect, ExitAction, PermissionTarget, TimeoutKind}; use events::{Event, PermissionChoice, PermissionResponse}; use tools::{ToolManager, ToolState}; @@ -98,6 +100,9 @@ pub(crate) struct AgentContext { /// Tool IDs that belong to the current turn. Cleared on continuation start. /// Used to determine whether a turn needs continuation (has unprocessed results). current_turn_tool_ids: Vec<String>, + /// Maps timeout_id → tool_id for active tool execution timeouts. + /// Cleaned up when a tool completes naturally, so stale timeouts are ignored. + tool_timeout_ids: HashMap<u64, String>, /// Counter for generating unique timeout IDs. next_timeout_id: u64, /// Capabilities advertised to the server. @@ -150,6 +155,7 @@ impl AgentFsm { current_response: String::new(), tools: ToolManager::new(), current_turn_tool_ids: Vec::new(), + tool_timeout_ids: HashMap::new(), next_timeout_id: 0, capabilities, invocation_id, @@ -179,6 +185,7 @@ impl AgentFsm { current_response: String::new(), tools: ToolManager::new(), current_turn_tool_ids: Vec::new(), + tool_timeout_ids: HashMap::new(), next_timeout_id: 0, capabilities, invocation_id, @@ -224,6 +231,7 @@ impl AgentFsm { vec![Effect::ScheduleTimeout { timeout_id, duration: std::time::Duration::from_secs(5), + kind: TimeoutKind::Confirmation, }] } else { vec![Effect::ExitApp(ExitAction::Execute(cmd))] @@ -414,6 +422,7 @@ impl AgentFsm { .into_iter() .map(|tool_id| Effect::AbortTool { tool_id }) .collect(); + self.ctx.tool_timeout_ids.clear(); self.state = AgentState::Error(e); abort_effects } @@ -462,7 +471,7 @@ impl AgentFsm { tracked.preview = Some(tools::ToolPreviewData::Shell { lines, exit_code, - interrupted: false, + interrupted: None, }); } } @@ -471,11 +480,26 @@ impl AgentFsm { (AgentState::Turn { .. }, Event::InterruptTools) => { let ids = self.ctx.tools.executing_ids(); + for id in &ids { + if let Some(tracked) = self.ctx.tools.get_mut(id) { + tracked.interrupt_reason = Some(tools::InterruptReason::User); + } + // Clear any pending execution timeout for this tool + self.ctx.tool_timeout_ids.retain(|_, tid| tid != id); + } ids.into_iter() .map(|tool_id| Effect::AbortTool { tool_id }) .collect() } + ( + AgentState::Turn { .. }, + Event::ToolExecutionTimeout { + timeout_id, + tool_id, + }, + ) => self.handle_tool_execution_timeout(timeout_id, tool_id), + // ─── Cancel during Turn ───────────────────────────────────── (AgentState::Turn { stream }, Event::Cancel) => { let mut effects = Vec::new(); @@ -515,6 +539,9 @@ impl AgentFsm { }); } + // Clear timeout mappings — stale timeouts will be ignored by the guard + self.ctx.tool_timeout_ids.clear(); + self.state = AgentState::Idle { confirmation: None }; effects.push(Effect::Persist); effects @@ -694,7 +721,7 @@ impl AgentFsm { PermissionResponse::Allowed | PermissionResponse::SessionGranted => { tracked.state = ToolState::Executing; let tool = tracked.tool.clone(); - vec![Effect::ExecuteTool { tool_id, tool }] + self.emit_execute_tool(tool_id, tool) } PermissionResponse::Ask => { tracked.state = ToolState::AwaitingPermission; @@ -732,15 +759,12 @@ impl AgentFsm { PermissionChoice::Allow => { tracked.state = ToolState::Executing; let tool = tracked.tool.clone(); - vec![Effect::ExecuteTool { tool_id, tool }] + self.emit_execute_tool(tool_id, tool) } PermissionChoice::AllowForSession => { tracked.state = ToolState::Executing; let tool = tracked.tool.clone(); - let mut effects = vec![Effect::ExecuteTool { - tool_id, - tool: tool.clone(), - }]; + let mut effects = self.emit_execute_tool(tool_id, tool.clone()); if let Some(path) = tool.resolved_file_path() { effects.push(Effect::CacheSessionGrant { path }); } @@ -753,14 +777,13 @@ impl AgentFsm { tool: tool.rule_name().to_string(), scope: None, // project file provides the scoping }; - vec![ - Effect::ExecuteTool { tool_id, tool }, - Effect::WritePermissionRule { - target: PermissionTarget::Project, - rule, - disposition: crate::permissions::writer::RuleDisposition::Allow, - }, - ] + let mut effects = self.emit_execute_tool(tool_id, tool); + effects.push(Effect::WritePermissionRule { + target: PermissionTarget::Project, + rule, + disposition: crate::permissions::writer::RuleDisposition::Allow, + }); + effects } PermissionChoice::AlwaysAllow => { tracked.state = ToolState::Executing; @@ -772,14 +795,13 @@ impl AgentFsm { tool: tool.rule_name().to_string(), scope, }; - vec![ - Effect::ExecuteTool { tool_id, tool }, - Effect::WritePermissionRule { - target: PermissionTarget::Global, - rule, - disposition: crate::permissions::writer::RuleDisposition::Allow, - }, - ] + let mut effects = self.emit_execute_tool(tool_id, tool); + effects.push(Effect::WritePermissionRule { + target: PermissionTarget::Global, + rule, + disposition: crate::permissions::writer::RuleDisposition::Allow, + }); + effects } PermissionChoice::Deny => { tracked.state = ToolState::Denied; @@ -813,6 +835,19 @@ impl AgentFsm { tracked.state = ToolState::Completed; + // If the FSM tagged this tool with an interrupt reason (user or timeout), + // use it; otherwise derive from the outcome's interrupted flag. + let reason = tracked.interrupt_reason.take().or({ + if let crate::tools::ToolOutcome::Structured { + interrupted: true, .. + } = &outcome + { + Some(tools::InterruptReason::User) + } else { + None + } + }); + // Merge shell preview: the final ToolExecutionDone carries exit_code/interrupted // but has empty lines (the live lines were accumulated via ToolPreviewUpdate). // Preserve the accumulated lines and fold in the terminal metadata. @@ -825,20 +860,29 @@ impl AgentFsm { }), Some(tools::ToolPreviewData::Shell { exit_code: final_exit, - interrupted: final_interrupted, .. }), ) => { *exit_code = final_exit; - *interrupted = final_interrupted; + *interrupted = reason.clone(); } - (_, Some(p)) => { + (_, Some(mut p)) => { + if let tools::ToolPreviewData::Shell { + ref mut interrupted, + .. + } = p + { + *interrupted = reason.clone(); + } tracked.preview = Some(p); } _ => {} } - let content = outcome.format_for_llm(); + // Clean up any pending execution timeout for this tool + self.ctx.tool_timeout_ids.retain(|_, tid| tid != &tool_id); + + let content = outcome.format_for_llm(reason.as_ref()); let is_error = outcome.is_error(); self.ctx.events.push(ConversationEvent::ToolResult { tool_use_id: tool_id, @@ -851,6 +895,63 @@ impl AgentFsm { self.check_turn_completion() } + /// Handle a tool execution timeout. Aborts the tool if it's still running. + fn handle_tool_execution_timeout(&mut self, timeout_id: u64, tool_id: String) -> Vec<Effect> { + // Guard: only act if this timeout is still registered (not cleaned up by natural completion) + if self.ctx.tool_timeout_ids.remove(&timeout_id).is_none() { + return vec![]; + } + + let Some(tracked) = self.ctx.tools.get_mut(&tool_id) else { + return vec![]; + }; + + if tracked.is_resolved() { + return vec![]; + } + + // Tag the tool so handle_tool_done can distinguish timeout from user interrupt. + // Only shell tools have entries in tool_timeout_ids, so this is always Shell. + let timeout_secs = match &tracked.tool { + crate::tools::ClientToolCall::Shell(s) => s.timeout_secs, + _ => unreachable!("only shell tools have execution timeouts"), + }; + tracked.interrupt_reason = Some(tools::InterruptReason::Timeout(timeout_secs)); + + // Abort the tool — the driver sends the interrupt signal via oneshot, + // and execute_shell_command_streaming returns a Structured outcome with + // interrupted: true and partial stdout/stderr. This flows through the + // normal ToolExecutionDone path. + vec![Effect::AbortTool { tool_id }] + } + + /// Emit effects to begin executing a tool. For shell commands, also schedules + /// an execution timeout based on the LLM-specified timeout_secs. + fn emit_execute_tool( + &mut self, + tool_id: String, + tool: crate::tools::ClientToolCall, + ) -> Vec<Effect> { + let mut effects = vec![Effect::ExecuteTool { + tool_id: tool_id.clone(), + tool: tool.clone(), + }]; + + if let crate::tools::ClientToolCall::Shell(ref shell) = tool { + let timeout_id = self.ctx.next_timeout_id(); + self.ctx + .tool_timeout_ids + .insert(timeout_id, tool_id.clone()); + effects.push(Effect::ScheduleTimeout { + timeout_id, + duration: std::time::Duration::from_secs(shell.timeout_secs), + kind: TimeoutKind::ToolExecution { tool_id }, + }); + } + + effects + } + /// Check if the turn is complete (stream done + all tools resolved). /// If so, either continue the conversation or go Idle. fn check_turn_completion(&mut self) -> Vec<Effect> { diff --git a/crates/atuin-ai/src/fsm/tests.rs b/crates/atuin-ai/src/fsm/tests.rs index 9fc404c0..51c23915 100644 --- a/crates/atuin-ai/src/fsm/tests.rs +++ b/crates/atuin-ai/src/fsm/tests.rs @@ -539,3 +539,352 @@ fn permission_deny_completes_turn_and_continues() { ConversationEvent::ToolResult { tool_use_id, is_error: true, .. } if tool_use_id == "t1" ))); } + +// ============================================================================ +// Shell execution timeouts +// ============================================================================ + +fn fsm_with_shell() -> AgentFsm { + AgentFsm::new( + vec![ + "client_v1_read_file".to_string(), + "client_v1_execute_shell_command".to_string(), + ], + "test-inv".to_string(), + ) +} + +fn shell_tool_call_event(id: &str) -> Event { + Event::StreamToolCall { + id: id.into(), + name: "execute_shell_command".into(), + input: json!({ + "command": "sleep 999", + "shell": "bash", + "timeout": 60, + "description": "test" + }), + } +} + +#[test] +fn shell_tool_schedules_execution_timeout() { + let mut fsm = fsm_with_shell(); + fsm.handle(Event::UserSubmit("run something".into())); + fsm.handle(Event::StreamStarted); + fsm.handle(shell_tool_call_event("t1")); + + let effects = fsm.handle(Event::PermissionResolved { + tool_id: "t1".into(), + response: PermissionResponse::Allowed, + }); + + // Should have ExecuteTool + ScheduleTimeout + assert!( + effects + .iter() + .any(|e| matches!(e, Effect::ExecuteTool { .. })) + ); + assert!(effects.iter().any(|e| matches!( + e, + Effect::ScheduleTimeout { kind: effects::TimeoutKind::ToolExecution { tool_id }, .. } + if tool_id == "t1" + ))); + assert!(!fsm.ctx.tool_timeout_ids.is_empty()); +} + +#[test] +fn read_tool_does_not_schedule_timeout() { + let mut fsm = new_fsm(); + fsm.handle(Event::UserSubmit("read".into())); + fsm.handle(Event::StreamStarted); + fsm.handle(Event::StreamToolCall { + id: "t1".into(), + name: "read_file".into(), + input: json!({"file_path": "/tmp/test.txt"}), + }); + + let effects = fsm.handle(Event::PermissionResolved { + tool_id: "t1".into(), + response: PermissionResponse::Allowed, + }); + + assert!( + effects + .iter() + .any(|e| matches!(e, Effect::ExecuteTool { .. })) + ); + assert!( + !effects + .iter() + .any(|e| matches!(e, Effect::ScheduleTimeout { .. })) + ); + assert!(fsm.ctx.tool_timeout_ids.is_empty()); +} + +#[test] +fn tool_completion_clears_timeout_mapping() { + let mut fsm = fsm_with_shell(); + fsm.handle(Event::UserSubmit("run".into())); + fsm.handle(Event::StreamStarted); + fsm.handle(shell_tool_call_event("t1")); + fsm.handle(Event::PermissionResolved { + tool_id: "t1".into(), + response: PermissionResponse::Allowed, + }); + fsm.handle(Event::StreamDone { + session_id: "s1".into(), + }); + + assert!(!fsm.ctx.tool_timeout_ids.is_empty()); + + // Tool completes naturally + fsm.handle(Event::ToolExecutionDone { + tool_id: "t1".into(), + outcome: crate::tools::ToolOutcome::Success("done".into()), + preview: None, + }); + + assert!(fsm.ctx.tool_timeout_ids.is_empty()); +} + +#[test] +fn stale_timeout_after_natural_completion_is_ignored() { + let mut fsm = fsm_with_shell(); + fsm.handle(Event::UserSubmit("run".into())); + fsm.handle(Event::StreamStarted); + fsm.handle(shell_tool_call_event("t1")); + fsm.handle(Event::PermissionResolved { + tool_id: "t1".into(), + response: PermissionResponse::Allowed, + }); + fsm.handle(Event::StreamDone { + session_id: "s1".into(), + }); + + // Tool completes naturally + fsm.handle(Event::ToolExecutionDone { + tool_id: "t1".into(), + outcome: crate::tools::ToolOutcome::Success("done".into()), + preview: None, + }); + + // Stale timeout fires — should be no-op + let effects = fsm.handle(Event::ToolExecutionTimeout { + timeout_id: 0, + tool_id: "t1".into(), + }); + + assert!(effects.is_empty()); +} + +#[test] +fn timeout_fires_before_completion_emits_abort() { + let mut fsm = fsm_with_shell(); + fsm.handle(Event::UserSubmit("run".into())); + fsm.handle(Event::StreamStarted); + fsm.handle(shell_tool_call_event("t1")); + fsm.handle(Event::PermissionResolved { + tool_id: "t1".into(), + response: PermissionResponse::Allowed, + }); + fsm.handle(Event::StreamDone { + session_id: "s1".into(), + }); + + // Timeout fires while tool is still executing + let effects = fsm.handle(Event::ToolExecutionTimeout { + timeout_id: 0, + tool_id: "t1".into(), + }); + + assert_eq!(effects.len(), 1); + assert!(matches!( + effects[0], + Effect::AbortTool { ref tool_id } if tool_id == "t1" + )); + // Timeout mapping cleaned up + assert!(fsm.ctx.tool_timeout_ids.is_empty()); +} + +#[test] +fn timeout_respects_llm_specified_duration() { + let mut fsm = fsm_with_shell(); + fsm.handle(Event::UserSubmit("run".into())); + fsm.handle(Event::StreamStarted); + + // Tool call with timeout: 120 + fsm.handle(Event::StreamToolCall { + id: "t1".into(), + name: "execute_shell_command".into(), + input: json!({ + "command": "cargo build", + "shell": "bash", + "timeout": 120, + "description": "build" + }), + }); + + let effects = fsm.handle(Event::PermissionResolved { + tool_id: "t1".into(), + response: PermissionResponse::Allowed, + }); + + let timeout_effect = effects + .iter() + .find(|e| matches!(e, Effect::ScheduleTimeout { .. })); + assert!(matches!( + timeout_effect, + Some(Effect::ScheduleTimeout { duration, .. }) if *duration == std::time::Duration::from_secs(120) + )); +} + +#[test] +fn cancel_clears_timeout_mappings() { + let mut fsm = fsm_with_shell(); + fsm.handle(Event::UserSubmit("run".into())); + fsm.handle(Event::StreamStarted); + fsm.handle(shell_tool_call_event("t1")); + fsm.handle(Event::PermissionResolved { + tool_id: "t1".into(), + response: PermissionResponse::Allowed, + }); + + assert!(!fsm.ctx.tool_timeout_ids.is_empty()); + + fsm.handle(Event::Cancel); + + assert!(fsm.ctx.tool_timeout_ids.is_empty()); +} + +#[test] +fn timeout_abort_propagates_timeout_reason_to_preview_and_llm() { + use super::tools::InterruptReason; + + let mut fsm = fsm_with_shell(); + fsm.handle(Event::UserSubmit("run".into())); + fsm.handle(Event::StreamStarted); + fsm.handle(shell_tool_call_event("t1")); + fsm.handle(Event::PermissionResolved { + tool_id: "t1".into(), + response: PermissionResponse::Allowed, + }); + fsm.handle(Event::StreamDone { + session_id: "s1".into(), + }); + + // Timeout fires + fsm.handle(Event::ToolExecutionTimeout { + timeout_id: 0, + tool_id: "t1".into(), + }); + + // Tool completes after abort (interrupted: true from execute_shell_command_streaming) + fsm.handle(Event::ToolExecutionDone { + tool_id: "t1".into(), + outcome: crate::tools::ToolOutcome::Structured { + stdout: "partial output".into(), + stderr: String::new(), + exit_code: None, + duration_ms: 60000, + interrupted: true, + }, + preview: Some(super::tools::ToolPreviewData::Shell { + lines: vec!["partial output".into()], + exit_code: None, + interrupted: None, // FSM overrides this with the reason + }), + }); + + // Preview should carry Timeout reason + let tracked = fsm.ctx.tools.get("t1").unwrap(); + let preview = tracked.shell_preview().unwrap(); + assert_eq!(preview.interrupted, Some(InterruptReason::Timeout(60))); + + // LLM content should say "Timed out" not "Interrupted by user" + let tool_result = fsm.ctx.events.iter().find( + |e| matches!(e, ConversationEvent::ToolResult { tool_use_id, .. } if tool_use_id == "t1"), + ); + if let Some(ConversationEvent::ToolResult { content, .. }) = tool_result { + assert!( + content.contains("[Timed out after 60s]"), + "Expected timeout message, got: {content}" + ); + assert!(!content.contains("[Interrupted by user]")); + } else { + panic!("No ToolResult found for t1"); + } +} + +#[test] +fn user_interrupt_propagates_user_reason_to_preview_and_llm() { + use super::tools::InterruptReason; + + let mut fsm = fsm_with_shell(); + fsm.handle(Event::UserSubmit("run".into())); + fsm.handle(Event::StreamStarted); + fsm.handle(shell_tool_call_event("t1")); + fsm.handle(Event::PermissionResolved { + tool_id: "t1".into(), + response: PermissionResponse::Allowed, + }); + fsm.handle(Event::StreamDone { + session_id: "s1".into(), + }); + + // User interrupts + fsm.handle(Event::InterruptTools); + + // Tool completes after abort + fsm.handle(Event::ToolExecutionDone { + tool_id: "t1".into(), + outcome: crate::tools::ToolOutcome::Structured { + stdout: "partial".into(), + stderr: String::new(), + exit_code: None, + duration_ms: 5000, + interrupted: true, + }, + preview: Some(super::tools::ToolPreviewData::Shell { + lines: vec!["partial".into()], + exit_code: None, + interrupted: None, // FSM overrides this with the reason + }), + }); + + // Preview should carry User reason + let tracked = fsm.ctx.tools.get("t1").unwrap(); + let preview = tracked.shell_preview().unwrap(); + assert_eq!(preview.interrupted, Some(InterruptReason::User)); + + // LLM content should say "Interrupted by user" + let tool_result = fsm.ctx.events.iter().find( + |e| matches!(e, ConversationEvent::ToolResult { tool_use_id, .. } if tool_use_id == "t1"), + ); + if let Some(ConversationEvent::ToolResult { content, .. }) = tool_result { + assert!( + content.contains("[Interrupted by user]"), + "Expected user interrupt message, got: {content}" + ); + } else { + panic!("No ToolResult found for t1"); + } +} + +#[test] +fn user_interrupt_clears_timeout_mappings_for_aborted_tools() { + let mut fsm = fsm_with_shell(); + fsm.handle(Event::UserSubmit("run".into())); + fsm.handle(Event::StreamStarted); + fsm.handle(shell_tool_call_event("t1")); + fsm.handle(Event::PermissionResolved { + tool_id: "t1".into(), + response: PermissionResponse::Allowed, + }); + + assert!(!fsm.ctx.tool_timeout_ids.is_empty()); + + fsm.handle(Event::InterruptTools); + + assert!(fsm.ctx.tool_timeout_ids.is_empty()); +} diff --git a/crates/atuin-ai/src/fsm/tools.rs b/crates/atuin-ai/src/fsm/tools.rs index a6b2e9ae..96348672 100644 --- a/crates/atuin-ai/src/fsm/tools.rs +++ b/crates/atuin-ai/src/fsm/tools.rs @@ -7,6 +7,15 @@ use crate::diff::{EditPreview, WritePreview}; use crate::tools::ClientToolCall; +/// Why a tool execution was interrupted. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum InterruptReason { + /// User pressed Ctrl+C or Esc during execution. + User, + /// The LLM-specified execution timeout expired. + Timeout(u64), +} + /// Per-tool lifecycle state. #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) enum ToolState { @@ -29,7 +38,7 @@ pub(crate) enum ToolPreviewData { Shell { lines: Vec<String>, exit_code: Option<i32>, - interrupted: bool, + interrupted: Option<InterruptReason>, }, /// File edit diff preview. Edit(EditPreview), @@ -45,6 +54,9 @@ pub(crate) struct TrackedTool { pub state: ToolState, /// Cached preview data for rendering (populated during/after execution). pub preview: Option<ToolPreviewData>, + /// Set by the FSM when it emits AbortTool, so that ToolExecutionDone + /// can distinguish user interrupts from timeouts. + pub interrupt_reason: Option<InterruptReason>, } impl TrackedTool { @@ -63,7 +75,7 @@ impl TrackedTool { }) => Some(crate::tools::ToolPreview { lines: lines.clone(), exit_code: *exit_code, - interrupted: *interrupted, + interrupted: interrupted.clone(), }), _ => None, } @@ -109,6 +121,7 @@ impl ToolManager { tool, state: ToolState::CheckingPermission, preview: None, + interrupt_reason: None, }); } diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs index 783bb953..8a670be0 100644 --- a/crates/atuin-ai/src/tools/mod.rs +++ b/crates/atuin-ai/src/tools/mod.rs @@ -76,7 +76,13 @@ pub(crate) enum ToolOutcome { impl ToolOutcome { /// Format this outcome as a string for the tool result sent to the LLM. - pub fn format_for_llm(&self) -> String { + /// + /// The optional `interrupt_reason` overrides the generic interrupted message + /// with a specific one (user interrupt vs timeout). + pub fn format_for_llm( + &self, + interrupt_reason: Option<&crate::fsm::tools::InterruptReason>, + ) -> String { match self { ToolOutcome::Success(s) => s.clone(), ToolOutcome::Error(e) => e.clone(), @@ -108,7 +114,14 @@ impl ToolOutcome { } if *interrupted { - parts.push("[Interrupted by user]".to_string()); + use crate::fsm::tools::InterruptReason; + let msg = match interrupt_reason { + Some(InterruptReason::Timeout(secs)) => { + format!("[Timed out after {secs}s]") + } + _ => "[Interrupted by user]".to_string(), + }; + parts.push(msg); } parts.join("\n\n") @@ -134,7 +147,7 @@ impl ToolOutcome { pub(crate) struct ToolPreview { pub lines: Vec<String>, pub exit_code: Option<i32>, - pub interrupted: bool, + pub interrupted: Option<crate::fsm::tools::InterruptReason>, } /// A tool call from the server, with parsed input parameters. @@ -695,6 +708,8 @@ pub(crate) struct ShellToolCall { pub dir: Option<PathBuf>, pub command: String, pub shell: String, + /// Maximum execution time in seconds (from LLM). Clamped to 1..=600, default 30. + pub timeout_secs: u64, // allow dead code here; this will be tied into o11y and user-facing descriptions #[expect(dead_code)] pub description: Option<String>, @@ -717,6 +732,13 @@ impl TryFrom<&serde_json::Value> for ShellToolCall { .unwrap_or("bash") .to_string(); + let timeout_secs = value + .get("timeout") + .and_then(|v| v.as_u64()) + .filter(|&v| v > 0) + .unwrap_or(30) + .min(600); + let description = value .get("description") .and_then(|v| v.as_str()) @@ -726,6 +748,7 @@ impl TryFrom<&serde_json::Value> for ShellToolCall { dir: dir.map(expand_path), command: command.to_string(), shell, + timeout_secs, description, }) } diff --git a/crates/atuin-ai/src/tui/view/mod.rs b/crates/atuin-ai/src/tui/view/mod.rs index d40a44d4..2061ec38 100644 --- a/crates/atuin-ai/src/tui/view/mod.rs +++ b/crates/atuin-ai/src/tui/view/mod.rs @@ -431,7 +431,7 @@ const MAX_SHELL_PREVIEW_LINES: u16 = 5; /// Render a shell command execution with live VT100 output viewport. fn shell_tool_view(tool_key: &str, command: &str, preview: Option<&ToolPreview>) -> Elements { - let preview_done = preview.is_some_and(|p| p.exit_code.is_some() || p.interrupted); + let preview_done = preview.is_some_and(|p| p.exit_code.is_some() || p.interrupted.is_some()); element! { #(if let Some(preview) = preview { @@ -468,10 +468,16 @@ fn shell_tool_view(tool_key: &str, command: &str, preview: Option<&ToolPreview>) } fn shell_tool_footer(preview: &ToolPreview, preview_done: bool) -> Elements { - if preview.interrupted { + use crate::fsm::tools::InterruptReason; + + if let Some(reason) = &preview.interrupted { + let text = match reason { + InterruptReason::User => "Interrupted".to_string(), + InterruptReason::Timeout(secs) => format!("Timed out ({secs}s)"), + }; return element! { Text { - Span(text: "Interrupted", style: Style::default().fg(Color::Red).add_modifier(Modifier::BOLD)) + Span(text: text, style: Style::default().fg(Color::Red).add_modifier(Modifier::BOLD)) } }; } |
