aboutsummaryrefslogtreecommitdiffstats
path: root/crates/atuin-ai/src/user_context/interpolate.rs
blob: 91e34ab440e7545c61dbfc1594eca8f335ecec3a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
//! Parse `.atuin/ai-context.md` files and execute embedded commands.
//!
//! Two interpolation syntaxes are supported:
//!
//! **Inline:** `!`command`` — the `!` immediately before a code span triggers
//! execution. The entire `!`...`` span is replaced with the command's stdout.
//!
//! **Block:**
//! ````markdown
//! ```!
//! command
//! ```
//! ````
//! A fenced code block with `!` as the info string. The block body is executed
//! as a script and the entire fenced block is replaced with stdout.
//!
//! Regular code spans and fenced code blocks (without `!`) are left untouched.

use std::ops::Range;
use std::time::Duration;

use pulldown_cmark::{CodeBlockKind, Event, Options, Parser, Tag, TagEnd};

/// A command to execute, with its byte range in the source for replacement.
#[derive(Debug)]
struct Command {
    /// Byte range in the source to replace (includes the `!` for inline, or
    /// the full ``` fence for blocks).
    range: Range<usize>,
    /// The command string to execute.
    body: String,
}

/// Maximum time for a single command.
const COMMAND_TIMEOUT: Duration = Duration::from_secs(5);

/// Maximum bytes of stdout to capture from a single command.
const MAX_OUTPUT_BYTES: usize = 64 * 1024;

/// Parse a context file for interpolation commands.
fn parse_commands(source: &str) -> Vec<Command> {
    let parser = Parser::new_ext(source, Options::empty());
    let mut commands = Vec::new();

    // Block state: accumulate text across multiple Text events, finalize on End.
    let mut block_start: Option<usize> = None;
    let mut block_body = String::new();

    for (event, range) in parser.into_offset_iter() {
        match event {
            // Inline: !`command`
            Event::Code(code) if range.start > 0 && source.as_bytes()[range.start - 1] == b'!' => {
                commands.push(Command {
                    range: (range.start - 1)..range.end,
                    body: code.to_string(),
                });
            }

            // Block: ```! ... ```
            Event::Start(Tag::CodeBlock(CodeBlockKind::Fenced(info))) if info.as_ref() == "!" => {
                block_start = Some(range.start);
                block_body.clear();
            }
            Event::Text(text) if block_start.is_some() => {
                block_body.push_str(&text);
            }
            Event::End(TagEnd::CodeBlock) if block_start.is_some() => {
                let start = block_start.take().unwrap();
                let trimmed = block_body.trim();
                if !trimmed.is_empty() {
                    commands.push(Command {
                        range: start..range.end,
                        body: trimmed.to_string(),
                    });
                }
                block_body.clear();
            }

            _ => {}
        }
    }

    commands
}

/// Execute all commands in a context file and return the interpolated content.
///
/// Commands are executed in parallel. Failed commands are replaced with an
/// error marker so the AI has visibility into what went wrong.
pub(crate) async fn interpolate(source: &str, shell: &str) -> String {
    let commands = parse_commands(source);
    if commands.is_empty() {
        return source.to_string();
    }

    // Execute all commands in parallel.
    let mut handles = Vec::with_capacity(commands.len());
    for cmd in &commands {
        let shell = shell.to_string();
        let body = cmd.body.clone();
        handles.push(tokio::spawn(
            async move { run_command(&shell, &body).await },
        ));
    }

    // Collect results.
    let mut results = Vec::with_capacity(handles.len());
    for handle in handles {
        let output = match handle.await {
            Ok(output) => output,
            Err(e) => format!("[error: task panicked: {e}]"),
        };
        results.push(output);
    }

    // Rebuild the source, replacing command ranges with their output.
    // Commands are in source order from the parser, but let's sort to be safe.
    let mut replacements: Vec<(Range<usize>, &str)> = commands
        .iter()
        .zip(results.iter())
        .map(|(cmd, output)| (cmd.range.clone(), output.as_str()))
        .collect();
    replacements.sort_by_key(|(range, _)| range.start);

    let mut out = String::with_capacity(source.len());
    let mut cursor = 0;
    for (range, output) in &replacements {
        out.push_str(&source[cursor..range.start]);
        out.push_str(output);
        cursor = range.end;
    }
    out.push_str(&source[cursor..]);

    out
}

