diff options
| author | Benedikt Peetz <benedikt.peetz@b-peetz.de> | 2026-06-11 00:54:30 +0200 |
|---|---|---|
| committer | Benedikt Peetz <benedikt.peetz@b-peetz.de> | 2026-06-11 00:54:30 +0200 |
| commit | 5c39e7cf284a1f6e9a1657f2deb44e359fc47eb8 (patch) | |
| tree | c64baa8d5866c8e339eaf660dd3f94f30a3f7d8a /crates/turtle/src/atuin_client | |
| parent | chore: Somewhat simplify sync code (diff) | |
| download | atuin-5c39e7cf284a1f6e9a1657f2deb44e359fc47eb8.zip | |
chore: Move everything into one big crate
That helps remove duplicated code and rustc/cargo will now also show
dead code correctly.
Diffstat (limited to 'crates/turtle/src/atuin_client')
39 files changed, 11991 insertions, 0 deletions
diff --git a/crates/turtle/src/atuin_client/api_client.rs b/crates/turtle/src/atuin_client/api_client.rs new file mode 100644 index 00000000..7955c2da --- /dev/null +++ b/crates/turtle/src/atuin_client/api_client.rs @@ -0,0 +1,438 @@ +use std::collections::HashMap; +use std::env; +use std::time::Duration; + +use eyre::{Result, bail, eyre}; +use reqwest::{ + Response, StatusCode, Url, + header::{AUTHORIZATION, HeaderMap, USER_AGENT}, +}; +use tracing::debug; + +use crate::atuin_common::{ + api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ATUIN_VERSION}, + record::{EncryptedData, HostId, Record, RecordIdx}, + tls::ensure_crypto_provider, +}; +use crate::atuin_common::{ + api::{ + AddHistoryRequest, ChangePasswordRequest, CountResponse, DeleteHistoryRequest, + ErrorResponse, LoginRequest, LoginResponse, MeResponse, RegisterResponse, StatusResponse, + SyncHistoryResponse, + }, + record::RecordStatus, +}; + +use semver::Version; +use time::OffsetDateTime; +use time::format_description::well_known::Rfc3339; + +use crate::atuin_client::{history::History, sync::hash_str, utils::get_host_user}; + +static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION"),); + +/// Authentication token for sync API requests. +/// +/// The sync API supports two authentication methods: +/// - `Bearer`: Hub API tokens (for users authenticated via Atuin Hub) +/// - `Token`: Legacy CLI session tokens (for users registered via CLI or self-hosted) +/// +/// When both are available, Hub tokens are preferred as they provide unified +/// authentication across CLI and Hub features. +#[derive(Debug, Clone)] +pub enum AuthToken { + /// Legacy CLI session token, used with "Token {token}" header + Token(String), +} + +impl AuthToken { + /// Format the token as an Authorization header value + fn to_header_value(&self) -> String { + match self { + AuthToken::Token(token) => format!("Token {token}"), + } + } +} + +pub struct Client<'a> { + sync_addr: &'a str, + client: reqwest::Client, +} + +fn make_url(address: &str, path: &str) -> Result<String> { + // `join()` expects a trailing `/` in order to join paths + // e.g. it treats `http://host:port/subdir` as a file called `subdir` + let address = if address.ends_with("/") { + address + } else { + &format!("{address}/") + }; + + // passing a path with a leading `/` will cause `join()` to replace the entire URL path + let path = path.strip_prefix("/").unwrap_or(path); + + let url = Url::parse(address) + .map(|url| url.join(path))? + .map_err(|_| eyre!("invalid address"))?; + + Ok(url.to_string()) +} + +pub async fn register( + address: &str, + username: &str, + email: &str, + password: &str, +) -> Result<RegisterResponse> { + ensure_crypto_provider(); + let mut map = HashMap::new(); + map.insert("username", username); + map.insert("email", email); + map.insert("password", password); + + let url = make_url(address, &format!("/user/{username}"))?; + let resp = reqwest::get(url).await?; + + if resp.status().is_success() { + bail!("username already in use"); + } + + let url = make_url(address, "/register")?; + let client = reqwest::Client::new(); + let resp = client + .post(url) + .header(USER_AGENT, APP_USER_AGENT) + .header(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION) + .json(&map) + .send() + .await?; + let resp = handle_resp_error(resp).await?; + + if !ensure_version(&resp)? { + bail!("could not register user due to version mismatch"); + } + + let session = resp.json::<RegisterResponse>().await?; + Ok(session) +} + +pub async fn login(address: &str, req: LoginRequest) -> Result<LoginResponse> { + ensure_crypto_provider(); + let url = make_url(address, "/login")?; + let client = reqwest::Client::new(); + + let resp = client + .post(url) + .header(USER_AGENT, APP_USER_AGENT) + .json(&req) + .send() + .await?; + let resp = handle_resp_error(resp).await?; + + if !ensure_version(&resp)? { + bail!("Could not login due to version mismatch"); + } + + let session = resp.json::<LoginResponse>().await?; + Ok(session) +} + +pub fn ensure_version(response: &Response) -> Result<bool> { + let version = response.headers().get(ATUIN_HEADER_VERSION); + + let version = if let Some(version) = version { + match version.to_str() { + Ok(v) => Version::parse(v), + Err(e) => bail!("failed to parse server version: {:?}", e), + } + } else { + bail!("Server not reporting its version: it is either too old or unhealthy"); + }?; + + // If the client is newer than the server + if version.major < ATUIN_VERSION.major { + println!( + "Atuin version mismatch! In order to successfully sync, the server needs to run a newer version of Atuin" + ); + println!("Client: {ATUIN_CARGO_VERSION}"); + println!("Server: {version}"); + + return Ok(false); + } + + Ok(true) +} + +async fn handle_resp_error(resp: Response) -> Result<Response> { + let status = resp.status(); + let url = resp.url().to_string(); + + if status == StatusCode::SERVICE_UNAVAILABLE { + bail!( + "Service unavailable: check https://status.atuin.sh (or get in touch with your host)" + ); + } + + if status == StatusCode::TOO_MANY_REQUESTS { + bail!("Rate limited; please wait before doing that again"); + } + + if !status.is_success() { + if let Ok(error) = resp.json::<ErrorResponse>().await { + let reason = error.reason; + + if status.is_client_error() { + bail!("Invalid request to the service at {url}, {status} - {reason}.") + } + + bail!( + "There was an error with the atuin sync service at {url}, server error {status}: {reason}.\nIf the problem persists, contact the host" + ) + } + + bail!( + "There was an error with the atuin sync service at {url}, Status {status:?}.\nIf the problem persists, contact the host" + ) + } + + Ok(resp) +} + +impl<'a> Client<'a> { + pub fn new( + sync_addr: &'a str, + auth: AuthToken, + connect_timeout: u64, + timeout: u64, + ) -> Result<Self> { + ensure_crypto_provider(); + let mut headers = HeaderMap::new(); + headers.insert(AUTHORIZATION, auth.to_header_value().parse()?); + + // used for semver server check + headers.insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse()?); + + Ok(Client { + sync_addr, + client: reqwest::Client::builder() + .user_agent(APP_USER_AGENT) + .default_headers(headers) + .connect_timeout(Duration::new(connect_timeout, 0)) + .timeout(Duration::new(timeout, 0)) + .build()?, + }) + } + + pub async fn count(&self) -> Result<i64> { + let url = make_url(self.sync_addr, "/sync/count")?; + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + if !ensure_version(&resp)? { + bail!("could not sync due to version mismatch"); + } + + if resp.status() != StatusCode::OK { + bail!("failed to get count (are you logged in?)"); + } + + let count = resp.json::<CountResponse>().await?; + + Ok(count.count) + } + + pub async fn status(&self) -> Result<StatusResponse> { + let url = make_url(self.sync_addr, "/sync/status")?; + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + if !ensure_version(&resp)? { + bail!("could not sync due to version mismatch"); + } + + let status = resp.json::<StatusResponse>().await?; + + Ok(status) + } + + pub async fn me(&self) -> Result<MeResponse> { + let url = make_url(self.sync_addr, "/api/v0/me")?; + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + let status = resp.json::<MeResponse>().await?; + + Ok(status) + } + + pub async fn get_history( + &self, + sync_ts: OffsetDateTime, + history_ts: OffsetDateTime, + host: Option<String>, + ) -> Result<SyncHistoryResponse> { + let host = host.unwrap_or_else(|| hash_str(&get_host_user())); + + let url = make_url( + self.sync_addr, + &format!( + "/sync/history?sync_ts={}&history_ts={}&host={}", + urlencoding::encode(sync_ts.format(&Rfc3339)?.as_str()), + urlencoding::encode(history_ts.format(&Rfc3339)?.as_str()), + host, + ), + )?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + let history = resp.json::<SyncHistoryResponse>().await?; + Ok(history) + } + + pub async fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> { + let url = make_url(self.sync_addr, "/history")?; + let url = Url::parse(url.as_str())?; + + let resp = self.client.post(url).json(history).send().await?; + handle_resp_error(resp).await?; + + Ok(()) + } + + pub async fn delete_history(&self, h: History) -> Result<()> { + let url = make_url(self.sync_addr, "/history")?; + let url = Url::parse(url.as_str())?; + + let resp = self + .client + .delete(url) + .json(&DeleteHistoryRequest { + client_id: h.id.to_string(), + }) + .send() + .await?; + + handle_resp_error(resp).await?; + + Ok(()) + } + + pub async fn delete_store(&self) -> Result<()> { + let url = make_url(self.sync_addr, "/api/v0/store")?; + let url = Url::parse(url.as_str())?; + + let resp = self.client.delete(url).send().await?; + + handle_resp_error(resp).await?; + + Ok(()) + } + + pub async fn post_records(&self, records: &[Record<EncryptedData>]) -> Result<()> { + let url = make_url(self.sync_addr, "/api/v0/record")?; + let url = Url::parse(url.as_str())?; + + debug!("uploading {} records to {url}", records.len()); + + let resp = self.client.post(url).json(records).send().await?; + handle_resp_error(resp).await?; + + Ok(()) + } + + pub async fn next_records( + &self, + host: HostId, + tag: String, + start: RecordIdx, + count: u64, + ) -> Result<Vec<Record<EncryptedData>>> { + debug!("fetching record/s from host {}/{}/{}", host.0, tag, start); + + let url = make_url( + self.sync_addr, + &format!( + "/api/v0/record/next?host={}&tag={}&count={}&start={}", + host.0, tag, count, start + ), + )?; + + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + let records = resp.json::<Vec<Record<EncryptedData>>>().await?; + + Ok(records) + } + + pub async fn record_status(&self) -> Result<RecordStatus> { + let url = make_url(self.sync_addr, "/api/v0/record")?; + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + if !ensure_version(&resp)? { + bail!("could not sync records due to version mismatch"); + } + + let index = resp.json().await?; + + debug!("got remote index {index:?}"); + + Ok(index) + } + + pub async fn delete(&self) -> Result<()> { + let url = make_url(self.sync_addr, "/account")?; + let url = Url::parse(url.as_str())?; + + let resp = self.client.delete(url).send().await?; + + if resp.status() == 403 { + bail!("invalid login details"); + } else if resp.status() == 200 { + Ok(()) + } else { + bail!("Unknown error"); + } + } + + pub async fn change_password( + &self, + current_password: String, + new_password: String, + ) -> Result<()> { + let url = make_url(self.sync_addr, "/account/password")?; + let url = Url::parse(url.as_str())?; + + let resp = self + .client + .patch(url) + .json(&ChangePasswordRequest { + current_password, + new_password, + }) + .send() + .await?; + + if resp.status() == 401 { + bail!("current password is incorrect") + } else if resp.status() == 403 { + bail!("invalid login details"); + } else if resp.status() == 200 { + Ok(()) + } else { + bail!("Unknown error"); + } + } +} diff --git a/crates/turtle/src/atuin_client/auth.rs b/crates/turtle/src/atuin_client/auth.rs new file mode 100644 index 00000000..b770c488 --- /dev/null +++ b/crates/turtle/src/atuin_client/auth.rs @@ -0,0 +1,223 @@ +use async_trait::async_trait; +use eyre::{Context, Result, bail}; +use reqwest::{Url, header::USER_AGENT}; + +use crate::{ + atuin_client::api_client, + atuin_common::{ + api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ChangePasswordRequest, LoginRequest}, + tls::ensure_crypto_provider, + }, +}; + +use crate::atuin_client::settings::Settings; + +static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION")); + +/// Result of an auth operation that may require 2FA. +pub enum AuthResponse { + /// Operation succeeded; for login/register, contains the session token. + /// `auth_type` indicates the kind of token: `Some("hub")` for Hub API + /// tokens (prefixed `atapi_`), `Some("cli")` for legacy CLI session + /// tokens. `None` when the server didn't include the field (old servers). + Success { + session: String, + auth_type: Option<String>, + }, + /// Two-factor authentication is required; the caller should prompt for a + /// TOTP code and retry with it. + TwoFactorRequired, +} + +/// Result of a mutating account operation that may require 2FA. +pub enum MutateResponse { + /// Operation completed successfully. + Success, + /// Two-factor authentication is required; the caller should prompt for a + /// TOTP code and retry. + TwoFactorRequired, +} + +/// Abstraction over the legacy (Rust sync server) and Hub auth APIs. +/// +/// CLI commands use this trait so they don't need to know which backend is +/// active — they just prompt for input and call these methods. +#[async_trait] +pub trait AuthClient: Send + Sync { + /// Log in with username + password, optionally providing a TOTP code. + async fn login(&self, username: &str, password: &str) -> Result<AuthResponse>; + + /// Register a new account. + async fn register(&self, username: &str, email: &str, password: &str) -> Result<AuthResponse>; + + /// Change the account password, optionally providing a TOTP code. + async fn change_password( + &self, + current_password: &str, + new_password: &str, + totp_code: Option<&str>, + ) -> Result<MutateResponse>; + + /// Delete the account, requiring the current password and optionally a TOTP code. + async fn delete_account( + &self, + password: &str, + totp_code: Option<&str>, + ) -> Result<MutateResponse>; +} + +/// Resolve the appropriate [`AuthClient`] for the current settings. +pub async fn auth_client(settings: &Settings) -> Box<dyn AuthClient> { + Box::new(LegacyAuthClient::new( + &settings.sync_address, + settings.session_token().await.ok(), + settings.network_connect_timeout, + settings.network_timeout, + )) as Box<dyn AuthClient> +} + +// --------------------------------------------------------------------------- +// Legacy backend — talks to the Rust sync server +// --------------------------------------------------------------------------- + +pub struct LegacyAuthClient { + address: String, + session_token: Option<String>, + connect_timeout: u64, + timeout: u64, +} + +impl LegacyAuthClient { + pub fn new( + address: &str, + session_token: Option<String>, + connect_timeout: u64, + timeout: u64, + ) -> Self { + Self { + address: address.to_string(), + session_token, + connect_timeout, + timeout, + } + } + + fn authenticated_client(&self) -> Result<reqwest::Client> { + let token = self + .session_token + .as_deref() + .ok_or_else(|| eyre::eyre!("Not logged in"))?; + + ensure_crypto_provider(); + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + reqwest::header::AUTHORIZATION, + format!("Token {token}").parse()?, + ); + headers.insert(USER_AGENT, APP_USER_AGENT.parse()?); + headers.insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse()?); + + Ok(reqwest::Client::builder() + .default_headers(headers) + .connect_timeout(std::time::Duration::new(self.connect_timeout, 0)) + .timeout(std::time::Duration::new(self.timeout, 0)) + .build()?) + } +} + +#[async_trait] +impl AuthClient for LegacyAuthClient { + async fn login(&self, username: &str, password: &str) -> Result<AuthResponse> { + // The legacy server has no 2FA support; totp_code is ignored. + let resp = api_client::login( + &self.address, + LoginRequest { + username: username.to_string(), + password: password.to_string(), + }, + ) + .await?; + + Ok(AuthResponse::Success { + session: resp.session, + auth_type: resp.auth.or(Some("cli".into())), + }) + } + + async fn register(&self, username: &str, email: &str, password: &str) -> Result<AuthResponse> { + let resp = api_client::register(&self.address, username, email, password).await?; + Ok(AuthResponse::Success { + session: resp.session, + auth_type: resp.auth.or(Some("cli".into())), + }) + } + + async fn change_password( + &self, + current_password: &str, + new_password: &str, + _totp_code: Option<&str>, + ) -> Result<MutateResponse> { + let client = self.authenticated_client()?; + let url = make_url(&self.address, "/account/password")?; + + let resp = client + .patch(&url) + .json(&ChangePasswordRequest { + current_password: current_password.to_string(), + new_password: new_password.to_string(), + }) + .send() + .await?; + + match resp.status().as_u16() { + 200 => Ok(MutateResponse::Success), + 401 => bail!("current password is incorrect"), + 403 => bail!("invalid login details"), + _ => bail!("unknown error"), + } + } + + async fn delete_account( + &self, + password: &str, + _totp_code: Option<&str>, + ) -> Result<MutateResponse> { + let client = self.authenticated_client()?; + let url = make_url(&self.address, "/account")?; + + let resp = client + .delete(&url) + .json(&serde_json::json!({ "password": password })) + .send() + .await?; + + match resp.status().as_u16() { + 200 => Ok(MutateResponse::Success), + 401 => bail!("password is incorrect"), + 403 => bail!("invalid login details"), + _ => bail!("unknown error"), + } + } +} + +// --------------------------------------------------------------------------- +// Shared helpers +// --------------------------------------------------------------------------- + +fn make_url(address: &str, path: &str) -> Result<String> { + let address = if address.ends_with('/') { + address.to_string() + } else { + format!("{address}/") + }; + + let path = path.strip_prefix('/').unwrap_or(path); + + let url = Url::parse(&address) + .context("failed to parse server address")? + .join(path) + .context("failed to join URL path")?; + + Ok(url.to_string()) +} diff --git a/crates/turtle/src/atuin_client/database.rs b/crates/turtle/src/atuin_client/database.rs new file mode 100644 index 00000000..75b1200c --- /dev/null +++ b/crates/turtle/src/atuin_client/database.rs @@ -0,0 +1,1526 @@ +use std::{ + env, + path::{Path, PathBuf}, + str::FromStr, + time::Duration, +}; + +use crate::atuin_client::history::{AUTHOR_FILTER_ALL_AGENT, AUTHOR_FILTER_ALL_USER, KNOWN_AGENTS}; +use crate::atuin_common::utils; +use async_trait::async_trait; +use fs_err as fs; +use itertools::Itertools; +use rand::{Rng, distributions::Alphanumeric}; +use sql_builder::{SqlBuilder, SqlName, bind::Bind, esc, quote}; +use sqlx::{ + Result, Row, + sqlite::{ + SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow, + SqliteSynchronous, + }, +}; +use time::OffsetDateTime; +use tracing::debug; +use uuid::Uuid; + +use crate::atuin_client::{ + history::{HistoryId, HistoryStats}, + utils::get_host_user, +}; + +use super::{ + history::History, + ordering, + settings::{FilterMode, SearchMode, Settings}, +}; + +#[derive(Clone)] +pub struct Context { + pub session: String, + pub cwd: String, + pub hostname: String, + pub host_id: String, + pub git_root: Option<PathBuf>, +} + +#[derive(Default, Clone)] +pub struct OptFilters { + pub exit: Option<i64>, + pub exclude_exit: Option<i64>, + pub cwd: Option<String>, + pub exclude_cwd: Option<String>, + pub before: Option<String>, + pub after: Option<String>, + pub limit: Option<i64>, + pub offset: Option<i64>, + pub reverse: bool, + pub include_duplicates: bool, + /// Author filter. Supports special values `$all-user` and `$all-agent`. + pub authors: Vec<String>, +} + +pub async fn current_context() -> eyre::Result<Context> { + let session = env::var("ATUIN_SESSION").map_err(|_| { + eyre::eyre!("Failed to find $ATUIN_SESSION in the environment. Check that you have correctly set up your shell.") + })?; + let hostname = get_host_user(); + let cwd = utils::get_current_dir(); + let host_id = Settings::host_id().await?; + let git_root = utils::in_git_repo(cwd.as_str()); + + Ok(Context { + session, + hostname, + cwd, + git_root, + host_id: host_id.0.as_simple().to_string(), + }) +} + +impl Context { + pub fn from_history(entry: &History) -> Self { + Context { + session: entry.session.to_string(), + cwd: entry.cwd.to_string(), + hostname: entry.hostname.to_string(), + host_id: String::new(), + git_root: utils::in_git_repo(entry.cwd.as_str()), + } + } +} + +/// Each entry is OR'd: `$all-user` → NOT IN agents, `$all-agent` → IN agents, literal → exact match. +fn apply_author_filter(sql: &mut SqlBuilder, authors: &[String]) { + let mut conditions: Vec<String> = Vec::new(); + let agent_list: String = KNOWN_AGENTS.iter().map(quote).join(", "); + let author_expr = "CASE \ + WHEN author IS NULL OR trim(author) = '' THEN \ + CASE \ + WHEN instr(hostname, ':') > 0 THEN substr(hostname, instr(hostname, ':') + 1) \ + ELSE hostname \ + END \ + ELSE author \ + END"; + + for author in authors { + match author.as_str() { + AUTHOR_FILTER_ALL_USER => { + conditions.push(format!("{author_expr} NOT IN ({agent_list})")); + } + AUTHOR_FILTER_ALL_AGENT => { + conditions.push(format!("{author_expr} IN ({agent_list})")); + } + literal => { + conditions.push(format!("{author_expr} = {}", quote(literal))); + } + } + } + + if !conditions.is_empty() { + sql.and_where(format!("({})", conditions.join(" OR "))); + } +} + +fn get_session_start_time(session_id: &str) -> Option<i64> { + if let Ok(uuid) = Uuid::parse_str(session_id) + && let Some(timestamp) = uuid.get_timestamp() + { + let (seconds, nanos) = timestamp.to_unix(); + return Some(seconds as i64 * 1_000_000_000 + nanos as i64); + } + None +} + +#[async_trait] +pub trait Database: Send + Sync + 'static { + async fn save(&self, h: &History) -> Result<()>; + async fn save_bulk(&self, h: &[History]) -> Result<()>; + + async fn load(&self, id: &str) -> Result<Option<History>>; + async fn list( + &self, + filters: &[FilterMode], + context: &Context, + max: Option<usize>, + unique: bool, + include_deleted: bool, + ) -> Result<Vec<History>>; + async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>>; + + async fn update(&self, h: &History) -> Result<()>; + async fn history_count(&self, include_deleted: bool) -> Result<i64>; + + async fn last(&self) -> Result<Option<History>>; + async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>>; + + async fn delete(&self, h: History) -> Result<()>; + async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()>; + async fn deleted(&self) -> Result<Vec<History>>; + + // Yes I know, it's a lot. + // Could maybe break it down to a searchparams struct or smth but that feels a little... pointless. + // Been debating maybe a DSL for search? eg "before:time limit:1 the query" + #[expect(clippy::too_many_arguments)] + async fn search( + &self, + search_mode: SearchMode, + filter: FilterMode, + context: &Context, + query: &str, + filter_options: OptFilters, + ) -> Result<Vec<History>>; + + async fn query_history(&self, query: &str) -> Result<Vec<History>>; + + async fn all_with_count(&self) -> Result<Vec<(History, i32)>>; + + fn all_paged(&self, page_size: usize, include_deleted: bool, unique: bool) -> Paged; + + async fn stats(&self, h: &History) -> Result<HistoryStats>; + + async fn get_dups(&self, before: i64, dupkeep: u32) -> Result<Vec<History>>; + + fn clone_boxed(&self) -> Box<dyn Database + 'static>; +} + +// Intended for use on a developer machine and not a sync server. +// TODO: implement IntoIterator +#[derive(Debug, Clone)] +pub struct Sqlite { + pub pool: SqlitePool, +} + +impl Sqlite { + pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> { + let path = path.as_ref(); + debug!("opening sqlite database at {path:?}"); + + if utils::broken_symlink(path) { + eprintln!( + "Atuin: Sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement." + ); + std::process::exit(1); + } + + if !path.exists() + && let Some(dir) = path.parent() + { + fs::create_dir_all(dir)?; + } + + let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())? + .journal_mode(SqliteJournalMode::Wal) + .optimize_on_close(true, None) + .synchronous(SqliteSynchronous::Normal) + .with_regexp() + .create_if_missing(true); + + let pool = SqlitePoolOptions::new() + .acquire_timeout(Duration::from_secs_f64(timeout)) + .connect_with(opts) + .await?; + + Self::setup_db(&pool).await?; + Ok(Self { pool }) + } + + pub async fn sqlite_version(&self) -> Result<String> { + sqlx::query_scalar("SELECT sqlite_version()") + .fetch_one(&self.pool) + .await + } + + async fn setup_db(pool: &SqlitePool) -> Result<()> { + debug!("running sqlite database setup"); + + sqlx::migrate!("./migrations").run(pool).await?; + + Ok(()) + } + + async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, h: &History) -> Result<()> { + sqlx::query( + "insert or ignore into history(id, timestamp, duration, exit, command, cwd, session, hostname, author, intent, deleted_at) + values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)", + ) + .bind(h.id.0.as_str()) + .bind(h.timestamp.unix_timestamp_nanos() as i64) + .bind(h.duration) + .bind(h.exit) + .bind(h.command.as_str()) + .bind(h.cwd.as_str()) + .bind(h.session.as_str()) + .bind(h.hostname.as_str()) + .bind(h.author.as_str()) + .bind(h.intent.as_deref()) + .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64)) + .execute(&mut **tx) + .await?; + + Ok(()) + } + + async fn delete_row_raw( + tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, + id: HistoryId, + ) -> Result<()> { + sqlx::query("delete from history where id = ?1") + .bind(id.0.as_str()) + .execute(&mut **tx) + .await?; + + Ok(()) + } + + fn query_history(row: SqliteRow) -> History { + let deleted_at: Option<i64> = row.get("deleted_at"); + let hostname: String = row.get("hostname"); + let author: Option<String> = row.try_get("author").ok().flatten(); + let author = author + .filter(|author| !author.trim().is_empty()) + .unwrap_or_else(|| History::author_from_hostname(hostname.as_str())); + let intent: Option<String> = row.try_get("intent").ok().flatten(); + let intent = intent.filter(|intent| !intent.trim().is_empty()); + + History::from_db() + .id(row.get("id")) + .timestamp( + OffsetDateTime::from_unix_timestamp_nanos(row.get::<i64, _>("timestamp") as i128) + .unwrap(), + ) + .duration(row.get("duration")) + .exit(row.get("exit")) + .command(row.get("command")) + .cwd(row.get("cwd")) + .session(row.get("session")) + .hostname(hostname) + .author(author) + .intent(intent) + .deleted_at( + deleted_at.and_then(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128).ok()), + ) + .build() + .into() + } +} + +#[async_trait] +impl Database for Sqlite { + async fn save(&self, h: &History) -> Result<()> { + debug!("saving history to sqlite"); + let mut tx = self.pool.begin().await?; + Self::save_raw(&mut tx, h).await?; + tx.commit().await?; + + Ok(()) + } + + async fn save_bulk(&self, h: &[History]) -> Result<()> { + debug!("saving history to sqlite"); + + let mut tx = self.pool.begin().await?; + + for i in h { + Self::save_raw(&mut tx, i).await?; + } + + tx.commit().await?; + + Ok(()) + } + + async fn load(&self, id: &str) -> Result<Option<History>> { + debug!("loading history item {}", id); + + let res = sqlx::query("select * from history where id = ?1") + .bind(id) + .map(Self::query_history) + .fetch_optional(&self.pool) + .await?; + + Ok(res) + } + + async fn update(&self, h: &History) -> Result<()> { + debug!("updating sqlite history"); + + sqlx::query( + "update history + set timestamp = ?2, duration = ?3, exit = ?4, command = ?5, cwd = ?6, session = ?7, hostname = ?8, author = ?9, intent = ?10, deleted_at = ?11 + where id = ?1", + ) + .bind(h.id.0.as_str()) + .bind(h.timestamp.unix_timestamp_nanos() as i64) + .bind(h.duration) + .bind(h.exit) + .bind(h.command.as_str()) + .bind(h.cwd.as_str()) + .bind(h.session.as_str()) + .bind(h.hostname.as_str()) + .bind(h.author.as_str()) + .bind(h.intent.as_deref()) + .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64)) + .execute(&self.pool) + .await?; + + Ok(()) + } + + // make a unique list, that only shows the *newest* version of things + async fn list( + &self, + filters: &[FilterMode], + context: &Context, + max: Option<usize>, + unique: bool, + include_deleted: bool, + ) -> Result<Vec<History>> { + debug!("listing history"); + + let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted()); + query.field("*").order_desc("timestamp"); + if !include_deleted { + query.and_where_is_null("deleted_at"); + } + + let git_root = if let Some(git_root) = context.git_root.clone() { + git_root.to_str().unwrap_or("/").to_string() + } else { + context.cwd.clone() + }; + + let session_start = get_session_start_time(&context.session); + + for filter in filters { + match filter { + FilterMode::Global => &mut query, + FilterMode::Host => query.and_where_eq("hostname", quote(&context.hostname)), + FilterMode::Session => query.and_where_eq("session", quote(&context.session)), + FilterMode::SessionPreload => { + query.and_where_eq("session", quote(&context.session)); + if let Some(session_start) = session_start { + query.or_where_lt("timestamp", session_start); + } + &mut query + } + FilterMode::Directory => query.and_where_eq("cwd", quote(&context.cwd)), + FilterMode::Workspace => query.and_where_like_left("cwd", &git_root), + }; + } + + if unique { + query.group_by("command").having("max(timestamp)"); + } + + if let Some(max) = max { + query.limit(max); + } + + let query = query.sql().expect("bug in list query. please report"); + + let res = sqlx::query(&query) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>> { + debug!("listing history from {:?} to {:?}", from, to); + + let res = sqlx::query( + "select * from history where timestamp >= ?1 and timestamp <= ?2 order by timestamp asc", + ) + .bind(from.unix_timestamp_nanos() as i64) + .bind(to.unix_timestamp_nanos() as i64) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn last(&self) -> Result<Option<History>> { + let res = sqlx::query( + "select * from history where duration >= 0 order by timestamp desc limit 1", + ) + .map(Self::query_history) + .fetch_optional(&self.pool) + .await?; + + Ok(res) + } + + async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>> { + let res = sqlx::query( + "select * from history where timestamp < ?1 order by timestamp desc limit ?2", + ) + .bind(timestamp.unix_timestamp_nanos() as i64) + .bind(count) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn deleted(&self) -> Result<Vec<History>> { + let res = sqlx::query("select * from history where deleted_at is not null") + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn history_count(&self, include_deleted: bool) -> Result<i64> { + let query = if include_deleted { + "select count(1) from history" + } else { + "select count(1) from history where deleted_at is null" + }; + + let res: (i64,) = sqlx::query_as(query).fetch_one(&self.pool).await?; + Ok(res.0) + } + + async fn search( + &self, + search_mode: SearchMode, + filter: FilterMode, + context: &Context, + query: &str, + filter_options: OptFilters, + ) -> Result<Vec<History>> { + let mut sql = SqlBuilder::select_from("history"); + + if !filter_options.include_duplicates { + sql.group_by("command").having("max(timestamp)"); + } + + if let Some(limit) = filter_options.limit { + sql.limit(limit); + } + + if let Some(offset) = filter_options.offset { + sql.offset(offset); + } + + if filter_options.reverse { + sql.order_asc("timestamp"); + } else { + sql.order_desc("timestamp"); + } + + let git_root = if let Some(git_root) = context.git_root.clone() { + git_root.to_str().unwrap_or("/").to_string() + } else { + context.cwd.clone() + }; + + let session_start = get_session_start_time(&context.session); + + match filter { + FilterMode::Global => &mut sql, + FilterMode::Host => { + sql.and_where_eq("lower(hostname)", quote(context.hostname.to_lowercase())) + } + FilterMode::Session => sql.and_where_eq("session", quote(&context.session)), + FilterMode::SessionPreload => { + sql.and_where_eq("session", quote(&context.session)); + if let Some(session_start) = session_start { + sql.or_where_lt("timestamp", session_start); + } + &mut sql + } + FilterMode::Directory => sql.and_where_eq("cwd", quote(&context.cwd)), + FilterMode::Workspace => sql.and_where_like_left("cwd", git_root), + }; + + let orig_query = query; + + let mut regexes = Vec::new(); + match search_mode { + SearchMode::Prefix => sql.and_where_like_left("command", query.replace('*', "%")), + _ => { + let mut is_or = false; + for token in QueryTokenizer::new(query) { + // TODO smart case mode could be made configurable like in fzf + let (is_glob, glob) = if token.has_uppercase() { + (true, "*") + } else { + (false, "%") + }; + let param = match token { + QueryToken::Regex(r) => { + regexes.push(String::from(r)); + continue; + } + QueryToken::Or => { + if !is_or { + is_or = true; + continue; + } else { + format!("{glob}|{glob}") + } + } + QueryToken::MatchStart(term, _) => { + format!("{term}{glob}") + } + QueryToken::MatchEnd(term, _) => { + format!("{glob}{term}") + } + QueryToken::MatchFull(term, _) => { + format!("{glob}{term}{glob}") + } + QueryToken::Match(term, _) => { + if search_mode == SearchMode::FullText { + format!("{glob}{term}{glob}") + } else { + term.split("").join(glob) + } + } + }; + + sql.fuzzy_condition("command", param, token.is_inverse(), is_glob, is_or); + is_or = false; + } + + &mut sql + } + }; + + for regex in regexes { + sql.and_where("command regexp ?".bind(®ex)); + } + + filter_options + .exit + .map(|exit| sql.and_where_eq("exit", exit)); + + filter_options + .exclude_exit + .map(|exclude_exit| sql.and_where_ne("exit", exclude_exit)); + + filter_options + .cwd + .map(|cwd| sql.and_where_eq("cwd", quote(cwd))); + + filter_options + .exclude_cwd + .map(|exclude_cwd| sql.and_where_ne("cwd", quote(exclude_cwd))); + + filter_options.before.map(|before| { + interim::parse_date_string( + before.as_str(), + OffsetDateTime::now_utc(), + interim::Dialect::Uk, + ) + .map(|before| { + sql.and_where_lt("timestamp", quote(before.unix_timestamp_nanos() as i64)) + }) + }); + + filter_options.after.map(|after| { + interim::parse_date_string( + after.as_str(), + OffsetDateTime::now_utc(), + interim::Dialect::Uk, + ) + .map(|after| sql.and_where_gt("timestamp", quote(after.unix_timestamp_nanos() as i64))) + }); + + if !filter_options.authors.is_empty() { + apply_author_filter(&mut sql, &filter_options.authors); + } + + sql.and_where_is_null("deleted_at"); + + let query = sql.sql().expect("bug in search query. please report"); + + let res = sqlx::query(&query) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(ordering::reorder_fuzzy(search_mode, orig_query, res)) + } + + async fn query_history(&self, query: &str) -> Result<Vec<History>> { + let res = sqlx::query(query) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn all_with_count(&self) -> Result<Vec<(History, i32)>> { + debug!("listing history"); + + let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted()); + + query + .fields(&[ + "id", + "max(timestamp) as timestamp", + "max(duration) as duration", + "exit", + "command", + "deleted_at", + "null as author", + "null as intent", + "group_concat(cwd, ':') as cwd", + "group_concat(session) as session", + "group_concat(hostname, ',') as hostname", + "count(*) as count", + ]) + .group_by("command") + .group_by("exit") + .and_where("deleted_at is null") + .order_desc("timestamp"); + + let query = query.sql().expect("bug in list query. please report"); + + let res = sqlx::query(&query) + .map(|row: SqliteRow| { + let count: i32 = row.get("count"); + (Self::query_history(row), count) + }) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + fn all_paged(&self, page_size: usize, include_deleted: bool, unique: bool) -> Paged { + Paged::new(Box::new(self.clone()), page_size, include_deleted, unique) + } + + // deleted_at doesn't mean the actual time that the user deleted it, + // but the time that the system marks it as deleted + async fn delete(&self, mut h: History) -> Result<()> { + let now = OffsetDateTime::now_utc(); + h.command = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(32) + .map(char::from) + .collect(); // overwrite with random string + h.deleted_at = Some(now); // delete it + + self.update(&h).await?; // save it + + Ok(()) + } + + async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()> { + let mut tx = self.pool.begin().await?; + + for id in ids { + Self::delete_row_raw(&mut tx, id.clone()).await?; + } + + tx.commit().await?; + + Ok(()) + } + + async fn stats(&self, h: &History) -> Result<HistoryStats> { + // We select the previous in the session by time + let mut prev = SqlBuilder::select_from("history"); + prev.field("*") + .and_where("timestamp < ?1") + .and_where("session = ?2") + .order_by("timestamp", true) + .limit(1); + + let mut next = SqlBuilder::select_from("history"); + next.field("*") + .and_where("timestamp > ?1") + .and_where("session = ?2") + .order_by("timestamp", false) + .limit(1); + + let mut total = SqlBuilder::select_from("history"); + total.field("count(1)").and_where("command = ?1"); + + let mut average = SqlBuilder::select_from("history"); + average.field("avg(duration)").and_where("command = ?1"); + + let mut exits = SqlBuilder::select_from("history"); + exits + .fields(&["exit", "count(1) as count"]) + .and_where("command = ?1") + .group_by("exit"); + + // rewrite the following with sqlbuilder + let mut day_of_week = SqlBuilder::select_from("history"); + day_of_week + .fields(&[ + "strftime('%w', ROUND(timestamp / 1000000000), 'unixepoch') AS day_of_week", + "count(1) as count", + ]) + .and_where("command = ?1") + .group_by("day_of_week"); + + // Intentionally format the string with 01 hardcoded. We want the average runtime for the + // _entire month_, but will later parse it as a datetime for sorting + // Sqlite has no datetime so we cannot do it there, and otherwise sorting will just be a + // string sort, which won't be correct. + let mut duration_over_time = SqlBuilder::select_from("history"); + duration_over_time + .fields(&[ + "strftime('01-%m-%Y', ROUND(timestamp / 1000000000), 'unixepoch') AS month_year", + "avg(duration) as duration", + ]) + .and_where("command = ?1") + .group_by("month_year") + .having("duration > 0"); + + let prev = prev.sql().expect("issue in stats previous query"); + let next = next.sql().expect("issue in stats next query"); + let total = total.sql().expect("issue in stats average query"); + let average = average.sql().expect("issue in stats previous query"); + let exits = exits.sql().expect("issue in stats exits query"); + let day_of_week = day_of_week.sql().expect("issue in stats day of week query"); + let duration_over_time = duration_over_time + .sql() + .expect("issue in stats duration over time query"); + + let prev = sqlx::query(&prev) + .bind(h.timestamp.unix_timestamp_nanos() as i64) + .bind(&h.session) + .map(Self::query_history) + .fetch_optional(&self.pool) + .await?; + + let next = sqlx::query(&next) + .bind(h.timestamp.unix_timestamp_nanos() as i64) + .bind(&h.session) + .map(Self::query_history) + .fetch_optional(&self.pool) + .await?; + + let total: (i64,) = sqlx::query_as(&total) + .bind(&h.command) + .fetch_one(&self.pool) + .await?; + + let average: (f64,) = sqlx::query_as(&average) + .bind(&h.command) + .fetch_one(&self.pool) + .await?; + + let exits: Vec<(i64, i64)> = sqlx::query_as(&exits) + .bind(&h.command) + .fetch_all(&self.pool) + .await?; + + let day_of_week: Vec<(String, i64)> = sqlx::query_as(&day_of_week) + .bind(&h.command) + .fetch_all(&self.pool) + .await?; + + let duration_over_time: Vec<(String, f64)> = sqlx::query_as(&duration_over_time) + .bind(&h.command) + .fetch_all(&self.pool) + .await?; + + let duration_over_time = duration_over_time + .iter() + .map(|f| (f.0.clone(), f.1.round() as i64)) + .collect(); + + Ok(HistoryStats { + next, + previous: prev, + total: total.0 as u64, + average_duration: average.0 as u64, + exits, + day_of_week, + duration_over_time, + }) + } + + async fn get_dups(&self, before: i64, dupkeep: u32) -> Result<Vec<History>> { + let res = sqlx::query( + "SELECT * FROM ( + SELECT *, ROW_NUMBER() + OVER (PARTITION BY command, cwd, hostname ORDER BY timestamp DESC) + AS rn + FROM history + ) sub + WHERE rn > ?1 and timestamp < ?2; + ", + ) + .bind(dupkeep) + .bind(before) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + fn clone_boxed(&self) -> Box<dyn Database + 'static> { + Box::new(self.clone()) + } +} + +pub struct Paged { + database: Box<dyn Database + 'static>, + page_size: usize, + last_id: Option<String>, + include_deleted: bool, + unique: bool, +} + +impl Paged { + pub fn new( + database: Box<dyn Database + 'static>, + page_size: usize, + include_deleted: bool, + unique: bool, + ) -> Self { + Self { + database, + page_size, + last_id: None, + include_deleted, + unique, + } + } + + pub async fn next(&mut self) -> Result<Option<Vec<History>>> { + let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted()); + + query.field("*").order_desc("id"); + + if !self.include_deleted { + query.and_where_is_null("deleted_at"); + } + + if self.unique { + // We want to deduplicate on command, but the user can search via cwd, hostname, and session. + // Without those fields, filter modes won't work right. With those fields, we get duplicates. + // This must be handled upstream. + query + .group_by("command, cwd, hostname, session") + .having("max(timestamp)"); + } + + query.limit(self.page_size); + + if let Some(last_id) = &self.last_id { + query.and_where_lt("id", quote(last_id)); + } + + let query = query.sql().expect("bug in list query. please report"); + let res = self.database.query_history(&query).await?; + + if res.is_empty() { + Ok(None) + } else { + self.last_id = Some(res.last().unwrap().id.0.clone()); + Ok(Some(res)) + } + } +} + +trait SqlBuilderExt { + fn fuzzy_condition<S: ToString, T: ToString>( + &mut self, + field: S, + mask: T, + inverse: bool, + glob: bool, + is_or: bool, + ) -> &mut Self; +} + +impl SqlBuilderExt for SqlBuilder { + /// adapted from the sql-builder *like functions + fn fuzzy_condition<S: ToString, T: ToString>( + &mut self, + field: S, + mask: T, + inverse: bool, + glob: bool, + is_or: bool, + ) -> &mut Self { + let mut cond = field.to_string(); + if inverse { + cond.push_str(" NOT"); + } + if glob { + cond.push_str(" GLOB '"); + } else { + cond.push_str(" LIKE '"); + } + cond.push_str(&esc(mask.to_string())); + cond.push('\''); + if is_or { + self.or_where(cond) + } else { + self.and_where(cond) + } + } +} + +#[cfg(test)] +mod test { + use crate::atuin_client::settings::test_local_timeout; + + use super::*; + use std::time::{Duration, Instant}; + + async fn assert_search_eq( + db: &impl Database, + mode: SearchMode, + filter_mode: FilterMode, + query: &str, + expected: usize, + ) -> Result<Vec<History>> { + let context = Context { + hostname: "test:host".to_string(), + session: "beepboopiamasession".to_string(), + cwd: "/home/ellie".to_string(), + host_id: "test-host".to_string(), + git_root: None, + }; + + let results = db + .search( + mode, + filter_mode, + &context, + query, + OptFilters { + ..Default::default() + }, + ) + .await?; + + assert_eq!( + results.len(), + expected, + "query \"{}\", commands: {:?}", + query, + results.iter().map(|a| &a.command).collect::<Vec<&String>>() + ); + Ok(results) + } + + async fn assert_search_commands( + db: &impl Database, + mode: SearchMode, + filter_mode: FilterMode, + query: &str, + expected_commands: Vec<&str>, + ) { + let results = assert_search_eq(db, mode, filter_mode, query, expected_commands.len()) + .await + .unwrap(); + let commands: Vec<&str> = results.iter().map(|a| a.command.as_str()).collect(); + assert_eq!(commands, expected_commands); + } + + async fn new_history_item(db: &mut impl Database, cmd: &str) -> Result<()> { + let mut captured: History = History::capture() + .timestamp(OffsetDateTime::now_utc()) + .command(cmd) + .cwd("/home/ellie") + .build() + .into(); + + captured.exit = 0; + captured.duration = 1; + captured.session = "beep boop".to_string(); + captured.hostname = "booop".to_string(); + + db.save(&captured).await + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_search_prefix() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + new_history_item(&mut db, "ls /home/ellie").await.unwrap(); + + assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "/home", 0) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls ", 0) + .await + .unwrap(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_search_fulltext() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + new_history_item(&mut db, "ls /home/ellie").await.unwrap(); + + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls ho", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "hm", 0) + .await + .unwrap(); + + // regex + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r/^ls ", 1) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r/ls / ie$", + 1, + ) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r/ls / !ie", + 0, + ) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "meow r/ls/", + 0, + ) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r//hom/", 1) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r//home//", + 1, + ) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r//home///", + 0, + ) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home.*e", 0) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r/home.*e", + 1, + ) + .await + .unwrap(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_search_fuzzy() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + new_history_item(&mut db, "ls /home/ellie").await.unwrap(); + new_history_item(&mut db, "ls /home/frank").await.unwrap(); + new_history_item(&mut db, "cd /home/Ellie").await.unwrap(); + new_history_item(&mut db, "/home/ellie/.bin/rustup") + .await + .unwrap(); + + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls /", 3) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls/", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "l/h/", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e", 3) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/hmoe/", 0) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie/home", 0) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "lsellie", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, " ", 4) + .await + .unwrap(); + + // single term operators + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "'ls", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie$", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!^ls", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie$", 2) + .await + .unwrap(); + + // multiple terms + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls !ellie", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls !e$", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "home !^ls", 2) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::Fuzzy, + FilterMode::Global, + "'frank | 'rustup", + 2, + ) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::Fuzzy, + FilterMode::Global, + "'frank | 'rustup 'ls", + 1, + ) + .await + .unwrap(); + + // case matching + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "Ellie", 1) + .await + .unwrap(); + + // regex + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/^ls ", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/[Ee]llie", 3) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e r/^ls ", 1) + .await + .unwrap(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_search_reordered_fuzzy() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + // test ordering of results: we should choose the first, even though it happened longer ago. + + new_history_item(&mut db, "curl").await.unwrap(); + new_history_item(&mut db, "corburl").await.unwrap(); + + // if fuzzy reordering is on, it should come back in a more sensible order + assert_search_commands( + &db, + SearchMode::Fuzzy, + FilterMode::Global, + "curl", + vec!["curl", "corburl"], + ) + .await; + + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "xxxx", 0) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "", 2) + .await + .unwrap(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_paged_basic() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + + // Add 5 history items + for i in 0..5 { + new_history_item(&mut db, &format!("command{}", i)) + .await + .unwrap(); + } + + // Create a paged iterator with page_size of 2 + let mut paged = db.all_paged(2, false, false); + + // First page should have 2 items + let page1 = paged.next().await.unwrap(); + assert!(page1.is_some()); + assert_eq!(page1.unwrap().len(), 2); + + // Second page should have 2 items + let page2 = paged.next().await.unwrap(); + assert!(page2.is_some()); + assert_eq!(page2.unwrap().len(), 2); + + // Third page should have 1 item + let page3 = paged.next().await.unwrap(); + assert!(page3.is_some()); + assert_eq!(page3.unwrap().len(), 1); + + // Fourth page should be None (exhausted) + let page4 = paged.next().await.unwrap(); + assert!(page4.is_none()); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_paged_empty() { + let db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + + // Create a paged iterator on empty database + let mut paged = db.all_paged(10, false, false); + + // Should return None immediately + let page = paged.next().await.unwrap(); + assert!(page.is_none()); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_paged_unique() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + + // Add duplicate commands + new_history_item(&mut db, "duplicate").await.unwrap(); + new_history_item(&mut db, "duplicate").await.unwrap(); + new_history_item(&mut db, "unique1").await.unwrap(); + new_history_item(&mut db, "unique2").await.unwrap(); + + // Without unique flag - should get all 4 + let mut paged = db.all_paged(10, false, false); + let page = paged.next().await.unwrap().unwrap(); + assert_eq!(page.len(), 4); + + // With unique flag - should get 3 (duplicates collapsed) + let mut paged_unique = db.all_paged(10, false, true); + let page_unique = paged_unique.next().await.unwrap().unwrap(); + assert_eq!(page_unique.len(), 3); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_paged_include_deleted() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + + // Add items + new_history_item(&mut db, "keep1").await.unwrap(); + new_history_item(&mut db, "keep2").await.unwrap(); + new_history_item(&mut db, "delete_me").await.unwrap(); + + // Delete one item + let all = db + .list( + &[], + &Context { + hostname: "".to_string(), + session: "".to_string(), + cwd: "".to_string(), + host_id: "".to_string(), + git_root: None, + }, + None, + false, + false, + ) + .await + .unwrap(); + + let to_delete = all + .iter() + .find(|h| h.command == "delete_me") + .unwrap() + .clone(); + db.delete(to_delete).await.unwrap(); + + // Without include_deleted - should get 2 + let mut paged = db.all_paged(10, false, false); + let page = paged.next().await.unwrap().unwrap(); + assert_eq!(page.len(), 2); + + // With include_deleted - should get 3 + let mut paged_deleted = db.all_paged(10, true, false); + let page_deleted = paged_deleted.next().await.unwrap().unwrap(); + assert_eq!(page_deleted.len(), 3); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_search_bench_dupes() { + let context = Context { + hostname: "test:host".to_string(), + session: "beepboopiamasession".to_string(), + cwd: "/home/ellie".to_string(), + host_id: "test-host".to_string(), + git_root: None, + }; + + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + for _i in 1..10000 { + new_history_item(&mut db, "i am a duplicated command") + .await + .unwrap(); + } + let start = Instant::now(); + let _results = db + .search( + SearchMode::Fuzzy, + FilterMode::Global, + &context, + "", + OptFilters { + ..Default::default() + }, + ) + .await + .unwrap(); + let duration = start.elapsed(); + + assert!(duration < Duration::from_secs(15)); + } +} + +pub struct QueryTokenizer<'a> { + query: &'a str, + last_pos: usize, +} + +pub enum QueryToken<'a> { + Match(&'a str, bool), + MatchStart(&'a str, bool), + MatchEnd(&'a str, bool), + MatchFull(&'a str, bool), + Or, + Regex(&'a str), +} + +impl<'a> QueryToken<'a> { + pub fn has_uppercase(&self) -> bool { + match self { + Self::Match(term, _) + | Self::MatchStart(term, _) + | Self::MatchEnd(term, _) + | Self::MatchFull(term, _) => term.contains(char::is_uppercase), + _ => false, + } + } + + pub fn is_inverse(&self) -> bool { + match self { + Self::Match(_, inv) + | Self::MatchStart(_, inv) + | Self::MatchEnd(_, inv) + | Self::MatchFull(_, inv) => *inv, + _ => false, + } + } +} + +impl<'a> QueryTokenizer<'a> { + pub fn new(query: &'a str) -> Self { + Self { query, last_pos: 0 } + } +} + +impl<'a> Iterator for QueryTokenizer<'a> { + type Item = QueryToken<'a>; + fn next(&mut self) -> Option<Self::Item> { + let remaining = &self.query[self.last_pos..]; + if remaining.is_empty() { + return None; + } + + if let Some(remaining) = remaining.strip_prefix("r/") { + let (regex, next_pos) = if let Some(end) = remaining.find("/ ") { + (&remaining[..end], self.last_pos + 2 + end + 2) + } else if let Some(remaining) = remaining.strip_suffix('/') { + (remaining, self.query.len()) + } else { + (remaining, self.query.len()) + }; + self.last_pos = next_pos; + Some(QueryToken::Regex(regex)) + } else { + let (mut part, next_pos) = if let Some(sp) = remaining.find(' ') { + (&remaining[..sp], self.last_pos + sp + 1) + } else { + (remaining, self.query.len()) + }; + self.last_pos = next_pos; + + if part == "|" { + return Some(QueryToken::Or); + } + + let mut is_inverse = false; + if let Some(s) = part.strip_prefix('!') { + part = s; + is_inverse = true; + } + let token = if let Some(s) = part.strip_prefix('^') { + QueryToken::MatchStart(s, is_inverse) + } else if let Some(s) = part.strip_suffix('$') { + QueryToken::MatchEnd(s, is_inverse) + } else if let Some(s) = part.strip_prefix('\'') { + QueryToken::MatchFull(s, is_inverse) + } else { + QueryToken::Match(part, is_inverse) + }; + Some(token) + } + } +} diff --git a/crates/turtle/src/atuin_client/distro.rs b/crates/turtle/src/atuin_client/distro.rs new file mode 100644 index 00000000..dead8355 --- /dev/null +++ b/crates/turtle/src/atuin_client/distro.rs @@ -0,0 +1,89 @@ +use std::process::Command; + +/// Detect the Linux distribution from the system, +/// using system-specific release files and falling +/// back to lsb_release. +pub fn detect_linux_distribution() -> String { + detect_from_os_release() + .or_else(detect_from_debian_version) + .or_else(detect_from_centos_release) + .or_else(detect_from_redhat_release) + .or_else(detect_from_fedora_release) + .or_else(detect_from_arch_release) + .or_else(detect_from_alpine_release) + .or_else(detect_from_suse_release) + .or_else(detect_from_lsb_release) + .unwrap_or_else(|| "Unknown".to_string()) +} + +fn detect_from_os_release() -> Option<String> { + let content = std::fs::read_to_string("/etc/os-release").ok()?; + + content + .lines() + .find(|l| l.starts_with("PRETTY_NAME=")) + .and_then(|l| l.split_once('=').map(|s| s.1)) + .map(|s| s.trim_matches('"').to_string()) +} + +fn detect_from_debian_version() -> Option<String> { + std::fs::read_to_string("/etc/debian_version") + .ok() + .map(|v| format!("Debian {}", v.trim())) +} + +fn detect_from_centos_release() -> Option<String> { + std::fs::read_to_string("/etc/centos-release") + .ok() + .map(|v| v.trim().to_string()) +} + +fn detect_from_redhat_release() -> Option<String> { + std::fs::read_to_string("/etc/redhat-release") + .ok() + .map(|v| v.trim().to_string()) +} + +fn detect_from_fedora_release() -> Option<String> { + std::fs::read_to_string("/etc/fedora-release") + .ok() + .map(|v| v.trim().to_string()) +} + +fn detect_from_arch_release() -> Option<String> { + std::fs::read_to_string("/etc/arch-release") + .ok() + .filter(|v| !v.trim().is_empty()) + .map(|_| "Arch Linux".to_string()) +} + +fn detect_from_alpine_release() -> Option<String> { + std::fs::read_to_string("/etc/alpine-release") + .ok() + .map(|v| format!("Alpine {}", v.trim())) +} + +fn detect_from_suse_release() -> Option<String> { + std::fs::read_to_string("/etc/SuSE-release") + .ok() + .and_then(|content| content.lines().next().map(|l| l.trim().to_string())) +} + +fn detect_from_lsb_release() -> Option<String> { + let output = Command::new("lsb_release").arg("-a").output().ok()?; + + if !output.status.success() { + return None; + } + + let output = String::from_utf8(output.stdout).ok()?; + linux_distro_from_lsb_release(&output) +} + +fn linux_distro_from_lsb_release(output: &str) -> Option<String> { + output + .lines() + .find(|line| line.starts_with("Description:")) + .and_then(|line| line.split_once(':').map(|s| s.1)) + .map(|s| s.trim().to_string()) +} diff --git a/crates/turtle/src/atuin_client/encryption.rs b/crates/turtle/src/atuin_client/encryption.rs new file mode 100644 index 00000000..20a0cd90 --- /dev/null +++ b/crates/turtle/src/atuin_client/encryption.rs @@ -0,0 +1,440 @@ +// The general idea is that we NEVER send cleartext history to the server +// This way the odds of anything private ending up where it should not are +// very low +// The server authenticates via the usual username and password. This has +// nothing to do with the encryption, and is purely authentication! The client +// generates its own secret key, and encrypts all shell history with libsodium's +// secretbox. The data is then sent to the server, where it is stored. All +// clients must share the secret in order to be able to sync, as it is needed +// to decrypt + +use std::{io::prelude::*, path::PathBuf}; + +use base64::prelude::{BASE64_STANDARD, Engine}; +pub use crypto_secretbox::Key; +use crypto_secretbox::{ + AeadCore, AeadInPlace, KeyInit, XSalsa20Poly1305, + aead::{Nonce, OsRng}, +}; +use eyre::{Context, Result, bail, ensure, eyre}; +use fs_err as fs; +use rmp::{Marker, decode::Bytes}; +use serde::{Deserialize, Serialize}; +use time::{OffsetDateTime, format_description::well_known::Rfc3339, macros::format_description}; + +use crate::atuin_client::{history::History, settings::Settings}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct EncryptedHistory { + pub ciphertext: Vec<u8>, + pub nonce: Nonce<XSalsa20Poly1305>, +} + +pub fn generate_encoded_key() -> Result<(Key, String)> { + let key = XSalsa20Poly1305::generate_key(&mut OsRng); + let encoded = encode_key(&key)?; + + Ok((key, encoded)) +} + +pub fn new_key(settings: &Settings) -> Result<Key> { + let path = settings.key_path.as_str(); + let path = PathBuf::from(path); + + if path.exists() { + bail!("key already exists! cannot overwrite"); + } + + let (key, encoded) = generate_encoded_key()?; + + let mut file = fs::File::create(path)?; + file.write_all(encoded.as_bytes())?; + + Ok(key) +} + +// Loads the secret key, will create + save if it doesn't exist +pub fn load_key(settings: &Settings) -> Result<Key> { + let path = settings.key_path.as_str(); + + let key = if PathBuf::from(path).exists() { + let key = fs_err::read_to_string(path)?; + decode_key(key)? + } else { + new_key(settings)? + }; + + Ok(key) +} + +pub fn encode_key(key: &Key) -> Result<String> { + let mut buf = vec![]; + rmp::encode::write_array_len(&mut buf, key.len() as u32) + .wrap_err("could not encode key to message pack")?; + for b in key { + rmp::encode::write_uint(&mut buf, *b as u64) + .wrap_err("could not encode key to message pack")?; + } + let buf = BASE64_STANDARD.encode(buf); + + Ok(buf) +} + +pub fn decode_key(key: String) -> Result<Key> { + use rmp::decode; + + let buf = BASE64_STANDARD + .decode(key.trim_end()) + .wrap_err("encryption key is not a valid base64 encoding")?; + + // old code wrote the key as a fixed length array of 32 bytes + // new code writes the key with a length prefix + match <[u8; 32]>::try_from(&*buf) { + Ok(key) => Ok(key.into()), + Err(_) => { + let mut bytes = rmp::decode::Bytes::new(&buf); + + match Marker::from_u8(buf[0]) { + Marker::Bin8 => { + let len = decode::read_bin_len(&mut bytes).map_err(|err| eyre!("{err:?}"))?; + ensure!(len == 32, "encryption key is not the correct size"); + let key = <[u8; 32]>::try_from(bytes.remaining_slice()) + .context("could not decode encryption key")?; + Ok(key.into()) + } + Marker::Array16 => { + let len = decode::read_array_len(&mut bytes).map_err(|err| eyre!("{err:?}"))?; + ensure!(len == 32, "encryption key is not the correct size"); + + let mut key = Key::default(); + for i in &mut key { + *i = rmp::decode::read_int(&mut bytes).map_err(|err| eyre!("{err:?}"))?; + } + Ok(key) + } + _ => bail!("could not decode encryption key"), + } + } + } +} + +pub fn encrypt(history: &History, key: &Key) -> Result<EncryptedHistory> { + // serialize with msgpack + let mut buf = encode(history)?; + + let nonce = XSalsa20Poly1305::generate_nonce(&mut OsRng); + XSalsa20Poly1305::new(key) + .encrypt_in_place(&nonce, &[], &mut buf) + .map_err(|_| eyre!("could not encrypt"))?; + + Ok(EncryptedHistory { + ciphertext: buf, + nonce, + }) +} + +pub fn decrypt(mut encrypted_history: EncryptedHistory, key: &Key) -> Result<History> { + XSalsa20Poly1305::new(key) + .decrypt_in_place( + &encrypted_history.nonce, + &[], + &mut encrypted_history.ciphertext, + ) + .map_err(|_| eyre!("could not decrypt history"))?; + let plaintext = encrypted_history.ciphertext; + + let history = decode(&plaintext)?; + + Ok(history) +} + +fn format_rfc3339(ts: OffsetDateTime) -> Result<String> { + // horrible hack. chrono AutoSI limits to 0, 3, 6, or 9 decimal places for nanoseconds. + // time does not have this functionality. + static PARTIAL_RFC3339_0: &[time::format_description::FormatItem<'static>] = + format_description!("[year]-[month]-[day]T[hour]:[minute]:[second]Z"); + static PARTIAL_RFC3339_3: &[time::format_description::FormatItem<'static>] = + format_description!("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"); + static PARTIAL_RFC3339_6: &[time::format_description::FormatItem<'static>] = + format_description!("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:6]Z"); + static PARTIAL_RFC3339_9: &[time::format_description::FormatItem<'static>] = + format_description!("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:9]Z"); + + let fmt = match ts.nanosecond() { + 0 => PARTIAL_RFC3339_0, + ns if ns % 1_000_000 == 0 => PARTIAL_RFC3339_3, + ns if ns % 1_000 == 0 => PARTIAL_RFC3339_6, + _ => PARTIAL_RFC3339_9, + }; + + Ok(ts.format(fmt)?) +} + +fn encode(h: &History) -> Result<Vec<u8>> { + use rmp::encode; + + let mut output = vec![]; + // INFO: ensure this is updated when adding new fields + encode::write_array_len(&mut output, 9)?; + + encode::write_str(&mut output, &h.id.0)?; + encode::write_str(&mut output, &(format_rfc3339(h.timestamp)?))?; + encode::write_sint(&mut output, h.duration)?; + encode::write_sint(&mut output, h.exit)?; + encode::write_str(&mut output, &h.command)?; + encode::write_str(&mut output, &h.cwd)?; + encode::write_str(&mut output, &h.session)?; + encode::write_str(&mut output, &h.hostname)?; + match h.deleted_at { + Some(d) => encode::write_str(&mut output, &format_rfc3339(d)?)?, + None => encode::write_nil(&mut output)?, + } + + Ok(output) +} + +fn decode(bytes: &[u8]) -> Result<History> { + use rmp::decode::{self, DecodeStringError}; + + let mut bytes = Bytes::new(bytes); + + let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; + if nfields < 8 { + bail!("malformed decrypted history") + } + if nfields > 9 { + bail!("cannot decrypt history from a newer version of atuin"); + } + + let bytes = bytes.remaining_slice(); + let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (timestamp, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + let mut bytes = Bytes::new(bytes); + let duration = decode::read_int(&mut bytes).map_err(error_report)?; + let exit = decode::read_int(&mut bytes).map_err(error_report)?; + + let bytes = bytes.remaining_slice(); + let (command, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (cwd, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (session, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (hostname, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + // if we have more fields, try and get the deleted_at + let mut deleted_at = None; + let mut bytes = bytes; + if nfields > 8 { + bytes = match decode::read_str_from_slice(bytes) { + Ok((d, b)) => { + deleted_at = Some(d); + b + } + // we accept null here + Err(DecodeStringError::TypeMismatch(Marker::Null)) => { + // consume the null marker + let mut c = Bytes::new(bytes); + decode::read_nil(&mut c).map_err(error_report)?; + c.remaining_slice() + } + Err(err) => return Err(error_report(err)), + }; + } + + if !bytes.is_empty() { + bail!("trailing bytes in encoded history. malformed") + } + + Ok(History { + id: id.to_owned().into(), + timestamp: OffsetDateTime::parse(timestamp, &Rfc3339)?, + duration, + exit, + command: command.to_owned(), + cwd: cwd.to_owned(), + session: session.to_owned(), + hostname: hostname.to_owned(), + author: History::author_from_hostname(hostname), + intent: None, + deleted_at: deleted_at + .map(|t| OffsetDateTime::parse(t, &Rfc3339)) + .transpose()?, + }) +} + +fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { + eyre!("{err:?}") +} + +#[cfg(test)] +mod test { + use crypto_secretbox::{KeyInit, XSalsa20Poly1305, aead::OsRng}; + use pretty_assertions::assert_eq; + use time::{OffsetDateTime, macros::datetime}; + + use crate::history::History; + + use super::{decode, decrypt, encode, encrypt}; + + #[test] + fn test_encrypt_decrypt() { + let key1 = XSalsa20Poly1305::generate_key(&mut OsRng); + let key2 = XSalsa20Poly1305::generate_key(&mut OsRng); + + let history = History::from_db() + .id("1".into()) + .timestamp(OffsetDateTime::now_utc()) + .command("ls".into()) + .cwd("/home/ellie".into()) + .exit(0) + .duration(1) + .session("beep boop".into()) + .hostname("booop".into()) + .author("booop".into()) + .intent(None) + .deleted_at(None) + .build() + .into(); + + let e1 = encrypt(&history, &key1).unwrap(); + let e2 = encrypt(&history, &key2).unwrap(); + + assert_ne!(e1.ciphertext, e2.ciphertext); + assert_ne!(e1.nonce, e2.nonce); + + // test decryption works + // this should pass + match decrypt(e1, &key1) { + Err(e) => panic!("failed to decrypt, got {e}"), + Ok(h) => assert_eq!(h, history), + }; + + // this should err + let _ = decrypt(e2, &key1).expect_err("expected an error decrypting with invalid key"); + } + + #[test] + fn test_decode() { + let bytes = [ + 0x99, 0xD9, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, 53, 51, 56, + 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 187, 50, 48, 50, 51, 45, + 48, 53, 45, 50, 56, 84, 49, 56, 58, 51, 53, 58, 52, 48, 46, 54, 51, 51, 56, 55, 50, 90, + 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, 97, 116, 117, 115, 217, 42, + 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, + 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, 47, 99, 111, 100, 101, 47, 97, + 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, 51, 48, 54, 102, 50, 55, 52, 52, + 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, 102, 57, 52, 53, 55, 187, 102, + 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, 99, 111, 110, 114, 97, 100, 46, + 108, 117, 100, 103, 97, 116, 101, 192, + ]; + let history = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + author: "conrad.ludgate".to_owned(), + intent: None, + deleted_at: None, + }; + + let h = decode(&bytes).unwrap(); + assert_eq!(history, h); + + let b = encode(&h).unwrap(); + assert_eq!(&bytes, &*b); + } + + #[test] + fn test_decode_deleted() { + let history = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + author: "conrad.ludgate".to_owned(), + intent: None, + deleted_at: Some(datetime!(2023-05-28 18:35:40.633872 +00:00)), + }; + + let b = encode(&history).unwrap(); + let h = decode(&b).unwrap(); + assert_eq!(history, h); + } + + #[test] + fn test_decode_old() { + let bytes = [ + 0x98, 0xD9, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, 53, 51, 56, + 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 187, 50, 48, 50, 51, 45, + 48, 53, 45, 50, 56, 84, 49, 56, 58, 51, 53, 58, 52, 48, 46, 54, 51, 51, 56, 55, 50, 90, + 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, 97, 116, 117, 115, 217, 42, + 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, + 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, 47, 99, 111, 100, 101, 47, 97, + 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, 51, 48, 54, 102, 50, 55, 52, 52, + 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, 102, 57, 52, 53, 55, 187, 102, + 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, 99, 111, 110, 114, 97, 100, 46, + 108, 117, 100, 103, 97, 116, 101, + ]; + let history = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + author: "conrad.ludgate".to_owned(), + intent: None, + deleted_at: None, + }; + + let h = decode(&bytes).unwrap(); + assert_eq!(history, h); + } + + #[test] + fn key_encodings() { + use super::{Key, decode_key, encode_key}; + + // a history of our key encodings. + // v11.0.0 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== + // v12.0.0 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== + // v13.0.0 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== + // v13.0.1 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== + // v14.0.0 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== + // v14.0.1 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== + // c7d89c1 3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q== (https://github.com/ellie/atuin/pull/805) + // b53ca35 3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q== (https://github.com/ellie/atuin/pull/974) + // v15.0.0 3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q== + // b8b57c8 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== (https://github.com/ellie/atuin/pull/1057) + // 8c94d79 3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q== (https://github.com/ellie/atuin/pull/1089) + + let key = Key::from([ + 27, 91, 42, 91, 210, 107, 9, 216, 170, 190, 242, 62, 6, 84, 69, 148, 148, 53, 251, 117, + 226, 167, 173, 52, 82, 34, 138, 110, 169, 124, 92, 229, + ]); + + assert_eq!( + encode_key(&key).unwrap(), + "3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q==" + ); + + // key encodings we have to support + let valid_encodings = [ + "xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q==", + "3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q==", + ]; + + for k in valid_encodings { + assert_eq!(decode_key(k.to_owned()).expect(k), key); + } + } +} diff --git a/crates/turtle/src/atuin_client/history.rs b/crates/turtle/src/atuin_client/history.rs new file mode 100644 index 00000000..cef65115 --- /dev/null +++ b/crates/turtle/src/atuin_client/history.rs @@ -0,0 +1,756 @@ +use core::fmt::Formatter; +use rmp::decode::DecodeStringError; +use rmp::decode::ValueReadError; +use rmp::{Marker, decode::Bytes}; +use std::env; +use std::fmt::Display; + +use crate::atuin_common::record::DecryptedData; +use crate::atuin_common::utils::uuid_v7; + +use eyre::{Result, bail, eyre}; + +use crate::atuin_client::secrets::SECRET_PATTERNS_RE; +use crate::atuin_client::settings::Settings; +use crate::atuin_client::utils::get_host_user; +use time::OffsetDateTime; + +mod builder; +pub mod store; + +/// Known AI agent author values. Used to expand `$all-agent` and `$all-user` filters. +pub const KNOWN_AGENTS: &[&str] = &["claude-code", "codex", "copilot", "pi"]; +pub const AUTHOR_FILTER_ALL_USER: &str = "$all-user"; +pub const AUTHOR_FILTER_ALL_AGENT: &str = "$all-agent"; + +pub fn is_known_agent(author: &str) -> bool { + KNOWN_AGENTS.contains(&author) +} + +pub fn author_matches_filters(author: &str, filters: &[String]) -> bool { + filters.is_empty() + || filters.iter().any(|filter| match filter.as_str() { + AUTHOR_FILTER_ALL_USER => !is_known_agent(author), + AUTHOR_FILTER_ALL_AGENT => is_known_agent(author), + literal => author == literal, + }) +} + +pub(crate) const HISTORY_VERSION_V0: &str = "v0"; +pub(crate) const HISTORY_VERSION_V1: &str = "v1"; +const HISTORY_RECORD_VERSION_V0: u16 = 0; +const HISTORY_RECORD_VERSION_V1: u16 = 1; +pub(crate) const HISTORY_VERSION: &str = HISTORY_VERSION_V1; +pub const HISTORY_TAG: &str = "history"; +const HISTORY_AUTHOR_ENV: &str = "ATUIN_HISTORY_AUTHOR"; +const HISTORY_INTENT_ENV: &str = "ATUIN_HISTORY_INTENT"; + +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +pub struct HistoryId(pub String); + +impl Display for HistoryId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From<String> for HistoryId { + fn from(s: String) -> Self { + Self(s) + } +} + +/// Client-side history entry. +/// +/// Client stores data unencrypted, and only encrypts it before sending to the server. +/// +/// To create a new history entry, use one of the builders: +/// - [`History::import()`] to import an entry from the shell history file +/// - [`History::capture()`] to capture an entry via hook +/// - [`History::from_db()`] to create an instance from the database entry +// +// ## Implementation Notes +// +// New fields must be added to `History::{serialize,deserialize}` in a backwards +// compatible way (sensible defaults and careful `nfields` handling). +#[derive(Debug, Clone, PartialEq, Eq, sqlx::FromRow)] +pub struct History { + /// A client-generated ID, used to identify the entry when syncing. + /// + /// Stored as `client_id` in the database. + pub id: HistoryId, + /// When the command was run. + pub timestamp: OffsetDateTime, + /// How long the command took to run. + pub duration: i64, + /// The exit code of the command. + pub exit: i64, + /// The command that was run. + pub command: String, + /// The current working directory when the command was run. + pub cwd: String, + /// The session ID, associated with a terminal session. + pub session: String, + /// The hostname of the machine the command was run on. + pub hostname: String, + /// Who wrote this command (human user or automation/agent identity). + pub author: String, + /// Optional rationale for why the command was executed. + pub intent: Option<String>, + /// Timestamp, which is set when the entry is deleted, allowing a soft delete. + pub deleted_at: Option<OffsetDateTime>, +} + +#[derive(Debug, Clone, PartialEq, Eq, sqlx::FromRow)] +pub struct HistoryStats { + /// The command that was ran after this one in the session + pub next: Option<History>, + /// + /// The command that was ran before this one in the session + pub previous: Option<History>, + + /// How many times has this command been ran? + pub total: u64, + + pub average_duration: u64, + + pub exits: Vec<(i64, i64)>, + + pub day_of_week: Vec<(String, i64)>, + + pub duration_over_time: Vec<(String, i64)>, +} + +impl History { + pub(crate) fn author_from_hostname(hostname: &str) -> String { + hostname + .split_once(':') + .map_or_else(|| hostname.to_owned(), |(_, user)| user.to_owned()) + } + + fn normalize_optional_field(field: Option<String>) -> Option<String> { + field.and_then(|value| { + let trimmed = value.trim(); + if trimmed.is_empty() { + None + } else { + Some(trimmed.to_owned()) + } + }) + } + + #[expect(clippy::too_many_arguments)] + fn new( + timestamp: OffsetDateTime, + command: String, + cwd: String, + exit: i64, + duration: i64, + session: Option<String>, + hostname: Option<String>, + author: Option<String>, + intent: Option<String>, + deleted_at: Option<OffsetDateTime>, + ) -> Self { + let session = session + .or_else(|| env::var("ATUIN_SESSION").ok()) + .unwrap_or_else(|| uuid_v7().as_simple().to_string()); + let hostname = hostname.unwrap_or_else(get_host_user); + let author = Self::normalize_optional_field(author) + .or_else(|| Self::normalize_optional_field(env::var(HISTORY_AUTHOR_ENV).ok())) + .unwrap_or_else(|| Self::author_from_hostname(hostname.as_str())); + let intent = Self::normalize_optional_field(intent) + .or_else(|| Self::normalize_optional_field(env::var(HISTORY_INTENT_ENV).ok())); + + Self { + id: uuid_v7().as_simple().to_string().into(), + timestamp, + command, + cwd, + exit, + duration, + session, + hostname, + author, + intent, + deleted_at, + } + } + + pub fn serialize(&self) -> Result<DecryptedData> { + // This is pretty much the same as what we used for the old history, with one difference - + // it uses integers for timestamps rather than a string format. + + use rmp::encode; + + let mut output = vec![]; + + // write the version + encode::write_u16(&mut output, HISTORY_RECORD_VERSION_V1)?; + let include_intent = self.intent.is_some(); + encode::write_array_len(&mut output, 10 + u32::from(include_intent))?; + + encode::write_str(&mut output, &self.id.0)?; + encode::write_u64(&mut output, self.timestamp.unix_timestamp_nanos() as u64)?; + encode::write_sint(&mut output, self.duration)?; + encode::write_sint(&mut output, self.exit)?; + encode::write_str(&mut output, &self.command)?; + encode::write_str(&mut output, &self.cwd)?; + encode::write_str(&mut output, &self.session)?; + encode::write_str(&mut output, &self.hostname)?; + + match self.deleted_at { + Some(d) => encode::write_u64(&mut output, d.unix_timestamp_nanos() as u64)?, + None => encode::write_nil(&mut output)?, + } + + encode::write_str(&mut output, self.author.as_str())?; + if let Some(intent) = &self.intent { + encode::write_str(&mut output, intent.as_str())?; + } + + Ok(DecryptedData(output)) + } + + fn read_optional_string(bytes: &[u8]) -> Result<(Option<String>, &[u8])> { + use rmp::decode; + + fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { + eyre!("{err:?}") + } + + match decode::read_str_from_slice(bytes) { + Ok((value, bytes)) => Ok((Some(value.to_owned()), bytes)), + Err(DecodeStringError::TypeMismatch(Marker::Null)) => { + let mut cursor = Bytes::new(bytes); + decode::read_nil(&mut cursor).map_err(error_report)?; + + Ok((None, cursor.remaining_slice())) + } + Err(err) => Err(error_report(err)), + } + } + + fn deserialize_v0(bytes: &[u8]) -> Result<History> { + use rmp::decode; + + fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { + eyre!("{err:?}") + } + + let mut bytes = Bytes::new(bytes); + + let version = decode::read_u16(&mut bytes).map_err(error_report)?; + + if version != HISTORY_RECORD_VERSION_V0 { + bail!("expected decoding v0 record, found v{version}"); + } + + let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; + + if nfields != 9 { + bail!("cannot decrypt history from a different version of Atuin"); + } + + let bytes = bytes.remaining_slice(); + let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + let mut bytes = Bytes::new(bytes); + let timestamp = decode::read_u64(&mut bytes).map_err(error_report)?; + let duration = decode::read_int(&mut bytes).map_err(error_report)?; + let exit = decode::read_int(&mut bytes).map_err(error_report)?; + + let bytes = bytes.remaining_slice(); + let (command, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (cwd, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (session, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (hostname, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + let mut bytes = Bytes::new(bytes); + + let (deleted_at, bytes) = match decode::read_u64(&mut bytes) { + Ok(unix) => (Some(unix), bytes.remaining_slice()), + // we accept null here + Err(ValueReadError::TypeMismatch(Marker::Null)) => (None, bytes.remaining_slice()), + Err(err) => return Err(error_report(err)), + }; + if !bytes.is_empty() { + bail!("trailing bytes in encoded history. malformed") + } + + Ok(History { + id: id.to_owned().into(), + timestamp: OffsetDateTime::from_unix_timestamp_nanos(timestamp as i128)?, + duration, + exit, + command: command.to_owned(), + cwd: cwd.to_owned(), + session: session.to_owned(), + hostname: hostname.to_owned(), + author: Self::author_from_hostname(hostname), + intent: None, + deleted_at: deleted_at + .map(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128)) + .transpose()?, + }) + } + + fn deserialize_v1(bytes: &[u8]) -> Result<History> { + use rmp::decode; + + fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { + eyre!("{err:?}") + } + + let mut bytes = Bytes::new(bytes); + + let version = decode::read_u16(&mut bytes).map_err(error_report)?; + + if version != HISTORY_RECORD_VERSION_V1 { + bail!("expected decoding v1 record, found v{version}"); + } + + let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; + + if !(10..=11).contains(&nfields) { + bail!("cannot decrypt history from a different version of Atuin"); + } + + let bytes = bytes.remaining_slice(); + let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + let mut bytes = Bytes::new(bytes); + let timestamp = decode::read_u64(&mut bytes).map_err(error_report)?; + let duration = decode::read_int(&mut bytes).map_err(error_report)?; + let exit = decode::read_int(&mut bytes).map_err(error_report)?; + + let bytes = bytes.remaining_slice(); + let (command, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (cwd, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (session, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (hostname, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + let mut bytes = Bytes::new(bytes); + + let (deleted_at, bytes) = match decode::read_u64(&mut bytes) { + Ok(unix) => (Some(unix), bytes.remaining_slice()), + // we accept null here + Err(ValueReadError::TypeMismatch(Marker::Null)) => (None, bytes.remaining_slice()), + Err(err) => return Err(error_report(err)), + }; + let (author, bytes) = Self::read_optional_string(bytes)?; + let (intent, bytes) = if nfields > 10 { + Self::read_optional_string(bytes)? + } else { + (None, bytes) + }; + + if !bytes.is_empty() { + bail!("trailing bytes in encoded history. malformed") + } + + Ok(History { + id: id.to_owned().into(), + timestamp: OffsetDateTime::from_unix_timestamp_nanos(timestamp as i128)?, + duration, + exit, + command: command.to_owned(), + cwd: cwd.to_owned(), + session: session.to_owned(), + hostname: hostname.to_owned(), + author: author.unwrap_or_else(|| Self::author_from_hostname(hostname)), + intent, + deleted_at: deleted_at + .map(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128)) + .transpose()?, + }) + } + + pub fn deserialize(bytes: &[u8], version: &str) -> Result<History> { + match version { + HISTORY_VERSION_V0 => Self::deserialize_v0(bytes), + HISTORY_VERSION_V1 => Self::deserialize_v1(bytes), + + _ => bail!("unknown version {version:?}"), + } + } + + /// Builder for a history entry that is imported from shell history. + /// + /// The only two required fields are `timestamp` and `command`. + /// + /// ## Examples + /// ``` + /// use crate::atuin_client::history::History; + /// + /// let history: History = History::import() + /// .timestamp(time::OffsetDateTime::now_utc()) + /// .command("ls -la") + /// .build() + /// .into(); + /// ``` + /// + /// If shell history contains more information, it can be added to the builder: + /// ``` + /// use crate::atuin_client::history::History; + /// + /// let history: History = History::import() + /// .timestamp(time::OffsetDateTime::now_utc()) + /// .command("ls -la") + /// .cwd("/home/user") + /// .exit(0) + /// .duration(100) + /// .build() + /// .into(); + /// ``` + /// + /// Unknown command or command without timestamp cannot be imported, which + /// is forced at compile time: + /// + /// ```compile_fail + /// use crate::atuin_client::history::History; + /// + /// // this will not compile because timestamp is missing + /// let history: History = History::import() + /// .command("ls -la") + /// .build() + /// .into(); + /// ``` + pub fn import() -> builder::HistoryImportedBuilder { + builder::HistoryImported::builder() + } + + /// Builder for a history entry that is captured via hook. + /// + /// This builder is used only at the `start` step of the hook, + /// so it doesn't have any fields which are known only after + /// the command is finished, such as `exit` or `duration`. + /// + /// ## Examples + /// ```rust + /// use crate::atuin_client::history::History; + /// + /// let history: History = History::capture() + /// .timestamp(time::OffsetDateTime::now_utc()) + /// .command("ls -la") + /// .cwd("/home/user") + /// .build() + /// .into(); + /// ``` + /// + /// Command without any required info cannot be captured, which is forced at compile time: + /// + /// ```compile_fail + /// use crate::atuin_client::history::History; + /// + /// // this will not compile because `cwd` is missing + /// let history: History = History::capture() + /// .timestamp(time::OffsetDateTime::now_utc()) + /// .command("ls -la") + /// .build() + /// .into(); + /// ``` + pub fn capture() -> builder::HistoryCapturedBuilder { + builder::HistoryCaptured::builder() + } + + /// Builder for a history entry that is captured via hook, and sent to the daemon. + /// + /// This builder is used only at the `start` step of the hook, + /// so it doesn't have any fields which are known only after + /// the command is finished, such as `exit` or `duration`. + /// + /// It does, however, include information that can usually be inferred. + /// + /// This is because the daemon we are sending a request to lacks the context of the command + /// + /// ## Examples + /// ```rust + /// use crate::atuin_client::history::History; + /// + /// let history: History = History::daemon() + /// .timestamp(time::OffsetDateTime::now_utc()) + /// .command("ls -la") + /// .cwd("/home/user") + /// .session("018deb6e8287781f9973ef40e0fde76b") + /// .hostname("computer:ellie") + /// .build() + /// .into(); + /// ``` + /// + /// Command without any required info cannot be captured, which is forced at compile time: + /// + /// ```compile_fail + /// use crate::atuin_client::history::History; + /// + /// // this will not compile because `hostname` is missing + /// let history: History = History::daemon() + /// .timestamp(time::OffsetDateTime::now_utc()) + /// .command("ls -la") + /// .cwd("/home/user") + /// .session("018deb6e8287781f9973ef40e0fde76b") + /// .build() + /// .into(); + /// ``` + pub fn daemon() -> builder::HistoryDaemonCaptureBuilder { + builder::HistoryDaemonCapture::builder() + } + + /// Builder for a history entry that is imported from the database. + /// + /// All fields are required, as they are all present in the database. + /// + /// ```compile_fail + /// use crate::atuin_client::history::History; + /// + /// // this will not compile because `id` field is missing + /// let history: History = History::from_db() + /// .timestamp(time::OffsetDateTime::now_utc()) + /// .command("ls -la".to_string()) + /// .cwd("/home/user".to_string()) + /// .exit(0) + /// .duration(100) + /// .session("somesession".to_string()) + /// .hostname("localhost".to_string()) + /// .author("user".to_string()) + /// .intent(None) + /// .deleted_at(None) + /// .build() + /// .into(); + /// ``` + pub fn from_db() -> builder::HistoryFromDbBuilder { + builder::HistoryFromDb::builder() + } + + pub fn success(&self) -> bool { + self.exit == 0 || self.duration == -1 + } + + pub fn should_save(&self, settings: &Settings) -> bool { + !(self.command.starts_with(' ') + || self.command.is_empty() + || settings.history_filter.is_match(&self.command) + || settings.cwd_filter.is_match(&self.cwd) + || (settings.secrets_filter && SECRET_PATTERNS_RE.is_match(&self.command))) + } +} + +#[cfg(test)] +mod tests { + use regex::RegexSet; + use time::macros::datetime; + + use crate::{ + history::{AUTHOR_FILTER_ALL_AGENT, AUTHOR_FILTER_ALL_USER, HISTORY_VERSION}, + settings::Settings, + }; + + use super::{History, author_matches_filters, is_known_agent}; + + // Test that we don't save history where necessary + #[test] + fn privacy_test() { + let settings = Settings { + cwd_filter: RegexSet::new(["^/supasecret"]).unwrap(), + history_filter: RegexSet::new(["^psql"]).unwrap(), + ..Settings::utc() + }; + + let normal_command: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("echo foo") + .cwd("/") + .build() + .into(); + + let with_space: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command(" echo bar") + .cwd("/") + .build() + .into(); + + let empty: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("") + .cwd("/") + .build() + .into(); + + let stripe_key: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("curl foo.com/bar?key=sk_test_1234567890abcdefghijklmnop") + .cwd("/") + .build() + .into(); + + let secret_dir: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("echo ohno") + .cwd("/supasecret") + .build() + .into(); + + let with_psql: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("psql") + .cwd("/supasecret") + .build() + .into(); + + assert!(normal_command.should_save(&settings)); + assert!(!with_space.should_save(&settings)); + assert!(!empty.should_save(&settings)); + assert!(!stripe_key.should_save(&settings)); + assert!(!secret_dir.should_save(&settings)); + assert!(!with_psql.should_save(&settings)); + } + + #[test] + fn known_agents_include_pi() { + assert!(is_known_agent("pi")); + assert!(author_matches_filters( + "pi", + &[AUTHOR_FILTER_ALL_AGENT.to_string()] + )); + assert!(!author_matches_filters( + "pi", + &[AUTHOR_FILTER_ALL_USER.to_string()] + )); + } + + #[test] + fn disable_secrets() { + let settings = Settings { + secrets_filter: false, + ..Settings::utc() + }; + + let stripe_key: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("curl foo.com/bar?key=sk_test_1234567890abcdefghijklmnop") + .cwd("/") + .build() + .into(); + + assert!(stripe_key.should_save(&settings)); + } + + #[test] + fn test_serialize_deserialize() { + let history = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + author: "conrad.ludgate".to_owned(), + intent: None, + deleted_at: None, + }; + + let serialized = history.serialize().expect("failed to serialize history"); + assert_eq!( + &serialized.0[0..3], + [205, 0, 1], + "should encode as history v1" + ); + + let deserialized = History::deserialize(&serialized.0, HISTORY_VERSION) + .expect("failed to deserialize history"); + assert_eq!(history, deserialized); + } + + #[test] + fn test_serialize_deserialize_deleted() { + let history = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + author: "conrad.ludgate".to_owned(), + intent: None, + deleted_at: Some(datetime!(2023-11-19 20:18 +00:00)), + }; + + let serialized = history.serialize().expect("failed to serialize history"); + + let deserialized = History::deserialize(&serialized.0, HISTORY_VERSION) + .expect("failed to deserialize history"); + + assert_eq!(history, deserialized); + } + + #[test] + fn test_serialize_deserialize_with_author_and_intent() { + let history = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + author: "claude".to_owned(), + intent: Some("check repository status".to_owned()), + deleted_at: None, + }; + + let serialized = history.serialize().expect("failed to serialize history"); + let deserialized = History::deserialize(&serialized.0, HISTORY_VERSION) + .expect("failed to deserialize history"); + + assert_eq!(history, deserialized); + } + + #[test] + fn test_serialize_deserialize_version() { + // v0 + let bytes_v0 = [ + 205, 0, 0, 153, 217, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, + 53, 51, 56, 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 207, 23, 99, + 98, 117, 24, 210, 246, 128, 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, + 97, 116, 117, 115, 217, 42, 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, + 46, 108, 117, 100, 103, 97, 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, + 47, 99, 111, 100, 101, 47, 97, 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, + 51, 48, 54, 102, 50, 55, 52, 52, 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, + 102, 57, 52, 53, 55, 187, 102, 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, + 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, 116, 101, 192, + ]; + + let deserialized = History::deserialize(&bytes_v0, "v0"); + assert!(deserialized.is_ok()); + + let deserialized = History::deserialize(&bytes_v0, HISTORY_VERSION); + assert!(deserialized.is_err()); + + let current = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + author: "conrad.ludgate".to_owned(), + intent: None, + deleted_at: None, + }; + + let bytes_v1 = current.serialize().expect("failed to serialize history"); + let deserialized = History::deserialize(&bytes_v1.0, HISTORY_VERSION); + assert!(deserialized.is_ok()); + + let deserialized = History::deserialize(&bytes_v1.0, "v0"); + assert!(deserialized.is_err()); + } +} diff --git a/crates/turtle/src/atuin_client/history/builder.rs b/crates/turtle/src/atuin_client/history/builder.rs new file mode 100644 index 00000000..72a505fd --- /dev/null +++ b/crates/turtle/src/atuin_client/history/builder.rs @@ -0,0 +1,154 @@ +use typed_builder::TypedBuilder; + +use super::History; + +/// Builder for a history entry that is imported from shell history. +/// +/// The only two required fields are `timestamp` and `command`. +#[derive(Debug, Clone, TypedBuilder)] +pub struct HistoryImported { + timestamp: time::OffsetDateTime, + #[builder(setter(into))] + command: String, + #[builder(default = "unknown".into(), setter(into))] + cwd: String, + #[builder(default = -1)] + exit: i64, + #[builder(default = -1)] + duration: i64, + #[builder(default, setter(strip_option, into))] + session: Option<String>, + #[builder(default, setter(strip_option, into))] + hostname: Option<String>, + #[builder(default, setter(strip_option, into))] + author: Option<String>, + #[builder(default, setter(strip_option, into))] + intent: Option<String>, +} + +impl From<HistoryImported> for History { + fn from(imported: HistoryImported) -> Self { + History::new( + imported.timestamp, + imported.command, + imported.cwd, + imported.exit, + imported.duration, + imported.session, + imported.hostname, + imported.author, + imported.intent, + None, + ) + } +} + +/// Builder for a history entry that is captured via hook. +/// +/// This builder is used only at the `start` step of the hook, +/// so it doesn't have any fields which are known only after +/// the command is finished, such as `exit` or `duration`. +#[derive(Debug, Clone, TypedBuilder)] +pub struct HistoryCaptured { + timestamp: time::OffsetDateTime, + #[builder(setter(into))] + command: String, + #[builder(setter(into))] + cwd: String, + #[builder(default, setter(strip_option, into))] + author: Option<String>, + #[builder(default, setter(strip_option, into))] + intent: Option<String>, +} + +impl From<HistoryCaptured> for History { + fn from(captured: HistoryCaptured) -> Self { + History::new( + captured.timestamp, + captured.command, + captured.cwd, + -1, + -1, + None, + None, + captured.author, + captured.intent, + None, + ) + } +} + +/// Builder for a history entry that is loaded from the database. +/// +/// All fields are required, as they are all present in the database. +#[derive(Debug, Clone, TypedBuilder)] +pub struct HistoryFromDb { + id: String, + timestamp: time::OffsetDateTime, + command: String, + cwd: String, + exit: i64, + duration: i64, + session: String, + hostname: String, + author: String, + intent: Option<String>, + deleted_at: Option<time::OffsetDateTime>, +} + +impl From<HistoryFromDb> for History { + fn from(from_db: HistoryFromDb) -> Self { + History { + id: from_db.id.into(), + timestamp: from_db.timestamp, + exit: from_db.exit, + command: from_db.command, + cwd: from_db.cwd, + duration: from_db.duration, + session: from_db.session, + hostname: from_db.hostname, + author: from_db.author, + intent: from_db.intent, + deleted_at: from_db.deleted_at, + } + } +} + +/// Builder for a history entry that is captured via hook and sent to the daemon +/// +/// This builder is similar to Capture, but we just require more information up front. +/// For the old setup, we could just rely on History::new to read some of the missing +/// data. This is no longer the case. +#[derive(Debug, Clone, TypedBuilder)] +pub struct HistoryDaemonCapture { + timestamp: time::OffsetDateTime, + #[builder(setter(into))] + command: String, + #[builder(setter(into))] + cwd: String, + #[builder(setter(into))] + session: String, + #[builder(setter(into))] + hostname: String, + #[builder(default, setter(strip_option, into))] + author: Option<String>, + #[builder(default, setter(strip_option, into))] + intent: Option<String>, +} + +impl From<HistoryDaemonCapture> for History { + fn from(captured: HistoryDaemonCapture) -> Self { + History::new( + captured.timestamp, + captured.command, + captured.cwd, + -1, + -1, + Some(captured.session), + Some(captured.hostname), + captured.author, + captured.intent, + None, + ) + } +} diff --git a/crates/turtle/src/atuin_client/history/store.rs b/crates/turtle/src/atuin_client/history/store.rs new file mode 100644 index 00000000..66d9db47 --- /dev/null +++ b/crates/turtle/src/atuin_client/history/store.rs @@ -0,0 +1,435 @@ +use std::{collections::HashSet, fmt::Write, time::Duration}; + +use eyre::{Result, bail, eyre}; +use indicatif::{ProgressBar, ProgressState, ProgressStyle}; +use rmp::decode::Bytes; +use tracing::debug; + +use crate::atuin_client::{ + database::{Database, current_context}, + record::{encryption::PASETO_V4, sqlite_store::SqliteStore, store::Store}, +}; +use crate::atuin_common::record::{DecryptedData, Host, HostId, Record, RecordId, RecordIdx}; + +use super::{HISTORY_TAG, HISTORY_VERSION, HISTORY_VERSION_V0, History, HistoryId}; + +#[derive(Debug, Clone)] +pub struct HistoryStore { + pub store: SqliteStore, + pub host_id: HostId, + pub encryption_key: [u8; 32], +} + +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum HistoryRecord { + Create(History), // Create a history record + Delete(HistoryId), // Delete a history record, identified by ID +} + +impl HistoryRecord { + /// Serialize a history record, returning DecryptedData + /// The record will be of a certain type + /// We map those like so: + /// + /// HistoryRecord::Create -> 0 + /// HistoryRecord::Delete-> 1 + /// + /// This numeric identifier is then written as the first byte to the buffer. For history, we + /// append the serialized history right afterwards, to avoid having to handle serialization + /// twice. + /// + /// Deletion simply refers to the history by ID + pub fn serialize(&self) -> Result<DecryptedData> { + // probably don't actually need to use rmp here, but if we ever need to extend it, it's a + // nice wrapper around raw byte stuff + use rmp::encode; + + let mut output = vec![]; + + match self { + HistoryRecord::Create(history) => { + // 0 -> a history create + encode::write_u8(&mut output, 0)?; + + let bytes = history.serialize()?; + + encode::write_bin(&mut output, &bytes.0)?; + } + HistoryRecord::Delete(id) => { + // 1 -> a history delete + encode::write_u8(&mut output, 1)?; + encode::write_str(&mut output, id.0.as_str())?; + } + }; + + Ok(DecryptedData(output)) + } + + pub fn deserialize(bytes: &DecryptedData, version: &str) -> Result<Self> { + use rmp::decode; + + fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { + eyre!("{err:?}") + } + + let mut bytes = Bytes::new(&bytes.0); + + let record_type = decode::read_u8(&mut bytes).map_err(error_report)?; + + match record_type { + // 0 -> HistoryRecord::Create + 0 => { + // not super useful to us atm, but perhaps in the future + // written by write_bin above + let _ = decode::read_bin_len(&mut bytes).map_err(error_report)?; + + let record = History::deserialize(bytes.remaining_slice(), version)?; + + Ok(HistoryRecord::Create(record)) + } + + // 1 -> HistoryRecord::Delete + 1 => { + let bytes = bytes.remaining_slice(); + let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + if !bytes.is_empty() { + bail!( + "trailing bytes decoding HistoryRecord::Delete - malformed? got {bytes:?}" + ); + } + + Ok(HistoryRecord::Delete(id.to_string().into())) + } + + n => { + bail!("unknown HistoryRecord type {n}") + } + } + } +} + +impl HistoryStore { + pub fn new(store: SqliteStore, host_id: HostId, encryption_key: [u8; 32]) -> Self { + HistoryStore { + store, + host_id, + encryption_key, + } + } + + async fn push_record(&self, record: HistoryRecord) -> Result<(RecordId, RecordIdx)> { + let bytes = record.serialize()?; + let idx = self + .store + .last(self.host_id, HISTORY_TAG) + .await? + .map_or(0, |p| p.idx + 1); + + let record = Record::builder() + .host(Host::new(self.host_id)) + .version(HISTORY_VERSION.to_string()) + .tag(HISTORY_TAG.to_string()) + .idx(idx) + .data(bytes) + .build(); + + let id = record.id; + + self.store + .push(&record.encrypt::<PASETO_V4>(&self.encryption_key)) + .await?; + + Ok((id, idx)) + } + + async fn push_batch(&self, records: impl Iterator<Item = HistoryRecord>) -> Result<()> { + let mut ret = Vec::new(); + + let idx = self + .store + .last(self.host_id, HISTORY_TAG) + .await? + .map_or(0, |p| p.idx + 1); + + // Could probably _also_ do this as an iterator, but let's see how this is for now. + // optimizing for minimal sqlite transactions, this code can be optimised later + for (n, record) in records.enumerate() { + let bytes = record.serialize()?; + + let record = Record::builder() + .host(Host::new(self.host_id)) + .version(HISTORY_VERSION.to_string()) + .tag(HISTORY_TAG.to_string()) + .idx(idx + n as u64) + .data(bytes) + .build(); + + let record = record.encrypt::<PASETO_V4>(&self.encryption_key); + + ret.push(record); + } + + self.store.push_batch(ret.iter()).await?; + + Ok(()) + } + + pub async fn delete(&self, id: HistoryId) -> Result<(RecordId, RecordIdx)> { + let record = HistoryRecord::Delete(id); + + self.push_record(record).await + } + + /// Delete a batch of history entries via the record store. + /// Returns the record IDs so the caller can run incremental_build when ready. + pub async fn delete_entries( + &self, + entries: impl IntoIterator<Item = History>, + ) -> Result<Vec<RecordId>> { + let mut record_ids = Vec::new(); + for entry in entries { + let (id, _) = self.delete(entry.id).await?; + record_ids.push(id); + } + Ok(record_ids) + } + + pub async fn push(&self, history: History) -> Result<(RecordId, RecordIdx)> { + // TODO(ellie): move the history store to its own file + // it's tiny rn so fine as is + let record = HistoryRecord::Create(history); + + self.push_record(record).await + } + + pub async fn history(&self) -> Result<Vec<HistoryRecord>> { + // Atm this loads all history into memory + // Not ideal as that is potentially quite a lot, although history will be small. + let records = self.store.all_tagged(HISTORY_TAG).await?; + let mut ret = Vec::with_capacity(records.len()); + + for record in records.into_iter() { + let hist = match record.version.as_str() { + HISTORY_VERSION_V0 | HISTORY_VERSION => { + let version = record.version.clone(); + let decrypted = record.decrypt::<PASETO_V4>(&self.encryption_key)?; + + HistoryRecord::deserialize(&decrypted.data, version.as_str()) + } + version => bail!("unknown history version {version:?}"), + }?; + + ret.push(hist); + } + + Ok(ret) + } + + pub async fn build(&self, database: &dyn Database) -> Result<()> { + // I'd like to change how we rebuild and not couple this with the database, but need to + // consider the structure more deeply. This will be easy to change. + + // TODO(ellie): page or iterate this + let history = self.history().await?; + + // In theory we could flatten this here + // The current issue is that the database may have history in it already, from the old sync + // This didn't actually delete old history + // If we're sure we have a DB only maintained by the new store, we can flatten + // create/delete before we even get to sqlite + let mut creates = Vec::new(); + let mut deletes = Vec::new(); + + for i in history { + match i { + HistoryRecord::Create(h) => { + creates.push(h); + } + HistoryRecord::Delete(id) => { + deletes.push(id); + } + } + } + + database.save_bulk(&creates).await?; + database.delete_rows(&deletes).await?; + + Ok(()) + } + + pub async fn incremental_build(&self, database: &dyn Database, ids: &[RecordId]) -> Result<()> { + for id in ids { + let record = self.store.get(*id).await; + + let record = match record { + Ok(record) => record, + _ => { + continue; + } + }; + + if record.tag != HISTORY_TAG { + continue; + } + + let version = record.version.clone(); + let decrypted = record.decrypt::<PASETO_V4>(&self.encryption_key)?; + let record = match version.as_str() { + HISTORY_VERSION_V0 | HISTORY_VERSION => { + HistoryRecord::deserialize(&decrypted.data, version.as_str())? + } + version => bail!("unknown history version {version:?}"), + }; + + match record { + HistoryRecord::Create(h) => { + // TODO: benchmark CPU time/memory tradeoff of batch commit vs one at a time + database.save(&h).await?; + } + HistoryRecord::Delete(id) => { + database.delete_rows(&[id]).await?; + } + } + } + + Ok(()) + } + + /// Get a list of history IDs that exist in the store + /// Note: This currently involves loading all history into memory. This is not going to be a + /// large amount in absolute terms, but do not all it in a hot loop. + pub async fn history_ids(&self) -> Result<HashSet<HistoryId>> { + let history = self.history().await?; + + let ret = HashSet::from_iter(history.iter().map(|h| match h { + HistoryRecord::Create(h) => h.id.clone(), + HistoryRecord::Delete(id) => id.clone(), + })); + + Ok(ret) + } + + pub async fn init_store(&self, db: &impl Database) -> Result<()> { + let pb = ProgressBar::new_spinner(); + pb.set_style( + ProgressStyle::with_template("{spinner:.blue} {msg}") + .unwrap() + .with_key("eta", |state: &ProgressState, w: &mut dyn Write| { + write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap() + }) + .progress_chars("#>-"), + ); + pb.enable_steady_tick(Duration::from_millis(500)); + + pb.set_message("Fetching history from old database"); + + let context = current_context().await?; + let history = db.list(&[], &context, None, false, true).await?; + + pb.set_message("Fetching history already in store"); + let store_ids = self.history_ids().await?; + + pb.set_message("Converting old history to new store"); + let mut records = Vec::new(); + + for i in history { + debug!("loaded {}", i.id); + + if store_ids.contains(&i.id) { + debug!("skipping {} - already exists", i.id); + continue; + } + + if i.deleted_at.is_some() { + records.push(HistoryRecord::Delete(i.id)); + } else { + records.push(HistoryRecord::Create(i)); + } + } + + pb.set_message("Writing to db"); + + if !records.is_empty() { + self.push_batch(records.into_iter()).await?; + } + + pb.finish_with_message("Import complete"); + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use crate::atuin_common::record::DecryptedData; + use time::macros::datetime; + + use crate::atuin_client::history::{HISTORY_VERSION, store::HistoryRecord}; + + use super::History; + + #[test] + fn test_serialize_deserialize_create() { + let bytes = [ + 204, 0, 196, 147, 205, 0, 1, 154, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 56, 49, + 55, 53, 55, 99, 100, 50, 97, 101, 101, 54, 53, 99, 100, 55, 56, 54, 49, 102, 57, 99, + 56, 49, 207, 23, 166, 251, 212, 181, 82, 0, 0, 100, 0, 162, 108, 115, 217, 41, 47, 85, + 115, 101, 114, 115, 47, 101, 108, 108, 105, 101, 47, 115, 114, 99, 47, 103, 105, 116, + 104, 117, 98, 46, 99, 111, 109, 47, 97, 116, 117, 105, 110, 115, 104, 47, 97, 116, 117, + 105, 110, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 97, 100, 56, 57, 55, 53, 57, 55, + 56, 53, 50, 53, 50, 55, 97, 51, 49, 99, 57, 57, 56, 48, 53, 57, 170, 98, 111, 111, 112, + 58, 101, 108, 108, 105, 101, 192, 165, 101, 108, 108, 105, 101, + ]; + + let history = History { + id: "018cd4fe81757cd2aee65cd7861f9c81".to_owned().into(), + timestamp: datetime!(2024-01-04 00:00:00.000000 +00:00), + duration: 100, + exit: 0, + command: "ls".to_owned(), + cwd: "/Users/ellie/src/github.com/atuinsh/atuin".to_owned(), + session: "018cd4fead897597852527a31c998059".to_owned(), + hostname: "boop:ellie".to_owned(), + author: "ellie".to_owned(), + intent: None, + deleted_at: None, + }; + + let record = HistoryRecord::Create(history); + + let serialized = record.serialize().expect("failed to serialize history"); + assert_eq!(serialized.0, bytes); + + let deserialized = HistoryRecord::deserialize(&serialized, HISTORY_VERSION) + .expect("failed to deserialize HistoryRecord"); + assert_eq!(deserialized, record); + + // check the snapshot too + let deserialized = + HistoryRecord::deserialize(&DecryptedData(Vec::from(bytes)), HISTORY_VERSION) + .expect("failed to deserialize HistoryRecord"); + assert_eq!(deserialized, record); + } + + #[test] + fn test_serialize_deserialize_delete() { + let bytes = [ + 204, 1, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 56, 49, 55, 53, 55, 99, 100, 50, + 97, 101, 101, 54, 53, 99, 100, 55, 56, 54, 49, 102, 57, 99, 56, 49, + ]; + let record = HistoryRecord::Delete("018cd4fe81757cd2aee65cd7861f9c81".to_string().into()); + + let serialized = record.serialize().expect("failed to serialize history"); + assert_eq!(serialized.0, bytes); + + let deserialized = HistoryRecord::deserialize(&serialized, HISTORY_VERSION) + .expect("failed to deserialize HistoryRecord"); + assert_eq!(deserialized, record); + + let deserialized = + HistoryRecord::deserialize(&DecryptedData(Vec::from(bytes)), HISTORY_VERSION) + .expect("failed to deserialize HistoryRecord"); + assert_eq!(deserialized, record); + } +} diff --git a/crates/turtle/src/atuin_client/import/bash.rs b/crates/turtle/src/atuin_client/import/bash.rs new file mode 100644 index 00000000..d92fdfa0 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/bash.rs @@ -0,0 +1,221 @@ +use std::{path::PathBuf, str}; + +use async_trait::async_trait; +use directories::UserDirs; +use eyre::{Result, eyre}; +use itertools::Itertools; +use time::{Duration, OffsetDateTime}; +use tracing::warn; + +use super::{Importer, Loader, get_histfile_path, unix_byte_lines}; +use crate::atuin_client::history::History; +use crate::atuin_client::import::read_to_end; + +#[derive(Debug)] +pub struct Bash { + bytes: Vec<u8>, +} + +fn default_histpath() -> Result<PathBuf> { + let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; + let home_dir = user_dirs.home_dir(); + + Ok(home_dir.join(".bash_history")) +} + +#[async_trait] +impl Importer for Bash { + const NAME: &'static str = "bash"; + + async fn new() -> Result<Self> { + let bytes = read_to_end(get_histfile_path(default_histpath)?)?; + Ok(Self { bytes }) + } + + async fn entries(&mut self) -> Result<usize> { + let count = unix_byte_lines(&self.bytes) + .map(LineType::from) + .filter(|line| matches!(line, LineType::Command(_))) + .count(); + Ok(count) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + let lines = unix_byte_lines(&self.bytes) + .map(LineType::from) + .filter(|line| !matches!(line, LineType::NotUtf8)) // invalid utf8 are ignored + .collect_vec(); + + let (commands_before_first_timestamp, first_timestamp) = lines + .iter() + .enumerate() + .find_map(|(i, line)| match line { + LineType::Timestamp(t) => Some((i, *t)), + _ => None, + }) + // if no known timestamps, use now as base + .unwrap_or((lines.len(), OffsetDateTime::now_utc())); + + // if no timestamp is recorded, then use this increment to set an arbitrary timestamp + // to preserve ordering + // this increment is deliberately very small to prevent particularly fast fingers + // causing ordering issues; it also helps in handling the "here document" syntax, + // where several lines are recorded in succession without individual timestamps + let timestamp_increment = Duration::milliseconds(1); + + // make sure there is a minimum amount of time before the first known timestamp + // to fit all commands, given the default increment + let mut next_timestamp = + first_timestamp - timestamp_increment * commands_before_first_timestamp as i32; + + for line in lines.into_iter() { + match line { + LineType::NotUtf8 => unreachable!(), // already filtered + LineType::Empty => {} // do nothing + LineType::Timestamp(t) => { + if t < next_timestamp { + warn!( + "Time reversal detected in Bash history! Commands may be ordered incorrectly." + ); + } + next_timestamp = t; + } + LineType::Command(c) => { + let imported = History::import().timestamp(next_timestamp).command(c); + + h.push(imported.build().into()).await?; + next_timestamp += timestamp_increment; + } + } + } + + Ok(()) + } +} + +#[derive(Debug, Clone)] +enum LineType<'a> { + NotUtf8, + /// Can happen when using the "here document" syntax. + Empty, + /// A timestamp line start with a '#', followed immediately by an integer + /// that represents seconds since UNIX epoch. + Timestamp(OffsetDateTime), + /// Anything else. + Command(&'a str), +} +impl<'a> From<&'a [u8]> for LineType<'a> { + fn from(bytes: &'a [u8]) -> Self { + let Ok(line) = str::from_utf8(bytes) else { + return LineType::NotUtf8; + }; + if line.is_empty() { + return LineType::Empty; + } + + match try_parse_line_as_timestamp(line) { + Some(time) => LineType::Timestamp(time), + None => LineType::Command(line), + } + } +} + +fn try_parse_line_as_timestamp(line: &str) -> Option<OffsetDateTime> { + let seconds = line.strip_prefix('#')?.parse().ok()?; + OffsetDateTime::from_unix_timestamp(seconds).ok() +} + +#[cfg(test)] +mod test { + use std::cmp::Ordering; + + use itertools::{Itertools, assert_equal}; + + use crate::atuin_client::import::{Importer, tests::TestLoader}; + + use super::Bash; + + #[tokio::test] + async fn parse_no_timestamps() { + let bytes = r"cargo install atuin +cargo update +cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷ +" + .as_bytes() + .to_owned(); + + let mut bash = Bash { bytes }; + assert_eq!(bash.entries().await.unwrap(), 3); + + let mut loader = TestLoader::default(); + bash.load(&mut loader).await.unwrap(); + + assert_equal( + loader.buf.iter().map(|h| h.command.as_str()), + [ + "cargo install atuin", + "cargo update", + "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", + ], + ); + assert!(is_strictly_sorted(loader.buf.iter().map(|h| h.timestamp))) + } + + #[tokio::test] + async fn parse_with_timestamps() { + let bytes = b"#1672918999 +git reset +#1672919006 +git clean -dxf +#1672919020 +cd ../ +" + .to_vec(); + + let mut bash = Bash { bytes }; + assert_eq!(bash.entries().await.unwrap(), 3); + + let mut loader = TestLoader::default(); + bash.load(&mut loader).await.unwrap(); + + assert_equal( + loader.buf.iter().map(|h| h.command.as_str()), + ["git reset", "git clean -dxf", "cd ../"], + ); + assert_equal( + loader.buf.iter().map(|h| h.timestamp.unix_timestamp()), + [1_672_918_999, 1_672_919_006, 1_672_919_020], + ) + } + + #[tokio::test] + async fn parse_with_partial_timestamps() { + let bytes = b"git reset +#1672919006 +git clean -dxf +cd ../ +" + .to_vec(); + + let mut bash = Bash { bytes }; + assert_eq!(bash.entries().await.unwrap(), 3); + + let mut loader = TestLoader::default(); + bash.load(&mut loader).await.unwrap(); + + assert_equal( + loader.buf.iter().map(|h| h.command.as_str()), + ["git reset", "git clean -dxf", "cd ../"], + ); + assert!(is_strictly_sorted(loader.buf.iter().map(|h| h.timestamp))) + } + + fn is_strictly_sorted<T>(iter: impl IntoIterator<Item = T>) -> bool + where + T: Clone + PartialOrd, + { + iter.into_iter() + .tuple_windows() + .all(|(a, b)| matches!(a.partial_cmp(&b), Some(Ordering::Less))) + } +} diff --git a/crates/turtle/src/atuin_client/import/fish.rs b/crates/turtle/src/atuin_client/import/fish.rs new file mode 100644 index 00000000..1375bdd6 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/fish.rs @@ -0,0 +1,179 @@ +// import old shell history! +// automatically hoover up all that we can find + +use std::path::PathBuf; + +use async_trait::async_trait; +use directories::BaseDirs; +use eyre::{Result, eyre}; +use time::OffsetDateTime; + +use super::{Importer, Loader, unix_byte_lines}; +use crate::atuin_client::history::History; +use crate::atuin_client::import::read_to_end; + +#[derive(Debug)] +pub struct Fish { + bytes: Vec<u8>, +} + +/// see https://fishshell.com/docs/current/interactive.html#searchable-command-history +fn default_histpath() -> Result<PathBuf> { + let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; + let data = std::env::var("XDG_DATA_HOME").map_or_else( + |_| base.home_dir().join(".local").join("share"), + PathBuf::from, + ); + + // fish supports multiple history sessions + // If `fish_history` var is missing, or set to `default`, use `fish` as the session + let session = std::env::var("fish_history").unwrap_or_else(|_| String::from("fish")); + let session = if session == "default" { + String::from("fish") + } else { + session + }; + + let mut histpath = data.join("fish"); + histpath.push(format!("{session}_history")); + + if histpath.exists() { + Ok(histpath) + } else { + Err(eyre!("Could not find history file.")) + } +} + +#[async_trait] +impl Importer for Fish { + const NAME: &'static str = "fish"; + + async fn new() -> Result<Self> { + let bytes = read_to_end(default_histpath()?)?; + Ok(Self { bytes }) + } + + async fn entries(&mut self) -> Result<usize> { + Ok(super::count_lines(&self.bytes)) + } + + async fn load(self, loader: &mut impl Loader) -> Result<()> { + let now = OffsetDateTime::now_utc(); + let mut time: Option<OffsetDateTime> = None; + let mut cmd: Option<String> = None; + + for b in unix_byte_lines(&self.bytes) { + let s = match std::str::from_utf8(b) { + Ok(s) => s, + Err(_) => continue, // we can skip past things like invalid utf8 + }; + + if let Some(c) = s.strip_prefix("- cmd: ") { + // first, we must deal with the prev cmd + if let Some(cmd) = cmd.take() { + let time = time.unwrap_or(now); + let entry = History::import().timestamp(time).command(cmd); + + loader.push(entry.build().into()).await?; + } + + // using raw strings to avoid needing escaping. + // replaces double backslashes with single backslashes + let c = c.replace(r"\\", r"\"); + // replaces escaped newlines + let c = c.replace(r"\n", "\n"); + // TODO: any other escape characters? + + cmd = Some(c); + } else if let Some(t) = s.strip_prefix(" when: ") { + // if t is not an int, just ignore this line + if let Ok(t) = t.parse::<i64>() { + time = Some(OffsetDateTime::from_unix_timestamp(t)?); + } + } else { + // ... ignore paths lines + } + } + + // we might have a trailing cmd + if let Some(cmd) = cmd.take() { + let time = time.unwrap_or(now); + let entry = History::import().timestamp(time).command(cmd); + + loader.push(entry.build().into()).await?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + + use crate::import::{Importer, tests::TestLoader}; + + use super::Fish; + + #[tokio::test] + async fn parse_complex() { + // complicated input with varying contents and escaped strings. + let bytes = r#"- cmd: history --help + when: 1639162832 +- cmd: cat ~/.bash_history + when: 1639162851 + paths: + - ~/.bash_history +- cmd: ls ~/.local/share/fish/fish_history + when: 1639162890 + paths: + - ~/.local/share/fish/fish_history +- cmd: cat ~/.local/share/fish/fish_history + when: 1639162893 + paths: + - ~/.local/share/fish/fish_history +ERROR +- CORRUPTED: ENTRY + CONTINUE: + - AS + - NORMAL +- cmd: echo "foo" \\\n'bar' baz + when: 1639162933 +- cmd: cat ~/.local/share/fish/fish_history + when: 1639162939 + paths: + - ~/.local/share/fish/fish_history +- cmd: echo "\\"" \\\\ "\\\\" + when: 1639163063 +- cmd: cat ~/.local/share/fish/fish_history + when: 1639163066 + paths: + - ~/.local/share/fish/fish_history +"# + .as_bytes() + .to_owned(); + + let fish = Fish { bytes }; + + let mut loader = TestLoader::default(); + fish.load(&mut loader).await.unwrap(); + let mut history = loader.buf.into_iter(); + + // simple wrapper for fish history entry + macro_rules! fishtory { + ($timestamp:expr_2021, $command:expr_2021) => { + let h = history.next().expect("missing entry in history"); + assert_eq!(h.command.as_str(), $command); + assert_eq!(h.timestamp.unix_timestamp(), $timestamp); + }; + } + + fishtory!(1639162832, "history --help"); + fishtory!(1639162851, "cat ~/.bash_history"); + fishtory!(1639162890, "ls ~/.local/share/fish/fish_history"); + fishtory!(1639162893, "cat ~/.local/share/fish/fish_history"); + fishtory!(1639162933, "echo \"foo\" \\\n'bar' baz"); + fishtory!(1639162939, "cat ~/.local/share/fish/fish_history"); + fishtory!(1639163063, r#"echo "\"" \\ "\\""#); + fishtory!(1639163066, "cat ~/.local/share/fish/fish_history"); + } +} diff --git a/crates/turtle/src/atuin_client/import/mod.rs b/crates/turtle/src/atuin_client/import/mod.rs new file mode 100644 index 00000000..7726ead7 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/mod.rs @@ -0,0 +1,140 @@ +use std::fs::File; +use std::io::Read; +use std::path::PathBuf; + +use async_trait::async_trait; +use eyre::{Result, bail}; +use memchr::Memchr; + +use crate::atuin_client::history::History; + +pub mod bash; +pub mod fish; +pub mod nu; +pub mod nu_histdb; +pub mod powershell; +pub mod replxx; +pub mod resh; +pub mod xonsh; +pub mod xonsh_sqlite; +pub mod zsh; +pub mod zsh_histdb; + +#[async_trait] +pub trait Importer: Sized { + const NAME: &'static str; + async fn new() -> Result<Self>; + async fn entries(&mut self) -> Result<usize>; + async fn load(self, loader: &mut impl Loader) -> Result<()>; +} + +#[async_trait] +pub trait Loader: Sync + Send { + async fn push(&mut self, hist: History) -> eyre::Result<()>; +} + +fn unix_byte_lines(input: &[u8]) -> impl Iterator<Item = &[u8]> { + UnixByteLines { + iter: memchr::memchr_iter(b'\n', input), + bytes: input, + i: 0, + } +} + +struct UnixByteLines<'a> { + iter: Memchr<'a>, + bytes: &'a [u8], + i: usize, +} + +impl<'a> Iterator for UnixByteLines<'a> { + type Item = &'a [u8]; + + fn next(&mut self) -> Option<Self::Item> { + let j = self.iter.next()?; + let out = &self.bytes[self.i..j]; + self.i = j + 1; + Some(out) + } + + fn count(self) -> usize + where + Self: Sized, + { + self.iter.count() + } +} + +fn count_lines(input: &[u8]) -> usize { + unix_byte_lines(input).count() +} + +fn get_histpath<D>(def: D) -> Result<PathBuf> +where + D: FnOnce() -> Result<PathBuf>, +{ + if let Ok(p) = std::env::var("HISTFILE") { + Ok(PathBuf::from(p)) + } else { + def() + } +} + +fn get_histfile_path<D>(def: D) -> Result<PathBuf> +where + D: FnOnce() -> Result<PathBuf>, +{ + get_histpath(def).and_then(is_file) +} + +fn get_histdir_path<D>(def: D) -> Result<PathBuf> +where + D: FnOnce() -> Result<PathBuf>, +{ + get_histpath(def).and_then(is_dir) +} + +fn read_to_end(path: PathBuf) -> Result<Vec<u8>> { + let mut bytes = Vec::new(); + let mut f = File::open(path)?; + f.read_to_end(&mut bytes)?; + Ok(bytes) +} +fn is_file(p: PathBuf) -> Result<PathBuf> { + if p.is_file() { + Ok(p) + } else { + bail!( + "Could not find history file {:?}. Try setting and exporting $HISTFILE", + p + ) + } +} +fn is_dir(p: PathBuf) -> Result<PathBuf> { + if p.is_dir() { + Ok(p) + } else { + bail!( + "Could not find history directory {:?}. Try setting and exporting $HISTFILE", + p + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Default)] + pub struct TestLoader { + pub buf: Vec<History>, + } + + #[async_trait] + impl Loader for TestLoader { + async fn push(&mut self, hist: History) -> Result<()> { + self.buf.push(hist); + Ok(()) + } + } +} diff --git a/crates/turtle/src/atuin_client/import/nu.rs b/crates/turtle/src/atuin_client/import/nu.rs new file mode 100644 index 00000000..c93789b8 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/nu.rs @@ -0,0 +1,67 @@ +// import old shell history! +// automatically hoover up all that we can find + +use std::path::PathBuf; + +use async_trait::async_trait; +use directories::BaseDirs; +use eyre::{Result, eyre}; +use time::OffsetDateTime; + +use super::{Importer, Loader, unix_byte_lines}; +use crate::atuin_client::history::History; +use crate::atuin_client::import::read_to_end; + +#[derive(Debug)] +pub struct Nu { + bytes: Vec<u8>, +} + +fn get_histpath() -> Result<PathBuf> { + let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; + let config_dir = base.config_dir().join("nushell"); + + let histpath = config_dir.join("history.txt"); + if histpath.exists() { + Ok(histpath) + } else { + Err(eyre!("Could not find history file.")) + } +} + +#[async_trait] +impl Importer for Nu { + const NAME: &'static str = "nu"; + + async fn new() -> Result<Self> { + let bytes = read_to_end(get_histpath()?)?; + Ok(Self { bytes }) + } + + async fn entries(&mut self) -> Result<usize> { + Ok(super::count_lines(&self.bytes)) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + let now = OffsetDateTime::now_utc(); + + let mut counter = 0; + for b in unix_byte_lines(&self.bytes) { + let s = match std::str::from_utf8(b) { + Ok(s) => s, + Err(_) => continue, // we can skip past things like invalid utf8 + }; + + let cmd: String = s.replace("<\\n>", "\n"); + + let offset = time::Duration::nanoseconds(counter); + counter += 1; + + let entry = History::import().timestamp(now - offset).command(cmd); + + h.push(entry.build().into()).await?; + } + + Ok(()) + } +} diff --git a/crates/turtle/src/atuin_client/import/nu_histdb.rs b/crates/turtle/src/atuin_client/import/nu_histdb.rs new file mode 100644 index 00000000..7de18369 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/nu_histdb.rs @@ -0,0 +1,113 @@ +// import old shell history! +// automatically hoover up all that we can find + +use std::path::PathBuf; + +use async_trait::async_trait; +use directories::BaseDirs; +use eyre::{Result, eyre}; +use sqlx::{Pool, sqlite::SqlitePool}; +use time::{Duration, OffsetDateTime}; + +use super::Importer; +use crate::atuin_client::history::History; +use crate::atuin_client::import::Loader; + +#[derive(sqlx::FromRow, Debug)] +pub struct HistDbEntry { + pub id: i64, + pub command_line: Vec<u8>, + pub start_timestamp: i64, + pub session_id: i64, + pub hostname: Vec<u8>, + pub cwd: Vec<u8>, + pub duration_ms: i64, + pub exit_status: i64, + pub more_info: Vec<u8>, +} + +impl From<HistDbEntry> for History { + fn from(histdb_item: HistDbEntry) -> Self { + let ts_secs = histdb_item.start_timestamp / 1000; + let ts_ns = (histdb_item.start_timestamp % 1000) * 1_000_000; + let imported = History::import() + .timestamp( + OffsetDateTime::from_unix_timestamp(ts_secs).unwrap() + + Duration::nanoseconds(ts_ns), + ) + .command(String::from_utf8(histdb_item.command_line).unwrap()) + .cwd(String::from_utf8(histdb_item.cwd).unwrap()) + .exit(histdb_item.exit_status) + .duration(histdb_item.duration_ms) + .session(format!("{:x}", histdb_item.session_id)) + .hostname(String::from_utf8(histdb_item.hostname).unwrap()); + + imported.build().into() + } +} + +#[derive(Debug)] +pub struct NuHistDb { + histdb: Vec<HistDbEntry>, +} + +/// Read db at given file, return vector of entries. +async fn hist_from_db(dbpath: PathBuf) -> Result<Vec<HistDbEntry>> { + let pool = SqlitePool::connect(dbpath.to_str().unwrap()).await?; + hist_from_db_conn(pool).await +} + +async fn hist_from_db_conn(pool: Pool<sqlx::Sqlite>) -> Result<Vec<HistDbEntry>> { + let query = r#" + SELECT + id, command_line, start_timestamp, session_id, hostname, cwd, duration_ms, exit_status, + more_info + FROM history + ORDER BY start_timestamp + "#; + let histdb_vec: Vec<HistDbEntry> = sqlx::query_as::<_, HistDbEntry>(query) + .fetch_all(&pool) + .await?; + Ok(histdb_vec) +} + +impl NuHistDb { + pub fn histpath() -> Result<PathBuf> { + let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; + let config_dir = base.config_dir().join("nushell"); + + let histdb_path = config_dir.join("history.sqlite3"); + if histdb_path.exists() { + Ok(histdb_path) + } else { + Err(eyre!("Could not find history file.")) + } + } +} + +#[async_trait] +impl Importer for NuHistDb { + // Not sure how this is used + const NAME: &'static str = "nu_histdb"; + + /// Creates a new NuHistDb and populates the history based on the pre-populated data + /// structure. + async fn new() -> Result<Self> { + let dbpath = NuHistDb::histpath()?; + let histdb_entry_vec = hist_from_db(dbpath).await?; + Ok(Self { + histdb: histdb_entry_vec, + }) + } + + async fn entries(&mut self) -> Result<usize> { + Ok(self.histdb.len()) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + for i in self.histdb { + h.push(i.into()).await?; + } + Ok(()) + } +} diff --git a/crates/turtle/src/atuin_client/import/powershell.rs b/crates/turtle/src/atuin_client/import/powershell.rs new file mode 100644 index 00000000..8adcc850 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/powershell.rs @@ -0,0 +1,202 @@ +use async_trait::async_trait; +use directories::BaseDirs; +use eyre::{Result, eyre}; +use std::path::PathBuf; +use time::{Duration, OffsetDateTime}; + +use super::{Importer, Loader, count_lines, unix_byte_lines}; +use crate::atuin_client::history::History; +use crate::atuin_client::import::read_to_end; + +#[derive(Debug)] +pub struct PowerShell { + bytes: Vec<u8>, + line_count: Option<usize>, +} + +fn get_history_path() -> Result<PathBuf> { + let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; + + // The command line history in PowerShell is maintained by the PSReadLine module: + // https://learn.microsoft.com/en-us/powershell/module/psreadline/about/about_psreadline#command-history + // + // > PSReadLine maintains a history file containing all the commands and data you've entered from the command line. + // > The history files are a file named `$($Host.Name)_history.txt`. + // > On Windows systems the history file is stored at `$Env:APPDATA\Microsoft\Windows\PowerShell\PSReadLine`. + // > On non-Windows systems, the history files are stored at `$Env:XDG_DATA_HOME/powershell/PSReadLine` + // > or `$Env:HOME/.local/share/powershell/PSReadLine`. + + let dir = if cfg!(windows) { + base.data_dir() + .join("Microsoft") + .join("Windows") + .join("PowerShell") + .join("PSReadLine") + } else { + std::env::var("XDG_DATA_HOME") + .map_or_else( + |_| base.home_dir().join(".local").join("share"), + PathBuf::from, + ) + .join("powershell") + .join("PSReadLine") + }; + + // The history is stored in a file named `$($Host.Name)_history.txt`. + // For the default console host shipped by Microsoft,`$Host.Name` is `ConsoleHost`: + // https://learn.microsoft.com/en-us/dotnet/api/system.management.automation.host.pshost.name#remarks + + let file = dir.join("ConsoleHost_history.txt"); + + if file.is_file() { + Ok(file) + } else { + Err(eyre!("Could not find history file: {}", file.display())) + } +} + +#[async_trait] +impl Importer for PowerShell { + const NAME: &'static str = "PowerShell"; + + async fn new() -> Result<Self> { + let bytes = read_to_end(get_history_path()?)?; + Ok(Self { + bytes, + line_count: None, + }) + } + + async fn entries(&mut self) -> Result<usize> { + // Commands can be split over multiple lines, + // but this is only used for a progress bar, and multi-line commands + // should be quite rare, so this is not an issue in practice. + if self.line_count.is_none() { + self.line_count = Some(count_lines(&self.bytes)); + } + Ok(self.line_count.unwrap()) + } + + async fn load(mut self, h: &mut impl Loader) -> Result<()> { + let line_count = self.entries().await?; + let start = OffsetDateTime::now_utc() - Duration::milliseconds(line_count as i64); + + let mut counter = 0; + let mut iter = unix_byte_lines(&self.bytes); + + while let Some(s) = iter.next() { + let Ok(s) = read_line(s) else { + continue; // We can skip past things like invalid utf8 + }; + + let mut cmd = s.to_string(); + + // Multi-line commands end with a backtick, append the following lines. + while cmd.ends_with('`') { + cmd.pop(); + + let Some(next) = iter.next() else { + break; + }; + let Ok(next) = read_line(next) else { + break; + }; + + cmd.push('\n'); + cmd.push_str(next); + } + + if cmd.is_empty() { + continue; + } + + let offset = Duration::milliseconds(counter); + counter += 1; + + let entry = History::import().timestamp(start + offset).command(cmd); + h.push(entry.build().into()).await?; + } + + Ok(()) + } +} + +fn read_line(s: &[u8]) -> Result<&str> { + let s = str::from_utf8(s)?; + + // History is stored in CRLF on Windows, normalize the input to LF on all platforms. + let s = s.strip_suffix('\r').unwrap_or(s); + + Ok(s) +} + +#[cfg(test)] +mod test { + use super::*; + use crate::import::tests::TestLoader; + use itertools::assert_equal; + + const INPUT: &str = r#"cargo install atuin +cargo update +echo "first line` +second line` +` +last line" +echo foo + +echo bar +echo baz +"#; + + const EXPECTED: &[&str] = &[ + "cargo install atuin", + "cargo update", + "echo \"first line\nsecond line\n\nlast line\"", + "echo foo", + "echo bar", + "echo baz", + ]; + + #[tokio::test] + async fn test_import() { + let loader = import(INPUT).await; + + let actual = loader.buf.iter().map(|h| h.command.clone()); + let expected = EXPECTED.iter().map(|s| s.to_string()); + + assert_equal(actual, expected); + } + + #[tokio::test] + async fn test_crlf() { + let input = INPUT.replace("\n", "\r\n"); + let loader = import(input.as_str()).await; + + let actual = loader.buf.iter().map(|h| h.command.clone()); + let expected = EXPECTED.iter().map(|s| s.to_string()); + + assert_equal(actual, expected); + } + + #[tokio::test] + async fn test_timestamps() { + let loader = import(INPUT).await; + + let mut prev = loader.buf.first().unwrap().timestamp; + for current in loader.buf.iter().skip(1).map(|h| h.timestamp) { + assert!(current > prev); + prev = current; + } + } + + async fn import(input: &str) -> TestLoader { + let powershell = PowerShell { + bytes: input.as_bytes().to_vec(), + line_count: None, + }; + + let mut loader = TestLoader::default(); + powershell.load(&mut loader).await.unwrap(); + loader + } +} diff --git a/crates/turtle/src/atuin_client/import/replxx.rs b/crates/turtle/src/atuin_client/import/replxx.rs new file mode 100644 index 00000000..42f84df5 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/replxx.rs @@ -0,0 +1,137 @@ +use std::{path::PathBuf, str}; + +use async_trait::async_trait; +use directories::UserDirs; +use eyre::{Result, eyre}; +use time::{OffsetDateTime, PrimitiveDateTime, macros::format_description}; + +use super::{Importer, Loader, get_histfile_path, unix_byte_lines}; +use crate::atuin_client::history::History; +use crate::atuin_client::import::read_to_end; + +#[derive(Debug)] +pub struct Replxx { + bytes: Vec<u8>, +} + +fn default_histpath() -> Result<PathBuf> { + let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; + let home_dir = user_dirs.home_dir(); + + // There is no default histfile for replxx. + // Here we try a couple of common names. + let mut candidates = ["replxx_history.txt", ".histfile"].iter(); + loop { + match candidates.next() { + Some(candidate) => { + let histpath = home_dir.join(candidate); + if histpath.exists() { + break Ok(histpath); + } + } + None => { + break Err(eyre!( + "Could not find history file. Try setting and exporting $HISTFILE" + )); + } + } + } +} + +#[async_trait] +impl Importer for Replxx { + const NAME: &'static str = "replxx"; + + async fn new() -> Result<Self> { + let bytes = read_to_end(get_histfile_path(default_histpath)?)?; + Ok(Self { bytes }) + } + + async fn entries(&mut self) -> Result<usize> { + Ok(super::count_lines(&self.bytes) / 2) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + let mut timestamp = OffsetDateTime::UNIX_EPOCH; + + for b in unix_byte_lines(&self.bytes) { + let s = std::str::from_utf8(b)?; + match try_parse_line_as_timestamp(s) { + Some(t) => timestamp = t, + None => { + // replxx uses ETB character (0x17) as line breaker + let cmd = s.replace('\u{0017}', "\n"); + let imported = History::import().timestamp(timestamp).command(cmd); + + h.push(imported.build().into()).await?; + } + } + } + + Ok(()) + } +} + +fn try_parse_line_as_timestamp(line: &str) -> Option<OffsetDateTime> { + // replxx history date time format: ### yyyy-mm-dd hh:mm:ss.xxx + let date_time_str = line.strip_prefix("### ")?; + let format = + format_description!("[year]-[month]-[day] [hour]:[minute]:[second].[subsecond digits:3]"); + + let primitive_date_time = PrimitiveDateTime::parse(date_time_str, format).ok()?; + // There is no safe way to get local time offset. + // For simplicity let's just assume UTC. + Some(primitive_date_time.assume_utc()) +} + +#[cfg(test)] +mod test { + + use crate::import::{Importer, tests::TestLoader}; + + use super::Replxx; + + #[tokio::test] + async fn parse_complex() { + let bytes = r#"### 2024-02-10 22:16:28.302 +select * from remote('127.0.0.1:20222', view(select 1)) +### 2024-02-10 22:16:36.919 +select * from numbers(10) +### 2024-02-10 22:16:41.710 +select * from system.numbers +### 2024-02-10 22:19:28.655 +select 1 +### 2024-02-22 11:15:33.046 +CREATE TABLE test( stamp DateTime('UTC'))ENGINE = MergeTreePARTITION BY toDate(stamp)order by tuple() as select toDateTime('2020-01-01')+number*60 from numbers(80000); +"# + .as_bytes() + .to_owned(); + + let replxx = Replxx { bytes }; + + let mut loader = TestLoader::default(); + replxx.load(&mut loader).await.unwrap(); + let mut history = loader.buf.into_iter(); + + // simple wrapper for replxx history entry + macro_rules! history { + ($timestamp:expr_2021, $command:expr_2021) => { + let h = history.next().expect("missing entry in history"); + assert_eq!(h.command.as_str(), $command); + assert_eq!(h.timestamp.unix_timestamp(), $timestamp); + }; + } + + history!( + 1707603388, + "select * from remote('127.0.0.1:20222', view(select 1))" + ); + history!(1707603396, "select * from numbers(10)"); + history!(1707603401, "select * from system.numbers"); + history!(1707603568, "select 1"); + history!( + 1708600533, + "CREATE TABLE test\n( stamp DateTime('UTC'))\nENGINE = MergeTree\nPARTITION BY toDate(stamp)\norder by tuple() as select toDateTime('2020-01-01')+number*60 from numbers(80000);" + ); + } +} diff --git a/crates/turtle/src/atuin_client/import/resh.rs b/crates/turtle/src/atuin_client/import/resh.rs new file mode 100644 index 00000000..c5980c44 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/resh.rs @@ -0,0 +1,140 @@ +use std::path::PathBuf; + +use async_trait::async_trait; +use directories::UserDirs; +use eyre::{Result, eyre}; +use serde::Deserialize; + +use crate::atuin_common::utils::uuid_v7; +use time::OffsetDateTime; + +use super::{Importer, Loader, get_histfile_path, unix_byte_lines}; +use crate::atuin_client::history::History; +use crate::atuin_client::import::read_to_end; + +#[derive(Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct ReshEntry { + pub cmd_line: String, + pub exit_code: i64, + pub shell: String, + pub uname: String, + pub session_id: String, + pub home: String, + pub lang: String, + pub lc_all: String, + pub login: String, + pub pwd: String, + pub pwd_after: String, + pub shell_env: String, + pub term: String, + pub real_pwd: String, + pub real_pwd_after: String, + pub pid: i64, + pub session_pid: i64, + pub host: String, + pub hosttype: String, + pub ostype: String, + pub machtype: String, + pub shlvl: i64, + pub timezone_before: String, + pub timezone_after: String, + pub realtime_before: f64, + pub realtime_after: f64, + pub realtime_before_local: f64, + pub realtime_after_local: f64, + pub realtime_duration: f64, + pub realtime_since_session_start: f64, + pub realtime_since_boot: f64, + pub git_dir: String, + pub git_real_dir: String, + pub git_origin_remote: String, + pub git_dir_after: String, + pub git_real_dir_after: String, + pub git_origin_remote_after: String, + pub machine_id: String, + pub os_release_id: String, + pub os_release_version_id: String, + pub os_release_id_like: String, + pub os_release_name: String, + pub os_release_pretty_name: String, + pub resh_uuid: String, + pub resh_version: String, + pub resh_revision: String, + pub parts_merged: bool, + pub recalled: bool, + pub recall_last_cmd_line: String, + pub cols: String, + pub lines: String, +} + +#[derive(Debug)] +pub struct Resh { + bytes: Vec<u8>, +} + +fn default_histpath() -> Result<PathBuf> { + let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; + let home_dir = user_dirs.home_dir(); + + Ok(home_dir.join(".resh_history.json")) +} + +#[async_trait] +impl Importer for Resh { + const NAME: &'static str = "resh"; + + async fn new() -> Result<Self> { + let bytes = read_to_end(get_histfile_path(default_histpath)?)?; + Ok(Self { bytes }) + } + + async fn entries(&mut self) -> Result<usize> { + Ok(super::count_lines(&self.bytes)) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + for b in unix_byte_lines(&self.bytes) { + let s = match std::str::from_utf8(b) { + Ok(s) => s, + Err(_) => continue, // we can skip past things like invalid utf8 + }; + let entry = match serde_json::from_str::<ReshEntry>(s) { + Ok(e) => e, + Err(_) => continue, // skip invalid json :shrug: + }; + + #[expect(clippy::cast_possible_truncation)] + #[expect(clippy::cast_sign_loss)] + let timestamp = { + let secs = entry.realtime_before.floor() as i64; + let nanosecs = (entry.realtime_before.fract() * 1_000_000_000_f64).round() as i64; + OffsetDateTime::from_unix_timestamp(secs)? + time::Duration::nanoseconds(nanosecs) + }; + #[expect(clippy::cast_possible_truncation)] + #[expect(clippy::cast_sign_loss)] + let duration = { + let secs = entry.realtime_after.floor() as i64; + let nanosecs = (entry.realtime_after.fract() * 1_000_000_000_f64).round() as i64; + let base = OffsetDateTime::from_unix_timestamp(secs)? + + time::Duration::nanoseconds(nanosecs); + let difference = base - timestamp; + difference.whole_nanoseconds() as i64 + }; + + let imported = History::import() + .command(entry.cmd_line) + .timestamp(timestamp) + .duration(duration) + .exit(entry.exit_code) + .cwd(entry.pwd) + .hostname(entry.host) + // CHECK: should we add uuid here? It's not set in the other importers + .session(uuid_v7().as_simple().to_string()); + + h.push(imported.build().into()).await?; + } + + Ok(()) + } +} diff --git a/crates/turtle/src/atuin_client/import/xonsh.rs b/crates/turtle/src/atuin_client/import/xonsh.rs new file mode 100644 index 00000000..a7217826 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/xonsh.rs @@ -0,0 +1,234 @@ +use std::env; +use std::fs::{self, File}; +use std::path::{Path, PathBuf}; + +use async_trait::async_trait; +use directories::BaseDirs; +use eyre::{Result, eyre}; +use serde::Deserialize; +use time::OffsetDateTime; +use uuid::Uuid; +use uuid::timestamp::{Timestamp, context::NoContext}; + +use super::{Importer, Loader, get_histdir_path}; +use crate::atuin_client::history::History; +use crate::atuin_client::utils::get_host_user; + +// Note: both HistoryFile and HistoryData have other keys present in the JSON, we don't +// care about them so we leave them unspecified so as to avoid deserializing unnecessarily. +#[derive(Debug, Deserialize)] +struct HistoryFile { + data: HistoryData, +} + +#[derive(Debug, Deserialize)] +struct HistoryData { + sessionid: String, + cmds: Vec<HistoryCmd>, +} + +#[derive(Debug, Deserialize)] +struct HistoryCmd { + cwd: String, + inp: String, + rtn: Option<i64>, + ts: (f64, f64), +} + +#[derive(Debug)] +pub struct Xonsh { + // history is stored as a bunch of json files, one per session + sessions: Vec<HistoryData>, + hostname: String, +} + +fn xonsh_hist_dir(xonsh_data_dir: Option<String>) -> Result<PathBuf> { + // if running within xonsh, this will be available + if let Some(d) = xonsh_data_dir { + let mut path = PathBuf::from(d); + path.push("history_json"); + return Ok(path); + } + + // otherwise, fall back to default + let base = BaseDirs::new().ok_or_else(|| eyre!("Could not determine home directory"))?; + + let hist_dir = base.data_dir().join("xonsh/history_json"); + if hist_dir.exists() || cfg!(test) { + Ok(hist_dir) + } else { + Err(eyre!("Could not find xonsh history files")) + } +} + +fn load_sessions(hist_dir: &Path) -> Result<Vec<HistoryData>> { + let mut sessions = vec![]; + for entry in fs::read_dir(hist_dir)? { + let p = entry?.path(); + let ext = p.extension().and_then(|e| e.to_str()); + if p.is_file() + && ext == Some("json") + && let Some(data) = load_session(&p)? + { + sessions.push(data); + } + } + Ok(sessions) +} + +fn load_session(path: &Path) -> Result<Option<HistoryData>> { + let file = File::open(path)?; + // empty files are not valid json, so we can't deserialize them + if file.metadata()?.len() == 0 { + return Ok(None); + } + + let mut hist_file: HistoryFile = serde_json::from_reader(file)?; + + // if there are commands in this session, replace the existing UUIDv4 + // with a UUIDv7 generated from the timestamp of the first command + if let Some(cmd) = hist_file.data.cmds.first() { + let seconds = cmd.ts.0.trunc() as u64; + let nanos = (cmd.ts.0.fract() * 1_000_000_000_f64) as u32; + let ts = Timestamp::from_unix(NoContext, seconds, nanos); + hist_file.data.sessionid = Uuid::new_v7(ts).to_string(); + } + Ok(Some(hist_file.data)) +} + +#[async_trait] +impl Importer for Xonsh { + const NAME: &'static str = "xonsh"; + + async fn new() -> Result<Self> { + // wrap xonsh-specific path resolver in general one so that it respects $HISTPATH + let xonsh_data_dir = env::var("XONSH_DATA_DIR").ok(); + let hist_dir = get_histdir_path(|| xonsh_hist_dir(xonsh_data_dir))?; + let sessions = load_sessions(&hist_dir)?; + let hostname = get_host_user(); + Ok(Xonsh { sessions, hostname }) + } + + async fn entries(&mut self) -> Result<usize> { + let total = self.sessions.iter().map(|s| s.cmds.len()).sum(); + Ok(total) + } + + async fn load(self, loader: &mut impl Loader) -> Result<()> { + for session in self.sessions { + for cmd in session.cmds { + let (start, end) = cmd.ts; + let ts_nanos = (start * 1_000_000_000_f64) as i128; + let timestamp = OffsetDateTime::from_unix_timestamp_nanos(ts_nanos)?; + + let duration = (end - start) * 1_000_000_000_f64; + + match cmd.rtn { + Some(exit) => { + let entry = History::import() + .timestamp(timestamp) + .duration(duration.trunc() as i64) + .exit(exit) + .command(cmd.inp.trim()) + .cwd(cmd.cwd) + .session(session.sessionid.clone()) + .hostname(self.hostname.clone()); + loader.push(entry.build().into()).await?; + } + None => { + let entry = History::import() + .timestamp(timestamp) + .duration(duration.trunc() as i64) + .command(cmd.inp.trim()) + .cwd(cmd.cwd) + .session(session.sessionid.clone()) + .hostname(self.hostname.clone()); + loader.push(entry.build().into()).await?; + } + } + } + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use time::macros::datetime; + + use super::*; + + use crate::history::History; + use crate::import::tests::TestLoader; + + #[test] + fn test_hist_dir_xonsh() { + let hist_dir = xonsh_hist_dir(Some("/home/user/xonsh_data".to_string())).unwrap(); + assert_eq!( + hist_dir, + PathBuf::from("/home/user/xonsh_data/history_json") + ); + } + + #[tokio::test] + async fn test_import() { + let dir = PathBuf::from("tests/data/xonsh"); + let sessions = load_sessions(&dir).unwrap(); + let hostname = "box:user".to_string(); + let xonsh = Xonsh { sessions, hostname }; + + let mut loader = TestLoader::default(); + xonsh.load(&mut loader).await.unwrap(); + // order in buf will depend on filenames, so sort by timestamp for consistency + loader.buf.sort_by_key(|h| h.timestamp); + for (actual, expected) in loader.buf.iter().zip(expected_hist_entries().iter()) { + assert_eq!(actual.timestamp, expected.timestamp); + assert_eq!(actual.command, expected.command); + assert_eq!(actual.cwd, expected.cwd); + assert_eq!(actual.exit, expected.exit); + assert_eq!(actual.duration, expected.duration); + assert_eq!(actual.hostname, expected.hostname); + } + } + + fn expected_hist_entries() -> [History; 4] { + [ + History::import() + .timestamp(datetime!(2024-02-6 04:17:59.478272256 +00:00:00)) + .command("echo hello world!".to_string()) + .cwd("/home/user/Documents/code/atuin".to_string()) + .exit(0) + .duration(4651069) + .hostname("box:user".to_string()) + .build() + .into(), + History::import() + .timestamp(datetime!(2024-02-06 04:18:01.70632832 +00:00:00)) + .command("ls -l".to_string()) + .cwd("/home/user/Documents/code/atuin".to_string()) + .exit(0) + .duration(21288633) + .hostname("box:user".to_string()) + .build() + .into(), + History::import() + .timestamp(datetime!(2024-02-06 17:41:31.142515968 +00:00:00)) + .command("false".to_string()) + .cwd("/home/user/Documents/code/atuin/atuin-client".to_string()) + .exit(1) + .duration(10269403) + .hostname("box:user".to_string()) + .build() + .into(), + History::import() + .timestamp(datetime!(2024-02-06 17:41:32.271584 +00:00:00)) + .command("exit".to_string()) + .cwd("/home/user/Documents/code/atuin/atuin-client".to_string()) + .exit(0) + .duration(4259347) + .hostname("box:user".to_string()) + .build() + .into(), + ] + } +} diff --git a/crates/turtle/src/atuin_client/import/xonsh_sqlite.rs b/crates/turtle/src/atuin_client/import/xonsh_sqlite.rs new file mode 100644 index 00000000..ceedf7e9 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/xonsh_sqlite.rs @@ -0,0 +1,217 @@ +use std::env; +use std::path::PathBuf; + +use async_trait::async_trait; +use directories::BaseDirs; +use eyre::{Result, eyre}; +use futures::TryStreamExt; +use sqlx::{FromRow, Row, sqlite::SqlitePool}; +use time::OffsetDateTime; +use uuid::Uuid; +use uuid::timestamp::{Timestamp, context::NoContext}; + +use super::{Importer, Loader, get_histfile_path}; +use crate::atuin_client::history::History; +use crate::atuin_client::utils::get_host_user; + +#[derive(Debug, FromRow)] +struct HistDbEntry { + inp: String, + rtn: Option<i64>, + tsb: f64, + tse: f64, + cwd: String, + session_start: f64, +} + +impl HistDbEntry { + fn into_hist_with_hostname(self, hostname: String) -> History { + let ts_nanos = (self.tsb * 1_000_000_000_f64) as i128; + let timestamp = OffsetDateTime::from_unix_timestamp_nanos(ts_nanos).unwrap(); + + let session_ts_seconds = self.session_start.trunc() as u64; + let session_ts_nanos = (self.session_start.fract() * 1_000_000_000_f64) as u32; + let session_ts = Timestamp::from_unix(NoContext, session_ts_seconds, session_ts_nanos); + let session_id = Uuid::new_v7(session_ts).to_string(); + let duration = (self.tse - self.tsb) * 1_000_000_000_f64; + + if let Some(exit) = self.rtn { + let imported = History::import() + .timestamp(timestamp) + .duration(duration.trunc() as i64) + .exit(exit) + .command(self.inp) + .cwd(self.cwd) + .session(session_id) + .hostname(hostname); + imported.build().into() + } else { + let imported = History::import() + .timestamp(timestamp) + .duration(duration.trunc() as i64) + .command(self.inp) + .cwd(self.cwd) + .session(session_id) + .hostname(hostname); + imported.build().into() + } + } +} + +fn xonsh_db_path(xonsh_data_dir: Option<String>) -> Result<PathBuf> { + // if running within xonsh, this will be available + if let Some(d) = xonsh_data_dir { + let mut path = PathBuf::from(d); + path.push("xonsh-history.sqlite"); + return Ok(path); + } + + // otherwise, fall back to default + let base = BaseDirs::new().ok_or_else(|| eyre!("Could not determine home directory"))?; + + let hist_file = base.data_dir().join("xonsh/xonsh-history.sqlite"); + if hist_file.exists() || cfg!(test) { + Ok(hist_file) + } else { + Err(eyre!( + "Could not find xonsh history db at: {}", + hist_file.to_string_lossy() + )) + } +} + +#[derive(Debug)] +pub struct XonshSqlite { + pool: SqlitePool, + hostname: String, +} + +#[async_trait] +impl Importer for XonshSqlite { + const NAME: &'static str = "xonsh_sqlite"; + + async fn new() -> Result<Self> { + // wrap xonsh-specific path resolver in general one so that it respects $HISTPATH + let xonsh_data_dir = env::var("XONSH_DATA_DIR").ok(); + let db_path = get_histfile_path(|| xonsh_db_path(xonsh_data_dir))?; + let connection_str = db_path.to_str().ok_or_else(|| { + eyre!( + "Invalid path for SQLite database: {}", + db_path.to_string_lossy() + ) + })?; + + let pool = SqlitePool::connect(connection_str).await?; + let hostname = get_host_user(); + Ok(XonshSqlite { pool, hostname }) + } + + async fn entries(&mut self) -> Result<usize> { + let query = "SELECT COUNT(*) FROM xonsh_history"; + let row = sqlx::query(query).fetch_one(&self.pool).await?; + let count: u32 = row.get(0); + Ok(count as usize) + } + + async fn load(self, loader: &mut impl Loader) -> Result<()> { + let query = r#" + SELECT inp, rtn, tsb, tse, cwd, + MIN(tsb) OVER (PARTITION BY sessionid) AS session_start + FROM xonsh_history + ORDER BY rowid + "#; + + let mut entries = sqlx::query_as::<_, HistDbEntry>(query).fetch(&self.pool); + + let mut count = 0; + while let Some(entry) = entries.try_next().await? { + let hist = entry.into_hist_with_hostname(self.hostname.clone()); + loader.push(hist).await?; + count += 1; + } + + println!("Loaded: {count}"); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use time::macros::datetime; + + use super::*; + + use crate::history::History; + use crate::import::tests::TestLoader; + + #[test] + fn test_db_path_xonsh() { + let db_path = xonsh_db_path(Some("/home/user/xonsh_data".to_string())).unwrap(); + assert_eq!( + db_path, + PathBuf::from("/home/user/xonsh_data/xonsh-history.sqlite") + ); + } + + #[tokio::test] + async fn test_import() { + let connection_str = "tests/data/xonsh-history.sqlite"; + let xonsh_sqlite = XonshSqlite { + pool: SqlitePool::connect(connection_str).await.unwrap(), + hostname: "box:user".to_string(), + }; + + let mut loader = TestLoader::default(); + xonsh_sqlite.load(&mut loader).await.unwrap(); + + for (actual, expected) in loader.buf.iter().zip(expected_hist_entries().iter()) { + assert_eq!(actual.timestamp, expected.timestamp); + assert_eq!(actual.command, expected.command); + assert_eq!(actual.cwd, expected.cwd); + assert_eq!(actual.exit, expected.exit); + assert_eq!(actual.duration, expected.duration); + assert_eq!(actual.hostname, expected.hostname); + } + } + + fn expected_hist_entries() -> [History; 4] { + [ + History::import() + .timestamp(datetime!(2024-02-6 17:56:21.130956288 +00:00:00)) + .command("echo hello world!".to_string()) + .cwd("/home/user/Documents/code/atuin".to_string()) + .exit(0) + .duration(2628564) + .hostname("box:user".to_string()) + .build() + .into(), + History::import() + .timestamp(datetime!(2024-02-06 17:56:28.190406144 +00:00:00)) + .command("ls -l".to_string()) + .cwd("/home/user/Documents/code/atuin".to_string()) + .exit(0) + .duration(9371519) + .hostname("box:user".to_string()) + .build() + .into(), + History::import() + .timestamp(datetime!(2024-02-06 17:56:46.989020928 +00:00:00)) + .command("false".to_string()) + .cwd("/home/user/Documents/code/atuin".to_string()) + .exit(1) + .duration(17337560) + .hostname("box:user".to_string()) + .build() + .into(), + History::import() + .timestamp(datetime!(2024-02-06 17:56:48.218384128 +00:00:00)) + .command("exit".to_string()) + .cwd("/home/user/Documents/code/atuin".to_string()) + .exit(0) + .duration(4599094) + .hostname("box:user".to_string()) + .build() + .into(), + ] + } +} diff --git a/crates/turtle/src/atuin_client/import/zsh.rs b/crates/turtle/src/atuin_client/import/zsh.rs new file mode 100644 index 00000000..e1fd813a --- /dev/null +++ b/crates/turtle/src/atuin_client/import/zsh.rs @@ -0,0 +1,230 @@ +// import old shell history! +// automatically hoover up all that we can find + +use std::borrow::Cow; +use std::path::PathBuf; + +use async_trait::async_trait; +use directories::UserDirs; +use eyre::{Result, eyre}; +use time::OffsetDateTime; + +use super::{Importer, Loader, get_histfile_path, unix_byte_lines}; +use crate::atuin_client::history::History; +use crate::atuin_client::import::read_to_end; + +#[derive(Debug)] +pub struct Zsh { + bytes: Vec<u8>, +} + +fn default_histpath() -> Result<PathBuf> { + // oh-my-zsh sets HISTFILE=~/.zhistory + // zsh has no default value for this var, but uses ~/.zhistory. + // zsh-newuser-install propose as default .histfile https://github.com/zsh-users/zsh/blob/master/Functions/Newuser/zsh-newuser-install#L794 + // we could maybe be smarter about this in the future :) + let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; + let home_dir = user_dirs.home_dir(); + + let mut candidates = [".zhistory", ".zsh_history", ".histfile"].iter(); + loop { + match candidates.next() { + Some(candidate) => { + let histpath = home_dir.join(candidate); + if histpath.exists() { + break Ok(histpath); + } + } + None => { + break Err(eyre!( + "Could not find history file. Try setting and exporting $HISTFILE" + )); + } + } + } +} + +#[async_trait] +impl Importer for Zsh { + const NAME: &'static str = "zsh"; + + async fn new() -> Result<Self> { + let bytes = read_to_end(get_histfile_path(default_histpath)?)?; + Ok(Self { bytes }) + } + + async fn entries(&mut self) -> Result<usize> { + Ok(super::count_lines(&self.bytes)) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + let now = OffsetDateTime::now_utc(); + let mut line = String::new(); + + let mut counter = 0; + for b in unix_byte_lines(&self.bytes) { + let s = match unmetafy(b) { + Some(s) => s, + _ => continue, // we can skip past things like invalid utf8 + }; + + if let Some(s) = s.strip_suffix('\\') { + line.push_str(s); + line.push('\n'); + } else { + line.push_str(&s); + let command = std::mem::take(&mut line); + + if let Some(command) = command.strip_prefix(": ") { + counter += 1; + h.push(parse_extended(command, counter)).await?; + } else { + let offset = time::Duration::seconds(counter); + counter += 1; + + let imported = History::import() + // preserve ordering + .timestamp(now - offset) + .command(command.trim_end().to_string()); + + h.push(imported.build().into()).await?; + } + } + } + + Ok(()) + } +} + +fn parse_extended(line: &str, counter: i64) -> History { + let (time, duration) = line.split_once(':').unwrap(); + let (duration, command) = duration.split_once(';').unwrap(); + + let time = time + .parse::<i64>() + .ok() + .and_then(|t| OffsetDateTime::from_unix_timestamp(t).ok()) + .unwrap_or_else(OffsetDateTime::now_utc) + + time::Duration::milliseconds(counter); + + // use nanos, because why the hell not? we won't display them. + let duration = duration.parse::<i64>().map_or(-1, |t| t * 1_000_000_000); + + let imported = History::import() + .timestamp(time) + .command(command.trim_end().to_string()) + .duration(duration); + + imported.build().into() +} + +fn unmetafy(line: &[u8]) -> Option<Cow<'_, str>> { + if line.contains(&0x83) { + let mut s = Vec::with_capacity(line.len()); + let mut is_meta = false; + for ch in line { + if *ch == 0x83 { + is_meta = true; + } else if is_meta { + is_meta = false; + s.push(*ch ^ 32); + } else { + s.push(*ch) + } + } + String::from_utf8(s).ok().map(Cow::Owned) + } else { + std::str::from_utf8(line).ok().map(Cow::Borrowed) + } +} + +#[cfg(test)] +mod test { + use itertools::assert_equal; + + use crate::import::tests::TestLoader; + + use super::*; + + #[test] + fn test_parse_extended_simple() { + let parsed = parse_extended("1613322469:0;cargo install atuin", 0); + + assert_eq!(parsed.command, "cargo install atuin"); + assert_eq!(parsed.duration, 0); + assert_eq!( + parsed.timestamp, + OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() + ); + + let parsed = parse_extended("1613322469:10;cargo install atuin;cargo update", 0); + + assert_eq!(parsed.command, "cargo install atuin;cargo update"); + assert_eq!(parsed.duration, 10_000_000_000); + assert_eq!( + parsed.timestamp, + OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() + ); + + let parsed = parse_extended("1613322469:10;cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", 0); + + assert_eq!(parsed.command, "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷"); + assert_eq!(parsed.duration, 10_000_000_000); + assert_eq!( + parsed.timestamp, + OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() + ); + + let parsed = parse_extended("1613322469:10;cargo install \\n atuin\n", 0); + + assert_eq!(parsed.command, "cargo install \\n atuin"); + assert_eq!(parsed.duration, 10_000_000_000); + assert_eq!( + parsed.timestamp, + OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() + ); + } + + #[tokio::test] + async fn test_parse_file() { + let bytes = r": 1613322469:0;cargo install atuin +: 1613322469:10;cargo install atuin; \\ +cargo update +: 1613322469:10;cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷ +" + .as_bytes() + .to_owned(); + + let mut zsh = Zsh { bytes }; + assert_eq!(zsh.entries().await.unwrap(), 4); + + let mut loader = TestLoader::default(); + zsh.load(&mut loader).await.unwrap(); + + assert_equal( + loader.buf.iter().map(|h| h.command.as_str()), + [ + "cargo install atuin", + "cargo install atuin; \\\ncargo update", + "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", + ], + ); + } + + #[tokio::test] + async fn test_parse_metafied() { + let bytes = + b"echo \xe4\xbd\x83\x80\xe5\xa5\xbd\nls ~/\xe9\x83\xbf\xb3\xe4\xb9\x83\xb0\n".to_vec(); + + let mut zsh = Zsh { bytes }; + assert_eq!(zsh.entries().await.unwrap(), 2); + + let mut loader = TestLoader::default(); + zsh.load(&mut loader).await.unwrap(); + + assert_equal( + loader.buf.iter().map(|h| h.command.as_str()), + ["echo 你好", "ls ~/音乐"], + ); + } +} diff --git a/crates/turtle/src/atuin_client/import/zsh_histdb.rs b/crates/turtle/src/atuin_client/import/zsh_histdb.rs new file mode 100644 index 00000000..f61bb74f --- /dev/null +++ b/crates/turtle/src/atuin_client/import/zsh_histdb.rs @@ -0,0 +1,249 @@ +// import old shell history from zsh-histdb! +// automatically hoover up all that we can find + +// As far as i can tell there are no version numbers in the histdb sqlite DB, so we're going based +// on the schema from 2022-05-01 +// +// I have run into some histories that will not import b/c of non UTF-8 characters. +// + +// +// An Example sqlite query for hsitdb data: +// +//id|session|command_id|place_id|exit_status|start_time|duration|id|argv|id|host|dir +// +// +// select +// history.id, +// history.start_time, +// places.host, +// places.dir, +// commands.argv +// from history +// left join commands on history.command_id = commands.id +// left join places on history.place_id = places.id ; +// +// CREATE TABLE history (id integer primary key autoincrement, +// session int, +// command_id int references commands (id), +// place_id int references places (id), +// exit_status int, +// start_time int, +// duration int); +// + +use std::collections::HashMap; +use std::path::{Path, PathBuf}; + +use async_trait::async_trait; +use crate::atuin_common::utils::uuid_v7; +use directories::UserDirs; +use eyre::{Result, eyre}; +use sqlx::{Pool, sqlite::SqlitePool}; +use time::PrimitiveDateTime; + +use super::Importer; +use crate::atuin_client::history::History; +use crate::atuin_client::import::Loader; +use crate::atuin_client::utils::{get_hostname, get_username}; + +#[derive(sqlx::FromRow, Debug)] +pub struct HistDbEntryCount { + pub count: usize, +} + +#[derive(sqlx::FromRow, Debug)] +pub struct HistDbEntry { + pub id: i64, + pub start_time: PrimitiveDateTime, + pub host: Vec<u8>, + pub dir: Vec<u8>, + pub argv: Vec<u8>, + pub duration: i64, + pub exit_status: i64, + pub session: i64, +} + +#[derive(Debug)] +pub struct ZshHistDb { + histdb: Vec<HistDbEntry>, + username: String, +} + +/// Read db at given file, return vector of entries. +async fn hist_from_db(dbpath: PathBuf) -> Result<Vec<HistDbEntry>> { + let pool = SqlitePool::connect(dbpath.to_str().unwrap()).await?; + hist_from_db_conn(pool).await +} + +async fn hist_from_db_conn(pool: Pool<sqlx::Sqlite>) -> Result<Vec<HistDbEntry>> { + let query = r#" + SELECT + history.id, history.start_time, history.duration, places.host, places.dir, + commands.argv, history.exit_status, history.session + FROM history + LEFT JOIN commands ON history.command_id = commands.id + LEFT JOIN places ON history.place_id = places.id + ORDER BY history.start_time + "#; + let histdb_vec: Vec<HistDbEntry> = sqlx::query_as::<_, HistDbEntry>(query) + .fetch_all(&pool) + .await?; + Ok(histdb_vec) +} + +impl ZshHistDb { + pub fn histpath_candidate() -> PathBuf { + // By default histdb database is `${HOME}/.histdb/zsh-history.db` + // This can be modified by ${HISTDB_FILE} + // + // if [[ -z ${HISTDB_FILE} ]]; then + // typeset -g HISTDB_FILE="${HOME}/.histdb/zsh-history.db" + let user_dirs = UserDirs::new().unwrap(); // should catch error here? + let home_dir = user_dirs.home_dir(); + std::env::var("HISTDB_FILE") + .as_ref() + .map(|x| Path::new(x).to_path_buf()) + .unwrap_or_else(|_err| home_dir.join(".histdb/zsh-history.db")) + } + pub fn histpath() -> Result<PathBuf> { + let histdb_path = ZshHistDb::histpath_candidate(); + if histdb_path.exists() { + Ok(histdb_path) + } else { + Err(eyre!( + "Could not find history file. Try setting $HISTDB_FILE" + )) + } + } +} + +#[async_trait] +impl Importer for ZshHistDb { + // Not sure how this is used + const NAME: &'static str = "zsh_histdb"; + + /// Creates a new ZshHistDb and populates the history based on the pre-populated data + /// structure. + async fn new() -> Result<Self> { + let dbpath = ZshHistDb::histpath()?; + let histdb_entry_vec = hist_from_db(dbpath).await?; + Ok(Self { + histdb: histdb_entry_vec, + username: get_username(), + }) + } + + async fn entries(&mut self) -> Result<usize> { + Ok(self.histdb.len()) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + let mut session_map = HashMap::new(); + for entry in self.histdb { + let command = match std::str::from_utf8(&entry.argv) { + Ok(s) => s.trim_end(), + Err(_) => continue, // we can skip past things like invalid utf8 + }; + let cwd = match std::str::from_utf8(&entry.dir) { + Ok(s) => s.trim_end(), + Err(_) => continue, // we can skip past things like invalid utf8 + }; + let hostname = format!( + "{}:{}", + String::from_utf8(entry.host).unwrap_or_else(|_e| get_hostname()), + self.username + ); + let session = session_map.entry(entry.session).or_insert_with(uuid_v7); + + let imported = History::import() + .timestamp(entry.start_time.assume_utc()) + .command(command) + .cwd(cwd) + .duration(entry.duration * 1_000_000_000) + .exit(entry.exit_status) + .session(session.as_simple().to_string()) + .hostname(hostname) + .build(); + h.push(imported.into()).await?; + } + Ok(()) + } +} + +#[cfg(test)] +mod test { + + use super::*; + use sqlx::sqlite::SqlitePoolOptions; + use std::env; + #[tokio::test(flavor = "multi_thread")] + #[expect(unsafe_code)] + async fn test_env_vars() { + let test_env_db = "nonstd-zsh-history.db"; + let key = "HISTDB_FILE"; + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::set_var(key, test_env_db) }; + + // test the env got set + assert_eq!(env::var(key).unwrap(), test_env_db.to_string()); + + // test histdb returns the proper db from previous step + let histdb_path = ZshHistDb::histpath_candidate(); + assert_eq!(histdb_path.to_str().unwrap(), test_env_db); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_import() { + let pool: SqlitePool = SqlitePoolOptions::new() + .min_connections(2) + .connect(":memory:") + .await + .unwrap(); + + // sql dump directly from a test database. + let db_sql = r#" + PRAGMA foreign_keys=OFF; + BEGIN TRANSACTION; + CREATE TABLE commands (id integer primary key autoincrement, argv text, unique(argv) on conflict ignore); + INSERT INTO commands VALUES(1,'pwd'); + INSERT INTO commands VALUES(2,'curl google.com'); + INSERT INTO commands VALUES(3,'bash'); + CREATE TABLE places (id integer primary key autoincrement, host text, dir text, unique(host, dir) on conflict ignore); + INSERT INTO places VALUES(1,'mbp16.local','/home/noyez'); + CREATE TABLE history (id integer primary key autoincrement, + session int, + command_id int references commands (id), + place_id int references places (id), + exit_status int, + start_time int, + duration int); + INSERT INTO history VALUES(1,0,1,1,0,1651497918,1); + INSERT INTO history VALUES(2,0,2,1,0,1651497923,1); + INSERT INTO history VALUES(3,0,3,1,NULL,1651497930,NULL); + DELETE FROM sqlite_sequence; + INSERT INTO sqlite_sequence VALUES('commands',3); + INSERT INTO sqlite_sequence VALUES('places',3); + INSERT INTO sqlite_sequence VALUES('history',3); + CREATE INDEX hist_time on history(start_time); + CREATE INDEX place_dir on places(dir); + CREATE INDEX place_host on places(host); + CREATE INDEX history_command_place on history(command_id, place_id); + COMMIT; "#; + + sqlx::query(db_sql).execute(&pool).await.unwrap(); + + // test histdb iterator + let histdb_vec = hist_from_db_conn(pool).await.unwrap(); + let histdb = ZshHistDb { + histdb: histdb_vec, + username: get_username(), + }; + + println!("h: {:#?}", histdb.histdb); + println!("counter: {:?}", histdb.histdb.len()); + for i in histdb.histdb { + println!("{i:?}"); + } + } +} diff --git a/crates/turtle/src/atuin_client/login.rs b/crates/turtle/src/atuin_client/login.rs new file mode 100644 index 00000000..ca4e16fe --- /dev/null +++ b/crates/turtle/src/atuin_client/login.rs @@ -0,0 +1,68 @@ +use std::path::PathBuf; + +use crate::atuin_common::api::LoginRequest; +use eyre::{Context, Result, bail}; +use tokio::fs::File; +use tokio::io::AsyncWriteExt; + +use crate::atuin_client::{ + api_client, + encryption::{decode_key, load_key}, + record::{sqlite_store::SqliteStore, store::Store}, + settings::Settings, +}; + +pub async fn login( + settings: &Settings, + store: &SqliteStore, + username: String, + password: String, + key: String, +) -> Result<String> { + let key_path = settings.key_path.as_str(); + let key_path = PathBuf::from(key_path); + + if !key_path.exists() { + if decode_key(key.clone()).is_err() { + bail!("the specified key was invalid"); + } + + let mut file = File::create(&key_path).await?; + file.write_all(key.as_bytes()).await?; + } else { + // we now know that the user has logged in specifying a key, AND that the key path + // exists + + // 1. check if the saved key and the provided key match. if so, nothing to do. + // 2. if not, re-encrypt the local history and overwrite the key + let current_key: [u8; 32] = load_key(settings)?.into(); + + let encoded = key.clone(); // gonna want to save it in a bit + let new_key: [u8; 32] = decode_key(key) + .context("could not decode provided key - is not valid base64")? + .into(); + + if new_key != current_key { + println!("\nRe-encrypting local store with new key"); + + store.re_encrypt(¤t_key, &new_key).await?; + + println!("Writing new key"); + let mut file = File::create(&key_path).await?; + file.write_all(encoded.as_bytes()).await?; + } + } + + let session = api_client::login( + settings.sync_address.as_str(), + LoginRequest { username, password }, + ) + .await?; + + Settings::meta_store() + .await? + .save_session(&session.session) + .await?; + + Ok(session.session) +} diff --git a/crates/turtle/src/atuin_client/logout.rs b/crates/turtle/src/atuin_client/logout.rs new file mode 100644 index 00000000..343934b9 --- /dev/null +++ b/crates/turtle/src/atuin_client/logout.rs @@ -0,0 +1,16 @@ +use eyre::Result; + +use crate::atuin_client::settings::Settings; + +pub async fn logout() -> Result<()> { + let meta = Settings::meta_store().await?; + + if meta.logged_in().await? { + meta.delete_session().await?; + println!("You have logged out!"); + } else { + println!("You are not logged in"); + } + + Ok(()) +} diff --git a/crates/turtle/src/atuin_client/meta.rs b/crates/turtle/src/atuin_client/meta.rs new file mode 100644 index 00000000..1eea7061 --- /dev/null +++ b/crates/turtle/src/atuin_client/meta.rs @@ -0,0 +1,366 @@ +use std::path::Path; +use std::str::FromStr; +use std::time::Duration; + +use crate::atuin_common::record::HostId; +use eyre::{Result, eyre}; +use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions}; +use time::{OffsetDateTime, format_description::well_known::Rfc3339}; +use tokio::sync::OnceCell; +use tracing::{debug, warn}; +use uuid::Uuid; + +// Filenames for the legacy plain-text files that we migrate from. +const LEGACY_HOST_ID_FILENAME: &str = "host_id"; +const LEGACY_LAST_SYNC_FILENAME: &str = "last_sync_time"; +const LEGACY_LAST_VERSION_CHECK_FILENAME: &str = "last_version_check_time"; +const LEGACY_LATEST_VERSION_FILENAME: &str = "latest_version"; +const LEGACY_SESSION_FILENAME: &str = "session"; + +const KEY_HOST_ID: &str = "host_id"; +const KEY_LAST_SYNC: &str = "last_sync_time"; +const KEY_LAST_VERSION_CHECK: &str = "last_version_check_time"; +const KEY_LATEST_VERSION: &str = "latest_version"; +const KEY_SESSION: &str = "session"; +const KEY_FILES_MIGRATED: &str = "files_migrated"; + +pub struct MetaStore { + pool: SqlitePool, + cached_host_id: OnceCell<HostId>, +} + +impl MetaStore { + pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> { + let path = path.as_ref(); + let path_str = path + .as_os_str() + .to_str() + .ok_or_else(|| eyre!("meta database path is not valid UTF-8: {path:?}"))?; + debug!("opening meta sqlite database at {path:?}"); + + let is_memory = path_str.contains(":memory:"); + + if !is_memory + && !path.exists() + && let Some(dir) = path.parent() + { + fs_err::create_dir_all(dir)?; + } + + // Use DELETE journal mode instead of WAL. This is a small, infrequently- + // written KV store — WAL's concurrency benefits aren't needed, and DELETE + // mode avoids creating auxiliary -wal/-shm files that complicate + // permission handling. + let opts = SqliteConnectOptions::from_str(path_str)? + .journal_mode(SqliteJournalMode::Delete) + .optimize_on_close(true, None) + .create_if_missing(true); + + let pool = SqlitePoolOptions::new() + .acquire_timeout(Duration::from_secs_f64(timeout)) + .connect_with(opts) + .await?; + + sqlx::migrate!("./meta-migrations").run(&pool).await?; + + // Session tokens are stored in this database, so restrict permissions. + #[cfg(unix)] + if !is_memory { + use std::os::unix::fs::PermissionsExt; + std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600))?; + } + + let store = Self { + pool, + cached_host_id: OnceCell::const_new(), + }; + + if !is_memory { + store.migrate_files().await?; + } + + Ok(store) + } + + // Generic key-value operations + + pub async fn get(&self, key: &str) -> Result<Option<String>> { + let row: Option<(String,)> = sqlx::query_as("SELECT value FROM meta WHERE key = ?1") + .bind(key) + .fetch_optional(&self.pool) + .await?; + + Ok(row.map(|r| r.0)) + } + + pub async fn set(&self, key: &str, value: &str) -> Result<()> { + sqlx::query( + "INSERT INTO meta (key, value, updated_at) VALUES (?1, ?2, strftime('%s', 'now')) + ON CONFLICT(key) DO UPDATE SET value = ?2, updated_at = strftime('%s', 'now')", + ) + .bind(key) + .bind(value) + .execute(&self.pool) + .await?; + + Ok(()) + } + + pub async fn delete(&self, key: &str) -> Result<()> { + sqlx::query("DELETE FROM meta WHERE key = ?1") + .bind(key) + .execute(&self.pool) + .await?; + + Ok(()) + } + + // Typed accessors + + pub async fn host_id(&self) -> Result<HostId> { + self.cached_host_id + .get_or_try_init(|| async { + if let Some(id) = self.get(KEY_HOST_ID).await? { + let parsed = Uuid::from_str(id.as_str()) + .map_err(|e| eyre!("failed to parse host ID: {e}"))?; + return Ok(HostId(parsed)); + } + + let uuid = crate::atuin_common::utils::uuid_v7(); + self.set(KEY_HOST_ID, uuid.as_simple().to_string().as_ref()) + .await?; + + Ok(HostId(uuid)) + }) + .await + .copied() + } + + pub async fn last_sync(&self) -> Result<OffsetDateTime> { + match self.get(KEY_LAST_SYNC).await? { + Some(v) => Ok(OffsetDateTime::parse(v.as_str(), &Rfc3339)?), + None => Ok(OffsetDateTime::UNIX_EPOCH), + } + } + + pub async fn save_sync_time(&self) -> Result<()> { + self.set( + KEY_LAST_SYNC, + OffsetDateTime::now_utc().format(&Rfc3339)?.as_str(), + ) + .await + } + + pub async fn last_version_check(&self) -> Result<OffsetDateTime> { + match self.get(KEY_LAST_VERSION_CHECK).await? { + Some(v) => Ok(OffsetDateTime::parse(v.as_str(), &Rfc3339)?), + None => Ok(OffsetDateTime::UNIX_EPOCH), + } + } + + pub async fn save_version_check_time(&self) -> Result<()> { + self.set( + KEY_LAST_VERSION_CHECK, + OffsetDateTime::now_utc().format(&Rfc3339)?.as_str(), + ) + .await + } + + pub async fn latest_version(&self) -> Result<Option<String>> { + self.get(KEY_LATEST_VERSION).await + } + + pub async fn save_latest_version(&self, version: &str) -> Result<()> { + self.set(KEY_LATEST_VERSION, version).await + } + + pub async fn session_token(&self) -> Result<Option<String>> { + self.get(KEY_SESSION).await + } + + pub async fn save_session(&self, token: &str) -> Result<()> { + self.set(KEY_SESSION, token).await + } + + pub async fn delete_session(&self) -> Result<()> { + self.delete(KEY_SESSION).await + } + + pub async fn logged_in(&self) -> Result<bool> { + Ok(self.session_token().await?.is_some()) + } + + // File migration: on first open, migrate old plain-text files into the database. + // Old files are left in place for safe downgrades. + + async fn migrate_files(&self) -> Result<()> { + if self.get(KEY_FILES_MIGRATED).await?.is_some() { + return Ok(()); + } + + let data_dir = crate::atuin_client::settings::Settings::effective_data_dir(); + + // host_id — validate as UUID + let host_id_path = data_dir.join(LEGACY_HOST_ID_FILENAME); + if host_id_path.exists() + && let Ok(value) = fs_err::read_to_string(&host_id_path) + { + let value = value.trim(); + if !value.is_empty() { + if Uuid::from_str(value).is_ok() { + self.set(KEY_HOST_ID, value).await?; + } else { + warn!("skipping migration of host_id: invalid UUID {value:?}"); + } + } + } + + // last_sync_time — validate as RFC3339 + let sync_path = data_dir.join(LEGACY_LAST_SYNC_FILENAME); + if sync_path.exists() + && let Ok(value) = fs_err::read_to_string(&sync_path) + { + let value = value.trim(); + if !value.is_empty() { + if OffsetDateTime::parse(value, &Rfc3339).is_ok() { + self.set(KEY_LAST_SYNC, value).await?; + } else { + warn!("skipping migration of last_sync_time: invalid RFC3339 {value:?}"); + } + } + } + + // last_version_check_time — validate as RFC3339 + let version_check_path = data_dir.join(LEGACY_LAST_VERSION_CHECK_FILENAME); + if version_check_path.exists() + && let Ok(value) = fs_err::read_to_string(&version_check_path) + { + let value = value.trim(); + if !value.is_empty() { + if OffsetDateTime::parse(value, &Rfc3339).is_ok() { + self.set(KEY_LAST_VERSION_CHECK, value).await?; + } else { + warn!( + "skipping migration of last_version_check_time: invalid RFC3339 {value:?}" + ); + } + } + } + + // latest_version — no strict validation, just non-empty + let latest_version_path = data_dir.join(LEGACY_LATEST_VERSION_FILENAME); + if latest_version_path.exists() + && let Ok(value) = fs_err::read_to_string(&latest_version_path) + { + let value = value.trim(); + if !value.is_empty() { + self.set(KEY_LATEST_VERSION, value).await?; + } + } + + // session token — no strict validation, just non-empty + let session_path = data_dir.join(LEGACY_SESSION_FILENAME); + if session_path.exists() + && let Ok(value) = fs_err::read_to_string(&session_path) + { + let value = value.trim(); + if !value.is_empty() { + self.set(KEY_SESSION, value).await?; + } + } + + self.set(KEY_FILES_MIGRATED, "true").await?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + async fn new_test_store() -> MetaStore { + MetaStore::new("sqlite::memory:", 2.0).await.unwrap() + } + + #[tokio::test] + async fn test_get_set_delete() { + let store = new_test_store().await; + + assert_eq!(store.get("foo").await.unwrap(), None); + + store.set("foo", "bar").await.unwrap(); + assert_eq!(store.get("foo").await.unwrap(), Some("bar".to_string())); + + store.set("foo", "baz").await.unwrap(); + assert_eq!(store.get("foo").await.unwrap(), Some("baz".to_string())); + + store.delete("foo").await.unwrap(); + assert_eq!(store.get("foo").await.unwrap(), None); + } + + #[tokio::test] + async fn test_host_id_generation_and_stability() { + let store = new_test_store().await; + + let id1 = store.host_id().await.unwrap(); + let id2 = store.host_id().await.unwrap(); + + assert_eq!(id1, id2, "host_id should be stable across calls"); + } + + #[tokio::test] + async fn test_sync_time() { + let store = new_test_store().await; + + let t = store.last_sync().await.unwrap(); + assert_eq!(t, OffsetDateTime::UNIX_EPOCH); + + store.save_sync_time().await.unwrap(); + let t = store.last_sync().await.unwrap(); + assert!(t > OffsetDateTime::UNIX_EPOCH); + } + + #[tokio::test] + async fn test_version_check_time() { + let store = new_test_store().await; + + let t = store.last_version_check().await.unwrap(); + assert_eq!(t, OffsetDateTime::UNIX_EPOCH); + + store.save_version_check_time().await.unwrap(); + let t = store.last_version_check().await.unwrap(); + assert!(t > OffsetDateTime::UNIX_EPOCH); + } + + #[tokio::test] + async fn test_session_crud() { + let store = new_test_store().await; + + assert!(!store.logged_in().await.unwrap()); + assert_eq!(store.session_token().await.unwrap(), None); + + store.save_session("tok123").await.unwrap(); + assert!(store.logged_in().await.unwrap()); + assert_eq!( + store.session_token().await.unwrap(), + Some("tok123".to_string()) + ); + + store.delete_session().await.unwrap(); + assert!(!store.logged_in().await.unwrap()); + } + + #[tokio::test] + async fn test_latest_version() { + let store = new_test_store().await; + + assert_eq!(store.latest_version().await.unwrap(), None); + + store.save_latest_version("1.2.3").await.unwrap(); + assert_eq!( + store.latest_version().await.unwrap(), + Some("1.2.3".to_string()) + ); + } +} diff --git a/crates/turtle/src/atuin_client/mod.rs b/crates/turtle/src/atuin_client/mod.rs new file mode 100644 index 00000000..7f07f2e2 --- /dev/null +++ b/crates/turtle/src/atuin_client/mod.rs @@ -0,0 +1,26 @@ +#[cfg(feature = "sync")] +pub mod api_client; +#[cfg(feature = "sync")] +pub mod auth; +#[cfg(feature = "sync")] +pub mod login; +#[cfg(feature = "sync")] +pub mod register; +#[cfg(feature = "sync")] +pub mod sync; + +pub mod database; +pub mod distro; +pub mod encryption; +pub mod history; +pub mod import; +pub mod logout; +pub mod meta; +pub mod ordering; +pub mod plugin; +pub mod record; +pub mod secrets; +pub mod settings; +pub mod theme; + +mod utils; diff --git a/crates/turtle/src/atuin_client/ordering.rs b/crates/turtle/src/atuin_client/ordering.rs new file mode 100644 index 00000000..4e5ec84c --- /dev/null +++ b/crates/turtle/src/atuin_client/ordering.rs @@ -0,0 +1,32 @@ +use minspan::minspan; + +use super::{history::History, settings::SearchMode}; + +pub fn reorder_fuzzy(mode: SearchMode, query: &str, res: Vec<History>) -> Vec<History> { + match mode { + SearchMode::Fuzzy => reorder(query, |x| &x.command, res), + _ => res, + } +} + +fn reorder<F, A>(query: &str, f: F, res: Vec<A>) -> Vec<A> +where + F: Fn(&A) -> &String, + A: Clone, +{ + let mut r = res.clone(); + let qvec = &query.chars().collect(); + r.sort_by_cached_key(|h| { + // TODO for fzf search we should sum up scores for each matched term + let (from, to) = match minspan::span(qvec, &(f(h).chars().collect())) { + Some(x) => x, + // this is a little unfortunate: when we are asked to match a query that is found nowhere, + // we don't want to return a None, as the comparison behaviour would put the worst matches + // at the front. therefore, we'll return a set of indices that are one larger than the longest + // possible legitimate match. This is meaningless except as a comparison. + None => (0, res.len()), + }; + 1 + to - from + }); + r +} diff --git a/crates/turtle/src/atuin_client/plugin.rs b/crates/turtle/src/atuin_client/plugin.rs new file mode 100644 index 00000000..6f351bf1 --- /dev/null +++ b/crates/turtle/src/atuin_client/plugin.rs @@ -0,0 +1,150 @@ +use std::collections::HashMap; + +#[derive(Debug, Clone)] +pub struct OfficialPlugin { + pub name: String, + pub description: String, + pub install_message: String, +} + +impl OfficialPlugin { + pub fn new(name: &str, description: &str, install_message: &str) -> Self { + Self { + name: name.to_string(), + description: description.to_string(), + install_message: install_message.to_string(), + } + } +} + +pub struct OfficialPluginRegistry { + plugins: HashMap<String, OfficialPlugin>, +} + +impl OfficialPluginRegistry { + pub fn new() -> Self { + let mut registry = Self { + plugins: HashMap::new(), + }; + + // Register official plugins + registry.register_official_plugins(); + + registry + } + + fn register_official_plugins(&mut self) { + // atuin-update plugin + self.plugins.insert( + "update".to_string(), + OfficialPlugin::new( + "update", + "Update atuin to the latest version", + "The 'atuin update' command is provided by the atuin-update plugin.\n\ + It is only installed if you used the install script\n \ + If you used a package manager (brew, apt, etc), please continue to use it for updates", + ), + ); + } + + pub fn get_plugin(&self, name: &str) -> Option<&OfficialPlugin> { + self.plugins.get(name) + } + + pub fn is_official_plugin(&self, name: &str) -> bool { + self.plugins.contains_key(name) + } + + pub fn get_install_message(&self, name: &str) -> Option<&str> { + self.plugins + .get(name) + .map(|plugin| plugin.install_message.as_str()) + } +} + +impl Default for OfficialPluginRegistry { + fn default() -> Self { + Self::new() + } +} + +pub struct PluginContext { + #[cfg(windows)] + _update_on_windows: Option<UpdateOnWindowsContext>, +} + +impl PluginContext { + pub fn new(_subcommand: &str) -> Self { + PluginContext { + #[cfg(windows)] + _update_on_windows: (_subcommand == "update").then(UpdateOnWindowsContext::new), + } + } +} + +impl Drop for PluginContext { + fn drop(&mut self) {} +} + +#[cfg(windows)] +struct UpdateOnWindowsContext { + initial_exe: Option<std::path::PathBuf>, +} + +#[cfg(windows)] +impl UpdateOnWindowsContext { + const OLD_FILE_NAME: &'static str = "atuin.old"; + + pub fn new() -> Self { + // Windows doesn't let you overwrite a running exe, but it lets you rename it, + // so make some room for atuin-update to install the new version. + let initial_exe = std::env::current_exe().ok().and_then(|exe| { + std::fs::rename(&exe, exe.with_file_name(Self::OLD_FILE_NAME)).ok()?; + Some(exe) + }); + + Self { initial_exe } + } +} + +#[cfg(windows)] +impl Drop for UpdateOnWindowsContext { + fn drop(&mut self) { + if let Some(exe) = &self.initial_exe + && !exe.exists() + { + // The update failed, roll back the current exe to its initial name. + std::fs::rename(exe.with_file_name(Self::OLD_FILE_NAME), exe).unwrap_or_else(|e| { + eprintln!("Failed to roll back the update, you may need to reinstall Atuin: {e}"); + }); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_registry_creation() { + let registry = OfficialPluginRegistry::new(); + assert!(registry.is_official_plugin("update")); + assert!(!registry.is_official_plugin("nonexistent")); + } + + #[test] + fn test_get_plugin() { + let registry = OfficialPluginRegistry::new(); + let plugin = registry.get_plugin("update"); + assert!(plugin.is_some()); + assert_eq!(plugin.unwrap().name, "update"); + } + + #[test] + fn test_get_install_message() { + let registry = OfficialPluginRegistry::new(); + let message = registry.get_install_message("update"); + assert!(message.is_some()); + assert!(message.unwrap().contains("atuin-update")); + } +} diff --git a/crates/turtle/src/atuin_client/record/encryption.rs b/crates/turtle/src/atuin_client/record/encryption.rs new file mode 100644 index 00000000..22dcdec3 --- /dev/null +++ b/crates/turtle/src/atuin_client/record/encryption.rs @@ -0,0 +1,373 @@ +use crate::atuin_common::record::{ + AdditionalData, DecryptedData, EncryptedData, Encryption, HostId, RecordId, RecordIdx, +}; +use base64::{Engine, engine::general_purpose}; +use eyre::{Context, Result, ensure}; +use rusty_paserk::{Key, KeyId, Local, PieWrappedKey}; +use rusty_paseto::core::{ + ImplicitAssertion, Key as DataKey, Local as LocalPurpose, Paseto, PasetoNonce, Payload, V4, +}; +use serde::{Deserialize, Serialize}; + +/// Use PASETO V4 Local encryption using the additional data as an implicit assertion. +#[expect(non_camel_case_types)] +pub struct PASETO_V4; + +/* +Why do we use a random content-encryption key? +Originally I was planning on using a derived key for encryption based on additional data. +This would be a lot more secure than using the master key directly. + +However, there's an established norm of using a random key. This scheme might be otherwise known as +- client-side encryption +- envelope encryption +- key wrapping + +A HSM (Hardware Security Module) provider, eg: AWS, Azure, GCP, or even a physical device like a YubiKey +will have some keys that they keep to themselves. These keys never leave their physical hardware. +If they never leave the hardware, then encrypting large amounts of data means giving them the data and waiting. +This is not a practical solution. Instead, generate a unique key for your data, encrypt that using your HSM +and then store that with your data. + +See + - <https://docs.aws.amazon.com/wellarchitected/latest/financial-services-industry-lens/use-envelope-encryption-with-customer-master-keys.html> + - <https://cloud.google.com/kms/docs/envelope-encryption> + - <https://learn.microsoft.com/en-us/azure/storage/blobs/client-side-encryption?tabs=dotnet#encryption-and-decryption-via-the-envelope-technique> + - <https://www.yubico.com/gb/product/yubihsm-2-fips/> + - <https://cheatsheetseries.owasp.org/cheatsheets/Cryptographic_Storage_Cheat_Sheet.html#encrypting-stored-keys> + +Why would we care? In the past we have received some requests for company solutions. If in future we can configure a +KMS service with little effort, then that would solve a lot of issues for their security team. + +Even for personal use, if a user is not comfortable with sharing keys between hosts, +GCP HSM costs $1/month and $0.03 per 10,000 key operations. Assuming an active user runs +1000 atuin records a day, that would only cost them $1 and 10 cent a month. + +Additionally, key rotations are much simpler using this scheme. Rotating a key is as simple as re-encrypting the CEK, and not the message contents. +This makes it very fast to rotate a key in bulk. + +For future reference, with asymmetric encryption, you can encrypt the CEK without the HSM's involvement, but decrypting +will need the HSM. This allows the encryption path to still be extremely fast (no network calls) but downloads/decryption +that happens in the background can make the network calls to the HSM +*/ + +impl Encryption for PASETO_V4 { + fn re_encrypt( + mut data: EncryptedData, + _ad: AdditionalData, + old_key: &[u8; 32], + new_key: &[u8; 32], + ) -> Result<EncryptedData> { + let cek = Self::decrypt_cek(data.content_encryption_key, old_key)?; + data.content_encryption_key = Self::encrypt_cek(cek, new_key); + Ok(data) + } + + fn encrypt(data: DecryptedData, ad: AdditionalData, key: &[u8; 32]) -> EncryptedData { + // generate a random key for this entry + // aka content-encryption-key (CEK) + let random_key = Key::<V4, Local>::new_os_random(); + + // encode the implicit assertions + let assertions = Assertions::from(ad).encode(); + + // build the payload and encrypt the token + let payload = serde_json::to_string(&AtuinPayload { + data: general_purpose::URL_SAFE_NO_PAD.encode(data.0), + }) + .expect("json encoding can't fail"); + let nonce = DataKey::<32>::try_new_random().expect("could not source from random"); + let nonce = PasetoNonce::<V4, LocalPurpose>::from(&nonce); + + let token = Paseto::<V4, LocalPurpose>::builder() + .set_payload(Payload::from(payload.as_str())) + .set_implicit_assertion(ImplicitAssertion::from(assertions.as_str())) + .try_encrypt(&random_key.into(), &nonce) + .expect("error encrypting atuin data"); + + EncryptedData { + data: token, + content_encryption_key: Self::encrypt_cek(random_key, key), + } + } + + fn decrypt(data: EncryptedData, ad: AdditionalData, key: &[u8; 32]) -> Result<DecryptedData> { + let token = data.data; + let cek = Self::decrypt_cek(data.content_encryption_key, key)?; + + // encode the implicit assertions + let assertions = Assertions::from(ad).encode(); + + // decrypt the payload with the footer and implicit assertions + let payload = Paseto::<V4, LocalPurpose>::try_decrypt( + &token, + &cek.into(), + None, + ImplicitAssertion::from(&*assertions), + ) + .context("could not decrypt entry")?; + + let payload: AtuinPayload = serde_json::from_str(&payload)?; + let data = general_purpose::URL_SAFE_NO_PAD.decode(payload.data)?; + Ok(DecryptedData(data)) + } +} + +impl PASETO_V4 { + fn decrypt_cek(wrapped_cek: String, key: &[u8; 32]) -> Result<Key<V4, Local>> { + let wrapping_key = Key::<V4, Local>::from_bytes(*key); + + // let wrapping_key = PasetoSymmetricKey::from(Key::from(key)); + + let AtuinFooter { kid, wpk } = serde_json::from_str(&wrapped_cek) + .context("wrapped cek did not contain the correct contents")?; + + // check that the wrapping key matches the required key to decrypt. + // In future, we could support multiple keys and use this key to + // look up the key rather than only allow one key. + // For now though we will only support the one key and key rotation will + // have to be a hard reset + let current_kid = wrapping_key.to_id(); + + ensure!( + current_kid == kid, + "attempting to decrypt with incorrect key. currently using {current_kid}, expecting {kid}" + ); + + // decrypt the random key + Ok(wpk.unwrap_key(&wrapping_key)?) + } + + fn encrypt_cek(cek: Key<V4, Local>, key: &[u8; 32]) -> String { + // aka key-encryption-key (KEK) + let wrapping_key = Key::<V4, Local>::from_bytes(*key); + + // wrap the random key so we can decrypt it later + let wrapped_cek = AtuinFooter { + wpk: cek.wrap_pie(&wrapping_key), + kid: wrapping_key.to_id(), + }; + serde_json::to_string(&wrapped_cek).expect("could not serialize wrapped cek") + } +} + +#[derive(Serialize, Deserialize)] +struct AtuinPayload { + data: String, +} + +#[derive(Serialize, Deserialize)] +/// Well-known footer claims for decrypting. This is not encrypted but is stored in the record. +/// <https://github.com/paseto-standard/paseto-spec/blob/master/docs/02-Implementation-Guide/04-Claims.md#optional-footer-claims> +struct AtuinFooter { + /// Wrapped key + wpk: PieWrappedKey<V4, Local>, + /// ID of the key which was used to wrap + kid: KeyId<V4, Local>, +} + +/// Used in the implicit assertions. This is not encrypted and not stored in the data blob. +// This cannot be changed, otherwise it breaks the authenticated encryption. +#[derive(Debug, Copy, Clone, Serialize)] +struct Assertions<'a> { + id: &'a RecordId, + idx: &'a RecordIdx, + version: &'a str, + tag: &'a str, + host: &'a HostId, +} + +impl<'a> From<AdditionalData<'a>> for Assertions<'a> { + fn from(ad: AdditionalData<'a>) -> Self { + Self { + id: ad.id, + version: ad.version, + tag: ad.tag, + host: ad.host, + idx: ad.idx, + } + } +} + +impl Assertions<'_> { + fn encode(&self) -> String { + serde_json::to_string(self).expect("could not serialize implicit assertions") + } +} + +#[cfg(test)] +mod tests { + use crate::atuin_common::{ + record::{Host, Record}, + utils::uuid_v7, + }; + + use super::*; + + #[test] + fn round_trip() { + let key = Key::<V4, Local>::new_os_random(); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + version: "v0", + tag: "kv", + host: &HostId(uuid_v7()), + idx: &0, + }; + + let data = DecryptedData(vec![1, 2, 3, 4]); + + let encrypted = PASETO_V4::encrypt(data.clone(), ad, &key.to_bytes()); + let decrypted = PASETO_V4::decrypt(encrypted, ad, &key.to_bytes()).unwrap(); + assert_eq!(decrypted, data); + } + + #[test] + fn same_entry_different_output() { + let key = Key::<V4, Local>::new_os_random(); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + version: "v0", + tag: "kv", + host: &HostId(uuid_v7()), + idx: &0, + }; + + let data = DecryptedData(vec![1, 2, 3, 4]); + + let encrypted = PASETO_V4::encrypt(data.clone(), ad, &key.to_bytes()); + let encrypted2 = PASETO_V4::encrypt(data, ad, &key.to_bytes()); + + assert_ne!( + encrypted.data, encrypted2.data, + "re-encrypting the same contents should have different output due to key randomization" + ); + } + + #[test] + fn cannot_decrypt_different_key() { + let key = Key::<V4, Local>::new_os_random(); + let fake_key = Key::<V4, Local>::new_os_random(); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + version: "v0", + tag: "kv", + host: &HostId(uuid_v7()), + idx: &0, + }; + + let data = DecryptedData(vec![1, 2, 3, 4]); + + let encrypted = PASETO_V4::encrypt(data, ad, &key.to_bytes()); + let _ = PASETO_V4::decrypt(encrypted, ad, &fake_key.to_bytes()).unwrap_err(); + } + + #[test] + fn cannot_decrypt_different_id() { + let key = Key::<V4, Local>::new_os_random(); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + version: "v0", + tag: "kv", + host: &HostId(uuid_v7()), + idx: &0, + }; + + let data = DecryptedData(vec![1, 2, 3, 4]); + + let encrypted = PASETO_V4::encrypt(data, ad, &key.to_bytes()); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + ..ad + }; + let _ = PASETO_V4::decrypt(encrypted, ad, &key.to_bytes()).unwrap_err(); + } + + #[test] + fn re_encrypt_round_trip() { + let key1 = Key::<V4, Local>::new_os_random(); + let key2 = Key::<V4, Local>::new_os_random(); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + version: "v0", + tag: "kv", + host: &HostId(uuid_v7()), + idx: &0, + }; + + let data = DecryptedData(vec![1, 2, 3, 4]); + + let encrypted1 = PASETO_V4::encrypt(data.clone(), ad, &key1.to_bytes()); + let encrypted2 = + PASETO_V4::re_encrypt(encrypted1.clone(), ad, &key1.to_bytes(), &key2.to_bytes()) + .unwrap(); + + // we only re-encrypt the content keys + assert_eq!(encrypted1.data, encrypted2.data); + assert_ne!( + encrypted1.content_encryption_key, + encrypted2.content_encryption_key + ); + + let decrypted = PASETO_V4::decrypt(encrypted2, ad, &key2.to_bytes()).unwrap(); + + assert_eq!(decrypted, data); + } + + #[test] + fn full_record_round_trip() { + let key = [0x55; 32]; + let record = Record::builder() + .id(RecordId(uuid_v7())) + .version("v0".to_owned()) + .tag("kv".to_owned()) + .host(Host::new(HostId(uuid_v7()))) + .timestamp(1687244806000000) + .data(DecryptedData(vec![1, 2, 3, 4])) + .idx(0) + .build(); + + let encrypted = record.encrypt::<PASETO_V4>(&key); + + assert!(!encrypted.data.data.is_empty()); + assert!(!encrypted.data.content_encryption_key.is_empty()); + + let decrypted = encrypted.decrypt::<PASETO_V4>(&key).unwrap(); + + assert_eq!(decrypted.data.0, [1, 2, 3, 4]); + } + + #[test] + fn full_record_round_trip_fail() { + let key = [0x55; 32]; + let record = Record::builder() + .id(RecordId(uuid_v7())) + .version("v0".to_owned()) + .tag("kv".to_owned()) + .host(Host::new(HostId(uuid_v7()))) + .timestamp(1687244806000000) + .data(DecryptedData(vec![1, 2, 3, 4])) + .idx(0) + .build(); + + let encrypted = record.encrypt::<PASETO_V4>(&key); + + let mut enc1 = encrypted.clone(); + enc1.host = Host::new(HostId(uuid_v7())); + let _ = enc1 + .decrypt::<PASETO_V4>(&key) + .expect_err("tampering with the host should result in auth failure"); + + let mut enc2 = encrypted; + enc2.id = RecordId(uuid_v7()); + let _ = enc2 + .decrypt::<PASETO_V4>(&key) + .expect_err("tampering with the id should result in auth failure"); + } +} diff --git a/crates/turtle/src/atuin_client/record/mod.rs b/crates/turtle/src/atuin_client/record/mod.rs new file mode 100644 index 00000000..c40fd395 --- /dev/null +++ b/crates/turtle/src/atuin_client/record/mod.rs @@ -0,0 +1,6 @@ +pub mod encryption; +pub mod sqlite_store; +pub mod store; + +#[cfg(feature = "sync")] +pub mod sync; diff --git a/crates/turtle/src/atuin_client/record/sqlite_store.rs b/crates/turtle/src/atuin_client/record/sqlite_store.rs new file mode 100644 index 00000000..5fab999d --- /dev/null +++ b/crates/turtle/src/atuin_client/record/sqlite_store.rs @@ -0,0 +1,643 @@ +// Here we are using sqlite as a pretty dumb store, and will not be running any complex queries. +// Multiple stores of multiple types are all stored in one chonky table (for now), and we just index +// by tag/host + +use std::str::FromStr; +use std::{path::Path, time::Duration}; + +use async_trait::async_trait; +use eyre::{Result, eyre}; +use fs_err as fs; + +use sqlx::{ + Row, + sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow}, +}; +use tracing::debug; + +use crate::atuin_common::record::{ + EncryptedData, Host, HostId, Record, RecordId, RecordIdx, RecordStatus, +}; +use crate::atuin_common::utils; +use uuid::Uuid; + +use super::encryption::PASETO_V4; +use super::store::Store; + +#[derive(Debug, Clone)] +pub struct SqliteStore { + pool: SqlitePool, +} + +impl SqliteStore { + pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> { + let path = path.as_ref(); + + debug!("opening sqlite database at {path:?}"); + + if utils::broken_symlink(path) { + eprintln!( + "Atuin: Sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement." + ); + std::process::exit(1); + } + + if !path.exists() + && let Some(dir) = path.parent() + { + fs::create_dir_all(dir)?; + } + + let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())? + .journal_mode(SqliteJournalMode::Wal) + .foreign_keys(true) + .create_if_missing(true); + + let pool = SqlitePoolOptions::new() + .acquire_timeout(Duration::from_secs_f64(timeout)) + .connect_with(opts) + .await?; + + Self::setup_db(&pool).await?; + + Ok(Self { pool }) + } + + async fn setup_db(pool: &SqlitePool) -> Result<()> { + debug!("running sqlite database setup"); + + sqlx::migrate!("./record-migrations").run(pool).await?; + + Ok(()) + } + + async fn save_raw( + tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, + r: &Record<EncryptedData>, + ) -> Result<()> { + // In sqlite, we are "limited" to i64. But that is still fine, until 2262. + sqlx::query( + "insert or ignore into store(id, idx, host, tag, timestamp, version, data, cek) + values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", + ) + .bind(r.id.0.as_hyphenated().to_string()) + .bind(r.idx as i64) + .bind(r.host.id.0.as_hyphenated().to_string()) + .bind(r.tag.as_str()) + .bind(r.timestamp as i64) + .bind(r.version.as_str()) + .bind(r.data.data.as_str()) + .bind(r.data.content_encryption_key.as_str()) + .execute(&mut **tx) + .await?; + + Ok(()) + } + + fn query_row(row: SqliteRow) -> Record<EncryptedData> { + let idx: i64 = row.get("idx"); + let timestamp: i64 = row.get("timestamp"); + + // tbh at this point things are pretty fucked so just panic + let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB"); + let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB"); + + Record { + id: RecordId(id), + idx: idx as u64, + host: Host::new(HostId(host)), + timestamp: timestamp as u64, + tag: row.get("tag"), + version: row.get("version"), + data: EncryptedData { + data: row.get("data"), + content_encryption_key: row.get("cek"), + }, + } + } + + async fn load_all(&self) -> Result<Vec<Record<EncryptedData>>> { + let res = sqlx::query("select * from store ") + .map(Self::query_row) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } +} + +#[async_trait] +impl Store for SqliteStore { + async fn push_batch( + &self, + records: impl Iterator<Item = &Record<EncryptedData>> + Send + Sync, + ) -> Result<()> { + let mut tx = self.pool.begin().await?; + + for record in records { + Self::save_raw(&mut tx, record).await?; + } + + tx.commit().await?; + + Ok(()) + } + + async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>> { + let res = sqlx::query("select * from store where store.id = ?1") + .bind(id.0.as_hyphenated().to_string()) + .map(Self::query_row) + .fetch_one(&self.pool) + .await?; + + Ok(res) + } + + async fn delete(&self, id: RecordId) -> Result<()> { + sqlx::query("delete from store where id = ?1") + .bind(id.0.as_hyphenated().to_string()) + .execute(&self.pool) + .await?; + + Ok(()) + } + + async fn delete_all(&self) -> Result<()> { + sqlx::query("delete from store").execute(&self.pool).await?; + + Ok(()) + } + + async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> { + let res = + sqlx::query("select * from store where host=?1 and tag=?2 order by idx desc limit 1") + .bind(host.0.as_hyphenated().to_string()) + .bind(tag) + .map(Self::query_row) + .fetch_one(&self.pool) + .await; + + match res { + Err(sqlx::Error::RowNotFound) => Ok(None), + Err(e) => Err(eyre!("an error occurred: {}", e)), + Ok(record) => Ok(Some(record)), + } + } + + async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> { + self.idx(host, tag, 0).await + } + + async fn len_all(&self) -> Result<u64> { + let res: Result<(i64,), sqlx::Error> = sqlx::query_as("select count(*) from store") + .fetch_one(&self.pool) + .await; + match res { + Err(e) => Err(eyre!("failed to fetch local store len: {}", e)), + Ok(v) => Ok(v.0 as u64), + } + } + + async fn len_tag(&self, tag: &str) -> Result<u64> { + let res: Result<(i64,), sqlx::Error> = + sqlx::query_as("select count(*) from store where tag=?1") + .bind(tag) + .fetch_one(&self.pool) + .await; + match res { + Err(e) => Err(eyre!("failed to fetch local store len: {}", e)), + Ok(v) => Ok(v.0 as u64), + } + } + + async fn len(&self, host: HostId, tag: &str) -> Result<u64> { + let last = self.last(host, tag).await?; + + if let Some(last) = last { + return Ok(last.idx + 1); + } + + return Ok(0); + } + + async fn next( + &self, + host: HostId, + tag: &str, + idx: RecordIdx, + limit: u64, + ) -> Result<Vec<Record<EncryptedData>>> { + let res = sqlx::query( + "select * from store where idx >= ?1 and host = ?2 and tag = ?3 order by idx asc limit ?4", + ) + .bind(idx as i64) + .bind(host.0.as_hyphenated().to_string()) + .bind(tag) + .bind(limit as i64) + .map(Self::query_row) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn idx( + &self, + host: HostId, + tag: &str, + idx: RecordIdx, + ) -> Result<Option<Record<EncryptedData>>> { + let res = sqlx::query("select * from store where idx = ?1 and host = ?2 and tag = ?3") + .bind(idx as i64) + .bind(host.0.as_hyphenated().to_string()) + .bind(tag) + .map(Self::query_row) + .fetch_one(&self.pool) + .await; + + match res { + Err(sqlx::Error::RowNotFound) => Ok(None), + Err(e) => Err(eyre!("an error occurred: {}", e)), + Ok(v) => Ok(Some(v)), + } + } + + async fn status(&self) -> Result<RecordStatus> { + let mut status = RecordStatus::new(); + + let res: Result<Vec<(String, String, i64)>, sqlx::Error> = + sqlx::query_as("select host, tag, max(idx) from store group by host, tag") + .fetch_all(&self.pool) + .await; + + let res = match res { + Err(e) => return Err(eyre!("failed to fetch local store status: {}", e)), + Ok(v) => v, + }; + + for i in res { + let host = HostId( + Uuid::from_str(i.0.as_str()).expect("failed to parse uuid for local store status"), + ); + + status.set_raw(host, i.1, i.2 as u64); + } + + Ok(status) + } + + async fn all_tagged(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> { + let res = sqlx::query("select * from store where tag = ?1 order by timestamp asc") + .bind(tag) + .map(Self::query_row) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + /// Reencrypt every single item in this store with a new key + /// Be careful - this may mess with sync. + async fn re_encrypt(&self, old_key: &[u8; 32], new_key: &[u8; 32]) -> Result<()> { + // Load all the records + // In memory like some of the other code here + // This will never be called in a hot loop, and only under the following circumstances + // 1. The user has logged into a new account, with a new key. They are unlikely to have a + // lot of data + // 2. The user has encountered some sort of issue, and runs a maintenance command that + // invokes this + let all = self.load_all().await?; + + let re_encrypted = all + .into_iter() + .map(|record| record.re_encrypt::<PASETO_V4>(old_key, new_key)) + .collect::<Result<Vec<_>>>()?; + + // next up, we delete all the old data and reinsert the new stuff + // do it in one transaction, so if anything fails we rollback OK + + let mut tx = self.pool.begin().await?; + + let res = sqlx::query("delete from store").execute(&mut *tx).await?; + + let rows = res.rows_affected(); + debug!("deleted {rows} rows"); + + // don't call push_batch, as it will start its own transaction + // call the underlying save_raw + + for record in re_encrypted { + Self::save_raw(&mut tx, &record).await?; + } + + tx.commit().await?; + + Ok(()) + } + + /// Verify that every record in this store can be decrypted with the current key + /// Someday maybe also check each tag/record can be deserialized, but not for now. + async fn verify(&self, key: &[u8; 32]) -> Result<()> { + let all = self.load_all().await?; + + all.into_iter() + .map(|record| record.decrypt::<PASETO_V4>(key)) + .collect::<Result<Vec<_>>>()?; + + Ok(()) + } + + /// Verify that every record in this store can be decrypted with the current key + /// Someday maybe also check each tag/record can be deserialized, but not for now. + async fn purge(&self, key: &[u8; 32]) -> Result<()> { + let all = self.load_all().await?; + + for record in all.iter() { + match record.clone().decrypt::<PASETO_V4>(key) { + Ok(_) => continue, + Err(_) => { + println!( + "Failed to decrypt {}, deleting", + record.id.0.as_hyphenated() + ); + + self.delete(record.id).await?; + } + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use crate::atuin_common::{ + record::{DecryptedData, EncryptedData, Host, HostId, Record}, + utils::uuid_v7, + }; + + use crate::{ + encryption::generate_encoded_key, + record::{encryption::PASETO_V4, store::Store}, + settings::test_local_timeout, + }; + + use super::SqliteStore; + + fn test_record() -> Record<EncryptedData> { + Record::builder() + .host(Host::new(HostId(atuin_common::utils::uuid_v7()))) + .version("v1".into()) + .tag(atuin_common::utils::uuid_v7().simple().to_string()) + .data(EncryptedData { + data: "1234".into(), + content_encryption_key: "1234".into(), + }) + .idx(0) + .build() + } + + #[tokio::test] + async fn create_db() { + let db = SqliteStore::new(":memory:", test_local_timeout()).await; + + assert!( + db.is_ok(), + "db could not be created, {:?}", + db.err().unwrap() + ); + } + + #[tokio::test] + async fn push_record() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + let record = test_record(); + + db.push(&record).await.expect("failed to insert record"); + } + + #[tokio::test] + async fn get_record() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let new_record = db.get(record.id).await.expect("failed to fetch record"); + + assert_eq!(record, new_record, "records are not equal"); + } + + #[tokio::test] + async fn last() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let last = db + .last(record.host.id, record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!( + last.unwrap().id, + record.id, + "expected to get back the same record that was inserted" + ); + } + + #[tokio::test] + async fn first() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let first = db + .first(record.host.id, record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!( + first.unwrap().id, + record.id, + "expected to get back the same record that was inserted" + ); + } + + #[tokio::test] + async fn len() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let len = db + .len(record.host.id, record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!(len, 1, "expected length of 1 after insert"); + } + + #[tokio::test] + async fn len_tag() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let len = db + .len_tag(record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!(len, 1, "expected length of 1 after insert"); + } + + #[tokio::test] + async fn len_different_tags() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + + // these have different tags, so the len should be the same + // we model multiple stores within one database + // new store = new tag = independent length + let first = test_record(); + let second = test_record(); + + db.push(&first).await.unwrap(); + db.push(&second).await.unwrap(); + + let first_len = db.len(first.host.id, first.tag.as_str()).await.unwrap(); + let second_len = db.len(second.host.id, second.tag.as_str()).await.unwrap(); + + assert_eq!(first_len, 1, "expected length of 1 after insert"); + assert_eq!(second_len, 1, "expected length of 1 after insert"); + } + + #[tokio::test] + async fn append_a_bunch() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + + let mut tail = test_record(); + db.push(&tail).await.expect("failed to push record"); + + for _ in 1..100 { + tail = tail.append(vec![1, 2, 3, 4]).encrypt::<PASETO_V4>(&[0; 32]); + db.push(&tail).await.unwrap(); + } + + assert_eq!( + db.len(tail.host.id, tail.tag.as_str()).await.unwrap(), + 100, + "failed to insert 100 records" + ); + + assert_eq!( + db.len_tag(tail.tag.as_str()).await.unwrap(), + 100, + "failed to insert 100 records" + ); + } + + #[tokio::test] + async fn append_a_big_bunch() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + + let mut records: Vec<Record<EncryptedData>> = Vec::with_capacity(10000); + + let mut tail = test_record(); + records.push(tail.clone()); + + for _ in 1..10000 { + tail = tail.append(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]); + records.push(tail.clone()); + } + + db.push_batch(records.iter()).await.unwrap(); + + assert_eq!( + db.len(tail.host.id, tail.tag.as_str()).await.unwrap(), + 10000, + "failed to insert 10k records" + ); + } + + #[tokio::test] + async fn re_encrypt() { + let store = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + let (key, _) = generate_encoded_key().unwrap(); + let data = vec![0u8, 1u8, 2u8, 3u8]; + let host_id = HostId(uuid_v7()); + + for i in 0..10 { + let record = Record::builder() + .host(Host::new(host_id)) + .version(String::from("test")) + .tag(String::from("test")) + .idx(i) + .data(DecryptedData(data.clone())) + .build(); + + let record = record.encrypt::<PASETO_V4>(&key.into()); + store + .push(&record) + .await + .expect("failed to push encrypted record"); + } + + // first, check that we can decrypt the data with the current key + let all = store.all_tagged("test").await.unwrap(); + + assert_eq!(all.len(), 10, "failed to fetch all records"); + + for record in all { + let decrypted = record.decrypt::<PASETO_V4>(&key.into()).unwrap(); + assert_eq!(decrypted.data.0, data); + } + + // reencrypt the store, then check if + // 1) it cannot be decrypted with the old key + // 2) it can be decrypted with the new key + + let (new_key, _) = generate_encoded_key().unwrap(); + store + .re_encrypt(&key.into(), &new_key.into()) + .await + .expect("failed to re-encrypt store"); + + let all = store.all_tagged("test").await.unwrap(); + + for record in all.iter() { + let decrypted = record.clone().decrypt::<PASETO_V4>(&key.into()); + assert!( + decrypted.is_err(), + "did not get error decrypting with old key after re-encrypt" + ) + } + + for record in all { + let decrypted = record.decrypt::<PASETO_V4>(&new_key.into()).unwrap(); + assert_eq!(decrypted.data.0, data); + } + + assert_eq!(store.len(host_id, "test").await.unwrap(), 10); + } +} diff --git a/crates/turtle/src/atuin_client/record/store.rs b/crates/turtle/src/atuin_client/record/store.rs new file mode 100644 index 00000000..f99085d0 --- /dev/null +++ b/crates/turtle/src/atuin_client/record/store.rs @@ -0,0 +1,60 @@ +use async_trait::async_trait; +use eyre::Result; + +use crate::atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIdx, RecordStatus}; + +/// A record store stores records +/// In more detail - we tend to need to process this into _another_ format to actually query it. +/// As is, the record store is intended as the source of truth for arbitrary data, which could +/// be shell history, kvs, etc. +#[async_trait] +pub trait Store { + // Push a record + async fn push(&self, record: &Record<EncryptedData>) -> Result<()> { + self.push_batch(std::iter::once(record)).await + } + + // Push a batch of records, all in one transaction + async fn push_batch( + &self, + records: impl Iterator<Item = &Record<EncryptedData>> + Send + Sync, + ) -> Result<()>; + + async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>>; + + async fn delete(&self, id: RecordId) -> Result<()>; + async fn delete_all(&self) -> Result<()>; + + async fn len_all(&self) -> Result<u64>; + async fn len(&self, host: HostId, tag: &str) -> Result<u64>; + async fn len_tag(&self, tag: &str) -> Result<u64>; + + async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>; + async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>; + + async fn re_encrypt(&self, old_key: &[u8; 32], new_key: &[u8; 32]) -> Result<()>; + async fn verify(&self, key: &[u8; 32]) -> Result<()>; + async fn purge(&self, key: &[u8; 32]) -> Result<()>; + + /// Get the next `limit` records, after and including the given index + async fn next( + &self, + host: HostId, + tag: &str, + idx: RecordIdx, + limit: u64, + ) -> Result<Vec<Record<EncryptedData>>>; + + /// Get the first record for a given host and tag + async fn idx( + &self, + host: HostId, + tag: &str, + idx: RecordIdx, + ) -> Result<Option<Record<EncryptedData>>>; + + async fn status(&self) -> Result<RecordStatus>; + + /// Get all records for a given tag + async fn all_tagged(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>>; +} diff --git a/crates/turtle/src/atuin_client/record/sync.rs b/crates/turtle/src/atuin_client/record/sync.rs new file mode 100644 index 00000000..f831570b --- /dev/null +++ b/crates/turtle/src/atuin_client/record/sync.rs @@ -0,0 +1,664 @@ +// do a sync :O +use std::{cmp::Ordering, fmt::Write}; + +use eyre::Result; +use thiserror::Error; +use tracing::error; + +use super::{encryption::PASETO_V4, store::Store}; +use crate::atuin_client::{api_client::Client, settings::Settings}; + +use crate::atuin_common::record::{Diff, HostId, RecordId, RecordIdx, RecordStatus}; +use indicatif::{ProgressBar, ProgressState, ProgressStyle}; + +#[derive(Error, Debug)] +pub enum SyncError { + #[error("the local store is ahead of the remote, but for another host. has remote lost data?")] + LocalAheadOtherHost, + + #[error("an issue with the local database occurred: {msg:?}")] + LocalStoreError { msg: String }, + + #[error("something has gone wrong with the sync logic: {msg:?}")] + SyncLogicError { msg: String }, + + #[error("operational error: {msg:?}")] + OperationalError { msg: String }, + + #[error("a request to the sync server failed: {msg:?}")] + RemoteRequestError { msg: String }, + + #[error( + "the encryption key on this machine does not match the data on the server. \ + this usually means a new machine was set up without copying the existing key. \ + to fix: run `atuin key` on a machine that already syncs correctly, then run \ + `atuin store rekey <key>` on this machine with the value from the other machine" + )] + WrongKey, +} + +#[derive(Debug, Eq, PartialEq)] +pub enum Operation { + // Either upload or download until the states matches the below + Upload { + local: RecordIdx, + remote: Option<RecordIdx>, + host: HostId, + tag: String, + }, + Download { + local: Option<RecordIdx>, + remote: RecordIdx, + host: HostId, + tag: String, + }, + Noop { + host: HostId, + tag: String, + }, +} + +pub async fn build_client(settings: &Settings) -> Result<Client<'_>, SyncError> { + Client::new( + &settings.sync_address, + settings + .sync_auth_token() + .await + .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?, + settings.network_connect_timeout, + settings.network_timeout, + ) + .map_err(|e| SyncError::OperationalError { msg: e.to_string() }) +} + +pub async fn diff( + client: &Client<'_>, + store: &impl Store, +) -> Result<(Vec<Diff>, RecordStatus), SyncError> { + let local_index = store + .status() + .await + .map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?; + + let remote_index = client + .record_status() + .await + .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?; + + let diff = local_index.diff(&remote_index); + + Ok((diff, remote_index)) +} + +// Take a diff, along with a local store, and resolve it into a set of operations. +// With the store as context, we can determine if a tail exists locally or not and therefore if it needs uploading or download. +// In theory this could be done as a part of the diffing stage, but it's easier to reason +// about and test this way +pub async fn operations( + diffs: Vec<Diff>, + _store: &impl Store, +) -> Result<Vec<Operation>, SyncError> { + let mut operations = Vec::with_capacity(diffs.len()); + + for diff in diffs { + let op = match (diff.local, diff.remote) { + // We both have it! Could be either. Compare. + (Some(local), Some(remote)) => match local.cmp(&remote) { + Ordering::Equal => Operation::Noop { + host: diff.host, + tag: diff.tag, + }, + Ordering::Greater => Operation::Upload { + local, + remote: Some(remote), + host: diff.host, + tag: diff.tag, + }, + Ordering::Less => Operation::Download { + local: Some(local), + remote, + host: diff.host, + tag: diff.tag, + }, + }, + + // Remote has it, we don't. Gotta be download + (None, Some(remote)) => Operation::Download { + local: None, + remote, + host: diff.host, + tag: diff.tag, + }, + + // We have it, remote doesn't. Gotta be upload. + (Some(local), None) => Operation::Upload { + local, + remote: None, + host: diff.host, + tag: diff.tag, + }, + + // something is pretty fucked. + (None, None) => { + return Err(SyncError::SyncLogicError { + msg: String::from( + "diff has nothing for local or remote - (host, tag) does not exist", + ), + }); + } + }; + + operations.push(op); + } + + // sort them - purely so we have a stable testing order, and can rely on + // same input = same output + // We can sort by ID so long as we continue to use UUIDv7 or something + // with the same properties + + operations.sort_by_key(|op| match op { + Operation::Noop { host, tag } => (0, *host, tag.clone()), + + Operation::Upload { host, tag, .. } => (1, *host, tag.clone()), + + Operation::Download { host, tag, .. } => (2, *host, tag.clone()), + }); + + Ok(operations) +} + +async fn sync_upload( + store: &impl Store, + client: &Client<'_>, + host: HostId, + tag: String, + local: RecordIdx, + remote: Option<RecordIdx>, + page_size: u64, +) -> Result<i64, SyncError> { + let remote = remote.unwrap_or(0); + let expected = local - remote; + let mut progress = 0; + + let pb = ProgressBar::new(expected); + pb.set_style(ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {human_pos}/{human_len} ({eta})") + .unwrap() + .with_key("eta", |state: &ProgressState, w: &mut dyn Write| write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()) + .progress_chars("#>-")); + + println!( + "Uploading {} records to {}/{}", + expected, + host.0.as_simple(), + tag + ); + + loop { + let page = store + .next(host, tag.as_str(), remote + progress, page_size) + .await + .map_err(|e| { + error!("failed to read upload page: {e:?}"); + + SyncError::LocalStoreError { msg: e.to_string() } + })?; + + if page.is_empty() { + break; + } + + client.post_records(&page).await.map_err(|e| { + error!("failed to post records: {e:?}"); + + SyncError::RemoteRequestError { msg: e.to_string() } + })?; + + progress += page.len() as u64; + pb.set_position(progress); + + if progress >= expected { + break; + } + } + + pb.finish_with_message("Uploaded records"); + + Ok(progress as i64) +} + +async fn sync_download( + store: &impl Store, + client: &Client<'_>, + host: HostId, + tag: String, + local: Option<RecordIdx>, + remote: RecordIdx, + page_size: u64, +) -> Result<Vec<RecordId>, SyncError> { + let local = local.unwrap_or(0); + let expected = remote - local; + let mut progress = 0; + let mut ret = Vec::new(); + + println!( + "Downloading {} records from {}/{}", + expected, + host.0.as_simple(), + tag + ); + + let pb = ProgressBar::new(expected); + pb.set_style(ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {human_pos}/{human_len} ({eta})") + .unwrap() + .with_key("eta", |state: &ProgressState, w: &mut dyn Write| write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()) + .progress_chars("#>-")); + + loop { + let page = client + .next_records(host, tag.clone(), local + progress, page_size) + .await + .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?; + + if page.is_empty() { + break; + } + + store + .push_batch(page.iter()) + .await + .map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?; + + ret.extend(page.iter().map(|f| f.id)); + + progress += page.len() as u64; + pb.set_position(progress); + + if progress >= expected { + break; + } + } + + pb.finish_with_message("Downloaded records"); + + Ok(ret) +} + +pub async fn sync_remote( + client: &Client<'_>, + operations: Vec<Operation>, + local_store: &impl Store, + page_size: u64, +) -> Result<(i64, Vec<RecordId>), SyncError> { + let mut uploaded = 0; + let mut downloaded = Vec::new(); + + // this can totally run in parallel, but lets get it working first + for i in operations { + match i { + Operation::Upload { + host, + tag, + local, + remote, + } => { + uploaded += + sync_upload(local_store, client, host, tag, local, remote, page_size).await? + } + + Operation::Download { + host, + tag, + local, + remote, + } => { + let mut d = + sync_download(local_store, client, host, tag, local, remote, page_size).await?; + downloaded.append(&mut d) + } + + Operation::Noop { .. } => continue, + } + } + + Ok((uploaded, downloaded)) +} + +pub async fn check_encryption_key( + client: &Client<'_>, + remote_index: &RecordStatus, + encryption_key: &[u8; 32], +) -> Result<(), SyncError> { + let sample = remote_index + .hosts + .iter() + .flat_map(|(host, tags)| tags.keys().map(move |tag| (*host, tag.clone()))) + .next(); + + let Some((host, tag)) = sample else { + return Ok(()); + }; + + let records = client + .next_records(host, tag, 0, 1) + .await + .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?; + + let Some(record) = records.into_iter().next() else { + return Ok(()); + }; + + record + .decrypt::<PASETO_V4>(encryption_key) + .map_err(|_| SyncError::WrongKey)?; + + Ok(()) +} + +pub async fn sync( + settings: &Settings, + store: &impl Store, + encryption_key: &[u8; 32], +) -> Result<(i64, Vec<RecordId>), SyncError> { + let client = build_client(settings).await?; + let (diff, remote_index) = diff(&client, store).await?; + + // Bail before mutating either side if the local key can't read the remote. + check_encryption_key(&client, &remote_index, encryption_key).await?; + + let operations = operations(diff, store).await?; + let (uploaded, downloaded) = sync_remote(&client, operations, store, 100).await?; + + Ok((uploaded, downloaded)) +} + +#[cfg(test)] +mod tests { + use crate::atuin_common::record::{Diff, EncryptedData, HostId, Record}; + use pretty_assertions::assert_eq; + + use crate::atuin_client::{ + record::{ + encryption::PASETO_V4, + sqlite_store::SqliteStore, + store::Store, + sync::{self, Operation}, + }, + settings::test_local_timeout, + }; + + fn test_record() -> Record<EncryptedData> { + Record::builder() + .host(crate::atuin_common::record::Host::new(HostId( + crate::atuin_common::utils::uuid_v7(), + ))) + .version("v1".into()) + .tag(crate::atuin_common::utils::uuid_v7().simple().to_string()) + .data(EncryptedData { + data: String::new(), + content_encryption_key: String::new(), + }) + .idx(0) + .build() + } + + // Take a list of local records, and a list of remote records. + // Return the local database, and a diff of local/remote, ready to build + // ops + async fn build_test_diff( + local_records: Vec<Record<EncryptedData>>, + remote_records: Vec<Record<EncryptedData>>, + ) -> (SqliteStore, Vec<Diff>) { + let local_store = SqliteStore::new(":memory:", test_local_timeout()) + .await + .expect("failed to open in memory sqlite"); + let remote_store = SqliteStore::new(":memory:", test_local_timeout()) + .await + .expect("failed to open in memory sqlite"); // "remote" + + for i in local_records { + local_store.push(&i).await.unwrap(); + } + + for i in remote_records { + remote_store.push(&i).await.unwrap(); + } + + let local_index = local_store.status().await.unwrap(); + let remote_index = remote_store.status().await.unwrap(); + + let diff = local_index.diff(&remote_index); + + (local_store, diff) + } + + #[tokio::test] + async fn test_basic_diff() { + // a diff where local is ahead of remote. nothing else. + + let record = test_record(); + let (store, diff) = build_test_diff(vec![record.clone()], vec![]).await; + + assert_eq!(diff.len(), 1); + + let operations = sync::operations(diff, &store).await.unwrap(); + + assert_eq!(operations.len(), 1); + + assert_eq!( + operations[0], + Operation::Upload { + host: record.host.id, + tag: record.tag, + local: record.idx, + remote: None, + } + ); + } + + #[tokio::test] + async fn build_two_way_diff() { + // a diff where local is ahead of remote for one, and remote for + // another. One upload, one download + + let shared_record = test_record(); + let remote_ahead = test_record(); + + let local_ahead = shared_record + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + + assert_eq!(local_ahead.idx, 1); + + let local = vec![shared_record.clone(), local_ahead.clone()]; // local knows about the already synced, and something newer in the same store + let remote = vec![shared_record.clone(), remote_ahead.clone()]; // remote knows about the already-synced, and one new record in a new store + + let (store, diff) = build_test_diff(local, remote).await; + let operations = sync::operations(diff, &store).await.unwrap(); + + assert_eq!(operations.len(), 2); + + assert_eq!( + operations, + vec![ + // Or in otherwords, local is ahead by one + Operation::Upload { + host: local_ahead.host.id, + tag: local_ahead.tag, + local: 1, + remote: Some(0), + }, + // Or in other words, remote knows of a record in an entirely new store (tag) + Operation::Download { + host: remote_ahead.host.id, + tag: remote_ahead.tag, + local: None, + remote: 0, + }, + ] + ); + } + + #[tokio::test] + async fn build_complex_diff() { + // One shared, ahead but known only by remote + // One known only by local + // One known only by remote + + let shared_record = test_record(); + let local_only = test_record(); + + let local_only_20 = test_record(); + let local_only_21 = local_only_20 + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + let local_only_22 = local_only_21 + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + let local_only_23 = local_only_22 + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + + let remote_only = test_record(); + + let remote_only_20 = test_record(); + let remote_only_21 = remote_only_20 + .append(vec![2, 3, 2]) + .encrypt::<PASETO_V4>(&[0; 32]); + let remote_only_22 = remote_only_21 + .append(vec![2, 3, 2]) + .encrypt::<PASETO_V4>(&[0; 32]); + let remote_only_23 = remote_only_22 + .append(vec![2, 3, 2]) + .encrypt::<PASETO_V4>(&[0; 32]); + let remote_only_24 = remote_only_23 + .append(vec![2, 3, 2]) + .encrypt::<PASETO_V4>(&[0; 32]); + + let second_shared = test_record(); + let second_shared_remote_ahead = second_shared + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + let second_shared_remote_ahead2 = second_shared_remote_ahead + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + + let third_shared = test_record(); + let third_shared_local_ahead = third_shared + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + let third_shared_local_ahead2 = third_shared_local_ahead + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + + let fourth_shared = test_record(); + let fourth_shared_remote_ahead = fourth_shared + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + let fourth_shared_remote_ahead2 = fourth_shared_remote_ahead + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + + let local = vec![ + shared_record.clone(), + second_shared.clone(), + third_shared.clone(), + fourth_shared.clone(), + fourth_shared_remote_ahead.clone(), + // single store, only local has it + local_only.clone(), + // bigger store, also only known by local + local_only_20.clone(), + local_only_21.clone(), + local_only_22.clone(), + local_only_23.clone(), + // another shared store, but local is ahead on this one + third_shared_local_ahead.clone(), + third_shared_local_ahead2.clone(), + ]; + + let remote = vec![ + remote_only.clone(), + remote_only_20.clone(), + remote_only_21.clone(), + remote_only_22.clone(), + remote_only_23.clone(), + remote_only_24.clone(), + shared_record.clone(), + second_shared.clone(), + third_shared.clone(), + second_shared_remote_ahead.clone(), + second_shared_remote_ahead2.clone(), + fourth_shared.clone(), + fourth_shared_remote_ahead.clone(), + fourth_shared_remote_ahead2.clone(), + ]; // remote knows about the already-synced, and one new record in a new store + + let (store, diff) = build_test_diff(local, remote).await; + let operations = sync::operations(diff, &store).await.unwrap(); + + assert_eq!(operations.len(), 7); + + let mut result_ops = vec![ + // We started with a shared record, but the remote knows of two newer records in the + // same store + Operation::Download { + local: Some(0), + remote: 2, + host: second_shared_remote_ahead.host.id, + tag: second_shared_remote_ahead.tag, + }, + // We have a shared record, local knows of the first two but not the last + Operation::Download { + local: Some(1), + remote: 2, + host: fourth_shared_remote_ahead2.host.id, + tag: fourth_shared_remote_ahead2.tag, + }, + // Remote knows of a store with a single record that local does not have + Operation::Download { + local: None, + remote: 0, + host: remote_only.host.id, + tag: remote_only.tag, + }, + // Remote knows of a store with a bunch of records that local does not have + Operation::Download { + local: None, + remote: 4, + host: remote_only_20.host.id, + tag: remote_only_20.tag, + }, + // Local knows of a record in a store that remote does not have + Operation::Upload { + local: 0, + remote: None, + host: local_only.host.id, + tag: local_only.tag, + }, + // Local knows of 4 records in a store that remote does not have + Operation::Upload { + local: 3, + remote: None, + host: local_only_20.host.id, + tag: local_only_20.tag, + }, + // Local knows of 2 more records in a shared store that remote only has one of + Operation::Upload { + local: 2, + remote: Some(0), + host: third_shared.host.id, + tag: third_shared.tag, + }, + ]; + + result_ops.sort_by_key(|op| match op { + Operation::Noop { host, tag } => (0, *host, tag.clone()), + + Operation::Upload { host, tag, .. } => (1, *host, tag.clone()), + + Operation::Download { host, tag, .. } => (2, *host, tag.clone()), + }); + + assert_eq!(result_ops, operations); + } +} diff --git a/crates/turtle/src/atuin_client/register.rs b/crates/turtle/src/atuin_client/register.rs new file mode 100644 index 00000000..4b14c233 --- /dev/null +++ b/crates/turtle/src/atuin_client/register.rs @@ -0,0 +1,20 @@ +use eyre::Result; + +use crate::atuin_client::{api_client, settings::Settings}; + +pub async fn register_classic( + settings: &Settings, + username: String, + email: String, + password: String, +) -> Result<String> { + let session = + api_client::register(settings.sync_address.as_str(), &username, &email, &password).await?; + + let meta = Settings::meta_store().await?; + meta.save_session(&session.session).await?; + + let _key = crate::atuin_client::encryption::load_key(settings)?; + + Ok(session.session) +} diff --git a/crates/turtle/src/atuin_client/secrets.rs b/crates/turtle/src/atuin_client/secrets.rs new file mode 100644 index 00000000..e8a6ab62 --- /dev/null +++ b/crates/turtle/src/atuin_client/secrets.rs @@ -0,0 +1,194 @@ +// This file will probably trigger a lot of scanners. Sorry. + +use regex::RegexSet; +use std::sync::LazyLock; + +pub enum TestValue<'a> { + Single(&'a str), + Multiple(&'a [&'a str]), +} + +/// A list of `(name, regex, test)`, where `test` should match against `regex`. +pub static SECRET_PATTERNS: &[(&str, &str, TestValue)] = &[ + ( + "AWS Access Key ID", + "A[KS]IA[0-9A-Z]{16}", + TestValue::Single("AKIAIOSFODNN7EXAMPLE"), + ), + ( + "AWS Secret Access Key env var", + "AWS_SECRET_ACCESS_KEY", + TestValue::Single("AWS_SECRET_ACCESS_KEY=KEYDATA"), + ), + ( + "AWS Session Token env var", + "AWS_SESSION_TOKEN", + TestValue::Single("AWS_SESSION_TOKEN=KEYDATA"), + ), + ( + "Microsoft Azure secret access key env var", + "AZURE_.*_KEY", + TestValue::Single("export AZURE_STORAGE_ACCOUNT_KEY=KEYDATA"), + ), + ( + "Google cloud platform key env var", + "GOOGLE_SERVICE_ACCOUNT_KEY", + TestValue::Single("export GOOGLE_SERVICE_ACCOUNT_KEY=KEYDATA"), + ), + ( + "Atuin login", + r"atuin\s+login", + TestValue::Single( + "atuin login -u mycoolusername -p mycoolpassword -k \"lots of random words\"", + ), + ), + ( + "GitHub PAT (old)", + "ghp_[a-zA-Z0-9]{36}", + TestValue::Single("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"), // legit, I expired it + ), + ( + "GitHub PAT (new)", + "gh1_[A-Za-z0-9]{21}_[A-Za-z0-9]{59}|github_pat_[0-9][A-Za-z0-9]{21}_[A-Za-z0-9]{59}", + TestValue::Multiple(&[ + "gh1_1234567890abcdefghijk_1234567890abcdefghijklmnopqrstuvwxyz1234567890abcdefghijklm", + "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", // also legit, also expired + ]), + ), + ( + "GitHub OAuth Access Token", + "gho_[A-Za-z0-9]{36}", + TestValue::Single("gho_1234567890abcdefghijklmnopqrstuvwx000"), // not a real token + ), + ( + "GitHub OAuth Access Token (user)", + "ghu_[A-Za-z0-9]{36}", + TestValue::Single("ghu_1234567890abcdefghijklmnopqrstuvwx000"), // not a real token + ), + ( + "GitHub App Installation Access Token", + "ghs_[A-Za-z0-9._-]{36,}", + TestValue::Multiple(&[ + "ghs_1234567890abcdefghijklmnopqrstuvwx000", // not a real token + "ghs_abc-def.ghi_jklMNOP0123456789qrstuv-wxyzABCD", // new token format, fake data + ]), + ), + ( + "GitHub Refresh Token", + "ghr_[A-Za-z0-9]{76}", + TestValue::Single( + "ghr_1234567890abcdefghijklmnopqrstuvwx1234567890abcdefghijklmnopqrstuvwx1234567890abcdefghijklmnopqrstuvwx", + ), // not a real token + ), + ( + "GitHub App Installation Access Token v1", + "v1\\.[0-9A-Fa-f]{40}", + TestValue::Single("v1.1234567890abcdef1234567890abcdef12345678"), // not a real token + ), + ( + "GitLab PAT", + "glpat-[a-zA-Z0-9_]{20}", + TestValue::Single("glpat-RkE_BG5p_bbjML21WSfy"), + ), + ( + "Slack OAuth v2 bot", + "xoxb-[0-9]{11}-[0-9]{11}-[0-9a-zA-Z]{24}", + TestValue::Single("xoxb-17653672481-19874698323-pdFZKVeTuE8sk7oOcBrzbqgy"), + ), + ( + "Slack OAuth v2 user token", + "xoxp-[0-9]{11}-[0-9]{11}-[0-9a-zA-Z]{24}", + TestValue::Single("xoxp-17653672481-19874698323-pdFZKVeTuE8sk7oOcBrzbqgy"), + ), + ( + "Slack webhook", + "T[a-zA-Z0-9_]{8}/B[a-zA-Z0-9_]{8}/[a-zA-Z0-9_]{24}", + TestValue::Single( + "https://hooks.slack.com/services/T00000000/B00000000/XXXXXXXXXXXXXXXXXXXXXXXX", + ), + ), + ( + "Stripe test key", + "sk_test_[0-9a-zA-Z]{24}", + TestValue::Single("sk_test_1234567890abcdefghijklmnop"), + ), + ( + "Stripe live key", + "sk_live_[0-9a-zA-Z]{24}", + TestValue::Single("sk_live_1234567890abcdefghijklmnop"), + ), + ( + "Netlify authentication token", + "nf[pcoub]_[0-9a-zA-Z]{36}", + TestValue::Single("nfp_nBh7BdJxUwyaBBwFzpyD29MMFT6pZ9wq5634"), + ), + ( + "npm token", + "npm_[A-Za-z0-9]{36}", + TestValue::Single("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"), + ), + ( + "Pulumi personal access token", + "pul-[0-9a-f]{40}", + TestValue::Single("pul-683c2770662c51d960d72ec27613be7653c5cb26"), + ), +]; + +/// The `regex` expressions from [`SECRET_PATTERNS`] compiled into a `RegexSet`. +pub static SECRET_PATTERNS_RE: LazyLock<RegexSet> = LazyLock::new(|| { + let exprs = SECRET_PATTERNS.iter().map(|f| f.1); + RegexSet::new(exprs).expect("Failed to build secrets regex") +}); + +#[cfg(test)] +mod tests { + use regex::Regex; + + use crate::secrets::{SECRET_PATTERNS, TestValue}; + + #[test] + fn test_secrets() { + for (name, regex, test) in SECRET_PATTERNS { + let re = + Regex::new(regex).unwrap_or_else(|_| panic!("Failed to compile regex for {name}")); + + match test { + TestValue::Single(test) => { + assert!(re.is_match(test), "{name} test failed!"); + } + TestValue::Multiple(tests) => { + for test_str in tests.iter() { + assert!( + re.is_match(test_str), + "{name} test with value \"{test_str}\" failed!" + ); + } + } + } + } + } + + #[test] + fn test_secrets_embedded() { + for (name, regex, test) in SECRET_PATTERNS { + let re = + Regex::new(regex).unwrap_or_else(|_| panic!("Failed to compile regex for {name}")); + + match test { + TestValue::Single(test) => { + let embedded = format!("some random text {test} some more random text"); + assert!(re.is_match(&embedded), "{name} embedded test failed!"); + } + TestValue::Multiple(tests) => { + for test_str in tests.iter() { + let embedded = format!("some random text {test_str} some more random text"); + assert!( + re.is_match(&embedded), + "{name} embedded test with value \"{test_str}\" failed!" + ); + } + } + } + } + } +} diff --git a/crates/turtle/src/atuin_client/settings.rs b/crates/turtle/src/atuin_client/settings.rs new file mode 100644 index 00000000..b0ffc7c1 --- /dev/null +++ b/crates/turtle/src/atuin_client/settings.rs @@ -0,0 +1,1851 @@ +use std::{collections::HashMap, fmt, io::prelude::*, path::PathBuf, str::FromStr, sync::OnceLock}; +use tokio::sync::OnceCell; + +use crate::atuin_common::record::HostId; +use crate::atuin_common::utils; +use clap::ValueEnum; +use config::{ + Config, ConfigBuilder, Environment, File as ConfigFile, FileFormat, builder::DefaultState, +}; +use eyre::{Context, Error, Result, bail, eyre}; +use fs_err::{File, create_dir_all}; +use humantime::parse_duration; +use regex::RegexSet; +use serde::{Deserialize, Serialize}; +use serde_with::DeserializeFromStr; +use time::{OffsetDateTime, UtcOffset, format_description::FormatItem, macros::format_description}; + +pub const HISTORY_PAGE_SIZE: i64 = 100; + +static DATA_DIR: OnceLock<PathBuf> = OnceLock::new(); +static META_CONFIG: OnceLock<(String, f64)> = OnceLock::new(); +static META_STORE: OnceCell<crate::atuin_client::meta::MetaStore> = OnceCell::const_new(); + +pub(crate) mod meta; +pub mod watcher; + +/// Default sync address for Atuin's hosted service +pub const DEFAULT_SYNC_ADDRESS: &str = "https://api.atuin.sh"; + +#[derive(Clone, Debug, Deserialize, Copy, ValueEnum, PartialEq, Serialize)] +pub enum SearchMode { + #[serde(rename = "prefix")] + Prefix, + + #[serde(rename = "fulltext")] + #[clap(aliases = &["fulltext"])] + FullText, + + #[serde(rename = "fuzzy")] + Fuzzy, + + #[serde(rename = "skim")] + Skim, + + #[serde(rename = "daemon-fuzzy")] + #[clap(aliases = &["daemon-fuzzy"])] + DaemonFuzzy, +} + +impl SearchMode { + pub fn as_str(&self) -> &'static str { + match self { + SearchMode::Prefix => "PREFIX", + SearchMode::FullText => "FULLTXT", + SearchMode::Fuzzy => "FUZZY", + SearchMode::Skim => "SKIM", + SearchMode::DaemonFuzzy => "DAEMON", + } + } + pub fn next(&self, settings: &Settings) -> Self { + match self { + SearchMode::Prefix => SearchMode::FullText, + // if the user is using skim, we go to skim + SearchMode::FullText if settings.search_mode == SearchMode::Skim => SearchMode::Skim, + // if the user is using daemon-fuzzy, we go to daemon-fuzzy + SearchMode::FullText if settings.search_mode == SearchMode::DaemonFuzzy => { + SearchMode::DaemonFuzzy + } + // otherwise fuzzy. + SearchMode::FullText => SearchMode::Fuzzy, + SearchMode::Fuzzy | SearchMode::Skim | SearchMode::DaemonFuzzy => SearchMode::Prefix, + } + } +} + +#[derive(Clone, Debug, Deserialize, Copy, PartialEq, Eq, ValueEnum, Serialize)] +pub enum FilterMode { + #[serde(rename = "global")] + Global = 0, + + #[serde(rename = "host")] + Host = 1, + + #[serde(rename = "session")] + Session = 2, + + #[serde(rename = "directory")] + Directory = 3, + + #[serde(rename = "workspace")] + Workspace = 4, + + #[serde(rename = "session-preload")] + SessionPreload = 5, +} + +impl FilterMode { + pub fn as_str(&self) -> &'static str { + match self { + FilterMode::Global => "GLOBAL", + FilterMode::Host => "HOST", + FilterMode::Session => "SESSION", + FilterMode::Directory => "DIRECTORY", + FilterMode::Workspace => "WORKSPACE", + FilterMode::SessionPreload => "SESSION+", + } + } +} + +#[derive(Clone, Debug, Deserialize, Copy, Serialize)] +pub enum ExitMode { + #[serde(rename = "return-original")] + ReturnOriginal, + + #[serde(rename = "return-query")] + ReturnQuery, +} + +// FIXME: Can use upstream Dialect enum if https://github.com/stevedonovan/chrono-english/pull/16 is merged +// FIXME: Above PR was merged, but dependency was changed to interim (fork of chrono-english) in the ... interim +#[derive(Clone, Debug, Deserialize, Copy, Serialize)] +pub enum Dialect { + #[serde(rename = "us")] + Us, + + #[serde(rename = "uk")] + Uk, +} + +impl From<Dialect> for interim::Dialect { + fn from(d: Dialect) -> interim::Dialect { + match d { + Dialect::Uk => interim::Dialect::Uk, + Dialect::Us => interim::Dialect::Us, + } + } +} + +/// Type wrapper around `time::UtcOffset` to support a wider variety of timezone formats. +/// +/// Note that the parsing of this struct needs to be done before starting any +/// multithreaded runtime, otherwise it will fail on most Unix systems. +/// +/// See: <https://github.com/atuinsh/atuin/pull/1517#discussion_r1447516426> +#[derive(Clone, Copy, Debug, Eq, PartialEq, DeserializeFromStr, Serialize)] +pub struct Timezone(pub UtcOffset); +impl fmt::Display for Timezone { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} +/// format: <+|-><hour>[:<minute>[:<second>]] +static OFFSET_FMT: &[FormatItem<'_>] = format_description!( + "[offset_hour sign:mandatory padding:none][optional [:[offset_minute padding:none][optional [:[offset_second padding:none]]]]]" +); +impl FromStr for Timezone { + type Err = Error; + + fn from_str(s: &str) -> Result<Self> { + // local timezone + if matches!(s.to_lowercase().as_str(), "l" | "local") { + // There have been some timezone issues, related to errors fetching it on some + // platforms + // Rather than fail to start, fallback to UTC. The user should still be able to specify + // their timezone manually in the config file. + let offset = UtcOffset::current_local_offset().unwrap_or(UtcOffset::UTC); + return Ok(Self(offset)); + } + + if matches!(s.to_lowercase().as_str(), "0" | "utc") { + let offset = UtcOffset::UTC; + return Ok(Self(offset)); + } + + // offset from UTC + if let Ok(offset) = UtcOffset::parse(s, OFFSET_FMT) { + return Ok(Self(offset)); + } + + // IDEA: Currently named timezones are not supported, because the well-known crate + // for this is `chrono_tz`, which is not really interoperable with the datetime crate + // that we currently use - `time`. If ever we migrate to using `chrono`, this would + // be a good feature to add. + + bail!(r#""{s}" is not a valid timezone spec"#) + } +} + +#[derive(Clone, Debug, Deserialize, Copy, Serialize)] +pub enum Style { + #[serde(rename = "auto")] + Auto, + + #[serde(rename = "full")] + Full, + + #[serde(rename = "compact")] + Compact, +} + +#[derive(Clone, Debug, Deserialize, Copy, Serialize)] +pub enum WordJumpMode { + #[serde(rename = "emacs")] + Emacs, + + #[serde(rename = "subl")] + Subl, +} + +#[derive(Clone, Debug, Deserialize, Copy, PartialEq, Eq, ValueEnum, Serialize)] +pub enum KeymapMode { + #[serde(rename = "emacs")] + Emacs, + + #[serde(rename = "vim-normal")] + VimNormal, + + #[serde(rename = "vim-insert")] + VimInsert, + + #[serde(rename = "auto")] + Auto, +} + +impl KeymapMode { + pub fn as_str(&self) -> &'static str { + match self { + KeymapMode::Emacs => "EMACS", + KeymapMode::VimNormal => "VIMNORMAL", + KeymapMode::VimInsert => "VIMINSERT", + KeymapMode::Auto => "AUTO", + } + } +} + +// We want to translate the config to crossterm::cursor::SetCursorStyle, but +// the original type does not implement trait serde::Deserialize unfortunately. +// It seems impossible to implement Deserialize for external types when it is +// used in HashMap (https://stackoverflow.com/questions/67142663). We instead +// define an adapter type. +#[derive(Clone, Debug, Deserialize, Copy, PartialEq, Eq, ValueEnum, Serialize)] +pub enum CursorStyle { + #[serde(rename = "default")] + DefaultUserShape, + + #[serde(rename = "blink-block")] + BlinkingBlock, + + #[serde(rename = "steady-block")] + SteadyBlock, + + #[serde(rename = "blink-underline")] + BlinkingUnderScore, + + #[serde(rename = "steady-underline")] + SteadyUnderScore, + + #[serde(rename = "blink-bar")] + BlinkingBar, + + #[serde(rename = "steady-bar")] + SteadyBar, +} + +impl CursorStyle { + pub fn as_str(&self) -> &'static str { + match self { + CursorStyle::DefaultUserShape => "DEFAULT", + CursorStyle::BlinkingBlock => "BLINKBLOCK", + CursorStyle::SteadyBlock => "STEADYBLOCK", + CursorStyle::BlinkingUnderScore => "BLINKUNDERLINE", + CursorStyle::SteadyUnderScore => "STEADYUNDERLINE", + CursorStyle::BlinkingBar => "BLINKBAR", + CursorStyle::SteadyBar => "STEADYBAR", + } + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Stats { + #[serde(default = "Stats::common_prefix_default")] + pub common_prefix: Vec<String>, // sudo, etc. commands we want to strip off + #[serde(default = "Stats::common_subcommands_default")] + pub common_subcommands: Vec<String>, // kubectl, commands we should consider subcommands for + #[serde(default = "Stats::ignored_commands_default")] + pub ignored_commands: Vec<String>, // cd, ls, etc. commands we want to completely hide from stats +} + +impl Stats { + fn common_prefix_default() -> Vec<String> { + vec!["sudo", "doas"].into_iter().map(String::from).collect() + } + + fn common_subcommands_default() -> Vec<String> { + vec![ + "apt", + "cargo", + "composer", + "dnf", + "docker", + "dotnet", + "git", + "go", + "ip", + "jj", + "kubectl", + "nix", + "nmcli", + "npm", + "pecl", + "pnpm", + "podman", + "port", + "systemctl", + "tmux", + "yarn", + ] + .into_iter() + .map(String::from) + .collect() + } + + fn ignored_commands_default() -> Vec<String> { + vec![] + } +} + +impl Default for Stats { + fn default() -> Self { + Self { + common_prefix: Self::common_prefix_default(), + common_subcommands: Self::common_subcommands_default(), + ignored_commands: Self::ignored_commands_default(), + } + } +} + +/// Sync protocol type for authentication. +/// +/// This setting is primarily for development/testing. When not explicitly set, +/// the protocol is inferred from the sync_address: +/// - Default sync address (api.atuin.sh) → Hub protocol +/// - Custom sync address → Legacy protocol +/// +/// Set explicitly to "hub" to use Hub authentication with a custom sync_address +/// (useful for local development against a Hub instance). +#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Eq, Serialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum SyncProtocol { + /// Use legacy CLI authentication (Token from CLI register/login) + #[default] + Legacy, +} + +/// Resolved authentication state for sync operations. +/// +/// Determined at runtime by examining which tokens are available and what +/// server the client is configured to talk to. Operations use this to pick +/// the right auth header and endpoint style. +#[cfg(feature = "sync")] +#[derive(Debug, Clone)] +pub enum SyncAuth { + /// Self-hosted Rust server. Uses `Authorization: Token <session>` and + /// legacy endpoints. + Legacy { token: String }, + + /// Not authenticated at all. Contains an actionable user-facing message. + NotLoggedIn { reason: String }, +} + +#[cfg(feature = "sync")] +impl SyncAuth { + /// Convert into the auth token type used by the API client. + /// + /// Returns an error with an actionable message for `NotLoggedIn`. + pub fn into_auth_token(self) -> Result<crate::atuin_client::api_client::AuthToken> { + use crate::atuin_client::api_client::AuthToken; + match self { + SyncAuth::Legacy { token } => Ok(AuthToken::Token(token)), + SyncAuth::NotLoggedIn { reason } => Err(eyre!(reason)), + } + } +} + +#[derive(Clone, Debug, Deserialize, Default, Serialize)] +pub struct Keys { + pub scroll_exits: bool, + pub exit_past_line_start: bool, + pub accept_past_line_end: bool, + pub accept_past_line_start: bool, + pub accept_with_backspace: bool, + pub prefix: String, +} + +impl Keys { + /// The standard default values for all `[keys]` options. + /// These match the config defaults set in `builder_with_data_dir()`. + pub fn standard_defaults() -> Self { + Keys { + scroll_exits: true, + exit_past_line_start: true, + accept_past_line_end: true, + accept_past_line_start: false, + accept_with_backspace: false, + prefix: "a".to_string(), + } + } + + /// Returns true if any value differs from the standard defaults. + pub fn has_non_default_values(&self) -> bool { + let d = Self::standard_defaults(); + self.scroll_exits != d.scroll_exits + || self.exit_past_line_start != d.exit_past_line_start + || self.accept_past_line_end != d.accept_past_line_end + || self.accept_past_line_start != d.accept_past_line_start + || self.accept_with_backspace != d.accept_with_backspace + || self.prefix != d.prefix + } +} + +/// A single rule within a conditional keybinding config. +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct KeyRuleConfig { + /// Optional condition expression (e.g. "cursor-at-start", "input-empty && no-results"). + /// If absent, the rule always matches. + #[serde(default)] + pub when: Option<String>, + /// The action to perform (e.g. "exit", "cursor-left", "accept"). + pub action: String, +} + +/// A keybinding config value: either a simple action string or an ordered list of conditional rules. +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(untagged)] +pub enum KeyBindingConfig { + /// Simple unconditional binding: `"ctrl-c" = "return-original"` + Simple(String), + /// Conditional binding: `"left" = [{ when = "cursor-at-start", action = "exit" }, { action = "cursor-left" }]` + Rules(Vec<KeyRuleConfig>), +} + +/// User-facing keymap configuration. Each mode maps key strings to bindings. +/// Keys present here override the defaults for that key; unmentioned keys keep defaults. +#[derive(Clone, Debug, Deserialize, Serialize, Default)] +pub struct KeymapConfig { + #[serde(default)] + pub emacs: HashMap<String, KeyBindingConfig>, + #[serde(default, rename = "vim-normal")] + pub vim_normal: HashMap<String, KeyBindingConfig>, + #[serde(default, rename = "vim-insert")] + pub vim_insert: HashMap<String, KeyBindingConfig>, + #[serde(default)] + pub inspector: HashMap<String, KeyBindingConfig>, + #[serde(default)] + pub prefix: HashMap<String, KeyBindingConfig>, +} + +impl KeymapConfig { + /// Returns true if no keybinding overrides are configured in any mode. + pub fn is_empty(&self) -> bool { + self.emacs.is_empty() + && self.vim_normal.is_empty() + && self.vim_insert.is_empty() + && self.inspector.is_empty() + && self.prefix.is_empty() + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Preview { + pub strategy: PreviewStrategy, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Theme { + /// Name of desired theme ("default" for base) + pub name: String, + + /// Whether any available additional theme debug should be shown + pub debug: Option<bool>, + + /// How many levels of parenthood will be traversed if needed + pub max_depth: Option<u8>, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Daemon { + /// Use the daemon to sync + /// If enabled, history hooks are routed through the daemon. + #[serde(alias = "enable")] + pub enabled: bool, + + /// Automatically start and manage a local daemon when needed. + pub autostart: bool, + + /// The daemon will handle sync on an interval. How often to sync, in seconds. + pub sync_frequency: u64, + + /// The path to the unix socket used by the daemon + pub socket_path: String, + + /// Path to the daemon pidfile used for process coordination. + pub pidfile_path: String, + + /// Use a socket passed via systemd's socket activation protocol, instead of the path + pub systemd_socket: bool, + + /// The port that should be used for TCP on non unix systems + pub tcp_port: u64, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Search { + /// The list of enabled filter modes, in order of priority. + pub filters: Vec<FilterMode>, + + /// The recency score multiplier for the search index (default: 1.0). + /// Values < 1.0 reduce weight, > 1.0 increase weight, 0.0 disables. + pub recency_score_multiplier: f64, + + /// The frequency score multiplier for the search index (default: 1.0). + /// Values < 1.0 reduce weight, > 1.0 increase weight, 0.0 disables. + pub frequency_score_multiplier: f64, + + /// The overall frecency score multiplier for the search index (default: 1.0). + /// Applied after combining recency and frequency scores. + pub frecency_score_multiplier: f64, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Tmux { + /// Enable using atuin with tmux popup (tmux >= 3.2) + pub enabled: bool, + + /// Width of the tmux popup (percentage) + pub width: String, + + /// Height of the tmux popup (percentage) + pub height: String, +} + +/// Log level for file logging. Maps to tracing's LevelFilter. +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum LogLevel { + Trace, + Debug, + #[default] + Info, + Warn, + Error, +} + +impl LogLevel { + /// Convert to a tracing directive string for use with EnvFilter. + pub fn as_directive(&self) -> &'static str { + match self { + LogLevel::Trace => "trace", + LogLevel::Debug => "debug", + LogLevel::Info => "info", + LogLevel::Warn => "warn", + LogLevel::Error => "error", + } + } +} + +/// Configuration for a specific log type (search or daemon). +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +pub struct LogConfig { + /// Log file name (relative to dir) or absolute path. + pub file: String, + + /// Override global enabled setting for this log type. + pub enabled: Option<bool>, + + /// Override global level setting for this log type. + pub level: Option<LogLevel>, + + /// Override global retention days setting for this log type. + pub retention: Option<u64>, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Logs { + /// Enable file logging globally. Defaults to true. + #[serde(default = "Logs::default_enabled")] + pub enabled: bool, + + /// Directory for log files. Defaults to ~/.atuin/logs + pub dir: String, + + /// Default log level for file logging. Defaults to "info". + /// Note: ATUIN_LOG environment variable overrides this. + #[serde(default)] + pub level: LogLevel, + + /// Default retention days for log files. Defaults to 4. + #[serde(default = "Logs::default_retention")] + pub retention: u64, + + /// Search log settings + #[serde(default)] + pub search: LogConfig, + + /// Daemon log settings + #[serde(default)] + pub daemon: LogConfig, + + /// AI log settings + #[serde(default)] + pub ai: LogConfig, +} + +#[derive(Default, Clone, Debug, Deserialize, Serialize)] +pub struct Ai { + /// Whether or not the AI features are enabled. + pub enabled: Option<bool>, + + /// The address of the Atuin AI endpoint. Used for AI features like command generation. + /// Only necessary for custom AI endpoints. + pub endpoint: Option<String>, + + /// The API token for the Atuin AI endpoint. Used for AI features like command generation. + /// Only necessary for custom AI endpoints. + pub api_token: Option<String>, + + /// Path to the AI sessions database. + pub db_path: String, + + /// The maximum time in minutes that an AI session can be automatically resumed. + pub session_continue_minutes: i64, + + /// Deprecated: use opening.send_cwd instead. Kept for backwards compatibility. + #[serde(default)] + pub send_cwd: Option<bool>, + + /// Configuration for what context is sent in the opening AI request. + #[serde(default)] + pub opening: AiOpening, + + /// Tool capability flags. + #[serde(default)] + pub capabilities: AiCapabilities, +} + +#[derive(Default, Clone, Debug, Deserialize, Serialize)] +pub struct AiCapabilities { + /// Whether the AI can request to search Atuin history. `None` = unset (defaults to enabled, and the ai will ask for permission). + pub enable_history_search: Option<bool>, + /// Whether the AI can request to view the stored output, if any, for Atuin history entries. `None` = unset (defaults to enabled, and the ai will ask for permission). + pub enable_history_output: Option<bool>, + /// Whether the AI can request to read and write files. `None` = unset (defaults to enabled, and the ai will ask for permission). + pub enable_file_tools: Option<bool>, + /// Whether the AI can request to execute bash commands. `None` = unset (defaults to enabled, and the ai will ask for permission). + pub enable_command_execution: Option<bool>, +} + +#[derive(Default, Clone, Debug, Deserialize, Serialize)] +pub struct AiOpening { + /// Whether or not to send the current working directory to the AI endpoint. + pub send_cwd: Option<bool>, + + /// Whether or not to send the last command as context in the opening AI request. + pub send_last_command: Option<bool>, +} + +impl Default for Preview { + fn default() -> Self { + Self { + strategy: PreviewStrategy::Auto, + } + } +} + +impl Default for Theme { + fn default() -> Self { + Self { + name: "".to_string(), + debug: None::<bool>, + max_depth: Some(10), + } + } +} + +impl Default for Daemon { + fn default() -> Self { + Self { + enabled: false, + autostart: false, + sync_frequency: 300, + socket_path: "".to_string(), + pidfile_path: "".to_string(), + systemd_socket: false, + tcp_port: 8889, + } + } +} + +impl Default for Logs { + fn default() -> Self { + Self { + enabled: true, + dir: "".to_string(), + level: LogLevel::default(), + retention: Self::default_retention(), + search: LogConfig { + file: "search.log".to_string(), + ..Default::default() + }, + daemon: LogConfig { + file: "daemon.log".to_string(), + ..Default::default() + }, + ai: LogConfig { + file: "ai.log".to_string(), + ..Default::default() + }, + } + } +} + +impl Logs { + fn default_enabled() -> bool { + true + } + + fn default_retention() -> u64 { + 4 + } + + /// Returns whether search logging is enabled. + /// Uses search-specific setting if set, otherwise falls back to global. + pub fn search_enabled(&self) -> bool { + self.search.enabled.unwrap_or(self.enabled) + } + + /// Returns whether daemon logging is enabled. + /// Uses daemon-specific setting if set, otherwise falls back to global. + pub fn daemon_enabled(&self) -> bool { + self.daemon.enabled.unwrap_or(self.enabled) + } + + /// Returns whether AI logging is enabled. + /// Uses AI-specific setting if set, otherwise falls back to global. + pub fn ai_enabled(&self) -> bool { + self.ai.enabled.unwrap_or(self.enabled) + } + + /// Returns the log level for search logging. + /// Uses search-specific setting if set, otherwise falls back to global. + pub fn search_level(&self) -> LogLevel { + self.search.level.unwrap_or(self.level) + } + + /// Returns the log level for daemon logging. + /// Uses daemon-specific setting if set, otherwise falls back to global. + pub fn daemon_level(&self) -> LogLevel { + self.daemon.level.unwrap_or(self.level) + } + + /// Returns the log level for AI logging. + /// Uses AI-specific setting if set, otherwise falls back to global. + pub fn ai_level(&self) -> LogLevel { + self.ai.level.unwrap_or(self.level) + } + + /// Returns the retention days for search logging. + /// Uses search-specific setting if set, otherwise falls back to global. + pub fn search_retention(&self) -> u64 { + self.search.retention.unwrap_or(self.retention) + } + + /// Returns the retention days for daemon logging. + /// Uses daemon-specific setting if set, otherwise falls back to global. + pub fn daemon_retention(&self) -> u64 { + self.daemon.retention.unwrap_or(self.retention) + } + + /// Returns the retention days for AI logging. + /// Uses AI-specific setting if set, otherwise falls back to global. + pub fn ai_retention(&self) -> u64 { + self.ai.retention.unwrap_or(self.retention) + } + + /// Returns the full path for the search log file. + pub fn search_path(&self) -> PathBuf { + let path = PathBuf::from(&self.search.file); + PathBuf::from(&self.dir).join(path) + } + + /// Returns the full path for the daemon log file. + pub fn daemon_path(&self) -> PathBuf { + let path = PathBuf::from(&self.daemon.file); + PathBuf::from(&self.dir).join(path) + } + + /// Returns the full path for the AI log file. + pub fn ai_path(&self) -> PathBuf { + let path = PathBuf::from(&self.ai.file); + PathBuf::from(&self.dir).join(path) + } +} + +impl Default for Search { + fn default() -> Self { + Self { + filters: vec![ + FilterMode::Global, + FilterMode::Host, + FilterMode::Session, + FilterMode::SessionPreload, + FilterMode::Workspace, + FilterMode::Directory, + ], + + recency_score_multiplier: 1.0, + frequency_score_multiplier: 1.0, + frecency_score_multiplier: 1.0, + } + } +} + +impl Default for Tmux { + fn default() -> Self { + Self { + enabled: false, + width: "80%".to_string(), + height: "60%".to_string(), + } + } +} + +// The preview height strategy also takes max_preview_height into account. +#[derive(Clone, Debug, Deserialize, Copy, PartialEq, Eq, ValueEnum, Serialize)] +pub enum PreviewStrategy { + // Preview height is calculated for the length of the selected command. + #[serde(rename = "auto")] + Auto, + + // Preview height is calculated for the length of the longest command stored in the history. + #[serde(rename = "static")] + Static, + + // max_preview_height is used as fixed height. + #[serde(rename = "fixed")] + Fixed, +} + +/// Column types available for the interactive search UI. +#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Eq, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum UiColumnType { + /// Command execution duration (e.g., "123ms") + Duration, + /// Relative time since execution (e.g., "59s ago") + Time, + /// Absolute timestamp (e.g., "2025-01-22 14:35") + Datetime, + /// Working directory + Directory, + /// Hostname + Host, + /// Username + User, + /// Exit code + Exit, + /// The command itself (should be last, expands to fill) + Command, +} + +impl UiColumnType { + /// Returns the default width for this column type (in characters). + /// The Command column returns 0 as it expands to fill remaining space. + pub fn default_width(&self) -> u16 { + match self { + UiColumnType::Duration => 5, // "814ms" + UiColumnType::Time => 9, // "459ms ago" + UiColumnType::Datetime => 16, // "2025-01-22 14:35" + UiColumnType::Directory => 20, + UiColumnType::Host => 15, + UiColumnType::User => 10, + UiColumnType::Exit => { + if cfg!(windows) { + 11 // 32-bit integer on Windows: "-1978335212" + } else { + 3 // Usually a byte on Unix + } + } + UiColumnType::Command => 0, // Expands to fill + } + } +} + +/// A column configuration with type and optional custom width. +/// Can be specified as just a string (uses default width) or as an object with type and width. +#[derive(Clone, Debug, Serialize)] +pub struct UiColumn { + pub column_type: UiColumnType, + pub width: u16, + /// If true, this column expands to fill remaining space. Only one column should expand. + pub expand: bool, +} + +impl UiColumn { + pub fn new(column_type: UiColumnType) -> Self { + Self { + width: column_type.default_width(), + expand: column_type == UiColumnType::Command, + column_type, + } + } + + pub fn with_width(column_type: UiColumnType, width: u16) -> Self { + Self { + column_type, + width, + expand: column_type == UiColumnType::Command, + } + } +} + +// Custom deserialize to handle both string and object formats: +// "duration" or { type = "duration", width = 8, expand = true } +impl<'de> serde::Deserialize<'de> for UiColumn { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + use serde::de::{self, MapAccess, Visitor}; + + struct UiColumnVisitor; + + impl<'de> Visitor<'de> for UiColumnVisitor { + type Value = UiColumn; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str( + "a column type string or an object with 'type' and optional 'width'/'expand'", + ) + } + + fn visit_str<E>(self, value: &str) -> Result<UiColumn, E> + where + E: de::Error, + { + let column_type: UiColumnType = + serde::Deserialize::deserialize(serde::de::value::StrDeserializer::new(value))?; + Ok(UiColumn::new(column_type)) + } + + fn visit_map<M>(self, mut map: M) -> Result<UiColumn, M::Error> + where + M: MapAccess<'de>, + { + let mut column_type: Option<UiColumnType> = None; + let mut width: Option<u16> = None; + let mut expand: Option<bool> = None; + + while let Some(key) = map.next_key::<String>()? { + match key.as_str() { + "type" => { + column_type = Some(map.next_value()?); + } + "width" => { + width = Some(map.next_value()?); + } + "expand" => { + expand = Some(map.next_value()?); + } + _ => { + let _: serde::de::IgnoredAny = map.next_value()?; + } + } + } + + let column_type = column_type.ok_or_else(|| de::Error::missing_field("type"))?; + let width = width.unwrap_or_else(|| column_type.default_width()); + let expand = expand.unwrap_or(column_type == UiColumnType::Command); + Ok(UiColumn { + column_type, + width, + expand, + }) + } + } + + deserializer.deserialize_any(UiColumnVisitor) + } +} + +/// UI-specific settings for the interactive search. +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Ui { + /// Columns to display in interactive search, from left to right. + /// The indicator column (" > ") is always shown first implicitly. + /// The "command" column should be last as it expands to fill remaining space. + /// Can be simple strings or objects with type and width. + #[serde(default = "Ui::default_columns")] + pub columns: Vec<UiColumn>, +} + +impl Ui { + fn default_columns() -> Vec<UiColumn> { + vec![ + UiColumn::new(UiColumnType::Duration), + UiColumn::new(UiColumnType::Time), + UiColumn::new(UiColumnType::Command), + ] + } + + /// Validate the UI configuration. + /// Returns an error if more than one column has expand = true. + pub fn validate(&self) -> Result<()> { + let expand_count = self.columns.iter().filter(|c| c.expand).count(); + if expand_count > 1 { + bail!( + "Only one column can have expand = true, but {} columns are set to expand", + expand_count + ); + } + Ok(()) + } +} + +impl Default for Ui { + fn default() -> Self { + Self { + columns: Self::default_columns(), + } + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Settings { + pub data_dir: Option<String>, + pub dialect: Dialect, + pub timezone: Timezone, + pub style: Style, + pub auto_sync: bool, + + /// The sync address for atuin. + pub sync_address: String, + + #[serde(default)] + pub sync_protocol: SyncProtocol, + + pub sync_frequency: String, + pub db_path: String, + pub record_store_path: String, + pub key_path: String, + pub search_mode: SearchMode, + pub filter_mode: Option<FilterMode>, + pub filter_mode_shell_up_key_binding: Option<FilterMode>, + pub search_mode_shell_up_key_binding: Option<SearchMode>, + pub shell_up_key_binding: bool, + pub inline_height: u16, + pub inline_height_shell_up_key_binding: Option<u16>, + pub invert: bool, + pub show_preview: bool, + pub max_preview_height: u16, + pub show_help: bool, + pub show_tabs: bool, + pub show_numeric_shortcuts: bool, + pub auto_hide_height: u16, + pub exit_mode: ExitMode, + pub keymap_mode: KeymapMode, + pub keymap_mode_shell: KeymapMode, + pub keymap_cursor: HashMap<String, CursorStyle>, + pub word_jump_mode: WordJumpMode, + pub word_chars: String, + pub scroll_context_lines: usize, + pub history_format: String, + pub strip_trailing_whitespace: bool, + pub prefers_reduced_motion: bool, + pub store_failed: bool, + pub no_mouse: bool, + + #[serde(with = "serde_regex", default = "RegexSet::empty", skip_serializing)] + pub history_filter: RegexSet, + + #[serde(with = "serde_regex", default = "RegexSet::empty", skip_serializing)] + pub cwd_filter: RegexSet, + + pub secrets_filter: bool, + pub workspaces: bool, + pub ctrl_n_shortcuts: bool, + + pub network_connect_timeout: u64, + pub network_timeout: u64, + pub local_timeout: f64, + pub enter_accept: bool, + pub smart_sort: bool, + pub command_chaining: bool, + + #[serde(default)] + pub stats: Stats, + + #[serde(default)] + pub keys: Keys, + + #[serde(default)] + pub keymap: KeymapConfig, + + #[serde(default)] + pub preview: Preview, + + #[serde(default)] + pub daemon: Daemon, + + #[serde(default)] + pub search: Search, + + #[serde(default)] + pub theme: Theme, + + #[serde(default)] + pub ui: Ui, + + #[serde(default)] + pub tmux: Tmux, + + #[serde(default)] + pub logs: Logs, + + #[serde(default)] + pub meta: meta::Settings, +} + +impl Settings { + pub fn utc() -> Self { + Self::builder() + .expect("Could not build default") + .set_override("timezone", "0") + .expect("failed to override timezone with UTC") + .build() + .expect("Could not build config") + .try_deserialize() + .expect("Could not deserialize config") + } + + pub(crate) fn effective_data_dir() -> PathBuf { + DATA_DIR + .get() + .cloned() + .unwrap_or_else(crate::atuin_common::utils::data_dir) + } + + // -- Meta store: lazily initialized on first access -- + + pub async fn meta_store() -> Result<&'static crate::atuin_client::meta::MetaStore> { + META_STORE + .get_or_try_init(|| async { + let (db_path, timeout) = META_CONFIG.get().ok_or_else(|| { + eyre!("meta store config not set — Settings::new() has not been called") + })?; + crate::atuin_client::meta::MetaStore::new(db_path, *timeout).await + }) + .await + } + + pub async fn host_id() -> Result<HostId> { + Self::meta_store().await?.host_id().await + } + + pub async fn last_sync() -> Result<OffsetDateTime> { + Self::meta_store().await?.last_sync().await + } + + pub async fn save_sync_time() -> Result<()> { + Self::meta_store().await?.save_sync_time().await + } + + pub async fn last_version_check() -> Result<OffsetDateTime> { + Self::meta_store().await?.last_version_check().await + } + + pub async fn save_version_check_time() -> Result<()> { + Self::meta_store().await?.save_version_check_time().await + } + + pub async fn should_sync(&self) -> Result<bool> { + if !self.auto_sync || !Self::meta_store().await?.logged_in().await? { + return Ok(false); + } + + if self.sync_frequency == "0" { + return Ok(true); + } + + match parse_duration(self.sync_frequency.as_str()) { + Ok(d) => { + let d = time::Duration::try_from(d)?; + Ok(OffsetDateTime::now_utc() - Settings::last_sync().await? >= d) + } + Err(e) => Err(eyre!("failed to check sync: {}", e)), + } + } + + pub async fn logged_in(&self) -> Result<bool> { + Self::meta_store().await?.logged_in().await + } + + pub async fn session_token(&self) -> Result<String> { + match Self::meta_store().await?.session_token().await? { + Some(token) => Ok(token), + None => Err(eyre!("Tried to load session; not logged in")), + } + } + + /// Examines the configured sync target and available tokens to determine + /// the correct auth strategy. Also performs cleanup of mis-stored tokens + /// (e.g. a CLI token incorrectly saved in the Hub session slot). + #[cfg(feature = "sync")] + pub async fn resolve_sync_auth(&self) -> SyncAuth { + let meta = match Self::meta_store().await { + Ok(m) => m, + Err(e) => { + return SyncAuth::NotLoggedIn { + reason: format!("Failed to open meta store: {e}"), + }; + } + }; + + // Self-hosted / legacy server + match meta.session_token().await { + Ok(Some(token)) => SyncAuth::Legacy { token }, + _ => SyncAuth::NotLoggedIn { + reason: "Not logged in. Run 'atuin login' to authenticate \ + with your sync server." + .into(), + }, + } + } + + /// Returns the appropriate auth token for sync operations. + /// + /// Delegates to [`resolve_sync_auth`] and converts the result to an + /// `AuthToken`. Callers that need to distinguish between auth states + /// (e.g. to show different UI) should call `resolve_sync_auth` directly. + #[cfg(feature = "sync")] + pub async fn sync_auth_token(&self) -> Result<crate::atuin_client::api_client::AuthToken> { + self.resolve_sync_auth().await.into_auth_token() + } + + pub fn default_filter_mode(&self, git_root: bool) -> FilterMode { + self.filter_mode + .filter(|x| self.search.filters.contains(x)) + .or_else(|| { + self.search + .filters + .iter() + .find(|x| match (x, git_root, self.workspaces) { + (FilterMode::Workspace, true, true) => true, + (FilterMode::Workspace, _, _) => false, + (_, _, _) => true, + }) + .copied() + }) + .unwrap_or(FilterMode::Global) + } + + pub fn builder() -> Result<ConfigBuilder<DefaultState>> { + Self::builder_with_data_dir(&crate::atuin_common::utils::data_dir()) + } + + fn builder_with_data_dir(data_dir: &std::path::Path) -> Result<ConfigBuilder<DefaultState>> { + let db_path = data_dir.join("history.db"); + let record_store_path = data_dir.join("records.db"); + let kv_path = data_dir.join("kv.db"); + let scripts_path = data_dir.join("scripts.db"); + let ai_sessions_path = data_dir.join("ai_sessions.db"); + let socket_path = crate::atuin_common::utils::runtime_dir().join("atuin.sock"); + let pidfile_path = data_dir.join("atuin-daemon.pid"); + let logs_dir = crate::atuin_common::utils::logs_dir(); + + let key_path = data_dir.join("key"); + let meta_path = data_dir.join("meta.db"); + + Ok(Config::builder() + .set_default("history_format", "{time}\t{command}\t{duration}")? + .set_default("db_path", db_path.to_str())? + .set_default("record_store_path", record_store_path.to_str())? + .set_default("key_path", key_path.to_str())? + .set_default("dialect", "us")? + .set_default("timezone", "local")? + .set_default("auto_sync", true)? + .set_default("sync_address", "https://api.atuin.sh")? + .set_default("sync_frequency", "5m")? + .set_default("search_mode", "fuzzy")? + .set_default("filter_mode", None::<String>)? + .set_default("style", "compact")? + .set_default("inline_height", 40)? + .set_default("show_preview", true)? + .set_default("preview.strategy", "auto")? + .set_default("max_preview_height", 4)? + .set_default("show_help", true)? + .set_default("show_tabs", true)? + .set_default("show_numeric_shortcuts", true)? + .set_default("auto_hide_height", 8)? + .set_default("invert", false)? + .set_default("exit_mode", "return-original")? + .set_default("word_jump_mode", "emacs")? + .set_default( + "word_chars", + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", + )? + .set_default("scroll_context_lines", 1)? + .set_default("shell_up_key_binding", false)? + .set_default("workspaces", false)? + .set_default("ctrl_n_shortcuts", false)? + .set_default("secrets_filter", true)? + .set_default("strip_trailing_whitespace", true)? + .set_default("network_connect_timeout", 5)? + .set_default("network_timeout", 30)? + .set_default("local_timeout", 2.0)? + // enter_accept defaults to false here, but true in the default config file. The dissonance is + // intentional! + // Existing users will get the default "False", so we don't mess with any potential + // muscle memory. + // New users will get the new default, that is more similar to what they are used to. + .set_default("enter_accept", false)? + .set_default("keys.scroll_exits", true)? + .set_default("keys.accept_past_line_end", true)? + .set_default("keys.exit_past_line_start", true)? + .set_default("keys.accept_past_line_start", false)? + .set_default("keys.accept_with_backspace", false)? + .set_default("keys.prefix", "a")? + .set_default("keymap_mode", "emacs")? + .set_default("keymap_mode_shell", "auto")? + .set_default("keymap_cursor", HashMap::<String, String>::new())? + .set_default("smart_sort", false)? + .set_default("command_chaining", false)? + .set_default("store_failed", true)? + .set_default("daemon.sync_frequency", 300)? + .set_default("daemon.enabled", false)? + .set_default("daemon.autostart", false)? + .set_default("daemon.socket_path", socket_path.to_str())? + .set_default("daemon.pidfile_path", pidfile_path.to_str())? + .set_default("daemon.systemd_socket", false)? + .set_default("daemon.tcp_port", 8889)? + .set_default("logs.enabled", true)? + .set_default("logs.dir", logs_dir.to_str())? + .set_default("logs.level", "info")? + .set_default("logs.search.file", "search.log")? + .set_default("logs.daemon.file", "daemon.log")? + .set_default("logs.ai.file", "ai.log")? + .set_default("kv.db_path", kv_path.to_str())? + .set_default("scripts.db_path", scripts_path.to_str())? + .set_default("search.recency_score_multiplier", 1.0)? + .set_default("search.frequency_score_multiplier", 1.0)? + .set_default("search.frecency_score_multiplier", 1.0)? + .set_default("meta.db_path", meta_path.to_str())? + .set_default("ai.db_path", ai_sessions_path.to_str())? + .set_default("ai.session_continue_minutes", 60)? + .set_default("ai.send_cwd", false)? + .set_default("ai.opening.send_cwd", false)? + .set_default("ai.opening.send_last_command", false)? + .set_default( + "search.filters", + vec![ + "global", + "host", + "session", + "workspace", + "directory", + "session-preload", + ], + )? + .set_default("theme.name", "default")? + .set_default("theme.debug", None::<bool>)? + .set_default("tmux.enabled", false)? + .set_default("tmux.width", "80%")? + .set_default("tmux.height", "60%")? + .set_default( + "prefers_reduced_motion", + std::env::var("NO_MOTION") + .ok() + .map(|_| config::Value::new(None, config::ValueKind::Boolean(true))) + .unwrap_or_else(|| config::Value::new(None, config::ValueKind::Boolean(false))), + )? + .set_default("no_mouse", false)? + .add_source( + Environment::with_prefix("atuin") + .prefix_separator("_") + .separator("__"), + )) + } + + pub fn get_config_path() -> Result<PathBuf> { + let config_dir = crate::atuin_common::utils::config_dir(); + + create_dir_all(&config_dir) + .wrap_err_with(|| format!("could not create dir {config_dir:?}"))?; + + let mut config_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { + PathBuf::from(p) + } else { + let mut config_file = PathBuf::new(); + config_file.push(config_dir); + config_file + }; + + config_file.push("config.toml"); + + Ok(config_file) + } + + /// Build a merged `Config` from defaults, config file, and environment. + /// + /// This resolves `data_dir`, initializes the data directory on disk, + /// and layers defaults → config file → env overrides. Both `new()` and + /// `get_config_value()` use this so the resolution logic lives in one place. + fn build_config() -> Result<Config> { + let config_file = Self::get_config_path()?; + + // extract data_dir first so we can use it as the base for other path defaults + let effective_data_dir = if config_file.exists() { + #[derive(Deserialize, Default)] + struct DataDirOnly { + data_dir: Option<String>, + } + + let config_file_str = config_file + .to_str() + .ok_or_else(|| eyre!("config file path is not valid UTF-8"))?; + + let partial_config = Config::builder() + .add_source(ConfigFile::new(config_file_str, FileFormat::Toml)) + .add_source( + Environment::with_prefix("atuin") + .prefix_separator("_") + .separator("__"), + ) + .build() + .ok(); + + let custom_data_dir = partial_config + .and_then(|c| c.try_deserialize::<DataDirOnly>().ok()) + .and_then(|d| d.data_dir); + + match custom_data_dir { + Some(dir) => { + let expanded = shellexpand::full(&dir) + .map_err(|e| eyre!("failed to expand data_dir path: {}", e))?; + PathBuf::from(expanded.as_ref()) + } + None => crate::atuin_common::utils::data_dir(), + } + } else { + crate::atuin_common::utils::data_dir() + }; + + DATA_DIR.set(effective_data_dir.clone()).ok(); + + create_dir_all(&effective_data_dir) + .wrap_err_with(|| format!("could not create dir {effective_data_dir:?}"))?; + + let mut config_builder = Self::builder_with_data_dir(&effective_data_dir)?; + + config_builder = if config_file.exists() { + let config_file_str = config_file + .to_str() + .ok_or_else(|| eyre!("config file path is not valid UTF-8"))?; + config_builder.add_source(ConfigFile::new(config_file_str, FileFormat::Toml)) + } else { + let mut file = File::create(config_file).wrap_err("could not create config file")?; + + let config = config_builder.build_cloned()?; + // TODO(@bpeetz): Not so sure about this <2026-06-10> + file.write_all(config.cache.to_string().as_bytes()) + .wrap_err("could not write default config file")?; + + config_builder + }; + + // all paths should be expanded + let built = config_builder.build_cloned()?; + config_builder = [ + "db_path", + "record_store_path", + "key_path", + "daemon.socket_path", + "daemon.pidfile_path", + "logs.dir", + "logs.search.file", + "logs.daemon.file", + ] + .iter() + .map(|key| (key, built.get_string(key).unwrap_or_default())) + .filter_map(|(key, value)| match Self::expand_path(value) { + Ok(expanded) => Some((key, expanded)), + Err(e) => { + log::warn!("failed to expand path for {key}: {e}"); + None + } + }) + .fold(config_builder, |builder, (key, value)| { + builder + .set_override(key, value) + .unwrap_or_else(|_| panic!("failed to set absolute path override for {key}")) + }); + + config_builder.build().map_err(Into::into) + } + + /// Look up a single config value by dotted key (e.g. `"daemon.sync_frequency"`). + /// + /// Returns the effective value after merging defaults, config file, and + /// environment — without the side-effects of full `Settings` construction + /// (meta store init, path expansion, etc.). + pub fn get_config_value(key: &str) -> Result<String> { + let config = Self::build_config()?; + let value: config::Value = config + .get(key) + .map_err(|e| eyre!("failed to get config value '{}': {}", key, e))?; + Ok(Self::format_resolved_value(&value, key)) + } + + fn format_resolved_value(value: &config::Value, prefix: &str) -> String { + use config::ValueKind; + + match &value.kind { + ValueKind::Nil => String::new(), + ValueKind::Boolean(b) => b.to_string(), + ValueKind::I64(i) => i.to_string(), + ValueKind::I128(i) => i.to_string(), + ValueKind::U64(u) => u.to_string(), + ValueKind::U128(u) => u.to_string(), + ValueKind::Float(f) => f.to_string(), + ValueKind::String(s) => s.clone(), + ValueKind::Array(arr) => { + let items: Vec<String> = arr + .iter() + .map(|v| Self::format_resolved_value(v, "")) + .collect(); + format!("[{}]", items.join(", ")) + } + ValueKind::Table(map) => { + let mut lines = Vec::new(); + let mut keys: Vec<_> = map.keys().collect(); + keys.sort(); + + for k in keys { + let v = &map[k]; + let full_key = if prefix.is_empty() { + k.clone() + } else { + format!("{}.{}", prefix, k) + }; + + match &v.kind { + ValueKind::Table(_) => { + lines.push(Self::format_resolved_value(v, &full_key)); + } + _ => { + lines.push(format!( + "{} = {}", + full_key, + Self::format_resolved_value(v, "") + )); + } + } + } + + lines.join("\n") + } + } + } + + pub fn new() -> Result<Self> { + let config = Self::build_config()?; + let settings: Settings = config + .try_deserialize() + .map_err(|e| eyre!("failed to deserialize: {}", e))?; + + // Validate UI settings + settings.ui.validate()?; + + // Register meta store config for lazy initialization on first access + META_CONFIG + .set((settings.meta.db_path.clone(), settings.local_timeout)) + .ok(); + + Ok(settings) + } + + fn expand_path(path: String) -> Result<String> { + shellexpand::full(&path) + .map(|p| p.to_string()) + .map_err(|e| eyre!("failed to expand path: {}", e)) + } + + pub fn paths_ok(&self) -> bool { + let paths = [ + &self.db_path, + &self.record_store_path, + &self.key_path, + &self.meta.db_path, + ]; + paths.iter().all(|p| !utils::broken_symlink(p)) + } +} + +impl Default for Settings { + fn default() -> Self { + // if this panics something is very wrong, as the default config + // does not build or deserialize into the settings struct + Self::builder() + .expect("Could not build default") + .build() + .expect("Could not build config") + .try_deserialize() + .expect("Could not deserialize config") + } +} + +/// Initialize the meta store configuration for testing. +/// +/// This should only be used in tests. It allows tests to bypass the normal +/// Settings::new() flow while still being able to use Settings::host_id() +/// and other meta store dependent functions. +/// +/// # Safety +/// This function is not thread-safe with concurrent calls to Settings::new() +/// or other meta store initialization. Only call from tests. +#[doc(hidden)] +pub fn init_meta_config_for_testing(meta_db_path: impl Into<String>, local_timeout: f64) { + META_CONFIG.set((meta_db_path.into(), local_timeout)).ok(); +} + +#[cfg(test)] +pub(crate) fn test_local_timeout() -> f64 { + std::env::var("ATUIN_TEST_LOCAL_TIMEOUT") + .ok() + .and_then(|x| x.parse().ok()) + // this hardcoded value should be replaced by a simple way to get the + // default local_timeout of Settings if possible + .unwrap_or(2.0) +} + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use eyre::Result; + + use super::Timezone; + + #[test] + fn can_parse_offset_timezone_spec() -> Result<()> { + assert_eq!(Timezone::from_str("+02")?.0.as_hms(), (2, 0, 0)); + assert_eq!(Timezone::from_str("-04")?.0.as_hms(), (-4, 0, 0)); + assert_eq!(Timezone::from_str("+05:30")?.0.as_hms(), (5, 30, 0)); + assert_eq!(Timezone::from_str("-09:30")?.0.as_hms(), (-9, -30, 0)); + + // single digit hours are allowed + assert_eq!(Timezone::from_str("+2")?.0.as_hms(), (2, 0, 0)); + assert_eq!(Timezone::from_str("-4")?.0.as_hms(), (-4, 0, 0)); + assert_eq!(Timezone::from_str("+5:30")?.0.as_hms(), (5, 30, 0)); + assert_eq!(Timezone::from_str("-9:30")?.0.as_hms(), (-9, -30, 0)); + + // fully qualified form + assert_eq!(Timezone::from_str("+09:30:00")?.0.as_hms(), (9, 30, 0)); + assert_eq!(Timezone::from_str("-09:30:00")?.0.as_hms(), (-9, -30, 0)); + + // these offsets don't really exist but are supported anyway + assert_eq!(Timezone::from_str("+0:5")?.0.as_hms(), (0, 5, 0)); + assert_eq!(Timezone::from_str("-0:5")?.0.as_hms(), (0, -5, 0)); + assert_eq!(Timezone::from_str("+01:23:45")?.0.as_hms(), (1, 23, 45)); + assert_eq!(Timezone::from_str("-01:23:45")?.0.as_hms(), (-1, -23, -45)); + + // require a leading sign for clarity + assert!(Timezone::from_str("5").is_err()); + assert!(Timezone::from_str("10:30").is_err()); + + Ok(()) + } + + #[test] + fn can_choose_workspace_filters_when_in_git_context() -> Result<()> { + let mut settings = super::Settings::default(); + settings.search.filters = vec![ + super::FilterMode::Workspace, + super::FilterMode::Host, + super::FilterMode::Directory, + super::FilterMode::Session, + super::FilterMode::Global, + ]; + settings.workspaces = true; + + assert_eq!( + settings.default_filter_mode(true), + super::FilterMode::Workspace, + ); + + Ok(()) + } + + #[test] + fn wont_choose_workspace_filters_when_not_in_git_context() -> Result<()> { + let mut settings = super::Settings::default(); + settings.search.filters = vec![ + super::FilterMode::Workspace, + super::FilterMode::Host, + super::FilterMode::Directory, + super::FilterMode::Session, + super::FilterMode::Global, + ]; + settings.workspaces = true; + + assert_eq!(settings.default_filter_mode(false), super::FilterMode::Host,); + + Ok(()) + } + + #[test] + fn wont_choose_workspace_filters_when_workspaces_disabled() -> Result<()> { + let mut settings = super::Settings::default(); + settings.search.filters = vec![ + super::FilterMode::Workspace, + super::FilterMode::Host, + super::FilterMode::Directory, + super::FilterMode::Session, + super::FilterMode::Global, + ]; + settings.workspaces = false; + + assert_eq!(settings.default_filter_mode(true), super::FilterMode::Host,); + + Ok(()) + } + + #[test] + fn builder_with_data_dir_uses_custom_paths() -> Result<()> { + use std::path::PathBuf; + + let custom_dir = PathBuf::from("/custom/data/dir"); + let builder = super::Settings::builder_with_data_dir(&custom_dir)?; + let config = builder.build()?; + + let db_path: String = config.get("db_path")?; + let key_path: String = config.get("key_path")?; + let record_store_path: String = config.get("record_store_path")?; + let kv_db_path: String = config.get("kv.db_path")?; + let scripts_db_path: String = config.get("scripts.db_path")?; + let meta_db_path: String = config.get("meta.db_path")?; + let daemon_socket_path: String = config.get("daemon.socket_path")?; + let daemon_pidfile_path: String = config.get("daemon.pidfile_path")?; + let daemon_autostart: bool = config.get("daemon.autostart")?; + + assert_eq!(db_path, custom_dir.join("history.db").to_str().unwrap()); + assert_eq!(key_path, custom_dir.join("key").to_str().unwrap()); + assert_eq!( + record_store_path, + custom_dir.join("records.db").to_str().unwrap() + ); + assert_eq!(kv_db_path, custom_dir.join("kv.db").to_str().unwrap()); + assert_eq!( + scripts_db_path, + custom_dir.join("scripts.db").to_str().unwrap() + ); + assert_eq!(meta_db_path, custom_dir.join("meta.db").to_str().unwrap()); + assert_eq!( + daemon_socket_path, + crate::atuin_common::utils::runtime_dir() + .join("atuin.sock") + .to_str() + .unwrap() + ); + assert_eq!( + daemon_pidfile_path, + custom_dir.join("atuin-daemon.pid").to_str().unwrap() + ); + assert!(!daemon_autostart); + + Ok(()) + } + + #[test] + fn effective_data_dir_returns_default_when_not_set() { + let effective = super::Settings::effective_data_dir(); + let default = crate::atuin_common::utils::data_dir(); + + assert!(effective.to_str().is_some()); + assert!(effective.ends_with("atuin") || effective == default); + } + + #[test] + fn keymap_config_deserializes_simple_binding() { + let json = r#"{"emacs": {"ctrl-c": "exit"}}"#; + let config: super::KeymapConfig = serde_json::from_str(json).unwrap(); + assert_eq!(config.emacs.len(), 1); + match &config.emacs["ctrl-c"] { + super::KeyBindingConfig::Simple(s) => assert_eq!(s, "exit"), + _ => panic!("expected Simple variant"), + } + } + + #[test] + fn keymap_config_deserializes_conditional_binding() { + let json = r#"{ + "emacs": { + "left": [ + {"when": "cursor-at-start", "action": "exit"}, + {"action": "cursor-left"} + ] + } + }"#; + let config: super::KeymapConfig = serde_json::from_str(json).unwrap(); + match &config.emacs["left"] { + super::KeyBindingConfig::Rules(rules) => { + assert_eq!(rules.len(), 2); + assert_eq!(rules[0].when.as_deref(), Some("cursor-at-start")); + assert_eq!(rules[0].action, "exit"); + assert!(rules[1].when.is_none()); + assert_eq!(rules[1].action, "cursor-left"); + } + _ => panic!("expected Rules variant"), + } + } + + #[test] + fn keymap_config_deserializes_vim_normal() { + let json = r#"{"vim-normal": {"j": "select-next", "k": "select-previous"}}"#; + let config: super::KeymapConfig = serde_json::from_str(json).unwrap(); + assert_eq!(config.vim_normal.len(), 2); + assert!(config.emacs.is_empty()); + } + + #[test] + fn keymap_config_is_empty_when_default() { + let config = super::KeymapConfig::default(); + assert!(config.is_empty()); + } + + #[test] + fn keymap_config_mixed_modes() { + let json = r#"{ + "emacs": {"ctrl-c": "exit"}, + "vim-normal": {"q": "exit"}, + "inspector": {"d": "delete"} + }"#; + let config: super::KeymapConfig = serde_json::from_str(json).unwrap(); + assert!(!config.is_empty()); + assert_eq!(config.emacs.len(), 1); + assert_eq!(config.vim_normal.len(), 1); + assert_eq!(config.inspector.len(), 1); + assert!(config.vim_insert.is_empty()); + assert!(config.prefix.is_empty()); + } +} diff --git a/crates/turtle/src/atuin_client/settings/meta.rs b/crates/turtle/src/atuin_client/settings/meta.rs new file mode 100644 index 00000000..450d0432 --- /dev/null +++ b/crates/turtle/src/atuin_client/settings/meta.rs @@ -0,0 +1,17 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Settings { + pub db_path: String, +} + +impl Default for Settings { + fn default() -> Self { + let dir = crate::atuin_common::utils::data_dir(); + let path = dir.join("meta.db"); + + Self { + db_path: path.to_string_lossy().to_string(), + } + } +} diff --git a/crates/turtle/src/atuin_client/settings/watcher.rs b/crates/turtle/src/atuin_client/settings/watcher.rs new file mode 100644 index 00000000..7548573e --- /dev/null +++ b/crates/turtle/src/atuin_client/settings/watcher.rs @@ -0,0 +1,256 @@ +//! Config file watching for automatic settings reload. +//! +//! This module provides a `SettingsWatcher` that monitors the config file +//! for changes and broadcasts updated settings via a `tokio::sync::watch` channel. +//! +//! # Example +//! +//! ```no_run +//! use crate::atuin_client::settings::watcher::global_settings_watcher; +//! +//! async fn example() -> eyre::Result<()> { +//! let watcher = global_settings_watcher()?; +//! let mut rx = watcher.subscribe(); +//! +//! // React to settings changes +//! while rx.changed().await.is_ok() { +//! let settings = rx.borrow(); +//! println!("Settings updated!"); +//! } +//! Ok(()) +//! } +//! ``` + +use std::{ + path::PathBuf, + sync::{Arc, OnceLock}, + time::Duration, +}; + +use eyre::{Result, WrapErr}; +use log::{debug, error, info, warn}; +use notify::{ + Config as NotifyConfig, RecommendedWatcher, RecursiveMode, Watcher, + event::{EventKind, ModifyKind}, +}; +use tokio::sync::watch; + +use super::Settings; + +/// Global singleton for the settings watcher. +static SETTINGS_WATCHER: OnceLock<Result<SettingsWatcher, String>> = OnceLock::new(); + +/// Get the global settings watcher singleton. +/// +/// Initializes the watcher on first call. Subsequent calls return the same instance. +/// The watcher monitors the config file for changes and broadcasts updates. +pub fn global_settings_watcher() -> Result<&'static SettingsWatcher> { + let result = SETTINGS_WATCHER.get_or_init(|| SettingsWatcher::new().map_err(|e| e.to_string())); + + match result { + Ok(watcher) => Ok(watcher), + Err(e) => Err(eyre::eyre!("{}", e)), + } +} + +/// Watches the config file for changes and broadcasts updated settings. +/// +/// Uses `notify` for cross-platform file watching and `tokio::sync::watch` +/// for efficient broadcast to multiple subscribers. +pub struct SettingsWatcher { + /// Receiver for settings updates. Clone this to subscribe. + rx: watch::Receiver<Arc<Settings>>, + /// Keeps the file watcher alive for the lifetime of this struct. + _watcher: RecommendedWatcher, +} + +impl SettingsWatcher { + /// Create a new settings watcher. + /// + /// Loads initial settings and starts watching the config file for changes. + /// Changes are debounced (500ms) to avoid multiple reloads during saves. + pub fn new() -> Result<Self> { + let initial_settings = Arc::new(Settings::new()?); + let (tx, rx) = watch::channel(initial_settings); + + let config_path = Self::config_path(); + info!("starting config file watcher: {:?}", config_path); + + let watcher = Self::create_watcher(tx, config_path)?; + + Ok(Self { + rx, + _watcher: watcher, + }) + } + + /// Subscribe to settings updates. + /// + /// Returns a receiver that will be notified when settings change. + /// Use `changed().await` to wait for the next update, then `borrow()` + /// to access the current settings. + pub fn subscribe(&self) -> watch::Receiver<Arc<Settings>> { + self.rx.clone() + } + + /// Get the current settings without subscribing to updates. + pub fn current(&self) -> Arc<Settings> { + self.rx.borrow().clone() + } + + /// Get the config file path. + fn config_path() -> PathBuf { + let config_dir = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { + PathBuf::from(p) + } else { + crate::atuin_common::utils::config_dir() + }; + config_dir.join("config.toml") + } + + /// Create the file watcher with debouncing. + fn create_watcher( + tx: watch::Sender<Arc<Settings>>, + config_path: PathBuf, + ) -> Result<RecommendedWatcher> { + // Channel for debouncing file events + let (debounce_tx, debounce_rx) = std::sync::mpsc::channel::<()>(); + + // Spawn debounce thread + let config_path_clone = config_path.clone(); + std::thread::spawn(move || { + Self::debounce_loop(debounce_rx, tx, config_path_clone); + }); + + // Clone config_path for use in the watcher callback + let config_path_for_watcher = config_path.clone(); + + // Canonicalize config path for reliable comparison on macOS + // (handles symlinks like /var -> /private/var) + let canonical_config_path = config_path_for_watcher + .canonicalize() + .unwrap_or_else(|_| config_path_for_watcher.clone()); + + // Create file watcher + let mut watcher = RecommendedWatcher::new( + move |res: Result<notify::Event, notify::Error>| { + match res { + Ok(event) => { + // Defensive: if paths is empty, we can't filter, so assume + // it might be our config file and trigger a reload to be safe + if event.paths.is_empty() { + warn!( + "config watcher: event has no paths, triggering reload to be safe" + ); + let _ = debounce_tx.send(()); + return; + } + + // Only react to events for our specific config file + // (filter out editor temp files, backups, etc.) + let is_config_file = event.paths.iter().any(|path| { + // Canonicalize for reliable comparison (handles macOS symlinks) + let canonical_event_path = + path.canonicalize().unwrap_or_else(|_| path.clone()); + + // Check if this event is for our config file + // (either exact match or the file was renamed to our config) + canonical_event_path == canonical_config_path + || path.file_name() == config_path_for_watcher.file_name() + }); + + if !is_config_file { + return; + } + + // Only react to modify events (content changes) or creates + if matches!( + event.kind, + EventKind::Modify(ModifyKind::Data(_) | ModifyKind::Any) + | EventKind::Create(_) + ) { + debug!("config file event detected: {:?}", event); + // Send to debounce channel (ignore send errors - receiver might be gone) + let _ = debounce_tx.send(()); + } + } + Err(e) => { + error!("file watcher error: {}", e); + } + } + }, + NotifyConfig::default(), + ) + .wrap_err("failed to create file watcher")?; + + // Watch the config file's parent directory (some editors create new files) + let watch_path = config_path.parent().unwrap_or(&config_path); + + // Defensive: ensure watch path exists before trying to watch + if !watch_path.exists() { + warn!( + "config directory does not exist, creating it: {:?}", + watch_path + ); + std::fs::create_dir_all(watch_path) + .wrap_err_with(|| format!("failed to create config directory: {:?}", watch_path))?; + } + + watcher + .watch(watch_path, RecursiveMode::NonRecursive) + .wrap_err_with(|| format!("failed to watch config directory: {:?}", watch_path))?; + + info!("config file watcher initialized for: {:?}", watch_path); + Ok(watcher) + } + + /// Debounce loop that batches file events and reloads settings. + fn debounce_loop( + rx: std::sync::mpsc::Receiver<()>, + tx: watch::Sender<Arc<Settings>>, + config_path: PathBuf, + ) { + const DEBOUNCE_DURATION: Duration = Duration::from_millis(500); + + loop { + // Wait for first event + if rx.recv().is_err() { + // Channel closed, watcher was dropped + debug!("config watcher debounce loop exiting"); + return; + } + + // Drain any additional events within debounce window + while rx.recv_timeout(DEBOUNCE_DURATION).is_ok() { + // Keep draining + } + + // Defensive: check if config file exists before reloading + // (handles case where file was deleted - we'll get notified when it's recreated) + if !config_path.exists() { + debug!( + "config file does not exist, skipping reload: {:?}", + config_path + ); + continue; + } + + // Now reload settings + info!("config file changed, reloading settings: {:?}", config_path); + match Settings::new() { + Ok(settings) => { + if tx.send(Arc::new(settings)).is_err() { + // All receivers dropped + debug!("all settings subscribers dropped, exiting"); + return; + } + info!("settings reloaded successfully"); + } + Err(e) => { + warn!("failed to reload settings: {}", e); + // Keep the old settings, don't broadcast the error + } + } + } + } +} diff --git a/crates/turtle/src/atuin_client/sync.rs b/crates/turtle/src/atuin_client/sync.rs new file mode 100644 index 00000000..361b6636 --- /dev/null +++ b/crates/turtle/src/atuin_client/sync.rs @@ -0,0 +1,214 @@ +use std::collections::HashSet; +use std::iter::FromIterator; + +use eyre::Result; +use tracing::{debug, info}; + +use crate::atuin_common::api::AddHistoryRequest; +use crypto_secretbox::Key; +use time::OffsetDateTime; + +use crate::atuin_client::{ + api_client, + database::Database, + encryption::{decrypt, encrypt, load_key}, + settings::Settings, +}; + +pub fn hash_str(string: &str) -> String { + use sha2::{Digest, Sha256}; + let mut hasher = Sha256::new(); + hasher.update(string.as_bytes()); + hex::encode(hasher.finalize()) +} + +// Currently sync is kinda naive, and basically just pages backwards through +// history. This means newly added stuff shows up properly! We also just use +// the total count in each database to indicate whether a sync is needed. +// I think this could be massively improved! If we had a way of easily +// indicating count per time period (hour, day, week, year, etc) then we can +// easily pinpoint where we are missing data and what needs downloading. Start +// with year, then find the week, then the day, then the hour, then download it +// all! The current naive approach will do for now. + +// Check if remote has things we don't, and if so, download them. +// Returns (num downloaded, total local) +async fn sync_download( + key: &Key, + force: bool, + client: &api_client::Client<'_>, + db: &impl Database, +) -> Result<(i64, i64)> { + debug!("starting sync download"); + + let remote_status = client.status().await?; + let remote_count = remote_status.count; + + // useful to ensure we don't even save something that hasn't yet been synced + deleted + let remote_deleted = + HashSet::<&str>::from_iter(remote_status.deleted.iter().map(String::as_str)); + + let initial_local = db.history_count(true).await?; + let mut local_count = initial_local; + + let mut last_sync = if force { + OffsetDateTime::UNIX_EPOCH + } else { + Settings::last_sync().await? + }; + + let mut last_timestamp = OffsetDateTime::UNIX_EPOCH; + + let host = if force { Some(String::from("")) } else { None }; + + while remote_count > local_count { + let page = client + .get_history(last_sync, last_timestamp, host.clone()) + .await?; + + let history: Vec<_> = page + .history + .iter() + // TODO: handle deletion earlier in this chain + .map(|h| serde_json::from_str(h).expect("invalid base64")) + .map(|h| decrypt(h, key).expect("failed to decrypt history! check your key")) + .map(|mut h| { + if remote_deleted.contains(h.id.0.as_str()) { + h.deleted_at = Some(time::OffsetDateTime::now_utc()); + h.command = String::from(""); + } + + h + }) + .collect(); + + db.save_bulk(&history).await?; + + local_count = db.history_count(true).await?; + let remote_page_size = std::cmp::max(remote_status.page_size, 0) as usize; + + if history.len() < remote_page_size { + break; + } + + let page_last = history + .last() + .expect("could not get last element of page") + .timestamp; + + // in the case of a small sync frequency, it's possible for history to + // be "lost" between syncs. In this case we need to rewind the sync + // timestamps + if page_last == last_timestamp { + last_timestamp = OffsetDateTime::UNIX_EPOCH; + last_sync -= time::Duration::hours(1); + } else { + last_timestamp = page_last; + } + } + + for i in remote_status.deleted { + // we will update the stored history to have this data + // pretty much everything can be nullified + match db.load(i.as_str()).await? { + Some(h) => { + db.delete(h).await?; + } + _ => { + info!( + "could not delete history with id {}, not found locally", + i.as_str() + ); + } + } + } + + Ok((local_count - initial_local, local_count)) +} + +// Check if we have things remote doesn't, and if so, upload them +async fn sync_upload( + key: &Key, + _force: bool, + client: &api_client::Client<'_>, + db: &impl Database, +) -> Result<()> { + debug!("starting sync upload"); + + let remote_status = client.status().await?; + let remote_deleted: HashSet<String> = HashSet::from_iter(remote_status.deleted.clone()); + + let initial_remote_count = client.count().await?; + let mut remote_count = initial_remote_count; + + let local_count = db.history_count(true).await?; + + debug!("remote has {remote_count}, we have {local_count}"); + + // first just try the most recent set + let mut cursor = OffsetDateTime::now_utc(); + + while local_count > remote_count { + let last = db.before(cursor, remote_status.page_size).await?; + let mut buffer = Vec::new(); + + if last.is_empty() { + break; + } + + for i in last { + let data = encrypt(&i, key)?; + let data = serde_json::to_string(&data)?; + + let add_hist = AddHistoryRequest { + id: i.id.to_string(), + timestamp: i.timestamp, + data, + hostname: hash_str(&i.hostname), + }; + + buffer.push(add_hist); + } + + // anything left over outside of the 100 block size + client.post_history(&buffer).await?; + cursor = buffer.last().unwrap().timestamp; + remote_count = client.count().await?; + + debug!("upload cursor: {cursor:?}"); + } + + let deleted = db.deleted().await?; + + for i in deleted { + if remote_deleted.contains(&i.id.to_string()) { + continue; + } + + info!("deleting {} on remote", i.id); + client.delete_history(i).await?; + } + + Ok(()) +} + +pub async fn sync(settings: &Settings, force: bool, db: &impl Database) -> Result<()> { + let client = api_client::Client::new( + &settings.sync_address, + settings.sync_auth_token().await?, + settings.network_connect_timeout, + settings.network_timeout, + )?; + + Settings::save_sync_time().await?; + + let key = load_key(settings)?; // encryption key + + sync_upload(&key, force, &client, db).await?; + + let download = sync_download(&key, force, &client, db).await?; + + debug!("sync downloaded {}", download.0); + + Ok(()) +} diff --git a/crates/turtle/src/atuin_client/theme.rs b/crates/turtle/src/atuin_client/theme.rs new file mode 100644 index 00000000..1d9c0b9e --- /dev/null +++ b/crates/turtle/src/atuin_client/theme.rs @@ -0,0 +1,831 @@ +use config::{Config, File as ConfigFile, FileFormat}; +use log; +use palette::named; +use serde::{Deserialize, Serialize}; +use serde_json; +use std::collections::HashMap; +use std::error; +use std::io::{Error, ErrorKind}; +use std::path::PathBuf; +use std::sync::LazyLock; +use strum_macros; + +static DEFAULT_MAX_DEPTH: u8 = 10; + +// Collection of settable "meanings" that can have colors set. +// NOTE: You can add a new meaning here without breaking backwards compatibility but please: +// - update the atuin/docs repository, which has a list of available meanings +// - add a fallback in the MEANING_FALLBACKS below, so that themes which do not have it +// get a sensible fallback (see Title as an example) +#[derive( + Serialize, Deserialize, Copy, Clone, Hash, Debug, Eq, PartialEq, strum_macros::Display, +)] +#[strum(serialize_all = "camel_case")] +pub enum Meaning { + AlertInfo, + AlertWarn, + AlertError, + Annotation, + Base, + Guidance, + Important, + Title, + Muted, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ThemeConfig { + // Definition of the theme + pub theme: ThemeDefinitionConfigBlock, + + // Colors + pub colors: HashMap<Meaning, String>, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ThemeDefinitionConfigBlock { + /// Name of theme ("default" for base) + pub name: String, + + /// Whether any theme should be treated as a parent _if available_ + pub parent: Option<String>, +} + +use crossterm::style::{Attribute, Attributes, Color, ContentStyle}; + +// For now, a theme is loaded as a mapping of meanings to colors, but it may be desirable to +// expand that in the future to general styles, so we populate a Meaning->ContentStyle hashmap. +pub struct Theme { + pub name: String, + pub parent: Option<String>, + pub styles: HashMap<Meaning, ContentStyle>, +} + +// Themes have a number of convenience functions for the most commonly used meanings. +// The general purpose `as_style` routine gives back a style, but for ease-of-use and to keep +// theme-related boilerplate minimal, the convenience functions give a color. +impl Theme { + // This is the base "default" color, for general text + pub fn get_base(&self) -> ContentStyle { + self.styles[&Meaning::Base] + } + + pub fn get_info(&self) -> ContentStyle { + self.get_alert(log::Level::Info) + } + + pub fn get_warning(&self) -> ContentStyle { + self.get_alert(log::Level::Warn) + } + + pub fn get_error(&self) -> ContentStyle { + self.get_alert(log::Level::Error) + } + + // The alert meanings may be chosen by the Level enum, rather than the methods above + // or the full Meaning enum, to simplify programmatic selection of a log-level. + pub fn get_alert(&self, severity: log::Level) -> ContentStyle { + self.styles[ALERT_TYPES.get(&severity).unwrap()] + } + + pub fn new( + name: String, + parent: Option<String>, + styles: HashMap<Meaning, ContentStyle>, + ) -> Theme { + Theme { + name, + parent, + styles, + } + } + + pub fn closest_meaning<'a>(&self, meaning: &'a Meaning) -> &'a Meaning { + if self.styles.contains_key(meaning) { + meaning + } else if MEANING_FALLBACKS.contains_key(meaning) { + self.closest_meaning(&MEANING_FALLBACKS[meaning]) + } else { + &Meaning::Base + } + } + + // General access - if you have a meaning, this will give you a (crossterm) style + pub fn as_style(&self, meaning: Meaning) -> ContentStyle { + self.styles[self.closest_meaning(&meaning)] + } + + // Turns a map of meanings to colornames into a theme + // If theme-debug is on, then we will print any colornames that we cannot load, + // but we do not have this on in general, as it could print unfiltered text to the terminal + // from a theme TOML file. However, it will always return a theme, falling back to + // defaults on error, so that a TOML file does not break loading + pub fn from_foreground_colors( + name: String, + parent: Option<&Theme>, + foreground_colors: HashMap<Meaning, String>, + debug: bool, + ) -> Theme { + let styles: HashMap<Meaning, ContentStyle> = foreground_colors + .iter() + .map(|(name, color)| { + ( + *name, + StyleFactory::from_fg_string(color).unwrap_or_else(|err| { + if debug { + log::warn!("Tried to load string as a color unsuccessfully: ({name}={color}) {err}"); + } + ContentStyle::default() + }), + ) + }) + .collect(); + Theme::from_map(name, parent, &styles) + } + + // Boil down a meaning-color hashmap into a theme, by taking the defaults + // for any unknown colors + fn from_map( + name: String, + parent: Option<&Theme>, + overrides: &HashMap<Meaning, ContentStyle>, + ) -> Theme { + let styles = match parent { + Some(theme) => Box::new(theme.styles.clone()), + None => Box::new(DEFAULT_THEME.styles.clone()), + } + .iter() + .map(|(name, color)| match overrides.get(name) { + Some(value) => (*name, *value), + None => (*name, *color), + }) + .collect(); + Theme::new(name, parent.map(|p| p.name.clone()), styles) + } +} + +// Use palette to get a color from a string name, if possible +fn from_string(name: &str) -> Result<Color, String> { + if name.is_empty() { + return Err("Empty string".into()); + } + let first_char = name.chars().next().unwrap(); + match first_char { + '#' => { + let hexcode = &name[1..]; + let vec: Vec<u8> = hexcode + .chars() + .collect::<Vec<char>>() + .chunks(2) + .map(|pair| u8::from_str_radix(pair.iter().collect::<String>().as_str(), 16)) + .filter_map(|n| n.ok()) + .collect(); + if vec.len() != 3 { + return Err("Could not parse 3 hex values from string".into()); + } + Ok(Color::Rgb { + r: vec[0], + g: vec[1], + b: vec[2], + }) + } + '@' => { + // For full flexibility, we need to use serde_json, given + // crossterm's approach. + serde_json::from_str::<Color>(format!("\"{}\"", &name[1..]).as_str()) + .map_err(|_| format!("Could not convert color name {name} to Crossterm color")) + } + _ => { + let srgb = named::from_str(name).ok_or("No such color in palette")?; + Ok(Color::Rgb { + r: srgb.red, + g: srgb.green, + b: srgb.blue, + }) + } + } +} + +pub struct StyleFactory {} + +impl StyleFactory { + fn from_fg_string(name: &str) -> Result<ContentStyle, String> { + match from_string(name) { + Ok(color) => Ok(Self::from_fg_color(color)), + Err(err) => Err(err), + } + } + + // For succinctness, if we are confident that the name will be known, + // this routine is available to keep the code readable + fn known_fg_string(name: &str) -> ContentStyle { + Self::from_fg_string(name).unwrap() + } + + fn from_fg_color(color: Color) -> ContentStyle { + ContentStyle { + foreground_color: Some(color), + ..ContentStyle::default() + } + } + + fn from_fg_color_and_attributes(color: Color, attributes: Attributes) -> ContentStyle { + ContentStyle { + foreground_color: Some(color), + attributes, + ..ContentStyle::default() + } + } +} + +// Built-in themes. Rather than having extra files added before any theming +// is available, this gives a couple of basic options, demonstrating the use +// of themes: autumn and marine +static ALERT_TYPES: LazyLock<HashMap<log::Level, Meaning>> = LazyLock::new(|| { + HashMap::from([ + (log::Level::Info, Meaning::AlertInfo), + (log::Level::Warn, Meaning::AlertWarn), + (log::Level::Error, Meaning::AlertError), + ]) +}); + +static MEANING_FALLBACKS: LazyLock<HashMap<Meaning, Meaning>> = LazyLock::new(|| { + HashMap::from([ + (Meaning::Guidance, Meaning::AlertInfo), + (Meaning::Annotation, Meaning::AlertInfo), + (Meaning::Title, Meaning::Important), + ]) +}); + +static DEFAULT_THEME: LazyLock<Theme> = LazyLock::new(|| { + Theme::new( + "default".to_string(), + None, + HashMap::from([ + ( + Meaning::AlertError, + StyleFactory::from_fg_color(Color::DarkRed), + ), + ( + Meaning::AlertWarn, + StyleFactory::from_fg_color(Color::DarkYellow), + ), + ( + Meaning::AlertInfo, + StyleFactory::from_fg_color(Color::DarkGreen), + ), + ( + Meaning::Annotation, + StyleFactory::from_fg_color(Color::DarkGrey), + ), + ( + Meaning::Guidance, + StyleFactory::from_fg_color(Color::DarkBlue), + ), + ( + Meaning::Important, + StyleFactory::from_fg_color_and_attributes( + Color::White, + Attributes::from(Attribute::Bold), + ), + ), + (Meaning::Muted, StyleFactory::from_fg_color(Color::Grey)), + (Meaning::Base, ContentStyle::default()), + ]), + ) +}); + +static BUILTIN_THEMES: LazyLock<HashMap<&'static str, Theme>> = LazyLock::new(|| { + HashMap::from([ + ("default", HashMap::new()), + ( + "(none)", + HashMap::from([ + (Meaning::AlertError, ContentStyle::default()), + (Meaning::AlertWarn, ContentStyle::default()), + (Meaning::AlertInfo, ContentStyle::default()), + (Meaning::Annotation, ContentStyle::default()), + (Meaning::Guidance, ContentStyle::default()), + (Meaning::Important, ContentStyle::default()), + (Meaning::Muted, ContentStyle::default()), + (Meaning::Base, ContentStyle::default()), + ]), + ), + ( + "autumn", + HashMap::from([ + ( + Meaning::AlertError, + StyleFactory::known_fg_string("saddlebrown"), + ), + ( + Meaning::AlertWarn, + StyleFactory::known_fg_string("darkorange"), + ), + (Meaning::AlertInfo, StyleFactory::known_fg_string("gold")), + ( + Meaning::Annotation, + StyleFactory::from_fg_color(Color::DarkGrey), + ), + (Meaning::Guidance, StyleFactory::known_fg_string("brown")), + ]), + ), + ( + "marine", + HashMap::from([ + ( + Meaning::AlertError, + StyleFactory::known_fg_string("yellowgreen"), + ), + (Meaning::AlertWarn, StyleFactory::known_fg_string("cyan")), + ( + Meaning::AlertInfo, + StyleFactory::known_fg_string("turquoise"), + ), + ( + Meaning::Annotation, + StyleFactory::known_fg_string("steelblue"), + ), + ( + Meaning::Base, + StyleFactory::known_fg_string("lightsteelblue"), + ), + (Meaning::Guidance, StyleFactory::known_fg_string("teal")), + ]), + ), + ]) + .iter() + .map(|(name, theme)| (*name, Theme::from_map(name.to_string(), None, theme))) + .collect() +}); + +// To avoid themes being repeatedly loaded, we store them in a theme manager +pub struct ThemeManager { + loaded_themes: HashMap<String, Theme>, + debug: bool, + override_theme_dir: Option<String>, +} + +// Theme-loading logic +impl ThemeManager { + pub fn new(debug: Option<bool>, theme_dir: Option<String>) -> Self { + Self { + loaded_themes: HashMap::new(), + debug: debug.unwrap_or(false), + override_theme_dir: match theme_dir { + Some(theme_dir) => Some(theme_dir), + None => std::env::var("ATUIN_THEME_DIR").ok(), + }, + } + } + + // Try to load a theme from a `{name}.toml` file in the theme directory. If an override is set + // for the theme dir (via ATUIN_THEME_DIR env) we should load the theme from there + pub fn load_theme_from_file( + &mut self, + name: &str, + max_depth: u8, + ) -> Result<&Theme, Box<dyn error::Error>> { + let mut theme_file = if let Some(p) = &self.override_theme_dir { + if p.is_empty() { + return Err(Box::new(Error::new( + ErrorKind::NotFound, + "Empty theme directory override and could not find theme elsewhere", + ))); + } + PathBuf::from(p) + } else { + let config_dir = crate::atuin_common::utils::config_dir(); + let mut theme_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { + PathBuf::from(p) + } else { + let mut theme_file = PathBuf::new(); + theme_file.push(config_dir); + theme_file + }; + theme_file.push("themes"); + theme_file + }; + + let theme_toml = format!("{name}.toml"); + theme_file.push(theme_toml); + + let mut config_builder = Config::builder(); + + config_builder = config_builder.add_source(ConfigFile::new( + theme_file.to_str().unwrap(), + FileFormat::Toml, + )); + + let config = config_builder.build()?; + self.load_theme_from_config(name, config, max_depth) + } + + pub fn load_theme_from_config( + &mut self, + name: &str, + config: Config, + max_depth: u8, + ) -> Result<&Theme, Box<dyn error::Error>> { + let debug = self.debug; + let theme_config: ThemeConfig = match config.try_deserialize() { + Ok(tc) => tc, + Err(e) => { + return Err(Box::new(Error::new( + ErrorKind::InvalidInput, + format!( + "Failed to deserialize theme: {}", + if debug { + e.to_string() + } else { + "set theme debug on for more info".to_string() + } + ), + ))); + } + }; + let colors: HashMap<Meaning, String> = theme_config.colors; + let parent: Option<&Theme> = match theme_config.theme.parent { + Some(parent_name) => { + if max_depth == 0 { + return Err(Box::new(Error::new( + ErrorKind::InvalidInput, + "Parent requested but we hit the recursion limit", + ))); + } + Some(self.load_theme(parent_name.as_str(), Some(max_depth - 1))) + } + None => Some(self.load_theme("default", Some(max_depth - 1))), + }; + + if debug && name != theme_config.theme.name { + log::warn!( + "Your theme config name is not the name of your loaded theme {} != {}", + name, + theme_config.theme.name + ); + } + + let theme = Theme::from_foreground_colors(theme_config.theme.name, parent, colors, debug); + let name = name.to_string(); + self.loaded_themes.insert(name.clone(), theme); + let theme = self.loaded_themes.get(&name).unwrap(); + Ok(theme) + } + + // Check if the requested theme is loaded and, if not, then attempt to get it + // from the builtins or, if not there, from file + pub fn load_theme(&mut self, name: &str, max_depth: Option<u8>) -> &Theme { + if self.loaded_themes.contains_key(name) { + return self.loaded_themes.get(name).unwrap(); + } + let built_ins = &BUILTIN_THEMES; + match built_ins.get(name) { + Some(theme) => theme, + None => match self.load_theme_from_file(name, max_depth.unwrap_or(DEFAULT_MAX_DEPTH)) { + Ok(theme) => theme, + Err(err) => { + log::warn!("Could not load theme {name}: {err}"); + built_ins.get("(none)").unwrap() + } + }, + } + } +} + +#[cfg(test)] +mod theme_tests { + use super::*; + + #[test] + fn test_can_load_builtin_theme() { + let mut manager = ThemeManager::new(Some(false), Some("".to_string())); + let theme = manager.load_theme("autumn", None); + assert_eq!( + theme.as_style(Meaning::Guidance).foreground_color, + from_string("brown").ok() + ); + } + + #[test] + fn test_can_create_theme() { + let mut manager = ThemeManager::new(Some(false), Some("".to_string())); + let mytheme = Theme::new( + "mytheme".to_string(), + None, + HashMap::from([( + Meaning::AlertError, + StyleFactory::known_fg_string("yellowgreen"), + )]), + ); + manager.loaded_themes.insert("mytheme".to_string(), mytheme); + let theme = manager.load_theme("mytheme", None); + assert_eq!( + theme.as_style(Meaning::AlertError).foreground_color, + from_string("yellowgreen").ok() + ); + } + + #[test] + fn test_can_fallback_when_meaning_missing() { + let mut manager = ThemeManager::new(Some(false), Some("".to_string())); + + // We use title as an example of a meaning that is not defined + // even in the base theme. + assert!(!DEFAULT_THEME.styles.contains_key(&Meaning::Title)); + + let config = Config::builder() + .add_source(ConfigFile::from_str( + " + [theme] + name = \"title_theme\" + + [colors] + Guidance = \"white\" + AlertInfo = \"zomp\" + ", + FileFormat::Toml, + )) + .build() + .unwrap(); + let theme = manager + .load_theme_from_config("config_theme", config, 1) + .unwrap(); + + // Correctly picks overridden color. + assert_eq!( + theme.as_style(Meaning::Guidance).foreground_color, + from_string("white").ok() + ); + + // Does not fall back to any color. + assert_eq!(theme.as_style(Meaning::AlertInfo).foreground_color, None); + + // Even for the base. + assert_eq!(theme.as_style(Meaning::Base).foreground_color, None); + + // Falls back to red as meaning missing from theme, so picks base default. + assert_eq!( + theme.as_style(Meaning::AlertError).foreground_color, + Some(Color::DarkRed) + ); + + // Falls back to Important as Title not available. + assert_eq!( + theme.as_style(Meaning::Title).foreground_color, + theme.as_style(Meaning::Important).foreground_color, + ); + + let title_config = Config::builder() + .add_source(ConfigFile::from_str( + " + [theme] + name = \"title_theme\" + + [colors] + Title = \"white\" + AlertInfo = \"zomp\" + ", + FileFormat::Toml, + )) + .build() + .unwrap(); + let title_theme = manager + .load_theme_from_config("title_theme", title_config, 1) + .unwrap(); + + assert_eq!( + title_theme.as_style(Meaning::Title).foreground_color, + Some(Color::White) + ); + } + + #[test] + fn test_no_fallbacks_are_circular() { + let mytheme = Theme::new("mytheme".to_string(), None, HashMap::from([])); + MEANING_FALLBACKS + .iter() + .for_each(|pair| assert_eq!(mytheme.closest_meaning(pair.0), &Meaning::Base)) + } + + #[test] + fn test_can_get_colors_via_convenience_functions() { + let mut manager = ThemeManager::new(Some(true), Some("".to_string())); + let theme = manager.load_theme("default", None); + assert_eq!(theme.get_error().foreground_color.unwrap(), Color::DarkRed); + assert_eq!( + theme.get_warning().foreground_color.unwrap(), + Color::DarkYellow + ); + assert_eq!(theme.get_info().foreground_color.unwrap(), Color::DarkGreen); + assert_eq!(theme.get_base().foreground_color, None); + assert_eq!( + theme.get_alert(log::Level::Error).foreground_color.unwrap(), + Color::DarkRed + ) + } + + #[test] + fn test_can_use_parent_theme_for_fallbacks() { + testing_logger::setup(); + + let mut manager = ThemeManager::new(Some(false), Some("".to_string())); + + // First, we introduce a base theme + let solarized = Config::builder() + .add_source(ConfigFile::from_str( + " + [theme] + name = \"solarized\" + + [colors] + Guidance = \"white\" + AlertInfo = \"pink\" + ", + FileFormat::Toml, + )) + .build() + .unwrap(); + let solarized_theme = manager + .load_theme_from_config("solarized", solarized, 1) + .unwrap(); + + assert_eq!( + solarized_theme + .as_style(Meaning::AlertInfo) + .foreground_color, + from_string("pink").ok() + ); + + // Then we introduce a derived theme + let unsolarized = Config::builder() + .add_source(ConfigFile::from_str( + " + [theme] + name = \"unsolarized\" + parent = \"solarized\" + + [colors] + AlertInfo = \"red\" + ", + FileFormat::Toml, + )) + .build() + .unwrap(); + let unsolarized_theme = manager + .load_theme_from_config("unsolarized", unsolarized, 1) + .unwrap(); + + // It will take its own values + assert_eq!( + unsolarized_theme + .as_style(Meaning::AlertInfo) + .foreground_color, + from_string("red").ok() + ); + + // ...or fall back to the parent + assert_eq!( + unsolarized_theme + .as_style(Meaning::Guidance) + .foreground_color, + from_string("white").ok() + ); + + testing_logger::validate(|captured_logs| assert_eq!(captured_logs.len(), 0)); + + // If the parent is not found, we end up with the no theme colors or styling + // as this is considered a (soft) error state. + let nunsolarized = Config::builder() + .add_source(ConfigFile::from_str( + " + [theme] + name = \"nunsolarized\" + parent = \"nonsolarized\" + + [colors] + AlertInfo = \"red\" + ", + FileFormat::Toml, + )) + .build() + .unwrap(); + let nunsolarized_theme = manager + .load_theme_from_config("nunsolarized", nunsolarized, 1) + .unwrap(); + + assert_eq!( + nunsolarized_theme + .as_style(Meaning::Guidance) + .foreground_color, + None + ); + + testing_logger::validate(|captured_logs| { + assert_eq!(captured_logs.len(), 1); + assert_eq!( + captured_logs[0].body, + "Could not load theme nonsolarized: Empty theme directory override and could not find theme elsewhere" + ); + assert_eq!(captured_logs[0].level, log::Level::Warn) + }); + } + + #[test] + fn test_can_debug_theme() { + testing_logger::setup(); + [true, false].iter().for_each(|debug| { + let mut manager = ThemeManager::new(Some(*debug), Some("".to_string())); + let config = Config::builder() + .add_source(ConfigFile::from_str( + " + [theme] + name = \"mytheme\" + + [colors] + Guidance = \"white\" + AlertInfo = \"xinetic\" + ", + FileFormat::Toml, + )) + .build() + .unwrap(); + manager + .load_theme_from_config("config_theme", config, 1) + .unwrap(); + testing_logger::validate(|captured_logs| { + if *debug { + assert_eq!(captured_logs.len(), 2); + assert_eq!( + captured_logs[0].body, + "Your theme config name is not the name of your loaded theme config_theme != mytheme" + ); + assert_eq!(captured_logs[0].level, log::Level::Warn); + assert_eq!( + captured_logs[1].body, + "Tried to load string as a color unsuccessfully: (AlertInfo=xinetic) No such color in palette" + ); + assert_eq!(captured_logs[1].level, log::Level::Warn) + } else { + assert_eq!(captured_logs.len(), 0) + } + }) + }) + } + + #[test] + fn test_can_parse_color_strings_correctly() { + assert_eq!( + from_string("brown").unwrap(), + Color::Rgb { + r: 165, + g: 42, + b: 42 + } + ); + + assert_eq!(from_string(""), Err("Empty string".into())); + + ["manatee", "caput mortuum", "123456"] + .iter() + .for_each(|inp| { + assert_eq!(from_string(inp), Err("No such color in palette".into())); + }); + + assert_eq!( + from_string("#ff1122").unwrap(), + Color::Rgb { + r: 255, + g: 17, + b: 34 + } + ); + ["#1122", "#ffaa112", "#brown"].iter().for_each(|inp| { + assert_eq!( + from_string(inp), + Err("Could not parse 3 hex values from string".into()) + ); + }); + + assert_eq!(from_string("@dark_grey").unwrap(), Color::DarkGrey); + assert_eq!( + from_string("@rgb_(255,255,255)").unwrap(), + Color::Rgb { + r: 255, + g: 255, + b: 255 + } + ); + assert_eq!(from_string("@ansi_(255)").unwrap(), Color::AnsiValue(255)); + ["@", "@DarkGray", "@Dark 4ay", "@ansi(256)"] + .iter() + .for_each(|inp| { + assert_eq!( + from_string(inp), + Err(format!( + "Could not convert color name {inp} to Crossterm color" + )) + ); + }); + } +} diff --git a/crates/turtle/src/atuin_client/utils.rs b/crates/turtle/src/atuin_client/utils.rs new file mode 100644 index 00000000..35d7db26 --- /dev/null +++ b/crates/turtle/src/atuin_client/utils.rs @@ -0,0 +1,14 @@ +pub(crate) fn get_hostname() -> String { + std::env::var("ATUIN_HOST_NAME") + .unwrap_or_else(|_| whoami::hostname().unwrap_or_else(|_| "unknown-host".to_string())) +} + +pub(crate) fn get_username() -> String { + std::env::var("ATUIN_HOST_USER") + .unwrap_or_else(|_| whoami::username().unwrap_or_else(|_| "unknown-user".to_string())) +} + +/// Returns a pair of the hostname and username, separated by a colon. +pub(crate) fn get_host_user() -> String { + format!("{}:{}", get_hostname(), get_username()) +} |
