aboutsummaryrefslogtreecommitdiffstats
path: root/crates
diff options
context:
space:
mode:
Diffstat (limited to 'crates')
-rw-r--r--crates/atuin-ai/src/driver.rs8
-rw-r--r--crates/atuin-ai/src/lib.rs1
-rw-r--r--crates/atuin-ai/src/stream.rs11
-rw-r--r--crates/atuin-ai/src/user_context/interpolate.rs279
-rw-r--r--crates/atuin-ai/src/user_context/mod.rs68
-rw-r--r--crates/atuin-ai/src/user_context/walker.rs90
6 files changed, 456 insertions, 1 deletions
diff --git a/crates/atuin-ai/src/driver.rs b/crates/atuin-ai/src/driver.rs
index 3acb9798..b5e1c275 100644
--- a/crates/atuin-ai/src/driver.rs
+++ b/crates/atuin-ai/src/driver.rs
@@ -775,6 +775,13 @@ async fn run_stream_bridge(
use crate::stream::{StreamContent, StreamControl, StreamFrame, create_chat_stream};
use futures::StreamExt;
+ // Gather user context files (TERMINAL.md) and interpolate commands.
+ let shell = client_ctx.shell.as_deref().unwrap_or("sh");
+ let start_dir = std::env::current_dir().unwrap_or_default();
+ let global_ctx_path = crate::user_context::global_context_path();
+ let user_contexts =
+ crate::user_context::gather(&start_dir, Some(&global_ctx_path), shell).await;
+
let stream = create_chat_stream(
app_ctx.endpoint.clone(),
app_ctx.token.clone(),
@@ -782,6 +789,7 @@ async fn run_stream_bridge(
client_ctx,
app_ctx.send_cwd,
app_ctx.last_command.clone(),
+ user_contexts,
);
futures::pin_mut!(stream);
diff --git a/crates/atuin-ai/src/lib.rs b/crates/atuin-ai/src/lib.rs
index 540aece3..289f6ea2 100644
--- a/crates/atuin-ai/src/lib.rs
+++ b/crates/atuin-ai/src/lib.rs
@@ -14,3 +14,4 @@ pub(crate) mod store;
pub(crate) mod stream;
pub(crate) mod tools;
pub(crate) mod tui;
+pub(crate) mod user_context;
diff --git a/crates/atuin-ai/src/stream.rs b/crates/atuin-ai/src/stream.rs
index 19d287e7..e7155a08 100644
--- a/crates/atuin-ai/src/stream.rs
+++ b/crates/atuin-ai/src/stream.rs
@@ -100,6 +100,7 @@ pub(crate) fn create_chat_stream(
client_ctx: ClientContext,
send_cwd: bool,
last_command: Option<String>,
+ user_contexts: Vec<crate::user_context::UserContext>,
) -> std::pin::Pin<Box<dyn futures::Stream<Item = Result<StreamFrame>> + Send>> {
Box::pin(async_stream::stream! {
ensure_crypto_provider();
@@ -115,10 +116,18 @@ pub(crate) fn create_chat_stream(
let context = client_ctx.to_json(send_cwd, last_command.as_deref());
+ let mut config = serde_json::json!({
+ "capabilities": request.capabilities,
+ });
+
+ if !user_contexts.is_empty() {
+ config["user_contexts"] = serde_json::json!(user_contexts);
+ }
+
let mut request_body = serde_json::json!({
"messages": request.messages,
"context": context,
- "capabilities": request.capabilities,
+ "config": config,
"invocation_id": request.invocation_id
});
diff --git a/crates/atuin-ai/src/user_context/interpolate.rs b/crates/atuin-ai/src/user_context/interpolate.rs
new file mode 100644
index 00000000..91e34ab4
--- /dev/null
+++ b/crates/atuin-ai/src/user_context/interpolate.rs
@@ -0,0 +1,279 @@
+//! Parse `.atuin/ai-context.md` files and execute embedded commands.
+//!
+//! Two interpolation syntaxes are supported:
+//!
+//! **Inline:** `!`command`` — the `!` immediately before a code span triggers
+//! execution. The entire `!`...`` span is replaced with the command's stdout.
+//!
+//! **Block:**
+//! ````markdown
+//! ```!
+//! command
+//! ```
+//! ````
+//! A fenced code block with `!` as the info string. The block body is executed
+//! as a script and the entire fenced block is replaced with stdout.
+//!
+//! Regular code spans and fenced code blocks (without `!`) are left untouched.
+
+use std::ops::Range;
+use std::time::Duration;
+
+use pulldown_cmark::{CodeBlockKind, Event, Options, Parser, Tag, TagEnd};
+
+/// A command to execute, with its byte range in the source for replacement.
+#[derive(Debug)]
+struct Command {
+ /// Byte range in the source to replace (includes the `!` for inline, or
+ /// the full ``` fence for blocks).
+ range: Range<usize>,
+ /// The command string to execute.
+ body: String,
+}
+
+/// Maximum time for a single command.
+const COMMAND_TIMEOUT: Duration = Duration::from_secs(5);
+
+/// Maximum bytes of stdout to capture from a single command.
+const MAX_OUTPUT_BYTES: usize = 64 * 1024;
+
+/// Parse a context file for interpolation commands.
+fn parse_commands(source: &str) -> Vec<Command> {
+ let parser = Parser::new_ext(source, Options::empty());
+ let mut commands = Vec::new();
+
+ // Block state: accumulate text across multiple Text events, finalize on End.
+ let mut block_start: Option<usize> = None;
+ let mut block_body = String::new();
+
+ for (event, range) in parser.into_offset_iter() {
+ match event {
+ // Inline: !`command`
+ Event::Code(code) if range.start > 0 && source.as_bytes()[range.start - 1] == b'!' => {
+ commands.push(Command {
+ range: (range.start - 1)..range.end,
+ body: code.to_string(),
+ });
+ }
+
+ // Block: ```! ... ```
+ Event::Start(Tag::CodeBlock(CodeBlockKind::Fenced(info))) if info.as_ref() == "!" => {
+ block_start = Some(range.start);
+ block_body.clear();
+ }
+ Event::Text(text) if block_start.is_some() => {
+ block_body.push_str(&text);
+ }
+ Event::End(TagEnd::CodeBlock) if block_start.is_some() => {
+ let start = block_start.take().unwrap();
+ let trimmed = block_body.trim();
+ if !trimmed.is_empty() {
+ commands.push(Command {
+ range: start..range.end,
+ body: trimmed.to_string(),
+ });
+ }
+ block_body.clear();
+ }
+
+ _ => {}
+ }
+ }
+
+ commands
+}
+
+/// Execute all commands in a context file and return the interpolated content.
+///
+/// Commands are executed in parallel. Failed commands are replaced with an
+/// error marker so the AI has visibility into what went wrong.
+pub(crate) async fn interpolate(source: &str, shell: &str) -> String {
+ let commands = parse_commands(source);
+ if commands.is_empty() {
+ return source.to_string();
+ }
+
+ // Execute all commands in parallel.
+ let mut handles = Vec::with_capacity(commands.len());
+ for cmd in &commands {
+ let shell = shell.to_string();
+ let body = cmd.body.clone();
+ handles.push(tokio::spawn(
+ async move { run_command(&shell, &body).await },
+ ));
+ }
+
+ // Collect results.
+ let mut results = Vec::with_capacity(handles.len());
+ for handle in handles {
+ let output = match handle.await {
+ Ok(output) => output,
+ Err(e) => format!("[error: task panicked: {e}]"),
+ };
+ results.push(output);
+ }
+
+ // Rebuild the source, replacing command ranges with their output.
+ // Commands are in source order from the parser, but let's sort to be safe.
+ let mut replacements: Vec<(Range<usize>, &str)> = commands
+ .iter()
+ .zip(results.iter())
+ .map(|(cmd, output)| (cmd.range.clone(), output.as_str()))
+ .collect();
+ replacements.sort_by_key(|(range, _)| range.start);
+
+ let mut out = String::with_capacity(source.len());
+ let mut cursor = 0;
+ for (range, output) in &replacements {
+ out.push_str(&source[cursor..range.start]);
+ out.push_str(output);
+ cursor = range.end;
+ }
+ out.push_str(&source[cursor..]);
+
+ out
+}
+
+async fn run_command(shell: &str, body: &str) -> String {
+ let result = tokio::time::timeout(
+ COMMAND_TIMEOUT,
+ tokio::process::Command::new(shell)
+ .arg("-c")
+ .arg(body)
+ .output(),
+ )
+ .await;
+
+ match result {
+ Ok(Ok(output)) => {
+ if output.status.success() {
+ if output.stdout.len() > MAX_OUTPUT_BYTES {
+ let truncated = String::from_utf8_lossy(&output.stdout[..MAX_OUTPUT_BYTES]);
+ format!(
+ "{}\n[output truncated at {}KB]",
+ truncated.trim(),
+ MAX_OUTPUT_BYTES / 1024
+ )
+ } else {
+ String::from_utf8_lossy(&output.stdout).trim().to_string()
+ }
+ } else {
+ let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
+ let code = output.status.code().unwrap_or(-1);
+ format!("[error: exit code {code}: {stderr}]")
+ }
+ }
+ Ok(Err(e)) => format!("[error: {e}]"),
+ Err(_) => format!(
+ "[error: command timed out after {}s]",
+ COMMAND_TIMEOUT.as_secs()
+ ),
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn parse_inline_command() {
+ let source = "Branch: !`git branch --show-current`";
+ let cmds = parse_commands(source);
+ assert_eq!(cmds.len(), 1);
+ assert_eq!(cmds[0].body, "git branch --show-current");
+ assert_eq!(
+ &source[cmds[0].range.clone()],
+ "!`git branch --show-current`"
+ );
+ }
+
+ #[test]
+ fn parse_inline_double_backtick() {
+ let source = r#"Host: !``echo `hostname` ``"#;
+ let cmds = parse_commands(source);
+ assert_eq!(cmds.len(), 1);
+ assert_eq!(cmds[0].body, "echo `hostname` ");
+ }
+
+ #[test]
+ fn parse_block_command() {
+ let source = "Before\n\n```!\necho hello\npython3 --version\n```\n\nAfter";
+ let cmds = parse_commands(source);
+ assert_eq!(cmds.len(), 1);
+ assert_eq!(cmds[0].body, "echo hello\npython3 --version");
+ }
+
+ #[test]
+ fn regular_code_not_matched() {
+ let source = "Normal `code span` and ```bash\necho hi\n```";
+ let cmds = parse_commands(source);
+ assert_eq!(cmds.len(), 0);
+ }
+
+ #[test]
+ fn bang_not_adjacent_not_matched() {
+ let source = "Exclaim! Then `code` here.";
+ let cmds = parse_commands(source);
+ // The `!` and backtick are separated by " Then ", not adjacent.
+ assert_eq!(cmds.len(), 0);
+ }
+
+ #[test]
+ fn mixed_content() {
+ let source = "\
+# Project Context
+
+Branch: !`git branch --show-current`
+
+Regular code: `not a command`
+
+```!
+echo $VIRTUAL_ENV
+```
+
+```bash
+echo not interpolated
+```
+
+End.";
+ let cmds = parse_commands(source);
+ assert_eq!(cmds.len(), 2);
+ assert_eq!(cmds[0].body, "git branch --show-current");
+ assert_eq!(cmds[1].body, "echo $VIRTUAL_ENV");
+ }
+
+ #[tokio::test]
+ async fn interpolate_replaces_inline_command() {
+ let source = "Branch: !`echo main`";
+ let result = interpolate(source, "sh").await;
+ assert_eq!(result, "Branch: main");
+ }
+
+ #[tokio::test]
+ async fn interpolate_replaces_block_command() {
+ let source = "Before\n\n```!\necho hello world\n```\n\nAfter";
+ let result = interpolate(source, "sh").await;
+ assert_eq!(result, "Before\n\nhello world\n\nAfter");
+ }
+
+ #[tokio::test]
+ async fn interpolate_preserves_non_command_content() {
+ let source = "Just plain markdown with `code` and no bangs.";
+ let result = interpolate(source, "sh").await;
+ assert_eq!(result, source);
+ }
+
+ #[tokio::test]
+ async fn interpolate_failed_command_shows_error() {
+ let source = "Result: !`exit 1`";
+ let result = interpolate(source, "sh").await;
+ assert!(result.starts_with("Result: [error:"));
+ }
+
+ #[tokio::test]
+ async fn interpolate_multiple_commands() {
+ let source = "A: !`echo one` B: !`echo two`";
+ let result = interpolate(source, "sh").await;
+ assert_eq!(result, "A: one B: two");
+ }
+}
diff --git a/crates/atuin-ai/src/user_context/mod.rs b/crates/atuin-ai/src/user_context/mod.rs
new file mode 100644
index 00000000..295efdec
--- /dev/null
+++ b/crates/atuin-ai/src/user_context/mod.rs
@@ -0,0 +1,68 @@
+//! User-authored context files (`TERMINAL.md`).
+//!
+//! Context files are markdown documents that can embed shell commands for
+//! dynamic content. Before each API request, context files are discovered
+//! by walking the filesystem, commands are executed, and the interpolated
+//! content is sent to the server as `config.user_contexts`.
+
+mod interpolate;
+mod walker;
+
+use std::path::Path;
+
+pub(crate) use walker::global_context_path;
+
+/// A fully resolved user context, ready to include in an API request.
+#[derive(Debug, Clone, serde::Serialize)]
+pub(crate) struct UserContext {
+ /// The path to the context file on disk.
+ pub path: String,
+ /// The interpolated content.
+ pub data: String,
+}
+
+/// Discover context files and interpolate embedded commands.
+///
+/// Walks from `start` up to the filesystem root looking for
+/// `.atuin/ai-context.md`, then checks `global_path`. Returns contexts
+/// ordered from most general (global/root) to most specific (deepest).
+pub(crate) async fn gather(
+ start: &Path,
+ global_path: Option<&Path>,
+ shell: &str,
+) -> Vec<UserContext> {
+ let raw_files = match walker::walk(start, global_path).await {
+ Ok(files) => files,
+ Err(e) => {
+ tracing::warn!("Failed to walk for context files: {e}");
+ return Vec::new();
+ }
+ };
+
+ if raw_files.is_empty() {
+ return Vec::new();
+ }
+
+ // Interpolate all files in parallel.
+ let mut handles = Vec::with_capacity(raw_files.len());
+ for file in raw_files {
+ let shell = shell.to_string();
+ handles.push(tokio::spawn(async move {
+ let data = interpolate::interpolate(&file.content, &shell).await;
+ UserContext {
+ path: file.path.to_string_lossy().to_string(),
+ data,
+ }
+ }));
+ }
+
+ let mut contexts = Vec::with_capacity(handles.len());
+ for handle in handles {
+ match handle.await {
+ Ok(ctx) => contexts.push(ctx),
+ Err(e) => tracing::warn!("Context interpolation task failed: {e}"),
+ }
+ }
+
+ contexts
+}
diff --git a/crates/atuin-ai/src/user_context/walker.rs b/crates/atuin-ai/src/user_context/walker.rs
new file mode 100644
index 00000000..117bbd33
--- /dev/null
+++ b/crates/atuin-ai/src/user_context/walker.rs
@@ -0,0 +1,90 @@
+//! Filesystem traversal for `TERMINAL.md` context files.
+//!
+//! Walks from the starting directory up to the filesystem root, checking for
+//! `.atuin/TERMINAL.md` and `TERMINAL.md` at each level. Then checks the global
+//! config directory. Returns files ordered from shallowest (global/root) to
+//! deepest (most project-specific), so that context layers naturally from
+//! general to specific.
+
+use std::path::{Path, PathBuf};
+
+use eyre::Result;
+use tokio::task::JoinSet;
+
+const CONTEXT_FILENAME: &str = "TERMINAL.md";
+
+/// A context file found on disk, before interpolation.
+#[derive(Debug)]
+pub(crate) struct RawContextFile {
+ pub path: PathBuf,
+ pub content: String,
+}
+
+struct FoundFile {
+ depth: usize,
+ file: RawContextFile,
+}
+
+/// Walk from `start` up to the filesystem root collecting `TERMINAL.md`
+/// context files, then check the global path. Returns files shallowest-first.
+///
+/// At each ancestor directory, checks two locations:
+/// - `.atuin/TERMINAL.md` (dotdir-scoped)
+/// - `TERMINAL.md` (project root)
+pub(crate) async fn walk(start: &Path, global_path: Option<&Path>) -> Result<Vec<RawContextFile>> {
+ let dirs: Vec<PathBuf> = start.ancestors().map(PathBuf::from).collect();
+ let dir_count = dirs.len();
+
+ let mut set: JoinSet<Result<Option<FoundFile>>> = JoinSet::new();
+
+ for (index, dir) in dirs.into_iter().enumerate() {
+ let dir2 = dir.clone();
+ set.spawn(async move {
+ load_context_file(&dir.join(".atuin").join(CONTEXT_FILENAME), index).await
+ });
+ set.spawn(async move { load_context_file(&dir2.join(CONTEXT_FILENAME), index).await });
+ }
+
+ if let Some(global) = global_path {
+ let global = global.to_path_buf();
+ let depth = dir_count;
+ set.spawn(async move { load_context_file(&global, depth).await });
+ }
+
+ let mut found = Vec::new();
+ while let Some(result) = set.join_next().await {
+ match result? {
+ Ok(Some(f)) => found.push(f),
+ Ok(None) => {}
+ Err(e) => {
+ tracing::warn!("Error reading context file, skipping: {e}");
+ }
+ }
+ }
+
+ // Sort shallowest-first (highest depth index = shallowest ancestor).
+ // The global file has the highest depth index so it sorts last... but we
+ // actually want global first, then root → cwd. Reverse the depth ordering.
+ found.sort_by_key(|b| std::cmp::Reverse(b.depth));
+
+ Ok(found.into_iter().map(|f| f.file).collect())
+}
+
+/// The default global context file path (`~/.config/atuin/TERMINAL.md`).
+pub(crate) fn global_context_path() -> PathBuf {
+ atuin_common::utils::config_dir().join(CONTEXT_FILENAME)
+}
+
+async fn load_context_file(path: &Path, depth: usize) -> Result<Option<FoundFile>> {
+ match tokio::fs::read_to_string(path).await {
+ Ok(content) => Ok(Some(FoundFile {
+ depth,
+ file: RawContextFile {
+ path: path.to_path_buf(),
+ content,
+ },
+ })),
+ Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
+ Err(e) => Err(e.into()),
+ }
+}