diff options
| author | Michelle Tilley <michelle@michelletilley.net> | 2026-04-21 10:53:31 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-04-21 10:53:31 -0700 |
| commit | 33c779aa9894e1347aeaa4c73e536cf842aee684 (patch) | |
| tree | bfe0f60252518798ccf4621c7bea06021e64d2f4 /crates/atuin-ai/src/tools | |
| parent | feat: AI tool rendering overhaul + edit_file tool (#3423) (diff) | |
| download | atuin-33c779aa9894e1347aeaa4c73e536cf842aee684.zip | |
feat: Implement write_file tool with overwrite safety (#3432)
## Summary
Implements the `write_file` client-side tool — creates new files or
overwrites existing ones with an explicit `overwrite` flag for safety.
- **Overwrite flag**: Writing to an existing file without `overwrite:
true` returns an error directing the LLM to set the flag or use
`edit_file` for targeted changes. Prevents accidental overwrites.
- **Snapshots**: Existing files are backed up before overwriting (same
infrastructure as `edit_file`).
- **Content preview**: Completed writes show the first 10 lines in gray
with line numbers, plus "+ N more lines" for longer files.
- **Atomic writes**: Uses `tempfile` + fsync + rename (same as
`edit_file`).
- **File tracker update**: After writing, the file is registered in the
tracker so subsequent `edit_file` calls work without a separate read.
- **Permission**: Shares the `"Write"` rule with `edit_file` — one
permission covers both tools.
Diffstat (limited to 'crates/atuin-ai/src/tools')
| -rw-r--r-- | crates/atuin-ai/src/tools/mod.rs | 174 |
1 files changed, 171 insertions, 3 deletions
diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs index 8fe1ad73..890ea734 100644 --- a/crates/atuin-ai/src/tools/mod.rs +++ b/crates/atuin-ai/src/tools/mod.rs @@ -171,6 +171,8 @@ pub(crate) struct TrackedTool { pub abort_tx: Option<tokio::sync::oneshot::Sender<()>>, /// Diff preview for completed edit tool calls. pub edit_preview: Option<crate::diff::EditPreview>, + /// Content preview for completed write tool calls. + pub write_preview: Option<crate::diff::WritePreview>, } impl TrackedTool { @@ -237,6 +239,7 @@ impl ToolTracker { phase: ToolPhase::CheckingPermissions, abort_tx: None, edit_preview: None, + write_preview: None, }); } @@ -724,10 +727,10 @@ impl PermissableToolCall for EditToolCall { } #[derive(Debug, Clone)] -#[expect(dead_code)] pub(crate) struct WriteToolCall { pub path: PathBuf, pub content: String, + pub overwrite: bool, } impl TryFrom<&serde_json::Value> for WriteToolCall { @@ -735,22 +738,85 @@ impl TryFrom<&serde_json::Value> for WriteToolCall { fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> { let path = value - .get("path") + .get("file_path") .and_then(|v| v.as_str()) - .ok_or(eyre::eyre!("Missing path"))?; + .ok_or(eyre::eyre!("Missing file_path"))?; let content = value .get("content") .and_then(|v| v.as_str()) .ok_or(eyre::eyre!("Missing content"))?; + let overwrite = value + .get("overwrite") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + Ok(WriteToolCall { path: expand_path(path), content: content.to_string(), + overwrite, }) } } +impl WriteToolCall { + /// Resolve the write path to an absolute path. + pub fn resolved_path(&self) -> PathBuf { + if self.path.is_relative() { + std::env::current_dir() + .map(|cwd| cwd.join(&self.path)) + .unwrap_or_else(|_| self.path.clone()) + } else { + self.path.clone() + } + } + + /// Execute the write operation. + /// + /// Creates a new file or overwrites an existing one (if `overwrite` is set). + /// Returns the outcome and the written bytes (for tracker updates). + pub fn execute(&self, resolved_path: &Path) -> (ToolOutcome, Option<Vec<u8>>) { + if resolved_path.is_dir() { + return ( + ToolOutcome::Error(format!( + "Error: path is a directory, not a file: {}", + resolved_path.display() + )), + None, + ); + } + if resolved_path.exists() && !self.overwrite { + return ( + ToolOutcome::Error(format!( + "File already exists: {}. Set overwrite to true to replace it, or use edit_file to make targeted changes.", + resolved_path.display() + )), + None, + ); + } + + // Capture before the write — after atomic_write the file always exists. + let existed = resolved_path.exists(); + + // Write atomically + let content_bytes = self.content.as_bytes().to_vec(); + if let Err(e) = crate::snapshots::atomic_write_file(resolved_path, &content_bytes) { + return (ToolOutcome::Error(format!("Error writing file: {e}")), None); + } + + let line_count = self.content.lines().count(); + let verb = if existed { "Overwrote" } else { "Created" }; + ( + ToolOutcome::Success(format!( + "{verb} {} ({line_count} lines).", + resolved_path.display() + )), + Some(content_bytes), + ) + } +} + impl PermissableToolCall for WriteToolCall { fn target_dir(&self) -> Option<&Path> { Some(&self.path) @@ -1235,6 +1301,7 @@ mod tests { WriteToolCall { path: expand_path(path), content: String::new(), + overwrite: false, } } @@ -1735,6 +1802,107 @@ mod tests { } } + // ── write_file execution tests ── + + mod write { + use super::*; + + #[test] + fn creates_new_file() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("new_file.txt"); + + let call = WriteToolCall { + path: path.clone(), + content: "hello\nworld\n".to_string(), + overwrite: false, + }; + let (outcome, new_bytes) = call.execute(&path); + + assert!(matches!(outcome, ToolOutcome::Success(ref s) if s.contains("Created"))); + assert!(new_bytes.is_some()); + assert_eq!(std::fs::read_to_string(&path).unwrap(), "hello\nworld\n"); + } + + #[test] + fn error_file_exists_without_overwrite() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("existing.txt"); + std::fs::write(&path, "original").unwrap(); + + let call = WriteToolCall { + path: path.clone(), + content: "new content".to_string(), + overwrite: false, + }; + let (outcome, new_bytes) = call.execute(&path); + + assert!(new_bytes.is_none()); + match outcome { + ToolOutcome::Error(msg) => { + assert!(msg.contains("already exists"), "got: {msg}"); + assert!(msg.contains("overwrite"), "got: {msg}"); + } + _ => panic!("expected error"), + } + // Original preserved + assert_eq!(std::fs::read_to_string(&path).unwrap(), "original"); + } + + #[test] + fn overwrites_existing_file_when_flag_set() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("existing.txt"); + std::fs::write(&path, "original").unwrap(); + + let call = WriteToolCall { + path: path.clone(), + content: "replaced content\n".to_string(), + overwrite: true, + }; + let (outcome, new_bytes) = call.execute(&path); + + assert!(matches!(outcome, ToolOutcome::Success(_))); + assert!(new_bytes.is_some()); + assert_eq!( + std::fs::read_to_string(&path).unwrap(), + "replaced content\n" + ); + } + + #[test] + fn creates_parent_directories() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("sub").join("dir").join("file.txt"); + + let call = WriteToolCall { + path: path.clone(), + content: "nested\n".to_string(), + overwrite: false, + }; + let (outcome, _) = call.execute(&path); + + assert!(matches!(outcome, ToolOutcome::Success(_))); + assert_eq!(std::fs::read_to_string(&path).unwrap(), "nested\n"); + } + + #[test] + fn error_path_is_directory() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().to_path_buf(); + + let call = WriteToolCall { + path: path.clone(), + content: "content".to_string(), + overwrite: false, + }; + let (outcome, new_bytes) = call.execute(&path); + + assert!(new_bytes.is_none()); + assert!(matches!(outcome, ToolOutcome::Error(ref msg) if msg.contains("directory"))); + } + } + // ── Windows-specific tests (absolute paths with drive letters) ── #[cfg(windows)] |
