diff options
| author | Michelle Tilley <michelle@michelletilley.net> | 2026-04-21 10:32:54 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-04-21 10:32:54 -0700 |
| commit | 0f20ee4eb871907defe7848f0d3e2203cfff057e (patch) | |
| tree | cda9034c4c6e7b5ecf0fe957978284e9138b80ff /crates/atuin-ai/src/tui/dispatch.rs | |
| parent | chore: Clarified note about regular expressions matching in path. (#3427) (diff) | |
| download | atuin-0f20ee4eb871907defe7848f0d3e2203cfff057e.zip | |
feat: AI tool rendering overhaul + edit_file tool (#3423)
Overhaul of how AI tool calls are modeled, rendered, and displayed in
the Atuin AI TUI. Fixes bugs in shell command output capture, implements
the `edit_file` tool with full safety infrastructure, and adds a diff
preview for edits.
Diffstat (limited to 'crates/atuin-ai/src/tui/dispatch.rs')
| -rw-r--r-- | crates/atuin-ai/src/tui/dispatch.rs | 199 |
1 files changed, 194 insertions, 5 deletions
diff --git a/crates/atuin-ai/src/tui/dispatch.rs b/crates/atuin-ai/src/tui/dispatch.rs index ea895c01..fea26953 100644 --- a/crates/atuin-ai/src/tui/dispatch.rs +++ b/crates/atuin-ai/src/tui/dispatch.rs @@ -61,15 +61,17 @@ pub(crate) fn dispatch(ctx: &mut DispatchContext, event: AiTuiEvent) -> bool { !ctx.exiting.load(Ordering::Acquire) } -/// Persist new events and the server session ID if it has changed. +/// Persist new events, server session ID, file tracker, and edit permissions. /// Called from the dispatch thread (sync), bridges to async via the tokio handle. fn persist_session(ctx: &mut DispatchContext) { - let Ok((events, server_sid)) = ctx + let Ok((events, server_sid, file_tracker_json, edit_perms_json)) = ctx .handle .fetch(|state| { ( state.conversation.events.clone(), state.conversation.session_id.clone(), + state.file_tracker.to_json().ok(), + state.edit_permissions.to_json().ok(), ) }) .blocking_recv() @@ -86,6 +88,22 @@ fn persist_session(ctx: &mut DispatchContext) { { tracing::warn!("failed to persist server session ID: {e}"); } + if let Some(ref json) = file_tracker_json + && let Err(e) = rt.block_on( + ctx.session_mgr + .set_metadata(crate::file_tracker::METADATA_KEY, json), + ) + { + tracing::warn!("failed to persist file tracker: {e}"); + } + if let Some(ref json) = edit_perms_json + && let Err(e) = rt.block_on( + ctx.session_mgr + .set_metadata(crate::edit_permissions::METADATA_KEY, json), + ) + { + tracing::warn!("failed to persist edit permissions: {e}"); + } } fn launch_stream(ctx: &DispatchContext, setup: impl FnOnce(&mut Session) + Send + 'static) { @@ -210,6 +228,10 @@ fn execute_tool( let shell_call = shell_call.clone(); execute_shell_tool(handle, tx, &tool_id, &shell_call); } + ClientToolCall::Edit(edit_call) => { + let edit_call = edit_call.clone(); + execute_edit_tool(handle, tx, tool_id, edit_call); + } _ => { execute_simple_tool(handle, tx, tool_id, tool, db); } @@ -231,7 +253,21 @@ fn execute_simple_tool( tokio::spawn(async move { let outcome = tool.execute(&db).await; + + // After a successful file read, capture tracking data for freshness + // checking. This re-stats the file to get content hash and mtime. + let read_tracking = if let ClientToolCall::Read(ref read_tool) = tool + && !outcome.is_error() + { + capture_read_tracking(&read_tool.path) + } else { + None + }; + h.update(move |state| { + if let Some((path, content, mtime)) = read_tracking { + state.file_tracker.record_read(path, &content, mtime); + } state.finish_tool_call(&tool_id, outcome); if !state.tool_tracker.has_pending() { let _ = tx.send(AiTuiEvent::ContinueAfterTools); @@ -240,6 +276,117 @@ fn execute_simple_tool( }); } +/// Capture file content and mtime for the read tracker. +/// Returns None for directories or if the file can't be read. +fn capture_read_tracking( + path: &std::path::Path, +) -> Option<(std::path::PathBuf, Vec<u8>, std::time::SystemTime)> { + let resolved = if path.is_relative() { + std::env::current_dir().ok()?.join(path) + } else { + path.to_path_buf() + }; + if !resolved.is_file() { + return None; + } + let content = std::fs::read(&resolved).ok()?; + let mtime = std::fs::metadata(&resolved).ok()?.modified().ok()?; + Some((resolved, content, mtime)) +} + +/// Execute an edit_file tool call. +/// +/// Orchestrates snapshot → execute → tracker update. The snapshot and +/// tracker mutations happen via `h.update()` (on the TUI thread) since +/// they need mutable Session state. The actual file I/O (freshness check, +/// read, match, atomic write) runs in the tokio task. +fn execute_edit_tool( + handle: &Handle<Session>, + tx: &mpsc::Sender<AiTuiEvent>, + tool_id: String, + edit_call: crate::tools::EditToolCall, +) { + let h = handle.clone(); + let tx = tx.clone(); + + tokio::spawn(async move { + let resolved = edit_call.resolved_path(); + + // 1. Read the original file content (used for snapshot + diff). + let old_content = std::fs::read(&resolved).ok(); + + // 2. Snapshot the original file before editing. + if let Some(ref content) = old_content { + let snap_path = resolved.clone(); + let snap_content = content.clone(); + h.update(move |state| { + if let Some(ref mut store) = state.snapshot_store + && let Err(e) = store.ensure_snapshot(&snap_path, &snap_content) + { + tracing::warn!("failed to create file snapshot: {e}"); + } + }); + } + + // 3. Fetch a clone of the file tracker for freshness checking. + let Ok(tracker) = h.fetch(|state| state.file_tracker.clone()).await else { + let tc_id = tool_id.clone(); + h.update(move |state| { + state.finish_tool_call( + &tc_id, + crate::tools::ToolOutcome::Error("Internal error: TUI unavailable".into()), + ); + if !state.tool_tracker.has_pending() { + let _ = tx.send(AiTuiEvent::ContinueAfterTools); + } + }); + return; + }; + + // 4. Execute: freshness check → read → match → atomic write + let (outcome, new_bytes) = edit_call.execute(&resolved, &tracker); + + // 5. Compute diff preview on success + let edit_preview = if let Some(ref new_bytes) = new_bytes { + if let Some(ref old_bytes) = old_content { + let old_str = String::from_utf8_lossy(old_bytes); + let new_str = String::from_utf8_lossy(new_bytes); + let preview = crate::diff::EditPreview::compute(&old_str, &new_str); + if preview.hunks.is_empty() { + None + } else { + Some(preview) + } + } else { + None + } + } else { + None + }; + + // 6. Update tracker, store diff preview, and finish the tool call + let tc_id = tool_id; + h.update(move |state| { + if let Some(ref new_bytes) = new_bytes + && let Ok(mtime) = std::fs::metadata(&resolved).and_then(|m| m.modified()) + { + state + .file_tracker + .update_after_edit(&resolved, new_bytes, mtime); + } + if let Some(preview) = edit_preview + && let Some(tracked) = state.tool_tracker.get_mut(&tc_id) + { + tracked.edit_preview = Some(preview); + } + state.finish_tool_call(&tc_id, outcome); + if !state.tool_tracker.has_pending() { + let _ = tx.send(AiTuiEvent::ContinueAfterTools); + } + }); + }); +} + /// Execute a shell tool with streaming VT100 preview. fn execute_shell_tool( handle: &Handle<Session>, @@ -352,12 +499,28 @@ async fn check_tool_permission_inner( .map_err(|e| format!("Internal error fetching tool state: {e}"))? .ok_or_else(|| "Internal error: tool not found in tracker".to_string())?; - // 2. Resolve working directory + // 2. For edit tools, check session-scoped permission grants before + // hitting the filesystem-based resolver. A valid grant means the user + // already approved this file recently. + if let ClientToolCall::Edit(ref edit) = tool { + let resolved = edit.resolved_path(); + let has_grant = h2 + .fetch(move |state| state.edit_permissions.has_valid_grant(&resolved)) + .await + .unwrap_or(false); + + if has_grant { + execute_tool(h2, tx, id, tool, db); + return Ok(()); + } + } + + // 3. Resolve working directory let working_dir = target_dir .or_else(|| std::env::current_dir().ok()) .ok_or_else(|| "Could not determine working directory".to_string())?; - // 3. Create permission resolver and check + // 4. Create permission resolver and check let resolver = PermissionResolver::new(working_dir) .await .map_err(|e| format!("Permission check failed: {e}"))?; @@ -367,7 +530,7 @@ async fn check_tool_permission_inner( .await .map_err(|e| format!("Permission check failed: {e}"))?; - // 4. Handle response — all paths here handle the tool, so return Ok + // 5. Handle response — all paths here handle the tool, so return Ok let id_clone = id.clone(); match response { PermissionResponse::Allowed => { @@ -423,6 +586,32 @@ fn on_select_permission(ctx: &mut DispatchContext, permission: PermissionResult) execute_tool(&h2, &tx, tool_id, tool, &db); }); } + PermissionResult::AllowFileForSession => { + // Cache a session-scoped, time-limited grant for this file + let db = ctx.app_ctx.history_db.clone(); + tokio::spawn(async move { + let Ok(Some((tool_id, tool))) = h2 + .fetch(move |state| { + state + .tool_tracker + .asking_for_permission() + .map(|t| (t.id.clone(), t.tool.clone())) + }) + .await + else { + return; + }; + + if let ClientToolCall::Edit(ref edit) = tool { + let resolved = edit.resolved_path(); + h2.update(move |state| { + state.edit_permissions.grant(resolved); + }); + } + + execute_tool(&h2, &tx, tool_id, tool, &db); + }); + } PermissionResult::AlwaysAllowInDir => { let db = ctx.app_ctx.history_db.clone(); let git_root = ctx.app_ctx.git_root.clone(); |
