aboutsummaryrefslogtreecommitdiffstats
path: root/crates/atuin-ai/src
diff options
context:
space:
mode:
Diffstat (limited to 'crates/atuin-ai/src')
-rw-r--r--crates/atuin-ai/src/commands/inline.rs29
-rw-r--r--crates/atuin-ai/src/diff.rs294
-rw-r--r--crates/atuin-ai/src/edit_permissions.rs108
-rw-r--r--crates/atuin-ai/src/file_tracker.rs234
-rw-r--r--crates/atuin-ai/src/lib.rs4
-rw-r--r--crates/atuin-ai/src/session.rs27
-rw-r--r--crates/atuin-ai/src/snapshots.rs414
-rw-r--r--crates/atuin-ai/src/store.rs32
-rw-r--r--crates/atuin-ai/src/stream.rs8
-rw-r--r--crates/atuin-ai/src/tools/descriptor.rs17
-rw-r--r--crates/atuin-ai/src/tools/mod.rs737
-rw-r--r--crates/atuin-ai/src/tui/dispatch.rs199
-rw-r--r--crates/atuin-ai/src/tui/events.rs27
-rw-r--r--crates/atuin-ai/src/tui/state.rs9
-rw-r--r--crates/atuin-ai/src/tui/view/mod.rs570
-rw-r--r--crates/atuin-ai/src/tui/view/turn.rs198
16 files changed, 2763 insertions, 144 deletions
diff --git a/crates/atuin-ai/src/commands/inline.rs b/crates/atuin-ai/src/commands/inline.rs
index b7aae51f..e0a92ab4 100644
--- a/crates/atuin-ai/src/commands/inline.rs
+++ b/crates/atuin-ai/src/commands/inline.rs
@@ -175,7 +175,7 @@ async fn run_inline_tui(
.find_resumable(cwd.as_deref(), git_root_str.as_deref(), max_age_secs)
.await?;
- let (mut session_mgr, initial_state) = if let Some(stored) = resumable {
+ let (mut session_mgr, mut initial_state) = if let Some(stored) = resumable {
debug!(session_id = %stored.id, "resuming AI session");
let (mgr, events, server_sid, last_event_ts, invocation_id) =
SessionManager::resume(Box::new(service), &stored).await?;
@@ -199,6 +199,23 @@ async fn run_inline_tui(
session.is_resumed = true;
session.last_event_time =
last_event_ts.and_then(|ts| chrono::DateTime::from_timestamp(ts, 0));
+
+ // Restore file read tracker from session metadata
+ if let Ok(Some(json)) = mgr.get_metadata(crate::file_tracker::METADATA_KEY).await
+ && let Ok(tracker) = crate::file_tracker::FileReadTracker::from_json(&json)
+ {
+ session.file_tracker = tracker;
+ }
+
+ // Restore edit permission grants from session metadata
+ if let Ok(Some(json)) = mgr
+ .get_metadata(crate::edit_permissions::METADATA_KEY)
+ .await
+ && let Ok(cache) = crate::edit_permissions::EditPermissionCache::from_json(&json)
+ {
+ session.edit_permissions = cache;
+ }
+
(mgr, session)
} else {
// No meaningful content — treat as a fresh session
@@ -215,6 +232,16 @@ async fn run_inline_tui(
(mgr, Session::new(ctx.git_root.is_some(), None))
};
+ // Initialize the snapshot store now that we know the session ID.
+ let snapshot_dir = atuin_common::utils::data_dir()
+ .join("ai")
+ .join("snapshots")
+ .join(session_mgr.session_id());
+ match crate::snapshots::SnapshotStore::open(snapshot_dir) {
+ Ok(store) => initial_state.snapshot_store = Some(store),
+ Err(e) => tracing::warn!("failed to open snapshot store: {e}"),
+ }
+
let (tx, rx) = mpsc::channel::<AiTuiEvent>();
println!();
diff --git a/crates/atuin-ai/src/diff.rs b/crates/atuin-ai/src/diff.rs
new file mode 100644
index 00000000..663481c0
--- /dev/null
+++ b/crates/atuin-ai/src/diff.rs
@@ -0,0 +1,294 @@
+//! Structured diff computation for edit previews.
+//!
+//! Computes a line-level diff between old and new file content using
+//! imara-diff's Histogram algorithm, producing structured hunks with
+//! typed lines (Context, Added, Removed) suitable for TUI rendering.
+
+use imara_diff::{Algorithm, Diff, InternedInput};
+
+/// Number of context lines to show around each change.
+const CONTEXT_LINES: u32 = 3;
+
+/// A structured diff preview for a file edit, ready for rendering.
+#[derive(Debug, Clone)]
+pub(crate) struct EditPreview {
+ pub hunks: Vec<DiffHunk>,
+}
+
+/// A contiguous group of diff lines (context + changes).
+#[derive(Debug, Clone)]
+pub(crate) struct DiffHunk {
+ /// 1-indexed line number of the first line in this hunk (in the original file).
+ pub before_start: u32,
+ /// 1-indexed line number of the first line in this hunk (in the new file).
+ pub after_start: u32,
+ pub lines: Vec<DiffLine>,
+}
+
+/// A single line in a diff hunk.
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub(crate) enum DiffLine {
+ /// Unchanged line (shown for context).
+ Context(String),
+ /// Line added in the new version.
+ Added(String),
+ /// Line removed from the old version.
+ Removed(String),
+}
+
+impl EditPreview {
+ /// Compute a structured diff between old and new file content.
+ ///
+ /// Uses the Histogram algorithm with line-level granularity and
+ /// indentation-aware postprocessing for readable output.
+ pub fn compute(old: &str, new: &str) -> Self {
+ let input = InternedInput::new(old, new);
+ let mut diff = Diff::compute(Algorithm::Histogram, &input);
+ diff.postprocess_lines(&input);
+
+ let raw_hunks: Vec<_> = diff.hunks().collect();
+ if raw_hunks.is_empty() {
+ return EditPreview { hunks: Vec::new() };
+ }
+
+ // Merge hunks that are within 2*CONTEXT_LINES of each other
+ // (same logic as unified diff format).
+ let mut merged_groups: Vec<Vec<&imara_diff::Hunk>> = Vec::new();
+ let mut current_group: Vec<&imara_diff::Hunk> = vec![&raw_hunks[0]];
+
+ for hunk in &raw_hunks[1..] {
+ let prev = current_group.last().unwrap();
+ if hunk.before.start.saturating_sub(prev.before.end) <= 2 * CONTEXT_LINES {
+ current_group.push(hunk);
+ } else {
+ merged_groups.push(current_group);
+ current_group = vec![hunk];
+ }
+ }
+ merged_groups.push(current_group);
+
+ // Build structured hunks from merged groups
+ let hunks = merged_groups
+ .into_iter()
+ .map(|group| build_hunk(&group, &input))
+ .collect();
+
+ EditPreview { hunks }
+ }
+
+ /// The highest line number (from either file) that will be displayed.
+ /// Used to calculate gutter width.
+ pub fn max_line_number(&self) -> u32 {
+ self.hunks
+ .iter()
+ .map(|h| {
+ let mut before_pos = h.before_start;
+ let mut after_pos = h.after_start;
+ for line in &h.lines {
+ match line {
+ DiffLine::Context(_) => {
+ before_pos += 1;
+ after_pos += 1;
+ }
+ DiffLine::Removed(_) => before_pos += 1,
+ DiffLine::Added(_) => after_pos += 1,
+ }
+ }
+ before_pos.max(after_pos).saturating_sub(1)
+ })
+ .max()
+ .unwrap_or(0)
+ }
+}
+
+/// Build a single DiffHunk from a group of adjacent raw hunks.
+fn build_hunk(group: &[&imara_diff::Hunk], input: &InternedInput<&str>) -> DiffHunk {
+ let first = group.first().unwrap();
+ let last = group.last().unwrap();
+
+ let context_start = first.before.start.saturating_sub(CONTEXT_LINES);
+ let context_end = (last.before.end + CONTEXT_LINES).min(input.before.len() as u32);
+
+ // The after-file position of context_start: same offset as before since
+ // context before the first change is identical in both files.
+ let after_context_start = first.after.start - (first.before.start - context_start);
+
+ let mut lines = Vec::new();
+ let mut pos = context_start;
+
+ for hunk in group {
+ // Context lines before this hunk
+ for i in pos..hunk.before.start {
+ lines.push(DiffLine::Context(token_text(input, true, i)));
+ }
+
+ // Removed lines
+ for i in hunk.before.start..hunk.before.end {
+ lines.push(DiffLine::Removed(token_text(input, true, i)));
+ }
+
+ // Added lines
+ for i in hunk.after.start..hunk.after.end {
+ lines.push(DiffLine::Added(token_text(input, false, i)));
+ }
+
+ pos = hunk.before.end;
+ }
+
+ // Trailing context
+ for i in pos..context_end {
+ lines.push(DiffLine::Context(token_text(input, true, i)));
+ }
+
+ DiffHunk {
+ before_start: context_start + 1, // 1-indexed
+ after_start: after_context_start + 1, // 1-indexed
+ lines,
+ }
+}
+
+/// Extract the text content of a token, trimming the trailing newline
+/// that imara-diff includes in line-based tokenization.
+fn token_text(input: &InternedInput<&str>, is_before: bool, idx: u32) -> String {
+ let tokens = if is_before {
+ &input.before
+ } else {
+ &input.after
+ };
+ let text = input.interner[tokens[idx as usize]];
+ text.strip_suffix('\n')
+ .unwrap_or(text)
+ .strip_suffix('\r')
+ .unwrap_or(text.strip_suffix('\n').unwrap_or(text))
+ .to_string()
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn no_changes_produces_empty_preview() {
+ let preview = EditPreview::compute("hello\nworld\n", "hello\nworld\n");
+ assert!(preview.hunks.is_empty());
+ }
+
+ #[test]
+ fn single_line_replacement() {
+ let old = "line1\nline2\nline3\n";
+ let new = "line1\nchanged\nline3\n";
+ let preview = EditPreview::compute(old, new);
+
+ assert_eq!(preview.hunks.len(), 1);
+ let hunk = &preview.hunks[0];
+
+ // Should have: context(line1), removed(line2), added(changed), context(line3)
+ assert!(hunk.lines.contains(&DiffLine::Context("line1".into())));
+ assert!(hunk.lines.contains(&DiffLine::Removed("line2".into())));
+ assert!(hunk.lines.contains(&DiffLine::Added("changed".into())));
+ assert!(hunk.lines.contains(&DiffLine::Context("line3".into())));
+ }
+
+ #[test]
+ fn addition_only() {
+ let old = "aaa\nbbb\n";
+ let new = "aaa\nnew_line\nbbb\n";
+ let preview = EditPreview::compute(old, new);
+
+ assert_eq!(preview.hunks.len(), 1);
+ let hunk = &preview.hunks[0];
+ assert!(hunk.lines.contains(&DiffLine::Added("new_line".into())));
+ // Original lines are context
+ assert!(hunk.lines.contains(&DiffLine::Context("aaa".into())));
+ assert!(hunk.lines.contains(&DiffLine::Context("bbb".into())));
+ }
+
+ #[test]
+ fn removal_only() {
+ let old = "aaa\nremove_me\nbbb\n";
+ let new = "aaa\nbbb\n";
+ let preview = EditPreview::compute(old, new);
+
+ assert_eq!(preview.hunks.len(), 1);
+ let hunk = &preview.hunks[0];
+ assert!(hunk.lines.contains(&DiffLine::Removed("remove_me".into())));
+ }
+
+ #[test]
+ fn distant_changes_produce_separate_hunks() {
+ // Two changes separated by more than 2*CONTEXT_LINES (6) lines
+ let old = "1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n11\n12\n";
+ let new = "1\nX\n3\n4\n5\n6\n7\n8\n9\n10\n11\nY\n";
+ let preview = EditPreview::compute(old, new);
+
+ assert_eq!(preview.hunks.len(), 2);
+ }
+
+ #[test]
+ fn close_changes_merge_into_one_hunk() {
+ // Two changes separated by fewer than 2*CONTEXT_LINES lines
+ let old = "1\n2\n3\n4\n5\n";
+ let new = "X\n2\n3\n4\nY\n";
+ let preview = EditPreview::compute(old, new);
+
+ assert_eq!(preview.hunks.len(), 1);
+ }
+
+ #[test]
+ fn context_is_limited() {
+ // With CONTEXT_LINES=3, a change at line 10 shouldn't include line 1
+ let mut old_lines: Vec<&str> = (1..=20).map(|_| "unchanged").collect();
+ old_lines[9] = "target";
+ let old = old_lines.join("\n") + "\n";
+ let new = old.replace("target", "replaced");
+
+ let preview = EditPreview::compute(&old, &new);
+ assert_eq!(preview.hunks.len(), 1);
+
+ // Should have at most 3 context lines before + 3 after + 1 removed + 1 added = 8 lines
+ assert!(preview.hunks[0].lines.len() <= 8);
+ }
+
+ #[test]
+ fn max_line_number_reflects_file_position() {
+ let old = "a\nb\nc\n";
+ let new = "a\nX\nc\n";
+ let preview = EditPreview::compute(old, new);
+ // 3-line file, context + removed lines span positions 1-3
+ assert_eq!(preview.max_line_number(), 3);
+ }
+
+ #[test]
+ fn start_line_is_correct_for_later_changes() {
+ // Change at line 10 with 3 context lines → before_start = 7
+ let mut lines: Vec<String> = (1..=15).map(|i| format!("line{i}")).collect();
+ let old = lines.join("\n") + "\n";
+ lines[9] = "CHANGED".to_string();
+ let new = lines.join("\n") + "\n";
+
+ let preview = EditPreview::compute(&old, &new);
+ assert_eq!(preview.hunks.len(), 1);
+ assert_eq!(preview.hunks[0].before_start, 7); // line 10 - 3 context = line 7
+ assert_eq!(preview.hunks[0].after_start, 7); // same for a simple replacement
+ }
+
+ #[test]
+ fn multiline_replacement() {
+ let old = "[section]\nkey1 = old1\nkey2 = old2\n[other]\n";
+ let new = "[section]\nkey1 = new1\nkey2 = new2\n[other]\n";
+ let preview = EditPreview::compute(old, new);
+
+ assert_eq!(preview.hunks.len(), 1);
+ let hunk = &preview.hunks[0];
+ assert!(
+ hunk.lines
+ .contains(&DiffLine::Removed("key1 = old1".into()))
+ );
+ assert!(
+ hunk.lines
+ .contains(&DiffLine::Removed("key2 = old2".into()))
+ );
+ assert!(hunk.lines.contains(&DiffLine::Added("key1 = new1".into())));
+ assert!(hunk.lines.contains(&DiffLine::Added("key2 = new2".into())));
+ }
+}
diff --git a/crates/atuin-ai/src/edit_permissions.rs b/crates/atuin-ai/src/edit_permissions.rs
new file mode 100644
index 00000000..5015a007
--- /dev/null
+++ b/crates/atuin-ai/src/edit_permissions.rs
@@ -0,0 +1,108 @@
+//! Session-scoped permission cache for file edits.
+//!
+//! When the user selects "Allow this file for this session", the grant is
+//! recorded here with a timestamp. Subsequent edits to the same file skip
+//! the permission prompt as long as the grant hasn't expired.
+//!
+//! Grants are time-limited (1 hour TTL) so they don't outlive the user's
+//! attention in long-running sessions. Persisted as JSON in session
+//! metadata so they survive across CLI invocations.
+
+use std::collections::HashMap;
+use std::path::{Path, PathBuf};
+use std::time::SystemTime;
+
+use eyre::Result;
+use serde::{Deserialize, Serialize};
+
+/// Session metadata key for persistence.
+pub(crate) const METADATA_KEY: &str = "edit_permissions";
+
+/// How long a session-scoped edit permission remains valid.
+const TTL_MS: i64 = 60 * 60 * 1000; // 1 hour
+
+/// Cache of per-file edit permission grants within a session.
+#[derive(Debug, Default, Clone, Serialize, Deserialize)]
+pub(crate) struct EditPermissionCache {
+ /// Maps canonical file paths to the grant timestamp (unix millis).
+ grants: HashMap<PathBuf, i64>,
+}
+
+impl EditPermissionCache {
+ /// Record a permission grant for a file.
+ pub fn grant(&mut self, path: PathBuf) {
+ self.grants.insert(path, now_ms());
+ }
+
+ /// Check whether there's a valid (non-expired) grant for a file.
+ pub fn has_valid_grant(&self, path: &Path) -> bool {
+ if let Some(&granted_at) = self.grants.get(path) {
+ (now_ms() - granted_at) < TTL_MS
+ } else {
+ false
+ }
+ }
+
+ pub fn to_json(&self) -> Result<String> {
+ Ok(serde_json::to_string(self)?)
+ }
+
+ pub fn from_json(json: &str) -> Result<Self> {
+ Ok(serde_json::from_str(json)?)
+ }
+}
+
+fn now_ms() -> i64 {
+ SystemTime::now()
+ .duration_since(SystemTime::UNIX_EPOCH)
+ .map(|d| d.as_millis() as i64)
+ .unwrap_or(0)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn grant_and_check() {
+ let mut cache = EditPermissionCache::default();
+ let path = PathBuf::from("/Users/me/.config/foo.toml");
+
+ assert!(!cache.has_valid_grant(&path));
+ cache.grant(path.clone());
+ assert!(cache.has_valid_grant(&path));
+ }
+
+ #[test]
+ fn different_paths_are_independent() {
+ let mut cache = EditPermissionCache::default();
+ let a = PathBuf::from("/etc/hosts");
+ let b = PathBuf::from("/etc/resolv.conf");
+
+ cache.grant(a.clone());
+ assert!(cache.has_valid_grant(&a));
+ assert!(!cache.has_valid_grant(&b));
+ }
+
+ #[test]
+ fn roundtrip_json() {
+ let mut cache = EditPermissionCache::default();
+ cache.grant(PathBuf::from("/some/file.toml"));
+
+ let json = cache.to_json().unwrap();
+ let restored = EditPermissionCache::from_json(&json).unwrap();
+ assert!(restored.has_valid_grant(Path::new("/some/file.toml")));
+ }
+
+ #[test]
+ fn expired_grant_is_invalid() {
+ let mut cache = EditPermissionCache::default();
+ let path = PathBuf::from("/expired/file.toml");
+
+ // Insert a grant from 2 hours ago
+ let two_hours_ago = now_ms() - (2 * 60 * 60 * 1000);
+ cache.grants.insert(path.clone(), two_hours_ago);
+
+ assert!(!cache.has_valid_grant(&path));
+ }
+}
diff --git a/crates/atuin-ai/src/file_tracker.rs b/crates/atuin-ai/src/file_tracker.rs
new file mode 100644
index 00000000..feee1ee8
--- /dev/null
+++ b/crates/atuin-ai/src/file_tracker.rs
@@ -0,0 +1,234 @@
+//! Tracks which files have been read in the current session, for freshness
+//! checking before edits.
+//!
+//! The tracker records the content hash and mtime of each file at the time
+//! it was last read. Before an edit, the tracker verifies the file hasn't
+//! changed since the last read — catching both external modifications and
+//! concurrent tool calls.
+//!
+//! Persisted as JSON in session metadata so it survives across CLI
+//! invocations within the same logical session.
+
+use std::collections::HashMap;
+use std::path::{Path, PathBuf};
+use std::time::SystemTime;
+
+use eyre::Result;
+use serde::{Deserialize, Serialize};
+
+/// Metadata key used for session_metadata persistence.
+pub(crate) const METADATA_KEY: &str = "file_read_tracker";
+
+/// State recorded for a single file read.
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub(crate) struct FileReadState {
+ /// Hash of the file contents at the time of the last read.
+ pub content_hash: u64,
+ /// File mtime (as milliseconds since epoch) at the time of the last read.
+ /// Millisecond precision ensures sub-second modifications are detected.
+ pub mtime_ms: i64,
+}
+
+/// Tracks file read state for freshness checking.
+#[derive(Debug, Default, Clone, Serialize, Deserialize)]
+pub(crate) struct FileReadTracker {
+ reads: HashMap<PathBuf, FileReadState>,
+}
+
+/// Result of a freshness check.
+pub(crate) enum FreshnessCheck {
+ /// File is fresh — the content hasn't changed since the last read.
+ Fresh,
+ /// File has never been read in this session.
+ NotRead,
+ /// File has been modified since the last read.
+ Stale,
+}
+
+impl FileReadTracker {
+ /// Record that a file was read. Call this after a successful `read_file`
+ /// execution. The `path` should be canonical (absolute, tilde-expanded).
+ pub fn record_read(&mut self, path: PathBuf, content: &[u8], mtime: SystemTime) {
+ let content_hash = hash_content(content);
+ let mtime_ms = system_time_to_ms(mtime);
+
+ self.reads.insert(
+ path,
+ FileReadState {
+ content_hash,
+ mtime_ms,
+ },
+ );
+ }
+
+ /// Check whether a file is fresh (unchanged since last read).
+ ///
+ /// Uses mtime as a fast path — only re-hashes if mtime differs.
+ pub fn check_freshness(&self, path: &Path) -> Result<FreshnessCheck> {
+ let state = match self.reads.get(path) {
+ Some(s) => s,
+ None => return Ok(FreshnessCheck::NotRead),
+ };
+
+ // Stat the file
+ let metadata = match std::fs::metadata(path) {
+ Ok(m) => m,
+ Err(_) => return Ok(FreshnessCheck::Stale), // file deleted or inaccessible
+ };
+
+ let current_mtime_ms =
+ system_time_to_ms(metadata.modified().unwrap_or(SystemTime::UNIX_EPOCH));
+
+ // Fast path: mtime unchanged → fresh
+ if current_mtime_ms == state.mtime_ms {
+ return Ok(FreshnessCheck::Fresh);
+ }
+
+ // Mtime changed — re-hash to confirm
+ let content = std::fs::read(path)?;
+ let current_hash = hash_content(&content);
+
+ if current_hash == state.content_hash {
+ Ok(FreshnessCheck::Fresh)
+ } else {
+ Ok(FreshnessCheck::Stale)
+ }
+ }
+
+ /// Update the tracker entry after a successful edit (new content written).
+ pub fn update_after_edit(&mut self, path: &Path, new_content: &[u8], new_mtime: SystemTime) {
+ let content_hash = hash_content(new_content);
+ let mtime_ms = system_time_to_ms(new_mtime);
+
+ self.reads.insert(
+ path.to_path_buf(),
+ FileReadState {
+ content_hash,
+ mtime_ms,
+ },
+ );
+ }
+
+ /// Serialize to JSON for session metadata persistence.
+ pub fn to_json(&self) -> Result<String> {
+ Ok(serde_json::to_string(self)?)
+ }
+
+ /// Deserialize from JSON session metadata.
+ pub fn from_json(json: &str) -> Result<Self> {
+ Ok(serde_json::from_str(json)?)
+ }
+}
+
+fn system_time_to_ms(t: SystemTime) -> i64 {
+ t.duration_since(SystemTime::UNIX_EPOCH)
+ .map(|d| d.as_millis() as i64)
+ .unwrap_or(0)
+}
+
+fn hash_content(content: &[u8]) -> u64 {
+ xxhash_rust::xxh3::xxh3_64(content)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::io::Write;
+ use tempfile::NamedTempFile;
+
+ #[test]
+ fn record_and_check_fresh() {
+ let mut tracker = FileReadTracker::default();
+ let mut tmp = NamedTempFile::new().unwrap();
+ write!(tmp, "hello world").unwrap();
+
+ let path = tmp.path().to_path_buf();
+ let content = std::fs::read(&path).unwrap();
+ let mtime = std::fs::metadata(&path).unwrap().modified().unwrap();
+
+ tracker.record_read(path.clone(), &content, mtime);
+
+ assert!(matches!(
+ tracker.check_freshness(&path).unwrap(),
+ FreshnessCheck::Fresh
+ ));
+ }
+
+ #[test]
+ fn check_not_read() {
+ let tracker = FileReadTracker::default();
+ let path = PathBuf::from("/nonexistent/file.txt");
+ assert!(matches!(
+ tracker.check_freshness(&path).unwrap(),
+ FreshnessCheck::NotRead
+ ));
+ }
+
+ #[test]
+ fn check_stale_after_modification() {
+ let mut tracker = FileReadTracker::default();
+ let mut tmp = NamedTempFile::new().unwrap();
+ write!(tmp, "original").unwrap();
+
+ let path = tmp.path().to_path_buf();
+ let content = std::fs::read(&path).unwrap();
+ let mtime = std::fs::metadata(&path).unwrap().modified().unwrap();
+
+ tracker.record_read(path.clone(), &content, mtime);
+
+ // Small delay to ensure the filesystem mtime advances
+ std::thread::sleep(std::time::Duration::from_millis(10));
+
+ // Modify the file
+ std::fs::write(&path, "modified").unwrap();
+
+ assert!(matches!(
+ tracker.check_freshness(&path).unwrap(),
+ FreshnessCheck::Stale
+ ));
+ }
+
+ #[test]
+ fn update_after_edit_makes_fresh() {
+ let mut tracker = FileReadTracker::default();
+ let mut tmp = NamedTempFile::new().unwrap();
+ write!(tmp, "original").unwrap();
+
+ let path = tmp.path().to_path_buf();
+ let content = std::fs::read(&path).unwrap();
+ let mtime = std::fs::metadata(&path).unwrap().modified().unwrap();
+
+ tracker.record_read(path.clone(), &content, mtime);
+
+ // Simulate an edit
+ let new_content = b"edited content";
+ std::fs::write(&path, new_content).unwrap();
+ let new_mtime = std::fs::metadata(&path).unwrap().modified().unwrap();
+ tracker.update_after_edit(&path, new_content, new_mtime);
+
+ assert!(matches!(
+ tracker.check_freshness(&path).unwrap(),
+ FreshnessCheck::Fresh
+ ));
+ }
+
+ #[test]
+ fn roundtrip_json() {
+ let mut tracker = FileReadTracker::default();
+ tracker.reads.insert(
+ PathBuf::from("/some/file.toml"),
+ FileReadState {
+ content_hash: 12345,
+ mtime_ms: 1700000000000,
+ },
+ );
+
+ let json = tracker.to_json().unwrap();
+ let restored = FileReadTracker::from_json(&json).unwrap();
+ assert_eq!(restored.reads.len(), 1);
+ assert_eq!(
+ restored.reads[&PathBuf::from("/some/file.toml")].content_hash,
+ 12345
+ );
+ }
+}
diff --git a/crates/atuin-ai/src/lib.rs b/crates/atuin-ai/src/lib.rs
index febb488e..afe9c1e4 100644
--- a/crates/atuin-ai/src/lib.rs
+++ b/crates/atuin-ai/src/lib.rs
@@ -1,9 +1,13 @@
pub mod commands;
pub(crate) mod context;
pub(crate) mod context_window;
+pub(crate) mod diff;
+pub(crate) mod edit_permissions;
pub(crate) mod event_serde;
+pub(crate) mod file_tracker;
pub(crate) mod permissions;
pub(crate) mod session;
+pub(crate) mod snapshots;
pub(crate) mod store;
pub(crate) mod stream;
pub(crate) mod tools;
diff --git a/crates/atuin-ai/src/session.rs b/crates/atuin-ai/src/session.rs
index d8314343..848330fc 100644
--- a/crates/atuin-ai/src/session.rs
+++ b/crates/atuin-ai/src/session.rs
@@ -51,6 +51,9 @@ pub(crate) trait SessionService: Send + Sync {
) -> Result<()>;
async fn archive(&self, session_id: &str) -> Result<()>;
+
+ async fn get_metadata(&self, session_id: &str, key: &str) -> Result<Option<String>>;
+ async fn set_metadata(&self, session_id: &str, key: &str, value: &str) -> Result<()>;
}
// ---------------------------------------------------------------------------
@@ -128,6 +131,14 @@ impl SessionService for LocalSessionService {
async fn archive(&self, session_id: &str) -> Result<()> {
self.store.archive_session(session_id).await
}
+
+ async fn get_metadata(&self, session_id: &str, key: &str) -> Result<Option<String>> {
+ self.store.get_metadata(session_id, key).await
+ }
+
+ async fn set_metadata(&self, session_id: &str, key: &str, value: &str) -> Result<()> {
+ self.store.set_metadata(session_id, key, value).await
+ }
}
// ---------------------------------------------------------------------------
@@ -310,6 +321,22 @@ impl SessionManager {
pub fn invocation_id(&self) -> &str {
&self.invocation_id
}
+
+ /// Read a metadata value for the current session.
+ pub async fn get_metadata(&self, key: &str) -> Result<Option<String>> {
+ if !self.persisted_to_db {
+ return Ok(None);
+ }
+ self.service.get_metadata(&self.session_id, key).await
+ }
+
+ /// Write a metadata value for the current session.
+ pub async fn set_metadata(&mut self, key: &str, value: &str) -> Result<()> {
+ self.ensure_persisted().await?;
+ self.service
+ .set_metadata(&self.session_id, key, value)
+ .await
+ }
}
#[cfg(test)]
diff --git a/crates/atuin-ai/src/snapshots.rs b/crates/atuin-ai/src/snapshots.rs
new file mode 100644
index 00000000..6c7b0c9c
--- /dev/null
+++ b/crates/atuin-ai/src/snapshots.rs
@@ -0,0 +1,414 @@
+//! Backup snapshots for files before AI edits.
+//!
+//! Before the first edit to a file within a session, a snapshot of the
+//! original content is saved so the user can recover if needed. Snapshots
+//! are stored alongside a manifest that maps sanitized filenames back to
+//! their original paths.
+//!
+//! Filenames use percent-encoding (`/` → `%2F`) so the snapshot directory
+//! is human-readable via `ls`.
+
+use std::collections::HashMap;
+use std::io::Write;
+use std::path::{Path, PathBuf};
+
+use eyre::{Result, eyre};
+use serde::{Deserialize, Serialize};
+use time::OffsetDateTime;
+
+/// Snapshot store for a single session.
+///
+/// Each session gets its own directory under the snapshots root:
+/// `<data_dir>/ai/snapshots/<session_id>/`
+///
+/// Files are stored with percent-encoded filenames derived from their
+/// canonical paths, alongside a `manifest.json` that maps filenames
+/// back to original paths with timestamps.
+#[derive(Debug)]
+pub(crate) struct SnapshotStore {
+ session_dir: PathBuf,
+ manifest: SnapshotManifest,
+}
+
+#[derive(Debug, Default, Serialize, Deserialize)]
+struct SnapshotManifest {
+ files: HashMap<String, SnapshotEntry>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+struct SnapshotEntry {
+ original_path: String,
+ snapshot_at: String,
+ size_bytes: u64,
+}
+
+impl SnapshotStore {
+ /// Open or create a snapshot store for the given session directory.
+ ///
+ /// If a manifest already exists (from a prior CLI invocation in the same
+ /// session), it's loaded so we don't re-snapshot files that were already
+ /// backed up.
+ pub fn open(session_dir: PathBuf) -> Result<Self> {
+ let manifest_path = session_dir.join("manifest.json");
+ let manifest = if manifest_path.exists() {
+ let data = fs_err::read_to_string(&manifest_path)?;
+ serde_json::from_str(&data)?
+ } else {
+ SnapshotManifest::default()
+ };
+
+ Ok(Self {
+ session_dir,
+ manifest,
+ })
+ }
+
+ /// Snapshot a file's contents if it hasn't been snapshotted yet this session.
+ ///
+ /// Returns `true` if a new snapshot was created, `false` if one already
+ /// existed. The `canonical_path` should be absolute (already tilde-expanded
+ /// and resolved).
+ pub fn ensure_snapshot(&mut self, canonical_path: &Path, content: &[u8]) -> Result<bool> {
+ let filename = sanitize_path(canonical_path);
+
+ if self.manifest.files.contains_key(&filename) {
+ return Ok(false);
+ }
+
+ fs_err::create_dir_all(&self.session_dir)?;
+
+ let snapshot_path = self.session_dir.join(&filename);
+ atomic_write_file(&snapshot_path, content)?;
+
+ let now = OffsetDateTime::now_utc();
+ let entry = SnapshotEntry {
+ original_path: canonical_path.to_string_lossy().into_owned(),
+ snapshot_at: format_iso8601(now),
+ size_bytes: content.len() as u64,
+ };
+
+ self.manifest.files.insert(filename, entry);
+ self.save_manifest()?;
+
+ Ok(true)
+ }
+
+ /// Whether a file has already been snapshotted in this session.
+ #[expect(dead_code)]
+ pub fn has_snapshot(&self, canonical_path: &Path) -> bool {
+ let filename = sanitize_path(canonical_path);
+ self.manifest.files.contains_key(&filename)
+ }
+
+ fn save_manifest(&self) -> Result<()> {
+ let json = serde_json::to_string_pretty(&self.manifest)?;
+ atomic_write_file(&self.session_dir.join("manifest.json"), json.as_bytes())
+ }
+}
+
+/// Percent-encode a path for use as a filename.
+///
+/// Encodes `%` as `%25`, `/` as `%2F`, and `\` as `%5C`, then strips
+/// leading separators and drive prefixes (e.g. `C:\`). The result is
+/// always a flat filename safe for use with `Path::join` on any platform.
+///
+/// Example (Unix): `/Users/me/.config/foo.toml` → `Users%2Fme%2F.config%2Ffoo.toml`
+/// Example (Windows): `C:\Users\me\config.toml` → `Users%5Cme%5Cconfig.toml`
+pub(crate) fn sanitize_path(path: &Path) -> String {
+ let s = path.to_string_lossy();
+ // Strip drive letter prefix on Windows (e.g. "C:\")
+ let s = s.strip_prefix('/').unwrap_or_else(|| {
+ // Handle Windows drive prefix like "C:\" or "C:/"
+ if s.len() >= 3
+ && s.as_bytes()[0].is_ascii_alphabetic()
+ && s.as_bytes()[1] == b':'
+ && (s.as_bytes()[2] == b'\\' || s.as_bytes()[2] == b'/')
+ {
+ &s[3..]
+ } else {
+ &s
+ }
+ });
+ s.replace('%', "%25")
+ .replace('/', "%2F")
+ .replace('\\', "%5C")
+}
+
+/// Write a file atomically using temp-file-then-rename.
+///
+/// Creates a temporary file in the same directory as `target`, writes
+/// content, fsyncs, then renames into place. Preserves permissions from
+/// the original file if it exists.
+pub(crate) fn atomic_write_file(target: &Path, content: &[u8]) -> Result<()> {
+ let dir = target
+ .parent()
+ .ok_or_else(|| eyre!("target path has no parent directory"))?;
+ fs_err::create_dir_all(dir)?;
+
+ let mut tmp = tempfile::NamedTempFile::new_in(dir)?;
+ tmp.write_all(content)?;
+ tmp.as_file().sync_all()?;
+
+ // Preserve permissions from original if it exists
+ if let Ok(meta) = std::fs::metadata(target) {
+ std::fs::set_permissions(tmp.path(), meta.permissions())?;
+ }
+
+ tmp.persist(target).map_err(|e| {
+ eyre!(
+ "failed to persist atomic write to {}: {}",
+ target.display(),
+ e
+ )
+ })?;
+ Ok(())
+}
+
+fn format_iso8601(dt: OffsetDateTime) -> String {
+ format!(
+ "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z",
+ dt.year(),
+ dt.month() as u8,
+ dt.day(),
+ dt.hour(),
+ dt.minute(),
+ dt.second(),
+ )
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ // ── sanitize_path ──────────────────────────────────────────
+
+ #[test]
+ fn sanitize_absolute_path() {
+ let path = Path::new("/Users/me/.config/atuin/config.toml");
+ assert_eq!(
+ sanitize_path(path),
+ "Users%2Fme%2F.config%2Fatuin%2Fconfig.toml"
+ );
+ }
+
+ #[test]
+ fn sanitize_preserves_existing_percent() {
+ let path = Path::new("/data/100%done/file.txt");
+ assert_eq!(sanitize_path(path), "data%2F100%25done%2Ffile.txt");
+ }
+
+ #[test]
+ fn sanitize_relative_path() {
+ let path = Path::new("relative/path.txt");
+ assert_eq!(sanitize_path(path), "relative%2Fpath.txt");
+ }
+
+ #[test]
+ fn sanitize_no_collision_between_similar_paths() {
+ let a = sanitize_path(Path::new("/foo/bar-baz"));
+ let b = sanitize_path(Path::new("/foo/bar/baz"));
+ assert_ne!(a, b);
+ }
+
+ #[test]
+ fn sanitize_backslash_encoded() {
+ // Windows-style path: backslashes become %5C, drive prefix stripped
+ let s = sanitize_path(Path::new("C:\\Users\\me\\config.toml"));
+ assert!(!s.contains('\\'), "backslashes must be encoded: {s}");
+ assert!(!s.starts_with("C:"), "drive prefix must be stripped: {s}");
+ assert!(s.contains("Users"));
+ assert!(s.contains("config.toml"));
+ }
+
+ #[test]
+ fn sanitize_result_is_flat_filename() {
+ // The result must not be interpreted as a path with separators
+ // when passed to Path::join — no raw / or \ allowed.
+ let unix = sanitize_path(Path::new("/home/user/file.txt"));
+ assert!(!unix.contains('/'));
+ // Construct as if on Windows
+ let win = "C:\\Users\\me\\file.txt";
+ let encoded = win
+ .strip_prefix("C:\\")
+ .unwrap()
+ .replace('%', "%25")
+ .replace('/', "%2F")
+ .replace('\\', "%5C");
+ assert!(!encoded.contains('\\'));
+ assert!(!encoded.contains('/'));
+ }
+
+ // ── atomic_write_file ──────────────────────────────────────
+
+ #[test]
+ fn atomic_write_creates_file() {
+ let dir = tempfile::tempdir().unwrap();
+ let target = dir.path().join("test.txt");
+
+ atomic_write_file(&target, b"hello world").unwrap();
+
+ assert_eq!(std::fs::read_to_string(&target).unwrap(), "hello world");
+ }
+
+ #[test]
+ fn atomic_write_overwrites_existing() {
+ let dir = tempfile::tempdir().unwrap();
+ let target = dir.path().join("test.txt");
+
+ std::fs::write(&target, "old content").unwrap();
+ atomic_write_file(&target, b"new content").unwrap();
+
+ assert_eq!(std::fs::read_to_string(&target).unwrap(), "new content");
+ }
+
+ #[test]
+ fn atomic_write_creates_parent_dirs() {
+ let dir = tempfile::tempdir().unwrap();
+ let target = dir.path().join("sub").join("dir").join("test.txt");
+
+ atomic_write_file(&target, b"nested").unwrap();
+
+ assert_eq!(std::fs::read_to_string(&target).unwrap(), "nested");
+ }
+
+ #[cfg(unix)]
+ #[test]
+ fn atomic_write_preserves_permissions() {
+ use std::os::unix::fs::PermissionsExt;
+
+ let dir = tempfile::tempdir().unwrap();
+ let target = dir.path().join("test.txt");
+
+ std::fs::write(&target, "original").unwrap();
+ std::fs::set_permissions(&target, std::fs::Permissions::from_mode(0o600)).unwrap();
+
+ atomic_write_file(&target, b"updated").unwrap();
+
+ let mode = std::fs::metadata(&target).unwrap().permissions().mode() & 0o777;
+ assert_eq!(mode, 0o600);
+ }
+
+ // ── SnapshotStore ──────────────────────────────────────────
+
+ #[test]
+ fn snapshot_creates_file_and_manifest() {
+ let dir = tempfile::tempdir().unwrap();
+ let session_dir = dir.path().join("session-abc");
+ let mut store = SnapshotStore::open(session_dir.clone()).unwrap();
+
+ let file_path = Path::new("/Users/me/.config/foo.toml");
+ let created = store
+ .ensure_snapshot(file_path, b"[key]\nval = 1\n")
+ .unwrap();
+
+ assert!(created);
+ assert!(store.has_snapshot(file_path));
+
+ // Snapshot file on disk
+ let expected_file = session_dir.join("Users%2Fme%2F.config%2Ffoo.toml");
+ assert!(expected_file.exists());
+ assert_eq!(
+ std::fs::read_to_string(&expected_file).unwrap(),
+ "[key]\nval = 1\n"
+ );
+
+ // Manifest on disk
+ let manifest_path = session_dir.join("manifest.json");
+ assert!(manifest_path.exists());
+ let manifest: serde_json::Value =
+ serde_json::from_str(&std::fs::read_to_string(&manifest_path).unwrap()).unwrap();
+ let files = manifest["files"].as_object().unwrap();
+ assert_eq!(files.len(), 1);
+ let entry = &files["Users%2Fme%2F.config%2Ffoo.toml"];
+ assert_eq!(
+ entry["original_path"].as_str().unwrap(),
+ "/Users/me/.config/foo.toml"
+ );
+ assert_eq!(entry["size_bytes"].as_u64().unwrap(), 14);
+ }
+
+ #[test]
+ fn snapshot_is_idempotent() {
+ let dir = tempfile::tempdir().unwrap();
+ let session_dir = dir.path().join("session-abc");
+ let mut store = SnapshotStore::open(session_dir.clone()).unwrap();
+
+ let path = Path::new("/etc/hosts");
+ let first = store.ensure_snapshot(path, b"first content").unwrap();
+ let second = store.ensure_snapshot(path, b"different content").unwrap();
+
+ assert!(first);
+ assert!(!second);
+
+ // Original content preserved, not overwritten
+ let snapshot_file = session_dir.join("etc%2Fhosts");
+ assert_eq!(
+ std::fs::read_to_string(snapshot_file).unwrap(),
+ "first content"
+ );
+ }
+
+ #[test]
+ fn snapshot_store_loads_existing_manifest() {
+ let dir = tempfile::tempdir().unwrap();
+ let session_dir = dir.path().join("session-abc");
+
+ // First store: create a snapshot
+ {
+ let mut store = SnapshotStore::open(session_dir.clone()).unwrap();
+ store
+ .ensure_snapshot(Path::new("/etc/hosts"), b"127.0.0.1")
+ .unwrap();
+ }
+
+ // Second store (simulates new CLI invocation): should see existing snapshot
+ {
+ let mut store = SnapshotStore::open(session_dir).unwrap();
+ assert!(store.has_snapshot(Path::new("/etc/hosts")));
+
+ let created = store
+ .ensure_snapshot(Path::new("/etc/hosts"), b"new content")
+ .unwrap();
+ assert!(!created);
+ }
+ }
+
+ #[test]
+ fn snapshot_multiple_files() {
+ let dir = tempfile::tempdir().unwrap();
+ let session_dir = dir.path().join("session-abc");
+ let mut store = SnapshotStore::open(session_dir.clone()).unwrap();
+
+ store
+ .ensure_snapshot(Path::new("/etc/hosts"), b"hosts content")
+ .unwrap();
+ store
+ .ensure_snapshot(Path::new("/Users/me/.bashrc"), b"bashrc content")
+ .unwrap();
+
+ assert!(store.has_snapshot(Path::new("/etc/hosts")));
+ assert!(store.has_snapshot(Path::new("/Users/me/.bashrc")));
+ assert!(!store.has_snapshot(Path::new("/nonexistent")));
+
+ // Both snapshot files exist
+ assert!(session_dir.join("etc%2Fhosts").exists());
+ assert!(session_dir.join("Users%2Fme%2F.bashrc").exists());
+
+ // Manifest has both entries
+ let manifest: serde_json::Value = serde_json::from_str(
+ &std::fs::read_to_string(session_dir.join("manifest.json")).unwrap(),
+ )
+ .unwrap();
+ assert_eq!(manifest["files"].as_object().unwrap().len(), 2);
+ }
+
+ #[test]
+ fn format_iso8601_produces_valid_format() {
+ let dt = OffsetDateTime::from_unix_timestamp(1700000000).unwrap();
+ let formatted = format_iso8601(dt);
+ assert_eq!(formatted.len(), 20);
+ assert!(formatted.starts_with("2023-"));
+ assert!(formatted.contains('T'));
+ assert!(formatted.ends_with('Z'));
+ }
+}
diff --git a/crates/atuin-ai/src/store.rs b/crates/atuin-ai/src/store.rs
index 2a75d8f4..20b9e881 100644
--- a/crates/atuin-ai/src/store.rs
+++ b/crates/atuin-ai/src/store.rs
@@ -299,6 +299,38 @@ impl AiSessionStore {
.await?;
Ok(())
}
+
+ // ── Session metadata (key-value per session) ──
+
+ /// Read a metadata value for a session. Returns `None` if the key doesn't
+ /// exist or the session hasn't been persisted yet.
+ pub async fn get_metadata(&self, session_id: &str, key: &str) -> Result<Option<String>> {
+ let row: Option<(String,)> =
+ sqlx::query_as("SELECT value FROM session_metadata WHERE session_id = ?1 AND key = ?2")
+ .bind(session_id)
+ .bind(key)
+ .fetch_optional(&self.pool)
+ .await?;
+
+ Ok(row.map(|(v,)| v))
+ }
+
+ /// Write a metadata value for a session (upsert).
+ pub async fn set_metadata(&self, session_id: &str, key: &str, value: &str) -> Result<()> {
+ let now = OffsetDateTime::now_utc().unix_timestamp();
+ sqlx::query(
+ "INSERT INTO session_metadata (session_id, key, value, updated_at)
+ VALUES (?1, ?2, ?3, ?4)
+ ON CONFLICT (session_id, key) DO UPDATE SET value = ?3, updated_at = ?4",
+ )
+ .bind(session_id)
+ .bind(key)
+ .bind(value)
+ .bind(now)
+ .execute(&self.pool)
+ .await?;
+ Ok(())
+ }
}
#[cfg(test)]
diff --git a/crates/atuin-ai/src/stream.rs b/crates/atuin-ai/src/stream.rs
index f4f4d704..24770abe 100644
--- a/crates/atuin-ai/src/stream.rs
+++ b/crates/atuin-ai/src/stream.rs
@@ -74,6 +74,14 @@ impl ChatRequest {
if capabilities.enable_history_search.unwrap_or(true) {
caps.push("client_v1_atuin_history".to_string());
}
+ if capabilities.enable_file_tools.unwrap_or(true) {
+ caps.push("client_v1_read_file".to_string());
+ caps.push("client_v1_edit_file".to_string());
+ caps.push("client_v1_write_file".to_string());
+ }
+ if capabilities.enable_command_execution.unwrap_or(true) {
+ caps.push("client_v1_execute_shell_command".to_string());
+ }
if let Ok(extra) = std::env::var("ATUIN_AI__ADDITIONAL_CAPS") {
caps.extend(
extra
diff --git a/crates/atuin-ai/src/tools/descriptor.rs b/crates/atuin-ai/src/tools/descriptor.rs
index fc44ec10..6ccb595f 100644
--- a/crates/atuin-ai/src/tools/descriptor.rs
+++ b/crates/atuin-ai/src/tools/descriptor.rs
@@ -16,6 +16,7 @@ pub(crate) struct ToolDescriptor {
/// Past-tense verb for summaries (e.g. "Read file").
pub past_verb: &'static str,
/// Whether this tool is executed client-side (by the CLI).
+ #[expect(dead_code)]
pub is_client: bool,
}
@@ -30,9 +31,18 @@ pub(crate) const READ: &ToolDescriptor = &ToolDescriptor {
is_client: true,
};
+pub(crate) const EDIT: &ToolDescriptor = &ToolDescriptor {
+ canonical_names: &["edit_file"],
+ capability: Some("client_v1_edit_file"),
+ display_verb: "edit",
+ progressive_verb: "Editing file...",
+ past_verb: "Edited file",
+ is_client: true,
+};
+
pub(crate) const WRITE: &ToolDescriptor = &ToolDescriptor {
- canonical_names: &["str_replace", "file_create", "file_insert"],
- capability: Some("client_v1_write"),
+ canonical_names: &["write_file"],
+ capability: Some("client_v1_write_file"),
display_verb: "write to",
progressive_verb: "Writing file...",
past_verb: "Wrote file",
@@ -41,7 +51,7 @@ pub(crate) const WRITE: &ToolDescriptor = &ToolDescriptor {
pub(crate) const SHELL: &ToolDescriptor = &ToolDescriptor {
canonical_names: &["execute_shell_command"],
- capability: Some("client_v1_shell"),
+ capability: Some("client_v1_execute_shell_command"),
display_verb: "run",
progressive_verb: "Running command...",
past_verb: "Ran command",
@@ -81,6 +91,7 @@ pub(crate) const SERVER_SCRAPE: &ToolDescriptor = &ToolDescriptor {
/// All known tool descriptors, for lookup by name.
const ALL_DESCRIPTORS: &[&ToolDescriptor] = &[
READ,
+ EDIT,
WRITE,
SHELL,
ATUIN_HISTORY,
diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs
index 8f2183b7..8fe1ad73 100644
--- a/crates/atuin-ai/src/tools/mod.rs
+++ b/crates/atuin-ai/src/tools/mod.rs
@@ -169,6 +169,8 @@ pub(crate) struct TrackedTool {
pub phase: ToolPhase,
/// Sender to interrupt a running shell command (only set during ExecutingWithPreview).
pub abort_tx: Option<tokio::sync::oneshot::Sender<()>>,
+ /// Diff preview for completed edit tool calls.
+ pub edit_preview: Option<crate::diff::EditPreview>,
}
impl TrackedTool {
@@ -234,6 +236,7 @@ impl ToolTracker {
tool,
phase: ToolPhase::CheckingPermissions,
abort_tx: None,
+ edit_preview: None,
});
}
@@ -294,11 +297,6 @@ impl ToolTracker {
.find(|t| t.phase == ToolPhase::AskingForPermission)
}
- /// Get the preview for a tool by ID (live or cached).
- pub fn preview_for(&self, id: &str) -> Option<ToolPreview> {
- self.get(id)?.preview()
- }
-
/// Iterate mutably over all tracked tools.
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut TrackedTool> {
self.tools.iter_mut()
@@ -309,6 +307,7 @@ impl ToolTracker {
#[derive(Debug, Clone)]
pub(crate) enum ClientToolCall {
Read(ReadToolCall),
+ Edit(EditToolCall),
Write(WriteToolCall),
Shell(ShellToolCall),
AtuinHistory(AtuinHistoryToolCall),
@@ -320,9 +319,8 @@ impl TryFrom<(&str, &serde_json::Value)> for ClientToolCall {
fn try_from((name, input): (&str, &serde_json::Value)) -> Result<Self, Self::Error> {
match name {
"read_file" => Ok(ClientToolCall::Read(ReadToolCall::try_from(input)?)),
- "create_file" => Ok(ClientToolCall::Write(WriteToolCall::try_from(input)?)),
- // "append_to_file" => Ok(ClientToolCall::Append(AppendToolCall::try_from(input)?)),
- // "str_replace" => Ok(ClientToolCall::StrReplace(StrReplaceToolCall::try_from(input)?)),
+ "edit_file" => Ok(ClientToolCall::Edit(EditToolCall::try_from(input)?)),
+ "write_file" => Ok(ClientToolCall::Write(WriteToolCall::try_from(input)?)),
"execute_shell_command" => Ok(ClientToolCall::Shell(ShellToolCall::try_from(input)?)),
"atuin_history" => Ok(ClientToolCall::AtuinHistory(
AtuinHistoryToolCall::try_from(input)?,
@@ -336,17 +334,22 @@ impl ClientToolCall {
pub(crate) fn descriptor(&self) -> &'static descriptor::ToolDescriptor {
match self {
ClientToolCall::Read(_) => descriptor::READ,
+ ClientToolCall::Edit(_) => descriptor::EDIT,
ClientToolCall::Write(_) => descriptor::WRITE,
ClientToolCall::Shell(_) => descriptor::SHELL,
ClientToolCall::AtuinHistory(_) => descriptor::ATUIN_HISTORY,
}
}
- /// The permission rule name for this tool category (e.g. "Write" covers
- /// str_replace, file_create, file_insert).
+ /// The permission rule name for this tool category.
+ ///
+ /// Edit and Write share the `"Write"` rule name — a Write permission
+ /// covers both str_replace edits and full file creates. Write also
+ /// implies Read (checked in `ReadToolCall::matches_rule`).
pub(crate) fn rule_name(&self) -> &'static str {
match self {
ClientToolCall::Read(_) => "Read",
+ ClientToolCall::Edit(_) => "Write",
ClientToolCall::Write(_) => "Write",
ClientToolCall::Shell(_) => "Shell",
ClientToolCall::AtuinHistory(_) => "AtuinHistory",
@@ -356,6 +359,7 @@ impl ClientToolCall {
pub(crate) fn matches_rule(&self, rule: &Rule) -> bool {
match self {
ClientToolCall::Read(tool) => tool.matches_rule(rule),
+ ClientToolCall::Edit(tool) => tool.matches_rule(rule),
ClientToolCall::Write(tool) => tool.matches_rule(rule),
ClientToolCall::Shell(tool) => tool.matches_rule(rule),
ClientToolCall::AtuinHistory(tool) => tool.matches_rule(rule),
@@ -365,6 +369,7 @@ impl ClientToolCall {
pub(crate) fn target_dir(&self) -> Option<&Path> {
match self {
ClientToolCall::Read(tool) => tool.target_dir(),
+ ClientToolCall::Edit(tool) => tool.target_dir(),
ClientToolCall::Write(tool) => tool.target_dir(),
ClientToolCall::Shell(tool) => tool.target_dir(),
ClientToolCall::AtuinHistory(tool) => tool.target_dir(),
@@ -401,6 +406,14 @@ impl PermissableToolCall for ClientToolCall {
}
}
+/// Expand shell constructs (`~`, `$HOME`, etc.) in a path string.
+///
+/// Tool call paths arrive as raw strings from the API without shell
+/// expansion. Uses `shellexpand` (same as `atuin-client`).
+fn expand_path(path: &str) -> PathBuf {
+ PathBuf::from(shellexpand::tilde(path).into_owned())
+}
+
#[derive(Debug, Clone)]
pub(crate) struct ReadToolCall {
pub path: PathBuf,
@@ -425,7 +438,7 @@ impl TryFrom<&serde_json::Value> for ReadToolCall {
.min(MAX_FILE_READ_LINES);
Ok(ReadToolCall {
- path: PathBuf::from(path),
+ path: expand_path(path),
offset,
limit,
})
@@ -499,7 +512,207 @@ impl PermissableToolCall for ReadToolCall {
}
fn matches_rule(&self, rule: &Rule) -> bool {
- if rule.tool != "Read" {
+ // Write implies Read — a Write permission on a path also permits reading it.
+ if rule.tool != "Read" && rule.tool != "Write" {
+ return false;
+ }
+
+ match rule.scope.as_deref() {
+ None | Some("*") => true,
+ Some(scope) => path_matches_scope(&self.path, scope),
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub(crate) struct EditToolCall {
+ pub path: PathBuf,
+ pub old_string: String,
+ pub new_string: String,
+ pub replace_all: bool,
+}
+
+impl TryFrom<&serde_json::Value> for EditToolCall {
+ type Error = eyre::Error;
+
+ fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> {
+ let path = value
+ .get("file_path")
+ .and_then(|v| v.as_str())
+ .ok_or(eyre::eyre!("Missing file_path"))?;
+
+ let old_string = value
+ .get("old_string")
+ .and_then(|v| v.as_str())
+ .ok_or(eyre::eyre!("Missing old_string"))?;
+
+ let new_string = value
+ .get("new_string")
+ .and_then(|v| v.as_str())
+ .ok_or(eyre::eyre!("Missing new_string"))?;
+
+ let replace_all = value
+ .get("replace_all")
+ .and_then(|v| v.as_bool())
+ .unwrap_or(false);
+
+ Ok(EditToolCall {
+ path: expand_path(path),
+ old_string: old_string.to_string(),
+ new_string: new_string.to_string(),
+ replace_all,
+ })
+ }
+}
+
+impl EditToolCall {
+ /// Resolve the edit path to an absolute path.
+ pub fn resolved_path(&self) -> PathBuf {
+ if self.path.is_relative() {
+ std::env::current_dir()
+ .map(|cwd| cwd.join(&self.path))
+ .unwrap_or_else(|_| self.path.clone())
+ } else {
+ self.path.clone()
+ }
+ }
+
+ /// Execute the edit against the filesystem.
+ ///
+ /// Checks freshness via the provided tracker, validates matches, applies
+ /// the replacement, and writes atomically. Returns the outcome and (on
+ /// success) the new file content bytes for tracker updates.
+ ///
+ /// Callers should snapshot the file before calling this method and
+ /// update the file tracker after a successful return.
+ pub fn execute(
+ &self,
+ resolved_path: &Path,
+ file_tracker: &crate::file_tracker::FileReadTracker,
+ ) -> (ToolOutcome, Option<Vec<u8>>) {
+ use crate::file_tracker::FreshnessCheck;
+
+ // 1. Basic validation
+ if !resolved_path.exists() {
+ return (
+ ToolOutcome::Error(format!(
+ "Error: file does not exist: {}",
+ resolved_path.display()
+ )),
+ None,
+ );
+ }
+ if resolved_path.is_dir() {
+ return (
+ ToolOutcome::Error(format!(
+ "Error: path is a directory, not a file: {}",
+ resolved_path.display()
+ )),
+ None,
+ );
+ }
+ if self.old_string.is_empty() {
+ return (
+ ToolOutcome::Error(
+ "old_string must not be empty. To create a new file, use write_file instead."
+ .to_string(),
+ ),
+ None,
+ );
+ }
+
+ // 2. Freshness check
+ match file_tracker.check_freshness(resolved_path) {
+ Ok(FreshnessCheck::NotRead) => {
+ return (
+ ToolOutcome::Error(
+ "File has not been read yet. Read it first before editing.".to_string(),
+ ),
+ None,
+ );
+ }
+ Ok(FreshnessCheck::Stale) => {
+ return (
+ ToolOutcome::Error(
+ "File has been modified since read, either by the user or by a linter. Read it again before attempting to edit it.".to_string(),
+ ),
+ None,
+ );
+ }
+ Err(e) => {
+ return (
+ ToolOutcome::Error(format!("Error checking file state: {e}")),
+ None,
+ );
+ }
+ Ok(FreshnessCheck::Fresh) => {}
+ }
+
+ // 3. Read current contents
+ let content = match std::fs::read_to_string(resolved_path) {
+ Ok(c) => c,
+ Err(e) => return (ToolOutcome::Error(format!("Error reading file: {e}")), None),
+ };
+
+ // 4. Find and validate matches
+ let match_count = content.matches(&self.old_string).count();
+
+ if match_count == 0 {
+ return (
+ ToolOutcome::Error(format!(
+ "old_string not found in {}. Make sure it matches exactly, including whitespace and indentation.",
+ resolved_path.display()
+ )),
+ None,
+ );
+ }
+
+ if match_count > 1 && !self.replace_all {
+ return (
+ ToolOutcome::Error(format!(
+ "Found {match_count} matches of old_string in {}, but replace_all is false. Either provide more context to make the match unique, or set replace_all to true.",
+ resolved_path.display()
+ )),
+ None,
+ );
+ }
+
+ // 5. Apply replacement
+ let new_content = if self.replace_all {
+ content.replace(&self.old_string, &self.new_string)
+ } else {
+ content.replacen(&self.old_string, &self.new_string, 1)
+ };
+
+ // 6. Write atomically
+ let new_bytes = new_content.into_bytes();
+ if let Err(e) = crate::snapshots::atomic_write_file(resolved_path, &new_bytes) {
+ return (ToolOutcome::Error(format!("Error writing file: {e}")), None);
+ }
+
+ // 7. Success
+ let verb = if match_count == 1 {
+ "occurrence"
+ } else {
+ "occurrences"
+ };
+ (
+ ToolOutcome::Success(format!(
+ "Edited {}: replaced {match_count} {verb} of old_string with new_string.",
+ resolved_path.display()
+ )),
+ Some(new_bytes),
+ )
+ }
+}
+
+impl PermissableToolCall for EditToolCall {
+ fn target_dir(&self) -> Option<&Path> {
+ Some(&self.path)
+ }
+
+ fn matches_rule(&self, rule: &Rule) -> bool {
+ if rule.tool != "Write" {
return false;
}
@@ -532,7 +745,7 @@ impl TryFrom<&serde_json::Value> for WriteToolCall {
.ok_or(eyre::eyre!("Missing content"))?;
Ok(WriteToolCall {
- path: PathBuf::from(path),
+ path: expand_path(path),
content: content.to_string(),
})
}
@@ -560,6 +773,9 @@ pub(crate) struct ShellToolCall {
pub dir: Option<PathBuf>,
pub command: String,
pub shell: String,
+ // allow dead code here; this will be tied into o11y and user-facing descriptions
+ #[expect(dead_code)]
+ pub description: Option<String>,
}
impl TryFrom<&serde_json::Value> for ShellToolCall {
@@ -579,10 +795,16 @@ impl TryFrom<&serde_json::Value> for ShellToolCall {
.unwrap_or("bash")
.to_string();
+ let description = value
+ .get("description")
+ .and_then(|v| v.as_str())
+ .map(|s| s.to_string());
+
Ok(ShellToolCall {
- dir: dir.map(PathBuf::from),
+ dir: dir.map(expand_path),
command: command.to_string(),
shell,
+ description,
})
}
}
@@ -614,7 +836,34 @@ const PREVIEW_HEIGHT: u16 = 10;
/// Default terminal width for VT100 emulation.
const PREVIEW_WIDTH: u16 = 120;
+/// Normalize newlines for VT100 processing.
+///
+/// When subprocess output is captured via pipes (no PTY), bare `\n` (LF) bytes
+/// are not translated to `\r\n` (CR+LF) the way a kernel terminal driver would
+/// with the `ONLCR` flag. In VT100, LF only moves the cursor down without
+/// returning to column 0. This causes lines to start at progressively higher
+/// column offsets and eventually wrap, producing garbled output.
+///
+/// This function inserts `\r` before any `\n` that isn't already preceded by
+/// `\r`, mimicking the terminal driver's ONLCR behavior.
+fn normalize_newlines_for_vt100(data: &[u8]) -> Vec<u8> {
+ let mut out = Vec::with_capacity(data.len() + data.len() / 8);
+ for (i, &b) in data.iter().enumerate() {
+ if b == b'\n' && (i == 0 || data[i - 1] != b'\r') {
+ out.push(b'\r');
+ }
+ out.push(b);
+ }
+ out
+}
+
/// Extract plain text lines from a VT100 screen buffer.
+///
+/// Strips trailing blank lines so the result only contains rows with actual
+/// content. Without this, the fixed-size VT100 screen (PREVIEW_HEIGHT rows)
+/// would always return that many lines, and downstream components that use
+/// tail-mode display (like the Viewport) would show the blank padding rows
+/// instead of the real output.
fn vt100_screen_lines(screen: &vt100::Screen) -> Vec<String> {
let (rows, cols) = screen.size();
let mut lines = Vec::with_capacity(rows as usize);
@@ -625,9 +874,11 @@ fn vt100_screen_lines(screen: &vt100::Screen) -> Vec<String> {
line.push_str(cell.contents());
}
}
- // Trim trailing whitespace for cleaner display
lines.push(line.trim_end().to_string());
}
+ while lines.last().is_some_and(|l| l.is_empty()) {
+ lines.pop();
+ }
lines
}
@@ -640,12 +891,17 @@ fn strip_ansi_via_vt100(raw: &[u8]) -> String {
if raw.is_empty() {
return String::new();
}
- // Use the contents_formatted → screen approach: feed bytes into a parser
- // with enough rows to hold everything, then read back the plain text.
- // Estimate rows: one row per ~PREVIEW_WIDTH bytes, plus generous padding.
- let estimated_rows = (raw.len() / PREVIEW_WIDTH as usize + 1).min(10_000) as u16;
+ // Normalize bare LF to CR+LF so lines start at column 0 in the VT100 screen.
+ let normalized = normalize_newlines_for_vt100(raw);
+ // Feed bytes into a VT100 parser large enough to hold all output, then
+ // read back the plain text. We estimate rows from the number of newlines
+ // (not total byte length) because real output typically has short lines
+ // that would be severely under-counted by a bytes÷width estimate.
+ let newline_count = normalized.iter().filter(|&&b| b == b'\n').count();
+ let wrap_estimate = normalized.len() / PREVIEW_WIDTH as usize;
+ let estimated_rows = (newline_count + wrap_estimate + 1).min(10_000) as u16;
let mut parser = vt100::Parser::new(estimated_rows, PREVIEW_WIDTH, 0);
- parser.process(raw);
+ parser.process(&normalized);
let screen = parser.screen();
// screen.contents() returns the full plain-text content with trailing
// whitespace trimmed per line and trailing blank lines removed.
@@ -727,7 +983,8 @@ pub(crate) async fn execute_shell_command_streaming(
Ok(0) => stdout_done = true,
Ok(n) => {
full_stdout.extend_from_slice(&stdout_buf[..n]);
- parser.process(&stdout_buf[..n]);
+ let normalized = normalize_newlines_for_vt100(&stdout_buf[..n]);
+ parser.process(&normalized);
}
Err(_) => stdout_done = true,
}
@@ -740,7 +997,8 @@ pub(crate) async fn execute_shell_command_streaming(
Ok(n) => {
full_stderr.extend_from_slice(&stderr_buf[..n]);
// Feed stderr to the preview parser too, so it shows in the VT100 screen
- parser.process(&stderr_buf[..n]);
+ let normalized = normalize_newlines_for_vt100(&stderr_buf[..n]);
+ parser.process(&normalized);
}
Err(_) => stderr_done = true,
}
@@ -967,7 +1225,7 @@ mod tests {
fn read_tool(path: &str) -> ReadToolCall {
ReadToolCall {
- path: PathBuf::from(path),
+ path: expand_path(path),
offset: 0,
limit: 100,
}
@@ -975,7 +1233,7 @@ mod tests {
fn write_tool(path: &str) -> WriteToolCall {
WriteToolCall {
- path: PathBuf::from(path),
+ path: expand_path(path),
content: String::new(),
}
}
@@ -994,12 +1252,26 @@ mod tests {
}
#[test]
- fn wrong_tool_never_matches() {
- assert!(!read_tool("foo.txt").matches_rule(&write_rule(None)));
+ fn write_implies_read() {
+ // A Write rule also permits reads on the same path
+ assert!(read_tool("foo.txt").matches_rule(&write_rule(None)));
+ // But a Read rule does not permit writes
assert!(!write_tool("foo.txt").matches_rule(&read_rule(None)));
}
#[test]
+ fn edit_uses_write_rule() {
+ let edit = EditToolCall {
+ path: expand_path("/home/user/config.toml"),
+ old_string: "x".into(),
+ new_string: "y".into(),
+ replace_all: false,
+ };
+ assert!(edit.matches_rule(&write_rule(None)));
+ assert!(!edit.matches_rule(&read_rule(None)));
+ }
+
+ #[test]
fn extension_glob() {
assert!(read_tool("notes.md").matches_rule(&read_rule(Some("*.md"))));
assert!(!read_tool("notes.txt").matches_rule(&read_rule(Some("*.md"))));
@@ -1050,6 +1322,419 @@ mod tests {
}
}
+ // ── edit_file execution tests ──
+
+ mod edit {
+ use super::*;
+ use crate::file_tracker::FileReadTracker;
+
+ /// Helper: create a temp file (with a closed handle), record it in a tracker.
+ /// Returns the TempDir (keeps the path alive) and tracker.
+ /// The file handle is closed so atomic_write_file can rename over it on Windows.
+ fn setup_tracked_file(content: &str) -> (tempfile::TempDir, PathBuf, FileReadTracker) {
+ let dir = tempfile::tempdir().unwrap();
+ let path = dir.path().join("test_file.toml");
+ std::fs::write(&path, content).unwrap();
+
+ let file_content = std::fs::read(&path).unwrap();
+ let mtime = std::fs::metadata(&path).unwrap().modified().unwrap();
+
+ let mut tracker = FileReadTracker::default();
+ tracker.record_read(path.clone(), &file_content, mtime);
+
+ (dir, path, tracker)
+ }
+
+ fn edit_call(path: &Path, old: &str, new: &str, replace_all: bool) -> EditToolCall {
+ EditToolCall {
+ path: path.to_path_buf(),
+ old_string: old.to_string(),
+ new_string: new.to_string(),
+ replace_all,
+ }
+ }
+
+ #[test]
+ fn successful_single_replacement() {
+ let (_dir, path, tracker) = setup_tracked_file("[section]\nkey = old_value\n");
+
+ let call = edit_call(&path, "old_value", "new_value", false);
+ let (outcome, new_bytes) = call.execute(&path, &tracker);
+
+ assert!(matches!(outcome, ToolOutcome::Success(_)));
+ assert!(new_bytes.is_some());
+ assert_eq!(
+ std::fs::read_to_string(&path).unwrap(),
+ "[section]\nkey = new_value\n"
+ );
+ }
+
+ #[test]
+ fn successful_replace_all() {
+ let (_dir, path, tracker) = setup_tracked_file("aaa bbb aaa ccc aaa");
+
+ let call = edit_call(&path, "aaa", "xxx", true);
+ let (outcome, _) = call.execute(&path, &tracker);
+
+ assert!(matches!(outcome, ToolOutcome::Success(ref s) if s.contains("3 occurrences")));
+ assert_eq!(
+ std::fs::read_to_string(&path).unwrap(),
+ "xxx bbb xxx ccc xxx"
+ );
+ }
+
+ #[test]
+ fn error_file_not_read() {
+ let dir = tempfile::tempdir().unwrap();
+ let path = dir.path().join("unread.txt");
+ std::fs::write(&path, "content").unwrap();
+ let tracker = FileReadTracker::default(); // empty — never read
+
+ let call = edit_call(&path, "x", "y", false);
+ let (outcome, new_bytes) = call.execute(&path, &tracker);
+
+ assert!(new_bytes.is_none());
+ match outcome {
+ ToolOutcome::Error(msg) => {
+ assert!(msg.contains("not been read yet"), "got: {msg}");
+ }
+ _ => panic!("expected error"),
+ }
+ }
+
+ #[test]
+ fn error_file_modified_since_read() {
+ let (_dir, path, tracker) = setup_tracked_file("original");
+
+ // Modify the file after the read was recorded
+ std::thread::sleep(std::time::Duration::from_millis(10));
+ std::fs::write(&path, "modified externally").unwrap();
+
+ let call = edit_call(&path, "original", "replaced", false);
+ let (outcome, new_bytes) = call.execute(&path, &tracker);
+
+ assert!(new_bytes.is_none());
+ match outcome {
+ ToolOutcome::Error(msg) => {
+ assert!(msg.contains("modified since read"), "got: {msg}");
+ }
+ _ => panic!("expected error"),
+ }
+ }
+
+ #[test]
+ fn error_no_match() {
+ let (_dir, path, tracker) = setup_tracked_file("hello world");
+
+ let call = edit_call(&path, "nonexistent", "replacement", false);
+ let (outcome, new_bytes) = call.execute(&path, &tracker);
+
+ assert!(new_bytes.is_none());
+ match outcome {
+ ToolOutcome::Error(msg) => {
+ assert!(msg.contains("not found"), "got: {msg}");
+ }
+ _ => panic!("expected error"),
+ }
+ }
+
+ #[test]
+ fn error_multiple_matches_without_replace_all() {
+ let (_dir, path, tracker) = setup_tracked_file("foo bar foo baz foo");
+
+ let call = edit_call(&path, "foo", "qux", false);
+ let (outcome, new_bytes) = call.execute(&path, &tracker);
+
+ assert!(new_bytes.is_none());
+ match outcome {
+ ToolOutcome::Error(msg) => {
+ assert!(msg.contains("3 matches"), "got: {msg}");
+ assert!(msg.contains("replace_all"), "got: {msg}");
+ }
+ _ => panic!("expected error"),
+ }
+ // File should be unchanged
+ assert_eq!(
+ std::fs::read_to_string(&path).unwrap(),
+ "foo bar foo baz foo"
+ );
+ }
+
+ #[test]
+ fn error_empty_old_string() {
+ let (_dir, path, tracker) = setup_tracked_file("content");
+
+ let call = edit_call(&path, "", "something", false);
+ let (outcome, new_bytes) = call.execute(&path, &tracker);
+
+ assert!(new_bytes.is_none());
+ assert!(matches!(outcome, ToolOutcome::Error(_)));
+ }
+
+ #[test]
+ fn error_file_does_not_exist() {
+ let tracker = FileReadTracker::default();
+ let dir = tempfile::tempdir().unwrap();
+ let path = dir.path().join("nonexistent.txt");
+
+ let call = edit_call(&path, "x", "y", false);
+ let (outcome, new_bytes) = call.execute(&path, &tracker);
+
+ assert!(new_bytes.is_none());
+ match outcome {
+ ToolOutcome::Error(msg) => {
+ assert!(msg.contains("does not exist"), "got: {msg}");
+ }
+ _ => panic!("expected error"),
+ }
+ }
+
+ #[test]
+ fn preserves_file_when_no_match() {
+ let original = "[config]\nport = 8080\nhost = localhost\n";
+ let (_dir, path, tracker) = setup_tracked_file(original);
+
+ let call = edit_call(&path, "port = 9090", "port = 3000", false);
+ let (outcome, _) = call.execute(&path, &tracker);
+
+ assert!(matches!(outcome, ToolOutcome::Error(_)));
+ assert_eq!(std::fs::read_to_string(&path).unwrap(), original);
+ }
+
+ #[test]
+ fn multiline_replacement() {
+ let content = "[section]\nkey1 = val1\nkey2 = val2\n[other]\n";
+ let (_dir, path, tracker) = setup_tracked_file(content);
+
+ let call = edit_call(
+ &path,
+ "key1 = val1\nkey2 = val2",
+ "key1 = new1\nkey2 = new2",
+ false,
+ );
+ let (outcome, new_bytes) = call.execute(&path, &tracker);
+
+ assert!(matches!(outcome, ToolOutcome::Success(_)));
+ assert!(new_bytes.is_some());
+ assert_eq!(
+ std::fs::read_to_string(&path).unwrap(),
+ "[section]\nkey1 = new1\nkey2 = new2\n[other]\n"
+ );
+ }
+ }
+
+ // ── Integration tests: full edit lifecycle ──
+ //
+ // These exercise the cross-component flow that dispatch orchestrates:
+ // FileReadTracker → SnapshotStore → EditToolCall.execute → tracker update
+
+ mod edit_integration {
+ use super::*;
+ use crate::edit_permissions::EditPermissionCache;
+ use crate::file_tracker::FileReadTracker;
+ use crate::snapshots::SnapshotStore;
+
+ /// Simulate a file read (what dispatch does after ReadToolCall.execute).
+ fn simulate_read(tracker: &mut FileReadTracker, path: &std::path::Path) {
+ let content = std::fs::read(path).unwrap();
+ let mtime = std::fs::metadata(path).unwrap().modified().unwrap();
+ tracker.record_read(path.to_path_buf(), &content, mtime);
+ }
+
+ /// Simulate a tracker update after edit (what dispatch does after execute).
+ fn simulate_tracker_update(
+ tracker: &mut FileReadTracker,
+ path: &std::path::Path,
+ new_bytes: &[u8],
+ ) {
+ let mtime = std::fs::metadata(path).unwrap().modified().unwrap();
+ tracker.update_after_edit(path, new_bytes, mtime);
+ }
+
+ #[test]
+ fn full_read_snapshot_edit_cycle() {
+ let dir = tempfile::tempdir().unwrap();
+ let file_path = dir.path().join("config.toml");
+ std::fs::write(&file_path, "[db]\nhost = localhost\nport = 5432\n").unwrap();
+
+ let snapshot_dir = dir.path().join("snapshots").join("session-1");
+ let mut tracker = FileReadTracker::default();
+ let mut store = SnapshotStore::open(snapshot_dir.clone()).unwrap();
+
+ // 1. Simulate reading the file
+ simulate_read(&mut tracker, &file_path);
+
+ // 2. Snapshot before edit
+ let original = std::fs::read(&file_path).unwrap();
+ store.ensure_snapshot(&file_path, &original).unwrap();
+
+ // 3. Execute edit
+ let call = EditToolCall {
+ path: file_path.clone(),
+ old_string: "host = localhost".to_string(),
+ new_string: "host = 10.0.0.1".to_string(),
+ replace_all: false,
+ };
+ let (outcome, new_bytes) = call.execute(&file_path, &tracker);
+ assert!(matches!(outcome, ToolOutcome::Success(_)));
+ let new_bytes = new_bytes.unwrap();
+
+ // 4. Update tracker (simulating what dispatch does)
+ simulate_tracker_update(&mut tracker, &file_path, &new_bytes);
+
+ // Verify: file was edited
+ assert_eq!(
+ std::fs::read_to_string(&file_path).unwrap(),
+ "[db]\nhost = 10.0.0.1\nport = 5432\n"
+ );
+
+ // Verify: snapshot has original content
+ assert!(store.has_snapshot(&file_path));
+ let snapshot_name = crate::snapshots::sanitize_path(&file_path);
+ let snapshot_content =
+ std::fs::read_to_string(snapshot_dir.join(snapshot_name)).unwrap();
+ assert_eq!(snapshot_content, "[db]\nhost = localhost\nport = 5432\n");
+ }
+
+ #[test]
+ fn second_edit_without_reread() {
+ let dir = tempfile::tempdir().unwrap();
+ let file_path = dir.path().join("config.toml");
+ std::fs::write(&file_path, "key1 = aaa\nkey2 = bbb\n").unwrap();
+
+ let mut tracker = FileReadTracker::default();
+
+ // Read the file
+ simulate_read(&mut tracker, &file_path);
+
+ // First edit
+ let call1 = EditToolCall {
+ path: file_path.clone(),
+ old_string: "key1 = aaa".to_string(),
+ new_string: "key1 = xxx".to_string(),
+ replace_all: false,
+ };
+ let (outcome, new_bytes) = call1.execute(&file_path, &tracker);
+ assert!(matches!(outcome, ToolOutcome::Success(_)));
+ simulate_tracker_update(&mut tracker, &file_path, &new_bytes.unwrap());
+
+ // Second edit — should work without re-reading because tracker was updated
+ let call2 = EditToolCall {
+ path: file_path.clone(),
+ old_string: "key2 = bbb".to_string(),
+ new_string: "key2 = yyy".to_string(),
+ replace_all: false,
+ };
+ let (outcome, new_bytes) = call2.execute(&file_path, &tracker);
+ assert!(matches!(outcome, ToolOutcome::Success(_)));
+ assert!(new_bytes.is_some());
+ assert_eq!(
+ std::fs::read_to_string(&file_path).unwrap(),
+ "key1 = xxx\nkey2 = yyy\n"
+ );
+ }
+
+ #[test]
+ fn external_modification_between_edits() {
+ let dir = tempfile::tempdir().unwrap();
+ let file_path = dir.path().join("config.toml");
+ std::fs::write(&file_path, "value = original\n").unwrap();
+
+ let mut tracker = FileReadTracker::default();
+ simulate_read(&mut tracker, &file_path);
+
+ // First edit succeeds
+ let call1 = EditToolCall {
+ path: file_path.clone(),
+ old_string: "value = original".to_string(),
+ new_string: "value = edited".to_string(),
+ replace_all: false,
+ };
+ let (outcome, new_bytes) = call1.execute(&file_path, &tracker);
+ assert!(matches!(outcome, ToolOutcome::Success(_)));
+ simulate_tracker_update(&mut tracker, &file_path, &new_bytes.unwrap());
+
+ // External modification (e.g., user edits the file)
+ std::thread::sleep(std::time::Duration::from_millis(10));
+ std::fs::write(&file_path, "value = user_changed\n").unwrap();
+
+ // Second edit should fail (stale)
+ let call2 = EditToolCall {
+ path: file_path.clone(),
+ old_string: "value = edited".to_string(),
+ new_string: "value = second_edit".to_string(),
+ replace_all: false,
+ };
+ let (outcome, new_bytes) = call2.execute(&file_path, &tracker);
+ assert!(new_bytes.is_none());
+ match outcome {
+ ToolOutcome::Error(msg) => assert!(msg.contains("modified since read")),
+ _ => panic!("expected stale error"),
+ }
+
+ // File should be unchanged (the user's edit preserved)
+ assert_eq!(
+ std::fs::read_to_string(&file_path).unwrap(),
+ "value = user_changed\n"
+ );
+ }
+
+ #[test]
+ fn snapshot_only_created_once_per_file() {
+ let dir = tempfile::tempdir().unwrap();
+ let file_path = dir.path().join("config.toml");
+ std::fs::write(&file_path, "a = 1\nb = 2\n").unwrap();
+
+ let snapshot_dir = dir.path().join("snapshots").join("session-1");
+ let mut tracker = FileReadTracker::default();
+ let mut store = SnapshotStore::open(snapshot_dir).unwrap();
+
+ simulate_read(&mut tracker, &file_path);
+
+ // First edit — snapshot should be created
+ let original = std::fs::read(&file_path).unwrap();
+ let created = store.ensure_snapshot(&file_path, &original).unwrap();
+ assert!(created);
+
+ let call1 = EditToolCall {
+ path: file_path.clone(),
+ old_string: "a = 1".to_string(),
+ new_string: "a = 10".to_string(),
+ replace_all: false,
+ };
+ let (_, new_bytes) = call1.execute(&file_path, &tracker);
+ simulate_tracker_update(&mut tracker, &file_path, &new_bytes.unwrap());
+
+ // Second edit — snapshot should NOT be recreated
+ let content_before_second = std::fs::read(&file_path).unwrap();
+ let created = store
+ .ensure_snapshot(&file_path, &content_before_second)
+ .unwrap();
+ assert!(!created); // idempotent — already snapshotted
+ }
+
+ #[test]
+ fn permission_cache_grant_and_check() {
+ let mut cache = EditPermissionCache::default();
+ let path = std::path::PathBuf::from("/Users/me/.config/atuin/config.toml");
+
+ // Initially no grant
+ assert!(!cache.has_valid_grant(&path));
+
+ // Grant permission
+ cache.grant(path.clone());
+ assert!(cache.has_valid_grant(&path));
+
+ // Different file has no grant
+ assert!(!cache.has_valid_grant(std::path::Path::new("/other/file.toml")));
+
+ // Roundtrip through JSON (simulates session persistence)
+ let json = cache.to_json().unwrap();
+ let restored = EditPermissionCache::from_json(&json).unwrap();
+ assert!(restored.has_valid_grant(&path));
+ }
+ }
+
// ── Windows-specific tests (absolute paths with drive letters) ──
#[cfg(windows)]
diff --git a/crates/atuin-ai/src/tui/dispatch.rs b/crates/atuin-ai/src/tui/dispatch.rs
index ea895c01..fea26953 100644
--- a/crates/atuin-ai/src/tui/dispatch.rs
+++ b/crates/atuin-ai/src/tui/dispatch.rs
@@ -61,15 +61,17 @@ pub(crate) fn dispatch(ctx: &mut DispatchContext, event: AiTuiEvent) -> bool {
!ctx.exiting.load(Ordering::Acquire)
}
-/// Persist new events and the server session ID if it has changed.
+/// Persist new events, server session ID, file tracker, and edit permissions.
/// Called from the dispatch thread (sync), bridges to async via the tokio handle.
fn persist_session(ctx: &mut DispatchContext) {
- let Ok((events, server_sid)) = ctx
+ let Ok((events, server_sid, file_tracker_json, edit_perms_json)) = ctx
.handle
.fetch(|state| {
(
state.conversation.events.clone(),
state.conversation.session_id.clone(),
+ state.file_tracker.to_json().ok(),
+ state.edit_permissions.to_json().ok(),
)
})
.blocking_recv()
@@ -86,6 +88,22 @@ fn persist_session(ctx: &mut DispatchContext) {
{
tracing::warn!("failed to persist server session ID: {e}");
}
+ if let Some(ref json) = file_tracker_json
+ && let Err(e) = rt.block_on(
+ ctx.session_mgr
+ .set_metadata(crate::file_tracker::METADATA_KEY, json),
+ )
+ {
+ tracing::warn!("failed to persist file tracker: {e}");
+ }
+ if let Some(ref json) = edit_perms_json
+ && let Err(e) = rt.block_on(
+ ctx.session_mgr
+ .set_metadata(crate::edit_permissions::METADATA_KEY, json),
+ )
+ {
+ tracing::warn!("failed to persist edit permissions: {e}");
+ }
}
fn launch_stream(ctx: &DispatchContext, setup: impl FnOnce(&mut Session) + Send + 'static) {
@@ -210,6 +228,10 @@ fn execute_tool(
let shell_call = shell_call.clone();
execute_shell_tool(handle, tx, &tool_id, &shell_call);
}
+ ClientToolCall::Edit(edit_call) => {
+ let edit_call = edit_call.clone();
+ execute_edit_tool(handle, tx, tool_id, edit_call);
+ }
_ => {
execute_simple_tool(handle, tx, tool_id, tool, db);
}
@@ -231,7 +253,21 @@ fn execute_simple_tool(
tokio::spawn(async move {
let outcome = tool.execute(&db).await;
+
+ // After a successful file read, capture tracking data for freshness
+ // checking. This re-stats the file to get content hash and mtime.
+ let read_tracking = if let ClientToolCall::Read(ref read_tool) = tool
+ && !outcome.is_error()
+ {
+ capture_read_tracking(&read_tool.path)
+ } else {
+ None
+ };
+
h.update(move |state| {
+ if let Some((path, content, mtime)) = read_tracking {
+ state.file_tracker.record_read(path, &content, mtime);
+ }
state.finish_tool_call(&tool_id, outcome);
if !state.tool_tracker.has_pending() {
let _ = tx.send(AiTuiEvent::ContinueAfterTools);
@@ -240,6 +276,117 @@ fn execute_simple_tool(
});
}
+/// Capture file content and mtime for the read tracker.
+/// Returns None for directories or if the file can't be read.
+fn capture_read_tracking(
+ path: &std::path::Path,
+) -> Option<(std::path::PathBuf, Vec<u8>, std::time::SystemTime)> {
+ let resolved = if path.is_relative() {
+ std::env::current_dir().ok()?.join(path)
+ } else {
+ path.to_path_buf()
+ };
+ if !resolved.is_file() {
+ return None;
+ }
+ let content = std::fs::read(&resolved).ok()?;
+ let mtime = std::fs::metadata(&resolved).ok()?.modified().ok()?;
+ Some((resolved, content, mtime))
+}
+
+/// Execute an edit_file tool call.
+///
+/// Orchestrates snapshot → execute → tracker update. The snapshot and
+/// tracker mutations happen via `h.update()` (on the TUI thread) since
+/// they need mutable Session state. The actual file I/O (freshness check,
+/// read, match, atomic write) runs in the tokio task.
+fn execute_edit_tool(
+ handle: &Handle<Session>,
+ tx: &mpsc::Sender<AiTuiEvent>,
+ tool_id: String,
+ edit_call: crate::tools::EditToolCall,
+) {
+ let h = handle.clone();
+ let tx = tx.clone();
+
+ tokio::spawn(async move {
+ let resolved = edit_call.resolved_path();
+
+ // 1. Read the original file content (used for snapshot + diff).
+ let old_content = std::fs::read(&resolved).ok();
+
+ // 2. Snapshot the original file before editing.
+ if let Some(ref content) = old_content {
+ let snap_path = resolved.clone();
+ let snap_content = content.clone();
+ h.update(move |state| {
+ if let Some(ref mut store) = state.snapshot_store
+ && let Err(e) = store.ensure_snapshot(&snap_path, &snap_content)
+ {
+ tracing::warn!("failed to create file snapshot: {e}");
+ }
+ });
+ }
+
+ // 3. Fetch a clone of the file tracker for freshness checking.
+ let Ok(tracker) = h.fetch(|state| state.file_tracker.clone()).await else {
+ let tc_id = tool_id.clone();
+ h.update(move |state| {
+ state.finish_tool_call(
+ &tc_id,
+ crate::tools::ToolOutcome::Error("Internal error: TUI unavailable".into()),
+ );
+ if !state.tool_tracker.has_pending() {
+ let _ = tx.send(AiTuiEvent::ContinueAfterTools);
+ }
+ });
+ return;
+ };
+
+ // 4. Execute: freshness check → read → match → atomic write
+ let (outcome, new_bytes) = edit_call.execute(&resolved, &tracker);
+
+ // 5. Compute diff preview on success
+ let edit_preview = if let Some(ref new_bytes) = new_bytes {
+ if let Some(ref old_bytes) = old_content {
+ let old_str = String::from_utf8_lossy(old_bytes);
+ let new_str = String::from_utf8_lossy(new_bytes);
+ let preview = crate::diff::EditPreview::compute(&old_str, &new_str);
+ if preview.hunks.is_empty() {
+ None
+ } else {
+ Some(preview)
+ }
+ } else {
+ None
+ }
+ } else {
+ None
+ };
+
+ // 6. Update tracker, store diff preview, and finish the tool call
+ let tc_id = tool_id;
+ h.update(move |state| {
+ if let Some(ref new_bytes) = new_bytes
+ && let Ok(mtime) = std::fs::metadata(&resolved).and_then(|m| m.modified())
+ {
+ state
+ .file_tracker
+ .update_after_edit(&resolved, new_bytes, mtime);
+ }
+ if let Some(preview) = edit_preview
+ && let Some(tracked) = state.tool_tracker.get_mut(&tc_id)
+ {
+ tracked.edit_preview = Some(preview);
+ }
+ state.finish_tool_call(&tc_id, outcome);
+ if !state.tool_tracker.has_pending() {
+ let _ = tx.send(AiTuiEvent::ContinueAfterTools);
+ }
+ });
+ });
+}
+
/// Execute a shell tool with streaming VT100 preview.
fn execute_shell_tool(
handle: &Handle<Session>,
@@ -352,12 +499,28 @@ async fn check_tool_permission_inner(
.map_err(|e| format!("Internal error fetching tool state: {e}"))?
.ok_or_else(|| "Internal error: tool not found in tracker".to_string())?;
- // 2. Resolve working directory
+ // 2. For edit tools, check session-scoped permission grants before
+ // hitting the filesystem-based resolver. A valid grant means the user
+ // already approved this file recently.
+ if let ClientToolCall::Edit(ref edit) = tool {
+ let resolved = edit.resolved_path();
+ let has_grant = h2
+ .fetch(move |state| state.edit_permissions.has_valid_grant(&resolved))
+ .await
+ .unwrap_or(false);
+
+ if has_grant {
+ execute_tool(h2, tx, id, tool, db);
+ return Ok(());
+ }
+ }
+
+ // 3. Resolve working directory
let working_dir = target_dir
.or_else(|| std::env::current_dir().ok())
.ok_or_else(|| "Could not determine working directory".to_string())?;
- // 3. Create permission resolver and check
+ // 4. Create permission resolver and check
let resolver = PermissionResolver::new(working_dir)
.await
.map_err(|e| format!("Permission check failed: {e}"))?;
@@ -367,7 +530,7 @@ async fn check_tool_permission_inner(
.await
.map_err(|e| format!("Permission check failed: {e}"))?;
- // 4. Handle response — all paths here handle the tool, so return Ok
+ // 5. Handle response — all paths here handle the tool, so return Ok
let id_clone = id.clone();
match response {
PermissionResponse::Allowed => {
@@ -423,6 +586,32 @@ fn on_select_permission(ctx: &mut DispatchContext, permission: PermissionResult)
execute_tool(&h2, &tx, tool_id, tool, &db);
});
}
+ PermissionResult::AllowFileForSession => {
+ // Cache a session-scoped, time-limited grant for this file
+ let db = ctx.app_ctx.history_db.clone();
+ tokio::spawn(async move {
+ let Ok(Some((tool_id, tool))) = h2
+ .fetch(move |state| {
+ state
+ .tool_tracker
+ .asking_for_permission()
+ .map(|t| (t.id.clone(), t.tool.clone()))
+ })
+ .await
+ else {
+ return;
+ };
+
+ if let ClientToolCall::Edit(ref edit) = tool {
+ let resolved = edit.resolved_path();
+ h2.update(move |state| {
+ state.edit_permissions.grant(resolved);
+ });
+ }
+
+ execute_tool(&h2, &tx, tool_id, tool, &db);
+ });
+ }
PermissionResult::AlwaysAllowInDir => {
let db = ctx.app_ctx.history_db.clone();
let git_root = ctx.app_ctx.git_root.clone();
diff --git a/crates/atuin-ai/src/tui/events.rs b/crates/atuin-ai/src/tui/events.rs
index 1a422fef..969f6ae5 100644
--- a/crates/atuin-ai/src/tui/events.rs
+++ b/crates/atuin-ai/src/tui/events.rs
@@ -38,7 +38,34 @@ pub(crate) enum AiTuiEvent {
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum PermissionResult {
Allow,
+ /// Per-file, time-limited grant scoped to the current session.
+ AllowFileForSession,
AlwaysAllowInDir,
AlwaysAllow,
Deny,
}
+
+impl PermissionResult {
+ /// String identifier used as the SelectOption value.
+ pub fn as_value_str(&self) -> &'static str {
+ match self {
+ Self::Allow => "allow",
+ Self::AllowFileForSession => "allow-file-session",
+ Self::AlwaysAllowInDir => "always-allow-in-dir",
+ Self::AlwaysAllow => "always-allow",
+ Self::Deny => "deny",
+ }
+ }
+
+ /// Parse from a SelectOption value string.
+ pub fn from_value_str(s: &str) -> Option<Self> {
+ match s {
+ "allow" => Some(Self::Allow),
+ "allow-file-session" => Some(Self::AllowFileForSession),
+ "always-allow-in-dir" => Some(Self::AlwaysAllowInDir),
+ "always-allow" => Some(Self::AlwaysAllow),
+ "deny" => Some(Self::Deny),
+ _ => None,
+ }
+ }
+}
diff --git a/crates/atuin-ai/src/tui/state.rs b/crates/atuin-ai/src/tui/state.rs
index e122918e..af1ebffe 100644
--- a/crates/atuin-ai/src/tui/state.rs
+++ b/crates/atuin-ai/src/tui/state.rs
@@ -474,6 +474,12 @@ pub(crate) struct Session {
pub slash_registry: SlashCommandRegistry,
/// The unique ID for this invocation
pub invocation_id: String,
+ /// Tracks which files have been read, for freshness checking before edits.
+ pub file_tracker: crate::file_tracker::FileReadTracker,
+ /// Session-scoped edit permission grants (per-file, time-limited).
+ pub edit_permissions: crate::edit_permissions::EditPermissionCache,
+ /// Backs up files before the first edit in a session.
+ pub snapshot_store: Option<crate::snapshots::SnapshotStore>,
}
impl Session {
@@ -491,6 +497,9 @@ impl Session {
archived_view_events: Vec::new(),
slash_registry: Default::default(),
invocation_id: invocation_id.unwrap_or_else(|| uuid::Uuid::now_v7().to_string()),
+ file_tracker: Default::default(),
+ edit_permissions: Default::default(),
+ snapshot_store: None,
}
}
diff --git a/crates/atuin-ai/src/tui/view/mod.rs b/crates/atuin-ai/src/tui/view/mod.rs
index 565a0597..bdbece9c 100644
--- a/crates/atuin-ai/src/tui/view/mod.rs
+++ b/crates/atuin-ai/src/tui/view/mod.rs
@@ -1,12 +1,11 @@
//! View function that builds the eye-declare element tree from app state.
use eye_declare::{
- BorderType, Cells, Column, Elements, HStack, Span, Spinner, Text, View, Viewport,
- WidthConstraint, element,
+ Cells, Column, Elements, HStack, Span, Spinner, Text, View, Viewport, WidthConstraint, element,
};
use ratatui_core::style::{Color, Modifier, Style};
-use crate::tools::{ClientToolCall, TrackedTool};
+use crate::tools::{ClientToolCall, HistorySearchFilterMode, ToolPreview, TrackedTool};
use crate::tui::components::select::SelectOption;
use crate::tui::components::session_continue::SessionContinue;
use crate::tui::events::{AiTuiEvent, PermissionResult};
@@ -68,6 +67,16 @@ pub(crate) fn ai_view(state: &Session) -> Elements {
})
})
+ #({
+ let needs_pending_banner = busy && !matches!(turns.last(), Some(turn::UiTurn::Agent { .. }));
+ if needs_pending_banner {
+ let empty: &[turn::UiEvent] = &[];
+ agent_turn_view(empty, true)
+ } else {
+ element! {}
+ }
+ })
+
#(if !state.is_exiting() {
#(input_view(state))
})
@@ -135,16 +144,13 @@ fn tool_call_view(tool_call: &TrackedTool, in_git_project: bool) -> Elements {
let verb = tool_call.tool.descriptor().display_verb;
let tool_desc = match &tool_call.tool {
ClientToolCall::Read(tool) => tool.path.display().to_string(),
+ ClientToolCall::Edit(tool) => tool.path.display().to_string(),
ClientToolCall::Write(tool) => tool.path.display().to_string(),
ClientToolCall::Shell(tool) => tool.command.clone(),
ClientToolCall::AtuinHistory(tool) => tool.query.clone(),
};
- let dir_label = if in_git_project {
- "Always allow in this workspace"
- } else {
- "Always allow in this directory"
- };
+ let select_options = permission_options_for_tool(&tool_call.tool, in_git_project);
element! {
View(key: format!("tool-call-{}", tool_call.id), padding_left: Cells::from(2), padding_top: Cells::from(1)) {
@@ -153,39 +159,68 @@ fn tool_call_view(tool_call: &TrackedTool, in_git_project: bool) -> Elements {
Span(text: &tool_desc, style: Style::default().fg(Color::Yellow))
}
View(padding_left: Cells::from(2)) {
- Select(options: [
- SelectOption::builder()
- .label("Allow")
- .value("allow")
- .build(),
- SelectOption::builder()
- .label(dir_label)
- .value("always-allow-in-dir")
- .build(),
- SelectOption::builder()
- .label("Always allow")
- .value("always-allow")
- .build(),
- SelectOption::builder()
- .label("Deny")
- .value("deny")
- .build(),
- ], on_select: Box::new(move |option: &SelectOption| {
- let value = match option.value.as_str() {
- "allow" => PermissionResult::Allow,
- "always-allow-in-dir" => PermissionResult::AlwaysAllowInDir,
- "always-allow" => PermissionResult::AlwaysAllow,
- "deny" => PermissionResult::Deny,
- _ => unreachable!(),
- };
-
- Some(AiTuiEvent::SelectPermission(value))
+ Select(options: select_options, on_select: Box::new(move |option: &SelectOption| {
+ PermissionResult::from_value_str(option.value.as_str())
+ .map(AiTuiEvent::SelectPermission)
}) as Box<dyn Fn(&SelectOption) -> Option<AiTuiEvent> + Send + Sync>)
}
}
}
}
+/// Build the permission SelectOptions appropriate for a tool call.
+///
+/// Edit tools get a per-file session-scoped option instead of the
+/// workspace-level "Always allow in this directory". Other tools
+/// keep the standard set.
+fn permission_options_for_tool(tool: &ClientToolCall, in_git_project: bool) -> Vec<SelectOption> {
+ match tool {
+ ClientToolCall::Edit(_) => vec![
+ SelectOption::builder()
+ .label("Allow")
+ .value(PermissionResult::Allow.as_value_str())
+ .build(),
+ SelectOption::builder()
+ .label("Allow this file for this session")
+ .value(PermissionResult::AllowFileForSession.as_value_str())
+ .build(),
+ SelectOption::builder()
+ .label("Always allow")
+ .value(PermissionResult::AlwaysAllow.as_value_str())
+ .build(),
+ SelectOption::builder()
+ .label("Deny")
+ .value(PermissionResult::Deny.as_value_str())
+ .build(),
+ ],
+ _ => {
+ let dir_label = if in_git_project {
+ "Always allow in this workspace"
+ } else {
+ "Always allow in this directory"
+ };
+ vec![
+ SelectOption::builder()
+ .label("Allow")
+ .value(PermissionResult::Allow.as_value_str())
+ .build(),
+ SelectOption::builder()
+ .label(dir_label)
+ .value(PermissionResult::AlwaysAllowInDir.as_value_str())
+ .build(),
+ SelectOption::builder()
+ .label("Always allow")
+ .value(PermissionResult::AlwaysAllow.as_value_str())
+ .build(),
+ SelectOption::builder()
+ .label("Deny")
+ .value(PermissionResult::Deny.as_value_str())
+ .build(),
+ ]
+ }
+ }
+}
+
fn user_turn_view(events: &[turn::UiEvent], first_turn: bool) -> Elements {
let label_style = Style::default()
.fg(Color::Cyan)
@@ -231,7 +266,10 @@ fn agent_turn_view(events: &[turn::UiEvent], busy: bool) -> Elements {
label_first: true,
done: !busy,
)
- #(for event in events {
+ #(for (i, event) in events.iter().enumerate() {
+ #(if i > 0 {
+ Text { Span(text: "") }
+ })
#(match event {
turn::UiEvent::Text { content } => {
element! {
@@ -247,47 +285,42 @@ fn agent_turn_view(events: &[turn::UiEvent], busy: bool) -> Elements {
suggested_command_view(details)
},
turn::UiEvent::ToolCall(details) => {
- let preview_done = details.preview.as_ref().is_some_and(|p| p.exit_code.is_some() || p.interrupted);
let tool_key = details.tool_use_id.clone();
element! {
View(key: format!("tool-output-{tool_key}"), padding_left: Cells::from(2)) {
- #(if let Some(ref preview) = details.preview {
- View(key: format!("preview-{tool_key}")) {
- #(preview_spinner_view(&details.name, preview_done))
- Viewport(
- key: format!("viewport-{tool_key}"),
- lines: preview.lines.clone(),
- height: 10,
- border: BorderType::Plain,
- border_style: Style::default().fg(Color::DarkGray),
- style: Style::default().fg(Color::White),
- wrap: false,
- )
- #(if let Some(code) = preview.exit_code {
- #(if code == 0 {
- Text {
- Span(text: format!("Exit code: {code}"), style: Style::default().fg(Color::Green))
- }
- } else {
- Text {
- Span(text: format!("Exit code: {code}"), style: Style::default().fg(Color::Red))
- }
- })
- })
- #(if preview.interrupted {
- Text {
- Span(text: "Interrupted", style: Style::default().fg(Color::Red).add_modifier(Modifier::BOLD))
- }
- })
- #(if !preview_done {
- Text {
- Span(text: "[Ctrl+C] Interrupt", style: Style::default().fg(Color::DarkGray))
- }
- })
- }
- } else {
- #(tool_status_view(&details.name, &details.status))
+ #(match &details.render_data {
+ turn::ToolRenderData::Shell { command, preview } => {
+ shell_tool_view(&tool_key, command, preview.as_ref())
+ },
+ turn::ToolRenderData::FileEdit { path, preview } => {
+ file_edit_tool_view(&tool_key, &details.status, path, preview.as_ref())
+ },
+ turn::ToolRenderData::FileWrite { path } => {
+ file_write_tool_view(&details.status, path)
+ },
+ turn::ToolRenderData::Remote => {
+ tool_status_view(&details.name, &details.status)
+ },
+ turn::ToolRenderData::FileRead { .. }
+ | turn::ToolRenderData::HistorySearch { .. } => {
+ element!{}
+ },
+ })
+ }
+ }
+ }
+ turn::UiEvent::ToolGroup(group) => {
+ let group_key = group.calls
+ .first()
+ .map(|c| c.tool_use_id.as_str())
+ .unwrap_or("empty");
+
+ element! {
+ View(key: format!("group-{group_key}"), padding_left: Cells::from(2)) {
+ #(match group.kind {
+ turn::ToolGroupKind::FileRead => file_read_group_view(group),
+ turn::ToolGroupKind::HistorySearch => history_search_group_view(group),
})
}
}
@@ -367,17 +400,391 @@ fn tool_status_view(name: &str, status: &turn::ToolResultStatus) -> Elements {
}
}
-/// Render a spinner/status line for a command preview (shell tools).
-fn preview_spinner_view(name: &str, done: bool) -> Elements {
+// ───────────────────────────────────────────────────────────────────
+// Per-tool view functions
+// ───────────────────────────────────────────────────────────────────
+
+/// Max output lines shown for a shell command preview.
+const MAX_SHELL_PREVIEW_LINES: u16 = 5;
+
+/// Render a shell command execution with live VT100 output viewport.
+fn shell_tool_view(tool_key: &str, command: &str, preview: Option<&ToolPreview>) -> Elements {
+ let preview_done = preview.is_some_and(|p| p.exit_code.is_some() || p.interrupted);
+
+ element! {
+ #(if let Some(preview) = preview {
+ View(key: format!("preview-{tool_key}")) {
+ Spinner(
+ label: if preview_done { format!("Ran: {command}") } else { format!("Running: {command}") },
+ done: preview_done,
+ hide_checkmark: true,
+ )
+ HStack {
+ View(width: WidthConstraint::Fixed(2)) {
+ Text { Span(text: "└ ") }
+ }
+ Column {
+ Viewport(
+ key: format!("viewport-{tool_key}"),
+ lines: preview.lines.clone(),
+ height: (preview.lines.len() as u16).clamp(1, MAX_SHELL_PREVIEW_LINES),
+ style: Style::default().fg(Color::Gray),
+ wrap: false,
+ )
+ }
+ }
+ #(shell_tool_footer(preview, preview_done))
+ }
+ } else {
+ Spinner(
+ label: format!("Running: {command}"),
+ label_style: Style::default().fg(Color::Yellow),
+ done: false,
+ )
+ })
+ }
+}
+
+fn shell_tool_footer(preview: &ToolPreview, preview_done: bool) -> Elements {
+ if preview.interrupted {
+ return element! {
+ Text {
+ Span(text: "Interrupted", style: Style::default().fg(Color::Red).add_modifier(Modifier::BOLD))
+ }
+ };
+ }
+ if !preview_done {
+ return element! {
+ Text {
+ Span(text: "[Ctrl+C] Interrupt", style: Style::default().fg(Color::DarkGray))
+ }
+ };
+ }
+ if let Some(code) = preview.exit_code {
+ let style = if code == 0 {
+ Style::default().fg(Color::Green)
+ } else {
+ Style::default().fg(Color::Red)
+ };
+ return element! {
+ Text { Span(text: format!("Exit code: {code}"), style: style) }
+ };
+ }
+ element! {}
+}
+
+/// Render a file edit tool call with diff preview.
+fn file_edit_tool_view(
+ key: &str,
+ status: &turn::ToolResultStatus,
+ path: &std::path::Path,
+ preview: Option<&crate::diff::EditPreview>,
+) -> Elements {
+ use crate::diff::DiffLine;
+
+ let display_path = format_path_for_display(path);
+
+ let status_line = match status {
+ turn::ToolResultStatus::Pending => {
+ element! {
+ Spinner(
+ label: format!("Editing: {display_path}"),
+ label_style: Style::default().fg(Color::Yellow),
+ done: false,
+ )
+ }
+ }
+ turn::ToolResultStatus::Success => {
+ element! {
+ Spinner(label: format!("Edited: {display_path}"), done: true)
+ }
+ }
+ turn::ToolResultStatus::Error => {
+ element! {
+ Text {
+ Span(text: "✗ ", style: Style::default().fg(Color::Red))
+ Span(text: format!("Edit {display_path}: failed"), style: Style::default().fg(Color::Red))
+ }
+ }
+ }
+ };
+
+ // If no preview, just show the status line
+ let Some(preview) = preview else {
+ return status_line;
+ };
+ if preview.hunks.is_empty() {
+ return status_line;
+ }
+
+ // Calculate the line number gutter width from the highest line number
+ let max_line_num = preview.max_line_number();
+ let gutter_width = max_line_num.to_string().len().max(2) as u16 + 1; // +1 for spacing
+
+ element! {
+ View(key: key.to_string()) {
+ #(status_line)
+
+ View(key: format!("{key}-diff"), padding_left: Cells::from(2)) {
+ #(for (hunk_idx, hunk) in preview.hunks.iter().enumerate() {
+ #({
+ let gutter_w = gutter_width;
+ let mut before_pos = hunk.before_start;
+ let mut after_pos = hunk.after_start;
+ let lines_rendered: Vec<_> = hunk.lines.iter().enumerate().map(|(line_idx, line)| {
+ let (prefix, text, style, gutter_text, gutter_style) = match line {
+ DiffLine::Context(t) => {
+ let num = format!("{:>width$}", after_pos, width = (gutter_w - 1) as usize);
+ before_pos += 1;
+ after_pos += 1;
+ (" ", t.as_str(), Style::default().fg(Color::DarkGray), num, Style::default().fg(Color::DarkGray))
+ }
+ DiffLine::Removed(t) => {
+ let num = format!("{:>width$}", before_pos, width = (gutter_w - 1) as usize);
+ before_pos += 1;
+ ("-", t.as_str(), Style::default().fg(Color::Red), num, Style::default().fg(Color::Red))
+ }
+ DiffLine::Added(t) => {
+ let num = format!("{:>width$}", after_pos, width = (gutter_w - 1) as usize);
+ after_pos += 1;
+ ("+", t.as_str(), Style::default().fg(Color::Green), num, Style::default().fg(Color::Green))
+ }
+ };
+ (line_idx, prefix, text.to_string(), style, gutter_text, gutter_style)
+ }).collect();
+
+ element! {
+ View(key: format!("{key}-hunk-{hunk_idx}")) {
+ #(for (line_idx, prefix, text, style, gutter_text, gutter_style) in &lines_rendered {
+ HStack(key: format!("{key}-hunk-{hunk_idx}-line-{line_idx}")) {
+ View(width: WidthConstraint::Fixed(gutter_w)) {
+ Text { Span(text: gutter_text, style: *gutter_style) }
+ }
+ View {
+ Text {
+ Span(text: *prefix, style: *style)
+ Span(text: text, style: *style)
+ }
+ }
+ }
+ })
+ }
+ }
+ })
+ })
+ }
+ }
+ }
+}
+
+/// Render a file write tool call status with the target path.
+fn file_write_tool_view(status: &turn::ToolResultStatus, path: &std::path::Path) -> Elements {
+ let display_path = path.display();
+ match status {
+ turn::ToolResultStatus::Pending => {
+ element! {
+ Spinner(
+ label: format!("Writing: {display_path}"),
+ label_style: Style::default().fg(Color::Yellow),
+ done: false,
+ )
+ }
+ }
+ turn::ToolResultStatus::Success => {
+ element! {
+ Spinner(label: format!("Wrote: {display_path}"), done: true)
+ }
+ }
+ turn::ToolResultStatus::Error => {
+ element! {
+ Text {
+ Span(text: "✗ ", style: Style::default().fg(Color::Red))
+ Span(text: format!("Write {display_path}: denied"), style: Style::default().fg(Color::Red))
+ }
+ }
+ }
+ }
+}
+
+// ───────────────────────────────────────────────────────────────────
+// Tool group view functions
+// ───────────────────────────────────────────────────────────────────
+
+/// Max entries shown under a tool group header. When the group holds more
+/// than this, only the most recent `MAX_GROUP_ENTRIES` are displayed; the
+/// count in the header line tells the full story.
+const MAX_GROUP_ENTRIES: usize = 5;
+
+/// Format a filesystem path for display in tool rows.
+///
+/// - Relative to the current working directory if the path is under it
+/// - `~/...` prefix if the path is under the user's home directory
+/// - Absolute otherwise (and relative paths pass through unchanged)
+fn format_path_for_display(path: &std::path::Path) -> String {
+ if let Ok(cwd) = std::env::current_dir()
+ && let Ok(relative) = path.strip_prefix(&cwd)
+ {
+ return relative.display().to_string();
+ }
+
+ if let Ok(home) = std::env::var("HOME")
+ && let Ok(relative) = path.strip_prefix(&home)
+ {
+ return format!("~/{}", relative.display());
+ }
+
+ path.display().to_string()
+}
+
+fn filter_mode_label(mode: &HistorySearchFilterMode) -> &'static str {
+ match mode {
+ HistorySearchFilterMode::Global => "global",
+ HistorySearchFilterMode::Host => "host",
+ HistorySearchFilterMode::Session => "session",
+ HistorySearchFilterMode::Directory => "directory",
+ HistorySearchFilterMode::Workspace => "workspace",
+ }
+}
+
+/// Format a list of filter modes as `"(global, workspace)"`, or an empty
+/// string if the list is empty.
+fn format_filter_modes(modes: &[HistorySearchFilterMode]) -> String {
+ if modes.is_empty() {
+ return String::new();
+ }
+ let parts: Vec<&'static str> = modes.iter().map(filter_mode_label).collect();
+ format!("({})", parts.join(", "))
+}
+
+/// Tree-connector marker for a row in a grouped list: `└ ` for the first
+/// visible row, two spaces for subsequent rows.
+fn tree_marker(is_first: bool) -> &'static str {
+ if is_first { "└ " } else { " " }
+}
+
+/// 2-char status marker column: ✓ / ✗ / blank.
+fn status_marker_view(status: &turn::ToolResultStatus) -> Elements {
+ match status {
+ turn::ToolResultStatus::Pending => element! {
+ Text { Span(text: " ") }
+ },
+ turn::ToolResultStatus::Success => element! {
+ Text { Span(text: "✓ ", style: Style::default().fg(Color::Green)) }
+ },
+ turn::ToolResultStatus::Error => element! {
+ Text { Span(text: "✗ ", style: Style::default().fg(Color::Red)) }
+ },
+ }
+}
+
+/// Compute the slice of calls to show — the most recent `MAX_GROUP_ENTRIES`.
+fn visible_group_calls(group: &turn::ToolGroup) -> &[turn::ToolCallDetails] {
+ let start = group.calls.len().saturating_sub(MAX_GROUP_ENTRIES);
+ &group.calls[start..]
+}
+
+/// Render a single row in a grouped list: [tree marker][status][content].
+fn group_row_view(is_first: bool, status: &turn::ToolResultStatus, content: Elements) -> Elements {
+ element! {
+ HStack {
+ View(width: WidthConstraint::Fixed(2)) {
+ Text { Span(text: tree_marker(is_first)) }
+ }
+ View(width: WidthConstraint::Fixed(2)) {
+ #(status_marker_view(status))
+ }
+ Column {
+ #(content)
+ }
+ }
+ }
+}
+
+/// Render a group of consecutive `read_file` tool calls.
+fn file_read_group_view(group: &turn::ToolGroup) -> Elements {
+ let count = group.calls.len();
+ let label = if count == 1 {
+ "Read 1 file".to_string()
+ } else {
+ format!("Read {count} files")
+ };
+ let done = !group.any_pending();
+ let visible = visible_group_calls(group);
+
+ element! {
+ Spinner(label: label, done: done, hide_checkmark: true)
+ #(for (i, details) in visible.iter().enumerate() {
+ #(file_read_row(i == 0, details))
+ })
+ }
+}
+
+fn file_read_row(is_first: bool, details: &turn::ToolCallDetails) -> Elements {
+ let path_str = match &details.render_data {
+ turn::ToolRenderData::FileRead { path } => format_path_for_display(path),
+ _ => String::new(),
+ };
+
+ let content = element! {
+ Text { Span(text: path_str) }
+ };
+
+ group_row_view(is_first, &details.status, content)
+}
+
+/// Render a group of consecutive `atuin_history` tool calls.
+fn history_search_group_view(group: &turn::ToolGroup) -> Elements {
+ let done = !group.any_pending();
+ let visible = visible_group_calls(group);
+
element! {
- Spinner(
- label: if done { format!("Ran: {name}") } else { format!("Running: {name}") },
- label_style: Style::default().fg(Color::Yellow),
- done: done,
- )
+ Spinner(label: "Searched Atuin history:", done: done, hide_checkmark: true)
+ #(for (i, details) in visible.iter().enumerate() {
+ #(history_search_row(i == 0, details))
+ })
}
}
+fn history_search_row(is_first: bool, details: &turn::ToolCallDetails) -> Elements {
+ let (query, filter_modes) = match &details.render_data {
+ turn::ToolRenderData::HistorySearch {
+ query,
+ filter_modes,
+ } => (query.as_str(), filter_modes.as_slice()),
+ _ => ("", [].as_slice()),
+ };
+
+ let is_empty_query = query.trim().is_empty();
+ let filter_label = format_filter_modes(filter_modes);
+
+ let content = if is_empty_query {
+ element! {
+ Text {
+ Span(
+ text: "recent commands",
+ style: Style::default().fg(Color::Gray).add_modifier(Modifier::ITALIC),
+ )
+ #(if !filter_label.is_empty() {
+ Span(text: " ")
+ Span(text: filter_label, style: Style::default().fg(Color::DarkGray))
+ })
+ }
+ }
+ } else {
+ element! {
+ Text {
+ Span(text: query.to_string())
+ #(if !filter_label.is_empty() {
+ Span(text: " ")
+ Span(text: filter_label, style: Style::default().fg(Color::DarkGray))
+ })
+ }
+ }
+ };
+
+ group_row_view(is_first, &details.status, content)
+}
+
fn suggested_command_view(details: &turn::SuggestedCommandDetails) -> Elements {
let is_dangerous = matches!(
details.danger_level,
@@ -413,9 +820,6 @@ fn suggested_command_view(details: &turn::SuggestedCommandDetails) -> Elements {
element! {
View {
- #(if !details.first_event_in_turn {
- Text { Span(text: "") }
- })
Text {
Span(text: " Suggested command:", style: Style::default().fg(Color::Cyan))
}
diff --git a/crates/atuin-ai/src/tui/view/turn.rs b/crates/atuin-ai/src/tui/view/turn.rs
index a2555dc6..1c19a6b2 100644
--- a/crates/atuin-ai/src/tui/view/turn.rs
+++ b/crates/atuin-ai/src/tui/view/turn.rs
@@ -1,5 +1,7 @@
+use std::path::PathBuf;
+
use crate::tools::descriptor;
-use crate::tools::{ToolPreview, ToolTracker};
+use crate::tools::{ClientToolCall, HistorySearchFilterMode, ToolPreview, ToolTracker};
use crate::tui::ConversationEvent;
/// Server-sent danger level for a suggested command
@@ -80,20 +82,99 @@ impl From<(&String, &String)> for ConfidenceLevel {
#[derive(Debug)]
pub(crate) enum UiEvent {
- Text { content: String },
+ Text {
+ content: String,
+ },
ToolCall(ToolCallDetails),
+ /// Consecutive client-side tool calls of the same groupable kind, collapsed
+ /// into one unit so the view can render a shared status line + a list of
+ /// individual entries.
+ ToolGroup(ToolGroup),
ToolSummary(ToolSummary),
SuggestedCommand(SuggestedCommandDetails),
OutOfBandOutput(OutOfBandOutputDetails),
}
+/// A run of consecutive client-side tool calls of the same groupable kind.
+#[derive(Debug)]
+pub(crate) struct ToolGroup {
+ pub(crate) kind: ToolGroupKind,
+ pub(crate) calls: Vec<ToolCallDetails>,
+}
+
+impl ToolGroup {
+ /// True if any call in the group is still pending.
+ pub(crate) fn any_pending(&self) -> bool {
+ self.calls
+ .iter()
+ .any(|c| c.status == ToolResultStatus::Pending)
+ }
+}
+
+/// Which kind of client-side tools this group holds.
+///
+/// Only tool types that benefit from grouped presentation appear here.
+/// Shell (needs its own viewport) and FileWrite (wants diffs/contents) are
+/// intentionally absent — those render as individual `UiEvent::ToolCall`s.
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+pub(crate) enum ToolGroupKind {
+ FileRead,
+ HistorySearch,
+}
+
+/// Tool-type-specific data for rendering in the view layer.
+///
+/// Each variant carries the data a per-tool renderer component needs.
+/// Built by TurnBuilder from ToolTracker + ConversationEvent data.
+#[derive(Debug)]
+pub(crate) enum ToolRenderData {
+ /// Shell command with live/cached VT100 output preview.
+ Shell {
+ command: String,
+ preview: Option<ToolPreview>,
+ },
+ /// File read operation.
+ FileRead { path: PathBuf },
+ /// File edit (str_replace) operation.
+ FileEdit {
+ path: PathBuf,
+ preview: Option<crate::diff::EditPreview>,
+ },
+ /// File write/create operation.
+ FileWrite { path: PathBuf },
+ /// Atuin history search.
+ HistorySearch {
+ query: String,
+ filter_modes: Vec<HistorySearchFilterMode>,
+ },
+ /// Server-side tool — no client rendering data available.
+ Remote,
+}
+
+impl ToolRenderData {
+ pub(crate) fn is_remote(&self) -> bool {
+ matches!(self, ToolRenderData::Remote)
+ }
+
+ /// The group kind this tool should collapse into, if any.
+ ///
+ /// Returns `None` for tools that render as individual `UiEvent::ToolCall`s
+ /// (shell, file writes, remote).
+ pub(crate) fn group_kind(&self) -> Option<ToolGroupKind> {
+ match self {
+ ToolRenderData::FileRead { .. } => Some(ToolGroupKind::FileRead),
+ ToolRenderData::HistorySearch { .. } => Some(ToolGroupKind::HistorySearch),
+ _ => None,
+ }
+ }
+}
+
#[derive(Debug)]
pub(crate) struct ToolCallDetails {
pub(crate) tool_use_id: String,
pub(crate) name: String,
pub(crate) status: ToolResultStatus,
- pub(crate) is_client: bool,
- pub(crate) preview: Option<ToolPreview>,
+ pub(crate) render_data: ToolRenderData,
}
#[derive(Debug)]
@@ -101,7 +182,6 @@ pub(crate) struct SuggestedCommandDetails {
pub(crate) command: String,
pub(crate) danger_level: DangerLevel,
pub(crate) confidence_level: ConfidenceLevel,
- pub(crate) first_event_in_turn: bool,
}
#[derive(Debug)]
@@ -179,33 +259,49 @@ impl<'a> TurnBuilder<'a> {
pub(crate) fn build(&mut self) -> Vec<UiTurn> {
self.commit_turn();
- // Collapse consecutive tool calls within each agent turn into ToolSummary
+ // Within each agent turn:
+ // - Consecutive remote tool calls collapse into a ToolSummary
+ // - Consecutive client-side tool calls of the same group kind collapse
+ // into a ToolGroup (e.g. N file reads → one group)
+ // - All other events pass through unchanged
for turn in &mut self.turns {
if let UiTurn::Agent { events } = turn {
let mut new_events: Vec<UiEvent> = Vec::new();
- let mut pending_tools: Vec<ToolCallDetails> = Vec::new();
+ let mut pending_remote: Vec<ToolCallDetails> = Vec::new();
+ let mut pending_group: Option<(ToolGroupKind, Vec<ToolCallDetails>)> = None;
for event in events.drain(..) {
match event {
- UiEvent::ToolCall(details) if !details.is_client => {
- pending_tools.push(details);
+ UiEvent::ToolCall(details) if details.render_data.is_remote() => {
+ flush_group(&mut pending_group, &mut new_events);
+ pending_remote.push(details);
}
- other => {
- if !pending_tools.is_empty() {
- new_events.push(UiEvent::ToolSummary(ToolSummary {
- tool_calls: std::mem::take(&mut pending_tools),
- }));
+ UiEvent::ToolCall(details)
+ if details.render_data.group_kind().is_some() =>
+ {
+ flush_remote(&mut pending_remote, &mut new_events);
+
+ let kind = details.render_data.group_kind().unwrap();
+ match pending_group.as_mut() {
+ Some((current_kind, calls)) if *current_kind == kind => {
+ calls.push(details);
+ }
+ _ => {
+ flush_group(&mut pending_group, &mut new_events);
+ pending_group = Some((kind, vec![details]));
+ }
}
+ }
+ other => {
+ flush_remote(&mut pending_remote, &mut new_events);
+ flush_group(&mut pending_group, &mut new_events);
new_events.push(other);
}
}
}
- if !pending_tools.is_empty() {
- new_events.push(UiEvent::ToolSummary(ToolSummary {
- tool_calls: pending_tools,
- }));
- }
+ flush_remote(&mut pending_remote, &mut new_events);
+ flush_group(&mut pending_group, &mut new_events);
*events = new_events;
}
@@ -255,6 +351,9 @@ impl<'a> TurnBuilder<'a> {
}
fn add_agent_text(&mut self, content: &str) {
+ if content.trim().is_empty() {
+ return;
+ }
self.start_agent_turn();
if let UiTurn::Agent { events } = self.turn_mut_unsafe() {
events.push(UiEvent::Text {
@@ -303,8 +402,6 @@ impl<'a> TurnBuilder<'a> {
let danger = DangerLevel::from((&danger_level, &danger_notes));
let confidence = ConfidenceLevel::from((&confidence_level, &confidence_notes));
- let first_event_in_turn = events.is_empty();
-
events.push(UiEvent::SuggestedCommand(SuggestedCommandDetails {
command: input
.get("command")
@@ -313,14 +410,12 @@ impl<'a> TurnBuilder<'a> {
.to_string(),
danger_level: danger,
confidence_level: confidence,
- first_event_in_turn,
}));
}
}
fn add_tool_call(&mut self, id: &str, name: &str, _input: &serde_json::Value) {
- let is_client = descriptor::by_name(name).is_some_and(|d| d.is_client);
- let preview = self.tracker.preview_for(id);
+ let render_data = self.build_render_data(id, name);
self.start_agent_turn();
if let UiTurn::Agent { events } = self.turn_mut_unsafe() {
@@ -328,12 +423,44 @@ impl<'a> TurnBuilder<'a> {
tool_use_id: id.to_string(),
name: name.to_string(),
status: ToolResultStatus::Pending,
- is_client,
- preview,
+ render_data,
}));
}
}
+ /// Build tool-type-specific render data from the ToolTracker.
+ ///
+ /// For client-side tools, the tracker holds the typed `ClientToolCall` and
+ /// any live/cached preview data. For server-side (or unknown) tools, we
+ /// fall back to `ToolRenderData::Remote`.
+ fn build_render_data(&self, id: &str, _name: &str) -> ToolRenderData {
+ if let Some(tracked) = self.tracker.get(id) {
+ match &tracked.tool {
+ ClientToolCall::Shell(shell) => ToolRenderData::Shell {
+ command: shell.command.clone(),
+ preview: tracked.preview(),
+ },
+ ClientToolCall::Read(read) => ToolRenderData::FileRead {
+ path: read.path.clone(),
+ },
+ ClientToolCall::Edit(edit) => ToolRenderData::FileEdit {
+ path: edit.path.clone(),
+ preview: tracked.edit_preview.clone(),
+ },
+ ClientToolCall::Write(write) => ToolRenderData::FileWrite {
+ path: write.path.clone(),
+ },
+ ClientToolCall::AtuinHistory(history) => ToolRenderData::HistorySearch {
+ query: history.query.clone(),
+ filter_modes: history.filter_modes.clone(),
+ },
+ }
+ } else {
+ // Not in tracker → server-side tool
+ ToolRenderData::Remote
+ }
+ }
+
fn add_tool_result(&mut self, tool_use_id: &str, _content: &str, is_error: bool) {
self.start_agent_turn();
if let UiTurn::Agent { events } = self.turn_mut_unsafe() {
@@ -364,6 +491,25 @@ impl<'a> TurnBuilder<'a> {
}
}
+/// Drain pending remote tool calls into a `ToolSummary`.
+fn flush_remote(pending: &mut Vec<ToolCallDetails>, out: &mut Vec<UiEvent>) {
+ if !pending.is_empty() {
+ out.push(UiEvent::ToolSummary(ToolSummary {
+ tool_calls: std::mem::take(pending),
+ }));
+ }
+}
+
+/// Drain a pending client-side tool group into a `ToolGroup`.
+fn flush_group(
+ pending: &mut Option<(ToolGroupKind, Vec<ToolCallDetails>)>,
+ out: &mut Vec<UiEvent>,
+) {
+ if let Some((kind, calls)) = pending.take() {
+ out.push(UiEvent::ToolGroup(ToolGroup { kind, calls }));
+ }
+}
+
#[derive(Debug)]
pub(crate) struct ToolSummary {
tool_calls: Vec<ToolCallDetails>,