aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMichelle Tilley <michelle@michelletilley.net>2026-04-21 10:32:54 -0700
committerGitHub <noreply@github.com>2026-04-21 10:32:54 -0700
commit0f20ee4eb871907defe7848f0d3e2203cfff057e (patch)
treecda9034c4c6e7b5ecf0fe957978284e9138b80ff
parentchore: Clarified note about regular expressions matching in path. (#3427) (diff)
downloadatuin-0f20ee4eb871907defe7848f0d3e2203cfff057e.zip
feat: AI tool rendering overhaul + edit_file tool (#3423)
Overhaul of how AI tool calls are modeled, rendered, and displayed in the Atuin AI TUI. Fixes bugs in shell command output capture, implements the `edit_file` tool with full safety infrastructure, and adds a diff preview for edits.
Diffstat (limited to '')
-rw-r--r--Cargo.lock19
-rw-r--r--Cargo.toml3
-rw-r--r--crates/atuin-ai/Cargo.toml5
-rw-r--r--crates/atuin-ai/migrations/20260417000000_add_session_metadata.sql9
-rw-r--r--crates/atuin-ai/src/commands/inline.rs29
-rw-r--r--crates/atuin-ai/src/diff.rs294
-rw-r--r--crates/atuin-ai/src/edit_permissions.rs108
-rw-r--r--crates/atuin-ai/src/file_tracker.rs234
-rw-r--r--crates/atuin-ai/src/lib.rs4
-rw-r--r--crates/atuin-ai/src/session.rs27
-rw-r--r--crates/atuin-ai/src/snapshots.rs414
-rw-r--r--crates/atuin-ai/src/store.rs32
-rw-r--r--crates/atuin-ai/src/stream.rs8
-rw-r--r--crates/atuin-ai/src/tools/descriptor.rs17
-rw-r--r--crates/atuin-ai/src/tools/mod.rs737
-rw-r--r--crates/atuin-ai/src/tui/dispatch.rs199
-rw-r--r--crates/atuin-ai/src/tui/events.rs27
-rw-r--r--crates/atuin-ai/src/tui/state.rs9
-rw-r--r--crates/atuin-ai/src/tui/view/mod.rs570
-rw-r--r--crates/atuin-ai/src/tui/view/turn.rs198
-rw-r--r--crates/atuin-client/Cargo.toml2
-rw-r--r--crates/atuin-client/src/settings.rs4
-rw-r--r--docs/docs/ai/settings.md12
23 files changed, 2815 insertions, 146 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 0eaf100f..68f93205 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -285,6 +285,7 @@ dependencies = [
"fs-err",
"futures",
"glob-match",
+ "imara-diff",
"pretty_assertions",
"pulldown-cmark",
"ratatui",
@@ -294,6 +295,7 @@ dependencies = [
"reqwest",
"serde",
"serde_json",
+ "shellexpand",
"sqlx",
"tempfile",
"thiserror 2.0.18",
@@ -312,6 +314,7 @@ dependencies = [
"unicode-width 0.2.2",
"uuid",
"vt100",
+ "xxhash-rust",
]
[[package]]
@@ -2258,6 +2261,16 @@ dependencies = [
]
[[package]]
+name = "imara-diff"
+version = "0.2.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2f01d462f766df78ab820dd06f5eb700233c51f0f4c2e846520eaf4ba6aa5c5c"
+dependencies = [
+ "hashbrown 0.15.5",
+ "memchr",
+]
+
+[[package]]
name = "indenter"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -6711,6 +6724,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea6fc2961e4ef194dcbfe56bb845534d0dc8098940c7e5c012a258bfec6701bd"
[[package]]
+name = "xxhash-rust"
+version = "0.8.15"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3"
+
+[[package]]
name = "yansi"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
diff --git a/Cargo.toml b/Cargo.toml
index 80975d44..eb698f99 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -51,6 +51,7 @@ rand = { version = "0.8.5", features = ["std"] }
semver = "1.0.20"
serde = { version = "1.0.202", features = ["derive"] }
serde_json = "1.0.119"
+shellexpand = "3"
tokio = { version = "1", features = ["full"] }
uuid = { version = "1.9", features = ["v4", "v7", "serde"] }
whoami = "2.1.0"
@@ -70,6 +71,8 @@ rustls = { version = "0.23", default-features = false, features = [
"tls12",
] }
glob-match = "0.2.1"
+imara-diff = "0.2"
+xxhash-rust = { version = "0.8", features = ["xxh3"] }
vt100 = "0.16"
regex = "1.10.5"
toml_edit = "0.25.4"
diff --git a/crates/atuin-ai/Cargo.toml b/crates/atuin-ai/Cargo.toml
index 3be127de..167b625c 100644
--- a/crates/atuin-ai/Cargo.toml
+++ b/crates/atuin-ai/Cargo.toml
@@ -59,10 +59,13 @@ tree-sitter-bash = { version = "0.25.1", optional = true }
tree-sitter-fish = { version = "3.6.0", optional = true }
sqlx = { workspace = true, features = ["sqlite"] }
typed-builder = { workspace = true }
+shellexpand = { workspace = true }
+imara-diff = { workspace = true }
+xxhash-rust = { workspace = true }
vt100 = { workspace = true }
+tempfile = { workspace = true }
chrono = "0.4"
chrono-humanize = "0.2"
[dev-dependencies]
pretty_assertions = { workspace = true }
-tempfile = { workspace = true }
diff --git a/crates/atuin-ai/migrations/20260417000000_add_session_metadata.sql b/crates/atuin-ai/migrations/20260417000000_add_session_metadata.sql
new file mode 100644
index 00000000..f97dfd1b
--- /dev/null
+++ b/crates/atuin-ai/migrations/20260417000000_add_session_metadata.sql
@@ -0,0 +1,9 @@
+CREATE TABLE IF NOT EXISTS session_metadata (
+ session_id TEXT NOT NULL,
+ key TEXT NOT NULL,
+ value TEXT NOT NULL,
+ updated_at INTEGER NOT NULL,
+
+ PRIMARY KEY (session_id, key),
+ FOREIGN KEY (session_id) REFERENCES sessions(id)
+);
diff --git a/crates/atuin-ai/src/commands/inline.rs b/crates/atuin-ai/src/commands/inline.rs
index b7aae51f..e0a92ab4 100644
--- a/crates/atuin-ai/src/commands/inline.rs
+++ b/crates/atuin-ai/src/commands/inline.rs
@@ -175,7 +175,7 @@ async fn run_inline_tui(
.find_resumable(cwd.as_deref(), git_root_str.as_deref(), max_age_secs)
.await?;
- let (mut session_mgr, initial_state) = if let Some(stored) = resumable {
+ let (mut session_mgr, mut initial_state) = if let Some(stored) = resumable {
debug!(session_id = %stored.id, "resuming AI session");
let (mgr, events, server_sid, last_event_ts, invocation_id) =
SessionManager::resume(Box::new(service), &stored).await?;
@@ -199,6 +199,23 @@ async fn run_inline_tui(
session.is_resumed = true;
session.last_event_time =
last_event_ts.and_then(|ts| chrono::DateTime::from_timestamp(ts, 0));
+
+ // Restore file read tracker from session metadata
+ if let Ok(Some(json)) = mgr.get_metadata(crate::file_tracker::METADATA_KEY).await
+ && let Ok(tracker) = crate::file_tracker::FileReadTracker::from_json(&json)
+ {
+ session.file_tracker = tracker;
+ }
+
+ // Restore edit permission grants from session metadata
+ if let Ok(Some(json)) = mgr
+ .get_metadata(crate::edit_permissions::METADATA_KEY)
+ .await
+ && let Ok(cache) = crate::edit_permissions::EditPermissionCache::from_json(&json)
+ {
+ session.edit_permissions = cache;
+ }
+
(mgr, session)
} else {
// No meaningful content — treat as a fresh session
@@ -215,6 +232,16 @@ async fn run_inline_tui(
(mgr, Session::new(ctx.git_root.is_some(), None))
};
+ // Initialize the snapshot store now that we know the session ID.
+ let snapshot_dir = atuin_common::utils::data_dir()
+ .join("ai")
+ .join("snapshots")
+ .join(session_mgr.session_id());
+ match crate::snapshots::SnapshotStore::open(snapshot_dir) {
+ Ok(store) => initial_state.snapshot_store = Some(store),
+ Err(e) => tracing::warn!("failed to open snapshot store: {e}"),
+ }
+
let (tx, rx) = mpsc::channel::<AiTuiEvent>();
println!();
diff --git a/crates/atuin-ai/src/diff.rs b/crates/atuin-ai/src/diff.rs
new file mode 100644
index 00000000..663481c0
--- /dev/null
+++ b/crates/atuin-ai/src/diff.rs
@@ -0,0 +1,294 @@
+//! Structured diff computation for edit previews.
+//!
+//! Computes a line-level diff between old and new file content using
+//! imara-diff's Histogram algorithm, producing structured hunks with
+//! typed lines (Context, Added, Removed) suitable for TUI rendering.
+
+use imara_diff::{Algorithm, Diff, InternedInput};
+
+/// Number of context lines to show around each change.
+const CONTEXT_LINES: u32 = 3;
+
+/// A structured diff preview for a file edit, ready for rendering.
+#[derive(Debug, Clone)]
+pub(crate) struct EditPreview {
+ pub hunks: Vec<DiffHunk>,
+}
+
+/// A contiguous group of diff lines (context + changes).
+#[derive(Debug, Clone)]
+pub(crate) struct DiffHunk {
+ /// 1-indexed line number of the first line in this hunk (in the original file).
+ pub before_start: u32,
+ /// 1-indexed line number of the first line in this hunk (in the new file).
+ pub after_start: u32,
+ pub lines: Vec<DiffLine>,
+}
+
+/// A single line in a diff hunk.
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub(crate) enum DiffLine {
+ /// Unchanged line (shown for context).
+ Context(String),
+ /// Line added in the new version.
+ Added(String),
+ /// Line removed from the old version.
+ Removed(String),
+}
+
+impl EditPreview {
+ /// Compute a structured diff between old and new file content.
+ ///
+ /// Uses the Histogram algorithm with line-level granularity and
+ /// indentation-aware postprocessing for readable output.
+ pub fn compute(old: &str, new: &str) -> Self {
+ let input = InternedInput::new(old, new);
+ let mut diff = Diff::compute(Algorithm::Histogram, &input);
+ diff.postprocess_lines(&input);
+
+ let raw_hunks: Vec<_> = diff.hunks().collect();
+ if raw_hunks.is_empty() {
+ return EditPreview { hunks: Vec::new() };
+ }
+
+ // Merge hunks that are within 2*CONTEXT_LINES of each other
+ // (same logic as unified diff format).
+ let mut merged_groups: Vec<Vec<&imara_diff::Hunk>> = Vec::new();
+ let mut current_group: Vec<&imara_diff::Hunk> = vec![&raw_hunks[0]];
+
+ for hunk in &raw_hunks[1..] {
+ let prev = current_group.last().unwrap();
+ if hunk.before.start.saturating_sub(prev.before.end) <= 2 * CONTEXT_LINES {
+ current_group.push(hunk);
+ } else {
+ merged_groups.push(current_group);
+ current_group = vec![hunk];
+ }
+ }
+ merged_groups.push(current_group);
+
+ // Build structured hunks from merged groups
+ let hunks = merged_groups
+ .into_iter()
+ .map(|group| build_hunk(&group, &input))
+ .collect();
+
+ EditPreview { hunks }
+ }
+
+ /// The highest line number (from either file) that will be displayed.
+ /// Used to calculate gutter width.
+ pub fn max_line_number(&self) -> u32 {
+ self.hunks
+ .iter()
+ .map(|h| {
+ let mut before_pos = h.before_start;
+ let mut after_pos = h.after_start;
+ for line in &h.lines {
+ match line {
+ DiffLine::Context(_) => {
+ before_pos += 1;
+ after_pos += 1;
+ }
+ DiffLine::Removed(_) => before_pos += 1,
+ DiffLine::Added(_) => after_pos += 1,
+ }
+ }
+ before_pos.max(after_pos).saturating_sub(1)
+ })
+ .max()
+ .unwrap_or(0)
+ }
+}
+
+/// Build a single DiffHunk from a group of adjacent raw hunks.
+fn build_hunk(group: &[&imara_diff::Hunk], input: &InternedInput<&str>) -> DiffHunk {
+ let first = group.first().unwrap();
+ let last = group.last().unwrap();
+
+ let context_start = first.before.start.saturating_sub(CONTEXT_LINES);
+ let context_end = (last.before.end + CONTEXT_LINES).min(input.before.len() as u32);
+
+ // The after-file position of context_start: same offset as before since
+ // context before the first change is identical in both files.
+ let after_context_start = first.after.start - (first.before.start - context_start);
+
+ let mut lines = Vec::new();
+ let mut pos = context_start;
+
+ for hunk in group {
+ // Context lines before this hunk
+ for i in pos..hunk.before.start {
+ lines.push(DiffLine::Context(token_text(input, true, i)));
+ }
+
+ // Removed lines
+ for i in hunk.before.start..hunk.before.end {
+ lines.push(DiffLine::Removed(token_text(input, true, i)));
+ }
+
+ // Added lines
+ for i in hunk.after.start..hunk.after.end {
+ lines.push(DiffLine::Added(token_text(input, false, i)));
+ }
+
+ pos = hunk.before.end;
+ }
+
+ // Trailing context
+ for i in pos..context_end {
+ lines.push(DiffLine::Context(token_text(input, true, i)));
+ }
+
+ DiffHunk {
+ before_start: context_start + 1, // 1-indexed
+ after_start: after_context_start + 1, // 1-indexed
+ lines,
+ }
+}
+
+/// Extract the text content of a token, trimming the trailing newline
+/// that imara-diff includes in line-based tokenization.
+fn token_text(input: &InternedInput<&str>, is_before: bool, idx: u32) -> String {
+ let tokens = if is_before {
+ &input.before
+ } else {
+ &input.after
+ };
+ let text = input.interner[tokens[idx as usize]];
+ text.strip_suffix('\n')
+ .unwrap_or(text)
+ .strip_suffix('\r')
+ .unwrap_or(text.strip_suffix('\n').unwrap_or(text))
+ .to_string()
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn no_changes_produces_empty_preview() {
+ let preview = EditPreview::compute("hello\nworld\n", "hello\nworld\n");
+ assert!(preview.hunks.is_empty());
+ }
+
+ #[test]
+ fn single_line_replacement() {
+ let old = "line1\nline2\nline3\n";
+ let new = "line1\nchanged\nline3\n";
+ let preview = EditPreview::compute(old, new);
+
+ assert_eq!(preview.hunks.len(), 1);
+ let hunk = &preview.hunks[0];
+
+ // Should have: context(line1), removed(line2), added(changed), context(line3)
+ assert!(hunk.lines.contains(&DiffLine::Context("line1".into())));
+ assert!(hunk.lines.contains(&DiffLine::Removed("line2".into())));
+ assert!(hunk.lines.contains(&DiffLine::Added("changed".into())));
+ assert!(hunk.lines.contains(&DiffLine::Context("line3".into())));
+ }
+
+ #[test]
+ fn addition_only() {
+ let old = "aaa\nbbb\n";
+ let new = "aaa\nnew_line\nbbb\n";
+ let preview = EditPreview::compute(old, new);
+
+ assert_eq!(preview.hunks.len(), 1);
+ let hunk = &preview.hunks[0];
+ assert!(hunk.lines.contains(&DiffLine::Added("new_line".into())));
+ // Original lines are context
+ assert!(hunk.lines.contains(&DiffLine::Context("aaa".into())));
+ assert!(hunk.lines.contains(&DiffLine::Context("bbb".into())));
+ }
+
+ #[test]
+ fn removal_only() {
+ let old = "aaa\nremove_me\nbbb\n";
+ let new = "aaa\nbbb\n";
+ let preview = EditPreview::compute(old, new);
+
+ assert_eq!(preview.hunks.len(), 1);
+ let hunk = &preview.hunks[0];
+ assert!(hunk.lines.contains(&DiffLine::Removed("remove_me".into())));
+ }
+
+ #[test]
+ fn distant_changes_produce_separate_hunks() {
+ // Two changes separated by more than 2*CONTEXT_LINES (6) lines
+ let old = "1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n11\n12\n";
+ let new = "1\nX\n3\n4\n5\n6\n7\n8\n9\n10\n11\nY\n";
+ let preview = EditPreview::compute(old, new);
+
+ assert_eq!(preview.hunks.len(), 2);
+ }
+
+ #[test]
+ fn close_changes_merge_into_one_hunk() {
+ // Two changes separated by fewer than 2*CONTEXT_LINES lines
+ let old = "1\n2\n3\n4\n5\n";
+ let new = "X\n2\n3\n4\nY\n";
+ let preview = EditPreview::compute(old, new);
+
+ assert_eq!(preview.hunks.len(), 1);
+ }
+
+ #[test]
+ fn context_is_limited() {
+ // With CONTEXT_LINES=3, a change at line 10 shouldn't include line 1
+ let mut old_lines: Vec<&str> = (1..=20).map(|_| "unchanged").collect();
+ old_lines[9] = "target";
+ let old = old_lines.join("\n") + "\n";
+ let new = old.replace("target", "replaced");
+
+ let preview = EditPreview::compute(&old, &new);
+ assert_eq!(preview.hunks.len(), 1);
+
+ // Should have at most 3 context lines before + 3 after + 1 removed + 1 added = 8 lines
+ assert!(preview.hunks[0].lines.len() <= 8);
+ }
+
+ #[test]
+ fn max_line_number_reflects_file_position() {
+ let old = "a\nb\nc\n";
+ let new = "a\nX\nc\n";
+ let preview = EditPreview::compute(old, new);
+ // 3-line file, context + removed lines span positions 1-3
+ assert_eq!(preview.max_line_number(), 3);
+ }
+
+ #[test]
+ fn start_line_is_correct_for_later_changes() {
+ // Change at line 10 with 3 context lines → before_start = 7
+ let mut lines: Vec<String> = (1..=15).map(|i| format!("line{i}")).collect();
+ let old = lines.join("\n") + "\n";
+ lines[9] = "CHANGED".to_string();
+ let new = lines.join("\n") + "\n";
+
+ let preview = EditPreview::compute(&old, &new);
+ assert_eq!(preview.hunks.len(), 1);
+ assert_eq!(preview.hunks[0].before_start, 7); // line 10 - 3 context = line 7
+ assert_eq!(preview.hunks[0].after_start, 7); // same for a simple replacement
+ }
+
+ #[test]
+ fn multiline_replacement() {
+ let old = "[section]\nkey1 = old1\nkey2 = old2\n[other]\n";
+ let new = "[section]\nkey1 = new1\nkey2 = new2\n[other]\n";
+ let preview = EditPreview::compute(old, new);
+
+ assert_eq!(preview.hunks.len(), 1);
+ let hunk = &preview.hunks[0];
+ assert!(
+ hunk.lines
+ .contains(&DiffLine::Removed("key1 = old1".into()))
+ );
+ assert!(
+ hunk.lines
+ .contains(&DiffLine::Removed("key2 = old2".into()))
+ );
+ assert!(hunk.lines.contains(&DiffLine::Added("key1 = new1".into())));
+ assert!(hunk.lines.contains(&DiffLine::Added("key2 = new2".into())));
+ }
+}
diff --git a/crates/atuin-ai/src/edit_permissions.rs b/crates/atuin-ai/src/edit_permissions.rs
new file mode 100644
index 00000000..5015a007
--- /dev/null
+++ b/crates/atuin-ai/src/edit_permissions.rs
@@ -0,0 +1,108 @@
+//! Session-scoped permission cache for file edits.
+//!
+//! When the user selects "Allow this file for this session", the grant is
+//! recorded here with a timestamp. Subsequent edits to the same file skip
+//! the permission prompt as long as the grant hasn't expired.
+//!
+//! Grants are time-limited (1 hour TTL) so they don't outlive the user's
+//! attention in long-running sessions. Persisted as JSON in session
+//! metadata so they survive across CLI invocations.
+
+use std::collections::HashMap;
+use std::path::{Path, PathBuf};
+use std::time::SystemTime;
+
+use eyre::Result;
+use serde::{Deserialize, Serialize};
+
+/// Session metadata key for persistence.
+pub(crate) const METADATA_KEY: &str = "edit_permissions";
+
+/// How long a session-scoped edit permission remains valid.
+const TTL_MS: i64 = 60 * 60 * 1000; // 1 hour
+
+/// Cache of per-file edit permission grants within a session.
+#[derive(Debug, Default, Clone, Serialize, Deserialize)]
+pub(crate) struct EditPermissionCache {
+ /// Maps canonical file paths to the grant timestamp (unix millis).
+ grants: HashMap<PathBuf, i64>,
+}
+
+impl EditPermissionCache {
+ /// Record a permission grant for a file.
+ pub fn grant(&mut self, path: PathBuf) {
+ self.grants.insert(path, now_ms());
+ }
+
+ /// Check whether there's a valid (non-expired) grant for a file.
+ pub fn has_valid_grant(&self, path: &Path) -> bool {
+ if let Some(&granted_at) = self.grants.get(path) {
+ (now_ms() - granted_at) < TTL_MS
+ } else {
+ false
+ }
+ }
+
+ pub fn to_json(&self) -> Result<String> {
+ Ok(serde_json::to_string(self)?)
+ }
+
+ pub fn from_json(json: &str) -> Result<Self> {
+ Ok(serde_json::from_str(json)?)
+ }
+}
+
+fn now_ms() -> i64 {
+ SystemTime::now()
+ .duration_since(SystemTime::UNIX_EPOCH)
+ .map(|d| d.as_millis() as i64)
+ .unwrap_or(0)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn grant_and_check() {
+ let mut cache = EditPermissionCache::default();
+ let path = PathBuf::from("/Users/me/.config/foo.toml");
+
+ assert!(!cache.has_valid_grant(&path));
+ cache.grant(path.clone());
+ assert!(cache.has_valid_grant(&path));
+ }
+
+ #[test]
+ fn different_paths_are_independent() {
+ let mut cache = EditPermissionCache::default();
+ let a = PathBuf::from("/etc/hosts");
+ let b = PathBuf::from("/etc/resolv.conf");
+
+ cache.grant(a.clone());
+ assert!(cache.has_valid_grant(&a));
+ assert!(!cache.has_valid_grant(&b));
+ }
+
+ #[test]
+ fn roundtrip_json() {
+ let mut cache = EditPermissionCache::default();
+ cache.grant(PathBuf::from("/some/file.toml"));
+
+ let json = cache.to_json().unwrap();
+ let restored = EditPermissionCache::from_json(&json).unwrap();
+ assert!(restored.has_valid_grant(Path::new("/some/file.toml")));
+ }
+
+ #[test]
+ fn expired_grant_is_invalid() {
+ let mut cache = EditPermissionCache::default();
+ let path = PathBuf::from("/expired/file.toml");
+
+ // Insert a grant from 2 hours ago
+ let two_hours_ago = now_ms() - (2 * 60 * 60 * 1000);
+ cache.grants.insert(path.clone(), two_hours_ago);
+
+ assert!(!cache.has_valid_grant(&path));
+ }
+}
diff --git a/crates/atuin-ai/src/file_tracker.rs b/crates/atuin-ai/src/file_tracker.rs
new file mode 100644
index 00000000..feee1ee8
--- /dev/null
+++ b/crates/atuin-ai/src/file_tracker.rs
@@ -0,0 +1,234 @@
+//! Tracks which files have been read in the current session, for freshness
+//! checking before edits.
+//!
+//! The tracker records the content hash and mtime of each file at the time
+//! it was last read. Before an edit, the tracker verifies the file hasn't
+//! changed since the last read — catching both external modifications and
+//! concurrent tool calls.
+//!
+//! Persisted as JSON in session metadata so it survives across CLI
+//! invocations within the same logical session.
+
+use std::collections::HashMap;
+use std::path::{Path, PathBuf};
+use std::time::SystemTime;
+
+use eyre::Result;
+use serde::{Deserialize, Serialize};
+
+/// Metadata key used for session_metadata persistence.
+pub(crate) const METADATA_KEY: &str = "file_read_tracker";
+
+/// State recorded for a single file read.
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub(crate) struct FileReadState {
+ /// Hash of the file contents at the time of the last read.
+ pub content_hash: u64,
+ /// File mtime (as milliseconds since epoch) at the time of the last read.
+ /// Millisecond precision ensures sub-second modifications are detected.
+ pub mtime_ms: i64,
+}
+
+/// Tracks file read state for freshness checking.
+#[derive(Debug, Default, Clone, Serialize, Deserialize)]
+pub(crate) struct FileReadTracker {
+ reads: HashMap<PathBuf, FileReadState>,
+}
+
+/// Result of a freshness check.
+pub(crate) enum FreshnessCheck {
+ /// File is fresh — the content hasn't changed since the last read.
+ Fresh,
+ /// File has never been read in this session.
+ NotRead,
+ /// File has been modified since the last read.
+ Stale,
+}
+
+impl FileReadTracker {
+ /// Record that a file was read. Call this after a successful `read_file`
+ /// execution. The `path` should be canonical (absolute, tilde-expanded).
+ pub fn record_read(&mut self, path: PathBuf, content: &[u8], mtime: SystemTime) {
+ let content_hash = hash_content(content);
+ let mtime_ms = system_time_to_ms(mtime);
+
+ self.reads.insert(
+ path,
+ FileReadState {
+ content_hash,
+ mtime_ms,
+ },
+ );
+ }
+
+ /// Check whether a file is fresh (unchanged since last read).
+ ///
+ /// Uses mtime as a fast path — only re-hashes if mtime differs.
+ pub fn check_freshness(&self, path: &Path) -> Result<FreshnessCheck> {
+ let state = match self.reads.get(path) {
+ Some(s) => s,
+ None => return Ok(FreshnessCheck::NotRead),
+ };
+
+ // Stat the file
+ let metadata = match std::fs::metadata(path) {
+ Ok(m) => m,
+ Err(_) => return Ok(FreshnessCheck::Stale), // file deleted or inaccessible
+ };
+
+ let current_mtime_ms =
+ system_time_to_ms(metadata.modified().unwrap_or(SystemTime::UNIX_EPOCH));
+
+ // Fast path: mtime unchanged → fresh
+ if current_mtime_ms == state.mtime_ms {
+ return Ok(FreshnessCheck::Fresh);
+ }
+
+ // Mtime changed — re-hash to confirm
+ let content = std::fs::read(path)?;
+ let current_hash = hash_content(&content);
+
+ if current_hash == state.content_hash {
+ Ok(FreshnessCheck::Fresh)
+ } else {
+ Ok(FreshnessCheck::Stale)
+ }
+ }
+
+ /// Update the tracker entry after a successful edit (new content written).
+ pub fn update_after_edit(&mut self, path: &Path, new_content: &[u8], new_mtime: SystemTime) {
+ let content_hash = hash_content(new_content);
+ let mtime_ms = system_time_to_ms(new_mtime);
+
+ self.reads.insert(
+ path.to_path_buf(),
+ FileReadState {
+ content_hash,
+ mtime_ms,
+ },
+ );
+ }
+
+ /// Serialize to JSON for session metadata persistence.
+ pub fn to_json(&self) -> Result<String> {
+ Ok(serde_json::to_string(self)?)
+ }
+
+ /// Deserialize from JSON session metadata.
+ pub fn from_json(json: &str) -> Result<Self> {
+ Ok(serde_json::from_str(json)?)
+ }
+}
+
+fn system_time_to_ms(t: SystemTime) -> i64 {
+ t.duration_since(SystemTime::UNIX_EPOCH)
+ .map(|d| d.as_millis() as i64)
+ .unwrap_or(0)
+}
+
+fn hash_content(content: &[u8]) -> u64 {
+ xxhash_rust::xxh3::xxh3_64(content)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::io::Write;
+ use tempfile::NamedTempFile;
+
+ #[test]
+ fn record_and_check_fresh() {
+ let mut tracker = FileReadTracker::default();
+ let mut tmp = NamedTempFile::new().unwrap();
+ write!(tmp, "hello world").unwrap();
+
+ let path = tmp.path().to_path_buf();
+ let content = std::fs::read(&path).unwrap();
+ let mtime = std::fs::metadata(&path).unwrap().modified().unwrap();
+
+ tracker.record_read(path.clone(), &content, mtime);
+
+ assert!(matches!(
+ tracker.check_freshness(&path).unwrap(),
+ FreshnessCheck::Fresh
+ ));
+ }
+
+ #[test]
+ fn check_not_read() {
+ let tracker = FileReadTracker::default();
+ let path = PathBuf::from("/nonexistent/file.txt");
+ assert!(matches!(
+ tracker.check_freshness(&path).unwrap(),
+ FreshnessCheck::NotRead
+ ));
+ }
+
+ #[test]
+ fn check_stale_after_modification() {
+ let mut tracker = FileReadTracker::default();
+ let mut tmp = NamedTempFile::new().unwrap();
+ write!(tmp, "original").unwrap();
+
+ let path = tmp.path().to_path_buf();
+ let content = std::fs::read(&path).unwrap();
+ let mtime = std::fs::metadata(&path).unwrap().modified().unwrap();
+
+ tracker.record_read(path.clone(), &content, mtime);
+
+ // Small delay to ensure the filesystem mtime advances
+ std::thread::sleep(std::time::Duration::from_millis(10));
+
+ // Modify the file
+ std::fs::write(&path, "modified").unwrap();
+
+ assert!(matches!(
+ tracker.check_freshness(&path).unwrap(),
+ FreshnessCheck::Stale
+ ));
+ }
+
+ #[test]
+ fn update_after_edit_makes_fresh() {
+ let mut tracker = FileReadTracker::default();
+ let mut tmp = NamedTempFile::new().unwrap();
+ write!(tmp, "original").unwrap();
+
+ let path = tmp.path().to_path_buf();
+ let content = std::fs::read(&path).unwrap();
+ let mtime = std::fs::metadata(&path).unwrap().modified().unwrap();
+
+ tracker.record_read(path.clone(), &content, mtime);
+
+ // Simulate an edit
+ let new_content = b"edited content";
+ std::fs::write(&path, new_content).unwrap();
+ let new_mtime = std::fs::metadata(&path).unwrap().modified().unwrap();
+ tracker.update_after_edit(&path, new_content, new_mtime);
+
+ assert!(matches!(
+ tracker.check_freshness(&path).unwrap(),
+ FreshnessCheck::Fresh
+ ));
+ }
+
+ #[test]
+ fn roundtrip_json() {
+ let mut tracker = FileReadTracker::default();
+ tracker.reads.insert(
+ PathBuf::from("/some/file.toml"),
+ FileReadState {
+ content_hash: 12345,
+ mtime_ms: 1700000000000,
+ },
+ );
+
+ let json = tracker.to_json().unwrap();
+ let restored = FileReadTracker::from_json(&json).unwrap();
+ assert_eq!(restored.reads.len(), 1);
+ assert_eq!(
+ restored.reads[&PathBuf::from("/some/file.toml")].content_hash,
+ 12345
+ );
+ }
+}
diff --git a/crates/atuin-ai/src/lib.rs b/crates/atuin-ai/src/lib.rs
index febb488e..afe9c1e4 100644
--- a/crates/atuin-ai/src/lib.rs
+++ b/crates/atuin-ai/src/lib.rs
@@ -1,9 +1,13 @@
pub mod commands;
pub(crate) mod context;
pub(crate) mod context_window;
+pub(crate) mod diff;
+pub(crate) mod edit_permissions;
pub(crate) mod event_serde;
+pub(crate) mod file_tracker;
pub(crate) mod permissions;
pub(crate) mod session;
+pub(crate) mod snapshots;
pub(crate) mod store;
pub(crate) mod stream;
pub(crate) mod tools;
diff --git a/crates/atuin-ai/src/session.rs b/crates/atuin-ai/src/session.rs
index d8314343..848330fc 100644
--- a/crates/atuin-ai/src/session.rs
+++ b/crates/atuin-ai/src/session.rs
@@ -51,6 +51,9 @@ pub(crate) trait SessionService: Send + Sync {
) -> Result<()>;
async fn archive(&self, session_id: &str) -> Result<()>;
+
+ async fn get_metadata(&self, session_id: &str, key: &str) -> Result<Option<String>>;
+ async fn set_metadata(&self, session_id: &str, key: &str, value: &str) -> Result<()>;
}
// ---------------------------------------------------------------------------
@@ -128,6 +131,14 @@ impl SessionService for LocalSessionService {
async fn archive(&self, session_id: &str) -> Result<()> {
self.store.archive_session(session_id).await
}
+
+ async fn get_metadata(&self, session_id: &str, key: &str) -> Result<Option<String>> {
+ self.store.get_metadata(session_id, key).await
+ }
+
+ async fn set_metadata(&self, session_id: &str, key: &str, value: &str) -> Result<()> {
+ self.store.set_metadata(session_id, key, value).await
+ }
}
// ---------------------------------------------------------------------------
@@ -310,6 +321,22 @@ impl SessionManager {
pub fn invocation_id(&self) -> &str {
&self.invocation_id
}
+
+ /// Read a metadata value for the current session.
+ pub async fn get_metadata(&self, key: &str) -> Result<Option<String>> {
+ if !self.persisted_to_db {
+ return Ok(None);
+ }
+ self.service.get_metadata(&self.session_id, key).await
+ }
+
+ /// Write a metadata value for the current session.
+ pub async fn set_metadata(&mut self, key: &str, value: &str) -> Result<()> {
+ self.ensure_persisted().await?;
+ self.service
+ .set_metadata(&self.session_id, key, value)
+ .await
+ }
}
#[cfg(test)]
diff --git a/crates/atuin-ai/src/snapshots.rs b/crates/atuin-ai/src/snapshots.rs
new file mode 100644
index 00000000..6c7b0c9c
--- /dev/null
+++ b/crates/atuin-ai/src/snapshots.rs
@@ -0,0 +1,414 @@
+//! Backup snapshots for files before AI edits.
+//!
+//! Before the first edit to a file within a session, a snapshot of the
+//! original content is saved so the user can recover if needed. Snapshots
+//! are stored alongside a manifest that maps sanitized filenames back to
+//! their original paths.
+//!
+//! Filenames use percent-encoding (`/` → `%2F`) so the snapshot directory
+//! is human-readable via `ls`.
+
+use std::collections::HashMap;
+use std::io::Write;
+use std::path::{Path, PathBuf};
+
+use eyre::{Result, eyre};
+use serde::{Deserialize, Serialize};
+use time::OffsetDateTime;
+
+/// Snapshot store for a single session.
+///
+/// Each session gets its own directory under the snapshots root:
+/// `<data_dir>/ai/snapshots/<session_id>/`
+///
+/// Files are stored with percent-encoded filenames derived from their
+/// canonical paths, alongside a `manifest.json` that maps filenames
+/// back to original paths with timestamps.
+#[derive(Debug)]
+pub(crate) struct SnapshotStore {
+ session_dir: PathBuf,
+ manifest: SnapshotManifest,
+}
+
+#[derive(Debug, Default, Serialize, Deserialize)]
+struct SnapshotManifest {
+ files: HashMap<String, SnapshotEntry>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+struct SnapshotEntry {
+ original_path: String,
+ snapshot_at: String,
+ size_bytes: u64,
+}
+
+impl SnapshotStore {
+ /// Open or create a snapshot store for the given session directory.
+ ///
+ /// If a manifest already exists (from a prior CLI invocation in the same
+ /// session), it's loaded so we don't re-snapshot files that were already
+ /// backed up.
+ pub fn open(session_dir: PathBuf) -> Result<Self> {
+ let manifest_path = session_dir.join("manifest.json");
+ let manifest = if manifest_path.exists() {
+ let data = fs_err::read_to_string(&manifest_path)?;
+ serde_json::from_str(&data)?
+ } else {
+ SnapshotManifest::default()
+ };
+
+ Ok(Self {
+ session_dir,
+ manifest,
+ })
+ }
+
+ /// Snapshot a file's contents if it hasn't been snapshotted yet this session.
+ ///
+ /// Returns `true` if a new snapshot was created, `false` if one already
+ /// existed. The `canonical_path` should be absolute (already tilde-expanded
+ /// and resolved).
+ pub fn ensure_snapshot(&mut self, canonical_path: &Path, content: &[u8]) -> Result<bool> {
+ let filename = sanitize_path(canonical_path);
+
+ if self.manifest.files.contains_key(&filename) {
+ return Ok(false);
+ }
+
+ fs_err::create_dir_all(&self.session_dir)?;
+
+ let snapshot_path = self.session_dir.join(&filename);
+ atomic_write_file(&snapshot_path, content)?;
+
+ let now = OffsetDateTime::now_utc();
+ let entry = SnapshotEntry {
+ original_path: canonical_path.to_string_lossy().into_owned(),
+ snapshot_at: format_iso8601(now),
+ size_bytes: content.len() as u64,
+ };
+
+ self.manifest.files.insert(filename, entry);
+ self.save_manifest()?;
+
+ Ok(true)
+ }
+
+ /// Whether a file has already been snapshotted in this session.
+ #[expect(dead_code)]
+ pub fn has_snapshot(&self, canonical_path: &Path) -> bool {
+ let filename = sanitize_path(canonical_path);
+ self.manifest.files.contains_key(&filename)
+ }
+
+ fn save_manifest(&self) -> Result<()> {
+ let json = serde_json::to_string_pretty(&self.manifest)?;
+ atomic_write_file(&self.session_dir.join("manifest.json"), json.as_bytes())
+ }
+}
+
+/// Percent-encode a path for use as a filename.
+///
+/// Encodes `%` as `%25`, `/` as `%2F`, and `\` as `%5C`, then strips
+/// leading separators and drive prefixes (e.g. `C:\`). The result is
+/// always a flat filename safe for use with `Path::join` on any platform.
+///
+/// Example (Unix): `/Users/me/.config/foo.toml` → `Users%2Fme%2F.config%2Ffoo.toml`
+/// Example (Windows): `C:\Users\me\config.toml` → `Users%5Cme%5Cconfig.toml`
+pub(crate) fn sanitize_path(path: &Path) -> String {
+ let s = path.to_string_lossy();
+ // Strip drive letter prefix on Windows (e.g. "C:\")
+ let s = s.strip_prefix('/').unwrap_or_else(|| {
+ // Handle Windows drive prefix like "C:\" or "C:/"
+ if s.len() >= 3
+ && s.as_bytes()[0].is_ascii_alphabetic()
+ && s.as_bytes()[1] == b':'
+ && (s.as_bytes()[2] == b'\\' || s.as_bytes()[2] == b'/')
+ {
+ &s[3..]
+ } else {
+ &s
+ }
+ });
+ s.replace('%', "%25")
+ .replace('/', "%2F")
+ .replace('\\', "%5C")
+}
+
+/// Write a file atomically using temp-file-then-rename.
+///
+/// Creates a temporary file in the same directory as `target`, writes
+/// content, fsyncs, then renames into place. Preserves permissions from
+/// the original file if it exists.
+pub(crate) fn atomic_write_file(target: &Path, content: &[u8]) -> Result<()> {
+ let dir = target
+ .parent()
+ .ok_or_else(|| eyre!("target path has no parent directory"))?;
+ fs_err::create_dir_all(dir)?;
+
+ let mut tmp = tempfile::NamedTempFile::new_in(dir)?;
+ tmp.write_all(content)?;
+ tmp.as_file().sync_all()?;
+
+ // Preserve permissions from original if it exists
+ if let Ok(meta) = std::fs::metadata(target) {
+ std::fs::set_permissions(tmp.path(), meta.permissions())?;
+ }
+
+ tmp.persist(target).map_err(|e| {
+ eyre!(
+ "failed to persist atomic write to {}: {}",
+ target.display(),
+ e
+ )
+ })?;
+ Ok(())
+}
+
+fn format_iso8601(dt: OffsetDateTime) -> String {
+ format!(
+ "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z",
+ dt.year(),
+ dt.month() as u8,
+ dt.day(),
+ dt.hour(),
+ dt.minute(),
+ dt.second(),
+ )
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ // ── sanitize_path ──────────────────────────────────────────
+
+ #[test]
+ fn sanitize_absolute_path() {
+ let path = Path::new("/Users/me/.config/atuin/config.toml");
+ assert_eq!(
+ sanitize_path(path),
+ "Users%2Fme%2F.config%2Fatuin%2Fconfig.toml"
+ );
+ }
+
+ #[test]
+ fn sanitize_preserves_existing_percent() {
+ let path = Path::new("/data/100%done/file.txt");
+ assert_eq!(sanitize_path(path), "data%2F100%25done%2Ffile.txt");
+ }
+
+ #[test]
+ fn sanitize_relative_path() {
+ let path = Path::new("relative/path.txt");
+ assert_eq!(sanitize_path(path), "relative%2Fpath.txt");
+ }
+
+ #[test]
+ fn sanitize_no_collision_between_similar_paths() {
+ let a = sanitize_path(Path::new("/foo/bar-baz"));
+ let b = sanitize_path(Path::new("/foo/bar/baz"));
+ assert_ne!(a, b);
+ }
+
+ #[test]
+ fn sanitize_backslash_encoded() {
+ // Windows-style path: backslashes become %5C, drive prefix stripped
+ let s = sanitize_path(Path::new("C:\\Users\\me\\config.toml"));
+ assert!(!s.contains('\\'), "backslashes must be encoded: {s}");
+ assert!(!s.starts_with("C:"), "drive prefix must be stripped: {s}");
+ assert!(s.contains("Users"));
+ assert!(s.contains("config.toml"));
+ }
+
+ #[test]
+ fn sanitize_result_is_flat_filename() {
+ // The result must not be interpreted as a path with separators
+ // when passed to Path::join — no raw / or \ allowed.
+ let unix = sanitize_path(Path::new("/home/user/file.txt"));
+ assert!(!unix.contains('/'));
+ // Construct as if on Windows
+ let win = "C:\\Users\\me\\file.txt";
+ let encoded = win
+ .strip_prefix("C:\\")
+ .unwrap()
+ .replace('%', "%25")
+ .replace('/', "%2F")
+ .replace('\\', "%5C");
+ assert!(!encoded.contains('\\'));
+ assert!(!encoded.contains('/'));
+ }
+
+ // ── atomic_write_file ──────────────────────────────────────
+
+ #[test]
+ fn atomic_write_creates_file() {
+ let dir = tempfile::tempdir().unwrap();
+ let target = dir.path().join("test.txt");
+
+ atomic_write_file(&target, b"hello world").unwrap();
+
+ assert_eq!(std::fs::read_to_string(&target).unwrap(), "hello world");
+ }
+
+ #[test]
+ fn atomic_write_overwrites_existing() {
+ let dir = tempfile::tempdir().unwrap();
+ let target = dir.path().join("test.txt");
+
+ std::fs::write(&target, "old content").unwrap();
+ atomic_write_file(&target, b"new content").unwrap();
+
+ assert_eq!(std::fs::read_to_string(&target).unwrap(), "new content");
+ }
+
+ #[test]
+ fn atomic_write_creates_parent_dirs() {
+ let dir = tempfile::tempdir().unwrap();
+ let target = dir.path().join("sub").join("dir").join("test.txt");
+
+ atomic_write_file(&target, b"nested").unwrap();
+
+ assert_eq!(std::fs::read_to_string(&target).unwrap(), "nested");
+ }
+
+ #[cfg(unix)]
+ #[test]
+ fn atomic_write_preserves_permissions() {
+ use std::os::unix::fs::PermissionsExt;
+
+ let dir = tempfile::tempdir().unwrap();
+ let target = dir.path().join("test.txt");
+
+ std::fs::write(&target, "original").unwrap();
+ std::fs::set_permissions(&target, std::fs::Permissions::from_mode(0o600)).unwrap();
+
+ atomic_write_file(&target, b"updated").unwrap();
+
+ let mode = std::fs::metadata(&target).unwrap().permissions().mode() & 0o777;
+ assert_eq!(mode, 0o600);
+ }
+
+ // ── SnapshotStore ──────────────────────────────────────────
+
+ #[test]
+ fn snapshot_creates_file_and_manifest() {
+ let dir = tempfile::tempdir().unwrap();
+ let session_dir = dir.path().join("session-abc");
+ let mut store = SnapshotStore::open(session_dir.clone()).unwrap();
+
+ let file_path = Path::new("/Users/me/.config/foo.toml");
+ let created = store
+ .ensure_snapshot(file_path, b"[key]\nval = 1\n")
+ .unwrap();
+
+ assert!(created);
+ assert!(store.has_snapshot(file_path));
+
+ // Snapshot file on disk
+ let expected_file = session_dir.join("Users%2Fme%2F.config%2Ffoo.toml");
+ assert!(expected_file.exists());
+ assert_eq!(
+ std::fs::read_to_string(&expected_file).unwrap(),
+ "[key]\nval = 1\n"
+ );
+
+ // Manifest on disk
+ let manifest_path = session_dir.join("manifest.json");
+ assert!(manifest_path.exists());
+ let manifest: serde_json::Value =
+ serde_json::from_str(&std::fs::read_to_string(&manifest_path).unwrap()).unwrap();
+ let files = manifest["files"].as_object().unwrap();
+ assert_eq!(files.len(), 1);
+ let entry = &files["Users%2Fme%2F.config%2Ffoo.toml"];
+ assert_eq!(
+ entry["original_path"].as_str().unwrap(),
+ "/Users/me/.config/foo.toml"
+ );
+ assert_eq!(entry["size_bytes"].as_u64().unwrap(), 14);
+ }
+
+ #[test]
+ fn snapshot_is_idempotent() {
+ let dir = tempfile::tempdir().unwrap();
+ let session_dir = dir.path().join("session-abc");
+ let mut store = SnapshotStore::open(session_dir.clone()).unwrap();
+
+ let path = Path::new("/etc/hosts");
+ let first = store.ensure_snapshot(path, b"first content").unwrap();
+ let second = store.ensure_snapshot(path, b"different content").unwrap();
+
+ assert!(first);
+ assert!(!second);
+
+ // Original content preserved, not overwritten
+ let snapshot_file = session_dir.join("etc%2Fhosts");
+ assert_eq!(
+ std::fs::read_to_string(snapshot_file).unwrap(),
+ "first content"
+ );
+ }
+
+ #[test]
+ fn snapshot_store_loads_existing_manifest() {
+ let dir = tempfile::tempdir().unwrap();
+ let session_dir = dir.path().join("session-abc");
+
+ // First store: create a snapshot
+ {
+ let mut store = SnapshotStore::open(session_dir.clone()).unwrap();
+ store
+ .ensure_snapshot(Path::new("/etc/hosts"), b"127.0.0.1")
+ .unwrap();
+ }
+
+ // Second store (simulates new CLI invocation): should see existing snapshot
+ {
+ let mut store = SnapshotStore::open(session_dir).unwrap();
+ assert!(store.has_snapshot(Path::new("/etc/hosts")));
+
+ let created = store
+ .ensure_snapshot(Path::new("/etc/hosts"), b"new content")
+ .unwrap();
+ assert!(!created);
+ }
+ }
+
+ #[test]
+ fn snapshot_multiple_files() {
+ let dir = tempfile::tempdir().unwrap();
+ let session_dir = dir.path().join("session-abc");
+ let mut store = SnapshotStore::open(session_dir.clone()).unwrap();
+
+ store
+ .ensure_snapshot(Path::new("/etc/hosts"), b"hosts content")
+ .unwrap();
+ store
+ .ensure_snapshot(Path::new("/Users/me/.bashrc"), b"bashrc content")
+ .unwrap();
+
+ assert!(store.has_snapshot(Path::new("/etc/hosts")));
+ assert!(store.has_snapshot(Path::new("/Users/me/.bashrc")));
+ assert!(!store.has_snapshot(Path::new("/nonexistent")));
+
+ // Both snapshot files exist
+ assert!(session_dir.join("etc%2Fhosts").exists());
+ assert!(session_dir.join("Users%2Fme%2F.bashrc").exists());
+
+ // Manifest has both entries
+ let manifest: serde_json::Value = serde_json::from_str(
+ &std::fs::read_to_string(session_dir.join("manifest.json")).unwrap(),
+ )
+ .unwrap();
+ assert_eq!(manifest["files"].as_object().unwrap().len(), 2);
+ }
+
+ #[test]
+ fn format_iso8601_produces_valid_format() {
+ let dt = OffsetDateTime::from_unix_timestamp(1700000000).unwrap();
+ let formatted = format_iso8601(dt);
+ assert_eq!(formatted.len(), 20);
+ assert!(formatted.starts_with("2023-"));
+ assert!(formatted.contains('T'));
+ assert!(formatted.ends_with('Z'));
+ }
+}
diff --git a/crates/atuin-ai/src/store.rs b/crates/atuin-ai/src/store.rs
index 2a75d8f4..20b9e881 100644
--- a/crates/atuin-ai/src/store.rs
+++ b/crates/atuin-ai/src/store.rs
@@ -299,6 +299,38 @@ impl AiSessionStore {
.await?;
Ok(())
}
+
+ // ── Session metadata (key-value per session) ──
+
+ /// Read a metadata value for a session. Returns `None` if the key doesn't
+ /// exist or the session hasn't been persisted yet.
+ pub async fn get_metadata(&self, session_id: &str, key: &str) -> Result<Option<String>> {
+ let row: Option<(String,)> =
+ sqlx::query_as("SELECT value FROM session_metadata WHERE session_id = ?1 AND key = ?2")
+ .bind(session_id)
+ .bind(key)
+ .fetch_optional(&self.pool)
+ .await?;
+
+ Ok(row.map(|(v,)| v))
+ }
+
+ /// Write a metadata value for a session (upsert).
+ pub async fn set_metadata(&self, session_id: &str, key: &str, value: &str) -> Result<()> {
+ let now = OffsetDateTime::now_utc().unix_timestamp();
+ sqlx::query(
+ "INSERT INTO session_metadata (session_id, key, value, updated_at)
+ VALUES (?1, ?2, ?3, ?4)
+ ON CONFLICT (session_id, key) DO UPDATE SET value = ?3, updated_at = ?4",
+ )
+ .bind(session_id)
+ .bind(key)
+ .bind(value)
+ .bind(now)
+ .execute(&self.pool)
+ .await?;
+ Ok(())
+ }
}
#[cfg(test)]
diff --git a/crates/atuin-ai/src/stream.rs b/crates/atuin-ai/src/stream.rs
index f4f4d704..24770abe 100644
--- a/crates/atuin-ai/src/stream.rs
+++ b/crates/atuin-ai/src/stream.rs
@@ -74,6 +74,14 @@ impl ChatRequest {
if capabilities.enable_history_search.unwrap_or(true) {
caps.push("client_v1_atuin_history".to_string());
}
+ if capabilities.enable_file_tools.unwrap_or(true) {
+ caps.push("client_v1_read_file".to_string());
+ caps.push("client_v1_edit_file".to_string());
+ caps.push("client_v1_write_file".to_string());
+ }
+ if capabilities.enable_command_execution.unwrap_or(true) {
+ caps.push("client_v1_execute_shell_command".to_string());
+ }
if let Ok(extra) = std::env::var("ATUIN_AI__ADDITIONAL_CAPS") {
caps.extend(
extra
diff --git a/crates/atuin-ai/src/tools/descriptor.rs b/crates/atuin-ai/src/tools/descriptor.rs
index fc44ec10..6ccb595f 100644
--- a/crates/atuin-ai/src/tools/descriptor.rs
+++ b/crates/atuin-ai/src/tools/descriptor.rs
@@ -16,6 +16,7 @@ pub(crate) struct ToolDescriptor {
/// Past-tense verb for summaries (e.g. "Read file").
pub past_verb: &'static str,
/// Whether this tool is executed client-side (by the CLI).
+ #[expect(dead_code)]
pub is_client: bool,
}
@@ -30,9 +31,18 @@ pub(crate) const READ: &ToolDescriptor = &ToolDescriptor {
is_client: true,
};
+pub(crate) const EDIT: &ToolDescriptor = &ToolDescriptor {
+ canonical_names: &["edit_file"],
+ capability: Some("client_v1_edit_file"),
+ display_verb: "edit",
+ progressive_verb: "Editing file...",
+ past_verb: "Edited file",
+ is_client: true,
+};
+
pub(crate) const WRITE: &ToolDescriptor = &ToolDescriptor {
- canonical_names: &["str_replace", "file_create", "file_insert"],
- capability: Some("client_v1_write"),
+ canonical_names: &["write_file"],
+ capability: Some("client_v1_write_file"),
display_verb: "write to",
progressive_verb: "Writing file...",
past_verb: "Wrote file",
@@ -41,7 +51,7 @@ pub(crate) const WRITE: &ToolDescriptor = &ToolDescriptor {
pub(crate) const SHELL: &ToolDescriptor = &ToolDescriptor {
canonical_names: &["execute_shell_command"],
- capability: Some("client_v1_shell"),
+ capability: Some("client_v1_execute_shell_command"),
display_verb: "run",
progressive_verb: "Running command...",
past_verb: "Ran command",
@@ -81,6 +91,7 @@ pub(crate) const SERVER_SCRAPE: &ToolDescriptor = &ToolDescriptor {
/// All known tool descriptors, for lookup by name.
const ALL_DESCRIPTORS: &[&ToolDescriptor] = &[
READ,
+ EDIT,
WRITE,
SHELL,
ATUIN_HISTORY,
diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs
index 8f2183b7..8fe1ad73 100644
--- a/crates/atuin-ai/src/tools/mod.rs
+++ b/crates/atuin-ai/src/tools/mod.rs
@@ -169,6 +169,8 @@ pub(crate) struct TrackedTool {
pub phase: ToolPhase,
/// Sender to interrupt a running shell command (only set during ExecutingWithPreview).
pub abort_tx: Option<tokio::sync::oneshot::Sender<()>>,
+ /// Diff preview for completed edit tool calls.
+ pub edit_preview: Option<crate::diff::EditPreview>,
}
impl TrackedTool {
@@ -234,6 +236,7 @@ impl ToolTracker {
tool,
phase: ToolPhase::CheckingPermissions,
abort_tx: None,
+ edit_preview: None,
});
}
@@ -294,11 +297,6 @@ impl ToolTracker {
.find(|t| t.phase == ToolPhase::AskingForPermission)
}
- /// Get the preview for a tool by ID (live or cached).
- pub fn preview_for(&self, id: &str) -> Option<ToolPreview> {
- self.get(id)?.preview()
- }
-
/// Iterate mutably over all tracked tools.
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut TrackedTool> {
self.tools.iter_mut()
@@ -309,6 +307,7 @@ impl ToolTracker {
#[derive(Debug, Clone)]
pub(crate) enum ClientToolCall {
Read(ReadToolCall),
+ Edit(EditToolCall),
Write(WriteToolCall),
Shell(ShellToolCall),
AtuinHistory(AtuinHistoryToolCall),
@@ -320,9 +319,8 @@ impl TryFrom<(&str, &serde_json::Value)> for ClientToolCall {
fn try_from((name, input): (&str, &serde_json::Value)) -> Result<Self, Self::Error> {
match name {
"read_file" => Ok(ClientToolCall::Read(ReadToolCall::try_from(input)?)),
- "create_file" => Ok(ClientToolCall::Write(WriteToolCall::try_from(input)?)),
- // "append_to_file" => Ok(ClientToolCall::Append(AppendToolCall::try_from(input)?)),
- // "str_replace" => Ok(ClientToolCall::StrReplace(StrReplaceToolCall::try_from(input)?)),
+ "edit_file" => Ok(ClientToolCall::Edit(EditToolCall::try_from(input)?)),
+ "write_file" => Ok(ClientToolCall::Write(WriteToolCall::try_from(input)?)),
"execute_shell_command" => Ok(ClientToolCall::Shell(ShellToolCall::try_from(input)?)),
"atuin_history" => Ok(ClientToolCall::AtuinHistory(
AtuinHistoryToolCall::try_from(input)?,
@@ -336,17 +334,22 @@ impl ClientToolCall {
pub(crate) fn descriptor(&self) -> &'static descriptor::ToolDescriptor {
match self {
ClientToolCall::Read(_) => descriptor::READ,
+ ClientToolCall::Edit(_) => descriptor::EDIT,
ClientToolCall::Write(_) => descriptor::WRITE,
ClientToolCall::Shell(_) => descriptor::SHELL,
ClientToolCall::AtuinHistory(_) => descriptor::ATUIN_HISTORY,
}
}
- /// The permission rule name for this tool category (e.g. "Write" covers
- /// str_replace, file_create, file_insert).
+ /// The permission rule name for this tool category.
+ ///
+ /// Edit and Write share the `"Write"` rule name — a Write permission
+ /// covers both str_replace edits and full file creates. Write also
+ /// implies Read (checked in `ReadToolCall::matches_rule`).
pub(crate) fn rule_name(&self) -> &'static str {
match self {
ClientToolCall::Read(_) => "Read",
+ ClientToolCall::Edit(_) => "Write",
ClientToolCall::Write(_) => "Write",
ClientToolCall::Shell(_) => "Shell",
ClientToolCall::AtuinHistory(_) => "AtuinHistory",
@@ -356,6 +359,7 @@ impl ClientToolCall {
pub(crate) fn matches_rule(&self, rule: &Rule) -> bool {
match self {
ClientToolCall::Read(tool) => tool.matches_rule(rule),
+ ClientToolCall::Edit(tool) => tool.matches_rule(rule),
ClientToolCall::Write(tool) => tool.matches_rule(rule),
ClientToolCall::Shell(tool) => tool.matches_rule(rule),
ClientToolCall::AtuinHistory(tool) => tool.matches_rule(rule),
@@ -365,6 +369,7 @@ impl ClientToolCall {
pub(crate) fn target_dir(&self) -> Option<&Path> {
match self {
ClientToolCall::Read(tool) => tool.target_dir(),
+ ClientToolCall::Edit(tool) => tool.target_dir(),
ClientToolCall::Write(tool) => tool.target_dir(),
ClientToolCall::Shell(tool) => tool.target_dir(),
ClientToolCall::AtuinHistory(tool) => tool.target_dir(),
@@ -401,6 +406,14 @@ impl PermissableToolCall for ClientToolCall {
}
}
+/// Expand shell constructs (`~`, `$HOME`, etc.) in a path string.
+///
+/// Tool call paths arrive as raw strings from the API without shell
+/// expansion. Uses `shellexpand` (same as `atuin-client`).
+fn expand_path(path: &str) -> PathBuf {
+ PathBuf::from(shellexpand::tilde(path).into_owned())
+}
+
#[derive(Debug, Clone)]
pub(crate) struct ReadToolCall {
pub path: PathBuf,
@@ -425,7 +438,7 @@ impl TryFrom<&serde_json::Value> for ReadToolCall {
.min(MAX_FILE_READ_LINES);
Ok(ReadToolCall {
- path: PathBuf::from(path),
+ path: expand_path(path),
offset,
limit,
})
@@ -499,7 +512,207 @@ impl PermissableToolCall for ReadToolCall {
}
fn matches_rule(&self, rule: &Rule) -> bool {
- if rule.tool != "Read" {
+ // Write implies Read — a Write permission on a path also permits reading it.
+ if rule.tool != "Read" && rule.tool != "Write" {
+ return false;
+ }
+
+ match rule.scope.as_deref() {
+ None | Some("*") => true,
+ Some(scope) => path_matches_scope(&self.path, scope),
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub(crate) struct EditToolCall {
+ pub path: PathBuf,
+ pub old_string: String,
+ pub new_string: String,
+ pub replace_all: bool,
+}
+
+impl TryFrom<&serde_json::Value> for EditToolCall {
+ type Error = eyre::Error;
+
+ fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> {
+ let path = value
+ .get("file_path")
+ .and_then(|v| v.as_str())
+ .ok_or(eyre::eyre!("Missing file_path"))?;
+
+ let old_string = value
+ .get("old_string")
+ .and_then(|v| v.as_str())
+ .ok_or(eyre::eyre!("Missing old_string"))?;
+
+ let new_string = value
+ .get("new_string")
+ .and_then(|v| v.as_str())
+ .ok_or(eyre::eyre!("Missing new_string"))?;
+
+ let replace_all = value
+ .get("replace_all")
+ .and_then(|v| v.as_bool())
+ .unwrap_or(false);
+
+ Ok(EditToolCall {
+ path: expand_path(path),
+ old_string: old_string.to_string(),
+ new_string: new_string.to_string(),
+ replace_all,
+ })
+ }
+}
+
+impl EditToolCall {
+ /// Resolve the edit path to an absolute path.
+ pub fn resolved_path(&self) -> PathBuf {
+ if self.path.is_relative() {
+ std::env::current_dir()
+ .map(|cwd| cwd.join(&self.path))
+ .unwrap_or_else(|_| self.path.clone())
+ } else {
+ self.path.clone()
+ }
+ }
+
+ /// Execute the edit against the filesystem.
+ ///
+ /// Checks freshness via the provided tracker, validates matches, applies
+ /// the replacement, and writes atomically. Returns the outcome and (on
+ /// success) the new file content bytes for tracker updates.
+ ///
+ /// Callers should snapshot the file before calling this method and
+ /// update the file tracker after a successful return.
+ pub fn execute(
+ &self,
+ resolved_path: &Path,
+ file_tracker: &crate::file_tracker::FileReadTracker,
+ ) -> (ToolOutcome, Option<Vec<u8>>) {
+ use crate::file_tracker::FreshnessCheck;
+
+ // 1. Basic validation
+ if !resolved_path.exists() {
+ return (
+ ToolOutcome::Error(format!(
+ "Error: file does not exist: {}",
+ resolved_path.display()
+ )),
+ None,
+ );
+ }
+ if resolved_path.is_dir() {
+ return (
+ ToolOutcome::Error(format!(
+ "Error: path is a directory, not a file: {}",
+ resolved_path.display()
+ )),
+ None,
+ );
+ }
+ if self.old_string.is_empty() {
+ return (
+ ToolOutcome::Error(
+ "old_string must not be empty. To create a new file, use write_file instead."
+ .to_string(),
+ ),
+ None,
+ );
+ }
+
+ // 2. Freshness check
+ match file_tracker.check_freshness(resolved_path) {
+ Ok(FreshnessCheck::NotRead) => {
+ return (
+ ToolOutcome::Error(
+ "File has not been read yet. Read it first before editing.".to_string(),
+ ),
+ None,
+ );
+ }
+ Ok(FreshnessCheck::Stale) => {
+ return (
+ ToolOutcome::Error(
+ "File has been modified since read, either by the user or by a linter. Read it again before attempting to edit it.".to_string(),
+ ),
+ None,
+ );
+ }
+ Err(e) => {
+ return (
+ ToolOutcome::Error(format!("Error checking file state: {e}")),
+ None,
+ );
+ }
+ Ok(FreshnessCheck::Fresh) => {}
+ }
+
+ // 3. Read current contents
+ let content = match std::fs::read_to_string(resolved_path) {
+ Ok(c) => c,
+ Err(e) => return (ToolOutcome::Error(format!("Error reading file: {e}")), None),
+ };
+
+ // 4. Find and validate matches
+ let match_count = content.matches(&self.old_string).count();
+
+ if match_count == 0 {
+ return (
+ ToolOutcome::Error(format!(
+ "old_string not found in {}. Make sure it matches exactly, including whitespace and indentation.",
+ resolved_path.display()
+ )),
+ None,
+ );
+ }
+
+ if match_count > 1 && !self.replace_all {
+ return (
+ ToolOutcome::Error(format!(
+ "Found {match_count} matches of old_string in {}, but replace_all is false. Either provide more context to make the match unique, or set replace_all to true.",
+ resolved_path.display()
+ )),
+ None,
+ );
+ }
+
+ // 5. Apply replacement
+ let new_content = if self.replace_all {
+ content.replace(&self.old_string, &self.new_string)
+ } else {
+ content.replacen(&self.old_string, &self.new_string, 1)
+ };
+
+ // 6. Write atomically
+ let new_bytes = new_content.into_bytes();
+ if let Err(e) = crate::snapshots::atomic_write_file(resolved_path, &new_bytes) {
+ return (ToolOutcome::Error(format!("Error writing file: {e}")), None);
+ }
+
+ // 7. Success
+ let verb = if match_count == 1 {
+ "occurrence"
+ } else {
+ "occurrences"
+ };
+ (
+ ToolOutcome::Success(format!(
+ "Edited {}: replaced {match_count} {verb} of old_string with new_string.",
+ resolved_path.display()
+ )),
+ Some(new_bytes),
+ )
+ }
+}
+
+impl PermissableToolCall for EditToolCall {
+ fn target_dir(&self) -> Option<&Path> {
+ Some(&self.path)
+ }
+
+ fn matches_rule(&self, rule: &Rule) -> bool {
+ if rule.tool != "Write" {
return false;
}
@@ -532,7 +745,7 @@ impl TryFrom<&serde_json::Value> for WriteToolCall {
.ok_or(eyre::eyre!("Missing content"))?;
Ok(WriteToolCall {
- path: PathBuf::from(path),
+ path: expand_path(path),
content: content.to_string(),
})
}
@@ -560,6 +773,9 @@ pub(crate) struct ShellToolCall {
pub dir: Option<PathBuf>,
pub command: String,
pub shell: String,
+ // allow dead code here; this will be tied into o11y and user-facing descriptions
+ #[expect(dead_code)]
+ pub description: Option<String>,
}
impl TryFrom<&serde_json::Value> for ShellToolCall {
@@ -579,10 +795,16 @@ impl TryFrom<&serde_json::Value> for ShellToolCall {
.unwrap_or("bash")
.to_string();
+ let description = value
+ .get("description")
+ .and_then(|v| v.as_str())
+ .map(|s| s.to_string());
+
Ok(ShellToolCall {
- dir: dir.map(PathBuf::from),
+ dir: dir.map(expand_path),
command: command.to_string(),
shell,
+ description,
})
}
}
@@ -614,7 +836,34 @@ const PREVIEW_HEIGHT: u16 = 10;
/// Default terminal width for VT100 emulation.
const PREVIEW_WIDTH: u16 = 120;
+/// Normalize newlines for VT100 processing.
+///
+/// When subprocess output is captured via pipes (no PTY), bare `\n` (LF) bytes
+/// are not translated to `\r\n` (CR+LF) the way a kernel terminal driver would
+/// with the `ONLCR` flag. In VT100, LF only moves the cursor down without
+/// returning to column 0. This causes lines to start at progressively higher
+/// column offsets and eventually wrap, producing garbled output.
+///
+/// This function inserts `\r` before any `\n` that isn't already preceded by
+/// `\r`, mimicking the terminal driver's ONLCR behavior.
+fn normalize_newlines_for_vt100(data: &[u8]) -> Vec<u8> {
+ let mut out = Vec::with_capacity(data.len() + data.len() / 8);
+ for (i, &b) in data.iter().enumerate() {
+ if b == b'\n' && (i == 0 || data[i - 1] != b'\r') {
+ out.push(b'\r');
+ }
+ out.push(b);
+ }
+ out
+}
+
/// Extract plain text lines from a VT100 screen buffer.
+///
+/// Strips trailing blank lines so the result only contains rows with actual
+/// content. Without this, the fixed-size VT100 screen (PREVIEW_HEIGHT rows)
+/// would always return that many lines, and downstream components that use
+/// tail-mode display (like the Viewport) would show the blank padding rows
+/// instead of the real output.
fn vt100_screen_lines(screen: &vt100::Screen) -> Vec<String> {
let (rows, cols) = screen.size();
let mut lines = Vec::with_capacity(rows as usize);
@@ -625,9 +874,11 @@ fn vt100_screen_lines(screen: &vt100::Screen) -> Vec<String> {
line.push_str(cell.contents());
}
}
- // Trim trailing whitespace for cleaner display
lines.push(line.trim_end().to_string());
}
+ while lines.last().is_some_and(|l| l.is_empty()) {
+ lines.pop();
+ }
lines
}
@@ -640,12 +891,17 @@ fn strip_ansi_via_vt100(raw: &[u8]) -> String {
if raw.is_empty() {
return String::new();
}
- // Use the contents_formatted → screen approach: feed bytes into a parser
- // with enough rows to hold everything, then read back the plain text.
- // Estimate rows: one row per ~PREVIEW_WIDTH bytes, plus generous padding.
- let estimated_rows = (raw.len() / PREVIEW_WIDTH as usize + 1).min(10_000) as u16;
+ // Normalize bare LF to CR+LF so lines start at column 0 in the VT100 screen.
+ let normalized = normalize_newlines_for_vt100(raw);
+ // Feed bytes into a VT100 parser large enough to hold all output, then
+ // read back the plain text. We estimate rows from the number of newlines
+ // (not total byte length) because real output typically has short lines
+ // that would be severely under-counted by a bytes÷width estimate.
+ let newline_count = normalized.iter().filter(|&&b| b == b'\n').count();
+ let wrap_estimate = normalized.len() / PREVIEW_WIDTH as usize;
+ let estimated_rows = (newline_count + wrap_estimate + 1).min(10_000) as u16;
let mut parser = vt100::Parser::new(estimated_rows, PREVIEW_WIDTH, 0);
- parser.process(raw);
+ parser.process(&normalized);
let screen = parser.screen();
// screen.contents() returns the full plain-text content with trailing
// whitespace trimmed per line and trailing blank lines removed.
@@ -727,7 +983,8 @@ pub(crate) async fn execute_shell_command_streaming(
Ok(0) => stdout_done = true,
Ok(n) => {
full_stdout.extend_from_slice(&stdout_buf[..n]);
- parser.process(&stdout_buf[..n]);
+ let normalized = normalize_newlines_for_vt100(&stdout_buf[..n]);
+ parser.process(&normalized);
}
Err(_) => stdout_done = true,
}
@@ -740,7 +997,8 @@ pub(crate) async fn execute_shell_command_streaming(
Ok(n) => {
full_stderr.extend_from_slice(&stderr_buf[..n]);
// Feed stderr to the preview parser too, so it shows in the VT100 screen
- parser.process(&stderr_buf[..n]);
+ let normalized = normalize_newlines_for_vt100(&stderr_buf[..n]);
+ parser.process(&normalized);
}
Err(_) => stderr_done = true,
}
@@ -967,7 +1225,7 @@ mod tests {
fn read_tool(path: &str) -> ReadToolCall {
ReadToolCall {
- path: PathBuf::from(path),
+ path: expand_path(path),
offset: 0,
limit: 100,
}
@@ -975,7 +1233,7 @@ mod tests {
fn write_tool(path: &str) -> WriteToolCall {
WriteToolCall {
- path: PathBuf::from(path),
+ path: expand_path(path),
content: String::new(),
}
}
@@ -994,12 +1252,26 @@ mod tests {
}
#[test]
- fn wrong_tool_never_matches() {
- assert!(!read_tool("foo.txt").matches_rule(&write_rule(None)));
+ fn write_implies_read() {
+ // A Write rule also permits reads on the same path
+ assert!(read_tool("foo.txt").matches_rule(&write_rule(None)));
+ // But a Read rule does not permit writes
assert!(!write_tool("foo.txt").matches_rule(&read_rule(None)));
}
#[test]
+ fn edit_uses_write_rule() {
+ let edit = EditToolCall {
+ path: expand_path("/home/user/config.toml"),
+ old_string: "x".into(),
+ new_string: "y".into(),
+ replace_all: false,
+ };
+ assert!(edit.matches_rule(&write_rule(None)));
+ assert!(!edit.matches_rule(&read_rule(None)));
+ }
+
+ #[test]
fn extension_glob() {
assert!(read_tool("notes.md").matches_rule(&read_rule(Some("*.md"))));
assert!(!read_tool("notes.txt").matches_rule(&read_rule(Some("*.md"))));
@@ -1050,6 +1322,419 @@ mod tests {
}
}
+ // ── edit_file execution tests ──
+
+ mod edit {
+ use super::*;
+ use crate::file_tracker::FileReadTracker;
+
+ /// Helper: create a temp file (with a closed handle), record it in a tracker.
+ /// Returns the TempDir (keeps the path alive) and tracker.
+ /// The file handle is closed so atomic_write_file can rename over it on Windows.
+ fn setup_tracked_file(content: &str) -> (tempfile::TempDir, PathBuf, FileReadTracker) {
+ let dir = tempfile::tempdir().unwrap();
+ let path = dir.path().join("test_file.toml");
+ std::fs::write(&path, content).unwrap();
+
+ let file_content = std::fs::read(&path).unwrap();
+ let mtime = std::fs::metadata(&path).unwrap().modified().unwrap();
+
+ let mut tracker = FileReadTracker::default();
+ tracker.record_read(path.clone(), &file_content, mtime);
+
+ (dir, path, tracker)
+ }
+
+ fn edit_call(path: &Path, old: &str, new: &str, replace_all: bool) -> EditToolCall {
+ EditToolCall {
+ path: path.to_path_buf(),
+ old_string: old.to_string(),
+ new_string: new.to_string(),
+ replace_all,
+ }
+ }
+
+ #[test]
+ fn successful_single_replacement() {
+ let (_dir, path, tracker) = setup_tracked_file("[section]\nkey = old_value\n");
+
+ let call = edit_call(&path, "old_value", "new_value", false);
+ let (outcome, new_bytes) = call.execute(&path, &tracker);
+
+ assert!(matches!(outcome, ToolOutcome::Success(_)));
+ assert!(new_bytes.is_some());
+ assert_eq!(
+ std::fs::read_to_string(&path).unwrap(),
+ "[section]\nkey = new_value\n"
+ );
+ }
+
+ #[test]
+ fn successful_replace_all() {
+ let (_dir, path, tracker) = setup_tracked_file("aaa bbb aaa ccc aaa");
+
+ let call = edit_call(&path, "aaa", "xxx", true);
+ let (outcome, _) = call.execute(&path, &tracker);
+
+ assert!(matches!(outcome, ToolOutcome::Success(ref s) if s.contains("3 occurrences")));
+ assert_eq!(
+ std::fs::read_to_string(&path).unwrap(),
+ "xxx bbb xxx ccc xxx"
+ );
+ }
+
+ #[test]
+ fn error_file_not_read() {
+ let dir = tempfile::tempdir().unwrap();
+ let path = dir.path().join("unread.txt");
+ std::fs::write(&path, "content").unwrap();
+ let tracker = FileReadTracker::default(); // empty — never read
+
+ let call = edit_call(&path, "x", "y", false);
+ let (outcome, new_bytes) = call.execute(&path, &tracker);
+
+ assert!(new_bytes.is_none());
+ match outcome {
+ ToolOutcome::Error(msg) => {
+ assert!(msg.contains("not been read yet"), "got: {msg}");
+ }
+ _ => panic!("expected error"),
+ }
+ }
+
+ #[test]
+ fn error_file_modified_since_read() {
+ let (_dir, path, tracker) = setup_tracked_file("original");
+
+ // Modify the file after the read was recorded
+ std::thread::sleep(std::time::Duration::from_millis(10));
+ std::fs::write(&path, "modified externally").unwrap();
+
+ let call = edit_call(&path, "original", "replaced", false);
+ let (outcome, new_bytes) = call.execute(&path, &tracker);
+
+ assert!(new_bytes.is_none());
+ match outcome {
+ ToolOutcome::Error(msg) => {
+ assert!(msg.contains("modified since read"), "got: {msg}");
+ }
+ _ => panic!("expected error"),
+ }
+ }
+
+ #[test]
+ fn error_no_match() {
+ let (_dir, path, tracker) = setup_tracked_file("hello world");
+
+ let call = edit_call(&path, "nonexistent", "replacement", false);
+ let (outcome, new_bytes) = call.execute(&path, &tracker);
+
+ assert!(new_bytes.is_none());
+ match outcome {
+ ToolOutcome::Error(msg) => {
+ assert!(msg.contains("not found"), "got: {msg}");
+ }
+ _ => panic!("expected error"),
+ }
+ }
+
+ #[test]
+ fn error_multiple_matches_without_replace_all() {
+ let (_dir, path, tracker) = setup_tracked_file("foo bar foo baz foo");
+
+ let call = edit_call(&path, "foo", "qux", false);
+ let (outcome, new_bytes) = call.execute(&path, &tracker);
+
+ assert!(new_bytes.is_none());
+ match outcome {
+ ToolOutcome::Error(msg) => {
+ assert!(msg.contains("3 matches"), "got: {msg}");
+ assert!(msg.contains("replace_all"), "got: {msg}");
+ }
+ _ => panic!("expected error"),
+ }
+ // File should be unchanged
+ assert_eq!(
+ std::fs::read_to_string(&path).unwrap(),
+ "foo bar foo baz foo"
+ );
+ }
+
+ #[test]
+ fn error_empty_old_string() {
+ let (_dir, path, tracker) = setup_tracked_file("content");
+
+ let call = edit_call(&path, "", "something", false);
+ let (outcome, new_bytes) = call.execute(&path, &tracker);
+
+ assert!(new_bytes.is_none());
+ assert!(matches!(outcome, ToolOutcome::Error(_)));
+ }
+
+ #[test]
+ fn error_file_does_not_exist() {
+ let tracker = FileReadTracker::default();
+ let dir = tempfile::tempdir().unwrap();
+ let path = dir.path().join("nonexistent.txt");
+
+ let call = edit_call(&path, "x", "y", false);
+ let (outcome, new_bytes) = call.execute(&path, &tracker);
+
+ assert!(new_bytes.is_none());
+ match outcome {
+ ToolOutcome::Error(msg) => {
+ assert!(msg.contains("does not exist"), "got: {msg}");
+ }
+ _ => panic!("expected error"),
+ }
+ }
+
+ #[test]
+ fn preserves_file_when_no_match() {
+ let original = "[config]\nport = 8080\nhost = localhost\n";
+ let (_dir, path, tracker) = setup_tracked_file(original);
+
+ let call = edit_call(&path, "port = 9090", "port = 3000", false);
+ let (outcome, _) = call.execute(&path, &tracker);
+
+ assert!(matches!(outcome, ToolOutcome::Error(_)));
+ assert_eq!(std::fs::read_to_string(&path).unwrap(), original);
+ }
+
+ #[test]
+ fn multiline_replacement() {
+ let content = "[section]\nkey1 = val1\nkey2 = val2\n[other]\n";
+ let (_dir, path, tracker) = setup_tracked_file(content);
+
+ let call = edit_call(
+ &path,
+ "key1 = val1\nkey2 = val2",
+ "key1 = new1\nkey2 = new2",
+ false,
+ );
+ let (outcome, new_bytes) = call.execute(&path, &tracker);
+
+ assert!(matches!(outcome, ToolOutcome::Success(_)));
+ assert!(new_bytes.is_some());
+ assert_eq!(
+ std::fs::read_to_string(&path).unwrap(),
+ "[section]\nkey1 = new1\nkey2 = new2\n[other]\n"
+ );
+ }
+ }
+
+ // ── Integration tests: full edit lifecycle ──
+ //
+ // These exercise the cross-component flow that dispatch orchestrates:
+ // FileReadTracker → SnapshotStore → EditToolCall.execute → tracker update
+
+ mod edit_integration {
+ use super::*;
+ use crate::edit_permissions::EditPermissionCache;
+ use crate::file_tracker::FileReadTracker;
+ use crate::snapshots::SnapshotStore;
+
+ /// Simulate a file read (what dispatch does after ReadToolCall.execute).
+ fn simulate_read(tracker: &mut FileReadTracker, path: &std::path::Path) {
+ let content = std::fs::read(path).unwrap();
+ let mtime = std::fs::metadata(path).unwrap().modified().unwrap();
+ tracker.record_read(path.to_path_buf(), &content, mtime);
+ }
+
+ /// Simulate a tracker update after edit (what dispatch does after execute).
+ fn simulate_tracker_update(
+ tracker: &mut FileReadTracker,
+ path: &std::path::Path,
+ new_bytes: &[u8],
+ ) {
+ let mtime = std::fs::metadata(path).unwrap().modified().unwrap();
+ tracker.update_after_edit(path, new_bytes, mtime);
+ }
+
+ #[test]
+ fn full_read_snapshot_edit_cycle() {
+ let dir = tempfile::tempdir().unwrap();
+ let file_path = dir.path().join("config.toml");
+ std::fs::write(&file_path, "[db]\nhost = localhost\nport = 5432\n").unwrap();
+
+ let snapshot_dir = dir.path().join("snapshots").join("session-1");
+ let mut tracker = FileReadTracker::default();
+ let mut store = SnapshotStore::open(snapshot_dir.clone()).unwrap();
+
+ // 1. Simulate reading the file
+ simulate_read(&mut tracker, &file_path);
+
+ // 2. Snapshot before edit
+ let original = std::fs::read(&file_path).unwrap();
+ store.ensure_snapshot(&file_path, &original).unwrap();
+
+ // 3. Execute edit
+ let call = EditToolCall {
+ path: file_path.clone(),
+ old_string: "host = localhost".to_string(),
+ new_string: "host = 10.0.0.1".to_string(),
+ replace_all: false,
+ };
+ let (outcome, new_bytes) = call.execute(&file_path, &tracker);
+ assert!(matches!(outcome, ToolOutcome::Success(_)));
+ let new_bytes = new_bytes.unwrap();
+
+ // 4. Update tracker (simulating what dispatch does)
+ simulate_tracker_update(&mut tracker, &file_path, &new_bytes);
+
+ // Verify: file was edited
+ assert_eq!(
+ std::fs::read_to_string(&file_path).unwrap(),
+ "[db]\nhost = 10.0.0.1\nport = 5432\n"
+ );
+
+ // Verify: snapshot has original content
+ assert!(store.has_snapshot(&file_path));
+ let snapshot_name = crate::snapshots::sanitize_path(&file_path);
+ let snapshot_content =
+ std::fs::read_to_string(snapshot_dir.join(snapshot_name)).unwrap();
+ assert_eq!(snapshot_content, "[db]\nhost = localhost\nport = 5432\n");
+ }
+
+ #[test]
+ fn second_edit_without_reread() {
+ let dir = tempfile::tempdir().unwrap();
+ let file_path = dir.path().join("config.toml");
+ std::fs::write(&file_path, "key1 = aaa\nkey2 = bbb\n").unwrap();
+
+ let mut tracker = FileReadTracker::default();
+
+ // Read the file
+ simulate_read(&mut tracker, &file_path);
+
+ // First edit
+ let call1 = EditToolCall {
+ path: file_path.clone(),
+ old_string: "key1 = aaa".to_string(),
+ new_string: "key1 = xxx".to_string(),
+ replace_all: false,
+ };
+ let (outcome, new_bytes) = call1.execute(&file_path, &tracker);
+ assert!(matches!(outcome, ToolOutcome::Success(_)));
+ simulate_tracker_update(&mut tracker, &file_path, &new_bytes.unwrap());
+
+ // Second edit — should work without re-reading because tracker was updated
+ let call2 = EditToolCall {
+ path: file_path.clone(),
+ old_string: "key2 = bbb".to_string(),
+ new_string: "key2 = yyy".to_string(),
+ replace_all: false,
+ };
+ let (outcome, new_bytes) = call2.execute(&file_path, &tracker);
+ assert!(matches!(outcome, ToolOutcome::Success(_)));
+ assert!(new_bytes.is_some());
+ assert_eq!(
+ std::fs::read_to_string(&file_path).unwrap(),
+ "key1 = xxx\nkey2 = yyy\n"
+ );
+ }
+
+ #[test]
+ fn external_modification_between_edits() {
+ let dir = tempfile::tempdir().unwrap();
+ let file_path = dir.path().join("config.toml");
+ std::fs::write(&file_path, "value = original\n").unwrap();
+
+ let mut tracker = FileReadTracker::default();
+ simulate_read(&mut tracker, &file_path);
+
+ // First edit succeeds
+ let call1 = EditToolCall {
+ path: file_path.clone(),
+ old_string: "value = original".to_string(),
+ new_string: "value = edited".to_string(),
+ replace_all: false,
+ };
+ let (outcome, new_bytes) = call1.execute(&file_path, &tracker);
+ assert!(matches!(outcome, ToolOutcome::Success(_)));
+ simulate_tracker_update(&mut tracker, &file_path, &new_bytes.unwrap());
+
+ // External modification (e.g., user edits the file)
+ std::thread::sleep(std::time::Duration::from_millis(10));
+ std::fs::write(&file_path, "value = user_changed\n").unwrap();
+
+ // Second edit should fail (stale)
+ let call2 = EditToolCall {
+ path: file_path.clone(),
+ old_string: "value = edited".to_string(),
+ new_string: "value = second_edit".to_string(),
+ replace_all: false,
+ };
+ let (outcome, new_bytes) = call2.execute(&file_path, &tracker);
+ assert!(new_bytes.is_none());
+ match outcome {
+ ToolOutcome::Error(msg) => assert!(msg.contains("modified since read")),
+ _ => panic!("expected stale error"),
+ }
+
+ // File should be unchanged (the user's edit preserved)
+ assert_eq!(
+ std::fs::read_to_string(&file_path).unwrap(),
+ "value = user_changed\n"
+ );
+ }
+
+ #[test]
+ fn snapshot_only_created_once_per_file() {
+ let dir = tempfile::tempdir().unwrap();
+ let file_path = dir.path().join("config.toml");
+ std::fs::write(&file_path, "a = 1\nb = 2\n").unwrap();
+
+ let snapshot_dir = dir.path().join("snapshots").join("session-1");
+ let mut tracker = FileReadTracker::default();
+ let mut store = SnapshotStore::open(snapshot_dir).unwrap();
+
+ simulate_read(&mut tracker, &file_path);
+
+ // First edit — snapshot should be created
+ let original = std::fs::read(&file_path).unwrap();
+ let created = store.ensure_snapshot(&file_path, &original).unwrap();
+ assert!(created);
+
+ let call1 = EditToolCall {
+ path: file_path.clone(),
+ old_string: "a = 1".to_string(),
+ new_string: "a = 10".to_string(),
+ replace_all: false,
+ };
+ let (_, new_bytes) = call1.execute(&file_path, &tracker);
+ simulate_tracker_update(&mut tracker, &file_path, &new_bytes.unwrap());
+
+ // Second edit — snapshot should NOT be recreated
+ let content_before_second = std::fs::read(&file_path).unwrap();
+ let created = store
+ .ensure_snapshot(&file_path, &content_before_second)
+ .unwrap();
+ assert!(!created); // idempotent — already snapshotted
+ }
+
+ #[test]
+ fn permission_cache_grant_and_check() {
+ let mut cache = EditPermissionCache::default();
+ let path = std::path::PathBuf::from("/Users/me/.config/atuin/config.toml");
+
+ // Initially no grant
+ assert!(!cache.has_valid_grant(&path));
+
+ // Grant permission
+ cache.grant(path.clone());
+ assert!(cache.has_valid_grant(&path));
+
+ // Different file has no grant
+ assert!(!cache.has_valid_grant(std::path::Path::new("/other/file.toml")));
+
+ // Roundtrip through JSON (simulates session persistence)
+ let json = cache.to_json().unwrap();
+ let restored = EditPermissionCache::from_json(&json).unwrap();
+ assert!(restored.has_valid_grant(&path));
+ }
+ }
+
// ── Windows-specific tests (absolute paths with drive letters) ──
#[cfg(windows)]
diff --git a/crates/atuin-ai/src/tui/dispatch.rs b/crates/atuin-ai/src/tui/dispatch.rs
index ea895c01..fea26953 100644
--- a/crates/atuin-ai/src/tui/dispatch.rs
+++ b/crates/atuin-ai/src/tui/dispatch.rs
@@ -61,15 +61,17 @@ pub(crate) fn dispatch(ctx: &mut DispatchContext, event: AiTuiEvent) -> bool {
!ctx.exiting.load(Ordering::Acquire)
}
-/// Persist new events and the server session ID if it has changed.
+/// Persist new events, server session ID, file tracker, and edit permissions.
/// Called from the dispatch thread (sync), bridges to async via the tokio handle.
fn persist_session(ctx: &mut DispatchContext) {
- let Ok((events, server_sid)) = ctx
+ let Ok((events, server_sid, file_tracker_json, edit_perms_json)) = ctx
.handle
.fetch(|state| {
(
state.conversation.events.clone(),
state.conversation.session_id.clone(),
+ state.file_tracker.to_json().ok(),
+ state.edit_permissions.to_json().ok(),
)
})
.blocking_recv()
@@ -86,6 +88,22 @@ fn persist_session(ctx: &mut DispatchContext) {
{
tracing::warn!("failed to persist server session ID: {e}");
}
+ if let Some(ref json) = file_tracker_json
+ && let Err(e) = rt.block_on(
+ ctx.session_mgr
+ .set_metadata(crate::file_tracker::METADATA_KEY, json),
+ )
+ {
+ tracing::warn!("failed to persist file tracker: {e}");
+ }
+ if let Some(ref json) = edit_perms_json
+ && let Err(e) = rt.block_on(
+ ctx.session_mgr
+ .set_metadata(crate::edit_permissions::METADATA_KEY, json),
+ )
+ {
+ tracing::warn!("failed to persist edit permissions: {e}");
+ }
}
fn launch_stream(ctx: &DispatchContext, setup: impl FnOnce(&mut Session) + Send + 'static) {
@@ -210,6 +228,10 @@ fn execute_tool(
let shell_call = shell_call.clone();
execute_shell_tool(handle, tx, &tool_id, &shell_call);
}
+ ClientToolCall::Edit(edit_call) => {
+ let edit_call = edit_call.clone();
+ execute_edit_tool(handle, tx, tool_id, edit_call);
+ }
_ => {
execute_simple_tool(handle, tx, tool_id, tool, db);
}
@@ -231,7 +253,21 @@ fn execute_simple_tool(
tokio::spawn(async move {
let outcome = tool.execute(&db).await;
+
+ // After a successful file read, capture tracking data for freshness
+ // checking. This re-stats the file to get content hash and mtime.
+ let read_tracking = if let ClientToolCall::Read(ref read_tool) = tool
+ && !outcome.is_error()
+ {
+ capture_read_tracking(&read_tool.path)
+ } else {
+ None
+ };
+
h.update(move |state| {
+ if let Some((path, content, mtime)) = read_tracking {
+ state.file_tracker.record_read(path, &content, mtime);
+ }
state.finish_tool_call(&tool_id, outcome);
if !state.tool_tracker.has_pending() {
let _ = tx.send(AiTuiEvent::ContinueAfterTools);
@@ -240,6 +276,117 @@ fn execute_simple_tool(
});
}
+/// Capture file content and mtime for the read tracker.
+/// Returns None for directories or if the file can't be read.
+fn capture_read_tracking(
+ path: &std::path::Path,
+) -> Option<(std::path::PathBuf, Vec<u8>, std::time::SystemTime)> {
+ let resolved = if path.is_relative() {
+ std::env::current_dir().ok()?.join(path)
+ } else {
+ path.to_path_buf()
+ };
+ if !resolved.is_file() {
+ return None;
+ }
+ let content = std::fs::read(&resolved).ok()?;
+ let mtime = std::fs::metadata(&resolved).ok()?.modified().ok()?;
+ Some((resolved, content, mtime))
+}
+
+/// Execute an edit_file tool call.
+///
+/// Orchestrates snapshot → execute → tracker update. The snapshot and
+/// tracker mutations happen via `h.update()` (on the TUI thread) since
+/// they need mutable Session state. The actual file I/O (freshness check,
+/// read, match, atomic write) runs in the tokio task.
+fn execute_edit_tool(
+ handle: &Handle<Session>,
+ tx: &mpsc::Sender<AiTuiEvent>,
+ tool_id: String,
+ edit_call: crate::tools::EditToolCall,
+) {
+ let h = handle.clone();
+ let tx = tx.clone();
+
+ tokio::spawn(async move {
+ let resolved = edit_call.resolved_path();
+
+ // 1. Read the original file content (used for snapshot + diff).
+ let old_content = std::fs::read(&resolved).ok();
+
+ // 2. Snapshot the original file before editing.
+ if let Some(ref content) = old_content {
+ let snap_path = resolved.clone();
+ let snap_content = content.clone();
+ h.update(move |state| {
+ if let Some(ref mut store) = state.snapshot_store
+ && let Err(e) = store.ensure_snapshot(&snap_path, &snap_content)
+ {
+ tracing::warn!("failed to create file snapshot: {e}");
+ }
+ });
+ }
+
+ // 3. Fetch a clone of the file tracker for freshness checking.
+ let Ok(tracker) = h.fetch(|state| state.file_tracker.clone()).await else {
+ let tc_id = tool_id.clone();
+ h.update(move |state| {
+ state.finish_tool_call(
+ &tc_id,
+ crate::tools::ToolOutcome::Error("Internal error: TUI unavailable".into()),
+ );
+ if !state.tool_tracker.has_pending() {
+ let _ = tx.send(AiTuiEvent::ContinueAfterTools);
+ }
+ });
+ return;
+ };
+
+ // 4. Execute: freshness check → read → match → atomic write
+ let (outcome, new_bytes) = edit_call.execute(&resolved, &tracker);
+
+ // 5. Compute diff preview on success
+ let edit_preview = if let Some(ref new_bytes) = new_bytes {
+ if let Some(ref old_bytes) = old_content {
+ let old_str = String::from_utf8_lossy(old_bytes);
+ let new_str = String::from_utf8_lossy(new_bytes);
+ let preview = crate::diff::EditPreview::compute(&old_str, &new_str);
+ if preview.hunks.is_empty() {
+ None
+ } else {
+ Some(preview)
+ }
+ } else {
+ None
+ }
+ } else {
+ None
+ };
+
+ // 6. Update tracker, store diff preview, and finish the tool call
+ let tc_id = tool_id;
+ h.update(move |state| {
+ if let Some(ref new_bytes) = new_bytes
+ && let Ok(mtime) = std::fs::metadata(&resolved).and_then(|m| m.modified())
+ {
+ state
+ .file_tracker
+ .update_after_edit(&resolved, new_bytes, mtime);
+ }
+ if let Some(preview) = edit_preview
+ && let Some(tracked) = state.tool_tracker.get_mut(&tc_id)
+ {
+ tracked.edit_preview = Some(preview);
+ }
+ state.finish_tool_call(&tc_id, outcome);
+ if !state.tool_tracker.has_pending() {
+ let _ = tx.send(AiTuiEvent::ContinueAfterTools);
+ }
+ });
+ });
+}
+
/// Execute a shell tool with streaming VT100 preview.
fn execute_shell_tool(
handle: &Handle<Session>,
@@ -352,12 +499,28 @@ async fn check_tool_permission_inner(
.map_err(|e| format!("Internal error fetching tool state: {e}"))?
.ok_or_else(|| "Internal error: tool not found in tracker".to_string())?;
- // 2. Resolve working directory
+ // 2. For edit tools, check session-scoped permission grants before
+ // hitting the filesystem-based resolver. A valid grant means the user
+ // already approved this file recently.
+ if let ClientToolCall::Edit(ref edit) = tool {
+ let resolved = edit.resolved_path();
+ let has_grant = h2
+ .fetch(move |state| state.edit_permissions.has_valid_grant(&resolved))
+ .await
+ .unwrap_or(false);
+
+ if has_grant {
+ execute_tool(h2, tx, id, tool, db);
+ return Ok(());
+ }
+ }
+
+ // 3. Resolve working directory
let working_dir = target_dir
.or_else(|| std::env::current_dir().ok())
.ok_or_else(|| "Could not determine working directory".to_string())?;
- // 3. Create permission resolver and check
+ // 4. Create permission resolver and check
let resolver = PermissionResolver::new(working_dir)
.await
.map_err(|e| format!("Permission check failed: {e}"))?;
@@ -367,7 +530,7 @@ async fn check_tool_permission_inner(
.await
.map_err(|e| format!("Permission check failed: {e}"))?;
- // 4. Handle response — all paths here handle the tool, so return Ok
+ // 5. Handle response — all paths here handle the tool, so return Ok
let id_clone = id.clone();
match response {
PermissionResponse::Allowed => {
@@ -423,6 +586,32 @@ fn on_select_permission(ctx: &mut DispatchContext, permission: PermissionResult)
execute_tool(&h2, &tx, tool_id, tool, &db);
});
}
+ PermissionResult::AllowFileForSession => {
+ // Cache a session-scoped, time-limited grant for this file
+ let db = ctx.app_ctx.history_db.clone();
+ tokio::spawn(async move {
+ let Ok(Some((tool_id, tool))) = h2
+ .fetch(move |state| {
+ state
+ .tool_tracker
+ .asking_for_permission()
+ .map(|t| (t.id.clone(), t.tool.clone()))
+ })
+ .await
+ else {
+ return;
+ };
+
+ if let ClientToolCall::Edit(ref edit) = tool {
+ let resolved = edit.resolved_path();
+ h2.update(move |state| {
+ state.edit_permissions.grant(resolved);
+ });
+ }
+
+ execute_tool(&h2, &tx, tool_id, tool, &db);
+ });
+ }
PermissionResult::AlwaysAllowInDir => {
let db = ctx.app_ctx.history_db.clone();
let git_root = ctx.app_ctx.git_root.clone();
diff --git a/crates/atuin-ai/src/tui/events.rs b/crates/atuin-ai/src/tui/events.rs
index 1a422fef..969f6ae5 100644
--- a/crates/atuin-ai/src/tui/events.rs
+++ b/crates/atuin-ai/src/tui/events.rs
@@ -38,7 +38,34 @@ pub(crate) enum AiTuiEvent {
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum PermissionResult {
Allow,
+ /// Per-file, time-limited grant scoped to the current session.
+ AllowFileForSession,
AlwaysAllowInDir,
AlwaysAllow,
Deny,
}
+
+impl PermissionResult {
+ /// String identifier used as the SelectOption value.
+ pub fn as_value_str(&self) -> &'static str {
+ match self {
+ Self::Allow => "allow",
+ Self::AllowFileForSession => "allow-file-session",
+ Self::AlwaysAllowInDir => "always-allow-in-dir",
+ Self::AlwaysAllow => "always-allow",
+ Self::Deny => "deny",
+ }
+ }
+
+ /// Parse from a SelectOption value string.
+ pub fn from_value_str(s: &str) -> Option<Self> {
+ match s {
+ "allow" => Some(Self::Allow),
+ "allow-file-session" => Some(Self::AllowFileForSession),
+ "always-allow-in-dir" => Some(Self::AlwaysAllowInDir),
+ "always-allow" => Some(Self::AlwaysAllow),
+ "deny" => Some(Self::Deny),
+ _ => None,
+ }
+ }
+}
diff --git a/crates/atuin-ai/src/tui/state.rs b/crates/atuin-ai/src/tui/state.rs
index e122918e..af1ebffe 100644
--- a/crates/atuin-ai/src/tui/state.rs
+++ b/crates/atuin-ai/src/tui/state.rs
@@ -474,6 +474,12 @@ pub(crate) struct Session {
pub slash_registry: SlashCommandRegistry,
/// The unique ID for this invocation
pub invocation_id: String,
+ /// Tracks which files have been read, for freshness checking before edits.
+ pub file_tracker: crate::file_tracker::FileReadTracker,
+ /// Session-scoped edit permission grants (per-file, time-limited).
+ pub edit_permissions: crate::edit_permissions::EditPermissionCache,
+ /// Backs up files before the first edit in a session.
+ pub snapshot_store: Option<crate::snapshots::SnapshotStore>,
}
impl Session {
@@ -491,6 +497,9 @@ impl Session {
archived_view_events: Vec::new(),
slash_registry: Default::default(),
invocation_id: invocation_id.unwrap_or_else(|| uuid::Uuid::now_v7().to_string()),
+ file_tracker: Default::default(),
+ edit_permissions: Default::default(),
+ snapshot_store: None,
}
}
diff --git a/crates/atuin-ai/src/tui/view/mod.rs b/crates/atuin-ai/src/tui/view/mod.rs
index 565a0597..bdbece9c 100644
--- a/crates/atuin-ai/src/tui/view/mod.rs
+++ b/crates/atuin-ai/src/tui/view/mod.rs
@@ -1,12 +1,11 @@
//! View function that builds the eye-declare element tree from app state.
use eye_declare::{
- BorderType, Cells, Column, Elements, HStack, Span, Spinner, Text, View, Viewport,
- WidthConstraint, element,
+ Cells, Column, Elements, HStack, Span, Spinner, Text, View, Viewport, WidthConstraint, element,
};
use ratatui_core::style::{Color, Modifier, Style};
-use crate::tools::{ClientToolCall, TrackedTool};
+use crate::tools::{ClientToolCall, HistorySearchFilterMode, ToolPreview, TrackedTool};
use crate::tui::components::select::SelectOption;
use crate::tui::components::session_continue::SessionContinue;
use crate::tui::events::{AiTuiEvent, PermissionResult};
@@ -68,6 +67,16 @@ pub(crate) fn ai_view(state: &Session) -> Elements {
})
})
+ #({
+ let needs_pending_banner = busy && !matches!(turns.last(), Some(turn::UiTurn::Agent { .. }));
+ if needs_pending_banner {
+ let empty: &[turn::UiEvent] = &[];
+ agent_turn_view(empty, true)
+ } else {
+ element! {}
+ }
+ })
+
#(if !state.is_exiting() {
#(input_view(state))
})
@@ -135,16 +144,13 @@ fn tool_call_view(tool_call: &TrackedTool, in_git_project: bool) -> Elements {
let verb = tool_call.tool.descriptor().display_verb;
let tool_desc = match &tool_call.tool {
ClientToolCall::Read(tool) => tool.path.display().to_string(),
+ ClientToolCall::Edit(tool) => tool.path.display().to_string(),
ClientToolCall::Write(tool) => tool.path.display().to_string(),
ClientToolCall::Shell(tool) => tool.command.clone(),
ClientToolCall::AtuinHistory(tool) => tool.query.clone(),
};
- let dir_label = if in_git_project {
- "Always allow in this workspace"
- } else {
- "Always allow in this directory"
- };
+ let select_options = permission_options_for_tool(&tool_call.tool, in_git_project);
element! {
View(key: format!("tool-call-{}", tool_call.id), padding_left: Cells::from(2), padding_top: Cells::from(1)) {
@@ -153,39 +159,68 @@ fn tool_call_view(tool_call: &TrackedTool, in_git_project: bool) -> Elements {
Span(text: &tool_desc, style: Style::default().fg(Color::Yellow))
}
View(padding_left: Cells::from(2)) {
- Select(options: [
- SelectOption::builder()
- .label("Allow")
- .value("allow")
- .build(),
- SelectOption::builder()
- .label(dir_label)
- .value("always-allow-in-dir")
- .build(),
- SelectOption::builder()
- .label("Always allow")
- .value("always-allow")
- .build(),
- SelectOption::builder()
- .label("Deny")
- .value("deny")
- .build(),
- ], on_select: Box::new(move |option: &SelectOption| {
- let value = match option.value.as_str() {
- "allow" => PermissionResult::Allow,
- "always-allow-in-dir" => PermissionResult::AlwaysAllowInDir,
- "always-allow" => PermissionResult::AlwaysAllow,
- "deny" => PermissionResult::Deny,
- _ => unreachable!(),
- };
-
- Some(AiTuiEvent::SelectPermission(value))
+ Select(options: select_options, on_select: Box::new(move |option: &SelectOption| {
+ PermissionResult::from_value_str(option.value.as_str())
+ .map(AiTuiEvent::SelectPermission)
}) as Box<dyn Fn(&SelectOption) -> Option<AiTuiEvent> + Send + Sync>)
}
}
}
}
+/// Build the permission SelectOptions appropriate for a tool call.
+///
+/// Edit tools get a per-file session-scoped option instead of the
+/// workspace-level "Always allow in this directory". Other tools
+/// keep the standard set.
+fn permission_options_for_tool(tool: &ClientToolCall, in_git_project: bool) -> Vec<SelectOption> {
+ match tool {
+ ClientToolCall::Edit(_) => vec![
+ SelectOption::builder()
+ .label("Allow")
+ .value(PermissionResult::Allow.as_value_str())
+ .build(),
+ SelectOption::builder()
+ .label("Allow this file for this session")
+ .value(PermissionResult::AllowFileForSession.as_value_str())
+ .build(),
+ SelectOption::builder()
+ .label("Always allow")
+ .value(PermissionResult::AlwaysAllow.as_value_str())
+ .build(),
+ SelectOption::builder()
+ .label("Deny")
+ .value(PermissionResult::Deny.as_value_str())
+ .build(),
+ ],
+ _ => {
+ let dir_label = if in_git_project {
+ "Always allow in this workspace"
+ } else {
+ "Always allow in this directory"
+ };
+ vec![
+ SelectOption::builder()
+ .label("Allow")
+ .value(PermissionResult::Allow.as_value_str())
+ .build(),
+ SelectOption::builder()
+ .label(dir_label)
+ .value(PermissionResult::AlwaysAllowInDir.as_value_str())
+ .build(),
+ SelectOption::builder()
+ .label("Always allow")
+ .value(PermissionResult::AlwaysAllow.as_value_str())
+ .build(),
+ SelectOption::builder()
+ .label("Deny")
+ .value(PermissionResult::Deny.as_value_str())
+ .build(),
+ ]
+ }
+ }
+}
+
fn user_turn_view(events: &[turn::UiEvent], first_turn: bool) -> Elements {
let label_style = Style::default()
.fg(Color::Cyan)
@@ -231,7 +266,10 @@ fn agent_turn_view(events: &[turn::UiEvent], busy: bool) -> Elements {
label_first: true,
done: !busy,
)
- #(for event in events {
+ #(for (i, event) in events.iter().enumerate() {
+ #(if i > 0 {
+ Text { Span(text: "") }
+ })
#(match event {
turn::UiEvent::Text { content } => {
element! {
@@ -247,47 +285,42 @@ fn agent_turn_view(events: &[turn::UiEvent], busy: bool) -> Elements {
suggested_command_view(details)
},
turn::UiEvent::ToolCall(details) => {
- let preview_done = details.preview.as_ref().is_some_and(|p| p.exit_code.is_some() || p.interrupted);
let tool_key = details.tool_use_id.clone();
element! {
View(key: format!("tool-output-{tool_key}"), padding_left: Cells::from(2)) {
- #(if let Some(ref preview) = details.preview {
- View(key: format!("preview-{tool_key}")) {
- #(preview_spinner_view(&details.name, preview_done))
- Viewport(
- key: format!("viewport-{tool_key}"),
- lines: preview.lines.clone(),
- height: 10,
- border: BorderType::Plain,
- border_style: Style::default().fg(Color::DarkGray),
- style: Style::default().fg(Color::White),
- wrap: false,
- )
- #(if let Some(code) = preview.exit_code {
- #(if code == 0 {
- Text {
- Span(text: format!("Exit code: {code}"), style: Style::default().fg(Color::Green))
- }
- } else {
- Text {
- Span(text: format!("Exit code: {code}"), style: Style::default().fg(Color::Red))
- }
- })
- })
- #(if preview.interrupted {
- Text {
- Span(text: "Interrupted", style: Style::default().fg(Color::Red).add_modifier(Modifier::BOLD))
- }
- })
- #(if !preview_done {
- Text {
- Span(text: "[Ctrl+C] Interrupt", style: Style::default().fg(Color::DarkGray))
- }
- })
- }
- } else {
- #(tool_status_view(&details.name, &details.status))
+ #(match &details.render_data {
+ turn::ToolRenderData::Shell { command, preview } => {
+ shell_tool_view(&tool_key, command, preview.as_ref())
+ },
+ turn::ToolRenderData::FileEdit { path, preview } => {
+ file_edit_tool_view(&tool_key, &details.status, path, preview.as_ref())
+ },
+ turn::ToolRenderData::FileWrite { path } => {
+ file_write_tool_view(&details.status, path)
+ },
+ turn::ToolRenderData::Remote => {
+ tool_status_view(&details.name, &details.status)
+ },
+ turn::ToolRenderData::FileRead { .. }
+ | turn::ToolRenderData::HistorySearch { .. } => {
+ element!{}
+ },
+ })
+ }
+ }
+ }
+ turn::UiEvent::ToolGroup(group) => {
+ let group_key = group.calls
+ .first()
+ .map(|c| c.tool_use_id.as_str())
+ .unwrap_or("empty");
+
+ element! {
+ View(key: format!("group-{group_key}"), padding_left: Cells::from(2)) {
+ #(match group.kind {
+ turn::ToolGroupKind::FileRead => file_read_group_view(group),
+ turn::ToolGroupKind::HistorySearch => history_search_group_view(group),
})
}
}
@@ -367,17 +400,391 @@ fn tool_status_view(name: &str, status: &turn::ToolResultStatus) -> Elements {
}
}
-/// Render a spinner/status line for a command preview (shell tools).
-fn preview_spinner_view(name: &str, done: bool) -> Elements {
+// ───────────────────────────────────────────────────────────────────
+// Per-tool view functions
+// ───────────────────────────────────────────────────────────────────
+
+/// Max output lines shown for a shell command preview.
+const MAX_SHELL_PREVIEW_LINES: u16 = 5;
+
+/// Render a shell command execution with live VT100 output viewport.
+fn shell_tool_view(tool_key: &str, command: &str, preview: Option<&ToolPreview>) -> Elements {
+ let preview_done = preview.is_some_and(|p| p.exit_code.is_some() || p.interrupted);
+
+ element! {
+ #(if let Some(preview) = preview {
+ View(key: format!("preview-{tool_key}")) {
+ Spinner(
+ label: if preview_done { format!("Ran: {command}") } else { format!("Running: {command}") },
+ done: preview_done,
+ hide_checkmark: true,
+ )
+ HStack {
+ View(width: WidthConstraint::Fixed(2)) {
+ Text { Span(text: "└ ") }
+ }
+ Column {
+ Viewport(
+ key: format!("viewport-{tool_key}"),
+ lines: preview.lines.clone(),
+ height: (preview.lines.len() as u16).clamp(1, MAX_SHELL_PREVIEW_LINES),
+ style: Style::default().fg(Color::Gray),
+ wrap: false,
+ )
+ }
+ }
+ #(shell_tool_footer(preview, preview_done))
+ }
+ } else {
+ Spinner(
+ label: format!("Running: {command}"),
+ label_style: Style::default().fg(Color::Yellow),
+ done: false,
+ )
+ })
+ }
+}
+
+fn shell_tool_footer(preview: &ToolPreview, preview_done: bool) -> Elements {
+ if preview.interrupted {
+ return element! {
+ Text {
+ Span(text: "Interrupted", style: Style::default().fg(Color::Red).add_modifier(Modifier::BOLD))
+ }
+ };
+ }
+ if !preview_done {
+ return element! {
+ Text {
+ Span(text: "[Ctrl+C] Interrupt", style: Style::default().fg(Color::DarkGray))
+ }
+ };
+ }
+ if let Some(code) = preview.exit_code {
+ let style = if code == 0 {
+ Style::default().fg(Color::Green)
+ } else {
+ Style::default().fg(Color::Red)
+ };
+ return element! {
+ Text { Span(text: format!("Exit code: {code}"), style: style) }
+ };
+ }
+ element! {}
+}
+
+/// Render a file edit tool call with diff preview.
+fn file_edit_tool_view(
+ key: &str,
+ status: &turn::ToolResultStatus,
+ path: &std::path::Path,
+ preview: Option<&crate::diff::EditPreview>,
+) -> Elements {
+ use crate::diff::DiffLine;
+
+ let display_path = format_path_for_display(path);
+
+ let status_line = match status {
+ turn::ToolResultStatus::Pending => {
+ element! {
+ Spinner(
+ label: format!("Editing: {display_path}"),
+ label_style: Style::default().fg(Color::Yellow),
+ done: false,
+ )
+ }
+ }
+ turn::ToolResultStatus::Success => {
+ element! {
+ Spinner(label: format!("Edited: {display_path}"), done: true)
+ }
+ }
+ turn::ToolResultStatus::Error => {
+ element! {
+ Text {
+ Span(text: "✗ ", style: Style::default().fg(Color::Red))
+ Span(text: format!("Edit {display_path}: failed"), style: Style::default().fg(Color::Red))
+ }
+ }
+ }
+ };
+
+ // If no preview, just show the status line
+ let Some(preview) = preview else {
+ return status_line;
+ };
+ if preview.hunks.is_empty() {
+ return status_line;
+ }
+
+ // Calculate the line number gutter width from the highest line number
+ let max_line_num = preview.max_line_number();
+ let gutter_width = max_line_num.to_string().len().max(2) as u16 + 1; // +1 for spacing
+
+ element! {
+ View(key: key.to_string()) {
+ #(status_line)
+
+ View(key: format!("{key}-diff"), padding_left: Cells::from(2)) {
+ #(for (hunk_idx, hunk) in preview.hunks.iter().enumerate() {
+ #({
+ let gutter_w = gutter_width;
+ let mut before_pos = hunk.before_start;
+ let mut after_pos = hunk.after_start;
+ let lines_rendered: Vec<_> = hunk.lines.iter().enumerate().map(|(line_idx, line)| {
+ let (prefix, text, style, gutter_text, gutter_style) = match line {
+ DiffLine::Context(t) => {
+ let num = format!("{:>width$}", after_pos, width = (gutter_w - 1) as usize);
+ before_pos += 1;
+ after_pos += 1;
+ (" ", t.as_str(), Style::default().fg(Color::DarkGray), num, Style::default().fg(Color::DarkGray))
+ }
+ DiffLine::Removed(t) => {
+ let num = format!("{:>width$}", before_pos, width = (gutter_w - 1) as usize);
+ before_pos += 1;
+ ("-", t.as_str(), Style::default().fg(Color::Red), num, Style::default().fg(Color::Red))
+ }
+ DiffLine::Added(t) => {
+ let num = format!("{:>width$}", after_pos, width = (gutter_w - 1) as usize);
+ after_pos += 1;
+ ("+", t.as_str(), Style::default().fg(Color::Green), num, Style::default().fg(Color::Green))
+ }
+ };
+ (line_idx, prefix, text.to_string(), style, gutter_text, gutter_style)
+ }).collect();
+
+ element! {
+ View(key: format!("{key}-hunk-{hunk_idx}")) {
+ #(for (line_idx, prefix, text, style, gutter_text, gutter_style) in &lines_rendered {
+ HStack(key: format!("{key}-hunk-{hunk_idx}-line-{line_idx}")) {
+ View(width: WidthConstraint::Fixed(gutter_w)) {
+ Text { Span(text: gutter_text, style: *gutter_style) }
+ }
+ View {
+ Text {
+ Span(text: *prefix, style: *style)
+ Span(text: text, style: *style)
+ }
+ }
+ }
+ })
+ }
+ }
+ })
+ })
+ }
+ }
+ }
+}
+
+/// Render a file write tool call status with the target path.
+fn file_write_tool_view(status: &turn::ToolResultStatus, path: &std::path::Path) -> Elements {
+ let display_path = path.display();
+ match status {
+ turn::ToolResultStatus::Pending => {
+ element! {
+ Spinner(
+ label: format!("Writing: {display_path}"),
+ label_style: Style::default().fg(Color::Yellow),
+ done: false,
+ )
+ }
+ }
+ turn::ToolResultStatus::Success => {
+ element! {
+ Spinner(label: format!("Wrote: {display_path}"), done: true)
+ }
+ }
+ turn::ToolResultStatus::Error => {
+ element! {
+ Text {
+ Span(text: "✗ ", style: Style::default().fg(Color::Red))
+ Span(text: format!("Write {display_path}: denied"), style: Style::default().fg(Color::Red))
+ }
+ }
+ }
+ }
+}
+
+// ───────────────────────────────────────────────────────────────────
+// Tool group view functions
+// ───────────────────────────────────────────────────────────────────
+
+/// Max entries shown under a tool group header. When the group holds more
+/// than this, only the most recent `MAX_GROUP_ENTRIES` are displayed; the
+/// count in the header line tells the full story.
+const MAX_GROUP_ENTRIES: usize = 5;
+
+/// Format a filesystem path for display in tool rows.
+///
+/// - Relative to the current working directory if the path is under it
+/// - `~/...` prefix if the path is under the user's home directory
+/// - Absolute otherwise (and relative paths pass through unchanged)
+fn format_path_for_display(path: &std::path::Path) -> String {
+ if let Ok(cwd) = std::env::current_dir()
+ && let Ok(relative) = path.strip_prefix(&cwd)
+ {
+ return relative.display().to_string();
+ }
+
+ if let Ok(home) = std::env::var("HOME")
+ && let Ok(relative) = path.strip_prefix(&home)
+ {
+ return format!("~/{}", relative.display());
+ }
+
+ path.display().to_string()
+}
+
+fn filter_mode_label(mode: &HistorySearchFilterMode) -> &'static str {
+ match mode {
+ HistorySearchFilterMode::Global => "global",
+ HistorySearchFilterMode::Host => "host",
+ HistorySearchFilterMode::Session => "session",
+ HistorySearchFilterMode::Directory => "directory",
+ HistorySearchFilterMode::Workspace => "workspace",
+ }
+}
+
+/// Format a list of filter modes as `"(global, workspace)"`, or an empty
+/// string if the list is empty.
+fn format_filter_modes(modes: &[HistorySearchFilterMode]) -> String {
+ if modes.is_empty() {
+ return String::new();
+ }
+ let parts: Vec<&'static str> = modes.iter().map(filter_mode_label).collect();
+ format!("({})", parts.join(", "))
+}
+
+/// Tree-connector marker for a row in a grouped list: `└ ` for the first
+/// visible row, two spaces for subsequent rows.
+fn tree_marker(is_first: bool) -> &'static str {
+ if is_first { "└ " } else { " " }
+}
+
+/// 2-char status marker column: ✓ / ✗ / blank.
+fn status_marker_view(status: &turn::ToolResultStatus) -> Elements {
+ match status {
+ turn::ToolResultStatus::Pending => element! {
+ Text { Span(text: " ") }
+ },
+ turn::ToolResultStatus::Success => element! {
+ Text { Span(text: "✓ ", style: Style::default().fg(Color::Green)) }
+ },
+ turn::ToolResultStatus::Error => element! {
+ Text { Span(text: "✗ ", style: Style::default().fg(Color::Red)) }
+ },
+ }
+}
+
+/// Compute the slice of calls to show — the most recent `MAX_GROUP_ENTRIES`.
+fn visible_group_calls(group: &turn::ToolGroup) -> &[turn::ToolCallDetails] {
+ let start = group.calls.len().saturating_sub(MAX_GROUP_ENTRIES);
+ &group.calls[start..]
+}
+
+/// Render a single row in a grouped list: [tree marker][status][content].
+fn group_row_view(is_first: bool, status: &turn::ToolResultStatus, content: Elements) -> Elements {
+ element! {
+ HStack {
+ View(width: WidthConstraint::Fixed(2)) {
+ Text { Span(text: tree_marker(is_first)) }
+ }
+ View(width: WidthConstraint::Fixed(2)) {
+ #(status_marker_view(status))
+ }
+ Column {
+ #(content)
+ }
+ }
+ }
+}
+
+/// Render a group of consecutive `read_file` tool calls.
+fn file_read_group_view(group: &turn::ToolGroup) -> Elements {
+ let count = group.calls.len();
+ let label = if count == 1 {
+ "Read 1 file".to_string()
+ } else {
+ format!("Read {count} files")
+ };
+ let done = !group.any_pending();
+ let visible = visible_group_calls(group);
+
+ element! {
+ Spinner(label: label, done: done, hide_checkmark: true)
+ #(for (i, details) in visible.iter().enumerate() {
+ #(file_read_row(i == 0, details))
+ })
+ }
+}
+
+fn file_read_row(is_first: bool, details: &turn::ToolCallDetails) -> Elements {
+ let path_str = match &details.render_data {
+ turn::ToolRenderData::FileRead { path } => format_path_for_display(path),
+ _ => String::new(),
+ };
+
+ let content = element! {
+ Text { Span(text: path_str) }
+ };
+
+ group_row_view(is_first, &details.status, content)
+}
+
+/// Render a group of consecutive `atuin_history` tool calls.
+fn history_search_group_view(group: &turn::ToolGroup) -> Elements {
+ let done = !group.any_pending();
+ let visible = visible_group_calls(group);
+
element! {
- Spinner(
- label: if done { format!("Ran: {name}") } else { format!("Running: {name}") },
- label_style: Style::default().fg(Color::Yellow),
- done: done,
- )
+ Spinner(label: "Searched Atuin history:", done: done, hide_checkmark: true)
+ #(for (i, details) in visible.iter().enumerate() {
+ #(history_search_row(i == 0, details))
+ })
}
}
+fn history_search_row(is_first: bool, details: &turn::ToolCallDetails) -> Elements {
+ let (query, filter_modes) = match &details.render_data {
+ turn::ToolRenderData::HistorySearch {
+ query,
+ filter_modes,
+ } => (query.as_str(), filter_modes.as_slice()),
+ _ => ("", [].as_slice()),
+ };
+
+ let is_empty_query = query.trim().is_empty();
+ let filter_label = format_filter_modes(filter_modes);
+
+ let content = if is_empty_query {
+ element! {
+ Text {
+ Span(
+ text: "recent commands",
+ style: Style::default().fg(Color::Gray).add_modifier(Modifier::ITALIC),
+ )
+ #(if !filter_label.is_empty() {
+ Span(text: " ")
+ Span(text: filter_label, style: Style::default().fg(Color::DarkGray))
+ })
+ }
+ }
+ } else {
+ element! {
+ Text {
+ Span(text: query.to_string())
+ #(if !filter_label.is_empty() {
+ Span(text: " ")
+ Span(text: filter_label, style: Style::default().fg(Color::DarkGray))
+ })
+ }
+ }
+ };
+
+ group_row_view(is_first, &details.status, content)
+}
+
fn suggested_command_view(details: &turn::SuggestedCommandDetails) -> Elements {
let is_dangerous = matches!(
details.danger_level,
@@ -413,9 +820,6 @@ fn suggested_command_view(details: &turn::SuggestedCommandDetails) -> Elements {
element! {
View {
- #(if !details.first_event_in_turn {
- Text { Span(text: "") }
- })
Text {
Span(text: " Suggested command:", style: Style::default().fg(Color::Cyan))
}
diff --git a/crates/atuin-ai/src/tui/view/turn.rs b/crates/atuin-ai/src/tui/view/turn.rs
index a2555dc6..1c19a6b2 100644
--- a/crates/atuin-ai/src/tui/view/turn.rs
+++ b/crates/atuin-ai/src/tui/view/turn.rs
@@ -1,5 +1,7 @@
+use std::path::PathBuf;
+
use crate::tools::descriptor;
-use crate::tools::{ToolPreview, ToolTracker};
+use crate::tools::{ClientToolCall, HistorySearchFilterMode, ToolPreview, ToolTracker};
use crate::tui::ConversationEvent;
/// Server-sent danger level for a suggested command
@@ -80,20 +82,99 @@ impl From<(&String, &String)> for ConfidenceLevel {
#[derive(Debug)]
pub(crate) enum UiEvent {
- Text { content: String },
+ Text {
+ content: String,
+ },
ToolCall(ToolCallDetails),
+ /// Consecutive client-side tool calls of the same groupable kind, collapsed
+ /// into one unit so the view can render a shared status line + a list of
+ /// individual entries.
+ ToolGroup(ToolGroup),
ToolSummary(ToolSummary),
SuggestedCommand(SuggestedCommandDetails),
OutOfBandOutput(OutOfBandOutputDetails),
}
+/// A run of consecutive client-side tool calls of the same groupable kind.
+#[derive(Debug)]
+pub(crate) struct ToolGroup {
+ pub(crate) kind: ToolGroupKind,
+ pub(crate) calls: Vec<ToolCallDetails>,
+}
+
+impl ToolGroup {
+ /// True if any call in the group is still pending.
+ pub(crate) fn any_pending(&self) -> bool {
+ self.calls
+ .iter()
+ .any(|c| c.status == ToolResultStatus::Pending)
+ }
+}
+
+/// Which kind of client-side tools this group holds.
+///
+/// Only tool types that benefit from grouped presentation appear here.
+/// Shell (needs its own viewport) and FileWrite (wants diffs/contents) are
+/// intentionally absent — those render as individual `UiEvent::ToolCall`s.
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+pub(crate) enum ToolGroupKind {
+ FileRead,
+ HistorySearch,
+}
+
+/// Tool-type-specific data for rendering in the view layer.
+///
+/// Each variant carries the data a per-tool renderer component needs.
+/// Built by TurnBuilder from ToolTracker + ConversationEvent data.
+#[derive(Debug)]
+pub(crate) enum ToolRenderData {
+ /// Shell command with live/cached VT100 output preview.
+ Shell {
+ command: String,
+ preview: Option<ToolPreview>,
+ },
+ /// File read operation.
+ FileRead { path: PathBuf },
+ /// File edit (str_replace) operation.
+ FileEdit {
+ path: PathBuf,
+ preview: Option<crate::diff::EditPreview>,
+ },
+ /// File write/create operation.
+ FileWrite { path: PathBuf },
+ /// Atuin history search.
+ HistorySearch {
+ query: String,
+ filter_modes: Vec<HistorySearchFilterMode>,
+ },
+ /// Server-side tool — no client rendering data available.
+ Remote,
+}
+
+impl ToolRenderData {
+ pub(crate) fn is_remote(&self) -> bool {
+ matches!(self, ToolRenderData::Remote)
+ }
+
+ /// The group kind this tool should collapse into, if any.
+ ///
+ /// Returns `None` for tools that render as individual `UiEvent::ToolCall`s
+ /// (shell, file writes, remote).
+ pub(crate) fn group_kind(&self) -> Option<ToolGroupKind> {
+ match self {
+ ToolRenderData::FileRead { .. } => Some(ToolGroupKind::FileRead),
+ ToolRenderData::HistorySearch { .. } => Some(ToolGroupKind::HistorySearch),
+ _ => None,
+ }
+ }
+}
+
#[derive(Debug)]
pub(crate) struct ToolCallDetails {
pub(crate) tool_use_id: String,
pub(crate) name: String,
pub(crate) status: ToolResultStatus,
- pub(crate) is_client: bool,
- pub(crate) preview: Option<ToolPreview>,
+ pub(crate) render_data: ToolRenderData,
}
#[derive(Debug)]
@@ -101,7 +182,6 @@ pub(crate) struct SuggestedCommandDetails {
pub(crate) command: String,
pub(crate) danger_level: DangerLevel,
pub(crate) confidence_level: ConfidenceLevel,
- pub(crate) first_event_in_turn: bool,
}
#[derive(Debug)]
@@ -179,33 +259,49 @@ impl<'a> TurnBuilder<'a> {
pub(crate) fn build(&mut self) -> Vec<UiTurn> {
self.commit_turn();
- // Collapse consecutive tool calls within each agent turn into ToolSummary
+ // Within each agent turn:
+ // - Consecutive remote tool calls collapse into a ToolSummary
+ // - Consecutive client-side tool calls of the same group kind collapse
+ // into a ToolGroup (e.g. N file reads → one group)
+ // - All other events pass through unchanged
for turn in &mut self.turns {
if let UiTurn::Agent { events } = turn {
let mut new_events: Vec<UiEvent> = Vec::new();
- let mut pending_tools: Vec<ToolCallDetails> = Vec::new();
+ let mut pending_remote: Vec<ToolCallDetails> = Vec::new();
+ let mut pending_group: Option<(ToolGroupKind, Vec<ToolCallDetails>)> = None;
for event in events.drain(..) {
match event {
- UiEvent::ToolCall(details) if !details.is_client => {
- pending_tools.push(details);
+ UiEvent::ToolCall(details) if details.render_data.is_remote() => {
+ flush_group(&mut pending_group, &mut new_events);
+ pending_remote.push(details);
}
- other => {
- if !pending_tools.is_empty() {
- new_events.push(UiEvent::ToolSummary(ToolSummary {
- tool_calls: std::mem::take(&mut pending_tools),
- }));
+ UiEvent::ToolCall(details)
+ if details.render_data.group_kind().is_some() =>
+ {
+ flush_remote(&mut pending_remote, &mut new_events);
+
+ let kind = details.render_data.group_kind().unwrap();
+ match pending_group.as_mut() {
+ Some((current_kind, calls)) if *current_kind == kind => {
+ calls.push(details);
+ }
+ _ => {
+ flush_group(&mut pending_group, &mut new_events);
+ pending_group = Some((kind, vec![details]));
+ }
}
+ }
+ other => {
+ flush_remote(&mut pending_remote, &mut new_events);
+ flush_group(&mut pending_group, &mut new_events);
new_events.push(other);
}
}
}
- if !pending_tools.is_empty() {
- new_events.push(UiEvent::ToolSummary(ToolSummary {
- tool_calls: pending_tools,
- }));
- }
+ flush_remote(&mut pending_remote, &mut new_events);
+ flush_group(&mut pending_group, &mut new_events);
*events = new_events;
}
@@ -255,6 +351,9 @@ impl<'a> TurnBuilder<'a> {
}
fn add_agent_text(&mut self, content: &str) {
+ if content.trim().is_empty() {
+ return;
+ }
self.start_agent_turn();
if let UiTurn::Agent { events } = self.turn_mut_unsafe() {
events.push(UiEvent::Text {
@@ -303,8 +402,6 @@ impl<'a> TurnBuilder<'a> {
let danger = DangerLevel::from((&danger_level, &danger_notes));
let confidence = ConfidenceLevel::from((&confidence_level, &confidence_notes));
- let first_event_in_turn = events.is_empty();
-
events.push(UiEvent::SuggestedCommand(SuggestedCommandDetails {
command: input
.get("command")
@@ -313,14 +410,12 @@ impl<'a> TurnBuilder<'a> {
.to_string(),
danger_level: danger,
confidence_level: confidence,
- first_event_in_turn,
}));
}
}
fn add_tool_call(&mut self, id: &str, name: &str, _input: &serde_json::Value) {
- let is_client = descriptor::by_name(name).is_some_and(|d| d.is_client);
- let preview = self.tracker.preview_for(id);
+ let render_data = self.build_render_data(id, name);
self.start_agent_turn();
if let UiTurn::Agent { events } = self.turn_mut_unsafe() {
@@ -328,12 +423,44 @@ impl<'a> TurnBuilder<'a> {
tool_use_id: id.to_string(),
name: name.to_string(),
status: ToolResultStatus::Pending,
- is_client,
- preview,
+ render_data,
}));
}
}
+ /// Build tool-type-specific render data from the ToolTracker.
+ ///
+ /// For client-side tools, the tracker holds the typed `ClientToolCall` and
+ /// any live/cached preview data. For server-side (or unknown) tools, we
+ /// fall back to `ToolRenderData::Remote`.
+ fn build_render_data(&self, id: &str, _name: &str) -> ToolRenderData {
+ if let Some(tracked) = self.tracker.get(id) {
+ match &tracked.tool {
+ ClientToolCall::Shell(shell) => ToolRenderData::Shell {
+ command: shell.command.clone(),
+ preview: tracked.preview(),
+ },
+ ClientToolCall::Read(read) => ToolRenderData::FileRead {
+ path: read.path.clone(),
+ },
+ ClientToolCall::Edit(edit) => ToolRenderData::FileEdit {
+ path: edit.path.clone(),
+ preview: tracked.edit_preview.clone(),
+ },
+ ClientToolCall::Write(write) => ToolRenderData::FileWrite {
+ path: write.path.clone(),
+ },
+ ClientToolCall::AtuinHistory(history) => ToolRenderData::HistorySearch {
+ query: history.query.clone(),
+ filter_modes: history.filter_modes.clone(),
+ },
+ }
+ } else {
+ // Not in tracker → server-side tool
+ ToolRenderData::Remote
+ }
+ }
+
fn add_tool_result(&mut self, tool_use_id: &str, _content: &str, is_error: bool) {
self.start_agent_turn();
if let UiTurn::Agent { events } = self.turn_mut_unsafe() {
@@ -364,6 +491,25 @@ impl<'a> TurnBuilder<'a> {
}
}
+/// Drain pending remote tool calls into a `ToolSummary`.
+fn flush_remote(pending: &mut Vec<ToolCallDetails>, out: &mut Vec<UiEvent>) {
+ if !pending.is_empty() {
+ out.push(UiEvent::ToolSummary(ToolSummary {
+ tool_calls: std::mem::take(pending),
+ }));
+ }
+}
+
+/// Drain a pending client-side tool group into a `ToolGroup`.
+fn flush_group(
+ pending: &mut Option<(ToolGroupKind, Vec<ToolCallDetails>)>,
+ out: &mut Vec<UiEvent>,
+) {
+ if let Some((kind, calls)) = pending.take() {
+ out.push(UiEvent::ToolGroup(ToolGroup { kind, calls }));
+ }
+}
+
#[derive(Debug)]
pub(crate) struct ToolSummary {
tool_calls: Vec<ToolCallDetails>,
diff --git a/crates/atuin-client/Cargo.toml b/crates/atuin-client/Cargo.toml
index 1e2407b2..1faaaa81 100644
--- a/crates/atuin-client/Cargo.toml
+++ b/crates/atuin-client/Cargo.toml
@@ -38,7 +38,7 @@ humantime = "2.1.0"
async-trait = { workspace = true }
itertools = { workspace = true }
rand = { workspace = true }
-shellexpand = "3"
+shellexpand = { workspace = true }
sqlx = { workspace = true, features = ["sqlite", "regexp"] }
minspan = "0.1.5"
regex = { workspace = true }
diff --git a/crates/atuin-client/src/settings.rs b/crates/atuin-client/src/settings.rs
index 9a2b84f5..4df404c4 100644
--- a/crates/atuin-client/src/settings.rs
+++ b/crates/atuin-client/src/settings.rs
@@ -687,6 +687,10 @@ pub struct Ai {
pub struct AiCapabilities {
/// Whether the AI can request to search Atuin history. `None` = unset (defaults to enabled, and the ai will ask for permission).
pub enable_history_search: Option<bool>,
+ /// Whether the AI can request to read and write files. `None` = unset (defaults to enabled, and the ai will ask for permission).
+ pub enable_file_tools: Option<bool>,
+ /// Whether the AI can request to execute bash commands. `None` = unset (defaults to enabled, and the ai will ask for permission).
+ pub enable_command_execution: Option<bool>,
}
#[derive(Default, Clone, Debug, Deserialize, Serialize)]
diff --git a/docs/docs/ai/settings.md b/docs/docs/ai/settings.md
index c61f611e..a8d3dab3 100644
--- a/docs/docs/ai/settings.md
+++ b/docs/docs/ai/settings.md
@@ -42,6 +42,18 @@ Default: `true`
Whether or not to include the "history search" capability in the context sent to the LLM. This allows the AI to request to search your Atuin history for relevant commands when generating suggestions or answering questions.
+### enable_file_tools
+
+Default: `true`
+
+Whether or not to include the "file tools" capability in the context sent to the LLM. This allows the AI to request to read and update files on your system.
+
+### enable_command_execution
+
+Default: `true`
+
+Whether or not to include the "command execution" capability in the context sent to the LLM. This allows the AI to request to execute commands on your system.
+
**Example config**
```toml