aboutsummaryrefslogtreecommitdiffstats
path: root/crates/atuin-scripts/src/store
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--crates/atuin-scripts/src/store.rs109
-rw-r--r--crates/atuin-scripts/src/store/record.rs215
-rw-r--r--crates/atuin-scripts/src/store/script.rs151
3 files changed, 475 insertions, 0 deletions
diff --git a/crates/atuin-scripts/src/store.rs b/crates/atuin-scripts/src/store.rs
new file mode 100644
index 00000000..ba7a1ca1
--- /dev/null
+++ b/crates/atuin-scripts/src/store.rs
@@ -0,0 +1,109 @@
+use eyre::{Result, bail};
+
+use atuin_client::record::sqlite_store::SqliteStore;
+use atuin_client::record::{encryption::PASETO_V4, store::Store};
+use atuin_common::record::{Host, HostId, Record, RecordId, RecordIdx};
+use record::ScriptRecord;
+use script::{SCRIPT_TAG, SCRIPT_VERSION, Script};
+
+use crate::database::Database;
+
+pub mod record;
+pub mod script;
+
+#[derive(Debug, Clone)]
+pub struct ScriptStore {
+ pub store: SqliteStore,
+ pub host_id: HostId,
+ pub encryption_key: [u8; 32],
+}
+
+impl ScriptStore {
+ pub fn new(store: SqliteStore, host_id: HostId, encryption_key: [u8; 32]) -> Self {
+ ScriptStore {
+ store,
+ host_id,
+ encryption_key,
+ }
+ }
+
+ async fn push_record(&self, record: ScriptRecord) -> Result<(RecordId, RecordIdx)> {
+ let bytes = record.serialize()?;
+ let idx = self
+ .store
+ .last(self.host_id, SCRIPT_TAG)
+ .await?
+ .map_or(0, |p| p.idx + 1);
+
+ let record = Record::builder()
+ .host(Host::new(self.host_id))
+ .version(SCRIPT_VERSION.to_string())
+ .tag(SCRIPT_TAG.to_string())
+ .idx(idx)
+ .data(bytes)
+ .build();
+
+ let id = record.id;
+
+ self.store
+ .push(&record.encrypt::<PASETO_V4>(&self.encryption_key))
+ .await?;
+
+ Ok((id, idx))
+ }
+
+ pub async fn create(&self, script: Script) -> Result<()> {
+ let record = ScriptRecord::Create(script);
+ self.push_record(record).await?;
+ Ok(())
+ }
+
+ pub async fn update(&self, script: Script) -> Result<()> {
+ let record = ScriptRecord::Update(script);
+ self.push_record(record).await?;
+ Ok(())
+ }
+
+ pub async fn delete(&self, script_id: uuid::Uuid) -> Result<()> {
+ let record = ScriptRecord::Delete(script_id);
+ self.push_record(record).await?;
+ Ok(())
+ }
+
+ pub async fn scripts(&self) -> Result<Vec<ScriptRecord>> {
+ let records = self.store.all_tagged(SCRIPT_TAG).await?;
+ let mut ret = Vec::with_capacity(records.len());
+
+ for record in records.into_iter() {
+ let script = match record.version.as_str() {
+ SCRIPT_VERSION => {
+ let decrypted = record.decrypt::<PASETO_V4>(&self.encryption_key)?;
+
+ ScriptRecord::deserialize(&decrypted.data, SCRIPT_VERSION)
+ }
+ version => bail!("unknown history version {version:?}"),
+ }?;
+
+ ret.push(script);
+ }
+
+ Ok(ret)
+ }
+
+ pub async fn build(&self, database: Database) -> Result<()> {
+ // Get all the scripts from the database - they are already sorted by timestamp
+ let scripts = self.scripts().await?;
+
+ for script in scripts {
+ match script {
+ ScriptRecord::Create(script) => {
+ database.save(&script).await?;
+ }
+ ScriptRecord::Update(script) => database.update(&script).await?,
+ ScriptRecord::Delete(id) => database.delete(&id.to_string()).await?,
+ }
+ }
+
+ Ok(())
+ }
+}
diff --git a/crates/atuin-scripts/src/store/record.rs b/crates/atuin-scripts/src/store/record.rs
new file mode 100644
index 00000000..4c925be3
--- /dev/null
+++ b/crates/atuin-scripts/src/store/record.rs
@@ -0,0 +1,215 @@
+use atuin_common::record::DecryptedData;
+use eyre::{Result, eyre};
+use uuid::Uuid;
+
+use crate::store::script::SCRIPT_VERSION;
+
+use super::script::Script;
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub enum ScriptRecord {
+ Create(Script),
+ Update(Script),
+ Delete(Uuid),
+}
+
+impl ScriptRecord {
+ pub fn serialize(&self) -> Result<DecryptedData> {
+ use rmp::encode;
+
+ let mut output = vec![];
+
+ match self {
+ ScriptRecord::Create(script) => {
+ // 0 -> a script create
+ encode::write_u8(&mut output, 0)?;
+
+ let bytes = script.serialize()?;
+
+ encode::write_bin(&mut output, &bytes.0)?;
+ }
+
+ ScriptRecord::Delete(id) => {
+ // 1 -> a script delete
+ encode::write_u8(&mut output, 1)?;
+ encode::write_str(&mut output, id.to_string().as_str())?;
+ }
+
+ ScriptRecord::Update(script) => {
+ // 2 -> a script update
+ encode::write_u8(&mut output, 2)?;
+ let bytes = script.serialize()?;
+ encode::write_bin(&mut output, &bytes.0)?;
+ }
+ };
+
+ Ok(DecryptedData(output))
+ }
+
+ pub fn deserialize(data: &DecryptedData, version: &str) -> Result<Self> {
+ use rmp::decode;
+
+ fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report {
+ eyre!("{err:?}")
+ }
+
+ match version {
+ SCRIPT_VERSION => {
+ let mut bytes = decode::Bytes::new(&data.0);
+
+ let record_type = decode::read_u8(&mut bytes).map_err(error_report)?;
+
+ match record_type {
+ // create
+ 0 => {
+ // written by encode::write_bin above
+ let _ = decode::read_bin_len(&mut bytes).map_err(error_report)?;
+ let script = Script::deserialize(bytes.remaining_slice())?;
+ Ok(ScriptRecord::Create(script))
+ }
+
+ // delete
+ 1 => {
+ let bytes = bytes.remaining_slice();
+ let (id, _) = decode::read_str_from_slice(bytes).map_err(error_report)?;
+ Ok(ScriptRecord::Delete(Uuid::parse_str(id)?))
+ }
+
+ // update
+ 2 => {
+ // written by encode::write_bin above
+ let _ = decode::read_bin_len(&mut bytes).map_err(error_report)?;
+ let script = Script::deserialize(bytes.remaining_slice())?;
+ Ok(ScriptRecord::Update(script))
+ }
+
+ _ => Err(eyre!("unknown script record type {record_type}")),
+ }
+ }
+ _ => Err(eyre!("unknown version {version:?}")),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_serialize_create() {
+ let script = Script::builder()
+ .id(uuid::Uuid::parse_str("0195c825a35f7982bdb016168881cbc6").unwrap())
+ .name("test".to_string())
+ .description("test".to_string())
+ .shebang("test".to_string())
+ .tags(vec!["test".to_string()])
+ .script("test".to_string())
+ .build();
+
+ let record = ScriptRecord::Create(script);
+
+ let serialized = record.serialize().unwrap();
+
+ assert_eq!(
+ serialized.0,
+ vec![
+ 204, 0, 196, 65, 150, 217, 36, 48, 49, 57, 53, 99, 56, 50, 53, 45, 97, 51, 53, 102,
+ 45, 55, 57, 56, 50, 45, 98, 100, 98, 48, 45, 49, 54, 49, 54, 56, 56, 56, 49, 99,
+ 98, 99, 54, 164, 116, 101, 115, 116, 164, 116, 101, 115, 116, 164, 116, 101, 115,
+ 116, 145, 164, 116, 101, 115, 116, 164, 116, 101, 115, 116
+ ]
+ );
+ }
+
+ #[test]
+ fn test_serialize_delete() {
+ let record = ScriptRecord::Delete(
+ uuid::Uuid::parse_str("0195c825a35f7982bdb016168881cbc6").unwrap(),
+ );
+
+ let serialized = record.serialize().unwrap();
+
+ assert_eq!(
+ serialized.0,
+ vec![
+ 204, 1, 217, 36, 48, 49, 57, 53, 99, 56, 50, 53, 45, 97, 51, 53, 102, 45, 55, 57,
+ 56, 50, 45, 98, 100, 98, 48, 45, 49, 54, 49, 54, 56, 56, 56, 49, 99, 98, 99, 54
+ ]
+ );
+ }
+
+ #[test]
+ fn test_serialize_update() {
+ let script = Script::builder()
+ .id(uuid::Uuid::parse_str("0195c825a35f7982bdb016168881cbc6").unwrap())
+ .name(String::from("test"))
+ .description(String::from("test"))
+ .shebang(String::from("test"))
+ .tags(vec![String::from("test"), String::from("test2")])
+ .script(String::from("test"))
+ .build();
+
+ let record = ScriptRecord::Update(script);
+
+ let serialized = record.serialize().unwrap();
+
+ assert_eq!(
+ serialized.0,
+ vec![
+ 204, 2, 196, 71, 150, 217, 36, 48, 49, 57, 53, 99, 56, 50, 53, 45, 97, 51, 53, 102,
+ 45, 55, 57, 56, 50, 45, 98, 100, 98, 48, 45, 49, 54, 49, 54, 56, 56, 56, 49, 99,
+ 98, 99, 54, 164, 116, 101, 115, 116, 164, 116, 101, 115, 116, 164, 116, 101, 115,
+ 116, 146, 164, 116, 101, 115, 116, 165, 116, 101, 115, 116, 50, 164, 116, 101, 115,
+ 116
+ ],
+ );
+ }
+
+ #[test]
+ fn test_serialize_deserialize_create() {
+ let script = Script::builder()
+ .name("test".to_string())
+ .description("test".to_string())
+ .shebang("test".to_string())
+ .tags(vec!["test".to_string()])
+ .script("test".to_string())
+ .build();
+
+ let record = ScriptRecord::Create(script);
+
+ let serialized = record.serialize().unwrap();
+ let deserialized = ScriptRecord::deserialize(&serialized, SCRIPT_VERSION).unwrap();
+
+ assert_eq!(record, deserialized);
+ }
+
+ #[test]
+ fn test_serialize_deserialize_delete() {
+ let record = ScriptRecord::Delete(
+ uuid::Uuid::parse_str("0195c825a35f7982bdb016168881cbc6").unwrap(),
+ );
+
+ let serialized = record.serialize().unwrap();
+ let deserialized = ScriptRecord::deserialize(&serialized, SCRIPT_VERSION).unwrap();
+
+ assert_eq!(record, deserialized);
+ }
+
+ #[test]
+ fn test_serialize_deserialize_update() {
+ let script = Script::builder()
+ .name("test".to_string())
+ .description("test".to_string())
+ .shebang("test".to_string())
+ .tags(vec!["test".to_string()])
+ .script("test".to_string())
+ .build();
+
+ let record = ScriptRecord::Update(script);
+
+ let serialized = record.serialize().unwrap();
+ let deserialized = ScriptRecord::deserialize(&serialized, SCRIPT_VERSION).unwrap();
+
+ assert_eq!(record, deserialized);
+ }
+}
diff --git a/crates/atuin-scripts/src/store/script.rs b/crates/atuin-scripts/src/store/script.rs
new file mode 100644
index 00000000..af180320
--- /dev/null
+++ b/crates/atuin-scripts/src/store/script.rs
@@ -0,0 +1,151 @@
+use atuin_common::record::DecryptedData;
+use eyre::{Result, bail, ensure};
+use uuid::Uuid;
+
+use rmp::{
+ decode::{self, Bytes},
+ encode,
+};
+use typed_builder::TypedBuilder;
+
+pub const SCRIPT_VERSION: &str = "v0";
+pub const SCRIPT_TAG: &str = "script";
+pub const SCRIPT_LEN: usize = 20000; // 20kb max total len
+
+#[derive(Debug, Clone, PartialEq, Eq, TypedBuilder)]
+/// A script is a set of commands that can be run, with the specified shebang
+pub struct Script {
+ /// The id of the script
+ #[builder(default = uuid::Uuid::new_v4())]
+ pub id: Uuid,
+
+ /// The name of the script
+ pub name: String,
+
+ /// The description of the script
+ #[builder(default = String::new())]
+ pub description: String,
+
+ /// The interpreter of the script
+ #[builder(default = String::new())]
+ pub shebang: String,
+
+ /// The tags of the script
+ #[builder(default = Vec::new())]
+ pub tags: Vec<String>,
+
+ /// The script content
+ pub script: String,
+}
+
+impl Script {
+ pub fn serialize(&self) -> Result<DecryptedData> {
+ // sort the tags first, to ensure consistent ordering
+ let mut tags = self.tags.clone();
+ tags.sort();
+
+ let mut output = vec![];
+
+ encode::write_array_len(&mut output, 6)?;
+ encode::write_str(&mut output, &self.id.to_string())?;
+ encode::write_str(&mut output, &self.name)?;
+ encode::write_str(&mut output, &self.description)?;
+ encode::write_str(&mut output, &self.shebang)?;
+ encode::write_array_len(&mut output, self.tags.len() as u32)?;
+
+ for tag in &tags {
+ encode::write_str(&mut output, tag)?;
+ }
+
+ encode::write_str(&mut output, &self.script)?;
+
+ Ok(DecryptedData(output))
+ }
+
+ pub fn deserialize(bytes: &[u8]) -> Result<Self> {
+ let mut bytes = decode::Bytes::new(bytes);
+ let nfields = decode::read_array_len(&mut bytes).unwrap();
+
+ ensure!(nfields == 6, "too many entries in v0 script record");
+
+ let bytes = bytes.remaining_slice();
+
+ let (id, bytes) = decode::read_str_from_slice(bytes).unwrap();
+ let (name, bytes) = decode::read_str_from_slice(bytes).unwrap();
+ let (description, bytes) = decode::read_str_from_slice(bytes).unwrap();
+ let (shebang, bytes) = decode::read_str_from_slice(bytes).unwrap();
+
+ let mut bytes = Bytes::new(bytes);
+ let tags_len = decode::read_array_len(&mut bytes).unwrap();
+
+ let mut bytes = bytes.remaining_slice();
+
+ let mut tags = Vec::new();
+ for _ in 0..tags_len {
+ let (tag, remaining) = decode::read_str_from_slice(bytes).unwrap();
+ tags.push(tag.to_owned());
+ bytes = remaining;
+ }
+
+ let (script, bytes) = decode::read_str_from_slice(bytes).unwrap();
+
+ if !bytes.is_empty() {
+ bail!("trailing bytes in encoded script record. malformed")
+ }
+
+ Ok(Script {
+ id: Uuid::parse_str(id).unwrap(),
+ name: name.to_owned(),
+ description: description.to_owned(),
+ shebang: shebang.to_owned(),
+ tags,
+ script: script.to_owned(),
+ })
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_serialize() {
+ let script = Script {
+ id: uuid::Uuid::parse_str("0195c825a35f7982bdb016168881cbc6").unwrap(),
+ name: "test".to_string(),
+ description: "test".to_string(),
+ shebang: "test".to_string(),
+ tags: vec!["test".to_string()],
+ script: "test".to_string(),
+ };
+
+ let serialized = script.serialize().unwrap();
+ assert_eq!(
+ serialized.0,
+ vec![
+ 150, 217, 36, 48, 49, 57, 53, 99, 56, 50, 53, 45, 97, 51, 53, 102, 45, 55, 57, 56,
+ 50, 45, 98, 100, 98, 48, 45, 49, 54, 49, 54, 56, 56, 56, 49, 99, 98, 99, 54, 164,
+ 116, 101, 115, 116, 164, 116, 101, 115, 116, 164, 116, 101, 115, 116, 145, 164,
+ 116, 101, 115, 116, 164, 116, 101, 115, 116
+ ]
+ );
+ }
+
+ #[test]
+ fn test_serialize_deserialize() {
+ let script = Script {
+ id: uuid::Uuid::new_v4(),
+ name: "test".to_string(),
+ description: "test".to_string(),
+ shebang: "test".to_string(),
+ tags: vec!["test".to_string()],
+ script: "test".to_string(),
+ };
+
+ let serialized = script.serialize().unwrap();
+
+ let deserialized = Script::deserialize(&serialized.0).unwrap();
+
+ assert_eq!(script, deserialized);
+ }
+}