diff options
Diffstat (limited to 'crates/turtle/src/atuin_common')
| -rw-r--r-- | crates/turtle/src/atuin_common/api.rs | 144 | ||||
| -rw-r--r-- | crates/turtle/src/atuin_common/calendar.rs | 16 | ||||
| -rw-r--r-- | crates/turtle/src/atuin_common/mod.rs | 58 | ||||
| -rw-r--r-- | crates/turtle/src/atuin_common/record.rs | 426 | ||||
| -rw-r--r-- | crates/turtle/src/atuin_common/shell.rs | 183 | ||||
| -rw-r--r-- | crates/turtle/src/atuin_common/tls.rs | 15 | ||||
| -rw-r--r-- | crates/turtle/src/atuin_common/utils.rs | 383 |
7 files changed, 1225 insertions, 0 deletions
diff --git a/crates/turtle/src/atuin_common/api.rs b/crates/turtle/src/atuin_common/api.rs new file mode 100644 index 00000000..1a9f348c --- /dev/null +++ b/crates/turtle/src/atuin_common/api.rs @@ -0,0 +1,144 @@ +use semver::Version; +use serde::{Deserialize, Serialize}; +use std::borrow::Cow; +use std::sync::LazyLock; +use time::OffsetDateTime; + +// the usage of X- has been deprecated for quite along time, it turns out +pub static ATUIN_HEADER_VERSION: &str = "Atuin-Version"; +pub static ATUIN_CARGO_VERSION: &str = env!("CARGO_PKG_VERSION"); + +pub static ATUIN_VERSION: LazyLock<Version> = + LazyLock::new(|| Version::parse(ATUIN_CARGO_VERSION).expect("failed to parse self semver")); + +#[derive(Debug, Serialize, Deserialize)] +pub struct UserResponse { + pub username: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct RegisterRequest { + pub email: String, + pub username: String, + pub password: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct RegisterResponse { + pub session: String, + /// Auth type: "hub" for Hub API tokens, "cli" for legacy CLI session tokens. + /// Old servers that don't return this field will deserialize as None. + #[serde(default)] + pub auth: Option<String>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct DeleteUserResponse {} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ChangePasswordRequest { + pub current_password: String, + pub new_password: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ChangePasswordResponse {} + +#[derive(Debug, Serialize, Deserialize)] +pub struct LoginRequest { + pub username: String, + pub password: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct LoginResponse { + pub session: String, + /// Auth type: "hub" for Hub API tokens, "cli" for legacy CLI session tokens. + /// Old servers that don't return this field will deserialize as None. + #[serde(default)] + pub auth: Option<String>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct AddHistoryRequest { + pub id: String, + #[serde(with = "time::serde::rfc3339")] + pub timestamp: OffsetDateTime, + pub data: String, + pub hostname: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CountResponse { + pub count: i64, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SyncHistoryRequest { + #[serde(with = "time::serde::rfc3339")] + pub sync_ts: OffsetDateTime, + #[serde(with = "time::serde::rfc3339")] + pub history_ts: OffsetDateTime, + pub host: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SyncHistoryResponse { + pub history: Vec<String>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ErrorResponse<'a> { + pub reason: Cow<'a, str>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct IndexResponse { + pub homage: String, + pub version: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct StatusResponse { + pub count: i64, + pub username: String, + pub deleted: Vec<String>, + + // These could/should also go on the index of the server + // However, we do not request the server index as a part of normal sync + // I'd rather slightly increase the size of this response, than add an extra HTTP request + pub page_size: i64, // max page size supported by the server + pub version: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct DeleteHistoryRequest { + pub client_id: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct MessageResponse { + pub message: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct MeResponse { + pub username: String, +} + +// Hub CLI authentication types + +/// Response from POST /auth/cli/code - generates a code for CLI auth +#[derive(Debug, Serialize, Deserialize)] +pub struct CliCodeResponse { + pub code: String, +} + +/// Response from GET /auth/cli/verify?code=<code> - polls for authorization +#[derive(Debug, Serialize, Deserialize)] +pub struct CliVerifyResponse { + /// Session token, present only when authorization is complete + pub token: Option<String>, + pub success: Option<bool>, + pub error: Option<String>, +} diff --git a/crates/turtle/src/atuin_common/calendar.rs b/crates/turtle/src/atuin_common/calendar.rs new file mode 100644 index 00000000..d3b1d921 --- /dev/null +++ b/crates/turtle/src/atuin_common/calendar.rs @@ -0,0 +1,16 @@ +// Calendar data +use serde::{Serialize, Deserialize}; + +pub enum TimePeriod { + YEAR, + MONTH, + DAY, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TimePeriodInfo { + pub count: u64, + + // TODO: Use this for merkle tree magic + pub hash: String, +} diff --git a/crates/turtle/src/atuin_common/mod.rs b/crates/turtle/src/atuin_common/mod.rs new file mode 100644 index 00000000..d886520d --- /dev/null +++ b/crates/turtle/src/atuin_common/mod.rs @@ -0,0 +1,58 @@ +/// Defines a new UUID type wrapper +macro_rules! new_uuid { + ($name:ident) => { + #[derive( + Debug, + Copy, + Clone, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + serde::Serialize, + serde::Deserialize, + )] + #[serde(transparent)] + pub struct $name(pub Uuid); + + impl<DB: sqlx::Database> sqlx::Type<DB> for $name + where + Uuid: sqlx::Type<DB>, + { + fn type_info() -> <DB as sqlx::Database>::TypeInfo { + Uuid::type_info() + } + } + + impl<'r, DB: sqlx::Database> sqlx::Decode<'r, DB> for $name + where + Uuid: sqlx::Decode<'r, DB>, + { + fn decode( + value: DB::ValueRef<'r>, + ) -> std::result::Result<Self, sqlx::error::BoxDynError> { + Uuid::decode(value).map(Self) + } + } + + impl<'q, DB: sqlx::Database> sqlx::Encode<'q, DB> for $name + where + Uuid: sqlx::Encode<'q, DB>, + { + fn encode_by_ref( + &self, + buf: &mut DB::ArgumentBuffer<'q>, + ) -> Result<sqlx::encode::IsNull, Box<dyn std::error::Error + Send + Sync + 'static>> + { + self.0.encode_by_ref(buf) + } + } + }; +} + +pub mod api; +pub mod record; +pub mod shell; +pub mod tls; +pub mod utils; diff --git a/crates/turtle/src/atuin_common/record.rs b/crates/turtle/src/atuin_common/record.rs new file mode 100644 index 00000000..05c29338 --- /dev/null +++ b/crates/turtle/src/atuin_common/record.rs @@ -0,0 +1,426 @@ +use std::collections::HashMap; + +use eyre::Result; +use serde::{Deserialize, Serialize}; +use typed_builder::TypedBuilder; +use uuid::Uuid; + +#[derive(Clone, Debug, PartialEq)] +pub struct DecryptedData(pub Vec<u8>); + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct EncryptedData { + pub data: String, + pub content_encryption_key: String, +} + +#[derive(Debug, PartialEq, PartialOrd, Ord, Eq)] +pub struct Diff { + pub host: HostId, + pub tag: String, + pub local: Option<RecordIdx>, + pub remote: Option<RecordIdx>, +} + +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] +pub struct Host { + pub id: HostId, + pub name: String, +} + +impl Host { + pub fn new(id: HostId) -> Self { + Host { + id, + name: String::new(), + } + } +} + +new_uuid!(RecordId); +new_uuid!(HostId); + +pub type RecordIdx = u64; + +/// A single record stored inside of our local database +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TypedBuilder)] +pub struct Record<Data> { + /// a unique ID + #[builder(default = RecordId(crate::atuin_common::utils::uuid_v7()))] + pub id: RecordId, + + /// The integer record ID. This is only unique per (host, tag). + pub idx: RecordIdx, + + /// The unique ID of the host. + // TODO(ellie): Optimize the storage here. We use a bunch of IDs, and currently store + // as strings. I would rather avoid normalization, so store as UUID binary instead of + // encoding to a string and wasting much more storage. + pub host: Host, + + /// The creation time in nanoseconds since unix epoch + #[builder(default = time::OffsetDateTime::now_utc().unix_timestamp_nanos() as u64)] + pub timestamp: u64, + + /// The version the data in the entry conforms to + // However we want to track versions for this tag, eg v2 + pub version: String, + + /// The type of data we are storing here. Eg, "history" + pub tag: String, + + /// Some data. This can be anything you wish to store. Use the tag field to know how to handle it. + pub data: Data, +} + +/// Extra data from the record that should be encoded in the data +#[derive(Debug, Copy, Clone)] +pub struct AdditionalData<'a> { + pub id: &'a RecordId, + pub idx: &'a u64, + pub version: &'a str, + pub tag: &'a str, + pub host: &'a HostId, +} + +impl<Data> Record<Data> { + pub fn append(&self, data: Vec<u8>) -> Record<DecryptedData> { + Record::builder() + .host(self.host.clone()) + .version(self.version.clone()) + .idx(self.idx + 1) + .tag(self.tag.clone()) + .data(DecryptedData(data)) + .build() + } +} + +/// An index representing the current state of the record stores +/// This can be both remote, or local, and compared in either direction +#[derive(Debug, Serialize, Deserialize)] +pub struct RecordStatus { + // A map of host -> tag -> max(idx) + pub hosts: HashMap<HostId, HashMap<String, RecordIdx>>, +} + +impl Default for RecordStatus { + fn default() -> Self { + Self::new() + } +} + +impl Extend<(HostId, String, RecordIdx)> for RecordStatus { + fn extend<T: IntoIterator<Item = (HostId, String, RecordIdx)>>(&mut self, iter: T) { + for (host, tag, tail_idx) in iter { + self.set_raw(host, tag, tail_idx); + } + } +} + +impl RecordStatus { + pub fn new() -> RecordStatus { + RecordStatus { + hosts: HashMap::new(), + } + } + + /// Insert a new tail record into the store + pub fn set(&mut self, tail: Record<DecryptedData>) { + self.set_raw(tail.host.id, tail.tag, tail.idx) + } + + pub fn set_raw(&mut self, host: HostId, tag: String, tail_id: RecordIdx) { + self.hosts.entry(host).or_default().insert(tag, tail_id); + } + + pub fn get(&self, host: HostId, tag: String) -> Option<RecordIdx> { + self.hosts.get(&host).and_then(|v| v.get(&tag)).cloned() + } + + /// Diff this index with another, likely remote index. + /// The two diffs can then be reconciled, and the optimal change set calculated + /// Returns a tuple, with (host, tag, Option(OTHER)) + /// OTHER is set to the value of the idx on the other machine. If it is greater than our index, + /// then we need to do some downloading. If it is smaller, then we need to do some uploading + /// Note that we cannot upload if we are not the owner of the record store - hosts can only + /// write to their own store. + pub fn diff(&self, other: &Self) -> Vec<Diff> { + let mut ret = Vec::new(); + + // First, we check if other has everything that self has + for (host, tag_map) in self.hosts.iter() { + for (tag, idx) in tag_map.iter() { + match other.get(*host, tag.clone()) { + // The other store is all up to date! No diff. + Some(t) if t.eq(idx) => continue, + + // The other store does exist, and it is either ahead or behind us. A diff regardless + Some(t) => ret.push(Diff { + host: *host, + tag: tag.clone(), + local: Some(*idx), + remote: Some(t), + }), + + // The other store does not exist :O + None => ret.push(Diff { + host: *host, + tag: tag.clone(), + local: Some(*idx), + remote: None, + }), + }; + } + } + + // At this point, there is a single case we have not yet considered. + // If the other store knows of a tag that we are not yet aware of, then the diff will be missed + + // account for that! + for (host, tag_map) in other.hosts.iter() { + for (tag, idx) in tag_map.iter() { + match self.get(*host, tag.clone()) { + // If we have this host/tag combo, the comparison and diff will have already happened above + Some(_) => continue, + + None => ret.push(Diff { + host: *host, + tag: tag.clone(), + remote: Some(*idx), + local: None, + }), + }; + } + } + + // Stability is a nice property to have + ret.sort(); + ret + } +} + +pub trait Encryption { + fn re_encrypt( + data: EncryptedData, + ad: AdditionalData, + old_key: &[u8; 32], + new_key: &[u8; 32], + ) -> Result<EncryptedData> { + let data = Self::decrypt(data, ad, old_key)?; + Ok(Self::encrypt(data, ad, new_key)) + } + fn encrypt(data: DecryptedData, ad: AdditionalData, key: &[u8; 32]) -> EncryptedData; + fn decrypt(data: EncryptedData, ad: AdditionalData, key: &[u8; 32]) -> Result<DecryptedData>; +} + +impl Record<DecryptedData> { + pub fn encrypt<E: Encryption>(self, key: &[u8; 32]) -> Record<EncryptedData> { + let ad = AdditionalData { + id: &self.id, + version: &self.version, + tag: &self.tag, + host: &self.host.id, + idx: &self.idx, + }; + Record { + data: E::encrypt(self.data, ad, key), + id: self.id, + host: self.host, + idx: self.idx, + timestamp: self.timestamp, + version: self.version, + tag: self.tag, + } + } +} + +impl Record<EncryptedData> { + pub fn decrypt<E: Encryption>(self, key: &[u8; 32]) -> Result<Record<DecryptedData>> { + let ad = AdditionalData { + id: &self.id, + version: &self.version, + tag: &self.tag, + host: &self.host.id, + idx: &self.idx, + }; + Ok(Record { + data: E::decrypt(self.data, ad, key)?, + id: self.id, + host: self.host, + idx: self.idx, + timestamp: self.timestamp, + version: self.version, + tag: self.tag, + }) + } + + pub fn re_encrypt<E: Encryption>( + self, + old_key: &[u8; 32], + new_key: &[u8; 32], + ) -> Result<Record<EncryptedData>> { + let ad = AdditionalData { + id: &self.id, + version: &self.version, + tag: &self.tag, + host: &self.host.id, + idx: &self.idx, + }; + Ok(Record { + data: E::re_encrypt(self.data, ad, old_key, new_key)?, + id: self.id, + host: self.host, + idx: self.idx, + timestamp: self.timestamp, + version: self.version, + tag: self.tag, + }) + } +} + +#[cfg(test)] +mod tests { + use crate::atuin_common::record::{Host, HostId}; + + use super::{DecryptedData, Diff, Record, RecordStatus}; + use pretty_assertions::assert_eq; + + fn test_record() -> Record<DecryptedData> { + Record::builder() + .host(Host::new(HostId(crate::atuin_common::utils::uuid_v7()))) + .version("v1".into()) + .tag(crate::atuin_common::utils::uuid_v7().simple().to_string()) + .data(DecryptedData(vec![0, 1, 2, 3])) + .idx(0) + .build() + } + + #[test] + fn record_index() { + let mut index = RecordStatus::new(); + let record = test_record(); + + index.set(record.clone()); + + let tail = index.get(record.host.id, record.tag); + + assert_eq!( + record.idx, + tail.expect("tail not in store"), + "tail in store did not match" + ); + } + + #[test] + fn record_index_overwrite() { + let mut index = RecordStatus::new(); + let record = test_record(); + let child = record.append(vec![1, 2, 3]); + + index.set(record.clone()); + index.set(child.clone()); + + let tail = index.get(record.host.id, record.tag); + + assert_eq!( + child.idx, + tail.expect("tail not in store"), + "tail in store did not match" + ); + } + + #[test] + fn record_index_no_diff() { + // Here, they both have the same version and should have no diff + + let mut index1 = RecordStatus::new(); + let mut index2 = RecordStatus::new(); + + let record1 = test_record(); + + index1.set(record1.clone()); + index2.set(record1); + + let diff = index1.diff(&index2); + + assert_eq!(0, diff.len(), "expected empty diff"); + } + + #[test] + fn record_index_single_diff() { + // Here, they both have the same stores, but one is ahead by a single record + + let mut index1 = RecordStatus::new(); + let mut index2 = RecordStatus::new(); + + let record1 = test_record(); + let record2 = record1.append(vec![1, 2, 3]); + + index1.set(record1); + index2.set(record2.clone()); + + let diff = index1.diff(&index2); + + assert_eq!(1, diff.len(), "expected single diff"); + assert_eq!( + diff[0], + Diff { + host: record2.host.id, + tag: record2.tag, + remote: Some(1), + local: Some(0) + } + ); + } + + #[test] + fn record_index_multi_diff() { + // A much more complex case, with a bunch more checks + let mut index1 = RecordStatus::new(); + let mut index2 = RecordStatus::new(); + + let store1record1 = test_record(); + let store1record2 = store1record1.append(vec![1, 2, 3]); + + let store2record1 = test_record(); + let store2record2 = store2record1.append(vec![1, 2, 3]); + + let store3record1 = test_record(); + + let store4record1 = test_record(); + + // index1 only knows about the first two entries of the first two stores + index1.set(store1record1); + index1.set(store2record1); + + // index2 is fully up to date with the first two stores, and knows of a third + index2.set(store1record2); + index2.set(store2record2); + index2.set(store3record1); + + // index1 knows of a 4th store + index1.set(store4record1); + + let diff1 = index1.diff(&index2); + let diff2 = index2.diff(&index1); + + // both diffs the same length + assert_eq!(4, diff1.len()); + assert_eq!(4, diff2.len()); + + dbg!(&diff1, &diff2); + + // both diffs should be ALMOST the same. They will agree on which hosts and tags + // require updating, but the "other" value will not be the same. + let smol_diff_1: Vec<(HostId, String)> = + diff1.iter().map(|v| (v.host, v.tag.clone())).collect(); + let smol_diff_2: Vec<(HostId, String)> = + diff1.iter().map(|v| (v.host, v.tag.clone())).collect(); + + assert_eq!(smol_diff_1, smol_diff_2); + + // diffing with yourself = no diff + assert_eq!(index1.diff(&index1).len(), 0); + assert_eq!(index2.diff(&index2).len(), 0); + } +} diff --git a/crates/turtle/src/atuin_common/shell.rs b/crates/turtle/src/atuin_common/shell.rs new file mode 100644 index 00000000..7f9a7b8f --- /dev/null +++ b/crates/turtle/src/atuin_common/shell.rs @@ -0,0 +1,183 @@ +use std::{ffi::OsStr, path::Path, process::Command}; + +use serde::Serialize; +use sysinfo::{Process, System, get_current_pid}; +use thiserror::Error; + +#[derive(PartialEq)] +pub enum Shell { + Sh, + Bash, + Fish, + Zsh, + Xonsh, + Nu, + Powershell, + + Unknown, +} + +impl std::fmt::Display for Shell { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let shell = match self { + Shell::Bash => "bash", + Shell::Fish => "fish", + Shell::Zsh => "zsh", + Shell::Nu => "nu", + Shell::Xonsh => "xonsh", + Shell::Sh => "sh", + Shell::Powershell => "powershell", + + Shell::Unknown => "unknown", + }; + + write!(f, "{shell}") + } +} + +#[derive(Debug, Error, Serialize)] +pub enum ShellError { + #[error("shell not supported")] + NotSupported, + + #[error("failed to execute shell command: {0}")] + ExecError(String), +} + +impl Shell { + pub fn current() -> Shell { + let sys = System::new_all(); + + let process = sys + .process(get_current_pid().expect("Failed to get current PID")) + .expect("Process with current pid does not exist"); + + let parent = sys + .process(process.parent().expect("Atuin running with no parent!")) + .expect("Process with parent pid does not exist"); + + let shell = parent.name().trim().to_lowercase(); + let shell = shell.strip_prefix('-').unwrap_or(&shell); + + Shell::from_string(shell.to_string()) + } + + pub fn from_env() -> Shell { + std::env::var("ATUIN_SHELL").map_or(Shell::Unknown, |shell| { + Shell::from_string(shell.trim().to_lowercase()) + }) + } + + pub fn config_file(&self) -> Option<std::path::PathBuf> { + let mut path = if let Some(base) = directories::BaseDirs::new() { + base.home_dir().to_owned() + } else { + return None; + }; + + // TODO: handle all shells + match self { + Shell::Bash => path.push(".bashrc"), + Shell::Zsh => path.push(".zshrc"), + Shell::Fish => path.push(".config/fish/config.fish"), + + _ => return None, + }; + + Some(path) + } + + /// Best-effort attempt to determine the default shell + /// This implementation will be different across different platforms + /// Caller should ensure to handle Shell::Unknown correctly + pub fn default_shell() -> Result<Shell, ShellError> { + let sys = System::name().unwrap_or("".to_string()).to_lowercase(); + + // TODO: Support Linux + // I'm pretty sure we can use /etc/passwd there, though there will probably be some issues + let path = if sys.contains("darwin") { + // This works in my testing so far + Shell::Sh.run_interactive([ + "dscl localhost -read \"/Local/Default/Users/$USER\" shell | awk '{print $2}'", + ])? + } else if cfg!(windows) { + return Ok(Shell::Powershell); + } else { + Shell::Sh.run_interactive(["getent passwd $LOGNAME | cut -d: -f7"])? + }; + + let path = Path::new(path.trim()); + let shell = path.file_name(); + + if shell.is_none() { + return Err(ShellError::NotSupported); + } + + Ok(Shell::from_string( + shell.unwrap().to_string_lossy().to_string(), + )) + } + + pub fn from_string(name: String) -> Shell { + match name.as_str() { + "bash" => Shell::Bash, + "fish" => Shell::Fish, + "zsh" => Shell::Zsh, + "xonsh" => Shell::Xonsh, + "nu" => Shell::Nu, + "sh" => Shell::Sh, + "powershell" => Shell::Powershell, + + _ => Shell::Unknown, + } + } + + /// Returns true if the shell is posix-like + /// Note that while fish is not posix compliant, it behaves well enough for our current + /// featureset that this does not matter. + pub fn is_posixish(&self) -> bool { + matches!(self, Shell::Bash | Shell::Fish | Shell::Zsh) + } + + pub fn run_interactive<I, S>(&self, args: I) -> Result<String, ShellError> + where + I: IntoIterator<Item = S>, + S: AsRef<OsStr>, + { + let shell = self.to_string(); + let output = if self == &Self::Powershell { + Command::new(shell) + .args(args) + .output() + .map_err(|e| ShellError::ExecError(e.to_string()))? + } else { + Command::new(shell) + .arg("-ic") + .args(args) + .output() + .map_err(|e| ShellError::ExecError(e.to_string()))? + }; + + Ok(String::from_utf8(output.stdout).unwrap()) + } +} + +pub fn shell_name(parent: Option<&Process>) -> String { + let sys = System::new_all(); + + let parent = if let Some(parent) = parent { + parent + } else { + let process = sys + .process(get_current_pid().expect("Failed to get current PID")) + .expect("Process with current pid does not exist"); + + sys.process(process.parent().expect("Atuin running with no parent!")) + .expect("Process with parent pid does not exist") + }; + + let shell = parent.name().trim().to_lowercase(); + let shell = shell.strip_prefix('-').unwrap_or(&shell); + + shell.to_string() +} diff --git a/crates/turtle/src/atuin_common/tls.rs b/crates/turtle/src/atuin_common/tls.rs new file mode 100644 index 00000000..e8c840e0 --- /dev/null +++ b/crates/turtle/src/atuin_common/tls.rs @@ -0,0 +1,15 @@ +use std::sync::Once; + +static INIT: Once = Once::new(); + +/// Ensure the rustls crypto provider (ring) is installed. +/// +/// Must be called before creating any reqwest clients. Safe to call +/// multiple times — only the first call installs the provider. +pub fn ensure_crypto_provider() { + INIT.call_once(|| { + rustls::crypto::ring::default_provider() + .install_default() + .expect("Failed to install rustls crypto provider"); + }); +} diff --git a/crates/turtle/src/atuin_common/utils.rs b/crates/turtle/src/atuin_common/utils.rs new file mode 100644 index 00000000..d7382fb2 --- /dev/null +++ b/crates/turtle/src/atuin_common/utils.rs @@ -0,0 +1,383 @@ +use std::borrow::Cow; +use std::env; +use std::path::{Path, PathBuf}; + +use eyre::{Result, eyre}; + +use base64::prelude::{BASE64_URL_SAFE_NO_PAD, Engine}; +use getrandom::getrandom; +use uuid::Uuid; + +/// Generate N random bytes, using a cryptographically secure source +pub fn crypto_random_bytes<const N: usize>() -> [u8; N] { + // rand say they are in principle safe for crypto purposes, but that it is perhaps a better + // idea to use getrandom for things such as passwords. + let mut ret = [0u8; N]; + + getrandom(&mut ret).expect("Failed to generate random bytes!"); + + ret +} + +/// Generate N random bytes using a cryptographically secure source, return encoded as a string +pub fn crypto_random_string<const N: usize>() -> String { + let bytes = crypto_random_bytes::<N>(); + + // We only use this to create a random string, and won't be reversing it to find the original + // data - no padding is OK there. It may be in URLs. + BASE64_URL_SAFE_NO_PAD.encode(bytes) +} + +pub fn uuid_v7() -> Uuid { + Uuid::now_v7() +} + +pub fn uuid_v4() -> String { + Uuid::new_v4().as_simple().to_string() +} + +pub fn has_git_dir(path: &str) -> bool { + let mut gitdir = PathBuf::from(path); + gitdir.push(".git"); + + gitdir.exists() +} + +// in a git worktree, .git is a file containing "gitdir: <path>" pointing +// to the main repo's .git/worktrees/<name> directory. follow the pointer +// back to the main repo root so all worktrees share a workspace. +fn resolve_git_worktree(path: &Path) -> Option<PathBuf> { + let git_path = path.join(".git"); + + if !git_path.is_file() { + return None; + } + + let contents = std::fs::read_to_string(&git_path).ok()?; + let gitdir_str = contents.strip_prefix("gitdir: ")?.trim(); + + let gitdir = PathBuf::from(gitdir_str); + let gitdir = if gitdir.is_absolute() { + gitdir + } else { + path.join(gitdir_str) + }; + + // walk up from e.g. /repo/.git/worktrees/feature to find /repo + let mut candidate = gitdir.as_path(); + while let Some(parent) = candidate.parent() { + if parent.join(".git").is_dir() { + return Some(parent.to_path_buf()); + } + candidate = parent; + } + + None +} + +// detect if any parent dir has a git repo in it +// I really don't want to bring in libgit for something simple like this +// If we start to do anything more advanced, then perhaps +pub fn in_git_repo(path: &str) -> Option<PathBuf> { + let mut gitdir = PathBuf::from(path); + + while gitdir.parent().is_some() && !has_git_dir(gitdir.to_str().unwrap()) { + gitdir.pop(); + } + + // No parent? then we hit root, finding no git + if gitdir.parent().is_some() { + // if .git is a file (worktree), resolve to the main repo root + if let Some(main_repo) = resolve_git_worktree(&gitdir) { + return Some(main_repo); + } + return Some(gitdir); + } + + None +} + +// TODO: more reliable, more tested +// I don't want to use ProjectDirs, it puts config in awkward places on +// mac. Data too. Seems to be more intended for GUI apps. + +pub fn home_dir() -> PathBuf { + directories::BaseDirs::new() + .map(|d| d.home_dir().to_path_buf()) + .expect("could not determine home directory") +} + +pub fn config_dir() -> PathBuf { + let config_dir = + std::env::var("XDG_CONFIG_HOME").map_or_else(|_| home_dir().join(".config"), PathBuf::from); + config_dir.join("atuin") +} + +pub fn data_dir() -> PathBuf { + let data_dir = std::env::var("XDG_DATA_HOME") + .map_or_else(|_| home_dir().join(".local").join("share"), PathBuf::from); + + data_dir.join("atuin") +} + +pub fn runtime_dir() -> PathBuf { + std::env::var("XDG_RUNTIME_DIR").map_or_else(|_| data_dir(), PathBuf::from) +} + +pub fn logs_dir() -> PathBuf { + home_dir().join(".atuin").join("logs") +} + +pub fn dotfiles_cache_dir() -> PathBuf { + // In most cases, this will be ~/.local/share/atuin/dotfiles/cache + let data_dir = std::env::var("XDG_DATA_HOME") + .map_or_else(|_| home_dir().join(".local").join("share"), PathBuf::from); + + data_dir.join("atuin").join("dotfiles").join("cache") +} + +pub fn get_current_dir() -> String { + // Prefer PWD environment variable over cwd if available to better support symbolic links + match env::var("PWD") { + Ok(v) => v, + Err(_) => match env::current_dir() { + Ok(dir) => dir.display().to_string(), + Err(_) => String::from(""), + }, + } +} + +pub fn broken_symlink<P: Into<PathBuf>>(path: P) -> bool { + let path = path.into(); + path.is_symlink() && !path.exists() +} + +/// Extension trait for anything that can behave like a string to make it easy to escape control +/// characters. +/// +/// Intended to help prevent control characters being printed and interpreted by the terminal when +/// printing history as well as to ensure the commands that appear in the interactive search +/// reflect the actual command run rather than just the printable characters. +pub trait Escapable: AsRef<str> { + fn escape_control(&self) -> Cow<'_, str> { + if !self.as_ref().contains(|c: char| c.is_ascii_control()) { + self.as_ref().into() + } else { + let mut remaining = self.as_ref(); + // Not a perfect way to reserve space but should reduce the allocations + let mut buf = String::with_capacity(remaining.len()); + while let Some(i) = remaining.find(|c: char| c.is_ascii_control()) { + // safe to index with `..i`, `i` and `i+1..` as part[i] is a single byte ascii char + buf.push_str(&remaining[..i]); + buf.push('^'); + buf.push(match remaining.as_bytes()[i] { + 0x7F => '?', + code => char::from_u32(u32::from(code) + 64).unwrap(), + }); + remaining = &remaining[i + 1..]; + } + buf.push_str(remaining); + buf.into() + } + } +} + +pub fn unquote(s: &str) -> Result<String> { + if s.chars().count() < 2 { + return Err(eyre!("not enough chars")); + } + + let quote = s.chars().next().unwrap(); + + // not quoted, do nothing + if quote != '"' && quote != '\'' && quote != '`' { + return Ok(s.to_string()); + } + + if s.chars().last().unwrap() != quote { + return Err(eyre!("unexpected eof, quotes do not match")); + } + + // removes quote characters + // the sanity checks performed above ensure that the quotes will be ASCII and this will not + // panic + let s = &s[1..s.len() - 1]; + + Ok(s.to_string()) +} + +impl<T: AsRef<str>> Escapable for T {} + +#[expect(unsafe_code)] +#[cfg(test)] +mod tests { + use pretty_assertions::assert_ne; + + use super::*; + + use std::collections::HashSet; + + #[cfg(not(windows))] + #[test] + fn test_dirs() { + // these tests need to be run sequentially to prevent race condition + test_config_dir_xdg(); + test_config_dir(); + test_data_dir_xdg(); + test_data_dir(); + } + + #[cfg(not(windows))] + fn test_config_dir_xdg() { + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::remove_var("HOME") }; + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::set_var("XDG_CONFIG_HOME", "/home/user/custom_config") }; + assert_eq!( + config_dir(), + PathBuf::from("/home/user/custom_config/atuin") + ); + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::remove_var("XDG_CONFIG_HOME") }; + } + + #[cfg(not(windows))] + fn test_config_dir() { + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::set_var("HOME", "/home/user") }; + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::remove_var("XDG_CONFIG_HOME") }; + + assert_eq!(config_dir(), PathBuf::from("/home/user/.config/atuin")); + + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::remove_var("HOME") }; + } + + #[cfg(not(windows))] + fn test_data_dir_xdg() { + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::remove_var("HOME") }; + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::set_var("XDG_DATA_HOME", "/home/user/custom_data") }; + assert_eq!(data_dir(), PathBuf::from("/home/user/custom_data/atuin")); + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::remove_var("XDG_DATA_HOME") }; + } + + #[cfg(not(windows))] + fn test_data_dir() { + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::set_var("HOME", "/home/user") }; + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::remove_var("XDG_DATA_HOME") }; + assert_eq!(data_dir(), PathBuf::from("/home/user/.local/share/atuin")); + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::remove_var("HOME") }; + } + + #[test] + fn uuid_is_unique() { + let how_many: usize = 1000000; + + // for peace of mind + let mut uuids: HashSet<Uuid> = HashSet::with_capacity(how_many); + + // there will be many in the same millisecond + for _ in 0..how_many { + let uuid = uuid_v7(); + uuids.insert(uuid); + } + + assert_eq!(uuids.len(), how_many); + } + + #[test] + fn escape_control_characters() { + use super::Escapable; + // CSI colour sequence + assert_eq!("\x1b[31mfoo".escape_control(), "^[[31mfoo"); + + // Tabs count as control chars + assert_eq!("foo\tbar".escape_control(), "foo^Ibar"); + + // space is in control char range but should be excluded + assert_eq!("two words".escape_control(), "two words"); + + // unicode multi-byte characters + let s = "🐢\x1b[32m🦀"; + assert_eq!(s.escape_control(), s.replace("\x1b", "^[")); + } + + #[test] + fn escape_no_control_characters() { + use super::Escapable as _; + assert!(matches!( + "no control characters".escape_control(), + Cow::Borrowed(_) + )); + assert!(matches!( + "with \x1b[31mcontrol\x1b[0m characters".escape_control(), + Cow::Owned(_) + )); + } + + #[cfg(not(windows))] + #[test] + fn in_git_repo_regular() { + // regular git repo should resolve to the directory containing .git + let tmp = std::env::temp_dir().join("atuin-test-regular-git"); + let _ = std::fs::remove_dir_all(&tmp); + let subdir = tmp.join("src").join("deep"); + std::fs::create_dir_all(&subdir).unwrap(); + std::fs::create_dir_all(tmp.join(".git")).unwrap(); + + let result = in_git_repo(subdir.to_str().unwrap()); + assert_eq!(result, Some(tmp.clone())); + + std::fs::remove_dir_all(&tmp).unwrap(); + } + + #[cfg(not(windows))] + #[test] + fn in_git_repo_worktree_resolves_to_main_repo() { + // worktree .git is a file pointing back to the main repo — + // in_git_repo should follow it so all worktrees share a workspace + let tmp = std::env::temp_dir().join("atuin-test-worktree-git"); + let _ = std::fs::remove_dir_all(&tmp); + + // main repo at tmp/main with a real .git directory + let main_repo = tmp.join("main"); + let worktree_git_dir = main_repo.join(".git").join("worktrees").join("feature"); + std::fs::create_dir_all(&worktree_git_dir).unwrap(); + + // worktree at tmp/worktree with a .git file + let worktree = tmp.join("worktree"); + let worktree_subdir = worktree.join("src"); + std::fs::create_dir_all(&worktree_subdir).unwrap(); + std::fs::write( + worktree.join(".git"), + format!("gitdir: {}", worktree_git_dir.to_str().unwrap()), + ) + .unwrap(); + + // should resolve to the main repo root, not the worktree root + let result = in_git_repo(worktree_subdir.to_str().unwrap()); + assert_eq!(result, Some(main_repo.clone())); + + std::fs::remove_dir_all(&tmp).unwrap(); + } + + #[test] + fn dumb_random_test() { + // Obviously not a test of randomness, but make sure we haven't made some + // catastrophic error + + assert_ne!(crypto_random_string::<1>(), crypto_random_string::<1>()); + assert_ne!(crypto_random_string::<2>(), crypto_random_string::<2>()); + assert_ne!(crypto_random_string::<4>(), crypto_random_string::<4>()); + assert_ne!(crypto_random_string::<8>(), crypto_random_string::<8>()); + assert_ne!(crypto_random_string::<16>(), crypto_random_string::<16>()); + assert_ne!(crypto_random_string::<32>(), crypto_random_string::<32>()); + } +} |
