aboutsummaryrefslogtreecommitdiffstats
path: root/crates/atuin-ai/src/fsm/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/atuin-ai/src/fsm/mod.rs')
-rw-r--r--crates/atuin-ai/src/fsm/mod.rs157
1 files changed, 129 insertions, 28 deletions
diff --git a/crates/atuin-ai/src/fsm/mod.rs b/crates/atuin-ai/src/fsm/mod.rs
index d32d6d7b..25de41f3 100644
--- a/crates/atuin-ai/src/fsm/mod.rs
+++ b/crates/atuin-ai/src/fsm/mod.rs
@@ -13,12 +13,14 @@ pub(crate) mod tools;
#[cfg(test)]
mod tests;
+use std::collections::HashMap;
+
use serde_json::Value;
use crate::context_window::ContextWindowBuilder;
use crate::tui::state::ConversationEvent;
-use effects::{Effect, ExitAction, PermissionTarget};
+use effects::{Effect, ExitAction, PermissionTarget, TimeoutKind};
use events::{Event, PermissionChoice, PermissionResponse};
use tools::{ToolManager, ToolState};
@@ -98,6 +100,9 @@ pub(crate) struct AgentContext {
/// Tool IDs that belong to the current turn. Cleared on continuation start.
/// Used to determine whether a turn needs continuation (has unprocessed results).
current_turn_tool_ids: Vec<String>,
+ /// Maps timeout_id → tool_id for active tool execution timeouts.
+ /// Cleaned up when a tool completes naturally, so stale timeouts are ignored.
+ tool_timeout_ids: HashMap<u64, String>,
/// Counter for generating unique timeout IDs.
next_timeout_id: u64,
/// Capabilities advertised to the server.
@@ -150,6 +155,7 @@ impl AgentFsm {
current_response: String::new(),
tools: ToolManager::new(),
current_turn_tool_ids: Vec::new(),
+ tool_timeout_ids: HashMap::new(),
next_timeout_id: 0,
capabilities,
invocation_id,
@@ -179,6 +185,7 @@ impl AgentFsm {
current_response: String::new(),
tools: ToolManager::new(),
current_turn_tool_ids: Vec::new(),
+ tool_timeout_ids: HashMap::new(),
next_timeout_id: 0,
capabilities,
invocation_id,
@@ -224,6 +231,7 @@ impl AgentFsm {
vec![Effect::ScheduleTimeout {
timeout_id,
duration: std::time::Duration::from_secs(5),
+ kind: TimeoutKind::Confirmation,
}]
} else {
vec![Effect::ExitApp(ExitAction::Execute(cmd))]
@@ -414,6 +422,7 @@ impl AgentFsm {
.into_iter()
.map(|tool_id| Effect::AbortTool { tool_id })
.collect();
+ self.ctx.tool_timeout_ids.clear();
self.state = AgentState::Error(e);
abort_effects
}
@@ -462,7 +471,7 @@ impl AgentFsm {
tracked.preview = Some(tools::ToolPreviewData::Shell {
lines,
exit_code,
- interrupted: false,
+ interrupted: None,
});
}
}
@@ -471,11 +480,26 @@ impl AgentFsm {
(AgentState::Turn { .. }, Event::InterruptTools) => {
let ids = self.ctx.tools.executing_ids();
+ for id in &ids {
+ if let Some(tracked) = self.ctx.tools.get_mut(id) {
+ tracked.interrupt_reason = Some(tools::InterruptReason::User);
+ }
+ // Clear any pending execution timeout for this tool
+ self.ctx.tool_timeout_ids.retain(|_, tid| tid != id);
+ }
ids.into_iter()
.map(|tool_id| Effect::AbortTool { tool_id })
.collect()
}
+ (
+ AgentState::Turn { .. },
+ Event::ToolExecutionTimeout {
+ timeout_id,
+ tool_id,
+ },
+ ) => self.handle_tool_execution_timeout(timeout_id, tool_id),
+
// ─── Cancel during Turn ─────────────────────────────────────
(AgentState::Turn { stream }, Event::Cancel) => {
let mut effects = Vec::new();
@@ -515,6 +539,9 @@ impl AgentFsm {
});
}
+ // Clear timeout mappings — stale timeouts will be ignored by the guard
+ self.ctx.tool_timeout_ids.clear();
+
self.state = AgentState::Idle { confirmation: None };
effects.push(Effect::Persist);
effects
@@ -694,7 +721,7 @@ impl AgentFsm {
PermissionResponse::Allowed | PermissionResponse::SessionGranted => {
tracked.state = ToolState::Executing;
let tool = tracked.tool.clone();
- vec![Effect::ExecuteTool { tool_id, tool }]
+ self.emit_execute_tool(tool_id, tool)
}
PermissionResponse::Ask => {
tracked.state = ToolState::AwaitingPermission;
@@ -732,15 +759,12 @@ impl AgentFsm {
PermissionChoice::Allow => {
tracked.state = ToolState::Executing;
let tool = tracked.tool.clone();
- vec![Effect::ExecuteTool { tool_id, tool }]
+ self.emit_execute_tool(tool_id, tool)
}
PermissionChoice::AllowForSession => {
tracked.state = ToolState::Executing;
let tool = tracked.tool.clone();
- let mut effects = vec![Effect::ExecuteTool {
- tool_id,
- tool: tool.clone(),
- }];
+ let mut effects = self.emit_execute_tool(tool_id, tool.clone());
if let Some(path) = tool.resolved_file_path() {
effects.push(Effect::CacheSessionGrant { path });
}
@@ -753,14 +777,13 @@ impl AgentFsm {
tool: tool.rule_name().to_string(),
scope: None, // project file provides the scoping
};
- vec![
- Effect::ExecuteTool { tool_id, tool },
- Effect::WritePermissionRule {
- target: PermissionTarget::Project,
- rule,
- disposition: crate::permissions::writer::RuleDisposition::Allow,
- },
- ]
+ let mut effects = self.emit_execute_tool(tool_id, tool);
+ effects.push(Effect::WritePermissionRule {
+ target: PermissionTarget::Project,
+ rule,
+ disposition: crate::permissions::writer::RuleDisposition::Allow,
+ });
+ effects
}
PermissionChoice::AlwaysAllow => {
tracked.state = ToolState::Executing;
@@ -772,14 +795,13 @@ impl AgentFsm {
tool: tool.rule_name().to_string(),
scope,
};
- vec![
- Effect::ExecuteTool { tool_id, tool },
- Effect::WritePermissionRule {
- target: PermissionTarget::Global,
- rule,
- disposition: crate::permissions::writer::RuleDisposition::Allow,
- },
- ]
+ let mut effects = self.emit_execute_tool(tool_id, tool);
+ effects.push(Effect::WritePermissionRule {
+ target: PermissionTarget::Global,
+ rule,
+ disposition: crate::permissions::writer::RuleDisposition::Allow,
+ });
+ effects
}
PermissionChoice::Deny => {
tracked.state = ToolState::Denied;
@@ -813,6 +835,19 @@ impl AgentFsm {
tracked.state = ToolState::Completed;
+ // If the FSM tagged this tool with an interrupt reason (user or timeout),
+ // use it; otherwise derive from the outcome's interrupted flag.
+ let reason = tracked.interrupt_reason.take().or({
+ if let crate::tools::ToolOutcome::Structured {
+ interrupted: true, ..
+ } = &outcome
+ {
+ Some(tools::InterruptReason::User)
+ } else {
+ None
+ }
+ });
+
// Merge shell preview: the final ToolExecutionDone carries exit_code/interrupted
// but has empty lines (the live lines were accumulated via ToolPreviewUpdate).
// Preserve the accumulated lines and fold in the terminal metadata.
@@ -825,20 +860,29 @@ impl AgentFsm {
}),
Some(tools::ToolPreviewData::Shell {
exit_code: final_exit,
- interrupted: final_interrupted,
..
}),
) => {
*exit_code = final_exit;
- *interrupted = final_interrupted;
+ *interrupted = reason.clone();
}
- (_, Some(p)) => {
+ (_, Some(mut p)) => {
+ if let tools::ToolPreviewData::Shell {
+ ref mut interrupted,
+ ..
+ } = p
+ {
+ *interrupted = reason.clone();
+ }
tracked.preview = Some(p);
}
_ => {}
}
- let content = outcome.format_for_llm();
+ // Clean up any pending execution timeout for this tool
+ self.ctx.tool_timeout_ids.retain(|_, tid| tid != &tool_id);
+
+ let content = outcome.format_for_llm(reason.as_ref());
let is_error = outcome.is_error();
self.ctx.events.push(ConversationEvent::ToolResult {
tool_use_id: tool_id,
@@ -851,6 +895,63 @@ impl AgentFsm {
self.check_turn_completion()
}
+ /// Handle a tool execution timeout. Aborts the tool if it's still running.
+ fn handle_tool_execution_timeout(&mut self, timeout_id: u64, tool_id: String) -> Vec<Effect> {
+ // Guard: only act if this timeout is still registered (not cleaned up by natural completion)
+ if self.ctx.tool_timeout_ids.remove(&timeout_id).is_none() {
+ return vec![];
+ }
+
+ let Some(tracked) = self.ctx.tools.get_mut(&tool_id) else {
+ return vec![];
+ };
+
+ if tracked.is_resolved() {
+ return vec![];
+ }
+
+ // Tag the tool so handle_tool_done can distinguish timeout from user interrupt.
+ // Only shell tools have entries in tool_timeout_ids, so this is always Shell.
+ let timeout_secs = match &tracked.tool {
+ crate::tools::ClientToolCall::Shell(s) => s.timeout_secs,
+ _ => unreachable!("only shell tools have execution timeouts"),
+ };
+ tracked.interrupt_reason = Some(tools::InterruptReason::Timeout(timeout_secs));
+
+ // Abort the tool — the driver sends the interrupt signal via oneshot,
+ // and execute_shell_command_streaming returns a Structured outcome with
+ // interrupted: true and partial stdout/stderr. This flows through the
+ // normal ToolExecutionDone path.
+ vec![Effect::AbortTool { tool_id }]
+ }
+
+ /// Emit effects to begin executing a tool. For shell commands, also schedules
+ /// an execution timeout based on the LLM-specified timeout_secs.
+ fn emit_execute_tool(
+ &mut self,
+ tool_id: String,
+ tool: crate::tools::ClientToolCall,
+ ) -> Vec<Effect> {
+ let mut effects = vec![Effect::ExecuteTool {
+ tool_id: tool_id.clone(),
+ tool: tool.clone(),
+ }];
+
+ if let crate::tools::ClientToolCall::Shell(ref shell) = tool {
+ let timeout_id = self.ctx.next_timeout_id();
+ self.ctx
+ .tool_timeout_ids
+ .insert(timeout_id, tool_id.clone());
+ effects.push(Effect::ScheduleTimeout {
+ timeout_id,
+ duration: std::time::Duration::from_secs(shell.timeout_secs),
+ kind: TimeoutKind::ToolExecution { tool_id },
+ });
+ }
+
+ effects
+ }
+
/// Check if the turn is complete (stream done + all tools resolved).
/// If so, either continue the conversation or go Idle.
fn check_turn_completion(&mut self) -> Vec<Effect> {