diff options
Diffstat (limited to 'crates/atuin-ai/src/tools')
| -rw-r--r-- | crates/atuin-ai/src/tools/mod.rs | 200 |
1 files changed, 23 insertions, 177 deletions
diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs index 890ea734..530f0e83 100644 --- a/crates/atuin-ai/src/tools/mod.rs +++ b/crates/atuin-ai/src/tools/mod.rs @@ -58,6 +58,7 @@ fn path_matches_scope(path: &Path, scope: &str) -> bool { } /// Result of executing a client-side tool. +#[derive(Debug, Clone)] pub(crate) enum ToolOutcome { /// Simple success with a text result (used by Read, AtuinHistory). Success(String), @@ -136,176 +137,6 @@ pub(crate) struct ToolPreview { pub interrupted: bool, } -/// Lifecycle phase of a tracked tool call. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum ToolPhase { - CheckingPermissions, - AskingForPermission, - #[expect(dead_code)] - Denied(String), - #[expect(dead_code)] - Executing, - /// Shell command is executing with live preview output. - ExecutingWithPreview { - command: String, - /// Current VT100 screen lines (plain text, viewport-sized). - output_lines: Vec<String>, - /// Exit code once the process completes. - exit_code: Option<i32>, - /// Whether the command was interrupted by the user. - interrupted: bool, - }, - /// Tool execution has completed. Preview is cached for rendering history. - Completed { - preview: Option<ToolPreview>, - }, -} - -/// A tracked tool call through its full lifecycle. -#[derive(Debug)] -pub(crate) struct TrackedTool { - pub id: String, - pub tool: ClientToolCall, - pub phase: ToolPhase, - /// Sender to interrupt a running shell command (only set during ExecutingWithPreview). - pub abort_tx: Option<tokio::sync::oneshot::Sender<()>>, - /// Diff preview for completed edit tool calls. - pub edit_preview: Option<crate::diff::EditPreview>, - /// Content preview for completed write tool calls. - pub write_preview: Option<crate::diff::WritePreview>, -} - -impl TrackedTool { - pub(crate) fn target_dir(&self) -> Option<&Path> { - self.tool.target_dir() - } - - pub fn mark_asking(&mut self) { - self.phase = ToolPhase::AskingForPermission; - } - - pub fn mark_executing_preview(&mut self, command: String) { - self.phase = ToolPhase::ExecutingWithPreview { - command, - output_lines: Vec::new(), - exit_code: None, - interrupted: false, - }; - } - - pub fn complete(&mut self, preview: Option<ToolPreview>) { - self.phase = ToolPhase::Completed { preview }; - self.abort_tx = None; - } - - /// Extract the current preview, whether live or completed. - pub fn preview(&self) -> Option<ToolPreview> { - match &self.phase { - ToolPhase::ExecutingWithPreview { - output_lines, - exit_code, - interrupted, - .. - } => Some(ToolPreview { - lines: output_lines.clone(), - exit_code: *exit_code, - interrupted: *interrupted, - }), - ToolPhase::Completed { preview } => preview.clone(), - _ => None, - } - } -} - -/// Tracks all tool calls through their full lifecycle. -/// -/// Single source of truth for tool execution state. Entries persist after -/// completion so cached previews remain available for rendering history. -#[derive(Debug)] -pub(crate) struct ToolTracker { - tools: Vec<TrackedTool>, -} - -impl ToolTracker { - pub fn new() -> Self { - Self { tools: Vec::new() } - } - - /// Insert a new tool call in CheckingPermissions phase. - pub fn insert(&mut self, id: String, tool: ClientToolCall) { - self.tools.push(TrackedTool { - id, - tool, - phase: ToolPhase::CheckingPermissions, - abort_tx: None, - edit_preview: None, - write_preview: None, - }); - } - - pub fn get(&self, id: &str) -> Option<&TrackedTool> { - self.tools.iter().find(|t| t.id == id) - } - - pub fn get_mut(&mut self, id: &str) -> Option<&mut TrackedTool> { - self.tools.iter_mut().find(|t| t.id == id) - } - - /// Remove a tool by ID and return it. - #[expect(dead_code)] - pub fn remove(&mut self, id: &str) -> Option<TrackedTool> { - let pos = self.tools.iter().position(|t| t.id == id)?; - Some(self.tools.remove(pos)) - } - - /// True if any tool is still awaiting a permission decision. - #[expect(dead_code)] - pub fn has_unresolved(&self) -> bool { - self.tools.iter().any(|t| { - matches!( - t.phase, - ToolPhase::CheckingPermissions | ToolPhase::AskingForPermission - ) - }) - } - - /// True if any tool has not yet reached the Completed phase. - /// Use this to gate `ContinueAfterTools` — we must wait for all tools - /// (including those still executing) before resuming the conversation. - pub fn has_pending(&self) -> bool { - self.tools - .iter() - .any(|t| !matches!(t.phase, ToolPhase::Completed { .. })) - } - - /// True if any tool is currently executing with a preview. - pub fn has_executing_preview(&self) -> bool { - self.tools - .iter() - .any(|t| matches!(t.phase, ToolPhase::ExecutingWithPreview { .. })) - } - - /// Find the first tool that is asking for permission. - pub fn asking_for_permission(&self) -> Option<&TrackedTool> { - self.tools - .iter() - .find(|t| t.phase == ToolPhase::AskingForPermission) - } - - /// Find the first tool that is asking for permission (mutable). - #[expect(dead_code)] - pub fn asking_for_permission_mut(&mut self) -> Option<&mut TrackedTool> { - self.tools - .iter_mut() - .find(|t| t.phase == ToolPhase::AskingForPermission) - } - - /// Iterate mutably over all tracked tools. - pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut TrackedTool> { - self.tools.iter_mut() - } -} - /// A tool call from the server, with parsed input parameters. #[derive(Debug, Clone)] pub(crate) enum ClientToolCall { @@ -359,6 +190,17 @@ impl ClientToolCall { } } + /// The resolved file path for this tool call, if it's a file-based tool. + /// Used to build scoped permission rules like `Write(/abs/path/to/file)`. + pub(crate) fn resolved_file_path(&self) -> Option<PathBuf> { + match self { + ClientToolCall::Read(tool) => Some(tool.resolved_path()), + ClientToolCall::Edit(tool) => Some(tool.resolved_path()), + ClientToolCall::Write(tool) => Some(tool.resolved_path()), + _ => None, + } + } + pub(crate) fn matches_rule(&self, rule: &Rule) -> bool { match self { ClientToolCall::Read(tool) => tool.matches_rule(rule), @@ -449,14 +291,18 @@ impl TryFrom<&serde_json::Value> for ReadToolCall { } impl ReadToolCall { - fn execute(&self) -> ToolOutcome { - let mut path = self.path.clone(); - - if path.is_relative() - && let Ok(current_dir) = std::env::current_dir() - { - path = current_dir.join(path); + pub fn resolved_path(&self) -> PathBuf { + if self.path.is_relative() { + std::env::current_dir() + .map(|cwd| cwd.join(&self.path)) + .unwrap_or_else(|_| self.path.clone()) + } else { + self.path.clone() } + } + + pub fn execute(&self) -> ToolOutcome { + let path = self.resolved_path(); if !path.exists() { return ToolOutcome::Error(format!("Error: file does not exist: {}", path.display())); |