async fn run_command(shell: &str, body: &str) -> String {
    let result = tokio::time::timeout(
        COMMAND_TIMEOUT,
        tokio::process::Command::new(shell)
            .arg("-c")
            .arg(body)
            .output(),
    )
    .await;

    match result {
        Ok(Ok(output)) => {
            if output.status.success() {
                if output.stdout.len() > MAX_OUTPUT_BYTES {
                    let truncated = String::from_utf8_lossy(&output.stdout[..MAX_OUTPUT_BYTES]);
                    format!(
                        "{}\n[output truncated at {}KB]",
                        truncated.trim(),
                        MAX_OUTPUT_BYTES / 1024
                    )
                } else {
                    String::from_utf8_lossy(&output.stdout).trim().to_string()
                }
            } else {
                let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
                let code = output.status.code().unwrap_or(-1);
                format!("[error: exit code {code}: {stderr}]")
            }
        }
        Ok(Err(e)) => format!("[error: {e}]"),
        Err(_) => format!(
            "[error: command timed out after {}s]",
            COMMAND_TIMEOUT.as_secs()
        ),
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn parse_inline_command() {
        let source = "Branch: !`git branch --show-current`";
        let cmds = parse_commands(source);
        assert_eq!(cmds.len(), 1);
        assert_eq!(cmds[0].body, "git branch --show-current");
        assert_eq!(
            &source[cmds[0].range.clone()],
            "!`git branch --show-current`"
        );
    }

    #[test]
    fn parse_inline_double_backtick() {
        let source = r#"Host: !``echo `hostname` ``"#;
        let cmds = parse_commands(source);
        assert_eq!(cmds.len(), 1);
        assert_eq!(cmds[0].body, "echo `hostname` ");
    }

    #[test]
    fn parse_block_command() {
        let source = "Before\n\n```!\necho hello\npython3 --version\n```\n\nAfter";
        let cmds = parse_commands(source);
        assert_eq!(cmds.len(), 1);
        assert_eq!(cmds[0].body, "echo hello\npython3 --version");
    }

    #[test]
    fn regular_code_not_matched() {
        let source = "Normal `code span` and ```bash\necho hi\n```";
        let cmds = parse_commands(source);
        assert_eq!(cmds.len(), 0);
    }

    #[test]
    fn bang_not_adjacent_not_matched() {
        let source = "Exclaim! Then `code` here.";
        let cmds = parse_commands(source);
        // The `!` and backtick are separated by " Then ", not adjacent.
        assert_eq!(cmds.len(), 0);
    }

    #[test]
    fn mixed_content() {
        let source = "\
# Project Context

Branch: !`git branch --show-current`

Regular code: `not a command`

```!
echo $VIRTUAL_ENV
```

```bash
echo not interpolated
```

End.";
        let cmds = parse_commands(source);
        assert_eq!(cmds.len(), 2);
        assert_eq!(cmds[0].body, "git branch --show-current");
        assert_eq!(cmds[1].body, "echo $VIRTUAL_ENV");
    }

    #[tokio::test]
    async fn interpolate_replaces_inline_command() {
        let source = "Branch: !`echo main`";
        let result = interpolate(source, "sh").await;
        assert_eq!(result, "Branch: main");
    }

    #[tokio::test]
    async fn interpolate_replaces_block_command() {
        let source = "Before\n\n```!\necho hello world\n```\n\nAfter";
        let result = interpolate(source, "sh").await;
        assert_eq!(result, "Before\n\nhello world\n\nAfter");
    }

    #[tokio::test]
    async fn interpolate_preserves_non_command_content() {
        let source = "Just plain markdown with `code` and no bangs.";
        let result = interpolate(source, "sh").await;
        assert_eq!(result, source);
    }

    #[tokio::test]
    async fn interpolate_failed_command_shows_error() {
        let source = "Result: !`exit 1`";
        let result = interpolate(source, "sh").await;
        assert!(result.starts_with("Result: [error:"));
    }

    #[tokio::test]
    async fn interpolate_multiple_commands() {
        let source = "A: !`echo one` B: !`echo two`";
        let result = interpolate(source, "sh").await;
        assert_eq!(result, "A: one B: two");
    }
}