aboutsummaryrefslogtreecommitdiffstats
path: root/crates/atuin-ai/src/tools/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/atuin-ai/src/tools/mod.rs')
-rw-r--r--crates/atuin-ai/src/tools/mod.rs737
1 files changed, 711 insertions, 26 deletions
diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs
index 8f2183b7..8fe1ad73 100644
--- a/crates/atuin-ai/src/tools/mod.rs
+++ b/crates/atuin-ai/src/tools/mod.rs
@@ -169,6 +169,8 @@ pub(crate) struct TrackedTool {
pub phase: ToolPhase,
/// Sender to interrupt a running shell command (only set during ExecutingWithPreview).
pub abort_tx: Option<tokio::sync::oneshot::Sender<()>>,
+ /// Diff preview for completed edit tool calls.
+ pub edit_preview: Option<crate::diff::EditPreview>,
}
impl TrackedTool {
@@ -234,6 +236,7 @@ impl ToolTracker {
tool,
phase: ToolPhase::CheckingPermissions,
abort_tx: None,
+ edit_preview: None,
});
}
@@ -294,11 +297,6 @@ impl ToolTracker {
.find(|t| t.phase == ToolPhase::AskingForPermission)
}
- /// Get the preview for a tool by ID (live or cached).
- pub fn preview_for(&self, id: &str) -> Option<ToolPreview> {
- self.get(id)?.preview()
- }
-
/// Iterate mutably over all tracked tools.
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut TrackedTool> {
self.tools.iter_mut()
@@ -309,6 +307,7 @@ impl ToolTracker {
#[derive(Debug, Clone)]
pub(crate) enum ClientToolCall {
Read(ReadToolCall),
+ Edit(EditToolCall),
Write(WriteToolCall),
Shell(ShellToolCall),
AtuinHistory(AtuinHistoryToolCall),
@@ -320,9 +319,8 @@ impl TryFrom<(&str, &serde_json::Value)> for ClientToolCall {
fn try_from((name, input): (&str, &serde_json::Value)) -> Result<Self, Self::Error> {
match name {
"read_file" => Ok(ClientToolCall::Read(ReadToolCall::try_from(input)?)),
- "create_file" => Ok(ClientToolCall::Write(WriteToolCall::try_from(input)?)),
- // "append_to_file" => Ok(ClientToolCall::Append(AppendToolCall::try_from(input)?)),
- // "str_replace" => Ok(ClientToolCall::StrReplace(StrReplaceToolCall::try_from(input)?)),
+ "edit_file" => Ok(ClientToolCall::Edit(EditToolCall::try_from(input)?)),
+ "write_file" => Ok(ClientToolCall::Write(WriteToolCall::try_from(input)?)),
"execute_shell_command" => Ok(ClientToolCall::Shell(ShellToolCall::try_from(input)?)),
"atuin_history" => Ok(ClientToolCall::AtuinHistory(
AtuinHistoryToolCall::try_from(input)?,
@@ -336,17 +334,22 @@ impl ClientToolCall {
pub(crate) fn descriptor(&self) -> &'static descriptor::ToolDescriptor {
match self {
ClientToolCall::Read(_) => descriptor::READ,
+ ClientToolCall::Edit(_) => descriptor::EDIT,
ClientToolCall::Write(_) => descriptor::WRITE,
ClientToolCall::Shell(_) => descriptor::SHELL,
ClientToolCall::AtuinHistory(_) => descriptor::ATUIN_HISTORY,
}
}
- /// The permission rule name for this tool category (e.g. "Write" covers
- /// str_replace, file_create, file_insert).
+ /// The permission rule name for this tool category.
+ ///
+ /// Edit and Write share the `"Write"` rule name — a Write permission
+ /// covers both str_replace edits and full file creates. Write also
+ /// implies Read (checked in `ReadToolCall::matches_rule`).
pub(crate) fn rule_name(&self) -> &'static str {
match self {
ClientToolCall::Read(_) => "Read",
+ ClientToolCall::Edit(_) => "Write",
ClientToolCall::Write(_) => "Write",
ClientToolCall::Shell(_) => "Shell",
ClientToolCall::AtuinHistory(_) => "AtuinHistory",
@@ -356,6 +359,7 @@ impl ClientToolCall {
pub(crate) fn matches_rule(&self, rule: &Rule) -> bool {
match self {
ClientToolCall::Read(tool) => tool.matches_rule(rule),
+ ClientToolCall::Edit(tool) => tool.matches_rule(rule),
ClientToolCall::Write(tool) => tool.matches_rule(rule),
ClientToolCall::Shell(tool) => tool.matches_rule(rule),
ClientToolCall::AtuinHistory(tool) => tool.matches_rule(rule),
@@ -365,6 +369,7 @@ impl ClientToolCall {
pub(crate) fn target_dir(&self) -> Option<&Path> {
match self {
ClientToolCall::Read(tool) => tool.target_dir(),
+ ClientToolCall::Edit(tool) => tool.target_dir(),
ClientToolCall::Write(tool) => tool.target_dir(),
ClientToolCall::Shell(tool) => tool.target_dir(),
ClientToolCall::AtuinHistory(tool) => tool.target_dir(),
@@ -401,6 +406,14 @@ impl PermissableToolCall for ClientToolCall {
}
}
+/// Expand shell constructs (`~`, `$HOME`, etc.) in a path string.
+///
+/// Tool call paths arrive as raw strings from the API without shell
+/// expansion. Uses `shellexpand` (same as `atuin-client`).
+fn expand_path(path: &str) -> PathBuf {
+ PathBuf::from(shellexpand::tilde(path).into_owned())
+}
+
#[derive(Debug, Clone)]
pub(crate) struct ReadToolCall {
pub path: PathBuf,
@@ -425,7 +438,7 @@ impl TryFrom<&serde_json::Value> for ReadToolCall {
.min(MAX_FILE_READ_LINES);
Ok(ReadToolCall {
- path: PathBuf::from(path),
+ path: expand_path(path),
offset,
limit,
})
@@ -499,7 +512,207 @@ impl PermissableToolCall for ReadToolCall {
}
fn matches_rule(&self, rule: &Rule) -> bool {
- if rule.tool != "Read" {
+ // Write implies Read — a Write permission on a path also permits reading it.
+ if rule.tool != "Read" && rule.tool != "Write" {
+ return false;
+ }
+
+ match rule.scope.as_deref() {
+ None | Some("*") => true,
+ Some(scope) => path_matches_scope(&self.path, scope),
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub(crate) struct EditToolCall {
+ pub path: PathBuf,
+ pub old_string: String,
+ pub new_string: String,
+ pub replace_all: bool,
+}
+
+impl TryFrom<&serde_json::Value> for EditToolCall {
+ type Error = eyre::Error;
+
+ fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> {
+ let path = value
+ .get("file_path")
+ .and_then(|v| v.as_str())
+ .ok_or(eyre::eyre!("Missing file_path"))?;
+
+ let old_string = value
+ .get("old_string")
+ .and_then(|v| v.as_str())
+ .ok_or(eyre::eyre!("Missing old_string"))?;
+
+ let new_string = value
+ .get("new_string")
+ .and_then(|v| v.as_str())
+ .ok_or(eyre::eyre!("Missing new_string"))?;
+
+ let replace_all = value
+ .get("replace_all")
+ .and_then(|v| v.as_bool())
+ .unwrap_or(false);
+
+ Ok(EditToolCall {
+ path: expand_path(path),
+ old_string: old_string.to_string(),
+ new_string: new_string.to_string(),
+ replace_all,
+ })
+ }
+}
+
+impl EditToolCall {
+ /// Resolve the edit path to an absolute path.
+ pub fn resolved_path(&self) -> PathBuf {
+ if self.path.is_relative() {
+ std::env::current_dir()
+ .map(|cwd| cwd.join(&self.path))
+ .unwrap_or_else(|_| self.path.clone())
+ } else {
+ self.path.clone()
+ }
+ }
+
+ /// Execute the edit against the filesystem.
+ ///
+ /// Checks freshness via the provided tracker, validates matches, applies
+ /// the replacement, and writes atomically. Returns the outcome and (on
+ /// success) the new file content bytes for tracker updates.
+ ///
+ /// Callers should snapshot the file before calling this method and
+ /// update the file tracker after a successful return.
+ pub fn execute(
+ &self,
+ resolved_path: &Path,
+ file_tracker: &crate::file_tracker::FileReadTracker,
+ ) -> (ToolOutcome, Option<Vec<u8>>) {
+ use crate::file_tracker::FreshnessCheck;
+
+ // 1. Basic validation
+ if !resolved_path.exists() {
+ return (
+ ToolOutcome::Error(format!(
+ "Error: file does not exist: {}",
+ resolved_path.display()
+ )),
+ None,
+ );
+ }
+ if resolved_path.is_dir() {
+ return (
+ ToolOutcome::Error(format!(
+ "Error: path is a directory, not a file: {}",
+ resolved_path.display()
+ )),
+ None,
+ );
+ }
+ if self.old_string.is_empty() {
+ return (
+ ToolOutcome::Error(
+ "old_string must not be empty. To create a new file, use write_file instead."
+ .to_string(),
+ ),
+ None,
+ );
+ }
+
+ // 2. Freshness check
+ match file_tracker.check_freshness(resolved_path) {
+ Ok(FreshnessCheck::NotRead) => {
+ return (
+ ToolOutcome::Error(
+ "File has not been read yet. Read it first before editing.".to_string(),
+ ),
+ None,
+ );
+ }
+ Ok(FreshnessCheck::Stale) => {
+ return (
+ ToolOutcome::Error(
+ "File has been modified since read, either by the user or by a linter. Read it again before attempting to edit it.".to_string(),
+ ),
+ None,
+ );
+ }
+ Err(e) => {
+ return (
+ ToolOutcome::Error(format!("Error checking file state: {e}")),
+ None,
+ );
+ }
+ Ok(FreshnessCheck::Fresh) => {}
+ }
+
+ // 3. Read current contents
+ let content = match std::fs::read_to_string(resolved_path) {
+ Ok(c) => c,
+ Err(e) => return (ToolOutcome::Error(format!("Error reading file: {e}")), None),
+ };
+
+ // 4. Find and validate matches
+ let match_count = content.matches(&self.old_string).count();
+
+ if match_count == 0 {
+ return (
+ ToolOutcome::Error(format!(
+ "old_string not found in {}. Make sure it matches exactly, including whitespace and indentation.",
+ resolved_path.display()
+ )),
+ None,
+ );
+ }
+
+ if match_count > 1 && !self.replace_all {
+ return (
+ ToolOutcome::Error(format!(
+ "Found {match_count} matches of old_string in {}, but replace_all is false. Either provide more context to make the match unique, or set replace_all to true.",
+ resolved_path.display()
+ )),
+ None,
+ );
+ }
+
+ // 5. Apply replacement
+ let new_content = if self.replace_all {
+ content.replace(&self.old_string, &self.new_string)
+ } else {
+ content.replacen(&self.old_string, &self.new_string, 1)
+ };
+
+ // 6. Write atomically
+ let new_bytes = new_content.into_bytes();
+ if let Err(e) = crate::snapshots::atomic_write_file(resolved_path, &new_bytes) {
+ return (ToolOutcome::Error(format!("Error writing file: {e}")), None);
+ }
+
+ // 7. Success
+ let verb = if match_count == 1 {
+ "occurrence"
+ } else {
+ "occurrences"
+ };
+ (
+ ToolOutcome::Success(format!(
+ "Edited {}: replaced {match_count} {verb} of old_string with new_string.",
+ resolved_path.display()
+ )),
+ Some(new_bytes),
+ )
+ }
+}
+
+impl PermissableToolCall for EditToolCall {
+ fn target_dir(&self) -> Option<&Path> {
+ Some(&self.path)
+ }
+
+ fn matches_rule(&self, rule: &Rule) -> bool {
+ if rule.tool != "Write" {
return false;
}
@@ -532,7 +745,7 @@ impl TryFrom<&serde_json::Value> for WriteToolCall {
.ok_or(eyre::eyre!("Missing content"))?;
Ok(WriteToolCall {
- path: PathBuf::from(path),
+ path: expand_path(path),
content: content.to_string(),
})
}
@@ -560,6 +773,9 @@ pub(crate) struct ShellToolCall {
pub dir: Option<PathBuf>,
pub command: String,
pub shell: String,
+ // allow dead code here; this will be tied into o11y and user-facing descriptions
+ #[expect(dead_code)]
+ pub description: Option<String>,
}
impl TryFrom<&serde_json::Value> for ShellToolCall {
@@ -579,10 +795,16 @@ impl TryFrom<&serde_json::Value> for ShellToolCall {
.unwrap_or("bash")
.to_string();
+ let description = value
+ .get("description")
+ .and_then(|v| v.as_str())
+ .map(|s| s.to_string());
+
Ok(ShellToolCall {
- dir: dir.map(PathBuf::from),
+ dir: dir.map(expand_path),
command: command.to_string(),
shell,
+ description,
})
}
}
@@ -614,7 +836,34 @@ const PREVIEW_HEIGHT: u16 = 10;
/// Default terminal width for VT100 emulation.
const PREVIEW_WIDTH: u16 = 120;
+/// Normalize newlines for VT100 processing.
+///
+/// When subprocess output is captured via pipes (no PTY), bare `\n` (LF) bytes
+/// are not translated to `\r\n` (CR+LF) the way a kernel terminal driver would
+/// with the `ONLCR` flag. In VT100, LF only moves the cursor down without
+/// returning to column 0. This causes lines to start at progressively higher
+/// column offsets and eventually wrap, producing garbled output.
+///
+/// This function inserts `\r` before any `\n` that isn't already preceded by
+/// `\r`, mimicking the terminal driver's ONLCR behavior.
+fn normalize_newlines_for_vt100(data: &[u8]) -> Vec<u8> {
+ let mut out = Vec::with_capacity(data.len() + data.len() / 8);
+ for (i, &b) in data.iter().enumerate() {
+ if b == b'\n' && (i == 0 || data[i - 1] != b'\r') {
+ out.push(b'\r');
+ }
+ out.push(b);
+ }
+ out
+}
+
/// Extract plain text lines from a VT100 screen buffer.
+///
+/// Strips trailing blank lines so the result only contains rows with actual
+/// content. Without this, the fixed-size VT100 screen (PREVIEW_HEIGHT rows)
+/// would always return that many lines, and downstream components that use
+/// tail-mode display (like the Viewport) would show the blank padding rows
+/// instead of the real output.
fn vt100_screen_lines(screen: &vt100::Screen) -> Vec<String> {
let (rows, cols) = screen.size();
let mut lines = Vec::with_capacity(rows as usize);
@@ -625,9 +874,11 @@ fn vt100_screen_lines(screen: &vt100::Screen) -> Vec<String> {
line.push_str(cell.contents());
}
}
- // Trim trailing whitespace for cleaner display
lines.push(line.trim_end().to_string());
}
+ while lines.last().is_some_and(|l| l.is_empty()) {
+ lines.pop();
+ }
lines
}
@@ -640,12 +891,17 @@ fn strip_ansi_via_vt100(raw: &[u8]) -> String {
if raw.is_empty() {
return String::new();
}
- // Use the contents_formatted → screen approach: feed bytes into a parser
- // with enough rows to hold everything, then read back the plain text.
- // Estimate rows: one row per ~PREVIEW_WIDTH bytes, plus generous padding.
- let estimated_rows = (raw.len() / PREVIEW_WIDTH as usize + 1).min(10_000) as u16;
+ // Normalize bare LF to CR+LF so lines start at column 0 in the VT100 screen.
+ let normalized = normalize_newlines_for_vt100(raw);
+ // Feed bytes into a VT100 parser large enough to hold all output, then
+ // read back the plain text. We estimate rows from the number of newlines
+ // (not total byte length) because real output typically has short lines
+ // that would be severely under-counted by a bytes÷width estimate.
+ let newline_count = normalized.iter().filter(|&&b| b == b'\n').count();
+ let wrap_estimate = normalized.len() / PREVIEW_WIDTH as usize;
+ let estimated_rows = (newline_count + wrap_estimate + 1).min(10_000) as u16;
let mut parser = vt100::Parser::new(estimated_rows, PREVIEW_WIDTH, 0);
- parser.process(raw);
+ parser.process(&normalized);
let screen = parser.screen();
// screen.contents() returns the full plain-text content with trailing
// whitespace trimmed per line and trailing blank lines removed.
@@ -727,7 +983,8 @@ pub(crate) async fn execute_shell_command_streaming(
Ok(0) => stdout_done = true,
Ok(n) => {
full_stdout.extend_from_slice(&stdout_buf[..n]);
- parser.process(&stdout_buf[..n]);
+ let normalized = normalize_newlines_for_vt100(&stdout_buf[..n]);
+ parser.process(&normalized);
}
Err(_) => stdout_done = true,
}
@@ -740,7 +997,8 @@ pub(crate) async fn execute_shell_command_streaming(
Ok(n) => {
full_stderr.extend_from_slice(&stderr_buf[..n]);
// Feed stderr to the preview parser too, so it shows in the VT100 screen
- parser.process(&stderr_buf[..n]);
+ let normalized = normalize_newlines_for_vt100(&stderr_buf[..n]);
+ parser.process(&normalized);
}
Err(_) => stderr_done = true,
}
@@ -967,7 +1225,7 @@ mod tests {
fn read_tool(path: &str) -> ReadToolCall {
ReadToolCall {
- path: PathBuf::from(path),
+ path: expand_path(path),
offset: 0,
limit: 100,
}
@@ -975,7 +1233,7 @@ mod tests {
fn write_tool(path: &str) -> WriteToolCall {
WriteToolCall {
- path: PathBuf::from(path),
+ path: expand_path(path),
content: String::new(),
}
}
@@ -994,12 +1252,26 @@ mod tests {
}
#[test]
- fn wrong_tool_never_matches() {
- assert!(!read_tool("foo.txt").matches_rule(&write_rule(None)));
+ fn write_implies_read() {
+ // A Write rule also permits reads on the same path
+ assert!(read_tool("foo.txt").matches_rule(&write_rule(None)));
+ // But a Read rule does not permit writes
assert!(!write_tool("foo.txt").matches_rule(&read_rule(None)));
}
#[test]
+ fn edit_uses_write_rule() {
+ let edit = EditToolCall {
+ path: expand_path("/home/user/config.toml"),
+ old_string: "x".into(),
+ new_string: "y".into(),
+ replace_all: false,
+ };
+ assert!(edit.matches_rule(&write_rule(None)));
+ assert!(!edit.matches_rule(&read_rule(None)));
+ }
+
+ #[test]
fn extension_glob() {
assert!(read_tool("notes.md").matches_rule(&read_rule(Some("*.md"))));
assert!(!read_tool("notes.txt").matches_rule(&read_rule(Some("*.md"))));
@@ -1050,6 +1322,419 @@ mod tests {
}
}
+ // ── edit_file execution tests ──
+
+ mod edit {
+ use super::*;
+ use crate::file_tracker::FileReadTracker;
+
+ /// Helper: create a temp file (with a closed handle), record it in a tracker.
+ /// Returns the TempDir (keeps the path alive) and tracker.
+ /// The file handle is closed so atomic_write_file can rename over it on Windows.
+ fn setup_tracked_file(content: &str) -> (tempfile::TempDir, PathBuf, FileReadTracker) {
+ let dir = tempfile::tempdir().unwrap();
+ let path = dir.path().join("test_file.toml");
+ std::fs::write(&path, content).unwrap();
+
+ let file_content = std::fs::read(&path).unwrap();
+ let mtime = std::fs::metadata(&path).unwrap().modified().unwrap();
+
+ let mut tracker = FileReadTracker::default();
+ tracker.record_read(path.clone(), &file_content, mtime);
+
+ (dir, path, tracker)
+ }
+
+ fn edit_call(path: &Path, old: &str, new: &str, replace_all: bool) -> EditToolCall {
+ EditToolCall {
+ path: path.to_path_buf(),
+ old_string: old.to_string(),
+ new_string: new.to_string(),
+ replace_all,
+ }
+ }
+
+ #[test]
+ fn successful_single_replacement() {
+ let (_dir, path, tracker) = setup_tracked_file("[section]\nkey = old_value\n");
+
+ let call = edit_call(&path, "old_value", "new_value", false);
+ let (outcome, new_bytes) = call.execute(&path, &tracker);
+
+ assert!(matches!(outcome, ToolOutcome::Success(_)));
+ assert!(new_bytes.is_some());
+ assert_eq!(
+ std::fs::read_to_string(&path).unwrap(),
+ "[section]\nkey = new_value\n"
+ );
+ }
+
+ #[test]
+ fn successful_replace_all() {
+ let (_dir, path, tracker) = setup_tracked_file("aaa bbb aaa ccc aaa");
+
+ let call = edit_call(&path, "aaa", "xxx", true);
+ let (outcome, _) = call.execute(&path, &tracker);
+
+ assert!(matches!(outcome, ToolOutcome::Success(ref s) if s.contains("3 occurrences")));
+ assert_eq!(
+ std::fs::read_to_string(&path).unwrap(),
+ "xxx bbb xxx ccc xxx"
+ );
+ }
+
+ #[test]
+ fn error_file_not_read() {
+ let dir = tempfile::tempdir().unwrap();
+ let path = dir.path().join("unread.txt");
+ std::fs::write(&path, "content").unwrap();
+ let tracker = FileReadTracker::default(); // empty — never read
+
+ let call = edit_call(&path, "x", "y", false);
+ let (outcome, new_bytes) = call.execute(&path, &tracker);
+
+ assert!(new_bytes.is_none());
+ match outcome {
+ ToolOutcome::Error(msg) => {
+ assert!(msg.contains("not been read yet"), "got: {msg}");
+ }
+ _ => panic!("expected error"),
+ }
+ }
+
+ #[test]
+ fn error_file_modified_since_read() {
+ let (_dir, path, tracker) = setup_tracked_file("original");
+
+ // Modify the file after the read was recorded
+ std::thread::sleep(std::time::Duration::from_millis(10));
+ std::fs::write(&path, "modified externally").unwrap();
+
+ let call = edit_call(&path, "original", "replaced", false);
+ let (outcome, new_bytes) = call.execute(&path, &tracker);
+
+ assert!(new_bytes.is_none());
+ match outcome {
+ ToolOutcome::Error(msg) => {
+ assert!(msg.contains("modified since read"), "got: {msg}");
+ }
+ _ => panic!("expected error"),
+ }
+ }
+
+ #[test]
+ fn error_no_match() {
+ let (_dir, path, tracker) = setup_tracked_file("hello world");
+
+ let call = edit_call(&path, "nonexistent", "replacement", false);
+ let (outcome, new_bytes) = call.execute(&path, &tracker);
+
+ assert!(new_bytes.is_none());
+ match outcome {
+ ToolOutcome::Error(msg) => {
+ assert!(msg.contains("not found"), "got: {msg}");
+ }
+ _ => panic!("expected error"),
+ }
+ }
+
+ #[test]
+ fn error_multiple_matches_without_replace_all() {
+ let (_dir, path, tracker) = setup_tracked_file("foo bar foo baz foo");
+
+ let call = edit_call(&path, "foo", "qux", false);
+ let (outcome, new_bytes) = call.execute(&path, &tracker);
+
+ assert!(new_bytes.is_none());
+ match outcome {
+ ToolOutcome::Error(msg) => {
+ assert!(msg.contains("3 matches"), "got: {msg}");
+ assert!(msg.contains("replace_all"), "got: {msg}");
+ }
+ _ => panic!("expected error"),
+ }
+ // File should be unchanged
+ assert_eq!(
+ std::fs::read_to_string(&path).unwrap(),
+ "foo bar foo baz foo"
+ );
+ }
+
+ #[test]
+ fn error_empty_old_string() {
+ let (_dir, path, tracker) = setup_tracked_file("content");
+
+ let call = edit_call(&path, "", "something", false);
+ let (outcome, new_bytes) = call.execute(&path, &tracker);
+
+ assert!(new_bytes.is_none());
+ assert!(matches!(outcome, ToolOutcome::Error(_)));
+ }
+
+ #[test]
+ fn error_file_does_not_exist() {
+ let tracker = FileReadTracker::default();
+ let dir = tempfile::tempdir().unwrap();
+ let path = dir.path().join("nonexistent.txt");
+
+ let call = edit_call(&path, "x", "y", false);
+ let (outcome, new_bytes) = call.execute(&path, &tracker);
+
+ assert!(new_bytes.is_none());
+ match outcome {
+ ToolOutcome::Error(msg) => {
+ assert!(msg.contains("does not exist"), "got: {msg}");
+ }
+ _ => panic!("expected error"),
+ }
+ }
+
+ #[test]
+ fn preserves_file_when_no_match() {
+ let original = "[config]\nport = 8080\nhost = localhost\n";
+ let (_dir, path, tracker) = setup_tracked_file(original);
+
+ let call = edit_call(&path, "port = 9090", "port = 3000", false);
+ let (outcome, _) = call.execute(&path, &tracker);
+
+ assert!(matches!(outcome, ToolOutcome::Error(_)));
+ assert_eq!(std::fs::read_to_string(&path).unwrap(), original);
+ }
+
+ #[test]
+ fn multiline_replacement() {
+ let content = "[section]\nkey1 = val1\nkey2 = val2\n[other]\n";
+ let (_dir, path, tracker) = setup_tracked_file(content);
+
+ let call = edit_call(
+ &path,
+ "key1 = val1\nkey2 = val2",
+ "key1 = new1\nkey2 = new2",
+ false,
+ );
+ let (outcome, new_bytes) = call.execute(&path, &tracker);
+
+ assert!(matches!(outcome, ToolOutcome::Success(_)));
+ assert!(new_bytes.is_some());
+ assert_eq!(
+ std::fs::read_to_string(&path).unwrap(),
+ "[section]\nkey1 = new1\nkey2 = new2\n[other]\n"
+ );
+ }
+ }
+
+ // ── Integration tests: full edit lifecycle ──
+ //
+ // These exercise the cross-component flow that dispatch orchestrates:
+ // FileReadTracker → SnapshotStore → EditToolCall.execute → tracker update
+
+ mod edit_integration {
+ use super::*;
+ use crate::edit_permissions::EditPermissionCache;
+ use crate::file_tracker::FileReadTracker;
+ use crate::snapshots::SnapshotStore;
+
+ /// Simulate a file read (what dispatch does after ReadToolCall.execute).
+ fn simulate_read(tracker: &mut FileReadTracker, path: &std::path::Path) {
+ let content = std::fs::read(path).unwrap();
+ let mtime = std::fs::metadata(path).unwrap().modified().unwrap();
+ tracker.record_read(path.to_path_buf(), &content, mtime);
+ }
+
+ /// Simulate a tracker update after edit (what dispatch does after execute).
+ fn simulate_tracker_update(
+ tracker: &mut FileReadTracker,
+ path: &std::path::Path,
+ new_bytes: &[u8],
+ ) {
+ let mtime = std::fs::metadata(path).unwrap().modified().unwrap();
+ tracker.update_after_edit(path, new_bytes, mtime);
+ }
+
+ #[test]
+ fn full_read_snapshot_edit_cycle() {
+ let dir = tempfile::tempdir().unwrap();
+ let file_path = dir.path().join("config.toml");
+ std::fs::write(&file_path, "[db]\nhost = localhost\nport = 5432\n").unwrap();
+
+ let snapshot_dir = dir.path().join("snapshots").join("session-1");
+ let mut tracker = FileReadTracker::default();
+ let mut store = SnapshotStore::open(snapshot_dir.clone()).unwrap();
+
+ // 1. Simulate reading the file
+ simulate_read(&mut tracker, &file_path);
+
+ // 2. Snapshot before edit
+ let original = std::fs::read(&file_path).unwrap();
+ store.ensure_snapshot(&file_path, &original).unwrap();
+
+ // 3. Execute edit
+ let call = EditToolCall {
+ path: file_path.clone(),
+ old_string: "host = localhost".to_string(),
+ new_string: "host = 10.0.0.1".to_string(),
+ replace_all: false,
+ };
+ let (outcome, new_bytes) = call.execute(&file_path, &tracker);
+ assert!(matches!(outcome, ToolOutcome::Success(_)));
+ let new_bytes = new_bytes.unwrap();
+
+ // 4. Update tracker (simulating what dispatch does)
+ simulate_tracker_update(&mut tracker, &file_path, &new_bytes);
+
+ // Verify: file was edited
+ assert_eq!(
+ std::fs::read_to_string(&file_path).unwrap(),
+ "[db]\nhost = 10.0.0.1\nport = 5432\n"
+ );
+
+ // Verify: snapshot has original content
+ assert!(store.has_snapshot(&file_path));
+ let snapshot_name = crate::snapshots::sanitize_path(&file_path);
+ let snapshot_content =
+ std::fs::read_to_string(snapshot_dir.join(snapshot_name)).unwrap();
+ assert_eq!(snapshot_content, "[db]\nhost = localhost\nport = 5432\n");
+ }
+
+ #[test]
+ fn second_edit_without_reread() {
+ let dir = tempfile::tempdir().unwrap();
+ let file_path = dir.path().join("config.toml");
+ std::fs::write(&file_path, "key1 = aaa\nkey2 = bbb\n").unwrap();
+
+ let mut tracker = FileReadTracker::default();
+
+ // Read the file
+ simulate_read(&mut tracker, &file_path);
+
+ // First edit
+ let call1 = EditToolCall {
+ path: file_path.clone(),
+ old_string: "key1 = aaa".to_string(),
+ new_string: "key1 = xxx".to_string(),
+ replace_all: false,
+ };
+ let (outcome, new_bytes) = call1.execute(&file_path, &tracker);
+ assert!(matches!(outcome, ToolOutcome::Success(_)));
+ simulate_tracker_update(&mut tracker, &file_path, &new_bytes.unwrap());
+
+ // Second edit — should work without re-reading because tracker was updated
+ let call2 = EditToolCall {
+ path: file_path.clone(),
+ old_string: "key2 = bbb".to_string(),
+ new_string: "key2 = yyy".to_string(),
+ replace_all: false,
+ };
+ let (outcome, new_bytes) = call2.execute(&file_path, &tracker);
+ assert!(matches!(outcome, ToolOutcome::Success(_)));
+ assert!(new_bytes.is_some());
+ assert_eq!(
+ std::fs::read_to_string(&file_path).unwrap(),
+ "key1 = xxx\nkey2 = yyy\n"
+ );
+ }
+
+ #[test]
+ fn external_modification_between_edits() {
+ let dir = tempfile::tempdir().unwrap();
+ let file_path = dir.path().join("config.toml");
+ std::fs::write(&file_path, "value = original\n").unwrap();
+
+ let mut tracker = FileReadTracker::default();
+ simulate_read(&mut tracker, &file_path);
+
+ // First edit succeeds
+ let call1 = EditToolCall {
+ path: file_path.clone(),
+ old_string: "value = original".to_string(),
+ new_string: "value = edited".to_string(),
+ replace_all: false,
+ };
+ let (outcome, new_bytes) = call1.execute(&file_path, &tracker);
+ assert!(matches!(outcome, ToolOutcome::Success(_)));
+ simulate_tracker_update(&mut tracker, &file_path, &new_bytes.unwrap());
+
+ // External modification (e.g., user edits the file)
+ std::thread::sleep(std::time::Duration::from_millis(10));
+ std::fs::write(&file_path, "value = user_changed\n").unwrap();
+
+ // Second edit should fail (stale)
+ let call2 = EditToolCall {
+ path: file_path.clone(),
+ old_string: "value = edited".to_string(),
+ new_string: "value = second_edit".to_string(),
+ replace_all: false,
+ };
+ let (outcome, new_bytes) = call2.execute(&file_path, &tracker);
+ assert!(new_bytes.is_none());
+ match outcome {
+ ToolOutcome::Error(msg) => assert!(msg.contains("modified since read")),
+ _ => panic!("expected stale error"),
+ }
+
+ // File should be unchanged (the user's edit preserved)
+ assert_eq!(
+ std::fs::read_to_string(&file_path).unwrap(),
+ "value = user_changed\n"
+ );
+ }
+
+ #[test]
+ fn snapshot_only_created_once_per_file() {
+ let dir = tempfile::tempdir().unwrap();
+ let file_path = dir.path().join("config.toml");
+ std::fs::write(&file_path, "a = 1\nb = 2\n").unwrap();
+
+ let snapshot_dir = dir.path().join("snapshots").join("session-1");
+ let mut tracker = FileReadTracker::default();
+ let mut store = SnapshotStore::open(snapshot_dir).unwrap();
+
+ simulate_read(&mut tracker, &file_path);
+
+ // First edit — snapshot should be created
+ let original = std::fs::read(&file_path).unwrap();
+ let created = store.ensure_snapshot(&file_path, &original).unwrap();
+ assert!(created);
+
+ let call1 = EditToolCall {
+ path: file_path.clone(),
+ old_string: "a = 1".to_string(),
+ new_string: "a = 10".to_string(),
+ replace_all: false,
+ };
+ let (_, new_bytes) = call1.execute(&file_path, &tracker);
+ simulate_tracker_update(&mut tracker, &file_path, &new_bytes.unwrap());
+
+ // Second edit — snapshot should NOT be recreated
+ let content_before_second = std::fs::read(&file_path).unwrap();
+ let created = store
+ .ensure_snapshot(&file_path, &content_before_second)
+ .unwrap();
+ assert!(!created); // idempotent — already snapshotted
+ }
+
+ #[test]
+ fn permission_cache_grant_and_check() {
+ let mut cache = EditPermissionCache::default();
+ let path = std::path::PathBuf::from("/Users/me/.config/atuin/config.toml");
+
+ // Initially no grant
+ assert!(!cache.has_valid_grant(&path));
+
+ // Grant permission
+ cache.grant(path.clone());
+ assert!(cache.has_valid_grant(&path));
+
+ // Different file has no grant
+ assert!(!cache.has_valid_grant(std::path::Path::new("/other/file.toml")));
+
+ // Roundtrip through JSON (simulates session persistence)
+ let json = cache.to_json().unwrap();
+ let restored = EditPermissionCache::from_json(&json).unwrap();
+ assert!(restored.has_valid_grant(&path));
+ }
+ }
+
// ── Windows-specific tests (absolute paths with drive letters) ──
#[cfg(windows)]