aboutsummaryrefslogtreecommitdiffstats
path: root/crates
diff options
context:
space:
mode:
Diffstat (limited to 'crates')
-rw-r--r--crates/atuin-ai/src/diff.rs34
-rw-r--r--crates/atuin-ai/src/tools/mod.rs174
-rw-r--r--crates/atuin-ai/src/tui/dispatch.rs67
-rw-r--r--crates/atuin-ai/src/tui/view/mod.rs68
-rw-r--r--crates/atuin-ai/src/tui/view/turn.rs6
5 files changed, 336 insertions, 13 deletions
diff --git a/crates/atuin-ai/src/diff.rs b/crates/atuin-ai/src/diff.rs
index 663481c0..e704175c 100644
--- a/crates/atuin-ai/src/diff.rs
+++ b/crates/atuin-ai/src/diff.rs
@@ -101,6 +101,40 @@ impl EditPreview {
}
}
+/// Maximum lines to show in a write preview.
+const WRITE_PREVIEW_LINES: usize = 10;
+
+/// A content preview for a write_file operation.
+///
+/// Shows the first N lines of the written content plus a count of
+/// remaining lines if truncated.
+#[derive(Debug, Clone)]
+pub(crate) struct WritePreview {
+ /// First lines of content (up to WRITE_PREVIEW_LINES).
+ pub lines: Vec<String>,
+ /// Total number of lines in the written file.
+ pub total_lines: usize,
+}
+
+impl WritePreview {
+ /// Create a preview from file content.
+ pub fn from_content(content: &str) -> Self {
+ let all_lines: Vec<&str> = content.lines().collect();
+ let total_lines = all_lines.len();
+ let lines = all_lines
+ .into_iter()
+ .take(WRITE_PREVIEW_LINES)
+ .map(String::from)
+ .collect();
+ WritePreview { lines, total_lines }
+ }
+
+ /// Number of lines not shown in the preview.
+ pub fn remaining_lines(&self) -> usize {
+ self.total_lines.saturating_sub(self.lines.len())
+ }
+}
+
/// Build a single DiffHunk from a group of adjacent raw hunks.
fn build_hunk(group: &[&imara_diff::Hunk], input: &InternedInput<&str>) -> DiffHunk {
let first = group.first().unwrap();
diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs
index 8fe1ad73..890ea734 100644
--- a/crates/atuin-ai/src/tools/mod.rs
+++ b/crates/atuin-ai/src/tools/mod.rs
@@ -171,6 +171,8 @@ pub(crate) struct TrackedTool {
pub abort_tx: Option<tokio::sync::oneshot::Sender<()>>,
/// Diff preview for completed edit tool calls.
pub edit_preview: Option<crate::diff::EditPreview>,
+ /// Content preview for completed write tool calls.
+ pub write_preview: Option<crate::diff::WritePreview>,
}
impl TrackedTool {
@@ -237,6 +239,7 @@ impl ToolTracker {
phase: ToolPhase::CheckingPermissions,
abort_tx: None,
edit_preview: None,
+ write_preview: None,
});
}
@@ -724,10 +727,10 @@ impl PermissableToolCall for EditToolCall {
}
#[derive(Debug, Clone)]
-#[expect(dead_code)]
pub(crate) struct WriteToolCall {
pub path: PathBuf,
pub content: String,
+ pub overwrite: bool,
}
impl TryFrom<&serde_json::Value> for WriteToolCall {
@@ -735,22 +738,85 @@ impl TryFrom<&serde_json::Value> for WriteToolCall {
fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> {
let path = value
- .get("path")
+ .get("file_path")
.and_then(|v| v.as_str())
- .ok_or(eyre::eyre!("Missing path"))?;
+ .ok_or(eyre::eyre!("Missing file_path"))?;
let content = value
.get("content")
.and_then(|v| v.as_str())
.ok_or(eyre::eyre!("Missing content"))?;
+ let overwrite = value
+ .get("overwrite")
+ .and_then(|v| v.as_bool())
+ .unwrap_or(false);
+
Ok(WriteToolCall {
path: expand_path(path),
content: content.to_string(),
+ overwrite,
})
}
}
+impl WriteToolCall {
+ /// Resolve the write path to an absolute path.
+ pub fn resolved_path(&self) -> PathBuf {
+ if self.path.is_relative() {
+ std::env::current_dir()
+ .map(|cwd| cwd.join(&self.path))
+ .unwrap_or_else(|_| self.path.clone())
+ } else {
+ self.path.clone()
+ }
+ }
+
+ /// Execute the write operation.
+ ///
+ /// Creates a new file or overwrites an existing one (if `overwrite` is set).
+ /// Returns the outcome and the written bytes (for tracker updates).
+ pub fn execute(&self, resolved_path: &Path) -> (ToolOutcome, Option<Vec<u8>>) {
+ if resolved_path.is_dir() {
+ return (
+ ToolOutcome::Error(format!(
+ "Error: path is a directory, not a file: {}",
+ resolved_path.display()
+ )),
+ None,
+ );
+ }
+ if resolved_path.exists() && !self.overwrite {
+ return (
+ ToolOutcome::Error(format!(
+ "File already exists: {}. Set overwrite to true to replace it, or use edit_file to make targeted changes.",
+ resolved_path.display()
+ )),
+ None,
+ );
+ }
+
+ // Capture before the write — after atomic_write the file always exists.
+ let existed = resolved_path.exists();
+
+ // Write atomically
+ let content_bytes = self.content.as_bytes().to_vec();
+ if let Err(e) = crate::snapshots::atomic_write_file(resolved_path, &content_bytes) {
+ return (ToolOutcome::Error(format!("Error writing file: {e}")), None);
+ }
+
+ let line_count = self.content.lines().count();
+ let verb = if existed { "Overwrote" } else { "Created" };
+ (
+ ToolOutcome::Success(format!(
+ "{verb} {} ({line_count} lines).",
+ resolved_path.display()
+ )),
+ Some(content_bytes),
+ )
+ }
+}
+
impl PermissableToolCall for WriteToolCall {
fn target_dir(&self) -> Option<&Path> {
Some(&self.path)
@@ -1235,6 +1301,7 @@ mod tests {
WriteToolCall {
path: expand_path(path),
content: String::new(),
+ overwrite: false,
}
}
@@ -1735,6 +1802,107 @@ mod tests {
}
}
+ // ── write_file execution tests ──
+
+ mod write {
+ use super::*;
+
+ #[test]
+ fn creates_new_file() {
+ let dir = tempfile::tempdir().unwrap();
+ let path = dir.path().join("new_file.txt");
+
+ let call = WriteToolCall {
+ path: path.clone(),
+ content: "hello\nworld\n".to_string(),
+ overwrite: false,
+ };
+ let (outcome, new_bytes) = call.execute(&path);
+
+ assert!(matches!(outcome, ToolOutcome::Success(ref s) if s.contains("Created")));
+ assert!(new_bytes.is_some());
+ assert_eq!(std::fs::read_to_string(&path).unwrap(), "hello\nworld\n");
+ }
+
+ #[test]
+ fn error_file_exists_without_overwrite() {
+ let dir = tempfile::tempdir().unwrap();
+ let path = dir.path().join("existing.txt");
+ std::fs::write(&path, "original").unwrap();
+
+ let call = WriteToolCall {
+ path: path.clone(),
+ content: "new content".to_string(),
+ overwrite: false,
+ };
+ let (outcome, new_bytes) = call.execute(&path);
+
+ assert!(new_bytes.is_none());
+ match outcome {
+ ToolOutcome::Error(msg) => {
+ assert!(msg.contains("already exists"), "got: {msg}");
+ assert!(msg.contains("overwrite"), "got: {msg}");
+ }
+ _ => panic!("expected error"),
+ }
+ // Original preserved
+ assert_eq!(std::fs::read_to_string(&path).unwrap(), "original");
+ }
+
+ #[test]
+ fn overwrites_existing_file_when_flag_set() {
+ let dir = tempfile::tempdir().unwrap();
+ let path = dir.path().join("existing.txt");
+ std::fs::write(&path, "original").unwrap();
+
+ let call = WriteToolCall {
+ path: path.clone(),
+ content: "replaced content\n".to_string(),
+ overwrite: true,
+ };
+ let (outcome, new_bytes) = call.execute(&path);
+
+ assert!(matches!(outcome, ToolOutcome::Success(_)));
+ assert!(new_bytes.is_some());
+ assert_eq!(
+ std::fs::read_to_string(&path).unwrap(),
+ "replaced content\n"
+ );
+ }
+
+ #[test]
+ fn creates_parent_directories() {
+ let dir = tempfile::tempdir().unwrap();
+ let path = dir.path().join("sub").join("dir").join("file.txt");
+
+ let call = WriteToolCall {
+ path: path.clone(),
+ content: "nested\n".to_string(),
+ overwrite: false,
+ };
+ let (outcome, _) = call.execute(&path);
+
+ assert!(matches!(outcome, ToolOutcome::Success(_)));
+ assert_eq!(std::fs::read_to_string(&path).unwrap(), "nested\n");
+ }
+
+ #[test]
+ fn error_path_is_directory() {
+ let dir = tempfile::tempdir().unwrap();
+ let path = dir.path().to_path_buf();
+
+ let call = WriteToolCall {
+ path: path.clone(),
+ content: "content".to_string(),
+ overwrite: false,
+ };
+ let (outcome, new_bytes) = call.execute(&path);
+
+ assert!(new_bytes.is_none());
+ assert!(matches!(outcome, ToolOutcome::Error(ref msg) if msg.contains("directory")));
+ }
+ }
+
// ── Windows-specific tests (absolute paths with drive letters) ──
#[cfg(windows)]
diff --git a/crates/atuin-ai/src/tui/dispatch.rs b/crates/atuin-ai/src/tui/dispatch.rs
index fea26953..46eebd9b 100644
--- a/crates/atuin-ai/src/tui/dispatch.rs
+++ b/crates/atuin-ai/src/tui/dispatch.rs
@@ -232,6 +232,10 @@ fn execute_tool(
let edit_call = edit_call.clone();
execute_edit_tool(handle, tx, tool_id, edit_call);
}
+ ClientToolCall::Write(write_call) => {
+ let write_call = write_call.clone();
+ execute_write_tool(handle, tx, tool_id, write_call);
+ }
_ => {
execute_simple_tool(handle, tx, tool_id, tool, db);
}
@@ -387,6 +391,69 @@ fn execute_edit_tool(
});
}
+/// Execute a write_file tool call.
+///
+/// Snapshots the existing file (if any) before overwriting, writes atomically,
+/// stores a content preview on the tracker, and updates the file tracker.
+fn execute_write_tool(
+ handle: &Handle<Session>,
+ tx: &mpsc::Sender<AiTuiEvent>,
+ tool_id: String,
+ write_call: crate::tools::WriteToolCall,
+) {
+ let h = handle.clone();
+ let tx = tx.clone();
+
+ tokio::spawn(async move {
+ let resolved = write_call.resolved_path();
+
+ // 1. Snapshot the existing file before overwriting (if it exists).
+ if resolved.exists()
+ && let Ok(original_content) = std::fs::read(&resolved)
+ {
+ let snap_path = resolved.clone();
+ h.update(move |state| {
+ if let Some(ref mut store) = state.snapshot_store
+ && let Err(e) = store.ensure_snapshot(&snap_path, &original_content)
+ {
+ tracing::warn!("failed to create file snapshot: {e}");
+ }
+ });
+ }
+
+ // 2. Execute: check exists/overwrite, atomic write
+ let (outcome, new_bytes) = write_call.execute(&resolved);
+
+ // 3. Build content preview on success
+ let write_preview = if new_bytes.is_some() {
+ Some(crate::diff::WritePreview::from_content(&write_call.content))
+ } else {
+ None
+ };
+
+ // 4. Update tracker, store preview, and finish
+ let tc_id = tool_id;
+ h.update(move |state| {
+ if let Some(ref new_bytes) = new_bytes
+ && let Ok(mtime) = std::fs::metadata(&resolved).and_then(|m| m.modified())
+ {
+ state
+ .file_tracker
+ .update_after_edit(&resolved, new_bytes, mtime);
+ }
+ if let Some(preview) = write_preview
+ && let Some(tracked) = state.tool_tracker.get_mut(&tc_id)
+ {
+ tracked.write_preview = Some(preview);
+ }
+ state.finish_tool_call(&tc_id, outcome);
+ if !state.tool_tracker.has_pending() {
+ let _ = tx.send(AiTuiEvent::ContinueAfterTools);
+ }
+ });
+ });
+}
+
/// Execute a shell tool with streaming VT100 preview.
fn execute_shell_tool(
handle: &Handle<Session>,
diff --git a/crates/atuin-ai/src/tui/view/mod.rs b/crates/atuin-ai/src/tui/view/mod.rs
index bdbece9c..6e13e406 100644
--- a/crates/atuin-ai/src/tui/view/mod.rs
+++ b/crates/atuin-ai/src/tui/view/mod.rs
@@ -175,7 +175,7 @@ fn tool_call_view(tool_call: &TrackedTool, in_git_project: bool) -> Elements {
/// keep the standard set.
fn permission_options_for_tool(tool: &ClientToolCall, in_git_project: bool) -> Vec<SelectOption> {
match tool {
- ClientToolCall::Edit(_) => vec![
+ ClientToolCall::Edit(_) | ClientToolCall::Write(_) => vec![
SelectOption::builder()
.label("Allow")
.value(PermissionResult::Allow.as_value_str())
@@ -296,8 +296,8 @@ fn agent_turn_view(events: &[turn::UiEvent], busy: bool) -> Elements {
turn::ToolRenderData::FileEdit { path, preview } => {
file_edit_tool_view(&tool_key, &details.status, path, preview.as_ref())
},
- turn::ToolRenderData::FileWrite { path } => {
- file_write_tool_view(&details.status, path)
+ turn::ToolRenderData::FileWrite { path, preview } => {
+ file_write_tool_view(&tool_key, &details.status, path, preview.as_ref())
},
turn::ToolRenderData::Remote => {
tool_status_view(&details.name, &details.status)
@@ -577,10 +577,16 @@ fn file_edit_tool_view(
}
}
-/// Render a file write tool call status with the target path.
-fn file_write_tool_view(status: &turn::ToolResultStatus, path: &std::path::Path) -> Elements {
- let display_path = path.display();
- match status {
+/// Render a file write tool call with content preview.
+fn file_write_tool_view(
+ key: &str,
+ status: &turn::ToolResultStatus,
+ path: &std::path::Path,
+ preview: Option<&crate::diff::WritePreview>,
+) -> Elements {
+ let display_path = format_path_for_display(path);
+
+ let status_line = match status {
turn::ToolResultStatus::Pending => {
element! {
Spinner(
@@ -591,18 +597,62 @@ fn file_write_tool_view(status: &turn::ToolResultStatus, path: &std::path::Path)
}
}
turn::ToolResultStatus::Success => {
+ let line_info = preview
+ .map(|p| format!(" ({} lines)", p.total_lines))
+ .unwrap_or_default();
element! {
- Spinner(label: format!("Wrote: {display_path}"), done: true)
+ Spinner(label: format!("Wrote: {display_path}{line_info}"), done: true)
}
}
turn::ToolResultStatus::Error => {
element! {
Text {
Span(text: "✗ ", style: Style::default().fg(Color::Red))
- Span(text: format!("Write {display_path}: denied"), style: Style::default().fg(Color::Red))
+ Span(text: format!("Write {display_path}: failed"), style: Style::default().fg(Color::Red))
}
}
}
+ };
+
+ let Some(preview) = preview else {
+ return status_line;
+ };
+ if preview.lines.is_empty() {
+ return status_line;
+ }
+
+ let gutter_width = preview.total_lines.to_string().len().max(2) as u16 + 1;
+ let remaining = preview.remaining_lines();
+
+ element! {
+ View(key: key.to_string()) {
+ #(status_line)
+
+ View(key: format!("{key}-content"), padding_left: Cells::from(2)) {
+ #(for (idx, line) in preview.lines.iter().enumerate() {
+ HStack(key: format!("{key}-line-{idx}")) {
+ View(width: WidthConstraint::Fixed(gutter_width)) {
+ Text { Span(
+ text: format!("{:>width$}", idx + 1, width = (gutter_width - 1) as usize),
+ style: Style::default().fg(Color::DarkGray)
+ ) }
+ }
+ View {
+ Text { Span(text: line, style: Style::default().fg(Color::DarkGray)) }
+ }
+ }
+ })
+
+ #(if remaining > 0 {
+ Text {
+ Span(
+ text: format!(" ... +{remaining} more lines"),
+ style: Style::default().fg(Color::DarkGray)
+ )
+ }
+ })
+ }
+ }
}
}
diff --git a/crates/atuin-ai/src/tui/view/turn.rs b/crates/atuin-ai/src/tui/view/turn.rs
index 1c19a6b2..6c3d5c29 100644
--- a/crates/atuin-ai/src/tui/view/turn.rs
+++ b/crates/atuin-ai/src/tui/view/turn.rs
@@ -141,7 +141,10 @@ pub(crate) enum ToolRenderData {
preview: Option<crate::diff::EditPreview>,
},
/// File write/create operation.
- FileWrite { path: PathBuf },
+ FileWrite {
+ path: PathBuf,
+ preview: Option<crate::diff::WritePreview>,
+ },
/// Atuin history search.
HistorySearch {
query: String,
@@ -449,6 +452,7 @@ impl<'a> TurnBuilder<'a> {
},
ClientToolCall::Write(write) => ToolRenderData::FileWrite {
path: write.path.clone(),
+ preview: tracked.write_preview.clone(),
},
ClientToolCall::AtuinHistory(history) => ToolRenderData::HistorySearch {
query: history.query.clone(),