diff options
Diffstat (limited to 'crates/atuin-ai/src')
| -rw-r--r-- | crates/atuin-ai/src/commands/inline.rs | 29 | ||||
| -rw-r--r-- | crates/atuin-ai/src/diff.rs | 294 | ||||
| -rw-r--r-- | crates/atuin-ai/src/edit_permissions.rs | 108 | ||||
| -rw-r--r-- | crates/atuin-ai/src/file_tracker.rs | 234 | ||||
| -rw-r--r-- | crates/atuin-ai/src/lib.rs | 4 | ||||
| -rw-r--r-- | crates/atuin-ai/src/session.rs | 27 | ||||
| -rw-r--r-- | crates/atuin-ai/src/snapshots.rs | 414 | ||||
| -rw-r--r-- | crates/atuin-ai/src/store.rs | 32 | ||||
| -rw-r--r-- | crates/atuin-ai/src/stream.rs | 8 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tools/descriptor.rs | 17 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tools/mod.rs | 737 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tui/dispatch.rs | 199 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tui/events.rs | 27 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tui/state.rs | 9 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tui/view/mod.rs | 570 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tui/view/turn.rs | 198 |
16 files changed, 2763 insertions, 144 deletions
diff --git a/crates/atuin-ai/src/commands/inline.rs b/crates/atuin-ai/src/commands/inline.rs index b7aae51f..e0a92ab4 100644 --- a/crates/atuin-ai/src/commands/inline.rs +++ b/crates/atuin-ai/src/commands/inline.rs @@ -175,7 +175,7 @@ async fn run_inline_tui( .find_resumable(cwd.as_deref(), git_root_str.as_deref(), max_age_secs) .await?; - let (mut session_mgr, initial_state) = if let Some(stored) = resumable { + let (mut session_mgr, mut initial_state) = if let Some(stored) = resumable { debug!(session_id = %stored.id, "resuming AI session"); let (mgr, events, server_sid, last_event_ts, invocation_id) = SessionManager::resume(Box::new(service), &stored).await?; @@ -199,6 +199,23 @@ async fn run_inline_tui( session.is_resumed = true; session.last_event_time = last_event_ts.and_then(|ts| chrono::DateTime::from_timestamp(ts, 0)); + + // Restore file read tracker from session metadata + if let Ok(Some(json)) = mgr.get_metadata(crate::file_tracker::METADATA_KEY).await + && let Ok(tracker) = crate::file_tracker::FileReadTracker::from_json(&json) + { + session.file_tracker = tracker; + } + + // Restore edit permission grants from session metadata + if let Ok(Some(json)) = mgr + .get_metadata(crate::edit_permissions::METADATA_KEY) + .await + && let Ok(cache) = crate::edit_permissions::EditPermissionCache::from_json(&json) + { + session.edit_permissions = cache; + } + (mgr, session) } else { // No meaningful content — treat as a fresh session @@ -215,6 +232,16 @@ async fn run_inline_tui( (mgr, Session::new(ctx.git_root.is_some(), None)) }; + // Initialize the snapshot store now that we know the session ID. + let snapshot_dir = atuin_common::utils::data_dir() + .join("ai") + .join("snapshots") + .join(session_mgr.session_id()); + match crate::snapshots::SnapshotStore::open(snapshot_dir) { + Ok(store) => initial_state.snapshot_store = Some(store), + Err(e) => tracing::warn!("failed to open snapshot store: {e}"), + } + let (tx, rx) = mpsc::channel::<AiTuiEvent>(); println!(); diff --git a/crates/atuin-ai/src/diff.rs b/crates/atuin-ai/src/diff.rs new file mode 100644 index 00000000..663481c0 --- /dev/null +++ b/crates/atuin-ai/src/diff.rs @@ -0,0 +1,294 @@ +//! Structured diff computation for edit previews. +//! +//! Computes a line-level diff between old and new file content using +//! imara-diff's Histogram algorithm, producing structured hunks with +//! typed lines (Context, Added, Removed) suitable for TUI rendering. + +use imara_diff::{Algorithm, Diff, InternedInput}; + +/// Number of context lines to show around each change. +const CONTEXT_LINES: u32 = 3; + +/// A structured diff preview for a file edit, ready for rendering. +#[derive(Debug, Clone)] +pub(crate) struct EditPreview { + pub hunks: Vec<DiffHunk>, +} + +/// A contiguous group of diff lines (context + changes). +#[derive(Debug, Clone)] +pub(crate) struct DiffHunk { + /// 1-indexed line number of the first line in this hunk (in the original file). + pub before_start: u32, + /// 1-indexed line number of the first line in this hunk (in the new file). + pub after_start: u32, + pub lines: Vec<DiffLine>, +} + +/// A single line in a diff hunk. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum DiffLine { + /// Unchanged line (shown for context). + Context(String), + /// Line added in the new version. + Added(String), + /// Line removed from the old version. + Removed(String), +} + +impl EditPreview { + /// Compute a structured diff between old and new file content. + /// + /// Uses the Histogram algorithm with line-level granularity and + /// indentation-aware postprocessing for readable output. + pub fn compute(old: &str, new: &str) -> Self { + let input = InternedInput::new(old, new); + let mut diff = Diff::compute(Algorithm::Histogram, &input); + diff.postprocess_lines(&input); + + let raw_hunks: Vec<_> = diff.hunks().collect(); + if raw_hunks.is_empty() { + return EditPreview { hunks: Vec::new() }; + } + + // Merge hunks that are within 2*CONTEXT_LINES of each other + // (same logic as unified diff format). + let mut merged_groups: Vec<Vec<&imara_diff::Hunk>> = Vec::new(); + let mut current_group: Vec<&imara_diff::Hunk> = vec![&raw_hunks[0]]; + + for hunk in &raw_hunks[1..] { + let prev = current_group.last().unwrap(); + if hunk.before.start.saturating_sub(prev.before.end) <= 2 * CONTEXT_LINES { + current_group.push(hunk); + } else { + merged_groups.push(current_group); + current_group = vec![hunk]; + } + } + merged_groups.push(current_group); + + // Build structured hunks from merged groups + let hunks = merged_groups + .into_iter() + .map(|group| build_hunk(&group, &input)) + .collect(); + + EditPreview { hunks } + } + + /// The highest line number (from either file) that will be displayed. + /// Used to calculate gutter width. + pub fn max_line_number(&self) -> u32 { + self.hunks + .iter() + .map(|h| { + let mut before_pos = h.before_start; + let mut after_pos = h.after_start; + for line in &h.lines { + match line { + DiffLine::Context(_) => { + before_pos += 1; + after_pos += 1; + } + DiffLine::Removed(_) => before_pos += 1, + DiffLine::Added(_) => after_pos += 1, + } + } + before_pos.max(after_pos).saturating_sub(1) + }) + .max() + .unwrap_or(0) + } +} + +/// Build a single DiffHunk from a group of adjacent raw hunks. +fn build_hunk(group: &[&imara_diff::Hunk], input: &InternedInput<&str>) -> DiffHunk { + let first = group.first().unwrap(); + let last = group.last().unwrap(); + + let context_start = first.before.start.saturating_sub(CONTEXT_LINES); + let context_end = (last.before.end + CONTEXT_LINES).min(input.before.len() as u32); + + // The after-file position of context_start: same offset as before since + // context before the first change is identical in both files. + let after_context_start = first.after.start - (first.before.start - context_start); + + let mut lines = Vec::new(); + let mut pos = context_start; + + for hunk in group { + // Context lines before this hunk + for i in pos..hunk.before.start { + lines.push(DiffLine::Context(token_text(input, true, i))); + } + + // Removed lines + for i in hunk.before.start..hunk.before.end { + lines.push(DiffLine::Removed(token_text(input, true, i))); + } + + // Added lines + for i in hunk.after.start..hunk.after.end { + lines.push(DiffLine::Added(token_text(input, false, i))); + } + + pos = hunk.before.end; + } + + // Trailing context + for i in pos..context_end { + lines.push(DiffLine::Context(token_text(input, true, i))); + } + + DiffHunk { + before_start: context_start + 1, // 1-indexed + after_start: after_context_start + 1, // 1-indexed + lines, + } +} + +/// Extract the text content of a token, trimming the trailing newline +/// that imara-diff includes in line-based tokenization. +fn token_text(input: &InternedInput<&str>, is_before: bool, idx: u32) -> String { + let tokens = if is_before { + &input.before + } else { + &input.after + }; + let text = input.interner[tokens[idx as usize]]; + text.strip_suffix('\n') + .unwrap_or(text) + .strip_suffix('\r') + .unwrap_or(text.strip_suffix('\n').unwrap_or(text)) + .to_string() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn no_changes_produces_empty_preview() { + let preview = EditPreview::compute("hello\nworld\n", "hello\nworld\n"); + assert!(preview.hunks.is_empty()); + } + + #[test] + fn single_line_replacement() { + let old = "line1\nline2\nline3\n"; + let new = "line1\nchanged\nline3\n"; + let preview = EditPreview::compute(old, new); + + assert_eq!(preview.hunks.len(), 1); + let hunk = &preview.hunks[0]; + + // Should have: context(line1), removed(line2), added(changed), context(line3) + assert!(hunk.lines.contains(&DiffLine::Context("line1".into()))); + assert!(hunk.lines.contains(&DiffLine::Removed("line2".into()))); + assert!(hunk.lines.contains(&DiffLine::Added("changed".into()))); + assert!(hunk.lines.contains(&DiffLine::Context("line3".into()))); + } + + #[test] + fn addition_only() { + let old = "aaa\nbbb\n"; + let new = "aaa\nnew_line\nbbb\n"; + let preview = EditPreview::compute(old, new); + + assert_eq!(preview.hunks.len(), 1); + let hunk = &preview.hunks[0]; + assert!(hunk.lines.contains(&DiffLine::Added("new_line".into()))); + // Original lines are context + assert!(hunk.lines.contains(&DiffLine::Context("aaa".into()))); + assert!(hunk.lines.contains(&DiffLine::Context("bbb".into()))); + } + + #[test] + fn removal_only() { + let old = "aaa\nremove_me\nbbb\n"; + let new = "aaa\nbbb\n"; + let preview = EditPreview::compute(old, new); + + assert_eq!(preview.hunks.len(), 1); + let hunk = &preview.hunks[0]; + assert!(hunk.lines.contains(&DiffLine::Removed("remove_me".into()))); + } + + #[test] + fn distant_changes_produce_separate_hunks() { + // Two changes separated by more than 2*CONTEXT_LINES (6) lines + let old = "1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n11\n12\n"; + let new = "1\nX\n3\n4\n5\n6\n7\n8\n9\n10\n11\nY\n"; + let preview = EditPreview::compute(old, new); + + assert_eq!(preview.hunks.len(), 2); + } + + #[test] + fn close_changes_merge_into_one_hunk() { + // Two changes separated by fewer than 2*CONTEXT_LINES lines + let old = "1\n2\n3\n4\n5\n"; + let new = "X\n2\n3\n4\nY\n"; + let preview = EditPreview::compute(old, new); + + assert_eq!(preview.hunks.len(), 1); + } + + #[test] + fn context_is_limited() { + // With CONTEXT_LINES=3, a change at line 10 shouldn't include line 1 + let mut old_lines: Vec<&str> = (1..=20).map(|_| "unchanged").collect(); + old_lines[9] = "target"; + let old = old_lines.join("\n") + "\n"; + let new = old.replace("target", "replaced"); + + let preview = EditPreview::compute(&old, &new); + assert_eq!(preview.hunks.len(), 1); + + // Should have at most 3 context lines before + 3 after + 1 removed + 1 added = 8 lines + assert!(preview.hunks[0].lines.len() <= 8); + } + + #[test] + fn max_line_number_reflects_file_position() { + let old = "a\nb\nc\n"; + let new = "a\nX\nc\n"; + let preview = EditPreview::compute(old, new); + // 3-line file, context + removed lines span positions 1-3 + assert_eq!(preview.max_line_number(), 3); + } + + #[test] + fn start_line_is_correct_for_later_changes() { + // Change at line 10 with 3 context lines → before_start = 7 + let mut lines: Vec<String> = (1..=15).map(|i| format!("line{i}")).collect(); + let old = lines.join("\n") + "\n"; + lines[9] = "CHANGED".to_string(); + let new = lines.join("\n") + "\n"; + + let preview = EditPreview::compute(&old, &new); + assert_eq!(preview.hunks.len(), 1); + assert_eq!(preview.hunks[0].before_start, 7); // line 10 - 3 context = line 7 + assert_eq!(preview.hunks[0].after_start, 7); // same for a simple replacement + } + + #[test] + fn multiline_replacement() { + let old = "[section]\nkey1 = old1\nkey2 = old2\n[other]\n"; + let new = "[section]\nkey1 = new1\nkey2 = new2\n[other]\n"; + let preview = EditPreview::compute(old, new); + + assert_eq!(preview.hunks.len(), 1); + let hunk = &preview.hunks[0]; + assert!( + hunk.lines + .contains(&DiffLine::Removed("key1 = old1".into())) + ); + assert!( + hunk.lines + .contains(&DiffLine::Removed("key2 = old2".into())) + ); + assert!(hunk.lines.contains(&DiffLine::Added("key1 = new1".into()))); + assert!(hunk.lines.contains(&DiffLine::Added("key2 = new2".into()))); + } +} diff --git a/crates/atuin-ai/src/edit_permissions.rs b/crates/atuin-ai/src/edit_permissions.rs new file mode 100644 index 00000000..5015a007 --- /dev/null +++ b/crates/atuin-ai/src/edit_permissions.rs @@ -0,0 +1,108 @@ +//! Session-scoped permission cache for file edits. +//! +//! When the user selects "Allow this file for this session", the grant is +//! recorded here with a timestamp. Subsequent edits to the same file skip +//! the permission prompt as long as the grant hasn't expired. +//! +//! Grants are time-limited (1 hour TTL) so they don't outlive the user's +//! attention in long-running sessions. Persisted as JSON in session +//! metadata so they survive across CLI invocations. + +use std::collections::HashMap; +use std::path::{Path, PathBuf}; +use std::time::SystemTime; + +use eyre::Result; +use serde::{Deserialize, Serialize}; + +/// Session metadata key for persistence. +pub(crate) const METADATA_KEY: &str = "edit_permissions"; + +/// How long a session-scoped edit permission remains valid. +const TTL_MS: i64 = 60 * 60 * 1000; // 1 hour + +/// Cache of per-file edit permission grants within a session. +#[derive(Debug, Default, Clone, Serialize, Deserialize)] +pub(crate) struct EditPermissionCache { + /// Maps canonical file paths to the grant timestamp (unix millis). + grants: HashMap<PathBuf, i64>, +} + +impl EditPermissionCache { + /// Record a permission grant for a file. + pub fn grant(&mut self, path: PathBuf) { + self.grants.insert(path, now_ms()); + } + + /// Check whether there's a valid (non-expired) grant for a file. + pub fn has_valid_grant(&self, path: &Path) -> bool { + if let Some(&granted_at) = self.grants.get(path) { + (now_ms() - granted_at) < TTL_MS + } else { + false + } + } + + pub fn to_json(&self) -> Result<String> { + Ok(serde_json::to_string(self)?) + } + + pub fn from_json(json: &str) -> Result<Self> { + Ok(serde_json::from_str(json)?) + } +} + +fn now_ms() -> i64 { + SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .map(|d| d.as_millis() as i64) + .unwrap_or(0) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn grant_and_check() { + let mut cache = EditPermissionCache::default(); + let path = PathBuf::from("/Users/me/.config/foo.toml"); + + assert!(!cache.has_valid_grant(&path)); + cache.grant(path.clone()); + assert!(cache.has_valid_grant(&path)); + } + + #[test] + fn different_paths_are_independent() { + let mut cache = EditPermissionCache::default(); + let a = PathBuf::from("/etc/hosts"); + let b = PathBuf::from("/etc/resolv.conf"); + + cache.grant(a.clone()); + assert!(cache.has_valid_grant(&a)); + assert!(!cache.has_valid_grant(&b)); + } + + #[test] + fn roundtrip_json() { + let mut cache = EditPermissionCache::default(); + cache.grant(PathBuf::from("/some/file.toml")); + + let json = cache.to_json().unwrap(); + let restored = EditPermissionCache::from_json(&json).unwrap(); + assert!(restored.has_valid_grant(Path::new("/some/file.toml"))); + } + + #[test] + fn expired_grant_is_invalid() { + let mut cache = EditPermissionCache::default(); + let path = PathBuf::from("/expired/file.toml"); + + // Insert a grant from 2 hours ago + let two_hours_ago = now_ms() - (2 * 60 * 60 * 1000); + cache.grants.insert(path.clone(), two_hours_ago); + + assert!(!cache.has_valid_grant(&path)); + } +} diff --git a/crates/atuin-ai/src/file_tracker.rs b/crates/atuin-ai/src/file_tracker.rs new file mode 100644 index 00000000..feee1ee8 --- /dev/null +++ b/crates/atuin-ai/src/file_tracker.rs @@ -0,0 +1,234 @@ +//! Tracks which files have been read in the current session, for freshness +//! checking before edits. +//! +//! The tracker records the content hash and mtime of each file at the time +//! it was last read. Before an edit, the tracker verifies the file hasn't +//! changed since the last read — catching both external modifications and +//! concurrent tool calls. +//! +//! Persisted as JSON in session metadata so it survives across CLI +//! invocations within the same logical session. + +use std::collections::HashMap; +use std::path::{Path, PathBuf}; +use std::time::SystemTime; + +use eyre::Result; +use serde::{Deserialize, Serialize}; + +/// Metadata key used for session_metadata persistence. +pub(crate) const METADATA_KEY: &str = "file_read_tracker"; + +/// State recorded for a single file read. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct FileReadState { + /// Hash of the file contents at the time of the last read. + pub content_hash: u64, + /// File mtime (as milliseconds since epoch) at the time of the last read. + /// Millisecond precision ensures sub-second modifications are detected. + pub mtime_ms: i64, +} + +/// Tracks file read state for freshness checking. +#[derive(Debug, Default, Clone, Serialize, Deserialize)] +pub(crate) struct FileReadTracker { + reads: HashMap<PathBuf, FileReadState>, +} + +/// Result of a freshness check. +pub(crate) enum FreshnessCheck { + /// File is fresh — the content hasn't changed since the last read. + Fresh, + /// File has never been read in this session. + NotRead, + /// File has been modified since the last read. + Stale, +} + +impl FileReadTracker { + /// Record that a file was read. Call this after a successful `read_file` + /// execution. The `path` should be canonical (absolute, tilde-expanded). + pub fn record_read(&mut self, path: PathBuf, content: &[u8], mtime: SystemTime) { + let content_hash = hash_content(content); + let mtime_ms = system_time_to_ms(mtime); + + self.reads.insert( + path, + FileReadState { + content_hash, + mtime_ms, + }, + ); + } + + /// Check whether a file is fresh (unchanged since last read). + /// + /// Uses mtime as a fast path — only re-hashes if mtime differs. + pub fn check_freshness(&self, path: &Path) -> Result<FreshnessCheck> { + let state = match self.reads.get(path) { + Some(s) => s, + None => return Ok(FreshnessCheck::NotRead), + }; + + // Stat the file + let metadata = match std::fs::metadata(path) { + Ok(m) => m, + Err(_) => return Ok(FreshnessCheck::Stale), // file deleted or inaccessible + }; + + let current_mtime_ms = + system_time_to_ms(metadata.modified().unwrap_or(SystemTime::UNIX_EPOCH)); + + // Fast path: mtime unchanged → fresh + if current_mtime_ms == state.mtime_ms { + return Ok(FreshnessCheck::Fresh); + } + + // Mtime changed — re-hash to confirm + let content = std::fs::read(path)?; + let current_hash = hash_content(&content); + + if current_hash == state.content_hash { + Ok(FreshnessCheck::Fresh) + } else { + Ok(FreshnessCheck::Stale) + } + } + + /// Update the tracker entry after a successful edit (new content written). + pub fn update_after_edit(&mut self, path: &Path, new_content: &[u8], new_mtime: SystemTime) { + let content_hash = hash_content(new_content); + let mtime_ms = system_time_to_ms(new_mtime); + + self.reads.insert( + path.to_path_buf(), + FileReadState { + content_hash, + mtime_ms, + }, + ); + } + + /// Serialize to JSON for session metadata persistence. + pub fn to_json(&self) -> Result<String> { + Ok(serde_json::to_string(self)?) + } + + /// Deserialize from JSON session metadata. + pub fn from_json(json: &str) -> Result<Self> { + Ok(serde_json::from_str(json)?) + } +} + +fn system_time_to_ms(t: SystemTime) -> i64 { + t.duration_since(SystemTime::UNIX_EPOCH) + .map(|d| d.as_millis() as i64) + .unwrap_or(0) +} + +fn hash_content(content: &[u8]) -> u64 { + xxhash_rust::xxh3::xxh3_64(content) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write; + use tempfile::NamedTempFile; + + #[test] + fn record_and_check_fresh() { + let mut tracker = FileReadTracker::default(); + let mut tmp = NamedTempFile::new().unwrap(); + write!(tmp, "hello world").unwrap(); + + let path = tmp.path().to_path_buf(); + let content = std::fs::read(&path).unwrap(); + let mtime = std::fs::metadata(&path).unwrap().modified().unwrap(); + + tracker.record_read(path.clone(), &content, mtime); + + assert!(matches!( + tracker.check_freshness(&path).unwrap(), + FreshnessCheck::Fresh + )); + } + + #[test] + fn check_not_read() { + let tracker = FileReadTracker::default(); + let path = PathBuf::from("/nonexistent/file.txt"); + assert!(matches!( + tracker.check_freshness(&path).unwrap(), + FreshnessCheck::NotRead + )); + } + + #[test] + fn check_stale_after_modification() { + let mut tracker = FileReadTracker::default(); + let mut tmp = NamedTempFile::new().unwrap(); + write!(tmp, "original").unwrap(); + + let path = tmp.path().to_path_buf(); + let content = std::fs::read(&path).unwrap(); + let mtime = std::fs::metadata(&path).unwrap().modified().unwrap(); + + tracker.record_read(path.clone(), &content, mtime); + + // Small delay to ensure the filesystem mtime advances + std::thread::sleep(std::time::Duration::from_millis(10)); + + // Modify the file + std::fs::write(&path, "modified").unwrap(); + + assert!(matches!( + tracker.check_freshness(&path).unwrap(), + FreshnessCheck::Stale + )); + } + + #[test] + fn update_after_edit_makes_fresh() { + let mut tracker = FileReadTracker::default(); + let mut tmp = NamedTempFile::new().unwrap(); + write!(tmp, "original").unwrap(); + + let path = tmp.path().to_path_buf(); + let content = std::fs::read(&path).unwrap(); + let mtime = std::fs::metadata(&path).unwrap().modified().unwrap(); + + tracker.record_read(path.clone(), &content, mtime); + + // Simulate an edit + let new_content = b"edited content"; + std::fs::write(&path, new_content).unwrap(); + let new_mtime = std::fs::metadata(&path).unwrap().modified().unwrap(); + tracker.update_after_edit(&path, new_content, new_mtime); + + assert!(matches!( + tracker.check_freshness(&path).unwrap(), + FreshnessCheck::Fresh + )); + } + + #[test] + fn roundtrip_json() { + let mut tracker = FileReadTracker::default(); + tracker.reads.insert( + PathBuf::from("/some/file.toml"), + FileReadState { + content_hash: 12345, + mtime_ms: 1700000000000, + }, + ); + + let json = tracker.to_json().unwrap(); + let restored = FileReadTracker::from_json(&json).unwrap(); + assert_eq!(restored.reads.len(), 1); + assert_eq!( + restored.reads[&PathBuf::from("/some/file.toml")].content_hash, + 12345 + ); + } +} diff --git a/crates/atuin-ai/src/lib.rs b/crates/atuin-ai/src/lib.rs index febb488e..afe9c1e4 100644 --- a/crates/atuin-ai/src/lib.rs +++ b/crates/atuin-ai/src/lib.rs @@ -1,9 +1,13 @@ pub mod commands; pub(crate) mod context; pub(crate) mod context_window; +pub(crate) mod diff; +pub(crate) mod edit_permissions; pub(crate) mod event_serde; +pub(crate) mod file_tracker; pub(crate) mod permissions; pub(crate) mod session; +pub(crate) mod snapshots; pub(crate) mod store; pub(crate) mod stream; pub(crate) mod tools; diff --git a/crates/atuin-ai/src/session.rs b/crates/atuin-ai/src/session.rs index d8314343..848330fc 100644 --- a/crates/atuin-ai/src/session.rs +++ b/crates/atuin-ai/src/session.rs @@ -51,6 +51,9 @@ pub(crate) trait SessionService: Send + Sync { ) -> Result<()>; async fn archive(&self, session_id: &str) -> Result<()>; + + async fn get_metadata(&self, session_id: &str, key: &str) -> Result<Option<String>>; + async fn set_metadata(&self, session_id: &str, key: &str, value: &str) -> Result<()>; } // --------------------------------------------------------------------------- @@ -128,6 +131,14 @@ impl SessionService for LocalSessionService { async fn archive(&self, session_id: &str) -> Result<()> { self.store.archive_session(session_id).await } + + async fn get_metadata(&self, session_id: &str, key: &str) -> Result<Option<String>> { + self.store.get_metadata(session_id, key).await + } + + async fn set_metadata(&self, session_id: &str, key: &str, value: &str) -> Result<()> { + self.store.set_metadata(session_id, key, value).await + } } // --------------------------------------------------------------------------- @@ -310,6 +321,22 @@ impl SessionManager { pub fn invocation_id(&self) -> &str { &self.invocation_id } + + /// Read a metadata value for the current session. + pub async fn get_metadata(&self, key: &str) -> Result<Option<String>> { + if !self.persisted_to_db { + return Ok(None); + } + self.service.get_metadata(&self.session_id, key).await + } + + /// Write a metadata value for the current session. + pub async fn set_metadata(&mut self, key: &str, value: &str) -> Result<()> { + self.ensure_persisted().await?; + self.service + .set_metadata(&self.session_id, key, value) + .await + } } #[cfg(test)] diff --git a/crates/atuin-ai/src/snapshots.rs b/crates/atuin-ai/src/snapshots.rs new file mode 100644 index 00000000..6c7b0c9c --- /dev/null +++ b/crates/atuin-ai/src/snapshots.rs @@ -0,0 +1,414 @@ +//! Backup snapshots for files before AI edits. +//! +//! Before the first edit to a file within a session, a snapshot of the +//! original content is saved so the user can recover if needed. Snapshots +//! are stored alongside a manifest that maps sanitized filenames back to +//! their original paths. +//! +//! Filenames use percent-encoding (`/` → `%2F`) so the snapshot directory +//! is human-readable via `ls`. + +use std::collections::HashMap; +use std::io::Write; +use std::path::{Path, PathBuf}; + +use eyre::{Result, eyre}; +use serde::{Deserialize, Serialize}; +use time::OffsetDateTime; + +/// Snapshot store for a single session. +/// +/// Each session gets its own directory under the snapshots root: +/// `<data_dir>/ai/snapshots/<session_id>/` +/// +/// Files are stored with percent-encoded filenames derived from their +/// canonical paths, alongside a `manifest.json` that maps filenames +/// back to original paths with timestamps. +#[derive(Debug)] +pub(crate) struct SnapshotStore { + session_dir: PathBuf, + manifest: SnapshotManifest, +} + +#[derive(Debug, Default, Serialize, Deserialize)] +struct SnapshotManifest { + files: HashMap<String, SnapshotEntry>, +} + +#[derive(Debug, Serialize, Deserialize)] +struct SnapshotEntry { + original_path: String, + snapshot_at: String, + size_bytes: u64, +} + +impl SnapshotStore { + /// Open or create a snapshot store for the given session directory. + /// + /// If a manifest already exists (from a prior CLI invocation in the same + /// session), it's loaded so we don't re-snapshot files that were already + /// backed up. + pub fn open(session_dir: PathBuf) -> Result<Self> { + let manifest_path = session_dir.join("manifest.json"); + let manifest = if manifest_path.exists() { + let data = fs_err::read_to_string(&manifest_path)?; + serde_json::from_str(&data)? + } else { + SnapshotManifest::default() + }; + + Ok(Self { + session_dir, + manifest, + }) + } + + /// Snapshot a file's contents if it hasn't been snapshotted yet this session. + /// + /// Returns `true` if a new snapshot was created, `false` if one already + /// existed. The `canonical_path` should be absolute (already tilde-expanded + /// and resolved). + pub fn ensure_snapshot(&mut self, canonical_path: &Path, content: &[u8]) -> Result<bool> { + let filename = sanitize_path(canonical_path); + + if self.manifest.files.contains_key(&filename) { + return Ok(false); + } + + fs_err::create_dir_all(&self.session_dir)?; + + let snapshot_path = self.session_dir.join(&filename); + atomic_write_file(&snapshot_path, content)?; + + let now = OffsetDateTime::now_utc(); + let entry = SnapshotEntry { + original_path: canonical_path.to_string_lossy().into_owned(), + snapshot_at: format_iso8601(now), + size_bytes: content.len() as u64, + }; + + self.manifest.files.insert(filename, entry); + self.save_manifest()?; + + Ok(true) + } + + /// Whether a file has already been snapshotted in this session. + #[expect(dead_code)] + pub fn has_snapshot(&self, canonical_path: &Path) -> bool { + let filename = sanitize_path(canonical_path); + self.manifest.files.contains_key(&filename) + } + + fn save_manifest(&self) -> Result<()> { + let json = serde_json::to_string_pretty(&self.manifest)?; + atomic_write_file(&self.session_dir.join("manifest.json"), json.as_bytes()) + } +} + +/// Percent-encode a path for use as a filename. +/// +/// Encodes `%` as `%25`, `/` as `%2F`, and `\` as `%5C`, then strips +/// leading separators and drive prefixes (e.g. `C:\`). The result is +/// always a flat filename safe for use with `Path::join` on any platform. +/// +/// Example (Unix): `/Users/me/.config/foo.toml` → `Users%2Fme%2F.config%2Ffoo.toml` +/// Example (Windows): `C:\Users\me\config.toml` → `Users%5Cme%5Cconfig.toml` +pub(crate) fn sanitize_path(path: &Path) -> String { + let s = path.to_string_lossy(); + // Strip drive letter prefix on Windows (e.g. "C:\") + let s = s.strip_prefix('/').unwrap_or_else(|| { + // Handle Windows drive prefix like "C:\" or "C:/" + if s.len() >= 3 + && s.as_bytes()[0].is_ascii_alphabetic() + && s.as_bytes()[1] == b':' + && (s.as_bytes()[2] == b'\\' || s.as_bytes()[2] == b'/') + { + &s[3..] + } else { + &s + } + }); + s.replace('%', "%25") + .replace('/', "%2F") + .replace('\\', "%5C") +} + +/// Write a file atomically using temp-file-then-rename. +/// +/// Creates a temporary file in the same directory as `target`, writes +/// content, fsyncs, then renames into place. Preserves permissions from +/// the original file if it exists. +pub(crate) fn atomic_write_file(target: &Path, content: &[u8]) -> Result<()> { + let dir = target + .parent() + .ok_or_else(|| eyre!("target path has no parent directory"))?; + fs_err::create_dir_all(dir)?; + + let mut tmp = tempfile::NamedTempFile::new_in(dir)?; + tmp.write_all(content)?; + tmp.as_file().sync_all()?; + + // Preserve permissions from original if it exists + if let Ok(meta) = std::fs::metadata(target) { + std::fs::set_permissions(tmp.path(), meta.permissions())?; + } + + tmp.persist(target).map_err(|e| { + eyre!( + "failed to persist atomic write to {}: {}", + target.display(), + e + ) + })?; + Ok(()) +} + +fn format_iso8601(dt: OffsetDateTime) -> String { + format!( + "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z", + dt.year(), + dt.month() as u8, + dt.day(), + dt.hour(), + dt.minute(), + dt.second(), + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── sanitize_path ────────────────────────────────────────── + + #[test] + fn sanitize_absolute_path() { + let path = Path::new("/Users/me/.config/atuin/config.toml"); + assert_eq!( + sanitize_path(path), + "Users%2Fme%2F.config%2Fatuin%2Fconfig.toml" + ); + } + + #[test] + fn sanitize_preserves_existing_percent() { + let path = Path::new("/data/100%done/file.txt"); + assert_eq!(sanitize_path(path), "data%2F100%25done%2Ffile.txt"); + } + + #[test] + fn sanitize_relative_path() { + let path = Path::new("relative/path.txt"); + assert_eq!(sanitize_path(path), "relative%2Fpath.txt"); + } + + #[test] + fn sanitize_no_collision_between_similar_paths() { + let a = sanitize_path(Path::new("/foo/bar-baz")); + let b = sanitize_path(Path::new("/foo/bar/baz")); + assert_ne!(a, b); + } + + #[test] + fn sanitize_backslash_encoded() { + // Windows-style path: backslashes become %5C, drive prefix stripped + let s = sanitize_path(Path::new("C:\\Users\\me\\config.toml")); + assert!(!s.contains('\\'), "backslashes must be encoded: {s}"); + assert!(!s.starts_with("C:"), "drive prefix must be stripped: {s}"); + assert!(s.contains("Users")); + assert!(s.contains("config.toml")); + } + + #[test] + fn sanitize_result_is_flat_filename() { + // The result must not be interpreted as a path with separators + // when passed to Path::join — no raw / or \ allowed. + let unix = sanitize_path(Path::new("/home/user/file.txt")); + assert!(!unix.contains('/')); + // Construct as if on Windows + let win = "C:\\Users\\me\\file.txt"; + let encoded = win + .strip_prefix("C:\\") + .unwrap() + .replace('%', "%25") + .replace('/', "%2F") + .replace('\\', "%5C"); + assert!(!encoded.contains('\\')); + assert!(!encoded.contains('/')); + } + + // ── atomic_write_file ────────────────────────────────────── + + #[test] + fn atomic_write_creates_file() { + let dir = tempfile::tempdir().unwrap(); + let target = dir.path().join("test.txt"); + + atomic_write_file(&target, b"hello world").unwrap(); + + assert_eq!(std::fs::read_to_string(&target).unwrap(), "hello world"); + } + + #[test] + fn atomic_write_overwrites_existing() { + let dir = tempfile::tempdir().unwrap(); + let target = dir.path().join("test.txt"); + + std::fs::write(&target, "old content").unwrap(); + atomic_write_file(&target, b"new content").unwrap(); + + assert_eq!(std::fs::read_to_string(&target).unwrap(), "new content"); + } + + #[test] + fn atomic_write_creates_parent_dirs() { + let dir = tempfile::tempdir().unwrap(); + let target = dir.path().join("sub").join("dir").join("test.txt"); + + atomic_write_file(&target, b"nested").unwrap(); + + assert_eq!(std::fs::read_to_string(&target).unwrap(), "nested"); + } + + #[cfg(unix)] + #[test] + fn atomic_write_preserves_permissions() { + use std::os::unix::fs::PermissionsExt; + + let dir = tempfile::tempdir().unwrap(); + let target = dir.path().join("test.txt"); + + std::fs::write(&target, "original").unwrap(); + std::fs::set_permissions(&target, std::fs::Permissions::from_mode(0o600)).unwrap(); + + atomic_write_file(&target, b"updated").unwrap(); + + let mode = std::fs::metadata(&target).unwrap().permissions().mode() & 0o777; + assert_eq!(mode, 0o600); + } + + // ── SnapshotStore ────────────────────────────────────────── + + #[test] + fn snapshot_creates_file_and_manifest() { + let dir = tempfile::tempdir().unwrap(); + let session_dir = dir.path().join("session-abc"); + let mut store = SnapshotStore::open(session_dir.clone()).unwrap(); + + let file_path = Path::new("/Users/me/.config/foo.toml"); + let created = store + .ensure_snapshot(file_path, b"[key]\nval = 1\n") + .unwrap(); + + assert!(created); + assert!(store.has_snapshot(file_path)); + + // Snapshot file on disk + let expected_file = session_dir.join("Users%2Fme%2F.config%2Ffoo.toml"); + assert!(expected_file.exists()); + assert_eq!( + std::fs::read_to_string(&expected_file).unwrap(), + "[key]\nval = 1\n" + ); + + // Manifest on disk + let manifest_path = session_dir.join("manifest.json"); + assert!(manifest_path.exists()); + let manifest: serde_json::Value = + serde_json::from_str(&std::fs::read_to_string(&manifest_path).unwrap()).unwrap(); + let files = manifest["files"].as_object().unwrap(); + assert_eq!(files.len(), 1); + let entry = &files["Users%2Fme%2F.config%2Ffoo.toml"]; + assert_eq!( + entry["original_path"].as_str().unwrap(), + "/Users/me/.config/foo.toml" + ); + assert_eq!(entry["size_bytes"].as_u64().unwrap(), 14); + } + + #[test] + fn snapshot_is_idempotent() { + let dir = tempfile::tempdir().unwrap(); + let session_dir = dir.path().join("session-abc"); + let mut store = SnapshotStore::open(session_dir.clone()).unwrap(); + + let path = Path::new("/etc/hosts"); + let first = store.ensure_snapshot(path, b"first content").unwrap(); + let second = store.ensure_snapshot(path, b"different content").unwrap(); + + assert!(first); + assert!(!second); + + // Original content preserved, not overwritten + let snapshot_file = session_dir.join("etc%2Fhosts"); + assert_eq!( + std::fs::read_to_string(snapshot_file).unwrap(), + "first content" + ); + } + + #[test] + fn snapshot_store_loads_existing_manifest() { + let dir = tempfile::tempdir().unwrap(); + let session_dir = dir.path().join("session-abc"); + + // First store: create a snapshot + { + let mut store = SnapshotStore::open(session_dir.clone()).unwrap(); + store + .ensure_snapshot(Path::new("/etc/hosts"), b"127.0.0.1") + .unwrap(); + } + + // Second store (simulates new CLI invocation): should see existing snapshot + { + let mut store = SnapshotStore::open(session_dir).unwrap(); + assert!(store.has_snapshot(Path::new("/etc/hosts"))); + + let created = store + .ensure_snapshot(Path::new("/etc/hosts"), b"new content") + .unwrap(); + assert!(!created); + } + } + + #[test] + fn snapshot_multiple_files() { + let dir = tempfile::tempdir().unwrap(); + let session_dir = dir.path().join("session-abc"); + let mut store = SnapshotStore::open(session_dir.clone()).unwrap(); + + store + .ensure_snapshot(Path::new("/etc/hosts"), b"hosts content") + .unwrap(); + store + .ensure_snapshot(Path::new("/Users/me/.bashrc"), b"bashrc content") + .unwrap(); + + assert!(store.has_snapshot(Path::new("/etc/hosts"))); + assert!(store.has_snapshot(Path::new("/Users/me/.bashrc"))); + assert!(!store.has_snapshot(Path::new("/nonexistent"))); + + // Both snapshot files exist + assert!(session_dir.join("etc%2Fhosts").exists()); + assert!(session_dir.join("Users%2Fme%2F.bashrc").exists()); + + // Manifest has both entries + let manifest: serde_json::Value = serde_json::from_str( + &std::fs::read_to_string(session_dir.join("manifest.json")).unwrap(), + ) + .unwrap(); + assert_eq!(manifest["files"].as_object().unwrap().len(), 2); + } + + #[test] + fn format_iso8601_produces_valid_format() { + let dt = OffsetDateTime::from_unix_timestamp(1700000000).unwrap(); + let formatted = format_iso8601(dt); + assert_eq!(formatted.len(), 20); + assert!(formatted.starts_with("2023-")); + assert!(formatted.contains('T')); + assert!(formatted.ends_with('Z')); + } +} diff --git a/crates/atuin-ai/src/store.rs b/crates/atuin-ai/src/store.rs index 2a75d8f4..20b9e881 100644 --- a/crates/atuin-ai/src/store.rs +++ b/crates/atuin-ai/src/store.rs @@ -299,6 +299,38 @@ impl AiSessionStore { .await?; Ok(()) } + + // ── Session metadata (key-value per session) ── + + /// Read a metadata value for a session. Returns `None` if the key doesn't + /// exist or the session hasn't been persisted yet. + pub async fn get_metadata(&self, session_id: &str, key: &str) -> Result<Option<String>> { + let row: Option<(String,)> = + sqlx::query_as("SELECT value FROM session_metadata WHERE session_id = ?1 AND key = ?2") + .bind(session_id) + .bind(key) + .fetch_optional(&self.pool) + .await?; + + Ok(row.map(|(v,)| v)) + } + + /// Write a metadata value for a session (upsert). + pub async fn set_metadata(&self, session_id: &str, key: &str, value: &str) -> Result<()> { + let now = OffsetDateTime::now_utc().unix_timestamp(); + sqlx::query( + "INSERT INTO session_metadata (session_id, key, value, updated_at) + VALUES (?1, ?2, ?3, ?4) + ON CONFLICT (session_id, key) DO UPDATE SET value = ?3, updated_at = ?4", + ) + .bind(session_id) + .bind(key) + .bind(value) + .bind(now) + .execute(&self.pool) + .await?; + Ok(()) + } } #[cfg(test)] diff --git a/crates/atuin-ai/src/stream.rs b/crates/atuin-ai/src/stream.rs index f4f4d704..24770abe 100644 --- a/crates/atuin-ai/src/stream.rs +++ b/crates/atuin-ai/src/stream.rs @@ -74,6 +74,14 @@ impl ChatRequest { if capabilities.enable_history_search.unwrap_or(true) { caps.push("client_v1_atuin_history".to_string()); } + if capabilities.enable_file_tools.unwrap_or(true) { + caps.push("client_v1_read_file".to_string()); + caps.push("client_v1_edit_file".to_string()); + caps.push("client_v1_write_file".to_string()); + } + if capabilities.enable_command_execution.unwrap_or(true) { + caps.push("client_v1_execute_shell_command".to_string()); + } if let Ok(extra) = std::env::var("ATUIN_AI__ADDITIONAL_CAPS") { caps.extend( extra diff --git a/crates/atuin-ai/src/tools/descriptor.rs b/crates/atuin-ai/src/tools/descriptor.rs index fc44ec10..6ccb595f 100644 --- a/crates/atuin-ai/src/tools/descriptor.rs +++ b/crates/atuin-ai/src/tools/descriptor.rs @@ -16,6 +16,7 @@ pub(crate) struct ToolDescriptor { /// Past-tense verb for summaries (e.g. "Read file"). pub past_verb: &'static str, /// Whether this tool is executed client-side (by the CLI). + #[expect(dead_code)] pub is_client: bool, } @@ -30,9 +31,18 @@ pub(crate) const READ: &ToolDescriptor = &ToolDescriptor { is_client: true, }; +pub(crate) const EDIT: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["edit_file"], + capability: Some("client_v1_edit_file"), + display_verb: "edit", + progressive_verb: "Editing file...", + past_verb: "Edited file", + is_client: true, +}; + pub(crate) const WRITE: &ToolDescriptor = &ToolDescriptor { - canonical_names: &["str_replace", "file_create", "file_insert"], - capability: Some("client_v1_write"), + canonical_names: &["write_file"], + capability: Some("client_v1_write_file"), display_verb: "write to", progressive_verb: "Writing file...", past_verb: "Wrote file", @@ -41,7 +51,7 @@ pub(crate) const WRITE: &ToolDescriptor = &ToolDescriptor { pub(crate) const SHELL: &ToolDescriptor = &ToolDescriptor { canonical_names: &["execute_shell_command"], - capability: Some("client_v1_shell"), + capability: Some("client_v1_execute_shell_command"), display_verb: "run", progressive_verb: "Running command...", past_verb: "Ran command", @@ -81,6 +91,7 @@ pub(crate) const SERVER_SCRAPE: &ToolDescriptor = &ToolDescriptor { /// All known tool descriptors, for lookup by name. const ALL_DESCRIPTORS: &[&ToolDescriptor] = &[ READ, + EDIT, WRITE, SHELL, ATUIN_HISTORY, diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs index 8f2183b7..8fe1ad73 100644 --- a/crates/atuin-ai/src/tools/mod.rs +++ b/crates/atuin-ai/src/tools/mod.rs @@ -169,6 +169,8 @@ pub(crate) struct TrackedTool { pub phase: ToolPhase, /// Sender to interrupt a running shell command (only set during ExecutingWithPreview). pub abort_tx: Option<tokio::sync::oneshot::Sender<()>>, + /// Diff preview for completed edit tool calls. + pub edit_preview: Option<crate::diff::EditPreview>, } impl TrackedTool { @@ -234,6 +236,7 @@ impl ToolTracker { tool, phase: ToolPhase::CheckingPermissions, abort_tx: None, + edit_preview: None, }); } @@ -294,11 +297,6 @@ impl ToolTracker { .find(|t| t.phase == ToolPhase::AskingForPermission) } - /// Get the preview for a tool by ID (live or cached). - pub fn preview_for(&self, id: &str) -> Option<ToolPreview> { - self.get(id)?.preview() - } - /// Iterate mutably over all tracked tools. pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut TrackedTool> { self.tools.iter_mut() @@ -309,6 +307,7 @@ impl ToolTracker { #[derive(Debug, Clone)] pub(crate) enum ClientToolCall { Read(ReadToolCall), + Edit(EditToolCall), Write(WriteToolCall), Shell(ShellToolCall), AtuinHistory(AtuinHistoryToolCall), @@ -320,9 +319,8 @@ impl TryFrom<(&str, &serde_json::Value)> for ClientToolCall { fn try_from((name, input): (&str, &serde_json::Value)) -> Result<Self, Self::Error> { match name { "read_file" => Ok(ClientToolCall::Read(ReadToolCall::try_from(input)?)), - "create_file" => Ok(ClientToolCall::Write(WriteToolCall::try_from(input)?)), - // "append_to_file" => Ok(ClientToolCall::Append(AppendToolCall::try_from(input)?)), - // "str_replace" => Ok(ClientToolCall::StrReplace(StrReplaceToolCall::try_from(input)?)), + "edit_file" => Ok(ClientToolCall::Edit(EditToolCall::try_from(input)?)), + "write_file" => Ok(ClientToolCall::Write(WriteToolCall::try_from(input)?)), "execute_shell_command" => Ok(ClientToolCall::Shell(ShellToolCall::try_from(input)?)), "atuin_history" => Ok(ClientToolCall::AtuinHistory( AtuinHistoryToolCall::try_from(input)?, @@ -336,17 +334,22 @@ impl ClientToolCall { pub(crate) fn descriptor(&self) -> &'static descriptor::ToolDescriptor { match self { ClientToolCall::Read(_) => descriptor::READ, + ClientToolCall::Edit(_) => descriptor::EDIT, ClientToolCall::Write(_) => descriptor::WRITE, ClientToolCall::Shell(_) => descriptor::SHELL, ClientToolCall::AtuinHistory(_) => descriptor::ATUIN_HISTORY, } } - /// The permission rule name for this tool category (e.g. "Write" covers - /// str_replace, file_create, file_insert). + /// The permission rule name for this tool category. + /// + /// Edit and Write share the `"Write"` rule name — a Write permission + /// covers both str_replace edits and full file creates. Write also + /// implies Read (checked in `ReadToolCall::matches_rule`). pub(crate) fn rule_name(&self) -> &'static str { match self { ClientToolCall::Read(_) => "Read", + ClientToolCall::Edit(_) => "Write", ClientToolCall::Write(_) => "Write", ClientToolCall::Shell(_) => "Shell", ClientToolCall::AtuinHistory(_) => "AtuinHistory", @@ -356,6 +359,7 @@ impl ClientToolCall { pub(crate) fn matches_rule(&self, rule: &Rule) -> bool { match self { ClientToolCall::Read(tool) => tool.matches_rule(rule), + ClientToolCall::Edit(tool) => tool.matches_rule(rule), ClientToolCall::Write(tool) => tool.matches_rule(rule), ClientToolCall::Shell(tool) => tool.matches_rule(rule), ClientToolCall::AtuinHistory(tool) => tool.matches_rule(rule), @@ -365,6 +369,7 @@ impl ClientToolCall { pub(crate) fn target_dir(&self) -> Option<&Path> { match self { ClientToolCall::Read(tool) => tool.target_dir(), + ClientToolCall::Edit(tool) => tool.target_dir(), ClientToolCall::Write(tool) => tool.target_dir(), ClientToolCall::Shell(tool) => tool.target_dir(), ClientToolCall::AtuinHistory(tool) => tool.target_dir(), @@ -401,6 +406,14 @@ impl PermissableToolCall for ClientToolCall { } } +/// Expand shell constructs (`~`, `$HOME`, etc.) in a path string. +/// +/// Tool call paths arrive as raw strings from the API without shell +/// expansion. Uses `shellexpand` (same as `atuin-client`). +fn expand_path(path: &str) -> PathBuf { + PathBuf::from(shellexpand::tilde(path).into_owned()) +} + #[derive(Debug, Clone)] pub(crate) struct ReadToolCall { pub path: PathBuf, @@ -425,7 +438,7 @@ impl TryFrom<&serde_json::Value> for ReadToolCall { .min(MAX_FILE_READ_LINES); Ok(ReadToolCall { - path: PathBuf::from(path), + path: expand_path(path), offset, limit, }) @@ -499,7 +512,207 @@ impl PermissableToolCall for ReadToolCall { } fn matches_rule(&self, rule: &Rule) -> bool { - if rule.tool != "Read" { + // Write implies Read — a Write permission on a path also permits reading it. + if rule.tool != "Read" && rule.tool != "Write" { + return false; + } + + match rule.scope.as_deref() { + None | Some("*") => true, + Some(scope) => path_matches_scope(&self.path, scope), + } + } +} + +#[derive(Debug, Clone)] +pub(crate) struct EditToolCall { + pub path: PathBuf, + pub old_string: String, + pub new_string: String, + pub replace_all: bool, +} + +impl TryFrom<&serde_json::Value> for EditToolCall { + type Error = eyre::Error; + + fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> { + let path = value + .get("file_path") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing file_path"))?; + + let old_string = value + .get("old_string") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing old_string"))?; + + let new_string = value + .get("new_string") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing new_string"))?; + + let replace_all = value + .get("replace_all") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + Ok(EditToolCall { + path: expand_path(path), + old_string: old_string.to_string(), + new_string: new_string.to_string(), + replace_all, + }) + } +} + +impl EditToolCall { + /// Resolve the edit path to an absolute path. + pub fn resolved_path(&self) -> PathBuf { + if self.path.is_relative() { + std::env::current_dir() + .map(|cwd| cwd.join(&self.path)) + .unwrap_or_else(|_| self.path.clone()) + } else { + self.path.clone() + } + } + + /// Execute the edit against the filesystem. + /// + /// Checks freshness via the provided tracker, validates matches, applies + /// the replacement, and writes atomically. Returns the outcome and (on + /// success) the new file content bytes for tracker updates. + /// + /// Callers should snapshot the file before calling this method and + /// update the file tracker after a successful return. + pub fn execute( + &self, + resolved_path: &Path, + file_tracker: &crate::file_tracker::FileReadTracker, + ) -> (ToolOutcome, Option<Vec<u8>>) { + use crate::file_tracker::FreshnessCheck; + + // 1. Basic validation + if !resolved_path.exists() { + return ( + ToolOutcome::Error(format!( + "Error: file does not exist: {}", + resolved_path.display() + )), + None, + ); + } + if resolved_path.is_dir() { + return ( + ToolOutcome::Error(format!( + "Error: path is a directory, not a file: {}", + resolved_path.display() + )), + None, + ); + } + if self.old_string.is_empty() { + return ( + ToolOutcome::Error( + "old_string must not be empty. To create a new file, use write_file instead." + .to_string(), + ), + None, + ); + } + + // 2. Freshness check + match file_tracker.check_freshness(resolved_path) { + Ok(FreshnessCheck::NotRead) => { + return ( + ToolOutcome::Error( + "File has not been read yet. Read it first before editing.".to_string(), + ), + None, + ); + } + Ok(FreshnessCheck::Stale) => { + return ( + ToolOutcome::Error( + "File has been modified since read, either by the user or by a linter. Read it again before attempting to edit it.".to_string(), + ), + None, + ); + } + Err(e) => { + return ( + ToolOutcome::Error(format!("Error checking file state: {e}")), + None, + ); + } + Ok(FreshnessCheck::Fresh) => {} + } + + // 3. Read current contents + let content = match std::fs::read_to_string(resolved_path) { + Ok(c) => c, + Err(e) => return (ToolOutcome::Error(format!("Error reading file: {e}")), None), + }; + + // 4. Find and validate matches + let match_count = content.matches(&self.old_string).count(); + + if match_count == 0 { + return ( + ToolOutcome::Error(format!( + "old_string not found in {}. Make sure it matches exactly, including whitespace and indentation.", + resolved_path.display() + )), + None, + ); + } + + if match_count > 1 && !self.replace_all { + return ( + ToolOutcome::Error(format!( + "Found {match_count} matches of old_string in {}, but replace_all is false. Either provide more context to make the match unique, or set replace_all to true.", + resolved_path.display() + )), + None, + ); + } + + // 5. Apply replacement + let new_content = if self.replace_all { + content.replace(&self.old_string, &self.new_string) + } else { + content.replacen(&self.old_string, &self.new_string, 1) + }; + + // 6. Write atomically + let new_bytes = new_content.into_bytes(); + if let Err(e) = crate::snapshots::atomic_write_file(resolved_path, &new_bytes) { + return (ToolOutcome::Error(format!("Error writing file: {e}")), None); + } + + // 7. Success + let verb = if match_count == 1 { + "occurrence" + } else { + "occurrences" + }; + ( + ToolOutcome::Success(format!( + "Edited {}: replaced {match_count} {verb} of old_string with new_string.", + resolved_path.display() + )), + Some(new_bytes), + ) + } +} + +impl PermissableToolCall for EditToolCall { + fn target_dir(&self) -> Option<&Path> { + Some(&self.path) + } + + fn matches_rule(&self, rule: &Rule) -> bool { + if rule.tool != "Write" { return false; } @@ -532,7 +745,7 @@ impl TryFrom<&serde_json::Value> for WriteToolCall { .ok_or(eyre::eyre!("Missing content"))?; Ok(WriteToolCall { - path: PathBuf::from(path), + path: expand_path(path), content: content.to_string(), }) } @@ -560,6 +773,9 @@ pub(crate) struct ShellToolCall { pub dir: Option<PathBuf>, pub command: String, pub shell: String, + // allow dead code here; this will be tied into o11y and user-facing descriptions + #[expect(dead_code)] + pub description: Option<String>, } impl TryFrom<&serde_json::Value> for ShellToolCall { @@ -579,10 +795,16 @@ impl TryFrom<&serde_json::Value> for ShellToolCall { .unwrap_or("bash") .to_string(); + let description = value + .get("description") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + Ok(ShellToolCall { - dir: dir.map(PathBuf::from), + dir: dir.map(expand_path), command: command.to_string(), shell, + description, }) } } @@ -614,7 +836,34 @@ const PREVIEW_HEIGHT: u16 = 10; /// Default terminal width for VT100 emulation. const PREVIEW_WIDTH: u16 = 120; +/// Normalize newlines for VT100 processing. +/// +/// When subprocess output is captured via pipes (no PTY), bare `\n` (LF) bytes +/// are not translated to `\r\n` (CR+LF) the way a kernel terminal driver would +/// with the `ONLCR` flag. In VT100, LF only moves the cursor down without +/// returning to column 0. This causes lines to start at progressively higher +/// column offsets and eventually wrap, producing garbled output. +/// +/// This function inserts `\r` before any `\n` that isn't already preceded by +/// `\r`, mimicking the terminal driver's ONLCR behavior. +fn normalize_newlines_for_vt100(data: &[u8]) -> Vec<u8> { + let mut out = Vec::with_capacity(data.len() + data.len() / 8); + for (i, &b) in data.iter().enumerate() { + if b == b'\n' && (i == 0 || data[i - 1] != b'\r') { + out.push(b'\r'); + } + out.push(b); + } + out +} + /// Extract plain text lines from a VT100 screen buffer. +/// +/// Strips trailing blank lines so the result only contains rows with actual +/// content. Without this, the fixed-size VT100 screen (PREVIEW_HEIGHT rows) +/// would always return that many lines, and downstream components that use +/// tail-mode display (like the Viewport) would show the blank padding rows +/// instead of the real output. fn vt100_screen_lines(screen: &vt100::Screen) -> Vec<String> { let (rows, cols) = screen.size(); let mut lines = Vec::with_capacity(rows as usize); @@ -625,9 +874,11 @@ fn vt100_screen_lines(screen: &vt100::Screen) -> Vec<String> { line.push_str(cell.contents()); } } - // Trim trailing whitespace for cleaner display lines.push(line.trim_end().to_string()); } + while lines.last().is_some_and(|l| l.is_empty()) { + lines.pop(); + } lines } @@ -640,12 +891,17 @@ fn strip_ansi_via_vt100(raw: &[u8]) -> String { if raw.is_empty() { return String::new(); } - // Use the contents_formatted → screen approach: feed bytes into a parser - // with enough rows to hold everything, then read back the plain text. - // Estimate rows: one row per ~PREVIEW_WIDTH bytes, plus generous padding. - let estimated_rows = (raw.len() / PREVIEW_WIDTH as usize + 1).min(10_000) as u16; + // Normalize bare LF to CR+LF so lines start at column 0 in the VT100 screen. + let normalized = normalize_newlines_for_vt100(raw); + // Feed bytes into a VT100 parser large enough to hold all output, then + // read back the plain text. We estimate rows from the number of newlines + // (not total byte length) because real output typically has short lines + // that would be severely under-counted by a bytes÷width estimate. + let newline_count = normalized.iter().filter(|&&b| b == b'\n').count(); + let wrap_estimate = normalized.len() / PREVIEW_WIDTH as usize; + let estimated_rows = (newline_count + wrap_estimate + 1).min(10_000) as u16; let mut parser = vt100::Parser::new(estimated_rows, PREVIEW_WIDTH, 0); - parser.process(raw); + parser.process(&normalized); let screen = parser.screen(); // screen.contents() returns the full plain-text content with trailing // whitespace trimmed per line and trailing blank lines removed. @@ -727,7 +983,8 @@ pub(crate) async fn execute_shell_command_streaming( Ok(0) => stdout_done = true, Ok(n) => { full_stdout.extend_from_slice(&stdout_buf[..n]); - parser.process(&stdout_buf[..n]); + let normalized = normalize_newlines_for_vt100(&stdout_buf[..n]); + parser.process(&normalized); } Err(_) => stdout_done = true, } @@ -740,7 +997,8 @@ pub(crate) async fn execute_shell_command_streaming( Ok(n) => { full_stderr.extend_from_slice(&stderr_buf[..n]); // Feed stderr to the preview parser too, so it shows in the VT100 screen - parser.process(&stderr_buf[..n]); + let normalized = normalize_newlines_for_vt100(&stderr_buf[..n]); + parser.process(&normalized); } Err(_) => stderr_done = true, } @@ -967,7 +1225,7 @@ mod tests { fn read_tool(path: &str) -> ReadToolCall { ReadToolCall { - path: PathBuf::from(path), + path: expand_path(path), offset: 0, limit: 100, } @@ -975,7 +1233,7 @@ mod tests { fn write_tool(path: &str) -> WriteToolCall { WriteToolCall { - path: PathBuf::from(path), + path: expand_path(path), content: String::new(), } } @@ -994,12 +1252,26 @@ mod tests { } #[test] - fn wrong_tool_never_matches() { - assert!(!read_tool("foo.txt").matches_rule(&write_rule(None))); + fn write_implies_read() { + // A Write rule also permits reads on the same path + assert!(read_tool("foo.txt").matches_rule(&write_rule(None))); + // But a Read rule does not permit writes assert!(!write_tool("foo.txt").matches_rule(&read_rule(None))); } #[test] + fn edit_uses_write_rule() { + let edit = EditToolCall { + path: expand_path("/home/user/config.toml"), + old_string: "x".into(), + new_string: "y".into(), + replace_all: false, + }; + assert!(edit.matches_rule(&write_rule(None))); + assert!(!edit.matches_rule(&read_rule(None))); + } + + #[test] fn extension_glob() { assert!(read_tool("notes.md").matches_rule(&read_rule(Some("*.md")))); assert!(!read_tool("notes.txt").matches_rule(&read_rule(Some("*.md")))); @@ -1050,6 +1322,419 @@ mod tests { } } + // ── edit_file execution tests ── + + mod edit { + use super::*; + use crate::file_tracker::FileReadTracker; + + /// Helper: create a temp file (with a closed handle), record it in a tracker. + /// Returns the TempDir (keeps the path alive) and tracker. + /// The file handle is closed so atomic_write_file can rename over it on Windows. + fn setup_tracked_file(content: &str) -> (tempfile::TempDir, PathBuf, FileReadTracker) { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("test_file.toml"); + std::fs::write(&path, content).unwrap(); + + let file_content = std::fs::read(&path).unwrap(); + let mtime = std::fs::metadata(&path).unwrap().modified().unwrap(); + + let mut tracker = FileReadTracker::default(); + tracker.record_read(path.clone(), &file_content, mtime); + + (dir, path, tracker) + } + + fn edit_call(path: &Path, old: &str, new: &str, replace_all: bool) -> EditToolCall { + EditToolCall { + path: path.to_path_buf(), + old_string: old.to_string(), + new_string: new.to_string(), + replace_all, + } + } + + #[test] + fn successful_single_replacement() { + let (_dir, path, tracker) = setup_tracked_file("[section]\nkey = old_value\n"); + + let call = edit_call(&path, "old_value", "new_value", false); + let (outcome, new_bytes) = call.execute(&path, &tracker); + + assert!(matches!(outcome, ToolOutcome::Success(_))); + assert!(new_bytes.is_some()); + assert_eq!( + std::fs::read_to_string(&path).unwrap(), + "[section]\nkey = new_value\n" + ); + } + + #[test] + fn successful_replace_all() { + let (_dir, path, tracker) = setup_tracked_file("aaa bbb aaa ccc aaa"); + + let call = edit_call(&path, "aaa", "xxx", true); + let (outcome, _) = call.execute(&path, &tracker); + + assert!(matches!(outcome, ToolOutcome::Success(ref s) if s.contains("3 occurrences"))); + assert_eq!( + std::fs::read_to_string(&path).unwrap(), + "xxx bbb xxx ccc xxx" + ); + } + + #[test] + fn error_file_not_read() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("unread.txt"); + std::fs::write(&path, "content").unwrap(); + let tracker = FileReadTracker::default(); // empty — never read + + let call = edit_call(&path, "x", "y", false); + let (outcome, new_bytes) = call.execute(&path, &tracker); + + assert!(new_bytes.is_none()); + match outcome { + ToolOutcome::Error(msg) => { + assert!(msg.contains("not been read yet"), "got: {msg}"); + } + _ => panic!("expected error"), + } + } + + #[test] + fn error_file_modified_since_read() { + let (_dir, path, tracker) = setup_tracked_file("original"); + + // Modify the file after the read was recorded + std::thread::sleep(std::time::Duration::from_millis(10)); + std::fs::write(&path, "modified externally").unwrap(); + + let call = edit_call(&path, "original", "replaced", false); + let (outcome, new_bytes) = call.execute(&path, &tracker); + + assert!(new_bytes.is_none()); + match outcome { + ToolOutcome::Error(msg) => { + assert!(msg.contains("modified since read"), "got: {msg}"); + } + _ => panic!("expected error"), + } + } + + #[test] + fn error_no_match() { + let (_dir, path, tracker) = setup_tracked_file("hello world"); + + let call = edit_call(&path, "nonexistent", "replacement", false); + let (outcome, new_bytes) = call.execute(&path, &tracker); + + assert!(new_bytes.is_none()); + match outcome { + ToolOutcome::Error(msg) => { + assert!(msg.contains("not found"), "got: {msg}"); + } + _ => panic!("expected error"), + } + } + + #[test] + fn error_multiple_matches_without_replace_all() { + let (_dir, path, tracker) = setup_tracked_file("foo bar foo baz foo"); + + let call = edit_call(&path, "foo", "qux", false); + let (outcome, new_bytes) = call.execute(&path, &tracker); + + assert!(new_bytes.is_none()); + match outcome { + ToolOutcome::Error(msg) => { + assert!(msg.contains("3 matches"), "got: {msg}"); + assert!(msg.contains("replace_all"), "got: {msg}"); + } + _ => panic!("expected error"), + } + // File should be unchanged + assert_eq!( + std::fs::read_to_string(&path).unwrap(), + "foo bar foo baz foo" + ); + } + + #[test] + fn error_empty_old_string() { + let (_dir, path, tracker) = setup_tracked_file("content"); + + let call = edit_call(&path, "", "something", false); + let (outcome, new_bytes) = call.execute(&path, &tracker); + + assert!(new_bytes.is_none()); + assert!(matches!(outcome, ToolOutcome::Error(_))); + } + + #[test] + fn error_file_does_not_exist() { + let tracker = FileReadTracker::default(); + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("nonexistent.txt"); + + let call = edit_call(&path, "x", "y", false); + let (outcome, new_bytes) = call.execute(&path, &tracker); + + assert!(new_bytes.is_none()); + match outcome { + ToolOutcome::Error(msg) => { + assert!(msg.contains("does not exist"), "got: {msg}"); + } + _ => panic!("expected error"), + } + } + + #[test] + fn preserves_file_when_no_match() { + let original = "[config]\nport = 8080\nhost = localhost\n"; + let (_dir, path, tracker) = setup_tracked_file(original); + + let call = edit_call(&path, "port = 9090", "port = 3000", false); + let (outcome, _) = call.execute(&path, &tracker); + + assert!(matches!(outcome, ToolOutcome::Error(_))); + assert_eq!(std::fs::read_to_string(&path).unwrap(), original); + } + + #[test] + fn multiline_replacement() { + let content = "[section]\nkey1 = val1\nkey2 = val2\n[other]\n"; + let (_dir, path, tracker) = setup_tracked_file(content); + + let call = edit_call( + &path, + "key1 = val1\nkey2 = val2", + "key1 = new1\nkey2 = new2", + false, + ); + let (outcome, new_bytes) = call.execute(&path, &tracker); + + assert!(matches!(outcome, ToolOutcome::Success(_))); + assert!(new_bytes.is_some()); + assert_eq!( + std::fs::read_to_string(&path).unwrap(), + "[section]\nkey1 = new1\nkey2 = new2\n[other]\n" + ); + } + } + + // ── Integration tests: full edit lifecycle ── + // + // These exercise the cross-component flow that dispatch orchestrates: + // FileReadTracker → SnapshotStore → EditToolCall.execute → tracker update + + mod edit_integration { + use super::*; + use crate::edit_permissions::EditPermissionCache; + use crate::file_tracker::FileReadTracker; + use crate::snapshots::SnapshotStore; + + /// Simulate a file read (what dispatch does after ReadToolCall.execute). + fn simulate_read(tracker: &mut FileReadTracker, path: &std::path::Path) { + let content = std::fs::read(path).unwrap(); + let mtime = std::fs::metadata(path).unwrap().modified().unwrap(); + tracker.record_read(path.to_path_buf(), &content, mtime); + } + + /// Simulate a tracker update after edit (what dispatch does after execute). + fn simulate_tracker_update( + tracker: &mut FileReadTracker, + path: &std::path::Path, + new_bytes: &[u8], + ) { + let mtime = std::fs::metadata(path).unwrap().modified().unwrap(); + tracker.update_after_edit(path, new_bytes, mtime); + } + + #[test] + fn full_read_snapshot_edit_cycle() { + let dir = tempfile::tempdir().unwrap(); + let file_path = dir.path().join("config.toml"); + std::fs::write(&file_path, "[db]\nhost = localhost\nport = 5432\n").unwrap(); + + let snapshot_dir = dir.path().join("snapshots").join("session-1"); + let mut tracker = FileReadTracker::default(); + let mut store = SnapshotStore::open(snapshot_dir.clone()).unwrap(); + + // 1. Simulate reading the file + simulate_read(&mut tracker, &file_path); + + // 2. Snapshot before edit + let original = std::fs::read(&file_path).unwrap(); + store.ensure_snapshot(&file_path, &original).unwrap(); + + // 3. Execute edit + let call = EditToolCall { + path: file_path.clone(), + old_string: "host = localhost".to_string(), + new_string: "host = 10.0.0.1".to_string(), + replace_all: false, + }; + let (outcome, new_bytes) = call.execute(&file_path, &tracker); + assert!(matches!(outcome, ToolOutcome::Success(_))); + let new_bytes = new_bytes.unwrap(); + + // 4. Update tracker (simulating what dispatch does) + simulate_tracker_update(&mut tracker, &file_path, &new_bytes); + + // Verify: file was edited + assert_eq!( + std::fs::read_to_string(&file_path).unwrap(), + "[db]\nhost = 10.0.0.1\nport = 5432\n" + ); + + // Verify: snapshot has original content + assert!(store.has_snapshot(&file_path)); + let snapshot_name = crate::snapshots::sanitize_path(&file_path); + let snapshot_content = + std::fs::read_to_string(snapshot_dir.join(snapshot_name)).unwrap(); + assert_eq!(snapshot_content, "[db]\nhost = localhost\nport = 5432\n"); + } + + #[test] + fn second_edit_without_reread() { + let dir = tempfile::tempdir().unwrap(); + let file_path = dir.path().join("config.toml"); + std::fs::write(&file_path, "key1 = aaa\nkey2 = bbb\n").unwrap(); + + let mut tracker = FileReadTracker::default(); + + // Read the file + simulate_read(&mut tracker, &file_path); + + // First edit + let call1 = EditToolCall { + path: file_path.clone(), + old_string: "key1 = aaa".to_string(), + new_string: "key1 = xxx".to_string(), + replace_all: false, + }; + let (outcome, new_bytes) = call1.execute(&file_path, &tracker); + assert!(matches!(outcome, ToolOutcome::Success(_))); + simulate_tracker_update(&mut tracker, &file_path, &new_bytes.unwrap()); + + // Second edit — should work without re-reading because tracker was updated + let call2 = EditToolCall { + path: file_path.clone(), + old_string: "key2 = bbb".to_string(), + new_string: "key2 = yyy".to_string(), + replace_all: false, + }; + let (outcome, new_bytes) = call2.execute(&file_path, &tracker); + assert!(matches!(outcome, ToolOutcome::Success(_))); + assert!(new_bytes.is_some()); + assert_eq!( + std::fs::read_to_string(&file_path).unwrap(), + "key1 = xxx\nkey2 = yyy\n" + ); + } + + #[test] + fn external_modification_between_edits() { + let dir = tempfile::tempdir().unwrap(); + let file_path = dir.path().join("config.toml"); + std::fs::write(&file_path, "value = original\n").unwrap(); + + let mut tracker = FileReadTracker::default(); + simulate_read(&mut tracker, &file_path); + + // First edit succeeds + let call1 = EditToolCall { + path: file_path.clone(), + old_string: "value = original".to_string(), + new_string: "value = edited".to_string(), + replace_all: false, + }; + let (outcome, new_bytes) = call1.execute(&file_path, &tracker); + assert!(matches!(outcome, ToolOutcome::Success(_))); + simulate_tracker_update(&mut tracker, &file_path, &new_bytes.unwrap()); + + // External modification (e.g., user edits the file) + std::thread::sleep(std::time::Duration::from_millis(10)); + std::fs::write(&file_path, "value = user_changed\n").unwrap(); + + // Second edit should fail (stale) + let call2 = EditToolCall { + path: file_path.clone(), + old_string: "value = edited".to_string(), + new_string: "value = second_edit".to_string(), + replace_all: false, + }; + let (outcome, new_bytes) = call2.execute(&file_path, &tracker); + assert!(new_bytes.is_none()); + match outcome { + ToolOutcome::Error(msg) => assert!(msg.contains("modified since read")), + _ => panic!("expected stale error"), + } + + // File should be unchanged (the user's edit preserved) + assert_eq!( + std::fs::read_to_string(&file_path).unwrap(), + "value = user_changed\n" + ); + } + + #[test] + fn snapshot_only_created_once_per_file() { + let dir = tempfile::tempdir().unwrap(); + let file_path = dir.path().join("config.toml"); + std::fs::write(&file_path, "a = 1\nb = 2\n").unwrap(); + + let snapshot_dir = dir.path().join("snapshots").join("session-1"); + let mut tracker = FileReadTracker::default(); + let mut store = SnapshotStore::open(snapshot_dir).unwrap(); + + simulate_read(&mut tracker, &file_path); + + // First edit — snapshot should be created + let original = std::fs::read(&file_path).unwrap(); + let created = store.ensure_snapshot(&file_path, &original).unwrap(); + assert!(created); + + let call1 = EditToolCall { + path: file_path.clone(), + old_string: "a = 1".to_string(), + new_string: "a = 10".to_string(), + replace_all: false, + }; + let (_, new_bytes) = call1.execute(&file_path, &tracker); + simulate_tracker_update(&mut tracker, &file_path, &new_bytes.unwrap()); + + // Second edit — snapshot should NOT be recreated + let content_before_second = std::fs::read(&file_path).unwrap(); + let created = store + .ensure_snapshot(&file_path, &content_before_second) + .unwrap(); + assert!(!created); // idempotent — already snapshotted + } + + #[test] + fn permission_cache_grant_and_check() { + let mut cache = EditPermissionCache::default(); + let path = std::path::PathBuf::from("/Users/me/.config/atuin/config.toml"); + + // Initially no grant + assert!(!cache.has_valid_grant(&path)); + + // Grant permission + cache.grant(path.clone()); + assert!(cache.has_valid_grant(&path)); + + // Different file has no grant + assert!(!cache.has_valid_grant(std::path::Path::new("/other/file.toml"))); + + // Roundtrip through JSON (simulates session persistence) + let json = cache.to_json().unwrap(); + let restored = EditPermissionCache::from_json(&json).unwrap(); + assert!(restored.has_valid_grant(&path)); + } + } + // ── Windows-specific tests (absolute paths with drive letters) ── #[cfg(windows)] diff --git a/crates/atuin-ai/src/tui/dispatch.rs b/crates/atuin-ai/src/tui/dispatch.rs index ea895c01..fea26953 100644 --- a/crates/atuin-ai/src/tui/dispatch.rs +++ b/crates/atuin-ai/src/tui/dispatch.rs @@ -61,15 +61,17 @@ pub(crate) fn dispatch(ctx: &mut DispatchContext, event: AiTuiEvent) -> bool { !ctx.exiting.load(Ordering::Acquire) } -/// Persist new events and the server session ID if it has changed. +/// Persist new events, server session ID, file tracker, and edit permissions. /// Called from the dispatch thread (sync), bridges to async via the tokio handle. fn persist_session(ctx: &mut DispatchContext) { - let Ok((events, server_sid)) = ctx + let Ok((events, server_sid, file_tracker_json, edit_perms_json)) = ctx .handle .fetch(|state| { ( state.conversation.events.clone(), state.conversation.session_id.clone(), + state.file_tracker.to_json().ok(), + state.edit_permissions.to_json().ok(), ) }) .blocking_recv() @@ -86,6 +88,22 @@ fn persist_session(ctx: &mut DispatchContext) { { tracing::warn!("failed to persist server session ID: {e}"); } + if let Some(ref json) = file_tracker_json + && let Err(e) = rt.block_on( + ctx.session_mgr + .set_metadata(crate::file_tracker::METADATA_KEY, json), + ) + { + tracing::warn!("failed to persist file tracker: {e}"); + } + if let Some(ref json) = edit_perms_json + && let Err(e) = rt.block_on( + ctx.session_mgr + .set_metadata(crate::edit_permissions::METADATA_KEY, json), + ) + { + tracing::warn!("failed to persist edit permissions: {e}"); + } } fn launch_stream(ctx: &DispatchContext, setup: impl FnOnce(&mut Session) + Send + 'static) { @@ -210,6 +228,10 @@ fn execute_tool( let shell_call = shell_call.clone(); execute_shell_tool(handle, tx, &tool_id, &shell_call); } + ClientToolCall::Edit(edit_call) => { + let edit_call = edit_call.clone(); + execute_edit_tool(handle, tx, tool_id, edit_call); + } _ => { execute_simple_tool(handle, tx, tool_id, tool, db); } @@ -231,7 +253,21 @@ fn execute_simple_tool( tokio::spawn(async move { let outcome = tool.execute(&db).await; + + // After a successful file read, capture tracking data for freshness + // checking. This re-stats the file to get content hash and mtime. + let read_tracking = if let ClientToolCall::Read(ref read_tool) = tool + && !outcome.is_error() + { + capture_read_tracking(&read_tool.path) + } else { + None + }; + h.update(move |state| { + if let Some((path, content, mtime)) = read_tracking { + state.file_tracker.record_read(path, &content, mtime); + } state.finish_tool_call(&tool_id, outcome); if !state.tool_tracker.has_pending() { let _ = tx.send(AiTuiEvent::ContinueAfterTools); @@ -240,6 +276,117 @@ fn execute_simple_tool( }); } +/// Capture file content and mtime for the read tracker. +/// Returns None for directories or if the file can't be read. +fn capture_read_tracking( + path: &std::path::Path, +) -> Option<(std::path::PathBuf, Vec<u8>, std::time::SystemTime)> { + let resolved = if path.is_relative() { + std::env::current_dir().ok()?.join(path) + } else { + path.to_path_buf() + }; + if !resolved.is_file() { + return None; + } + let content = std::fs::read(&resolved).ok()?; + let mtime = std::fs::metadata(&resolved).ok()?.modified().ok()?; + Some((resolved, content, mtime)) +} + +/// Execute an edit_file tool call. +/// +/// Orchestrates snapshot → execute → tracker update. The snapshot and +/// tracker mutations happen via `h.update()` (on the TUI thread) since +/// they need mutable Session state. The actual file I/O (freshness check, +/// read, match, atomic write) runs in the tokio task. +fn execute_edit_tool( + handle: &Handle<Session>, + tx: &mpsc::Sender<AiTuiEvent>, + tool_id: String, + edit_call: crate::tools::EditToolCall, +) { + let h = handle.clone(); + let tx = tx.clone(); + + tokio::spawn(async move { + let resolved = edit_call.resolved_path(); + + // 1. Read the original file content (used for snapshot + diff). + let old_content = std::fs::read(&resolved).ok(); + + // 2. Snapshot the original file before editing. + if let Some(ref content) = old_content { + let snap_path = resolved.clone(); + let snap_content = content.clone(); + h.update(move |state| { + if let Some(ref mut store) = state.snapshot_store + && let Err(e) = store.ensure_snapshot(&snap_path, &snap_content) + { + tracing::warn!("failed to create file snapshot: {e}"); + } + }); + } + + // 3. Fetch a clone of the file tracker for freshness checking. + let Ok(tracker) = h.fetch(|state| state.file_tracker.clone()).await else { + let tc_id = tool_id.clone(); + h.update(move |state| { + state.finish_tool_call( + &tc_id, + crate::tools::ToolOutcome::Error("Internal error: TUI unavailable".into()), + ); + if !state.tool_tracker.has_pending() { + let _ = tx.send(AiTuiEvent::ContinueAfterTools); + } + }); + return; + }; + + // 4. Execute: freshness check → read → match → atomic write + let (outcome, new_bytes) = edit_call.execute(&resolved, &tracker); + + // 5. Compute diff preview on success + let edit_preview = if let Some(ref new_bytes) = new_bytes { + if let Some(ref old_bytes) = old_content { + let old_str = String::from_utf8_lossy(old_bytes); + let new_str = String::from_utf8_lossy(new_bytes); + let preview = crate::diff::EditPreview::compute(&old_str, &new_str); + if preview.hunks.is_empty() { + None + } else { + Some(preview) + } + } else { + None + } + } else { + None + }; + + // 6. Update tracker, store diff preview, and finish the tool call + let tc_id = tool_id; + h.update(move |state| { + if let Some(ref new_bytes) = new_bytes + && let Ok(mtime) = std::fs::metadata(&resolved).and_then(|m| m.modified()) + { + state + .file_tracker + .update_after_edit(&resolved, new_bytes, mtime); + } + if let Some(preview) = edit_preview + && let Some(tracked) = state.tool_tracker.get_mut(&tc_id) + { + tracked.edit_preview = Some(preview); + } + state.finish_tool_call(&tc_id, outcome); + if !state.tool_tracker.has_pending() { + let _ = tx.send(AiTuiEvent::ContinueAfterTools); + } + }); + }); +} + /// Execute a shell tool with streaming VT100 preview. fn execute_shell_tool( handle: &Handle<Session>, @@ -352,12 +499,28 @@ async fn check_tool_permission_inner( .map_err(|e| format!("Internal error fetching tool state: {e}"))? .ok_or_else(|| "Internal error: tool not found in tracker".to_string())?; - // 2. Resolve working directory + // 2. For edit tools, check session-scoped permission grants before + // hitting the filesystem-based resolver. A valid grant means the user + // already approved this file recently. + if let ClientToolCall::Edit(ref edit) = tool { + let resolved = edit.resolved_path(); + let has_grant = h2 + .fetch(move |state| state.edit_permissions.has_valid_grant(&resolved)) + .await + .unwrap_or(false); + + if has_grant { + execute_tool(h2, tx, id, tool, db); + return Ok(()); + } + } + + // 3. Resolve working directory let working_dir = target_dir .or_else(|| std::env::current_dir().ok()) .ok_or_else(|| "Could not determine working directory".to_string())?; - // 3. Create permission resolver and check + // 4. Create permission resolver and check let resolver = PermissionResolver::new(working_dir) .await .map_err(|e| format!("Permission check failed: {e}"))?; @@ -367,7 +530,7 @@ async fn check_tool_permission_inner( .await .map_err(|e| format!("Permission check failed: {e}"))?; - // 4. Handle response — all paths here handle the tool, so return Ok + // 5. Handle response — all paths here handle the tool, so return Ok let id_clone = id.clone(); match response { PermissionResponse::Allowed => { @@ -423,6 +586,32 @@ fn on_select_permission(ctx: &mut DispatchContext, permission: PermissionResult) execute_tool(&h2, &tx, tool_id, tool, &db); }); } + PermissionResult::AllowFileForSession => { + // Cache a session-scoped, time-limited grant for this file + let db = ctx.app_ctx.history_db.clone(); + tokio::spawn(async move { + let Ok(Some((tool_id, tool))) = h2 + .fetch(move |state| { + state + .tool_tracker + .asking_for_permission() + .map(|t| (t.id.clone(), t.tool.clone())) + }) + .await + else { + return; + }; + + if let ClientToolCall::Edit(ref edit) = tool { + let resolved = edit.resolved_path(); + h2.update(move |state| { + state.edit_permissions.grant(resolved); + }); + } + + execute_tool(&h2, &tx, tool_id, tool, &db); + }); + } PermissionResult::AlwaysAllowInDir => { let db = ctx.app_ctx.history_db.clone(); let git_root = ctx.app_ctx.git_root.clone(); diff --git a/crates/atuin-ai/src/tui/events.rs b/crates/atuin-ai/src/tui/events.rs index 1a422fef..969f6ae5 100644 --- a/crates/atuin-ai/src/tui/events.rs +++ b/crates/atuin-ai/src/tui/events.rs @@ -38,7 +38,34 @@ pub(crate) enum AiTuiEvent { #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) enum PermissionResult { Allow, + /// Per-file, time-limited grant scoped to the current session. + AllowFileForSession, AlwaysAllowInDir, AlwaysAllow, Deny, } + +impl PermissionResult { + /// String identifier used as the SelectOption value. + pub fn as_value_str(&self) -> &'static str { + match self { + Self::Allow => "allow", + Self::AllowFileForSession => "allow-file-session", + Self::AlwaysAllowInDir => "always-allow-in-dir", + Self::AlwaysAllow => "always-allow", + Self::Deny => "deny", + } + } + + /// Parse from a SelectOption value string. + pub fn from_value_str(s: &str) -> Option<Self> { + match s { + "allow" => Some(Self::Allow), + "allow-file-session" => Some(Self::AllowFileForSession), + "always-allow-in-dir" => Some(Self::AlwaysAllowInDir), + "always-allow" => Some(Self::AlwaysAllow), + "deny" => Some(Self::Deny), + _ => None, + } + } +} diff --git a/crates/atuin-ai/src/tui/state.rs b/crates/atuin-ai/src/tui/state.rs index e122918e..af1ebffe 100644 --- a/crates/atuin-ai/src/tui/state.rs +++ b/crates/atuin-ai/src/tui/state.rs @@ -474,6 +474,12 @@ pub(crate) struct Session { pub slash_registry: SlashCommandRegistry, /// The unique ID for this invocation pub invocation_id: String, + /// Tracks which files have been read, for freshness checking before edits. + pub file_tracker: crate::file_tracker::FileReadTracker, + /// Session-scoped edit permission grants (per-file, time-limited). + pub edit_permissions: crate::edit_permissions::EditPermissionCache, + /// Backs up files before the first edit in a session. + pub snapshot_store: Option<crate::snapshots::SnapshotStore>, } impl Session { @@ -491,6 +497,9 @@ impl Session { archived_view_events: Vec::new(), slash_registry: Default::default(), invocation_id: invocation_id.unwrap_or_else(|| uuid::Uuid::now_v7().to_string()), + file_tracker: Default::default(), + edit_permissions: Default::default(), + snapshot_store: None, } } diff --git a/crates/atuin-ai/src/tui/view/mod.rs b/crates/atuin-ai/src/tui/view/mod.rs index 565a0597..bdbece9c 100644 --- a/crates/atuin-ai/src/tui/view/mod.rs +++ b/crates/atuin-ai/src/tui/view/mod.rs @@ -1,12 +1,11 @@ //! View function that builds the eye-declare element tree from app state. use eye_declare::{ - BorderType, Cells, Column, Elements, HStack, Span, Spinner, Text, View, Viewport, - WidthConstraint, element, + Cells, Column, Elements, HStack, Span, Spinner, Text, View, Viewport, WidthConstraint, element, }; use ratatui_core::style::{Color, Modifier, Style}; -use crate::tools::{ClientToolCall, TrackedTool}; +use crate::tools::{ClientToolCall, HistorySearchFilterMode, ToolPreview, TrackedTool}; use crate::tui::components::select::SelectOption; use crate::tui::components::session_continue::SessionContinue; use crate::tui::events::{AiTuiEvent, PermissionResult}; @@ -68,6 +67,16 @@ pub(crate) fn ai_view(state: &Session) -> Elements { }) }) + #({ + let needs_pending_banner = busy && !matches!(turns.last(), Some(turn::UiTurn::Agent { .. })); + if needs_pending_banner { + let empty: &[turn::UiEvent] = &[]; + agent_turn_view(empty, true) + } else { + element! {} + } + }) + #(if !state.is_exiting() { #(input_view(state)) }) @@ -135,16 +144,13 @@ fn tool_call_view(tool_call: &TrackedTool, in_git_project: bool) -> Elements { let verb = tool_call.tool.descriptor().display_verb; let tool_desc = match &tool_call.tool { ClientToolCall::Read(tool) => tool.path.display().to_string(), + ClientToolCall::Edit(tool) => tool.path.display().to_string(), ClientToolCall::Write(tool) => tool.path.display().to_string(), ClientToolCall::Shell(tool) => tool.command.clone(), ClientToolCall::AtuinHistory(tool) => tool.query.clone(), }; - let dir_label = if in_git_project { - "Always allow in this workspace" - } else { - "Always allow in this directory" - }; + let select_options = permission_options_for_tool(&tool_call.tool, in_git_project); element! { View(key: format!("tool-call-{}", tool_call.id), padding_left: Cells::from(2), padding_top: Cells::from(1)) { @@ -153,39 +159,68 @@ fn tool_call_view(tool_call: &TrackedTool, in_git_project: bool) -> Elements { Span(text: &tool_desc, style: Style::default().fg(Color::Yellow)) } View(padding_left: Cells::from(2)) { - Select(options: [ - SelectOption::builder() - .label("Allow") - .value("allow") - .build(), - SelectOption::builder() - .label(dir_label) - .value("always-allow-in-dir") - .build(), - SelectOption::builder() - .label("Always allow") - .value("always-allow") - .build(), - SelectOption::builder() - .label("Deny") - .value("deny") - .build(), - ], on_select: Box::new(move |option: &SelectOption| { - let value = match option.value.as_str() { - "allow" => PermissionResult::Allow, - "always-allow-in-dir" => PermissionResult::AlwaysAllowInDir, - "always-allow" => PermissionResult::AlwaysAllow, - "deny" => PermissionResult::Deny, - _ => unreachable!(), - }; - - Some(AiTuiEvent::SelectPermission(value)) + Select(options: select_options, on_select: Box::new(move |option: &SelectOption| { + PermissionResult::from_value_str(option.value.as_str()) + .map(AiTuiEvent::SelectPermission) }) as Box<dyn Fn(&SelectOption) -> Option<AiTuiEvent> + Send + Sync>) } } } } +/// Build the permission SelectOptions appropriate for a tool call. +/// +/// Edit tools get a per-file session-scoped option instead of the +/// workspace-level "Always allow in this directory". Other tools +/// keep the standard set. +fn permission_options_for_tool(tool: &ClientToolCall, in_git_project: bool) -> Vec<SelectOption> { + match tool { + ClientToolCall::Edit(_) => vec![ + SelectOption::builder() + .label("Allow") + .value(PermissionResult::Allow.as_value_str()) + .build(), + SelectOption::builder() + .label("Allow this file for this session") + .value(PermissionResult::AllowFileForSession.as_value_str()) + .build(), + SelectOption::builder() + .label("Always allow") + .value(PermissionResult::AlwaysAllow.as_value_str()) + .build(), + SelectOption::builder() + .label("Deny") + .value(PermissionResult::Deny.as_value_str()) + .build(), + ], + _ => { + let dir_label = if in_git_project { + "Always allow in this workspace" + } else { + "Always allow in this directory" + }; + vec![ + SelectOption::builder() + .label("Allow") + .value(PermissionResult::Allow.as_value_str()) + .build(), + SelectOption::builder() + .label(dir_label) + .value(PermissionResult::AlwaysAllowInDir.as_value_str()) + .build(), + SelectOption::builder() + .label("Always allow") + .value(PermissionResult::AlwaysAllow.as_value_str()) + .build(), + SelectOption::builder() + .label("Deny") + .value(PermissionResult::Deny.as_value_str()) + .build(), + ] + } + } +} + fn user_turn_view(events: &[turn::UiEvent], first_turn: bool) -> Elements { let label_style = Style::default() .fg(Color::Cyan) @@ -231,7 +266,10 @@ fn agent_turn_view(events: &[turn::UiEvent], busy: bool) -> Elements { label_first: true, done: !busy, ) - #(for event in events { + #(for (i, event) in events.iter().enumerate() { + #(if i > 0 { + Text { Span(text: "") } + }) #(match event { turn::UiEvent::Text { content } => { element! { @@ -247,47 +285,42 @@ fn agent_turn_view(events: &[turn::UiEvent], busy: bool) -> Elements { suggested_command_view(details) }, turn::UiEvent::ToolCall(details) => { - let preview_done = details.preview.as_ref().is_some_and(|p| p.exit_code.is_some() || p.interrupted); let tool_key = details.tool_use_id.clone(); element! { View(key: format!("tool-output-{tool_key}"), padding_left: Cells::from(2)) { - #(if let Some(ref preview) = details.preview { - View(key: format!("preview-{tool_key}")) { - #(preview_spinner_view(&details.name, preview_done)) - Viewport( - key: format!("viewport-{tool_key}"), - lines: preview.lines.clone(), - height: 10, - border: BorderType::Plain, - border_style: Style::default().fg(Color::DarkGray), - style: Style::default().fg(Color::White), - wrap: false, - ) - #(if let Some(code) = preview.exit_code { - #(if code == 0 { - Text { - Span(text: format!("Exit code: {code}"), style: Style::default().fg(Color::Green)) - } - } else { - Text { - Span(text: format!("Exit code: {code}"), style: Style::default().fg(Color::Red)) - } - }) - }) - #(if preview.interrupted { - Text { - Span(text: "Interrupted", style: Style::default().fg(Color::Red).add_modifier(Modifier::BOLD)) - } - }) - #(if !preview_done { - Text { - Span(text: "[Ctrl+C] Interrupt", style: Style::default().fg(Color::DarkGray)) - } - }) - } - } else { - #(tool_status_view(&details.name, &details.status)) + #(match &details.render_data { + turn::ToolRenderData::Shell { command, preview } => { + shell_tool_view(&tool_key, command, preview.as_ref()) + }, + turn::ToolRenderData::FileEdit { path, preview } => { + file_edit_tool_view(&tool_key, &details.status, path, preview.as_ref()) + }, + turn::ToolRenderData::FileWrite { path } => { + file_write_tool_view(&details.status, path) + }, + turn::ToolRenderData::Remote => { + tool_status_view(&details.name, &details.status) + }, + turn::ToolRenderData::FileRead { .. } + | turn::ToolRenderData::HistorySearch { .. } => { + element!{} + }, + }) + } + } + } + turn::UiEvent::ToolGroup(group) => { + let group_key = group.calls + .first() + .map(|c| c.tool_use_id.as_str()) + .unwrap_or("empty"); + + element! { + View(key: format!("group-{group_key}"), padding_left: Cells::from(2)) { + #(match group.kind { + turn::ToolGroupKind::FileRead => file_read_group_view(group), + turn::ToolGroupKind::HistorySearch => history_search_group_view(group), }) } } @@ -367,17 +400,391 @@ fn tool_status_view(name: &str, status: &turn::ToolResultStatus) -> Elements { } } -/// Render a spinner/status line for a command preview (shell tools). -fn preview_spinner_view(name: &str, done: bool) -> Elements { +// ─────────────────────────────────────────────────────────────────── +// Per-tool view functions +// ─────────────────────────────────────────────────────────────────── + +/// Max output lines shown for a shell command preview. +const MAX_SHELL_PREVIEW_LINES: u16 = 5; + +/// Render a shell command execution with live VT100 output viewport. +fn shell_tool_view(tool_key: &str, command: &str, preview: Option<&ToolPreview>) -> Elements { + let preview_done = preview.is_some_and(|p| p.exit_code.is_some() || p.interrupted); + + element! { + #(if let Some(preview) = preview { + View(key: format!("preview-{tool_key}")) { + Spinner( + label: if preview_done { format!("Ran: {command}") } else { format!("Running: {command}") }, + done: preview_done, + hide_checkmark: true, + ) + HStack { + View(width: WidthConstraint::Fixed(2)) { + Text { Span(text: "└ ") } + } + Column { + Viewport( + key: format!("viewport-{tool_key}"), + lines: preview.lines.clone(), + height: (preview.lines.len() as u16).clamp(1, MAX_SHELL_PREVIEW_LINES), + style: Style::default().fg(Color::Gray), + wrap: false, + ) + } + } + #(shell_tool_footer(preview, preview_done)) + } + } else { + Spinner( + label: format!("Running: {command}"), + label_style: Style::default().fg(Color::Yellow), + done: false, + ) + }) + } +} + +fn shell_tool_footer(preview: &ToolPreview, preview_done: bool) -> Elements { + if preview.interrupted { + return element! { + Text { + Span(text: "Interrupted", style: Style::default().fg(Color::Red).add_modifier(Modifier::BOLD)) + } + }; + } + if !preview_done { + return element! { + Text { + Span(text: "[Ctrl+C] Interrupt", style: Style::default().fg(Color::DarkGray)) + } + }; + } + if let Some(code) = preview.exit_code { + let style = if code == 0 { + Style::default().fg(Color::Green) + } else { + Style::default().fg(Color::Red) + }; + return element! { + Text { Span(text: format!("Exit code: {code}"), style: style) } + }; + } + element! {} +} + +/// Render a file edit tool call with diff preview. +fn file_edit_tool_view( + key: &str, + status: &turn::ToolResultStatus, + path: &std::path::Path, + preview: Option<&crate::diff::EditPreview>, +) -> Elements { + use crate::diff::DiffLine; + + let display_path = format_path_for_display(path); + + let status_line = match status { + turn::ToolResultStatus::Pending => { + element! { + Spinner( + label: format!("Editing: {display_path}"), + label_style: Style::default().fg(Color::Yellow), + done: false, + ) + } + } + turn::ToolResultStatus::Success => { + element! { + Spinner(label: format!("Edited: {display_path}"), done: true) + } + } + turn::ToolResultStatus::Error => { + element! { + Text { + Span(text: "✗ ", style: Style::default().fg(Color::Red)) + Span(text: format!("Edit {display_path}: failed"), style: Style::default().fg(Color::Red)) + } + } + } + }; + + // If no preview, just show the status line + let Some(preview) = preview else { + return status_line; + }; + if preview.hunks.is_empty() { + return status_line; + } + + // Calculate the line number gutter width from the highest line number + let max_line_num = preview.max_line_number(); + let gutter_width = max_line_num.to_string().len().max(2) as u16 + 1; // +1 for spacing + + element! { + View(key: key.to_string()) { + #(status_line) + + View(key: format!("{key}-diff"), padding_left: Cells::from(2)) { + #(for (hunk_idx, hunk) in preview.hunks.iter().enumerate() { + #({ + let gutter_w = gutter_width; + let mut before_pos = hunk.before_start; + let mut after_pos = hunk.after_start; + let lines_rendered: Vec<_> = hunk.lines.iter().enumerate().map(|(line_idx, line)| { + let (prefix, text, style, gutter_text, gutter_style) = match line { + DiffLine::Context(t) => { + let num = format!("{:>width$}", after_pos, width = (gutter_w - 1) as usize); + before_pos += 1; + after_pos += 1; + (" ", t.as_str(), Style::default().fg(Color::DarkGray), num, Style::default().fg(Color::DarkGray)) + } + DiffLine::Removed(t) => { + let num = format!("{:>width$}", before_pos, width = (gutter_w - 1) as usize); + before_pos += 1; + ("-", t.as_str(), Style::default().fg(Color::Red), num, Style::default().fg(Color::Red)) + } + DiffLine::Added(t) => { + let num = format!("{:>width$}", after_pos, width = (gutter_w - 1) as usize); + after_pos += 1; + ("+", t.as_str(), Style::default().fg(Color::Green), num, Style::default().fg(Color::Green)) + } + }; + (line_idx, prefix, text.to_string(), style, gutter_text, gutter_style) + }).collect(); + + element! { + View(key: format!("{key}-hunk-{hunk_idx}")) { + #(for (line_idx, prefix, text, style, gutter_text, gutter_style) in &lines_rendered { + HStack(key: format!("{key}-hunk-{hunk_idx}-line-{line_idx}")) { + View(width: WidthConstraint::Fixed(gutter_w)) { + Text { Span(text: gutter_text, style: *gutter_style) } + } + View { + Text { + Span(text: *prefix, style: *style) + Span(text: text, style: *style) + } + } + } + }) + } + } + }) + }) + } + } + } +} + +/// Render a file write tool call status with the target path. +fn file_write_tool_view(status: &turn::ToolResultStatus, path: &std::path::Path) -> Elements { + let display_path = path.display(); + match status { + turn::ToolResultStatus::Pending => { + element! { + Spinner( + label: format!("Writing: {display_path}"), + label_style: Style::default().fg(Color::Yellow), + done: false, + ) + } + } + turn::ToolResultStatus::Success => { + element! { + Spinner(label: format!("Wrote: {display_path}"), done: true) + } + } + turn::ToolResultStatus::Error => { + element! { + Text { + Span(text: "✗ ", style: Style::default().fg(Color::Red)) + Span(text: format!("Write {display_path}: denied"), style: Style::default().fg(Color::Red)) + } + } + } + } +} + +// ─────────────────────────────────────────────────────────────────── +// Tool group view functions +// ─────────────────────────────────────────────────────────────────── + +/// Max entries shown under a tool group header. When the group holds more +/// than this, only the most recent `MAX_GROUP_ENTRIES` are displayed; the +/// count in the header line tells the full story. +const MAX_GROUP_ENTRIES: usize = 5; + +/// Format a filesystem path for display in tool rows. +/// +/// - Relative to the current working directory if the path is under it +/// - `~/...` prefix if the path is under the user's home directory +/// - Absolute otherwise (and relative paths pass through unchanged) +fn format_path_for_display(path: &std::path::Path) -> String { + if let Ok(cwd) = std::env::current_dir() + && let Ok(relative) = path.strip_prefix(&cwd) + { + return relative.display().to_string(); + } + + if let Ok(home) = std::env::var("HOME") + && let Ok(relative) = path.strip_prefix(&home) + { + return format!("~/{}", relative.display()); + } + + path.display().to_string() +} + +fn filter_mode_label(mode: &HistorySearchFilterMode) -> &'static str { + match mode { + HistorySearchFilterMode::Global => "global", + HistorySearchFilterMode::Host => "host", + HistorySearchFilterMode::Session => "session", + HistorySearchFilterMode::Directory => "directory", + HistorySearchFilterMode::Workspace => "workspace", + } +} + +/// Format a list of filter modes as `"(global, workspace)"`, or an empty +/// string if the list is empty. +fn format_filter_modes(modes: &[HistorySearchFilterMode]) -> String { + if modes.is_empty() { + return String::new(); + } + let parts: Vec<&'static str> = modes.iter().map(filter_mode_label).collect(); + format!("({})", parts.join(", ")) +} + +/// Tree-connector marker for a row in a grouped list: `└ ` for the first +/// visible row, two spaces for subsequent rows. +fn tree_marker(is_first: bool) -> &'static str { + if is_first { "└ " } else { " " } +} + +/// 2-char status marker column: ✓ / ✗ / blank. +fn status_marker_view(status: &turn::ToolResultStatus) -> Elements { + match status { + turn::ToolResultStatus::Pending => element! { + Text { Span(text: " ") } + }, + turn::ToolResultStatus::Success => element! { + Text { Span(text: "✓ ", style: Style::default().fg(Color::Green)) } + }, + turn::ToolResultStatus::Error => element! { + Text { Span(text: "✗ ", style: Style::default().fg(Color::Red)) } + }, + } +} + +/// Compute the slice of calls to show — the most recent `MAX_GROUP_ENTRIES`. +fn visible_group_calls(group: &turn::ToolGroup) -> &[turn::ToolCallDetails] { + let start = group.calls.len().saturating_sub(MAX_GROUP_ENTRIES); + &group.calls[start..] +} + +/// Render a single row in a grouped list: [tree marker][status][content]. +fn group_row_view(is_first: bool, status: &turn::ToolResultStatus, content: Elements) -> Elements { + element! { + HStack { + View(width: WidthConstraint::Fixed(2)) { + Text { Span(text: tree_marker(is_first)) } + } + View(width: WidthConstraint::Fixed(2)) { + #(status_marker_view(status)) + } + Column { + #(content) + } + } + } +} + +/// Render a group of consecutive `read_file` tool calls. +fn file_read_group_view(group: &turn::ToolGroup) -> Elements { + let count = group.calls.len(); + let label = if count == 1 { + "Read 1 file".to_string() + } else { + format!("Read {count} files") + }; + let done = !group.any_pending(); + let visible = visible_group_calls(group); + + element! { + Spinner(label: label, done: done, hide_checkmark: true) + #(for (i, details) in visible.iter().enumerate() { + #(file_read_row(i == 0, details)) + }) + } +} + +fn file_read_row(is_first: bool, details: &turn::ToolCallDetails) -> Elements { + let path_str = match &details.render_data { + turn::ToolRenderData::FileRead { path } => format_path_for_display(path), + _ => String::new(), + }; + + let content = element! { + Text { Span(text: path_str) } + }; + + group_row_view(is_first, &details.status, content) +} + +/// Render a group of consecutive `atuin_history` tool calls. +fn history_search_group_view(group: &turn::ToolGroup) -> Elements { + let done = !group.any_pending(); + let visible = visible_group_calls(group); + element! { - Spinner( - label: if done { format!("Ran: {name}") } else { format!("Running: {name}") }, - label_style: Style::default().fg(Color::Yellow), - done: done, - ) + Spinner(label: "Searched Atuin history:", done: done, hide_checkmark: true) + #(for (i, details) in visible.iter().enumerate() { + #(history_search_row(i == 0, details)) + }) } } +fn history_search_row(is_first: bool, details: &turn::ToolCallDetails) -> Elements { + let (query, filter_modes) = match &details.render_data { + turn::ToolRenderData::HistorySearch { + query, + filter_modes, + } => (query.as_str(), filter_modes.as_slice()), + _ => ("", [].as_slice()), + }; + + let is_empty_query = query.trim().is_empty(); + let filter_label = format_filter_modes(filter_modes); + + let content = if is_empty_query { + element! { + Text { + Span( + text: "recent commands", + style: Style::default().fg(Color::Gray).add_modifier(Modifier::ITALIC), + ) + #(if !filter_label.is_empty() { + Span(text: " ") + Span(text: filter_label, style: Style::default().fg(Color::DarkGray)) + }) + } + } + } else { + element! { + Text { + Span(text: query.to_string()) + #(if !filter_label.is_empty() { + Span(text: " ") + Span(text: filter_label, style: Style::default().fg(Color::DarkGray)) + }) + } + } + }; + + group_row_view(is_first, &details.status, content) +} + fn suggested_command_view(details: &turn::SuggestedCommandDetails) -> Elements { let is_dangerous = matches!( details.danger_level, @@ -413,9 +820,6 @@ fn suggested_command_view(details: &turn::SuggestedCommandDetails) -> Elements { element! { View { - #(if !details.first_event_in_turn { - Text { Span(text: "") } - }) Text { Span(text: " Suggested command:", style: Style::default().fg(Color::Cyan)) } diff --git a/crates/atuin-ai/src/tui/view/turn.rs b/crates/atuin-ai/src/tui/view/turn.rs index a2555dc6..1c19a6b2 100644 --- a/crates/atuin-ai/src/tui/view/turn.rs +++ b/crates/atuin-ai/src/tui/view/turn.rs @@ -1,5 +1,7 @@ +use std::path::PathBuf; + use crate::tools::descriptor; -use crate::tools::{ToolPreview, ToolTracker}; +use crate::tools::{ClientToolCall, HistorySearchFilterMode, ToolPreview, ToolTracker}; use crate::tui::ConversationEvent; /// Server-sent danger level for a suggested command @@ -80,20 +82,99 @@ impl From<(&String, &String)> for ConfidenceLevel { #[derive(Debug)] pub(crate) enum UiEvent { - Text { content: String }, + Text { + content: String, + }, ToolCall(ToolCallDetails), + /// Consecutive client-side tool calls of the same groupable kind, collapsed + /// into one unit so the view can render a shared status line + a list of + /// individual entries. + ToolGroup(ToolGroup), ToolSummary(ToolSummary), SuggestedCommand(SuggestedCommandDetails), OutOfBandOutput(OutOfBandOutputDetails), } +/// A run of consecutive client-side tool calls of the same groupable kind. +#[derive(Debug)] +pub(crate) struct ToolGroup { + pub(crate) kind: ToolGroupKind, + pub(crate) calls: Vec<ToolCallDetails>, +} + +impl ToolGroup { + /// True if any call in the group is still pending. + pub(crate) fn any_pending(&self) -> bool { + self.calls + .iter() + .any(|c| c.status == ToolResultStatus::Pending) + } +} + +/// Which kind of client-side tools this group holds. +/// +/// Only tool types that benefit from grouped presentation appear here. +/// Shell (needs its own viewport) and FileWrite (wants diffs/contents) are +/// intentionally absent — those render as individual `UiEvent::ToolCall`s. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub(crate) enum ToolGroupKind { + FileRead, + HistorySearch, +} + +/// Tool-type-specific data for rendering in the view layer. +/// +/// Each variant carries the data a per-tool renderer component needs. +/// Built by TurnBuilder from ToolTracker + ConversationEvent data. +#[derive(Debug)] +pub(crate) enum ToolRenderData { + /// Shell command with live/cached VT100 output preview. + Shell { + command: String, + preview: Option<ToolPreview>, + }, + /// File read operation. + FileRead { path: PathBuf }, + /// File edit (str_replace) operation. + FileEdit { + path: PathBuf, + preview: Option<crate::diff::EditPreview>, + }, + /// File write/create operation. + FileWrite { path: PathBuf }, + /// Atuin history search. + HistorySearch { + query: String, + filter_modes: Vec<HistorySearchFilterMode>, + }, + /// Server-side tool — no client rendering data available. + Remote, +} + +impl ToolRenderData { + pub(crate) fn is_remote(&self) -> bool { + matches!(self, ToolRenderData::Remote) + } + + /// The group kind this tool should collapse into, if any. + /// + /// Returns `None` for tools that render as individual `UiEvent::ToolCall`s + /// (shell, file writes, remote). + pub(crate) fn group_kind(&self) -> Option<ToolGroupKind> { + match self { + ToolRenderData::FileRead { .. } => Some(ToolGroupKind::FileRead), + ToolRenderData::HistorySearch { .. } => Some(ToolGroupKind::HistorySearch), + _ => None, + } + } +} + #[derive(Debug)] pub(crate) struct ToolCallDetails { pub(crate) tool_use_id: String, pub(crate) name: String, pub(crate) status: ToolResultStatus, - pub(crate) is_client: bool, - pub(crate) preview: Option<ToolPreview>, + pub(crate) render_data: ToolRenderData, } #[derive(Debug)] @@ -101,7 +182,6 @@ pub(crate) struct SuggestedCommandDetails { pub(crate) command: String, pub(crate) danger_level: DangerLevel, pub(crate) confidence_level: ConfidenceLevel, - pub(crate) first_event_in_turn: bool, } #[derive(Debug)] @@ -179,33 +259,49 @@ impl<'a> TurnBuilder<'a> { pub(crate) fn build(&mut self) -> Vec<UiTurn> { self.commit_turn(); - // Collapse consecutive tool calls within each agent turn into ToolSummary + // Within each agent turn: + // - Consecutive remote tool calls collapse into a ToolSummary + // - Consecutive client-side tool calls of the same group kind collapse + // into a ToolGroup (e.g. N file reads → one group) + // - All other events pass through unchanged for turn in &mut self.turns { if let UiTurn::Agent { events } = turn { let mut new_events: Vec<UiEvent> = Vec::new(); - let mut pending_tools: Vec<ToolCallDetails> = Vec::new(); + let mut pending_remote: Vec<ToolCallDetails> = Vec::new(); + let mut pending_group: Option<(ToolGroupKind, Vec<ToolCallDetails>)> = None; for event in events.drain(..) { match event { - UiEvent::ToolCall(details) if !details.is_client => { - pending_tools.push(details); + UiEvent::ToolCall(details) if details.render_data.is_remote() => { + flush_group(&mut pending_group, &mut new_events); + pending_remote.push(details); } - other => { - if !pending_tools.is_empty() { - new_events.push(UiEvent::ToolSummary(ToolSummary { - tool_calls: std::mem::take(&mut pending_tools), - })); + UiEvent::ToolCall(details) + if details.render_data.group_kind().is_some() => + { + flush_remote(&mut pending_remote, &mut new_events); + + let kind = details.render_data.group_kind().unwrap(); + match pending_group.as_mut() { + Some((current_kind, calls)) if *current_kind == kind => { + calls.push(details); + } + _ => { + flush_group(&mut pending_group, &mut new_events); + pending_group = Some((kind, vec![details])); + } } + } + other => { + flush_remote(&mut pending_remote, &mut new_events); + flush_group(&mut pending_group, &mut new_events); new_events.push(other); } } } - if !pending_tools.is_empty() { - new_events.push(UiEvent::ToolSummary(ToolSummary { - tool_calls: pending_tools, - })); - } + flush_remote(&mut pending_remote, &mut new_events); + flush_group(&mut pending_group, &mut new_events); *events = new_events; } @@ -255,6 +351,9 @@ impl<'a> TurnBuilder<'a> { } fn add_agent_text(&mut self, content: &str) { + if content.trim().is_empty() { + return; + } self.start_agent_turn(); if let UiTurn::Agent { events } = self.turn_mut_unsafe() { events.push(UiEvent::Text { @@ -303,8 +402,6 @@ impl<'a> TurnBuilder<'a> { let danger = DangerLevel::from((&danger_level, &danger_notes)); let confidence = ConfidenceLevel::from((&confidence_level, &confidence_notes)); - let first_event_in_turn = events.is_empty(); - events.push(UiEvent::SuggestedCommand(SuggestedCommandDetails { command: input .get("command") @@ -313,14 +410,12 @@ impl<'a> TurnBuilder<'a> { .to_string(), danger_level: danger, confidence_level: confidence, - first_event_in_turn, })); } } fn add_tool_call(&mut self, id: &str, name: &str, _input: &serde_json::Value) { - let is_client = descriptor::by_name(name).is_some_and(|d| d.is_client); - let preview = self.tracker.preview_for(id); + let render_data = self.build_render_data(id, name); self.start_agent_turn(); if let UiTurn::Agent { events } = self.turn_mut_unsafe() { @@ -328,12 +423,44 @@ impl<'a> TurnBuilder<'a> { tool_use_id: id.to_string(), name: name.to_string(), status: ToolResultStatus::Pending, - is_client, - preview, + render_data, })); } } + /// Build tool-type-specific render data from the ToolTracker. + /// + /// For client-side tools, the tracker holds the typed `ClientToolCall` and + /// any live/cached preview data. For server-side (or unknown) tools, we + /// fall back to `ToolRenderData::Remote`. + fn build_render_data(&self, id: &str, _name: &str) -> ToolRenderData { + if let Some(tracked) = self.tracker.get(id) { + match &tracked.tool { + ClientToolCall::Shell(shell) => ToolRenderData::Shell { + command: shell.command.clone(), + preview: tracked.preview(), + }, + ClientToolCall::Read(read) => ToolRenderData::FileRead { + path: read.path.clone(), + }, + ClientToolCall::Edit(edit) => ToolRenderData::FileEdit { + path: edit.path.clone(), + preview: tracked.edit_preview.clone(), + }, + ClientToolCall::Write(write) => ToolRenderData::FileWrite { + path: write.path.clone(), + }, + ClientToolCall::AtuinHistory(history) => ToolRenderData::HistorySearch { + query: history.query.clone(), + filter_modes: history.filter_modes.clone(), + }, + } + } else { + // Not in tracker → server-side tool + ToolRenderData::Remote + } + } + fn add_tool_result(&mut self, tool_use_id: &str, _content: &str, is_error: bool) { self.start_agent_turn(); if let UiTurn::Agent { events } = self.turn_mut_unsafe() { @@ -364,6 +491,25 @@ impl<'a> TurnBuilder<'a> { } } +/// Drain pending remote tool calls into a `ToolSummary`. +fn flush_remote(pending: &mut Vec<ToolCallDetails>, out: &mut Vec<UiEvent>) { + if !pending.is_empty() { + out.push(UiEvent::ToolSummary(ToolSummary { + tool_calls: std::mem::take(pending), + })); + } +} + +/// Drain a pending client-side tool group into a `ToolGroup`. +fn flush_group( + pending: &mut Option<(ToolGroupKind, Vec<ToolCallDetails>)>, + out: &mut Vec<UiEvent>, +) { + if let Some((kind, calls)) = pending.take() { + out.push(UiEvent::ToolGroup(ToolGroup { kind, calls })); + } +} + #[derive(Debug)] pub(crate) struct ToolSummary { tool_calls: Vec<ToolCallDetails>, |
