From b121b73d07df389d324b3a8f27066661a6609618 Mon Sep 17 00:00:00 2001 From: Michelle Tilley Date: Thu, 23 Apr 2026 13:29:58 -0700 Subject: feat: Send user-defined context with `TERMINAL.md` (#3443) This PR adds the ability to inject user-defined content into Atuin AI requests, a la `AGENTS.md` or `CLAUDE.md`. * `.atuin/TERMINAL.md` (or alternatively just `TERMINAL.md`) is checked in every directory from the cwd up to the root * `~/.config/atuin/TERMINAL.md` (or equivalent config dir) is also checked * Supports Claude-style ``` !`` ``` and ```` ```!...``` ```` style shell interpolation --- crates/atuin-ai/src/driver.rs | 8 + crates/atuin-ai/src/lib.rs | 1 + crates/atuin-ai/src/stream.rs | 11 +- crates/atuin-ai/src/user_context/interpolate.rs | 279 ++++++++++++++++++++++++ crates/atuin-ai/src/user_context/mod.rs | 68 ++++++ crates/atuin-ai/src/user_context/walker.rs | 90 ++++++++ 6 files changed, 456 insertions(+), 1 deletion(-) create mode 100644 crates/atuin-ai/src/user_context/interpolate.rs create mode 100644 crates/atuin-ai/src/user_context/mod.rs create mode 100644 crates/atuin-ai/src/user_context/walker.rs (limited to 'crates') diff --git a/crates/atuin-ai/src/driver.rs b/crates/atuin-ai/src/driver.rs index 3acb9798..b5e1c275 100644 --- a/crates/atuin-ai/src/driver.rs +++ b/crates/atuin-ai/src/driver.rs @@ -775,6 +775,13 @@ async fn run_stream_bridge( use crate::stream::{StreamContent, StreamControl, StreamFrame, create_chat_stream}; use futures::StreamExt; + // Gather user context files (TERMINAL.md) and interpolate commands. + let shell = client_ctx.shell.as_deref().unwrap_or("sh"); + let start_dir = std::env::current_dir().unwrap_or_default(); + let global_ctx_path = crate::user_context::global_context_path(); + let user_contexts = + crate::user_context::gather(&start_dir, Some(&global_ctx_path), shell).await; + let stream = create_chat_stream( app_ctx.endpoint.clone(), app_ctx.token.clone(), @@ -782,6 +789,7 @@ async fn run_stream_bridge( client_ctx, app_ctx.send_cwd, app_ctx.last_command.clone(), + user_contexts, ); futures::pin_mut!(stream); diff --git a/crates/atuin-ai/src/lib.rs b/crates/atuin-ai/src/lib.rs index 540aece3..289f6ea2 100644 --- a/crates/atuin-ai/src/lib.rs +++ b/crates/atuin-ai/src/lib.rs @@ -14,3 +14,4 @@ pub(crate) mod store; pub(crate) mod stream; pub(crate) mod tools; pub(crate) mod tui; +pub(crate) mod user_context; diff --git a/crates/atuin-ai/src/stream.rs b/crates/atuin-ai/src/stream.rs index 19d287e7..e7155a08 100644 --- a/crates/atuin-ai/src/stream.rs +++ b/crates/atuin-ai/src/stream.rs @@ -100,6 +100,7 @@ pub(crate) fn create_chat_stream( client_ctx: ClientContext, send_cwd: bool, last_command: Option, + user_contexts: Vec, ) -> std::pin::Pin> + Send>> { Box::pin(async_stream::stream! { ensure_crypto_provider(); @@ -115,10 +116,18 @@ pub(crate) fn create_chat_stream( let context = client_ctx.to_json(send_cwd, last_command.as_deref()); + let mut config = serde_json::json!({ + "capabilities": request.capabilities, + }); + + if !user_contexts.is_empty() { + config["user_contexts"] = serde_json::json!(user_contexts); + } + let mut request_body = serde_json::json!({ "messages": request.messages, "context": context, - "capabilities": request.capabilities, + "config": config, "invocation_id": request.invocation_id }); diff --git a/crates/atuin-ai/src/user_context/interpolate.rs b/crates/atuin-ai/src/user_context/interpolate.rs new file mode 100644 index 00000000..91e34ab4 --- /dev/null +++ b/crates/atuin-ai/src/user_context/interpolate.rs @@ -0,0 +1,279 @@ +//! Parse `.atuin/ai-context.md` files and execute embedded commands. +//! +//! Two interpolation syntaxes are supported: +//! +//! **Inline:** `!`command`` — the `!` immediately before a code span triggers +//! execution. The entire `!`...`` span is replaced with the command's stdout. +//! +//! **Block:** +//! ````markdown +//! ```! +//! command +//! ``` +//! ```` +//! A fenced code block with `!` as the info string. The block body is executed +//! as a script and the entire fenced block is replaced with stdout. +//! +//! Regular code spans and fenced code blocks (without `!`) are left untouched. + +use std::ops::Range; +use std::time::Duration; + +use pulldown_cmark::{CodeBlockKind, Event, Options, Parser, Tag, TagEnd}; + +/// A command to execute, with its byte range in the source for replacement. +#[derive(Debug)] +struct Command { + /// Byte range in the source to replace (includes the `!` for inline, or + /// the full ``` fence for blocks). + range: Range, + /// The command string to execute. + body: String, +} + +/// Maximum time for a single command. +const COMMAND_TIMEOUT: Duration = Duration::from_secs(5); + +/// Maximum bytes of stdout to capture from a single command. +const MAX_OUTPUT_BYTES: usize = 64 * 1024; + +/// Parse a context file for interpolation commands. +fn parse_commands(source: &str) -> Vec { + let parser = Parser::new_ext(source, Options::empty()); + let mut commands = Vec::new(); + + // Block state: accumulate text across multiple Text events, finalize on End. + let mut block_start: Option = None; + let mut block_body = String::new(); + + for (event, range) in parser.into_offset_iter() { + match event { + // Inline: !`command` + Event::Code(code) if range.start > 0 && source.as_bytes()[range.start - 1] == b'!' => { + commands.push(Command { + range: (range.start - 1)..range.end, + body: code.to_string(), + }); + } + + // Block: ```! ... ``` + Event::Start(Tag::CodeBlock(CodeBlockKind::Fenced(info))) if info.as_ref() == "!" => { + block_start = Some(range.start); + block_body.clear(); + } + Event::Text(text) if block_start.is_some() => { + block_body.push_str(&text); + } + Event::End(TagEnd::CodeBlock) if block_start.is_some() => { + let start = block_start.take().unwrap(); + let trimmed = block_body.trim(); + if !trimmed.is_empty() { + commands.push(Command { + range: start..range.end, + body: trimmed.to_string(), + }); + } + block_body.clear(); + } + + _ => {} + } + } + + commands +} + +/// Execute all commands in a context file and return the interpolated content. +/// +/// Commands are executed in parallel. Failed commands are replaced with an +/// error marker so the AI has visibility into what went wrong. +pub(crate) async fn interpolate(source: &str, shell: &str) -> String { + let commands = parse_commands(source); + if commands.is_empty() { + return source.to_string(); + } + + // Execute all commands in parallel. + let mut handles = Vec::with_capacity(commands.len()); + for cmd in &commands { + let shell = shell.to_string(); + let body = cmd.body.clone(); + handles.push(tokio::spawn( + async move { run_command(&shell, &body).await }, + )); + } + + // Collect results. + let mut results = Vec::with_capacity(handles.len()); + for handle in handles { + let output = match handle.await { + Ok(output) => output, + Err(e) => format!("[error: task panicked: {e}]"), + }; + results.push(output); + } + + // Rebuild the source, replacing command ranges with their output. + // Commands are in source order from the parser, but let's sort to be safe. + let mut replacements: Vec<(Range, &str)> = commands + .iter() + .zip(results.iter()) + .map(|(cmd, output)| (cmd.range.clone(), output.as_str())) + .collect(); + replacements.sort_by_key(|(range, _)| range.start); + + let mut out = String::with_capacity(source.len()); + let mut cursor = 0; + for (range, output) in &replacements { + out.push_str(&source[cursor..range.start]); + out.push_str(output); + cursor = range.end; + } + out.push_str(&source[cursor..]); + + out +} + +async fn run_command(shell: &str, body: &str) -> String { + let result = tokio::time::timeout( + COMMAND_TIMEOUT, + tokio::process::Command::new(shell) + .arg("-c") + .arg(body) + .output(), + ) + .await; + + match result { + Ok(Ok(output)) => { + if output.status.success() { + if output.stdout.len() > MAX_OUTPUT_BYTES { + let truncated = String::from_utf8_lossy(&output.stdout[..MAX_OUTPUT_BYTES]); + format!( + "{}\n[output truncated at {}KB]", + truncated.trim(), + MAX_OUTPUT_BYTES / 1024 + ) + } else { + String::from_utf8_lossy(&output.stdout).trim().to_string() + } + } else { + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + let code = output.status.code().unwrap_or(-1); + format!("[error: exit code {code}: {stderr}]") + } + } + Ok(Err(e)) => format!("[error: {e}]"), + Err(_) => format!( + "[error: command timed out after {}s]", + COMMAND_TIMEOUT.as_secs() + ), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_inline_command() { + let source = "Branch: !`git branch --show-current`"; + let cmds = parse_commands(source); + assert_eq!(cmds.len(), 1); + assert_eq!(cmds[0].body, "git branch --show-current"); + assert_eq!( + &source[cmds[0].range.clone()], + "!`git branch --show-current`" + ); + } + + #[test] + fn parse_inline_double_backtick() { + let source = r#"Host: !``echo `hostname` ``"#; + let cmds = parse_commands(source); + assert_eq!(cmds.len(), 1); + assert_eq!(cmds[0].body, "echo `hostname` "); + } + + #[test] + fn parse_block_command() { + let source = "Before\n\n```!\necho hello\npython3 --version\n```\n\nAfter"; + let cmds = parse_commands(source); + assert_eq!(cmds.len(), 1); + assert_eq!(cmds[0].body, "echo hello\npython3 --version"); + } + + #[test] + fn regular_code_not_matched() { + let source = "Normal `code span` and ```bash\necho hi\n```"; + let cmds = parse_commands(source); + assert_eq!(cmds.len(), 0); + } + + #[test] + fn bang_not_adjacent_not_matched() { + let source = "Exclaim! Then `code` here."; + let cmds = parse_commands(source); + // The `!` and backtick are separated by " Then ", not adjacent. + assert_eq!(cmds.len(), 0); + } + + #[test] + fn mixed_content() { + let source = "\ +# Project Context + +Branch: !`git branch --show-current` + +Regular code: `not a command` + +```! +echo $VIRTUAL_ENV +``` + +```bash +echo not interpolated +``` + +End."; + let cmds = parse_commands(source); + assert_eq!(cmds.len(), 2); + assert_eq!(cmds[0].body, "git branch --show-current"); + assert_eq!(cmds[1].body, "echo $VIRTUAL_ENV"); + } + + #[tokio::test] + async fn interpolate_replaces_inline_command() { + let source = "Branch: !`echo main`"; + let result = interpolate(source, "sh").await; + assert_eq!(result, "Branch: main"); + } + + #[tokio::test] + async fn interpolate_replaces_block_command() { + let source = "Before\n\n```!\necho hello world\n```\n\nAfter"; + let result = interpolate(source, "sh").await; + assert_eq!(result, "Before\n\nhello world\n\nAfter"); + } + + #[tokio::test] + async fn interpolate_preserves_non_command_content() { + let source = "Just plain markdown with `code` and no bangs."; + let result = interpolate(source, "sh").await; + assert_eq!(result, source); + } + + #[tokio::test] + async fn interpolate_failed_command_shows_error() { + let source = "Result: !`exit 1`"; + let result = interpolate(source, "sh").await; + assert!(result.starts_with("Result: [error:")); + } + + #[tokio::test] + async fn interpolate_multiple_commands() { + let source = "A: !`echo one` B: !`echo two`"; + let result = interpolate(source, "sh").await; + assert_eq!(result, "A: one B: two"); + } +} diff --git a/crates/atuin-ai/src/user_context/mod.rs b/crates/atuin-ai/src/user_context/mod.rs new file mode 100644 index 00000000..295efdec --- /dev/null +++ b/crates/atuin-ai/src/user_context/mod.rs @@ -0,0 +1,68 @@ +//! User-authored context files (`TERMINAL.md`). +//! +//! Context files are markdown documents that can embed shell commands for +//! dynamic content. Before each API request, context files are discovered +//! by walking the filesystem, commands are executed, and the interpolated +//! content is sent to the server as `config.user_contexts`. + +mod interpolate; +mod walker; + +use std::path::Path; + +pub(crate) use walker::global_context_path; + +/// A fully resolved user context, ready to include in an API request. +#[derive(Debug, Clone, serde::Serialize)] +pub(crate) struct UserContext { + /// The path to the context file on disk. + pub path: String, + /// The interpolated content. + pub data: String, +} + +/// Discover context files and interpolate embedded commands. +/// +/// Walks from `start` up to the filesystem root looking for +/// `.atuin/ai-context.md`, then checks `global_path`. Returns contexts +/// ordered from most general (global/root) to most specific (deepest). +pub(crate) async fn gather( + start: &Path, + global_path: Option<&Path>, + shell: &str, +) -> Vec { + let raw_files = match walker::walk(start, global_path).await { + Ok(files) => files, + Err(e) => { + tracing::warn!("Failed to walk for context files: {e}"); + return Vec::new(); + } + }; + + if raw_files.is_empty() { + return Vec::new(); + } + + // Interpolate all files in parallel. + let mut handles = Vec::with_capacity(raw_files.len()); + for file in raw_files { + let shell = shell.to_string(); + handles.push(tokio::spawn(async move { + let data = interpolate::interpolate(&file.content, &shell).await; + UserContext { + path: file.path.to_string_lossy().to_string(), + data, + } + })); + } + + let mut contexts = Vec::with_capacity(handles.len()); + for handle in handles { + match handle.await { + Ok(ctx) => contexts.push(ctx), + Err(e) => tracing::warn!("Context interpolation task failed: {e}"), + } + } + + contexts +} diff --git a/crates/atuin-ai/src/user_context/walker.rs b/crates/atuin-ai/src/user_context/walker.rs new file mode 100644 index 00000000..117bbd33 --- /dev/null +++ b/crates/atuin-ai/src/user_context/walker.rs @@ -0,0 +1,90 @@ +//! Filesystem traversal for `TERMINAL.md` context files. +//! +//! Walks from the starting directory up to the filesystem root, checking for +//! `.atuin/TERMINAL.md` and `TERMINAL.md` at each level. Then checks the global +//! config directory. Returns files ordered from shallowest (global/root) to +//! deepest (most project-specific), so that context layers naturally from +//! general to specific. + +use std::path::{Path, PathBuf}; + +use eyre::Result; +use tokio::task::JoinSet; + +const CONTEXT_FILENAME: &str = "TERMINAL.md"; + +/// A context file found on disk, before interpolation. +#[derive(Debug)] +pub(crate) struct RawContextFile { + pub path: PathBuf, + pub content: String, +} + +struct FoundFile { + depth: usize, + file: RawContextFile, +} + +/// Walk from `start` up to the filesystem root collecting `TERMINAL.md` +/// context files, then check the global path. Returns files shallowest-first. +/// +/// At each ancestor directory, checks two locations: +/// - `.atuin/TERMINAL.md` (dotdir-scoped) +/// - `TERMINAL.md` (project root) +pub(crate) async fn walk(start: &Path, global_path: Option<&Path>) -> Result> { + let dirs: Vec = start.ancestors().map(PathBuf::from).collect(); + let dir_count = dirs.len(); + + let mut set: JoinSet>> = JoinSet::new(); + + for (index, dir) in dirs.into_iter().enumerate() { + let dir2 = dir.clone(); + set.spawn(async move { + load_context_file(&dir.join(".atuin").join(CONTEXT_FILENAME), index).await + }); + set.spawn(async move { load_context_file(&dir2.join(CONTEXT_FILENAME), index).await }); + } + + if let Some(global) = global_path { + let global = global.to_path_buf(); + let depth = dir_count; + set.spawn(async move { load_context_file(&global, depth).await }); + } + + let mut found = Vec::new(); + while let Some(result) = set.join_next().await { + match result? { + Ok(Some(f)) => found.push(f), + Ok(None) => {} + Err(e) => { + tracing::warn!("Error reading context file, skipping: {e}"); + } + } + } + + // Sort shallowest-first (highest depth index = shallowest ancestor). + // The global file has the highest depth index so it sorts last... but we + // actually want global first, then root → cwd. Reverse the depth ordering. + found.sort_by_key(|b| std::cmp::Reverse(b.depth)); + + Ok(found.into_iter().map(|f| f.file).collect()) +} + +/// The default global context file path (`~/.config/atuin/TERMINAL.md`). +pub(crate) fn global_context_path() -> PathBuf { + atuin_common::utils::config_dir().join(CONTEXT_FILENAME) +} + +async fn load_context_file(path: &Path, depth: usize) -> Result> { + match tokio::fs::read_to_string(path).await { + Ok(content) => Ok(Some(FoundFile { + depth, + file: RawContextFile { + path: path.to_path_buf(), + content, + }, + })), + Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None), + Err(e) => Err(e.into()), + } +} -- cgit v1.3.1