aboutsummaryrefslogtreecommitdiffstats
path: root/crates/atuin-scripts/src/execution.rs
blob: bb7f42279fdfb6d0035f4ac74095a82463bdaef2 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
use crate::store::script::Script;
use eyre::Result;
use std::collections::{HashMap, HashSet};
use std::process::Stdio;
use tempfile::NamedTempFile;
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::sync::mpsc;
use tokio::task;
use tracing::debug;

// Helper function to build a complete script with shebang
pub fn build_executable_script(script: String, shebang: String) -> String {
    if shebang.is_empty() {
        // Default to bash if no shebang is provided
        format!("#!/usr/bin/env bash\n{script}")
    } else if script.starts_with("#!") {
        format!("{shebang}\n{script}")
    } else {
        format!("#!{shebang}\n{script}")
    }
}

/// Represents the communication channels for an interactive script
pub struct ScriptSession {
    /// Channel to send input to the script
    pub stdin_tx: mpsc::Sender<String>,
    /// Exit code of the process once it completes
    pub exit_code_rx: mpsc::Receiver<i32>,
}

impl ScriptSession {
    /// Send input to the running script
    pub async fn send_input(&self, input: String) -> Result<(), mpsc::error::SendError<String>> {
        self.stdin_tx.send(input).await
    }

    /// Wait for the script to complete and get the exit code
    pub async fn wait_for_exit(&mut self) -> Option<i32> {
        self.exit_code_rx.recv().await
    }
}

fn setup_template(script: &Script) -> Result<minijinja::Environment> {
    let mut env = minijinja::Environment::new();
    env.set_trim_blocks(true);
    env.add_template("script", script.script.as_str())?;

    Ok(env)
}

/// Template a script with the given context
pub fn template_script(
    script: &Script,
    context: &HashMap<String, serde_json::Value>,
) -> Result<String> {
    let env = setup_template(script)?;
    let template = env.get_template("script")?;
    let rendered = template.render(context)?;

    Ok(rendered)
}

/// Get the variables that need to be templated in a script
pub fn template_variables(script: &Script) -> Result<HashSet<String>> {
    let env = setup_template(script)?;
    let template = env.get_template("script")?;

    Ok(template.undeclared_variables(true))
}

