diff options
| author | Benedikt Peetz <benedikt.peetz@b-peetz.de> | 2026-06-11 00:54:30 +0200 |
|---|---|---|
| committer | Benedikt Peetz <benedikt.peetz@b-peetz.de> | 2026-06-11 00:54:30 +0200 |
| commit | 5c39e7cf284a1f6e9a1657f2deb44e359fc47eb8 (patch) | |
| tree | c64baa8d5866c8e339eaf660dd3f94f30a3f7d8a /crates/atuin-client/src | |
| parent | chore: Somewhat simplify sync code (diff) | |
| download | atuin-5c39e7cf284a1f6e9a1657f2deb44e359fc47eb8.zip | |
chore: Move everything into one big crate
That helps remove duplicated code and rustc/cargo will now also show
dead code correctly.
Diffstat (limited to 'crates/atuin-client/src')
39 files changed, 0 insertions, 11999 deletions
diff --git a/crates/atuin-client/src/api_client.rs b/crates/atuin-client/src/api_client.rs deleted file mode 100644 index ca2fc661..00000000 --- a/crates/atuin-client/src/api_client.rs +++ /dev/null @@ -1,437 +0,0 @@ -use std::collections::HashMap; -use std::env; -use std::time::Duration; - -use eyre::{Result, bail, eyre}; -use reqwest::{ - Response, StatusCode, Url, - header::{AUTHORIZATION, HeaderMap, USER_AGENT}, -}; - -use atuin_common::{ - api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ATUIN_VERSION}, - record::{EncryptedData, HostId, Record, RecordIdx}, - tls::ensure_crypto_provider, -}; -use atuin_common::{ - api::{ - AddHistoryRequest, ChangePasswordRequest, CountResponse, DeleteHistoryRequest, - ErrorResponse, LoginRequest, LoginResponse, MeResponse, RegisterResponse, StatusResponse, - SyncHistoryResponse, - }, - record::RecordStatus, -}; - -use semver::Version; -use time::OffsetDateTime; -use time::format_description::well_known::Rfc3339; - -use crate::{history::History, sync::hash_str, utils::get_host_user}; - -static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION"),); - -/// Authentication token for sync API requests. -/// -/// The sync API supports two authentication methods: -/// - `Bearer`: Hub API tokens (for users authenticated via Atuin Hub) -/// - `Token`: Legacy CLI session tokens (for users registered via CLI or self-hosted) -/// -/// When both are available, Hub tokens are preferred as they provide unified -/// authentication across CLI and Hub features. -#[derive(Debug, Clone)] -pub enum AuthToken { - /// Legacy CLI session token, used with "Token {token}" header - Token(String), -} - -impl AuthToken { - /// Format the token as an Authorization header value - fn to_header_value(&self) -> String { - match self { - AuthToken::Token(token) => format!("Token {token}"), - } - } -} - -pub struct Client<'a> { - sync_addr: &'a str, - client: reqwest::Client, -} - -fn make_url(address: &str, path: &str) -> Result<String> { - // `join()` expects a trailing `/` in order to join paths - // e.g. it treats `http://host:port/subdir` as a file called `subdir` - let address = if address.ends_with("/") { - address - } else { - &format!("{address}/") - }; - - // passing a path with a leading `/` will cause `join()` to replace the entire URL path - let path = path.strip_prefix("/").unwrap_or(path); - - let url = Url::parse(address) - .map(|url| url.join(path))? - .map_err(|_| eyre!("invalid address"))?; - - Ok(url.to_string()) -} - -pub async fn register( - address: &str, - username: &str, - email: &str, - password: &str, -) -> Result<RegisterResponse> { - ensure_crypto_provider(); - let mut map = HashMap::new(); - map.insert("username", username); - map.insert("email", email); - map.insert("password", password); - - let url = make_url(address, &format!("/user/{username}"))?; - let resp = reqwest::get(url).await?; - - if resp.status().is_success() { - bail!("username already in use"); - } - - let url = make_url(address, "/register")?; - let client = reqwest::Client::new(); - let resp = client - .post(url) - .header(USER_AGENT, APP_USER_AGENT) - .header(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION) - .json(&map) - .send() - .await?; - let resp = handle_resp_error(resp).await?; - - if !ensure_version(&resp)? { - bail!("could not register user due to version mismatch"); - } - - let session = resp.json::<RegisterResponse>().await?; - Ok(session) -} - -pub async fn login(address: &str, req: LoginRequest) -> Result<LoginResponse> { - ensure_crypto_provider(); - let url = make_url(address, "/login")?; - let client = reqwest::Client::new(); - - let resp = client - .post(url) - .header(USER_AGENT, APP_USER_AGENT) - .json(&req) - .send() - .await?; - let resp = handle_resp_error(resp).await?; - - if !ensure_version(&resp)? { - bail!("Could not login due to version mismatch"); - } - - let session = resp.json::<LoginResponse>().await?; - Ok(session) -} - -pub fn ensure_version(response: &Response) -> Result<bool> { - let version = response.headers().get(ATUIN_HEADER_VERSION); - - let version = if let Some(version) = version { - match version.to_str() { - Ok(v) => Version::parse(v), - Err(e) => bail!("failed to parse server version: {:?}", e), - } - } else { - bail!("Server not reporting its version: it is either too old or unhealthy"); - }?; - - // If the client is newer than the server - if version.major < ATUIN_VERSION.major { - println!( - "Atuin version mismatch! In order to successfully sync, the server needs to run a newer version of Atuin" - ); - println!("Client: {ATUIN_CARGO_VERSION}"); - println!("Server: {version}"); - - return Ok(false); - } - - Ok(true) -} - -async fn handle_resp_error(resp: Response) -> Result<Response> { - let status = resp.status(); - let url = resp.url().to_string(); - - if status == StatusCode::SERVICE_UNAVAILABLE { - bail!( - "Service unavailable: check https://status.atuin.sh (or get in touch with your host)" - ); - } - - if status == StatusCode::TOO_MANY_REQUESTS { - bail!("Rate limited; please wait before doing that again"); - } - - if !status.is_success() { - if let Ok(error) = resp.json::<ErrorResponse>().await { - let reason = error.reason; - - if status.is_client_error() { - bail!("Invalid request to the service at {url}, {status} - {reason}.") - } - - bail!( - "There was an error with the atuin sync service at {url}, server error {status}: {reason}.\nIf the problem persists, contact the host" - ) - } - - bail!( - "There was an error with the atuin sync service at {url}, Status {status:?}.\nIf the problem persists, contact the host" - ) - } - - Ok(resp) -} - -impl<'a> Client<'a> { - pub fn new( - sync_addr: &'a str, - auth: AuthToken, - connect_timeout: u64, - timeout: u64, - ) -> Result<Self> { - ensure_crypto_provider(); - let mut headers = HeaderMap::new(); - headers.insert(AUTHORIZATION, auth.to_header_value().parse()?); - - // used for semver server check - headers.insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse()?); - - Ok(Client { - sync_addr, - client: reqwest::Client::builder() - .user_agent(APP_USER_AGENT) - .default_headers(headers) - .connect_timeout(Duration::new(connect_timeout, 0)) - .timeout(Duration::new(timeout, 0)) - .build()?, - }) - } - - pub async fn count(&self) -> Result<i64> { - let url = make_url(self.sync_addr, "/sync/count")?; - let url = Url::parse(url.as_str())?; - - let resp = self.client.get(url).send().await?; - let resp = handle_resp_error(resp).await?; - - if !ensure_version(&resp)? { - bail!("could not sync due to version mismatch"); - } - - if resp.status() != StatusCode::OK { - bail!("failed to get count (are you logged in?)"); - } - - let count = resp.json::<CountResponse>().await?; - - Ok(count.count) - } - - pub async fn status(&self) -> Result<StatusResponse> { - let url = make_url(self.sync_addr, "/sync/status")?; - let url = Url::parse(url.as_str())?; - - let resp = self.client.get(url).send().await?; - let resp = handle_resp_error(resp).await?; - - if !ensure_version(&resp)? { - bail!("could not sync due to version mismatch"); - } - - let status = resp.json::<StatusResponse>().await?; - - Ok(status) - } - - pub async fn me(&self) -> Result<MeResponse> { - let url = make_url(self.sync_addr, "/api/v0/me")?; - let url = Url::parse(url.as_str())?; - - let resp = self.client.get(url).send().await?; - let resp = handle_resp_error(resp).await?; - - let status = resp.json::<MeResponse>().await?; - - Ok(status) - } - - pub async fn get_history( - &self, - sync_ts: OffsetDateTime, - history_ts: OffsetDateTime, - host: Option<String>, - ) -> Result<SyncHistoryResponse> { - let host = host.unwrap_or_else(|| hash_str(&get_host_user())); - - let url = make_url( - self.sync_addr, - &format!( - "/sync/history?sync_ts={}&history_ts={}&host={}", - urlencoding::encode(sync_ts.format(&Rfc3339)?.as_str()), - urlencoding::encode(history_ts.format(&Rfc3339)?.as_str()), - host, - ), - )?; - - let resp = self.client.get(url).send().await?; - let resp = handle_resp_error(resp).await?; - - let history = resp.json::<SyncHistoryResponse>().await?; - Ok(history) - } - - pub async fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> { - let url = make_url(self.sync_addr, "/history")?; - let url = Url::parse(url.as_str())?; - - let resp = self.client.post(url).json(history).send().await?; - handle_resp_error(resp).await?; - - Ok(()) - } - - pub async fn delete_history(&self, h: History) -> Result<()> { - let url = make_url(self.sync_addr, "/history")?; - let url = Url::parse(url.as_str())?; - - let resp = self - .client - .delete(url) - .json(&DeleteHistoryRequest { - client_id: h.id.to_string(), - }) - .send() - .await?; - - handle_resp_error(resp).await?; - - Ok(()) - } - - pub async fn delete_store(&self) -> Result<()> { - let url = make_url(self.sync_addr, "/api/v0/store")?; - let url = Url::parse(url.as_str())?; - - let resp = self.client.delete(url).send().await?; - - handle_resp_error(resp).await?; - - Ok(()) - } - - pub async fn post_records(&self, records: &[Record<EncryptedData>]) -> Result<()> { - let url = make_url(self.sync_addr, "/api/v0/record")?; - let url = Url::parse(url.as_str())?; - - debug!("uploading {} records to {url}", records.len()); - - let resp = self.client.post(url).json(records).send().await?; - handle_resp_error(resp).await?; - - Ok(()) - } - - pub async fn next_records( - &self, - host: HostId, - tag: String, - start: RecordIdx, - count: u64, - ) -> Result<Vec<Record<EncryptedData>>> { - debug!("fetching record/s from host {}/{}/{}", host.0, tag, start); - - let url = make_url( - self.sync_addr, - &format!( - "/api/v0/record/next?host={}&tag={}&count={}&start={}", - host.0, tag, count, start - ), - )?; - - let url = Url::parse(url.as_str())?; - - let resp = self.client.get(url).send().await?; - let resp = handle_resp_error(resp).await?; - - let records = resp.json::<Vec<Record<EncryptedData>>>().await?; - - Ok(records) - } - - pub async fn record_status(&self) -> Result<RecordStatus> { - let url = make_url(self.sync_addr, "/api/v0/record")?; - let url = Url::parse(url.as_str())?; - - let resp = self.client.get(url).send().await?; - let resp = handle_resp_error(resp).await?; - - if !ensure_version(&resp)? { - bail!("could not sync records due to version mismatch"); - } - - let index = resp.json().await?; - - debug!("got remote index {index:?}"); - - Ok(index) - } - - pub async fn delete(&self) -> Result<()> { - let url = make_url(self.sync_addr, "/account")?; - let url = Url::parse(url.as_str())?; - - let resp = self.client.delete(url).send().await?; - - if resp.status() == 403 { - bail!("invalid login details"); - } else if resp.status() == 200 { - Ok(()) - } else { - bail!("Unknown error"); - } - } - - pub async fn change_password( - &self, - current_password: String, - new_password: String, - ) -> Result<()> { - let url = make_url(self.sync_addr, "/account/password")?; - let url = Url::parse(url.as_str())?; - - let resp = self - .client - .patch(url) - .json(&ChangePasswordRequest { - current_password, - new_password, - }) - .send() - .await?; - - if resp.status() == 401 { - bail!("current password is incorrect") - } else if resp.status() == 403 { - bail!("invalid login details"); - } else if resp.status() == 200 { - Ok(()) - } else { - bail!("Unknown error"); - } - } -} diff --git a/crates/atuin-client/src/auth.rs b/crates/atuin-client/src/auth.rs deleted file mode 100644 index 1031c11f..00000000 --- a/crates/atuin-client/src/auth.rs +++ /dev/null @@ -1,230 +0,0 @@ -use async_trait::async_trait; -use eyre::{Context, Result, bail}; -use reqwest::{Url, header::USER_AGENT}; - -use atuin_common::{ - api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ChangePasswordRequest, LoginRequest}, - tls::ensure_crypto_provider, -}; - -use crate::settings::Settings; - -static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION")); - -/// Result of an auth operation that may require 2FA. -pub enum AuthResponse { - /// Operation succeeded; for login/register, contains the session token. - /// `auth_type` indicates the kind of token: `Some("hub")` for Hub API - /// tokens (prefixed `atapi_`), `Some("cli")` for legacy CLI session - /// tokens. `None` when the server didn't include the field (old servers). - Success { - session: String, - auth_type: Option<String>, - }, - /// Two-factor authentication is required; the caller should prompt for a - /// TOTP code and retry with it. - TwoFactorRequired, -} - -/// Result of a mutating account operation that may require 2FA. -pub enum MutateResponse { - /// Operation completed successfully. - Success, - /// Two-factor authentication is required; the caller should prompt for a - /// TOTP code and retry. - TwoFactorRequired, -} - -/// Abstraction over the legacy (Rust sync server) and Hub auth APIs. -/// -/// CLI commands use this trait so they don't need to know which backend is -/// active — they just prompt for input and call these methods. -#[async_trait] -pub trait AuthClient: Send + Sync { - /// Log in with username + password, optionally providing a TOTP code. - async fn login( - &self, - username: &str, - password: &str, - totp_code: Option<&str>, - ) -> Result<AuthResponse>; - - /// Register a new account. - async fn register(&self, username: &str, email: &str, password: &str) -> Result<AuthResponse>; - - /// Change the account password, optionally providing a TOTP code. - async fn change_password( - &self, - current_password: &str, - new_password: &str, - totp_code: Option<&str>, - ) -> Result<MutateResponse>; - - /// Delete the account, requiring the current password and optionally a TOTP code. - async fn delete_account( - &self, - password: &str, - totp_code: Option<&str>, - ) -> Result<MutateResponse>; -} - -/// Resolve the appropriate [`AuthClient`] for the current settings. -pub async fn auth_client(settings: &Settings) -> Box<dyn AuthClient> { - Box::new(LegacyAuthClient::new( - &settings.sync_address, - settings.session_token().await.ok(), - settings.network_connect_timeout, - settings.network_timeout, - )) as Box<dyn AuthClient> -} - -// --------------------------------------------------------------------------- -// Legacy backend — talks to the Rust sync server -// --------------------------------------------------------------------------- - -pub struct LegacyAuthClient { - address: String, - session_token: Option<String>, - connect_timeout: u64, - timeout: u64, -} - -impl LegacyAuthClient { - pub fn new( - address: &str, - session_token: Option<String>, - connect_timeout: u64, - timeout: u64, - ) -> Self { - Self { - address: address.to_string(), - session_token, - connect_timeout, - timeout, - } - } - - fn authenticated_client(&self) -> Result<reqwest::Client> { - let token = self - .session_token - .as_deref() - .ok_or_else(|| eyre::eyre!("Not logged in"))?; - - ensure_crypto_provider(); - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert( - reqwest::header::AUTHORIZATION, - format!("Token {token}").parse()?, - ); - headers.insert(USER_AGENT, APP_USER_AGENT.parse()?); - headers.insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse()?); - - Ok(reqwest::Client::builder() - .default_headers(headers) - .connect_timeout(std::time::Duration::new(self.connect_timeout, 0)) - .timeout(std::time::Duration::new(self.timeout, 0)) - .build()?) - } -} - -#[async_trait] -impl AuthClient for LegacyAuthClient { - async fn login( - &self, - username: &str, - password: &str, - _totp_code: Option<&str>, - ) -> Result<AuthResponse> { - // The legacy server has no 2FA support; totp_code is ignored. - let resp = crate::api_client::login( - &self.address, - LoginRequest { - username: username.to_string(), - password: password.to_string(), - }, - ) - .await?; - - Ok(AuthResponse::Success { - session: resp.session, - auth_type: resp.auth.or(Some("cli".into())), - }) - } - - async fn register(&self, username: &str, email: &str, password: &str) -> Result<AuthResponse> { - let resp = crate::api_client::register(&self.address, username, email, password).await?; - Ok(AuthResponse::Success { - session: resp.session, - auth_type: resp.auth.or(Some("cli".into())), - }) - } - - async fn change_password( - &self, - current_password: &str, - new_password: &str, - _totp_code: Option<&str>, - ) -> Result<MutateResponse> { - let client = self.authenticated_client()?; - let url = make_url(&self.address, "/account/password")?; - - let resp = client - .patch(&url) - .json(&ChangePasswordRequest { - current_password: current_password.to_string(), - new_password: new_password.to_string(), - }) - .send() - .await?; - - match resp.status().as_u16() { - 200 => Ok(MutateResponse::Success), - 401 => bail!("current password is incorrect"), - 403 => bail!("invalid login details"), - _ => bail!("unknown error"), - } - } - - async fn delete_account( - &self, - password: &str, - _totp_code: Option<&str>, - ) -> Result<MutateResponse> { - let client = self.authenticated_client()?; - let url = make_url(&self.address, "/account")?; - - let resp = client - .delete(&url) - .json(&serde_json::json!({ "password": password })) - .send() - .await?; - - match resp.status().as_u16() { - 200 => Ok(MutateResponse::Success), - 401 => bail!("password is incorrect"), - 403 => bail!("invalid login details"), - _ => bail!("unknown error"), - } - } -} - -// --------------------------------------------------------------------------- -// Shared helpers -// --------------------------------------------------------------------------- - -fn make_url(address: &str, path: &str) -> Result<String> { - let address = if address.ends_with('/') { - address.to_string() - } else { - format!("{address}/") - }; - - let path = path.strip_prefix('/').unwrap_or(path); - - let url = Url::parse(&address) - .context("failed to parse server address")? - .join(path) - .context("failed to join URL path")?; - - Ok(url.to_string()) -} diff --git a/crates/atuin-client/src/database.rs b/crates/atuin-client/src/database.rs deleted file mode 100644 index 946c1eb0..00000000 --- a/crates/atuin-client/src/database.rs +++ /dev/null @@ -1,1525 +0,0 @@ -use std::{ - env, - path::{Path, PathBuf}, - str::FromStr, - time::Duration, -}; - -use crate::history::{AUTHOR_FILTER_ALL_AGENT, AUTHOR_FILTER_ALL_USER, KNOWN_AGENTS}; -use async_trait::async_trait; -use atuin_common::utils; -use fs_err as fs; -use itertools::Itertools; -use rand::{Rng, distributions::Alphanumeric}; -use sql_builder::{SqlBuilder, SqlName, bind::Bind, esc, quote}; -use sqlx::{ - Result, Row, - sqlite::{ - SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow, - SqliteSynchronous, - }, -}; -use time::OffsetDateTime; -use uuid::Uuid; - -use crate::{ - history::{HistoryId, HistoryStats}, - utils::get_host_user, -}; - -use super::{ - history::History, - ordering, - settings::{FilterMode, SearchMode, Settings}, -}; - -#[derive(Clone)] -pub struct Context { - pub session: String, - pub cwd: String, - pub hostname: String, - pub host_id: String, - pub git_root: Option<PathBuf>, -} - -#[derive(Default, Clone)] -pub struct OptFilters { - pub exit: Option<i64>, - pub exclude_exit: Option<i64>, - pub cwd: Option<String>, - pub exclude_cwd: Option<String>, - pub before: Option<String>, - pub after: Option<String>, - pub limit: Option<i64>, - pub offset: Option<i64>, - pub reverse: bool, - pub include_duplicates: bool, - /// Author filter. Supports special values `$all-user` and `$all-agent`. - pub authors: Vec<String>, -} - -pub async fn current_context() -> eyre::Result<Context> { - let session = env::var("ATUIN_SESSION").map_err(|_| { - eyre::eyre!("Failed to find $ATUIN_SESSION in the environment. Check that you have correctly set up your shell.") - })?; - let hostname = get_host_user(); - let cwd = utils::get_current_dir(); - let host_id = Settings::host_id().await?; - let git_root = utils::in_git_repo(cwd.as_str()); - - Ok(Context { - session, - hostname, - cwd, - git_root, - host_id: host_id.0.as_simple().to_string(), - }) -} - -impl Context { - pub fn from_history(entry: &History) -> Self { - Context { - session: entry.session.to_string(), - cwd: entry.cwd.to_string(), - hostname: entry.hostname.to_string(), - host_id: String::new(), - git_root: utils::in_git_repo(entry.cwd.as_str()), - } - } -} - -/// Each entry is OR'd: `$all-user` → NOT IN agents, `$all-agent` → IN agents, literal → exact match. -fn apply_author_filter(sql: &mut SqlBuilder, authors: &[String]) { - let mut conditions: Vec<String> = Vec::new(); - let agent_list: String = KNOWN_AGENTS.iter().map(quote).join(", "); - let author_expr = "CASE \ - WHEN author IS NULL OR trim(author) = '' THEN \ - CASE \ - WHEN instr(hostname, ':') > 0 THEN substr(hostname, instr(hostname, ':') + 1) \ - ELSE hostname \ - END \ - ELSE author \ - END"; - - for author in authors { - match author.as_str() { - AUTHOR_FILTER_ALL_USER => { - conditions.push(format!("{author_expr} NOT IN ({agent_list})")); - } - AUTHOR_FILTER_ALL_AGENT => { - conditions.push(format!("{author_expr} IN ({agent_list})")); - } - literal => { - conditions.push(format!("{author_expr} = {}", quote(literal))); - } - } - } - - if !conditions.is_empty() { - sql.and_where(format!("({})", conditions.join(" OR "))); - } -} - -fn get_session_start_time(session_id: &str) -> Option<i64> { - if let Ok(uuid) = Uuid::parse_str(session_id) - && let Some(timestamp) = uuid.get_timestamp() - { - let (seconds, nanos) = timestamp.to_unix(); - return Some(seconds as i64 * 1_000_000_000 + nanos as i64); - } - None -} - -#[async_trait] -pub trait Database: Send + Sync + 'static { - async fn save(&self, h: &History) -> Result<()>; - async fn save_bulk(&self, h: &[History]) -> Result<()>; - - async fn load(&self, id: &str) -> Result<Option<History>>; - async fn list( - &self, - filters: &[FilterMode], - context: &Context, - max: Option<usize>, - unique: bool, - include_deleted: bool, - ) -> Result<Vec<History>>; - async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>>; - - async fn update(&self, h: &History) -> Result<()>; - async fn history_count(&self, include_deleted: bool) -> Result<i64>; - - async fn last(&self) -> Result<Option<History>>; - async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>>; - - async fn delete(&self, h: History) -> Result<()>; - async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()>; - async fn deleted(&self) -> Result<Vec<History>>; - - // Yes I know, it's a lot. - // Could maybe break it down to a searchparams struct or smth but that feels a little... pointless. - // Been debating maybe a DSL for search? eg "before:time limit:1 the query" - #[expect(clippy::too_many_arguments)] - async fn search( - &self, - search_mode: SearchMode, - filter: FilterMode, - context: &Context, - query: &str, - filter_options: OptFilters, - ) -> Result<Vec<History>>; - - async fn query_history(&self, query: &str) -> Result<Vec<History>>; - - async fn all_with_count(&self) -> Result<Vec<(History, i32)>>; - - fn all_paged(&self, page_size: usize, include_deleted: bool, unique: bool) -> Paged; - - async fn stats(&self, h: &History) -> Result<HistoryStats>; - - async fn get_dups(&self, before: i64, dupkeep: u32) -> Result<Vec<History>>; - - fn clone_boxed(&self) -> Box<dyn Database + 'static>; -} - -// Intended for use on a developer machine and not a sync server. -// TODO: implement IntoIterator -#[derive(Debug, Clone)] -pub struct Sqlite { - pub pool: SqlitePool, -} - -impl Sqlite { - pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> { - let path = path.as_ref(); - debug!("opening sqlite database at {path:?}"); - - if utils::broken_symlink(path) { - eprintln!( - "Atuin: Sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement." - ); - std::process::exit(1); - } - - if !path.exists() - && let Some(dir) = path.parent() - { - fs::create_dir_all(dir)?; - } - - let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())? - .journal_mode(SqliteJournalMode::Wal) - .optimize_on_close(true, None) - .synchronous(SqliteSynchronous::Normal) - .with_regexp() - .create_if_missing(true); - - let pool = SqlitePoolOptions::new() - .acquire_timeout(Duration::from_secs_f64(timeout)) - .connect_with(opts) - .await?; - - Self::setup_db(&pool).await?; - Ok(Self { pool }) - } - - pub async fn sqlite_version(&self) -> Result<String> { - sqlx::query_scalar("SELECT sqlite_version()") - .fetch_one(&self.pool) - .await - } - - async fn setup_db(pool: &SqlitePool) -> Result<()> { - debug!("running sqlite database setup"); - - sqlx::migrate!("./migrations").run(pool).await?; - - Ok(()) - } - - async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, h: &History) -> Result<()> { - sqlx::query( - "insert or ignore into history(id, timestamp, duration, exit, command, cwd, session, hostname, author, intent, deleted_at) - values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)", - ) - .bind(h.id.0.as_str()) - .bind(h.timestamp.unix_timestamp_nanos() as i64) - .bind(h.duration) - .bind(h.exit) - .bind(h.command.as_str()) - .bind(h.cwd.as_str()) - .bind(h.session.as_str()) - .bind(h.hostname.as_str()) - .bind(h.author.as_str()) - .bind(h.intent.as_deref()) - .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64)) - .execute(&mut **tx) - .await?; - - Ok(()) - } - - async fn delete_row_raw( - tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, - id: HistoryId, - ) -> Result<()> { - sqlx::query("delete from history where id = ?1") - .bind(id.0.as_str()) - .execute(&mut **tx) - .await?; - - Ok(()) - } - - fn query_history(row: SqliteRow) -> History { - let deleted_at: Option<i64> = row.get("deleted_at"); - let hostname: String = row.get("hostname"); - let author: Option<String> = row.try_get("author").ok().flatten(); - let author = author - .filter(|author| !author.trim().is_empty()) - .unwrap_or_else(|| History::author_from_hostname(hostname.as_str())); - let intent: Option<String> = row.try_get("intent").ok().flatten(); - let intent = intent.filter(|intent| !intent.trim().is_empty()); - - History::from_db() - .id(row.get("id")) - .timestamp( - OffsetDateTime::from_unix_timestamp_nanos(row.get::<i64, _>("timestamp") as i128) - .unwrap(), - ) - .duration(row.get("duration")) - .exit(row.get("exit")) - .command(row.get("command")) - .cwd(row.get("cwd")) - .session(row.get("session")) - .hostname(hostname) - .author(author) - .intent(intent) - .deleted_at( - deleted_at.and_then(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128).ok()), - ) - .build() - .into() - } -} - -#[async_trait] -impl Database for Sqlite { - async fn save(&self, h: &History) -> Result<()> { - debug!("saving history to sqlite"); - let mut tx = self.pool.begin().await?; - Self::save_raw(&mut tx, h).await?; - tx.commit().await?; - - Ok(()) - } - - async fn save_bulk(&self, h: &[History]) -> Result<()> { - debug!("saving history to sqlite"); - - let mut tx = self.pool.begin().await?; - - for i in h { - Self::save_raw(&mut tx, i).await?; - } - - tx.commit().await?; - - Ok(()) - } - - async fn load(&self, id: &str) -> Result<Option<History>> { - debug!("loading history item {}", id); - - let res = sqlx::query("select * from history where id = ?1") - .bind(id) - .map(Self::query_history) - .fetch_optional(&self.pool) - .await?; - - Ok(res) - } - - async fn update(&self, h: &History) -> Result<()> { - debug!("updating sqlite history"); - - sqlx::query( - "update history - set timestamp = ?2, duration = ?3, exit = ?4, command = ?5, cwd = ?6, session = ?7, hostname = ?8, author = ?9, intent = ?10, deleted_at = ?11 - where id = ?1", - ) - .bind(h.id.0.as_str()) - .bind(h.timestamp.unix_timestamp_nanos() as i64) - .bind(h.duration) - .bind(h.exit) - .bind(h.command.as_str()) - .bind(h.cwd.as_str()) - .bind(h.session.as_str()) - .bind(h.hostname.as_str()) - .bind(h.author.as_str()) - .bind(h.intent.as_deref()) - .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64)) - .execute(&self.pool) - .await?; - - Ok(()) - } - - // make a unique list, that only shows the *newest* version of things - async fn list( - &self, - filters: &[FilterMode], - context: &Context, - max: Option<usize>, - unique: bool, - include_deleted: bool, - ) -> Result<Vec<History>> { - debug!("listing history"); - - let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted()); - query.field("*").order_desc("timestamp"); - if !include_deleted { - query.and_where_is_null("deleted_at"); - } - - let git_root = if let Some(git_root) = context.git_root.clone() { - git_root.to_str().unwrap_or("/").to_string() - } else { - context.cwd.clone() - }; - - let session_start = get_session_start_time(&context.session); - - for filter in filters { - match filter { - FilterMode::Global => &mut query, - FilterMode::Host => query.and_where_eq("hostname", quote(&context.hostname)), - FilterMode::Session => query.and_where_eq("session", quote(&context.session)), - FilterMode::SessionPreload => { - query.and_where_eq("session", quote(&context.session)); - if let Some(session_start) = session_start { - query.or_where_lt("timestamp", session_start); - } - &mut query - } - FilterMode::Directory => query.and_where_eq("cwd", quote(&context.cwd)), - FilterMode::Workspace => query.and_where_like_left("cwd", &git_root), - }; - } - - if unique { - query.group_by("command").having("max(timestamp)"); - } - - if let Some(max) = max { - query.limit(max); - } - - let query = query.sql().expect("bug in list query. please report"); - - let res = sqlx::query(&query) - .map(Self::query_history) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } - - async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>> { - debug!("listing history from {:?} to {:?}", from, to); - - let res = sqlx::query( - "select * from history where timestamp >= ?1 and timestamp <= ?2 order by timestamp asc", - ) - .bind(from.unix_timestamp_nanos() as i64) - .bind(to.unix_timestamp_nanos() as i64) - .map(Self::query_history) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } - - async fn last(&self) -> Result<Option<History>> { - let res = sqlx::query( - "select * from history where duration >= 0 order by timestamp desc limit 1", - ) - .map(Self::query_history) - .fetch_optional(&self.pool) - .await?; - - Ok(res) - } - - async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>> { - let res = sqlx::query( - "select * from history where timestamp < ?1 order by timestamp desc limit ?2", - ) - .bind(timestamp.unix_timestamp_nanos() as i64) - .bind(count) - .map(Self::query_history) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } - - async fn deleted(&self) -> Result<Vec<History>> { - let res = sqlx::query("select * from history where deleted_at is not null") - .map(Self::query_history) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } - - async fn history_count(&self, include_deleted: bool) -> Result<i64> { - let query = if include_deleted { - "select count(1) from history" - } else { - "select count(1) from history where deleted_at is null" - }; - - let res: (i64,) = sqlx::query_as(query).fetch_one(&self.pool).await?; - Ok(res.0) - } - - async fn search( - &self, - search_mode: SearchMode, - filter: FilterMode, - context: &Context, - query: &str, - filter_options: OptFilters, - ) -> Result<Vec<History>> { - let mut sql = SqlBuilder::select_from("history"); - - if !filter_options.include_duplicates { - sql.group_by("command").having("max(timestamp)"); - } - - if let Some(limit) = filter_options.limit { - sql.limit(limit); - } - - if let Some(offset) = filter_options.offset { - sql.offset(offset); - } - - if filter_options.reverse { - sql.order_asc("timestamp"); - } else { - sql.order_desc("timestamp"); - } - - let git_root = if let Some(git_root) = context.git_root.clone() { - git_root.to_str().unwrap_or("/").to_string() - } else { - context.cwd.clone() - }; - - let session_start = get_session_start_time(&context.session); - - match filter { - FilterMode::Global => &mut sql, - FilterMode::Host => { - sql.and_where_eq("lower(hostname)", quote(context.hostname.to_lowercase())) - } - FilterMode::Session => sql.and_where_eq("session", quote(&context.session)), - FilterMode::SessionPreload => { - sql.and_where_eq("session", quote(&context.session)); - if let Some(session_start) = session_start { - sql.or_where_lt("timestamp", session_start); - } - &mut sql - } - FilterMode::Directory => sql.and_where_eq("cwd", quote(&context.cwd)), - FilterMode::Workspace => sql.and_where_like_left("cwd", git_root), - }; - - let orig_query = query; - - let mut regexes = Vec::new(); - match search_mode { - SearchMode::Prefix => sql.and_where_like_left("command", query.replace('*', "%")), - _ => { - let mut is_or = false; - for token in QueryTokenizer::new(query) { - // TODO smart case mode could be made configurable like in fzf - let (is_glob, glob) = if token.has_uppercase() { - (true, "*") - } else { - (false, "%") - }; - let param = match token { - QueryToken::Regex(r) => { - regexes.push(String::from(r)); - continue; - } - QueryToken::Or => { - if !is_or { - is_or = true; - continue; - } else { - format!("{glob}|{glob}") - } - } - QueryToken::MatchStart(term, _) => { - format!("{term}{glob}") - } - QueryToken::MatchEnd(term, _) => { - format!("{glob}{term}") - } - QueryToken::MatchFull(term, _) => { - format!("{glob}{term}{glob}") - } - QueryToken::Match(term, _) => { - if search_mode == SearchMode::FullText { - format!("{glob}{term}{glob}") - } else { - term.split("").join(glob) - } - } - }; - - sql.fuzzy_condition("command", param, token.is_inverse(), is_glob, is_or); - is_or = false; - } - - &mut sql - } - }; - - for regex in regexes { - sql.and_where("command regexp ?".bind(®ex)); - } - - filter_options - .exit - .map(|exit| sql.and_where_eq("exit", exit)); - - filter_options - .exclude_exit - .map(|exclude_exit| sql.and_where_ne("exit", exclude_exit)); - - filter_options - .cwd - .map(|cwd| sql.and_where_eq("cwd", quote(cwd))); - - filter_options - .exclude_cwd - .map(|exclude_cwd| sql.and_where_ne("cwd", quote(exclude_cwd))); - - filter_options.before.map(|before| { - interim::parse_date_string( - before.as_str(), - OffsetDateTime::now_utc(), - interim::Dialect::Uk, - ) - .map(|before| { - sql.and_where_lt("timestamp", quote(before.unix_timestamp_nanos() as i64)) - }) - }); - - filter_options.after.map(|after| { - interim::parse_date_string( - after.as_str(), - OffsetDateTime::now_utc(), - interim::Dialect::Uk, - ) - .map(|after| sql.and_where_gt("timestamp", quote(after.unix_timestamp_nanos() as i64))) - }); - - if !filter_options.authors.is_empty() { - apply_author_filter(&mut sql, &filter_options.authors); - } - - sql.and_where_is_null("deleted_at"); - - let query = sql.sql().expect("bug in search query. please report"); - - let res = sqlx::query(&query) - .map(Self::query_history) - .fetch_all(&self.pool) - .await?; - - Ok(ordering::reorder_fuzzy(search_mode, orig_query, res)) - } - - async fn query_history(&self, query: &str) -> Result<Vec<History>> { - let res = sqlx::query(query) - .map(Self::query_history) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } - - async fn all_with_count(&self) -> Result<Vec<(History, i32)>> { - debug!("listing history"); - - let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted()); - - query - .fields(&[ - "id", - "max(timestamp) as timestamp", - "max(duration) as duration", - "exit", - "command", - "deleted_at", - "null as author", - "null as intent", - "group_concat(cwd, ':') as cwd", - "group_concat(session) as session", - "group_concat(hostname, ',') as hostname", - "count(*) as count", - ]) - .group_by("command") - .group_by("exit") - .and_where("deleted_at is null") - .order_desc("timestamp"); - - let query = query.sql().expect("bug in list query. please report"); - - let res = sqlx::query(&query) - .map(|row: SqliteRow| { - let count: i32 = row.get("count"); - (Self::query_history(row), count) - }) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } - - fn all_paged(&self, page_size: usize, include_deleted: bool, unique: bool) -> Paged { - Paged::new(Box::new(self.clone()), page_size, include_deleted, unique) - } - - // deleted_at doesn't mean the actual time that the user deleted it, - // but the time that the system marks it as deleted - async fn delete(&self, mut h: History) -> Result<()> { - let now = OffsetDateTime::now_utc(); - h.command = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(32) - .map(char::from) - .collect(); // overwrite with random string - h.deleted_at = Some(now); // delete it - - self.update(&h).await?; // save it - - Ok(()) - } - - async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()> { - let mut tx = self.pool.begin().await?; - - for id in ids { - Self::delete_row_raw(&mut tx, id.clone()).await?; - } - - tx.commit().await?; - - Ok(()) - } - - async fn stats(&self, h: &History) -> Result<HistoryStats> { - // We select the previous in the session by time - let mut prev = SqlBuilder::select_from("history"); - prev.field("*") - .and_where("timestamp < ?1") - .and_where("session = ?2") - .order_by("timestamp", true) - .limit(1); - - let mut next = SqlBuilder::select_from("history"); - next.field("*") - .and_where("timestamp > ?1") - .and_where("session = ?2") - .order_by("timestamp", false) - .limit(1); - - let mut total = SqlBuilder::select_from("history"); - total.field("count(1)").and_where("command = ?1"); - - let mut average = SqlBuilder::select_from("history"); - average.field("avg(duration)").and_where("command = ?1"); - - let mut exits = SqlBuilder::select_from("history"); - exits - .fields(&["exit", "count(1) as count"]) - .and_where("command = ?1") - .group_by("exit"); - - // rewrite the following with sqlbuilder - let mut day_of_week = SqlBuilder::select_from("history"); - day_of_week - .fields(&[ - "strftime('%w', ROUND(timestamp / 1000000000), 'unixepoch') AS day_of_week", - "count(1) as count", - ]) - .and_where("command = ?1") - .group_by("day_of_week"); - - // Intentionally format the string with 01 hardcoded. We want the average runtime for the - // _entire month_, but will later parse it as a datetime for sorting - // Sqlite has no datetime so we cannot do it there, and otherwise sorting will just be a - // string sort, which won't be correct. - let mut duration_over_time = SqlBuilder::select_from("history"); - duration_over_time - .fields(&[ - "strftime('01-%m-%Y', ROUND(timestamp / 1000000000), 'unixepoch') AS month_year", - "avg(duration) as duration", - ]) - .and_where("command = ?1") - .group_by("month_year") - .having("duration > 0"); - - let prev = prev.sql().expect("issue in stats previous query"); - let next = next.sql().expect("issue in stats next query"); - let total = total.sql().expect("issue in stats average query"); - let average = average.sql().expect("issue in stats previous query"); - let exits = exits.sql().expect("issue in stats exits query"); - let day_of_week = day_of_week.sql().expect("issue in stats day of week query"); - let duration_over_time = duration_over_time - .sql() - .expect("issue in stats duration over time query"); - - let prev = sqlx::query(&prev) - .bind(h.timestamp.unix_timestamp_nanos() as i64) - .bind(&h.session) - .map(Self::query_history) - .fetch_optional(&self.pool) - .await?; - - let next = sqlx::query(&next) - .bind(h.timestamp.unix_timestamp_nanos() as i64) - .bind(&h.session) - .map(Self::query_history) - .fetch_optional(&self.pool) - .await?; - - let total: (i64,) = sqlx::query_as(&total) - .bind(&h.command) - .fetch_one(&self.pool) - .await?; - - let average: (f64,) = sqlx::query_as(&average) - .bind(&h.command) - .fetch_one(&self.pool) - .await?; - - let exits: Vec<(i64, i64)> = sqlx::query_as(&exits) - .bind(&h.command) - .fetch_all(&self.pool) - .await?; - - let day_of_week: Vec<(String, i64)> = sqlx::query_as(&day_of_week) - .bind(&h.command) - .fetch_all(&self.pool) - .await?; - - let duration_over_time: Vec<(String, f64)> = sqlx::query_as(&duration_over_time) - .bind(&h.command) - .fetch_all(&self.pool) - .await?; - - let duration_over_time = duration_over_time - .iter() - .map(|f| (f.0.clone(), f.1.round() as i64)) - .collect(); - - Ok(HistoryStats { - next, - previous: prev, - total: total.0 as u64, - average_duration: average.0 as u64, - exits, - day_of_week, - duration_over_time, - }) - } - - async fn get_dups(&self, before: i64, dupkeep: u32) -> Result<Vec<History>> { - let res = sqlx::query( - "SELECT * FROM ( - SELECT *, ROW_NUMBER() - OVER (PARTITION BY command, cwd, hostname ORDER BY timestamp DESC) - AS rn - FROM history - ) sub - WHERE rn > ?1 and timestamp < ?2; - ", - ) - .bind(dupkeep) - .bind(before) - .map(Self::query_history) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } - - fn clone_boxed(&self) -> Box<dyn Database + 'static> { - Box::new(self.clone()) - } -} - -pub struct Paged { - database: Box<dyn Database + 'static>, - page_size: usize, - last_id: Option<String>, - include_deleted: bool, - unique: bool, -} - -impl Paged { - pub fn new( - database: Box<dyn Database + 'static>, - page_size: usize, - include_deleted: bool, - unique: bool, - ) -> Self { - Self { - database, - page_size, - last_id: None, - include_deleted, - unique, - } - } - - pub async fn next(&mut self) -> Result<Option<Vec<History>>> { - let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted()); - - query.field("*").order_desc("id"); - - if !self.include_deleted { - query.and_where_is_null("deleted_at"); - } - - if self.unique { - // We want to deduplicate on command, but the user can search via cwd, hostname, and session. - // Without those fields, filter modes won't work right. With those fields, we get duplicates. - // This must be handled upstream. - query - .group_by("command, cwd, hostname, session") - .having("max(timestamp)"); - } - - query.limit(self.page_size); - - if let Some(last_id) = &self.last_id { - query.and_where_lt("id", quote(last_id)); - } - - let query = query.sql().expect("bug in list query. please report"); - let res = self.database.query_history(&query).await?; - - if res.is_empty() { - Ok(None) - } else { - self.last_id = Some(res.last().unwrap().id.0.clone()); - Ok(Some(res)) - } - } -} - -trait SqlBuilderExt { - fn fuzzy_condition<S: ToString, T: ToString>( - &mut self, - field: S, - mask: T, - inverse: bool, - glob: bool, - is_or: bool, - ) -> &mut Self; -} - -impl SqlBuilderExt for SqlBuilder { - /// adapted from the sql-builder *like functions - fn fuzzy_condition<S: ToString, T: ToString>( - &mut self, - field: S, - mask: T, - inverse: bool, - glob: bool, - is_or: bool, - ) -> &mut Self { - let mut cond = field.to_string(); - if inverse { - cond.push_str(" NOT"); - } - if glob { - cond.push_str(" GLOB '"); - } else { - cond.push_str(" LIKE '"); - } - cond.push_str(&esc(mask.to_string())); - cond.push('\''); - if is_or { - self.or_where(cond) - } else { - self.and_where(cond) - } - } -} - -#[cfg(test)] -mod test { - use crate::settings::test_local_timeout; - - use super::*; - use std::time::{Duration, Instant}; - - async fn assert_search_eq( - db: &impl Database, - mode: SearchMode, - filter_mode: FilterMode, - query: &str, - expected: usize, - ) -> Result<Vec<History>> { - let context = Context { - hostname: "test:host".to_string(), - session: "beepboopiamasession".to_string(), - cwd: "/home/ellie".to_string(), - host_id: "test-host".to_string(), - git_root: None, - }; - - let results = db - .search( - mode, - filter_mode, - &context, - query, - OptFilters { - ..Default::default() - }, - ) - .await?; - - assert_eq!( - results.len(), - expected, - "query \"{}\", commands: {:?}", - query, - results.iter().map(|a| &a.command).collect::<Vec<&String>>() - ); - Ok(results) - } - - async fn assert_search_commands( - db: &impl Database, - mode: SearchMode, - filter_mode: FilterMode, - query: &str, - expected_commands: Vec<&str>, - ) { - let results = assert_search_eq(db, mode, filter_mode, query, expected_commands.len()) - .await - .unwrap(); - let commands: Vec<&str> = results.iter().map(|a| a.command.as_str()).collect(); - assert_eq!(commands, expected_commands); - } - - async fn new_history_item(db: &mut impl Database, cmd: &str) -> Result<()> { - let mut captured: History = History::capture() - .timestamp(OffsetDateTime::now_utc()) - .command(cmd) - .cwd("/home/ellie") - .build() - .into(); - - captured.exit = 0; - captured.duration = 1; - captured.session = "beep boop".to_string(); - captured.hostname = "booop".to_string(); - - db.save(&captured).await - } - - #[tokio::test(flavor = "multi_thread")] - async fn test_search_prefix() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) - .await - .unwrap(); - new_history_item(&mut db, "ls /home/ellie").await.unwrap(); - - assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls", 1) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "/home", 0) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls ", 0) - .await - .unwrap(); - } - - #[tokio::test(flavor = "multi_thread")] - async fn test_search_fulltext() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) - .await - .unwrap(); - new_history_item(&mut db, "ls /home/ellie").await.unwrap(); - - assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls", 1) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home", 1) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls ho", 1) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "hm", 0) - .await - .unwrap(); - - // regex - assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r/^ls ", 1) - .await - .unwrap(); - assert_search_eq( - &db, - SearchMode::FullText, - FilterMode::Global, - "r/ls / ie$", - 1, - ) - .await - .unwrap(); - assert_search_eq( - &db, - SearchMode::FullText, - FilterMode::Global, - "r/ls / !ie", - 0, - ) - .await - .unwrap(); - assert_search_eq( - &db, - SearchMode::FullText, - FilterMode::Global, - "meow r/ls/", - 0, - ) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r//hom/", 1) - .await - .unwrap(); - assert_search_eq( - &db, - SearchMode::FullText, - FilterMode::Global, - "r//home//", - 1, - ) - .await - .unwrap(); - assert_search_eq( - &db, - SearchMode::FullText, - FilterMode::Global, - "r//home///", - 0, - ) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home.*e", 0) - .await - .unwrap(); - assert_search_eq( - &db, - SearchMode::FullText, - FilterMode::Global, - "r/home.*e", - 1, - ) - .await - .unwrap(); - } - - #[tokio::test(flavor = "multi_thread")] - async fn test_search_fuzzy() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) - .await - .unwrap(); - new_history_item(&mut db, "ls /home/ellie").await.unwrap(); - new_history_item(&mut db, "ls /home/frank").await.unwrap(); - new_history_item(&mut db, "cd /home/Ellie").await.unwrap(); - new_history_item(&mut db, "/home/ellie/.bin/rustup") - .await - .unwrap(); - - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls /", 3) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls/", 2) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "l/h/", 2) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e", 3) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/hmoe/", 0) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie/home", 0) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "lsellie", 1) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, " ", 4) - .await - .unwrap(); - - // single term operators - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls", 2) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "'ls", 2) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie$", 2) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!^ls", 2) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie", 1) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie$", 2) - .await - .unwrap(); - - // multiple terms - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls !ellie", 1) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls !e$", 1) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "home !^ls", 2) - .await - .unwrap(); - assert_search_eq( - &db, - SearchMode::Fuzzy, - FilterMode::Global, - "'frank | 'rustup", - 2, - ) - .await - .unwrap(); - assert_search_eq( - &db, - SearchMode::Fuzzy, - FilterMode::Global, - "'frank | 'rustup 'ls", - 1, - ) - .await - .unwrap(); - - // case matching - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "Ellie", 1) - .await - .unwrap(); - - // regex - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/^ls ", 2) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/[Ee]llie", 3) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e r/^ls ", 1) - .await - .unwrap(); - } - - #[tokio::test(flavor = "multi_thread")] - async fn test_search_reordered_fuzzy() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) - .await - .unwrap(); - // test ordering of results: we should choose the first, even though it happened longer ago. - - new_history_item(&mut db, "curl").await.unwrap(); - new_history_item(&mut db, "corburl").await.unwrap(); - - // if fuzzy reordering is on, it should come back in a more sensible order - assert_search_commands( - &db, - SearchMode::Fuzzy, - FilterMode::Global, - "curl", - vec!["curl", "corburl"], - ) - .await; - - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "xxxx", 0) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "", 2) - .await - .unwrap(); - } - - #[tokio::test(flavor = "multi_thread")] - async fn test_paged_basic() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) - .await - .unwrap(); - - // Add 5 history items - for i in 0..5 { - new_history_item(&mut db, &format!("command{}", i)) - .await - .unwrap(); - } - - // Create a paged iterator with page_size of 2 - let mut paged = db.all_paged(2, false, false); - - // First page should have 2 items - let page1 = paged.next().await.unwrap(); - assert!(page1.is_some()); - assert_eq!(page1.unwrap().len(), 2); - - // Second page should have 2 items - let page2 = paged.next().await.unwrap(); - assert!(page2.is_some()); - assert_eq!(page2.unwrap().len(), 2); - - // Third page should have 1 item - let page3 = paged.next().await.unwrap(); - assert!(page3.is_some()); - assert_eq!(page3.unwrap().len(), 1); - - // Fourth page should be None (exhausted) - let page4 = paged.next().await.unwrap(); - assert!(page4.is_none()); - } - - #[tokio::test(flavor = "multi_thread")] - async fn test_paged_empty() { - let db = Sqlite::new("sqlite::memory:", test_local_timeout()) - .await - .unwrap(); - - // Create a paged iterator on empty database - let mut paged = db.all_paged(10, false, false); - - // Should return None immediately - let page = paged.next().await.unwrap(); - assert!(page.is_none()); - } - - #[tokio::test(flavor = "multi_thread")] - async fn test_paged_unique() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) - .await - .unwrap(); - - // Add duplicate commands - new_history_item(&mut db, "duplicate").await.unwrap(); - new_history_item(&mut db, "duplicate").await.unwrap(); - new_history_item(&mut db, "unique1").await.unwrap(); - new_history_item(&mut db, "unique2").await.unwrap(); - - // Without unique flag - should get all 4 - let mut paged = db.all_paged(10, false, false); - let page = paged.next().await.unwrap().unwrap(); - assert_eq!(page.len(), 4); - - // With unique flag - should get 3 (duplicates collapsed) - let mut paged_unique = db.all_paged(10, false, true); - let page_unique = paged_unique.next().await.unwrap().unwrap(); - assert_eq!(page_unique.len(), 3); - } - - #[tokio::test(flavor = "multi_thread")] - async fn test_paged_include_deleted() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) - .await - .unwrap(); - - // Add items - new_history_item(&mut db, "keep1").await.unwrap(); - new_history_item(&mut db, "keep2").await.unwrap(); - new_history_item(&mut db, "delete_me").await.unwrap(); - - // Delete one item - let all = db - .list( - &[], - &Context { - hostname: "".to_string(), - session: "".to_string(), - cwd: "".to_string(), - host_id: "".to_string(), - git_root: None, - }, - None, - false, - false, - ) - .await - .unwrap(); - - let to_delete = all - .iter() - .find(|h| h.command == "delete_me") - .unwrap() - .clone(); - db.delete(to_delete).await.unwrap(); - - // Without include_deleted - should get 2 - let mut paged = db.all_paged(10, false, false); - let page = paged.next().await.unwrap().unwrap(); - assert_eq!(page.len(), 2); - - // With include_deleted - should get 3 - let mut paged_deleted = db.all_paged(10, true, false); - let page_deleted = paged_deleted.next().await.unwrap().unwrap(); - assert_eq!(page_deleted.len(), 3); - } - - #[tokio::test(flavor = "multi_thread")] - async fn test_search_bench_dupes() { - let context = Context { - hostname: "test:host".to_string(), - session: "beepboopiamasession".to_string(), - cwd: "/home/ellie".to_string(), - host_id: "test-host".to_string(), - git_root: None, - }; - - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) - .await - .unwrap(); - for _i in 1..10000 { - new_history_item(&mut db, "i am a duplicated command") - .await - .unwrap(); - } - let start = Instant::now(); - let _results = db - .search( - SearchMode::Fuzzy, - FilterMode::Global, - &context, - "", - OptFilters { - ..Default::default() - }, - ) - .await - .unwrap(); - let duration = start.elapsed(); - - assert!(duration < Duration::from_secs(15)); - } -} - -pub struct QueryTokenizer<'a> { - query: &'a str, - last_pos: usize, -} - -pub enum QueryToken<'a> { - Match(&'a str, bool), - MatchStart(&'a str, bool), - MatchEnd(&'a str, bool), - MatchFull(&'a str, bool), - Or, - Regex(&'a str), -} - -impl<'a> QueryToken<'a> { - pub fn has_uppercase(&self) -> bool { - match self { - Self::Match(term, _) - | Self::MatchStart(term, _) - | Self::MatchEnd(term, _) - | Self::MatchFull(term, _) => term.contains(char::is_uppercase), - _ => false, - } - } - - pub fn is_inverse(&self) -> bool { - match self { - Self::Match(_, inv) - | Self::MatchStart(_, inv) - | Self::MatchEnd(_, inv) - | Self::MatchFull(_, inv) => *inv, - _ => false, - } - } -} - -impl<'a> QueryTokenizer<'a> { - pub fn new(query: &'a str) -> Self { - Self { query, last_pos: 0 } - } -} - -impl<'a> Iterator for QueryTokenizer<'a> { - type Item = QueryToken<'a>; - fn next(&mut self) -> Option<Self::Item> { - let remaining = &self.query[self.last_pos..]; - if remaining.is_empty() { - return None; - } - - if let Some(remaining) = remaining.strip_prefix("r/") { - let (regex, next_pos) = if let Some(end) = remaining.find("/ ") { - (&remaining[..end], self.last_pos + 2 + end + 2) - } else if let Some(remaining) = remaining.strip_suffix('/') { - (remaining, self.query.len()) - } else { - (remaining, self.query.len()) - }; - self.last_pos = next_pos; - Some(QueryToken::Regex(regex)) - } else { - let (mut part, next_pos) = if let Some(sp) = remaining.find(' ') { - (&remaining[..sp], self.last_pos + sp + 1) - } else { - (remaining, self.query.len()) - }; - self.last_pos = next_pos; - - if part == "|" { - return Some(QueryToken::Or); - } - - let mut is_inverse = false; - if let Some(s) = part.strip_prefix('!') { - part = s; - is_inverse = true; - } - let token = if let Some(s) = part.strip_prefix('^') { - QueryToken::MatchStart(s, is_inverse) - } else if let Some(s) = part.strip_suffix('$') { - QueryToken::MatchEnd(s, is_inverse) - } else if let Some(s) = part.strip_prefix('\'') { - QueryToken::MatchFull(s, is_inverse) - } else { - QueryToken::Match(part, is_inverse) - }; - Some(token) - } - } -} diff --git a/crates/atuin-client/src/distro.rs b/crates/atuin-client/src/distro.rs deleted file mode 100644 index dead8355..00000000 --- a/crates/atuin-client/src/distro.rs +++ /dev/null @@ -1,89 +0,0 @@ -use std::process::Command; - -/// Detect the Linux distribution from the system, -/// using system-specific release files and falling -/// back to lsb_release. -pub fn detect_linux_distribution() -> String { - detect_from_os_release() - .or_else(detect_from_debian_version) - .or_else(detect_from_centos_release) - .or_else(detect_from_redhat_release) - .or_else(detect_from_fedora_release) - .or_else(detect_from_arch_release) - .or_else(detect_from_alpine_release) - .or_else(detect_from_suse_release) - .or_else(detect_from_lsb_release) - .unwrap_or_else(|| "Unknown".to_string()) -} - -fn detect_from_os_release() -> Option<String> { - let content = std::fs::read_to_string("/etc/os-release").ok()?; - - content - .lines() - .find(|l| l.starts_with("PRETTY_NAME=")) - .and_then(|l| l.split_once('=').map(|s| s.1)) - .map(|s| s.trim_matches('"').to_string()) -} - -fn detect_from_debian_version() -> Option<String> { - std::fs::read_to_string("/etc/debian_version") - .ok() - .map(|v| format!("Debian {}", v.trim())) -} - -fn detect_from_centos_release() -> Option<String> { - std::fs::read_to_string("/etc/centos-release") - .ok() - .map(|v| v.trim().to_string()) -} - -fn detect_from_redhat_release() -> Option<String> { - std::fs::read_to_string("/etc/redhat-release") - .ok() - .map(|v| v.trim().to_string()) -} - -fn detect_from_fedora_release() -> Option<String> { - std::fs::read_to_string("/etc/fedora-release") - .ok() - .map(|v| v.trim().to_string()) -} - -fn detect_from_arch_release() -> Option<String> { - std::fs::read_to_string("/etc/arch-release") - .ok() - .filter(|v| !v.trim().is_empty()) - .map(|_| "Arch Linux".to_string()) -} - -fn detect_from_alpine_release() -> Option<String> { - std::fs::read_to_string("/etc/alpine-release") - .ok() - .map(|v| format!("Alpine {}", v.trim())) -} - -fn detect_from_suse_release() -> Option<String> { - std::fs::read_to_string("/etc/SuSE-release") - .ok() - .and_then(|content| content.lines().next().map(|l| l.trim().to_string())) -} - -fn detect_from_lsb_release() -> Option<String> { - let output = Command::new("lsb_release").arg("-a").output().ok()?; - - if !output.status.success() { - return None; - } - - let output = String::from_utf8(output.stdout).ok()?; - linux_distro_from_lsb_release(&output) -} - -fn linux_distro_from_lsb_release(output: &str) -> Option<String> { - output - .lines() - .find(|line| line.starts_with("Description:")) - .and_then(|line| line.split_once(':').map(|s| s.1)) - .map(|s| s.trim().to_string()) -} diff --git a/crates/atuin-client/src/encryption.rs b/crates/atuin-client/src/encryption.rs deleted file mode 100644 index f2032482..00000000 --- a/crates/atuin-client/src/encryption.rs +++ /dev/null @@ -1,440 +0,0 @@ -// The general idea is that we NEVER send cleartext history to the server -// This way the odds of anything private ending up where it should not are -// very low -// The server authenticates via the usual username and password. This has -// nothing to do with the encryption, and is purely authentication! The client -// generates its own secret key, and encrypts all shell history with libsodium's -// secretbox. The data is then sent to the server, where it is stored. All -// clients must share the secret in order to be able to sync, as it is needed -// to decrypt - -use std::{io::prelude::*, path::PathBuf}; - -use base64::prelude::{BASE64_STANDARD, Engine}; -pub use crypto_secretbox::Key; -use crypto_secretbox::{ - AeadCore, AeadInPlace, KeyInit, XSalsa20Poly1305, - aead::{Nonce, OsRng}, -}; -use eyre::{Context, Result, bail, ensure, eyre}; -use fs_err as fs; -use rmp::{Marker, decode::Bytes}; -use serde::{Deserialize, Serialize}; -use time::{OffsetDateTime, format_description::well_known::Rfc3339, macros::format_description}; - -use crate::{history::History, settings::Settings}; - -#[derive(Debug, Serialize, Deserialize)] -pub struct EncryptedHistory { - pub ciphertext: Vec<u8>, - pub nonce: Nonce<XSalsa20Poly1305>, -} - -pub fn generate_encoded_key() -> Result<(Key, String)> { - let key = XSalsa20Poly1305::generate_key(&mut OsRng); - let encoded = encode_key(&key)?; - - Ok((key, encoded)) -} - -pub fn new_key(settings: &Settings) -> Result<Key> { - let path = settings.key_path.as_str(); - let path = PathBuf::from(path); - - if path.exists() { - bail!("key already exists! cannot overwrite"); - } - - let (key, encoded) = generate_encoded_key()?; - - let mut file = fs::File::create(path)?; - file.write_all(encoded.as_bytes())?; - - Ok(key) -} - -// Loads the secret key, will create + save if it doesn't exist -pub fn load_key(settings: &Settings) -> Result<Key> { - let path = settings.key_path.as_str(); - - let key = if PathBuf::from(path).exists() { - let key = fs_err::read_to_string(path)?; - decode_key(key)? - } else { - new_key(settings)? - }; - - Ok(key) -} - -pub fn encode_key(key: &Key) -> Result<String> { - let mut buf = vec![]; - rmp::encode::write_array_len(&mut buf, key.len() as u32) - .wrap_err("could not encode key to message pack")?; - for b in key { - rmp::encode::write_uint(&mut buf, *b as u64) - .wrap_err("could not encode key to message pack")?; - } - let buf = BASE64_STANDARD.encode(buf); - - Ok(buf) -} - -pub fn decode_key(key: String) -> Result<Key> { - use rmp::decode; - - let buf = BASE64_STANDARD - .decode(key.trim_end()) - .wrap_err("encryption key is not a valid base64 encoding")?; - - // old code wrote the key as a fixed length array of 32 bytes - // new code writes the key with a length prefix - match <[u8; 32]>::try_from(&*buf) { - Ok(key) => Ok(key.into()), - Err(_) => { - let mut bytes = rmp::decode::Bytes::new(&buf); - - match Marker::from_u8(buf[0]) { - Marker::Bin8 => { - let len = decode::read_bin_len(&mut bytes).map_err(|err| eyre!("{err:?}"))?; - ensure!(len == 32, "encryption key is not the correct size"); - let key = <[u8; 32]>::try_from(bytes.remaining_slice()) - .context("could not decode encryption key")?; - Ok(key.into()) - } - Marker::Array16 => { - let len = decode::read_array_len(&mut bytes).map_err(|err| eyre!("{err:?}"))?; - ensure!(len == 32, "encryption key is not the correct size"); - - let mut key = Key::default(); - for i in &mut key { - *i = rmp::decode::read_int(&mut bytes).map_err(|err| eyre!("{err:?}"))?; - } - Ok(key) - } - _ => bail!("could not decode encryption key"), - } - } - } -} - -pub fn encrypt(history: &History, key: &Key) -> Result<EncryptedHistory> { - // serialize with msgpack - let mut buf = encode(history)?; - - let nonce = XSalsa20Poly1305::generate_nonce(&mut OsRng); - XSalsa20Poly1305::new(key) - .encrypt_in_place(&nonce, &[], &mut buf) - .map_err(|_| eyre!("could not encrypt"))?; - - Ok(EncryptedHistory { - ciphertext: buf, - nonce, - }) -} - -pub fn decrypt(mut encrypted_history: EncryptedHistory, key: &Key) -> Result<History> { - XSalsa20Poly1305::new(key) - .decrypt_in_place( - &encrypted_history.nonce, - &[], - &mut encrypted_history.ciphertext, - ) - .map_err(|_| eyre!("could not decrypt history"))?; - let plaintext = encrypted_history.ciphertext; - - let history = decode(&plaintext)?; - - Ok(history) -} - -fn format_rfc3339(ts: OffsetDateTime) -> Result<String> { - // horrible hack. chrono AutoSI limits to 0, 3, 6, or 9 decimal places for nanoseconds. - // time does not have this functionality. - static PARTIAL_RFC3339_0: &[time::format_description::FormatItem<'static>] = - format_description!("[year]-[month]-[day]T[hour]:[minute]:[second]Z"); - static PARTIAL_RFC3339_3: &[time::format_description::FormatItem<'static>] = - format_description!("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"); - static PARTIAL_RFC3339_6: &[time::format_description::FormatItem<'static>] = - format_description!("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:6]Z"); - static PARTIAL_RFC3339_9: &[time::format_description::FormatItem<'static>] = - format_description!("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:9]Z"); - - let fmt = match ts.nanosecond() { - 0 => PARTIAL_RFC3339_0, - ns if ns % 1_000_000 == 0 => PARTIAL_RFC3339_3, - ns if ns % 1_000 == 0 => PARTIAL_RFC3339_6, - _ => PARTIAL_RFC3339_9, - }; - - Ok(ts.format(fmt)?) -} - -fn encode(h: &History) -> Result<Vec<u8>> { - use rmp::encode; - - let mut output = vec![]; - // INFO: ensure this is updated when adding new fields - encode::write_array_len(&mut output, 9)?; - - encode::write_str(&mut output, &h.id.0)?; - encode::write_str(&mut output, &(format_rfc3339(h.timestamp)?))?; - encode::write_sint(&mut output, h.duration)?; - encode::write_sint(&mut output, h.exit)?; - encode::write_str(&mut output, &h.command)?; - encode::write_str(&mut output, &h.cwd)?; - encode::write_str(&mut output, &h.session)?; - encode::write_str(&mut output, &h.hostname)?; - match h.deleted_at { - Some(d) => encode::write_str(&mut output, &format_rfc3339(d)?)?, - None => encode::write_nil(&mut output)?, - } - - Ok(output) -} - -fn decode(bytes: &[u8]) -> Result<History> { - use rmp::decode::{self, DecodeStringError}; - - let mut bytes = Bytes::new(bytes); - - let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; - if nfields < 8 { - bail!("malformed decrypted history") - } - if nfields > 9 { - bail!("cannot decrypt history from a newer version of atuin"); - } - - let bytes = bytes.remaining_slice(); - let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (timestamp, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - - let mut bytes = Bytes::new(bytes); - let duration = decode::read_int(&mut bytes).map_err(error_report)?; - let exit = decode::read_int(&mut bytes).map_err(error_report)?; - - let bytes = bytes.remaining_slice(); - let (command, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (cwd, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (session, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (hostname, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - - // if we have more fields, try and get the deleted_at - let mut deleted_at = None; - let mut bytes = bytes; - if nfields > 8 { - bytes = match decode::read_str_from_slice(bytes) { - Ok((d, b)) => { - deleted_at = Some(d); - b - } - // we accept null here - Err(DecodeStringError::TypeMismatch(Marker::Null)) => { - // consume the null marker - let mut c = Bytes::new(bytes); - decode::read_nil(&mut c).map_err(error_report)?; - c.remaining_slice() - } - Err(err) => return Err(error_report(err)), - }; - } - - if !bytes.is_empty() { - bail!("trailing bytes in encoded history. malformed") - } - - Ok(History { - id: id.to_owned().into(), - timestamp: OffsetDateTime::parse(timestamp, &Rfc3339)?, - duration, - exit, - command: command.to_owned(), - cwd: cwd.to_owned(), - session: session.to_owned(), - hostname: hostname.to_owned(), - author: History::author_from_hostname(hostname), - intent: None, - deleted_at: deleted_at - .map(|t| OffsetDateTime::parse(t, &Rfc3339)) - .transpose()?, - }) -} - -fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { - eyre!("{err:?}") -} - -#[cfg(test)] -mod test { - use crypto_secretbox::{KeyInit, XSalsa20Poly1305, aead::OsRng}; - use pretty_assertions::assert_eq; - use time::{OffsetDateTime, macros::datetime}; - - use crate::history::History; - - use super::{decode, decrypt, encode, encrypt}; - - #[test] - fn test_encrypt_decrypt() { - let key1 = XSalsa20Poly1305::generate_key(&mut OsRng); - let key2 = XSalsa20Poly1305::generate_key(&mut OsRng); - - let history = History::from_db() - .id("1".into()) - .timestamp(OffsetDateTime::now_utc()) - .command("ls".into()) - .cwd("/home/ellie".into()) - .exit(0) - .duration(1) - .session("beep boop".into()) - .hostname("booop".into()) - .author("booop".into()) - .intent(None) - .deleted_at(None) - .build() - .into(); - - let e1 = encrypt(&history, &key1).unwrap(); - let e2 = encrypt(&history, &key2).unwrap(); - - assert_ne!(e1.ciphertext, e2.ciphertext); - assert_ne!(e1.nonce, e2.nonce); - - // test decryption works - // this should pass - match decrypt(e1, &key1) { - Err(e) => panic!("failed to decrypt, got {e}"), - Ok(h) => assert_eq!(h, history), - }; - - // this should err - let _ = decrypt(e2, &key1).expect_err("expected an error decrypting with invalid key"); - } - - #[test] - fn test_decode() { - let bytes = [ - 0x99, 0xD9, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, 53, 51, 56, - 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 187, 50, 48, 50, 51, 45, - 48, 53, 45, 50, 56, 84, 49, 56, 58, 51, 53, 58, 52, 48, 46, 54, 51, 51, 56, 55, 50, 90, - 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, 97, 116, 117, 115, 217, 42, - 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, - 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, 47, 99, 111, 100, 101, 47, 97, - 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, 51, 48, 54, 102, 50, 55, 52, 52, - 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, 102, 57, 52, 53, 55, 187, 102, - 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, 99, 111, 110, 114, 97, 100, 46, - 108, 117, 100, 103, 97, 116, 101, 192, - ]; - let history = History { - id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), - timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), - duration: 49206000, - exit: 0, - command: "git status".to_owned(), - cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), - session: "b97d9a306f274473a203d2eba41f9457".to_owned(), - hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), - author: "conrad.ludgate".to_owned(), - intent: None, - deleted_at: None, - }; - - let h = decode(&bytes).unwrap(); - assert_eq!(history, h); - - let b = encode(&h).unwrap(); - assert_eq!(&bytes, &*b); - } - - #[test] - fn test_decode_deleted() { - let history = History { - id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), - timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), - duration: 49206000, - exit: 0, - command: "git status".to_owned(), - cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), - session: "b97d9a306f274473a203d2eba41f9457".to_owned(), - hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), - author: "conrad.ludgate".to_owned(), - intent: None, - deleted_at: Some(datetime!(2023-05-28 18:35:40.633872 +00:00)), - }; - - let b = encode(&history).unwrap(); - let h = decode(&b).unwrap(); - assert_eq!(history, h); - } - - #[test] - fn test_decode_old() { - let bytes = [ - 0x98, 0xD9, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, 53, 51, 56, - 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 187, 50, 48, 50, 51, 45, - 48, 53, 45, 50, 56, 84, 49, 56, 58, 51, 53, 58, 52, 48, 46, 54, 51, 51, 56, 55, 50, 90, - 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, 97, 116, 117, 115, 217, 42, - 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, - 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, 47, 99, 111, 100, 101, 47, 97, - 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, 51, 48, 54, 102, 50, 55, 52, 52, - 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, 102, 57, 52, 53, 55, 187, 102, - 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, 99, 111, 110, 114, 97, 100, 46, - 108, 117, 100, 103, 97, 116, 101, - ]; - let history = History { - id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), - timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), - duration: 49206000, - exit: 0, - command: "git status".to_owned(), - cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), - session: "b97d9a306f274473a203d2eba41f9457".to_owned(), - hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), - author: "conrad.ludgate".to_owned(), - intent: None, - deleted_at: None, - }; - - let h = decode(&bytes).unwrap(); - assert_eq!(history, h); - } - - #[test] - fn key_encodings() { - use super::{Key, decode_key, encode_key}; - - // a history of our key encodings. - // v11.0.0 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== - // v12.0.0 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== - // v13.0.0 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== - // v13.0.1 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== - // v14.0.0 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== - // v14.0.1 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== - // c7d89c1 3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q== (https://github.com/ellie/atuin/pull/805) - // b53ca35 3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q== (https://github.com/ellie/atuin/pull/974) - // v15.0.0 3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q== - // b8b57c8 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== (https://github.com/ellie/atuin/pull/1057) - // 8c94d79 3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q== (https://github.com/ellie/atuin/pull/1089) - - let key = Key::from([ - 27, 91, 42, 91, 210, 107, 9, 216, 170, 190, 242, 62, 6, 84, 69, 148, 148, 53, 251, 117, - 226, 167, 173, 52, 82, 34, 138, 110, 169, 124, 92, 229, - ]); - - assert_eq!( - encode_key(&key).unwrap(), - "3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q==" - ); - - // key encodings we have to support - let valid_encodings = [ - "xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q==", - "3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q==", - ]; - - for k in valid_encodings { - assert_eq!(decode_key(k.to_owned()).expect(k), key); - } - } -} diff --git a/crates/atuin-client/src/history.rs b/crates/atuin-client/src/history.rs deleted file mode 100644 index aa0d84d5..00000000 --- a/crates/atuin-client/src/history.rs +++ /dev/null @@ -1,756 +0,0 @@ -use core::fmt::Formatter; -use rmp::decode::DecodeStringError; -use rmp::decode::ValueReadError; -use rmp::{Marker, decode::Bytes}; -use std::env; -use std::fmt::Display; - -use atuin_common::record::DecryptedData; -use atuin_common::utils::uuid_v7; - -use eyre::{Result, bail, eyre}; - -use crate::secrets::SECRET_PATTERNS_RE; -use crate::settings::Settings; -use crate::utils::get_host_user; -use time::OffsetDateTime; - -mod builder; -pub mod store; - -/// Known AI agent author values. Used to expand `$all-agent` and `$all-user` filters. -pub const KNOWN_AGENTS: &[&str] = &["claude-code", "codex", "copilot", "pi"]; -pub const AUTHOR_FILTER_ALL_USER: &str = "$all-user"; -pub const AUTHOR_FILTER_ALL_AGENT: &str = "$all-agent"; - -pub fn is_known_agent(author: &str) -> bool { - KNOWN_AGENTS.contains(&author) -} - -pub fn author_matches_filters(author: &str, filters: &[String]) -> bool { - filters.is_empty() - || filters.iter().any(|filter| match filter.as_str() { - AUTHOR_FILTER_ALL_USER => !is_known_agent(author), - AUTHOR_FILTER_ALL_AGENT => is_known_agent(author), - literal => author == literal, - }) -} - -pub(crate) const HISTORY_VERSION_V0: &str = "v0"; -pub(crate) const HISTORY_VERSION_V1: &str = "v1"; -const HISTORY_RECORD_VERSION_V0: u16 = 0; -const HISTORY_RECORD_VERSION_V1: u16 = 1; -pub(crate) const HISTORY_VERSION: &str = HISTORY_VERSION_V1; -pub const HISTORY_TAG: &str = "history"; -const HISTORY_AUTHOR_ENV: &str = "ATUIN_HISTORY_AUTHOR"; -const HISTORY_INTENT_ENV: &str = "ATUIN_HISTORY_INTENT"; - -#[derive(Clone, Debug, Eq, PartialEq, Hash)] -pub struct HistoryId(pub String); - -impl Display for HistoryId { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -impl From<String> for HistoryId { - fn from(s: String) -> Self { - Self(s) - } -} - -/// Client-side history entry. -/// -/// Client stores data unencrypted, and only encrypts it before sending to the server. -/// -/// To create a new history entry, use one of the builders: -/// - [`History::import()`] to import an entry from the shell history file -/// - [`History::capture()`] to capture an entry via hook -/// - [`History::from_db()`] to create an instance from the database entry -// -// ## Implementation Notes -// -// New fields must be added to `History::{serialize,deserialize}` in a backwards -// compatible way (sensible defaults and careful `nfields` handling). -#[derive(Debug, Clone, PartialEq, Eq, sqlx::FromRow)] -pub struct History { - /// A client-generated ID, used to identify the entry when syncing. - /// - /// Stored as `client_id` in the database. - pub id: HistoryId, - /// When the command was run. - pub timestamp: OffsetDateTime, - /// How long the command took to run. - pub duration: i64, - /// The exit code of the command. - pub exit: i64, - /// The command that was run. - pub command: String, - /// The current working directory when the command was run. - pub cwd: String, - /// The session ID, associated with a terminal session. - pub session: String, - /// The hostname of the machine the command was run on. - pub hostname: String, - /// Who wrote this command (human user or automation/agent identity). - pub author: String, - /// Optional rationale for why the command was executed. - pub intent: Option<String>, - /// Timestamp, which is set when the entry is deleted, allowing a soft delete. - pub deleted_at: Option<OffsetDateTime>, -} - -#[derive(Debug, Clone, PartialEq, Eq, sqlx::FromRow)] -pub struct HistoryStats { - /// The command that was ran after this one in the session - pub next: Option<History>, - /// - /// The command that was ran before this one in the session - pub previous: Option<History>, - - /// How many times has this command been ran? - pub total: u64, - - pub average_duration: u64, - - pub exits: Vec<(i64, i64)>, - - pub day_of_week: Vec<(String, i64)>, - - pub duration_over_time: Vec<(String, i64)>, -} - -impl History { - pub(crate) fn author_from_hostname(hostname: &str) -> String { - hostname - .split_once(':') - .map_or_else(|| hostname.to_owned(), |(_, user)| user.to_owned()) - } - - fn normalize_optional_field(field: Option<String>) -> Option<String> { - field.and_then(|value| { - let trimmed = value.trim(); - if trimmed.is_empty() { - None - } else { - Some(trimmed.to_owned()) - } - }) - } - - #[expect(clippy::too_many_arguments)] - fn new( - timestamp: OffsetDateTime, - command: String, - cwd: String, - exit: i64, - duration: i64, - session: Option<String>, - hostname: Option<String>, - author: Option<String>, - intent: Option<String>, - deleted_at: Option<OffsetDateTime>, - ) -> Self { - let session = session - .or_else(|| env::var("ATUIN_SESSION").ok()) - .unwrap_or_else(|| uuid_v7().as_simple().to_string()); - let hostname = hostname.unwrap_or_else(get_host_user); - let author = Self::normalize_optional_field(author) - .or_else(|| Self::normalize_optional_field(env::var(HISTORY_AUTHOR_ENV).ok())) - .unwrap_or_else(|| Self::author_from_hostname(hostname.as_str())); - let intent = Self::normalize_optional_field(intent) - .or_else(|| Self::normalize_optional_field(env::var(HISTORY_INTENT_ENV).ok())); - - Self { - id: uuid_v7().as_simple().to_string().into(), - timestamp, - command, - cwd, - exit, - duration, - session, - hostname, - author, - intent, - deleted_at, - } - } - - pub fn serialize(&self) -> Result<DecryptedData> { - // This is pretty much the same as what we used for the old history, with one difference - - // it uses integers for timestamps rather than a string format. - - use rmp::encode; - - let mut output = vec![]; - - // write the version - encode::write_u16(&mut output, HISTORY_RECORD_VERSION_V1)?; - let include_intent = self.intent.is_some(); - encode::write_array_len(&mut output, 10 + u32::from(include_intent))?; - - encode::write_str(&mut output, &self.id.0)?; - encode::write_u64(&mut output, self.timestamp.unix_timestamp_nanos() as u64)?; - encode::write_sint(&mut output, self.duration)?; - encode::write_sint(&mut output, self.exit)?; - encode::write_str(&mut output, &self.command)?; - encode::write_str(&mut output, &self.cwd)?; - encode::write_str(&mut output, &self.session)?; - encode::write_str(&mut output, &self.hostname)?; - - match self.deleted_at { - Some(d) => encode::write_u64(&mut output, d.unix_timestamp_nanos() as u64)?, - None => encode::write_nil(&mut output)?, - } - - encode::write_str(&mut output, self.author.as_str())?; - if let Some(intent) = &self.intent { - encode::write_str(&mut output, intent.as_str())?; - } - - Ok(DecryptedData(output)) - } - - fn read_optional_string(bytes: &[u8]) -> Result<(Option<String>, &[u8])> { - use rmp::decode; - - fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { - eyre!("{err:?}") - } - - match decode::read_str_from_slice(bytes) { - Ok((value, bytes)) => Ok((Some(value.to_owned()), bytes)), - Err(DecodeStringError::TypeMismatch(Marker::Null)) => { - let mut cursor = Bytes::new(bytes); - decode::read_nil(&mut cursor).map_err(error_report)?; - - Ok((None, cursor.remaining_slice())) - } - Err(err) => Err(error_report(err)), - } - } - - fn deserialize_v0(bytes: &[u8]) -> Result<History> { - use rmp::decode; - - fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { - eyre!("{err:?}") - } - - let mut bytes = Bytes::new(bytes); - - let version = decode::read_u16(&mut bytes).map_err(error_report)?; - - if version != HISTORY_RECORD_VERSION_V0 { - bail!("expected decoding v0 record, found v{version}"); - } - - let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; - - if nfields != 9 { - bail!("cannot decrypt history from a different version of Atuin"); - } - - let bytes = bytes.remaining_slice(); - let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - - let mut bytes = Bytes::new(bytes); - let timestamp = decode::read_u64(&mut bytes).map_err(error_report)?; - let duration = decode::read_int(&mut bytes).map_err(error_report)?; - let exit = decode::read_int(&mut bytes).map_err(error_report)?; - - let bytes = bytes.remaining_slice(); - let (command, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (cwd, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (session, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (hostname, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - - let mut bytes = Bytes::new(bytes); - - let (deleted_at, bytes) = match decode::read_u64(&mut bytes) { - Ok(unix) => (Some(unix), bytes.remaining_slice()), - // we accept null here - Err(ValueReadError::TypeMismatch(Marker::Null)) => (None, bytes.remaining_slice()), - Err(err) => return Err(error_report(err)), - }; - if !bytes.is_empty() { - bail!("trailing bytes in encoded history. malformed") - } - - Ok(History { - id: id.to_owned().into(), - timestamp: OffsetDateTime::from_unix_timestamp_nanos(timestamp as i128)?, - duration, - exit, - command: command.to_owned(), - cwd: cwd.to_owned(), - session: session.to_owned(), - hostname: hostname.to_owned(), - author: Self::author_from_hostname(hostname), - intent: None, - deleted_at: deleted_at - .map(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128)) - .transpose()?, - }) - } - - fn deserialize_v1(bytes: &[u8]) -> Result<History> { - use rmp::decode; - - fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { - eyre!("{err:?}") - } - - let mut bytes = Bytes::new(bytes); - - let version = decode::read_u16(&mut bytes).map_err(error_report)?; - - if version != HISTORY_RECORD_VERSION_V1 { - bail!("expected decoding v1 record, found v{version}"); - } - - let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; - - if !(10..=11).contains(&nfields) { - bail!("cannot decrypt history from a different version of Atuin"); - } - - let bytes = bytes.remaining_slice(); - let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - - let mut bytes = Bytes::new(bytes); - let timestamp = decode::read_u64(&mut bytes).map_err(error_report)?; - let duration = decode::read_int(&mut bytes).map_err(error_report)?; - let exit = decode::read_int(&mut bytes).map_err(error_report)?; - - let bytes = bytes.remaining_slice(); - let (command, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (cwd, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (session, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (hostname, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - - let mut bytes = Bytes::new(bytes); - - let (deleted_at, bytes) = match decode::read_u64(&mut bytes) { - Ok(unix) => (Some(unix), bytes.remaining_slice()), - // we accept null here - Err(ValueReadError::TypeMismatch(Marker::Null)) => (None, bytes.remaining_slice()), - Err(err) => return Err(error_report(err)), - }; - let (author, bytes) = Self::read_optional_string(bytes)?; - let (intent, bytes) = if nfields > 10 { - Self::read_optional_string(bytes)? - } else { - (None, bytes) - }; - - if !bytes.is_empty() { - bail!("trailing bytes in encoded history. malformed") - } - - Ok(History { - id: id.to_owned().into(), - timestamp: OffsetDateTime::from_unix_timestamp_nanos(timestamp as i128)?, - duration, - exit, - command: command.to_owned(), - cwd: cwd.to_owned(), - session: session.to_owned(), - hostname: hostname.to_owned(), - author: author.unwrap_or_else(|| Self::author_from_hostname(hostname)), - intent, - deleted_at: deleted_at - .map(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128)) - .transpose()?, - }) - } - - pub fn deserialize(bytes: &[u8], version: &str) -> Result<History> { - match version { - HISTORY_VERSION_V0 => Self::deserialize_v0(bytes), - HISTORY_VERSION_V1 => Self::deserialize_v1(bytes), - - _ => bail!("unknown version {version:?}"), - } - } - - /// Builder for a history entry that is imported from shell history. - /// - /// The only two required fields are `timestamp` and `command`. - /// - /// ## Examples - /// ``` - /// use atuin_client::history::History; - /// - /// let history: History = History::import() - /// .timestamp(time::OffsetDateTime::now_utc()) - /// .command("ls -la") - /// .build() - /// .into(); - /// ``` - /// - /// If shell history contains more information, it can be added to the builder: - /// ``` - /// use atuin_client::history::History; - /// - /// let history: History = History::import() - /// .timestamp(time::OffsetDateTime::now_utc()) - /// .command("ls -la") - /// .cwd("/home/user") - /// .exit(0) - /// .duration(100) - /// .build() - /// .into(); - /// ``` - /// - /// Unknown command or command without timestamp cannot be imported, which - /// is forced at compile time: - /// - /// ```compile_fail - /// use atuin_client::history::History; - /// - /// // this will not compile because timestamp is missing - /// let history: History = History::import() - /// .command("ls -la") - /// .build() - /// .into(); - /// ``` - pub fn import() -> builder::HistoryImportedBuilder { - builder::HistoryImported::builder() - } - - /// Builder for a history entry that is captured via hook. - /// - /// This builder is used only at the `start` step of the hook, - /// so it doesn't have any fields which are known only after - /// the command is finished, such as `exit` or `duration`. - /// - /// ## Examples - /// ```rust - /// use atuin_client::history::History; - /// - /// let history: History = History::capture() - /// .timestamp(time::OffsetDateTime::now_utc()) - /// .command("ls -la") - /// .cwd("/home/user") - /// .build() - /// .into(); - /// ``` - /// - /// Command without any required info cannot be captured, which is forced at compile time: - /// - /// ```compile_fail - /// use atuin_client::history::History; - /// - /// // this will not compile because `cwd` is missing - /// let history: History = History::capture() - /// .timestamp(time::OffsetDateTime::now_utc()) - /// .command("ls -la") - /// .build() - /// .into(); - /// ``` - pub fn capture() -> builder::HistoryCapturedBuilder { - builder::HistoryCaptured::builder() - } - - /// Builder for a history entry that is captured via hook, and sent to the daemon. - /// - /// This builder is used only at the `start` step of the hook, - /// so it doesn't have any fields which are known only after - /// the command is finished, such as `exit` or `duration`. - /// - /// It does, however, include information that can usually be inferred. - /// - /// This is because the daemon we are sending a request to lacks the context of the command - /// - /// ## Examples - /// ```rust - /// use atuin_client::history::History; - /// - /// let history: History = History::daemon() - /// .timestamp(time::OffsetDateTime::now_utc()) - /// .command("ls -la") - /// .cwd("/home/user") - /// .session("018deb6e8287781f9973ef40e0fde76b") - /// .hostname("computer:ellie") - /// .build() - /// .into(); - /// ``` - /// - /// Command without any required info cannot be captured, which is forced at compile time: - /// - /// ```compile_fail - /// use atuin_client::history::History; - /// - /// // this will not compile because `hostname` is missing - /// let history: History = History::daemon() - /// .timestamp(time::OffsetDateTime::now_utc()) - /// .command("ls -la") - /// .cwd("/home/user") - /// .session("018deb6e8287781f9973ef40e0fde76b") - /// .build() - /// .into(); - /// ``` - pub fn daemon() -> builder::HistoryDaemonCaptureBuilder { - builder::HistoryDaemonCapture::builder() - } - - /// Builder for a history entry that is imported from the database. - /// - /// All fields are required, as they are all present in the database. - /// - /// ```compile_fail - /// use atuin_client::history::History; - /// - /// // this will not compile because `id` field is missing - /// let history: History = History::from_db() - /// .timestamp(time::OffsetDateTime::now_utc()) - /// .command("ls -la".to_string()) - /// .cwd("/home/user".to_string()) - /// .exit(0) - /// .duration(100) - /// .session("somesession".to_string()) - /// .hostname("localhost".to_string()) - /// .author("user".to_string()) - /// .intent(None) - /// .deleted_at(None) - /// .build() - /// .into(); - /// ``` - pub fn from_db() -> builder::HistoryFromDbBuilder { - builder::HistoryFromDb::builder() - } - - pub fn success(&self) -> bool { - self.exit == 0 || self.duration == -1 - } - - pub fn should_save(&self, settings: &Settings) -> bool { - !(self.command.starts_with(' ') - || self.command.is_empty() - || settings.history_filter.is_match(&self.command) - || settings.cwd_filter.is_match(&self.cwd) - || (settings.secrets_filter && SECRET_PATTERNS_RE.is_match(&self.command))) - } -} - -#[cfg(test)] -mod tests { - use regex::RegexSet; - use time::macros::datetime; - - use crate::{ - history::{AUTHOR_FILTER_ALL_AGENT, AUTHOR_FILTER_ALL_USER, HISTORY_VERSION}, - settings::Settings, - }; - - use super::{History, author_matches_filters, is_known_agent}; - - // Test that we don't save history where necessary - #[test] - fn privacy_test() { - let settings = Settings { - cwd_filter: RegexSet::new(["^/supasecret"]).unwrap(), - history_filter: RegexSet::new(["^psql"]).unwrap(), - ..Settings::utc() - }; - - let normal_command: History = History::capture() - .timestamp(time::OffsetDateTime::now_utc()) - .command("echo foo") - .cwd("/") - .build() - .into(); - - let with_space: History = History::capture() - .timestamp(time::OffsetDateTime::now_utc()) - .command(" echo bar") - .cwd("/") - .build() - .into(); - - let empty: History = History::capture() - .timestamp(time::OffsetDateTime::now_utc()) - .command("") - .cwd("/") - .build() - .into(); - - let stripe_key: History = History::capture() - .timestamp(time::OffsetDateTime::now_utc()) - .command("curl foo.com/bar?key=sk_test_1234567890abcdefghijklmnop") - .cwd("/") - .build() - .into(); - - let secret_dir: History = History::capture() - .timestamp(time::OffsetDateTime::now_utc()) - .command("echo ohno") - .cwd("/supasecret") - .build() - .into(); - - let with_psql: History = History::capture() - .timestamp(time::OffsetDateTime::now_utc()) - .command("psql") - .cwd("/supasecret") - .build() - .into(); - - assert!(normal_command.should_save(&settings)); - assert!(!with_space.should_save(&settings)); - assert!(!empty.should_save(&settings)); - assert!(!stripe_key.should_save(&settings)); - assert!(!secret_dir.should_save(&settings)); - assert!(!with_psql.should_save(&settings)); - } - - #[test] - fn known_agents_include_pi() { - assert!(is_known_agent("pi")); - assert!(author_matches_filters( - "pi", - &[AUTHOR_FILTER_ALL_AGENT.to_string()] - )); - assert!(!author_matches_filters( - "pi", - &[AUTHOR_FILTER_ALL_USER.to_string()] - )); - } - - #[test] - fn disable_secrets() { - let settings = Settings { - secrets_filter: false, - ..Settings::utc() - }; - - let stripe_key: History = History::capture() - .timestamp(time::OffsetDateTime::now_utc()) - .command("curl foo.com/bar?key=sk_test_1234567890abcdefghijklmnop") - .cwd("/") - .build() - .into(); - - assert!(stripe_key.should_save(&settings)); - } - - #[test] - fn test_serialize_deserialize() { - let history = History { - id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), - timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), - duration: 49206000, - exit: 0, - command: "git status".to_owned(), - cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), - session: "b97d9a306f274473a203d2eba41f9457".to_owned(), - hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), - author: "conrad.ludgate".to_owned(), - intent: None, - deleted_at: None, - }; - - let serialized = history.serialize().expect("failed to serialize history"); - assert_eq!( - &serialized.0[0..3], - [205, 0, 1], - "should encode as history v1" - ); - - let deserialized = History::deserialize(&serialized.0, HISTORY_VERSION) - .expect("failed to deserialize history"); - assert_eq!(history, deserialized); - } - - #[test] - fn test_serialize_deserialize_deleted() { - let history = History { - id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), - timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), - duration: 49206000, - exit: 0, - command: "git status".to_owned(), - cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), - session: "b97d9a306f274473a203d2eba41f9457".to_owned(), - hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), - author: "conrad.ludgate".to_owned(), - intent: None, - deleted_at: Some(datetime!(2023-11-19 20:18 +00:00)), - }; - - let serialized = history.serialize().expect("failed to serialize history"); - - let deserialized = History::deserialize(&serialized.0, HISTORY_VERSION) - .expect("failed to deserialize history"); - - assert_eq!(history, deserialized); - } - - #[test] - fn test_serialize_deserialize_with_author_and_intent() { - let history = History { - id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), - timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), - duration: 49206000, - exit: 0, - command: "git status".to_owned(), - cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), - session: "b97d9a306f274473a203d2eba41f9457".to_owned(), - hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), - author: "claude".to_owned(), - intent: Some("check repository status".to_owned()), - deleted_at: None, - }; - - let serialized = history.serialize().expect("failed to serialize history"); - let deserialized = History::deserialize(&serialized.0, HISTORY_VERSION) - .expect("failed to deserialize history"); - - assert_eq!(history, deserialized); - } - - #[test] - fn test_serialize_deserialize_version() { - // v0 - let bytes_v0 = [ - 205, 0, 0, 153, 217, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, - 53, 51, 56, 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 207, 23, 99, - 98, 117, 24, 210, 246, 128, 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, - 97, 116, 117, 115, 217, 42, 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, - 46, 108, 117, 100, 103, 97, 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, - 47, 99, 111, 100, 101, 47, 97, 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, - 51, 48, 54, 102, 50, 55, 52, 52, 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, - 102, 57, 52, 53, 55, 187, 102, 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, - 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, 116, 101, 192, - ]; - - let deserialized = History::deserialize(&bytes_v0, "v0"); - assert!(deserialized.is_ok()); - - let deserialized = History::deserialize(&bytes_v0, HISTORY_VERSION); - assert!(deserialized.is_err()); - - let current = History { - id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), - timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), - duration: 49206000, - exit: 0, - command: "git status".to_owned(), - cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), - session: "b97d9a306f274473a203d2eba41f9457".to_owned(), - hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), - author: "conrad.ludgate".to_owned(), - intent: None, - deleted_at: None, - }; - - let bytes_v1 = current.serialize().expect("failed to serialize history"); - let deserialized = History::deserialize(&bytes_v1.0, HISTORY_VERSION); - assert!(deserialized.is_ok()); - - let deserialized = History::deserialize(&bytes_v1.0, "v0"); - assert!(deserialized.is_err()); - } -} diff --git a/crates/atuin-client/src/history/builder.rs b/crates/atuin-client/src/history/builder.rs deleted file mode 100644 index 72a505fd..00000000 --- a/crates/atuin-client/src/history/builder.rs +++ /dev/null @@ -1,154 +0,0 @@ -use typed_builder::TypedBuilder; - -use super::History; - -/// Builder for a history entry that is imported from shell history. -/// -/// The only two required fields are `timestamp` and `command`. -#[derive(Debug, Clone, TypedBuilder)] -pub struct HistoryImported { - timestamp: time::OffsetDateTime, - #[builder(setter(into))] - command: String, - #[builder(default = "unknown".into(), setter(into))] - cwd: String, - #[builder(default = -1)] - exit: i64, - #[builder(default = -1)] - duration: i64, - #[builder(default, setter(strip_option, into))] - session: Option<String>, - #[builder(default, setter(strip_option, into))] - hostname: Option<String>, - #[builder(default, setter(strip_option, into))] - author: Option<String>, - #[builder(default, setter(strip_option, into))] - intent: Option<String>, -} - -impl From<HistoryImported> for History { - fn from(imported: HistoryImported) -> Self { - History::new( - imported.timestamp, - imported.command, - imported.cwd, - imported.exit, - imported.duration, - imported.session, - imported.hostname, - imported.author, - imported.intent, - None, - ) - } -} - -/// Builder for a history entry that is captured via hook. -/// -/// This builder is used only at the `start` step of the hook, -/// so it doesn't have any fields which are known only after -/// the command is finished, such as `exit` or `duration`. -#[derive(Debug, Clone, TypedBuilder)] -pub struct HistoryCaptured { - timestamp: time::OffsetDateTime, - #[builder(setter(into))] - command: String, - #[builder(setter(into))] - cwd: String, - #[builder(default, setter(strip_option, into))] - author: Option<String>, - #[builder(default, setter(strip_option, into))] - intent: Option<String>, -} - -impl From<HistoryCaptured> for History { - fn from(captured: HistoryCaptured) -> Self { - History::new( - captured.timestamp, - captured.command, - captured.cwd, - -1, - -1, - None, - None, - captured.author, - captured.intent, - None, - ) - } -} - -/// Builder for a history entry that is loaded from the database. -/// -/// All fields are required, as they are all present in the database. -#[derive(Debug, Clone, TypedBuilder)] -pub struct HistoryFromDb { - id: String, - timestamp: time::OffsetDateTime, - command: String, - cwd: String, - exit: i64, - duration: i64, - session: String, - hostname: String, - author: String, - intent: Option<String>, - deleted_at: Option<time::OffsetDateTime>, -} - -impl From<HistoryFromDb> for History { - fn from(from_db: HistoryFromDb) -> Self { - History { - id: from_db.id.into(), - timestamp: from_db.timestamp, - exit: from_db.exit, - command: from_db.command, - cwd: from_db.cwd, - duration: from_db.duration, - session: from_db.session, - hostname: from_db.hostname, - author: from_db.author, - intent: from_db.intent, - deleted_at: from_db.deleted_at, - } - } -} - -/// Builder for a history entry that is captured via hook and sent to the daemon -/// -/// This builder is similar to Capture, but we just require more information up front. -/// For the old setup, we could just rely on History::new to read some of the missing -/// data. This is no longer the case. -#[derive(Debug, Clone, TypedBuilder)] -pub struct HistoryDaemonCapture { - timestamp: time::OffsetDateTime, - #[builder(setter(into))] - command: String, - #[builder(setter(into))] - cwd: String, - #[builder(setter(into))] - session: String, - #[builder(setter(into))] - hostname: String, - #[builder(default, setter(strip_option, into))] - author: Option<String>, - #[builder(default, setter(strip_option, into))] - intent: Option<String>, -} - -impl From<HistoryDaemonCapture> for History { - fn from(captured: HistoryDaemonCapture) -> Self { - History::new( - captured.timestamp, - captured.command, - captured.cwd, - -1, - -1, - Some(captured.session), - Some(captured.hostname), - captured.author, - captured.intent, - None, - ) - } -} diff --git a/crates/atuin-client/src/history/store.rs b/crates/atuin-client/src/history/store.rs deleted file mode 100644 index ce7b43a1..00000000 --- a/crates/atuin-client/src/history/store.rs +++ /dev/null @@ -1,434 +0,0 @@ -use std::{collections::HashSet, fmt::Write, time::Duration}; - -use eyre::{Result, bail, eyre}; -use indicatif::{ProgressBar, ProgressState, ProgressStyle}; -use rmp::decode::Bytes; - -use crate::{ - database::{Database, current_context}, - record::{encryption::PASETO_V4, sqlite_store::SqliteStore, store::Store}, -}; -use atuin_common::record::{DecryptedData, Host, HostId, Record, RecordId, RecordIdx}; - -use super::{HISTORY_TAG, HISTORY_VERSION, HISTORY_VERSION_V0, History, HistoryId}; - -#[derive(Debug, Clone)] -pub struct HistoryStore { - pub store: SqliteStore, - pub host_id: HostId, - pub encryption_key: [u8; 32], -} - -#[derive(Debug, Eq, PartialEq, Clone)] -pub enum HistoryRecord { - Create(History), // Create a history record - Delete(HistoryId), // Delete a history record, identified by ID -} - -impl HistoryRecord { - /// Serialize a history record, returning DecryptedData - /// The record will be of a certain type - /// We map those like so: - /// - /// HistoryRecord::Create -> 0 - /// HistoryRecord::Delete-> 1 - /// - /// This numeric identifier is then written as the first byte to the buffer. For history, we - /// append the serialized history right afterwards, to avoid having to handle serialization - /// twice. - /// - /// Deletion simply refers to the history by ID - pub fn serialize(&self) -> Result<DecryptedData> { - // probably don't actually need to use rmp here, but if we ever need to extend it, it's a - // nice wrapper around raw byte stuff - use rmp::encode; - - let mut output = vec![]; - - match self { - HistoryRecord::Create(history) => { - // 0 -> a history create - encode::write_u8(&mut output, 0)?; - - let bytes = history.serialize()?; - - encode::write_bin(&mut output, &bytes.0)?; - } - HistoryRecord::Delete(id) => { - // 1 -> a history delete - encode::write_u8(&mut output, 1)?; - encode::write_str(&mut output, id.0.as_str())?; - } - }; - - Ok(DecryptedData(output)) - } - - pub fn deserialize(bytes: &DecryptedData, version: &str) -> Result<Self> { - use rmp::decode; - - fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { - eyre!("{err:?}") - } - - let mut bytes = Bytes::new(&bytes.0); - - let record_type = decode::read_u8(&mut bytes).map_err(error_report)?; - - match record_type { - // 0 -> HistoryRecord::Create - 0 => { - // not super useful to us atm, but perhaps in the future - // written by write_bin above - let _ = decode::read_bin_len(&mut bytes).map_err(error_report)?; - - let record = History::deserialize(bytes.remaining_slice(), version)?; - - Ok(HistoryRecord::Create(record)) - } - - // 1 -> HistoryRecord::Delete - 1 => { - let bytes = bytes.remaining_slice(); - let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - - if !bytes.is_empty() { - bail!( - "trailing bytes decoding HistoryRecord::Delete - malformed? got {bytes:?}" - ); - } - - Ok(HistoryRecord::Delete(id.to_string().into())) - } - - n => { - bail!("unknown HistoryRecord type {n}") - } - } - } -} - -impl HistoryStore { - pub fn new(store: SqliteStore, host_id: HostId, encryption_key: [u8; 32]) -> Self { - HistoryStore { - store, - host_id, - encryption_key, - } - } - - async fn push_record(&self, record: HistoryRecord) -> Result<(RecordId, RecordIdx)> { - let bytes = record.serialize()?; - let idx = self - .store - .last(self.host_id, HISTORY_TAG) - .await? - .map_or(0, |p| p.idx + 1); - - let record = Record::builder() - .host(Host::new(self.host_id)) - .version(HISTORY_VERSION.to_string()) - .tag(HISTORY_TAG.to_string()) - .idx(idx) - .data(bytes) - .build(); - - let id = record.id; - - self.store - .push(&record.encrypt::<PASETO_V4>(&self.encryption_key)) - .await?; - - Ok((id, idx)) - } - - async fn push_batch(&self, records: impl Iterator<Item = HistoryRecord>) -> Result<()> { - let mut ret = Vec::new(); - - let idx = self - .store - .last(self.host_id, HISTORY_TAG) - .await? - .map_or(0, |p| p.idx + 1); - - // Could probably _also_ do this as an iterator, but let's see how this is for now. - // optimizing for minimal sqlite transactions, this code can be optimised later - for (n, record) in records.enumerate() { - let bytes = record.serialize()?; - - let record = Record::builder() - .host(Host::new(self.host_id)) - .version(HISTORY_VERSION.to_string()) - .tag(HISTORY_TAG.to_string()) - .idx(idx + n as u64) - .data(bytes) - .build(); - - let record = record.encrypt::<PASETO_V4>(&self.encryption_key); - - ret.push(record); - } - - self.store.push_batch(ret.iter()).await?; - - Ok(()) - } - - pub async fn delete(&self, id: HistoryId) -> Result<(RecordId, RecordIdx)> { - let record = HistoryRecord::Delete(id); - - self.push_record(record).await - } - - /// Delete a batch of history entries via the record store. - /// Returns the record IDs so the caller can run incremental_build when ready. - pub async fn delete_entries( - &self, - entries: impl IntoIterator<Item = History>, - ) -> Result<Vec<RecordId>> { - let mut record_ids = Vec::new(); - for entry in entries { - let (id, _) = self.delete(entry.id).await?; - record_ids.push(id); - } - Ok(record_ids) - } - - pub async fn push(&self, history: History) -> Result<(RecordId, RecordIdx)> { - // TODO(ellie): move the history store to its own file - // it's tiny rn so fine as is - let record = HistoryRecord::Create(history); - - self.push_record(record).await - } - - pub async fn history(&self) -> Result<Vec<HistoryRecord>> { - // Atm this loads all history into memory - // Not ideal as that is potentially quite a lot, although history will be small. - let records = self.store.all_tagged(HISTORY_TAG).await?; - let mut ret = Vec::with_capacity(records.len()); - - for record in records.into_iter() { - let hist = match record.version.as_str() { - HISTORY_VERSION_V0 | HISTORY_VERSION => { - let version = record.version.clone(); - let decrypted = record.decrypt::<PASETO_V4>(&self.encryption_key)?; - - HistoryRecord::deserialize(&decrypted.data, version.as_str()) - } - version => bail!("unknown history version {version:?}"), - }?; - - ret.push(hist); - } - - Ok(ret) - } - - pub async fn build(&self, database: &dyn Database) -> Result<()> { - // I'd like to change how we rebuild and not couple this with the database, but need to - // consider the structure more deeply. This will be easy to change. - - // TODO(ellie): page or iterate this - let history = self.history().await?; - - // In theory we could flatten this here - // The current issue is that the database may have history in it already, from the old sync - // This didn't actually delete old history - // If we're sure we have a DB only maintained by the new store, we can flatten - // create/delete before we even get to sqlite - let mut creates = Vec::new(); - let mut deletes = Vec::new(); - - for i in history { - match i { - HistoryRecord::Create(h) => { - creates.push(h); - } - HistoryRecord::Delete(id) => { - deletes.push(id); - } - } - } - - database.save_bulk(&creates).await?; - database.delete_rows(&deletes).await?; - - Ok(()) - } - - pub async fn incremental_build(&self, database: &dyn Database, ids: &[RecordId]) -> Result<()> { - for id in ids { - let record = self.store.get(*id).await; - - let record = match record { - Ok(record) => record, - _ => { - continue; - } - }; - - if record.tag != HISTORY_TAG { - continue; - } - - let version = record.version.clone(); - let decrypted = record.decrypt::<PASETO_V4>(&self.encryption_key)?; - let record = match version.as_str() { - HISTORY_VERSION_V0 | HISTORY_VERSION => { - HistoryRecord::deserialize(&decrypted.data, version.as_str())? - } - version => bail!("unknown history version {version:?}"), - }; - - match record { - HistoryRecord::Create(h) => { - // TODO: benchmark CPU time/memory tradeoff of batch commit vs one at a time - database.save(&h).await?; - } - HistoryRecord::Delete(id) => { - database.delete_rows(&[id]).await?; - } - } - } - - Ok(()) - } - - /// Get a list of history IDs that exist in the store - /// Note: This currently involves loading all history into memory. This is not going to be a - /// large amount in absolute terms, but do not all it in a hot loop. - pub async fn history_ids(&self) -> Result<HashSet<HistoryId>> { - let history = self.history().await?; - - let ret = HashSet::from_iter(history.iter().map(|h| match h { - HistoryRecord::Create(h) => h.id.clone(), - HistoryRecord::Delete(id) => id.clone(), - })); - - Ok(ret) - } - - pub async fn init_store(&self, db: &impl Database) -> Result<()> { - let pb = ProgressBar::new_spinner(); - pb.set_style( - ProgressStyle::with_template("{spinner:.blue} {msg}") - .unwrap() - .with_key("eta", |state: &ProgressState, w: &mut dyn Write| { - write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap() - }) - .progress_chars("#>-"), - ); - pb.enable_steady_tick(Duration::from_millis(500)); - - pb.set_message("Fetching history from old database"); - - let context = current_context().await?; - let history = db.list(&[], &context, None, false, true).await?; - - pb.set_message("Fetching history already in store"); - let store_ids = self.history_ids().await?; - - pb.set_message("Converting old history to new store"); - let mut records = Vec::new(); - - for i in history { - debug!("loaded {}", i.id); - - if store_ids.contains(&i.id) { - debug!("skipping {} - already exists", i.id); - continue; - } - - if i.deleted_at.is_some() { - records.push(HistoryRecord::Delete(i.id)); - } else { - records.push(HistoryRecord::Create(i)); - } - } - - pb.set_message("Writing to db"); - - if !records.is_empty() { - self.push_batch(records.into_iter()).await?; - } - - pb.finish_with_message("Import complete"); - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use atuin_common::record::DecryptedData; - use time::macros::datetime; - - use crate::history::{HISTORY_VERSION, store::HistoryRecord}; - - use super::History; - - #[test] - fn test_serialize_deserialize_create() { - let bytes = [ - 204, 0, 196, 147, 205, 0, 1, 154, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 56, 49, - 55, 53, 55, 99, 100, 50, 97, 101, 101, 54, 53, 99, 100, 55, 56, 54, 49, 102, 57, 99, - 56, 49, 207, 23, 166, 251, 212, 181, 82, 0, 0, 100, 0, 162, 108, 115, 217, 41, 47, 85, - 115, 101, 114, 115, 47, 101, 108, 108, 105, 101, 47, 115, 114, 99, 47, 103, 105, 116, - 104, 117, 98, 46, 99, 111, 109, 47, 97, 116, 117, 105, 110, 115, 104, 47, 97, 116, 117, - 105, 110, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 97, 100, 56, 57, 55, 53, 57, 55, - 56, 53, 50, 53, 50, 55, 97, 51, 49, 99, 57, 57, 56, 48, 53, 57, 170, 98, 111, 111, 112, - 58, 101, 108, 108, 105, 101, 192, 165, 101, 108, 108, 105, 101, - ]; - - let history = History { - id: "018cd4fe81757cd2aee65cd7861f9c81".to_owned().into(), - timestamp: datetime!(2024-01-04 00:00:00.000000 +00:00), - duration: 100, - exit: 0, - command: "ls".to_owned(), - cwd: "/Users/ellie/src/github.com/atuinsh/atuin".to_owned(), - session: "018cd4fead897597852527a31c998059".to_owned(), - hostname: "boop:ellie".to_owned(), - author: "ellie".to_owned(), - intent: None, - deleted_at: None, - }; - - let record = HistoryRecord::Create(history); - - let serialized = record.serialize().expect("failed to serialize history"); - assert_eq!(serialized.0, bytes); - - let deserialized = HistoryRecord::deserialize(&serialized, HISTORY_VERSION) - .expect("failed to deserialize HistoryRecord"); - assert_eq!(deserialized, record); - - // check the snapshot too - let deserialized = - HistoryRecord::deserialize(&DecryptedData(Vec::from(bytes)), HISTORY_VERSION) - .expect("failed to deserialize HistoryRecord"); - assert_eq!(deserialized, record); - } - - #[test] - fn test_serialize_deserialize_delete() { - let bytes = [ - 204, 1, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 56, 49, 55, 53, 55, 99, 100, 50, - 97, 101, 101, 54, 53, 99, 100, 55, 56, 54, 49, 102, 57, 99, 56, 49, - ]; - let record = HistoryRecord::Delete("018cd4fe81757cd2aee65cd7861f9c81".to_string().into()); - - let serialized = record.serialize().expect("failed to serialize history"); - assert_eq!(serialized.0, bytes); - - let deserialized = HistoryRecord::deserialize(&serialized, HISTORY_VERSION) - .expect("failed to deserialize HistoryRecord"); - assert_eq!(deserialized, record); - - let deserialized = - HistoryRecord::deserialize(&DecryptedData(Vec::from(bytes)), HISTORY_VERSION) - .expect("failed to deserialize HistoryRecord"); - assert_eq!(deserialized, record); - } -} diff --git a/crates/atuin-client/src/import/bash.rs b/crates/atuin-client/src/import/bash.rs deleted file mode 100644 index 99a44a58..00000000 --- a/crates/atuin-client/src/import/bash.rs +++ /dev/null @@ -1,220 +0,0 @@ -use std::{path::PathBuf, str}; - -use async_trait::async_trait; -use directories::UserDirs; -use eyre::{Result, eyre}; -use itertools::Itertools; -use time::{Duration, OffsetDateTime}; - -use super::{Importer, Loader, get_histfile_path, unix_byte_lines}; -use crate::history::History; -use crate::import::read_to_end; - -#[derive(Debug)] -pub struct Bash { - bytes: Vec<u8>, -} - -fn default_histpath() -> Result<PathBuf> { - let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; - let home_dir = user_dirs.home_dir(); - - Ok(home_dir.join(".bash_history")) -} - -#[async_trait] -impl Importer for Bash { - const NAME: &'static str = "bash"; - - async fn new() -> Result<Self> { - let bytes = read_to_end(get_histfile_path(default_histpath)?)?; - Ok(Self { bytes }) - } - - async fn entries(&mut self) -> Result<usize> { - let count = unix_byte_lines(&self.bytes) - .map(LineType::from) - .filter(|line| matches!(line, LineType::Command(_))) - .count(); - Ok(count) - } - - async fn load(self, h: &mut impl Loader) -> Result<()> { - let lines = unix_byte_lines(&self.bytes) - .map(LineType::from) - .filter(|line| !matches!(line, LineType::NotUtf8)) // invalid utf8 are ignored - .collect_vec(); - - let (commands_before_first_timestamp, first_timestamp) = lines - .iter() - .enumerate() - .find_map(|(i, line)| match line { - LineType::Timestamp(t) => Some((i, *t)), - _ => None, - }) - // if no known timestamps, use now as base - .unwrap_or((lines.len(), OffsetDateTime::now_utc())); - - // if no timestamp is recorded, then use this increment to set an arbitrary timestamp - // to preserve ordering - // this increment is deliberately very small to prevent particularly fast fingers - // causing ordering issues; it also helps in handling the "here document" syntax, - // where several lines are recorded in succession without individual timestamps - let timestamp_increment = Duration::milliseconds(1); - - // make sure there is a minimum amount of time before the first known timestamp - // to fit all commands, given the default increment - let mut next_timestamp = - first_timestamp - timestamp_increment * commands_before_first_timestamp as i32; - - for line in lines.into_iter() { - match line { - LineType::NotUtf8 => unreachable!(), // already filtered - LineType::Empty => {} // do nothing - LineType::Timestamp(t) => { - if t < next_timestamp { - warn!( - "Time reversal detected in Bash history! Commands may be ordered incorrectly." - ); - } - next_timestamp = t; - } - LineType::Command(c) => { - let imported = History::import().timestamp(next_timestamp).command(c); - - h.push(imported.build().into()).await?; - next_timestamp += timestamp_increment; - } - } - } - - Ok(()) - } -} - -#[derive(Debug, Clone)] -enum LineType<'a> { - NotUtf8, - /// Can happen when using the "here document" syntax. - Empty, - /// A timestamp line start with a '#', followed immediately by an integer - /// that represents seconds since UNIX epoch. - Timestamp(OffsetDateTime), - /// Anything else. - Command(&'a str), -} -impl<'a> From<&'a [u8]> for LineType<'a> { - fn from(bytes: &'a [u8]) -> Self { - let Ok(line) = str::from_utf8(bytes) else { - return LineType::NotUtf8; - }; - if line.is_empty() { - return LineType::Empty; - } - - match try_parse_line_as_timestamp(line) { - Some(time) => LineType::Timestamp(time), - None => LineType::Command(line), - } - } -} - -fn try_parse_line_as_timestamp(line: &str) -> Option<OffsetDateTime> { - let seconds = line.strip_prefix('#')?.parse().ok()?; - OffsetDateTime::from_unix_timestamp(seconds).ok() -} - -#[cfg(test)] -mod test { - use std::cmp::Ordering; - - use itertools::{Itertools, assert_equal}; - - use crate::import::{Importer, tests::TestLoader}; - - use super::Bash; - - #[tokio::test] - async fn parse_no_timestamps() { - let bytes = r"cargo install atuin -cargo update -cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷ -" - .as_bytes() - .to_owned(); - - let mut bash = Bash { bytes }; - assert_eq!(bash.entries().await.unwrap(), 3); - - let mut loader = TestLoader::default(); - bash.load(&mut loader).await.unwrap(); - - assert_equal( - loader.buf.iter().map(|h| h.command.as_str()), - [ - "cargo install atuin", - "cargo update", - "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", - ], - ); - assert!(is_strictly_sorted(loader.buf.iter().map(|h| h.timestamp))) - } - - #[tokio::test] - async fn parse_with_timestamps() { - let bytes = b"#1672918999 -git reset -#1672919006 -git clean -dxf -#1672919020 -cd ../ -" - .to_vec(); - - let mut bash = Bash { bytes }; - assert_eq!(bash.entries().await.unwrap(), 3); - - let mut loader = TestLoader::default(); - bash.load(&mut loader).await.unwrap(); - - assert_equal( - loader.buf.iter().map(|h| h.command.as_str()), - ["git reset", "git clean -dxf", "cd ../"], - ); - assert_equal( - loader.buf.iter().map(|h| h.timestamp.unix_timestamp()), - [1672918999, 1672919006, 1672919020], - ) - } - - #[tokio::test] - async fn parse_with_partial_timestamps() { - let bytes = b"git reset -#1672919006 -git clean -dxf -cd ../ -" - .to_vec(); - - let mut bash = Bash { bytes }; - assert_eq!(bash.entries().await.unwrap(), 3); - - let mut loader = TestLoader::default(); - bash.load(&mut loader).await.unwrap(); - - assert_equal( - loader.buf.iter().map(|h| h.command.as_str()), - ["git reset", "git clean -dxf", "cd ../"], - ); - assert!(is_strictly_sorted(loader.buf.iter().map(|h| h.timestamp))) - } - - fn is_strictly_sorted<T>(iter: impl IntoIterator<Item = T>) -> bool - where - T: Clone + PartialOrd, - { - iter.into_iter() - .tuple_windows() - .all(|(a, b)| matches!(a.partial_cmp(&b), Some(Ordering::Less))) - } -} diff --git a/crates/atuin-client/src/import/fish.rs b/crates/atuin-client/src/import/fish.rs deleted file mode 100644 index 9fcf624c..00000000 --- a/crates/atuin-client/src/import/fish.rs +++ /dev/null @@ -1,179 +0,0 @@ -// import old shell history! -// automatically hoover up all that we can find - -use std::path::PathBuf; - -use async_trait::async_trait; -use directories::BaseDirs; -use eyre::{Result, eyre}; -use time::OffsetDateTime; - -use super::{Importer, Loader, unix_byte_lines}; -use crate::history::History; -use crate::import::read_to_end; - -#[derive(Debug)] -pub struct Fish { - bytes: Vec<u8>, -} - -/// see https://fishshell.com/docs/current/interactive.html#searchable-command-history -fn default_histpath() -> Result<PathBuf> { - let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; - let data = std::env::var("XDG_DATA_HOME").map_or_else( - |_| base.home_dir().join(".local").join("share"), - PathBuf::from, - ); - - // fish supports multiple history sessions - // If `fish_history` var is missing, or set to `default`, use `fish` as the session - let session = std::env::var("fish_history").unwrap_or_else(|_| String::from("fish")); - let session = if session == "default" { - String::from("fish") - } else { - session - }; - - let mut histpath = data.join("fish"); - histpath.push(format!("{session}_history")); - - if histpath.exists() { - Ok(histpath) - } else { - Err(eyre!("Could not find history file.")) - } -} - -#[async_trait] -impl Importer for Fish { - const NAME: &'static str = "fish"; - - async fn new() -> Result<Self> { - let bytes = read_to_end(default_histpath()?)?; - Ok(Self { bytes }) - } - - async fn entries(&mut self) -> Result<usize> { - Ok(super::count_lines(&self.bytes)) - } - - async fn load(self, loader: &mut impl Loader) -> Result<()> { - let now = OffsetDateTime::now_utc(); - let mut time: Option<OffsetDateTime> = None; - let mut cmd: Option<String> = None; - - for b in unix_byte_lines(&self.bytes) { - let s = match std::str::from_utf8(b) { - Ok(s) => s, - Err(_) => continue, // we can skip past things like invalid utf8 - }; - - if let Some(c) = s.strip_prefix("- cmd: ") { - // first, we must deal with the prev cmd - if let Some(cmd) = cmd.take() { - let time = time.unwrap_or(now); - let entry = History::import().timestamp(time).command(cmd); - - loader.push(entry.build().into()).await?; - } - - // using raw strings to avoid needing escaping. - // replaces double backslashes with single backslashes - let c = c.replace(r"\\", r"\"); - // replaces escaped newlines - let c = c.replace(r"\n", "\n"); - // TODO: any other escape characters? - - cmd = Some(c); - } else if let Some(t) = s.strip_prefix(" when: ") { - // if t is not an int, just ignore this line - if let Ok(t) = t.parse::<i64>() { - time = Some(OffsetDateTime::from_unix_timestamp(t)?); - } - } else { - // ... ignore paths lines - } - } - - // we might have a trailing cmd - if let Some(cmd) = cmd.take() { - let time = time.unwrap_or(now); - let entry = History::import().timestamp(time).command(cmd); - - loader.push(entry.build().into()).await?; - } - - Ok(()) - } -} - -#[cfg(test)] -mod test { - - use crate::import::{Importer, tests::TestLoader}; - - use super::Fish; - - #[tokio::test] - async fn parse_complex() { - // complicated input with varying contents and escaped strings. - let bytes = r#"- cmd: history --help - when: 1639162832 -- cmd: cat ~/.bash_history - when: 1639162851 - paths: - - ~/.bash_history -- cmd: ls ~/.local/share/fish/fish_history - when: 1639162890 - paths: - - ~/.local/share/fish/fish_history -- cmd: cat ~/.local/share/fish/fish_history - when: 1639162893 - paths: - - ~/.local/share/fish/fish_history -ERROR -- CORRUPTED: ENTRY - CONTINUE: - - AS - - NORMAL -- cmd: echo "foo" \\\n'bar' baz - when: 1639162933 -- cmd: cat ~/.local/share/fish/fish_history - when: 1639162939 - paths: - - ~/.local/share/fish/fish_history -- cmd: echo "\\"" \\\\ "\\\\" - when: 1639163063 -- cmd: cat ~/.local/share/fish/fish_history - when: 1639163066 - paths: - - ~/.local/share/fish/fish_history -"# - .as_bytes() - .to_owned(); - - let fish = Fish { bytes }; - - let mut loader = TestLoader::default(); - fish.load(&mut loader).await.unwrap(); - let mut history = loader.buf.into_iter(); - - // simple wrapper for fish history entry - macro_rules! fishtory { - ($timestamp:expr_2021, $command:expr_2021) => { - let h = history.next().expect("missing entry in history"); - assert_eq!(h.command.as_str(), $command); - assert_eq!(h.timestamp.unix_timestamp(), $timestamp); - }; - } - - fishtory!(1639162832, "history --help"); - fishtory!(1639162851, "cat ~/.bash_history"); - fishtory!(1639162890, "ls ~/.local/share/fish/fish_history"); - fishtory!(1639162893, "cat ~/.local/share/fish/fish_history"); - fishtory!(1639162933, "echo \"foo\" \\\n'bar' baz"); - fishtory!(1639162939, "cat ~/.local/share/fish/fish_history"); - fishtory!(1639163063, r#"echo "\"" \\ "\\""#); - fishtory!(1639163066, "cat ~/.local/share/fish/fish_history"); - } -} diff --git a/crates/atuin-client/src/import/mod.rs b/crates/atuin-client/src/import/mod.rs deleted file mode 100644 index 4a1c6af6..00000000 --- a/crates/atuin-client/src/import/mod.rs +++ /dev/null @@ -1,140 +0,0 @@ -use std::fs::File; -use std::io::Read; -use std::path::PathBuf; - -use async_trait::async_trait; -use eyre::{Result, bail}; -use memchr::Memchr; - -use crate::history::History; - -pub mod bash; -pub mod fish; -pub mod nu; -pub mod nu_histdb; -pub mod powershell; -pub mod replxx; -pub mod resh; -pub mod xonsh; -pub mod xonsh_sqlite; -pub mod zsh; -pub mod zsh_histdb; - -#[async_trait] -pub trait Importer: Sized { - const NAME: &'static str; - async fn new() -> Result<Self>; - async fn entries(&mut self) -> Result<usize>; - async fn load(self, loader: &mut impl Loader) -> Result<()>; -} - -#[async_trait] -pub trait Loader: Sync + Send { - async fn push(&mut self, hist: History) -> eyre::Result<()>; -} - -fn unix_byte_lines(input: &[u8]) -> impl Iterator<Item = &[u8]> { - UnixByteLines { - iter: memchr::memchr_iter(b'\n', input), - bytes: input, - i: 0, - } -} - -struct UnixByteLines<'a> { - iter: Memchr<'a>, - bytes: &'a [u8], - i: usize, -} - -impl<'a> Iterator for UnixByteLines<'a> { - type Item = &'a [u8]; - - fn next(&mut self) -> Option<Self::Item> { - let j = self.iter.next()?; - let out = &self.bytes[self.i..j]; - self.i = j + 1; - Some(out) - } - - fn count(self) -> usize - where - Self: Sized, - { - self.iter.count() - } -} - -fn count_lines(input: &[u8]) -> usize { - unix_byte_lines(input).count() -} - -fn get_histpath<D>(def: D) -> Result<PathBuf> -where - D: FnOnce() -> Result<PathBuf>, -{ - if let Ok(p) = std::env::var("HISTFILE") { - Ok(PathBuf::from(p)) - } else { - def() - } -} - -fn get_histfile_path<D>(def: D) -> Result<PathBuf> -where - D: FnOnce() -> Result<PathBuf>, -{ - get_histpath(def).and_then(is_file) -} - -fn get_histdir_path<D>(def: D) -> Result<PathBuf> -where - D: FnOnce() -> Result<PathBuf>, -{ - get_histpath(def).and_then(is_dir) -} - -fn read_to_end(path: PathBuf) -> Result<Vec<u8>> { - let mut bytes = Vec::new(); - let mut f = File::open(path)?; - f.read_to_end(&mut bytes)?; - Ok(bytes) -} -fn is_file(p: PathBuf) -> Result<PathBuf> { - if p.is_file() { - Ok(p) - } else { - bail!( - "Could not find history file {:?}. Try setting and exporting $HISTFILE", - p - ) - } -} -fn is_dir(p: PathBuf) -> Result<PathBuf> { - if p.is_dir() { - Ok(p) - } else { - bail!( - "Could not find history directory {:?}. Try setting and exporting $HISTFILE", - p - ) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[derive(Default)] - pub struct TestLoader { - pub buf: Vec<History>, - } - - #[async_trait] - impl Loader for TestLoader { - async fn push(&mut self, hist: History) -> Result<()> { - self.buf.push(hist); - Ok(()) - } - } -} diff --git a/crates/atuin-client/src/import/nu.rs b/crates/atuin-client/src/import/nu.rs deleted file mode 100644 index cae90ac4..00000000 --- a/crates/atuin-client/src/import/nu.rs +++ /dev/null @@ -1,67 +0,0 @@ -// import old shell history! -// automatically hoover up all that we can find - -use std::path::PathBuf; - -use async_trait::async_trait; -use directories::BaseDirs; -use eyre::{Result, eyre}; -use time::OffsetDateTime; - -use super::{Importer, Loader, unix_byte_lines}; -use crate::history::History; -use crate::import::read_to_end; - -#[derive(Debug)] -pub struct Nu { - bytes: Vec<u8>, -} - -fn get_histpath() -> Result<PathBuf> { - let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; - let config_dir = base.config_dir().join("nushell"); - - let histpath = config_dir.join("history.txt"); - if histpath.exists() { - Ok(histpath) - } else { - Err(eyre!("Could not find history file.")) - } -} - -#[async_trait] -impl Importer for Nu { - const NAME: &'static str = "nu"; - - async fn new() -> Result<Self> { - let bytes = read_to_end(get_histpath()?)?; - Ok(Self { bytes }) - } - - async fn entries(&mut self) -> Result<usize> { - Ok(super::count_lines(&self.bytes)) - } - - async fn load(self, h: &mut impl Loader) -> Result<()> { - let now = OffsetDateTime::now_utc(); - - let mut counter = 0; - for b in unix_byte_lines(&self.bytes) { - let s = match std::str::from_utf8(b) { - Ok(s) => s, - Err(_) => continue, // we can skip past things like invalid utf8 - }; - - let cmd: String = s.replace("<\\n>", "\n"); - - let offset = time::Duration::nanoseconds(counter); - counter += 1; - - let entry = History::import().timestamp(now - offset).command(cmd); - - h.push(entry.build().into()).await?; - } - - Ok(()) - } -} diff --git a/crates/atuin-client/src/import/nu_histdb.rs b/crates/atuin-client/src/import/nu_histdb.rs deleted file mode 100644 index a13cb2b4..00000000 --- a/crates/atuin-client/src/import/nu_histdb.rs +++ /dev/null @@ -1,113 +0,0 @@ -// import old shell history! -// automatically hoover up all that we can find - -use std::path::PathBuf; - -use async_trait::async_trait; -use directories::BaseDirs; -use eyre::{Result, eyre}; -use sqlx::{Pool, sqlite::SqlitePool}; -use time::{Duration, OffsetDateTime}; - -use super::Importer; -use crate::history::History; -use crate::import::Loader; - -#[derive(sqlx::FromRow, Debug)] -pub struct HistDbEntry { - pub id: i64, - pub command_line: Vec<u8>, - pub start_timestamp: i64, - pub session_id: i64, - pub hostname: Vec<u8>, - pub cwd: Vec<u8>, - pub duration_ms: i64, - pub exit_status: i64, - pub more_info: Vec<u8>, -} - -impl From<HistDbEntry> for History { - fn from(histdb_item: HistDbEntry) -> Self { - let ts_secs = histdb_item.start_timestamp / 1000; - let ts_ns = (histdb_item.start_timestamp % 1000) * 1_000_000; - let imported = History::import() - .timestamp( - OffsetDateTime::from_unix_timestamp(ts_secs).unwrap() - + Duration::nanoseconds(ts_ns), - ) - .command(String::from_utf8(histdb_item.command_line).unwrap()) - .cwd(String::from_utf8(histdb_item.cwd).unwrap()) - .exit(histdb_item.exit_status) - .duration(histdb_item.duration_ms) - .session(format!("{:x}", histdb_item.session_id)) - .hostname(String::from_utf8(histdb_item.hostname).unwrap()); - - imported.build().into() - } -} - -#[derive(Debug)] -pub struct NuHistDb { - histdb: Vec<HistDbEntry>, -} - -/// Read db at given file, return vector of entries. -async fn hist_from_db(dbpath: PathBuf) -> Result<Vec<HistDbEntry>> { - let pool = SqlitePool::connect(dbpath.to_str().unwrap()).await?; - hist_from_db_conn(pool).await -} - -async fn hist_from_db_conn(pool: Pool<sqlx::Sqlite>) -> Result<Vec<HistDbEntry>> { - let query = r#" - SELECT - id, command_line, start_timestamp, session_id, hostname, cwd, duration_ms, exit_status, - more_info - FROM history - ORDER BY start_timestamp - "#; - let histdb_vec: Vec<HistDbEntry> = sqlx::query_as::<_, HistDbEntry>(query) - .fetch_all(&pool) - .await?; - Ok(histdb_vec) -} - -impl NuHistDb { - pub fn histpath() -> Result<PathBuf> { - let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; - let config_dir = base.config_dir().join("nushell"); - - let histdb_path = config_dir.join("history.sqlite3"); - if histdb_path.exists() { - Ok(histdb_path) - } else { - Err(eyre!("Could not find history file.")) - } - } -} - -#[async_trait] -impl Importer for NuHistDb { - // Not sure how this is used - const NAME: &'static str = "nu_histdb"; - - /// Creates a new NuHistDb and populates the history based on the pre-populated data - /// structure. - async fn new() -> Result<Self> { - let dbpath = NuHistDb::histpath()?; - let histdb_entry_vec = hist_from_db(dbpath).await?; - Ok(Self { - histdb: histdb_entry_vec, - }) - } - - async fn entries(&mut self) -> Result<usize> { - Ok(self.histdb.len()) - } - - async fn load(self, h: &mut impl Loader) -> Result<()> { - for i in self.histdb { - h.push(i.into()).await?; - } - Ok(()) - } -} diff --git a/crates/atuin-client/src/import/powershell.rs b/crates/atuin-client/src/import/powershell.rs deleted file mode 100644 index 86fd007d..00000000 --- a/crates/atuin-client/src/import/powershell.rs +++ /dev/null @@ -1,202 +0,0 @@ -use async_trait::async_trait; -use directories::BaseDirs; -use eyre::{Result, eyre}; -use std::path::PathBuf; -use time::{Duration, OffsetDateTime}; - -use super::{Importer, Loader, count_lines, unix_byte_lines}; -use crate::history::History; -use crate::import::read_to_end; - -#[derive(Debug)] -pub struct PowerShell { - bytes: Vec<u8>, - line_count: Option<usize>, -} - -fn get_history_path() -> Result<PathBuf> { - let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; - - // The command line history in PowerShell is maintained by the PSReadLine module: - // https://learn.microsoft.com/en-us/powershell/module/psreadline/about/about_psreadline#command-history - // - // > PSReadLine maintains a history file containing all the commands and data you've entered from the command line. - // > The history files are a file named `$($Host.Name)_history.txt`. - // > On Windows systems the history file is stored at `$Env:APPDATA\Microsoft\Windows\PowerShell\PSReadLine`. - // > On non-Windows systems, the history files are stored at `$Env:XDG_DATA_HOME/powershell/PSReadLine` - // > or `$Env:HOME/.local/share/powershell/PSReadLine`. - - let dir = if cfg!(windows) { - base.data_dir() - .join("Microsoft") - .join("Windows") - .join("PowerShell") - .join("PSReadLine") - } else { - std::env::var("XDG_DATA_HOME") - .map_or_else( - |_| base.home_dir().join(".local").join("share"), - PathBuf::from, - ) - .join("powershell") - .join("PSReadLine") - }; - - // The history is stored in a file named `$($Host.Name)_history.txt`. - // For the default console host shipped by Microsoft,`$Host.Name` is `ConsoleHost`: - // https://learn.microsoft.com/en-us/dotnet/api/system.management.automation.host.pshost.name#remarks - - let file = dir.join("ConsoleHost_history.txt"); - - if file.is_file() { - Ok(file) - } else { - Err(eyre!("Could not find history file: {}", file.display())) - } -} - -#[async_trait] -impl Importer for PowerShell { - const NAME: &'static str = "PowerShell"; - - async fn new() -> Result<Self> { - let bytes = read_to_end(get_history_path()?)?; - Ok(Self { - bytes, - line_count: None, - }) - } - - async fn entries(&mut self) -> Result<usize> { - // Commands can be split over multiple lines, - // but this is only used for a progress bar, and multi-line commands - // should be quite rare, so this is not an issue in practice. - if self.line_count.is_none() { - self.line_count = Some(count_lines(&self.bytes)); - } - Ok(self.line_count.unwrap()) - } - - async fn load(mut self, h: &mut impl Loader) -> Result<()> { - let line_count = self.entries().await?; - let start = OffsetDateTime::now_utc() - Duration::milliseconds(line_count as i64); - - let mut counter = 0; - let mut iter = unix_byte_lines(&self.bytes); - - while let Some(s) = iter.next() { - let Ok(s) = read_line(s) else { - continue; // We can skip past things like invalid utf8 - }; - - let mut cmd = s.to_string(); - - // Multi-line commands end with a backtick, append the following lines. - while cmd.ends_with('`') { - cmd.pop(); - - let Some(next) = iter.next() else { - break; - }; - let Ok(next) = read_line(next) else { - break; - }; - - cmd.push('\n'); - cmd.push_str(next); - } - - if cmd.is_empty() { - continue; - } - - let offset = Duration::milliseconds(counter); - counter += 1; - - let entry = History::import().timestamp(start + offset).command(cmd); - h.push(entry.build().into()).await?; - } - - Ok(()) - } -} - -fn read_line(s: &[u8]) -> Result<&str> { - let s = str::from_utf8(s)?; - - // History is stored in CRLF on Windows, normalize the input to LF on all platforms. - let s = s.strip_suffix('\r').unwrap_or(s); - - Ok(s) -} - -#[cfg(test)] -mod test { - use super::*; - use crate::import::tests::TestLoader; - use itertools::assert_equal; - - const INPUT: &str = r#"cargo install atuin -cargo update -echo "first line` -second line` -` -last line" -echo foo - -echo bar -echo baz -"#; - - const EXPECTED: &[&str] = &[ - "cargo install atuin", - "cargo update", - "echo \"first line\nsecond line\n\nlast line\"", - "echo foo", - "echo bar", - "echo baz", - ]; - - #[tokio::test] - async fn test_import() { - let loader = import(INPUT).await; - - let actual = loader.buf.iter().map(|h| h.command.clone()); - let expected = EXPECTED.iter().map(|s| s.to_string()); - - assert_equal(actual, expected); - } - - #[tokio::test] - async fn test_crlf() { - let input = INPUT.replace("\n", "\r\n"); - let loader = import(input.as_str()).await; - - let actual = loader.buf.iter().map(|h| h.command.clone()); - let expected = EXPECTED.iter().map(|s| s.to_string()); - - assert_equal(actual, expected); - } - - #[tokio::test] - async fn test_timestamps() { - let loader = import(INPUT).await; - - let mut prev = loader.buf.first().unwrap().timestamp; - for current in loader.buf.iter().skip(1).map(|h| h.timestamp) { - assert!(current > prev); - prev = current; - } - } - - async fn import(input: &str) -> TestLoader { - let powershell = PowerShell { - bytes: input.as_bytes().to_vec(), - line_count: None, - }; - - let mut loader = TestLoader::default(); - powershell.load(&mut loader).await.unwrap(); - loader - } -} diff --git a/crates/atuin-client/src/import/replxx.rs b/crates/atuin-client/src/import/replxx.rs deleted file mode 100644 index 47d566cf..00000000 --- a/crates/atuin-client/src/import/replxx.rs +++ /dev/null @@ -1,137 +0,0 @@ -use std::{path::PathBuf, str}; - -use async_trait::async_trait; -use directories::UserDirs; -use eyre::{Result, eyre}; -use time::{OffsetDateTime, PrimitiveDateTime, macros::format_description}; - -use super::{Importer, Loader, get_histfile_path, unix_byte_lines}; -use crate::history::History; -use crate::import::read_to_end; - -#[derive(Debug)] -pub struct Replxx { - bytes: Vec<u8>, -} - -fn default_histpath() -> Result<PathBuf> { - let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; - let home_dir = user_dirs.home_dir(); - - // There is no default histfile for replxx. - // Here we try a couple of common names. - let mut candidates = ["replxx_history.txt", ".histfile"].iter(); - loop { - match candidates.next() { - Some(candidate) => { - let histpath = home_dir.join(candidate); - if histpath.exists() { - break Ok(histpath); - } - } - None => { - break Err(eyre!( - "Could not find history file. Try setting and exporting $HISTFILE" - )); - } - } - } -} - -#[async_trait] -impl Importer for Replxx { - const NAME: &'static str = "replxx"; - - async fn new() -> Result<Self> { - let bytes = read_to_end(get_histfile_path(default_histpath)?)?; - Ok(Self { bytes }) - } - - async fn entries(&mut self) -> Result<usize> { - Ok(super::count_lines(&self.bytes) / 2) - } - - async fn load(self, h: &mut impl Loader) -> Result<()> { - let mut timestamp = OffsetDateTime::UNIX_EPOCH; - - for b in unix_byte_lines(&self.bytes) { - let s = std::str::from_utf8(b)?; - match try_parse_line_as_timestamp(s) { - Some(t) => timestamp = t, - None => { - // replxx uses ETB character (0x17) as line breaker - let cmd = s.replace('\u{0017}', "\n"); - let imported = History::import().timestamp(timestamp).command(cmd); - - h.push(imported.build().into()).await?; - } - } - } - - Ok(()) - } -} - -fn try_parse_line_as_timestamp(line: &str) -> Option<OffsetDateTime> { - // replxx history date time format: ### yyyy-mm-dd hh:mm:ss.xxx - let date_time_str = line.strip_prefix("### ")?; - let format = - format_description!("[year]-[month]-[day] [hour]:[minute]:[second].[subsecond digits:3]"); - - let primitive_date_time = PrimitiveDateTime::parse(date_time_str, format).ok()?; - // There is no safe way to get local time offset. - // For simplicity let's just assume UTC. - Some(primitive_date_time.assume_utc()) -} - -#[cfg(test)] -mod test { - - use crate::import::{Importer, tests::TestLoader}; - - use super::Replxx; - - #[tokio::test] - async fn parse_complex() { - let bytes = r#"### 2024-02-10 22:16:28.302 -select * from remote('127.0.0.1:20222', view(select 1)) -### 2024-02-10 22:16:36.919 -select * from numbers(10) -### 2024-02-10 22:16:41.710 -select * from system.numbers -### 2024-02-10 22:19:28.655 -select 1 -### 2024-02-22 11:15:33.046 -CREATE TABLE test( stamp DateTime('UTC'))ENGINE = MergeTreePARTITION BY toDate(stamp)order by tuple() as select toDateTime('2020-01-01')+number*60 from numbers(80000); -"# - .as_bytes() - .to_owned(); - - let replxx = Replxx { bytes }; - - let mut loader = TestLoader::default(); - replxx.load(&mut loader).await.unwrap(); - let mut history = loader.buf.into_iter(); - - // simple wrapper for replxx history entry - macro_rules! history { - ($timestamp:expr_2021, $command:expr_2021) => { - let h = history.next().expect("missing entry in history"); - assert_eq!(h.command.as_str(), $command); - assert_eq!(h.timestamp.unix_timestamp(), $timestamp); - }; - } - - history!( - 1707603388, - "select * from remote('127.0.0.1:20222', view(select 1))" - ); - history!(1707603396, "select * from numbers(10)"); - history!(1707603401, "select * from system.numbers"); - history!(1707603568, "select 1"); - history!( - 1708600533, - "CREATE TABLE test\n( stamp DateTime('UTC'))\nENGINE = MergeTree\nPARTITION BY toDate(stamp)\norder by tuple() as select toDateTime('2020-01-01')+number*60 from numbers(80000);" - ); - } -} diff --git a/crates/atuin-client/src/import/resh.rs b/crates/atuin-client/src/import/resh.rs deleted file mode 100644 index df15f5b4..00000000 --- a/crates/atuin-client/src/import/resh.rs +++ /dev/null @@ -1,140 +0,0 @@ -use std::path::PathBuf; - -use async_trait::async_trait; -use directories::UserDirs; -use eyre::{Result, eyre}; -use serde::Deserialize; - -use atuin_common::utils::uuid_v7; -use time::OffsetDateTime; - -use super::{Importer, Loader, get_histfile_path, unix_byte_lines}; -use crate::history::History; -use crate::import::read_to_end; - -#[derive(Deserialize, Debug)] -#[serde(rename_all = "camelCase")] -pub struct ReshEntry { - pub cmd_line: String, - pub exit_code: i64, - pub shell: String, - pub uname: String, - pub session_id: String, - pub home: String, - pub lang: String, - pub lc_all: String, - pub login: String, - pub pwd: String, - pub pwd_after: String, - pub shell_env: String, - pub term: String, - pub real_pwd: String, - pub real_pwd_after: String, - pub pid: i64, - pub session_pid: i64, - pub host: String, - pub hosttype: String, - pub ostype: String, - pub machtype: String, - pub shlvl: i64, - pub timezone_before: String, - pub timezone_after: String, - pub realtime_before: f64, - pub realtime_after: f64, - pub realtime_before_local: f64, - pub realtime_after_local: f64, - pub realtime_duration: f64, - pub realtime_since_session_start: f64, - pub realtime_since_boot: f64, - pub git_dir: String, - pub git_real_dir: String, - pub git_origin_remote: String, - pub git_dir_after: String, - pub git_real_dir_after: String, - pub git_origin_remote_after: String, - pub machine_id: String, - pub os_release_id: String, - pub os_release_version_id: String, - pub os_release_id_like: String, - pub os_release_name: String, - pub os_release_pretty_name: String, - pub resh_uuid: String, - pub resh_version: String, - pub resh_revision: String, - pub parts_merged: bool, - pub recalled: bool, - pub recall_last_cmd_line: String, - pub cols: String, - pub lines: String, -} - -#[derive(Debug)] -pub struct Resh { - bytes: Vec<u8>, -} - -fn default_histpath() -> Result<PathBuf> { - let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; - let home_dir = user_dirs.home_dir(); - - Ok(home_dir.join(".resh_history.json")) -} - -#[async_trait] -impl Importer for Resh { - const NAME: &'static str = "resh"; - - async fn new() -> Result<Self> { - let bytes = read_to_end(get_histfile_path(default_histpath)?)?; - Ok(Self { bytes }) - } - - async fn entries(&mut self) -> Result<usize> { - Ok(super::count_lines(&self.bytes)) - } - - async fn load(self, h: &mut impl Loader) -> Result<()> { - for b in unix_byte_lines(&self.bytes) { - let s = match std::str::from_utf8(b) { - Ok(s) => s, - Err(_) => continue, // we can skip past things like invalid utf8 - }; - let entry = match serde_json::from_str::<ReshEntry>(s) { - Ok(e) => e, - Err(_) => continue, // skip invalid json :shrug: - }; - - #[expect(clippy::cast_possible_truncation)] - #[expect(clippy::cast_sign_loss)] - let timestamp = { - let secs = entry.realtime_before.floor() as i64; - let nanosecs = (entry.realtime_before.fract() * 1_000_000_000_f64).round() as i64; - OffsetDateTime::from_unix_timestamp(secs)? + time::Duration::nanoseconds(nanosecs) - }; - #[expect(clippy::cast_possible_truncation)] - #[expect(clippy::cast_sign_loss)] - let duration = { - let secs = entry.realtime_after.floor() as i64; - let nanosecs = (entry.realtime_after.fract() * 1_000_000_000_f64).round() as i64; - let base = OffsetDateTime::from_unix_timestamp(secs)? - + time::Duration::nanoseconds(nanosecs); - let difference = base - timestamp; - difference.whole_nanoseconds() as i64 - }; - - let imported = History::import() - .command(entry.cmd_line) - .timestamp(timestamp) - .duration(duration) - .exit(entry.exit_code) - .cwd(entry.pwd) - .hostname(entry.host) - // CHECK: should we add uuid here? It's not set in the other importers - .session(uuid_v7().as_simple().to_string()); - - h.push(imported.build().into()).await?; - } - - Ok(()) - } -} diff --git a/crates/atuin-client/src/import/xonsh.rs b/crates/atuin-client/src/import/xonsh.rs deleted file mode 100644 index 6f38de68..00000000 --- a/crates/atuin-client/src/import/xonsh.rs +++ /dev/null @@ -1,234 +0,0 @@ -use std::env; -use std::fs::{self, File}; -use std::path::{Path, PathBuf}; - -use async_trait::async_trait; -use directories::BaseDirs; -use eyre::{Result, eyre}; -use serde::Deserialize; -use time::OffsetDateTime; -use uuid::Uuid; -use uuid::timestamp::{Timestamp, context::NoContext}; - -use super::{Importer, Loader, get_histdir_path}; -use crate::history::History; -use crate::utils::get_host_user; - -// Note: both HistoryFile and HistoryData have other keys present in the JSON, we don't -// care about them so we leave them unspecified so as to avoid deserializing unnecessarily. -#[derive(Debug, Deserialize)] -struct HistoryFile { - data: HistoryData, -} - -#[derive(Debug, Deserialize)] -struct HistoryData { - sessionid: String, - cmds: Vec<HistoryCmd>, -} - -#[derive(Debug, Deserialize)] -struct HistoryCmd { - cwd: String, - inp: String, - rtn: Option<i64>, - ts: (f64, f64), -} - -#[derive(Debug)] -pub struct Xonsh { - // history is stored as a bunch of json files, one per session - sessions: Vec<HistoryData>, - hostname: String, -} - -fn xonsh_hist_dir(xonsh_data_dir: Option<String>) -> Result<PathBuf> { - // if running within xonsh, this will be available - if let Some(d) = xonsh_data_dir { - let mut path = PathBuf::from(d); - path.push("history_json"); - return Ok(path); - } - - // otherwise, fall back to default - let base = BaseDirs::new().ok_or_else(|| eyre!("Could not determine home directory"))?; - - let hist_dir = base.data_dir().join("xonsh/history_json"); - if hist_dir.exists() || cfg!(test) { - Ok(hist_dir) - } else { - Err(eyre!("Could not find xonsh history files")) - } -} - -fn load_sessions(hist_dir: &Path) -> Result<Vec<HistoryData>> { - let mut sessions = vec![]; - for entry in fs::read_dir(hist_dir)? { - let p = entry?.path(); - let ext = p.extension().and_then(|e| e.to_str()); - if p.is_file() - && ext == Some("json") - && let Some(data) = load_session(&p)? - { - sessions.push(data); - } - } - Ok(sessions) -} - -fn load_session(path: &Path) -> Result<Option<HistoryData>> { - let file = File::open(path)?; - // empty files are not valid json, so we can't deserialize them - if file.metadata()?.len() == 0 { - return Ok(None); - } - - let mut hist_file: HistoryFile = serde_json::from_reader(file)?; - - // if there are commands in this session, replace the existing UUIDv4 - // with a UUIDv7 generated from the timestamp of the first command - if let Some(cmd) = hist_file.data.cmds.first() { - let seconds = cmd.ts.0.trunc() as u64; - let nanos = (cmd.ts.0.fract() * 1_000_000_000_f64) as u32; - let ts = Timestamp::from_unix(NoContext, seconds, nanos); - hist_file.data.sessionid = Uuid::new_v7(ts).to_string(); - } - Ok(Some(hist_file.data)) -} - -#[async_trait] -impl Importer for Xonsh { - const NAME: &'static str = "xonsh"; - - async fn new() -> Result<Self> { - // wrap xonsh-specific path resolver in general one so that it respects $HISTPATH - let xonsh_data_dir = env::var("XONSH_DATA_DIR").ok(); - let hist_dir = get_histdir_path(|| xonsh_hist_dir(xonsh_data_dir))?; - let sessions = load_sessions(&hist_dir)?; - let hostname = get_host_user(); - Ok(Xonsh { sessions, hostname }) - } - - async fn entries(&mut self) -> Result<usize> { - let total = self.sessions.iter().map(|s| s.cmds.len()).sum(); - Ok(total) - } - - async fn load(self, loader: &mut impl Loader) -> Result<()> { - for session in self.sessions { - for cmd in session.cmds { - let (start, end) = cmd.ts; - let ts_nanos = (start * 1_000_000_000_f64) as i128; - let timestamp = OffsetDateTime::from_unix_timestamp_nanos(ts_nanos)?; - - let duration = (end - start) * 1_000_000_000_f64; - - match cmd.rtn { - Some(exit) => { - let entry = History::import() - .timestamp(timestamp) - .duration(duration.trunc() as i64) - .exit(exit) - .command(cmd.inp.trim()) - .cwd(cmd.cwd) - .session(session.sessionid.clone()) - .hostname(self.hostname.clone()); - loader.push(entry.build().into()).await?; - } - None => { - let entry = History::import() - .timestamp(timestamp) - .duration(duration.trunc() as i64) - .command(cmd.inp.trim()) - .cwd(cmd.cwd) - .session(session.sessionid.clone()) - .hostname(self.hostname.clone()); - loader.push(entry.build().into()).await?; - } - } - } - } - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use time::macros::datetime; - - use super::*; - - use crate::history::History; - use crate::import::tests::TestLoader; - - #[test] - fn test_hist_dir_xonsh() { - let hist_dir = xonsh_hist_dir(Some("/home/user/xonsh_data".to_string())).unwrap(); - assert_eq!( - hist_dir, - PathBuf::from("/home/user/xonsh_data/history_json") - ); - } - - #[tokio::test] - async fn test_import() { - let dir = PathBuf::from("tests/data/xonsh"); - let sessions = load_sessions(&dir).unwrap(); - let hostname = "box:user".to_string(); - let xonsh = Xonsh { sessions, hostname }; - - let mut loader = TestLoader::default(); - xonsh.load(&mut loader).await.unwrap(); - // order in buf will depend on filenames, so sort by timestamp for consistency - loader.buf.sort_by_key(|h| h.timestamp); - for (actual, expected) in loader.buf.iter().zip(expected_hist_entries().iter()) { - assert_eq!(actual.timestamp, expected.timestamp); - assert_eq!(actual.command, expected.command); - assert_eq!(actual.cwd, expected.cwd); - assert_eq!(actual.exit, expected.exit); - assert_eq!(actual.duration, expected.duration); - assert_eq!(actual.hostname, expected.hostname); - } - } - - fn expected_hist_entries() -> [History; 4] { - [ - History::import() - .timestamp(datetime!(2024-02-6 04:17:59.478272256 +00:00:00)) - .command("echo hello world!".to_string()) - .cwd("/home/user/Documents/code/atuin".to_string()) - .exit(0) - .duration(4651069) - .hostname("box:user".to_string()) - .build() - .into(), - History::import() - .timestamp(datetime!(2024-02-06 04:18:01.70632832 +00:00:00)) - .command("ls -l".to_string()) - .cwd("/home/user/Documents/code/atuin".to_string()) - .exit(0) - .duration(21288633) - .hostname("box:user".to_string()) - .build() - .into(), - History::import() - .timestamp(datetime!(2024-02-06 17:41:31.142515968 +00:00:00)) - .command("false".to_string()) - .cwd("/home/user/Documents/code/atuin/atuin-client".to_string()) - .exit(1) - .duration(10269403) - .hostname("box:user".to_string()) - .build() - .into(), - History::import() - .timestamp(datetime!(2024-02-06 17:41:32.271584 +00:00:00)) - .command("exit".to_string()) - .cwd("/home/user/Documents/code/atuin/atuin-client".to_string()) - .exit(0) - .duration(4259347) - .hostname("box:user".to_string()) - .build() - .into(), - ] - } -} diff --git a/crates/atuin-client/src/import/xonsh_sqlite.rs b/crates/atuin-client/src/import/xonsh_sqlite.rs deleted file mode 100644 index 7d50ac84..00000000 --- a/crates/atuin-client/src/import/xonsh_sqlite.rs +++ /dev/null @@ -1,217 +0,0 @@ -use std::env; -use std::path::PathBuf; - -use async_trait::async_trait; -use directories::BaseDirs; -use eyre::{Result, eyre}; -use futures::TryStreamExt; -use sqlx::{FromRow, Row, sqlite::SqlitePool}; -use time::OffsetDateTime; -use uuid::Uuid; -use uuid::timestamp::{Timestamp, context::NoContext}; - -use super::{Importer, Loader, get_histfile_path}; -use crate::history::History; -use crate::utils::get_host_user; - -#[derive(Debug, FromRow)] -struct HistDbEntry { - inp: String, - rtn: Option<i64>, - tsb: f64, - tse: f64, - cwd: String, - session_start: f64, -} - -impl HistDbEntry { - fn into_hist_with_hostname(self, hostname: String) -> History { - let ts_nanos = (self.tsb * 1_000_000_000_f64) as i128; - let timestamp = OffsetDateTime::from_unix_timestamp_nanos(ts_nanos).unwrap(); - - let session_ts_seconds = self.session_start.trunc() as u64; - let session_ts_nanos = (self.session_start.fract() * 1_000_000_000_f64) as u32; - let session_ts = Timestamp::from_unix(NoContext, session_ts_seconds, session_ts_nanos); - let session_id = Uuid::new_v7(session_ts).to_string(); - let duration = (self.tse - self.tsb) * 1_000_000_000_f64; - - if let Some(exit) = self.rtn { - let imported = History::import() - .timestamp(timestamp) - .duration(duration.trunc() as i64) - .exit(exit) - .command(self.inp) - .cwd(self.cwd) - .session(session_id) - .hostname(hostname); - imported.build().into() - } else { - let imported = History::import() - .timestamp(timestamp) - .duration(duration.trunc() as i64) - .command(self.inp) - .cwd(self.cwd) - .session(session_id) - .hostname(hostname); - imported.build().into() - } - } -} - -fn xonsh_db_path(xonsh_data_dir: Option<String>) -> Result<PathBuf> { - // if running within xonsh, this will be available - if let Some(d) = xonsh_data_dir { - let mut path = PathBuf::from(d); - path.push("xonsh-history.sqlite"); - return Ok(path); - } - - // otherwise, fall back to default - let base = BaseDirs::new().ok_or_else(|| eyre!("Could not determine home directory"))?; - - let hist_file = base.data_dir().join("xonsh/xonsh-history.sqlite"); - if hist_file.exists() || cfg!(test) { - Ok(hist_file) - } else { - Err(eyre!( - "Could not find xonsh history db at: {}", - hist_file.to_string_lossy() - )) - } -} - -#[derive(Debug)] -pub struct XonshSqlite { - pool: SqlitePool, - hostname: String, -} - -#[async_trait] -impl Importer for XonshSqlite { - const NAME: &'static str = "xonsh_sqlite"; - - async fn new() -> Result<Self> { - // wrap xonsh-specific path resolver in general one so that it respects $HISTPATH - let xonsh_data_dir = env::var("XONSH_DATA_DIR").ok(); - let db_path = get_histfile_path(|| xonsh_db_path(xonsh_data_dir))?; - let connection_str = db_path.to_str().ok_or_else(|| { - eyre!( - "Invalid path for SQLite database: {}", - db_path.to_string_lossy() - ) - })?; - - let pool = SqlitePool::connect(connection_str).await?; - let hostname = get_host_user(); - Ok(XonshSqlite { pool, hostname }) - } - - async fn entries(&mut self) -> Result<usize> { - let query = "SELECT COUNT(*) FROM xonsh_history"; - let row = sqlx::query(query).fetch_one(&self.pool).await?; - let count: u32 = row.get(0); - Ok(count as usize) - } - - async fn load(self, loader: &mut impl Loader) -> Result<()> { - let query = r#" - SELECT inp, rtn, tsb, tse, cwd, - MIN(tsb) OVER (PARTITION BY sessionid) AS session_start - FROM xonsh_history - ORDER BY rowid - "#; - - let mut entries = sqlx::query_as::<_, HistDbEntry>(query).fetch(&self.pool); - - let mut count = 0; - while let Some(entry) = entries.try_next().await? { - let hist = entry.into_hist_with_hostname(self.hostname.clone()); - loader.push(hist).await?; - count += 1; - } - - println!("Loaded: {count}"); - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use time::macros::datetime; - - use super::*; - - use crate::history::History; - use crate::import::tests::TestLoader; - - #[test] - fn test_db_path_xonsh() { - let db_path = xonsh_db_path(Some("/home/user/xonsh_data".to_string())).unwrap(); - assert_eq!( - db_path, - PathBuf::from("/home/user/xonsh_data/xonsh-history.sqlite") - ); - } - - #[tokio::test] - async fn test_import() { - let connection_str = "tests/data/xonsh-history.sqlite"; - let xonsh_sqlite = XonshSqlite { - pool: SqlitePool::connect(connection_str).await.unwrap(), - hostname: "box:user".to_string(), - }; - - let mut loader = TestLoader::default(); - xonsh_sqlite.load(&mut loader).await.unwrap(); - - for (actual, expected) in loader.buf.iter().zip(expected_hist_entries().iter()) { - assert_eq!(actual.timestamp, expected.timestamp); - assert_eq!(actual.command, expected.command); - assert_eq!(actual.cwd, expected.cwd); - assert_eq!(actual.exit, expected.exit); - assert_eq!(actual.duration, expected.duration); - assert_eq!(actual.hostname, expected.hostname); - } - } - - fn expected_hist_entries() -> [History; 4] { - [ - History::import() - .timestamp(datetime!(2024-02-6 17:56:21.130956288 +00:00:00)) - .command("echo hello world!".to_string()) - .cwd("/home/user/Documents/code/atuin".to_string()) - .exit(0) - .duration(2628564) - .hostname("box:user".to_string()) - .build() - .into(), - History::import() - .timestamp(datetime!(2024-02-06 17:56:28.190406144 +00:00:00)) - .command("ls -l".to_string()) - .cwd("/home/user/Documents/code/atuin".to_string()) - .exit(0) - .duration(9371519) - .hostname("box:user".to_string()) - .build() - .into(), - History::import() - .timestamp(datetime!(2024-02-06 17:56:46.989020928 +00:00:00)) - .command("false".to_string()) - .cwd("/home/user/Documents/code/atuin".to_string()) - .exit(1) - .duration(17337560) - .hostname("box:user".to_string()) - .build() - .into(), - History::import() - .timestamp(datetime!(2024-02-06 17:56:48.218384128 +00:00:00)) - .command("exit".to_string()) - .cwd("/home/user/Documents/code/atuin".to_string()) - .exit(0) - .duration(4599094) - .hostname("box:user".to_string()) - .build() - .into(), - ] - } -} diff --git a/crates/atuin-client/src/import/zsh.rs b/crates/atuin-client/src/import/zsh.rs deleted file mode 100644 index 11e2f371..00000000 --- a/crates/atuin-client/src/import/zsh.rs +++ /dev/null @@ -1,230 +0,0 @@ -// import old shell history! -// automatically hoover up all that we can find - -use std::borrow::Cow; -use std::path::PathBuf; - -use async_trait::async_trait; -use directories::UserDirs; -use eyre::{Result, eyre}; -use time::OffsetDateTime; - -use super::{Importer, Loader, get_histfile_path, unix_byte_lines}; -use crate::history::History; -use crate::import::read_to_end; - -#[derive(Debug)] -pub struct Zsh { - bytes: Vec<u8>, -} - -fn default_histpath() -> Result<PathBuf> { - // oh-my-zsh sets HISTFILE=~/.zhistory - // zsh has no default value for this var, but uses ~/.zhistory. - // zsh-newuser-install propose as default .histfile https://github.com/zsh-users/zsh/blob/master/Functions/Newuser/zsh-newuser-install#L794 - // we could maybe be smarter about this in the future :) - let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; - let home_dir = user_dirs.home_dir(); - - let mut candidates = [".zhistory", ".zsh_history", ".histfile"].iter(); - loop { - match candidates.next() { - Some(candidate) => { - let histpath = home_dir.join(candidate); - if histpath.exists() { - break Ok(histpath); - } - } - None => { - break Err(eyre!( - "Could not find history file. Try setting and exporting $HISTFILE" - )); - } - } - } -} - -#[async_trait] -impl Importer for Zsh { - const NAME: &'static str = "zsh"; - - async fn new() -> Result<Self> { - let bytes = read_to_end(get_histfile_path(default_histpath)?)?; - Ok(Self { bytes }) - } - - async fn entries(&mut self) -> Result<usize> { - Ok(super::count_lines(&self.bytes)) - } - - async fn load(self, h: &mut impl Loader) -> Result<()> { - let now = OffsetDateTime::now_utc(); - let mut line = String::new(); - - let mut counter = 0; - for b in unix_byte_lines(&self.bytes) { - let s = match unmetafy(b) { - Some(s) => s, - _ => continue, // we can skip past things like invalid utf8 - }; - - if let Some(s) = s.strip_suffix('\\') { - line.push_str(s); - line.push('\n'); - } else { - line.push_str(&s); - let command = std::mem::take(&mut line); - - if let Some(command) = command.strip_prefix(": ") { - counter += 1; - h.push(parse_extended(command, counter)).await?; - } else { - let offset = time::Duration::seconds(counter); - counter += 1; - - let imported = History::import() - // preserve ordering - .timestamp(now - offset) - .command(command.trim_end().to_string()); - - h.push(imported.build().into()).await?; - } - } - } - - Ok(()) - } -} - -fn parse_extended(line: &str, counter: i64) -> History { - let (time, duration) = line.split_once(':').unwrap(); - let (duration, command) = duration.split_once(';').unwrap(); - - let time = time - .parse::<i64>() - .ok() - .and_then(|t| OffsetDateTime::from_unix_timestamp(t).ok()) - .unwrap_or_else(OffsetDateTime::now_utc) - + time::Duration::milliseconds(counter); - - // use nanos, because why the hell not? we won't display them. - let duration = duration.parse::<i64>().map_or(-1, |t| t * 1_000_000_000); - - let imported = History::import() - .timestamp(time) - .command(command.trim_end().to_string()) - .duration(duration); - - imported.build().into() -} - -fn unmetafy(line: &[u8]) -> Option<Cow<'_, str>> { - if line.contains(&0x83) { - let mut s = Vec::with_capacity(line.len()); - let mut is_meta = false; - for ch in line { - if *ch == 0x83 { - is_meta = true; - } else if is_meta { - is_meta = false; - s.push(*ch ^ 32); - } else { - s.push(*ch) - } - } - String::from_utf8(s).ok().map(Cow::Owned) - } else { - std::str::from_utf8(line).ok().map(Cow::Borrowed) - } -} - -#[cfg(test)] -mod test { - use itertools::assert_equal; - - use crate::import::tests::TestLoader; - - use super::*; - - #[test] - fn test_parse_extended_simple() { - let parsed = parse_extended("1613322469:0;cargo install atuin", 0); - - assert_eq!(parsed.command, "cargo install atuin"); - assert_eq!(parsed.duration, 0); - assert_eq!( - parsed.timestamp, - OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() - ); - - let parsed = parse_extended("1613322469:10;cargo install atuin;cargo update", 0); - - assert_eq!(parsed.command, "cargo install atuin;cargo update"); - assert_eq!(parsed.duration, 10_000_000_000); - assert_eq!( - parsed.timestamp, - OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() - ); - - let parsed = parse_extended("1613322469:10;cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", 0); - - assert_eq!(parsed.command, "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷"); - assert_eq!(parsed.duration, 10_000_000_000); - assert_eq!( - parsed.timestamp, - OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() - ); - - let parsed = parse_extended("1613322469:10;cargo install \\n atuin\n", 0); - - assert_eq!(parsed.command, "cargo install \\n atuin"); - assert_eq!(parsed.duration, 10_000_000_000); - assert_eq!( - parsed.timestamp, - OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() - ); - } - - #[tokio::test] - async fn test_parse_file() { - let bytes = r": 1613322469:0;cargo install atuin -: 1613322469:10;cargo install atuin; \\ -cargo update -: 1613322469:10;cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷ -" - .as_bytes() - .to_owned(); - - let mut zsh = Zsh { bytes }; - assert_eq!(zsh.entries().await.unwrap(), 4); - - let mut loader = TestLoader::default(); - zsh.load(&mut loader).await.unwrap(); - - assert_equal( - loader.buf.iter().map(|h| h.command.as_str()), - [ - "cargo install atuin", - "cargo install atuin; \\\ncargo update", - "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", - ], - ); - } - - #[tokio::test] - async fn test_parse_metafied() { - let bytes = - b"echo \xe4\xbd\x83\x80\xe5\xa5\xbd\nls ~/\xe9\x83\xbf\xb3\xe4\xb9\x83\xb0\n".to_vec(); - - let mut zsh = Zsh { bytes }; - assert_eq!(zsh.entries().await.unwrap(), 2); - - let mut loader = TestLoader::default(); - zsh.load(&mut loader).await.unwrap(); - - assert_equal( - loader.buf.iter().map(|h| h.command.as_str()), - ["echo 你好", "ls ~/音乐"], - ); - } -} diff --git a/crates/atuin-client/src/import/zsh_histdb.rs b/crates/atuin-client/src/import/zsh_histdb.rs deleted file mode 100644 index bf44c3ad..00000000 --- a/crates/atuin-client/src/import/zsh_histdb.rs +++ /dev/null @@ -1,249 +0,0 @@ -// import old shell history from zsh-histdb! -// automatically hoover up all that we can find - -// As far as i can tell there are no version numbers in the histdb sqlite DB, so we're going based -// on the schema from 2022-05-01 -// -// I have run into some histories that will not import b/c of non UTF-8 characters. -// - -// -// An Example sqlite query for hsitdb data: -// -//id|session|command_id|place_id|exit_status|start_time|duration|id|argv|id|host|dir -// -// -// select -// history.id, -// history.start_time, -// places.host, -// places.dir, -// commands.argv -// from history -// left join commands on history.command_id = commands.id -// left join places on history.place_id = places.id ; -// -// CREATE TABLE history (id integer primary key autoincrement, -// session int, -// command_id int references commands (id), -// place_id int references places (id), -// exit_status int, -// start_time int, -// duration int); -// - -use std::collections::HashMap; -use std::path::{Path, PathBuf}; - -use async_trait::async_trait; -use atuin_common::utils::uuid_v7; -use directories::UserDirs; -use eyre::{Result, eyre}; -use sqlx::{Pool, sqlite::SqlitePool}; -use time::PrimitiveDateTime; - -use super::Importer; -use crate::history::History; -use crate::import::Loader; -use crate::utils::{get_hostname, get_username}; - -#[derive(sqlx::FromRow, Debug)] -pub struct HistDbEntryCount { - pub count: usize, -} - -#[derive(sqlx::FromRow, Debug)] -pub struct HistDbEntry { - pub id: i64, - pub start_time: PrimitiveDateTime, - pub host: Vec<u8>, - pub dir: Vec<u8>, - pub argv: Vec<u8>, - pub duration: i64, - pub exit_status: i64, - pub session: i64, -} - -#[derive(Debug)] -pub struct ZshHistDb { - histdb: Vec<HistDbEntry>, - username: String, -} - -/// Read db at given file, return vector of entries. -async fn hist_from_db(dbpath: PathBuf) -> Result<Vec<HistDbEntry>> { - let pool = SqlitePool::connect(dbpath.to_str().unwrap()).await?; - hist_from_db_conn(pool).await -} - -async fn hist_from_db_conn(pool: Pool<sqlx::Sqlite>) -> Result<Vec<HistDbEntry>> { - let query = r#" - SELECT - history.id, history.start_time, history.duration, places.host, places.dir, - commands.argv, history.exit_status, history.session - FROM history - LEFT JOIN commands ON history.command_id = commands.id - LEFT JOIN places ON history.place_id = places.id - ORDER BY history.start_time - "#; - let histdb_vec: Vec<HistDbEntry> = sqlx::query_as::<_, HistDbEntry>(query) - .fetch_all(&pool) - .await?; - Ok(histdb_vec) -} - -impl ZshHistDb { - pub fn histpath_candidate() -> PathBuf { - // By default histdb database is `${HOME}/.histdb/zsh-history.db` - // This can be modified by ${HISTDB_FILE} - // - // if [[ -z ${HISTDB_FILE} ]]; then - // typeset -g HISTDB_FILE="${HOME}/.histdb/zsh-history.db" - let user_dirs = UserDirs::new().unwrap(); // should catch error here? - let home_dir = user_dirs.home_dir(); - std::env::var("HISTDB_FILE") - .as_ref() - .map(|x| Path::new(x).to_path_buf()) - .unwrap_or_else(|_err| home_dir.join(".histdb/zsh-history.db")) - } - pub fn histpath() -> Result<PathBuf> { - let histdb_path = ZshHistDb::histpath_candidate(); - if histdb_path.exists() { - Ok(histdb_path) - } else { - Err(eyre!( - "Could not find history file. Try setting $HISTDB_FILE" - )) - } - } -} - -#[async_trait] -impl Importer for ZshHistDb { - // Not sure how this is used - const NAME: &'static str = "zsh_histdb"; - - /// Creates a new ZshHistDb and populates the history based on the pre-populated data - /// structure. - async fn new() -> Result<Self> { - let dbpath = ZshHistDb::histpath()?; - let histdb_entry_vec = hist_from_db(dbpath).await?; - Ok(Self { - histdb: histdb_entry_vec, - username: get_username(), - }) - } - - async fn entries(&mut self) -> Result<usize> { - Ok(self.histdb.len()) - } - - async fn load(self, h: &mut impl Loader) -> Result<()> { - let mut session_map = HashMap::new(); - for entry in self.histdb { - let command = match std::str::from_utf8(&entry.argv) { - Ok(s) => s.trim_end(), - Err(_) => continue, // we can skip past things like invalid utf8 - }; - let cwd = match std::str::from_utf8(&entry.dir) { - Ok(s) => s.trim_end(), - Err(_) => continue, // we can skip past things like invalid utf8 - }; - let hostname = format!( - "{}:{}", - String::from_utf8(entry.host).unwrap_or_else(|_e| get_hostname()), - self.username - ); - let session = session_map.entry(entry.session).or_insert_with(uuid_v7); - - let imported = History::import() - .timestamp(entry.start_time.assume_utc()) - .command(command) - .cwd(cwd) - .duration(entry.duration * 1_000_000_000) - .exit(entry.exit_status) - .session(session.as_simple().to_string()) - .hostname(hostname) - .build(); - h.push(imported.into()).await?; - } - Ok(()) - } -} - -#[cfg(test)] -mod test { - - use super::*; - use sqlx::sqlite::SqlitePoolOptions; - use std::env; - #[tokio::test(flavor = "multi_thread")] - #[expect(unsafe_code)] - async fn test_env_vars() { - let test_env_db = "nonstd-zsh-history.db"; - let key = "HISTDB_FILE"; - // TODO: Audit that the environment access only happens in single-threaded code. - unsafe { env::set_var(key, test_env_db) }; - - // test the env got set - assert_eq!(env::var(key).unwrap(), test_env_db.to_string()); - - // test histdb returns the proper db from previous step - let histdb_path = ZshHistDb::histpath_candidate(); - assert_eq!(histdb_path.to_str().unwrap(), test_env_db); - } - - #[tokio::test(flavor = "multi_thread")] - async fn test_import() { - let pool: SqlitePool = SqlitePoolOptions::new() - .min_connections(2) - .connect(":memory:") - .await - .unwrap(); - - // sql dump directly from a test database. - let db_sql = r#" - PRAGMA foreign_keys=OFF; - BEGIN TRANSACTION; - CREATE TABLE commands (id integer primary key autoincrement, argv text, unique(argv) on conflict ignore); - INSERT INTO commands VALUES(1,'pwd'); - INSERT INTO commands VALUES(2,'curl google.com'); - INSERT INTO commands VALUES(3,'bash'); - CREATE TABLE places (id integer primary key autoincrement, host text, dir text, unique(host, dir) on conflict ignore); - INSERT INTO places VALUES(1,'mbp16.local','/home/noyez'); - CREATE TABLE history (id integer primary key autoincrement, - session int, - command_id int references commands (id), - place_id int references places (id), - exit_status int, - start_time int, - duration int); - INSERT INTO history VALUES(1,0,1,1,0,1651497918,1); - INSERT INTO history VALUES(2,0,2,1,0,1651497923,1); - INSERT INTO history VALUES(3,0,3,1,NULL,1651497930,NULL); - DELETE FROM sqlite_sequence; - INSERT INTO sqlite_sequence VALUES('commands',3); - INSERT INTO sqlite_sequence VALUES('places',3); - INSERT INTO sqlite_sequence VALUES('history',3); - CREATE INDEX hist_time on history(start_time); - CREATE INDEX place_dir on places(dir); - CREATE INDEX place_host on places(host); - CREATE INDEX history_command_place on history(command_id, place_id); - COMMIT; "#; - - sqlx::query(db_sql).execute(&pool).await.unwrap(); - - // test histdb iterator - let histdb_vec = hist_from_db_conn(pool).await.unwrap(); - let histdb = ZshHistDb { - histdb: histdb_vec, - username: get_username(), - }; - - println!("h: {:#?}", histdb.histdb); - println!("counter: {:?}", histdb.histdb.len()); - for i in histdb.histdb { - println!("{i:?}"); - } - } -} diff --git a/crates/atuin-client/src/lib.rs b/crates/atuin-client/src/lib.rs deleted file mode 100644 index cd7785e1..00000000 --- a/crates/atuin-client/src/lib.rs +++ /dev/null @@ -1,31 +0,0 @@ -#![deny(unsafe_code)] - -#[macro_use] -extern crate log; - -#[cfg(feature = "sync")] -pub mod api_client; -#[cfg(feature = "sync")] -pub mod auth; -#[cfg(feature = "sync")] -pub mod login; -#[cfg(feature = "sync")] -pub mod register; -#[cfg(feature = "sync")] -pub mod sync; - -pub mod database; -pub mod distro; -pub mod encryption; -pub mod history; -pub mod import; -pub mod logout; -pub mod meta; -pub mod ordering; -pub mod plugin; -pub mod record; -pub mod secrets; -pub mod settings; -pub mod theme; - -mod utils; diff --git a/crates/atuin-client/src/login.rs b/crates/atuin-client/src/login.rs deleted file mode 100644 index 2545e890..00000000 --- a/crates/atuin-client/src/login.rs +++ /dev/null @@ -1,68 +0,0 @@ -use std::path::PathBuf; - -use atuin_common::api::LoginRequest; -use eyre::{Context, Result, bail}; -use tokio::fs::File; -use tokio::io::AsyncWriteExt; - -use crate::{ - api_client, - encryption::{decode_key, load_key}, - record::{sqlite_store::SqliteStore, store::Store}, - settings::Settings, -}; - -pub async fn login( - settings: &Settings, - store: &SqliteStore, - username: String, - password: String, - key: String, -) -> Result<String> { - let key_path = settings.key_path.as_str(); - let key_path = PathBuf::from(key_path); - - if !key_path.exists() { - if decode_key(key.clone()).is_err() { - bail!("the specified key was invalid"); - } - - let mut file = File::create(&key_path).await?; - file.write_all(key.as_bytes()).await?; - } else { - // we now know that the user has logged in specifying a key, AND that the key path - // exists - - // 1. check if the saved key and the provided key match. if so, nothing to do. - // 2. if not, re-encrypt the local history and overwrite the key - let current_key: [u8; 32] = load_key(settings)?.into(); - - let encoded = key.clone(); // gonna want to save it in a bit - let new_key: [u8; 32] = decode_key(key) - .context("could not decode provided key - is not valid base64")? - .into(); - - if new_key != current_key { - println!("\nRe-encrypting local store with new key"); - - store.re_encrypt(¤t_key, &new_key).await?; - - println!("Writing new key"); - let mut file = File::create(&key_path).await?; - file.write_all(encoded.as_bytes()).await?; - } - } - - let session = api_client::login( - settings.sync_address.as_str(), - LoginRequest { username, password }, - ) - .await?; - - Settings::meta_store() - .await? - .save_session(&session.session) - .await?; - - Ok(session.session) -} diff --git a/crates/atuin-client/src/logout.rs b/crates/atuin-client/src/logout.rs deleted file mode 100644 index f720b302..00000000 --- a/crates/atuin-client/src/logout.rs +++ /dev/null @@ -1,16 +0,0 @@ -use eyre::Result; - -use crate::settings::Settings; - -pub async fn logout() -> Result<()> { - let meta = Settings::meta_store().await?; - - if meta.logged_in().await? { - meta.delete_session().await?; - println!("You have logged out!"); - } else { - println!("You are not logged in"); - } - - Ok(()) -} diff --git a/crates/atuin-client/src/meta.rs b/crates/atuin-client/src/meta.rs deleted file mode 100644 index 870f36d0..00000000 --- a/crates/atuin-client/src/meta.rs +++ /dev/null @@ -1,365 +0,0 @@ -use std::path::Path; -use std::str::FromStr; -use std::time::Duration; - -use atuin_common::record::HostId; -use eyre::{Result, eyre}; -use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions}; -use time::{OffsetDateTime, format_description::well_known::Rfc3339}; -use tokio::sync::OnceCell; -use uuid::Uuid; - -// Filenames for the legacy plain-text files that we migrate from. -const LEGACY_HOST_ID_FILENAME: &str = "host_id"; -const LEGACY_LAST_SYNC_FILENAME: &str = "last_sync_time"; -const LEGACY_LAST_VERSION_CHECK_FILENAME: &str = "last_version_check_time"; -const LEGACY_LATEST_VERSION_FILENAME: &str = "latest_version"; -const LEGACY_SESSION_FILENAME: &str = "session"; - -const KEY_HOST_ID: &str = "host_id"; -const KEY_LAST_SYNC: &str = "last_sync_time"; -const KEY_LAST_VERSION_CHECK: &str = "last_version_check_time"; -const KEY_LATEST_VERSION: &str = "latest_version"; -const KEY_SESSION: &str = "session"; -const KEY_FILES_MIGRATED: &str = "files_migrated"; - -pub struct MetaStore { - pool: SqlitePool, - cached_host_id: OnceCell<HostId>, -} - -impl MetaStore { - pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> { - let path = path.as_ref(); - let path_str = path - .as_os_str() - .to_str() - .ok_or_else(|| eyre!("meta database path is not valid UTF-8: {path:?}"))?; - debug!("opening meta sqlite database at {path:?}"); - - let is_memory = path_str.contains(":memory:"); - - if !is_memory - && !path.exists() - && let Some(dir) = path.parent() - { - fs_err::create_dir_all(dir)?; - } - - // Use DELETE journal mode instead of WAL. This is a small, infrequently- - // written KV store — WAL's concurrency benefits aren't needed, and DELETE - // mode avoids creating auxiliary -wal/-shm files that complicate - // permission handling. - let opts = SqliteConnectOptions::from_str(path_str)? - .journal_mode(SqliteJournalMode::Delete) - .optimize_on_close(true, None) - .create_if_missing(true); - - let pool = SqlitePoolOptions::new() - .acquire_timeout(Duration::from_secs_f64(timeout)) - .connect_with(opts) - .await?; - - sqlx::migrate!("./meta-migrations").run(&pool).await?; - - // Session tokens are stored in this database, so restrict permissions. - #[cfg(unix)] - if !is_memory { - use std::os::unix::fs::PermissionsExt; - std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600))?; - } - - let store = Self { - pool, - cached_host_id: OnceCell::const_new(), - }; - - if !is_memory { - store.migrate_files().await?; - } - - Ok(store) - } - - // Generic key-value operations - - pub async fn get(&self, key: &str) -> Result<Option<String>> { - let row: Option<(String,)> = sqlx::query_as("SELECT value FROM meta WHERE key = ?1") - .bind(key) - .fetch_optional(&self.pool) - .await?; - - Ok(row.map(|r| r.0)) - } - - pub async fn set(&self, key: &str, value: &str) -> Result<()> { - sqlx::query( - "INSERT INTO meta (key, value, updated_at) VALUES (?1, ?2, strftime('%s', 'now')) - ON CONFLICT(key) DO UPDATE SET value = ?2, updated_at = strftime('%s', 'now')", - ) - .bind(key) - .bind(value) - .execute(&self.pool) - .await?; - - Ok(()) - } - - pub async fn delete(&self, key: &str) -> Result<()> { - sqlx::query("DELETE FROM meta WHERE key = ?1") - .bind(key) - .execute(&self.pool) - .await?; - - Ok(()) - } - - // Typed accessors - - pub async fn host_id(&self) -> Result<HostId> { - self.cached_host_id - .get_or_try_init(|| async { - if let Some(id) = self.get(KEY_HOST_ID).await? { - let parsed = Uuid::from_str(id.as_str()) - .map_err(|e| eyre!("failed to parse host ID: {e}"))?; - return Ok(HostId(parsed)); - } - - let uuid = atuin_common::utils::uuid_v7(); - self.set(KEY_HOST_ID, uuid.as_simple().to_string().as_ref()) - .await?; - - Ok(HostId(uuid)) - }) - .await - .copied() - } - - pub async fn last_sync(&self) -> Result<OffsetDateTime> { - match self.get(KEY_LAST_SYNC).await? { - Some(v) => Ok(OffsetDateTime::parse(v.as_str(), &Rfc3339)?), - None => Ok(OffsetDateTime::UNIX_EPOCH), - } - } - - pub async fn save_sync_time(&self) -> Result<()> { - self.set( - KEY_LAST_SYNC, - OffsetDateTime::now_utc().format(&Rfc3339)?.as_str(), - ) - .await - } - - pub async fn last_version_check(&self) -> Result<OffsetDateTime> { - match self.get(KEY_LAST_VERSION_CHECK).await? { - Some(v) => Ok(OffsetDateTime::parse(v.as_str(), &Rfc3339)?), - None => Ok(OffsetDateTime::UNIX_EPOCH), - } - } - - pub async fn save_version_check_time(&self) -> Result<()> { - self.set( - KEY_LAST_VERSION_CHECK, - OffsetDateTime::now_utc().format(&Rfc3339)?.as_str(), - ) - .await - } - - pub async fn latest_version(&self) -> Result<Option<String>> { - self.get(KEY_LATEST_VERSION).await - } - - pub async fn save_latest_version(&self, version: &str) -> Result<()> { - self.set(KEY_LATEST_VERSION, version).await - } - - pub async fn session_token(&self) -> Result<Option<String>> { - self.get(KEY_SESSION).await - } - - pub async fn save_session(&self, token: &str) -> Result<()> { - self.set(KEY_SESSION, token).await - } - - pub async fn delete_session(&self) -> Result<()> { - self.delete(KEY_SESSION).await - } - - pub async fn logged_in(&self) -> Result<bool> { - Ok(self.session_token().await?.is_some()) - } - - // File migration: on first open, migrate old plain-text files into the database. - // Old files are left in place for safe downgrades. - - async fn migrate_files(&self) -> Result<()> { - if self.get(KEY_FILES_MIGRATED).await?.is_some() { - return Ok(()); - } - - let data_dir = crate::settings::Settings::effective_data_dir(); - - // host_id — validate as UUID - let host_id_path = data_dir.join(LEGACY_HOST_ID_FILENAME); - if host_id_path.exists() - && let Ok(value) = fs_err::read_to_string(&host_id_path) - { - let value = value.trim(); - if !value.is_empty() { - if Uuid::from_str(value).is_ok() { - self.set(KEY_HOST_ID, value).await?; - } else { - warn!("skipping migration of host_id: invalid UUID {value:?}"); - } - } - } - - // last_sync_time — validate as RFC3339 - let sync_path = data_dir.join(LEGACY_LAST_SYNC_FILENAME); - if sync_path.exists() - && let Ok(value) = fs_err::read_to_string(&sync_path) - { - let value = value.trim(); - if !value.is_empty() { - if OffsetDateTime::parse(value, &Rfc3339).is_ok() { - self.set(KEY_LAST_SYNC, value).await?; - } else { - warn!("skipping migration of last_sync_time: invalid RFC3339 {value:?}"); - } - } - } - - // last_version_check_time — validate as RFC3339 - let version_check_path = data_dir.join(LEGACY_LAST_VERSION_CHECK_FILENAME); - if version_check_path.exists() - && let Ok(value) = fs_err::read_to_string(&version_check_path) - { - let value = value.trim(); - if !value.is_empty() { - if OffsetDateTime::parse(value, &Rfc3339).is_ok() { - self.set(KEY_LAST_VERSION_CHECK, value).await?; - } else { - warn!( - "skipping migration of last_version_check_time: invalid RFC3339 {value:?}" - ); - } - } - } - - // latest_version — no strict validation, just non-empty - let latest_version_path = data_dir.join(LEGACY_LATEST_VERSION_FILENAME); - if latest_version_path.exists() - && let Ok(value) = fs_err::read_to_string(&latest_version_path) - { - let value = value.trim(); - if !value.is_empty() { - self.set(KEY_LATEST_VERSION, value).await?; - } - } - - // session token — no strict validation, just non-empty - let session_path = data_dir.join(LEGACY_SESSION_FILENAME); - if session_path.exists() - && let Ok(value) = fs_err::read_to_string(&session_path) - { - let value = value.trim(); - if !value.is_empty() { - self.set(KEY_SESSION, value).await?; - } - } - - self.set(KEY_FILES_MIGRATED, "true").await?; - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - async fn new_test_store() -> MetaStore { - MetaStore::new("sqlite::memory:", 2.0).await.unwrap() - } - - #[tokio::test] - async fn test_get_set_delete() { - let store = new_test_store().await; - - assert_eq!(store.get("foo").await.unwrap(), None); - - store.set("foo", "bar").await.unwrap(); - assert_eq!(store.get("foo").await.unwrap(), Some("bar".to_string())); - - store.set("foo", "baz").await.unwrap(); - assert_eq!(store.get("foo").await.unwrap(), Some("baz".to_string())); - - store.delete("foo").await.unwrap(); - assert_eq!(store.get("foo").await.unwrap(), None); - } - - #[tokio::test] - async fn test_host_id_generation_and_stability() { - let store = new_test_store().await; - - let id1 = store.host_id().await.unwrap(); - let id2 = store.host_id().await.unwrap(); - - assert_eq!(id1, id2, "host_id should be stable across calls"); - } - - #[tokio::test] - async fn test_sync_time() { - let store = new_test_store().await; - - let t = store.last_sync().await.unwrap(); - assert_eq!(t, OffsetDateTime::UNIX_EPOCH); - - store.save_sync_time().await.unwrap(); - let t = store.last_sync().await.unwrap(); - assert!(t > OffsetDateTime::UNIX_EPOCH); - } - - #[tokio::test] - async fn test_version_check_time() { - let store = new_test_store().await; - - let t = store.last_version_check().await.unwrap(); - assert_eq!(t, OffsetDateTime::UNIX_EPOCH); - - store.save_version_check_time().await.unwrap(); - let t = store.last_version_check().await.unwrap(); - assert!(t > OffsetDateTime::UNIX_EPOCH); - } - - #[tokio::test] - async fn test_session_crud() { - let store = new_test_store().await; - - assert!(!store.logged_in().await.unwrap()); - assert_eq!(store.session_token().await.unwrap(), None); - - store.save_session("tok123").await.unwrap(); - assert!(store.logged_in().await.unwrap()); - assert_eq!( - store.session_token().await.unwrap(), - Some("tok123".to_string()) - ); - - store.delete_session().await.unwrap(); - assert!(!store.logged_in().await.unwrap()); - } - - #[tokio::test] - async fn test_latest_version() { - let store = new_test_store().await; - - assert_eq!(store.latest_version().await.unwrap(), None); - - store.save_latest_version("1.2.3").await.unwrap(); - assert_eq!( - store.latest_version().await.unwrap(), - Some("1.2.3".to_string()) - ); - } -} diff --git a/crates/atuin-client/src/ordering.rs b/crates/atuin-client/src/ordering.rs deleted file mode 100644 index 4e5ec84c..00000000 --- a/crates/atuin-client/src/ordering.rs +++ /dev/null @@ -1,32 +0,0 @@ -use minspan::minspan; - -use super::{history::History, settings::SearchMode}; - -pub fn reorder_fuzzy(mode: SearchMode, query: &str, res: Vec<History>) -> Vec<History> { - match mode { - SearchMode::Fuzzy => reorder(query, |x| &x.command, res), - _ => res, - } -} - -fn reorder<F, A>(query: &str, f: F, res: Vec<A>) -> Vec<A> -where - F: Fn(&A) -> &String, - A: Clone, -{ - let mut r = res.clone(); - let qvec = &query.chars().collect(); - r.sort_by_cached_key(|h| { - // TODO for fzf search we should sum up scores for each matched term - let (from, to) = match minspan::span(qvec, &(f(h).chars().collect())) { - Some(x) => x, - // this is a little unfortunate: when we are asked to match a query that is found nowhere, - // we don't want to return a None, as the comparison behaviour would put the worst matches - // at the front. therefore, we'll return a set of indices that are one larger than the longest - // possible legitimate match. This is meaningless except as a comparison. - None => (0, res.len()), - }; - 1 + to - from - }); - r -} diff --git a/crates/atuin-client/src/plugin.rs b/crates/atuin-client/src/plugin.rs deleted file mode 100644 index 6f351bf1..00000000 --- a/crates/atuin-client/src/plugin.rs +++ /dev/null @@ -1,150 +0,0 @@ -use std::collections::HashMap; - -#[derive(Debug, Clone)] -pub struct OfficialPlugin { - pub name: String, - pub description: String, - pub install_message: String, -} - -impl OfficialPlugin { - pub fn new(name: &str, description: &str, install_message: &str) -> Self { - Self { - name: name.to_string(), - description: description.to_string(), - install_message: install_message.to_string(), - } - } -} - -pub struct OfficialPluginRegistry { - plugins: HashMap<String, OfficialPlugin>, -} - -impl OfficialPluginRegistry { - pub fn new() -> Self { - let mut registry = Self { - plugins: HashMap::new(), - }; - - // Register official plugins - registry.register_official_plugins(); - - registry - } - - fn register_official_plugins(&mut self) { - // atuin-update plugin - self.plugins.insert( - "update".to_string(), - OfficialPlugin::new( - "update", - "Update atuin to the latest version", - "The 'atuin update' command is provided by the atuin-update plugin.\n\ - It is only installed if you used the install script\n \ - If you used a package manager (brew, apt, etc), please continue to use it for updates", - ), - ); - } - - pub fn get_plugin(&self, name: &str) -> Option<&OfficialPlugin> { - self.plugins.get(name) - } - - pub fn is_official_plugin(&self, name: &str) -> bool { - self.plugins.contains_key(name) - } - - pub fn get_install_message(&self, name: &str) -> Option<&str> { - self.plugins - .get(name) - .map(|plugin| plugin.install_message.as_str()) - } -} - -impl Default for OfficialPluginRegistry { - fn default() -> Self { - Self::new() - } -} - -pub struct PluginContext { - #[cfg(windows)] - _update_on_windows: Option<UpdateOnWindowsContext>, -} - -impl PluginContext { - pub fn new(_subcommand: &str) -> Self { - PluginContext { - #[cfg(windows)] - _update_on_windows: (_subcommand == "update").then(UpdateOnWindowsContext::new), - } - } -} - -impl Drop for PluginContext { - fn drop(&mut self) {} -} - -#[cfg(windows)] -struct UpdateOnWindowsContext { - initial_exe: Option<std::path::PathBuf>, -} - -#[cfg(windows)] -impl UpdateOnWindowsContext { - const OLD_FILE_NAME: &'static str = "atuin.old"; - - pub fn new() -> Self { - // Windows doesn't let you overwrite a running exe, but it lets you rename it, - // so make some room for atuin-update to install the new version. - let initial_exe = std::env::current_exe().ok().and_then(|exe| { - std::fs::rename(&exe, exe.with_file_name(Self::OLD_FILE_NAME)).ok()?; - Some(exe) - }); - - Self { initial_exe } - } -} - -#[cfg(windows)] -impl Drop for UpdateOnWindowsContext { - fn drop(&mut self) { - if let Some(exe) = &self.initial_exe - && !exe.exists() - { - // The update failed, roll back the current exe to its initial name. - std::fs::rename(exe.with_file_name(Self::OLD_FILE_NAME), exe).unwrap_or_else(|e| { - eprintln!("Failed to roll back the update, you may need to reinstall Atuin: {e}"); - }); - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_registry_creation() { - let registry = OfficialPluginRegistry::new(); - assert!(registry.is_official_plugin("update")); - assert!(!registry.is_official_plugin("nonexistent")); - } - - #[test] - fn test_get_plugin() { - let registry = OfficialPluginRegistry::new(); - let plugin = registry.get_plugin("update"); - assert!(plugin.is_some()); - assert_eq!(plugin.unwrap().name, "update"); - } - - #[test] - fn test_get_install_message() { - let registry = OfficialPluginRegistry::new(); - let message = registry.get_install_message("update"); - assert!(message.is_some()); - assert!(message.unwrap().contains("atuin-update")); - } -} diff --git a/crates/atuin-client/src/record/encryption.rs b/crates/atuin-client/src/record/encryption.rs deleted file mode 100644 index 1e94d967..00000000 --- a/crates/atuin-client/src/record/encryption.rs +++ /dev/null @@ -1,373 +0,0 @@ -use atuin_common::record::{ - AdditionalData, DecryptedData, EncryptedData, Encryption, HostId, RecordId, RecordIdx, -}; -use base64::{Engine, engine::general_purpose}; -use eyre::{Context, Result, ensure}; -use rusty_paserk::{Key, KeyId, Local, PieWrappedKey}; -use rusty_paseto::core::{ - ImplicitAssertion, Key as DataKey, Local as LocalPurpose, Paseto, PasetoNonce, Payload, V4, -}; -use serde::{Deserialize, Serialize}; - -/// Use PASETO V4 Local encryption using the additional data as an implicit assertion. -#[expect(non_camel_case_types)] -pub struct PASETO_V4; - -/* -Why do we use a random content-encryption key? -Originally I was planning on using a derived key for encryption based on additional data. -This would be a lot more secure than using the master key directly. - -However, there's an established norm of using a random key. This scheme might be otherwise known as -- client-side encryption -- envelope encryption -- key wrapping - -A HSM (Hardware Security Module) provider, eg: AWS, Azure, GCP, or even a physical device like a YubiKey -will have some keys that they keep to themselves. These keys never leave their physical hardware. -If they never leave the hardware, then encrypting large amounts of data means giving them the data and waiting. -This is not a practical solution. Instead, generate a unique key for your data, encrypt that using your HSM -and then store that with your data. - -See - - <https://docs.aws.amazon.com/wellarchitected/latest/financial-services-industry-lens/use-envelope-encryption-with-customer-master-keys.html> - - <https://cloud.google.com/kms/docs/envelope-encryption> - - <https://learn.microsoft.com/en-us/azure/storage/blobs/client-side-encryption?tabs=dotnet#encryption-and-decryption-via-the-envelope-technique> - - <https://www.yubico.com/gb/product/yubihsm-2-fips/> - - <https://cheatsheetseries.owasp.org/cheatsheets/Cryptographic_Storage_Cheat_Sheet.html#encrypting-stored-keys> - -Why would we care? In the past we have received some requests for company solutions. If in future we can configure a -KMS service with little effort, then that would solve a lot of issues for their security team. - -Even for personal use, if a user is not comfortable with sharing keys between hosts, -GCP HSM costs $1/month and $0.03 per 10,000 key operations. Assuming an active user runs -1000 atuin records a day, that would only cost them $1 and 10 cent a month. - -Additionally, key rotations are much simpler using this scheme. Rotating a key is as simple as re-encrypting the CEK, and not the message contents. -This makes it very fast to rotate a key in bulk. - -For future reference, with asymmetric encryption, you can encrypt the CEK without the HSM's involvement, but decrypting -will need the HSM. This allows the encryption path to still be extremely fast (no network calls) but downloads/decryption -that happens in the background can make the network calls to the HSM -*/ - -impl Encryption for PASETO_V4 { - fn re_encrypt( - mut data: EncryptedData, - _ad: AdditionalData, - old_key: &[u8; 32], - new_key: &[u8; 32], - ) -> Result<EncryptedData> { - let cek = Self::decrypt_cek(data.content_encryption_key, old_key)?; - data.content_encryption_key = Self::encrypt_cek(cek, new_key); - Ok(data) - } - - fn encrypt(data: DecryptedData, ad: AdditionalData, key: &[u8; 32]) -> EncryptedData { - // generate a random key for this entry - // aka content-encryption-key (CEK) - let random_key = Key::<V4, Local>::new_os_random(); - - // encode the implicit assertions - let assertions = Assertions::from(ad).encode(); - - // build the payload and encrypt the token - let payload = serde_json::to_string(&AtuinPayload { - data: general_purpose::URL_SAFE_NO_PAD.encode(data.0), - }) - .expect("json encoding can't fail"); - let nonce = DataKey::<32>::try_new_random().expect("could not source from random"); - let nonce = PasetoNonce::<V4, LocalPurpose>::from(&nonce); - - let token = Paseto::<V4, LocalPurpose>::builder() - .set_payload(Payload::from(payload.as_str())) - .set_implicit_assertion(ImplicitAssertion::from(assertions.as_str())) - .try_encrypt(&random_key.into(), &nonce) - .expect("error encrypting atuin data"); - - EncryptedData { - data: token, - content_encryption_key: Self::encrypt_cek(random_key, key), - } - } - - fn decrypt(data: EncryptedData, ad: AdditionalData, key: &[u8; 32]) -> Result<DecryptedData> { - let token = data.data; - let cek = Self::decrypt_cek(data.content_encryption_key, key)?; - - // encode the implicit assertions - let assertions = Assertions::from(ad).encode(); - - // decrypt the payload with the footer and implicit assertions - let payload = Paseto::<V4, LocalPurpose>::try_decrypt( - &token, - &cek.into(), - None, - ImplicitAssertion::from(&*assertions), - ) - .context("could not decrypt entry")?; - - let payload: AtuinPayload = serde_json::from_str(&payload)?; - let data = general_purpose::URL_SAFE_NO_PAD.decode(payload.data)?; - Ok(DecryptedData(data)) - } -} - -impl PASETO_V4 { - fn decrypt_cek(wrapped_cek: String, key: &[u8; 32]) -> Result<Key<V4, Local>> { - let wrapping_key = Key::<V4, Local>::from_bytes(*key); - - // let wrapping_key = PasetoSymmetricKey::from(Key::from(key)); - - let AtuinFooter { kid, wpk } = serde_json::from_str(&wrapped_cek) - .context("wrapped cek did not contain the correct contents")?; - - // check that the wrapping key matches the required key to decrypt. - // In future, we could support multiple keys and use this key to - // look up the key rather than only allow one key. - // For now though we will only support the one key and key rotation will - // have to be a hard reset - let current_kid = wrapping_key.to_id(); - - ensure!( - current_kid == kid, - "attempting to decrypt with incorrect key. currently using {current_kid}, expecting {kid}" - ); - - // decrypt the random key - Ok(wpk.unwrap_key(&wrapping_key)?) - } - - fn encrypt_cek(cek: Key<V4, Local>, key: &[u8; 32]) -> String { - // aka key-encryption-key (KEK) - let wrapping_key = Key::<V4, Local>::from_bytes(*key); - - // wrap the random key so we can decrypt it later - let wrapped_cek = AtuinFooter { - wpk: cek.wrap_pie(&wrapping_key), - kid: wrapping_key.to_id(), - }; - serde_json::to_string(&wrapped_cek).expect("could not serialize wrapped cek") - } -} - -#[derive(Serialize, Deserialize)] -struct AtuinPayload { - data: String, -} - -#[derive(Serialize, Deserialize)] -/// Well-known footer claims for decrypting. This is not encrypted but is stored in the record. -/// <https://github.com/paseto-standard/paseto-spec/blob/master/docs/02-Implementation-Guide/04-Claims.md#optional-footer-claims> -struct AtuinFooter { - /// Wrapped key - wpk: PieWrappedKey<V4, Local>, - /// ID of the key which was used to wrap - kid: KeyId<V4, Local>, -} - -/// Used in the implicit assertions. This is not encrypted and not stored in the data blob. -// This cannot be changed, otherwise it breaks the authenticated encryption. -#[derive(Debug, Copy, Clone, Serialize)] -struct Assertions<'a> { - id: &'a RecordId, - idx: &'a RecordIdx, - version: &'a str, - tag: &'a str, - host: &'a HostId, -} - -impl<'a> From<AdditionalData<'a>> for Assertions<'a> { - fn from(ad: AdditionalData<'a>) -> Self { - Self { - id: ad.id, - version: ad.version, - tag: ad.tag, - host: ad.host, - idx: ad.idx, - } - } -} - -impl Assertions<'_> { - fn encode(&self) -> String { - serde_json::to_string(self).expect("could not serialize implicit assertions") - } -} - -#[cfg(test)] -mod tests { - use atuin_common::{ - record::{Host, Record}, - utils::uuid_v7, - }; - - use super::*; - - #[test] - fn round_trip() { - let key = Key::<V4, Local>::new_os_random(); - - let ad = AdditionalData { - id: &RecordId(uuid_v7()), - version: "v0", - tag: "kv", - host: &HostId(uuid_v7()), - idx: &0, - }; - - let data = DecryptedData(vec![1, 2, 3, 4]); - - let encrypted = PASETO_V4::encrypt(data.clone(), ad, &key.to_bytes()); - let decrypted = PASETO_V4::decrypt(encrypted, ad, &key.to_bytes()).unwrap(); - assert_eq!(decrypted, data); - } - - #[test] - fn same_entry_different_output() { - let key = Key::<V4, Local>::new_os_random(); - - let ad = AdditionalData { - id: &RecordId(uuid_v7()), - version: "v0", - tag: "kv", - host: &HostId(uuid_v7()), - idx: &0, - }; - - let data = DecryptedData(vec![1, 2, 3, 4]); - - let encrypted = PASETO_V4::encrypt(data.clone(), ad, &key.to_bytes()); - let encrypted2 = PASETO_V4::encrypt(data, ad, &key.to_bytes()); - - assert_ne!( - encrypted.data, encrypted2.data, - "re-encrypting the same contents should have different output due to key randomization" - ); - } - - #[test] - fn cannot_decrypt_different_key() { - let key = Key::<V4, Local>::new_os_random(); - let fake_key = Key::<V4, Local>::new_os_random(); - - let ad = AdditionalData { - id: &RecordId(uuid_v7()), - version: "v0", - tag: "kv", - host: &HostId(uuid_v7()), - idx: &0, - }; - - let data = DecryptedData(vec![1, 2, 3, 4]); - - let encrypted = PASETO_V4::encrypt(data, ad, &key.to_bytes()); - let _ = PASETO_V4::decrypt(encrypted, ad, &fake_key.to_bytes()).unwrap_err(); - } - - #[test] - fn cannot_decrypt_different_id() { - let key = Key::<V4, Local>::new_os_random(); - - let ad = AdditionalData { - id: &RecordId(uuid_v7()), - version: "v0", - tag: "kv", - host: &HostId(uuid_v7()), - idx: &0, - }; - - let data = DecryptedData(vec![1, 2, 3, 4]); - - let encrypted = PASETO_V4::encrypt(data, ad, &key.to_bytes()); - - let ad = AdditionalData { - id: &RecordId(uuid_v7()), - ..ad - }; - let _ = PASETO_V4::decrypt(encrypted, ad, &key.to_bytes()).unwrap_err(); - } - - #[test] - fn re_encrypt_round_trip() { - let key1 = Key::<V4, Local>::new_os_random(); - let key2 = Key::<V4, Local>::new_os_random(); - - let ad = AdditionalData { - id: &RecordId(uuid_v7()), - version: "v0", - tag: "kv", - host: &HostId(uuid_v7()), - idx: &0, - }; - - let data = DecryptedData(vec![1, 2, 3, 4]); - - let encrypted1 = PASETO_V4::encrypt(data.clone(), ad, &key1.to_bytes()); - let encrypted2 = - PASETO_V4::re_encrypt(encrypted1.clone(), ad, &key1.to_bytes(), &key2.to_bytes()) - .unwrap(); - - // we only re-encrypt the content keys - assert_eq!(encrypted1.data, encrypted2.data); - assert_ne!( - encrypted1.content_encryption_key, - encrypted2.content_encryption_key - ); - - let decrypted = PASETO_V4::decrypt(encrypted2, ad, &key2.to_bytes()).unwrap(); - - assert_eq!(decrypted, data); - } - - #[test] - fn full_record_round_trip() { - let key = [0x55; 32]; - let record = Record::builder() - .id(RecordId(uuid_v7())) - .version("v0".to_owned()) - .tag("kv".to_owned()) - .host(Host::new(HostId(uuid_v7()))) - .timestamp(1687244806000000) - .data(DecryptedData(vec![1, 2, 3, 4])) - .idx(0) - .build(); - - let encrypted = record.encrypt::<PASETO_V4>(&key); - - assert!(!encrypted.data.data.is_empty()); - assert!(!encrypted.data.content_encryption_key.is_empty()); - - let decrypted = encrypted.decrypt::<PASETO_V4>(&key).unwrap(); - - assert_eq!(decrypted.data.0, [1, 2, 3, 4]); - } - - #[test] - fn full_record_round_trip_fail() { - let key = [0x55; 32]; - let record = Record::builder() - .id(RecordId(uuid_v7())) - .version("v0".to_owned()) - .tag("kv".to_owned()) - .host(Host::new(HostId(uuid_v7()))) - .timestamp(1687244806000000) - .data(DecryptedData(vec![1, 2, 3, 4])) - .idx(0) - .build(); - - let encrypted = record.encrypt::<PASETO_V4>(&key); - - let mut enc1 = encrypted.clone(); - enc1.host = Host::new(HostId(uuid_v7())); - let _ = enc1 - .decrypt::<PASETO_V4>(&key) - .expect_err("tampering with the host should result in auth failure"); - - let mut enc2 = encrypted; - enc2.id = RecordId(uuid_v7()); - let _ = enc2 - .decrypt::<PASETO_V4>(&key) - .expect_err("tampering with the id should result in auth failure"); - } -} diff --git a/crates/atuin-client/src/record/mod.rs b/crates/atuin-client/src/record/mod.rs deleted file mode 100644 index c40fd395..00000000 --- a/crates/atuin-client/src/record/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -pub mod encryption; -pub mod sqlite_store; -pub mod store; - -#[cfg(feature = "sync")] -pub mod sync; diff --git a/crates/atuin-client/src/record/sqlite_store.rs b/crates/atuin-client/src/record/sqlite_store.rs deleted file mode 100644 index ed51f3fd..00000000 --- a/crates/atuin-client/src/record/sqlite_store.rs +++ /dev/null @@ -1,642 +0,0 @@ -// Here we are using sqlite as a pretty dumb store, and will not be running any complex queries. -// Multiple stores of multiple types are all stored in one chonky table (for now), and we just index -// by tag/host - -use std::str::FromStr; -use std::{path::Path, time::Duration}; - -use async_trait::async_trait; -use eyre::{Result, eyre}; -use fs_err as fs; - -use sqlx::{ - Row, - sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow}, -}; - -use atuin_common::record::{ - EncryptedData, Host, HostId, Record, RecordId, RecordIdx, RecordStatus, -}; -use atuin_common::utils; -use uuid::Uuid; - -use super::encryption::PASETO_V4; -use super::store::Store; - -#[derive(Debug, Clone)] -pub struct SqliteStore { - pool: SqlitePool, -} - -impl SqliteStore { - pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> { - let path = path.as_ref(); - - debug!("opening sqlite database at {path:?}"); - - if utils::broken_symlink(path) { - eprintln!( - "Atuin: Sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement." - ); - std::process::exit(1); - } - - if !path.exists() - && let Some(dir) = path.parent() - { - fs::create_dir_all(dir)?; - } - - let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())? - .journal_mode(SqliteJournalMode::Wal) - .foreign_keys(true) - .create_if_missing(true); - - let pool = SqlitePoolOptions::new() - .acquire_timeout(Duration::from_secs_f64(timeout)) - .connect_with(opts) - .await?; - - Self::setup_db(&pool).await?; - - Ok(Self { pool }) - } - - async fn setup_db(pool: &SqlitePool) -> Result<()> { - debug!("running sqlite database setup"); - - sqlx::migrate!("./record-migrations").run(pool).await?; - - Ok(()) - } - - async fn save_raw( - tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, - r: &Record<EncryptedData>, - ) -> Result<()> { - // In sqlite, we are "limited" to i64. But that is still fine, until 2262. - sqlx::query( - "insert or ignore into store(id, idx, host, tag, timestamp, version, data, cek) - values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", - ) - .bind(r.id.0.as_hyphenated().to_string()) - .bind(r.idx as i64) - .bind(r.host.id.0.as_hyphenated().to_string()) - .bind(r.tag.as_str()) - .bind(r.timestamp as i64) - .bind(r.version.as_str()) - .bind(r.data.data.as_str()) - .bind(r.data.content_encryption_key.as_str()) - .execute(&mut **tx) - .await?; - - Ok(()) - } - - fn query_row(row: SqliteRow) -> Record<EncryptedData> { - let idx: i64 = row.get("idx"); - let timestamp: i64 = row.get("timestamp"); - - // tbh at this point things are pretty fucked so just panic - let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB"); - let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB"); - - Record { - id: RecordId(id), - idx: idx as u64, - host: Host::new(HostId(host)), - timestamp: timestamp as u64, - tag: row.get("tag"), - version: row.get("version"), - data: EncryptedData { - data: row.get("data"), - content_encryption_key: row.get("cek"), - }, - } - } - - async fn load_all(&self) -> Result<Vec<Record<EncryptedData>>> { - let res = sqlx::query("select * from store ") - .map(Self::query_row) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } -} - -#[async_trait] -impl Store for SqliteStore { - async fn push_batch( - &self, - records: impl Iterator<Item = &Record<EncryptedData>> + Send + Sync, - ) -> Result<()> { - let mut tx = self.pool.begin().await?; - - for record in records { - Self::save_raw(&mut tx, record).await?; - } - - tx.commit().await?; - - Ok(()) - } - - async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>> { - let res = sqlx::query("select * from store where store.id = ?1") - .bind(id.0.as_hyphenated().to_string()) - .map(Self::query_row) - .fetch_one(&self.pool) - .await?; - - Ok(res) - } - - async fn delete(&self, id: RecordId) -> Result<()> { - sqlx::query("delete from store where id = ?1") - .bind(id.0.as_hyphenated().to_string()) - .execute(&self.pool) - .await?; - - Ok(()) - } - - async fn delete_all(&self) -> Result<()> { - sqlx::query("delete from store").execute(&self.pool).await?; - - Ok(()) - } - - async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> { - let res = - sqlx::query("select * from store where host=?1 and tag=?2 order by idx desc limit 1") - .bind(host.0.as_hyphenated().to_string()) - .bind(tag) - .map(Self::query_row) - .fetch_one(&self.pool) - .await; - - match res { - Err(sqlx::Error::RowNotFound) => Ok(None), - Err(e) => Err(eyre!("an error occurred: {}", e)), - Ok(record) => Ok(Some(record)), - } - } - - async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> { - self.idx(host, tag, 0).await - } - - async fn len_all(&self) -> Result<u64> { - let res: Result<(i64,), sqlx::Error> = sqlx::query_as("select count(*) from store") - .fetch_one(&self.pool) - .await; - match res { - Err(e) => Err(eyre!("failed to fetch local store len: {}", e)), - Ok(v) => Ok(v.0 as u64), - } - } - - async fn len_tag(&self, tag: &str) -> Result<u64> { - let res: Result<(i64,), sqlx::Error> = - sqlx::query_as("select count(*) from store where tag=?1") - .bind(tag) - .fetch_one(&self.pool) - .await; - match res { - Err(e) => Err(eyre!("failed to fetch local store len: {}", e)), - Ok(v) => Ok(v.0 as u64), - } - } - - async fn len(&self, host: HostId, tag: &str) -> Result<u64> { - let last = self.last(host, tag).await?; - - if let Some(last) = last { - return Ok(last.idx + 1); - } - - return Ok(0); - } - - async fn next( - &self, - host: HostId, - tag: &str, - idx: RecordIdx, - limit: u64, - ) -> Result<Vec<Record<EncryptedData>>> { - let res = sqlx::query( - "select * from store where idx >= ?1 and host = ?2 and tag = ?3 order by idx asc limit ?4", - ) - .bind(idx as i64) - .bind(host.0.as_hyphenated().to_string()) - .bind(tag) - .bind(limit as i64) - .map(Self::query_row) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } - - async fn idx( - &self, - host: HostId, - tag: &str, - idx: RecordIdx, - ) -> Result<Option<Record<EncryptedData>>> { - let res = sqlx::query("select * from store where idx = ?1 and host = ?2 and tag = ?3") - .bind(idx as i64) - .bind(host.0.as_hyphenated().to_string()) - .bind(tag) - .map(Self::query_row) - .fetch_one(&self.pool) - .await; - - match res { - Err(sqlx::Error::RowNotFound) => Ok(None), - Err(e) => Err(eyre!("an error occurred: {}", e)), - Ok(v) => Ok(Some(v)), - } - } - - async fn status(&self) -> Result<RecordStatus> { - let mut status = RecordStatus::new(); - - let res: Result<Vec<(String, String, i64)>, sqlx::Error> = - sqlx::query_as("select host, tag, max(idx) from store group by host, tag") - .fetch_all(&self.pool) - .await; - - let res = match res { - Err(e) => return Err(eyre!("failed to fetch local store status: {}", e)), - Ok(v) => v, - }; - - for i in res { - let host = HostId( - Uuid::from_str(i.0.as_str()).expect("failed to parse uuid for local store status"), - ); - - status.set_raw(host, i.1, i.2 as u64); - } - - Ok(status) - } - - async fn all_tagged(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> { - let res = sqlx::query("select * from store where tag = ?1 order by timestamp asc") - .bind(tag) - .map(Self::query_row) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } - - /// Reencrypt every single item in this store with a new key - /// Be careful - this may mess with sync. - async fn re_encrypt(&self, old_key: &[u8; 32], new_key: &[u8; 32]) -> Result<()> { - // Load all the records - // In memory like some of the other code here - // This will never be called in a hot loop, and only under the following circumstances - // 1. The user has logged into a new account, with a new key. They are unlikely to have a - // lot of data - // 2. The user has encountered some sort of issue, and runs a maintenance command that - // invokes this - let all = self.load_all().await?; - - let re_encrypted = all - .into_iter() - .map(|record| record.re_encrypt::<PASETO_V4>(old_key, new_key)) - .collect::<Result<Vec<_>>>()?; - - // next up, we delete all the old data and reinsert the new stuff - // do it in one transaction, so if anything fails we rollback OK - - let mut tx = self.pool.begin().await?; - - let res = sqlx::query("delete from store").execute(&mut *tx).await?; - - let rows = res.rows_affected(); - debug!("deleted {rows} rows"); - - // don't call push_batch, as it will start its own transaction - // call the underlying save_raw - - for record in re_encrypted { - Self::save_raw(&mut tx, &record).await?; - } - - tx.commit().await?; - - Ok(()) - } - - /// Verify that every record in this store can be decrypted with the current key - /// Someday maybe also check each tag/record can be deserialized, but not for now. - async fn verify(&self, key: &[u8; 32]) -> Result<()> { - let all = self.load_all().await?; - - all.into_iter() - .map(|record| record.decrypt::<PASETO_V4>(key)) - .collect::<Result<Vec<_>>>()?; - - Ok(()) - } - - /// Verify that every record in this store can be decrypted with the current key - /// Someday maybe also check each tag/record can be deserialized, but not for now. - async fn purge(&self, key: &[u8; 32]) -> Result<()> { - let all = self.load_all().await?; - - for record in all.iter() { - match record.clone().decrypt::<PASETO_V4>(key) { - Ok(_) => continue, - Err(_) => { - println!( - "Failed to decrypt {}, deleting", - record.id.0.as_hyphenated() - ); - - self.delete(record.id).await?; - } - } - } - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use atuin_common::{ - record::{DecryptedData, EncryptedData, Host, HostId, Record}, - utils::uuid_v7, - }; - - use crate::{ - encryption::generate_encoded_key, - record::{encryption::PASETO_V4, store::Store}, - settings::test_local_timeout, - }; - - use super::SqliteStore; - - fn test_record() -> Record<EncryptedData> { - Record::builder() - .host(Host::new(HostId(atuin_common::utils::uuid_v7()))) - .version("v1".into()) - .tag(atuin_common::utils::uuid_v7().simple().to_string()) - .data(EncryptedData { - data: "1234".into(), - content_encryption_key: "1234".into(), - }) - .idx(0) - .build() - } - - #[tokio::test] - async fn create_db() { - let db = SqliteStore::new(":memory:", test_local_timeout()).await; - - assert!( - db.is_ok(), - "db could not be created, {:?}", - db.err().unwrap() - ); - } - - #[tokio::test] - async fn push_record() { - let db = SqliteStore::new(":memory:", test_local_timeout()) - .await - .unwrap(); - let record = test_record(); - - db.push(&record).await.expect("failed to insert record"); - } - - #[tokio::test] - async fn get_record() { - let db = SqliteStore::new(":memory:", test_local_timeout()) - .await - .unwrap(); - let record = test_record(); - db.push(&record).await.unwrap(); - - let new_record = db.get(record.id).await.expect("failed to fetch record"); - - assert_eq!(record, new_record, "records are not equal"); - } - - #[tokio::test] - async fn last() { - let db = SqliteStore::new(":memory:", test_local_timeout()) - .await - .unwrap(); - let record = test_record(); - db.push(&record).await.unwrap(); - - let last = db - .last(record.host.id, record.tag.as_str()) - .await - .expect("failed to get store len"); - - assert_eq!( - last.unwrap().id, - record.id, - "expected to get back the same record that was inserted" - ); - } - - #[tokio::test] - async fn first() { - let db = SqliteStore::new(":memory:", test_local_timeout()) - .await - .unwrap(); - let record = test_record(); - db.push(&record).await.unwrap(); - - let first = db - .first(record.host.id, record.tag.as_str()) - .await - .expect("failed to get store len"); - - assert_eq!( - first.unwrap().id, - record.id, - "expected to get back the same record that was inserted" - ); - } - - #[tokio::test] - async fn len() { - let db = SqliteStore::new(":memory:", test_local_timeout()) - .await - .unwrap(); - let record = test_record(); - db.push(&record).await.unwrap(); - - let len = db - .len(record.host.id, record.tag.as_str()) - .await - .expect("failed to get store len"); - - assert_eq!(len, 1, "expected length of 1 after insert"); - } - - #[tokio::test] - async fn len_tag() { - let db = SqliteStore::new(":memory:", test_local_timeout()) - .await - .unwrap(); - let record = test_record(); - db.push(&record).await.unwrap(); - - let len = db - .len_tag(record.tag.as_str()) - .await - .expect("failed to get store len"); - - assert_eq!(len, 1, "expected length of 1 after insert"); - } - - #[tokio::test] - async fn len_different_tags() { - let db = SqliteStore::new(":memory:", test_local_timeout()) - .await - .unwrap(); - - // these have different tags, so the len should be the same - // we model multiple stores within one database - // new store = new tag = independent length - let first = test_record(); - let second = test_record(); - - db.push(&first).await.unwrap(); - db.push(&second).await.unwrap(); - - let first_len = db.len(first.host.id, first.tag.as_str()).await.unwrap(); - let second_len = db.len(second.host.id, second.tag.as_str()).await.unwrap(); - - assert_eq!(first_len, 1, "expected length of 1 after insert"); - assert_eq!(second_len, 1, "expected length of 1 after insert"); - } - - #[tokio::test] - async fn append_a_bunch() { - let db = SqliteStore::new(":memory:", test_local_timeout()) - .await - .unwrap(); - - let mut tail = test_record(); - db.push(&tail).await.expect("failed to push record"); - - for _ in 1..100 { - tail = tail.append(vec![1, 2, 3, 4]).encrypt::<PASETO_V4>(&[0; 32]); - db.push(&tail).await.unwrap(); - } - - assert_eq!( - db.len(tail.host.id, tail.tag.as_str()).await.unwrap(), - 100, - "failed to insert 100 records" - ); - - assert_eq!( - db.len_tag(tail.tag.as_str()).await.unwrap(), - 100, - "failed to insert 100 records" - ); - } - - #[tokio::test] - async fn append_a_big_bunch() { - let db = SqliteStore::new(":memory:", test_local_timeout()) - .await - .unwrap(); - - let mut records: Vec<Record<EncryptedData>> = Vec::with_capacity(10000); - - let mut tail = test_record(); - records.push(tail.clone()); - - for _ in 1..10000 { - tail = tail.append(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]); - records.push(tail.clone()); - } - - db.push_batch(records.iter()).await.unwrap(); - - assert_eq!( - db.len(tail.host.id, tail.tag.as_str()).await.unwrap(), - 10000, - "failed to insert 10k records" - ); - } - - #[tokio::test] - async fn re_encrypt() { - let store = SqliteStore::new(":memory:", test_local_timeout()) - .await - .unwrap(); - let (key, _) = generate_encoded_key().unwrap(); - let data = vec![0u8, 1u8, 2u8, 3u8]; - let host_id = HostId(uuid_v7()); - - for i in 0..10 { - let record = Record::builder() - .host(Host::new(host_id)) - .version(String::from("test")) - .tag(String::from("test")) - .idx(i) - .data(DecryptedData(data.clone())) - .build(); - - let record = record.encrypt::<PASETO_V4>(&key.into()); - store - .push(&record) - .await - .expect("failed to push encrypted record"); - } - - // first, check that we can decrypt the data with the current key - let all = store.all_tagged("test").await.unwrap(); - - assert_eq!(all.len(), 10, "failed to fetch all records"); - - for record in all { - let decrypted = record.decrypt::<PASETO_V4>(&key.into()).unwrap(); - assert_eq!(decrypted.data.0, data); - } - - // reencrypt the store, then check if - // 1) it cannot be decrypted with the old key - // 2) it can be decrypted with the new key - - let (new_key, _) = generate_encoded_key().unwrap(); - store - .re_encrypt(&key.into(), &new_key.into()) - .await - .expect("failed to re-encrypt store"); - - let all = store.all_tagged("test").await.unwrap(); - - for record in all.iter() { - let decrypted = record.clone().decrypt::<PASETO_V4>(&key.into()); - assert!( - decrypted.is_err(), - "did not get error decrypting with old key after re-encrypt" - ) - } - - for record in all { - let decrypted = record.decrypt::<PASETO_V4>(&new_key.into()).unwrap(); - assert_eq!(decrypted.data.0, data); - } - - assert_eq!(store.len(host_id, "test").await.unwrap(), 10); - } -} diff --git a/crates/atuin-client/src/record/store.rs b/crates/atuin-client/src/record/store.rs deleted file mode 100644 index 49ca4968..00000000 --- a/crates/atuin-client/src/record/store.rs +++ /dev/null @@ -1,60 +0,0 @@ -use async_trait::async_trait; -use eyre::Result; - -use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIdx, RecordStatus}; - -/// A record store stores records -/// In more detail - we tend to need to process this into _another_ format to actually query it. -/// As is, the record store is intended as the source of truth for arbitrary data, which could -/// be shell history, kvs, etc. -#[async_trait] -pub trait Store { - // Push a record - async fn push(&self, record: &Record<EncryptedData>) -> Result<()> { - self.push_batch(std::iter::once(record)).await - } - - // Push a batch of records, all in one transaction - async fn push_batch( - &self, - records: impl Iterator<Item = &Record<EncryptedData>> + Send + Sync, - ) -> Result<()>; - - async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>>; - - async fn delete(&self, id: RecordId) -> Result<()>; - async fn delete_all(&self) -> Result<()>; - - async fn len_all(&self) -> Result<u64>; - async fn len(&self, host: HostId, tag: &str) -> Result<u64>; - async fn len_tag(&self, tag: &str) -> Result<u64>; - - async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>; - async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>; - - async fn re_encrypt(&self, old_key: &[u8; 32], new_key: &[u8; 32]) -> Result<()>; - async fn verify(&self, key: &[u8; 32]) -> Result<()>; - async fn purge(&self, key: &[u8; 32]) -> Result<()>; - - /// Get the next `limit` records, after and including the given index - async fn next( - &self, - host: HostId, - tag: &str, - idx: RecordIdx, - limit: u64, - ) -> Result<Vec<Record<EncryptedData>>>; - - /// Get the first record for a given host and tag - async fn idx( - &self, - host: HostId, - tag: &str, - idx: RecordIdx, - ) -> Result<Option<Record<EncryptedData>>>; - - async fn status(&self) -> Result<RecordStatus>; - - /// Get all records for a given tag - async fn all_tagged(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>>; -} diff --git a/crates/atuin-client/src/record/sync.rs b/crates/atuin-client/src/record/sync.rs deleted file mode 100644 index b785b5dc..00000000 --- a/crates/atuin-client/src/record/sync.rs +++ /dev/null @@ -1,663 +0,0 @@ -// do a sync :O -use std::{cmp::Ordering, fmt::Write}; - -use eyre::Result; -use thiserror::Error; - -use super::{encryption::PASETO_V4, store::Store}; -use crate::{api_client::Client, settings::Settings}; - -use atuin_common::record::{Diff, HostId, RecordId, RecordIdx, RecordStatus}; -use indicatif::{ProgressBar, ProgressState, ProgressStyle}; - -#[derive(Error, Debug)] -pub enum SyncError { - #[error("the local store is ahead of the remote, but for another host. has remote lost data?")] - LocalAheadOtherHost, - - #[error("an issue with the local database occurred: {msg:?}")] - LocalStoreError { msg: String }, - - #[error("something has gone wrong with the sync logic: {msg:?}")] - SyncLogicError { msg: String }, - - #[error("operational error: {msg:?}")] - OperationalError { msg: String }, - - #[error("a request to the sync server failed: {msg:?}")] - RemoteRequestError { msg: String }, - - #[error( - "the encryption key on this machine does not match the data on the server. \ - this usually means a new machine was set up without copying the existing key. \ - to fix: run `atuin key` on a machine that already syncs correctly, then run \ - `atuin store rekey <key>` on this machine with the value from the other machine" - )] - WrongKey, -} - -#[derive(Debug, Eq, PartialEq)] -pub enum Operation { - // Either upload or download until the states matches the below - Upload { - local: RecordIdx, - remote: Option<RecordIdx>, - host: HostId, - tag: String, - }, - Download { - local: Option<RecordIdx>, - remote: RecordIdx, - host: HostId, - tag: String, - }, - Noop { - host: HostId, - tag: String, - }, -} - -pub async fn build_client(settings: &Settings) -> Result<Client<'_>, SyncError> { - Client::new( - &settings.sync_address, - settings - .sync_auth_token() - .await - .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?, - settings.network_connect_timeout, - settings.network_timeout, - ) - .map_err(|e| SyncError::OperationalError { msg: e.to_string() }) -} - -pub async fn diff( - client: &Client<'_>, - store: &impl Store, -) -> Result<(Vec<Diff>, RecordStatus), SyncError> { - let local_index = store - .status() - .await - .map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?; - - let remote_index = client - .record_status() - .await - .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?; - - let diff = local_index.diff(&remote_index); - - Ok((diff, remote_index)) -} - -// Take a diff, along with a local store, and resolve it into a set of operations. -// With the store as context, we can determine if a tail exists locally or not and therefore if it needs uploading or download. -// In theory this could be done as a part of the diffing stage, but it's easier to reason -// about and test this way -pub async fn operations( - diffs: Vec<Diff>, - _store: &impl Store, -) -> Result<Vec<Operation>, SyncError> { - let mut operations = Vec::with_capacity(diffs.len()); - - for diff in diffs { - let op = match (diff.local, diff.remote) { - // We both have it! Could be either. Compare. - (Some(local), Some(remote)) => match local.cmp(&remote) { - Ordering::Equal => Operation::Noop { - host: diff.host, - tag: diff.tag, - }, - Ordering::Greater => Operation::Upload { - local, - remote: Some(remote), - host: diff.host, - tag: diff.tag, - }, - Ordering::Less => Operation::Download { - local: Some(local), - remote, - host: diff.host, - tag: diff.tag, - }, - }, - - // Remote has it, we don't. Gotta be download - (None, Some(remote)) => Operation::Download { - local: None, - remote, - host: diff.host, - tag: diff.tag, - }, - - // We have it, remote doesn't. Gotta be upload. - (Some(local), None) => Operation::Upload { - local, - remote: None, - host: diff.host, - tag: diff.tag, - }, - - // something is pretty fucked. - (None, None) => { - return Err(SyncError::SyncLogicError { - msg: String::from( - "diff has nothing for local or remote - (host, tag) does not exist", - ), - }); - } - }; - - operations.push(op); - } - - // sort them - purely so we have a stable testing order, and can rely on - // same input = same output - // We can sort by ID so long as we continue to use UUIDv7 or something - // with the same properties - - operations.sort_by_key(|op| match op { - Operation::Noop { host, tag } => (0, *host, tag.clone()), - - Operation::Upload { host, tag, .. } => (1, *host, tag.clone()), - - Operation::Download { host, tag, .. } => (2, *host, tag.clone()), - }); - - Ok(operations) -} - -async fn sync_upload( - store: &impl Store, - client: &Client<'_>, - host: HostId, - tag: String, - local: RecordIdx, - remote: Option<RecordIdx>, - page_size: u64, -) -> Result<i64, SyncError> { - let remote = remote.unwrap_or(0); - let expected = local - remote; - let mut progress = 0; - - let pb = ProgressBar::new(expected); - pb.set_style(ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {human_pos}/{human_len} ({eta})") - .unwrap() - .with_key("eta", |state: &ProgressState, w: &mut dyn Write| write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()) - .progress_chars("#>-")); - - println!( - "Uploading {} records to {}/{}", - expected, - host.0.as_simple(), - tag - ); - - loop { - let page = store - .next(host, tag.as_str(), remote + progress, page_size) - .await - .map_err(|e| { - error!("failed to read upload page: {e:?}"); - - SyncError::LocalStoreError { msg: e.to_string() } - })?; - - if page.is_empty() { - break; - } - - client.post_records(&page).await.map_err(|e| { - error!("failed to post records: {e:?}"); - - SyncError::RemoteRequestError { msg: e.to_string() } - })?; - - progress += page.len() as u64; - pb.set_position(progress); - - if progress >= expected { - break; - } - } - - pb.finish_with_message("Uploaded records"); - - Ok(progress as i64) -} - -async fn sync_download( - store: &impl Store, - client: &Client<'_>, - host: HostId, - tag: String, - local: Option<RecordIdx>, - remote: RecordIdx, - page_size: u64, -) -> Result<Vec<RecordId>, SyncError> { - let local = local.unwrap_or(0); - let expected = remote - local; - let mut progress = 0; - let mut ret = Vec::new(); - - println!( - "Downloading {} records from {}/{}", - expected, - host.0.as_simple(), - tag - ); - - let pb = ProgressBar::new(expected); - pb.set_style(ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {human_pos}/{human_len} ({eta})") - .unwrap() - .with_key("eta", |state: &ProgressState, w: &mut dyn Write| write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()) - .progress_chars("#>-")); - - loop { - let page = client - .next_records(host, tag.clone(), local + progress, page_size) - .await - .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?; - - if page.is_empty() { - break; - } - - store - .push_batch(page.iter()) - .await - .map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?; - - ret.extend(page.iter().map(|f| f.id)); - - progress += page.len() as u64; - pb.set_position(progress); - - if progress >= expected { - break; - } - } - - pb.finish_with_message("Downloaded records"); - - Ok(ret) -} - -pub async fn sync_remote( - client: &Client<'_>, - operations: Vec<Operation>, - local_store: &impl Store, - page_size: u64, -) -> Result<(i64, Vec<RecordId>), SyncError> { - let mut uploaded = 0; - let mut downloaded = Vec::new(); - - // this can totally run in parallel, but lets get it working first - for i in operations { - match i { - Operation::Upload { - host, - tag, - local, - remote, - } => { - uploaded += - sync_upload(local_store, client, host, tag, local, remote, page_size).await? - } - - Operation::Download { - host, - tag, - local, - remote, - } => { - let mut d = - sync_download(local_store, client, host, tag, local, remote, page_size).await?; - downloaded.append(&mut d) - } - - Operation::Noop { .. } => continue, - } - } - - Ok((uploaded, downloaded)) -} - -pub async fn check_encryption_key( - client: &Client<'_>, - remote_index: &RecordStatus, - encryption_key: &[u8; 32], -) -> Result<(), SyncError> { - let sample = remote_index - .hosts - .iter() - .flat_map(|(host, tags)| tags.keys().map(move |tag| (*host, tag.clone()))) - .next(); - - let Some((host, tag)) = sample else { - return Ok(()); - }; - - let records = client - .next_records(host, tag, 0, 1) - .await - .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?; - - let Some(record) = records.into_iter().next() else { - return Ok(()); - }; - - record - .decrypt::<PASETO_V4>(encryption_key) - .map_err(|_| SyncError::WrongKey)?; - - Ok(()) -} - -pub async fn sync( - settings: &Settings, - store: &impl Store, - encryption_key: &[u8; 32], -) -> Result<(i64, Vec<RecordId>), SyncError> { - let client = build_client(settings).await?; - let (diff, remote_index) = diff(&client, store).await?; - - // Bail before mutating either side if the local key can't read the remote. - check_encryption_key(&client, &remote_index, encryption_key).await?; - - let operations = operations(diff, store).await?; - let (uploaded, downloaded) = sync_remote(&client, operations, store, 100).await?; - - Ok((uploaded, downloaded)) -} - -#[cfg(test)] -mod tests { - use atuin_common::record::{Diff, EncryptedData, HostId, Record}; - use pretty_assertions::assert_eq; - - use crate::{ - record::{ - encryption::PASETO_V4, - sqlite_store::SqliteStore, - store::Store, - sync::{self, Operation}, - }, - settings::test_local_timeout, - }; - - fn test_record() -> Record<EncryptedData> { - Record::builder() - .host(atuin_common::record::Host::new(HostId( - atuin_common::utils::uuid_v7(), - ))) - .version("v1".into()) - .tag(atuin_common::utils::uuid_v7().simple().to_string()) - .data(EncryptedData { - data: String::new(), - content_encryption_key: String::new(), - }) - .idx(0) - .build() - } - - // Take a list of local records, and a list of remote records. - // Return the local database, and a diff of local/remote, ready to build - // ops - async fn build_test_diff( - local_records: Vec<Record<EncryptedData>>, - remote_records: Vec<Record<EncryptedData>>, - ) -> (SqliteStore, Vec<Diff>) { - let local_store = SqliteStore::new(":memory:", test_local_timeout()) - .await - .expect("failed to open in memory sqlite"); - let remote_store = SqliteStore::new(":memory:", test_local_timeout()) - .await - .expect("failed to open in memory sqlite"); // "remote" - - for i in local_records { - local_store.push(&i).await.unwrap(); - } - - for i in remote_records { - remote_store.push(&i).await.unwrap(); - } - - let local_index = local_store.status().await.unwrap(); - let remote_index = remote_store.status().await.unwrap(); - - let diff = local_index.diff(&remote_index); - - (local_store, diff) - } - - #[tokio::test] - async fn test_basic_diff() { - // a diff where local is ahead of remote. nothing else. - - let record = test_record(); - let (store, diff) = build_test_diff(vec![record.clone()], vec![]).await; - - assert_eq!(diff.len(), 1); - - let operations = sync::operations(diff, &store).await.unwrap(); - - assert_eq!(operations.len(), 1); - - assert_eq!( - operations[0], - Operation::Upload { - host: record.host.id, - tag: record.tag, - local: record.idx, - remote: None, - } - ); - } - - #[tokio::test] - async fn build_two_way_diff() { - // a diff where local is ahead of remote for one, and remote for - // another. One upload, one download - - let shared_record = test_record(); - let remote_ahead = test_record(); - - let local_ahead = shared_record - .append(vec![1, 2, 3]) - .encrypt::<PASETO_V4>(&[0; 32]); - - assert_eq!(local_ahead.idx, 1); - - let local = vec![shared_record.clone(), local_ahead.clone()]; // local knows about the already synced, and something newer in the same store - let remote = vec![shared_record.clone(), remote_ahead.clone()]; // remote knows about the already-synced, and one new record in a new store - - let (store, diff) = build_test_diff(local, remote).await; - let operations = sync::operations(diff, &store).await.unwrap(); - - assert_eq!(operations.len(), 2); - - assert_eq!( - operations, - vec![ - // Or in otherwords, local is ahead by one - Operation::Upload { - host: local_ahead.host.id, - tag: local_ahead.tag, - local: 1, - remote: Some(0), - }, - // Or in other words, remote knows of a record in an entirely new store (tag) - Operation::Download { - host: remote_ahead.host.id, - tag: remote_ahead.tag, - local: None, - remote: 0, - }, - ] - ); - } - - #[tokio::test] - async fn build_complex_diff() { - // One shared, ahead but known only by remote - // One known only by local - // One known only by remote - - let shared_record = test_record(); - let local_only = test_record(); - - let local_only_20 = test_record(); - let local_only_21 = local_only_20 - .append(vec![1, 2, 3]) - .encrypt::<PASETO_V4>(&[0; 32]); - let local_only_22 = local_only_21 - .append(vec![1, 2, 3]) - .encrypt::<PASETO_V4>(&[0; 32]); - let local_only_23 = local_only_22 - .append(vec![1, 2, 3]) - .encrypt::<PASETO_V4>(&[0; 32]); - - let remote_only = test_record(); - - let remote_only_20 = test_record(); - let remote_only_21 = remote_only_20 - .append(vec![2, 3, 2]) - .encrypt::<PASETO_V4>(&[0; 32]); - let remote_only_22 = remote_only_21 - .append(vec![2, 3, 2]) - .encrypt::<PASETO_V4>(&[0; 32]); - let remote_only_23 = remote_only_22 - .append(vec![2, 3, 2]) - .encrypt::<PASETO_V4>(&[0; 32]); - let remote_only_24 = remote_only_23 - .append(vec![2, 3, 2]) - .encrypt::<PASETO_V4>(&[0; 32]); - - let second_shared = test_record(); - let second_shared_remote_ahead = second_shared - .append(vec![1, 2, 3]) - .encrypt::<PASETO_V4>(&[0; 32]); - let second_shared_remote_ahead2 = second_shared_remote_ahead - .append(vec![1, 2, 3]) - .encrypt::<PASETO_V4>(&[0; 32]); - - let third_shared = test_record(); - let third_shared_local_ahead = third_shared - .append(vec![1, 2, 3]) - .encrypt::<PASETO_V4>(&[0; 32]); - let third_shared_local_ahead2 = third_shared_local_ahead - .append(vec![1, 2, 3]) - .encrypt::<PASETO_V4>(&[0; 32]); - - let fourth_shared = test_record(); - let fourth_shared_remote_ahead = fourth_shared - .append(vec![1, 2, 3]) - .encrypt::<PASETO_V4>(&[0; 32]); - let fourth_shared_remote_ahead2 = fourth_shared_remote_ahead - .append(vec![1, 2, 3]) - .encrypt::<PASETO_V4>(&[0; 32]); - - let local = vec![ - shared_record.clone(), - second_shared.clone(), - third_shared.clone(), - fourth_shared.clone(), - fourth_shared_remote_ahead.clone(), - // single store, only local has it - local_only.clone(), - // bigger store, also only known by local - local_only_20.clone(), - local_only_21.clone(), - local_only_22.clone(), - local_only_23.clone(), - // another shared store, but local is ahead on this one - third_shared_local_ahead.clone(), - third_shared_local_ahead2.clone(), - ]; - - let remote = vec![ - remote_only.clone(), - remote_only_20.clone(), - remote_only_21.clone(), - remote_only_22.clone(), - remote_only_23.clone(), - remote_only_24.clone(), - shared_record.clone(), - second_shared.clone(), - third_shared.clone(), - second_shared_remote_ahead.clone(), - second_shared_remote_ahead2.clone(), - fourth_shared.clone(), - fourth_shared_remote_ahead.clone(), - fourth_shared_remote_ahead2.clone(), - ]; // remote knows about the already-synced, and one new record in a new store - - let (store, diff) = build_test_diff(local, remote).await; - let operations = sync::operations(diff, &store).await.unwrap(); - - assert_eq!(operations.len(), 7); - - let mut result_ops = vec![ - // We started with a shared record, but the remote knows of two newer records in the - // same store - Operation::Download { - local: Some(0), - remote: 2, - host: second_shared_remote_ahead.host.id, - tag: second_shared_remote_ahead.tag, - }, - // We have a shared record, local knows of the first two but not the last - Operation::Download { - local: Some(1), - remote: 2, - host: fourth_shared_remote_ahead2.host.id, - tag: fourth_shared_remote_ahead2.tag, - }, - // Remote knows of a store with a single record that local does not have - Operation::Download { - local: None, - remote: 0, - host: remote_only.host.id, - tag: remote_only.tag, - }, - // Remote knows of a store with a bunch of records that local does not have - Operation::Download { - local: None, - remote: 4, - host: remote_only_20.host.id, - tag: remote_only_20.tag, - }, - // Local knows of a record in a store that remote does not have - Operation::Upload { - local: 0, - remote: None, - host: local_only.host.id, - tag: local_only.tag, - }, - // Local knows of 4 records in a store that remote does not have - Operation::Upload { - local: 3, - remote: None, - host: local_only_20.host.id, - tag: local_only_20.tag, - }, - // Local knows of 2 more records in a shared store that remote only has one of - Operation::Upload { - local: 2, - remote: Some(0), - host: third_shared.host.id, - tag: third_shared.tag, - }, - ]; - - result_ops.sort_by_key(|op| match op { - Operation::Noop { host, tag } => (0, *host, tag.clone()), - - Operation::Upload { host, tag, .. } => (1, *host, tag.clone()), - - Operation::Download { host, tag, .. } => (2, *host, tag.clone()), - }); - - assert_eq!(result_ops, operations); - } -} diff --git a/crates/atuin-client/src/register.rs b/crates/atuin-client/src/register.rs deleted file mode 100644 index ad077dd1..00000000 --- a/crates/atuin-client/src/register.rs +++ /dev/null @@ -1,20 +0,0 @@ -use eyre::Result; - -use crate::{api_client, settings::Settings}; - -pub async fn register_classic( - settings: &Settings, - username: String, - email: String, - password: String, -) -> Result<String> { - let session = - api_client::register(settings.sync_address.as_str(), &username, &email, &password).await?; - - let meta = Settings::meta_store().await?; - meta.save_session(&session.session).await?; - - let _key = crate::encryption::load_key(settings)?; - - Ok(session.session) -} diff --git a/crates/atuin-client/src/secrets.rs b/crates/atuin-client/src/secrets.rs deleted file mode 100644 index e8a6ab62..00000000 --- a/crates/atuin-client/src/secrets.rs +++ /dev/null @@ -1,194 +0,0 @@ -// This file will probably trigger a lot of scanners. Sorry. - -use regex::RegexSet; -use std::sync::LazyLock; - -pub enum TestValue<'a> { - Single(&'a str), - Multiple(&'a [&'a str]), -} - -/// A list of `(name, regex, test)`, where `test` should match against `regex`. -pub static SECRET_PATTERNS: &[(&str, &str, TestValue)] = &[ - ( - "AWS Access Key ID", - "A[KS]IA[0-9A-Z]{16}", - TestValue::Single("AKIAIOSFODNN7EXAMPLE"), - ), - ( - "AWS Secret Access Key env var", - "AWS_SECRET_ACCESS_KEY", - TestValue::Single("AWS_SECRET_ACCESS_KEY=KEYDATA"), - ), - ( - "AWS Session Token env var", - "AWS_SESSION_TOKEN", - TestValue::Single("AWS_SESSION_TOKEN=KEYDATA"), - ), - ( - "Microsoft Azure secret access key env var", - "AZURE_.*_KEY", - TestValue::Single("export AZURE_STORAGE_ACCOUNT_KEY=KEYDATA"), - ), - ( - "Google cloud platform key env var", - "GOOGLE_SERVICE_ACCOUNT_KEY", - TestValue::Single("export GOOGLE_SERVICE_ACCOUNT_KEY=KEYDATA"), - ), - ( - "Atuin login", - r"atuin\s+login", - TestValue::Single( - "atuin login -u mycoolusername -p mycoolpassword -k \"lots of random words\"", - ), - ), - ( - "GitHub PAT (old)", - "ghp_[a-zA-Z0-9]{36}", - TestValue::Single("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"), // legit, I expired it - ), - ( - "GitHub PAT (new)", - "gh1_[A-Za-z0-9]{21}_[A-Za-z0-9]{59}|github_pat_[0-9][A-Za-z0-9]{21}_[A-Za-z0-9]{59}", - TestValue::Multiple(&[ - "gh1_1234567890abcdefghijk_1234567890abcdefghijklmnopqrstuvwxyz1234567890abcdefghijklm", - "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", // also legit, also expired - ]), - ), - ( - "GitHub OAuth Access Token", - "gho_[A-Za-z0-9]{36}", - TestValue::Single("gho_1234567890abcdefghijklmnopqrstuvwx000"), // not a real token - ), - ( - "GitHub OAuth Access Token (user)", - "ghu_[A-Za-z0-9]{36}", - TestValue::Single("ghu_1234567890abcdefghijklmnopqrstuvwx000"), // not a real token - ), - ( - "GitHub App Installation Access Token", - "ghs_[A-Za-z0-9._-]{36,}", - TestValue::Multiple(&[ - "ghs_1234567890abcdefghijklmnopqrstuvwx000", // not a real token - "ghs_abc-def.ghi_jklMNOP0123456789qrstuv-wxyzABCD", // new token format, fake data - ]), - ), - ( - "GitHub Refresh Token", - "ghr_[A-Za-z0-9]{76}", - TestValue::Single( - "ghr_1234567890abcdefghijklmnopqrstuvwx1234567890abcdefghijklmnopqrstuvwx1234567890abcdefghijklmnopqrstuvwx", - ), // not a real token - ), - ( - "GitHub App Installation Access Token v1", - "v1\\.[0-9A-Fa-f]{40}", - TestValue::Single("v1.1234567890abcdef1234567890abcdef12345678"), // not a real token - ), - ( - "GitLab PAT", - "glpat-[a-zA-Z0-9_]{20}", - TestValue::Single("glpat-RkE_BG5p_bbjML21WSfy"), - ), - ( - "Slack OAuth v2 bot", - "xoxb-[0-9]{11}-[0-9]{11}-[0-9a-zA-Z]{24}", - TestValue::Single("xoxb-17653672481-19874698323-pdFZKVeTuE8sk7oOcBrzbqgy"), - ), - ( - "Slack OAuth v2 user token", - "xoxp-[0-9]{11}-[0-9]{11}-[0-9a-zA-Z]{24}", - TestValue::Single("xoxp-17653672481-19874698323-pdFZKVeTuE8sk7oOcBrzbqgy"), - ), - ( - "Slack webhook", - "T[a-zA-Z0-9_]{8}/B[a-zA-Z0-9_]{8}/[a-zA-Z0-9_]{24}", - TestValue::Single( - "https://hooks.slack.com/services/T00000000/B00000000/XXXXXXXXXXXXXXXXXXXXXXXX", - ), - ), - ( - "Stripe test key", - "sk_test_[0-9a-zA-Z]{24}", - TestValue::Single("sk_test_1234567890abcdefghijklmnop"), - ), - ( - "Stripe live key", - "sk_live_[0-9a-zA-Z]{24}", - TestValue::Single("sk_live_1234567890abcdefghijklmnop"), - ), - ( - "Netlify authentication token", - "nf[pcoub]_[0-9a-zA-Z]{36}", - TestValue::Single("nfp_nBh7BdJxUwyaBBwFzpyD29MMFT6pZ9wq5634"), - ), - ( - "npm token", - "npm_[A-Za-z0-9]{36}", - TestValue::Single("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"), - ), - ( - "Pulumi personal access token", - "pul-[0-9a-f]{40}", - TestValue::Single("pul-683c2770662c51d960d72ec27613be7653c5cb26"), - ), -]; - -/// The `regex` expressions from [`SECRET_PATTERNS`] compiled into a `RegexSet`. -pub static SECRET_PATTERNS_RE: LazyLock<RegexSet> = LazyLock::new(|| { - let exprs = SECRET_PATTERNS.iter().map(|f| f.1); - RegexSet::new(exprs).expect("Failed to build secrets regex") -}); - -#[cfg(test)] -mod tests { - use regex::Regex; - - use crate::secrets::{SECRET_PATTERNS, TestValue}; - - #[test] - fn test_secrets() { - for (name, regex, test) in SECRET_PATTERNS { - let re = - Regex::new(regex).unwrap_or_else(|_| panic!("Failed to compile regex for {name}")); - - match test { - TestValue::Single(test) => { - assert!(re.is_match(test), "{name} test failed!"); - } - TestValue::Multiple(tests) => { - for test_str in tests.iter() { - assert!( - re.is_match(test_str), - "{name} test with value \"{test_str}\" failed!" - ); - } - } - } - } - } - - #[test] - fn test_secrets_embedded() { - for (name, regex, test) in SECRET_PATTERNS { - let re = - Regex::new(regex).unwrap_or_else(|_| panic!("Failed to compile regex for {name}")); - - match test { - TestValue::Single(test) => { - let embedded = format!("some random text {test} some more random text"); - assert!(re.is_match(&embedded), "{name} embedded test failed!"); - } - TestValue::Multiple(tests) => { - for test_str in tests.iter() { - let embedded = format!("some random text {test_str} some more random text"); - assert!( - re.is_match(&embedded), - "{name} embedded test with value \"{test_str}\" failed!" - ); - } - } - } - } - } -} diff --git a/crates/atuin-client/src/settings.rs b/crates/atuin-client/src/settings.rs deleted file mode 100644 index 5fb65c17..00000000 --- a/crates/atuin-client/src/settings.rs +++ /dev/null @@ -1,1855 +0,0 @@ -use std::{collections::HashMap, fmt, io::prelude::*, path::PathBuf, str::FromStr, sync::OnceLock}; -use tokio::sync::OnceCell; - -use atuin_common::record::HostId; -use atuin_common::utils; -use clap::ValueEnum; -use config::{ - Config, ConfigBuilder, Environment, File as ConfigFile, FileFormat, builder::DefaultState, -}; -use eyre::{Context, Error, Result, bail, eyre}; -use fs_err::{File, create_dir_all}; -use humantime::parse_duration; -use regex::RegexSet; -use serde::{Deserialize, Serialize}; -use serde_with::DeserializeFromStr; -use time::{OffsetDateTime, UtcOffset, format_description::FormatItem, macros::format_description}; - -pub const HISTORY_PAGE_SIZE: i64 = 100; -static EXAMPLE_CONFIG: &str = include_str!("../config.toml"); - -static DATA_DIR: OnceLock<PathBuf> = OnceLock::new(); -static META_CONFIG: OnceLock<(String, f64)> = OnceLock::new(); -static META_STORE: OnceCell<crate::meta::MetaStore> = OnceCell::const_new(); - -pub(crate) mod meta; -pub mod watcher; - -/// Default sync address for Atuin's hosted service -pub const DEFAULT_SYNC_ADDRESS: &str = "https://api.atuin.sh"; - -#[derive(Clone, Debug, Deserialize, Copy, ValueEnum, PartialEq, Serialize)] -pub enum SearchMode { - #[serde(rename = "prefix")] - Prefix, - - #[serde(rename = "fulltext")] - #[clap(aliases = &["fulltext"])] - FullText, - - #[serde(rename = "fuzzy")] - Fuzzy, - - #[serde(rename = "skim")] - Skim, - - #[serde(rename = "daemon-fuzzy")] - #[clap(aliases = &["daemon-fuzzy"])] - DaemonFuzzy, -} - -impl SearchMode { - pub fn as_str(&self) -> &'static str { - match self { - SearchMode::Prefix => "PREFIX", - SearchMode::FullText => "FULLTXT", - SearchMode::Fuzzy => "FUZZY", - SearchMode::Skim => "SKIM", - SearchMode::DaemonFuzzy => "DAEMON", - } - } - pub fn next(&self, settings: &Settings) -> Self { - match self { - SearchMode::Prefix => SearchMode::FullText, - // if the user is using skim, we go to skim - SearchMode::FullText if settings.search_mode == SearchMode::Skim => SearchMode::Skim, - // if the user is using daemon-fuzzy, we go to daemon-fuzzy - SearchMode::FullText if settings.search_mode == SearchMode::DaemonFuzzy => { - SearchMode::DaemonFuzzy - } - // otherwise fuzzy. - SearchMode::FullText => SearchMode::Fuzzy, - SearchMode::Fuzzy | SearchMode::Skim | SearchMode::DaemonFuzzy => SearchMode::Prefix, - } - } -} - -#[derive(Clone, Debug, Deserialize, Copy, PartialEq, Eq, ValueEnum, Serialize)] -pub enum FilterMode { - #[serde(rename = "global")] - Global = 0, - - #[serde(rename = "host")] - Host = 1, - - #[serde(rename = "session")] - Session = 2, - - #[serde(rename = "directory")] - Directory = 3, - - #[serde(rename = "workspace")] - Workspace = 4, - - #[serde(rename = "session-preload")] - SessionPreload = 5, -} - -impl FilterMode { - pub fn as_str(&self) -> &'static str { - match self { - FilterMode::Global => "GLOBAL", - FilterMode::Host => "HOST", - FilterMode::Session => "SESSION", - FilterMode::Directory => "DIRECTORY", - FilterMode::Workspace => "WORKSPACE", - FilterMode::SessionPreload => "SESSION+", - } - } -} - -#[derive(Clone, Debug, Deserialize, Copy, Serialize)] -pub enum ExitMode { - #[serde(rename = "return-original")] - ReturnOriginal, - - #[serde(rename = "return-query")] - ReturnQuery, -} - -// FIXME: Can use upstream Dialect enum if https://github.com/stevedonovan/chrono-english/pull/16 is merged -// FIXME: Above PR was merged, but dependency was changed to interim (fork of chrono-english) in the ... interim -#[derive(Clone, Debug, Deserialize, Copy, Serialize)] -pub enum Dialect { - #[serde(rename = "us")] - Us, - - #[serde(rename = "uk")] - Uk, -} - -impl From<Dialect> for interim::Dialect { - fn from(d: Dialect) -> interim::Dialect { - match d { - Dialect::Uk => interim::Dialect::Uk, - Dialect::Us => interim::Dialect::Us, - } - } -} - -/// Type wrapper around `time::UtcOffset` to support a wider variety of timezone formats. -/// -/// Note that the parsing of this struct needs to be done before starting any -/// multithreaded runtime, otherwise it will fail on most Unix systems. -/// -/// See: <https://github.com/atuinsh/atuin/pull/1517#discussion_r1447516426> -#[derive(Clone, Copy, Debug, Eq, PartialEq, DeserializeFromStr, Serialize)] -pub struct Timezone(pub UtcOffset); -impl fmt::Display for Timezone { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} -/// format: <+|-><hour>[:<minute>[:<second>]] -static OFFSET_FMT: &[FormatItem<'_>] = format_description!( - "[offset_hour sign:mandatory padding:none][optional [:[offset_minute padding:none][optional [:[offset_second padding:none]]]]]" -); -impl FromStr for Timezone { - type Err = Error; - - fn from_str(s: &str) -> Result<Self> { - // local timezone - if matches!(s.to_lowercase().as_str(), "l" | "local") { - // There have been some timezone issues, related to errors fetching it on some - // platforms - // Rather than fail to start, fallback to UTC. The user should still be able to specify - // their timezone manually in the config file. - let offset = UtcOffset::current_local_offset().unwrap_or(UtcOffset::UTC); - return Ok(Self(offset)); - } - - if matches!(s.to_lowercase().as_str(), "0" | "utc") { - let offset = UtcOffset::UTC; - return Ok(Self(offset)); - } - - // offset from UTC - if let Ok(offset) = UtcOffset::parse(s, OFFSET_FMT) { - return Ok(Self(offset)); - } - - // IDEA: Currently named timezones are not supported, because the well-known crate - // for this is `chrono_tz`, which is not really interoperable with the datetime crate - // that we currently use - `time`. If ever we migrate to using `chrono`, this would - // be a good feature to add. - - bail!(r#""{s}" is not a valid timezone spec"#) - } -} - -#[derive(Clone, Debug, Deserialize, Copy, Serialize)] -pub enum Style { - #[serde(rename = "auto")] - Auto, - - #[serde(rename = "full")] - Full, - - #[serde(rename = "compact")] - Compact, -} - -#[derive(Clone, Debug, Deserialize, Copy, Serialize)] -pub enum WordJumpMode { - #[serde(rename = "emacs")] - Emacs, - - #[serde(rename = "subl")] - Subl, -} - -#[derive(Clone, Debug, Deserialize, Copy, PartialEq, Eq, ValueEnum, Serialize)] -pub enum KeymapMode { - #[serde(rename = "emacs")] - Emacs, - - #[serde(rename = "vim-normal")] - VimNormal, - - #[serde(rename = "vim-insert")] - VimInsert, - - #[serde(rename = "auto")] - Auto, -} - -impl KeymapMode { - pub fn as_str(&self) -> &'static str { - match self { - KeymapMode::Emacs => "EMACS", - KeymapMode::VimNormal => "VIMNORMAL", - KeymapMode::VimInsert => "VIMINSERT", - KeymapMode::Auto => "AUTO", - } - } -} - -// We want to translate the config to crossterm::cursor::SetCursorStyle, but -// the original type does not implement trait serde::Deserialize unfortunately. -// It seems impossible to implement Deserialize for external types when it is -// used in HashMap (https://stackoverflow.com/questions/67142663). We instead -// define an adapter type. -#[derive(Clone, Debug, Deserialize, Copy, PartialEq, Eq, ValueEnum, Serialize)] -pub enum CursorStyle { - #[serde(rename = "default")] - DefaultUserShape, - - #[serde(rename = "blink-block")] - BlinkingBlock, - - #[serde(rename = "steady-block")] - SteadyBlock, - - #[serde(rename = "blink-underline")] - BlinkingUnderScore, - - #[serde(rename = "steady-underline")] - SteadyUnderScore, - - #[serde(rename = "blink-bar")] - BlinkingBar, - - #[serde(rename = "steady-bar")] - SteadyBar, -} - -impl CursorStyle { - pub fn as_str(&self) -> &'static str { - match self { - CursorStyle::DefaultUserShape => "DEFAULT", - CursorStyle::BlinkingBlock => "BLINKBLOCK", - CursorStyle::SteadyBlock => "STEADYBLOCK", - CursorStyle::BlinkingUnderScore => "BLINKUNDERLINE", - CursorStyle::SteadyUnderScore => "STEADYUNDERLINE", - CursorStyle::BlinkingBar => "BLINKBAR", - CursorStyle::SteadyBar => "STEADYBAR", - } - } -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Stats { - #[serde(default = "Stats::common_prefix_default")] - pub common_prefix: Vec<String>, // sudo, etc. commands we want to strip off - #[serde(default = "Stats::common_subcommands_default")] - pub common_subcommands: Vec<String>, // kubectl, commands we should consider subcommands for - #[serde(default = "Stats::ignored_commands_default")] - pub ignored_commands: Vec<String>, // cd, ls, etc. commands we want to completely hide from stats -} - -impl Stats { - fn common_prefix_default() -> Vec<String> { - vec!["sudo", "doas"].into_iter().map(String::from).collect() - } - - fn common_subcommands_default() -> Vec<String> { - vec![ - "apt", - "cargo", - "composer", - "dnf", - "docker", - "dotnet", - "git", - "go", - "ip", - "jj", - "kubectl", - "nix", - "nmcli", - "npm", - "pecl", - "pnpm", - "podman", - "port", - "systemctl", - "tmux", - "yarn", - ] - .into_iter() - .map(String::from) - .collect() - } - - fn ignored_commands_default() -> Vec<String> { - vec![] - } -} - -impl Default for Stats { - fn default() -> Self { - Self { - common_prefix: Self::common_prefix_default(), - common_subcommands: Self::common_subcommands_default(), - ignored_commands: Self::ignored_commands_default(), - } - } -} - -/// Sync protocol type for authentication. -/// -/// This setting is primarily for development/testing. When not explicitly set, -/// the protocol is inferred from the sync_address: -/// - Default sync address (api.atuin.sh) → Hub protocol -/// - Custom sync address → Legacy protocol -/// -/// Set explicitly to "hub" to use Hub authentication with a custom sync_address -/// (useful for local development against a Hub instance). -#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Eq, Serialize, Default)] -#[serde(rename_all = "lowercase")] -pub enum SyncProtocol { - /// Use legacy CLI authentication (Token from CLI register/login) - #[default] - Legacy, -} - -/// Resolved authentication state for sync operations. -/// -/// Determined at runtime by examining which tokens are available and what -/// server the client is configured to talk to. Operations use this to pick -/// the right auth header and endpoint style. -#[cfg(feature = "sync")] -#[derive(Debug, Clone)] -pub enum SyncAuth { - /// Self-hosted Rust server. Uses `Authorization: Token <session>` and - /// legacy endpoints. - Legacy { token: String }, - - /// Not authenticated at all. Contains an actionable user-facing message. - NotLoggedIn { reason: String }, -} - -#[cfg(feature = "sync")] -impl SyncAuth { - /// Convert into the auth token type used by the API client. - /// - /// Returns an error with an actionable message for `NotLoggedIn`. - pub fn into_auth_token(self) -> Result<crate::api_client::AuthToken> { - use crate::api_client::AuthToken; - match self { - SyncAuth::Legacy { token } => Ok(AuthToken::Token(token)), - SyncAuth::NotLoggedIn { reason } => Err(eyre!(reason)), - } - } -} - -#[derive(Clone, Debug, Deserialize, Default, Serialize)] -pub struct Keys { - pub scroll_exits: bool, - pub exit_past_line_start: bool, - pub accept_past_line_end: bool, - pub accept_past_line_start: bool, - pub accept_with_backspace: bool, - pub prefix: String, -} - -impl Keys { - /// The standard default values for all `[keys]` options. - /// These match the config defaults set in `builder_with_data_dir()`. - pub fn standard_defaults() -> Self { - Keys { - scroll_exits: true, - exit_past_line_start: true, - accept_past_line_end: true, - accept_past_line_start: false, - accept_with_backspace: false, - prefix: "a".to_string(), - } - } - - /// Returns true if any value differs from the standard defaults. - pub fn has_non_default_values(&self) -> bool { - let d = Self::standard_defaults(); - self.scroll_exits != d.scroll_exits - || self.exit_past_line_start != d.exit_past_line_start - || self.accept_past_line_end != d.accept_past_line_end - || self.accept_past_line_start != d.accept_past_line_start - || self.accept_with_backspace != d.accept_with_backspace - || self.prefix != d.prefix - } -} - -/// A single rule within a conditional keybinding config. -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct KeyRuleConfig { - /// Optional condition expression (e.g. "cursor-at-start", "input-empty && no-results"). - /// If absent, the rule always matches. - #[serde(default)] - pub when: Option<String>, - /// The action to perform (e.g. "exit", "cursor-left", "accept"). - pub action: String, -} - -/// A keybinding config value: either a simple action string or an ordered list of conditional rules. -#[derive(Clone, Debug, Deserialize, Serialize)] -#[serde(untagged)] -pub enum KeyBindingConfig { - /// Simple unconditional binding: `"ctrl-c" = "return-original"` - Simple(String), - /// Conditional binding: `"left" = [{ when = "cursor-at-start", action = "exit" }, { action = "cursor-left" }]` - Rules(Vec<KeyRuleConfig>), -} - -/// User-facing keymap configuration. Each mode maps key strings to bindings. -/// Keys present here override the defaults for that key; unmentioned keys keep defaults. -#[derive(Clone, Debug, Deserialize, Serialize, Default)] -pub struct KeymapConfig { - #[serde(default)] - pub emacs: HashMap<String, KeyBindingConfig>, - #[serde(default, rename = "vim-normal")] - pub vim_normal: HashMap<String, KeyBindingConfig>, - #[serde(default, rename = "vim-insert")] - pub vim_insert: HashMap<String, KeyBindingConfig>, - #[serde(default)] - pub inspector: HashMap<String, KeyBindingConfig>, - #[serde(default)] - pub prefix: HashMap<String, KeyBindingConfig>, -} - -impl KeymapConfig { - /// Returns true if no keybinding overrides are configured in any mode. - pub fn is_empty(&self) -> bool { - self.emacs.is_empty() - && self.vim_normal.is_empty() - && self.vim_insert.is_empty() - && self.inspector.is_empty() - && self.prefix.is_empty() - } -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Preview { - pub strategy: PreviewStrategy, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Theme { - /// Name of desired theme ("default" for base) - pub name: String, - - /// Whether any available additional theme debug should be shown - pub debug: Option<bool>, - - /// How many levels of parenthood will be traversed if needed - pub max_depth: Option<u8>, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Daemon { - /// Use the daemon to sync - /// If enabled, history hooks are routed through the daemon. - #[serde(alias = "enable")] - pub enabled: bool, - - /// Automatically start and manage a local daemon when needed. - pub autostart: bool, - - /// The daemon will handle sync on an interval. How often to sync, in seconds. - pub sync_frequency: u64, - - /// The path to the unix socket used by the daemon - pub socket_path: String, - - /// Path to the daemon pidfile used for process coordination. - pub pidfile_path: String, - - /// Use a socket passed via systemd's socket activation protocol, instead of the path - pub systemd_socket: bool, - - /// The port that should be used for TCP on non unix systems - pub tcp_port: u64, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Search { - /// The list of enabled filter modes, in order of priority. - pub filters: Vec<FilterMode>, - - /// The recency score multiplier for the search index (default: 1.0). - /// Values < 1.0 reduce weight, > 1.0 increase weight, 0.0 disables. - pub recency_score_multiplier: f64, - - /// The frequency score multiplier for the search index (default: 1.0). - /// Values < 1.0 reduce weight, > 1.0 increase weight, 0.0 disables. - pub frequency_score_multiplier: f64, - - /// The overall frecency score multiplier for the search index (default: 1.0). - /// Applied after combining recency and frequency scores. - pub frecency_score_multiplier: f64, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Tmux { - /// Enable using atuin with tmux popup (tmux >= 3.2) - pub enabled: bool, - - /// Width of the tmux popup (percentage) - pub width: String, - - /// Height of the tmux popup (percentage) - pub height: String, -} - -/// Log level for file logging. Maps to tracing's LevelFilter. -#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Deserialize, Serialize)] -#[serde(rename_all = "lowercase")] -pub enum LogLevel { - Trace, - Debug, - #[default] - Info, - Warn, - Error, -} - -impl LogLevel { - /// Convert to a tracing directive string for use with EnvFilter. - pub fn as_directive(&self) -> &'static str { - match self { - LogLevel::Trace => "trace", - LogLevel::Debug => "debug", - LogLevel::Info => "info", - LogLevel::Warn => "warn", - LogLevel::Error => "error", - } - } -} - -/// Configuration for a specific log type (search or daemon). -#[derive(Clone, Debug, Default, Deserialize, Serialize)] -pub struct LogConfig { - /// Log file name (relative to dir) or absolute path. - pub file: String, - - /// Override global enabled setting for this log type. - pub enabled: Option<bool>, - - /// Override global level setting for this log type. - pub level: Option<LogLevel>, - - /// Override global retention days setting for this log type. - pub retention: Option<u64>, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Logs { - /// Enable file logging globally. Defaults to true. - #[serde(default = "Logs::default_enabled")] - pub enabled: bool, - - /// Directory for log files. Defaults to ~/.atuin/logs - pub dir: String, - - /// Default log level for file logging. Defaults to "info". - /// Note: ATUIN_LOG environment variable overrides this. - #[serde(default)] - pub level: LogLevel, - - /// Default retention days for log files. Defaults to 4. - #[serde(default = "Logs::default_retention")] - pub retention: u64, - - /// Search log settings - #[serde(default)] - pub search: LogConfig, - - /// Daemon log settings - #[serde(default)] - pub daemon: LogConfig, - - /// AI log settings - #[serde(default)] - pub ai: LogConfig, -} - -#[derive(Default, Clone, Debug, Deserialize, Serialize)] -pub struct Ai { - /// Whether or not the AI features are enabled. - pub enabled: Option<bool>, - - /// The address of the Atuin AI endpoint. Used for AI features like command generation. - /// Only necessary for custom AI endpoints. - pub endpoint: Option<String>, - - /// The API token for the Atuin AI endpoint. Used for AI features like command generation. - /// Only necessary for custom AI endpoints. - pub api_token: Option<String>, - - /// Path to the AI sessions database. - pub db_path: String, - - /// The maximum time in minutes that an AI session can be automatically resumed. - pub session_continue_minutes: i64, - - /// Deprecated: use opening.send_cwd instead. Kept for backwards compatibility. - #[serde(default)] - pub send_cwd: Option<bool>, - - /// Configuration for what context is sent in the opening AI request. - #[serde(default)] - pub opening: AiOpening, - - /// Tool capability flags. - #[serde(default)] - pub capabilities: AiCapabilities, -} - -#[derive(Default, Clone, Debug, Deserialize, Serialize)] -pub struct AiCapabilities { - /// Whether the AI can request to search Atuin history. `None` = unset (defaults to enabled, and the ai will ask for permission). - pub enable_history_search: Option<bool>, - /// Whether the AI can request to view the stored output, if any, for Atuin history entries. `None` = unset (defaults to enabled, and the ai will ask for permission). - pub enable_history_output: Option<bool>, - /// Whether the AI can request to read and write files. `None` = unset (defaults to enabled, and the ai will ask for permission). - pub enable_file_tools: Option<bool>, - /// Whether the AI can request to execute bash commands. `None` = unset (defaults to enabled, and the ai will ask for permission). - pub enable_command_execution: Option<bool>, -} - -#[derive(Default, Clone, Debug, Deserialize, Serialize)] -pub struct AiOpening { - /// Whether or not to send the current working directory to the AI endpoint. - pub send_cwd: Option<bool>, - - /// Whether or not to send the last command as context in the opening AI request. - pub send_last_command: Option<bool>, -} - -impl Default for Preview { - fn default() -> Self { - Self { - strategy: PreviewStrategy::Auto, - } - } -} - -impl Default for Theme { - fn default() -> Self { - Self { - name: "".to_string(), - debug: None::<bool>, - max_depth: Some(10), - } - } -} - -impl Default for Daemon { - fn default() -> Self { - Self { - enabled: false, - autostart: false, - sync_frequency: 300, - socket_path: "".to_string(), - pidfile_path: "".to_string(), - systemd_socket: false, - tcp_port: 8889, - } - } -} - -impl Default for Logs { - fn default() -> Self { - Self { - enabled: true, - dir: "".to_string(), - level: LogLevel::default(), - retention: Self::default_retention(), - search: LogConfig { - file: "search.log".to_string(), - ..Default::default() - }, - daemon: LogConfig { - file: "daemon.log".to_string(), - ..Default::default() - }, - ai: LogConfig { - file: "ai.log".to_string(), - ..Default::default() - }, - } - } -} - -impl Logs { - fn default_enabled() -> bool { - true - } - - fn default_retention() -> u64 { - 4 - } - - /// Returns whether search logging is enabled. - /// Uses search-specific setting if set, otherwise falls back to global. - pub fn search_enabled(&self) -> bool { - self.search.enabled.unwrap_or(self.enabled) - } - - /// Returns whether daemon logging is enabled. - /// Uses daemon-specific setting if set, otherwise falls back to global. - pub fn daemon_enabled(&self) -> bool { - self.daemon.enabled.unwrap_or(self.enabled) - } - - /// Returns whether AI logging is enabled. - /// Uses AI-specific setting if set, otherwise falls back to global. - pub fn ai_enabled(&self) -> bool { - self.ai.enabled.unwrap_or(self.enabled) - } - - /// Returns the log level for search logging. - /// Uses search-specific setting if set, otherwise falls back to global. - pub fn search_level(&self) -> LogLevel { - self.search.level.unwrap_or(self.level) - } - - /// Returns the log level for daemon logging. - /// Uses daemon-specific setting if set, otherwise falls back to global. - pub fn daemon_level(&self) -> LogLevel { - self.daemon.level.unwrap_or(self.level) - } - - /// Returns the log level for AI logging. - /// Uses AI-specific setting if set, otherwise falls back to global. - pub fn ai_level(&self) -> LogLevel { - self.ai.level.unwrap_or(self.level) - } - - /// Returns the retention days for search logging. - /// Uses search-specific setting if set, otherwise falls back to global. - pub fn search_retention(&self) -> u64 { - self.search.retention.unwrap_or(self.retention) - } - - /// Returns the retention days for daemon logging. - /// Uses daemon-specific setting if set, otherwise falls back to global. - pub fn daemon_retention(&self) -> u64 { - self.daemon.retention.unwrap_or(self.retention) - } - - /// Returns the retention days for AI logging. - /// Uses AI-specific setting if set, otherwise falls back to global. - pub fn ai_retention(&self) -> u64 { - self.ai.retention.unwrap_or(self.retention) - } - - /// Returns the full path for the search log file. - pub fn search_path(&self) -> PathBuf { - let path = PathBuf::from(&self.search.file); - PathBuf::from(&self.dir).join(path) - } - - /// Returns the full path for the daemon log file. - pub fn daemon_path(&self) -> PathBuf { - let path = PathBuf::from(&self.daemon.file); - PathBuf::from(&self.dir).join(path) - } - - /// Returns the full path for the AI log file. - pub fn ai_path(&self) -> PathBuf { - let path = PathBuf::from(&self.ai.file); - PathBuf::from(&self.dir).join(path) - } -} - -impl Default for Search { - fn default() -> Self { - Self { - filters: vec![ - FilterMode::Global, - FilterMode::Host, - FilterMode::Session, - FilterMode::SessionPreload, - FilterMode::Workspace, - FilterMode::Directory, - ], - - recency_score_multiplier: 1.0, - frequency_score_multiplier: 1.0, - frecency_score_multiplier: 1.0, - } - } -} - -impl Default for Tmux { - fn default() -> Self { - Self { - enabled: false, - width: "80%".to_string(), - height: "60%".to_string(), - } - } -} - -// The preview height strategy also takes max_preview_height into account. -#[derive(Clone, Debug, Deserialize, Copy, PartialEq, Eq, ValueEnum, Serialize)] -pub enum PreviewStrategy { - // Preview height is calculated for the length of the selected command. - #[serde(rename = "auto")] - Auto, - - // Preview height is calculated for the length of the longest command stored in the history. - #[serde(rename = "static")] - Static, - - // max_preview_height is used as fixed height. - #[serde(rename = "fixed")] - Fixed, -} - -/// Column types available for the interactive search UI. -#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Eq, Serialize)] -#[serde(rename_all = "lowercase")] -pub enum UiColumnType { - /// Command execution duration (e.g., "123ms") - Duration, - /// Relative time since execution (e.g., "59s ago") - Time, - /// Absolute timestamp (e.g., "2025-01-22 14:35") - Datetime, - /// Working directory - Directory, - /// Hostname - Host, - /// Username - User, - /// Exit code - Exit, - /// The command itself (should be last, expands to fill) - Command, -} - -impl UiColumnType { - /// Returns the default width for this column type (in characters). - /// The Command column returns 0 as it expands to fill remaining space. - pub fn default_width(&self) -> u16 { - match self { - UiColumnType::Duration => 5, // "814ms" - UiColumnType::Time => 9, // "459ms ago" - UiColumnType::Datetime => 16, // "2025-01-22 14:35" - UiColumnType::Directory => 20, - UiColumnType::Host => 15, - UiColumnType::User => 10, - UiColumnType::Exit => { - if cfg!(windows) { - 11 // 32-bit integer on Windows: "-1978335212" - } else { - 3 // Usually a byte on Unix - } - } - UiColumnType::Command => 0, // Expands to fill - } - } -} - -/// A column configuration with type and optional custom width. -/// Can be specified as just a string (uses default width) or as an object with type and width. -#[derive(Clone, Debug, Serialize)] -pub struct UiColumn { - pub column_type: UiColumnType, - pub width: u16, - /// If true, this column expands to fill remaining space. Only one column should expand. - pub expand: bool, -} - -impl UiColumn { - pub fn new(column_type: UiColumnType) -> Self { - Self { - width: column_type.default_width(), - expand: column_type == UiColumnType::Command, - column_type, - } - } - - pub fn with_width(column_type: UiColumnType, width: u16) -> Self { - Self { - column_type, - width, - expand: column_type == UiColumnType::Command, - } - } -} - -// Custom deserialize to handle both string and object formats: -// "duration" or { type = "duration", width = 8, expand = true } -impl<'de> serde::Deserialize<'de> for UiColumn { - fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> - where - D: serde::Deserializer<'de>, - { - use serde::de::{self, MapAccess, Visitor}; - - struct UiColumnVisitor; - - impl<'de> Visitor<'de> for UiColumnVisitor { - type Value = UiColumn; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str( - "a column type string or an object with 'type' and optional 'width'/'expand'", - ) - } - - fn visit_str<E>(self, value: &str) -> Result<UiColumn, E> - where - E: de::Error, - { - let column_type: UiColumnType = - serde::Deserialize::deserialize(serde::de::value::StrDeserializer::new(value))?; - Ok(UiColumn::new(column_type)) - } - - fn visit_map<M>(self, mut map: M) -> Result<UiColumn, M::Error> - where - M: MapAccess<'de>, - { - let mut column_type: Option<UiColumnType> = None; - let mut width: Option<u16> = None; - let mut expand: Option<bool> = None; - - while let Some(key) = map.next_key::<String>()? { - match key.as_str() { - "type" => { - column_type = Some(map.next_value()?); - } - "width" => { - width = Some(map.next_value()?); - } - "expand" => { - expand = Some(map.next_value()?); - } - _ => { - let _: serde::de::IgnoredAny = map.next_value()?; - } - } - } - - let column_type = column_type.ok_or_else(|| de::Error::missing_field("type"))?; - let width = width.unwrap_or_else(|| column_type.default_width()); - let expand = expand.unwrap_or(column_type == UiColumnType::Command); - Ok(UiColumn { - column_type, - width, - expand, - }) - } - } - - deserializer.deserialize_any(UiColumnVisitor) - } -} - -/// UI-specific settings for the interactive search. -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Ui { - /// Columns to display in interactive search, from left to right. - /// The indicator column (" > ") is always shown first implicitly. - /// The "command" column should be last as it expands to fill remaining space. - /// Can be simple strings or objects with type and width. - #[serde(default = "Ui::default_columns")] - pub columns: Vec<UiColumn>, -} - -impl Ui { - fn default_columns() -> Vec<UiColumn> { - vec![ - UiColumn::new(UiColumnType::Duration), - UiColumn::new(UiColumnType::Time), - UiColumn::new(UiColumnType::Command), - ] - } - - /// Validate the UI configuration. - /// Returns an error if more than one column has expand = true. - pub fn validate(&self) -> Result<()> { - let expand_count = self.columns.iter().filter(|c| c.expand).count(); - if expand_count > 1 { - bail!( - "Only one column can have expand = true, but {} columns are set to expand", - expand_count - ); - } - Ok(()) - } -} - -impl Default for Ui { - fn default() -> Self { - Self { - columns: Self::default_columns(), - } - } -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Settings { - pub data_dir: Option<String>, - pub dialect: Dialect, - pub timezone: Timezone, - pub style: Style, - pub auto_sync: bool, - pub update_check: bool, - - /// The sync address for atuin. - pub sync_address: String, - - #[serde(default)] - pub sync_protocol: SyncProtocol, - - pub sync_frequency: String, - pub db_path: String, - pub record_store_path: String, - pub key_path: String, - pub search_mode: SearchMode, - pub filter_mode: Option<FilterMode>, - pub filter_mode_shell_up_key_binding: Option<FilterMode>, - pub search_mode_shell_up_key_binding: Option<SearchMode>, - pub shell_up_key_binding: bool, - pub inline_height: u16, - pub inline_height_shell_up_key_binding: Option<u16>, - pub invert: bool, - pub show_preview: bool, - pub max_preview_height: u16, - pub show_help: bool, - pub show_tabs: bool, - pub show_numeric_shortcuts: bool, - pub auto_hide_height: u16, - pub exit_mode: ExitMode, - pub keymap_mode: KeymapMode, - pub keymap_mode_shell: KeymapMode, - pub keymap_cursor: HashMap<String, CursorStyle>, - pub word_jump_mode: WordJumpMode, - pub word_chars: String, - pub scroll_context_lines: usize, - pub history_format: String, - pub strip_trailing_whitespace: bool, - pub prefers_reduced_motion: bool, - pub store_failed: bool, - pub no_mouse: bool, - - #[serde(with = "serde_regex", default = "RegexSet::empty", skip_serializing)] - pub history_filter: RegexSet, - - #[serde(with = "serde_regex", default = "RegexSet::empty", skip_serializing)] - pub cwd_filter: RegexSet, - - pub secrets_filter: bool, - pub workspaces: bool, - pub ctrl_n_shortcuts: bool, - - pub network_connect_timeout: u64, - pub network_timeout: u64, - pub local_timeout: f64, - pub enter_accept: bool, - pub smart_sort: bool, - pub command_chaining: bool, - - #[serde(default)] - pub stats: Stats, - - #[serde(default)] - pub keys: Keys, - - #[serde(default)] - pub keymap: KeymapConfig, - - #[serde(default)] - pub preview: Preview, - - #[serde(default)] - pub daemon: Daemon, - - #[serde(default)] - pub search: Search, - - #[serde(default)] - pub theme: Theme, - - #[serde(default)] - pub ui: Ui, - - #[serde(default)] - pub tmux: Tmux, - - #[serde(default)] - pub logs: Logs, - - #[serde(default)] - pub meta: meta::Settings, -} - -impl Settings { - pub fn utc() -> Self { - Self::builder() - .expect("Could not build default") - .set_override("timezone", "0") - .expect("failed to override timezone with UTC") - .build() - .expect("Could not build config") - .try_deserialize() - .expect("Could not deserialize config") - } - - pub(crate) fn effective_data_dir() -> PathBuf { - DATA_DIR - .get() - .cloned() - .unwrap_or_else(atuin_common::utils::data_dir) - } - - // -- Meta store: lazily initialized on first access -- - - pub async fn meta_store() -> Result<&'static crate::meta::MetaStore> { - META_STORE - .get_or_try_init(|| async { - let (db_path, timeout) = META_CONFIG.get().ok_or_else(|| { - eyre!("meta store config not set — Settings::new() has not been called") - })?; - crate::meta::MetaStore::new(db_path, *timeout).await - }) - .await - } - - pub async fn host_id() -> Result<HostId> { - Self::meta_store().await?.host_id().await - } - - pub async fn last_sync() -> Result<OffsetDateTime> { - Self::meta_store().await?.last_sync().await - } - - pub async fn save_sync_time() -> Result<()> { - Self::meta_store().await?.save_sync_time().await - } - - pub async fn last_version_check() -> Result<OffsetDateTime> { - Self::meta_store().await?.last_version_check().await - } - - pub async fn save_version_check_time() -> Result<()> { - Self::meta_store().await?.save_version_check_time().await - } - - pub async fn should_sync(&self) -> Result<bool> { - if !self.auto_sync || !Self::meta_store().await?.logged_in().await? { - return Ok(false); - } - - if self.sync_frequency == "0" { - return Ok(true); - } - - match parse_duration(self.sync_frequency.as_str()) { - Ok(d) => { - let d = time::Duration::try_from(d)?; - Ok(OffsetDateTime::now_utc() - Settings::last_sync().await? >= d) - } - Err(e) => Err(eyre!("failed to check sync: {}", e)), - } - } - - pub async fn logged_in(&self) -> Result<bool> { - Self::meta_store().await?.logged_in().await - } - - pub async fn session_token(&self) -> Result<String> { - match Self::meta_store().await?.session_token().await? { - Some(token) => Ok(token), - None => Err(eyre!("Tried to load session; not logged in")), - } - } - - /// Examines the configured sync target and available tokens to determine - /// the correct auth strategy. Also performs cleanup of mis-stored tokens - /// (e.g. a CLI token incorrectly saved in the Hub session slot). - #[cfg(feature = "sync")] - pub async fn resolve_sync_auth(&self) -> SyncAuth { - let meta = match Self::meta_store().await { - Ok(m) => m, - Err(e) => { - return SyncAuth::NotLoggedIn { - reason: format!("Failed to open meta store: {e}"), - }; - } - }; - - // Self-hosted / legacy server - match meta.session_token().await { - Ok(Some(token)) => SyncAuth::Legacy { token }, - _ => SyncAuth::NotLoggedIn { - reason: "Not logged in. Run 'atuin login' to authenticate \ - with your sync server." - .into(), - }, - } - } - - /// Returns the appropriate auth token for sync operations. - /// - /// Delegates to [`resolve_sync_auth`] and converts the result to an - /// `AuthToken`. Callers that need to distinguish between auth states - /// (e.g. to show different UI) should call `resolve_sync_auth` directly. - #[cfg(feature = "sync")] - pub async fn sync_auth_token(&self) -> Result<crate::api_client::AuthToken> { - self.resolve_sync_auth().await.into_auth_token() - } - - pub fn default_filter_mode(&self, git_root: bool) -> FilterMode { - self.filter_mode - .filter(|x| self.search.filters.contains(x)) - .or_else(|| { - self.search - .filters - .iter() - .find(|x| match (x, git_root, self.workspaces) { - (FilterMode::Workspace, true, true) => true, - (FilterMode::Workspace, _, _) => false, - (_, _, _) => true, - }) - .copied() - }) - .unwrap_or(FilterMode::Global) - } - - pub fn builder() -> Result<ConfigBuilder<DefaultState>> { - Self::builder_with_data_dir(&atuin_common::utils::data_dir()) - } - - fn builder_with_data_dir(data_dir: &std::path::Path) -> Result<ConfigBuilder<DefaultState>> { - let db_path = data_dir.join("history.db"); - let record_store_path = data_dir.join("records.db"); - let kv_path = data_dir.join("kv.db"); - let scripts_path = data_dir.join("scripts.db"); - let ai_sessions_path = data_dir.join("ai_sessions.db"); - let socket_path = atuin_common::utils::runtime_dir().join("atuin.sock"); - let pidfile_path = data_dir.join("atuin-daemon.pid"); - let logs_dir = atuin_common::utils::logs_dir(); - - let key_path = data_dir.join("key"); - let meta_path = data_dir.join("meta.db"); - - Ok(Config::builder() - .set_default("history_format", "{time}\t{command}\t{duration}")? - .set_default("db_path", db_path.to_str())? - .set_default("record_store_path", record_store_path.to_str())? - .set_default("key_path", key_path.to_str())? - .set_default("dialect", "us")? - .set_default("timezone", "local")? - .set_default("auto_sync", true)? - .set_default("update_check", cfg!(feature = "check-update"))? - .set_default("sync_address", "https://api.atuin.sh")? - .set_default("sync_frequency", "5m")? - .set_default("search_mode", "fuzzy")? - .set_default("filter_mode", None::<String>)? - .set_default("style", "compact")? - .set_default("inline_height", 40)? - .set_default("show_preview", true)? - .set_default("preview.strategy", "auto")? - .set_default("max_preview_height", 4)? - .set_default("show_help", true)? - .set_default("show_tabs", true)? - .set_default("show_numeric_shortcuts", true)? - .set_default("auto_hide_height", 8)? - .set_default("invert", false)? - .set_default("exit_mode", "return-original")? - .set_default("word_jump_mode", "emacs")? - .set_default( - "word_chars", - "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", - )? - .set_default("scroll_context_lines", 1)? - .set_default("shell_up_key_binding", false)? - .set_default("workspaces", false)? - .set_default("ctrl_n_shortcuts", false)? - .set_default("secrets_filter", true)? - .set_default("strip_trailing_whitespace", true)? - .set_default("network_connect_timeout", 5)? - .set_default("network_timeout", 30)? - .set_default("local_timeout", 2.0)? - // enter_accept defaults to false here, but true in the default config file. The dissonance is - // intentional! - // Existing users will get the default "False", so we don't mess with any potential - // muscle memory. - // New users will get the new default, that is more similar to what they are used to. - .set_default("enter_accept", false)? - .set_default("keys.scroll_exits", true)? - .set_default("keys.accept_past_line_end", true)? - .set_default("keys.exit_past_line_start", true)? - .set_default("keys.accept_past_line_start", false)? - .set_default("keys.accept_with_backspace", false)? - .set_default("keys.prefix", "a")? - .set_default("keymap_mode", "emacs")? - .set_default("keymap_mode_shell", "auto")? - .set_default("keymap_cursor", HashMap::<String, String>::new())? - .set_default("smart_sort", false)? - .set_default("command_chaining", false)? - .set_default("store_failed", true)? - .set_default("daemon.sync_frequency", 300)? - .set_default("daemon.enabled", false)? - .set_default("daemon.autostart", false)? - .set_default("daemon.socket_path", socket_path.to_str())? - .set_default("daemon.pidfile_path", pidfile_path.to_str())? - .set_default("daemon.systemd_socket", false)? - .set_default("daemon.tcp_port", 8889)? - .set_default("logs.enabled", true)? - .set_default("logs.dir", logs_dir.to_str())? - .set_default("logs.level", "info")? - .set_default("logs.search.file", "search.log")? - .set_default("logs.daemon.file", "daemon.log")? - .set_default("logs.ai.file", "ai.log")? - .set_default("kv.db_path", kv_path.to_str())? - .set_default("scripts.db_path", scripts_path.to_str())? - .set_default("search.recency_score_multiplier", 1.0)? - .set_default("search.frequency_score_multiplier", 1.0)? - .set_default("search.frecency_score_multiplier", 1.0)? - .set_default("meta.db_path", meta_path.to_str())? - .set_default("ai.db_path", ai_sessions_path.to_str())? - .set_default("ai.session_continue_minutes", 60)? - .set_default("ai.send_cwd", false)? - .set_default("ai.opening.send_cwd", false)? - .set_default("ai.opening.send_last_command", false)? - .set_default( - "search.filters", - vec![ - "global", - "host", - "session", - "workspace", - "directory", - "session-preload", - ], - )? - .set_default("theme.name", "default")? - .set_default("theme.debug", None::<bool>)? - .set_default("tmux.enabled", false)? - .set_default("tmux.width", "80%")? - .set_default("tmux.height", "60%")? - .set_default( - "prefers_reduced_motion", - std::env::var("NO_MOTION") - .ok() - .map(|_| config::Value::new(None, config::ValueKind::Boolean(true))) - .unwrap_or_else(|| config::Value::new(None, config::ValueKind::Boolean(false))), - )? - .set_default("no_mouse", false)? - .add_source( - Environment::with_prefix("atuin") - .prefix_separator("_") - .separator("__"), - )) - } - - pub fn get_config_path() -> Result<PathBuf> { - let config_dir = atuin_common::utils::config_dir(); - - create_dir_all(&config_dir) - .wrap_err_with(|| format!("could not create dir {config_dir:?}"))?; - - let mut config_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { - PathBuf::from(p) - } else { - let mut config_file = PathBuf::new(); - config_file.push(config_dir); - config_file - }; - - config_file.push("config.toml"); - - Ok(config_file) - } - - /// Build a merged `Config` from defaults, config file, and environment. - /// - /// This resolves `data_dir`, initializes the data directory on disk, - /// and layers defaults → config file → env overrides. Both `new()` and - /// `get_config_value()` use this so the resolution logic lives in one place. - fn build_config() -> Result<Config> { - let config_file = Self::get_config_path()?; - - // extract data_dir first so we can use it as the base for other path defaults - let effective_data_dir = if config_file.exists() { - #[derive(Deserialize, Default)] - struct DataDirOnly { - data_dir: Option<String>, - } - - let config_file_str = config_file - .to_str() - .ok_or_else(|| eyre!("config file path is not valid UTF-8"))?; - - let partial_config = Config::builder() - .add_source(ConfigFile::new(config_file_str, FileFormat::Toml)) - .add_source( - Environment::with_prefix("atuin") - .prefix_separator("_") - .separator("__"), - ) - .build() - .ok(); - - let custom_data_dir = partial_config - .and_then(|c| c.try_deserialize::<DataDirOnly>().ok()) - .and_then(|d| d.data_dir); - - match custom_data_dir { - Some(dir) => { - let expanded = shellexpand::full(&dir) - .map_err(|e| eyre!("failed to expand data_dir path: {}", e))?; - PathBuf::from(expanded.as_ref()) - } - None => atuin_common::utils::data_dir(), - } - } else { - atuin_common::utils::data_dir() - }; - - DATA_DIR.set(effective_data_dir.clone()).ok(); - - create_dir_all(&effective_data_dir) - .wrap_err_with(|| format!("could not create dir {effective_data_dir:?}"))?; - - let mut config_builder = Self::builder_with_data_dir(&effective_data_dir)?; - - config_builder = if config_file.exists() { - let config_file_str = config_file - .to_str() - .ok_or_else(|| eyre!("config file path is not valid UTF-8"))?; - config_builder.add_source(ConfigFile::new(config_file_str, FileFormat::Toml)) - } else { - let mut file = File::create(config_file).wrap_err("could not create config file")?; - file.write_all(EXAMPLE_CONFIG.as_bytes()) - .wrap_err("could not write default config file")?; - - config_builder - }; - - // all paths should be expanded - let built = config_builder.build_cloned()?; - config_builder = [ - "db_path", - "record_store_path", - "key_path", - "daemon.socket_path", - "daemon.pidfile_path", - "logs.dir", - "logs.search.file", - "logs.daemon.file", - ] - .iter() - .map(|key| (key, built.get_string(key).unwrap_or_default())) - .filter_map(|(key, value)| match Self::expand_path(value) { - Ok(expanded) => Some((key, expanded)), - Err(e) => { - log::warn!("failed to expand path for {key}: {e}"); - None - } - }) - .fold(config_builder, |builder, (key, value)| { - builder - .set_override(key, value) - .unwrap_or_else(|_| panic!("failed to set absolute path override for {key}")) - }); - - config_builder.build().map_err(Into::into) - } - - /// Look up a single config value by dotted key (e.g. `"daemon.sync_frequency"`). - /// - /// Returns the effective value after merging defaults, config file, and - /// environment — without the side-effects of full `Settings` construction - /// (meta store init, path expansion, etc.). - pub fn get_config_value(key: &str) -> Result<String> { - let config = Self::build_config()?; - let value: config::Value = config - .get(key) - .map_err(|e| eyre!("failed to get config value '{}': {}", key, e))?; - Ok(Self::format_resolved_value(&value, key)) - } - - fn format_resolved_value(value: &config::Value, prefix: &str) -> String { - use config::ValueKind; - - match &value.kind { - ValueKind::Nil => String::new(), - ValueKind::Boolean(b) => b.to_string(), - ValueKind::I64(i) => i.to_string(), - ValueKind::I128(i) => i.to_string(), - ValueKind::U64(u) => u.to_string(), - ValueKind::U128(u) => u.to_string(), - ValueKind::Float(f) => f.to_string(), - ValueKind::String(s) => s.clone(), - ValueKind::Array(arr) => { - let items: Vec<String> = arr - .iter() - .map(|v| Self::format_resolved_value(v, "")) - .collect(); - format!("[{}]", items.join(", ")) - } - ValueKind::Table(map) => { - let mut lines = Vec::new(); - let mut keys: Vec<_> = map.keys().collect(); - keys.sort(); - - for k in keys { - let v = &map[k]; - let full_key = if prefix.is_empty() { - k.clone() - } else { - format!("{}.{}", prefix, k) - }; - - match &v.kind { - ValueKind::Table(_) => { - lines.push(Self::format_resolved_value(v, &full_key)); - } - _ => { - lines.push(format!( - "{} = {}", - full_key, - Self::format_resolved_value(v, "") - )); - } - } - } - - lines.join("\n") - } - } - } - - pub fn new() -> Result<Self> { - let config = Self::build_config()?; - let settings: Settings = config - .try_deserialize() - .map_err(|e| eyre!("failed to deserialize: {}", e))?; - - // Validate UI settings - settings.ui.validate()?; - - // Register meta store config for lazy initialization on first access - META_CONFIG - .set((settings.meta.db_path.clone(), settings.local_timeout)) - .ok(); - - Ok(settings) - } - - fn expand_path(path: String) -> Result<String> { - shellexpand::full(&path) - .map(|p| p.to_string()) - .map_err(|e| eyre!("failed to expand path: {}", e)) - } - - pub fn example_config() -> &'static str { - EXAMPLE_CONFIG - } - - pub fn paths_ok(&self) -> bool { - let paths = [ - &self.db_path, - &self.record_store_path, - &self.key_path, - &self.meta.db_path, - ]; - paths.iter().all(|p| !utils::broken_symlink(p)) - } -} - -impl Default for Settings { - fn default() -> Self { - // if this panics something is very wrong, as the default config - // does not build or deserialize into the settings struct - Self::builder() - .expect("Could not build default") - .build() - .expect("Could not build config") - .try_deserialize() - .expect("Could not deserialize config") - } -} - -/// Initialize the meta store configuration for testing. -/// -/// This should only be used in tests. It allows tests to bypass the normal -/// Settings::new() flow while still being able to use Settings::host_id() -/// and other meta store dependent functions. -/// -/// # Safety -/// This function is not thread-safe with concurrent calls to Settings::new() -/// or other meta store initialization. Only call from tests. -#[doc(hidden)] -pub fn init_meta_config_for_testing(meta_db_path: impl Into<String>, local_timeout: f64) { - META_CONFIG.set((meta_db_path.into(), local_timeout)).ok(); -} - -#[cfg(test)] -pub(crate) fn test_local_timeout() -> f64 { - std::env::var("ATUIN_TEST_LOCAL_TIMEOUT") - .ok() - .and_then(|x| x.parse().ok()) - // this hardcoded value should be replaced by a simple way to get the - // default local_timeout of Settings if possible - .unwrap_or(2.0) -} - -#[cfg(test)] -mod tests { - use std::str::FromStr; - - use eyre::Result; - - use super::Timezone; - - #[test] - fn can_parse_offset_timezone_spec() -> Result<()> { - assert_eq!(Timezone::from_str("+02")?.0.as_hms(), (2, 0, 0)); - assert_eq!(Timezone::from_str("-04")?.0.as_hms(), (-4, 0, 0)); - assert_eq!(Timezone::from_str("+05:30")?.0.as_hms(), (5, 30, 0)); - assert_eq!(Timezone::from_str("-09:30")?.0.as_hms(), (-9, -30, 0)); - - // single digit hours are allowed - assert_eq!(Timezone::from_str("+2")?.0.as_hms(), (2, 0, 0)); - assert_eq!(Timezone::from_str("-4")?.0.as_hms(), (-4, 0, 0)); - assert_eq!(Timezone::from_str("+5:30")?.0.as_hms(), (5, 30, 0)); - assert_eq!(Timezone::from_str("-9:30")?.0.as_hms(), (-9, -30, 0)); - - // fully qualified form - assert_eq!(Timezone::from_str("+09:30:00")?.0.as_hms(), (9, 30, 0)); - assert_eq!(Timezone::from_str("-09:30:00")?.0.as_hms(), (-9, -30, 0)); - - // these offsets don't really exist but are supported anyway - assert_eq!(Timezone::from_str("+0:5")?.0.as_hms(), (0, 5, 0)); - assert_eq!(Timezone::from_str("-0:5")?.0.as_hms(), (0, -5, 0)); - assert_eq!(Timezone::from_str("+01:23:45")?.0.as_hms(), (1, 23, 45)); - assert_eq!(Timezone::from_str("-01:23:45")?.0.as_hms(), (-1, -23, -45)); - - // require a leading sign for clarity - assert!(Timezone::from_str("5").is_err()); - assert!(Timezone::from_str("10:30").is_err()); - - Ok(()) - } - - #[test] - fn can_choose_workspace_filters_when_in_git_context() -> Result<()> { - let mut settings = super::Settings::default(); - settings.search.filters = vec![ - super::FilterMode::Workspace, - super::FilterMode::Host, - super::FilterMode::Directory, - super::FilterMode::Session, - super::FilterMode::Global, - ]; - settings.workspaces = true; - - assert_eq!( - settings.default_filter_mode(true), - super::FilterMode::Workspace, - ); - - Ok(()) - } - - #[test] - fn wont_choose_workspace_filters_when_not_in_git_context() -> Result<()> { - let mut settings = super::Settings::default(); - settings.search.filters = vec![ - super::FilterMode::Workspace, - super::FilterMode::Host, - super::FilterMode::Directory, - super::FilterMode::Session, - super::FilterMode::Global, - ]; - settings.workspaces = true; - - assert_eq!(settings.default_filter_mode(false), super::FilterMode::Host,); - - Ok(()) - } - - #[test] - fn wont_choose_workspace_filters_when_workspaces_disabled() -> Result<()> { - let mut settings = super::Settings::default(); - settings.search.filters = vec![ - super::FilterMode::Workspace, - super::FilterMode::Host, - super::FilterMode::Directory, - super::FilterMode::Session, - super::FilterMode::Global, - ]; - settings.workspaces = false; - - assert_eq!(settings.default_filter_mode(true), super::FilterMode::Host,); - - Ok(()) - } - - #[test] - fn builder_with_data_dir_uses_custom_paths() -> Result<()> { - use std::path::PathBuf; - - let custom_dir = PathBuf::from("/custom/data/dir"); - let builder = super::Settings::builder_with_data_dir(&custom_dir)?; - let config = builder.build()?; - - let db_path: String = config.get("db_path")?; - let key_path: String = config.get("key_path")?; - let record_store_path: String = config.get("record_store_path")?; - let kv_db_path: String = config.get("kv.db_path")?; - let scripts_db_path: String = config.get("scripts.db_path")?; - let meta_db_path: String = config.get("meta.db_path")?; - let daemon_socket_path: String = config.get("daemon.socket_path")?; - let daemon_pidfile_path: String = config.get("daemon.pidfile_path")?; - let daemon_autostart: bool = config.get("daemon.autostart")?; - - assert_eq!(db_path, custom_dir.join("history.db").to_str().unwrap()); - assert_eq!(key_path, custom_dir.join("key").to_str().unwrap()); - assert_eq!( - record_store_path, - custom_dir.join("records.db").to_str().unwrap() - ); - assert_eq!(kv_db_path, custom_dir.join("kv.db").to_str().unwrap()); - assert_eq!( - scripts_db_path, - custom_dir.join("scripts.db").to_str().unwrap() - ); - assert_eq!(meta_db_path, custom_dir.join("meta.db").to_str().unwrap()); - assert_eq!( - daemon_socket_path, - atuin_common::utils::runtime_dir() - .join("atuin.sock") - .to_str() - .unwrap() - ); - assert_eq!( - daemon_pidfile_path, - custom_dir.join("atuin-daemon.pid").to_str().unwrap() - ); - assert!(!daemon_autostart); - - Ok(()) - } - - #[test] - fn effective_data_dir_returns_default_when_not_set() { - let effective = super::Settings::effective_data_dir(); - let default = atuin_common::utils::data_dir(); - - assert!(effective.to_str().is_some()); - assert!(effective.ends_with("atuin") || effective == default); - } - - #[test] - fn keymap_config_deserializes_simple_binding() { - let json = r#"{"emacs": {"ctrl-c": "exit"}}"#; - let config: super::KeymapConfig = serde_json::from_str(json).unwrap(); - assert_eq!(config.emacs.len(), 1); - match &config.emacs["ctrl-c"] { - super::KeyBindingConfig::Simple(s) => assert_eq!(s, "exit"), - _ => panic!("expected Simple variant"), - } - } - - #[test] - fn keymap_config_deserializes_conditional_binding() { - let json = r#"{ - "emacs": { - "left": [ - {"when": "cursor-at-start", "action": "exit"}, - {"action": "cursor-left"} - ] - } - }"#; - let config: super::KeymapConfig = serde_json::from_str(json).unwrap(); - match &config.emacs["left"] { - super::KeyBindingConfig::Rules(rules) => { - assert_eq!(rules.len(), 2); - assert_eq!(rules[0].when.as_deref(), Some("cursor-at-start")); - assert_eq!(rules[0].action, "exit"); - assert!(rules[1].when.is_none()); - assert_eq!(rules[1].action, "cursor-left"); - } - _ => panic!("expected Rules variant"), - } - } - - #[test] - fn keymap_config_deserializes_vim_normal() { - let json = r#"{"vim-normal": {"j": "select-next", "k": "select-previous"}}"#; - let config: super::KeymapConfig = serde_json::from_str(json).unwrap(); - assert_eq!(config.vim_normal.len(), 2); - assert!(config.emacs.is_empty()); - } - - #[test] - fn keymap_config_is_empty_when_default() { - let config = super::KeymapConfig::default(); - assert!(config.is_empty()); - } - - #[test] - fn keymap_config_mixed_modes() { - let json = r#"{ - "emacs": {"ctrl-c": "exit"}, - "vim-normal": {"q": "exit"}, - "inspector": {"d": "delete"} - }"#; - let config: super::KeymapConfig = serde_json::from_str(json).unwrap(); - assert!(!config.is_empty()); - assert_eq!(config.emacs.len(), 1); - assert_eq!(config.vim_normal.len(), 1); - assert_eq!(config.inspector.len(), 1); - assert!(config.vim_insert.is_empty()); - assert!(config.prefix.is_empty()); - } -} diff --git a/crates/atuin-client/src/settings/meta.rs b/crates/atuin-client/src/settings/meta.rs deleted file mode 100644 index 108d74ec..00000000 --- a/crates/atuin-client/src/settings/meta.rs +++ /dev/null @@ -1,17 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct Settings { - pub db_path: String, -} - -impl Default for Settings { - fn default() -> Self { - let dir = atuin_common::utils::data_dir(); - let path = dir.join("meta.db"); - - Self { - db_path: path.to_string_lossy().to_string(), - } - } -} diff --git a/crates/atuin-client/src/settings/watcher.rs b/crates/atuin-client/src/settings/watcher.rs deleted file mode 100644 index 740b8d12..00000000 --- a/crates/atuin-client/src/settings/watcher.rs +++ /dev/null @@ -1,256 +0,0 @@ -//! Config file watching for automatic settings reload. -//! -//! This module provides a `SettingsWatcher` that monitors the config file -//! for changes and broadcasts updated settings via a `tokio::sync::watch` channel. -//! -//! # Example -//! -//! ```no_run -//! use atuin_client::settings::watcher::global_settings_watcher; -//! -//! async fn example() -> eyre::Result<()> { -//! let watcher = global_settings_watcher()?; -//! let mut rx = watcher.subscribe(); -//! -//! // React to settings changes -//! while rx.changed().await.is_ok() { -//! let settings = rx.borrow(); -//! println!("Settings updated!"); -//! } -//! Ok(()) -//! } -//! ``` - -use std::{ - path::PathBuf, - sync::{Arc, OnceLock}, - time::Duration, -}; - -use eyre::{Result, WrapErr}; -use log::{debug, error, info, warn}; -use notify::{ - Config as NotifyConfig, RecommendedWatcher, RecursiveMode, Watcher, - event::{EventKind, ModifyKind}, -}; -use tokio::sync::watch; - -use super::Settings; - -/// Global singleton for the settings watcher. -static SETTINGS_WATCHER: OnceLock<Result<SettingsWatcher, String>> = OnceLock::new(); - -/// Get the global settings watcher singleton. -/// -/// Initializes the watcher on first call. Subsequent calls return the same instance. -/// The watcher monitors the config file for changes and broadcasts updates. -pub fn global_settings_watcher() -> Result<&'static SettingsWatcher> { - let result = SETTINGS_WATCHER.get_or_init(|| SettingsWatcher::new().map_err(|e| e.to_string())); - - match result { - Ok(watcher) => Ok(watcher), - Err(e) => Err(eyre::eyre!("{}", e)), - } -} - -/// Watches the config file for changes and broadcasts updated settings. -/// -/// Uses `notify` for cross-platform file watching and `tokio::sync::watch` -/// for efficient broadcast to multiple subscribers. -pub struct SettingsWatcher { - /// Receiver for settings updates. Clone this to subscribe. - rx: watch::Receiver<Arc<Settings>>, - /// Keeps the file watcher alive for the lifetime of this struct. - _watcher: RecommendedWatcher, -} - -impl SettingsWatcher { - /// Create a new settings watcher. - /// - /// Loads initial settings and starts watching the config file for changes. - /// Changes are debounced (500ms) to avoid multiple reloads during saves. - pub fn new() -> Result<Self> { - let initial_settings = Arc::new(Settings::new()?); - let (tx, rx) = watch::channel(initial_settings); - - let config_path = Self::config_path(); - info!("starting config file watcher: {:?}", config_path); - - let watcher = Self::create_watcher(tx, config_path)?; - - Ok(Self { - rx, - _watcher: watcher, - }) - } - - /// Subscribe to settings updates. - /// - /// Returns a receiver that will be notified when settings change. - /// Use `changed().await` to wait for the next update, then `borrow()` - /// to access the current settings. - pub fn subscribe(&self) -> watch::Receiver<Arc<Settings>> { - self.rx.clone() - } - - /// Get the current settings without subscribing to updates. - pub fn current(&self) -> Arc<Settings> { - self.rx.borrow().clone() - } - - /// Get the config file path. - fn config_path() -> PathBuf { - let config_dir = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { - PathBuf::from(p) - } else { - atuin_common::utils::config_dir() - }; - config_dir.join("config.toml") - } - - /// Create the file watcher with debouncing. - fn create_watcher( - tx: watch::Sender<Arc<Settings>>, - config_path: PathBuf, - ) -> Result<RecommendedWatcher> { - // Channel for debouncing file events - let (debounce_tx, debounce_rx) = std::sync::mpsc::channel::<()>(); - - // Spawn debounce thread - let config_path_clone = config_path.clone(); - std::thread::spawn(move || { - Self::debounce_loop(debounce_rx, tx, config_path_clone); - }); - - // Clone config_path for use in the watcher callback - let config_path_for_watcher = config_path.clone(); - - // Canonicalize config path for reliable comparison on macOS - // (handles symlinks like /var -> /private/var) - let canonical_config_path = config_path_for_watcher - .canonicalize() - .unwrap_or_else(|_| config_path_for_watcher.clone()); - - // Create file watcher - let mut watcher = RecommendedWatcher::new( - move |res: Result<notify::Event, notify::Error>| { - match res { - Ok(event) => { - // Defensive: if paths is empty, we can't filter, so assume - // it might be our config file and trigger a reload to be safe - if event.paths.is_empty() { - warn!( - "config watcher: event has no paths, triggering reload to be safe" - ); - let _ = debounce_tx.send(()); - return; - } - - // Only react to events for our specific config file - // (filter out editor temp files, backups, etc.) - let is_config_file = event.paths.iter().any(|path| { - // Canonicalize for reliable comparison (handles macOS symlinks) - let canonical_event_path = - path.canonicalize().unwrap_or_else(|_| path.clone()); - - // Check if this event is for our config file - // (either exact match or the file was renamed to our config) - canonical_event_path == canonical_config_path - || path.file_name() == config_path_for_watcher.file_name() - }); - - if !is_config_file { - return; - } - - // Only react to modify events (content changes) or creates - if matches!( - event.kind, - EventKind::Modify(ModifyKind::Data(_) | ModifyKind::Any) - | EventKind::Create(_) - ) { - debug!("config file event detected: {:?}", event); - // Send to debounce channel (ignore send errors - receiver might be gone) - let _ = debounce_tx.send(()); - } - } - Err(e) => { - error!("file watcher error: {}", e); - } - } - }, - NotifyConfig::default(), - ) - .wrap_err("failed to create file watcher")?; - - // Watch the config file's parent directory (some editors create new files) - let watch_path = config_path.parent().unwrap_or(&config_path); - - // Defensive: ensure watch path exists before trying to watch - if !watch_path.exists() { - warn!( - "config directory does not exist, creating it: {:?}", - watch_path - ); - std::fs::create_dir_all(watch_path) - .wrap_err_with(|| format!("failed to create config directory: {:?}", watch_path))?; - } - - watcher - .watch(watch_path, RecursiveMode::NonRecursive) - .wrap_err_with(|| format!("failed to watch config directory: {:?}", watch_path))?; - - info!("config file watcher initialized for: {:?}", watch_path); - Ok(watcher) - } - - /// Debounce loop that batches file events and reloads settings. - fn debounce_loop( - rx: std::sync::mpsc::Receiver<()>, - tx: watch::Sender<Arc<Settings>>, - config_path: PathBuf, - ) { - const DEBOUNCE_DURATION: Duration = Duration::from_millis(500); - - loop { - // Wait for first event - if rx.recv().is_err() { - // Channel closed, watcher was dropped - debug!("config watcher debounce loop exiting"); - return; - } - - // Drain any additional events within debounce window - while rx.recv_timeout(DEBOUNCE_DURATION).is_ok() { - // Keep draining - } - - // Defensive: check if config file exists before reloading - // (handles case where file was deleted - we'll get notified when it's recreated) - if !config_path.exists() { - debug!( - "config file does not exist, skipping reload: {:?}", - config_path - ); - continue; - } - - // Now reload settings - info!("config file changed, reloading settings: {:?}", config_path); - match Settings::new() { - Ok(settings) => { - if tx.send(Arc::new(settings)).is_err() { - // All receivers dropped - debug!("all settings subscribers dropped, exiting"); - return; - } - info!("settings reloaded successfully"); - } - Err(e) => { - warn!("failed to reload settings: {}", e); - // Keep the old settings, don't broadcast the error - } - } - } - } -} diff --git a/crates/atuin-client/src/sync.rs b/crates/atuin-client/src/sync.rs deleted file mode 100644 index 2c902794..00000000 --- a/crates/atuin-client/src/sync.rs +++ /dev/null @@ -1,213 +0,0 @@ -use std::collections::HashSet; -use std::iter::FromIterator; - -use eyre::Result; - -use atuin_common::api::AddHistoryRequest; -use crypto_secretbox::Key; -use time::OffsetDateTime; - -use crate::{ - api_client, - database::Database, - encryption::{decrypt, encrypt, load_key}, - settings::Settings, -}; - -pub fn hash_str(string: &str) -> String { - use sha2::{Digest, Sha256}; - let mut hasher = Sha256::new(); - hasher.update(string.as_bytes()); - hex::encode(hasher.finalize()) -} - -// Currently sync is kinda naive, and basically just pages backwards through -// history. This means newly added stuff shows up properly! We also just use -// the total count in each database to indicate whether a sync is needed. -// I think this could be massively improved! If we had a way of easily -// indicating count per time period (hour, day, week, year, etc) then we can -// easily pinpoint where we are missing data and what needs downloading. Start -// with year, then find the week, then the day, then the hour, then download it -// all! The current naive approach will do for now. - -// Check if remote has things we don't, and if so, download them. -// Returns (num downloaded, total local) -async fn sync_download( - key: &Key, - force: bool, - client: &api_client::Client<'_>, - db: &impl Database, -) -> Result<(i64, i64)> { - debug!("starting sync download"); - - let remote_status = client.status().await?; - let remote_count = remote_status.count; - - // useful to ensure we don't even save something that hasn't yet been synced + deleted - let remote_deleted = - HashSet::<&str>::from_iter(remote_status.deleted.iter().map(String::as_str)); - - let initial_local = db.history_count(true).await?; - let mut local_count = initial_local; - - let mut last_sync = if force { - OffsetDateTime::UNIX_EPOCH - } else { - Settings::last_sync().await? - }; - - let mut last_timestamp = OffsetDateTime::UNIX_EPOCH; - - let host = if force { Some(String::from("")) } else { None }; - - while remote_count > local_count { - let page = client - .get_history(last_sync, last_timestamp, host.clone()) - .await?; - - let history: Vec<_> = page - .history - .iter() - // TODO: handle deletion earlier in this chain - .map(|h| serde_json::from_str(h).expect("invalid base64")) - .map(|h| decrypt(h, key).expect("failed to decrypt history! check your key")) - .map(|mut h| { - if remote_deleted.contains(h.id.0.as_str()) { - h.deleted_at = Some(time::OffsetDateTime::now_utc()); - h.command = String::from(""); - } - - h - }) - .collect(); - - db.save_bulk(&history).await?; - - local_count = db.history_count(true).await?; - let remote_page_size = std::cmp::max(remote_status.page_size, 0) as usize; - - if history.len() < remote_page_size { - break; - } - - let page_last = history - .last() - .expect("could not get last element of page") - .timestamp; - - // in the case of a small sync frequency, it's possible for history to - // be "lost" between syncs. In this case we need to rewind the sync - // timestamps - if page_last == last_timestamp { - last_timestamp = OffsetDateTime::UNIX_EPOCH; - last_sync -= time::Duration::hours(1); - } else { - last_timestamp = page_last; - } - } - - for i in remote_status.deleted { - // we will update the stored history to have this data - // pretty much everything can be nullified - match db.load(i.as_str()).await? { - Some(h) => { - db.delete(h).await?; - } - _ => { - info!( - "could not delete history with id {}, not found locally", - i.as_str() - ); - } - } - } - - Ok((local_count - initial_local, local_count)) -} - -// Check if we have things remote doesn't, and if so, upload them -async fn sync_upload( - key: &Key, - _force: bool, - client: &api_client::Client<'_>, - db: &impl Database, -) -> Result<()> { - debug!("starting sync upload"); - - let remote_status = client.status().await?; - let remote_deleted: HashSet<String> = HashSet::from_iter(remote_status.deleted.clone()); - - let initial_remote_count = client.count().await?; - let mut remote_count = initial_remote_count; - - let local_count = db.history_count(true).await?; - - debug!("remote has {remote_count}, we have {local_count}"); - - // first just try the most recent set - let mut cursor = OffsetDateTime::now_utc(); - - while local_count > remote_count { - let last = db.before(cursor, remote_status.page_size).await?; - let mut buffer = Vec::new(); - - if last.is_empty() { - break; - } - - for i in last { - let data = encrypt(&i, key)?; - let data = serde_json::to_string(&data)?; - - let add_hist = AddHistoryRequest { - id: i.id.to_string(), - timestamp: i.timestamp, - data, - hostname: hash_str(&i.hostname), - }; - - buffer.push(add_hist); - } - - // anything left over outside of the 100 block size - client.post_history(&buffer).await?; - cursor = buffer.last().unwrap().timestamp; - remote_count = client.count().await?; - - debug!("upload cursor: {cursor:?}"); - } - - let deleted = db.deleted().await?; - - for i in deleted { - if remote_deleted.contains(&i.id.to_string()) { - continue; - } - - info!("deleting {} on remote", i.id); - client.delete_history(i).await?; - } - - Ok(()) -} - -pub async fn sync(settings: &Settings, force: bool, db: &impl Database) -> Result<()> { - let client = api_client::Client::new( - &settings.sync_address, - settings.sync_auth_token().await?, - settings.network_connect_timeout, - settings.network_timeout, - )?; - - Settings::save_sync_time().await?; - - let key = load_key(settings)?; // encryption key - - sync_upload(&key, force, &client, db).await?; - - let download = sync_download(&key, force, &client, db).await?; - - debug!("sync downloaded {}", download.0); - - Ok(()) -} diff --git a/crates/atuin-client/src/theme.rs b/crates/atuin-client/src/theme.rs deleted file mode 100644 index a277ac13..00000000 --- a/crates/atuin-client/src/theme.rs +++ /dev/null @@ -1,831 +0,0 @@ -use config::{Config, File as ConfigFile, FileFormat}; -use log; -use palette::named; -use serde::{Deserialize, Serialize}; -use serde_json; -use std::collections::HashMap; -use std::error; -use std::io::{Error, ErrorKind}; -use std::path::PathBuf; -use std::sync::LazyLock; -use strum_macros; - -static DEFAULT_MAX_DEPTH: u8 = 10; - -// Collection of settable "meanings" that can have colors set. -// NOTE: You can add a new meaning here without breaking backwards compatibility but please: -// - update the atuin/docs repository, which has a list of available meanings -// - add a fallback in the MEANING_FALLBACKS below, so that themes which do not have it -// get a sensible fallback (see Title as an example) -#[derive( - Serialize, Deserialize, Copy, Clone, Hash, Debug, Eq, PartialEq, strum_macros::Display, -)] -#[strum(serialize_all = "camel_case")] -pub enum Meaning { - AlertInfo, - AlertWarn, - AlertError, - Annotation, - Base, - Guidance, - Important, - Title, - Muted, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct ThemeConfig { - // Definition of the theme - pub theme: ThemeDefinitionConfigBlock, - - // Colors - pub colors: HashMap<Meaning, String>, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct ThemeDefinitionConfigBlock { - /// Name of theme ("default" for base) - pub name: String, - - /// Whether any theme should be treated as a parent _if available_ - pub parent: Option<String>, -} - -use crossterm::style::{Attribute, Attributes, Color, ContentStyle}; - -// For now, a theme is loaded as a mapping of meanings to colors, but it may be desirable to -// expand that in the future to general styles, so we populate a Meaning->ContentStyle hashmap. -pub struct Theme { - pub name: String, - pub parent: Option<String>, - pub styles: HashMap<Meaning, ContentStyle>, -} - -// Themes have a number of convenience functions for the most commonly used meanings. -// The general purpose `as_style` routine gives back a style, but for ease-of-use and to keep -// theme-related boilerplate minimal, the convenience functions give a color. -impl Theme { - // This is the base "default" color, for general text - pub fn get_base(&self) -> ContentStyle { - self.styles[&Meaning::Base] - } - - pub fn get_info(&self) -> ContentStyle { - self.get_alert(log::Level::Info) - } - - pub fn get_warning(&self) -> ContentStyle { - self.get_alert(log::Level::Warn) - } - - pub fn get_error(&self) -> ContentStyle { - self.get_alert(log::Level::Error) - } - - // The alert meanings may be chosen by the Level enum, rather than the methods above - // or the full Meaning enum, to simplify programmatic selection of a log-level. - pub fn get_alert(&self, severity: log::Level) -> ContentStyle { - self.styles[ALERT_TYPES.get(&severity).unwrap()] - } - - pub fn new( - name: String, - parent: Option<String>, - styles: HashMap<Meaning, ContentStyle>, - ) -> Theme { - Theme { - name, - parent, - styles, - } - } - - pub fn closest_meaning<'a>(&self, meaning: &'a Meaning) -> &'a Meaning { - if self.styles.contains_key(meaning) { - meaning - } else if MEANING_FALLBACKS.contains_key(meaning) { - self.closest_meaning(&MEANING_FALLBACKS[meaning]) - } else { - &Meaning::Base - } - } - - // General access - if you have a meaning, this will give you a (crossterm) style - pub fn as_style(&self, meaning: Meaning) -> ContentStyle { - self.styles[self.closest_meaning(&meaning)] - } - - // Turns a map of meanings to colornames into a theme - // If theme-debug is on, then we will print any colornames that we cannot load, - // but we do not have this on in general, as it could print unfiltered text to the terminal - // from a theme TOML file. However, it will always return a theme, falling back to - // defaults on error, so that a TOML file does not break loading - pub fn from_foreground_colors( - name: String, - parent: Option<&Theme>, - foreground_colors: HashMap<Meaning, String>, - debug: bool, - ) -> Theme { - let styles: HashMap<Meaning, ContentStyle> = foreground_colors - .iter() - .map(|(name, color)| { - ( - *name, - StyleFactory::from_fg_string(color).unwrap_or_else(|err| { - if debug { - log::warn!("Tried to load string as a color unsuccessfully: ({name}={color}) {err}"); - } - ContentStyle::default() - }), - ) - }) - .collect(); - Theme::from_map(name, parent, &styles) - } - - // Boil down a meaning-color hashmap into a theme, by taking the defaults - // for any unknown colors - fn from_map( - name: String, - parent: Option<&Theme>, - overrides: &HashMap<Meaning, ContentStyle>, - ) -> Theme { - let styles = match parent { - Some(theme) => Box::new(theme.styles.clone()), - None => Box::new(DEFAULT_THEME.styles.clone()), - } - .iter() - .map(|(name, color)| match overrides.get(name) { - Some(value) => (*name, *value), - None => (*name, *color), - }) - .collect(); - Theme::new(name, parent.map(|p| p.name.clone()), styles) - } -} - -// Use palette to get a color from a string name, if possible -fn from_string(name: &str) -> Result<Color, String> { - if name.is_empty() { - return Err("Empty string".into()); - } - let first_char = name.chars().next().unwrap(); - match first_char { - '#' => { - let hexcode = &name[1..]; - let vec: Vec<u8> = hexcode - .chars() - .collect::<Vec<char>>() - .chunks(2) - .map(|pair| u8::from_str_radix(pair.iter().collect::<String>().as_str(), 16)) - .filter_map(|n| n.ok()) - .collect(); - if vec.len() != 3 { - return Err("Could not parse 3 hex values from string".into()); - } - Ok(Color::Rgb { - r: vec[0], - g: vec[1], - b: vec[2], - }) - } - '@' => { - // For full flexibility, we need to use serde_json, given - // crossterm's approach. - serde_json::from_str::<Color>(format!("\"{}\"", &name[1..]).as_str()) - .map_err(|_| format!("Could not convert color name {name} to Crossterm color")) - } - _ => { - let srgb = named::from_str(name).ok_or("No such color in palette")?; - Ok(Color::Rgb { - r: srgb.red, - g: srgb.green, - b: srgb.blue, - }) - } - } -} - -pub struct StyleFactory {} - -impl StyleFactory { - fn from_fg_string(name: &str) -> Result<ContentStyle, String> { - match from_string(name) { - Ok(color) => Ok(Self::from_fg_color(color)), - Err(err) => Err(err), - } - } - - // For succinctness, if we are confident that the name will be known, - // this routine is available to keep the code readable - fn known_fg_string(name: &str) -> ContentStyle { - Self::from_fg_string(name).unwrap() - } - - fn from_fg_color(color: Color) -> ContentStyle { - ContentStyle { - foreground_color: Some(color), - ..ContentStyle::default() - } - } - - fn from_fg_color_and_attributes(color: Color, attributes: Attributes) -> ContentStyle { - ContentStyle { - foreground_color: Some(color), - attributes, - ..ContentStyle::default() - } - } -} - -// Built-in themes. Rather than having extra files added before any theming -// is available, this gives a couple of basic options, demonstrating the use -// of themes: autumn and marine -static ALERT_TYPES: LazyLock<HashMap<log::Level, Meaning>> = LazyLock::new(|| { - HashMap::from([ - (log::Level::Info, Meaning::AlertInfo), - (log::Level::Warn, Meaning::AlertWarn), - (log::Level::Error, Meaning::AlertError), - ]) -}); - -static MEANING_FALLBACKS: LazyLock<HashMap<Meaning, Meaning>> = LazyLock::new(|| { - HashMap::from([ - (Meaning::Guidance, Meaning::AlertInfo), - (Meaning::Annotation, Meaning::AlertInfo), - (Meaning::Title, Meaning::Important), - ]) -}); - -static DEFAULT_THEME: LazyLock<Theme> = LazyLock::new(|| { - Theme::new( - "default".to_string(), - None, - HashMap::from([ - ( - Meaning::AlertError, - StyleFactory::from_fg_color(Color::DarkRed), - ), - ( - Meaning::AlertWarn, - StyleFactory::from_fg_color(Color::DarkYellow), - ), - ( - Meaning::AlertInfo, - StyleFactory::from_fg_color(Color::DarkGreen), - ), - ( - Meaning::Annotation, - StyleFactory::from_fg_color(Color::DarkGrey), - ), - ( - Meaning::Guidance, - StyleFactory::from_fg_color(Color::DarkBlue), - ), - ( - Meaning::Important, - StyleFactory::from_fg_color_and_attributes( - Color::White, - Attributes::from(Attribute::Bold), - ), - ), - (Meaning::Muted, StyleFactory::from_fg_color(Color::Grey)), - (Meaning::Base, ContentStyle::default()), - ]), - ) -}); - -static BUILTIN_THEMES: LazyLock<HashMap<&'static str, Theme>> = LazyLock::new(|| { - HashMap::from([ - ("default", HashMap::new()), - ( - "(none)", - HashMap::from([ - (Meaning::AlertError, ContentStyle::default()), - (Meaning::AlertWarn, ContentStyle::default()), - (Meaning::AlertInfo, ContentStyle::default()), - (Meaning::Annotation, ContentStyle::default()), - (Meaning::Guidance, ContentStyle::default()), - (Meaning::Important, ContentStyle::default()), - (Meaning::Muted, ContentStyle::default()), - (Meaning::Base, ContentStyle::default()), - ]), - ), - ( - "autumn", - HashMap::from([ - ( - Meaning::AlertError, - StyleFactory::known_fg_string("saddlebrown"), - ), - ( - Meaning::AlertWarn, - StyleFactory::known_fg_string("darkorange"), - ), - (Meaning::AlertInfo, StyleFactory::known_fg_string("gold")), - ( - Meaning::Annotation, - StyleFactory::from_fg_color(Color::DarkGrey), - ), - (Meaning::Guidance, StyleFactory::known_fg_string("brown")), - ]), - ), - ( - "marine", - HashMap::from([ - ( - Meaning::AlertError, - StyleFactory::known_fg_string("yellowgreen"), - ), - (Meaning::AlertWarn, StyleFactory::known_fg_string("cyan")), - ( - Meaning::AlertInfo, - StyleFactory::known_fg_string("turquoise"), - ), - ( - Meaning::Annotation, - StyleFactory::known_fg_string("steelblue"), - ), - ( - Meaning::Base, - StyleFactory::known_fg_string("lightsteelblue"), - ), - (Meaning::Guidance, StyleFactory::known_fg_string("teal")), - ]), - ), - ]) - .iter() - .map(|(name, theme)| (*name, Theme::from_map(name.to_string(), None, theme))) - .collect() -}); - -// To avoid themes being repeatedly loaded, we store them in a theme manager -pub struct ThemeManager { - loaded_themes: HashMap<String, Theme>, - debug: bool, - override_theme_dir: Option<String>, -} - -// Theme-loading logic -impl ThemeManager { - pub fn new(debug: Option<bool>, theme_dir: Option<String>) -> Self { - Self { - loaded_themes: HashMap::new(), - debug: debug.unwrap_or(false), - override_theme_dir: match theme_dir { - Some(theme_dir) => Some(theme_dir), - None => std::env::var("ATUIN_THEME_DIR").ok(), - }, - } - } - - // Try to load a theme from a `{name}.toml` file in the theme directory. If an override is set - // for the theme dir (via ATUIN_THEME_DIR env) we should load the theme from there - pub fn load_theme_from_file( - &mut self, - name: &str, - max_depth: u8, - ) -> Result<&Theme, Box<dyn error::Error>> { - let mut theme_file = if let Some(p) = &self.override_theme_dir { - if p.is_empty() { - return Err(Box::new(Error::new( - ErrorKind::NotFound, - "Empty theme directory override and could not find theme elsewhere", - ))); - } - PathBuf::from(p) - } else { - let config_dir = atuin_common::utils::config_dir(); - let mut theme_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { - PathBuf::from(p) - } else { - let mut theme_file = PathBuf::new(); - theme_file.push(config_dir); - theme_file - }; - theme_file.push("themes"); - theme_file - }; - - let theme_toml = format!["{name}.toml"]; - theme_file.push(theme_toml); - - let mut config_builder = Config::builder(); - - config_builder = config_builder.add_source(ConfigFile::new( - theme_file.to_str().unwrap(), - FileFormat::Toml, - )); - - let config = config_builder.build()?; - self.load_theme_from_config(name, config, max_depth) - } - - pub fn load_theme_from_config( - &mut self, - name: &str, - config: Config, - max_depth: u8, - ) -> Result<&Theme, Box<dyn error::Error>> { - let debug = self.debug; - let theme_config: ThemeConfig = match config.try_deserialize() { - Ok(tc) => tc, - Err(e) => { - return Err(Box::new(Error::new( - ErrorKind::InvalidInput, - format!( - "Failed to deserialize theme: {}", - if debug { - e.to_string() - } else { - "set theme debug on for more info".to_string() - } - ), - ))); - } - }; - let colors: HashMap<Meaning, String> = theme_config.colors; - let parent: Option<&Theme> = match theme_config.theme.parent { - Some(parent_name) => { - if max_depth == 0 { - return Err(Box::new(Error::new( - ErrorKind::InvalidInput, - "Parent requested but we hit the recursion limit", - ))); - } - Some(self.load_theme(parent_name.as_str(), Some(max_depth - 1))) - } - None => Some(self.load_theme("default", Some(max_depth - 1))), - }; - - if debug && name != theme_config.theme.name { - log::warn!( - "Your theme config name is not the name of your loaded theme {} != {}", - name, - theme_config.theme.name - ); - } - - let theme = Theme::from_foreground_colors(theme_config.theme.name, parent, colors, debug); - let name = name.to_string(); - self.loaded_themes.insert(name.clone(), theme); - let theme = self.loaded_themes.get(&name).unwrap(); - Ok(theme) - } - - // Check if the requested theme is loaded and, if not, then attempt to get it - // from the builtins or, if not there, from file - pub fn load_theme(&mut self, name: &str, max_depth: Option<u8>) -> &Theme { - if self.loaded_themes.contains_key(name) { - return self.loaded_themes.get(name).unwrap(); - } - let built_ins = &BUILTIN_THEMES; - match built_ins.get(name) { - Some(theme) => theme, - None => match self.load_theme_from_file(name, max_depth.unwrap_or(DEFAULT_MAX_DEPTH)) { - Ok(theme) => theme, - Err(err) => { - log::warn!("Could not load theme {name}: {err}"); - built_ins.get("(none)").unwrap() - } - }, - } - } -} - -#[cfg(test)] -mod theme_tests { - use super::*; - - #[test] - fn test_can_load_builtin_theme() { - let mut manager = ThemeManager::new(Some(false), Some("".to_string())); - let theme = manager.load_theme("autumn", None); - assert_eq!( - theme.as_style(Meaning::Guidance).foreground_color, - from_string("brown").ok() - ); - } - - #[test] - fn test_can_create_theme() { - let mut manager = ThemeManager::new(Some(false), Some("".to_string())); - let mytheme = Theme::new( - "mytheme".to_string(), - None, - HashMap::from([( - Meaning::AlertError, - StyleFactory::known_fg_string("yellowgreen"), - )]), - ); - manager.loaded_themes.insert("mytheme".to_string(), mytheme); - let theme = manager.load_theme("mytheme", None); - assert_eq!( - theme.as_style(Meaning::AlertError).foreground_color, - from_string("yellowgreen").ok() - ); - } - - #[test] - fn test_can_fallback_when_meaning_missing() { - let mut manager = ThemeManager::new(Some(false), Some("".to_string())); - - // We use title as an example of a meaning that is not defined - // even in the base theme. - assert!(!DEFAULT_THEME.styles.contains_key(&Meaning::Title)); - - let config = Config::builder() - .add_source(ConfigFile::from_str( - " - [theme] - name = \"title_theme\" - - [colors] - Guidance = \"white\" - AlertInfo = \"zomp\" - ", - FileFormat::Toml, - )) - .build() - .unwrap(); - let theme = manager - .load_theme_from_config("config_theme", config, 1) - .unwrap(); - - // Correctly picks overridden color. - assert_eq!( - theme.as_style(Meaning::Guidance).foreground_color, - from_string("white").ok() - ); - - // Does not fall back to any color. - assert_eq!(theme.as_style(Meaning::AlertInfo).foreground_color, None); - - // Even for the base. - assert_eq!(theme.as_style(Meaning::Base).foreground_color, None); - - // Falls back to red as meaning missing from theme, so picks base default. - assert_eq!( - theme.as_style(Meaning::AlertError).foreground_color, - Some(Color::DarkRed) - ); - - // Falls back to Important as Title not available. - assert_eq!( - theme.as_style(Meaning::Title).foreground_color, - theme.as_style(Meaning::Important).foreground_color, - ); - - let title_config = Config::builder() - .add_source(ConfigFile::from_str( - " - [theme] - name = \"title_theme\" - - [colors] - Title = \"white\" - AlertInfo = \"zomp\" - ", - FileFormat::Toml, - )) - .build() - .unwrap(); - let title_theme = manager - .load_theme_from_config("title_theme", title_config, 1) - .unwrap(); - - assert_eq!( - title_theme.as_style(Meaning::Title).foreground_color, - Some(Color::White) - ); - } - - #[test] - fn test_no_fallbacks_are_circular() { - let mytheme = Theme::new("mytheme".to_string(), None, HashMap::from([])); - MEANING_FALLBACKS - .iter() - .for_each(|pair| assert_eq!(mytheme.closest_meaning(pair.0), &Meaning::Base)) - } - - #[test] - fn test_can_get_colors_via_convenience_functions() { - let mut manager = ThemeManager::new(Some(true), Some("".to_string())); - let theme = manager.load_theme("default", None); - assert_eq!(theme.get_error().foreground_color.unwrap(), Color::DarkRed); - assert_eq!( - theme.get_warning().foreground_color.unwrap(), - Color::DarkYellow - ); - assert_eq!(theme.get_info().foreground_color.unwrap(), Color::DarkGreen); - assert_eq!(theme.get_base().foreground_color, None); - assert_eq!( - theme.get_alert(log::Level::Error).foreground_color.unwrap(), - Color::DarkRed - ) - } - - #[test] - fn test_can_use_parent_theme_for_fallbacks() { - testing_logger::setup(); - - let mut manager = ThemeManager::new(Some(false), Some("".to_string())); - - // First, we introduce a base theme - let solarized = Config::builder() - .add_source(ConfigFile::from_str( - " - [theme] - name = \"solarized\" - - [colors] - Guidance = \"white\" - AlertInfo = \"pink\" - ", - FileFormat::Toml, - )) - .build() - .unwrap(); - let solarized_theme = manager - .load_theme_from_config("solarized", solarized, 1) - .unwrap(); - - assert_eq!( - solarized_theme - .as_style(Meaning::AlertInfo) - .foreground_color, - from_string("pink").ok() - ); - - // Then we introduce a derived theme - let unsolarized = Config::builder() - .add_source(ConfigFile::from_str( - " - [theme] - name = \"unsolarized\" - parent = \"solarized\" - - [colors] - AlertInfo = \"red\" - ", - FileFormat::Toml, - )) - .build() - .unwrap(); - let unsolarized_theme = manager - .load_theme_from_config("unsolarized", unsolarized, 1) - .unwrap(); - - // It will take its own values - assert_eq!( - unsolarized_theme - .as_style(Meaning::AlertInfo) - .foreground_color, - from_string("red").ok() - ); - - // ...or fall back to the parent - assert_eq!( - unsolarized_theme - .as_style(Meaning::Guidance) - .foreground_color, - from_string("white").ok() - ); - - testing_logger::validate(|captured_logs| assert_eq!(captured_logs.len(), 0)); - - // If the parent is not found, we end up with the no theme colors or styling - // as this is considered a (soft) error state. - let nunsolarized = Config::builder() - .add_source(ConfigFile::from_str( - " - [theme] - name = \"nunsolarized\" - parent = \"nonsolarized\" - - [colors] - AlertInfo = \"red\" - ", - FileFormat::Toml, - )) - .build() - .unwrap(); - let nunsolarized_theme = manager - .load_theme_from_config("nunsolarized", nunsolarized, 1) - .unwrap(); - - assert_eq!( - nunsolarized_theme - .as_style(Meaning::Guidance) - .foreground_color, - None - ); - - testing_logger::validate(|captured_logs| { - assert_eq!(captured_logs.len(), 1); - assert_eq!( - captured_logs[0].body, - "Could not load theme nonsolarized: Empty theme directory override and could not find theme elsewhere" - ); - assert_eq!(captured_logs[0].level, log::Level::Warn) - }); - } - - #[test] - fn test_can_debug_theme() { - testing_logger::setup(); - [true, false].iter().for_each(|debug| { - let mut manager = ThemeManager::new(Some(*debug), Some("".to_string())); - let config = Config::builder() - .add_source(ConfigFile::from_str( - " - [theme] - name = \"mytheme\" - - [colors] - Guidance = \"white\" - AlertInfo = \"xinetic\" - ", - FileFormat::Toml, - )) - .build() - .unwrap(); - manager - .load_theme_from_config("config_theme", config, 1) - .unwrap(); - testing_logger::validate(|captured_logs| { - if *debug { - assert_eq!(captured_logs.len(), 2); - assert_eq!( - captured_logs[0].body, - "Your theme config name is not the name of your loaded theme config_theme != mytheme" - ); - assert_eq!(captured_logs[0].level, log::Level::Warn); - assert_eq!( - captured_logs[1].body, - "Tried to load string as a color unsuccessfully: (AlertInfo=xinetic) No such color in palette" - ); - assert_eq!(captured_logs[1].level, log::Level::Warn) - } else { - assert_eq!(captured_logs.len(), 0) - } - }) - }) - } - - #[test] - fn test_can_parse_color_strings_correctly() { - assert_eq!( - from_string("brown").unwrap(), - Color::Rgb { - r: 165, - g: 42, - b: 42 - } - ); - - assert_eq!(from_string(""), Err("Empty string".into())); - - ["manatee", "caput mortuum", "123456"] - .iter() - .for_each(|inp| { - assert_eq!(from_string(inp), Err("No such color in palette".into())); - }); - - assert_eq!( - from_string("#ff1122").unwrap(), - Color::Rgb { - r: 255, - g: 17, - b: 34 - } - ); - ["#1122", "#ffaa112", "#brown"].iter().for_each(|inp| { - assert_eq!( - from_string(inp), - Err("Could not parse 3 hex values from string".into()) - ); - }); - - assert_eq!(from_string("@dark_grey").unwrap(), Color::DarkGrey); - assert_eq!( - from_string("@rgb_(255,255,255)").unwrap(), - Color::Rgb { - r: 255, - g: 255, - b: 255 - } - ); - assert_eq!(from_string("@ansi_(255)").unwrap(), Color::AnsiValue(255)); - ["@", "@DarkGray", "@Dark 4ay", "@ansi(256)"] - .iter() - .for_each(|inp| { - assert_eq!( - from_string(inp), - Err(format!( - "Could not convert color name {inp} to Crossterm color" - )) - ); - }); - } -} diff --git a/crates/atuin-client/src/utils.rs b/crates/atuin-client/src/utils.rs deleted file mode 100644 index 35d7db26..00000000 --- a/crates/atuin-client/src/utils.rs +++ /dev/null @@ -1,14 +0,0 @@ -pub(crate) fn get_hostname() -> String { - std::env::var("ATUIN_HOST_NAME") - .unwrap_or_else(|_| whoami::hostname().unwrap_or_else(|_| "unknown-host".to_string())) -} - -pub(crate) fn get_username() -> String { - std::env::var("ATUIN_HOST_USER") - .unwrap_or_else(|_| whoami::username().unwrap_or_else(|_| "unknown-user".to_string())) -} - -/// Returns a pair of the hostname and username, separated by a colon. -pub(crate) fn get_host_user() -> String { - format!("{}:{}", get_hostname(), get_username()) -} |
