aboutsummaryrefslogtreecommitdiffstats
path: root/crates/atuin-scripts/src
diff options
context:
space:
mode:
Diffstat (limited to 'crates/atuin-scripts/src')
-rw-r--r--crates/atuin-scripts/src/database.rs358
-rw-r--r--crates/atuin-scripts/src/execution.rs287
-rw-r--r--crates/atuin-scripts/src/lib.rs4
-rw-r--r--crates/atuin-scripts/src/settings.rs1
-rw-r--r--crates/atuin-scripts/src/store.rs109
-rw-r--r--crates/atuin-scripts/src/store/record.rs215
-rw-r--r--crates/atuin-scripts/src/store/script.rs151
7 files changed, 1125 insertions, 0 deletions
diff --git a/crates/atuin-scripts/src/database.rs b/crates/atuin-scripts/src/database.rs
new file mode 100644
index 00000000..71da69ff
--- /dev/null
+++ b/crates/atuin-scripts/src/database.rs
@@ -0,0 +1,358 @@
+use std::{path::Path, str::FromStr, time::Duration};
+
+use atuin_common::utils;
+use sqlx::{
+ Result, Row,
+ sqlite::{
+ SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow,
+ SqliteSynchronous,
+ },
+};
+use tokio::fs;
+use tracing::debug;
+use uuid::Uuid;
+
+use crate::store::script::Script;
+
+#[derive(Debug, Clone)]
+pub struct Database {
+ pub pool: SqlitePool,
+}
+
+impl Database {
+ pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> {
+ let path = path.as_ref();
+ debug!("opening script sqlite database at {:?}", path);
+
+ if utils::broken_symlink(path) {
+ eprintln!(
+ "Atuin: Script sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement."
+ );
+ std::process::exit(1);
+ }
+
+ if !path.exists() {
+ if let Some(dir) = path.parent() {
+ fs::create_dir_all(dir).await?;
+ }
+ }
+
+ let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())?
+ .journal_mode(SqliteJournalMode::Wal)
+ .optimize_on_close(true, None)
+ .synchronous(SqliteSynchronous::Normal)
+ .with_regexp()
+ .foreign_keys(true)
+ .create_if_missing(true);
+
+ let pool = SqlitePoolOptions::new()
+ .acquire_timeout(Duration::from_secs_f64(timeout))
+ .connect_with(opts)
+ .await?;
+
+ Self::setup_db(&pool).await?;
+ Ok(Self { pool })
+ }
+
+ pub async fn sqlite_version(&self) -> Result<String> {
+ sqlx::query_scalar("SELECT sqlite_version()")
+ .fetch_one(&self.pool)
+ .await
+ }
+
+ async fn setup_db(pool: &SqlitePool) -> Result<()> {
+ debug!("running sqlite database setup");
+
+ sqlx::migrate!("./migrations").run(pool).await?;
+
+ Ok(())
+ }
+
+ async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, s: &Script) -> Result<()> {
+ sqlx::query(
+ "insert or ignore into scripts(id, name, description, shebang, script)
+ values(?1, ?2, ?3, ?4, ?5)",
+ )
+ .bind(s.id.to_string())
+ .bind(s.name.as_str())
+ .bind(s.description.as_str())
+ .bind(s.shebang.as_str())
+ .bind(s.script.as_str())
+ .execute(&mut **tx)
+ .await?;
+
+ for tag in s.tags.iter() {
+ sqlx::query(
+ "insert or ignore into script_tags(script_id, tag)
+ values(?1, ?2)",
+ )
+ .bind(s.id.to_string())
+ .bind(tag)
+ .execute(&mut **tx)
+ .await?;
+ }
+
+ Ok(())
+ }
+
+ pub async fn save(&self, s: &Script) -> Result<()> {
+ debug!("saving script to sqlite");
+ let mut tx = self.pool.begin().await?;
+ Self::save_raw(&mut tx, s).await?;
+ tx.commit().await?;
+
+ Ok(())
+ }
+
+ pub async fn save_bulk(&self, s: &[Script]) -> Result<()> {
+ debug!("saving scripts to sqlite");
+
+ let mut tx = self.pool.begin().await?;
+
+ for i in s {
+ Self::save_raw(&mut tx, i).await?;
+ }
+
+ tx.commit().await?;
+
+ Ok(())
+ }
+
+ fn query_script(row: SqliteRow) -> Script {
+ let id = row.get("id");
+ let name = row.get("name");
+ let description = row.get("description");
+ let shebang = row.get("shebang");
+ let script = row.get("script");
+
+ let id = Uuid::parse_str(id).unwrap();
+
+ Script {
+ id,
+ name,
+ description,
+ shebang,
+ script,
+ tags: vec![],
+ }
+ }
+
+ fn query_script_tags(row: SqliteRow) -> String {
+ row.get("tag")
+ }
+
+ #[allow(dead_code)]
+ async fn load(&self, id: &str) -> Result<Option<Script>> {
+ debug!("loading script item {}", id);
+
+ let res = sqlx::query("select * from scripts where id = ?1")
+ .bind(id)
+ .map(Self::query_script)
+ .fetch_optional(&self.pool)
+ .await?;
+
+ // intentionally not joining, don't want to duplicate the script data in memory a whole bunch.
+ if let Some(mut script) = res {
+ let tags = sqlx::query("select tag from script_tags where script_id = ?1")
+ .bind(id)
+ .map(Self::query_script_tags)
+ .fetch_all(&self.pool)
+ .await?;
+
+ script.tags = tags;
+ Ok(Some(script))
+ } else {
+ Ok(None)
+ }
+ }
+
+ pub async fn list(&self) -> Result<Vec<Script>> {
+ debug!("listing scripts");
+
+ let mut res = sqlx::query("select * from scripts")
+ .map(Self::query_script)
+ .fetch_all(&self.pool)
+ .await?;
+
+ // Fetch all the tags for each script
+ for script in res.iter_mut() {
+ let tags = sqlx::query("select tag from script_tags where script_id = ?1")
+ .bind(script.id.to_string())
+ .map(Self::query_script_tags)
+ .fetch_all(&self.pool)
+ .await?;
+
+ script.tags = tags;
+ }
+
+ Ok(res)
+ }
+
+ pub async fn delete(&self, id: &str) -> Result<()> {
+ debug!("deleting script {}", id);
+
+ sqlx::query("delete from scripts where id = ?1")
+ .bind(id)
+ .execute(&self.pool)
+ .await?;
+
+ // delete all the tags for the script
+ sqlx::query("delete from script_tags where script_id = ?1")
+ .bind(id)
+ .execute(&self.pool)
+ .await?;
+
+ Ok(())
+ }
+
+ pub async fn update(&self, s: &Script) -> Result<()> {
+ debug!("updating script {:?}", s);
+
+ let mut tx = self.pool.begin().await?;
+
+ // Update the script's base fields
+ sqlx::query("update scripts set name = ?1, description = ?2, shebang = ?3, script = ?4 where id = ?5")
+ .bind(s.name.as_str())
+ .bind(s.description.as_str())
+ .bind(s.shebang.as_str())
+ .bind(s.script.as_str())
+ .bind(s.id.to_string())
+ .execute(&mut *tx)
+ .await?;
+
+ // Delete all existing tags for this script
+ sqlx::query("delete from script_tags where script_id = ?1")
+ .bind(s.id.to_string())
+ .execute(&mut *tx)
+ .await?;
+
+ // Insert new tags
+ for tag in s.tags.iter() {
+ sqlx::query(
+ "insert or ignore into script_tags(script_id, tag)
+ values(?1, ?2)",
+ )
+ .bind(s.id.to_string())
+ .bind(tag)
+ .execute(&mut *tx)
+ .await?;
+ }
+
+ tx.commit().await?;
+
+ Ok(())
+ }
+
+ pub async fn get_by_name(&self, name: &str) -> Result<Option<Script>> {
+ let res = sqlx::query("select * from scripts where name = ?1")
+ .bind(name)
+ .map(Self::query_script)
+ .fetch_optional(&self.pool)
+ .await?;
+
+ let script = if let Some(mut script) = res {
+ let tags = sqlx::query("select tag from script_tags where script_id = ?1")
+ .bind(script.id.to_string())
+ .map(Self::query_script_tags)
+ .fetch_all(&self.pool)
+ .await?;
+
+ script.tags = tags;
+ Some(script)
+ } else {
+ None
+ };
+
+ Ok(script)
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ #[tokio::test]
+ async fn test_list() {
+ let db = Database::new("sqlite::memory:", 1.0).await.unwrap();
+ let scripts = db.list().await.unwrap();
+ assert_eq!(scripts.len(), 0);
+
+ let script = Script::builder()
+ .name("test".to_string())
+ .description("test".to_string())
+ .shebang("test".to_string())
+ .script("test".to_string())
+ .build();
+
+ db.save(&script).await.unwrap();
+
+ let scripts = db.list().await.unwrap();
+ assert_eq!(scripts.len(), 1);
+ assert_eq!(scripts[0].name, "test");
+ }
+
+ #[tokio::test]
+ async fn test_save_load() {
+ let db = Database::new("sqlite::memory:", 1.0).await.unwrap();
+
+ let script = Script::builder()
+ .name("test name".to_string())
+ .description("test description".to_string())
+ .shebang("test shebang".to_string())
+ .script("test script".to_string())
+ .build();
+
+ db.save(&script).await.unwrap();
+
+ let loaded = db.load(&script.id.to_string()).await.unwrap().unwrap();
+
+ assert_eq!(loaded, script);
+ }
+
+ #[tokio::test]
+ async fn test_save_bulk() {
+ let db = Database::new("sqlite::memory:", 1.0).await.unwrap();
+
+ let scripts = vec![
+ Script::builder()
+ .name("test name".to_string())
+ .description("test description".to_string())
+ .shebang("test shebang".to_string())
+ .script("test script".to_string())
+ .build(),
+ Script::builder()
+ .name("test name 2".to_string())
+ .description("test description 2".to_string())
+ .shebang("test shebang 2".to_string())
+ .script("test script 2".to_string())
+ .build(),
+ ];
+
+ db.save_bulk(&scripts).await.unwrap();
+
+ let loaded = db.list().await.unwrap();
+ assert_eq!(loaded.len(), 2);
+ assert_eq!(loaded[0].name, "test name");
+ assert_eq!(loaded[1].name, "test name 2");
+ }
+
+ #[tokio::test]
+ async fn test_delete() {
+ let db = Database::new("sqlite::memory:", 1.0).await.unwrap();
+
+ let script = Script::builder()
+ .name("test name".to_string())
+ .description("test description".to_string())
+ .shebang("test shebang".to_string())
+ .script("test script".to_string())
+ .build();
+
+ db.save(&script).await.unwrap();
+
+ assert_eq!(db.list().await.unwrap().len(), 1);
+ db.delete(&script.id.to_string()).await.unwrap();
+
+ let loaded = db.list().await.unwrap();
+ assert_eq!(loaded.len(), 0);
+ }
+}
diff --git a/crates/atuin-scripts/src/execution.rs b/crates/atuin-scripts/src/execution.rs
new file mode 100644
index 00000000..90f7c4eb
--- /dev/null
+++ b/crates/atuin-scripts/src/execution.rs
@@ -0,0 +1,287 @@
+use crate::store::script::Script;
+use eyre::Result;
+use std::collections::{HashMap, HashSet};
+use std::fs;
+use std::process::Stdio;
+use tempfile::NamedTempFile;
+use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
+use tokio::sync::mpsc;
+use tokio::task;
+use tracing::debug;
+
+// Helper function to build a complete script with shebang
+pub fn build_executable_script(script: String, shebang: String) -> String {
+ if shebang.is_empty() {
+ // Default to bash if no shebang is provided
+ format!("#!/usr/bin/env bash\n{}", script)
+ } else if script.starts_with("#!") {
+ format!("{}\n{}", shebang, script)
+ } else {
+ format!("#!{}\n{}", shebang, script)
+ }
+}
+
+/// Represents the communication channels for an interactive script
+pub struct ScriptSession {
+ /// Channel to send input to the script
+ pub stdin_tx: mpsc::Sender<String>,
+ /// Exit code of the process once it completes
+ pub exit_code_rx: mpsc::Receiver<i32>,
+}
+
+impl ScriptSession {
+ /// Send input to the running script
+ pub async fn send_input(&self, input: String) -> Result<(), mpsc::error::SendError<String>> {
+ self.stdin_tx.send(input).await
+ }
+
+ /// Wait for the script to complete and get the exit code
+ pub async fn wait_for_exit(&mut self) -> Option<i32> {
+ self.exit_code_rx.recv().await
+ }
+}
+
+fn setup_template(script: &Script) -> Result<minijinja::Environment> {
+ let mut env = minijinja::Environment::new();
+ env.set_trim_blocks(true);
+ env.add_template("script", script.script.as_str())?;
+
+ Ok(env)
+}
+
+/// Template a script with the given context
+pub fn template_script(
+ script: &Script,
+ context: &HashMap<String, serde_json::Value>,
+) -> Result<String> {
+ let env = setup_template(script)?;
+ let template = env.get_template("script")?;
+ let rendered = template.render(context)?;
+
+ Ok(rendered)
+}
+
+/// Get the variables that need to be templated in a script
+pub fn template_variables(script: &Script) -> Result<HashSet<String>> {
+ let env = setup_template(script)?;
+ let template = env.get_template("script")?;
+
+ Ok(template.undeclared_variables(true))
+}
+
+/// Execute a script interactively, allowing for ongoing stdin/stdout interaction
+pub async fn execute_script_interactive(
+ script: String,
+ shebang: String,
+) -> Result<ScriptSession, Box<dyn std::error::Error + Send + Sync>> {
+ // Create a temporary file for the script
+ let temp_file = NamedTempFile::new()?;
+ let temp_path = temp_file.path().to_path_buf();
+
+ debug!("creating temp file at {}", temp_path.display());
+
+ // Extract interpreter from shebang for fallback execution
+ let interpreter = if !shebang.is_empty() {
+ shebang.trim_start_matches("#!").trim().to_string()
+ } else {
+ "/usr/bin/env bash".to_string()
+ };
+
+ // Write script content to the temp file, including the shebang
+ let full_script_content = build_executable_script(script.clone(), shebang.clone());
+
+ debug!("writing script content to temp file");
+ tokio::fs::write(&temp_path, &full_script_content).await?;
+
+ // Make it executable on Unix systems
+ #[cfg(unix)]
+ {
+ debug!("making script executable");
+ use std::os::unix::fs::PermissionsExt;
+ let mut perms = fs::metadata(&temp_path)?.permissions();
+ perms.set_mode(0o755);
+ fs::set_permissions(&temp_path, perms)?;
+ }
+
+ // Store the temp_file to prevent it from being dropped
+ // This ensures it won't be deleted while the script is running
+ let _keep_temp_file = temp_file;
+
+ debug!("attempting direct script execution");
+ let mut child_result = tokio::process::Command::new(temp_path.to_str().unwrap())
+ .stdin(Stdio::piped())
+ .stdout(Stdio::piped())
+ .stderr(Stdio::piped())
+ .spawn();
+
+ // If direct execution fails, try using the interpreter
+ if let Err(e) = &child_result {
+ debug!("direct execution failed: {}, trying with interpreter", e);
+
+ // When falling back to interpreter, remove the shebang from the file
+ // Some interpreters don't handle scripts with shebangs well
+ debug!("writing script content without shebang for interpreter execution");
+ tokio::fs::write(&temp_path, &script).await?;
+
+ // Parse the interpreter command
+ let parts: Vec<&str> = interpreter.split_whitespace().collect();
+ if !parts.is_empty() {
+ let mut cmd = tokio::process::Command::new(parts[0]);
+
+ // Add any interpreter args
+ for i in parts.iter().skip(1) {
+ cmd.arg(i);
+ }
+
+ // Add the script path
+ cmd.arg(temp_path.to_str().unwrap());
+
+ // Try with the interpreter
+ child_result = cmd
+ .stdin(Stdio::piped())
+ .stdout(Stdio::piped())
+ .stderr(Stdio::piped())
+ .spawn();
+ }
+ }
+
+ // If it still fails, return the error
+ let mut child = match child_result {
+ Ok(child) => child,
+ Err(e) => {
+ return Err(format!("Failed to execute script: {}", e).into());
+ }
+ };
+
+ // Get handles to stdin, stdout, stderr
+ let mut stdin = child
+ .stdin
+ .take()
+ .ok_or_else(|| "Failed to open child process stdin".to_string())?;
+ let stdout = child
+ .stdout
+ .take()
+ .ok_or_else(|| "Failed to open child process stdout".to_string())?;
+ let stderr = child
+ .stderr
+ .take()
+ .ok_or_else(|| "Failed to open child process stderr".to_string())?;
+
+ // Create channels for the interactive session
+ let (stdin_tx, mut stdin_rx) = mpsc::channel::<String>(32);
+ let (exit_code_tx, exit_code_rx) = mpsc::channel::<i32>(1);
+
+ // handle user stdin
+ debug!("spawning stdin handler");
+ tokio::spawn(async move {
+ while let Some(input) = stdin_rx.recv().await {
+ if let Err(e) = stdin.write_all(input.as_bytes()).await {
+ eprintln!("Error writing to stdin: {}", e);
+ break;
+ }
+ if let Err(e) = stdin.flush().await {
+ eprintln!("Error flushing stdin: {}", e);
+ break;
+ }
+ }
+ // when the channel closes (sender dropped), we let stdin close naturally
+ });
+
+ // handle stdout
+ debug!("spawning stdout handler");
+ let stdout_handle = task::spawn(async move {
+ let mut stdout_reader = BufReader::new(stdout);
+ let mut buffer = [0u8; 1024];
+ let mut stdout_writer = tokio::io::stdout();
+
+ loop {
+ match stdout_reader.read(&mut buffer).await {
+ Ok(0) => break, // End of stdout
+ Ok(n) => {
+ if let Err(e) = stdout_writer.write_all(&buffer[0..n]).await {
+ eprintln!("Error writing to stdout: {}", e);
+ break;
+ }
+ if let Err(e) = stdout_writer.flush().await {
+ eprintln!("Error flushing stdout: {}", e);
+ break;
+ }
+ }
+ Err(e) => {
+ eprintln!("Error reading from process stdout: {}", e);
+ break;
+ }
+ }
+ }
+ });
+
+ // Process stderr in a separate task
+ debug!("spawning stderr handler");
+ let stderr_handle = task::spawn(async move {
+ let mut stderr_reader = BufReader::new(stderr);
+ let mut buffer = [0u8; 1024];
+ let mut stderr_writer = tokio::io::stderr();
+
+ loop {
+ match stderr_reader.read(&mut buffer).await {
+ Ok(0) => break, // End of stderr
+ Ok(n) => {
+ if let Err(e) = stderr_writer.write_all(&buffer[0..n]).await {
+ eprintln!("Error writing to stderr: {}", e);
+ break;
+ }
+ if let Err(e) = stderr_writer.flush().await {
+ eprintln!("Error flushing stderr: {}", e);
+ break;
+ }
+ }
+ Err(e) => {
+ eprintln!("Error reading from process stderr: {}", e);
+ break;
+ }
+ }
+ }
+ });
+
+ // Spawn a task to wait for the child process to complete
+ debug!("spawning exit code handler");
+ let _keep_temp_file_clone = _keep_temp_file;
+ tokio::spawn(async move {
+ // Keep the temp file alive until the process completes
+ let _temp_file_ref = _keep_temp_file_clone;
+
+ // Wait for the child process to complete
+ let status = match child.wait().await {
+ Ok(status) => {
+ debug!("Process exited with status: {:?}", status);
+ status
+ }
+ Err(e) => {
+ eprintln!("Error waiting for child process: {}", e);
+ // Send a default error code
+ let _ = exit_code_tx.send(-1).await;
+ return;
+ }
+ };
+
+ // Wait for stdout/stderr tasks to complete
+ if let Err(e) = stdout_handle.await {
+ eprintln!("Error joining stdout task: {}", e);
+ }
+
+ if let Err(e) = stderr_handle.await {
+ eprintln!("Error joining stderr task: {}", e);
+ }
+
+ // Send the exit code
+ let exit_code = status.code().unwrap_or(-1);
+ debug!("Sending exit code: {}", exit_code);
+ let _ = exit_code_tx.send(exit_code).await;
+ });
+
+ // Return the communication channels as a ScriptSession
+ Ok(ScriptSession {
+ stdin_tx,
+ exit_code_rx,
+ })
+}
diff --git a/crates/atuin-scripts/src/lib.rs b/crates/atuin-scripts/src/lib.rs
new file mode 100644
index 00000000..c79c7089
--- /dev/null
+++ b/crates/atuin-scripts/src/lib.rs
@@ -0,0 +1,4 @@
+pub mod database;
+pub mod execution;
+pub mod settings;
+pub mod store;
diff --git a/crates/atuin-scripts/src/settings.rs b/crates/atuin-scripts/src/settings.rs
new file mode 100644
index 00000000..8b137891
--- /dev/null
+++ b/crates/atuin-scripts/src/settings.rs
@@ -0,0 +1 @@
+
diff --git a/crates/atuin-scripts/src/store.rs b/crates/atuin-scripts/src/store.rs
new file mode 100644
index 00000000..ba7a1ca1
--- /dev/null
+++ b/crates/atuin-scripts/src/store.rs
@@ -0,0 +1,109 @@
+use eyre::{Result, bail};
+
+use atuin_client::record::sqlite_store::SqliteStore;
+use atuin_client::record::{encryption::PASETO_V4, store::Store};
+use atuin_common::record::{Host, HostId, Record, RecordId, RecordIdx};
+use record::ScriptRecord;
+use script::{SCRIPT_TAG, SCRIPT_VERSION, Script};
+
+use crate::database::Database;
+
+pub mod record;
+pub mod script;
+
+#[derive(Debug, Clone)]
+pub struct ScriptStore {
+ pub store: SqliteStore,
+ pub host_id: HostId,
+ pub encryption_key: [u8; 32],
+}
+
+impl ScriptStore {
+ pub fn new(store: SqliteStore, host_id: HostId, encryption_key: [u8; 32]) -> Self {
+ ScriptStore {
+ store,
+ host_id,
+ encryption_key,
+ }
+ }
+
+ async fn push_record(&self, record: ScriptRecord) -> Result<(RecordId, RecordIdx)> {
+ let bytes = record.serialize()?;
+ let idx = self
+ .store
+ .last(self.host_id, SCRIPT_TAG)
+ .await?
+ .map_or(0, |p| p.idx + 1);
+
+ let record = Record::builder()
+ .host(Host::new(self.host_id))
+ .version(SCRIPT_VERSION.to_string())
+ .tag(SCRIPT_TAG.to_string())
+ .idx(idx)
+ .data(bytes)
+ .build();
+
+ let id = record.id;
+
+ self.store
+ .push(&record.encrypt::<PASETO_V4>(&self.encryption_key))
+ .await?;
+
+ Ok((id, idx))
+ }
+
+ pub async fn create(&self, script: Script) -> Result<()> {
+ let record = ScriptRecord::Create(script);
+ self.push_record(record).await?;
+ Ok(())
+ }
+
+ pub async fn update(&self, script: Script) -> Result<()> {
+ let record = ScriptRecord::Update(script);
+ self.push_record(record).await?;
+ Ok(())
+ }
+
+ pub async fn delete(&self, script_id: uuid::Uuid) -> Result<()> {
+ let record = ScriptRecord::Delete(script_id);
+ self.push_record(record).await?;
+ Ok(())
+ }
+
+ pub async fn scripts(&self) -> Result<Vec<ScriptRecord>> {
+ let records = self.store.all_tagged(SCRIPT_TAG).await?;
+ let mut ret = Vec::with_capacity(records.len());
+
+ for record in records.into_iter() {
+ let script = match record.version.as_str() {
+ SCRIPT_VERSION => {
+ let decrypted = record.decrypt::<PASETO_V4>(&self.encryption_key)?;
+
+ ScriptRecord::deserialize(&decrypted.data, SCRIPT_VERSION)
+ }
+ version => bail!("unknown history version {version:?}"),
+ }?;
+
+ ret.push(script);
+ }
+
+ Ok(ret)
+ }
+
+ pub async fn build(&self, database: Database) -> Result<()> {
+ // Get all the scripts from the database - they are already sorted by timestamp
+ let scripts = self.scripts().await?;
+
+ for script in scripts {
+ match script {
+ ScriptRecord::Create(script) => {
+ database.save(&script).await?;
+ }
+ ScriptRecord::Update(script) => database.update(&script).await?,
+ ScriptRecord::Delete(id) => database.delete(&id.to_string()).await?,
+ }
+ }
+
+ Ok(())
+ }
+}
diff --git a/crates/atuin-scripts/src/store/record.rs b/crates/atuin-scripts/src/store/record.rs
new file mode 100644
index 00000000..4c925be3
--- /dev/null
+++ b/crates/atuin-scripts/src/store/record.rs
@@ -0,0 +1,215 @@
+use atuin_common::record::DecryptedData;
+use eyre::{Result, eyre};
+use uuid::Uuid;
+
+use crate::store::script::SCRIPT_VERSION;
+
+use super::script::Script;
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub enum ScriptRecord {
+ Create(Script),
+ Update(Script),
+ Delete(Uuid),
+}
+
+impl ScriptRecord {
+ pub fn serialize(&self) -> Result<DecryptedData> {
+ use rmp::encode;
+
+ let mut output = vec![];
+
+ match self {
+ ScriptRecord::Create(script) => {
+ // 0 -> a script create
+ encode::write_u8(&mut output, 0)?;
+
+ let bytes = script.serialize()?;
+
+ encode::write_bin(&mut output, &bytes.0)?;
+ }
+
+ ScriptRecord::Delete(id) => {
+ // 1 -> a script delete
+ encode::write_u8(&mut output, 1)?;
+ encode::write_str(&mut output, id.to_string().as_str())?;
+ }
+
+ ScriptRecord::Update(script) => {
+ // 2 -> a script update
+ encode::write_u8(&mut output, 2)?;
+ let bytes = script.serialize()?;
+ encode::write_bin(&mut output, &bytes.0)?;
+ }
+ };
+
+ Ok(DecryptedData(output))
+ }
+
+ pub fn deserialize(data: &DecryptedData, version: &str) -> Result<Self> {
+ use rmp::decode;
+
+ fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report {
+ eyre!("{err:?}")
+ }
+
+ match version {
+ SCRIPT_VERSION => {
+ let mut bytes = decode::Bytes::new(&data.0);
+
+ let record_type = decode::read_u8(&mut bytes).map_err(error_report)?;
+
+ match record_type {
+ // create
+ 0 => {
+ // written by encode::write_bin above
+ let _ = decode::read_bin_len(&mut bytes).map_err(error_report)?;
+ let script = Script::deserialize(bytes.remaining_slice())?;
+ Ok(ScriptRecord::Create(script))
+ }
+
+ // delete
+ 1 => {
+ let bytes = bytes.remaining_slice();
+ let (id, _) = decode::read_str_from_slice(bytes).map_err(error_report)?;
+ Ok(ScriptRecord::Delete(Uuid::parse_str(id)?))
+ }
+
+ // update
+ 2 => {
+ // written by encode::write_bin above
+ let _ = decode::read_bin_len(&mut bytes).map_err(error_report)?;
+ let script = Script::deserialize(bytes.remaining_slice())?;
+ Ok(ScriptRecord::Update(script))
+ }
+
+ _ => Err(eyre!("unknown script record type {record_type}")),
+ }
+ }
+ _ => Err(eyre!("unknown version {version:?}")),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_serialize_create() {
+ let script = Script::builder()
+ .id(uuid::Uuid::parse_str("0195c825a35f7982bdb016168881cbc6").unwrap())
+ .name("test".to_string())
+ .description("test".to_string())
+ .shebang("test".to_string())
+ .tags(vec!["test".to_string()])
+ .script("test".to_string())
+ .build();
+
+ let record = ScriptRecord::Create(script);
+
+ let serialized = record.serialize().unwrap();
+
+ assert_eq!(
+ serialized.0,
+ vec![
+ 204, 0, 196, 65, 150, 217, 36, 48, 49, 57, 53, 99, 56, 50, 53, 45, 97, 51, 53, 102,
+ 45, 55, 57, 56, 50, 45, 98, 100, 98, 48, 45, 49, 54, 49, 54, 56, 56, 56, 49, 99,
+ 98, 99, 54, 164, 116, 101, 115, 116, 164, 116, 101, 115, 116, 164, 116, 101, 115,
+ 116, 145, 164, 116, 101, 115, 116, 164, 116, 101, 115, 116
+ ]
+ );
+ }
+
+ #[test]
+ fn test_serialize_delete() {
+ let record = ScriptRecord::Delete(
+ uuid::Uuid::parse_str("0195c825a35f7982bdb016168881cbc6").unwrap(),
+ );
+
+ let serialized = record.serialize().unwrap();
+
+ assert_eq!(
+ serialized.0,
+ vec![
+ 204, 1, 217, 36, 48, 49, 57, 53, 99, 56, 50, 53, 45, 97, 51, 53, 102, 45, 55, 57,
+ 56, 50, 45, 98, 100, 98, 48, 45, 49, 54, 49, 54, 56, 56, 56, 49, 99, 98, 99, 54
+ ]
+ );
+ }
+
+ #[test]
+ fn test_serialize_update() {
+ let script = Script::builder()
+ .id(uuid::Uuid::parse_str("0195c825a35f7982bdb016168881cbc6").unwrap())
+ .name(String::from("test"))
+ .description(String::from("test"))
+ .shebang(String::from("test"))
+ .tags(vec![String::from("test"), String::from("test2")])
+ .script(String::from("test"))
+ .build();
+
+ let record = ScriptRecord::Update(script);
+
+ let serialized = record.serialize().unwrap();
+
+ assert_eq!(
+ serialized.0,
+ vec![
+ 204, 2, 196, 71, 150, 217, 36, 48, 49, 57, 53, 99, 56, 50, 53, 45, 97, 51, 53, 102,
+ 45, 55, 57, 56, 50, 45, 98, 100, 98, 48, 45, 49, 54, 49, 54, 56, 56, 56, 49, 99,
+ 98, 99, 54, 164, 116, 101, 115, 116, 164, 116, 101, 115, 116, 164, 116, 101, 115,
+ 116, 146, 164, 116, 101, 115, 116, 165, 116, 101, 115, 116, 50, 164, 116, 101, 115,
+ 116
+ ],
+ );
+ }
+
+ #[test]
+ fn test_serialize_deserialize_create() {
+ let script = Script::builder()
+ .name("test".to_string())
+ .description("test".to_string())
+ .shebang("test".to_string())
+ .tags(vec!["test".to_string()])
+ .script("test".to_string())
+ .build();
+
+ let record = ScriptRecord::Create(script);
+
+ let serialized = record.serialize().unwrap();
+ let deserialized = ScriptRecord::deserialize(&serialized, SCRIPT_VERSION).unwrap();
+
+ assert_eq!(record, deserialized);
+ }
+
+ #[test]
+ fn test_serialize_deserialize_delete() {
+ let record = ScriptRecord::Delete(
+ uuid::Uuid::parse_str("0195c825a35f7982bdb016168881cbc6").unwrap(),
+ );
+
+ let serialized = record.serialize().unwrap();
+ let deserialized = ScriptRecord::deserialize(&serialized, SCRIPT_VERSION).unwrap();
+
+ assert_eq!(record, deserialized);
+ }
+
+ #[test]
+ fn test_serialize_deserialize_update() {
+ let script = Script::builder()
+ .name("test".to_string())
+ .description("test".to_string())
+ .shebang("test".to_string())
+ .tags(vec!["test".to_string()])
+ .script("test".to_string())
+ .build();
+
+ let record = ScriptRecord::Update(script);
+
+ let serialized = record.serialize().unwrap();
+ let deserialized = ScriptRecord::deserialize(&serialized, SCRIPT_VERSION).unwrap();
+
+ assert_eq!(record, deserialized);
+ }
+}
diff --git a/crates/atuin-scripts/src/store/script.rs b/crates/atuin-scripts/src/store/script.rs
new file mode 100644
index 00000000..af180320
--- /dev/null
+++ b/crates/atuin-scripts/src/store/script.rs
@@ -0,0 +1,151 @@
+use atuin_common::record::DecryptedData;
+use eyre::{Result, bail, ensure};
+use uuid::Uuid;
+
+use rmp::{
+ decode::{self, Bytes},
+ encode,
+};
+use typed_builder::TypedBuilder;
+
+pub const SCRIPT_VERSION: &str = "v0";
+pub const SCRIPT_TAG: &str = "script";
+pub const SCRIPT_LEN: usize = 20000; // 20kb max total len
+
+#[derive(Debug, Clone, PartialEq, Eq, TypedBuilder)]
+/// A script is a set of commands that can be run, with the specified shebang
+pub struct Script {
+ /// The id of the script
+ #[builder(default = uuid::Uuid::new_v4())]
+ pub id: Uuid,
+
+ /// The name of the script
+ pub name: String,
+
+ /// The description of the script
+ #[builder(default = String::new())]
+ pub description: String,
+
+ /// The interpreter of the script
+ #[builder(default = String::new())]
+ pub shebang: String,
+
+ /// The tags of the script
+ #[builder(default = Vec::new())]
+ pub tags: Vec<String>,
+
+ /// The script content
+ pub script: String,
+}
+
+impl Script {
+ pub fn serialize(&self) -> Result<DecryptedData> {
+ // sort the tags first, to ensure consistent ordering
+ let mut tags = self.tags.clone();
+ tags.sort();
+
+ let mut output = vec![];
+
+ encode::write_array_len(&mut output, 6)?;
+ encode::write_str(&mut output, &self.id.to_string())?;
+ encode::write_str(&mut output, &self.name)?;
+ encode::write_str(&mut output, &self.description)?;
+ encode::write_str(&mut output, &self.shebang)?;
+ encode::write_array_len(&mut output, self.tags.len() as u32)?;
+
+ for tag in &tags {
+ encode::write_str(&mut output, tag)?;
+ }
+
+ encode::write_str(&mut output, &self.script)?;
+
+ Ok(DecryptedData(output))
+ }
+
+ pub fn deserialize(bytes: &[u8]) -> Result<Self> {
+ let mut bytes = decode::Bytes::new(bytes);
+ let nfields = decode::read_array_len(&mut bytes).unwrap();
+
+ ensure!(nfields == 6, "too many entries in v0 script record");
+
+ let bytes = bytes.remaining_slice();
+
+ let (id, bytes) = decode::read_str_from_slice(bytes).unwrap();
+ let (name, bytes) = decode::read_str_from_slice(bytes).unwrap();
+ let (description, bytes) = decode::read_str_from_slice(bytes).unwrap();
+ let (shebang, bytes) = decode::read_str_from_slice(bytes).unwrap();
+
+ let mut bytes = Bytes::new(bytes);
+ let tags_len = decode::read_array_len(&mut bytes).unwrap();
+
+ let mut bytes = bytes.remaining_slice();
+
+ let mut tags = Vec::new();
+ for _ in 0..tags_len {
+ let (tag, remaining) = decode::read_str_from_slice(bytes).unwrap();
+ tags.push(tag.to_owned());
+ bytes = remaining;
+ }
+
+ let (script, bytes) = decode::read_str_from_slice(bytes).unwrap();
+
+ if !bytes.is_empty() {
+ bail!("trailing bytes in encoded script record. malformed")
+ }
+
+ Ok(Script {
+ id: Uuid::parse_str(id).unwrap(),
+ name: name.to_owned(),
+ description: description.to_owned(),
+ shebang: shebang.to_owned(),
+ tags,
+ script: script.to_owned(),
+ })
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_serialize() {
+ let script = Script {
+ id: uuid::Uuid::parse_str("0195c825a35f7982bdb016168881cbc6").unwrap(),
+ name: "test".to_string(),
+ description: "test".to_string(),
+ shebang: "test".to_string(),
+ tags: vec!["test".to_string()],
+ script: "test".to_string(),
+ };
+
+ let serialized = script.serialize().unwrap();
+ assert_eq!(
+ serialized.0,
+ vec![
+ 150, 217, 36, 48, 49, 57, 53, 99, 56, 50, 53, 45, 97, 51, 53, 102, 45, 55, 57, 56,
+ 50, 45, 98, 100, 98, 48, 45, 49, 54, 49, 54, 56, 56, 56, 49, 99, 98, 99, 54, 164,
+ 116, 101, 115, 116, 164, 116, 101, 115, 116, 164, 116, 101, 115, 116, 145, 164,
+ 116, 101, 115, 116, 164, 116, 101, 115, 116
+ ]
+ );
+ }
+
+ #[test]
+ fn test_serialize_deserialize() {
+ let script = Script {
+ id: uuid::Uuid::new_v4(),
+ name: "test".to_string(),
+ description: "test".to_string(),
+ shebang: "test".to_string(),
+ tags: vec!["test".to_string()],
+ script: "test".to_string(),
+ };
+
+ let serialized = script.serialize().unwrap();
+
+ let deserialized = Script::deserialize(&serialized.0).unwrap();
+
+ assert_eq!(script, deserialized);
+ }
+}