/// Execute a script interactively, allowing for ongoing stdin/stdout interaction
pub async fn execute_script_interactive(
    script: String,
    shebang: String,
) -> Result<ScriptSession, Box<dyn std::error::Error + Send + Sync>> {
    // Create a temporary file for the script
    let temp_file = NamedTempFile::new()?;
    let temp_path = temp_file.path().to_path_buf();

    debug!("creating temp file at {}", temp_path.display());

    // Extract interpreter from shebang for fallback execution
    let interpreter = if !shebang.is_empty() {
        shebang.trim_start_matches("#!").trim().to_string()
    } else {
        "/usr/bin/env bash".to_string()
    };

    // Write script content to the temp file, including the shebang
    let full_script_content = build_executable_script(script.clone(), shebang.clone());

    debug!("writing script content to temp file");
    tokio::fs::write(&temp_path, &full_script_content).await?;

    // Make it executable on Unix systems
    #[cfg(unix)]
    {
        debug!("making script executable");
        use std::os::unix::fs::PermissionsExt;
        let mut perms = std::fs::metadata(&temp_path)?.permissions();
        perms.set_mode(0o755);
        std::fs::set_permissions(&temp_path, perms)?;
    }

    // Store the temp_file to prevent it from being dropped
    // This ensures it won't be deleted while the script is running
    let _keep_temp_file = temp_file;

    debug!("attempting direct script execution");
    let mut child_result = tokio::process::Command::new(temp_path.to_str().unwrap())
        .stdin(Stdio::piped())
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .spawn();

    // If direct execution fails, try using the interpreter
    if let Err(e) = &child_result {
        debug!("direct execution failed: {}, trying with interpreter", e);

        // When falling back to interpreter, remove the shebang from the file
        // Some interpreters don't handle scripts with shebangs well
        debug!("writing script content without shebang for interpreter execution");
        tokio::fs::write(&temp_path, &script).await?;

        // Parse the interpreter command
        let parts: Vec<&str> = interpreter.split_whitespace().collect();
        if !parts.is_empty() {
            let mut cmd = tokio::process::Command::new(parts[0]);

            // Add any interpreter args
            for i in parts.iter().skip(1) {
                cmd.arg(i);
            }

            // Add the script path
            cmd.arg(temp_path.to_str().unwrap());

            // Try with the interpreter
            child_result = cmd
                .stdin(Stdio::piped())
                .stdout(Stdio::piped())
                .stderr(Stdio::piped())
                .spawn();
        }
    }

    // If it still fails, return the error
    let mut child = match child_result {
        Ok(child) => child,
        Err(e) => {
            return Err(format!("Failed to execute script: {e}").into());
        }
    };

    // Get handles to stdin, stdout, stderr
    let mut stdin = child
        .stdin
        .take()
        .ok_or_else(|| "Failed to open child process stdin".to_string())?;
    let stdout = child
        .stdout
        .take()
        .ok_or_else(|| "Failed to open child process stdout".to_string())?;
    let stderr = child
        .stderr
        .take()
        .ok_or_else(|| "Failed to open child process stderr".to_string())?;

    // Create channels for the interactive session
    let (stdin_tx, mut stdin_rx) = mpsc::channel::<String>(32);
    let (exit_code_tx, exit_code_rx) = mpsc::channel::<i32>(1);

    // handle user stdin
    debug!("spawning stdin handler");
    tokio::spawn(async move {
        while let Some(input) = stdin_rx.recv().await {
            if let Err(e) = stdin.write_all(input.as_bytes()).await {
                eprintln!("Error writing to stdin: {e}");
                break;
            }
            if let Err(e) = stdin.flush().await {
                eprintln!("Error flushing stdin: {e}");
                break;
            }
        }
        // when the channel closes (sender dropped), we let stdin close naturally
    });

    // handle stdout
    debug!("spawning stdout handler");
    let stdout_handle = task::spawn(async move {
        let mut stdout_reader = BufReader::new(stdout);
        let mut buffer = [0u8; 1024];
        let mut stdout_writer = tokio::io::stdout();

        loop {
            match stdout_reader.read(&mut buffer).await {
                Ok(0) => break, // End of stdout
                Ok(n) => {
                    if let Err(e) = stdout_writer.write_all(&buffer[0..n]).await {
                        eprintln!("Error writing to stdout: {e}");
                        break;
                    }
                    if let Err(e) = stdout_writer.flush().await {
                        eprintln!("Error flushing stdout: {e}");
                        break;
                    }
                }
                Err(e) => {
                    eprintln!("Error reading from process stdout: {e}");
                    break;
                }
            }
        }
    });

    // Process stderr in a separate task
    debug!("spawning stderr handler");
    let stderr_handle = task::spawn(async move {
        let mut stderr_reader = BufReader::new(stderr);
        let mut buffer = [0u8; 1024];
        let mut stderr_writer = tokio::io::stderr();

        loop {
            match stderr_reader.read(&mut buffer).await {
                Ok(0) => break, // End of stderr
                Ok(n) => {
                    if let Err(e) = stderr_writer.write_all(&buffer[0..n]).await {
                        eprintln!("Error writing to stderr: {e}");
                        break;
                    }
                    if let Err(e) = stderr_writer.flush().await {
                        eprintln!("Error flushing stderr: {e}");
                        break;
                    }
                }
                Err(e) => {
                    eprintln!("Error reading from process stderr: {e}");
                    break;
                }
            }
        }
    });

    // Spawn a task to wait for the child process to complete
    debug!("spawning exit code handler");
    let _keep_temp_file_clone = _keep_temp_file;
    tokio::spawn(async move {
        // Keep the temp file alive until the process completes
        let _temp_file_ref = _keep_temp_file_clone;

        // Wait for the child process to complete
        let status = match child.wait().await {
            Ok(status) => {
                debug!("Process exited with status: {:?}", status);
                status
            }
            Err(e) => {
                eprintln!("Error waiting for child process: {e}");
                // Send a default error code
                let _ = exit_code_tx.send(-1).await;
                return;
            }
        };

        // Wait for stdout/stderr tasks to complete
        if let Err(e) = stdout_handle.await {
            eprintln!("Error joining stdout task: {e}");
        }

        if let Err(e) = stderr_handle.await {
            eprintln!("Error joining stderr task: {e}");
        }

        // Send the exit code
        let exit_code = status.code().unwrap_or(-1);
        debug!("Sending exit code: {}", exit_code);
        let _ = exit_code_tx.send(exit_code).await;
    });

    // Return the communication channels as a ScriptSession
    Ok(ScriptSession {
        stdin_tx,
        exit_code_rx,
    })
}