diff options
Diffstat (limited to 'crates/atuin-ai/src')
| -rw-r--r-- | crates/atuin-ai/src/diff.rs | 34 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tools/mod.rs | 174 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tui/dispatch.rs | 67 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tui/view/mod.rs | 68 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tui/view/turn.rs | 6 |
5 files changed, 336 insertions, 13 deletions
diff --git a/crates/atuin-ai/src/diff.rs b/crates/atuin-ai/src/diff.rs index 663481c0..e704175c 100644 --- a/crates/atuin-ai/src/diff.rs +++ b/crates/atuin-ai/src/diff.rs @@ -101,6 +101,40 @@ impl EditPreview { } } +/// Maximum lines to show in a write preview. +const WRITE_PREVIEW_LINES: usize = 10; + +/// A content preview for a write_file operation. +/// +/// Shows the first N lines of the written content plus a count of +/// remaining lines if truncated. +#[derive(Debug, Clone)] +pub(crate) struct WritePreview { + /// First lines of content (up to WRITE_PREVIEW_LINES). + pub lines: Vec<String>, + /// Total number of lines in the written file. + pub total_lines: usize, +} + +impl WritePreview { + /// Create a preview from file content. + pub fn from_content(content: &str) -> Self { + let all_lines: Vec<&str> = content.lines().collect(); + let total_lines = all_lines.len(); + let lines = all_lines + .into_iter() + .take(WRITE_PREVIEW_LINES) + .map(String::from) + .collect(); + WritePreview { lines, total_lines } + } + + /// Number of lines not shown in the preview. + pub fn remaining_lines(&self) -> usize { + self.total_lines.saturating_sub(self.lines.len()) + } +} + /// Build a single DiffHunk from a group of adjacent raw hunks. fn build_hunk(group: &[&imara_diff::Hunk], input: &InternedInput<&str>) -> DiffHunk { let first = group.first().unwrap(); diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs index 8fe1ad73..890ea734 100644 --- a/crates/atuin-ai/src/tools/mod.rs +++ b/crates/atuin-ai/src/tools/mod.rs @@ -171,6 +171,8 @@ pub(crate) struct TrackedTool { pub abort_tx: Option<tokio::sync::oneshot::Sender<()>>, /// Diff preview for completed edit tool calls. pub edit_preview: Option<crate::diff::EditPreview>, + /// Content preview for completed write tool calls. + pub write_preview: Option<crate::diff::WritePreview>, } impl TrackedTool { @@ -237,6 +239,7 @@ impl ToolTracker { phase: ToolPhase::CheckingPermissions, abort_tx: None, edit_preview: None, + write_preview: None, }); } @@ -724,10 +727,10 @@ impl PermissableToolCall for EditToolCall { } #[derive(Debug, Clone)] -#[expect(dead_code)] pub(crate) struct WriteToolCall { pub path: PathBuf, pub content: String, + pub overwrite: bool, } impl TryFrom<&serde_json::Value> for WriteToolCall { @@ -735,22 +738,85 @@ impl TryFrom<&serde_json::Value> for WriteToolCall { fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> { let path = value - .get("path") + .get("file_path") .and_then(|v| v.as_str()) - .ok_or(eyre::eyre!("Missing path"))?; + .ok_or(eyre::eyre!("Missing file_path"))?; let content = value .get("content") .and_then(|v| v.as_str()) .ok_or(eyre::eyre!("Missing content"))?; + let overwrite = value + .get("overwrite") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + Ok(WriteToolCall { path: expand_path(path), content: content.to_string(), + overwrite, }) } } +impl WriteToolCall { + /// Resolve the write path to an absolute path. + pub fn resolved_path(&self) -> PathBuf { + if self.path.is_relative() { + std::env::current_dir() + .map(|cwd| cwd.join(&self.path)) + .unwrap_or_else(|_| self.path.clone()) + } else { + self.path.clone() + } + } + + /// Execute the write operation. + /// + /// Creates a new file or overwrites an existing one (if `overwrite` is set). + /// Returns the outcome and the written bytes (for tracker updates). + pub fn execute(&self, resolved_path: &Path) -> (ToolOutcome, Option<Vec<u8>>) { + if resolved_path.is_dir() { + return ( + ToolOutcome::Error(format!( + "Error: path is a directory, not a file: {}", + resolved_path.display() + )), + None, + ); + } + if resolved_path.exists() && !self.overwrite { + return ( + ToolOutcome::Error(format!( + "File already exists: {}. Set overwrite to true to replace it, or use edit_file to make targeted changes.", + resolved_path.display() + )), + None, + ); + } + + // Capture before the write — after atomic_write the file always exists. + let existed = resolved_path.exists(); + + // Write atomically + let content_bytes = self.content.as_bytes().to_vec(); + if let Err(e) = crate::snapshots::atomic_write_file(resolved_path, &content_bytes) { + return (ToolOutcome::Error(format!("Error writing file: {e}")), None); + } + + let line_count = self.content.lines().count(); + let verb = if existed { "Overwrote" } else { "Created" }; + ( + ToolOutcome::Success(format!( + "{verb} {} ({line_count} lines).", + resolved_path.display() + )), + Some(content_bytes), + ) + } +} + impl PermissableToolCall for WriteToolCall { fn target_dir(&self) -> Option<&Path> { Some(&self.path) @@ -1235,6 +1301,7 @@ mod tests { WriteToolCall { path: expand_path(path), content: String::new(), + overwrite: false, } } @@ -1735,6 +1802,107 @@ mod tests { } } + // ── write_file execution tests ── + + mod write { + use super::*; + + #[test] + fn creates_new_file() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("new_file.txt"); + + let call = WriteToolCall { + path: path.clone(), + content: "hello\nworld\n".to_string(), + overwrite: false, + }; + let (outcome, new_bytes) = call.execute(&path); + + assert!(matches!(outcome, ToolOutcome::Success(ref s) if s.contains("Created"))); + assert!(new_bytes.is_some()); + assert_eq!(std::fs::read_to_string(&path).unwrap(), "hello\nworld\n"); + } + + #[test] + fn error_file_exists_without_overwrite() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("existing.txt"); + std::fs::write(&path, "original").unwrap(); + + let call = WriteToolCall { + path: path.clone(), + content: "new content".to_string(), + overwrite: false, + }; + let (outcome, new_bytes) = call.execute(&path); + + assert!(new_bytes.is_none()); + match outcome { + ToolOutcome::Error(msg) => { + assert!(msg.contains("already exists"), "got: {msg}"); + assert!(msg.contains("overwrite"), "got: {msg}"); + } + _ => panic!("expected error"), + } + // Original preserved + assert_eq!(std::fs::read_to_string(&path).unwrap(), "original"); + } + + #[test] + fn overwrites_existing_file_when_flag_set() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("existing.txt"); + std::fs::write(&path, "original").unwrap(); + + let call = WriteToolCall { + path: path.clone(), + content: "replaced content\n".to_string(), + overwrite: true, + }; + let (outcome, new_bytes) = call.execute(&path); + + assert!(matches!(outcome, ToolOutcome::Success(_))); + assert!(new_bytes.is_some()); + assert_eq!( + std::fs::read_to_string(&path).unwrap(), + "replaced content\n" + ); + } + + #[test] + fn creates_parent_directories() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("sub").join("dir").join("file.txt"); + + let call = WriteToolCall { + path: path.clone(), + content: "nested\n".to_string(), + overwrite: false, + }; + let (outcome, _) = call.execute(&path); + + assert!(matches!(outcome, ToolOutcome::Success(_))); + assert_eq!(std::fs::read_to_string(&path).unwrap(), "nested\n"); + } + + #[test] + fn error_path_is_directory() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().to_path_buf(); + + let call = WriteToolCall { + path: path.clone(), + content: "content".to_string(), + overwrite: false, + }; + let (outcome, new_bytes) = call.execute(&path); + + assert!(new_bytes.is_none()); + assert!(matches!(outcome, ToolOutcome::Error(ref msg) if msg.contains("directory"))); + } + } + // ── Windows-specific tests (absolute paths with drive letters) ── #[cfg(windows)] diff --git a/crates/atuin-ai/src/tui/dispatch.rs b/crates/atuin-ai/src/tui/dispatch.rs index fea26953..46eebd9b 100644 --- a/crates/atuin-ai/src/tui/dispatch.rs +++ b/crates/atuin-ai/src/tui/dispatch.rs @@ -232,6 +232,10 @@ fn execute_tool( let edit_call = edit_call.clone(); execute_edit_tool(handle, tx, tool_id, edit_call); } + ClientToolCall::Write(write_call) => { + let write_call = write_call.clone(); + execute_write_tool(handle, tx, tool_id, write_call); + } _ => { execute_simple_tool(handle, tx, tool_id, tool, db); } @@ -387,6 +391,69 @@ fn execute_edit_tool( }); } +/// Execute a write_file tool call. +/// +/// Snapshots the existing file (if any) before overwriting, writes atomically, +/// stores a content preview on the tracker, and updates the file tracker. +fn execute_write_tool( + handle: &Handle<Session>, + tx: &mpsc::Sender<AiTuiEvent>, + tool_id: String, + write_call: crate::tools::WriteToolCall, +) { + let h = handle.clone(); + let tx = tx.clone(); + + tokio::spawn(async move { + let resolved = write_call.resolved_path(); + + // 1. Snapshot the existing file before overwriting (if it exists). + if resolved.exists() + && let Ok(original_content) = std::fs::read(&resolved) + { + let snap_path = resolved.clone(); + h.update(move |state| { + if let Some(ref mut store) = state.snapshot_store + && let Err(e) = store.ensure_snapshot(&snap_path, &original_content) + { + tracing::warn!("failed to create file snapshot: {e}"); + } + }); + } + + // 2. Execute: check exists/overwrite, atomic write + let (outcome, new_bytes) = write_call.execute(&resolved); + + // 3. Build content preview on success + let write_preview = if new_bytes.is_some() { + Some(crate::diff::WritePreview::from_content(&write_call.content)) + } else { + None + }; + + // 4. Update tracker, store preview, and finish + let tc_id = tool_id; + h.update(move |state| { + if let Some(ref new_bytes) = new_bytes + && let Ok(mtime) = std::fs::metadata(&resolved).and_then(|m| m.modified()) + { + state + .file_tracker + .update_after_edit(&resolved, new_bytes, mtime); + } + if let Some(preview) = write_preview + && let Some(tracked) = state.tool_tracker.get_mut(&tc_id) + { + tracked.write_preview = Some(preview); + } + state.finish_tool_call(&tc_id, outcome); + if !state.tool_tracker.has_pending() { + let _ = tx.send(AiTuiEvent::ContinueAfterTools); + } + }); + }); +} + /// Execute a shell tool with streaming VT100 preview. fn execute_shell_tool( handle: &Handle<Session>, diff --git a/crates/atuin-ai/src/tui/view/mod.rs b/crates/atuin-ai/src/tui/view/mod.rs index bdbece9c..6e13e406 100644 --- a/crates/atuin-ai/src/tui/view/mod.rs +++ b/crates/atuin-ai/src/tui/view/mod.rs @@ -175,7 +175,7 @@ fn tool_call_view(tool_call: &TrackedTool, in_git_project: bool) -> Elements { /// keep the standard set. fn permission_options_for_tool(tool: &ClientToolCall, in_git_project: bool) -> Vec<SelectOption> { match tool { - ClientToolCall::Edit(_) => vec![ + ClientToolCall::Edit(_) | ClientToolCall::Write(_) => vec![ SelectOption::builder() .label("Allow") .value(PermissionResult::Allow.as_value_str()) @@ -296,8 +296,8 @@ fn agent_turn_view(events: &[turn::UiEvent], busy: bool) -> Elements { turn::ToolRenderData::FileEdit { path, preview } => { file_edit_tool_view(&tool_key, &details.status, path, preview.as_ref()) }, - turn::ToolRenderData::FileWrite { path } => { - file_write_tool_view(&details.status, path) + turn::ToolRenderData::FileWrite { path, preview } => { + file_write_tool_view(&tool_key, &details.status, path, preview.as_ref()) }, turn::ToolRenderData::Remote => { tool_status_view(&details.name, &details.status) @@ -577,10 +577,16 @@ fn file_edit_tool_view( } } -/// Render a file write tool call status with the target path. -fn file_write_tool_view(status: &turn::ToolResultStatus, path: &std::path::Path) -> Elements { - let display_path = path.display(); - match status { +/// Render a file write tool call with content preview. +fn file_write_tool_view( + key: &str, + status: &turn::ToolResultStatus, + path: &std::path::Path, + preview: Option<&crate::diff::WritePreview>, +) -> Elements { + let display_path = format_path_for_display(path); + + let status_line = match status { turn::ToolResultStatus::Pending => { element! { Spinner( @@ -591,18 +597,62 @@ fn file_write_tool_view(status: &turn::ToolResultStatus, path: &std::path::Path) } } turn::ToolResultStatus::Success => { + let line_info = preview + .map(|p| format!(" ({} lines)", p.total_lines)) + .unwrap_or_default(); element! { - Spinner(label: format!("Wrote: {display_path}"), done: true) + Spinner(label: format!("Wrote: {display_path}{line_info}"), done: true) } } turn::ToolResultStatus::Error => { element! { Text { Span(text: "✗ ", style: Style::default().fg(Color::Red)) - Span(text: format!("Write {display_path}: denied"), style: Style::default().fg(Color::Red)) + Span(text: format!("Write {display_path}: failed"), style: Style::default().fg(Color::Red)) } } } + }; + + let Some(preview) = preview else { + return status_line; + }; + if preview.lines.is_empty() { + return status_line; + } + + let gutter_width = preview.total_lines.to_string().len().max(2) as u16 + 1; + let remaining = preview.remaining_lines(); + + element! { + View(key: key.to_string()) { + #(status_line) + + View(key: format!("{key}-content"), padding_left: Cells::from(2)) { + #(for (idx, line) in preview.lines.iter().enumerate() { + HStack(key: format!("{key}-line-{idx}")) { + View(width: WidthConstraint::Fixed(gutter_width)) { + Text { Span( + text: format!("{:>width$}", idx + 1, width = (gutter_width - 1) as usize), + style: Style::default().fg(Color::DarkGray) + ) } + } + View { + Text { Span(text: line, style: Style::default().fg(Color::DarkGray)) } + } + } + }) + + #(if remaining > 0 { + Text { + Span( + text: format!(" ... +{remaining} more lines"), + style: Style::default().fg(Color::DarkGray) + ) + } + }) + } + } } } diff --git a/crates/atuin-ai/src/tui/view/turn.rs b/crates/atuin-ai/src/tui/view/turn.rs index 1c19a6b2..6c3d5c29 100644 --- a/crates/atuin-ai/src/tui/view/turn.rs +++ b/crates/atuin-ai/src/tui/view/turn.rs @@ -141,7 +141,10 @@ pub(crate) enum ToolRenderData { preview: Option<crate::diff::EditPreview>, }, /// File write/create operation. - FileWrite { path: PathBuf }, + FileWrite { + path: PathBuf, + preview: Option<crate::diff::WritePreview>, + }, /// Atuin history search. HistorySearch { query: String, @@ -449,6 +452,7 @@ impl<'a> TurnBuilder<'a> { }, ClientToolCall::Write(write) => ToolRenderData::FileWrite { path: write.path.clone(), + preview: tracked.write_preview.clone(), }, ClientToolCall::AtuinHistory(history) => ToolRenderData::HistorySearch { query: history.query.clone(), |
