aboutsummaryrefslogtreecommitdiffstats
path: root/crates/atuin-ai/src/tools/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/atuin-ai/src/tools/mod.rs')
-rw-r--r--crates/atuin-ai/src/tools/mod.rs174
1 files changed, 171 insertions, 3 deletions
diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs
index 8fe1ad73..890ea734 100644
--- a/crates/atuin-ai/src/tools/mod.rs
+++ b/crates/atuin-ai/src/tools/mod.rs
@@ -171,6 +171,8 @@ pub(crate) struct TrackedTool {
pub abort_tx: Option<tokio::sync::oneshot::Sender<()>>,
/// Diff preview for completed edit tool calls.
pub edit_preview: Option<crate::diff::EditPreview>,
+ /// Content preview for completed write tool calls.
+ pub write_preview: Option<crate::diff::WritePreview>,
}
impl TrackedTool {
@@ -237,6 +239,7 @@ impl ToolTracker {
phase: ToolPhase::CheckingPermissions,
abort_tx: None,
edit_preview: None,
+ write_preview: None,
});
}
@@ -724,10 +727,10 @@ impl PermissableToolCall for EditToolCall {
}
#[derive(Debug, Clone)]
-#[expect(dead_code)]
pub(crate) struct WriteToolCall {
pub path: PathBuf,
pub content: String,
+ pub overwrite: bool,
}
impl TryFrom<&serde_json::Value> for WriteToolCall {
@@ -735,22 +738,85 @@ impl TryFrom<&serde_json::Value> for WriteToolCall {
fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> {
let path = value
- .get("path")
+ .get("file_path")
.and_then(|v| v.as_str())
- .ok_or(eyre::eyre!("Missing path"))?;
+ .ok_or(eyre::eyre!("Missing file_path"))?;
let content = value
.get("content")
.and_then(|v| v.as_str())
.ok_or(eyre::eyre!("Missing content"))?;
+ let overwrite = value
+ .get("overwrite")
+ .and_then(|v| v.as_bool())
+ .unwrap_or(false);
+
Ok(WriteToolCall {
path: expand_path(path),
content: content.to_string(),
+ overwrite,
})
}
}
+impl WriteToolCall {
+ /// Resolve the write path to an absolute path.
+ pub fn resolved_path(&self) -> PathBuf {
+ if self.path.is_relative() {
+ std::env::current_dir()
+ .map(|cwd| cwd.join(&self.path))
+ .unwrap_or_else(|_| self.path.clone())
+ } else {
+ self.path.clone()
+ }
+ }
+
+ /// Execute the write operation.
+ ///
+ /// Creates a new file or overwrites an existing one (if `overwrite` is set).
+ /// Returns the outcome and the written bytes (for tracker updates).
+ pub fn execute(&self, resolved_path: &Path) -> (ToolOutcome, Option<Vec<u8>>) {
+ if resolved_path.is_dir() {
+ return (
+ ToolOutcome::Error(format!(
+ "Error: path is a directory, not a file: {}",
+ resolved_path.display()
+ )),
+ None,
+ );
+ }
+ if resolved_path.exists() && !self.overwrite {
+ return (
+ ToolOutcome::Error(format!(
+ "File already exists: {}. Set overwrite to true to replace it, or use edit_file to make targeted changes.",
+ resolved_path.display()
+ )),
+ None,
+ );
+ }
+
+ // Capture before the write — after atomic_write the file always exists.
+ let existed = resolved_path.exists();
+
+ // Write atomically
+ let content_bytes = self.content.as_bytes().to_vec();
+ if let Err(e) = crate::snapshots::atomic_write_file(resolved_path, &content_bytes) {
+ return (ToolOutcome::Error(format!("Error writing file: {e}")), None);
+ }
+
+ let line_count = self.content.lines().count();
+ let verb = if existed { "Overwrote" } else { "Created" };
+ (
+ ToolOutcome::Success(format!(
+ "{verb} {} ({line_count} lines).",
+ resolved_path.display()
+ )),
+ Some(content_bytes),
+ )
+ }
+}
+
impl PermissableToolCall for WriteToolCall {
fn target_dir(&self) -> Option<&Path> {
Some(&self.path)
@@ -1235,6 +1301,7 @@ mod tests {
WriteToolCall {
path: expand_path(path),
content: String::new(),
+ overwrite: false,
}
}
@@ -1735,6 +1802,107 @@ mod tests {
}
}
+ // ── write_file execution tests ──
+
+ mod write {
+ use super::*;
+
+ #[test]
+ fn creates_new_file() {
+ let dir = tempfile::tempdir().unwrap();
+ let path = dir.path().join("new_file.txt");
+
+ let call = WriteToolCall {
+ path: path.clone(),
+ content: "hello\nworld\n".to_string(),
+ overwrite: false,
+ };
+ let (outcome, new_bytes) = call.execute(&path);
+
+ assert!(matches!(outcome, ToolOutcome::Success(ref s) if s.contains("Created")));
+ assert!(new_bytes.is_some());
+ assert_eq!(std::fs::read_to_string(&path).unwrap(), "hello\nworld\n");
+ }
+
+ #[test]
+ fn error_file_exists_without_overwrite() {
+ let dir = tempfile::tempdir().unwrap();
+ let path = dir.path().join("existing.txt");
+ std::fs::write(&path, "original").unwrap();
+
+ let call = WriteToolCall {
+ path: path.clone(),
+ content: "new content".to_string(),
+ overwrite: false,
+ };
+ let (outcome, new_bytes) = call.execute(&path);
+
+ assert!(new_bytes.is_none());
+ match outcome {
+ ToolOutcome::Error(msg) => {
+ assert!(msg.contains("already exists"), "got: {msg}");
+ assert!(msg.contains("overwrite"), "got: {msg}");
+ }
+ _ => panic!("expected error"),
+ }
+ // Original preserved
+ assert_eq!(std::fs::read_to_string(&path).unwrap(), "original");
+ }
+
+ #[test]
+ fn overwrites_existing_file_when_flag_set() {
+ let dir = tempfile::tempdir().unwrap();
+ let path = dir.path().join("existing.txt");
+ std::fs::write(&path, "original").unwrap();
+
+ let call = WriteToolCall {
+ path: path.clone(),
+ content: "replaced content\n".to_string(),
+ overwrite: true,
+ };
+ let (outcome, new_bytes) = call.execute(&path);
+
+ assert!(matches!(outcome, ToolOutcome::Success(_)));
+ assert!(new_bytes.is_some());
+ assert_eq!(
+ std::fs::read_to_string(&path).unwrap(),
+ "replaced content\n"
+ );
+ }
+
+ #[test]
+ fn creates_parent_directories() {
+ let dir = tempfile::tempdir().unwrap();
+ let path = dir.path().join("sub").join("dir").join("file.txt");
+
+ let call = WriteToolCall {
+ path: path.clone(),
+ content: "nested\n".to_string(),
+ overwrite: false,
+ };
+ let (outcome, _) = call.execute(&path);
+
+ assert!(matches!(outcome, ToolOutcome::Success(_)));
+ assert_eq!(std::fs::read_to_string(&path).unwrap(), "nested\n");
+ }
+
+ #[test]
+ fn error_path_is_directory() {
+ let dir = tempfile::tempdir().unwrap();
+ let path = dir.path().to_path_buf();
+
+ let call = WriteToolCall {
+ path: path.clone(),
+ content: "content".to_string(),
+ overwrite: false,
+ };
+ let (outcome, new_bytes) = call.execute(&path);
+
+ assert!(new_bytes.is_none());
+ assert!(matches!(outcome, ToolOutcome::Error(ref msg) if msg.contains("directory")));
+ }
+ }
+
// ── Windows-specific tests (absolute paths with drive letters) ──
#[cfg(windows)]