diff options
| author | Michelle Tilley <michelle@michelletilley.net> | 2026-04-21 10:32:54 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-04-21 10:32:54 -0700 |
| commit | 0f20ee4eb871907defe7848f0d3e2203cfff057e (patch) | |
| tree | cda9034c4c6e7b5ecf0fe957978284e9138b80ff /crates/atuin-ai/src/tools/mod.rs | |
| parent | chore: Clarified note about regular expressions matching in path. (#3427) (diff) | |
| download | atuin-0f20ee4eb871907defe7848f0d3e2203cfff057e.zip | |
feat: AI tool rendering overhaul + edit_file tool (#3423)
Overhaul of how AI tool calls are modeled, rendered, and displayed in
the Atuin AI TUI. Fixes bugs in shell command output capture, implements
the `edit_file` tool with full safety infrastructure, and adds a diff
preview for edits.
Diffstat (limited to 'crates/atuin-ai/src/tools/mod.rs')
| -rw-r--r-- | crates/atuin-ai/src/tools/mod.rs | 737 |
1 files changed, 711 insertions, 26 deletions
diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs index 8f2183b7..8fe1ad73 100644 --- a/crates/atuin-ai/src/tools/mod.rs +++ b/crates/atuin-ai/src/tools/mod.rs @@ -169,6 +169,8 @@ pub(crate) struct TrackedTool { pub phase: ToolPhase, /// Sender to interrupt a running shell command (only set during ExecutingWithPreview). pub abort_tx: Option<tokio::sync::oneshot::Sender<()>>, + /// Diff preview for completed edit tool calls. + pub edit_preview: Option<crate::diff::EditPreview>, } impl TrackedTool { @@ -234,6 +236,7 @@ impl ToolTracker { tool, phase: ToolPhase::CheckingPermissions, abort_tx: None, + edit_preview: None, }); } @@ -294,11 +297,6 @@ impl ToolTracker { .find(|t| t.phase == ToolPhase::AskingForPermission) } - /// Get the preview for a tool by ID (live or cached). - pub fn preview_for(&self, id: &str) -> Option<ToolPreview> { - self.get(id)?.preview() - } - /// Iterate mutably over all tracked tools. pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut TrackedTool> { self.tools.iter_mut() @@ -309,6 +307,7 @@ impl ToolTracker { #[derive(Debug, Clone)] pub(crate) enum ClientToolCall { Read(ReadToolCall), + Edit(EditToolCall), Write(WriteToolCall), Shell(ShellToolCall), AtuinHistory(AtuinHistoryToolCall), @@ -320,9 +319,8 @@ impl TryFrom<(&str, &serde_json::Value)> for ClientToolCall { fn try_from((name, input): (&str, &serde_json::Value)) -> Result<Self, Self::Error> { match name { "read_file" => Ok(ClientToolCall::Read(ReadToolCall::try_from(input)?)), - "create_file" => Ok(ClientToolCall::Write(WriteToolCall::try_from(input)?)), - // "append_to_file" => Ok(ClientToolCall::Append(AppendToolCall::try_from(input)?)), - // "str_replace" => Ok(ClientToolCall::StrReplace(StrReplaceToolCall::try_from(input)?)), + "edit_file" => Ok(ClientToolCall::Edit(EditToolCall::try_from(input)?)), + "write_file" => Ok(ClientToolCall::Write(WriteToolCall::try_from(input)?)), "execute_shell_command" => Ok(ClientToolCall::Shell(ShellToolCall::try_from(input)?)), "atuin_history" => Ok(ClientToolCall::AtuinHistory( AtuinHistoryToolCall::try_from(input)?, @@ -336,17 +334,22 @@ impl ClientToolCall { pub(crate) fn descriptor(&self) -> &'static descriptor::ToolDescriptor { match self { ClientToolCall::Read(_) => descriptor::READ, + ClientToolCall::Edit(_) => descriptor::EDIT, ClientToolCall::Write(_) => descriptor::WRITE, ClientToolCall::Shell(_) => descriptor::SHELL, ClientToolCall::AtuinHistory(_) => descriptor::ATUIN_HISTORY, } } - /// The permission rule name for this tool category (e.g. "Write" covers - /// str_replace, file_create, file_insert). + /// The permission rule name for this tool category. + /// + /// Edit and Write share the `"Write"` rule name — a Write permission + /// covers both str_replace edits and full file creates. Write also + /// implies Read (checked in `ReadToolCall::matches_rule`). pub(crate) fn rule_name(&self) -> &'static str { match self { ClientToolCall::Read(_) => "Read", + ClientToolCall::Edit(_) => "Write", ClientToolCall::Write(_) => "Write", ClientToolCall::Shell(_) => "Shell", ClientToolCall::AtuinHistory(_) => "AtuinHistory", @@ -356,6 +359,7 @@ impl ClientToolCall { pub(crate) fn matches_rule(&self, rule: &Rule) -> bool { match self { ClientToolCall::Read(tool) => tool.matches_rule(rule), + ClientToolCall::Edit(tool) => tool.matches_rule(rule), ClientToolCall::Write(tool) => tool.matches_rule(rule), ClientToolCall::Shell(tool) => tool.matches_rule(rule), ClientToolCall::AtuinHistory(tool) => tool.matches_rule(rule), @@ -365,6 +369,7 @@ impl ClientToolCall { pub(crate) fn target_dir(&self) -> Option<&Path> { match self { ClientToolCall::Read(tool) => tool.target_dir(), + ClientToolCall::Edit(tool) => tool.target_dir(), ClientToolCall::Write(tool) => tool.target_dir(), ClientToolCall::Shell(tool) => tool.target_dir(), ClientToolCall::AtuinHistory(tool) => tool.target_dir(), @@ -401,6 +406,14 @@ impl PermissableToolCall for ClientToolCall { } } +/// Expand shell constructs (`~`, `$HOME`, etc.) in a path string. +/// +/// Tool call paths arrive as raw strings from the API without shell +/// expansion. Uses `shellexpand` (same as `atuin-client`). +fn expand_path(path: &str) -> PathBuf { + PathBuf::from(shellexpand::tilde(path).into_owned()) +} + #[derive(Debug, Clone)] pub(crate) struct ReadToolCall { pub path: PathBuf, @@ -425,7 +438,7 @@ impl TryFrom<&serde_json::Value> for ReadToolCall { .min(MAX_FILE_READ_LINES); Ok(ReadToolCall { - path: PathBuf::from(path), + path: expand_path(path), offset, limit, }) @@ -499,7 +512,207 @@ impl PermissableToolCall for ReadToolCall { } fn matches_rule(&self, rule: &Rule) -> bool { - if rule.tool != "Read" { + // Write implies Read — a Write permission on a path also permits reading it. + if rule.tool != "Read" && rule.tool != "Write" { + return false; + } + + match rule.scope.as_deref() { + None | Some("*") => true, + Some(scope) => path_matches_scope(&self.path, scope), + } + } +} + +#[derive(Debug, Clone)] +pub(crate) struct EditToolCall { + pub path: PathBuf, + pub old_string: String, + pub new_string: String, + pub replace_all: bool, +} + +impl TryFrom<&serde_json::Value> for EditToolCall { + type Error = eyre::Error; + + fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> { + let path = value + .get("file_path") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing file_path"))?; + + let old_string = value + .get("old_string") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing old_string"))?; + + let new_string = value + .get("new_string") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing new_string"))?; + + let replace_all = value + .get("replace_all") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + Ok(EditToolCall { + path: expand_path(path), + old_string: old_string.to_string(), + new_string: new_string.to_string(), + replace_all, + }) + } +} + +impl EditToolCall { + /// Resolve the edit path to an absolute path. + pub fn resolved_path(&self) -> PathBuf { + if self.path.is_relative() { + std::env::current_dir() + .map(|cwd| cwd.join(&self.path)) + .unwrap_or_else(|_| self.path.clone()) + } else { + self.path.clone() + } + } + + /// Execute the edit against the filesystem. + /// + /// Checks freshness via the provided tracker, validates matches, applies + /// the replacement, and writes atomically. Returns the outcome and (on + /// success) the new file content bytes for tracker updates. + /// + /// Callers should snapshot the file before calling this method and + /// update the file tracker after a successful return. + pub fn execute( + &self, + resolved_path: &Path, + file_tracker: &crate::file_tracker::FileReadTracker, + ) -> (ToolOutcome, Option<Vec<u8>>) { + use crate::file_tracker::FreshnessCheck; + + // 1. Basic validation + if !resolved_path.exists() { + return ( + ToolOutcome::Error(format!( + "Error: file does not exist: {}", + resolved_path.display() + )), + None, + ); + } + if resolved_path.is_dir() { + return ( + ToolOutcome::Error(format!( + "Error: path is a directory, not a file: {}", + resolved_path.display() + )), + None, + ); + } + if self.old_string.is_empty() { + return ( + ToolOutcome::Error( + "old_string must not be empty. To create a new file, use write_file instead." + .to_string(), + ), + None, + ); + } + + // 2. Freshness check + match file_tracker.check_freshness(resolved_path) { + Ok(FreshnessCheck::NotRead) => { + return ( + ToolOutcome::Error( + "File has not been read yet. Read it first before editing.".to_string(), + ), + None, + ); + } + Ok(FreshnessCheck::Stale) => { + return ( + ToolOutcome::Error( + "File has been modified since read, either by the user or by a linter. Read it again before attempting to edit it.".to_string(), + ), + None, + ); + } + Err(e) => { + return ( + ToolOutcome::Error(format!("Error checking file state: {e}")), + None, + ); + } + Ok(FreshnessCheck::Fresh) => {} + } + + // 3. Read current contents + let content = match std::fs::read_to_string(resolved_path) { + Ok(c) => c, + Err(e) => return (ToolOutcome::Error(format!("Error reading file: {e}")), None), + }; + + // 4. Find and validate matches + let match_count = content.matches(&self.old_string).count(); + + if match_count == 0 { + return ( + ToolOutcome::Error(format!( + "old_string not found in {}. Make sure it matches exactly, including whitespace and indentation.", + resolved_path.display() + )), + None, + ); + } + + if match_count > 1 && !self.replace_all { + return ( + ToolOutcome::Error(format!( + "Found {match_count} matches of old_string in {}, but replace_all is false. Either provide more context to make the match unique, or set replace_all to true.", + resolved_path.display() + )), + None, + ); + } + + // 5. Apply replacement + let new_content = if self.replace_all { + content.replace(&self.old_string, &self.new_string) + } else { + content.replacen(&self.old_string, &self.new_string, 1) + }; + + // 6. Write atomically + let new_bytes = new_content.into_bytes(); + if let Err(e) = crate::snapshots::atomic_write_file(resolved_path, &new_bytes) { + return (ToolOutcome::Error(format!("Error writing file: {e}")), None); + } + + // 7. Success + let verb = if match_count == 1 { + "occurrence" + } else { + "occurrences" + }; + ( + ToolOutcome::Success(format!( + "Edited {}: replaced {match_count} {verb} of old_string with new_string.", + resolved_path.display() + )), + Some(new_bytes), + ) + } +} + +impl PermissableToolCall for EditToolCall { + fn target_dir(&self) -> Option<&Path> { + Some(&self.path) + } + + fn matches_rule(&self, rule: &Rule) -> bool { + if rule.tool != "Write" { return false; } @@ -532,7 +745,7 @@ impl TryFrom<&serde_json::Value> for WriteToolCall { .ok_or(eyre::eyre!("Missing content"))?; Ok(WriteToolCall { - path: PathBuf::from(path), + path: expand_path(path), content: content.to_string(), }) } @@ -560,6 +773,9 @@ pub(crate) struct ShellToolCall { pub dir: Option<PathBuf>, pub command: String, pub shell: String, + // allow dead code here; this will be tied into o11y and user-facing descriptions + #[expect(dead_code)] + pub description: Option<String>, } impl TryFrom<&serde_json::Value> for ShellToolCall { @@ -579,10 +795,16 @@ impl TryFrom<&serde_json::Value> for ShellToolCall { .unwrap_or("bash") .to_string(); + let description = value + .get("description") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + Ok(ShellToolCall { - dir: dir.map(PathBuf::from), + dir: dir.map(expand_path), command: command.to_string(), shell, + description, }) } } @@ -614,7 +836,34 @@ const PREVIEW_HEIGHT: u16 = 10; /// Default terminal width for VT100 emulation. const PREVIEW_WIDTH: u16 = 120; +/// Normalize newlines for VT100 processing. +/// +/// When subprocess output is captured via pipes (no PTY), bare `\n` (LF) bytes +/// are not translated to `\r\n` (CR+LF) the way a kernel terminal driver would +/// with the `ONLCR` flag. In VT100, LF only moves the cursor down without +/// returning to column 0. This causes lines to start at progressively higher +/// column offsets and eventually wrap, producing garbled output. +/// +/// This function inserts `\r` before any `\n` that isn't already preceded by +/// `\r`, mimicking the terminal driver's ONLCR behavior. +fn normalize_newlines_for_vt100(data: &[u8]) -> Vec<u8> { + let mut out = Vec::with_capacity(data.len() + data.len() / 8); + for (i, &b) in data.iter().enumerate() { + if b == b'\n' && (i == 0 || data[i - 1] != b'\r') { + out.push(b'\r'); + } + out.push(b); + } + out +} + /// Extract plain text lines from a VT100 screen buffer. +/// +/// Strips trailing blank lines so the result only contains rows with actual +/// content. Without this, the fixed-size VT100 screen (PREVIEW_HEIGHT rows) +/// would always return that many lines, and downstream components that use +/// tail-mode display (like the Viewport) would show the blank padding rows +/// instead of the real output. fn vt100_screen_lines(screen: &vt100::Screen) -> Vec<String> { let (rows, cols) = screen.size(); let mut lines = Vec::with_capacity(rows as usize); @@ -625,9 +874,11 @@ fn vt100_screen_lines(screen: &vt100::Screen) -> Vec<String> { line.push_str(cell.contents()); } } - // Trim trailing whitespace for cleaner display lines.push(line.trim_end().to_string()); } + while lines.last().is_some_and(|l| l.is_empty()) { + lines.pop(); + } lines } @@ -640,12 +891,17 @@ fn strip_ansi_via_vt100(raw: &[u8]) -> String { if raw.is_empty() { return String::new(); } - // Use the contents_formatted → screen approach: feed bytes into a parser - // with enough rows to hold everything, then read back the plain text. - // Estimate rows: one row per ~PREVIEW_WIDTH bytes, plus generous padding. - let estimated_rows = (raw.len() / PREVIEW_WIDTH as usize + 1).min(10_000) as u16; + // Normalize bare LF to CR+LF so lines start at column 0 in the VT100 screen. + let normalized = normalize_newlines_for_vt100(raw); + // Feed bytes into a VT100 parser large enough to hold all output, then + // read back the plain text. We estimate rows from the number of newlines + // (not total byte length) because real output typically has short lines + // that would be severely under-counted by a bytes÷width estimate. + let newline_count = normalized.iter().filter(|&&b| b == b'\n').count(); + let wrap_estimate = normalized.len() / PREVIEW_WIDTH as usize; + let estimated_rows = (newline_count + wrap_estimate + 1).min(10_000) as u16; let mut parser = vt100::Parser::new(estimated_rows, PREVIEW_WIDTH, 0); - parser.process(raw); + parser.process(&normalized); let screen = parser.screen(); // screen.contents() returns the full plain-text content with trailing // whitespace trimmed per line and trailing blank lines removed. @@ -727,7 +983,8 @@ pub(crate) async fn execute_shell_command_streaming( Ok(0) => stdout_done = true, Ok(n) => { full_stdout.extend_from_slice(&stdout_buf[..n]); - parser.process(&stdout_buf[..n]); + let normalized = normalize_newlines_for_vt100(&stdout_buf[..n]); + parser.process(&normalized); } Err(_) => stdout_done = true, } @@ -740,7 +997,8 @@ pub(crate) async fn execute_shell_command_streaming( Ok(n) => { full_stderr.extend_from_slice(&stderr_buf[..n]); // Feed stderr to the preview parser too, so it shows in the VT100 screen - parser.process(&stderr_buf[..n]); + let normalized = normalize_newlines_for_vt100(&stderr_buf[..n]); + parser.process(&normalized); } Err(_) => stderr_done = true, } @@ -967,7 +1225,7 @@ mod tests { fn read_tool(path: &str) -> ReadToolCall { ReadToolCall { - path: PathBuf::from(path), + path: expand_path(path), offset: 0, limit: 100, } @@ -975,7 +1233,7 @@ mod tests { fn write_tool(path: &str) -> WriteToolCall { WriteToolCall { - path: PathBuf::from(path), + path: expand_path(path), content: String::new(), } } @@ -994,12 +1252,26 @@ mod tests { } #[test] - fn wrong_tool_never_matches() { - assert!(!read_tool("foo.txt").matches_rule(&write_rule(None))); + fn write_implies_read() { + // A Write rule also permits reads on the same path + assert!(read_tool("foo.txt").matches_rule(&write_rule(None))); + // But a Read rule does not permit writes assert!(!write_tool("foo.txt").matches_rule(&read_rule(None))); } #[test] + fn edit_uses_write_rule() { + let edit = EditToolCall { + path: expand_path("/home/user/config.toml"), + old_string: "x".into(), + new_string: "y".into(), + replace_all: false, + }; + assert!(edit.matches_rule(&write_rule(None))); + assert!(!edit.matches_rule(&read_rule(None))); + } + + #[test] fn extension_glob() { assert!(read_tool("notes.md").matches_rule(&read_rule(Some("*.md")))); assert!(!read_tool("notes.txt").matches_rule(&read_rule(Some("*.md")))); @@ -1050,6 +1322,419 @@ mod tests { } } + // ── edit_file execution tests ── + + mod edit { + use super::*; + use crate::file_tracker::FileReadTracker; + + /// Helper: create a temp file (with a closed handle), record it in a tracker. + /// Returns the TempDir (keeps the path alive) and tracker. + /// The file handle is closed so atomic_write_file can rename over it on Windows. + fn setup_tracked_file(content: &str) -> (tempfile::TempDir, PathBuf, FileReadTracker) { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("test_file.toml"); + std::fs::write(&path, content).unwrap(); + + let file_content = std::fs::read(&path).unwrap(); + let mtime = std::fs::metadata(&path).unwrap().modified().unwrap(); + + let mut tracker = FileReadTracker::default(); + tracker.record_read(path.clone(), &file_content, mtime); + + (dir, path, tracker) + } + + fn edit_call(path: &Path, old: &str, new: &str, replace_all: bool) -> EditToolCall { + EditToolCall { + path: path.to_path_buf(), + old_string: old.to_string(), + new_string: new.to_string(), + replace_all, + } + } + + #[test] + fn successful_single_replacement() { + let (_dir, path, tracker) = setup_tracked_file("[section]\nkey = old_value\n"); + + let call = edit_call(&path, "old_value", "new_value", false); + let (outcome, new_bytes) = call.execute(&path, &tracker); + + assert!(matches!(outcome, ToolOutcome::Success(_))); + assert!(new_bytes.is_some()); + assert_eq!( + std::fs::read_to_string(&path).unwrap(), + "[section]\nkey = new_value\n" + ); + } + + #[test] + fn successful_replace_all() { + let (_dir, path, tracker) = setup_tracked_file("aaa bbb aaa ccc aaa"); + + let call = edit_call(&path, "aaa", "xxx", true); + let (outcome, _) = call.execute(&path, &tracker); + + assert!(matches!(outcome, ToolOutcome::Success(ref s) if s.contains("3 occurrences"))); + assert_eq!( + std::fs::read_to_string(&path).unwrap(), + "xxx bbb xxx ccc xxx" + ); + } + + #[test] + fn error_file_not_read() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("unread.txt"); + std::fs::write(&path, "content").unwrap(); + let tracker = FileReadTracker::default(); // empty — never read + + let call = edit_call(&path, "x", "y", false); + let (outcome, new_bytes) = call.execute(&path, &tracker); + + assert!(new_bytes.is_none()); + match outcome { + ToolOutcome::Error(msg) => { + assert!(msg.contains("not been read yet"), "got: {msg}"); + } + _ => panic!("expected error"), + } + } + + #[test] + fn error_file_modified_since_read() { + let (_dir, path, tracker) = setup_tracked_file("original"); + + // Modify the file after the read was recorded + std::thread::sleep(std::time::Duration::from_millis(10)); + std::fs::write(&path, "modified externally").unwrap(); + + let call = edit_call(&path, "original", "replaced", false); + let (outcome, new_bytes) = call.execute(&path, &tracker); + + assert!(new_bytes.is_none()); + match outcome { + ToolOutcome::Error(msg) => { + assert!(msg.contains("modified since read"), "got: {msg}"); + } + _ => panic!("expected error"), + } + } + + #[test] + fn error_no_match() { + let (_dir, path, tracker) = setup_tracked_file("hello world"); + + let call = edit_call(&path, "nonexistent", "replacement", false); + let (outcome, new_bytes) = call.execute(&path, &tracker); + + assert!(new_bytes.is_none()); + match outcome { + ToolOutcome::Error(msg) => { + assert!(msg.contains("not found"), "got: {msg}"); + } + _ => panic!("expected error"), + } + } + + #[test] + fn error_multiple_matches_without_replace_all() { + let (_dir, path, tracker) = setup_tracked_file("foo bar foo baz foo"); + + let call = edit_call(&path, "foo", "qux", false); + let (outcome, new_bytes) = call.execute(&path, &tracker); + + assert!(new_bytes.is_none()); + match outcome { + ToolOutcome::Error(msg) => { + assert!(msg.contains("3 matches"), "got: {msg}"); + assert!(msg.contains("replace_all"), "got: {msg}"); + } + _ => panic!("expected error"), + } + // File should be unchanged + assert_eq!( + std::fs::read_to_string(&path).unwrap(), + "foo bar foo baz foo" + ); + } + + #[test] + fn error_empty_old_string() { + let (_dir, path, tracker) = setup_tracked_file("content"); + + let call = edit_call(&path, "", "something", false); + let (outcome, new_bytes) = call.execute(&path, &tracker); + + assert!(new_bytes.is_none()); + assert!(matches!(outcome, ToolOutcome::Error(_))); + } + + #[test] + fn error_file_does_not_exist() { + let tracker = FileReadTracker::default(); + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("nonexistent.txt"); + + let call = edit_call(&path, "x", "y", false); + let (outcome, new_bytes) = call.execute(&path, &tracker); + + assert!(new_bytes.is_none()); + match outcome { + ToolOutcome::Error(msg) => { + assert!(msg.contains("does not exist"), "got: {msg}"); + } + _ => panic!("expected error"), + } + } + + #[test] + fn preserves_file_when_no_match() { + let original = "[config]\nport = 8080\nhost = localhost\n"; + let (_dir, path, tracker) = setup_tracked_file(original); + + let call = edit_call(&path, "port = 9090", "port = 3000", false); + let (outcome, _) = call.execute(&path, &tracker); + + assert!(matches!(outcome, ToolOutcome::Error(_))); + assert_eq!(std::fs::read_to_string(&path).unwrap(), original); + } + + #[test] + fn multiline_replacement() { + let content = "[section]\nkey1 = val1\nkey2 = val2\n[other]\n"; + let (_dir, path, tracker) = setup_tracked_file(content); + + let call = edit_call( + &path, + "key1 = val1\nkey2 = val2", + "key1 = new1\nkey2 = new2", + false, + ); + let (outcome, new_bytes) = call.execute(&path, &tracker); + + assert!(matches!(outcome, ToolOutcome::Success(_))); + assert!(new_bytes.is_some()); + assert_eq!( + std::fs::read_to_string(&path).unwrap(), + "[section]\nkey1 = new1\nkey2 = new2\n[other]\n" + ); + } + } + + // ── Integration tests: full edit lifecycle ── + // + // These exercise the cross-component flow that dispatch orchestrates: + // FileReadTracker → SnapshotStore → EditToolCall.execute → tracker update + + mod edit_integration { + use super::*; + use crate::edit_permissions::EditPermissionCache; + use crate::file_tracker::FileReadTracker; + use crate::snapshots::SnapshotStore; + + /// Simulate a file read (what dispatch does after ReadToolCall.execute). + fn simulate_read(tracker: &mut FileReadTracker, path: &std::path::Path) { + let content = std::fs::read(path).unwrap(); + let mtime = std::fs::metadata(path).unwrap().modified().unwrap(); + tracker.record_read(path.to_path_buf(), &content, mtime); + } + + /// Simulate a tracker update after edit (what dispatch does after execute). + fn simulate_tracker_update( + tracker: &mut FileReadTracker, + path: &std::path::Path, + new_bytes: &[u8], + ) { + let mtime = std::fs::metadata(path).unwrap().modified().unwrap(); + tracker.update_after_edit(path, new_bytes, mtime); + } + + #[test] + fn full_read_snapshot_edit_cycle() { + let dir = tempfile::tempdir().unwrap(); + let file_path = dir.path().join("config.toml"); + std::fs::write(&file_path, "[db]\nhost = localhost\nport = 5432\n").unwrap(); + + let snapshot_dir = dir.path().join("snapshots").join("session-1"); + let mut tracker = FileReadTracker::default(); + let mut store = SnapshotStore::open(snapshot_dir.clone()).unwrap(); + + // 1. Simulate reading the file + simulate_read(&mut tracker, &file_path); + + // 2. Snapshot before edit + let original = std::fs::read(&file_path).unwrap(); + store.ensure_snapshot(&file_path, &original).unwrap(); + + // 3. Execute edit + let call = EditToolCall { + path: file_path.clone(), + old_string: "host = localhost".to_string(), + new_string: "host = 10.0.0.1".to_string(), + replace_all: false, + }; + let (outcome, new_bytes) = call.execute(&file_path, &tracker); + assert!(matches!(outcome, ToolOutcome::Success(_))); + let new_bytes = new_bytes.unwrap(); + + // 4. Update tracker (simulating what dispatch does) + simulate_tracker_update(&mut tracker, &file_path, &new_bytes); + + // Verify: file was edited + assert_eq!( + std::fs::read_to_string(&file_path).unwrap(), + "[db]\nhost = 10.0.0.1\nport = 5432\n" + ); + + // Verify: snapshot has original content + assert!(store.has_snapshot(&file_path)); + let snapshot_name = crate::snapshots::sanitize_path(&file_path); + let snapshot_content = + std::fs::read_to_string(snapshot_dir.join(snapshot_name)).unwrap(); + assert_eq!(snapshot_content, "[db]\nhost = localhost\nport = 5432\n"); + } + + #[test] + fn second_edit_without_reread() { + let dir = tempfile::tempdir().unwrap(); + let file_path = dir.path().join("config.toml"); + std::fs::write(&file_path, "key1 = aaa\nkey2 = bbb\n").unwrap(); + + let mut tracker = FileReadTracker::default(); + + // Read the file + simulate_read(&mut tracker, &file_path); + + // First edit + let call1 = EditToolCall { + path: file_path.clone(), + old_string: "key1 = aaa".to_string(), + new_string: "key1 = xxx".to_string(), + replace_all: false, + }; + let (outcome, new_bytes) = call1.execute(&file_path, &tracker); + assert!(matches!(outcome, ToolOutcome::Success(_))); + simulate_tracker_update(&mut tracker, &file_path, &new_bytes.unwrap()); + + // Second edit — should work without re-reading because tracker was updated + let call2 = EditToolCall { + path: file_path.clone(), + old_string: "key2 = bbb".to_string(), + new_string: "key2 = yyy".to_string(), + replace_all: false, + }; + let (outcome, new_bytes) = call2.execute(&file_path, &tracker); + assert!(matches!(outcome, ToolOutcome::Success(_))); + assert!(new_bytes.is_some()); + assert_eq!( + std::fs::read_to_string(&file_path).unwrap(), + "key1 = xxx\nkey2 = yyy\n" + ); + } + + #[test] + fn external_modification_between_edits() { + let dir = tempfile::tempdir().unwrap(); + let file_path = dir.path().join("config.toml"); + std::fs::write(&file_path, "value = original\n").unwrap(); + + let mut tracker = FileReadTracker::default(); + simulate_read(&mut tracker, &file_path); + + // First edit succeeds + let call1 = EditToolCall { + path: file_path.clone(), + old_string: "value = original".to_string(), + new_string: "value = edited".to_string(), + replace_all: false, + }; + let (outcome, new_bytes) = call1.execute(&file_path, &tracker); + assert!(matches!(outcome, ToolOutcome::Success(_))); + simulate_tracker_update(&mut tracker, &file_path, &new_bytes.unwrap()); + + // External modification (e.g., user edits the file) + std::thread::sleep(std::time::Duration::from_millis(10)); + std::fs::write(&file_path, "value = user_changed\n").unwrap(); + + // Second edit should fail (stale) + let call2 = EditToolCall { + path: file_path.clone(), + old_string: "value = edited".to_string(), + new_string: "value = second_edit".to_string(), + replace_all: false, + }; + let (outcome, new_bytes) = call2.execute(&file_path, &tracker); + assert!(new_bytes.is_none()); + match outcome { + ToolOutcome::Error(msg) => assert!(msg.contains("modified since read")), + _ => panic!("expected stale error"), + } + + // File should be unchanged (the user's edit preserved) + assert_eq!( + std::fs::read_to_string(&file_path).unwrap(), + "value = user_changed\n" + ); + } + + #[test] + fn snapshot_only_created_once_per_file() { + let dir = tempfile::tempdir().unwrap(); + let file_path = dir.path().join("config.toml"); + std::fs::write(&file_path, "a = 1\nb = 2\n").unwrap(); + + let snapshot_dir = dir.path().join("snapshots").join("session-1"); + let mut tracker = FileReadTracker::default(); + let mut store = SnapshotStore::open(snapshot_dir).unwrap(); + + simulate_read(&mut tracker, &file_path); + + // First edit — snapshot should be created + let original = std::fs::read(&file_path).unwrap(); + let created = store.ensure_snapshot(&file_path, &original).unwrap(); + assert!(created); + + let call1 = EditToolCall { + path: file_path.clone(), + old_string: "a = 1".to_string(), + new_string: "a = 10".to_string(), + replace_all: false, + }; + let (_, new_bytes) = call1.execute(&file_path, &tracker); + simulate_tracker_update(&mut tracker, &file_path, &new_bytes.unwrap()); + + // Second edit — snapshot should NOT be recreated + let content_before_second = std::fs::read(&file_path).unwrap(); + let created = store + .ensure_snapshot(&file_path, &content_before_second) + .unwrap(); + assert!(!created); // idempotent — already snapshotted + } + + #[test] + fn permission_cache_grant_and_check() { + let mut cache = EditPermissionCache::default(); + let path = std::path::PathBuf::from("/Users/me/.config/atuin/config.toml"); + + // Initially no grant + assert!(!cache.has_valid_grant(&path)); + + // Grant permission + cache.grant(path.clone()); + assert!(cache.has_valid_grant(&path)); + + // Different file has no grant + assert!(!cache.has_valid_grant(std::path::Path::new("/other/file.toml"))); + + // Roundtrip through JSON (simulates session persistence) + let json = cache.to_json().unwrap(); + let restored = EditPermissionCache::from_json(&json).unwrap(); + assert!(restored.has_valid_grant(&path)); + } + } + // ── Windows-specific tests (absolute paths with drive letters) ── #[cfg(windows)] |
