aboutsummaryrefslogtreecommitdiffstats
path: root/crates/atuin-ai/src/user_context/interpolate.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/atuin-ai/src/user_context/interpolate.rs')
-rw-r--r--crates/atuin-ai/src/user_context/interpolate.rs279
1 files changed, 279 insertions, 0 deletions
diff --git a/crates/atuin-ai/src/user_context/interpolate.rs b/crates/atuin-ai/src/user_context/interpolate.rs
new file mode 100644
index 00000000..91e34ab4
--- /dev/null
+++ b/crates/atuin-ai/src/user_context/interpolate.rs
@@ -0,0 +1,279 @@
+//! Parse `.atuin/ai-context.md` files and execute embedded commands.
+//!
+//! Two interpolation syntaxes are supported:
+//!
+//! **Inline:** `!`command`` — the `!` immediately before a code span triggers
+//! execution. The entire `!`...`` span is replaced with the command's stdout.
+//!
+//! **Block:**
+//! ````markdown
+//! ```!
+//! command
+//! ```
+//! ````
+//! A fenced code block with `!` as the info string. The block body is executed
+//! as a script and the entire fenced block is replaced with stdout.
+//!
+//! Regular code spans and fenced code blocks (without `!`) are left untouched.
+
+use std::ops::Range;
+use std::time::Duration;
+
+use pulldown_cmark::{CodeBlockKind, Event, Options, Parser, Tag, TagEnd};
+
+/// A command to execute, with its byte range in the source for replacement.
+#[derive(Debug)]
+struct Command {
+ /// Byte range in the source to replace (includes the `!` for inline, or
+ /// the full ``` fence for blocks).
+ range: Range<usize>,
+ /// The command string to execute.
+ body: String,
+}
+
+/// Maximum time for a single command.
+const COMMAND_TIMEOUT: Duration = Duration::from_secs(5);
+
+/// Maximum bytes of stdout to capture from a single command.
+const MAX_OUTPUT_BYTES: usize = 64 * 1024;
+
+/// Parse a context file for interpolation commands.
+fn parse_commands(source: &str) -> Vec<Command> {
+ let parser = Parser::new_ext(source, Options::empty());
+ let mut commands = Vec::new();
+
+ // Block state: accumulate text across multiple Text events, finalize on End.
+ let mut block_start: Option<usize> = None;
+ let mut block_body = String::new();
+
+ for (event, range) in parser.into_offset_iter() {
+ match event {
+ // Inline: !`command`
+ Event::Code(code) if range.start > 0 && source.as_bytes()[range.start - 1] == b'!' => {
+ commands.push(Command {
+ range: (range.start - 1)..range.end,
+ body: code.to_string(),
+ });
+ }
+
+ // Block: ```! ... ```
+ Event::Start(Tag::CodeBlock(CodeBlockKind::Fenced(info))) if info.as_ref() == "!" => {
+ block_start = Some(range.start);
+ block_body.clear();
+ }
+ Event::Text(text) if block_start.is_some() => {
+ block_body.push_str(&text);
+ }
+ Event::End(TagEnd::CodeBlock) if block_start.is_some() => {
+ let start = block_start.take().unwrap();
+ let trimmed = block_body.trim();
+ if !trimmed.is_empty() {
+ commands.push(Command {
+ range: start..range.end,
+ body: trimmed.to_string(),
+ });
+ }
+ block_body.clear();
+ }
+
+ _ => {}
+ }
+ }
+
+ commands
+}
+
+/// Execute all commands in a context file and return the interpolated content.
+///
+/// Commands are executed in parallel. Failed commands are replaced with an
+/// error marker so the AI has visibility into what went wrong.
+pub(crate) async fn interpolate(source: &str, shell: &str) -> String {
+ let commands = parse_commands(source);
+ if commands.is_empty() {
+ return source.to_string();
+ }
+
+ // Execute all commands in parallel.
+ let mut handles = Vec::with_capacity(commands.len());
+ for cmd in &commands {
+ let shell = shell.to_string();
+ let body = cmd.body.clone();
+ handles.push(tokio::spawn(
+ async move { run_command(&shell, &body).await },
+ ));
+ }
+
+ // Collect results.
+ let mut results = Vec::with_capacity(handles.len());
+ for handle in handles {
+ let output = match handle.await {
+ Ok(output) => output,
+ Err(e) => format!("[error: task panicked: {e}]"),
+ };
+ results.push(output);
+ }
+
+ // Rebuild the source, replacing command ranges with their output.
+ // Commands are in source order from the parser, but let's sort to be safe.
+ let mut replacements: Vec<(Range<usize>, &str)> = commands
+ .iter()
+ .zip(results.iter())
+ .map(|(cmd, output)| (cmd.range.clone(), output.as_str()))
+ .collect();
+ replacements.sort_by_key(|(range, _)| range.start);
+
+ let mut out = String::with_capacity(source.len());
+ let mut cursor = 0;
+ for (range, output) in &replacements {
+ out.push_str(&source[cursor..range.start]);
+ out.push_str(output);
+ cursor = range.end;
+ }
+ out.push_str(&source[cursor..]);
+
+ out
+}
+
+async fn run_command(shell: &str, body: &str) -> String {
+ let result = tokio::time::timeout(
+ COMMAND_TIMEOUT,
+ tokio::process::Command::new(shell)
+ .arg("-c")
+ .arg(body)
+ .output(),
+ )
+ .await;
+
+ match result {
+ Ok(Ok(output)) => {
+ if output.status.success() {
+ if output.stdout.len() > MAX_OUTPUT_BYTES {
+ let truncated = String::from_utf8_lossy(&output.stdout[..MAX_OUTPUT_BYTES]);
+ format!(
+ "{}\n[output truncated at {}KB]",
+ truncated.trim(),
+ MAX_OUTPUT_BYTES / 1024
+ )
+ } else {
+ String::from_utf8_lossy(&output.stdout).trim().to_string()
+ }
+ } else {
+ let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
+ let code = output.status.code().unwrap_or(-1);
+ format!("[error: exit code {code}: {stderr}]")
+ }
+ }
+ Ok(Err(e)) => format!("[error: {e}]"),
+ Err(_) => format!(
+ "[error: command timed out after {}s]",
+ COMMAND_TIMEOUT.as_secs()
+ ),
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn parse_inline_command() {
+ let source = "Branch: !`git branch --show-current`";
+ let cmds = parse_commands(source);
+ assert_eq!(cmds.len(), 1);
+ assert_eq!(cmds[0].body, "git branch --show-current");
+ assert_eq!(
+ &source[cmds[0].range.clone()],
+ "!`git branch --show-current`"
+ );
+ }
+
+ #[test]
+ fn parse_inline_double_backtick() {
+ let source = r#"Host: !``echo `hostname` ``"#;
+ let cmds = parse_commands(source);
+ assert_eq!(cmds.len(), 1);
+ assert_eq!(cmds[0].body, "echo `hostname` ");
+ }
+
+ #[test]
+ fn parse_block_command() {
+ let source = "Before\n\n```!\necho hello\npython3 --version\n```\n\nAfter";
+ let cmds = parse_commands(source);
+ assert_eq!(cmds.len(), 1);
+ assert_eq!(cmds[0].body, "echo hello\npython3 --version");
+ }
+
+ #[test]
+ fn regular_code_not_matched() {
+ let source = "Normal `code span` and ```bash\necho hi\n```";
+ let cmds = parse_commands(source);
+ assert_eq!(cmds.len(), 0);
+ }
+
+ #[test]
+ fn bang_not_adjacent_not_matched() {
+ let source = "Exclaim! Then `code` here.";
+ let cmds = parse_commands(source);
+ // The `!` and backtick are separated by " Then ", not adjacent.
+ assert_eq!(cmds.len(), 0);
+ }
+
+ #[test]
+ fn mixed_content() {
+ let source = "\
+# Project Context
+
+Branch: !`git branch --show-current`
+
+Regular code: `not a command`
+
+```!
+echo $VIRTUAL_ENV
+```
+
+```bash
+echo not interpolated
+```
+
+End.";
+ let cmds = parse_commands(source);
+ assert_eq!(cmds.len(), 2);
+ assert_eq!(cmds[0].body, "git branch --show-current");
+ assert_eq!(cmds[1].body, "echo $VIRTUAL_ENV");
+ }
+
+ #[tokio::test]
+ async fn interpolate_replaces_inline_command() {
+ let source = "Branch: !`echo main`";
+ let result = interpolate(source, "sh").await;
+ assert_eq!(result, "Branch: main");
+ }
+
+ #[tokio::test]
+ async fn interpolate_replaces_block_command() {
+ let source = "Before\n\n```!\necho hello world\n```\n\nAfter";
+ let result = interpolate(source, "sh").await;
+ assert_eq!(result, "Before\n\nhello world\n\nAfter");
+ }
+
+ #[tokio::test]
+ async fn interpolate_preserves_non_command_content() {
+ let source = "Just plain markdown with `code` and no bangs.";
+ let result = interpolate(source, "sh").await;
+ assert_eq!(result, source);
+ }
+
+ #[tokio::test]
+ async fn interpolate_failed_command_shows_error() {
+ let source = "Result: !`exit 1`";
+ let result = interpolate(source, "sh").await;
+ assert!(result.starts_with("Result: [error:"));
+ }
+
+ #[tokio::test]
+ async fn interpolate_multiple_commands() {
+ let source = "A: !`echo one` B: !`echo two`";
+ let result = interpolate(source, "sh").await;
+ assert_eq!(result, "A: one B: two");
+ }
+}