diff options
| author | Benedikt Peetz <benedikt.peetz@b-peetz.de> | 2026-06-11 00:54:30 +0200 |
|---|---|---|
| committer | Benedikt Peetz <benedikt.peetz@b-peetz.de> | 2026-06-11 00:54:30 +0200 |
| commit | 5c39e7cf284a1f6e9a1657f2deb44e359fc47eb8 (patch) | |
| tree | c64baa8d5866c8e339eaf660dd3f94f30a3f7d8a /crates/turtle | |
| parent | chore: Somewhat simplify sync code (diff) | |
| download | atuin-5c39e7cf284a1f6e9a1657f2deb44e359fc47eb8.zip | |
chore: Move everything into one big crate
That helps remove duplicated code and rustc/cargo will now also show
dead code correctly.
Diffstat (limited to 'crates/turtle')
164 files changed, 39127 insertions, 0 deletions
diff --git a/crates/turtle/Cargo.toml b/crates/turtle/Cargo.toml new file mode 100644 index 00000000..87557905 --- /dev/null +++ b/crates/turtle/Cargo.toml @@ -0,0 +1,142 @@ +[package] +name = "atuin" +edition = "2024" +description = "atuin - magical shell history" +readme = "./README.md" + +rust-version = { workspace = true } +version = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } + +[features] +default = [ + "clipboard", + "daemon", + "hex", + "sync", + "client", +] + +clipboard = ["arboard"] +daemon = ["pty-proxy"] +pty-proxy = [] +client = [] +hex = ["dep:hex"] +sync = ["urlencoding", "reqwest", "sha2", "hex"] + +[dependencies] +argon2 = "0.5" +async-trait = "0.1.58" +atuin-nucleo-matcher = { workspace = true } +atuin-nucleo = { workspace = true } +axum = "0.8" +base64 = "0.22" +clap = { version = "4.5.7", features = ["derive"] } +clap_complete = "4.5.8" +clap_complete_nushell = "4.5.4" +colored = "2.0.4" +config = { version = "0.15.8", default-features = false, features = ["toml"] } +crossterm = {version = "0.29.0", features = ["use-dev-tty", "serde"] } +crypto_secretbox = "0.1.1" +dashmap = "6.1.0" +directories = "6.0.0" +eyre = "0.6" +fs-err = "3.1" +fs4 = "0.13.1" +futures = "0.3" +futures-util = "0.3" +fuzzy-matcher = "0.3.7" +generic-array = { version = "0.14", features = ["serde"] } +getrandom = "0.2" +glob-match = "0.2.1" +hex = { version = "0.4", optional = true } +humantime = "2.1.0" +hyper-util = "0.1" +imara-diff = "0.2" +indicatif = "0.18.0" +interim = { version = "0.2.0", features = ["time_0_3"] } +itertools = "0.14.0" +lasso = { version = "0.7", features = ["multi-threaded"] } +log = "0.4" +memchr = "2.7" +metrics = "0.24" +metrics-exporter-prometheus = { version = "0.18", default-features = false } +minijinja = "2.9.0" +minspan = "0.1.5" +norm = { version = "0.1.1", features = ["fzf-v2"] } +notify = "7" +open = "5" +palette = { version = "0.7.5", features = ["serializing"] } +pretty_assertions = "1.3.0" +prost = "0.14" +prost-types = "0.14" +rand = { version = "0.8.5", features = ["std"] } +ratatui = "0.30.0" +regex = "1.10.5" +reqwest = { version = "0.13", optional = true, features = ["json", "rustls-no-provider", "stream"], default-features = false } +rmp = { version = "0.8.14" } +rpassword = "7.0" +runtime-format = "0.1.3" +rustix = { version = "1.1.4", features = ["process", "fs"] } +rustls = { version = "0.23", default-features = false, features = [ "ring", "std", "tls12", ] } +rusty_paserk = { version = "0.5.0", default-features = false, features = [ "v4", "serde", ] } +rusty_paseto = { version = "0.8.0", default-features = false } +semver = "1.0.20" +serde = { version = "1.0.202", features = ["derive"] } +serde_json = "1.0.119" +serde_regex = "1.1.0" +serde_with = "3.8.1" +sha2 = { version = "0.10", optional = true } +shellexpand = "3" +shlex = "1.3.0" +sql-builder = "3" +sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "time", "postgres", "uuid", "sqlite", "regexp"] } +strum = { version = "0.27", features = ["strum_macros"] } +strum_macros = "0.27" +sysinfo = "0.30.7" +tempfile = { version = "3.19" } +thiserror = "2" +time = { version = "0.3.47", features = [ "serde-human-readable", "macros", "local-offset", "macros", "formatting", "parsing"] } +tokio = { version = "1", features = ["full"] } +tokio-stream = { version = "0.1.14", features = ["net"] } +toml_edit = "0.25.4" +tonic = "0.14" +tonic-prost = "0.14" +tonic-types = "0.14" +tower = "0.5" +tower-http = { version = "0.6", features = ["trace"] } +tracing = "0.1" +tracing-appender = "0.2" +tracing-subscriber = { version = "0.3", features = ["ansi", "fmt", "registry", "env-filter", "json"] } +typed-builder = "0.18.2" +unicode-segmentation = "1.11.0" +unicode-width = "0.2" +url = "2.5.2" +urlencoding = { version = "2.1.0", optional = true } +uuid = { version = "1.9", features = ["v4", "v7", "serde"] } +vt100 = "0.16" +whoami = "2.1.0" +xxhash-rust = { version = "0.8", features = ["xxh3"] } + +[target.'cfg(target_os = "linux")'.dependencies] +arboard = { version = "3.4", optional = true, default-features = false, features = [ "wayland-data-control", ] } +listenfd = "1.0.1" + +[target.'cfg(unix)'.dependencies] +daemonize = "0.5.0" +portable-pty = "0.9" +signal-hook = "0.3" + +[dev-dependencies] +tracing-tree = "0.4" +divan = "0.1.14" +tokio = { version = "1", features = ["full"] } +testing_logger = "0.1.1" + +[build-dependencies] +protox = "0.9" +tonic-build = "0.14" +tonic-prost-build = "0.14" diff --git a/crates/turtle/build.rs b/crates/turtle/build.rs new file mode 100644 index 00000000..5f26e26c --- /dev/null +++ b/crates/turtle/build.rs @@ -0,0 +1,39 @@ +use std::process::Command; +use std::{env, fs, path::PathBuf}; + +use protox::prost::Message; + +fn main() -> Result<(), std::io::Error> { + { + let output = Command::new("git").args(["rev-parse", "HEAD"]).output(); + + let sha = match output { + Ok(sha) => String::from_utf8(sha.stdout).unwrap(), + Err(_) => String::from("NO_GIT"), + }; + + println!("cargo:rustc-env=GIT_HASH={sha}"); + } + + { + let proto_paths = [ + "proto/history.proto", + "proto/search.proto", + "proto/control.proto", + "proto/semantic.proto", + ]; + let proto_include_dirs = ["proto"]; + + let file_descriptors = protox::compile(proto_paths, proto_include_dirs).unwrap(); + + let file_descriptor_path = PathBuf::from(env::var_os("OUT_DIR").expect("OUT_DIR not set")) + .join("file_descriptor_set.bin"); + fs::write(&file_descriptor_path, file_descriptors.encode_to_vec()).unwrap(); + + tonic_prost_build::configure() + .build_server(true) + .file_descriptor_set_path(&file_descriptor_path) + .skip_protoc_run() + .compile_protos(&proto_paths, &proto_include_dirs) + } +} diff --git a/crates/turtle/proto/control.proto b/crates/turtle/proto/control.proto new file mode 100644 index 00000000..06347902 --- /dev/null +++ b/crates/turtle/proto/control.proto @@ -0,0 +1,62 @@ +syntax = "proto3"; +package control; + +// The Control service allows external processes (CLI commands, etc.) +// to inject events into the running daemon. +service Control { + // Send an event to the daemon's event bus + rpc SendEvent(SendEventRequest) returns (SendEventResponse); +} + +message SendEventRequest { + oneof event { + // History was pruned - search index needs full rebuild + HistoryPrunedEvent history_pruned = 1; + + // Specific history items were deleted + HistoryDeletedEvent history_deleted = 2; + + // Request immediate sync + ForceSyncEvent force_sync = 3; + + // Settings have changed, reload if needed + SettingsReloadedEvent settings_reloaded = 4; + + // Request graceful shutdown + ShutdownEvent shutdown = 5; + + // History was rebuilt - search index needs full rebuild + HistoryRebuiltEvent history_rebuilt = 6; + } +} + +message SendEventResponse { + // Empty on success; errors come through gRPC status +} + +// Individual event message types + +message HistoryPrunedEvent { + // No fields needed - just signals that pruning happened +} + +message HistoryRebuiltEvent { + // No fields needed - just signals that rebuilding happened +} + +message HistoryDeletedEvent { + // IDs of deleted history items (UUIDs as strings) + repeated string ids = 1; +} + +message ForceSyncEvent { + // No fields needed - just triggers sync +} + +message SettingsReloadedEvent { + // No fields needed - components should re-read settings +} + +message ShutdownEvent { + // No fields needed - triggers graceful shutdown +} diff --git a/crates/turtle/proto/history.proto b/crates/turtle/proto/history.proto new file mode 100644 index 00000000..59c12471 --- /dev/null +++ b/crates/turtle/proto/history.proto @@ -0,0 +1,81 @@ +syntax = "proto3"; +package history; + +message StartHistoryRequest { + // If people are still using my software in ~530 years, they can figure out a u128 migration + uint64 timestamp = 1; // nanosecond unix epoch + string command = 2; + string cwd = 3; + string session = 4; + string hostname = 5; + string author = 6; + string intent = 7; +} + +message EndHistoryRequest { + string id = 1; + int64 exit = 2; + uint64 duration = 3; +} + +message StartHistoryReply { + string id = 1; + string version = 2; + uint32 protocol = 3; +} + +message EndHistoryReply { + string id = 1; + uint64 idx = 2; + string version = 3; + uint32 protocol = 4; +} + +message StatusRequest {} + +message StatusReply { + bool healthy = 1; + string version = 2; + uint32 pid = 3; + uint32 protocol = 4; +} + +message ShutdownRequest {} + +message ShutdownReply { + bool accepted = 1; +} + +message TailHistoryRequest {} + +enum HistoryEventKind { + HISTORY_EVENT_KIND_UNSPECIFIED = 0; + HISTORY_EVENT_KIND_STARTED = 1; + HISTORY_EVENT_KIND_ENDED = 2; +} + +message HistoryEntry { + uint64 timestamp = 1; // nanosecond unix epoch + string id = 2; + string command = 3; + string cwd = 4; + string session = 5; + string hostname = 6; + string author = 7; + string intent = 8; + int64 exit = 9; + int64 duration = 10; +} + +message TailHistoryReply { + HistoryEventKind kind = 1; + HistoryEntry history = 2; +} + +service History { + rpc StartHistory(StartHistoryRequest) returns (StartHistoryReply); + rpc EndHistory(EndHistoryRequest) returns (EndHistoryReply); + rpc TailHistory(TailHistoryRequest) returns (stream TailHistoryReply); + rpc Status(StatusRequest) returns (StatusReply); + rpc Shutdown(ShutdownRequest) returns (ShutdownReply); +} diff --git a/crates/turtle/proto/search.proto b/crates/turtle/proto/search.proto new file mode 100644 index 00000000..6b84acbd --- /dev/null +++ b/crates/turtle/proto/search.proto @@ -0,0 +1,35 @@ +syntax = "proto3"; +package search; + +enum FilterMode { + GLOBAL = 0; + HOST = 1; + SESSION = 2; + DIRECTORY = 3; + WORKSPACE = 4; + SESSION_PRELOAD = 5; +} + +message SearchContext { + string session_id = 1; + string cwd = 2; + string hostname = 3; + string host_id = 4; + optional string git_root = 5; +} + +message SearchRequest { + string query = 1; + uint64 query_id = 2; // Incrementing ID to match responses to queries + FilterMode filter_mode = 3; + SearchContext context = 4; +} + +message SearchResponse { + uint64 query_id = 1; // Echo back the query ID + repeated bytes ids = 2; +} + +service Search { + rpc Search(stream SearchRequest) returns (stream SearchResponse); +} diff --git a/crates/turtle/proto/semantic.proto b/crates/turtle/proto/semantic.proto new file mode 100644 index 00000000..07e550c8 --- /dev/null +++ b/crates/turtle/proto/semantic.proto @@ -0,0 +1,47 @@ +syntax = "proto3"; +package semantic; + +service Semantic { + rpc RecordCommands(stream CommandCapture) returns (RecordCommandsReply); + rpc CommandOutput(CommandOutputRequest) returns (CommandOutputReply); +} + +message CommandCapture { + string prompt = 1; + string command = 2; + string output = 3; + optional int32 exit_code = 4; + optional string history_id = 5; + optional string session_id = 6; + bool output_truncated = 7; + uint64 output_observed_bytes = 8; +} + +message RecordCommandsReply { + uint64 accepted = 1; +} + +message CommandOutputRequest { + string history_id = 1; + repeated OutputRange ranges = 2; +} + +message OutputRange { + int64 start = 1; + int64 end = 2; +} + +message OutputLine { + uint64 line_number = 1; + string content = 2; +} + +message CommandOutputReply { + bool found = 1; + string output = 2; + uint64 total_bytes = 3; + uint64 total_lines = 4; + repeated OutputLine lines = 5; + bool output_truncated = 6; + uint64 output_observed_bytes = 7; +} diff --git a/crates/turtle/src/atuin_client/api_client.rs b/crates/turtle/src/atuin_client/api_client.rs new file mode 100644 index 00000000..7955c2da --- /dev/null +++ b/crates/turtle/src/atuin_client/api_client.rs @@ -0,0 +1,438 @@ +use std::collections::HashMap; +use std::env; +use std::time::Duration; + +use eyre::{Result, bail, eyre}; +use reqwest::{ + Response, StatusCode, Url, + header::{AUTHORIZATION, HeaderMap, USER_AGENT}, +}; +use tracing::debug; + +use crate::atuin_common::{ + api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ATUIN_VERSION}, + record::{EncryptedData, HostId, Record, RecordIdx}, + tls::ensure_crypto_provider, +}; +use crate::atuin_common::{ + api::{ + AddHistoryRequest, ChangePasswordRequest, CountResponse, DeleteHistoryRequest, + ErrorResponse, LoginRequest, LoginResponse, MeResponse, RegisterResponse, StatusResponse, + SyncHistoryResponse, + }, + record::RecordStatus, +}; + +use semver::Version; +use time::OffsetDateTime; +use time::format_description::well_known::Rfc3339; + +use crate::atuin_client::{history::History, sync::hash_str, utils::get_host_user}; + +static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION"),); + +/// Authentication token for sync API requests. +/// +/// The sync API supports two authentication methods: +/// - `Bearer`: Hub API tokens (for users authenticated via Atuin Hub) +/// - `Token`: Legacy CLI session tokens (for users registered via CLI or self-hosted) +/// +/// When both are available, Hub tokens are preferred as they provide unified +/// authentication across CLI and Hub features. +#[derive(Debug, Clone)] +pub enum AuthToken { + /// Legacy CLI session token, used with "Token {token}" header + Token(String), +} + +impl AuthToken { + /// Format the token as an Authorization header value + fn to_header_value(&self) -> String { + match self { + AuthToken::Token(token) => format!("Token {token}"), + } + } +} + +pub struct Client<'a> { + sync_addr: &'a str, + client: reqwest::Client, +} + +fn make_url(address: &str, path: &str) -> Result<String> { + // `join()` expects a trailing `/` in order to join paths + // e.g. it treats `http://host:port/subdir` as a file called `subdir` + let address = if address.ends_with("/") { + address + } else { + &format!("{address}/") + }; + + // passing a path with a leading `/` will cause `join()` to replace the entire URL path + let path = path.strip_prefix("/").unwrap_or(path); + + let url = Url::parse(address) + .map(|url| url.join(path))? + .map_err(|_| eyre!("invalid address"))?; + + Ok(url.to_string()) +} + +pub async fn register( + address: &str, + username: &str, + email: &str, + password: &str, +) -> Result<RegisterResponse> { + ensure_crypto_provider(); + let mut map = HashMap::new(); + map.insert("username", username); + map.insert("email", email); + map.insert("password", password); + + let url = make_url(address, &format!("/user/{username}"))?; + let resp = reqwest::get(url).await?; + + if resp.status().is_success() { + bail!("username already in use"); + } + + let url = make_url(address, "/register")?; + let client = reqwest::Client::new(); + let resp = client + .post(url) + .header(USER_AGENT, APP_USER_AGENT) + .header(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION) + .json(&map) + .send() + .await?; + let resp = handle_resp_error(resp).await?; + + if !ensure_version(&resp)? { + bail!("could not register user due to version mismatch"); + } + + let session = resp.json::<RegisterResponse>().await?; + Ok(session) +} + +pub async fn login(address: &str, req: LoginRequest) -> Result<LoginResponse> { + ensure_crypto_provider(); + let url = make_url(address, "/login")?; + let client = reqwest::Client::new(); + + let resp = client + .post(url) + .header(USER_AGENT, APP_USER_AGENT) + .json(&req) + .send() + .await?; + let resp = handle_resp_error(resp).await?; + + if !ensure_version(&resp)? { + bail!("Could not login due to version mismatch"); + } + + let session = resp.json::<LoginResponse>().await?; + Ok(session) +} + +pub fn ensure_version(response: &Response) -> Result<bool> { + let version = response.headers().get(ATUIN_HEADER_VERSION); + + let version = if let Some(version) = version { + match version.to_str() { + Ok(v) => Version::parse(v), + Err(e) => bail!("failed to parse server version: {:?}", e), + } + } else { + bail!("Server not reporting its version: it is either too old or unhealthy"); + }?; + + // If the client is newer than the server + if version.major < ATUIN_VERSION.major { + println!( + "Atuin version mismatch! In order to successfully sync, the server needs to run a newer version of Atuin" + ); + println!("Client: {ATUIN_CARGO_VERSION}"); + println!("Server: {version}"); + + return Ok(false); + } + + Ok(true) +} + +async fn handle_resp_error(resp: Response) -> Result<Response> { + let status = resp.status(); + let url = resp.url().to_string(); + + if status == StatusCode::SERVICE_UNAVAILABLE { + bail!( + "Service unavailable: check https://status.atuin.sh (or get in touch with your host)" + ); + } + + if status == StatusCode::TOO_MANY_REQUESTS { + bail!("Rate limited; please wait before doing that again"); + } + + if !status.is_success() { + if let Ok(error) = resp.json::<ErrorResponse>().await { + let reason = error.reason; + + if status.is_client_error() { + bail!("Invalid request to the service at {url}, {status} - {reason}.") + } + + bail!( + "There was an error with the atuin sync service at {url}, server error {status}: {reason}.\nIf the problem persists, contact the host" + ) + } + + bail!( + "There was an error with the atuin sync service at {url}, Status {status:?}.\nIf the problem persists, contact the host" + ) + } + + Ok(resp) +} + +impl<'a> Client<'a> { + pub fn new( + sync_addr: &'a str, + auth: AuthToken, + connect_timeout: u64, + timeout: u64, + ) -> Result<Self> { + ensure_crypto_provider(); + let mut headers = HeaderMap::new(); + headers.insert(AUTHORIZATION, auth.to_header_value().parse()?); + + // used for semver server check + headers.insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse()?); + + Ok(Client { + sync_addr, + client: reqwest::Client::builder() + .user_agent(APP_USER_AGENT) + .default_headers(headers) + .connect_timeout(Duration::new(connect_timeout, 0)) + .timeout(Duration::new(timeout, 0)) + .build()?, + }) + } + + pub async fn count(&self) -> Result<i64> { + let url = make_url(self.sync_addr, "/sync/count")?; + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + if !ensure_version(&resp)? { + bail!("could not sync due to version mismatch"); + } + + if resp.status() != StatusCode::OK { + bail!("failed to get count (are you logged in?)"); + } + + let count = resp.json::<CountResponse>().await?; + + Ok(count.count) + } + + pub async fn status(&self) -> Result<StatusResponse> { + let url = make_url(self.sync_addr, "/sync/status")?; + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + if !ensure_version(&resp)? { + bail!("could not sync due to version mismatch"); + } + + let status = resp.json::<StatusResponse>().await?; + + Ok(status) + } + + pub async fn me(&self) -> Result<MeResponse> { + let url = make_url(self.sync_addr, "/api/v0/me")?; + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + let status = resp.json::<MeResponse>().await?; + + Ok(status) + } + + pub async fn get_history( + &self, + sync_ts: OffsetDateTime, + history_ts: OffsetDateTime, + host: Option<String>, + ) -> Result<SyncHistoryResponse> { + let host = host.unwrap_or_else(|| hash_str(&get_host_user())); + + let url = make_url( + self.sync_addr, + &format!( + "/sync/history?sync_ts={}&history_ts={}&host={}", + urlencoding::encode(sync_ts.format(&Rfc3339)?.as_str()), + urlencoding::encode(history_ts.format(&Rfc3339)?.as_str()), + host, + ), + )?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + let history = resp.json::<SyncHistoryResponse>().await?; + Ok(history) + } + + pub async fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> { + let url = make_url(self.sync_addr, "/history")?; + let url = Url::parse(url.as_str())?; + + let resp = self.client.post(url).json(history).send().await?; + handle_resp_error(resp).await?; + + Ok(()) + } + + pub async fn delete_history(&self, h: History) -> Result<()> { + let url = make_url(self.sync_addr, "/history")?; + let url = Url::parse(url.as_str())?; + + let resp = self + .client + .delete(url) + .json(&DeleteHistoryRequest { + client_id: h.id.to_string(), + }) + .send() + .await?; + + handle_resp_error(resp).await?; + + Ok(()) + } + + pub async fn delete_store(&self) -> Result<()> { + let url = make_url(self.sync_addr, "/api/v0/store")?; + let url = Url::parse(url.as_str())?; + + let resp = self.client.delete(url).send().await?; + + handle_resp_error(resp).await?; + + Ok(()) + } + + pub async fn post_records(&self, records: &[Record<EncryptedData>]) -> Result<()> { + let url = make_url(self.sync_addr, "/api/v0/record")?; + let url = Url::parse(url.as_str())?; + + debug!("uploading {} records to {url}", records.len()); + + let resp = self.client.post(url).json(records).send().await?; + handle_resp_error(resp).await?; + + Ok(()) + } + + pub async fn next_records( + &self, + host: HostId, + tag: String, + start: RecordIdx, + count: u64, + ) -> Result<Vec<Record<EncryptedData>>> { + debug!("fetching record/s from host {}/{}/{}", host.0, tag, start); + + let url = make_url( + self.sync_addr, + &format!( + "/api/v0/record/next?host={}&tag={}&count={}&start={}", + host.0, tag, count, start + ), + )?; + + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + let records = resp.json::<Vec<Record<EncryptedData>>>().await?; + + Ok(records) + } + + pub async fn record_status(&self) -> Result<RecordStatus> { + let url = make_url(self.sync_addr, "/api/v0/record")?; + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + if !ensure_version(&resp)? { + bail!("could not sync records due to version mismatch"); + } + + let index = resp.json().await?; + + debug!("got remote index {index:?}"); + + Ok(index) + } + + pub async fn delete(&self) -> Result<()> { + let url = make_url(self.sync_addr, "/account")?; + let url = Url::parse(url.as_str())?; + + let resp = self.client.delete(url).send().await?; + + if resp.status() == 403 { + bail!("invalid login details"); + } else if resp.status() == 200 { + Ok(()) + } else { + bail!("Unknown error"); + } + } + + pub async fn change_password( + &self, + current_password: String, + new_password: String, + ) -> Result<()> { + let url = make_url(self.sync_addr, "/account/password")?; + let url = Url::parse(url.as_str())?; + + let resp = self + .client + .patch(url) + .json(&ChangePasswordRequest { + current_password, + new_password, + }) + .send() + .await?; + + if resp.status() == 401 { + bail!("current password is incorrect") + } else if resp.status() == 403 { + bail!("invalid login details"); + } else if resp.status() == 200 { + Ok(()) + } else { + bail!("Unknown error"); + } + } +} diff --git a/crates/turtle/src/atuin_client/auth.rs b/crates/turtle/src/atuin_client/auth.rs new file mode 100644 index 00000000..b770c488 --- /dev/null +++ b/crates/turtle/src/atuin_client/auth.rs @@ -0,0 +1,223 @@ +use async_trait::async_trait; +use eyre::{Context, Result, bail}; +use reqwest::{Url, header::USER_AGENT}; + +use crate::{ + atuin_client::api_client, + atuin_common::{ + api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ChangePasswordRequest, LoginRequest}, + tls::ensure_crypto_provider, + }, +}; + +use crate::atuin_client::settings::Settings; + +static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION")); + +/// Result of an auth operation that may require 2FA. +pub enum AuthResponse { + /// Operation succeeded; for login/register, contains the session token. + /// `auth_type` indicates the kind of token: `Some("hub")` for Hub API + /// tokens (prefixed `atapi_`), `Some("cli")` for legacy CLI session + /// tokens. `None` when the server didn't include the field (old servers). + Success { + session: String, + auth_type: Option<String>, + }, + /// Two-factor authentication is required; the caller should prompt for a + /// TOTP code and retry with it. + TwoFactorRequired, +} + +/// Result of a mutating account operation that may require 2FA. +pub enum MutateResponse { + /// Operation completed successfully. + Success, + /// Two-factor authentication is required; the caller should prompt for a + /// TOTP code and retry. + TwoFactorRequired, +} + +/// Abstraction over the legacy (Rust sync server) and Hub auth APIs. +/// +/// CLI commands use this trait so they don't need to know which backend is +/// active — they just prompt for input and call these methods. +#[async_trait] +pub trait AuthClient: Send + Sync { + /// Log in with username + password, optionally providing a TOTP code. + async fn login(&self, username: &str, password: &str) -> Result<AuthResponse>; + + /// Register a new account. + async fn register(&self, username: &str, email: &str, password: &str) -> Result<AuthResponse>; + + /// Change the account password, optionally providing a TOTP code. + async fn change_password( + &self, + current_password: &str, + new_password: &str, + totp_code: Option<&str>, + ) -> Result<MutateResponse>; + + /// Delete the account, requiring the current password and optionally a TOTP code. + async fn delete_account( + &self, + password: &str, + totp_code: Option<&str>, + ) -> Result<MutateResponse>; +} + +/// Resolve the appropriate [`AuthClient`] for the current settings. +pub async fn auth_client(settings: &Settings) -> Box<dyn AuthClient> { + Box::new(LegacyAuthClient::new( + &settings.sync_address, + settings.session_token().await.ok(), + settings.network_connect_timeout, + settings.network_timeout, + )) as Box<dyn AuthClient> +} + +// --------------------------------------------------------------------------- +// Legacy backend — talks to the Rust sync server +// --------------------------------------------------------------------------- + +pub struct LegacyAuthClient { + address: String, + session_token: Option<String>, + connect_timeout: u64, + timeout: u64, +} + +impl LegacyAuthClient { + pub fn new( + address: &str, + session_token: Option<String>, + connect_timeout: u64, + timeout: u64, + ) -> Self { + Self { + address: address.to_string(), + session_token, + connect_timeout, + timeout, + } + } + + fn authenticated_client(&self) -> Result<reqwest::Client> { + let token = self + .session_token + .as_deref() + .ok_or_else(|| eyre::eyre!("Not logged in"))?; + + ensure_crypto_provider(); + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + reqwest::header::AUTHORIZATION, + format!("Token {token}").parse()?, + ); + headers.insert(USER_AGENT, APP_USER_AGENT.parse()?); + headers.insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse()?); + + Ok(reqwest::Client::builder() + .default_headers(headers) + .connect_timeout(std::time::Duration::new(self.connect_timeout, 0)) + .timeout(std::time::Duration::new(self.timeout, 0)) + .build()?) + } +} + +#[async_trait] +impl AuthClient for LegacyAuthClient { + async fn login(&self, username: &str, password: &str) -> Result<AuthResponse> { + // The legacy server has no 2FA support; totp_code is ignored. + let resp = api_client::login( + &self.address, + LoginRequest { + username: username.to_string(), + password: password.to_string(), + }, + ) + .await?; + + Ok(AuthResponse::Success { + session: resp.session, + auth_type: resp.auth.or(Some("cli".into())), + }) + } + + async fn register(&self, username: &str, email: &str, password: &str) -> Result<AuthResponse> { + let resp = api_client::register(&self.address, username, email, password).await?; + Ok(AuthResponse::Success { + session: resp.session, + auth_type: resp.auth.or(Some("cli".into())), + }) + } + + async fn change_password( + &self, + current_password: &str, + new_password: &str, + _totp_code: Option<&str>, + ) -> Result<MutateResponse> { + let client = self.authenticated_client()?; + let url = make_url(&self.address, "/account/password")?; + + let resp = client + .patch(&url) + .json(&ChangePasswordRequest { + current_password: current_password.to_string(), + new_password: new_password.to_string(), + }) + .send() + .await?; + + match resp.status().as_u16() { + 200 => Ok(MutateResponse::Success), + 401 => bail!("current password is incorrect"), + 403 => bail!("invalid login details"), + _ => bail!("unknown error"), + } + } + + async fn delete_account( + &self, + password: &str, + _totp_code: Option<&str>, + ) -> Result<MutateResponse> { + let client = self.authenticated_client()?; + let url = make_url(&self.address, "/account")?; + + let resp = client + .delete(&url) + .json(&serde_json::json!({ "password": password })) + .send() + .await?; + + match resp.status().as_u16() { + 200 => Ok(MutateResponse::Success), + 401 => bail!("password is incorrect"), + 403 => bail!("invalid login details"), + _ => bail!("unknown error"), + } + } +} + +// --------------------------------------------------------------------------- +// Shared helpers +// --------------------------------------------------------------------------- + +fn make_url(address: &str, path: &str) -> Result<String> { + let address = if address.ends_with('/') { + address.to_string() + } else { + format!("{address}/") + }; + + let path = path.strip_prefix('/').unwrap_or(path); + + let url = Url::parse(&address) + .context("failed to parse server address")? + .join(path) + .context("failed to join URL path")?; + + Ok(url.to_string()) +} diff --git a/crates/turtle/src/atuin_client/database.rs b/crates/turtle/src/atuin_client/database.rs new file mode 100644 index 00000000..75b1200c --- /dev/null +++ b/crates/turtle/src/atuin_client/database.rs @@ -0,0 +1,1526 @@ +use std::{ + env, + path::{Path, PathBuf}, + str::FromStr, + time::Duration, +}; + +use crate::atuin_client::history::{AUTHOR_FILTER_ALL_AGENT, AUTHOR_FILTER_ALL_USER, KNOWN_AGENTS}; +use crate::atuin_common::utils; +use async_trait::async_trait; +use fs_err as fs; +use itertools::Itertools; +use rand::{Rng, distributions::Alphanumeric}; +use sql_builder::{SqlBuilder, SqlName, bind::Bind, esc, quote}; +use sqlx::{ + Result, Row, + sqlite::{ + SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow, + SqliteSynchronous, + }, +}; +use time::OffsetDateTime; +use tracing::debug; +use uuid::Uuid; + +use crate::atuin_client::{ + history::{HistoryId, HistoryStats}, + utils::get_host_user, +}; + +use super::{ + history::History, + ordering, + settings::{FilterMode, SearchMode, Settings}, +}; + +#[derive(Clone)] +pub struct Context { + pub session: String, + pub cwd: String, + pub hostname: String, + pub host_id: String, + pub git_root: Option<PathBuf>, +} + +#[derive(Default, Clone)] +pub struct OptFilters { + pub exit: Option<i64>, + pub exclude_exit: Option<i64>, + pub cwd: Option<String>, + pub exclude_cwd: Option<String>, + pub before: Option<String>, + pub after: Option<String>, + pub limit: Option<i64>, + pub offset: Option<i64>, + pub reverse: bool, + pub include_duplicates: bool, + /// Author filter. Supports special values `$all-user` and `$all-agent`. + pub authors: Vec<String>, +} + +pub async fn current_context() -> eyre::Result<Context> { + let session = env::var("ATUIN_SESSION").map_err(|_| { + eyre::eyre!("Failed to find $ATUIN_SESSION in the environment. Check that you have correctly set up your shell.") + })?; + let hostname = get_host_user(); + let cwd = utils::get_current_dir(); + let host_id = Settings::host_id().await?; + let git_root = utils::in_git_repo(cwd.as_str()); + + Ok(Context { + session, + hostname, + cwd, + git_root, + host_id: host_id.0.as_simple().to_string(), + }) +} + +impl Context { + pub fn from_history(entry: &History) -> Self { + Context { + session: entry.session.to_string(), + cwd: entry.cwd.to_string(), + hostname: entry.hostname.to_string(), + host_id: String::new(), + git_root: utils::in_git_repo(entry.cwd.as_str()), + } + } +} + +/// Each entry is OR'd: `$all-user` → NOT IN agents, `$all-agent` → IN agents, literal → exact match. +fn apply_author_filter(sql: &mut SqlBuilder, authors: &[String]) { + let mut conditions: Vec<String> = Vec::new(); + let agent_list: String = KNOWN_AGENTS.iter().map(quote).join(", "); + let author_expr = "CASE \ + WHEN author IS NULL OR trim(author) = '' THEN \ + CASE \ + WHEN instr(hostname, ':') > 0 THEN substr(hostname, instr(hostname, ':') + 1) \ + ELSE hostname \ + END \ + ELSE author \ + END"; + + for author in authors { + match author.as_str() { + AUTHOR_FILTER_ALL_USER => { + conditions.push(format!("{author_expr} NOT IN ({agent_list})")); + } + AUTHOR_FILTER_ALL_AGENT => { + conditions.push(format!("{author_expr} IN ({agent_list})")); + } + literal => { + conditions.push(format!("{author_expr} = {}", quote(literal))); + } + } + } + + if !conditions.is_empty() { + sql.and_where(format!("({})", conditions.join(" OR "))); + } +} + +fn get_session_start_time(session_id: &str) -> Option<i64> { + if let Ok(uuid) = Uuid::parse_str(session_id) + && let Some(timestamp) = uuid.get_timestamp() + { + let (seconds, nanos) = timestamp.to_unix(); + return Some(seconds as i64 * 1_000_000_000 + nanos as i64); + } + None +} + +#[async_trait] +pub trait Database: Send + Sync + 'static { + async fn save(&self, h: &History) -> Result<()>; + async fn save_bulk(&self, h: &[History]) -> Result<()>; + + async fn load(&self, id: &str) -> Result<Option<History>>; + async fn list( + &self, + filters: &[FilterMode], + context: &Context, + max: Option<usize>, + unique: bool, + include_deleted: bool, + ) -> Result<Vec<History>>; + async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>>; + + async fn update(&self, h: &History) -> Result<()>; + async fn history_count(&self, include_deleted: bool) -> Result<i64>; + + async fn last(&self) -> Result<Option<History>>; + async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>>; + + async fn delete(&self, h: History) -> Result<()>; + async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()>; + async fn deleted(&self) -> Result<Vec<History>>; + + // Yes I know, it's a lot. + // Could maybe break it down to a searchparams struct or smth but that feels a little... pointless. + // Been debating maybe a DSL for search? eg "before:time limit:1 the query" + #[expect(clippy::too_many_arguments)] + async fn search( + &self, + search_mode: SearchMode, + filter: FilterMode, + context: &Context, + query: &str, + filter_options: OptFilters, + ) -> Result<Vec<History>>; + + async fn query_history(&self, query: &str) -> Result<Vec<History>>; + + async fn all_with_count(&self) -> Result<Vec<(History, i32)>>; + + fn all_paged(&self, page_size: usize, include_deleted: bool, unique: bool) -> Paged; + + async fn stats(&self, h: &History) -> Result<HistoryStats>; + + async fn get_dups(&self, before: i64, dupkeep: u32) -> Result<Vec<History>>; + + fn clone_boxed(&self) -> Box<dyn Database + 'static>; +} + +// Intended for use on a developer machine and not a sync server. +// TODO: implement IntoIterator +#[derive(Debug, Clone)] +pub struct Sqlite { + pub pool: SqlitePool, +} + +impl Sqlite { + pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> { + let path = path.as_ref(); + debug!("opening sqlite database at {path:?}"); + + if utils::broken_symlink(path) { + eprintln!( + "Atuin: Sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement." + ); + std::process::exit(1); + } + + if !path.exists() + && let Some(dir) = path.parent() + { + fs::create_dir_all(dir)?; + } + + let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())? + .journal_mode(SqliteJournalMode::Wal) + .optimize_on_close(true, None) + .synchronous(SqliteSynchronous::Normal) + .with_regexp() + .create_if_missing(true); + + let pool = SqlitePoolOptions::new() + .acquire_timeout(Duration::from_secs_f64(timeout)) + .connect_with(opts) + .await?; + + Self::setup_db(&pool).await?; + Ok(Self { pool }) + } + + pub async fn sqlite_version(&self) -> Result<String> { + sqlx::query_scalar("SELECT sqlite_version()") + .fetch_one(&self.pool) + .await + } + + async fn setup_db(pool: &SqlitePool) -> Result<()> { + debug!("running sqlite database setup"); + + sqlx::migrate!("./migrations").run(pool).await?; + + Ok(()) + } + + async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, h: &History) -> Result<()> { + sqlx::query( + "insert or ignore into history(id, timestamp, duration, exit, command, cwd, session, hostname, author, intent, deleted_at) + values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)", + ) + .bind(h.id.0.as_str()) + .bind(h.timestamp.unix_timestamp_nanos() as i64) + .bind(h.duration) + .bind(h.exit) + .bind(h.command.as_str()) + .bind(h.cwd.as_str()) + .bind(h.session.as_str()) + .bind(h.hostname.as_str()) + .bind(h.author.as_str()) + .bind(h.intent.as_deref()) + .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64)) + .execute(&mut **tx) + .await?; + + Ok(()) + } + + async fn delete_row_raw( + tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, + id: HistoryId, + ) -> Result<()> { + sqlx::query("delete from history where id = ?1") + .bind(id.0.as_str()) + .execute(&mut **tx) + .await?; + + Ok(()) + } + + fn query_history(row: SqliteRow) -> History { + let deleted_at: Option<i64> = row.get("deleted_at"); + let hostname: String = row.get("hostname"); + let author: Option<String> = row.try_get("author").ok().flatten(); + let author = author + .filter(|author| !author.trim().is_empty()) + .unwrap_or_else(|| History::author_from_hostname(hostname.as_str())); + let intent: Option<String> = row.try_get("intent").ok().flatten(); + let intent = intent.filter(|intent| !intent.trim().is_empty()); + + History::from_db() + .id(row.get("id")) + .timestamp( + OffsetDateTime::from_unix_timestamp_nanos(row.get::<i64, _>("timestamp") as i128) + .unwrap(), + ) + .duration(row.get("duration")) + .exit(row.get("exit")) + .command(row.get("command")) + .cwd(row.get("cwd")) + .session(row.get("session")) + .hostname(hostname) + .author(author) + .intent(intent) + .deleted_at( + deleted_at.and_then(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128).ok()), + ) + .build() + .into() + } +} + +#[async_trait] +impl Database for Sqlite { + async fn save(&self, h: &History) -> Result<()> { + debug!("saving history to sqlite"); + let mut tx = self.pool.begin().await?; + Self::save_raw(&mut tx, h).await?; + tx.commit().await?; + + Ok(()) + } + + async fn save_bulk(&self, h: &[History]) -> Result<()> { + debug!("saving history to sqlite"); + + let mut tx = self.pool.begin().await?; + + for i in h { + Self::save_raw(&mut tx, i).await?; + } + + tx.commit().await?; + + Ok(()) + } + + async fn load(&self, id: &str) -> Result<Option<History>> { + debug!("loading history item {}", id); + + let res = sqlx::query("select * from history where id = ?1") + .bind(id) + .map(Self::query_history) + .fetch_optional(&self.pool) + .await?; + + Ok(res) + } + + async fn update(&self, h: &History) -> Result<()> { + debug!("updating sqlite history"); + + sqlx::query( + "update history + set timestamp = ?2, duration = ?3, exit = ?4, command = ?5, cwd = ?6, session = ?7, hostname = ?8, author = ?9, intent = ?10, deleted_at = ?11 + where id = ?1", + ) + .bind(h.id.0.as_str()) + .bind(h.timestamp.unix_timestamp_nanos() as i64) + .bind(h.duration) + .bind(h.exit) + .bind(h.command.as_str()) + .bind(h.cwd.as_str()) + .bind(h.session.as_str()) + .bind(h.hostname.as_str()) + .bind(h.author.as_str()) + .bind(h.intent.as_deref()) + .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64)) + .execute(&self.pool) + .await?; + + Ok(()) + } + + // make a unique list, that only shows the *newest* version of things + async fn list( + &self, + filters: &[FilterMode], + context: &Context, + max: Option<usize>, + unique: bool, + include_deleted: bool, + ) -> Result<Vec<History>> { + debug!("listing history"); + + let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted()); + query.field("*").order_desc("timestamp"); + if !include_deleted { + query.and_where_is_null("deleted_at"); + } + + let git_root = if let Some(git_root) = context.git_root.clone() { + git_root.to_str().unwrap_or("/").to_string() + } else { + context.cwd.clone() + }; + + let session_start = get_session_start_time(&context.session); + + for filter in filters { + match filter { + FilterMode::Global => &mut query, + FilterMode::Host => query.and_where_eq("hostname", quote(&context.hostname)), + FilterMode::Session => query.and_where_eq("session", quote(&context.session)), + FilterMode::SessionPreload => { + query.and_where_eq("session", quote(&context.session)); + if let Some(session_start) = session_start { + query.or_where_lt("timestamp", session_start); + } + &mut query + } + FilterMode::Directory => query.and_where_eq("cwd", quote(&context.cwd)), + FilterMode::Workspace => query.and_where_like_left("cwd", &git_root), + }; + } + + if unique { + query.group_by("command").having("max(timestamp)"); + } + + if let Some(max) = max { + query.limit(max); + } + + let query = query.sql().expect("bug in list query. please report"); + + let res = sqlx::query(&query) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>> { + debug!("listing history from {:?} to {:?}", from, to); + + let res = sqlx::query( + "select * from history where timestamp >= ?1 and timestamp <= ?2 order by timestamp asc", + ) + .bind(from.unix_timestamp_nanos() as i64) + .bind(to.unix_timestamp_nanos() as i64) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn last(&self) -> Result<Option<History>> { + let res = sqlx::query( + "select * from history where duration >= 0 order by timestamp desc limit 1", + ) + .map(Self::query_history) + .fetch_optional(&self.pool) + .await?; + + Ok(res) + } + + async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>> { + let res = sqlx::query( + "select * from history where timestamp < ?1 order by timestamp desc limit ?2", + ) + .bind(timestamp.unix_timestamp_nanos() as i64) + .bind(count) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn deleted(&self) -> Result<Vec<History>> { + let res = sqlx::query("select * from history where deleted_at is not null") + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn history_count(&self, include_deleted: bool) -> Result<i64> { + let query = if include_deleted { + "select count(1) from history" + } else { + "select count(1) from history where deleted_at is null" + }; + + let res: (i64,) = sqlx::query_as(query).fetch_one(&self.pool).await?; + Ok(res.0) + } + + async fn search( + &self, + search_mode: SearchMode, + filter: FilterMode, + context: &Context, + query: &str, + filter_options: OptFilters, + ) -> Result<Vec<History>> { + let mut sql = SqlBuilder::select_from("history"); + + if !filter_options.include_duplicates { + sql.group_by("command").having("max(timestamp)"); + } + + if let Some(limit) = filter_options.limit { + sql.limit(limit); + } + + if let Some(offset) = filter_options.offset { + sql.offset(offset); + } + + if filter_options.reverse { + sql.order_asc("timestamp"); + } else { + sql.order_desc("timestamp"); + } + + let git_root = if let Some(git_root) = context.git_root.clone() { + git_root.to_str().unwrap_or("/").to_string() + } else { + context.cwd.clone() + }; + + let session_start = get_session_start_time(&context.session); + + match filter { + FilterMode::Global => &mut sql, + FilterMode::Host => { + sql.and_where_eq("lower(hostname)", quote(context.hostname.to_lowercase())) + } + FilterMode::Session => sql.and_where_eq("session", quote(&context.session)), + FilterMode::SessionPreload => { + sql.and_where_eq("session", quote(&context.session)); + if let Some(session_start) = session_start { + sql.or_where_lt("timestamp", session_start); + } + &mut sql + } + FilterMode::Directory => sql.and_where_eq("cwd", quote(&context.cwd)), + FilterMode::Workspace => sql.and_where_like_left("cwd", git_root), + }; + + let orig_query = query; + + let mut regexes = Vec::new(); + match search_mode { + SearchMode::Prefix => sql.and_where_like_left("command", query.replace('*', "%")), + _ => { + let mut is_or = false; + for token in QueryTokenizer::new(query) { + // TODO smart case mode could be made configurable like in fzf + let (is_glob, glob) = if token.has_uppercase() { + (true, "*") + } else { + (false, "%") + }; + let param = match token { + QueryToken::Regex(r) => { + regexes.push(String::from(r)); + continue; + } + QueryToken::Or => { + if !is_or { + is_or = true; + continue; + } else { + format!("{glob}|{glob}") + } + } + QueryToken::MatchStart(term, _) => { + format!("{term}{glob}") + } + QueryToken::MatchEnd(term, _) => { + format!("{glob}{term}") + } + QueryToken::MatchFull(term, _) => { + format!("{glob}{term}{glob}") + } + QueryToken::Match(term, _) => { + if search_mode == SearchMode::FullText { + format!("{glob}{term}{glob}") + } else { + term.split("").join(glob) + } + } + }; + + sql.fuzzy_condition("command", param, token.is_inverse(), is_glob, is_or); + is_or = false; + } + + &mut sql + } + }; + + for regex in regexes { + sql.and_where("command regexp ?".bind(®ex)); + } + + filter_options + .exit + .map(|exit| sql.and_where_eq("exit", exit)); + + filter_options + .exclude_exit + .map(|exclude_exit| sql.and_where_ne("exit", exclude_exit)); + + filter_options + .cwd + .map(|cwd| sql.and_where_eq("cwd", quote(cwd))); + + filter_options + .exclude_cwd + .map(|exclude_cwd| sql.and_where_ne("cwd", quote(exclude_cwd))); + + filter_options.before.map(|before| { + interim::parse_date_string( + before.as_str(), + OffsetDateTime::now_utc(), + interim::Dialect::Uk, + ) + .map(|before| { + sql.and_where_lt("timestamp", quote(before.unix_timestamp_nanos() as i64)) + }) + }); + + filter_options.after.map(|after| { + interim::parse_date_string( + after.as_str(), + OffsetDateTime::now_utc(), + interim::Dialect::Uk, + ) + .map(|after| sql.and_where_gt("timestamp", quote(after.unix_timestamp_nanos() as i64))) + }); + + if !filter_options.authors.is_empty() { + apply_author_filter(&mut sql, &filter_options.authors); + } + + sql.and_where_is_null("deleted_at"); + + let query = sql.sql().expect("bug in search query. please report"); + + let res = sqlx::query(&query) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(ordering::reorder_fuzzy(search_mode, orig_query, res)) + } + + async fn query_history(&self, query: &str) -> Result<Vec<History>> { + let res = sqlx::query(query) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn all_with_count(&self) -> Result<Vec<(History, i32)>> { + debug!("listing history"); + + let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted()); + + query + .fields(&[ + "id", + "max(timestamp) as timestamp", + "max(duration) as duration", + "exit", + "command", + "deleted_at", + "null as author", + "null as intent", + "group_concat(cwd, ':') as cwd", + "group_concat(session) as session", + "group_concat(hostname, ',') as hostname", + "count(*) as count", + ]) + .group_by("command") + .group_by("exit") + .and_where("deleted_at is null") + .order_desc("timestamp"); + + let query = query.sql().expect("bug in list query. please report"); + + let res = sqlx::query(&query) + .map(|row: SqliteRow| { + let count: i32 = row.get("count"); + (Self::query_history(row), count) + }) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + fn all_paged(&self, page_size: usize, include_deleted: bool, unique: bool) -> Paged { + Paged::new(Box::new(self.clone()), page_size, include_deleted, unique) + } + + // deleted_at doesn't mean the actual time that the user deleted it, + // but the time that the system marks it as deleted + async fn delete(&self, mut h: History) -> Result<()> { + let now = OffsetDateTime::now_utc(); + h.command = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(32) + .map(char::from) + .collect(); // overwrite with random string + h.deleted_at = Some(now); // delete it + + self.update(&h).await?; // save it + + Ok(()) + } + + async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()> { + let mut tx = self.pool.begin().await?; + + for id in ids { + Self::delete_row_raw(&mut tx, id.clone()).await?; + } + + tx.commit().await?; + + Ok(()) + } + + async fn stats(&self, h: &History) -> Result<HistoryStats> { + // We select the previous in the session by time + let mut prev = SqlBuilder::select_from("history"); + prev.field("*") + .and_where("timestamp < ?1") + .and_where("session = ?2") + .order_by("timestamp", true) + .limit(1); + + let mut next = SqlBuilder::select_from("history"); + next.field("*") + .and_where("timestamp > ?1") + .and_where("session = ?2") + .order_by("timestamp", false) + .limit(1); + + let mut total = SqlBuilder::select_from("history"); + total.field("count(1)").and_where("command = ?1"); + + let mut average = SqlBuilder::select_from("history"); + average.field("avg(duration)").and_where("command = ?1"); + + let mut exits = SqlBuilder::select_from("history"); + exits + .fields(&["exit", "count(1) as count"]) + .and_where("command = ?1") + .group_by("exit"); + + // rewrite the following with sqlbuilder + let mut day_of_week = SqlBuilder::select_from("history"); + day_of_week + .fields(&[ + "strftime('%w', ROUND(timestamp / 1000000000), 'unixepoch') AS day_of_week", + "count(1) as count", + ]) + .and_where("command = ?1") + .group_by("day_of_week"); + + // Intentionally format the string with 01 hardcoded. We want the average runtime for the + // _entire month_, but will later parse it as a datetime for sorting + // Sqlite has no datetime so we cannot do it there, and otherwise sorting will just be a + // string sort, which won't be correct. + let mut duration_over_time = SqlBuilder::select_from("history"); + duration_over_time + .fields(&[ + "strftime('01-%m-%Y', ROUND(timestamp / 1000000000), 'unixepoch') AS month_year", + "avg(duration) as duration", + ]) + .and_where("command = ?1") + .group_by("month_year") + .having("duration > 0"); + + let prev = prev.sql().expect("issue in stats previous query"); + let next = next.sql().expect("issue in stats next query"); + let total = total.sql().expect("issue in stats average query"); + let average = average.sql().expect("issue in stats previous query"); + let exits = exits.sql().expect("issue in stats exits query"); + let day_of_week = day_of_week.sql().expect("issue in stats day of week query"); + let duration_over_time = duration_over_time + .sql() + .expect("issue in stats duration over time query"); + + let prev = sqlx::query(&prev) + .bind(h.timestamp.unix_timestamp_nanos() as i64) + .bind(&h.session) + .map(Self::query_history) + .fetch_optional(&self.pool) + .await?; + + let next = sqlx::query(&next) + .bind(h.timestamp.unix_timestamp_nanos() as i64) + .bind(&h.session) + .map(Self::query_history) + .fetch_optional(&self.pool) + .await?; + + let total: (i64,) = sqlx::query_as(&total) + .bind(&h.command) + .fetch_one(&self.pool) + .await?; + + let average: (f64,) = sqlx::query_as(&average) + .bind(&h.command) + .fetch_one(&self.pool) + .await?; + + let exits: Vec<(i64, i64)> = sqlx::query_as(&exits) + .bind(&h.command) + .fetch_all(&self.pool) + .await?; + + let day_of_week: Vec<(String, i64)> = sqlx::query_as(&day_of_week) + .bind(&h.command) + .fetch_all(&self.pool) + .await?; + + let duration_over_time: Vec<(String, f64)> = sqlx::query_as(&duration_over_time) + .bind(&h.command) + .fetch_all(&self.pool) + .await?; + + let duration_over_time = duration_over_time + .iter() + .map(|f| (f.0.clone(), f.1.round() as i64)) + .collect(); + + Ok(HistoryStats { + next, + previous: prev, + total: total.0 as u64, + average_duration: average.0 as u64, + exits, + day_of_week, + duration_over_time, + }) + } + + async fn get_dups(&self, before: i64, dupkeep: u32) -> Result<Vec<History>> { + let res = sqlx::query( + "SELECT * FROM ( + SELECT *, ROW_NUMBER() + OVER (PARTITION BY command, cwd, hostname ORDER BY timestamp DESC) + AS rn + FROM history + ) sub + WHERE rn > ?1 and timestamp < ?2; + ", + ) + .bind(dupkeep) + .bind(before) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + fn clone_boxed(&self) -> Box<dyn Database + 'static> { + Box::new(self.clone()) + } +} + +pub struct Paged { + database: Box<dyn Database + 'static>, + page_size: usize, + last_id: Option<String>, + include_deleted: bool, + unique: bool, +} + +impl Paged { + pub fn new( + database: Box<dyn Database + 'static>, + page_size: usize, + include_deleted: bool, + unique: bool, + ) -> Self { + Self { + database, + page_size, + last_id: None, + include_deleted, + unique, + } + } + + pub async fn next(&mut self) -> Result<Option<Vec<History>>> { + let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted()); + + query.field("*").order_desc("id"); + + if !self.include_deleted { + query.and_where_is_null("deleted_at"); + } + + if self.unique { + // We want to deduplicate on command, but the user can search via cwd, hostname, and session. + // Without those fields, filter modes won't work right. With those fields, we get duplicates. + // This must be handled upstream. + query + .group_by("command, cwd, hostname, session") + .having("max(timestamp)"); + } + + query.limit(self.page_size); + + if let Some(last_id) = &self.last_id { + query.and_where_lt("id", quote(last_id)); + } + + let query = query.sql().expect("bug in list query. please report"); + let res = self.database.query_history(&query).await?; + + if res.is_empty() { + Ok(None) + } else { + self.last_id = Some(res.last().unwrap().id.0.clone()); + Ok(Some(res)) + } + } +} + +trait SqlBuilderExt { + fn fuzzy_condition<S: ToString, T: ToString>( + &mut self, + field: S, + mask: T, + inverse: bool, + glob: bool, + is_or: bool, + ) -> &mut Self; +} + +impl SqlBuilderExt for SqlBuilder { + /// adapted from the sql-builder *like functions + fn fuzzy_condition<S: ToString, T: ToString>( + &mut self, + field: S, + mask: T, + inverse: bool, + glob: bool, + is_or: bool, + ) -> &mut Self { + let mut cond = field.to_string(); + if inverse { + cond.push_str(" NOT"); + } + if glob { + cond.push_str(" GLOB '"); + } else { + cond.push_str(" LIKE '"); + } + cond.push_str(&esc(mask.to_string())); + cond.push('\''); + if is_or { + self.or_where(cond) + } else { + self.and_where(cond) + } + } +} + +#[cfg(test)] +mod test { + use crate::atuin_client::settings::test_local_timeout; + + use super::*; + use std::time::{Duration, Instant}; + + async fn assert_search_eq( + db: &impl Database, + mode: SearchMode, + filter_mode: FilterMode, + query: &str, + expected: usize, + ) -> Result<Vec<History>> { + let context = Context { + hostname: "test:host".to_string(), + session: "beepboopiamasession".to_string(), + cwd: "/home/ellie".to_string(), + host_id: "test-host".to_string(), + git_root: None, + }; + + let results = db + .search( + mode, + filter_mode, + &context, + query, + OptFilters { + ..Default::default() + }, + ) + .await?; + + assert_eq!( + results.len(), + expected, + "query \"{}\", commands: {:?}", + query, + results.iter().map(|a| &a.command).collect::<Vec<&String>>() + ); + Ok(results) + } + + async fn assert_search_commands( + db: &impl Database, + mode: SearchMode, + filter_mode: FilterMode, + query: &str, + expected_commands: Vec<&str>, + ) { + let results = assert_search_eq(db, mode, filter_mode, query, expected_commands.len()) + .await + .unwrap(); + let commands: Vec<&str> = results.iter().map(|a| a.command.as_str()).collect(); + assert_eq!(commands, expected_commands); + } + + async fn new_history_item(db: &mut impl Database, cmd: &str) -> Result<()> { + let mut captured: History = History::capture() + .timestamp(OffsetDateTime::now_utc()) + .command(cmd) + .cwd("/home/ellie") + .build() + .into(); + + captured.exit = 0; + captured.duration = 1; + captured.session = "beep boop".to_string(); + captured.hostname = "booop".to_string(); + + db.save(&captured).await + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_search_prefix() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + new_history_item(&mut db, "ls /home/ellie").await.unwrap(); + + assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "/home", 0) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls ", 0) + .await + .unwrap(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_search_fulltext() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + new_history_item(&mut db, "ls /home/ellie").await.unwrap(); + + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls ho", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "hm", 0) + .await + .unwrap(); + + // regex + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r/^ls ", 1) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r/ls / ie$", + 1, + ) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r/ls / !ie", + 0, + ) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "meow r/ls/", + 0, + ) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r//hom/", 1) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r//home//", + 1, + ) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r//home///", + 0, + ) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home.*e", 0) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r/home.*e", + 1, + ) + .await + .unwrap(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_search_fuzzy() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + new_history_item(&mut db, "ls /home/ellie").await.unwrap(); + new_history_item(&mut db, "ls /home/frank").await.unwrap(); + new_history_item(&mut db, "cd /home/Ellie").await.unwrap(); + new_history_item(&mut db, "/home/ellie/.bin/rustup") + .await + .unwrap(); + + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls /", 3) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls/", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "l/h/", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e", 3) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/hmoe/", 0) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie/home", 0) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "lsellie", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, " ", 4) + .await + .unwrap(); + + // single term operators + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "'ls", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie$", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!^ls", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie$", 2) + .await + .unwrap(); + + // multiple terms + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls !ellie", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls !e$", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "home !^ls", 2) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::Fuzzy, + FilterMode::Global, + "'frank | 'rustup", + 2, + ) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::Fuzzy, + FilterMode::Global, + "'frank | 'rustup 'ls", + 1, + ) + .await + .unwrap(); + + // case matching + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "Ellie", 1) + .await + .unwrap(); + + // regex + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/^ls ", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/[Ee]llie", 3) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e r/^ls ", 1) + .await + .unwrap(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_search_reordered_fuzzy() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + // test ordering of results: we should choose the first, even though it happened longer ago. + + new_history_item(&mut db, "curl").await.unwrap(); + new_history_item(&mut db, "corburl").await.unwrap(); + + // if fuzzy reordering is on, it should come back in a more sensible order + assert_search_commands( + &db, + SearchMode::Fuzzy, + FilterMode::Global, + "curl", + vec!["curl", "corburl"], + ) + .await; + + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "xxxx", 0) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "", 2) + .await + .unwrap(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_paged_basic() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + + // Add 5 history items + for i in 0..5 { + new_history_item(&mut db, &format!("command{}", i)) + .await + .unwrap(); + } + + // Create a paged iterator with page_size of 2 + let mut paged = db.all_paged(2, false, false); + + // First page should have 2 items + let page1 = paged.next().await.unwrap(); + assert!(page1.is_some()); + assert_eq!(page1.unwrap().len(), 2); + + // Second page should have 2 items + let page2 = paged.next().await.unwrap(); + assert!(page2.is_some()); + assert_eq!(page2.unwrap().len(), 2); + + // Third page should have 1 item + let page3 = paged.next().await.unwrap(); + assert!(page3.is_some()); + assert_eq!(page3.unwrap().len(), 1); + + // Fourth page should be None (exhausted) + let page4 = paged.next().await.unwrap(); + assert!(page4.is_none()); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_paged_empty() { + let db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + + // Create a paged iterator on empty database + let mut paged = db.all_paged(10, false, false); + + // Should return None immediately + let page = paged.next().await.unwrap(); + assert!(page.is_none()); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_paged_unique() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + + // Add duplicate commands + new_history_item(&mut db, "duplicate").await.unwrap(); + new_history_item(&mut db, "duplicate").await.unwrap(); + new_history_item(&mut db, "unique1").await.unwrap(); + new_history_item(&mut db, "unique2").await.unwrap(); + + // Without unique flag - should get all 4 + let mut paged = db.all_paged(10, false, false); + let page = paged.next().await.unwrap().unwrap(); + assert_eq!(page.len(), 4); + + // With unique flag - should get 3 (duplicates collapsed) + let mut paged_unique = db.all_paged(10, false, true); + let page_unique = paged_unique.next().await.unwrap().unwrap(); + assert_eq!(page_unique.len(), 3); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_paged_include_deleted() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + + // Add items + new_history_item(&mut db, "keep1").await.unwrap(); + new_history_item(&mut db, "keep2").await.unwrap(); + new_history_item(&mut db, "delete_me").await.unwrap(); + + // Delete one item + let all = db + .list( + &[], + &Context { + hostname: "".to_string(), + session: "".to_string(), + cwd: "".to_string(), + host_id: "".to_string(), + git_root: None, + }, + None, + false, + false, + ) + .await + .unwrap(); + + let to_delete = all + .iter() + .find(|h| h.command == "delete_me") + .unwrap() + .clone(); + db.delete(to_delete).await.unwrap(); + + // Without include_deleted - should get 2 + let mut paged = db.all_paged(10, false, false); + let page = paged.next().await.unwrap().unwrap(); + assert_eq!(page.len(), 2); + + // With include_deleted - should get 3 + let mut paged_deleted = db.all_paged(10, true, false); + let page_deleted = paged_deleted.next().await.unwrap().unwrap(); + assert_eq!(page_deleted.len(), 3); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_search_bench_dupes() { + let context = Context { + hostname: "test:host".to_string(), + session: "beepboopiamasession".to_string(), + cwd: "/home/ellie".to_string(), + host_id: "test-host".to_string(), + git_root: None, + }; + + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + for _i in 1..10000 { + new_history_item(&mut db, "i am a duplicated command") + .await + .unwrap(); + } + let start = Instant::now(); + let _results = db + .search( + SearchMode::Fuzzy, + FilterMode::Global, + &context, + "", + OptFilters { + ..Default::default() + }, + ) + .await + .unwrap(); + let duration = start.elapsed(); + + assert!(duration < Duration::from_secs(15)); + } +} + +pub struct QueryTokenizer<'a> { + query: &'a str, + last_pos: usize, +} + +pub enum QueryToken<'a> { + Match(&'a str, bool), + MatchStart(&'a str, bool), + MatchEnd(&'a str, bool), + MatchFull(&'a str, bool), + Or, + Regex(&'a str), +} + +impl<'a> QueryToken<'a> { + pub fn has_uppercase(&self) -> bool { + match self { + Self::Match(term, _) + | Self::MatchStart(term, _) + | Self::MatchEnd(term, _) + | Self::MatchFull(term, _) => term.contains(char::is_uppercase), + _ => false, + } + } + + pub fn is_inverse(&self) -> bool { + match self { + Self::Match(_, inv) + | Self::MatchStart(_, inv) + | Self::MatchEnd(_, inv) + | Self::MatchFull(_, inv) => *inv, + _ => false, + } + } +} + +impl<'a> QueryTokenizer<'a> { + pub fn new(query: &'a str) -> Self { + Self { query, last_pos: 0 } + } +} + +impl<'a> Iterator for QueryTokenizer<'a> { + type Item = QueryToken<'a>; + fn next(&mut self) -> Option<Self::Item> { + let remaining = &self.query[self.last_pos..]; + if remaining.is_empty() { + return None; + } + + if let Some(remaining) = remaining.strip_prefix("r/") { + let (regex, next_pos) = if let Some(end) = remaining.find("/ ") { + (&remaining[..end], self.last_pos + 2 + end + 2) + } else if let Some(remaining) = remaining.strip_suffix('/') { + (remaining, self.query.len()) + } else { + (remaining, self.query.len()) + }; + self.last_pos = next_pos; + Some(QueryToken::Regex(regex)) + } else { + let (mut part, next_pos) = if let Some(sp) = remaining.find(' ') { + (&remaining[..sp], self.last_pos + sp + 1) + } else { + (remaining, self.query.len()) + }; + self.last_pos = next_pos; + + if part == "|" { + return Some(QueryToken::Or); + } + + let mut is_inverse = false; + if let Some(s) = part.strip_prefix('!') { + part = s; + is_inverse = true; + } + let token = if let Some(s) = part.strip_prefix('^') { + QueryToken::MatchStart(s, is_inverse) + } else if let Some(s) = part.strip_suffix('$') { + QueryToken::MatchEnd(s, is_inverse) + } else if let Some(s) = part.strip_prefix('\'') { + QueryToken::MatchFull(s, is_inverse) + } else { + QueryToken::Match(part, is_inverse) + }; + Some(token) + } + } +} diff --git a/crates/turtle/src/atuin_client/distro.rs b/crates/turtle/src/atuin_client/distro.rs new file mode 100644 index 00000000..dead8355 --- /dev/null +++ b/crates/turtle/src/atuin_client/distro.rs @@ -0,0 +1,89 @@ +use std::process::Command; + +/// Detect the Linux distribution from the system, +/// using system-specific release files and falling +/// back to lsb_release. +pub fn detect_linux_distribution() -> String { + detect_from_os_release() + .or_else(detect_from_debian_version) + .or_else(detect_from_centos_release) + .or_else(detect_from_redhat_release) + .or_else(detect_from_fedora_release) + .or_else(detect_from_arch_release) + .or_else(detect_from_alpine_release) + .or_else(detect_from_suse_release) + .or_else(detect_from_lsb_release) + .unwrap_or_else(|| "Unknown".to_string()) +} + +fn detect_from_os_release() -> Option<String> { + let content = std::fs::read_to_string("/etc/os-release").ok()?; + + content + .lines() + .find(|l| l.starts_with("PRETTY_NAME=")) + .and_then(|l| l.split_once('=').map(|s| s.1)) + .map(|s| s.trim_matches('"').to_string()) +} + +fn detect_from_debian_version() -> Option<String> { + std::fs::read_to_string("/etc/debian_version") + .ok() + .map(|v| format!("Debian {}", v.trim())) +} + +fn detect_from_centos_release() -> Option<String> { + std::fs::read_to_string("/etc/centos-release") + .ok() + .map(|v| v.trim().to_string()) +} + +fn detect_from_redhat_release() -> Option<String> { + std::fs::read_to_string("/etc/redhat-release") + .ok() + .map(|v| v.trim().to_string()) +} + +fn detect_from_fedora_release() -> Option<String> { + std::fs::read_to_string("/etc/fedora-release") + .ok() + .map(|v| v.trim().to_string()) +} + +fn detect_from_arch_release() -> Option<String> { + std::fs::read_to_string("/etc/arch-release") + .ok() + .filter(|v| !v.trim().is_empty()) + .map(|_| "Arch Linux".to_string()) +} + +fn detect_from_alpine_release() -> Option<String> { + std::fs::read_to_string("/etc/alpine-release") + .ok() + .map(|v| format!("Alpine {}", v.trim())) +} + +fn detect_from_suse_release() -> Option<String> { + std::fs::read_to_string("/etc/SuSE-release") + .ok() + .and_then(|content| content.lines().next().map(|l| l.trim().to_string())) +} + +fn detect_from_lsb_release() -> Option<String> { + let output = Command::new("lsb_release").arg("-a").output().ok()?; + + if !output.status.success() { + return None; + } + + let output = String::from_utf8(output.stdout).ok()?; + linux_distro_from_lsb_release(&output) +} + +fn linux_distro_from_lsb_release(output: &str) -> Option<String> { + output + .lines() + .find(|line| line.starts_with("Description:")) + .and_then(|line| line.split_once(':').map(|s| s.1)) + .map(|s| s.trim().to_string()) +} diff --git a/crates/turtle/src/atuin_client/encryption.rs b/crates/turtle/src/atuin_client/encryption.rs new file mode 100644 index 00000000..20a0cd90 --- /dev/null +++ b/crates/turtle/src/atuin_client/encryption.rs @@ -0,0 +1,440 @@ +// The general idea is that we NEVER send cleartext history to the server +// This way the odds of anything private ending up where it should not are +// very low +// The server authenticates via the usual username and password. This has +// nothing to do with the encryption, and is purely authentication! The client +// generates its own secret key, and encrypts all shell history with libsodium's +// secretbox. The data is then sent to the server, where it is stored. All +// clients must share the secret in order to be able to sync, as it is needed +// to decrypt + +use std::{io::prelude::*, path::PathBuf}; + +use base64::prelude::{BASE64_STANDARD, Engine}; +pub use crypto_secretbox::Key; +use crypto_secretbox::{ + AeadCore, AeadInPlace, KeyInit, XSalsa20Poly1305, + aead::{Nonce, OsRng}, +}; +use eyre::{Context, Result, bail, ensure, eyre}; +use fs_err as fs; +use rmp::{Marker, decode::Bytes}; +use serde::{Deserialize, Serialize}; +use time::{OffsetDateTime, format_description::well_known::Rfc3339, macros::format_description}; + +use crate::atuin_client::{history::History, settings::Settings}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct EncryptedHistory { + pub ciphertext: Vec<u8>, + pub nonce: Nonce<XSalsa20Poly1305>, +} + +pub fn generate_encoded_key() -> Result<(Key, String)> { + let key = XSalsa20Poly1305::generate_key(&mut OsRng); + let encoded = encode_key(&key)?; + + Ok((key, encoded)) +} + +pub fn new_key(settings: &Settings) -> Result<Key> { + let path = settings.key_path.as_str(); + let path = PathBuf::from(path); + + if path.exists() { + bail!("key already exists! cannot overwrite"); + } + + let (key, encoded) = generate_encoded_key()?; + + let mut file = fs::File::create(path)?; + file.write_all(encoded.as_bytes())?; + + Ok(key) +} + +// Loads the secret key, will create + save if it doesn't exist +pub fn load_key(settings: &Settings) -> Result<Key> { + let path = settings.key_path.as_str(); + + let key = if PathBuf::from(path).exists() { + let key = fs_err::read_to_string(path)?; + decode_key(key)? + } else { + new_key(settings)? + }; + + Ok(key) +} + +pub fn encode_key(key: &Key) -> Result<String> { + let mut buf = vec![]; + rmp::encode::write_array_len(&mut buf, key.len() as u32) + .wrap_err("could not encode key to message pack")?; + for b in key { + rmp::encode::write_uint(&mut buf, *b as u64) + .wrap_err("could not encode key to message pack")?; + } + let buf = BASE64_STANDARD.encode(buf); + + Ok(buf) +} + +pub fn decode_key(key: String) -> Result<Key> { + use rmp::decode; + + let buf = BASE64_STANDARD + .decode(key.trim_end()) + .wrap_err("encryption key is not a valid base64 encoding")?; + + // old code wrote the key as a fixed length array of 32 bytes + // new code writes the key with a length prefix + match <[u8; 32]>::try_from(&*buf) { + Ok(key) => Ok(key.into()), + Err(_) => { + let mut bytes = rmp::decode::Bytes::new(&buf); + + match Marker::from_u8(buf[0]) { + Marker::Bin8 => { + let len = decode::read_bin_len(&mut bytes).map_err(|err| eyre!("{err:?}"))?; + ensure!(len == 32, "encryption key is not the correct size"); + let key = <[u8; 32]>::try_from(bytes.remaining_slice()) + .context("could not decode encryption key")?; + Ok(key.into()) + } + Marker::Array16 => { + let len = decode::read_array_len(&mut bytes).map_err(|err| eyre!("{err:?}"))?; + ensure!(len == 32, "encryption key is not the correct size"); + + let mut key = Key::default(); + for i in &mut key { + *i = rmp::decode::read_int(&mut bytes).map_err(|err| eyre!("{err:?}"))?; + } + Ok(key) + } + _ => bail!("could not decode encryption key"), + } + } + } +} + +pub fn encrypt(history: &History, key: &Key) -> Result<EncryptedHistory> { + // serialize with msgpack + let mut buf = encode(history)?; + + let nonce = XSalsa20Poly1305::generate_nonce(&mut OsRng); + XSalsa20Poly1305::new(key) + .encrypt_in_place(&nonce, &[], &mut buf) + .map_err(|_| eyre!("could not encrypt"))?; + + Ok(EncryptedHistory { + ciphertext: buf, + nonce, + }) +} + +pub fn decrypt(mut encrypted_history: EncryptedHistory, key: &Key) -> Result<History> { + XSalsa20Poly1305::new(key) + .decrypt_in_place( + &encrypted_history.nonce, + &[], + &mut encrypted_history.ciphertext, + ) + .map_err(|_| eyre!("could not decrypt history"))?; + let plaintext = encrypted_history.ciphertext; + + let history = decode(&plaintext)?; + + Ok(history) +} + +fn format_rfc3339(ts: OffsetDateTime) -> Result<String> { + // horrible hack. chrono AutoSI limits to 0, 3, 6, or 9 decimal places for nanoseconds. + // time does not have this functionality. + static PARTIAL_RFC3339_0: &[time::format_description::FormatItem<'static>] = + format_description!("[year]-[month]-[day]T[hour]:[minute]:[second]Z"); + static PARTIAL_RFC3339_3: &[time::format_description::FormatItem<'static>] = + format_description!("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"); + static PARTIAL_RFC3339_6: &[time::format_description::FormatItem<'static>] = + format_description!("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:6]Z"); + static PARTIAL_RFC3339_9: &[time::format_description::FormatItem<'static>] = + format_description!("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:9]Z"); + + let fmt = match ts.nanosecond() { + 0 => PARTIAL_RFC3339_0, + ns if ns % 1_000_000 == 0 => PARTIAL_RFC3339_3, + ns if ns % 1_000 == 0 => PARTIAL_RFC3339_6, + _ => PARTIAL_RFC3339_9, + }; + + Ok(ts.format(fmt)?) +} + +fn encode(h: &History) -> Result<Vec<u8>> { + use rmp::encode; + + let mut output = vec![]; + // INFO: ensure this is updated when adding new fields + encode::write_array_len(&mut output, 9)?; + + encode::write_str(&mut output, &h.id.0)?; + encode::write_str(&mut output, &(format_rfc3339(h.timestamp)?))?; + encode::write_sint(&mut output, h.duration)?; + encode::write_sint(&mut output, h.exit)?; + encode::write_str(&mut output, &h.command)?; + encode::write_str(&mut output, &h.cwd)?; + encode::write_str(&mut output, &h.session)?; + encode::write_str(&mut output, &h.hostname)?; + match h.deleted_at { + Some(d) => encode::write_str(&mut output, &format_rfc3339(d)?)?, + None => encode::write_nil(&mut output)?, + } + + Ok(output) +} + +fn decode(bytes: &[u8]) -> Result<History> { + use rmp::decode::{self, DecodeStringError}; + + let mut bytes = Bytes::new(bytes); + + let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; + if nfields < 8 { + bail!("malformed decrypted history") + } + if nfields > 9 { + bail!("cannot decrypt history from a newer version of atuin"); + } + + let bytes = bytes.remaining_slice(); + let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (timestamp, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + let mut bytes = Bytes::new(bytes); + let duration = decode::read_int(&mut bytes).map_err(error_report)?; + let exit = decode::read_int(&mut bytes).map_err(error_report)?; + + let bytes = bytes.remaining_slice(); + let (command, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (cwd, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (session, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (hostname, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + // if we have more fields, try and get the deleted_at + let mut deleted_at = None; + let mut bytes = bytes; + if nfields > 8 { + bytes = match decode::read_str_from_slice(bytes) { + Ok((d, b)) => { + deleted_at = Some(d); + b + } + // we accept null here + Err(DecodeStringError::TypeMismatch(Marker::Null)) => { + // consume the null marker + let mut c = Bytes::new(bytes); + decode::read_nil(&mut c).map_err(error_report)?; + c.remaining_slice() + } + Err(err) => return Err(error_report(err)), + }; + } + + if !bytes.is_empty() { + bail!("trailing bytes in encoded history. malformed") + } + + Ok(History { + id: id.to_owned().into(), + timestamp: OffsetDateTime::parse(timestamp, &Rfc3339)?, + duration, + exit, + command: command.to_owned(), + cwd: cwd.to_owned(), + session: session.to_owned(), + hostname: hostname.to_owned(), + author: History::author_from_hostname(hostname), + intent: None, + deleted_at: deleted_at + .map(|t| OffsetDateTime::parse(t, &Rfc3339)) + .transpose()?, + }) +} + +fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { + eyre!("{err:?}") +} + +#[cfg(test)] +mod test { + use crypto_secretbox::{KeyInit, XSalsa20Poly1305, aead::OsRng}; + use pretty_assertions::assert_eq; + use time::{OffsetDateTime, macros::datetime}; + + use crate::history::History; + + use super::{decode, decrypt, encode, encrypt}; + + #[test] + fn test_encrypt_decrypt() { + let key1 = XSalsa20Poly1305::generate_key(&mut OsRng); + let key2 = XSalsa20Poly1305::generate_key(&mut OsRng); + + let history = History::from_db() + .id("1".into()) + .timestamp(OffsetDateTime::now_utc()) + .command("ls".into()) + .cwd("/home/ellie".into()) + .exit(0) + .duration(1) + .session("beep boop".into()) + .hostname("booop".into()) + .author("booop".into()) + .intent(None) + .deleted_at(None) + .build() + .into(); + + let e1 = encrypt(&history, &key1).unwrap(); + let e2 = encrypt(&history, &key2).unwrap(); + + assert_ne!(e1.ciphertext, e2.ciphertext); + assert_ne!(e1.nonce, e2.nonce); + + // test decryption works + // this should pass + match decrypt(e1, &key1) { + Err(e) => panic!("failed to decrypt, got {e}"), + Ok(h) => assert_eq!(h, history), + }; + + // this should err + let _ = decrypt(e2, &key1).expect_err("expected an error decrypting with invalid key"); + } + + #[test] + fn test_decode() { + let bytes = [ + 0x99, 0xD9, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, 53, 51, 56, + 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 187, 50, 48, 50, 51, 45, + 48, 53, 45, 50, 56, 84, 49, 56, 58, 51, 53, 58, 52, 48, 46, 54, 51, 51, 56, 55, 50, 90, + 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, 97, 116, 117, 115, 217, 42, + 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, + 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, 47, 99, 111, 100, 101, 47, 97, + 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, 51, 48, 54, 102, 50, 55, 52, 52, + 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, 102, 57, 52, 53, 55, 187, 102, + 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, 99, 111, 110, 114, 97, 100, 46, + 108, 117, 100, 103, 97, 116, 101, 192, + ]; + let history = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + author: "conrad.ludgate".to_owned(), + intent: None, + deleted_at: None, + }; + + let h = decode(&bytes).unwrap(); + assert_eq!(history, h); + + let b = encode(&h).unwrap(); + assert_eq!(&bytes, &*b); + } + + #[test] + fn test_decode_deleted() { + let history = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + author: "conrad.ludgate".to_owned(), + intent: None, + deleted_at: Some(datetime!(2023-05-28 18:35:40.633872 +00:00)), + }; + + let b = encode(&history).unwrap(); + let h = decode(&b).unwrap(); + assert_eq!(history, h); + } + + #[test] + fn test_decode_old() { + let bytes = [ + 0x98, 0xD9, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, 53, 51, 56, + 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 187, 50, 48, 50, 51, 45, + 48, 53, 45, 50, 56, 84, 49, 56, 58, 51, 53, 58, 52, 48, 46, 54, 51, 51, 56, 55, 50, 90, + 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, 97, 116, 117, 115, 217, 42, + 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, + 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, 47, 99, 111, 100, 101, 47, 97, + 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, 51, 48, 54, 102, 50, 55, 52, 52, + 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, 102, 57, 52, 53, 55, 187, 102, + 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, 99, 111, 110, 114, 97, 100, 46, + 108, 117, 100, 103, 97, 116, 101, + ]; + let history = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + author: "conrad.ludgate".to_owned(), + intent: None, + deleted_at: None, + }; + + let h = decode(&bytes).unwrap(); + assert_eq!(history, h); + } + + #[test] + fn key_encodings() { + use super::{Key, decode_key, encode_key}; + + // a history of our key encodings. + // v11.0.0 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== + // v12.0.0 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== + // v13.0.0 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== + // v13.0.1 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== + // v14.0.0 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== + // v14.0.1 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== + // c7d89c1 3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q== (https://github.com/ellie/atuin/pull/805) + // b53ca35 3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q== (https://github.com/ellie/atuin/pull/974) + // v15.0.0 3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q== + // b8b57c8 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== (https://github.com/ellie/atuin/pull/1057) + // 8c94d79 3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q== (https://github.com/ellie/atuin/pull/1089) + + let key = Key::from([ + 27, 91, 42, 91, 210, 107, 9, 216, 170, 190, 242, 62, 6, 84, 69, 148, 148, 53, 251, 117, + 226, 167, 173, 52, 82, 34, 138, 110, 169, 124, 92, 229, + ]); + + assert_eq!( + encode_key(&key).unwrap(), + "3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q==" + ); + + // key encodings we have to support + let valid_encodings = [ + "xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q==", + "3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q==", + ]; + + for k in valid_encodings { + assert_eq!(decode_key(k.to_owned()).expect(k), key); + } + } +} diff --git a/crates/turtle/src/atuin_client/history.rs b/crates/turtle/src/atuin_client/history.rs new file mode 100644 index 00000000..cef65115 --- /dev/null +++ b/crates/turtle/src/atuin_client/history.rs @@ -0,0 +1,756 @@ +use core::fmt::Formatter; +use rmp::decode::DecodeStringError; +use rmp::decode::ValueReadError; +use rmp::{Marker, decode::Bytes}; +use std::env; +use std::fmt::Display; + +use crate::atuin_common::record::DecryptedData; +use crate::atuin_common::utils::uuid_v7; + +use eyre::{Result, bail, eyre}; + +use crate::atuin_client::secrets::SECRET_PATTERNS_RE; +use crate::atuin_client::settings::Settings; +use crate::atuin_client::utils::get_host_user; +use time::OffsetDateTime; + +mod builder; +pub mod store; + +/// Known AI agent author values. Used to expand `$all-agent` and `$all-user` filters. +pub const KNOWN_AGENTS: &[&str] = &["claude-code", "codex", "copilot", "pi"]; +pub const AUTHOR_FILTER_ALL_USER: &str = "$all-user"; +pub const AUTHOR_FILTER_ALL_AGENT: &str = "$all-agent"; + +pub fn is_known_agent(author: &str) -> bool { + KNOWN_AGENTS.contains(&author) +} + +pub fn author_matches_filters(author: &str, filters: &[String]) -> bool { + filters.is_empty() + || filters.iter().any(|filter| match filter.as_str() { + AUTHOR_FILTER_ALL_USER => !is_known_agent(author), + AUTHOR_FILTER_ALL_AGENT => is_known_agent(author), + literal => author == literal, + }) +} + +pub(crate) const HISTORY_VERSION_V0: &str = "v0"; +pub(crate) const HISTORY_VERSION_V1: &str = "v1"; +const HISTORY_RECORD_VERSION_V0: u16 = 0; +const HISTORY_RECORD_VERSION_V1: u16 = 1; +pub(crate) const HISTORY_VERSION: &str = HISTORY_VERSION_V1; +pub const HISTORY_TAG: &str = "history"; +const HISTORY_AUTHOR_ENV: &str = "ATUIN_HISTORY_AUTHOR"; +const HISTORY_INTENT_ENV: &str = "ATUIN_HISTORY_INTENT"; + +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +pub struct HistoryId(pub String); + +impl Display for HistoryId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From<String> for HistoryId { + fn from(s: String) -> Self { + Self(s) + } +} + +/// Client-side history entry. +/// +/// Client stores data unencrypted, and only encrypts it before sending to the server. +/// +/// To create a new history entry, use one of the builders: +/// - [`History::import()`] to import an entry from the shell history file +/// - [`History::capture()`] to capture an entry via hook +/// - [`History::from_db()`] to create an instance from the database entry +// +// ## Implementation Notes +// +// New fields must be added to `History::{serialize,deserialize}` in a backwards +// compatible way (sensible defaults and careful `nfields` handling). +#[derive(Debug, Clone, PartialEq, Eq, sqlx::FromRow)] +pub struct History { + /// A client-generated ID, used to identify the entry when syncing. + /// + /// Stored as `client_id` in the database. + pub id: HistoryId, + /// When the command was run. + pub timestamp: OffsetDateTime, + /// How long the command took to run. + pub duration: i64, + /// The exit code of the command. + pub exit: i64, + /// The command that was run. + pub command: String, + /// The current working directory when the command was run. + pub cwd: String, + /// The session ID, associated with a terminal session. + pub session: String, + /// The hostname of the machine the command was run on. + pub hostname: String, + /// Who wrote this command (human user or automation/agent identity). + pub author: String, + /// Optional rationale for why the command was executed. + pub intent: Option<String>, + /// Timestamp, which is set when the entry is deleted, allowing a soft delete. + pub deleted_at: Option<OffsetDateTime>, +} + +#[derive(Debug, Clone, PartialEq, Eq, sqlx::FromRow)] +pub struct HistoryStats { + /// The command that was ran after this one in the session + pub next: Option<History>, + /// + /// The command that was ran before this one in the session + pub previous: Option<History>, + + /// How many times has this command been ran? + pub total: u64, + + pub average_duration: u64, + + pub exits: Vec<(i64, i64)>, + + pub day_of_week: Vec<(String, i64)>, + + pub duration_over_time: Vec<(String, i64)>, +} + +impl History { + pub(crate) fn author_from_hostname(hostname: &str) -> String { + hostname + .split_once(':') + .map_or_else(|| hostname.to_owned(), |(_, user)| user.to_owned()) + } + + fn normalize_optional_field(field: Option<String>) -> Option<String> { + field.and_then(|value| { + let trimmed = value.trim(); + if trimmed.is_empty() { + None + } else { + Some(trimmed.to_owned()) + } + }) + } + + #[expect(clippy::too_many_arguments)] + fn new( + timestamp: OffsetDateTime, + command: String, + cwd: String, + exit: i64, + duration: i64, + session: Option<String>, + hostname: Option<String>, + author: Option<String>, + intent: Option<String>, + deleted_at: Option<OffsetDateTime>, + ) -> Self { + let session = session + .or_else(|| env::var("ATUIN_SESSION").ok()) + .unwrap_or_else(|| uuid_v7().as_simple().to_string()); + let hostname = hostname.unwrap_or_else(get_host_user); + let author = Self::normalize_optional_field(author) + .or_else(|| Self::normalize_optional_field(env::var(HISTORY_AUTHOR_ENV).ok())) + .unwrap_or_else(|| Self::author_from_hostname(hostname.as_str())); + let intent = Self::normalize_optional_field(intent) + .or_else(|| Self::normalize_optional_field(env::var(HISTORY_INTENT_ENV).ok())); + + Self { + id: uuid_v7().as_simple().to_string().into(), + timestamp, + command, + cwd, + exit, + duration, + session, + hostname, + author, + intent, + deleted_at, + } + } + + pub fn serialize(&self) -> Result<DecryptedData> { + // This is pretty much the same as what we used for the old history, with one difference - + // it uses integers for timestamps rather than a string format. + + use rmp::encode; + + let mut output = vec![]; + + // write the version + encode::write_u16(&mut output, HISTORY_RECORD_VERSION_V1)?; + let include_intent = self.intent.is_some(); + encode::write_array_len(&mut output, 10 + u32::from(include_intent))?; + + encode::write_str(&mut output, &self.id.0)?; + encode::write_u64(&mut output, self.timestamp.unix_timestamp_nanos() as u64)?; + encode::write_sint(&mut output, self.duration)?; + encode::write_sint(&mut output, self.exit)?; + encode::write_str(&mut output, &self.command)?; + encode::write_str(&mut output, &self.cwd)?; + encode::write_str(&mut output, &self.session)?; + encode::write_str(&mut output, &self.hostname)?; + + match self.deleted_at { + Some(d) => encode::write_u64(&mut output, d.unix_timestamp_nanos() as u64)?, + None => encode::write_nil(&mut output)?, + } + + encode::write_str(&mut output, self.author.as_str())?; + if let Some(intent) = &self.intent { + encode::write_str(&mut output, intent.as_str())?; + } + + Ok(DecryptedData(output)) + } + + fn read_optional_string(bytes: &[u8]) -> Result<(Option<String>, &[u8])> { + use rmp::decode; + + fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { + eyre!("{err:?}") + } + + match decode::read_str_from_slice(bytes) { + Ok((value, bytes)) => Ok((Some(value.to_owned()), bytes)), + Err(DecodeStringError::TypeMismatch(Marker::Null)) => { + let mut cursor = Bytes::new(bytes); + decode::read_nil(&mut cursor).map_err(error_report)?; + + Ok((None, cursor.remaining_slice())) + } + Err(err) => Err(error_report(err)), + } + } + + fn deserialize_v0(bytes: &[u8]) -> Result<History> { + use rmp::decode; + + fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { + eyre!("{err:?}") + } + + let mut bytes = Bytes::new(bytes); + + let version = decode::read_u16(&mut bytes).map_err(error_report)?; + + if version != HISTORY_RECORD_VERSION_V0 { + bail!("expected decoding v0 record, found v{version}"); + } + + let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; + + if nfields != 9 { + bail!("cannot decrypt history from a different version of Atuin"); + } + + let bytes = bytes.remaining_slice(); + let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + let mut bytes = Bytes::new(bytes); + let timestamp = decode::read_u64(&mut bytes).map_err(error_report)?; + let duration = decode::read_int(&mut bytes).map_err(error_report)?; + let exit = decode::read_int(&mut bytes).map_err(error_report)?; + + let bytes = bytes.remaining_slice(); + let (command, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (cwd, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (session, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (hostname, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + let mut bytes = Bytes::new(bytes); + + let (deleted_at, bytes) = match decode::read_u64(&mut bytes) { + Ok(unix) => (Some(unix), bytes.remaining_slice()), + // we accept null here + Err(ValueReadError::TypeMismatch(Marker::Null)) => (None, bytes.remaining_slice()), + Err(err) => return Err(error_report(err)), + }; + if !bytes.is_empty() { + bail!("trailing bytes in encoded history. malformed") + } + + Ok(History { + id: id.to_owned().into(), + timestamp: OffsetDateTime::from_unix_timestamp_nanos(timestamp as i128)?, + duration, + exit, + command: command.to_owned(), + cwd: cwd.to_owned(), + session: session.to_owned(), + hostname: hostname.to_owned(), + author: Self::author_from_hostname(hostname), + intent: None, + deleted_at: deleted_at + .map(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128)) + .transpose()?, + }) + } + + fn deserialize_v1(bytes: &[u8]) -> Result<History> { + use rmp::decode; + + fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { + eyre!("{err:?}") + } + + let mut bytes = Bytes::new(bytes); + + let version = decode::read_u16(&mut bytes).map_err(error_report)?; + + if version != HISTORY_RECORD_VERSION_V1 { + bail!("expected decoding v1 record, found v{version}"); + } + + let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; + + if !(10..=11).contains(&nfields) { + bail!("cannot decrypt history from a different version of Atuin"); + } + + let bytes = bytes.remaining_slice(); + let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + let mut bytes = Bytes::new(bytes); + let timestamp = decode::read_u64(&mut bytes).map_err(error_report)?; + let duration = decode::read_int(&mut bytes).map_err(error_report)?; + let exit = decode::read_int(&mut bytes).map_err(error_report)?; + + let bytes = bytes.remaining_slice(); + let (command, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (cwd, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (session, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (hostname, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + let mut bytes = Bytes::new(bytes); + + let (deleted_at, bytes) = match decode::read_u64(&mut bytes) { + Ok(unix) => (Some(unix), bytes.remaining_slice()), + // we accept null here + Err(ValueReadError::TypeMismatch(Marker::Null)) => (None, bytes.remaining_slice()), + Err(err) => return Err(error_report(err)), + }; + let (author, bytes) = Self::read_optional_string(bytes)?; + let (intent, bytes) = if nfields > 10 { + Self::read_optional_string(bytes)? + } else { + (None, bytes) + }; + + if !bytes.is_empty() { + bail!("trailing bytes in encoded history. malformed") + } + + Ok(History { + id: id.to_owned().into(), + timestamp: OffsetDateTime::from_unix_timestamp_nanos(timestamp as i128)?, + duration, + exit, + command: command.to_owned(), + cwd: cwd.to_owned(), + session: session.to_owned(), + hostname: hostname.to_owned(), + author: author.unwrap_or_else(|| Self::author_from_hostname(hostname)), + intent, + deleted_at: deleted_at + .map(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128)) + .transpose()?, + }) + } + + pub fn deserialize(bytes: &[u8], version: &str) -> Result<History> { + match version { + HISTORY_VERSION_V0 => Self::deserialize_v0(bytes), + HISTORY_VERSION_V1 => Self::deserialize_v1(bytes), + + _ => bail!("unknown version {version:?}"), + } + } + + /// Builder for a history entry that is imported from shell history. + /// + /// The only two required fields are `timestamp` and `command`. + /// + /// ## Examples + /// ``` + /// use crate::atuin_client::history::History; + /// + /// let history: History = History::import() + /// .timestamp(time::OffsetDateTime::now_utc()) + /// .command("ls -la") + /// .build() + /// .into(); + /// ``` + /// + /// If shell history contains more information, it can be added to the builder: + /// ``` + /// use crate::atuin_client::history::History; + /// + /// let history: History = History::import() + /// .timestamp(time::OffsetDateTime::now_utc()) + /// .command("ls -la") + /// .cwd("/home/user") + /// .exit(0) + /// .duration(100) + /// .build() + /// .into(); + /// ``` + /// + /// Unknown command or command without timestamp cannot be imported, which + /// is forced at compile time: + /// + /// ```compile_fail + /// use crate::atuin_client::history::History; + /// + /// // this will not compile because timestamp is missing + /// let history: History = History::import() + /// .command("ls -la") + /// .build() + /// .into(); + /// ``` + pub fn import() -> builder::HistoryImportedBuilder { + builder::HistoryImported::builder() + } + + /// Builder for a history entry that is captured via hook. + /// + /// This builder is used only at the `start` step of the hook, + /// so it doesn't have any fields which are known only after + /// the command is finished, such as `exit` or `duration`. + /// + /// ## Examples + /// ```rust + /// use crate::atuin_client::history::History; + /// + /// let history: History = History::capture() + /// .timestamp(time::OffsetDateTime::now_utc()) + /// .command("ls -la") + /// .cwd("/home/user") + /// .build() + /// .into(); + /// ``` + /// + /// Command without any required info cannot be captured, which is forced at compile time: + /// + /// ```compile_fail + /// use crate::atuin_client::history::History; + /// + /// // this will not compile because `cwd` is missing + /// let history: History = History::capture() + /// .timestamp(time::OffsetDateTime::now_utc()) + /// .command("ls -la") + /// .build() + /// .into(); + /// ``` + pub fn capture() -> builder::HistoryCapturedBuilder { + builder::HistoryCaptured::builder() + } + + /// Builder for a history entry that is captured via hook, and sent to the daemon. + /// + /// This builder is used only at the `start` step of the hook, + /// so it doesn't have any fields which are known only after + /// the command is finished, such as `exit` or `duration`. + /// + /// It does, however, include information that can usually be inferred. + /// + /// This is because the daemon we are sending a request to lacks the context of the command + /// + /// ## Examples + /// ```rust + /// use crate::atuin_client::history::History; + /// + /// let history: History = History::daemon() + /// .timestamp(time::OffsetDateTime::now_utc()) + /// .command("ls -la") + /// .cwd("/home/user") + /// .session("018deb6e8287781f9973ef40e0fde76b") + /// .hostname("computer:ellie") + /// .build() + /// .into(); + /// ``` + /// + /// Command without any required info cannot be captured, which is forced at compile time: + /// + /// ```compile_fail + /// use crate::atuin_client::history::History; + /// + /// // this will not compile because `hostname` is missing + /// let history: History = History::daemon() + /// .timestamp(time::OffsetDateTime::now_utc()) + /// .command("ls -la") + /// .cwd("/home/user") + /// .session("018deb6e8287781f9973ef40e0fde76b") + /// .build() + /// .into(); + /// ``` + pub fn daemon() -> builder::HistoryDaemonCaptureBuilder { + builder::HistoryDaemonCapture::builder() + } + + /// Builder for a history entry that is imported from the database. + /// + /// All fields are required, as they are all present in the database. + /// + /// ```compile_fail + /// use crate::atuin_client::history::History; + /// + /// // this will not compile because `id` field is missing + /// let history: History = History::from_db() + /// .timestamp(time::OffsetDateTime::now_utc()) + /// .command("ls -la".to_string()) + /// .cwd("/home/user".to_string()) + /// .exit(0) + /// .duration(100) + /// .session("somesession".to_string()) + /// .hostname("localhost".to_string()) + /// .author("user".to_string()) + /// .intent(None) + /// .deleted_at(None) + /// .build() + /// .into(); + /// ``` + pub fn from_db() -> builder::HistoryFromDbBuilder { + builder::HistoryFromDb::builder() + } + + pub fn success(&self) -> bool { + self.exit == 0 || self.duration == -1 + } + + pub fn should_save(&self, settings: &Settings) -> bool { + !(self.command.starts_with(' ') + || self.command.is_empty() + || settings.history_filter.is_match(&self.command) + || settings.cwd_filter.is_match(&self.cwd) + || (settings.secrets_filter && SECRET_PATTERNS_RE.is_match(&self.command))) + } +} + +#[cfg(test)] +mod tests { + use regex::RegexSet; + use time::macros::datetime; + + use crate::{ + history::{AUTHOR_FILTER_ALL_AGENT, AUTHOR_FILTER_ALL_USER, HISTORY_VERSION}, + settings::Settings, + }; + + use super::{History, author_matches_filters, is_known_agent}; + + // Test that we don't save history where necessary + #[test] + fn privacy_test() { + let settings = Settings { + cwd_filter: RegexSet::new(["^/supasecret"]).unwrap(), + history_filter: RegexSet::new(["^psql"]).unwrap(), + ..Settings::utc() + }; + + let normal_command: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("echo foo") + .cwd("/") + .build() + .into(); + + let with_space: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command(" echo bar") + .cwd("/") + .build() + .into(); + + let empty: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("") + .cwd("/") + .build() + .into(); + + let stripe_key: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("curl foo.com/bar?key=sk_test_1234567890abcdefghijklmnop") + .cwd("/") + .build() + .into(); + + let secret_dir: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("echo ohno") + .cwd("/supasecret") + .build() + .into(); + + let with_psql: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("psql") + .cwd("/supasecret") + .build() + .into(); + + assert!(normal_command.should_save(&settings)); + assert!(!with_space.should_save(&settings)); + assert!(!empty.should_save(&settings)); + assert!(!stripe_key.should_save(&settings)); + assert!(!secret_dir.should_save(&settings)); + assert!(!with_psql.should_save(&settings)); + } + + #[test] + fn known_agents_include_pi() { + assert!(is_known_agent("pi")); + assert!(author_matches_filters( + "pi", + &[AUTHOR_FILTER_ALL_AGENT.to_string()] + )); + assert!(!author_matches_filters( + "pi", + &[AUTHOR_FILTER_ALL_USER.to_string()] + )); + } + + #[test] + fn disable_secrets() { + let settings = Settings { + secrets_filter: false, + ..Settings::utc() + }; + + let stripe_key: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("curl foo.com/bar?key=sk_test_1234567890abcdefghijklmnop") + .cwd("/") + .build() + .into(); + + assert!(stripe_key.should_save(&settings)); + } + + #[test] + fn test_serialize_deserialize() { + let history = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + author: "conrad.ludgate".to_owned(), + intent: None, + deleted_at: None, + }; + + let serialized = history.serialize().expect("failed to serialize history"); + assert_eq!( + &serialized.0[0..3], + [205, 0, 1], + "should encode as history v1" + ); + + let deserialized = History::deserialize(&serialized.0, HISTORY_VERSION) + .expect("failed to deserialize history"); + assert_eq!(history, deserialized); + } + + #[test] + fn test_serialize_deserialize_deleted() { + let history = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + author: "conrad.ludgate".to_owned(), + intent: None, + deleted_at: Some(datetime!(2023-11-19 20:18 +00:00)), + }; + + let serialized = history.serialize().expect("failed to serialize history"); + + let deserialized = History::deserialize(&serialized.0, HISTORY_VERSION) + .expect("failed to deserialize history"); + + assert_eq!(history, deserialized); + } + + #[test] + fn test_serialize_deserialize_with_author_and_intent() { + let history = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + author: "claude".to_owned(), + intent: Some("check repository status".to_owned()), + deleted_at: None, + }; + + let serialized = history.serialize().expect("failed to serialize history"); + let deserialized = History::deserialize(&serialized.0, HISTORY_VERSION) + .expect("failed to deserialize history"); + + assert_eq!(history, deserialized); + } + + #[test] + fn test_serialize_deserialize_version() { + // v0 + let bytes_v0 = [ + 205, 0, 0, 153, 217, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, + 53, 51, 56, 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 207, 23, 99, + 98, 117, 24, 210, 246, 128, 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, + 97, 116, 117, 115, 217, 42, 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, + 46, 108, 117, 100, 103, 97, 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, + 47, 99, 111, 100, 101, 47, 97, 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, + 51, 48, 54, 102, 50, 55, 52, 52, 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, + 102, 57, 52, 53, 55, 187, 102, 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, + 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, 116, 101, 192, + ]; + + let deserialized = History::deserialize(&bytes_v0, "v0"); + assert!(deserialized.is_ok()); + + let deserialized = History::deserialize(&bytes_v0, HISTORY_VERSION); + assert!(deserialized.is_err()); + + let current = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + author: "conrad.ludgate".to_owned(), + intent: None, + deleted_at: None, + }; + + let bytes_v1 = current.serialize().expect("failed to serialize history"); + let deserialized = History::deserialize(&bytes_v1.0, HISTORY_VERSION); + assert!(deserialized.is_ok()); + + let deserialized = History::deserialize(&bytes_v1.0, "v0"); + assert!(deserialized.is_err()); + } +} diff --git a/crates/turtle/src/atuin_client/history/builder.rs b/crates/turtle/src/atuin_client/history/builder.rs new file mode 100644 index 00000000..72a505fd --- /dev/null +++ b/crates/turtle/src/atuin_client/history/builder.rs @@ -0,0 +1,154 @@ +use typed_builder::TypedBuilder; + +use super::History; + +/// Builder for a history entry that is imported from shell history. +/// +/// The only two required fields are `timestamp` and `command`. +#[derive(Debug, Clone, TypedBuilder)] +pub struct HistoryImported { + timestamp: time::OffsetDateTime, + #[builder(setter(into))] + command: String, + #[builder(default = "unknown".into(), setter(into))] + cwd: String, + #[builder(default = -1)] + exit: i64, + #[builder(default = -1)] + duration: i64, + #[builder(default, setter(strip_option, into))] + session: Option<String>, + #[builder(default, setter(strip_option, into))] + hostname: Option<String>, + #[builder(default, setter(strip_option, into))] + author: Option<String>, + #[builder(default, setter(strip_option, into))] + intent: Option<String>, +} + +impl From<HistoryImported> for History { + fn from(imported: HistoryImported) -> Self { + History::new( + imported.timestamp, + imported.command, + imported.cwd, + imported.exit, + imported.duration, + imported.session, + imported.hostname, + imported.author, + imported.intent, + None, + ) + } +} + +/// Builder for a history entry that is captured via hook. +/// +/// This builder is used only at the `start` step of the hook, +/// so it doesn't have any fields which are known only after +/// the command is finished, such as `exit` or `duration`. +#[derive(Debug, Clone, TypedBuilder)] +pub struct HistoryCaptured { + timestamp: time::OffsetDateTime, + #[builder(setter(into))] + command: String, + #[builder(setter(into))] + cwd: String, + #[builder(default, setter(strip_option, into))] + author: Option<String>, + #[builder(default, setter(strip_option, into))] + intent: Option<String>, +} + +impl From<HistoryCaptured> for History { + fn from(captured: HistoryCaptured) -> Self { + History::new( + captured.timestamp, + captured.command, + captured.cwd, + -1, + -1, + None, + None, + captured.author, + captured.intent, + None, + ) + } +} + +/// Builder for a history entry that is loaded from the database. +/// +/// All fields are required, as they are all present in the database. +#[derive(Debug, Clone, TypedBuilder)] +pub struct HistoryFromDb { + id: String, + timestamp: time::OffsetDateTime, + command: String, + cwd: String, + exit: i64, + duration: i64, + session: String, + hostname: String, + author: String, + intent: Option<String>, + deleted_at: Option<time::OffsetDateTime>, +} + +impl From<HistoryFromDb> for History { + fn from(from_db: HistoryFromDb) -> Self { + History { + id: from_db.id.into(), + timestamp: from_db.timestamp, + exit: from_db.exit, + command: from_db.command, + cwd: from_db.cwd, + duration: from_db.duration, + session: from_db.session, + hostname: from_db.hostname, + author: from_db.author, + intent: from_db.intent, + deleted_at: from_db.deleted_at, + } + } +} + +/// Builder for a history entry that is captured via hook and sent to the daemon +/// +/// This builder is similar to Capture, but we just require more information up front. +/// For the old setup, we could just rely on History::new to read some of the missing +/// data. This is no longer the case. +#[derive(Debug, Clone, TypedBuilder)] +pub struct HistoryDaemonCapture { + timestamp: time::OffsetDateTime, + #[builder(setter(into))] + command: String, + #[builder(setter(into))] + cwd: String, + #[builder(setter(into))] + session: String, + #[builder(setter(into))] + hostname: String, + #[builder(default, setter(strip_option, into))] + author: Option<String>, + #[builder(default, setter(strip_option, into))] + intent: Option<String>, +} + +impl From<HistoryDaemonCapture> for History { + fn from(captured: HistoryDaemonCapture) -> Self { + History::new( + captured.timestamp, + captured.command, + captured.cwd, + -1, + -1, + Some(captured.session), + Some(captured.hostname), + captured.author, + captured.intent, + None, + ) + } +} diff --git a/crates/turtle/src/atuin_client/history/store.rs b/crates/turtle/src/atuin_client/history/store.rs new file mode 100644 index 00000000..66d9db47 --- /dev/null +++ b/crates/turtle/src/atuin_client/history/store.rs @@ -0,0 +1,435 @@ +use std::{collections::HashSet, fmt::Write, time::Duration}; + +use eyre::{Result, bail, eyre}; +use indicatif::{ProgressBar, ProgressState, ProgressStyle}; +use rmp::decode::Bytes; +use tracing::debug; + +use crate::atuin_client::{ + database::{Database, current_context}, + record::{encryption::PASETO_V4, sqlite_store::SqliteStore, store::Store}, +}; +use crate::atuin_common::record::{DecryptedData, Host, HostId, Record, RecordId, RecordIdx}; + +use super::{HISTORY_TAG, HISTORY_VERSION, HISTORY_VERSION_V0, History, HistoryId}; + +#[derive(Debug, Clone)] +pub struct HistoryStore { + pub store: SqliteStore, + pub host_id: HostId, + pub encryption_key: [u8; 32], +} + +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum HistoryRecord { + Create(History), // Create a history record + Delete(HistoryId), // Delete a history record, identified by ID +} + +impl HistoryRecord { + /// Serialize a history record, returning DecryptedData + /// The record will be of a certain type + /// We map those like so: + /// + /// HistoryRecord::Create -> 0 + /// HistoryRecord::Delete-> 1 + /// + /// This numeric identifier is then written as the first byte to the buffer. For history, we + /// append the serialized history right afterwards, to avoid having to handle serialization + /// twice. + /// + /// Deletion simply refers to the history by ID + pub fn serialize(&self) -> Result<DecryptedData> { + // probably don't actually need to use rmp here, but if we ever need to extend it, it's a + // nice wrapper around raw byte stuff + use rmp::encode; + + let mut output = vec![]; + + match self { + HistoryRecord::Create(history) => { + // 0 -> a history create + encode::write_u8(&mut output, 0)?; + + let bytes = history.serialize()?; + + encode::write_bin(&mut output, &bytes.0)?; + } + HistoryRecord::Delete(id) => { + // 1 -> a history delete + encode::write_u8(&mut output, 1)?; + encode::write_str(&mut output, id.0.as_str())?; + } + }; + + Ok(DecryptedData(output)) + } + + pub fn deserialize(bytes: &DecryptedData, version: &str) -> Result<Self> { + use rmp::decode; + + fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { + eyre!("{err:?}") + } + + let mut bytes = Bytes::new(&bytes.0); + + let record_type = decode::read_u8(&mut bytes).map_err(error_report)?; + + match record_type { + // 0 -> HistoryRecord::Create + 0 => { + // not super useful to us atm, but perhaps in the future + // written by write_bin above + let _ = decode::read_bin_len(&mut bytes).map_err(error_report)?; + + let record = History::deserialize(bytes.remaining_slice(), version)?; + + Ok(HistoryRecord::Create(record)) + } + + // 1 -> HistoryRecord::Delete + 1 => { + let bytes = bytes.remaining_slice(); + let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + if !bytes.is_empty() { + bail!( + "trailing bytes decoding HistoryRecord::Delete - malformed? got {bytes:?}" + ); + } + + Ok(HistoryRecord::Delete(id.to_string().into())) + } + + n => { + bail!("unknown HistoryRecord type {n}") + } + } + } +} + +impl HistoryStore { + pub fn new(store: SqliteStore, host_id: HostId, encryption_key: [u8; 32]) -> Self { + HistoryStore { + store, + host_id, + encryption_key, + } + } + + async fn push_record(&self, record: HistoryRecord) -> Result<(RecordId, RecordIdx)> { + let bytes = record.serialize()?; + let idx = self + .store + .last(self.host_id, HISTORY_TAG) + .await? + .map_or(0, |p| p.idx + 1); + + let record = Record::builder() + .host(Host::new(self.host_id)) + .version(HISTORY_VERSION.to_string()) + .tag(HISTORY_TAG.to_string()) + .idx(idx) + .data(bytes) + .build(); + + let id = record.id; + + self.store + .push(&record.encrypt::<PASETO_V4>(&self.encryption_key)) + .await?; + + Ok((id, idx)) + } + + async fn push_batch(&self, records: impl Iterator<Item = HistoryRecord>) -> Result<()> { + let mut ret = Vec::new(); + + let idx = self + .store + .last(self.host_id, HISTORY_TAG) + .await? + .map_or(0, |p| p.idx + 1); + + // Could probably _also_ do this as an iterator, but let's see how this is for now. + // optimizing for minimal sqlite transactions, this code can be optimised later + for (n, record) in records.enumerate() { + let bytes = record.serialize()?; + + let record = Record::builder() + .host(Host::new(self.host_id)) + .version(HISTORY_VERSION.to_string()) + .tag(HISTORY_TAG.to_string()) + .idx(idx + n as u64) + .data(bytes) + .build(); + + let record = record.encrypt::<PASETO_V4>(&self.encryption_key); + + ret.push(record); + } + + self.store.push_batch(ret.iter()).await?; + + Ok(()) + } + + pub async fn delete(&self, id: HistoryId) -> Result<(RecordId, RecordIdx)> { + let record = HistoryRecord::Delete(id); + + self.push_record(record).await + } + + /// Delete a batch of history entries via the record store. + /// Returns the record IDs so the caller can run incremental_build when ready. + pub async fn delete_entries( + &self, + entries: impl IntoIterator<Item = History>, + ) -> Result<Vec<RecordId>> { + let mut record_ids = Vec::new(); + for entry in entries { + let (id, _) = self.delete(entry.id).await?; + record_ids.push(id); + } + Ok(record_ids) + } + + pub async fn push(&self, history: History) -> Result<(RecordId, RecordIdx)> { + // TODO(ellie): move the history store to its own file + // it's tiny rn so fine as is + let record = HistoryRecord::Create(history); + + self.push_record(record).await + } + + pub async fn history(&self) -> Result<Vec<HistoryRecord>> { + // Atm this loads all history into memory + // Not ideal as that is potentially quite a lot, although history will be small. + let records = self.store.all_tagged(HISTORY_TAG).await?; + let mut ret = Vec::with_capacity(records.len()); + + for record in records.into_iter() { + let hist = match record.version.as_str() { + HISTORY_VERSION_V0 | HISTORY_VERSION => { + let version = record.version.clone(); + let decrypted = record.decrypt::<PASETO_V4>(&self.encryption_key)?; + + HistoryRecord::deserialize(&decrypted.data, version.as_str()) + } + version => bail!("unknown history version {version:?}"), + }?; + + ret.push(hist); + } + + Ok(ret) + } + + pub async fn build(&self, database: &dyn Database) -> Result<()> { + // I'd like to change how we rebuild and not couple this with the database, but need to + // consider the structure more deeply. This will be easy to change. + + // TODO(ellie): page or iterate this + let history = self.history().await?; + + // In theory we could flatten this here + // The current issue is that the database may have history in it already, from the old sync + // This didn't actually delete old history + // If we're sure we have a DB only maintained by the new store, we can flatten + // create/delete before we even get to sqlite + let mut creates = Vec::new(); + let mut deletes = Vec::new(); + + for i in history { + match i { + HistoryRecord::Create(h) => { + creates.push(h); + } + HistoryRecord::Delete(id) => { + deletes.push(id); + } + } + } + + database.save_bulk(&creates).await?; + database.delete_rows(&deletes).await?; + + Ok(()) + } + + pub async fn incremental_build(&self, database: &dyn Database, ids: &[RecordId]) -> Result<()> { + for id in ids { + let record = self.store.get(*id).await; + + let record = match record { + Ok(record) => record, + _ => { + continue; + } + }; + + if record.tag != HISTORY_TAG { + continue; + } + + let version = record.version.clone(); + let decrypted = record.decrypt::<PASETO_V4>(&self.encryption_key)?; + let record = match version.as_str() { + HISTORY_VERSION_V0 | HISTORY_VERSION => { + HistoryRecord::deserialize(&decrypted.data, version.as_str())? + } + version => bail!("unknown history version {version:?}"), + }; + + match record { + HistoryRecord::Create(h) => { + // TODO: benchmark CPU time/memory tradeoff of batch commit vs one at a time + database.save(&h).await?; + } + HistoryRecord::Delete(id) => { + database.delete_rows(&[id]).await?; + } + } + } + + Ok(()) + } + + /// Get a list of history IDs that exist in the store + /// Note: This currently involves loading all history into memory. This is not going to be a + /// large amount in absolute terms, but do not all it in a hot loop. + pub async fn history_ids(&self) -> Result<HashSet<HistoryId>> { + let history = self.history().await?; + + let ret = HashSet::from_iter(history.iter().map(|h| match h { + HistoryRecord::Create(h) => h.id.clone(), + HistoryRecord::Delete(id) => id.clone(), + })); + + Ok(ret) + } + + pub async fn init_store(&self, db: &impl Database) -> Result<()> { + let pb = ProgressBar::new_spinner(); + pb.set_style( + ProgressStyle::with_template("{spinner:.blue} {msg}") + .unwrap() + .with_key("eta", |state: &ProgressState, w: &mut dyn Write| { + write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap() + }) + .progress_chars("#>-"), + ); + pb.enable_steady_tick(Duration::from_millis(500)); + + pb.set_message("Fetching history from old database"); + + let context = current_context().await?; + let history = db.list(&[], &context, None, false, true).await?; + + pb.set_message("Fetching history already in store"); + let store_ids = self.history_ids().await?; + + pb.set_message("Converting old history to new store"); + let mut records = Vec::new(); + + for i in history { + debug!("loaded {}", i.id); + + if store_ids.contains(&i.id) { + debug!("skipping {} - already exists", i.id); + continue; + } + + if i.deleted_at.is_some() { + records.push(HistoryRecord::Delete(i.id)); + } else { + records.push(HistoryRecord::Create(i)); + } + } + + pb.set_message("Writing to db"); + + if !records.is_empty() { + self.push_batch(records.into_iter()).await?; + } + + pb.finish_with_message("Import complete"); + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use crate::atuin_common::record::DecryptedData; + use time::macros::datetime; + + use crate::atuin_client::history::{HISTORY_VERSION, store::HistoryRecord}; + + use super::History; + + #[test] + fn test_serialize_deserialize_create() { + let bytes = [ + 204, 0, 196, 147, 205, 0, 1, 154, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 56, 49, + 55, 53, 55, 99, 100, 50, 97, 101, 101, 54, 53, 99, 100, 55, 56, 54, 49, 102, 57, 99, + 56, 49, 207, 23, 166, 251, 212, 181, 82, 0, 0, 100, 0, 162, 108, 115, 217, 41, 47, 85, + 115, 101, 114, 115, 47, 101, 108, 108, 105, 101, 47, 115, 114, 99, 47, 103, 105, 116, + 104, 117, 98, 46, 99, 111, 109, 47, 97, 116, 117, 105, 110, 115, 104, 47, 97, 116, 117, + 105, 110, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 97, 100, 56, 57, 55, 53, 57, 55, + 56, 53, 50, 53, 50, 55, 97, 51, 49, 99, 57, 57, 56, 48, 53, 57, 170, 98, 111, 111, 112, + 58, 101, 108, 108, 105, 101, 192, 165, 101, 108, 108, 105, 101, + ]; + + let history = History { + id: "018cd4fe81757cd2aee65cd7861f9c81".to_owned().into(), + timestamp: datetime!(2024-01-04 00:00:00.000000 +00:00), + duration: 100, + exit: 0, + command: "ls".to_owned(), + cwd: "/Users/ellie/src/github.com/atuinsh/atuin".to_owned(), + session: "018cd4fead897597852527a31c998059".to_owned(), + hostname: "boop:ellie".to_owned(), + author: "ellie".to_owned(), + intent: None, + deleted_at: None, + }; + + let record = HistoryRecord::Create(history); + + let serialized = record.serialize().expect("failed to serialize history"); + assert_eq!(serialized.0, bytes); + + let deserialized = HistoryRecord::deserialize(&serialized, HISTORY_VERSION) + .expect("failed to deserialize HistoryRecord"); + assert_eq!(deserialized, record); + + // check the snapshot too + let deserialized = + HistoryRecord::deserialize(&DecryptedData(Vec::from(bytes)), HISTORY_VERSION) + .expect("failed to deserialize HistoryRecord"); + assert_eq!(deserialized, record); + } + + #[test] + fn test_serialize_deserialize_delete() { + let bytes = [ + 204, 1, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 56, 49, 55, 53, 55, 99, 100, 50, + 97, 101, 101, 54, 53, 99, 100, 55, 56, 54, 49, 102, 57, 99, 56, 49, + ]; + let record = HistoryRecord::Delete("018cd4fe81757cd2aee65cd7861f9c81".to_string().into()); + + let serialized = record.serialize().expect("failed to serialize history"); + assert_eq!(serialized.0, bytes); + + let deserialized = HistoryRecord::deserialize(&serialized, HISTORY_VERSION) + .expect("failed to deserialize HistoryRecord"); + assert_eq!(deserialized, record); + + let deserialized = + HistoryRecord::deserialize(&DecryptedData(Vec::from(bytes)), HISTORY_VERSION) + .expect("failed to deserialize HistoryRecord"); + assert_eq!(deserialized, record); + } +} diff --git a/crates/turtle/src/atuin_client/import/bash.rs b/crates/turtle/src/atuin_client/import/bash.rs new file mode 100644 index 00000000..d92fdfa0 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/bash.rs @@ -0,0 +1,221 @@ +use std::{path::PathBuf, str}; + +use async_trait::async_trait; +use directories::UserDirs; +use eyre::{Result, eyre}; +use itertools::Itertools; +use time::{Duration, OffsetDateTime}; +use tracing::warn; + +use super::{Importer, Loader, get_histfile_path, unix_byte_lines}; +use crate::atuin_client::history::History; +use crate::atuin_client::import::read_to_end; + +#[derive(Debug)] +pub struct Bash { + bytes: Vec<u8>, +} + +fn default_histpath() -> Result<PathBuf> { + let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; + let home_dir = user_dirs.home_dir(); + + Ok(home_dir.join(".bash_history")) +} + +#[async_trait] +impl Importer for Bash { + const NAME: &'static str = "bash"; + + async fn new() -> Result<Self> { + let bytes = read_to_end(get_histfile_path(default_histpath)?)?; + Ok(Self { bytes }) + } + + async fn entries(&mut self) -> Result<usize> { + let count = unix_byte_lines(&self.bytes) + .map(LineType::from) + .filter(|line| matches!(line, LineType::Command(_))) + .count(); + Ok(count) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + let lines = unix_byte_lines(&self.bytes) + .map(LineType::from) + .filter(|line| !matches!(line, LineType::NotUtf8)) // invalid utf8 are ignored + .collect_vec(); + + let (commands_before_first_timestamp, first_timestamp) = lines + .iter() + .enumerate() + .find_map(|(i, line)| match line { + LineType::Timestamp(t) => Some((i, *t)), + _ => None, + }) + // if no known timestamps, use now as base + .unwrap_or((lines.len(), OffsetDateTime::now_utc())); + + // if no timestamp is recorded, then use this increment to set an arbitrary timestamp + // to preserve ordering + // this increment is deliberately very small to prevent particularly fast fingers + // causing ordering issues; it also helps in handling the "here document" syntax, + // where several lines are recorded in succession without individual timestamps + let timestamp_increment = Duration::milliseconds(1); + + // make sure there is a minimum amount of time before the first known timestamp + // to fit all commands, given the default increment + let mut next_timestamp = + first_timestamp - timestamp_increment * commands_before_first_timestamp as i32; + + for line in lines.into_iter() { + match line { + LineType::NotUtf8 => unreachable!(), // already filtered + LineType::Empty => {} // do nothing + LineType::Timestamp(t) => { + if t < next_timestamp { + warn!( + "Time reversal detected in Bash history! Commands may be ordered incorrectly." + ); + } + next_timestamp = t; + } + LineType::Command(c) => { + let imported = History::import().timestamp(next_timestamp).command(c); + + h.push(imported.build().into()).await?; + next_timestamp += timestamp_increment; + } + } + } + + Ok(()) + } +} + +#[derive(Debug, Clone)] +enum LineType<'a> { + NotUtf8, + /// Can happen when using the "here document" syntax. + Empty, + /// A timestamp line start with a '#', followed immediately by an integer + /// that represents seconds since UNIX epoch. + Timestamp(OffsetDateTime), + /// Anything else. + Command(&'a str), +} +impl<'a> From<&'a [u8]> for LineType<'a> { + fn from(bytes: &'a [u8]) -> Self { + let Ok(line) = str::from_utf8(bytes) else { + return LineType::NotUtf8; + }; + if line.is_empty() { + return LineType::Empty; + } + + match try_parse_line_as_timestamp(line) { + Some(time) => LineType::Timestamp(time), + None => LineType::Command(line), + } + } +} + +fn try_parse_line_as_timestamp(line: &str) -> Option<OffsetDateTime> { + let seconds = line.strip_prefix('#')?.parse().ok()?; + OffsetDateTime::from_unix_timestamp(seconds).ok() +} + +#[cfg(test)] +mod test { + use std::cmp::Ordering; + + use itertools::{Itertools, assert_equal}; + + use crate::atuin_client::import::{Importer, tests::TestLoader}; + + use super::Bash; + + #[tokio::test] + async fn parse_no_timestamps() { + let bytes = r"cargo install atuin +cargo update +cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷ +" + .as_bytes() + .to_owned(); + + let mut bash = Bash { bytes }; + assert_eq!(bash.entries().await.unwrap(), 3); + + let mut loader = TestLoader::default(); + bash.load(&mut loader).await.unwrap(); + + assert_equal( + loader.buf.iter().map(|h| h.command.as_str()), + [ + "cargo install atuin", + "cargo update", + "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", + ], + ); + assert!(is_strictly_sorted(loader.buf.iter().map(|h| h.timestamp))) + } + + #[tokio::test] + async fn parse_with_timestamps() { + let bytes = b"#1672918999 +git reset +#1672919006 +git clean -dxf +#1672919020 +cd ../ +" + .to_vec(); + + let mut bash = Bash { bytes }; + assert_eq!(bash.entries().await.unwrap(), 3); + + let mut loader = TestLoader::default(); + bash.load(&mut loader).await.unwrap(); + + assert_equal( + loader.buf.iter().map(|h| h.command.as_str()), + ["git reset", "git clean -dxf", "cd ../"], + ); + assert_equal( + loader.buf.iter().map(|h| h.timestamp.unix_timestamp()), + [1_672_918_999, 1_672_919_006, 1_672_919_020], + ) + } + + #[tokio::test] + async fn parse_with_partial_timestamps() { + let bytes = b"git reset +#1672919006 +git clean -dxf +cd ../ +" + .to_vec(); + + let mut bash = Bash { bytes }; + assert_eq!(bash.entries().await.unwrap(), 3); + + let mut loader = TestLoader::default(); + bash.load(&mut loader).await.unwrap(); + + assert_equal( + loader.buf.iter().map(|h| h.command.as_str()), + ["git reset", "git clean -dxf", "cd ../"], + ); + assert!(is_strictly_sorted(loader.buf.iter().map(|h| h.timestamp))) + } + + fn is_strictly_sorted<T>(iter: impl IntoIterator<Item = T>) -> bool + where + T: Clone + PartialOrd, + { + iter.into_iter() + .tuple_windows() + .all(|(a, b)| matches!(a.partial_cmp(&b), Some(Ordering::Less))) + } +} diff --git a/crates/turtle/src/atuin_client/import/fish.rs b/crates/turtle/src/atuin_client/import/fish.rs new file mode 100644 index 00000000..1375bdd6 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/fish.rs @@ -0,0 +1,179 @@ +// import old shell history! +// automatically hoover up all that we can find + +use std::path::PathBuf; + +use async_trait::async_trait; +use directories::BaseDirs; +use eyre::{Result, eyre}; +use time::OffsetDateTime; + +use super::{Importer, Loader, unix_byte_lines}; +use crate::atuin_client::history::History; +use crate::atuin_client::import::read_to_end; + +#[derive(Debug)] +pub struct Fish { + bytes: Vec<u8>, +} + +/// see https://fishshell.com/docs/current/interactive.html#searchable-command-history +fn default_histpath() -> Result<PathBuf> { + let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; + let data = std::env::var("XDG_DATA_HOME").map_or_else( + |_| base.home_dir().join(".local").join("share"), + PathBuf::from, + ); + + // fish supports multiple history sessions + // If `fish_history` var is missing, or set to `default`, use `fish` as the session + let session = std::env::var("fish_history").unwrap_or_else(|_| String::from("fish")); + let session = if session == "default" { + String::from("fish") + } else { + session + }; + + let mut histpath = data.join("fish"); + histpath.push(format!("{session}_history")); + + if histpath.exists() { + Ok(histpath) + } else { + Err(eyre!("Could not find history file.")) + } +} + +#[async_trait] +impl Importer for Fish { + const NAME: &'static str = "fish"; + + async fn new() -> Result<Self> { + let bytes = read_to_end(default_histpath()?)?; + Ok(Self { bytes }) + } + + async fn entries(&mut self) -> Result<usize> { + Ok(super::count_lines(&self.bytes)) + } + + async fn load(self, loader: &mut impl Loader) -> Result<()> { + let now = OffsetDateTime::now_utc(); + let mut time: Option<OffsetDateTime> = None; + let mut cmd: Option<String> = None; + + for b in unix_byte_lines(&self.bytes) { + let s = match std::str::from_utf8(b) { + Ok(s) => s, + Err(_) => continue, // we can skip past things like invalid utf8 + }; + + if let Some(c) = s.strip_prefix("- cmd: ") { + // first, we must deal with the prev cmd + if let Some(cmd) = cmd.take() { + let time = time.unwrap_or(now); + let entry = History::import().timestamp(time).command(cmd); + + loader.push(entry.build().into()).await?; + } + + // using raw strings to avoid needing escaping. + // replaces double backslashes with single backslashes + let c = c.replace(r"\\", r"\"); + // replaces escaped newlines + let c = c.replace(r"\n", "\n"); + // TODO: any other escape characters? + + cmd = Some(c); + } else if let Some(t) = s.strip_prefix(" when: ") { + // if t is not an int, just ignore this line + if let Ok(t) = t.parse::<i64>() { + time = Some(OffsetDateTime::from_unix_timestamp(t)?); + } + } else { + // ... ignore paths lines + } + } + + // we might have a trailing cmd + if let Some(cmd) = cmd.take() { + let time = time.unwrap_or(now); + let entry = History::import().timestamp(time).command(cmd); + + loader.push(entry.build().into()).await?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + + use crate::import::{Importer, tests::TestLoader}; + + use super::Fish; + + #[tokio::test] + async fn parse_complex() { + // complicated input with varying contents and escaped strings. + let bytes = r#"- cmd: history --help + when: 1639162832 +- cmd: cat ~/.bash_history + when: 1639162851 + paths: + - ~/.bash_history +- cmd: ls ~/.local/share/fish/fish_history + when: 1639162890 + paths: + - ~/.local/share/fish/fish_history +- cmd: cat ~/.local/share/fish/fish_history + when: 1639162893 + paths: + - ~/.local/share/fish/fish_history +ERROR +- CORRUPTED: ENTRY + CONTINUE: + - AS + - NORMAL +- cmd: echo "foo" \\\n'bar' baz + when: 1639162933 +- cmd: cat ~/.local/share/fish/fish_history + when: 1639162939 + paths: + - ~/.local/share/fish/fish_history +- cmd: echo "\\"" \\\\ "\\\\" + when: 1639163063 +- cmd: cat ~/.local/share/fish/fish_history + when: 1639163066 + paths: + - ~/.local/share/fish/fish_history +"# + .as_bytes() + .to_owned(); + + let fish = Fish { bytes }; + + let mut loader = TestLoader::default(); + fish.load(&mut loader).await.unwrap(); + let mut history = loader.buf.into_iter(); + + // simple wrapper for fish history entry + macro_rules! fishtory { + ($timestamp:expr_2021, $command:expr_2021) => { + let h = history.next().expect("missing entry in history"); + assert_eq!(h.command.as_str(), $command); + assert_eq!(h.timestamp.unix_timestamp(), $timestamp); + }; + } + + fishtory!(1639162832, "history --help"); + fishtory!(1639162851, "cat ~/.bash_history"); + fishtory!(1639162890, "ls ~/.local/share/fish/fish_history"); + fishtory!(1639162893, "cat ~/.local/share/fish/fish_history"); + fishtory!(1639162933, "echo \"foo\" \\\n'bar' baz"); + fishtory!(1639162939, "cat ~/.local/share/fish/fish_history"); + fishtory!(1639163063, r#"echo "\"" \\ "\\""#); + fishtory!(1639163066, "cat ~/.local/share/fish/fish_history"); + } +} diff --git a/crates/turtle/src/atuin_client/import/mod.rs b/crates/turtle/src/atuin_client/import/mod.rs new file mode 100644 index 00000000..7726ead7 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/mod.rs @@ -0,0 +1,140 @@ +use std::fs::File; +use std::io::Read; +use std::path::PathBuf; + +use async_trait::async_trait; +use eyre::{Result, bail}; +use memchr::Memchr; + +use crate::atuin_client::history::History; + +pub mod bash; +pub mod fish; +pub mod nu; +pub mod nu_histdb; +pub mod powershell; +pub mod replxx; +pub mod resh; +pub mod xonsh; +pub mod xonsh_sqlite; +pub mod zsh; +pub mod zsh_histdb; + +#[async_trait] +pub trait Importer: Sized { + const NAME: &'static str; + async fn new() -> Result<Self>; + async fn entries(&mut self) -> Result<usize>; + async fn load(self, loader: &mut impl Loader) -> Result<()>; +} + +#[async_trait] +pub trait Loader: Sync + Send { + async fn push(&mut self, hist: History) -> eyre::Result<()>; +} + +fn unix_byte_lines(input: &[u8]) -> impl Iterator<Item = &[u8]> { + UnixByteLines { + iter: memchr::memchr_iter(b'\n', input), + bytes: input, + i: 0, + } +} + +struct UnixByteLines<'a> { + iter: Memchr<'a>, + bytes: &'a [u8], + i: usize, +} + +impl<'a> Iterator for UnixByteLines<'a> { + type Item = &'a [u8]; + + fn next(&mut self) -> Option<Self::Item> { + let j = self.iter.next()?; + let out = &self.bytes[self.i..j]; + self.i = j + 1; + Some(out) + } + + fn count(self) -> usize + where + Self: Sized, + { + self.iter.count() + } +} + +fn count_lines(input: &[u8]) -> usize { + unix_byte_lines(input).count() +} + +fn get_histpath<D>(def: D) -> Result<PathBuf> +where + D: FnOnce() -> Result<PathBuf>, +{ + if let Ok(p) = std::env::var("HISTFILE") { + Ok(PathBuf::from(p)) + } else { + def() + } +} + +fn get_histfile_path<D>(def: D) -> Result<PathBuf> +where + D: FnOnce() -> Result<PathBuf>, +{ + get_histpath(def).and_then(is_file) +} + +fn get_histdir_path<D>(def: D) -> Result<PathBuf> +where + D: FnOnce() -> Result<PathBuf>, +{ + get_histpath(def).and_then(is_dir) +} + +fn read_to_end(path: PathBuf) -> Result<Vec<u8>> { + let mut bytes = Vec::new(); + let mut f = File::open(path)?; + f.read_to_end(&mut bytes)?; + Ok(bytes) +} +fn is_file(p: PathBuf) -> Result<PathBuf> { + if p.is_file() { + Ok(p) + } else { + bail!( + "Could not find history file {:?}. Try setting and exporting $HISTFILE", + p + ) + } +} +fn is_dir(p: PathBuf) -> Result<PathBuf> { + if p.is_dir() { + Ok(p) + } else { + bail!( + "Could not find history directory {:?}. Try setting and exporting $HISTFILE", + p + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Default)] + pub struct TestLoader { + pub buf: Vec<History>, + } + + #[async_trait] + impl Loader for TestLoader { + async fn push(&mut self, hist: History) -> Result<()> { + self.buf.push(hist); + Ok(()) + } + } +} diff --git a/crates/turtle/src/atuin_client/import/nu.rs b/crates/turtle/src/atuin_client/import/nu.rs new file mode 100644 index 00000000..c93789b8 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/nu.rs @@ -0,0 +1,67 @@ +// import old shell history! +// automatically hoover up all that we can find + +use std::path::PathBuf; + +use async_trait::async_trait; +use directories::BaseDirs; +use eyre::{Result, eyre}; +use time::OffsetDateTime; + +use super::{Importer, Loader, unix_byte_lines}; +use crate::atuin_client::history::History; +use crate::atuin_client::import::read_to_end; + +#[derive(Debug)] +pub struct Nu { + bytes: Vec<u8>, +} + +fn get_histpath() -> Result<PathBuf> { + let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; + let config_dir = base.config_dir().join("nushell"); + + let histpath = config_dir.join("history.txt"); + if histpath.exists() { + Ok(histpath) + } else { + Err(eyre!("Could not find history file.")) + } +} + +#[async_trait] +impl Importer for Nu { + const NAME: &'static str = "nu"; + + async fn new() -> Result<Self> { + let bytes = read_to_end(get_histpath()?)?; + Ok(Self { bytes }) + } + + async fn entries(&mut self) -> Result<usize> { + Ok(super::count_lines(&self.bytes)) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + let now = OffsetDateTime::now_utc(); + + let mut counter = 0; + for b in unix_byte_lines(&self.bytes) { + let s = match std::str::from_utf8(b) { + Ok(s) => s, + Err(_) => continue, // we can skip past things like invalid utf8 + }; + + let cmd: String = s.replace("<\\n>", "\n"); + + let offset = time::Duration::nanoseconds(counter); + counter += 1; + + let entry = History::import().timestamp(now - offset).command(cmd); + + h.push(entry.build().into()).await?; + } + + Ok(()) + } +} diff --git a/crates/turtle/src/atuin_client/import/nu_histdb.rs b/crates/turtle/src/atuin_client/import/nu_histdb.rs new file mode 100644 index 00000000..7de18369 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/nu_histdb.rs @@ -0,0 +1,113 @@ +// import old shell history! +// automatically hoover up all that we can find + +use std::path::PathBuf; + +use async_trait::async_trait; +use directories::BaseDirs; +use eyre::{Result, eyre}; +use sqlx::{Pool, sqlite::SqlitePool}; +use time::{Duration, OffsetDateTime}; + +use super::Importer; +use crate::atuin_client::history::History; +use crate::atuin_client::import::Loader; + +#[derive(sqlx::FromRow, Debug)] +pub struct HistDbEntry { + pub id: i64, + pub command_line: Vec<u8>, + pub start_timestamp: i64, + pub session_id: i64, + pub hostname: Vec<u8>, + pub cwd: Vec<u8>, + pub duration_ms: i64, + pub exit_status: i64, + pub more_info: Vec<u8>, +} + +impl From<HistDbEntry> for History { + fn from(histdb_item: HistDbEntry) -> Self { + let ts_secs = histdb_item.start_timestamp / 1000; + let ts_ns = (histdb_item.start_timestamp % 1000) * 1_000_000; + let imported = History::import() + .timestamp( + OffsetDateTime::from_unix_timestamp(ts_secs).unwrap() + + Duration::nanoseconds(ts_ns), + ) + .command(String::from_utf8(histdb_item.command_line).unwrap()) + .cwd(String::from_utf8(histdb_item.cwd).unwrap()) + .exit(histdb_item.exit_status) + .duration(histdb_item.duration_ms) + .session(format!("{:x}", histdb_item.session_id)) + .hostname(String::from_utf8(histdb_item.hostname).unwrap()); + + imported.build().into() + } +} + +#[derive(Debug)] +pub struct NuHistDb { + histdb: Vec<HistDbEntry>, +} + +/// Read db at given file, return vector of entries. +async fn hist_from_db(dbpath: PathBuf) -> Result<Vec<HistDbEntry>> { + let pool = SqlitePool::connect(dbpath.to_str().unwrap()).await?; + hist_from_db_conn(pool).await +} + +async fn hist_from_db_conn(pool: Pool<sqlx::Sqlite>) -> Result<Vec<HistDbEntry>> { + let query = r#" + SELECT + id, command_line, start_timestamp, session_id, hostname, cwd, duration_ms, exit_status, + more_info + FROM history + ORDER BY start_timestamp + "#; + let histdb_vec: Vec<HistDbEntry> = sqlx::query_as::<_, HistDbEntry>(query) + .fetch_all(&pool) + .await?; + Ok(histdb_vec) +} + +impl NuHistDb { + pub fn histpath() -> Result<PathBuf> { + let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; + let config_dir = base.config_dir().join("nushell"); + + let histdb_path = config_dir.join("history.sqlite3"); + if histdb_path.exists() { + Ok(histdb_path) + } else { + Err(eyre!("Could not find history file.")) + } + } +} + +#[async_trait] +impl Importer for NuHistDb { + // Not sure how this is used + const NAME: &'static str = "nu_histdb"; + + /// Creates a new NuHistDb and populates the history based on the pre-populated data + /// structure. + async fn new() -> Result<Self> { + let dbpath = NuHistDb::histpath()?; + let histdb_entry_vec = hist_from_db(dbpath).await?; + Ok(Self { + histdb: histdb_entry_vec, + }) + } + + async fn entries(&mut self) -> Result<usize> { + Ok(self.histdb.len()) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + for i in self.histdb { + h.push(i.into()).await?; + } + Ok(()) + } +} diff --git a/crates/turtle/src/atuin_client/import/powershell.rs b/crates/turtle/src/atuin_client/import/powershell.rs new file mode 100644 index 00000000..8adcc850 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/powershell.rs @@ -0,0 +1,202 @@ +use async_trait::async_trait; +use directories::BaseDirs; +use eyre::{Result, eyre}; +use std::path::PathBuf; +use time::{Duration, OffsetDateTime}; + +use super::{Importer, Loader, count_lines, unix_byte_lines}; +use crate::atuin_client::history::History; +use crate::atuin_client::import::read_to_end; + +#[derive(Debug)] +pub struct PowerShell { + bytes: Vec<u8>, + line_count: Option<usize>, +} + +fn get_history_path() -> Result<PathBuf> { + let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; + + // The command line history in PowerShell is maintained by the PSReadLine module: + // https://learn.microsoft.com/en-us/powershell/module/psreadline/about/about_psreadline#command-history + // + // > PSReadLine maintains a history file containing all the commands and data you've entered from the command line. + // > The history files are a file named `$($Host.Name)_history.txt`. + // > On Windows systems the history file is stored at `$Env:APPDATA\Microsoft\Windows\PowerShell\PSReadLine`. + // > On non-Windows systems, the history files are stored at `$Env:XDG_DATA_HOME/powershell/PSReadLine` + // > or `$Env:HOME/.local/share/powershell/PSReadLine`. + + let dir = if cfg!(windows) { + base.data_dir() + .join("Microsoft") + .join("Windows") + .join("PowerShell") + .join("PSReadLine") + } else { + std::env::var("XDG_DATA_HOME") + .map_or_else( + |_| base.home_dir().join(".local").join("share"), + PathBuf::from, + ) + .join("powershell") + .join("PSReadLine") + }; + + // The history is stored in a file named `$($Host.Name)_history.txt`. + // For the default console host shipped by Microsoft,`$Host.Name` is `ConsoleHost`: + // https://learn.microsoft.com/en-us/dotnet/api/system.management.automation.host.pshost.name#remarks + + let file = dir.join("ConsoleHost_history.txt"); + + if file.is_file() { + Ok(file) + } else { + Err(eyre!("Could not find history file: {}", file.display())) + } +} + +#[async_trait] +impl Importer for PowerShell { + const NAME: &'static str = "PowerShell"; + + async fn new() -> Result<Self> { + let bytes = read_to_end(get_history_path()?)?; + Ok(Self { + bytes, + line_count: None, + }) + } + + async fn entries(&mut self) -> Result<usize> { + // Commands can be split over multiple lines, + // but this is only used for a progress bar, and multi-line commands + // should be quite rare, so this is not an issue in practice. + if self.line_count.is_none() { + self.line_count = Some(count_lines(&self.bytes)); + } + Ok(self.line_count.unwrap()) + } + + async fn load(mut self, h: &mut impl Loader) -> Result<()> { + let line_count = self.entries().await?; + let start = OffsetDateTime::now_utc() - Duration::milliseconds(line_count as i64); + + let mut counter = 0; + let mut iter = unix_byte_lines(&self.bytes); + + while let Some(s) = iter.next() { + let Ok(s) = read_line(s) else { + continue; // We can skip past things like invalid utf8 + }; + + let mut cmd = s.to_string(); + + // Multi-line commands end with a backtick, append the following lines. + while cmd.ends_with('`') { + cmd.pop(); + + let Some(next) = iter.next() else { + break; + }; + let Ok(next) = read_line(next) else { + break; + }; + + cmd.push('\n'); + cmd.push_str(next); + } + + if cmd.is_empty() { + continue; + } + + let offset = Duration::milliseconds(counter); + counter += 1; + + let entry = History::import().timestamp(start + offset).command(cmd); + h.push(entry.build().into()).await?; + } + + Ok(()) + } +} + +fn read_line(s: &[u8]) -> Result<&str> { + let s = str::from_utf8(s)?; + + // History is stored in CRLF on Windows, normalize the input to LF on all platforms. + let s = s.strip_suffix('\r').unwrap_or(s); + + Ok(s) +} + +#[cfg(test)] +mod test { + use super::*; + use crate::import::tests::TestLoader; + use itertools::assert_equal; + + const INPUT: &str = r#"cargo install atuin +cargo update +echo "first line` +second line` +` +last line" +echo foo + +echo bar +echo baz +"#; + + const EXPECTED: &[&str] = &[ + "cargo install atuin", + "cargo update", + "echo \"first line\nsecond line\n\nlast line\"", + "echo foo", + "echo bar", + "echo baz", + ]; + + #[tokio::test] + async fn test_import() { + let loader = import(INPUT).await; + + let actual = loader.buf.iter().map(|h| h.command.clone()); + let expected = EXPECTED.iter().map(|s| s.to_string()); + + assert_equal(actual, expected); + } + + #[tokio::test] + async fn test_crlf() { + let input = INPUT.replace("\n", "\r\n"); + let loader = import(input.as_str()).await; + + let actual = loader.buf.iter().map(|h| h.command.clone()); + let expected = EXPECTED.iter().map(|s| s.to_string()); + + assert_equal(actual, expected); + } + + #[tokio::test] + async fn test_timestamps() { + let loader = import(INPUT).await; + + let mut prev = loader.buf.first().unwrap().timestamp; + for current in loader.buf.iter().skip(1).map(|h| h.timestamp) { + assert!(current > prev); + prev = current; + } + } + + async fn import(input: &str) -> TestLoader { + let powershell = PowerShell { + bytes: input.as_bytes().to_vec(), + line_count: None, + }; + + let mut loader = TestLoader::default(); + powershell.load(&mut loader).await.unwrap(); + loader + } +} diff --git a/crates/turtle/src/atuin_client/import/replxx.rs b/crates/turtle/src/atuin_client/import/replxx.rs new file mode 100644 index 00000000..42f84df5 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/replxx.rs @@ -0,0 +1,137 @@ +use std::{path::PathBuf, str}; + +use async_trait::async_trait; +use directories::UserDirs; +use eyre::{Result, eyre}; +use time::{OffsetDateTime, PrimitiveDateTime, macros::format_description}; + +use super::{Importer, Loader, get_histfile_path, unix_byte_lines}; +use crate::atuin_client::history::History; +use crate::atuin_client::import::read_to_end; + +#[derive(Debug)] +pub struct Replxx { + bytes: Vec<u8>, +} + +fn default_histpath() -> Result<PathBuf> { + let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; + let home_dir = user_dirs.home_dir(); + + // There is no default histfile for replxx. + // Here we try a couple of common names. + let mut candidates = ["replxx_history.txt", ".histfile"].iter(); + loop { + match candidates.next() { + Some(candidate) => { + let histpath = home_dir.join(candidate); + if histpath.exists() { + break Ok(histpath); + } + } + None => { + break Err(eyre!( + "Could not find history file. Try setting and exporting $HISTFILE" + )); + } + } + } +} + +#[async_trait] +impl Importer for Replxx { + const NAME: &'static str = "replxx"; + + async fn new() -> Result<Self> { + let bytes = read_to_end(get_histfile_path(default_histpath)?)?; + Ok(Self { bytes }) + } + + async fn entries(&mut self) -> Result<usize> { + Ok(super::count_lines(&self.bytes) / 2) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + let mut timestamp = OffsetDateTime::UNIX_EPOCH; + + for b in unix_byte_lines(&self.bytes) { + let s = std::str::from_utf8(b)?; + match try_parse_line_as_timestamp(s) { + Some(t) => timestamp = t, + None => { + // replxx uses ETB character (0x17) as line breaker + let cmd = s.replace('\u{0017}', "\n"); + let imported = History::import().timestamp(timestamp).command(cmd); + + h.push(imported.build().into()).await?; + } + } + } + + Ok(()) + } +} + +fn try_parse_line_as_timestamp(line: &str) -> Option<OffsetDateTime> { + // replxx history date time format: ### yyyy-mm-dd hh:mm:ss.xxx + let date_time_str = line.strip_prefix("### ")?; + let format = + format_description!("[year]-[month]-[day] [hour]:[minute]:[second].[subsecond digits:3]"); + + let primitive_date_time = PrimitiveDateTime::parse(date_time_str, format).ok()?; + // There is no safe way to get local time offset. + // For simplicity let's just assume UTC. + Some(primitive_date_time.assume_utc()) +} + +#[cfg(test)] +mod test { + + use crate::import::{Importer, tests::TestLoader}; + + use super::Replxx; + + #[tokio::test] + async fn parse_complex() { + let bytes = r#"### 2024-02-10 22:16:28.302 +select * from remote('127.0.0.1:20222', view(select 1)) +### 2024-02-10 22:16:36.919 +select * from numbers(10) +### 2024-02-10 22:16:41.710 +select * from system.numbers +### 2024-02-10 22:19:28.655 +select 1 +### 2024-02-22 11:15:33.046 +CREATE TABLE test( stamp DateTime('UTC'))ENGINE = MergeTreePARTITION BY toDate(stamp)order by tuple() as select toDateTime('2020-01-01')+number*60 from numbers(80000); +"# + .as_bytes() + .to_owned(); + + let replxx = Replxx { bytes }; + + let mut loader = TestLoader::default(); + replxx.load(&mut loader).await.unwrap(); + let mut history = loader.buf.into_iter(); + + // simple wrapper for replxx history entry + macro_rules! history { + ($timestamp:expr_2021, $command:expr_2021) => { + let h = history.next().expect("missing entry in history"); + assert_eq!(h.command.as_str(), $command); + assert_eq!(h.timestamp.unix_timestamp(), $timestamp); + }; + } + + history!( + 1707603388, + "select * from remote('127.0.0.1:20222', view(select 1))" + ); + history!(1707603396, "select * from numbers(10)"); + history!(1707603401, "select * from system.numbers"); + history!(1707603568, "select 1"); + history!( + 1708600533, + "CREATE TABLE test\n( stamp DateTime('UTC'))\nENGINE = MergeTree\nPARTITION BY toDate(stamp)\norder by tuple() as select toDateTime('2020-01-01')+number*60 from numbers(80000);" + ); + } +} diff --git a/crates/turtle/src/atuin_client/import/resh.rs b/crates/turtle/src/atuin_client/import/resh.rs new file mode 100644 index 00000000..c5980c44 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/resh.rs @@ -0,0 +1,140 @@ +use std::path::PathBuf; + +use async_trait::async_trait; +use directories::UserDirs; +use eyre::{Result, eyre}; +use serde::Deserialize; + +use crate::atuin_common::utils::uuid_v7; +use time::OffsetDateTime; + +use super::{Importer, Loader, get_histfile_path, unix_byte_lines}; +use crate::atuin_client::history::History; +use crate::atuin_client::import::read_to_end; + +#[derive(Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct ReshEntry { + pub cmd_line: String, + pub exit_code: i64, + pub shell: String, + pub uname: String, + pub session_id: String, + pub home: String, + pub lang: String, + pub lc_all: String, + pub login: String, + pub pwd: String, + pub pwd_after: String, + pub shell_env: String, + pub term: String, + pub real_pwd: String, + pub real_pwd_after: String, + pub pid: i64, + pub session_pid: i64, + pub host: String, + pub hosttype: String, + pub ostype: String, + pub machtype: String, + pub shlvl: i64, + pub timezone_before: String, + pub timezone_after: String, + pub realtime_before: f64, + pub realtime_after: f64, + pub realtime_before_local: f64, + pub realtime_after_local: f64, + pub realtime_duration: f64, + pub realtime_since_session_start: f64, + pub realtime_since_boot: f64, + pub git_dir: String, + pub git_real_dir: String, + pub git_origin_remote: String, + pub git_dir_after: String, + pub git_real_dir_after: String, + pub git_origin_remote_after: String, + pub machine_id: String, + pub os_release_id: String, + pub os_release_version_id: String, + pub os_release_id_like: String, + pub os_release_name: String, + pub os_release_pretty_name: String, + pub resh_uuid: String, + pub resh_version: String, + pub resh_revision: String, + pub parts_merged: bool, + pub recalled: bool, + pub recall_last_cmd_line: String, + pub cols: String, + pub lines: String, +} + +#[derive(Debug)] +pub struct Resh { + bytes: Vec<u8>, +} + +fn default_histpath() -> Result<PathBuf> { + let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; + let home_dir = user_dirs.home_dir(); + + Ok(home_dir.join(".resh_history.json")) +} + +#[async_trait] +impl Importer for Resh { + const NAME: &'static str = "resh"; + + async fn new() -> Result<Self> { + let bytes = read_to_end(get_histfile_path(default_histpath)?)?; + Ok(Self { bytes }) + } + + async fn entries(&mut self) -> Result<usize> { + Ok(super::count_lines(&self.bytes)) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + for b in unix_byte_lines(&self.bytes) { + let s = match std::str::from_utf8(b) { + Ok(s) => s, + Err(_) => continue, // we can skip past things like invalid utf8 + }; + let entry = match serde_json::from_str::<ReshEntry>(s) { + Ok(e) => e, + Err(_) => continue, // skip invalid json :shrug: + }; + + #[expect(clippy::cast_possible_truncation)] + #[expect(clippy::cast_sign_loss)] + let timestamp = { + let secs = entry.realtime_before.floor() as i64; + let nanosecs = (entry.realtime_before.fract() * 1_000_000_000_f64).round() as i64; + OffsetDateTime::from_unix_timestamp(secs)? + time::Duration::nanoseconds(nanosecs) + }; + #[expect(clippy::cast_possible_truncation)] + #[expect(clippy::cast_sign_loss)] + let duration = { + let secs = entry.realtime_after.floor() as i64; + let nanosecs = (entry.realtime_after.fract() * 1_000_000_000_f64).round() as i64; + let base = OffsetDateTime::from_unix_timestamp(secs)? + + time::Duration::nanoseconds(nanosecs); + let difference = base - timestamp; + difference.whole_nanoseconds() as i64 + }; + + let imported = History::import() + .command(entry.cmd_line) + .timestamp(timestamp) + .duration(duration) + .exit(entry.exit_code) + .cwd(entry.pwd) + .hostname(entry.host) + // CHECK: should we add uuid here? It's not set in the other importers + .session(uuid_v7().as_simple().to_string()); + + h.push(imported.build().into()).await?; + } + + Ok(()) + } +} diff --git a/crates/turtle/src/atuin_client/import/xonsh.rs b/crates/turtle/src/atuin_client/import/xonsh.rs new file mode 100644 index 00000000..a7217826 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/xonsh.rs @@ -0,0 +1,234 @@ +use std::env; +use std::fs::{self, File}; +use std::path::{Path, PathBuf}; + +use async_trait::async_trait; +use directories::BaseDirs; +use eyre::{Result, eyre}; +use serde::Deserialize; +use time::OffsetDateTime; +use uuid::Uuid; +use uuid::timestamp::{Timestamp, context::NoContext}; + +use super::{Importer, Loader, get_histdir_path}; +use crate::atuin_client::history::History; +use crate::atuin_client::utils::get_host_user; + +// Note: both HistoryFile and HistoryData have other keys present in the JSON, we don't +// care about them so we leave them unspecified so as to avoid deserializing unnecessarily. +#[derive(Debug, Deserialize)] +struct HistoryFile { + data: HistoryData, +} + +#[derive(Debug, Deserialize)] +struct HistoryData { + sessionid: String, + cmds: Vec<HistoryCmd>, +} + +#[derive(Debug, Deserialize)] +struct HistoryCmd { + cwd: String, + inp: String, + rtn: Option<i64>, + ts: (f64, f64), +} + +#[derive(Debug)] +pub struct Xonsh { + // history is stored as a bunch of json files, one per session + sessions: Vec<HistoryData>, + hostname: String, +} + +fn xonsh_hist_dir(xonsh_data_dir: Option<String>) -> Result<PathBuf> { + // if running within xonsh, this will be available + if let Some(d) = xonsh_data_dir { + let mut path = PathBuf::from(d); + path.push("history_json"); + return Ok(path); + } + + // otherwise, fall back to default + let base = BaseDirs::new().ok_or_else(|| eyre!("Could not determine home directory"))?; + + let hist_dir = base.data_dir().join("xonsh/history_json"); + if hist_dir.exists() || cfg!(test) { + Ok(hist_dir) + } else { + Err(eyre!("Could not find xonsh history files")) + } +} + +fn load_sessions(hist_dir: &Path) -> Result<Vec<HistoryData>> { + let mut sessions = vec![]; + for entry in fs::read_dir(hist_dir)? { + let p = entry?.path(); + let ext = p.extension().and_then(|e| e.to_str()); + if p.is_file() + && ext == Some("json") + && let Some(data) = load_session(&p)? + { + sessions.push(data); + } + } + Ok(sessions) +} + +fn load_session(path: &Path) -> Result<Option<HistoryData>> { + let file = File::open(path)?; + // empty files are not valid json, so we can't deserialize them + if file.metadata()?.len() == 0 { + return Ok(None); + } + + let mut hist_file: HistoryFile = serde_json::from_reader(file)?; + + // if there are commands in this session, replace the existing UUIDv4 + // with a UUIDv7 generated from the timestamp of the first command + if let Some(cmd) = hist_file.data.cmds.first() { + let seconds = cmd.ts.0.trunc() as u64; + let nanos = (cmd.ts.0.fract() * 1_000_000_000_f64) as u32; + let ts = Timestamp::from_unix(NoContext, seconds, nanos); + hist_file.data.sessionid = Uuid::new_v7(ts).to_string(); + } + Ok(Some(hist_file.data)) +} + +#[async_trait] +impl Importer for Xonsh { + const NAME: &'static str = "xonsh"; + + async fn new() -> Result<Self> { + // wrap xonsh-specific path resolver in general one so that it respects $HISTPATH + let xonsh_data_dir = env::var("XONSH_DATA_DIR").ok(); + let hist_dir = get_histdir_path(|| xonsh_hist_dir(xonsh_data_dir))?; + let sessions = load_sessions(&hist_dir)?; + let hostname = get_host_user(); + Ok(Xonsh { sessions, hostname }) + } + + async fn entries(&mut self) -> Result<usize> { + let total = self.sessions.iter().map(|s| s.cmds.len()).sum(); + Ok(total) + } + + async fn load(self, loader: &mut impl Loader) -> Result<()> { + for session in self.sessions { + for cmd in session.cmds { + let (start, end) = cmd.ts; + let ts_nanos = (start * 1_000_000_000_f64) as i128; + let timestamp = OffsetDateTime::from_unix_timestamp_nanos(ts_nanos)?; + + let duration = (end - start) * 1_000_000_000_f64; + + match cmd.rtn { + Some(exit) => { + let entry = History::import() + .timestamp(timestamp) + .duration(duration.trunc() as i64) + .exit(exit) + .command(cmd.inp.trim()) + .cwd(cmd.cwd) + .session(session.sessionid.clone()) + .hostname(self.hostname.clone()); + loader.push(entry.build().into()).await?; + } + None => { + let entry = History::import() + .timestamp(timestamp) + .duration(duration.trunc() as i64) + .command(cmd.inp.trim()) + .cwd(cmd.cwd) + .session(session.sessionid.clone()) + .hostname(self.hostname.clone()); + loader.push(entry.build().into()).await?; + } + } + } + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use time::macros::datetime; + + use super::*; + + use crate::history::History; + use crate::import::tests::TestLoader; + + #[test] + fn test_hist_dir_xonsh() { + let hist_dir = xonsh_hist_dir(Some("/home/user/xonsh_data".to_string())).unwrap(); + assert_eq!( + hist_dir, + PathBuf::from("/home/user/xonsh_data/history_json") + ); + } + + #[tokio::test] + async fn test_import() { + let dir = PathBuf::from("tests/data/xonsh"); + let sessions = load_sessions(&dir).unwrap(); + let hostname = "box:user".to_string(); + let xonsh = Xonsh { sessions, hostname }; + + let mut loader = TestLoader::default(); + xonsh.load(&mut loader).await.unwrap(); + // order in buf will depend on filenames, so sort by timestamp for consistency + loader.buf.sort_by_key(|h| h.timestamp); + for (actual, expected) in loader.buf.iter().zip(expected_hist_entries().iter()) { + assert_eq!(actual.timestamp, expected.timestamp); + assert_eq!(actual.command, expected.command); + assert_eq!(actual.cwd, expected.cwd); + assert_eq!(actual.exit, expected.exit); + assert_eq!(actual.duration, expected.duration); + assert_eq!(actual.hostname, expected.hostname); + } + } + + fn expected_hist_entries() -> [History; 4] { + [ + History::import() + .timestamp(datetime!(2024-02-6 04:17:59.478272256 +00:00:00)) + .command("echo hello world!".to_string()) + .cwd("/home/user/Documents/code/atuin".to_string()) + .exit(0) + .duration(4651069) + .hostname("box:user".to_string()) + .build() + .into(), + History::import() + .timestamp(datetime!(2024-02-06 04:18:01.70632832 +00:00:00)) + .command("ls -l".to_string()) + .cwd("/home/user/Documents/code/atuin".to_string()) + .exit(0) + .duration(21288633) + .hostname("box:user".to_string()) + .build() + .into(), + History::import() + .timestamp(datetime!(2024-02-06 17:41:31.142515968 +00:00:00)) + .command("false".to_string()) + .cwd("/home/user/Documents/code/atuin/atuin-client".to_string()) + .exit(1) + .duration(10269403) + .hostname("box:user".to_string()) + .build() + .into(), + History::import() + .timestamp(datetime!(2024-02-06 17:41:32.271584 +00:00:00)) + .command("exit".to_string()) + .cwd("/home/user/Documents/code/atuin/atuin-client".to_string()) + .exit(0) + .duration(4259347) + .hostname("box:user".to_string()) + .build() + .into(), + ] + } +} diff --git a/crates/turtle/src/atuin_client/import/xonsh_sqlite.rs b/crates/turtle/src/atuin_client/import/xonsh_sqlite.rs new file mode 100644 index 00000000..ceedf7e9 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/xonsh_sqlite.rs @@ -0,0 +1,217 @@ +use std::env; +use std::path::PathBuf; + +use async_trait::async_trait; +use directories::BaseDirs; +use eyre::{Result, eyre}; +use futures::TryStreamExt; +use sqlx::{FromRow, Row, sqlite::SqlitePool}; +use time::OffsetDateTime; +use uuid::Uuid; +use uuid::timestamp::{Timestamp, context::NoContext}; + +use super::{Importer, Loader, get_histfile_path}; +use crate::atuin_client::history::History; +use crate::atuin_client::utils::get_host_user; + +#[derive(Debug, FromRow)] +struct HistDbEntry { + inp: String, + rtn: Option<i64>, + tsb: f64, + tse: f64, + cwd: String, + session_start: f64, +} + +impl HistDbEntry { + fn into_hist_with_hostname(self, hostname: String) -> History { + let ts_nanos = (self.tsb * 1_000_000_000_f64) as i128; + let timestamp = OffsetDateTime::from_unix_timestamp_nanos(ts_nanos).unwrap(); + + let session_ts_seconds = self.session_start.trunc() as u64; + let session_ts_nanos = (self.session_start.fract() * 1_000_000_000_f64) as u32; + let session_ts = Timestamp::from_unix(NoContext, session_ts_seconds, session_ts_nanos); + let session_id = Uuid::new_v7(session_ts).to_string(); + let duration = (self.tse - self.tsb) * 1_000_000_000_f64; + + if let Some(exit) = self.rtn { + let imported = History::import() + .timestamp(timestamp) + .duration(duration.trunc() as i64) + .exit(exit) + .command(self.inp) + .cwd(self.cwd) + .session(session_id) + .hostname(hostname); + imported.build().into() + } else { + let imported = History::import() + .timestamp(timestamp) + .duration(duration.trunc() as i64) + .command(self.inp) + .cwd(self.cwd) + .session(session_id) + .hostname(hostname); + imported.build().into() + } + } +} + +fn xonsh_db_path(xonsh_data_dir: Option<String>) -> Result<PathBuf> { + // if running within xonsh, this will be available + if let Some(d) = xonsh_data_dir { + let mut path = PathBuf::from(d); + path.push("xonsh-history.sqlite"); + return Ok(path); + } + + // otherwise, fall back to default + let base = BaseDirs::new().ok_or_else(|| eyre!("Could not determine home directory"))?; + + let hist_file = base.data_dir().join("xonsh/xonsh-history.sqlite"); + if hist_file.exists() || cfg!(test) { + Ok(hist_file) + } else { + Err(eyre!( + "Could not find xonsh history db at: {}", + hist_file.to_string_lossy() + )) + } +} + +#[derive(Debug)] +pub struct XonshSqlite { + pool: SqlitePool, + hostname: String, +} + +#[async_trait] +impl Importer for XonshSqlite { + const NAME: &'static str = "xonsh_sqlite"; + + async fn new() -> Result<Self> { + // wrap xonsh-specific path resolver in general one so that it respects $HISTPATH + let xonsh_data_dir = env::var("XONSH_DATA_DIR").ok(); + let db_path = get_histfile_path(|| xonsh_db_path(xonsh_data_dir))?; + let connection_str = db_path.to_str().ok_or_else(|| { + eyre!( + "Invalid path for SQLite database: {}", + db_path.to_string_lossy() + ) + })?; + + let pool = SqlitePool::connect(connection_str).await?; + let hostname = get_host_user(); + Ok(XonshSqlite { pool, hostname }) + } + + async fn entries(&mut self) -> Result<usize> { + let query = "SELECT COUNT(*) FROM xonsh_history"; + let row = sqlx::query(query).fetch_one(&self.pool).await?; + let count: u32 = row.get(0); + Ok(count as usize) + } + + async fn load(self, loader: &mut impl Loader) -> Result<()> { + let query = r#" + SELECT inp, rtn, tsb, tse, cwd, + MIN(tsb) OVER (PARTITION BY sessionid) AS session_start + FROM xonsh_history + ORDER BY rowid + "#; + + let mut entries = sqlx::query_as::<_, HistDbEntry>(query).fetch(&self.pool); + + let mut count = 0; + while let Some(entry) = entries.try_next().await? { + let hist = entry.into_hist_with_hostname(self.hostname.clone()); + loader.push(hist).await?; + count += 1; + } + + println!("Loaded: {count}"); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use time::macros::datetime; + + use super::*; + + use crate::history::History; + use crate::import::tests::TestLoader; + + #[test] + fn test_db_path_xonsh() { + let db_path = xonsh_db_path(Some("/home/user/xonsh_data".to_string())).unwrap(); + assert_eq!( + db_path, + PathBuf::from("/home/user/xonsh_data/xonsh-history.sqlite") + ); + } + + #[tokio::test] + async fn test_import() { + let connection_str = "tests/data/xonsh-history.sqlite"; + let xonsh_sqlite = XonshSqlite { + pool: SqlitePool::connect(connection_str).await.unwrap(), + hostname: "box:user".to_string(), + }; + + let mut loader = TestLoader::default(); + xonsh_sqlite.load(&mut loader).await.unwrap(); + + for (actual, expected) in loader.buf.iter().zip(expected_hist_entries().iter()) { + assert_eq!(actual.timestamp, expected.timestamp); + assert_eq!(actual.command, expected.command); + assert_eq!(actual.cwd, expected.cwd); + assert_eq!(actual.exit, expected.exit); + assert_eq!(actual.duration, expected.duration); + assert_eq!(actual.hostname, expected.hostname); + } + } + + fn expected_hist_entries() -> [History; 4] { + [ + History::import() + .timestamp(datetime!(2024-02-6 17:56:21.130956288 +00:00:00)) + .command("echo hello world!".to_string()) + .cwd("/home/user/Documents/code/atuin".to_string()) + .exit(0) + .duration(2628564) + .hostname("box:user".to_string()) + .build() + .into(), + History::import() + .timestamp(datetime!(2024-02-06 17:56:28.190406144 +00:00:00)) + .command("ls -l".to_string()) + .cwd("/home/user/Documents/code/atuin".to_string()) + .exit(0) + .duration(9371519) + .hostname("box:user".to_string()) + .build() + .into(), + History::import() + .timestamp(datetime!(2024-02-06 17:56:46.989020928 +00:00:00)) + .command("false".to_string()) + .cwd("/home/user/Documents/code/atuin".to_string()) + .exit(1) + .duration(17337560) + .hostname("box:user".to_string()) + .build() + .into(), + History::import() + .timestamp(datetime!(2024-02-06 17:56:48.218384128 +00:00:00)) + .command("exit".to_string()) + .cwd("/home/user/Documents/code/atuin".to_string()) + .exit(0) + .duration(4599094) + .hostname("box:user".to_string()) + .build() + .into(), + ] + } +} diff --git a/crates/turtle/src/atuin_client/import/zsh.rs b/crates/turtle/src/atuin_client/import/zsh.rs new file mode 100644 index 00000000..e1fd813a --- /dev/null +++ b/crates/turtle/src/atuin_client/import/zsh.rs @@ -0,0 +1,230 @@ +// import old shell history! +// automatically hoover up all that we can find + +use std::borrow::Cow; +use std::path::PathBuf; + +use async_trait::async_trait; +use directories::UserDirs; +use eyre::{Result, eyre}; +use time::OffsetDateTime; + +use super::{Importer, Loader, get_histfile_path, unix_byte_lines}; +use crate::atuin_client::history::History; +use crate::atuin_client::import::read_to_end; + +#[derive(Debug)] +pub struct Zsh { + bytes: Vec<u8>, +} + +fn default_histpath() -> Result<PathBuf> { + // oh-my-zsh sets HISTFILE=~/.zhistory + // zsh has no default value for this var, but uses ~/.zhistory. + // zsh-newuser-install propose as default .histfile https://github.com/zsh-users/zsh/blob/master/Functions/Newuser/zsh-newuser-install#L794 + // we could maybe be smarter about this in the future :) + let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; + let home_dir = user_dirs.home_dir(); + + let mut candidates = [".zhistory", ".zsh_history", ".histfile"].iter(); + loop { + match candidates.next() { + Some(candidate) => { + let histpath = home_dir.join(candidate); + if histpath.exists() { + break Ok(histpath); + } + } + None => { + break Err(eyre!( + "Could not find history file. Try setting and exporting $HISTFILE" + )); + } + } + } +} + +#[async_trait] +impl Importer for Zsh { + const NAME: &'static str = "zsh"; + + async fn new() -> Result<Self> { + let bytes = read_to_end(get_histfile_path(default_histpath)?)?; + Ok(Self { bytes }) + } + + async fn entries(&mut self) -> Result<usize> { + Ok(super::count_lines(&self.bytes)) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + let now = OffsetDateTime::now_utc(); + let mut line = String::new(); + + let mut counter = 0; + for b in unix_byte_lines(&self.bytes) { + let s = match unmetafy(b) { + Some(s) => s, + _ => continue, // we can skip past things like invalid utf8 + }; + + if let Some(s) = s.strip_suffix('\\') { + line.push_str(s); + line.push('\n'); + } else { + line.push_str(&s); + let command = std::mem::take(&mut line); + + if let Some(command) = command.strip_prefix(": ") { + counter += 1; + h.push(parse_extended(command, counter)).await?; + } else { + let offset = time::Duration::seconds(counter); + counter += 1; + + let imported = History::import() + // preserve ordering + .timestamp(now - offset) + .command(command.trim_end().to_string()); + + h.push(imported.build().into()).await?; + } + } + } + + Ok(()) + } +} + +fn parse_extended(line: &str, counter: i64) -> History { + let (time, duration) = line.split_once(':').unwrap(); + let (duration, command) = duration.split_once(';').unwrap(); + + let time = time + .parse::<i64>() + .ok() + .and_then(|t| OffsetDateTime::from_unix_timestamp(t).ok()) + .unwrap_or_else(OffsetDateTime::now_utc) + + time::Duration::milliseconds(counter); + + // use nanos, because why the hell not? we won't display them. + let duration = duration.parse::<i64>().map_or(-1, |t| t * 1_000_000_000); + + let imported = History::import() + .timestamp(time) + .command(command.trim_end().to_string()) + .duration(duration); + + imported.build().into() +} + +fn unmetafy(line: &[u8]) -> Option<Cow<'_, str>> { + if line.contains(&0x83) { + let mut s = Vec::with_capacity(line.len()); + let mut is_meta = false; + for ch in line { + if *ch == 0x83 { + is_meta = true; + } else if is_meta { + is_meta = false; + s.push(*ch ^ 32); + } else { + s.push(*ch) + } + } + String::from_utf8(s).ok().map(Cow::Owned) + } else { + std::str::from_utf8(line).ok().map(Cow::Borrowed) + } +} + +#[cfg(test)] +mod test { + use itertools::assert_equal; + + use crate::import::tests::TestLoader; + + use super::*; + + #[test] + fn test_parse_extended_simple() { + let parsed = parse_extended("1613322469:0;cargo install atuin", 0); + + assert_eq!(parsed.command, "cargo install atuin"); + assert_eq!(parsed.duration, 0); + assert_eq!( + parsed.timestamp, + OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() + ); + + let parsed = parse_extended("1613322469:10;cargo install atuin;cargo update", 0); + + assert_eq!(parsed.command, "cargo install atuin;cargo update"); + assert_eq!(parsed.duration, 10_000_000_000); + assert_eq!( + parsed.timestamp, + OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() + ); + + let parsed = parse_extended("1613322469:10;cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", 0); + + assert_eq!(parsed.command, "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷"); + assert_eq!(parsed.duration, 10_000_000_000); + assert_eq!( + parsed.timestamp, + OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() + ); + + let parsed = parse_extended("1613322469:10;cargo install \\n atuin\n", 0); + + assert_eq!(parsed.command, "cargo install \\n atuin"); + assert_eq!(parsed.duration, 10_000_000_000); + assert_eq!( + parsed.timestamp, + OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() + ); + } + + #[tokio::test] + async fn test_parse_file() { + let bytes = r": 1613322469:0;cargo install atuin +: 1613322469:10;cargo install atuin; \\ +cargo update +: 1613322469:10;cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷ +" + .as_bytes() + .to_owned(); + + let mut zsh = Zsh { bytes }; + assert_eq!(zsh.entries().await.unwrap(), 4); + + let mut loader = TestLoader::default(); + zsh.load(&mut loader).await.unwrap(); + + assert_equal( + loader.buf.iter().map(|h| h.command.as_str()), + [ + "cargo install atuin", + "cargo install atuin; \\\ncargo update", + "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", + ], + ); + } + + #[tokio::test] + async fn test_parse_metafied() { + let bytes = + b"echo \xe4\xbd\x83\x80\xe5\xa5\xbd\nls ~/\xe9\x83\xbf\xb3\xe4\xb9\x83\xb0\n".to_vec(); + + let mut zsh = Zsh { bytes }; + assert_eq!(zsh.entries().await.unwrap(), 2); + + let mut loader = TestLoader::default(); + zsh.load(&mut loader).await.unwrap(); + + assert_equal( + loader.buf.iter().map(|h| h.command.as_str()), + ["echo 你好", "ls ~/音乐"], + ); + } +} diff --git a/crates/turtle/src/atuin_client/import/zsh_histdb.rs b/crates/turtle/src/atuin_client/import/zsh_histdb.rs new file mode 100644 index 00000000..f61bb74f --- /dev/null +++ b/crates/turtle/src/atuin_client/import/zsh_histdb.rs @@ -0,0 +1,249 @@ +// import old shell history from zsh-histdb! +// automatically hoover up all that we can find + +// As far as i can tell there are no version numbers in the histdb sqlite DB, so we're going based +// on the schema from 2022-05-01 +// +// I have run into some histories that will not import b/c of non UTF-8 characters. +// + +// +// An Example sqlite query for hsitdb data: +// +//id|session|command_id|place_id|exit_status|start_time|duration|id|argv|id|host|dir +// +// +// select +// history.id, +// history.start_time, +// places.host, +// places.dir, +// commands.argv +// from history +// left join commands on history.command_id = commands.id +// left join places on history.place_id = places.id ; +// +// CREATE TABLE history (id integer primary key autoincrement, +// session int, +// command_id int references commands (id), +// place_id int references places (id), +// exit_status int, +// start_time int, +// duration int); +// + +use std::collections::HashMap; +use std::path::{Path, PathBuf}; + +use async_trait::async_trait; +use crate::atuin_common::utils::uuid_v7; +use directories::UserDirs; +use eyre::{Result, eyre}; +use sqlx::{Pool, sqlite::SqlitePool}; +use time::PrimitiveDateTime; + +use super::Importer; +use crate::atuin_client::history::History; +use crate::atuin_client::import::Loader; +use crate::atuin_client::utils::{get_hostname, get_username}; + +#[derive(sqlx::FromRow, Debug)] +pub struct HistDbEntryCount { + pub count: usize, +} + +#[derive(sqlx::FromRow, Debug)] +pub struct HistDbEntry { + pub id: i64, + pub start_time: PrimitiveDateTime, + pub host: Vec<u8>, + pub dir: Vec<u8>, + pub argv: Vec<u8>, + pub duration: i64, + pub exit_status: i64, + pub session: i64, +} + +#[derive(Debug)] +pub struct ZshHistDb { + histdb: Vec<HistDbEntry>, + username: String, +} + +/// Read db at given file, return vector of entries. +async fn hist_from_db(dbpath: PathBuf) -> Result<Vec<HistDbEntry>> { + let pool = SqlitePool::connect(dbpath.to_str().unwrap()).await?; + hist_from_db_conn(pool).await +} + +async fn hist_from_db_conn(pool: Pool<sqlx::Sqlite>) -> Result<Vec<HistDbEntry>> { + let query = r#" + SELECT + history.id, history.start_time, history.duration, places.host, places.dir, + commands.argv, history.exit_status, history.session + FROM history + LEFT JOIN commands ON history.command_id = commands.id + LEFT JOIN places ON history.place_id = places.id + ORDER BY history.start_time + "#; + let histdb_vec: Vec<HistDbEntry> = sqlx::query_as::<_, HistDbEntry>(query) + .fetch_all(&pool) + .await?; + Ok(histdb_vec) +} + +impl ZshHistDb { + pub fn histpath_candidate() -> PathBuf { + // By default histdb database is `${HOME}/.histdb/zsh-history.db` + // This can be modified by ${HISTDB_FILE} + // + // if [[ -z ${HISTDB_FILE} ]]; then + // typeset -g HISTDB_FILE="${HOME}/.histdb/zsh-history.db" + let user_dirs = UserDirs::new().unwrap(); // should catch error here? + let home_dir = user_dirs.home_dir(); + std::env::var("HISTDB_FILE") + .as_ref() + .map(|x| Path::new(x).to_path_buf()) + .unwrap_or_else(|_err| home_dir.join(".histdb/zsh-history.db")) + } + pub fn histpath() -> Result<PathBuf> { + let histdb_path = ZshHistDb::histpath_candidate(); + if histdb_path.exists() { + Ok(histdb_path) + } else { + Err(eyre!( + "Could not find history file. Try setting $HISTDB_FILE" + )) + } + } +} + +#[async_trait] +impl Importer for ZshHistDb { + // Not sure how this is used + const NAME: &'static str = "zsh_histdb"; + + /// Creates a new ZshHistDb and populates the history based on the pre-populated data + /// structure. + async fn new() -> Result<Self> { + let dbpath = ZshHistDb::histpath()?; + let histdb_entry_vec = hist_from_db(dbpath).await?; + Ok(Self { + histdb: histdb_entry_vec, + username: get_username(), + }) + } + + async fn entries(&mut self) -> Result<usize> { + Ok(self.histdb.len()) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + let mut session_map = HashMap::new(); + for entry in self.histdb { + let command = match std::str::from_utf8(&entry.argv) { + Ok(s) => s.trim_end(), + Err(_) => continue, // we can skip past things like invalid utf8 + }; + let cwd = match std::str::from_utf8(&entry.dir) { + Ok(s) => s.trim_end(), + Err(_) => continue, // we can skip past things like invalid utf8 + }; + let hostname = format!( + "{}:{}", + String::from_utf8(entry.host).unwrap_or_else(|_e| get_hostname()), + self.username + ); + let session = session_map.entry(entry.session).or_insert_with(uuid_v7); + + let imported = History::import() + .timestamp(entry.start_time.assume_utc()) + .command(command) + .cwd(cwd) + .duration(entry.duration * 1_000_000_000) + .exit(entry.exit_status) + .session(session.as_simple().to_string()) + .hostname(hostname) + .build(); + h.push(imported.into()).await?; + } + Ok(()) + } +} + +#[cfg(test)] +mod test { + + use super::*; + use sqlx::sqlite::SqlitePoolOptions; + use std::env; + #[tokio::test(flavor = "multi_thread")] + #[expect(unsafe_code)] + async fn test_env_vars() { + let test_env_db = "nonstd-zsh-history.db"; + let key = "HISTDB_FILE"; + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::set_var(key, test_env_db) }; + + // test the env got set + assert_eq!(env::var(key).unwrap(), test_env_db.to_string()); + + // test histdb returns the proper db from previous step + let histdb_path = ZshHistDb::histpath_candidate(); + assert_eq!(histdb_path.to_str().unwrap(), test_env_db); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_import() { + let pool: SqlitePool = SqlitePoolOptions::new() + .min_connections(2) + .connect(":memory:") + .await + .unwrap(); + + // sql dump directly from a test database. + let db_sql = r#" + PRAGMA foreign_keys=OFF; + BEGIN TRANSACTION; + CREATE TABLE commands (id integer primary key autoincrement, argv text, unique(argv) on conflict ignore); + INSERT INTO commands VALUES(1,'pwd'); + INSERT INTO commands VALUES(2,'curl google.com'); + INSERT INTO commands VALUES(3,'bash'); + CREATE TABLE places (id integer primary key autoincrement, host text, dir text, unique(host, dir) on conflict ignore); + INSERT INTO places VALUES(1,'mbp16.local','/home/noyez'); + CREATE TABLE history (id integer primary key autoincrement, + session int, + command_id int references commands (id), + place_id int references places (id), + exit_status int, + start_time int, + duration int); + INSERT INTO history VALUES(1,0,1,1,0,1651497918,1); + INSERT INTO history VALUES(2,0,2,1,0,1651497923,1); + INSERT INTO history VALUES(3,0,3,1,NULL,1651497930,NULL); + DELETE FROM sqlite_sequence; + INSERT INTO sqlite_sequence VALUES('commands',3); + INSERT INTO sqlite_sequence VALUES('places',3); + INSERT INTO sqlite_sequence VALUES('history',3); + CREATE INDEX hist_time on history(start_time); + CREATE INDEX place_dir on places(dir); + CREATE INDEX place_host on places(host); + CREATE INDEX history_command_place on history(command_id, place_id); + COMMIT; "#; + + sqlx::query(db_sql).execute(&pool).await.unwrap(); + + // test histdb iterator + let histdb_vec = hist_from_db_conn(pool).await.unwrap(); + let histdb = ZshHistDb { + histdb: histdb_vec, + username: get_username(), + }; + + println!("h: {:#?}", histdb.histdb); + println!("counter: {:?}", histdb.histdb.len()); + for i in histdb.histdb { + println!("{i:?}"); + } + } +} diff --git a/crates/turtle/src/atuin_client/login.rs b/crates/turtle/src/atuin_client/login.rs new file mode 100644 index 00000000..ca4e16fe --- /dev/null +++ b/crates/turtle/src/atuin_client/login.rs @@ -0,0 +1,68 @@ +use std::path::PathBuf; + +use crate::atuin_common::api::LoginRequest; +use eyre::{Context, Result, bail}; +use tokio::fs::File; +use tokio::io::AsyncWriteExt; + +use crate::atuin_client::{ + api_client, + encryption::{decode_key, load_key}, + record::{sqlite_store::SqliteStore, store::Store}, + settings::Settings, +}; + +pub async fn login( + settings: &Settings, + store: &SqliteStore, + username: String, + password: String, + key: String, +) -> Result<String> { + let key_path = settings.key_path.as_str(); + let key_path = PathBuf::from(key_path); + + if !key_path.exists() { + if decode_key(key.clone()).is_err() { + bail!("the specified key was invalid"); + } + + let mut file = File::create(&key_path).await?; + file.write_all(key.as_bytes()).await?; + } else { + // we now know that the user has logged in specifying a key, AND that the key path + // exists + + // 1. check if the saved key and the provided key match. if so, nothing to do. + // 2. if not, re-encrypt the local history and overwrite the key + let current_key: [u8; 32] = load_key(settings)?.into(); + + let encoded = key.clone(); // gonna want to save it in a bit + let new_key: [u8; 32] = decode_key(key) + .context("could not decode provided key - is not valid base64")? + .into(); + + if new_key != current_key { + println!("\nRe-encrypting local store with new key"); + + store.re_encrypt(¤t_key, &new_key).await?; + + println!("Writing new key"); + let mut file = File::create(&key_path).await?; + file.write_all(encoded.as_bytes()).await?; + } + } + + let session = api_client::login( + settings.sync_address.as_str(), + LoginRequest { username, password }, + ) + .await?; + + Settings::meta_store() + .await? + .save_session(&session.session) + .await?; + + Ok(session.session) +} diff --git a/crates/turtle/src/atuin_client/logout.rs b/crates/turtle/src/atuin_client/logout.rs new file mode 100644 index 00000000..343934b9 --- /dev/null +++ b/crates/turtle/src/atuin_client/logout.rs @@ -0,0 +1,16 @@ +use eyre::Result; + +use crate::atuin_client::settings::Settings; + +pub async fn logout() -> Result<()> { + let meta = Settings::meta_store().await?; + + if meta.logged_in().await? { + meta.delete_session().await?; + println!("You have logged out!"); + } else { + println!("You are not logged in"); + } + + Ok(()) +} diff --git a/crates/turtle/src/atuin_client/meta.rs b/crates/turtle/src/atuin_client/meta.rs new file mode 100644 index 00000000..1eea7061 --- /dev/null +++ b/crates/turtle/src/atuin_client/meta.rs @@ -0,0 +1,366 @@ +use std::path::Path; +use std::str::FromStr; +use std::time::Duration; + +use crate::atuin_common::record::HostId; +use eyre::{Result, eyre}; +use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions}; +use time::{OffsetDateTime, format_description::well_known::Rfc3339}; +use tokio::sync::OnceCell; +use tracing::{debug, warn}; +use uuid::Uuid; + +// Filenames for the legacy plain-text files that we migrate from. +const LEGACY_HOST_ID_FILENAME: &str = "host_id"; +const LEGACY_LAST_SYNC_FILENAME: &str = "last_sync_time"; +const LEGACY_LAST_VERSION_CHECK_FILENAME: &str = "last_version_check_time"; +const LEGACY_LATEST_VERSION_FILENAME: &str = "latest_version"; +const LEGACY_SESSION_FILENAME: &str = "session"; + +const KEY_HOST_ID: &str = "host_id"; +const KEY_LAST_SYNC: &str = "last_sync_time"; +const KEY_LAST_VERSION_CHECK: &str = "last_version_check_time"; +const KEY_LATEST_VERSION: &str = "latest_version"; +const KEY_SESSION: &str = "session"; +const KEY_FILES_MIGRATED: &str = "files_migrated"; + +pub struct MetaStore { + pool: SqlitePool, + cached_host_id: OnceCell<HostId>, +} + +impl MetaStore { + pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> { + let path = path.as_ref(); + let path_str = path + .as_os_str() + .to_str() + .ok_or_else(|| eyre!("meta database path is not valid UTF-8: {path:?}"))?; + debug!("opening meta sqlite database at {path:?}"); + + let is_memory = path_str.contains(":memory:"); + + if !is_memory + && !path.exists() + && let Some(dir) = path.parent() + { + fs_err::create_dir_all(dir)?; + } + + // Use DELETE journal mode instead of WAL. This is a small, infrequently- + // written KV store — WAL's concurrency benefits aren't needed, and DELETE + // mode avoids creating auxiliary -wal/-shm files that complicate + // permission handling. + let opts = SqliteConnectOptions::from_str(path_str)? + .journal_mode(SqliteJournalMode::Delete) + .optimize_on_close(true, None) + .create_if_missing(true); + + let pool = SqlitePoolOptions::new() + .acquire_timeout(Duration::from_secs_f64(timeout)) + .connect_with(opts) + .await?; + + sqlx::migrate!("./meta-migrations").run(&pool).await?; + + // Session tokens are stored in this database, so restrict permissions. + #[cfg(unix)] + if !is_memory { + use std::os::unix::fs::PermissionsExt; + std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600))?; + } + + let store = Self { + pool, + cached_host_id: OnceCell::const_new(), + }; + + if !is_memory { + store.migrate_files().await?; + } + + Ok(store) + } + + // Generic key-value operations + + pub async fn get(&self, key: &str) -> Result<Option<String>> { + let row: Option<(String,)> = sqlx::query_as("SELECT value FROM meta WHERE key = ?1") + .bind(key) + .fetch_optional(&self.pool) + .await?; + + Ok(row.map(|r| r.0)) + } + + pub async fn set(&self, key: &str, value: &str) -> Result<()> { + sqlx::query( + "INSERT INTO meta (key, value, updated_at) VALUES (?1, ?2, strftime('%s', 'now')) + ON CONFLICT(key) DO UPDATE SET value = ?2, updated_at = strftime('%s', 'now')", + ) + .bind(key) + .bind(value) + .execute(&self.pool) + .await?; + + Ok(()) + } + + pub async fn delete(&self, key: &str) -> Result<()> { + sqlx::query("DELETE FROM meta WHERE key = ?1") + .bind(key) + .execute(&self.pool) + .await?; + + Ok(()) + } + + // Typed accessors + + pub async fn host_id(&self) -> Result<HostId> { + self.cached_host_id + .get_or_try_init(|| async { + if let Some(id) = self.get(KEY_HOST_ID).await? { + let parsed = Uuid::from_str(id.as_str()) + .map_err(|e| eyre!("failed to parse host ID: {e}"))?; + return Ok(HostId(parsed)); + } + + let uuid = crate::atuin_common::utils::uuid_v7(); + self.set(KEY_HOST_ID, uuid.as_simple().to_string().as_ref()) + .await?; + + Ok(HostId(uuid)) + }) + .await + .copied() + } + + pub async fn last_sync(&self) -> Result<OffsetDateTime> { + match self.get(KEY_LAST_SYNC).await? { + Some(v) => Ok(OffsetDateTime::parse(v.as_str(), &Rfc3339)?), + None => Ok(OffsetDateTime::UNIX_EPOCH), + } + } + + pub async fn save_sync_time(&self) -> Result<()> { + self.set( + KEY_LAST_SYNC, + OffsetDateTime::now_utc().format(&Rfc3339)?.as_str(), + ) + .await + } + + pub async fn last_version_check(&self) -> Result<OffsetDateTime> { + match self.get(KEY_LAST_VERSION_CHECK).await? { + Some(v) => Ok(OffsetDateTime::parse(v.as_str(), &Rfc3339)?), + None => Ok(OffsetDateTime::UNIX_EPOCH), + } + } + + pub async fn save_version_check_time(&self) -> Result<()> { + self.set( + KEY_LAST_VERSION_CHECK, + OffsetDateTime::now_utc().format(&Rfc3339)?.as_str(), + ) + .await + } + + pub async fn latest_version(&self) -> Result<Option<String>> { + self.get(KEY_LATEST_VERSION).await + } + + pub async fn save_latest_version(&self, version: &str) -> Result<()> { + self.set(KEY_LATEST_VERSION, version).await + } + + pub async fn session_token(&self) -> Result<Option<String>> { + self.get(KEY_SESSION).await + } + + pub async fn save_session(&self, token: &str) -> Result<()> { + self.set(KEY_SESSION, token).await + } + + pub async fn delete_session(&self) -> Result<()> { + self.delete(KEY_SESSION).await + } + + pub async fn logged_in(&self) -> Result<bool> { + Ok(self.session_token().await?.is_some()) + } + + // File migration: on first open, migrate old plain-text files into the database. + // Old files are left in place for safe downgrades. + + async fn migrate_files(&self) -> Result<()> { + if self.get(KEY_FILES_MIGRATED).await?.is_some() { + return Ok(()); + } + + let data_dir = crate::atuin_client::settings::Settings::effective_data_dir(); + + // host_id — validate as UUID + let host_id_path = data_dir.join(LEGACY_HOST_ID_FILENAME); + if host_id_path.exists() + && let Ok(value) = fs_err::read_to_string(&host_id_path) + { + let value = value.trim(); + if !value.is_empty() { + if Uuid::from_str(value).is_ok() { + self.set(KEY_HOST_ID, value).await?; + } else { + warn!("skipping migration of host_id: invalid UUID {value:?}"); + } + } + } + + // last_sync_time — validate as RFC3339 + let sync_path = data_dir.join(LEGACY_LAST_SYNC_FILENAME); + if sync_path.exists() + && let Ok(value) = fs_err::read_to_string(&sync_path) + { + let value = value.trim(); + if !value.is_empty() { + if OffsetDateTime::parse(value, &Rfc3339).is_ok() { + self.set(KEY_LAST_SYNC, value).await?; + } else { + warn!("skipping migration of last_sync_time: invalid RFC3339 {value:?}"); + } + } + } + + // last_version_check_time — validate as RFC3339 + let version_check_path = data_dir.join(LEGACY_LAST_VERSION_CHECK_FILENAME); + if version_check_path.exists() + && let Ok(value) = fs_err::read_to_string(&version_check_path) + { + let value = value.trim(); + if !value.is_empty() { + if OffsetDateTime::parse(value, &Rfc3339).is_ok() { + self.set(KEY_LAST_VERSION_CHECK, value).await?; + } else { + warn!( + "skipping migration of last_version_check_time: invalid RFC3339 {value:?}" + ); + } + } + } + + // latest_version — no strict validation, just non-empty + let latest_version_path = data_dir.join(LEGACY_LATEST_VERSION_FILENAME); + if latest_version_path.exists() + && let Ok(value) = fs_err::read_to_string(&latest_version_path) + { + let value = value.trim(); + if !value.is_empty() { + self.set(KEY_LATEST_VERSION, value).await?; + } + } + + // session token — no strict validation, just non-empty + let session_path = data_dir.join(LEGACY_SESSION_FILENAME); + if session_path.exists() + && let Ok(value) = fs_err::read_to_string(&session_path) + { + let value = value.trim(); + if !value.is_empty() { + self.set(KEY_SESSION, value).await?; + } + } + + self.set(KEY_FILES_MIGRATED, "true").await?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + async fn new_test_store() -> MetaStore { + MetaStore::new("sqlite::memory:", 2.0).await.unwrap() + } + + #[tokio::test] + async fn test_get_set_delete() { + let store = new_test_store().await; + + assert_eq!(store.get("foo").await.unwrap(), None); + + store.set("foo", "bar").await.unwrap(); + assert_eq!(store.get("foo").await.unwrap(), Some("bar".to_string())); + + store.set("foo", "baz").await.unwrap(); + assert_eq!(store.get("foo").await.unwrap(), Some("baz".to_string())); + + store.delete("foo").await.unwrap(); + assert_eq!(store.get("foo").await.unwrap(), None); + } + + #[tokio::test] + async fn test_host_id_generation_and_stability() { + let store = new_test_store().await; + + let id1 = store.host_id().await.unwrap(); + let id2 = store.host_id().await.unwrap(); + + assert_eq!(id1, id2, "host_id should be stable across calls"); + } + + #[tokio::test] + async fn test_sync_time() { + let store = new_test_store().await; + + let t = store.last_sync().await.unwrap(); + assert_eq!(t, OffsetDateTime::UNIX_EPOCH); + + store.save_sync_time().await.unwrap(); + let t = store.last_sync().await.unwrap(); + assert!(t > OffsetDateTime::UNIX_EPOCH); + } + + #[tokio::test] + async fn test_version_check_time() { + let store = new_test_store().await; + + let t = store.last_version_check().await.unwrap(); + assert_eq!(t, OffsetDateTime::UNIX_EPOCH); + + store.save_version_check_time().await.unwrap(); + let t = store.last_version_check().await.unwrap(); + assert!(t > OffsetDateTime::UNIX_EPOCH); + } + + #[tokio::test] + async fn test_session_crud() { + let store = new_test_store().await; + + assert!(!store.logged_in().await.unwrap()); + assert_eq!(store.session_token().await.unwrap(), None); + + store.save_session("tok123").await.unwrap(); + assert!(store.logged_in().await.unwrap()); + assert_eq!( + store.session_token().await.unwrap(), + Some("tok123".to_string()) + ); + + store.delete_session().await.unwrap(); + assert!(!store.logged_in().await.unwrap()); + } + + #[tokio::test] + async fn test_latest_version() { + let store = new_test_store().await; + + assert_eq!(store.latest_version().await.unwrap(), None); + + store.save_latest_version("1.2.3").await.unwrap(); + assert_eq!( + store.latest_version().await.unwrap(), + Some("1.2.3".to_string()) + ); + } +} diff --git a/crates/turtle/src/atuin_client/mod.rs b/crates/turtle/src/atuin_client/mod.rs new file mode 100644 index 00000000..7f07f2e2 --- /dev/null +++ b/crates/turtle/src/atuin_client/mod.rs @@ -0,0 +1,26 @@ +#[cfg(feature = "sync")] +pub mod api_client; +#[cfg(feature = "sync")] +pub mod auth; +#[cfg(feature = "sync")] +pub mod login; +#[cfg(feature = "sync")] +pub mod register; +#[cfg(feature = "sync")] +pub mod sync; + +pub mod database; +pub mod distro; +pub mod encryption; +pub mod history; +pub mod import; +pub mod logout; +pub mod meta; +pub mod ordering; +pub mod plugin; +pub mod record; +pub mod secrets; +pub mod settings; +pub mod theme; + +mod utils; diff --git a/crates/turtle/src/atuin_client/ordering.rs b/crates/turtle/src/atuin_client/ordering.rs new file mode 100644 index 00000000..4e5ec84c --- /dev/null +++ b/crates/turtle/src/atuin_client/ordering.rs @@ -0,0 +1,32 @@ +use minspan::minspan; + +use super::{history::History, settings::SearchMode}; + +pub fn reorder_fuzzy(mode: SearchMode, query: &str, res: Vec<History>) -> Vec<History> { + match mode { + SearchMode::Fuzzy => reorder(query, |x| &x.command, res), + _ => res, + } +} + +fn reorder<F, A>(query: &str, f: F, res: Vec<A>) -> Vec<A> +where + F: Fn(&A) -> &String, + A: Clone, +{ + let mut r = res.clone(); + let qvec = &query.chars().collect(); + r.sort_by_cached_key(|h| { + // TODO for fzf search we should sum up scores for each matched term + let (from, to) = match minspan::span(qvec, &(f(h).chars().collect())) { + Some(x) => x, + // this is a little unfortunate: when we are asked to match a query that is found nowhere, + // we don't want to return a None, as the comparison behaviour would put the worst matches + // at the front. therefore, we'll return a set of indices that are one larger than the longest + // possible legitimate match. This is meaningless except as a comparison. + None => (0, res.len()), + }; + 1 + to - from + }); + r +} diff --git a/crates/turtle/src/atuin_client/plugin.rs b/crates/turtle/src/atuin_client/plugin.rs new file mode 100644 index 00000000..6f351bf1 --- /dev/null +++ b/crates/turtle/src/atuin_client/plugin.rs @@ -0,0 +1,150 @@ +use std::collections::HashMap; + +#[derive(Debug, Clone)] +pub struct OfficialPlugin { + pub name: String, + pub description: String, + pub install_message: String, +} + +impl OfficialPlugin { + pub fn new(name: &str, description: &str, install_message: &str) -> Self { + Self { + name: name.to_string(), + description: description.to_string(), + install_message: install_message.to_string(), + } + } +} + +pub struct OfficialPluginRegistry { + plugins: HashMap<String, OfficialPlugin>, +} + +impl OfficialPluginRegistry { + pub fn new() -> Self { + let mut registry = Self { + plugins: HashMap::new(), + }; + + // Register official plugins + registry.register_official_plugins(); + + registry + } + + fn register_official_plugins(&mut self) { + // atuin-update plugin + self.plugins.insert( + "update".to_string(), + OfficialPlugin::new( + "update", + "Update atuin to the latest version", + "The 'atuin update' command is provided by the atuin-update plugin.\n\ + It is only installed if you used the install script\n \ + If you used a package manager (brew, apt, etc), please continue to use it for updates", + ), + ); + } + + pub fn get_plugin(&self, name: &str) -> Option<&OfficialPlugin> { + self.plugins.get(name) + } + + pub fn is_official_plugin(&self, name: &str) -> bool { + self.plugins.contains_key(name) + } + + pub fn get_install_message(&self, name: &str) -> Option<&str> { + self.plugins + .get(name) + .map(|plugin| plugin.install_message.as_str()) + } +} + +impl Default for OfficialPluginRegistry { + fn default() -> Self { + Self::new() + } +} + +pub struct PluginContext { + #[cfg(windows)] + _update_on_windows: Option<UpdateOnWindowsContext>, +} + +impl PluginContext { + pub fn new(_subcommand: &str) -> Self { + PluginContext { + #[cfg(windows)] + _update_on_windows: (_subcommand == "update").then(UpdateOnWindowsContext::new), + } + } +} + +impl Drop for PluginContext { + fn drop(&mut self) {} +} + +#[cfg(windows)] +struct UpdateOnWindowsContext { + initial_exe: Option<std::path::PathBuf>, +} + +#[cfg(windows)] +impl UpdateOnWindowsContext { + const OLD_FILE_NAME: &'static str = "atuin.old"; + + pub fn new() -> Self { + // Windows doesn't let you overwrite a running exe, but it lets you rename it, + // so make some room for atuin-update to install the new version. + let initial_exe = std::env::current_exe().ok().and_then(|exe| { + std::fs::rename(&exe, exe.with_file_name(Self::OLD_FILE_NAME)).ok()?; + Some(exe) + }); + + Self { initial_exe } + } +} + +#[cfg(windows)] +impl Drop for UpdateOnWindowsContext { + fn drop(&mut self) { + if let Some(exe) = &self.initial_exe + && !exe.exists() + { + // The update failed, roll back the current exe to its initial name. + std::fs::rename(exe.with_file_name(Self::OLD_FILE_NAME), exe).unwrap_or_else(|e| { + eprintln!("Failed to roll back the update, you may need to reinstall Atuin: {e}"); + }); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_registry_creation() { + let registry = OfficialPluginRegistry::new(); + assert!(registry.is_official_plugin("update")); + assert!(!registry.is_official_plugin("nonexistent")); + } + + #[test] + fn test_get_plugin() { + let registry = OfficialPluginRegistry::new(); + let plugin = registry.get_plugin("update"); + assert!(plugin.is_some()); + assert_eq!(plugin.unwrap().name, "update"); + } + + #[test] + fn test_get_install_message() { + let registry = OfficialPluginRegistry::new(); + let message = registry.get_install_message("update"); + assert!(message.is_some()); + assert!(message.unwrap().contains("atuin-update")); + } +} diff --git a/crates/turtle/src/atuin_client/record/encryption.rs b/crates/turtle/src/atuin_client/record/encryption.rs new file mode 100644 index 00000000..22dcdec3 --- /dev/null +++ b/crates/turtle/src/atuin_client/record/encryption.rs @@ -0,0 +1,373 @@ +use crate::atuin_common::record::{ + AdditionalData, DecryptedData, EncryptedData, Encryption, HostId, RecordId, RecordIdx, +}; +use base64::{Engine, engine::general_purpose}; +use eyre::{Context, Result, ensure}; +use rusty_paserk::{Key, KeyId, Local, PieWrappedKey}; +use rusty_paseto::core::{ + ImplicitAssertion, Key as DataKey, Local as LocalPurpose, Paseto, PasetoNonce, Payload, V4, +}; +use serde::{Deserialize, Serialize}; + +/// Use PASETO V4 Local encryption using the additional data as an implicit assertion. +#[expect(non_camel_case_types)] +pub struct PASETO_V4; + +/* +Why do we use a random content-encryption key? +Originally I was planning on using a derived key for encryption based on additional data. +This would be a lot more secure than using the master key directly. + +However, there's an established norm of using a random key. This scheme might be otherwise known as +- client-side encryption +- envelope encryption +- key wrapping + +A HSM (Hardware Security Module) provider, eg: AWS, Azure, GCP, or even a physical device like a YubiKey +will have some keys that they keep to themselves. These keys never leave their physical hardware. +If they never leave the hardware, then encrypting large amounts of data means giving them the data and waiting. +This is not a practical solution. Instead, generate a unique key for your data, encrypt that using your HSM +and then store that with your data. + +See + - <https://docs.aws.amazon.com/wellarchitected/latest/financial-services-industry-lens/use-envelope-encryption-with-customer-master-keys.html> + - <https://cloud.google.com/kms/docs/envelope-encryption> + - <https://learn.microsoft.com/en-us/azure/storage/blobs/client-side-encryption?tabs=dotnet#encryption-and-decryption-via-the-envelope-technique> + - <https://www.yubico.com/gb/product/yubihsm-2-fips/> + - <https://cheatsheetseries.owasp.org/cheatsheets/Cryptographic_Storage_Cheat_Sheet.html#encrypting-stored-keys> + +Why would we care? In the past we have received some requests for company solutions. If in future we can configure a +KMS service with little effort, then that would solve a lot of issues for their security team. + +Even for personal use, if a user is not comfortable with sharing keys between hosts, +GCP HSM costs $1/month and $0.03 per 10,000 key operations. Assuming an active user runs +1000 atuin records a day, that would only cost them $1 and 10 cent a month. + +Additionally, key rotations are much simpler using this scheme. Rotating a key is as simple as re-encrypting the CEK, and not the message contents. +This makes it very fast to rotate a key in bulk. + +For future reference, with asymmetric encryption, you can encrypt the CEK without the HSM's involvement, but decrypting +will need the HSM. This allows the encryption path to still be extremely fast (no network calls) but downloads/decryption +that happens in the background can make the network calls to the HSM +*/ + +impl Encryption for PASETO_V4 { + fn re_encrypt( + mut data: EncryptedData, + _ad: AdditionalData, + old_key: &[u8; 32], + new_key: &[u8; 32], + ) -> Result<EncryptedData> { + let cek = Self::decrypt_cek(data.content_encryption_key, old_key)?; + data.content_encryption_key = Self::encrypt_cek(cek, new_key); + Ok(data) + } + + fn encrypt(data: DecryptedData, ad: AdditionalData, key: &[u8; 32]) -> EncryptedData { + // generate a random key for this entry + // aka content-encryption-key (CEK) + let random_key = Key::<V4, Local>::new_os_random(); + + // encode the implicit assertions + let assertions = Assertions::from(ad).encode(); + + // build the payload and encrypt the token + let payload = serde_json::to_string(&AtuinPayload { + data: general_purpose::URL_SAFE_NO_PAD.encode(data.0), + }) + .expect("json encoding can't fail"); + let nonce = DataKey::<32>::try_new_random().expect("could not source from random"); + let nonce = PasetoNonce::<V4, LocalPurpose>::from(&nonce); + + let token = Paseto::<V4, LocalPurpose>::builder() + .set_payload(Payload::from(payload.as_str())) + .set_implicit_assertion(ImplicitAssertion::from(assertions.as_str())) + .try_encrypt(&random_key.into(), &nonce) + .expect("error encrypting atuin data"); + + EncryptedData { + data: token, + content_encryption_key: Self::encrypt_cek(random_key, key), + } + } + + fn decrypt(data: EncryptedData, ad: AdditionalData, key: &[u8; 32]) -> Result<DecryptedData> { + let token = data.data; + let cek = Self::decrypt_cek(data.content_encryption_key, key)?; + + // encode the implicit assertions + let assertions = Assertions::from(ad).encode(); + + // decrypt the payload with the footer and implicit assertions + let payload = Paseto::<V4, LocalPurpose>::try_decrypt( + &token, + &cek.into(), + None, + ImplicitAssertion::from(&*assertions), + ) + .context("could not decrypt entry")?; + + let payload: AtuinPayload = serde_json::from_str(&payload)?; + let data = general_purpose::URL_SAFE_NO_PAD.decode(payload.data)?; + Ok(DecryptedData(data)) + } +} + +impl PASETO_V4 { + fn decrypt_cek(wrapped_cek: String, key: &[u8; 32]) -> Result<Key<V4, Local>> { + let wrapping_key = Key::<V4, Local>::from_bytes(*key); + + // let wrapping_key = PasetoSymmetricKey::from(Key::from(key)); + + let AtuinFooter { kid, wpk } = serde_json::from_str(&wrapped_cek) + .context("wrapped cek did not contain the correct contents")?; + + // check that the wrapping key matches the required key to decrypt. + // In future, we could support multiple keys and use this key to + // look up the key rather than only allow one key. + // For now though we will only support the one key and key rotation will + // have to be a hard reset + let current_kid = wrapping_key.to_id(); + + ensure!( + current_kid == kid, + "attempting to decrypt with incorrect key. currently using {current_kid}, expecting {kid}" + ); + + // decrypt the random key + Ok(wpk.unwrap_key(&wrapping_key)?) + } + + fn encrypt_cek(cek: Key<V4, Local>, key: &[u8; 32]) -> String { + // aka key-encryption-key (KEK) + let wrapping_key = Key::<V4, Local>::from_bytes(*key); + + // wrap the random key so we can decrypt it later + let wrapped_cek = AtuinFooter { + wpk: cek.wrap_pie(&wrapping_key), + kid: wrapping_key.to_id(), + }; + serde_json::to_string(&wrapped_cek).expect("could not serialize wrapped cek") + } +} + +#[derive(Serialize, Deserialize)] +struct AtuinPayload { + data: String, +} + +#[derive(Serialize, Deserialize)] +/// Well-known footer claims for decrypting. This is not encrypted but is stored in the record. +/// <https://github.com/paseto-standard/paseto-spec/blob/master/docs/02-Implementation-Guide/04-Claims.md#optional-footer-claims> +struct AtuinFooter { + /// Wrapped key + wpk: PieWrappedKey<V4, Local>, + /// ID of the key which was used to wrap + kid: KeyId<V4, Local>, +} + +/// Used in the implicit assertions. This is not encrypted and not stored in the data blob. +// This cannot be changed, otherwise it breaks the authenticated encryption. +#[derive(Debug, Copy, Clone, Serialize)] +struct Assertions<'a> { + id: &'a RecordId, + idx: &'a RecordIdx, + version: &'a str, + tag: &'a str, + host: &'a HostId, +} + +impl<'a> From<AdditionalData<'a>> for Assertions<'a> { + fn from(ad: AdditionalData<'a>) -> Self { + Self { + id: ad.id, + version: ad.version, + tag: ad.tag, + host: ad.host, + idx: ad.idx, + } + } +} + +impl Assertions<'_> { + fn encode(&self) -> String { + serde_json::to_string(self).expect("could not serialize implicit assertions") + } +} + +#[cfg(test)] +mod tests { + use crate::atuin_common::{ + record::{Host, Record}, + utils::uuid_v7, + }; + + use super::*; + + #[test] + fn round_trip() { + let key = Key::<V4, Local>::new_os_random(); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + version: "v0", + tag: "kv", + host: &HostId(uuid_v7()), + idx: &0, + }; + + let data = DecryptedData(vec![1, 2, 3, 4]); + + let encrypted = PASETO_V4::encrypt(data.clone(), ad, &key.to_bytes()); + let decrypted = PASETO_V4::decrypt(encrypted, ad, &key.to_bytes()).unwrap(); + assert_eq!(decrypted, data); + } + + #[test] + fn same_entry_different_output() { + let key = Key::<V4, Local>::new_os_random(); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + version: "v0", + tag: "kv", + host: &HostId(uuid_v7()), + idx: &0, + }; + + let data = DecryptedData(vec![1, 2, 3, 4]); + + let encrypted = PASETO_V4::encrypt(data.clone(), ad, &key.to_bytes()); + let encrypted2 = PASETO_V4::encrypt(data, ad, &key.to_bytes()); + + assert_ne!( + encrypted.data, encrypted2.data, + "re-encrypting the same contents should have different output due to key randomization" + ); + } + + #[test] + fn cannot_decrypt_different_key() { + let key = Key::<V4, Local>::new_os_random(); + let fake_key = Key::<V4, Local>::new_os_random(); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + version: "v0", + tag: "kv", + host: &HostId(uuid_v7()), + idx: &0, + }; + + let data = DecryptedData(vec![1, 2, 3, 4]); + + let encrypted = PASETO_V4::encrypt(data, ad, &key.to_bytes()); + let _ = PASETO_V4::decrypt(encrypted, ad, &fake_key.to_bytes()).unwrap_err(); + } + + #[test] + fn cannot_decrypt_different_id() { + let key = Key::<V4, Local>::new_os_random(); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + version: "v0", + tag: "kv", + host: &HostId(uuid_v7()), + idx: &0, + }; + + let data = DecryptedData(vec![1, 2, 3, 4]); + + let encrypted = PASETO_V4::encrypt(data, ad, &key.to_bytes()); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + ..ad + }; + let _ = PASETO_V4::decrypt(encrypted, ad, &key.to_bytes()).unwrap_err(); + } + + #[test] + fn re_encrypt_round_trip() { + let key1 = Key::<V4, Local>::new_os_random(); + let key2 = Key::<V4, Local>::new_os_random(); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + version: "v0", + tag: "kv", + host: &HostId(uuid_v7()), + idx: &0, + }; + + let data = DecryptedData(vec![1, 2, 3, 4]); + + let encrypted1 = PASETO_V4::encrypt(data.clone(), ad, &key1.to_bytes()); + let encrypted2 = + PASETO_V4::re_encrypt(encrypted1.clone(), ad, &key1.to_bytes(), &key2.to_bytes()) + .unwrap(); + + // we only re-encrypt the content keys + assert_eq!(encrypted1.data, encrypted2.data); + assert_ne!( + encrypted1.content_encryption_key, + encrypted2.content_encryption_key + ); + + let decrypted = PASETO_V4::decrypt(encrypted2, ad, &key2.to_bytes()).unwrap(); + + assert_eq!(decrypted, data); + } + + #[test] + fn full_record_round_trip() { + let key = [0x55; 32]; + let record = Record::builder() + .id(RecordId(uuid_v7())) + .version("v0".to_owned()) + .tag("kv".to_owned()) + .host(Host::new(HostId(uuid_v7()))) + .timestamp(1687244806000000) + .data(DecryptedData(vec![1, 2, 3, 4])) + .idx(0) + .build(); + + let encrypted = record.encrypt::<PASETO_V4>(&key); + + assert!(!encrypted.data.data.is_empty()); + assert!(!encrypted.data.content_encryption_key.is_empty()); + + let decrypted = encrypted.decrypt::<PASETO_V4>(&key).unwrap(); + + assert_eq!(decrypted.data.0, [1, 2, 3, 4]); + } + + #[test] + fn full_record_round_trip_fail() { + let key = [0x55; 32]; + let record = Record::builder() + .id(RecordId(uuid_v7())) + .version("v0".to_owned()) + .tag("kv".to_owned()) + .host(Host::new(HostId(uuid_v7()))) + .timestamp(1687244806000000) + .data(DecryptedData(vec![1, 2, 3, 4])) + .idx(0) + .build(); + + let encrypted = record.encrypt::<PASETO_V4>(&key); + + let mut enc1 = encrypted.clone(); + enc1.host = Host::new(HostId(uuid_v7())); + let _ = enc1 + .decrypt::<PASETO_V4>(&key) + .expect_err("tampering with the host should result in auth failure"); + + let mut enc2 = encrypted; + enc2.id = RecordId(uuid_v7()); + let _ = enc2 + .decrypt::<PASETO_V4>(&key) + .expect_err("tampering with the id should result in auth failure"); + } +} diff --git a/crates/turtle/src/atuin_client/record/mod.rs b/crates/turtle/src/atuin_client/record/mod.rs new file mode 100644 index 00000000..c40fd395 --- /dev/null +++ b/crates/turtle/src/atuin_client/record/mod.rs @@ -0,0 +1,6 @@ +pub mod encryption; +pub mod sqlite_store; +pub mod store; + +#[cfg(feature = "sync")] +pub mod sync; diff --git a/crates/turtle/src/atuin_client/record/sqlite_store.rs b/crates/turtle/src/atuin_client/record/sqlite_store.rs new file mode 100644 index 00000000..5fab999d --- /dev/null +++ b/crates/turtle/src/atuin_client/record/sqlite_store.rs @@ -0,0 +1,643 @@ +// Here we are using sqlite as a pretty dumb store, and will not be running any complex queries. +// Multiple stores of multiple types are all stored in one chonky table (for now), and we just index +// by tag/host + +use std::str::FromStr; +use std::{path::Path, time::Duration}; + +use async_trait::async_trait; +use eyre::{Result, eyre}; +use fs_err as fs; + +use sqlx::{ + Row, + sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow}, +}; +use tracing::debug; + +use crate::atuin_common::record::{ + EncryptedData, Host, HostId, Record, RecordId, RecordIdx, RecordStatus, +}; +use crate::atuin_common::utils; +use uuid::Uuid; + +use super::encryption::PASETO_V4; +use super::store::Store; + +#[derive(Debug, Clone)] +pub struct SqliteStore { + pool: SqlitePool, +} + +impl SqliteStore { + pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> { + let path = path.as_ref(); + + debug!("opening sqlite database at {path:?}"); + + if utils::broken_symlink(path) { + eprintln!( + "Atuin: Sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement." + ); + std::process::exit(1); + } + + if !path.exists() + && let Some(dir) = path.parent() + { + fs::create_dir_all(dir)?; + } + + let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())? + .journal_mode(SqliteJournalMode::Wal) + .foreign_keys(true) + .create_if_missing(true); + + let pool = SqlitePoolOptions::new() + .acquire_timeout(Duration::from_secs_f64(timeout)) + .connect_with(opts) + .await?; + + Self::setup_db(&pool).await?; + + Ok(Self { pool }) + } + + async fn setup_db(pool: &SqlitePool) -> Result<()> { + debug!("running sqlite database setup"); + + sqlx::migrate!("./record-migrations").run(pool).await?; + + Ok(()) + } + + async fn save_raw( + tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, + r: &Record<EncryptedData>, + ) -> Result<()> { + // In sqlite, we are "limited" to i64. But that is still fine, until 2262. + sqlx::query( + "insert or ignore into store(id, idx, host, tag, timestamp, version, data, cek) + values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", + ) + .bind(r.id.0.as_hyphenated().to_string()) + .bind(r.idx as i64) + .bind(r.host.id.0.as_hyphenated().to_string()) + .bind(r.tag.as_str()) + .bind(r.timestamp as i64) + .bind(r.version.as_str()) + .bind(r.data.data.as_str()) + .bind(r.data.content_encryption_key.as_str()) + .execute(&mut **tx) + .await?; + + Ok(()) + } + + fn query_row(row: SqliteRow) -> Record<EncryptedData> { + let idx: i64 = row.get("idx"); + let timestamp: i64 = row.get("timestamp"); + + // tbh at this point things are pretty fucked so just panic + let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB"); + let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB"); + + Record { + id: RecordId(id), + idx: idx as u64, + host: Host::new(HostId(host)), + timestamp: timestamp as u64, + tag: row.get("tag"), + version: row.get("version"), + data: EncryptedData { + data: row.get("data"), + content_encryption_key: row.get("cek"), + }, + } + } + + async fn load_all(&self) -> Result<Vec<Record<EncryptedData>>> { + let res = sqlx::query("select * from store ") + .map(Self::query_row) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } +} + +#[async_trait] +impl Store for SqliteStore { + async fn push_batch( + &self, + records: impl Iterator<Item = &Record<EncryptedData>> + Send + Sync, + ) -> Result<()> { + let mut tx = self.pool.begin().await?; + + for record in records { + Self::save_raw(&mut tx, record).await?; + } + + tx.commit().await?; + + Ok(()) + } + + async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>> { + let res = sqlx::query("select * from store where store.id = ?1") + .bind(id.0.as_hyphenated().to_string()) + .map(Self::query_row) + .fetch_one(&self.pool) + .await?; + + Ok(res) + } + + async fn delete(&self, id: RecordId) -> Result<()> { + sqlx::query("delete from store where id = ?1") + .bind(id.0.as_hyphenated().to_string()) + .execute(&self.pool) + .await?; + + Ok(()) + } + + async fn delete_all(&self) -> Result<()> { + sqlx::query("delete from store").execute(&self.pool).await?; + + Ok(()) + } + + async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> { + let res = + sqlx::query("select * from store where host=?1 and tag=?2 order by idx desc limit 1") + .bind(host.0.as_hyphenated().to_string()) + .bind(tag) + .map(Self::query_row) + .fetch_one(&self.pool) + .await; + + match res { + Err(sqlx::Error::RowNotFound) => Ok(None), + Err(e) => Err(eyre!("an error occurred: {}", e)), + Ok(record) => Ok(Some(record)), + } + } + + async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> { + self.idx(host, tag, 0).await + } + + async fn len_all(&self) -> Result<u64> { + let res: Result<(i64,), sqlx::Error> = sqlx::query_as("select count(*) from store") + .fetch_one(&self.pool) + .await; + match res { + Err(e) => Err(eyre!("failed to fetch local store len: {}", e)), + Ok(v) => Ok(v.0 as u64), + } + } + + async fn len_tag(&self, tag: &str) -> Result<u64> { + let res: Result<(i64,), sqlx::Error> = + sqlx::query_as("select count(*) from store where tag=?1") + .bind(tag) + .fetch_one(&self.pool) + .await; + match res { + Err(e) => Err(eyre!("failed to fetch local store len: {}", e)), + Ok(v) => Ok(v.0 as u64), + } + } + + async fn len(&self, host: HostId, tag: &str) -> Result<u64> { + let last = self.last(host, tag).await?; + + if let Some(last) = last { + return Ok(last.idx + 1); + } + + return Ok(0); + } + + async fn next( + &self, + host: HostId, + tag: &str, + idx: RecordIdx, + limit: u64, + ) -> Result<Vec<Record<EncryptedData>>> { + let res = sqlx::query( + "select * from store where idx >= ?1 and host = ?2 and tag = ?3 order by idx asc limit ?4", + ) + .bind(idx as i64) + .bind(host.0.as_hyphenated().to_string()) + .bind(tag) + .bind(limit as i64) + .map(Self::query_row) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn idx( + &self, + host: HostId, + tag: &str, + idx: RecordIdx, + ) -> Result<Option<Record<EncryptedData>>> { + let res = sqlx::query("select * from store where idx = ?1 and host = ?2 and tag = ?3") + .bind(idx as i64) + .bind(host.0.as_hyphenated().to_string()) + .bind(tag) + .map(Self::query_row) + .fetch_one(&self.pool) + .await; + + match res { + Err(sqlx::Error::RowNotFound) => Ok(None), + Err(e) => Err(eyre!("an error occurred: {}", e)), + Ok(v) => Ok(Some(v)), + } + } + + async fn status(&self) -> Result<RecordStatus> { + let mut status = RecordStatus::new(); + + let res: Result<Vec<(String, String, i64)>, sqlx::Error> = + sqlx::query_as("select host, tag, max(idx) from store group by host, tag") + .fetch_all(&self.pool) + .await; + + let res = match res { + Err(e) => return Err(eyre!("failed to fetch local store status: {}", e)), + Ok(v) => v, + }; + + for i in res { + let host = HostId( + Uuid::from_str(i.0.as_str()).expect("failed to parse uuid for local store status"), + ); + + status.set_raw(host, i.1, i.2 as u64); + } + + Ok(status) + } + + async fn all_tagged(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> { + let res = sqlx::query("select * from store where tag = ?1 order by timestamp asc") + .bind(tag) + .map(Self::query_row) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + /// Reencrypt every single item in this store with a new key + /// Be careful - this may mess with sync. + async fn re_encrypt(&self, old_key: &[u8; 32], new_key: &[u8; 32]) -> Result<()> { + // Load all the records + // In memory like some of the other code here + // This will never be called in a hot loop, and only under the following circumstances + // 1. The user has logged into a new account, with a new key. They are unlikely to have a + // lot of data + // 2. The user has encountered some sort of issue, and runs a maintenance command that + // invokes this + let all = self.load_all().await?; + + let re_encrypted = all + .into_iter() + .map(|record| record.re_encrypt::<PASETO_V4>(old_key, new_key)) + .collect::<Result<Vec<_>>>()?; + + // next up, we delete all the old data and reinsert the new stuff + // do it in one transaction, so if anything fails we rollback OK + + let mut tx = self.pool.begin().await?; + + let res = sqlx::query("delete from store").execute(&mut *tx).await?; + + let rows = res.rows_affected(); + debug!("deleted {rows} rows"); + + // don't call push_batch, as it will start its own transaction + // call the underlying save_raw + + for record in re_encrypted { + Self::save_raw(&mut tx, &record).await?; + } + + tx.commit().await?; + + Ok(()) + } + + /// Verify that every record in this store can be decrypted with the current key + /// Someday maybe also check each tag/record can be deserialized, but not for now. + async fn verify(&self, key: &[u8; 32]) -> Result<()> { + let all = self.load_all().await?; + + all.into_iter() + .map(|record| record.decrypt::<PASETO_V4>(key)) + .collect::<Result<Vec<_>>>()?; + + Ok(()) + } + + /// Verify that every record in this store can be decrypted with the current key + /// Someday maybe also check each tag/record can be deserialized, but not for now. + async fn purge(&self, key: &[u8; 32]) -> Result<()> { + let all = self.load_all().await?; + + for record in all.iter() { + match record.clone().decrypt::<PASETO_V4>(key) { + Ok(_) => continue, + Err(_) => { + println!( + "Failed to decrypt {}, deleting", + record.id.0.as_hyphenated() + ); + + self.delete(record.id).await?; + } + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use crate::atuin_common::{ + record::{DecryptedData, EncryptedData, Host, HostId, Record}, + utils::uuid_v7, + }; + + use crate::{ + encryption::generate_encoded_key, + record::{encryption::PASETO_V4, store::Store}, + settings::test_local_timeout, + }; + + use super::SqliteStore; + + fn test_record() -> Record<EncryptedData> { + Record::builder() + .host(Host::new(HostId(atuin_common::utils::uuid_v7()))) + .version("v1".into()) + .tag(atuin_common::utils::uuid_v7().simple().to_string()) + .data(EncryptedData { + data: "1234".into(), + content_encryption_key: "1234".into(), + }) + .idx(0) + .build() + } + + #[tokio::test] + async fn create_db() { + let db = SqliteStore::new(":memory:", test_local_timeout()).await; + + assert!( + db.is_ok(), + "db could not be created, {:?}", + db.err().unwrap() + ); + } + + #[tokio::test] + async fn push_record() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + let record = test_record(); + + db.push(&record).await.expect("failed to insert record"); + } + + #[tokio::test] + async fn get_record() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let new_record = db.get(record.id).await.expect("failed to fetch record"); + + assert_eq!(record, new_record, "records are not equal"); + } + + #[tokio::test] + async fn last() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let last = db + .last(record.host.id, record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!( + last.unwrap().id, + record.id, + "expected to get back the same record that was inserted" + ); + } + + #[tokio::test] + async fn first() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let first = db + .first(record.host.id, record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!( + first.unwrap().id, + record.id, + "expected to get back the same record that was inserted" + ); + } + + #[tokio::test] + async fn len() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let len = db + .len(record.host.id, record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!(len, 1, "expected length of 1 after insert"); + } + + #[tokio::test] + async fn len_tag() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let len = db + .len_tag(record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!(len, 1, "expected length of 1 after insert"); + } + + #[tokio::test] + async fn len_different_tags() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + + // these have different tags, so the len should be the same + // we model multiple stores within one database + // new store = new tag = independent length + let first = test_record(); + let second = test_record(); + + db.push(&first).await.unwrap(); + db.push(&second).await.unwrap(); + + let first_len = db.len(first.host.id, first.tag.as_str()).await.unwrap(); + let second_len = db.len(second.host.id, second.tag.as_str()).await.unwrap(); + + assert_eq!(first_len, 1, "expected length of 1 after insert"); + assert_eq!(second_len, 1, "expected length of 1 after insert"); + } + + #[tokio::test] + async fn append_a_bunch() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + + let mut tail = test_record(); + db.push(&tail).await.expect("failed to push record"); + + for _ in 1..100 { + tail = tail.append(vec![1, 2, 3, 4]).encrypt::<PASETO_V4>(&[0; 32]); + db.push(&tail).await.unwrap(); + } + + assert_eq!( + db.len(tail.host.id, tail.tag.as_str()).await.unwrap(), + 100, + "failed to insert 100 records" + ); + + assert_eq!( + db.len_tag(tail.tag.as_str()).await.unwrap(), + 100, + "failed to insert 100 records" + ); + } + + #[tokio::test] + async fn append_a_big_bunch() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + + let mut records: Vec<Record<EncryptedData>> = Vec::with_capacity(10000); + + let mut tail = test_record(); + records.push(tail.clone()); + + for _ in 1..10000 { + tail = tail.append(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]); + records.push(tail.clone()); + } + + db.push_batch(records.iter()).await.unwrap(); + + assert_eq!( + db.len(tail.host.id, tail.tag.as_str()).await.unwrap(), + 10000, + "failed to insert 10k records" + ); + } + + #[tokio::test] + async fn re_encrypt() { + let store = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + let (key, _) = generate_encoded_key().unwrap(); + let data = vec![0u8, 1u8, 2u8, 3u8]; + let host_id = HostId(uuid_v7()); + + for i in 0..10 { + let record = Record::builder() + .host(Host::new(host_id)) + .version(String::from("test")) + .tag(String::from("test")) + .idx(i) + .data(DecryptedData(data.clone())) + .build(); + + let record = record.encrypt::<PASETO_V4>(&key.into()); + store + .push(&record) + .await + .expect("failed to push encrypted record"); + } + + // first, check that we can decrypt the data with the current key + let all = store.all_tagged("test").await.unwrap(); + + assert_eq!(all.len(), 10, "failed to fetch all records"); + + for record in all { + let decrypted = record.decrypt::<PASETO_V4>(&key.into()).unwrap(); + assert_eq!(decrypted.data.0, data); + } + + // reencrypt the store, then check if + // 1) it cannot be decrypted with the old key + // 2) it can be decrypted with the new key + + let (new_key, _) = generate_encoded_key().unwrap(); + store + .re_encrypt(&key.into(), &new_key.into()) + .await + .expect("failed to re-encrypt store"); + + let all = store.all_tagged("test").await.unwrap(); + + for record in all.iter() { + let decrypted = record.clone().decrypt::<PASETO_V4>(&key.into()); + assert!( + decrypted.is_err(), + "did not get error decrypting with old key after re-encrypt" + ) + } + + for record in all { + let decrypted = record.decrypt::<PASETO_V4>(&new_key.into()).unwrap(); + assert_eq!(decrypted.data.0, data); + } + + assert_eq!(store.len(host_id, "test").await.unwrap(), 10); + } +} diff --git a/crates/turtle/src/atuin_client/record/store.rs b/crates/turtle/src/atuin_client/record/store.rs new file mode 100644 index 00000000..f99085d0 --- /dev/null +++ b/crates/turtle/src/atuin_client/record/store.rs @@ -0,0 +1,60 @@ +use async_trait::async_trait; +use eyre::Result; + +use crate::atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIdx, RecordStatus}; + +/// A record store stores records +/// In more detail - we tend to need to process this into _another_ format to actually query it. +/// As is, the record store is intended as the source of truth for arbitrary data, which could +/// be shell history, kvs, etc. +#[async_trait] +pub trait Store { + // Push a record + async fn push(&self, record: &Record<EncryptedData>) -> Result<()> { + self.push_batch(std::iter::once(record)).await + } + + // Push a batch of records, all in one transaction + async fn push_batch( + &self, + records: impl Iterator<Item = &Record<EncryptedData>> + Send + Sync, + ) -> Result<()>; + + async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>>; + + async fn delete(&self, id: RecordId) -> Result<()>; + async fn delete_all(&self) -> Result<()>; + + async fn len_all(&self) -> Result<u64>; + async fn len(&self, host: HostId, tag: &str) -> Result<u64>; + async fn len_tag(&self, tag: &str) -> Result<u64>; + + async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>; + async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>; + + async fn re_encrypt(&self, old_key: &[u8; 32], new_key: &[u8; 32]) -> Result<()>; + async fn verify(&self, key: &[u8; 32]) -> Result<()>; + async fn purge(&self, key: &[u8; 32]) -> Result<()>; + + /// Get the next `limit` records, after and including the given index + async fn next( + &self, + host: HostId, + tag: &str, + idx: RecordIdx, + limit: u64, + ) -> Result<Vec<Record<EncryptedData>>>; + + /// Get the first record for a given host and tag + async fn idx( + &self, + host: HostId, + tag: &str, + idx: RecordIdx, + ) -> Result<Option<Record<EncryptedData>>>; + + async fn status(&self) -> Result<RecordStatus>; + + /// Get all records for a given tag + async fn all_tagged(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>>; +} diff --git a/crates/turtle/src/atuin_client/record/sync.rs b/crates/turtle/src/atuin_client/record/sync.rs new file mode 100644 index 00000000..f831570b --- /dev/null +++ b/crates/turtle/src/atuin_client/record/sync.rs @@ -0,0 +1,664 @@ +// do a sync :O +use std::{cmp::Ordering, fmt::Write}; + +use eyre::Result; +use thiserror::Error; +use tracing::error; + +use super::{encryption::PASETO_V4, store::Store}; +use crate::atuin_client::{api_client::Client, settings::Settings}; + +use crate::atuin_common::record::{Diff, HostId, RecordId, RecordIdx, RecordStatus}; +use indicatif::{ProgressBar, ProgressState, ProgressStyle}; + +#[derive(Error, Debug)] +pub enum SyncError { + #[error("the local store is ahead of the remote, but for another host. has remote lost data?")] + LocalAheadOtherHost, + + #[error("an issue with the local database occurred: {msg:?}")] + LocalStoreError { msg: String }, + + #[error("something has gone wrong with the sync logic: {msg:?}")] + SyncLogicError { msg: String }, + + #[error("operational error: {msg:?}")] + OperationalError { msg: String }, + + #[error("a request to the sync server failed: {msg:?}")] + RemoteRequestError { msg: String }, + + #[error( + "the encryption key on this machine does not match the data on the server. \ + this usually means a new machine was set up without copying the existing key. \ + to fix: run `atuin key` on a machine that already syncs correctly, then run \ + `atuin store rekey <key>` on this machine with the value from the other machine" + )] + WrongKey, +} + +#[derive(Debug, Eq, PartialEq)] +pub enum Operation { + // Either upload or download until the states matches the below + Upload { + local: RecordIdx, + remote: Option<RecordIdx>, + host: HostId, + tag: String, + }, + Download { + local: Option<RecordIdx>, + remote: RecordIdx, + host: HostId, + tag: String, + }, + Noop { + host: HostId, + tag: String, + }, +} + +pub async fn build_client(settings: &Settings) -> Result<Client<'_>, SyncError> { + Client::new( + &settings.sync_address, + settings + .sync_auth_token() + .await + .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?, + settings.network_connect_timeout, + settings.network_timeout, + ) + .map_err(|e| SyncError::OperationalError { msg: e.to_string() }) +} + +pub async fn diff( + client: &Client<'_>, + store: &impl Store, +) -> Result<(Vec<Diff>, RecordStatus), SyncError> { + let local_index = store + .status() + .await + .map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?; + + let remote_index = client + .record_status() + .await + .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?; + + let diff = local_index.diff(&remote_index); + + Ok((diff, remote_index)) +} + +// Take a diff, along with a local store, and resolve it into a set of operations. +// With the store as context, we can determine if a tail exists locally or not and therefore if it needs uploading or download. +// In theory this could be done as a part of the diffing stage, but it's easier to reason +// about and test this way +pub async fn operations( + diffs: Vec<Diff>, + _store: &impl Store, +) -> Result<Vec<Operation>, SyncError> { + let mut operations = Vec::with_capacity(diffs.len()); + + for diff in diffs { + let op = match (diff.local, diff.remote) { + // We both have it! Could be either. Compare. + (Some(local), Some(remote)) => match local.cmp(&remote) { + Ordering::Equal => Operation::Noop { + host: diff.host, + tag: diff.tag, + }, + Ordering::Greater => Operation::Upload { + local, + remote: Some(remote), + host: diff.host, + tag: diff.tag, + }, + Ordering::Less => Operation::Download { + local: Some(local), + remote, + host: diff.host, + tag: diff.tag, + }, + }, + + // Remote has it, we don't. Gotta be download + (None, Some(remote)) => Operation::Download { + local: None, + remote, + host: diff.host, + tag: diff.tag, + }, + + // We have it, remote doesn't. Gotta be upload. + (Some(local), None) => Operation::Upload { + local, + remote: None, + host: diff.host, + tag: diff.tag, + }, + + // something is pretty fucked. + (None, None) => { + return Err(SyncError::SyncLogicError { + msg: String::from( + "diff has nothing for local or remote - (host, tag) does not exist", + ), + }); + } + }; + + operations.push(op); + } + + // sort them - purely so we have a stable testing order, and can rely on + // same input = same output + // We can sort by ID so long as we continue to use UUIDv7 or something + // with the same properties + + operations.sort_by_key(|op| match op { + Operation::Noop { host, tag } => (0, *host, tag.clone()), + + Operation::Upload { host, tag, .. } => (1, *host, tag.clone()), + + Operation::Download { host, tag, .. } => (2, *host, tag.clone()), + }); + + Ok(operations) +} + +async fn sync_upload( + store: &impl Store, + client: &Client<'_>, + host: HostId, + tag: String, + local: RecordIdx, + remote: Option<RecordIdx>, + page_size: u64, +) -> Result<i64, SyncError> { + let remote = remote.unwrap_or(0); + let expected = local - remote; + let mut progress = 0; + + let pb = ProgressBar::new(expected); + pb.set_style(ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {human_pos}/{human_len} ({eta})") + .unwrap() + .with_key("eta", |state: &ProgressState, w: &mut dyn Write| write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()) + .progress_chars("#>-")); + + println!( + "Uploading {} records to {}/{}", + expected, + host.0.as_simple(), + tag + ); + + loop { + let page = store + .next(host, tag.as_str(), remote + progress, page_size) + .await + .map_err(|e| { + error!("failed to read upload page: {e:?}"); + + SyncError::LocalStoreError { msg: e.to_string() } + })?; + + if page.is_empty() { + break; + } + + client.post_records(&page).await.map_err(|e| { + error!("failed to post records: {e:?}"); + + SyncError::RemoteRequestError { msg: e.to_string() } + })?; + + progress += page.len() as u64; + pb.set_position(progress); + + if progress >= expected { + break; + } + } + + pb.finish_with_message("Uploaded records"); + + Ok(progress as i64) +} + +async fn sync_download( + store: &impl Store, + client: &Client<'_>, + host: HostId, + tag: String, + local: Option<RecordIdx>, + remote: RecordIdx, + page_size: u64, +) -> Result<Vec<RecordId>, SyncError> { + let local = local.unwrap_or(0); + let expected = remote - local; + let mut progress = 0; + let mut ret = Vec::new(); + + println!( + "Downloading {} records from {}/{}", + expected, + host.0.as_simple(), + tag + ); + + let pb = ProgressBar::new(expected); + pb.set_style(ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {human_pos}/{human_len} ({eta})") + .unwrap() + .with_key("eta", |state: &ProgressState, w: &mut dyn Write| write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()) + .progress_chars("#>-")); + + loop { + let page = client + .next_records(host, tag.clone(), local + progress, page_size) + .await + .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?; + + if page.is_empty() { + break; + } + + store + .push_batch(page.iter()) + .await + .map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?; + + ret.extend(page.iter().map(|f| f.id)); + + progress += page.len() as u64; + pb.set_position(progress); + + if progress >= expected { + break; + } + } + + pb.finish_with_message("Downloaded records"); + + Ok(ret) +} + +pub async fn sync_remote( + client: &Client<'_>, + operations: Vec<Operation>, + local_store: &impl Store, + page_size: u64, +) -> Result<(i64, Vec<RecordId>), SyncError> { + let mut uploaded = 0; + let mut downloaded = Vec::new(); + + // this can totally run in parallel, but lets get it working first + for i in operations { + match i { + Operation::Upload { + host, + tag, + local, + remote, + } => { + uploaded += + sync_upload(local_store, client, host, tag, local, remote, page_size).await? + } + + Operation::Download { + host, + tag, + local, + remote, + } => { + let mut d = + sync_download(local_store, client, host, tag, local, remote, page_size).await?; + downloaded.append(&mut d) + } + + Operation::Noop { .. } => continue, + } + } + + Ok((uploaded, downloaded)) +} + +pub async fn check_encryption_key( + client: &Client<'_>, + remote_index: &RecordStatus, + encryption_key: &[u8; 32], +) -> Result<(), SyncError> { + let sample = remote_index + .hosts + .iter() + .flat_map(|(host, tags)| tags.keys().map(move |tag| (*host, tag.clone()))) + .next(); + + let Some((host, tag)) = sample else { + return Ok(()); + }; + + let records = client + .next_records(host, tag, 0, 1) + .await + .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?; + + let Some(record) = records.into_iter().next() else { + return Ok(()); + }; + + record + .decrypt::<PASETO_V4>(encryption_key) + .map_err(|_| SyncError::WrongKey)?; + + Ok(()) +} + +pub async fn sync( + settings: &Settings, + store: &impl Store, + encryption_key: &[u8; 32], +) -> Result<(i64, Vec<RecordId>), SyncError> { + let client = build_client(settings).await?; + let (diff, remote_index) = diff(&client, store).await?; + + // Bail before mutating either side if the local key can't read the remote. + check_encryption_key(&client, &remote_index, encryption_key).await?; + + let operations = operations(diff, store).await?; + let (uploaded, downloaded) = sync_remote(&client, operations, store, 100).await?; + + Ok((uploaded, downloaded)) +} + +#[cfg(test)] +mod tests { + use crate::atuin_common::record::{Diff, EncryptedData, HostId, Record}; + use pretty_assertions::assert_eq; + + use crate::atuin_client::{ + record::{ + encryption::PASETO_V4, + sqlite_store::SqliteStore, + store::Store, + sync::{self, Operation}, + }, + settings::test_local_timeout, + }; + + fn test_record() -> Record<EncryptedData> { + Record::builder() + .host(crate::atuin_common::record::Host::new(HostId( + crate::atuin_common::utils::uuid_v7(), + ))) + .version("v1".into()) + .tag(crate::atuin_common::utils::uuid_v7().simple().to_string()) + .data(EncryptedData { + data: String::new(), + content_encryption_key: String::new(), + }) + .idx(0) + .build() + } + + // Take a list of local records, and a list of remote records. + // Return the local database, and a diff of local/remote, ready to build + // ops + async fn build_test_diff( + local_records: Vec<Record<EncryptedData>>, + remote_records: Vec<Record<EncryptedData>>, + ) -> (SqliteStore, Vec<Diff>) { + let local_store = SqliteStore::new(":memory:", test_local_timeout()) + .await + .expect("failed to open in memory sqlite"); + let remote_store = SqliteStore::new(":memory:", test_local_timeout()) + .await + .expect("failed to open in memory sqlite"); // "remote" + + for i in local_records { + local_store.push(&i).await.unwrap(); + } + + for i in remote_records { + remote_store.push(&i).await.unwrap(); + } + + let local_index = local_store.status().await.unwrap(); + let remote_index = remote_store.status().await.unwrap(); + + let diff = local_index.diff(&remote_index); + + (local_store, diff) + } + + #[tokio::test] + async fn test_basic_diff() { + // a diff where local is ahead of remote. nothing else. + + let record = test_record(); + let (store, diff) = build_test_diff(vec![record.clone()], vec![]).await; + + assert_eq!(diff.len(), 1); + + let operations = sync::operations(diff, &store).await.unwrap(); + + assert_eq!(operations.len(), 1); + + assert_eq!( + operations[0], + Operation::Upload { + host: record.host.id, + tag: record.tag, + local: record.idx, + remote: None, + } + ); + } + + #[tokio::test] + async fn build_two_way_diff() { + // a diff where local is ahead of remote for one, and remote for + // another. One upload, one download + + let shared_record = test_record(); + let remote_ahead = test_record(); + + let local_ahead = shared_record + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + + assert_eq!(local_ahead.idx, 1); + + let local = vec![shared_record.clone(), local_ahead.clone()]; // local knows about the already synced, and something newer in the same store + let remote = vec![shared_record.clone(), remote_ahead.clone()]; // remote knows about the already-synced, and one new record in a new store + + let (store, diff) = build_test_diff(local, remote).await; + let operations = sync::operations(diff, &store).await.unwrap(); + + assert_eq!(operations.len(), 2); + + assert_eq!( + operations, + vec![ + // Or in otherwords, local is ahead by one + Operation::Upload { + host: local_ahead.host.id, + tag: local_ahead.tag, + local: 1, + remote: Some(0), + }, + // Or in other words, remote knows of a record in an entirely new store (tag) + Operation::Download { + host: remote_ahead.host.id, + tag: remote_ahead.tag, + local: None, + remote: 0, + }, + ] + ); + } + + #[tokio::test] + async fn build_complex_diff() { + // One shared, ahead but known only by remote + // One known only by local + // One known only by remote + + let shared_record = test_record(); + let local_only = test_record(); + + let local_only_20 = test_record(); + let local_only_21 = local_only_20 + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + let local_only_22 = local_only_21 + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + let local_only_23 = local_only_22 + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + + let remote_only = test_record(); + + let remote_only_20 = test_record(); + let remote_only_21 = remote_only_20 + .append(vec![2, 3, 2]) + .encrypt::<PASETO_V4>(&[0; 32]); + let remote_only_22 = remote_only_21 + .append(vec![2, 3, 2]) + .encrypt::<PASETO_V4>(&[0; 32]); + let remote_only_23 = remote_only_22 + .append(vec![2, 3, 2]) + .encrypt::<PASETO_V4>(&[0; 32]); + let remote_only_24 = remote_only_23 + .append(vec![2, 3, 2]) + .encrypt::<PASETO_V4>(&[0; 32]); + + let second_shared = test_record(); + let second_shared_remote_ahead = second_shared + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + let second_shared_remote_ahead2 = second_shared_remote_ahead + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + + let third_shared = test_record(); + let third_shared_local_ahead = third_shared + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + let third_shared_local_ahead2 = third_shared_local_ahead + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + + let fourth_shared = test_record(); + let fourth_shared_remote_ahead = fourth_shared + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + let fourth_shared_remote_ahead2 = fourth_shared_remote_ahead + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + + let local = vec![ + shared_record.clone(), + second_shared.clone(), + third_shared.clone(), + fourth_shared.clone(), + fourth_shared_remote_ahead.clone(), + // single store, only local has it + local_only.clone(), + // bigger store, also only known by local + local_only_20.clone(), + local_only_21.clone(), + local_only_22.clone(), + local_only_23.clone(), + // another shared store, but local is ahead on this one + third_shared_local_ahead.clone(), + third_shared_local_ahead2.clone(), + ]; + + let remote = vec![ + remote_only.clone(), + remote_only_20.clone(), + remote_only_21.clone(), + remote_only_22.clone(), + remote_only_23.clone(), + remote_only_24.clone(), + shared_record.clone(), + second_shared.clone(), + third_shared.clone(), + second_shared_remote_ahead.clone(), + second_shared_remote_ahead2.clone(), + fourth_shared.clone(), + fourth_shared_remote_ahead.clone(), + fourth_shared_remote_ahead2.clone(), + ]; // remote knows about the already-synced, and one new record in a new store + + let (store, diff) = build_test_diff(local, remote).await; + let operations = sync::operations(diff, &store).await.unwrap(); + + assert_eq!(operations.len(), 7); + + let mut result_ops = vec![ + // We started with a shared record, but the remote knows of two newer records in the + // same store + Operation::Download { + local: Some(0), + remote: 2, + host: second_shared_remote_ahead.host.id, + tag: second_shared_remote_ahead.tag, + }, + // We have a shared record, local knows of the first two but not the last + Operation::Download { + local: Some(1), + remote: 2, + host: fourth_shared_remote_ahead2.host.id, + tag: fourth_shared_remote_ahead2.tag, + }, + // Remote knows of a store with a single record that local does not have + Operation::Download { + local: None, + remote: 0, + host: remote_only.host.id, + tag: remote_only.tag, + }, + // Remote knows of a store with a bunch of records that local does not have + Operation::Download { + local: None, + remote: 4, + host: remote_only_20.host.id, + tag: remote_only_20.tag, + }, + // Local knows of a record in a store that remote does not have + Operation::Upload { + local: 0, + remote: None, + host: local_only.host.id, + tag: local_only.tag, + }, + // Local knows of 4 records in a store that remote does not have + Operation::Upload { + local: 3, + remote: None, + host: local_only_20.host.id, + tag: local_only_20.tag, + }, + // Local knows of 2 more records in a shared store that remote only has one of + Operation::Upload { + local: 2, + remote: Some(0), + host: third_shared.host.id, + tag: third_shared.tag, + }, + ]; + + result_ops.sort_by_key(|op| match op { + Operation::Noop { host, tag } => (0, *host, tag.clone()), + + Operation::Upload { host, tag, .. } => (1, *host, tag.clone()), + + Operation::Download { host, tag, .. } => (2, *host, tag.clone()), + }); + + assert_eq!(result_ops, operations); + } +} diff --git a/crates/turtle/src/atuin_client/register.rs b/crates/turtle/src/atuin_client/register.rs new file mode 100644 index 00000000..4b14c233 --- /dev/null +++ b/crates/turtle/src/atuin_client/register.rs @@ -0,0 +1,20 @@ +use eyre::Result; + +use crate::atuin_client::{api_client, settings::Settings}; + +pub async fn register_classic( + settings: &Settings, + username: String, + email: String, + password: String, +) -> Result<String> { + let session = + api_client::register(settings.sync_address.as_str(), &username, &email, &password).await?; + + let meta = Settings::meta_store().await?; + meta.save_session(&session.session).await?; + + let _key = crate::atuin_client::encryption::load_key(settings)?; + + Ok(session.session) +} diff --git a/crates/turtle/src/atuin_client/secrets.rs b/crates/turtle/src/atuin_client/secrets.rs new file mode 100644 index 00000000..e8a6ab62 --- /dev/null +++ b/crates/turtle/src/atuin_client/secrets.rs @@ -0,0 +1,194 @@ +// This file will probably trigger a lot of scanners. Sorry. + +use regex::RegexSet; +use std::sync::LazyLock; + +pub enum TestValue<'a> { + Single(&'a str), + Multiple(&'a [&'a str]), +} + +/// A list of `(name, regex, test)`, where `test` should match against `regex`. +pub static SECRET_PATTERNS: &[(&str, &str, TestValue)] = &[ + ( + "AWS Access Key ID", + "A[KS]IA[0-9A-Z]{16}", + TestValue::Single("AKIAIOSFODNN7EXAMPLE"), + ), + ( + "AWS Secret Access Key env var", + "AWS_SECRET_ACCESS_KEY", + TestValue::Single("AWS_SECRET_ACCESS_KEY=KEYDATA"), + ), + ( + "AWS Session Token env var", + "AWS_SESSION_TOKEN", + TestValue::Single("AWS_SESSION_TOKEN=KEYDATA"), + ), + ( + "Microsoft Azure secret access key env var", + "AZURE_.*_KEY", + TestValue::Single("export AZURE_STORAGE_ACCOUNT_KEY=KEYDATA"), + ), + ( + "Google cloud platform key env var", + "GOOGLE_SERVICE_ACCOUNT_KEY", + TestValue::Single("export GOOGLE_SERVICE_ACCOUNT_KEY=KEYDATA"), + ), + ( + "Atuin login", + r"atuin\s+login", + TestValue::Single( + "atuin login -u mycoolusername -p mycoolpassword -k \"lots of random words\"", + ), + ), + ( + "GitHub PAT (old)", + "ghp_[a-zA-Z0-9]{36}", + TestValue::Single("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"), // legit, I expired it + ), + ( + "GitHub PAT (new)", + "gh1_[A-Za-z0-9]{21}_[A-Za-z0-9]{59}|github_pat_[0-9][A-Za-z0-9]{21}_[A-Za-z0-9]{59}", + TestValue::Multiple(&[ + "gh1_1234567890abcdefghijk_1234567890abcdefghijklmnopqrstuvwxyz1234567890abcdefghijklm", + "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", // also legit, also expired + ]), + ), + ( + "GitHub OAuth Access Token", + "gho_[A-Za-z0-9]{36}", + TestValue::Single("gho_1234567890abcdefghijklmnopqrstuvwx000"), // not a real token + ), + ( + "GitHub OAuth Access Token (user)", + "ghu_[A-Za-z0-9]{36}", + TestValue::Single("ghu_1234567890abcdefghijklmnopqrstuvwx000"), // not a real token + ), + ( + "GitHub App Installation Access Token", + "ghs_[A-Za-z0-9._-]{36,}", + TestValue::Multiple(&[ + "ghs_1234567890abcdefghijklmnopqrstuvwx000", // not a real token + "ghs_abc-def.ghi_jklMNOP0123456789qrstuv-wxyzABCD", // new token format, fake data + ]), + ), + ( + "GitHub Refresh Token", + "ghr_[A-Za-z0-9]{76}", + TestValue::Single( + "ghr_1234567890abcdefghijklmnopqrstuvwx1234567890abcdefghijklmnopqrstuvwx1234567890abcdefghijklmnopqrstuvwx", + ), // not a real token + ), + ( + "GitHub App Installation Access Token v1", + "v1\\.[0-9A-Fa-f]{40}", + TestValue::Single("v1.1234567890abcdef1234567890abcdef12345678"), // not a real token + ), + ( + "GitLab PAT", + "glpat-[a-zA-Z0-9_]{20}", + TestValue::Single("glpat-RkE_BG5p_bbjML21WSfy"), + ), + ( + "Slack OAuth v2 bot", + "xoxb-[0-9]{11}-[0-9]{11}-[0-9a-zA-Z]{24}", + TestValue::Single("xoxb-17653672481-19874698323-pdFZKVeTuE8sk7oOcBrzbqgy"), + ), + ( + "Slack OAuth v2 user token", + "xoxp-[0-9]{11}-[0-9]{11}-[0-9a-zA-Z]{24}", + TestValue::Single("xoxp-17653672481-19874698323-pdFZKVeTuE8sk7oOcBrzbqgy"), + ), + ( + "Slack webhook", + "T[a-zA-Z0-9_]{8}/B[a-zA-Z0-9_]{8}/[a-zA-Z0-9_]{24}", + TestValue::Single( + "https://hooks.slack.com/services/T00000000/B00000000/XXXXXXXXXXXXXXXXXXXXXXXX", + ), + ), + ( + "Stripe test key", + "sk_test_[0-9a-zA-Z]{24}", + TestValue::Single("sk_test_1234567890abcdefghijklmnop"), + ), + ( + "Stripe live key", + "sk_live_[0-9a-zA-Z]{24}", + TestValue::Single("sk_live_1234567890abcdefghijklmnop"), + ), + ( + "Netlify authentication token", + "nf[pcoub]_[0-9a-zA-Z]{36}", + TestValue::Single("nfp_nBh7BdJxUwyaBBwFzpyD29MMFT6pZ9wq5634"), + ), + ( + "npm token", + "npm_[A-Za-z0-9]{36}", + TestValue::Single("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"), + ), + ( + "Pulumi personal access token", + "pul-[0-9a-f]{40}", + TestValue::Single("pul-683c2770662c51d960d72ec27613be7653c5cb26"), + ), +]; + +/// The `regex` expressions from [`SECRET_PATTERNS`] compiled into a `RegexSet`. +pub static SECRET_PATTERNS_RE: LazyLock<RegexSet> = LazyLock::new(|| { + let exprs = SECRET_PATTERNS.iter().map(|f| f.1); + RegexSet::new(exprs).expect("Failed to build secrets regex") +}); + +#[cfg(test)] +mod tests { + use regex::Regex; + + use crate::secrets::{SECRET_PATTERNS, TestValue}; + + #[test] + fn test_secrets() { + for (name, regex, test) in SECRET_PATTERNS { + let re = + Regex::new(regex).unwrap_or_else(|_| panic!("Failed to compile regex for {name}")); + + match test { + TestValue::Single(test) => { + assert!(re.is_match(test), "{name} test failed!"); + } + TestValue::Multiple(tests) => { + for test_str in tests.iter() { + assert!( + re.is_match(test_str), + "{name} test with value \"{test_str}\" failed!" + ); + } + } + } + } + } + + #[test] + fn test_secrets_embedded() { + for (name, regex, test) in SECRET_PATTERNS { + let re = + Regex::new(regex).unwrap_or_else(|_| panic!("Failed to compile regex for {name}")); + + match test { + TestValue::Single(test) => { + let embedded = format!("some random text {test} some more random text"); + assert!(re.is_match(&embedded), "{name} embedded test failed!"); + } + TestValue::Multiple(tests) => { + for test_str in tests.iter() { + let embedded = format!("some random text {test_str} some more random text"); + assert!( + re.is_match(&embedded), + "{name} embedded test with value \"{test_str}\" failed!" + ); + } + } + } + } + } +} diff --git a/crates/turtle/src/atuin_client/settings.rs b/crates/turtle/src/atuin_client/settings.rs new file mode 100644 index 00000000..b0ffc7c1 --- /dev/null +++ b/crates/turtle/src/atuin_client/settings.rs @@ -0,0 +1,1851 @@ +use std::{collections::HashMap, fmt, io::prelude::*, path::PathBuf, str::FromStr, sync::OnceLock}; +use tokio::sync::OnceCell; + +use crate::atuin_common::record::HostId; +use crate::atuin_common::utils; +use clap::ValueEnum; +use config::{ + Config, ConfigBuilder, Environment, File as ConfigFile, FileFormat, builder::DefaultState, +}; +use eyre::{Context, Error, Result, bail, eyre}; +use fs_err::{File, create_dir_all}; +use humantime::parse_duration; +use regex::RegexSet; +use serde::{Deserialize, Serialize}; +use serde_with::DeserializeFromStr; +use time::{OffsetDateTime, UtcOffset, format_description::FormatItem, macros::format_description}; + +pub const HISTORY_PAGE_SIZE: i64 = 100; + +static DATA_DIR: OnceLock<PathBuf> = OnceLock::new(); +static META_CONFIG: OnceLock<(String, f64)> = OnceLock::new(); +static META_STORE: OnceCell<crate::atuin_client::meta::MetaStore> = OnceCell::const_new(); + +pub(crate) mod meta; +pub mod watcher; + +/// Default sync address for Atuin's hosted service +pub const DEFAULT_SYNC_ADDRESS: &str = "https://api.atuin.sh"; + +#[derive(Clone, Debug, Deserialize, Copy, ValueEnum, PartialEq, Serialize)] +pub enum SearchMode { + #[serde(rename = "prefix")] + Prefix, + + #[serde(rename = "fulltext")] + #[clap(aliases = &["fulltext"])] + FullText, + + #[serde(rename = "fuzzy")] + Fuzzy, + + #[serde(rename = "skim")] + Skim, + + #[serde(rename = "daemon-fuzzy")] + #[clap(aliases = &["daemon-fuzzy"])] + DaemonFuzzy, +} + +impl SearchMode { + pub fn as_str(&self) -> &'static str { + match self { + SearchMode::Prefix => "PREFIX", + SearchMode::FullText => "FULLTXT", + SearchMode::Fuzzy => "FUZZY", + SearchMode::Skim => "SKIM", + SearchMode::DaemonFuzzy => "DAEMON", + } + } + pub fn next(&self, settings: &Settings) -> Self { + match self { + SearchMode::Prefix => SearchMode::FullText, + // if the user is using skim, we go to skim + SearchMode::FullText if settings.search_mode == SearchMode::Skim => SearchMode::Skim, + // if the user is using daemon-fuzzy, we go to daemon-fuzzy + SearchMode::FullText if settings.search_mode == SearchMode::DaemonFuzzy => { + SearchMode::DaemonFuzzy + } + // otherwise fuzzy. + SearchMode::FullText => SearchMode::Fuzzy, + SearchMode::Fuzzy | SearchMode::Skim | SearchMode::DaemonFuzzy => SearchMode::Prefix, + } + } +} + +#[derive(Clone, Debug, Deserialize, Copy, PartialEq, Eq, ValueEnum, Serialize)] +pub enum FilterMode { + #[serde(rename = "global")] + Global = 0, + + #[serde(rename = "host")] + Host = 1, + + #[serde(rename = "session")] + Session = 2, + + #[serde(rename = "directory")] + Directory = 3, + + #[serde(rename = "workspace")] + Workspace = 4, + + #[serde(rename = "session-preload")] + SessionPreload = 5, +} + +impl FilterMode { + pub fn as_str(&self) -> &'static str { + match self { + FilterMode::Global => "GLOBAL", + FilterMode::Host => "HOST", + FilterMode::Session => "SESSION", + FilterMode::Directory => "DIRECTORY", + FilterMode::Workspace => "WORKSPACE", + FilterMode::SessionPreload => "SESSION+", + } + } +} + +#[derive(Clone, Debug, Deserialize, Copy, Serialize)] +pub enum ExitMode { + #[serde(rename = "return-original")] + ReturnOriginal, + + #[serde(rename = "return-query")] + ReturnQuery, +} + +// FIXME: Can use upstream Dialect enum if https://github.com/stevedonovan/chrono-english/pull/16 is merged +// FIXME: Above PR was merged, but dependency was changed to interim (fork of chrono-english) in the ... interim +#[derive(Clone, Debug, Deserialize, Copy, Serialize)] +pub enum Dialect { + #[serde(rename = "us")] + Us, + + #[serde(rename = "uk")] + Uk, +} + +impl From<Dialect> for interim::Dialect { + fn from(d: Dialect) -> interim::Dialect { + match d { + Dialect::Uk => interim::Dialect::Uk, + Dialect::Us => interim::Dialect::Us, + } + } +} + +/// Type wrapper around `time::UtcOffset` to support a wider variety of timezone formats. +/// +/// Note that the parsing of this struct needs to be done before starting any +/// multithreaded runtime, otherwise it will fail on most Unix systems. +/// +/// See: <https://github.com/atuinsh/atuin/pull/1517#discussion_r1447516426> +#[derive(Clone, Copy, Debug, Eq, PartialEq, DeserializeFromStr, Serialize)] +pub struct Timezone(pub UtcOffset); +impl fmt::Display for Timezone { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} +/// format: <+|-><hour>[:<minute>[:<second>]] +static OFFSET_FMT: &[FormatItem<'_>] = format_description!( + "[offset_hour sign:mandatory padding:none][optional [:[offset_minute padding:none][optional [:[offset_second padding:none]]]]]" +); +impl FromStr for Timezone { + type Err = Error; + + fn from_str(s: &str) -> Result<Self> { + // local timezone + if matches!(s.to_lowercase().as_str(), "l" | "local") { + // There have been some timezone issues, related to errors fetching it on some + // platforms + // Rather than fail to start, fallback to UTC. The user should still be able to specify + // their timezone manually in the config file. + let offset = UtcOffset::current_local_offset().unwrap_or(UtcOffset::UTC); + return Ok(Self(offset)); + } + + if matches!(s.to_lowercase().as_str(), "0" | "utc") { + let offset = UtcOffset::UTC; + return Ok(Self(offset)); + } + + // offset from UTC + if let Ok(offset) = UtcOffset::parse(s, OFFSET_FMT) { + return Ok(Self(offset)); + } + + // IDEA: Currently named timezones are not supported, because the well-known crate + // for this is `chrono_tz`, which is not really interoperable with the datetime crate + // that we currently use - `time`. If ever we migrate to using `chrono`, this would + // be a good feature to add. + + bail!(r#""{s}" is not a valid timezone spec"#) + } +} + +#[derive(Clone, Debug, Deserialize, Copy, Serialize)] +pub enum Style { + #[serde(rename = "auto")] + Auto, + + #[serde(rename = "full")] + Full, + + #[serde(rename = "compact")] + Compact, +} + +#[derive(Clone, Debug, Deserialize, Copy, Serialize)] +pub enum WordJumpMode { + #[serde(rename = "emacs")] + Emacs, + + #[serde(rename = "subl")] + Subl, +} + +#[derive(Clone, Debug, Deserialize, Copy, PartialEq, Eq, ValueEnum, Serialize)] +pub enum KeymapMode { + #[serde(rename = "emacs")] + Emacs, + + #[serde(rename = "vim-normal")] + VimNormal, + + #[serde(rename = "vim-insert")] + VimInsert, + + #[serde(rename = "auto")] + Auto, +} + +impl KeymapMode { + pub fn as_str(&self) -> &'static str { + match self { + KeymapMode::Emacs => "EMACS", + KeymapMode::VimNormal => "VIMNORMAL", + KeymapMode::VimInsert => "VIMINSERT", + KeymapMode::Auto => "AUTO", + } + } +} + +// We want to translate the config to crossterm::cursor::SetCursorStyle, but +// the original type does not implement trait serde::Deserialize unfortunately. +// It seems impossible to implement Deserialize for external types when it is +// used in HashMap (https://stackoverflow.com/questions/67142663). We instead +// define an adapter type. +#[derive(Clone, Debug, Deserialize, Copy, PartialEq, Eq, ValueEnum, Serialize)] +pub enum CursorStyle { + #[serde(rename = "default")] + DefaultUserShape, + + #[serde(rename = "blink-block")] + BlinkingBlock, + + #[serde(rename = "steady-block")] + SteadyBlock, + + #[serde(rename = "blink-underline")] + BlinkingUnderScore, + + #[serde(rename = "steady-underline")] + SteadyUnderScore, + + #[serde(rename = "blink-bar")] + BlinkingBar, + + #[serde(rename = "steady-bar")] + SteadyBar, +} + +impl CursorStyle { + pub fn as_str(&self) -> &'static str { + match self { + CursorStyle::DefaultUserShape => "DEFAULT", + CursorStyle::BlinkingBlock => "BLINKBLOCK", + CursorStyle::SteadyBlock => "STEADYBLOCK", + CursorStyle::BlinkingUnderScore => "BLINKUNDERLINE", + CursorStyle::SteadyUnderScore => "STEADYUNDERLINE", + CursorStyle::BlinkingBar => "BLINKBAR", + CursorStyle::SteadyBar => "STEADYBAR", + } + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Stats { + #[serde(default = "Stats::common_prefix_default")] + pub common_prefix: Vec<String>, // sudo, etc. commands we want to strip off + #[serde(default = "Stats::common_subcommands_default")] + pub common_subcommands: Vec<String>, // kubectl, commands we should consider subcommands for + #[serde(default = "Stats::ignored_commands_default")] + pub ignored_commands: Vec<String>, // cd, ls, etc. commands we want to completely hide from stats +} + +impl Stats { + fn common_prefix_default() -> Vec<String> { + vec!["sudo", "doas"].into_iter().map(String::from).collect() + } + + fn common_subcommands_default() -> Vec<String> { + vec![ + "apt", + "cargo", + "composer", + "dnf", + "docker", + "dotnet", + "git", + "go", + "ip", + "jj", + "kubectl", + "nix", + "nmcli", + "npm", + "pecl", + "pnpm", + "podman", + "port", + "systemctl", + "tmux", + "yarn", + ] + .into_iter() + .map(String::from) + .collect() + } + + fn ignored_commands_default() -> Vec<String> { + vec![] + } +} + +impl Default for Stats { + fn default() -> Self { + Self { + common_prefix: Self::common_prefix_default(), + common_subcommands: Self::common_subcommands_default(), + ignored_commands: Self::ignored_commands_default(), + } + } +} + +/// Sync protocol type for authentication. +/// +/// This setting is primarily for development/testing. When not explicitly set, +/// the protocol is inferred from the sync_address: +/// - Default sync address (api.atuin.sh) → Hub protocol +/// - Custom sync address → Legacy protocol +/// +/// Set explicitly to "hub" to use Hub authentication with a custom sync_address +/// (useful for local development against a Hub instance). +#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Eq, Serialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum SyncProtocol { + /// Use legacy CLI authentication (Token from CLI register/login) + #[default] + Legacy, +} + +/// Resolved authentication state for sync operations. +/// +/// Determined at runtime by examining which tokens are available and what +/// server the client is configured to talk to. Operations use this to pick +/// the right auth header and endpoint style. +#[cfg(feature = "sync")] +#[derive(Debug, Clone)] +pub enum SyncAuth { + /// Self-hosted Rust server. Uses `Authorization: Token <session>` and + /// legacy endpoints. + Legacy { token: String }, + + /// Not authenticated at all. Contains an actionable user-facing message. + NotLoggedIn { reason: String }, +} + +#[cfg(feature = "sync")] +impl SyncAuth { + /// Convert into the auth token type used by the API client. + /// + /// Returns an error with an actionable message for `NotLoggedIn`. + pub fn into_auth_token(self) -> Result<crate::atuin_client::api_client::AuthToken> { + use crate::atuin_client::api_client::AuthToken; + match self { + SyncAuth::Legacy { token } => Ok(AuthToken::Token(token)), + SyncAuth::NotLoggedIn { reason } => Err(eyre!(reason)), + } + } +} + +#[derive(Clone, Debug, Deserialize, Default, Serialize)] +pub struct Keys { + pub scroll_exits: bool, + pub exit_past_line_start: bool, + pub accept_past_line_end: bool, + pub accept_past_line_start: bool, + pub accept_with_backspace: bool, + pub prefix: String, +} + +impl Keys { + /// The standard default values for all `[keys]` options. + /// These match the config defaults set in `builder_with_data_dir()`. + pub fn standard_defaults() -> Self { + Keys { + scroll_exits: true, + exit_past_line_start: true, + accept_past_line_end: true, + accept_past_line_start: false, + accept_with_backspace: false, + prefix: "a".to_string(), + } + } + + /// Returns true if any value differs from the standard defaults. + pub fn has_non_default_values(&self) -> bool { + let d = Self::standard_defaults(); + self.scroll_exits != d.scroll_exits + || self.exit_past_line_start != d.exit_past_line_start + || self.accept_past_line_end != d.accept_past_line_end + || self.accept_past_line_start != d.accept_past_line_start + || self.accept_with_backspace != d.accept_with_backspace + || self.prefix != d.prefix + } +} + +/// A single rule within a conditional keybinding config. +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct KeyRuleConfig { + /// Optional condition expression (e.g. "cursor-at-start", "input-empty && no-results"). + /// If absent, the rule always matches. + #[serde(default)] + pub when: Option<String>, + /// The action to perform (e.g. "exit", "cursor-left", "accept"). + pub action: String, +} + +/// A keybinding config value: either a simple action string or an ordered list of conditional rules. +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(untagged)] +pub enum KeyBindingConfig { + /// Simple unconditional binding: `"ctrl-c" = "return-original"` + Simple(String), + /// Conditional binding: `"left" = [{ when = "cursor-at-start", action = "exit" }, { action = "cursor-left" }]` + Rules(Vec<KeyRuleConfig>), +} + +/// User-facing keymap configuration. Each mode maps key strings to bindings. +/// Keys present here override the defaults for that key; unmentioned keys keep defaults. +#[derive(Clone, Debug, Deserialize, Serialize, Default)] +pub struct KeymapConfig { + #[serde(default)] + pub emacs: HashMap<String, KeyBindingConfig>, + #[serde(default, rename = "vim-normal")] + pub vim_normal: HashMap<String, KeyBindingConfig>, + #[serde(default, rename = "vim-insert")] + pub vim_insert: HashMap<String, KeyBindingConfig>, + #[serde(default)] + pub inspector: HashMap<String, KeyBindingConfig>, + #[serde(default)] + pub prefix: HashMap<String, KeyBindingConfig>, +} + +impl KeymapConfig { + /// Returns true if no keybinding overrides are configured in any mode. + pub fn is_empty(&self) -> bool { + self.emacs.is_empty() + && self.vim_normal.is_empty() + && self.vim_insert.is_empty() + && self.inspector.is_empty() + && self.prefix.is_empty() + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Preview { + pub strategy: PreviewStrategy, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Theme { + /// Name of desired theme ("default" for base) + pub name: String, + + /// Whether any available additional theme debug should be shown + pub debug: Option<bool>, + + /// How many levels of parenthood will be traversed if needed + pub max_depth: Option<u8>, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Daemon { + /// Use the daemon to sync + /// If enabled, history hooks are routed through the daemon. + #[serde(alias = "enable")] + pub enabled: bool, + + /// Automatically start and manage a local daemon when needed. + pub autostart: bool, + + /// The daemon will handle sync on an interval. How often to sync, in seconds. + pub sync_frequency: u64, + + /// The path to the unix socket used by the daemon + pub socket_path: String, + + /// Path to the daemon pidfile used for process coordination. + pub pidfile_path: String, + + /// Use a socket passed via systemd's socket activation protocol, instead of the path + pub systemd_socket: bool, + + /// The port that should be used for TCP on non unix systems + pub tcp_port: u64, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Search { + /// The list of enabled filter modes, in order of priority. + pub filters: Vec<FilterMode>, + + /// The recency score multiplier for the search index (default: 1.0). + /// Values < 1.0 reduce weight, > 1.0 increase weight, 0.0 disables. + pub recency_score_multiplier: f64, + + /// The frequency score multiplier for the search index (default: 1.0). + /// Values < 1.0 reduce weight, > 1.0 increase weight, 0.0 disables. + pub frequency_score_multiplier: f64, + + /// The overall frecency score multiplier for the search index (default: 1.0). + /// Applied after combining recency and frequency scores. + pub frecency_score_multiplier: f64, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Tmux { + /// Enable using atuin with tmux popup (tmux >= 3.2) + pub enabled: bool, + + /// Width of the tmux popup (percentage) + pub width: String, + + /// Height of the tmux popup (percentage) + pub height: String, +} + +/// Log level for file logging. Maps to tracing's LevelFilter. +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum LogLevel { + Trace, + Debug, + #[default] + Info, + Warn, + Error, +} + +impl LogLevel { + /// Convert to a tracing directive string for use with EnvFilter. + pub fn as_directive(&self) -> &'static str { + match self { + LogLevel::Trace => "trace", + LogLevel::Debug => "debug", + LogLevel::Info => "info", + LogLevel::Warn => "warn", + LogLevel::Error => "error", + } + } +} + +/// Configuration for a specific log type (search or daemon). +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +pub struct LogConfig { + /// Log file name (relative to dir) or absolute path. + pub file: String, + + /// Override global enabled setting for this log type. + pub enabled: Option<bool>, + + /// Override global level setting for this log type. + pub level: Option<LogLevel>, + + /// Override global retention days setting for this log type. + pub retention: Option<u64>, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Logs { + /// Enable file logging globally. Defaults to true. + #[serde(default = "Logs::default_enabled")] + pub enabled: bool, + + /// Directory for log files. Defaults to ~/.atuin/logs + pub dir: String, + + /// Default log level for file logging. Defaults to "info". + /// Note: ATUIN_LOG environment variable overrides this. + #[serde(default)] + pub level: LogLevel, + + /// Default retention days for log files. Defaults to 4. + #[serde(default = "Logs::default_retention")] + pub retention: u64, + + /// Search log settings + #[serde(default)] + pub search: LogConfig, + + /// Daemon log settings + #[serde(default)] + pub daemon: LogConfig, + + /// AI log settings + #[serde(default)] + pub ai: LogConfig, +} + +#[derive(Default, Clone, Debug, Deserialize, Serialize)] +pub struct Ai { + /// Whether or not the AI features are enabled. + pub enabled: Option<bool>, + + /// The address of the Atuin AI endpoint. Used for AI features like command generation. + /// Only necessary for custom AI endpoints. + pub endpoint: Option<String>, + + /// The API token for the Atuin AI endpoint. Used for AI features like command generation. + /// Only necessary for custom AI endpoints. + pub api_token: Option<String>, + + /// Path to the AI sessions database. + pub db_path: String, + + /// The maximum time in minutes that an AI session can be automatically resumed. + pub session_continue_minutes: i64, + + /// Deprecated: use opening.send_cwd instead. Kept for backwards compatibility. + #[serde(default)] + pub send_cwd: Option<bool>, + + /// Configuration for what context is sent in the opening AI request. + #[serde(default)] + pub opening: AiOpening, + + /// Tool capability flags. + #[serde(default)] + pub capabilities: AiCapabilities, +} + +#[derive(Default, Clone, Debug, Deserialize, Serialize)] +pub struct AiCapabilities { + /// Whether the AI can request to search Atuin history. `None` = unset (defaults to enabled, and the ai will ask for permission). + pub enable_history_search: Option<bool>, + /// Whether the AI can request to view the stored output, if any, for Atuin history entries. `None` = unset (defaults to enabled, and the ai will ask for permission). + pub enable_history_output: Option<bool>, + /// Whether the AI can request to read and write files. `None` = unset (defaults to enabled, and the ai will ask for permission). + pub enable_file_tools: Option<bool>, + /// Whether the AI can request to execute bash commands. `None` = unset (defaults to enabled, and the ai will ask for permission). + pub enable_command_execution: Option<bool>, +} + +#[derive(Default, Clone, Debug, Deserialize, Serialize)] +pub struct AiOpening { + /// Whether or not to send the current working directory to the AI endpoint. + pub send_cwd: Option<bool>, + + /// Whether or not to send the last command as context in the opening AI request. + pub send_last_command: Option<bool>, +} + +impl Default for Preview { + fn default() -> Self { + Self { + strategy: PreviewStrategy::Auto, + } + } +} + +impl Default for Theme { + fn default() -> Self { + Self { + name: "".to_string(), + debug: None::<bool>, + max_depth: Some(10), + } + } +} + +impl Default for Daemon { + fn default() -> Self { + Self { + enabled: false, + autostart: false, + sync_frequency: 300, + socket_path: "".to_string(), + pidfile_path: "".to_string(), + systemd_socket: false, + tcp_port: 8889, + } + } +} + +impl Default for Logs { + fn default() -> Self { + Self { + enabled: true, + dir: "".to_string(), + level: LogLevel::default(), + retention: Self::default_retention(), + search: LogConfig { + file: "search.log".to_string(), + ..Default::default() + }, + daemon: LogConfig { + file: "daemon.log".to_string(), + ..Default::default() + }, + ai: LogConfig { + file: "ai.log".to_string(), + ..Default::default() + }, + } + } +} + +impl Logs { + fn default_enabled() -> bool { + true + } + + fn default_retention() -> u64 { + 4 + } + + /// Returns whether search logging is enabled. + /// Uses search-specific setting if set, otherwise falls back to global. + pub fn search_enabled(&self) -> bool { + self.search.enabled.unwrap_or(self.enabled) + } + + /// Returns whether daemon logging is enabled. + /// Uses daemon-specific setting if set, otherwise falls back to global. + pub fn daemon_enabled(&self) -> bool { + self.daemon.enabled.unwrap_or(self.enabled) + } + + /// Returns whether AI logging is enabled. + /// Uses AI-specific setting if set, otherwise falls back to global. + pub fn ai_enabled(&self) -> bool { + self.ai.enabled.unwrap_or(self.enabled) + } + + /// Returns the log level for search logging. + /// Uses search-specific setting if set, otherwise falls back to global. + pub fn search_level(&self) -> LogLevel { + self.search.level.unwrap_or(self.level) + } + + /// Returns the log level for daemon logging. + /// Uses daemon-specific setting if set, otherwise falls back to global. + pub fn daemon_level(&self) -> LogLevel { + self.daemon.level.unwrap_or(self.level) + } + + /// Returns the log level for AI logging. + /// Uses AI-specific setting if set, otherwise falls back to global. + pub fn ai_level(&self) -> LogLevel { + self.ai.level.unwrap_or(self.level) + } + + /// Returns the retention days for search logging. + /// Uses search-specific setting if set, otherwise falls back to global. + pub fn search_retention(&self) -> u64 { + self.search.retention.unwrap_or(self.retention) + } + + /// Returns the retention days for daemon logging. + /// Uses daemon-specific setting if set, otherwise falls back to global. + pub fn daemon_retention(&self) -> u64 { + self.daemon.retention.unwrap_or(self.retention) + } + + /// Returns the retention days for AI logging. + /// Uses AI-specific setting if set, otherwise falls back to global. + pub fn ai_retention(&self) -> u64 { + self.ai.retention.unwrap_or(self.retention) + } + + /// Returns the full path for the search log file. + pub fn search_path(&self) -> PathBuf { + let path = PathBuf::from(&self.search.file); + PathBuf::from(&self.dir).join(path) + } + + /// Returns the full path for the daemon log file. + pub fn daemon_path(&self) -> PathBuf { + let path = PathBuf::from(&self.daemon.file); + PathBuf::from(&self.dir).join(path) + } + + /// Returns the full path for the AI log file. + pub fn ai_path(&self) -> PathBuf { + let path = PathBuf::from(&self.ai.file); + PathBuf::from(&self.dir).join(path) + } +} + +impl Default for Search { + fn default() -> Self { + Self { + filters: vec![ + FilterMode::Global, + FilterMode::Host, + FilterMode::Session, + FilterMode::SessionPreload, + FilterMode::Workspace, + FilterMode::Directory, + ], + + recency_score_multiplier: 1.0, + frequency_score_multiplier: 1.0, + frecency_score_multiplier: 1.0, + } + } +} + +impl Default for Tmux { + fn default() -> Self { + Self { + enabled: false, + width: "80%".to_string(), + height: "60%".to_string(), + } + } +} + +// The preview height strategy also takes max_preview_height into account. +#[derive(Clone, Debug, Deserialize, Copy, PartialEq, Eq, ValueEnum, Serialize)] +pub enum PreviewStrategy { + // Preview height is calculated for the length of the selected command. + #[serde(rename = "auto")] + Auto, + + // Preview height is calculated for the length of the longest command stored in the history. + #[serde(rename = "static")] + Static, + + // max_preview_height is used as fixed height. + #[serde(rename = "fixed")] + Fixed, +} + +/// Column types available for the interactive search UI. +#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Eq, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum UiColumnType { + /// Command execution duration (e.g., "123ms") + Duration, + /// Relative time since execution (e.g., "59s ago") + Time, + /// Absolute timestamp (e.g., "2025-01-22 14:35") + Datetime, + /// Working directory + Directory, + /// Hostname + Host, + /// Username + User, + /// Exit code + Exit, + /// The command itself (should be last, expands to fill) + Command, +} + +impl UiColumnType { + /// Returns the default width for this column type (in characters). + /// The Command column returns 0 as it expands to fill remaining space. + pub fn default_width(&self) -> u16 { + match self { + UiColumnType::Duration => 5, // "814ms" + UiColumnType::Time => 9, // "459ms ago" + UiColumnType::Datetime => 16, // "2025-01-22 14:35" + UiColumnType::Directory => 20, + UiColumnType::Host => 15, + UiColumnType::User => 10, + UiColumnType::Exit => { + if cfg!(windows) { + 11 // 32-bit integer on Windows: "-1978335212" + } else { + 3 // Usually a byte on Unix + } + } + UiColumnType::Command => 0, // Expands to fill + } + } +} + +/// A column configuration with type and optional custom width. +/// Can be specified as just a string (uses default width) or as an object with type and width. +#[derive(Clone, Debug, Serialize)] +pub struct UiColumn { + pub column_type: UiColumnType, + pub width: u16, + /// If true, this column expands to fill remaining space. Only one column should expand. + pub expand: bool, +} + +impl UiColumn { + pub fn new(column_type: UiColumnType) -> Self { + Self { + width: column_type.default_width(), + expand: column_type == UiColumnType::Command, + column_type, + } + } + + pub fn with_width(column_type: UiColumnType, width: u16) -> Self { + Self { + column_type, + width, + expand: column_type == UiColumnType::Command, + } + } +} + +// Custom deserialize to handle both string and object formats: +// "duration" or { type = "duration", width = 8, expand = true } +impl<'de> serde::Deserialize<'de> for UiColumn { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + use serde::de::{self, MapAccess, Visitor}; + + struct UiColumnVisitor; + + impl<'de> Visitor<'de> for UiColumnVisitor { + type Value = UiColumn; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str( + "a column type string or an object with 'type' and optional 'width'/'expand'", + ) + } + + fn visit_str<E>(self, value: &str) -> Result<UiColumn, E> + where + E: de::Error, + { + let column_type: UiColumnType = + serde::Deserialize::deserialize(serde::de::value::StrDeserializer::new(value))?; + Ok(UiColumn::new(column_type)) + } + + fn visit_map<M>(self, mut map: M) -> Result<UiColumn, M::Error> + where + M: MapAccess<'de>, + { + let mut column_type: Option<UiColumnType> = None; + let mut width: Option<u16> = None; + let mut expand: Option<bool> = None; + + while let Some(key) = map.next_key::<String>()? { + match key.as_str() { + "type" => { + column_type = Some(map.next_value()?); + } + "width" => { + width = Some(map.next_value()?); + } + "expand" => { + expand = Some(map.next_value()?); + } + _ => { + let _: serde::de::IgnoredAny = map.next_value()?; + } + } + } + + let column_type = column_type.ok_or_else(|| de::Error::missing_field("type"))?; + let width = width.unwrap_or_else(|| column_type.default_width()); + let expand = expand.unwrap_or(column_type == UiColumnType::Command); + Ok(UiColumn { + column_type, + width, + expand, + }) + } + } + + deserializer.deserialize_any(UiColumnVisitor) + } +} + +/// UI-specific settings for the interactive search. +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Ui { + /// Columns to display in interactive search, from left to right. + /// The indicator column (" > ") is always shown first implicitly. + /// The "command" column should be last as it expands to fill remaining space. + /// Can be simple strings or objects with type and width. + #[serde(default = "Ui::default_columns")] + pub columns: Vec<UiColumn>, +} + +impl Ui { + fn default_columns() -> Vec<UiColumn> { + vec![ + UiColumn::new(UiColumnType::Duration), + UiColumn::new(UiColumnType::Time), + UiColumn::new(UiColumnType::Command), + ] + } + + /// Validate the UI configuration. + /// Returns an error if more than one column has expand = true. + pub fn validate(&self) -> Result<()> { + let expand_count = self.columns.iter().filter(|c| c.expand).count(); + if expand_count > 1 { + bail!( + "Only one column can have expand = true, but {} columns are set to expand", + expand_count + ); + } + Ok(()) + } +} + +impl Default for Ui { + fn default() -> Self { + Self { + columns: Self::default_columns(), + } + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Settings { + pub data_dir: Option<String>, + pub dialect: Dialect, + pub timezone: Timezone, + pub style: Style, + pub auto_sync: bool, + + /// The sync address for atuin. + pub sync_address: String, + + #[serde(default)] + pub sync_protocol: SyncProtocol, + + pub sync_frequency: String, + pub db_path: String, + pub record_store_path: String, + pub key_path: String, + pub search_mode: SearchMode, + pub filter_mode: Option<FilterMode>, + pub filter_mode_shell_up_key_binding: Option<FilterMode>, + pub search_mode_shell_up_key_binding: Option<SearchMode>, + pub shell_up_key_binding: bool, + pub inline_height: u16, + pub inline_height_shell_up_key_binding: Option<u16>, + pub invert: bool, + pub show_preview: bool, + pub max_preview_height: u16, + pub show_help: bool, + pub show_tabs: bool, + pub show_numeric_shortcuts: bool, + pub auto_hide_height: u16, + pub exit_mode: ExitMode, + pub keymap_mode: KeymapMode, + pub keymap_mode_shell: KeymapMode, + pub keymap_cursor: HashMap<String, CursorStyle>, + pub word_jump_mode: WordJumpMode, + pub word_chars: String, + pub scroll_context_lines: usize, + pub history_format: String, + pub strip_trailing_whitespace: bool, + pub prefers_reduced_motion: bool, + pub store_failed: bool, + pub no_mouse: bool, + + #[serde(with = "serde_regex", default = "RegexSet::empty", skip_serializing)] + pub history_filter: RegexSet, + + #[serde(with = "serde_regex", default = "RegexSet::empty", skip_serializing)] + pub cwd_filter: RegexSet, + + pub secrets_filter: bool, + pub workspaces: bool, + pub ctrl_n_shortcuts: bool, + + pub network_connect_timeout: u64, + pub network_timeout: u64, + pub local_timeout: f64, + pub enter_accept: bool, + pub smart_sort: bool, + pub command_chaining: bool, + + #[serde(default)] + pub stats: Stats, + + #[serde(default)] + pub keys: Keys, + + #[serde(default)] + pub keymap: KeymapConfig, + + #[serde(default)] + pub preview: Preview, + + #[serde(default)] + pub daemon: Daemon, + + #[serde(default)] + pub search: Search, + + #[serde(default)] + pub theme: Theme, + + #[serde(default)] + pub ui: Ui, + + #[serde(default)] + pub tmux: Tmux, + + #[serde(default)] + pub logs: Logs, + + #[serde(default)] + pub meta: meta::Settings, +} + +impl Settings { + pub fn utc() -> Self { + Self::builder() + .expect("Could not build default") + .set_override("timezone", "0") + .expect("failed to override timezone with UTC") + .build() + .expect("Could not build config") + .try_deserialize() + .expect("Could not deserialize config") + } + + pub(crate) fn effective_data_dir() -> PathBuf { + DATA_DIR + .get() + .cloned() + .unwrap_or_else(crate::atuin_common::utils::data_dir) + } + + // -- Meta store: lazily initialized on first access -- + + pub async fn meta_store() -> Result<&'static crate::atuin_client::meta::MetaStore> { + META_STORE + .get_or_try_init(|| async { + let (db_path, timeout) = META_CONFIG.get().ok_or_else(|| { + eyre!("meta store config not set — Settings::new() has not been called") + })?; + crate::atuin_client::meta::MetaStore::new(db_path, *timeout).await + }) + .await + } + + pub async fn host_id() -> Result<HostId> { + Self::meta_store().await?.host_id().await + } + + pub async fn last_sync() -> Result<OffsetDateTime> { + Self::meta_store().await?.last_sync().await + } + + pub async fn save_sync_time() -> Result<()> { + Self::meta_store().await?.save_sync_time().await + } + + pub async fn last_version_check() -> Result<OffsetDateTime> { + Self::meta_store().await?.last_version_check().await + } + + pub async fn save_version_check_time() -> Result<()> { + Self::meta_store().await?.save_version_check_time().await + } + + pub async fn should_sync(&self) -> Result<bool> { + if !self.auto_sync || !Self::meta_store().await?.logged_in().await? { + return Ok(false); + } + + if self.sync_frequency == "0" { + return Ok(true); + } + + match parse_duration(self.sync_frequency.as_str()) { + Ok(d) => { + let d = time::Duration::try_from(d)?; + Ok(OffsetDateTime::now_utc() - Settings::last_sync().await? >= d) + } + Err(e) => Err(eyre!("failed to check sync: {}", e)), + } + } + + pub async fn logged_in(&self) -> Result<bool> { + Self::meta_store().await?.logged_in().await + } + + pub async fn session_token(&self) -> Result<String> { + match Self::meta_store().await?.session_token().await? { + Some(token) => Ok(token), + None => Err(eyre!("Tried to load session; not logged in")), + } + } + + /// Examines the configured sync target and available tokens to determine + /// the correct auth strategy. Also performs cleanup of mis-stored tokens + /// (e.g. a CLI token incorrectly saved in the Hub session slot). + #[cfg(feature = "sync")] + pub async fn resolve_sync_auth(&self) -> SyncAuth { + let meta = match Self::meta_store().await { + Ok(m) => m, + Err(e) => { + return SyncAuth::NotLoggedIn { + reason: format!("Failed to open meta store: {e}"), + }; + } + }; + + // Self-hosted / legacy server + match meta.session_token().await { + Ok(Some(token)) => SyncAuth::Legacy { token }, + _ => SyncAuth::NotLoggedIn { + reason: "Not logged in. Run 'atuin login' to authenticate \ + with your sync server." + .into(), + }, + } + } + + /// Returns the appropriate auth token for sync operations. + /// + /// Delegates to [`resolve_sync_auth`] and converts the result to an + /// `AuthToken`. Callers that need to distinguish between auth states + /// (e.g. to show different UI) should call `resolve_sync_auth` directly. + #[cfg(feature = "sync")] + pub async fn sync_auth_token(&self) -> Result<crate::atuin_client::api_client::AuthToken> { + self.resolve_sync_auth().await.into_auth_token() + } + + pub fn default_filter_mode(&self, git_root: bool) -> FilterMode { + self.filter_mode + .filter(|x| self.search.filters.contains(x)) + .or_else(|| { + self.search + .filters + .iter() + .find(|x| match (x, git_root, self.workspaces) { + (FilterMode::Workspace, true, true) => true, + (FilterMode::Workspace, _, _) => false, + (_, _, _) => true, + }) + .copied() + }) + .unwrap_or(FilterMode::Global) + } + + pub fn builder() -> Result<ConfigBuilder<DefaultState>> { + Self::builder_with_data_dir(&crate::atuin_common::utils::data_dir()) + } + + fn builder_with_data_dir(data_dir: &std::path::Path) -> Result<ConfigBuilder<DefaultState>> { + let db_path = data_dir.join("history.db"); + let record_store_path = data_dir.join("records.db"); + let kv_path = data_dir.join("kv.db"); + let scripts_path = data_dir.join("scripts.db"); + let ai_sessions_path = data_dir.join("ai_sessions.db"); + let socket_path = crate::atuin_common::utils::runtime_dir().join("atuin.sock"); + let pidfile_path = data_dir.join("atuin-daemon.pid"); + let logs_dir = crate::atuin_common::utils::logs_dir(); + + let key_path = data_dir.join("key"); + let meta_path = data_dir.join("meta.db"); + + Ok(Config::builder() + .set_default("history_format", "{time}\t{command}\t{duration}")? + .set_default("db_path", db_path.to_str())? + .set_default("record_store_path", record_store_path.to_str())? + .set_default("key_path", key_path.to_str())? + .set_default("dialect", "us")? + .set_default("timezone", "local")? + .set_default("auto_sync", true)? + .set_default("sync_address", "https://api.atuin.sh")? + .set_default("sync_frequency", "5m")? + .set_default("search_mode", "fuzzy")? + .set_default("filter_mode", None::<String>)? + .set_default("style", "compact")? + .set_default("inline_height", 40)? + .set_default("show_preview", true)? + .set_default("preview.strategy", "auto")? + .set_default("max_preview_height", 4)? + .set_default("show_help", true)? + .set_default("show_tabs", true)? + .set_default("show_numeric_shortcuts", true)? + .set_default("auto_hide_height", 8)? + .set_default("invert", false)? + .set_default("exit_mode", "return-original")? + .set_default("word_jump_mode", "emacs")? + .set_default( + "word_chars", + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", + )? + .set_default("scroll_context_lines", 1)? + .set_default("shell_up_key_binding", false)? + .set_default("workspaces", false)? + .set_default("ctrl_n_shortcuts", false)? + .set_default("secrets_filter", true)? + .set_default("strip_trailing_whitespace", true)? + .set_default("network_connect_timeout", 5)? + .set_default("network_timeout", 30)? + .set_default("local_timeout", 2.0)? + // enter_accept defaults to false here, but true in the default config file. The dissonance is + // intentional! + // Existing users will get the default "False", so we don't mess with any potential + // muscle memory. + // New users will get the new default, that is more similar to what they are used to. + .set_default("enter_accept", false)? + .set_default("keys.scroll_exits", true)? + .set_default("keys.accept_past_line_end", true)? + .set_default("keys.exit_past_line_start", true)? + .set_default("keys.accept_past_line_start", false)? + .set_default("keys.accept_with_backspace", false)? + .set_default("keys.prefix", "a")? + .set_default("keymap_mode", "emacs")? + .set_default("keymap_mode_shell", "auto")? + .set_default("keymap_cursor", HashMap::<String, String>::new())? + .set_default("smart_sort", false)? + .set_default("command_chaining", false)? + .set_default("store_failed", true)? + .set_default("daemon.sync_frequency", 300)? + .set_default("daemon.enabled", false)? + .set_default("daemon.autostart", false)? + .set_default("daemon.socket_path", socket_path.to_str())? + .set_default("daemon.pidfile_path", pidfile_path.to_str())? + .set_default("daemon.systemd_socket", false)? + .set_default("daemon.tcp_port", 8889)? + .set_default("logs.enabled", true)? + .set_default("logs.dir", logs_dir.to_str())? + .set_default("logs.level", "info")? + .set_default("logs.search.file", "search.log")? + .set_default("logs.daemon.file", "daemon.log")? + .set_default("logs.ai.file", "ai.log")? + .set_default("kv.db_path", kv_path.to_str())? + .set_default("scripts.db_path", scripts_path.to_str())? + .set_default("search.recency_score_multiplier", 1.0)? + .set_default("search.frequency_score_multiplier", 1.0)? + .set_default("search.frecency_score_multiplier", 1.0)? + .set_default("meta.db_path", meta_path.to_str())? + .set_default("ai.db_path", ai_sessions_path.to_str())? + .set_default("ai.session_continue_minutes", 60)? + .set_default("ai.send_cwd", false)? + .set_default("ai.opening.send_cwd", false)? + .set_default("ai.opening.send_last_command", false)? + .set_default( + "search.filters", + vec![ + "global", + "host", + "session", + "workspace", + "directory", + "session-preload", + ], + )? + .set_default("theme.name", "default")? + .set_default("theme.debug", None::<bool>)? + .set_default("tmux.enabled", false)? + .set_default("tmux.width", "80%")? + .set_default("tmux.height", "60%")? + .set_default( + "prefers_reduced_motion", + std::env::var("NO_MOTION") + .ok() + .map(|_| config::Value::new(None, config::ValueKind::Boolean(true))) + .unwrap_or_else(|| config::Value::new(None, config::ValueKind::Boolean(false))), + )? + .set_default("no_mouse", false)? + .add_source( + Environment::with_prefix("atuin") + .prefix_separator("_") + .separator("__"), + )) + } + + pub fn get_config_path() -> Result<PathBuf> { + let config_dir = crate::atuin_common::utils::config_dir(); + + create_dir_all(&config_dir) + .wrap_err_with(|| format!("could not create dir {config_dir:?}"))?; + + let mut config_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { + PathBuf::from(p) + } else { + let mut config_file = PathBuf::new(); + config_file.push(config_dir); + config_file + }; + + config_file.push("config.toml"); + + Ok(config_file) + } + + /// Build a merged `Config` from defaults, config file, and environment. + /// + /// This resolves `data_dir`, initializes the data directory on disk, + /// and layers defaults → config file → env overrides. Both `new()` and + /// `get_config_value()` use this so the resolution logic lives in one place. + fn build_config() -> Result<Config> { + let config_file = Self::get_config_path()?; + + // extract data_dir first so we can use it as the base for other path defaults + let effective_data_dir = if config_file.exists() { + #[derive(Deserialize, Default)] + struct DataDirOnly { + data_dir: Option<String>, + } + + let config_file_str = config_file + .to_str() + .ok_or_else(|| eyre!("config file path is not valid UTF-8"))?; + + let partial_config = Config::builder() + .add_source(ConfigFile::new(config_file_str, FileFormat::Toml)) + .add_source( + Environment::with_prefix("atuin") + .prefix_separator("_") + .separator("__"), + ) + .build() + .ok(); + + let custom_data_dir = partial_config + .and_then(|c| c.try_deserialize::<DataDirOnly>().ok()) + .and_then(|d| d.data_dir); + + match custom_data_dir { + Some(dir) => { + let expanded = shellexpand::full(&dir) + .map_err(|e| eyre!("failed to expand data_dir path: {}", e))?; + PathBuf::from(expanded.as_ref()) + } + None => crate::atuin_common::utils::data_dir(), + } + } else { + crate::atuin_common::utils::data_dir() + }; + + DATA_DIR.set(effective_data_dir.clone()).ok(); + + create_dir_all(&effective_data_dir) + .wrap_err_with(|| format!("could not create dir {effective_data_dir:?}"))?; + + let mut config_builder = Self::builder_with_data_dir(&effective_data_dir)?; + + config_builder = if config_file.exists() { + let config_file_str = config_file + .to_str() + .ok_or_else(|| eyre!("config file path is not valid UTF-8"))?; + config_builder.add_source(ConfigFile::new(config_file_str, FileFormat::Toml)) + } else { + let mut file = File::create(config_file).wrap_err("could not create config file")?; + + let config = config_builder.build_cloned()?; + // TODO(@bpeetz): Not so sure about this <2026-06-10> + file.write_all(config.cache.to_string().as_bytes()) + .wrap_err("could not write default config file")?; + + config_builder + }; + + // all paths should be expanded + let built = config_builder.build_cloned()?; + config_builder = [ + "db_path", + "record_store_path", + "key_path", + "daemon.socket_path", + "daemon.pidfile_path", + "logs.dir", + "logs.search.file", + "logs.daemon.file", + ] + .iter() + .map(|key| (key, built.get_string(key).unwrap_or_default())) + .filter_map(|(key, value)| match Self::expand_path(value) { + Ok(expanded) => Some((key, expanded)), + Err(e) => { + log::warn!("failed to expand path for {key}: {e}"); + None + } + }) + .fold(config_builder, |builder, (key, value)| { + builder + .set_override(key, value) + .unwrap_or_else(|_| panic!("failed to set absolute path override for {key}")) + }); + + config_builder.build().map_err(Into::into) + } + + /// Look up a single config value by dotted key (e.g. `"daemon.sync_frequency"`). + /// + /// Returns the effective value after merging defaults, config file, and + /// environment — without the side-effects of full `Settings` construction + /// (meta store init, path expansion, etc.). + pub fn get_config_value(key: &str) -> Result<String> { + let config = Self::build_config()?; + let value: config::Value = config + .get(key) + .map_err(|e| eyre!("failed to get config value '{}': {}", key, e))?; + Ok(Self::format_resolved_value(&value, key)) + } + + fn format_resolved_value(value: &config::Value, prefix: &str) -> String { + use config::ValueKind; + + match &value.kind { + ValueKind::Nil => String::new(), + ValueKind::Boolean(b) => b.to_string(), + ValueKind::I64(i) => i.to_string(), + ValueKind::I128(i) => i.to_string(), + ValueKind::U64(u) => u.to_string(), + ValueKind::U128(u) => u.to_string(), + ValueKind::Float(f) => f.to_string(), + ValueKind::String(s) => s.clone(), + ValueKind::Array(arr) => { + let items: Vec<String> = arr + .iter() + .map(|v| Self::format_resolved_value(v, "")) + .collect(); + format!("[{}]", items.join(", ")) + } + ValueKind::Table(map) => { + let mut lines = Vec::new(); + let mut keys: Vec<_> = map.keys().collect(); + keys.sort(); + + for k in keys { + let v = &map[k]; + let full_key = if prefix.is_empty() { + k.clone() + } else { + format!("{}.{}", prefix, k) + }; + + match &v.kind { + ValueKind::Table(_) => { + lines.push(Self::format_resolved_value(v, &full_key)); + } + _ => { + lines.push(format!( + "{} = {}", + full_key, + Self::format_resolved_value(v, "") + )); + } + } + } + + lines.join("\n") + } + } + } + + pub fn new() -> Result<Self> { + let config = Self::build_config()?; + let settings: Settings = config + .try_deserialize() + .map_err(|e| eyre!("failed to deserialize: {}", e))?; + + // Validate UI settings + settings.ui.validate()?; + + // Register meta store config for lazy initialization on first access + META_CONFIG + .set((settings.meta.db_path.clone(), settings.local_timeout)) + .ok(); + + Ok(settings) + } + + fn expand_path(path: String) -> Result<String> { + shellexpand::full(&path) + .map(|p| p.to_string()) + .map_err(|e| eyre!("failed to expand path: {}", e)) + } + + pub fn paths_ok(&self) -> bool { + let paths = [ + &self.db_path, + &self.record_store_path, + &self.key_path, + &self.meta.db_path, + ]; + paths.iter().all(|p| !utils::broken_symlink(p)) + } +} + +impl Default for Settings { + fn default() -> Self { + // if this panics something is very wrong, as the default config + // does not build or deserialize into the settings struct + Self::builder() + .expect("Could not build default") + .build() + .expect("Could not build config") + .try_deserialize() + .expect("Could not deserialize config") + } +} + +/// Initialize the meta store configuration for testing. +/// +/// This should only be used in tests. It allows tests to bypass the normal +/// Settings::new() flow while still being able to use Settings::host_id() +/// and other meta store dependent functions. +/// +/// # Safety +/// This function is not thread-safe with concurrent calls to Settings::new() +/// or other meta store initialization. Only call from tests. +#[doc(hidden)] +pub fn init_meta_config_for_testing(meta_db_path: impl Into<String>, local_timeout: f64) { + META_CONFIG.set((meta_db_path.into(), local_timeout)).ok(); +} + +#[cfg(test)] +pub(crate) fn test_local_timeout() -> f64 { + std::env::var("ATUIN_TEST_LOCAL_TIMEOUT") + .ok() + .and_then(|x| x.parse().ok()) + // this hardcoded value should be replaced by a simple way to get the + // default local_timeout of Settings if possible + .unwrap_or(2.0) +} + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use eyre::Result; + + use super::Timezone; + + #[test] + fn can_parse_offset_timezone_spec() -> Result<()> { + assert_eq!(Timezone::from_str("+02")?.0.as_hms(), (2, 0, 0)); + assert_eq!(Timezone::from_str("-04")?.0.as_hms(), (-4, 0, 0)); + assert_eq!(Timezone::from_str("+05:30")?.0.as_hms(), (5, 30, 0)); + assert_eq!(Timezone::from_str("-09:30")?.0.as_hms(), (-9, -30, 0)); + + // single digit hours are allowed + assert_eq!(Timezone::from_str("+2")?.0.as_hms(), (2, 0, 0)); + assert_eq!(Timezone::from_str("-4")?.0.as_hms(), (-4, 0, 0)); + assert_eq!(Timezone::from_str("+5:30")?.0.as_hms(), (5, 30, 0)); + assert_eq!(Timezone::from_str("-9:30")?.0.as_hms(), (-9, -30, 0)); + + // fully qualified form + assert_eq!(Timezone::from_str("+09:30:00")?.0.as_hms(), (9, 30, 0)); + assert_eq!(Timezone::from_str("-09:30:00")?.0.as_hms(), (-9, -30, 0)); + + // these offsets don't really exist but are supported anyway + assert_eq!(Timezone::from_str("+0:5")?.0.as_hms(), (0, 5, 0)); + assert_eq!(Timezone::from_str("-0:5")?.0.as_hms(), (0, -5, 0)); + assert_eq!(Timezone::from_str("+01:23:45")?.0.as_hms(), (1, 23, 45)); + assert_eq!(Timezone::from_str("-01:23:45")?.0.as_hms(), (-1, -23, -45)); + + // require a leading sign for clarity + assert!(Timezone::from_str("5").is_err()); + assert!(Timezone::from_str("10:30").is_err()); + + Ok(()) + } + + #[test] + fn can_choose_workspace_filters_when_in_git_context() -> Result<()> { + let mut settings = super::Settings::default(); + settings.search.filters = vec![ + super::FilterMode::Workspace, + super::FilterMode::Host, + super::FilterMode::Directory, + super::FilterMode::Session, + super::FilterMode::Global, + ]; + settings.workspaces = true; + + assert_eq!( + settings.default_filter_mode(true), + super::FilterMode::Workspace, + ); + + Ok(()) + } + + #[test] + fn wont_choose_workspace_filters_when_not_in_git_context() -> Result<()> { + let mut settings = super::Settings::default(); + settings.search.filters = vec![ + super::FilterMode::Workspace, + super::FilterMode::Host, + super::FilterMode::Directory, + super::FilterMode::Session, + super::FilterMode::Global, + ]; + settings.workspaces = true; + + assert_eq!(settings.default_filter_mode(false), super::FilterMode::Host,); + + Ok(()) + } + + #[test] + fn wont_choose_workspace_filters_when_workspaces_disabled() -> Result<()> { + let mut settings = super::Settings::default(); + settings.search.filters = vec![ + super::FilterMode::Workspace, + super::FilterMode::Host, + super::FilterMode::Directory, + super::FilterMode::Session, + super::FilterMode::Global, + ]; + settings.workspaces = false; + + assert_eq!(settings.default_filter_mode(true), super::FilterMode::Host,); + + Ok(()) + } + + #[test] + fn builder_with_data_dir_uses_custom_paths() -> Result<()> { + use std::path::PathBuf; + + let custom_dir = PathBuf::from("/custom/data/dir"); + let builder = super::Settings::builder_with_data_dir(&custom_dir)?; + let config = builder.build()?; + + let db_path: String = config.get("db_path")?; + let key_path: String = config.get("key_path")?; + let record_store_path: String = config.get("record_store_path")?; + let kv_db_path: String = config.get("kv.db_path")?; + let scripts_db_path: String = config.get("scripts.db_path")?; + let meta_db_path: String = config.get("meta.db_path")?; + let daemon_socket_path: String = config.get("daemon.socket_path")?; + let daemon_pidfile_path: String = config.get("daemon.pidfile_path")?; + let daemon_autostart: bool = config.get("daemon.autostart")?; + + assert_eq!(db_path, custom_dir.join("history.db").to_str().unwrap()); + assert_eq!(key_path, custom_dir.join("key").to_str().unwrap()); + assert_eq!( + record_store_path, + custom_dir.join("records.db").to_str().unwrap() + ); + assert_eq!(kv_db_path, custom_dir.join("kv.db").to_str().unwrap()); + assert_eq!( + scripts_db_path, + custom_dir.join("scripts.db").to_str().unwrap() + ); + assert_eq!(meta_db_path, custom_dir.join("meta.db").to_str().unwrap()); + assert_eq!( + daemon_socket_path, + crate::atuin_common::utils::runtime_dir() + .join("atuin.sock") + .to_str() + .unwrap() + ); + assert_eq!( + daemon_pidfile_path, + custom_dir.join("atuin-daemon.pid").to_str().unwrap() + ); + assert!(!daemon_autostart); + + Ok(()) + } + + #[test] + fn effective_data_dir_returns_default_when_not_set() { + let effective = super::Settings::effective_data_dir(); + let default = crate::atuin_common::utils::data_dir(); + + assert!(effective.to_str().is_some()); + assert!(effective.ends_with("atuin") || effective == default); + } + + #[test] + fn keymap_config_deserializes_simple_binding() { + let json = r#"{"emacs": {"ctrl-c": "exit"}}"#; + let config: super::KeymapConfig = serde_json::from_str(json).unwrap(); + assert_eq!(config.emacs.len(), 1); + match &config.emacs["ctrl-c"] { + super::KeyBindingConfig::Simple(s) => assert_eq!(s, "exit"), + _ => panic!("expected Simple variant"), + } + } + + #[test] + fn keymap_config_deserializes_conditional_binding() { + let json = r#"{ + "emacs": { + "left": [ + {"when": "cursor-at-start", "action": "exit"}, + {"action": "cursor-left"} + ] + } + }"#; + let config: super::KeymapConfig = serde_json::from_str(json).unwrap(); + match &config.emacs["left"] { + super::KeyBindingConfig::Rules(rules) => { + assert_eq!(rules.len(), 2); + assert_eq!(rules[0].when.as_deref(), Some("cursor-at-start")); + assert_eq!(rules[0].action, "exit"); + assert!(rules[1].when.is_none()); + assert_eq!(rules[1].action, "cursor-left"); + } + _ => panic!("expected Rules variant"), + } + } + + #[test] + fn keymap_config_deserializes_vim_normal() { + let json = r#"{"vim-normal": {"j": "select-next", "k": "select-previous"}}"#; + let config: super::KeymapConfig = serde_json::from_str(json).unwrap(); + assert_eq!(config.vim_normal.len(), 2); + assert!(config.emacs.is_empty()); + } + + #[test] + fn keymap_config_is_empty_when_default() { + let config = super::KeymapConfig::default(); + assert!(config.is_empty()); + } + + #[test] + fn keymap_config_mixed_modes() { + let json = r#"{ + "emacs": {"ctrl-c": "exit"}, + "vim-normal": {"q": "exit"}, + "inspector": {"d": "delete"} + }"#; + let config: super::KeymapConfig = serde_json::from_str(json).unwrap(); + assert!(!config.is_empty()); + assert_eq!(config.emacs.len(), 1); + assert_eq!(config.vim_normal.len(), 1); + assert_eq!(config.inspector.len(), 1); + assert!(config.vim_insert.is_empty()); + assert!(config.prefix.is_empty()); + } +} diff --git a/crates/turtle/src/atuin_client/settings/meta.rs b/crates/turtle/src/atuin_client/settings/meta.rs new file mode 100644 index 00000000..450d0432 --- /dev/null +++ b/crates/turtle/src/atuin_client/settings/meta.rs @@ -0,0 +1,17 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Settings { + pub db_path: String, +} + +impl Default for Settings { + fn default() -> Self { + let dir = crate::atuin_common::utils::data_dir(); + let path = dir.join("meta.db"); + + Self { + db_path: path.to_string_lossy().to_string(), + } + } +} diff --git a/crates/turtle/src/atuin_client/settings/watcher.rs b/crates/turtle/src/atuin_client/settings/watcher.rs new file mode 100644 index 00000000..7548573e --- /dev/null +++ b/crates/turtle/src/atuin_client/settings/watcher.rs @@ -0,0 +1,256 @@ +//! Config file watching for automatic settings reload. +//! +//! This module provides a `SettingsWatcher` that monitors the config file +//! for changes and broadcasts updated settings via a `tokio::sync::watch` channel. +//! +//! # Example +//! +//! ```no_run +//! use crate::atuin_client::settings::watcher::global_settings_watcher; +//! +//! async fn example() -> eyre::Result<()> { +//! let watcher = global_settings_watcher()?; +//! let mut rx = watcher.subscribe(); +//! +//! // React to settings changes +//! while rx.changed().await.is_ok() { +//! let settings = rx.borrow(); +//! println!("Settings updated!"); +//! } +//! Ok(()) +//! } +//! ``` + +use std::{ + path::PathBuf, + sync::{Arc, OnceLock}, + time::Duration, +}; + +use eyre::{Result, WrapErr}; +use log::{debug, error, info, warn}; +use notify::{ + Config as NotifyConfig, RecommendedWatcher, RecursiveMode, Watcher, + event::{EventKind, ModifyKind}, +}; +use tokio::sync::watch; + +use super::Settings; + +/// Global singleton for the settings watcher. +static SETTINGS_WATCHER: OnceLock<Result<SettingsWatcher, String>> = OnceLock::new(); + +/// Get the global settings watcher singleton. +/// +/// Initializes the watcher on first call. Subsequent calls return the same instance. +/// The watcher monitors the config file for changes and broadcasts updates. +pub fn global_settings_watcher() -> Result<&'static SettingsWatcher> { + let result = SETTINGS_WATCHER.get_or_init(|| SettingsWatcher::new().map_err(|e| e.to_string())); + + match result { + Ok(watcher) => Ok(watcher), + Err(e) => Err(eyre::eyre!("{}", e)), + } +} + +/// Watches the config file for changes and broadcasts updated settings. +/// +/// Uses `notify` for cross-platform file watching and `tokio::sync::watch` +/// for efficient broadcast to multiple subscribers. +pub struct SettingsWatcher { + /// Receiver for settings updates. Clone this to subscribe. + rx: watch::Receiver<Arc<Settings>>, + /// Keeps the file watcher alive for the lifetime of this struct. + _watcher: RecommendedWatcher, +} + +impl SettingsWatcher { + /// Create a new settings watcher. + /// + /// Loads initial settings and starts watching the config file for changes. + /// Changes are debounced (500ms) to avoid multiple reloads during saves. + pub fn new() -> Result<Self> { + let initial_settings = Arc::new(Settings::new()?); + let (tx, rx) = watch::channel(initial_settings); + + let config_path = Self::config_path(); + info!("starting config file watcher: {:?}", config_path); + + let watcher = Self::create_watcher(tx, config_path)?; + + Ok(Self { + rx, + _watcher: watcher, + }) + } + + /// Subscribe to settings updates. + /// + /// Returns a receiver that will be notified when settings change. + /// Use `changed().await` to wait for the next update, then `borrow()` + /// to access the current settings. + pub fn subscribe(&self) -> watch::Receiver<Arc<Settings>> { + self.rx.clone() + } + + /// Get the current settings without subscribing to updates. + pub fn current(&self) -> Arc<Settings> { + self.rx.borrow().clone() + } + + /// Get the config file path. + fn config_path() -> PathBuf { + let config_dir = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { + PathBuf::from(p) + } else { + crate::atuin_common::utils::config_dir() + }; + config_dir.join("config.toml") + } + + /// Create the file watcher with debouncing. + fn create_watcher( + tx: watch::Sender<Arc<Settings>>, + config_path: PathBuf, + ) -> Result<RecommendedWatcher> { + // Channel for debouncing file events + let (debounce_tx, debounce_rx) = std::sync::mpsc::channel::<()>(); + + // Spawn debounce thread + let config_path_clone = config_path.clone(); + std::thread::spawn(move || { + Self::debounce_loop(debounce_rx, tx, config_path_clone); + }); + + // Clone config_path for use in the watcher callback + let config_path_for_watcher = config_path.clone(); + + // Canonicalize config path for reliable comparison on macOS + // (handles symlinks like /var -> /private/var) + let canonical_config_path = config_path_for_watcher + .canonicalize() + .unwrap_or_else(|_| config_path_for_watcher.clone()); + + // Create file watcher + let mut watcher = RecommendedWatcher::new( + move |res: Result<notify::Event, notify::Error>| { + match res { + Ok(event) => { + // Defensive: if paths is empty, we can't filter, so assume + // it might be our config file and trigger a reload to be safe + if event.paths.is_empty() { + warn!( + "config watcher: event has no paths, triggering reload to be safe" + ); + let _ = debounce_tx.send(()); + return; + } + + // Only react to events for our specific config file + // (filter out editor temp files, backups, etc.) + let is_config_file = event.paths.iter().any(|path| { + // Canonicalize for reliable comparison (handles macOS symlinks) + let canonical_event_path = + path.canonicalize().unwrap_or_else(|_| path.clone()); + + // Check if this event is for our config file + // (either exact match or the file was renamed to our config) + canonical_event_path == canonical_config_path + || path.file_name() == config_path_for_watcher.file_name() + }); + + if !is_config_file { + return; + } + + // Only react to modify events (content changes) or creates + if matches!( + event.kind, + EventKind::Modify(ModifyKind::Data(_) | ModifyKind::Any) + | EventKind::Create(_) + ) { + debug!("config file event detected: {:?}", event); + // Send to debounce channel (ignore send errors - receiver might be gone) + let _ = debounce_tx.send(()); + } + } + Err(e) => { + error!("file watcher error: {}", e); + } + } + }, + NotifyConfig::default(), + ) + .wrap_err("failed to create file watcher")?; + + // Watch the config file's parent directory (some editors create new files) + let watch_path = config_path.parent().unwrap_or(&config_path); + + // Defensive: ensure watch path exists before trying to watch + if !watch_path.exists() { + warn!( + "config directory does not exist, creating it: {:?}", + watch_path + ); + std::fs::create_dir_all(watch_path) + .wrap_err_with(|| format!("failed to create config directory: {:?}", watch_path))?; + } + + watcher + .watch(watch_path, RecursiveMode::NonRecursive) + .wrap_err_with(|| format!("failed to watch config directory: {:?}", watch_path))?; + + info!("config file watcher initialized for: {:?}", watch_path); + Ok(watcher) + } + + /// Debounce loop that batches file events and reloads settings. + fn debounce_loop( + rx: std::sync::mpsc::Receiver<()>, + tx: watch::Sender<Arc<Settings>>, + config_path: PathBuf, + ) { + const DEBOUNCE_DURATION: Duration = Duration::from_millis(500); + + loop { + // Wait for first event + if rx.recv().is_err() { + // Channel closed, watcher was dropped + debug!("config watcher debounce loop exiting"); + return; + } + + // Drain any additional events within debounce window + while rx.recv_timeout(DEBOUNCE_DURATION).is_ok() { + // Keep draining + } + + // Defensive: check if config file exists before reloading + // (handles case where file was deleted - we'll get notified when it's recreated) + if !config_path.exists() { + debug!( + "config file does not exist, skipping reload: {:?}", + config_path + ); + continue; + } + + // Now reload settings + info!("config file changed, reloading settings: {:?}", config_path); + match Settings::new() { + Ok(settings) => { + if tx.send(Arc::new(settings)).is_err() { + // All receivers dropped + debug!("all settings subscribers dropped, exiting"); + return; + } + info!("settings reloaded successfully"); + } + Err(e) => { + warn!("failed to reload settings: {}", e); + // Keep the old settings, don't broadcast the error + } + } + } + } +} diff --git a/crates/turtle/src/atuin_client/sync.rs b/crates/turtle/src/atuin_client/sync.rs new file mode 100644 index 00000000..361b6636 --- /dev/null +++ b/crates/turtle/src/atuin_client/sync.rs @@ -0,0 +1,214 @@ +use std::collections::HashSet; +use std::iter::FromIterator; + +use eyre::Result; +use tracing::{debug, info}; + +use crate::atuin_common::api::AddHistoryRequest; +use crypto_secretbox::Key; +use time::OffsetDateTime; + +use crate::atuin_client::{ + api_client, + database::Database, + encryption::{decrypt, encrypt, load_key}, + settings::Settings, +}; + +pub fn hash_str(string: &str) -> String { + use sha2::{Digest, Sha256}; + let mut hasher = Sha256::new(); + hasher.update(string.as_bytes()); + hex::encode(hasher.finalize()) +} + +// Currently sync is kinda naive, and basically just pages backwards through +// history. This means newly added stuff shows up properly! We also just use +// the total count in each database to indicate whether a sync is needed. +// I think this could be massively improved! If we had a way of easily +// indicating count per time period (hour, day, week, year, etc) then we can +// easily pinpoint where we are missing data and what needs downloading. Start +// with year, then find the week, then the day, then the hour, then download it +// all! The current naive approach will do for now. + +// Check if remote has things we don't, and if so, download them. +// Returns (num downloaded, total local) +async fn sync_download( + key: &Key, + force: bool, + client: &api_client::Client<'_>, + db: &impl Database, +) -> Result<(i64, i64)> { + debug!("starting sync download"); + + let remote_status = client.status().await?; + let remote_count = remote_status.count; + + // useful to ensure we don't even save something that hasn't yet been synced + deleted + let remote_deleted = + HashSet::<&str>::from_iter(remote_status.deleted.iter().map(String::as_str)); + + let initial_local = db.history_count(true).await?; + let mut local_count = initial_local; + + let mut last_sync = if force { + OffsetDateTime::UNIX_EPOCH + } else { + Settings::last_sync().await? + }; + + let mut last_timestamp = OffsetDateTime::UNIX_EPOCH; + + let host = if force { Some(String::from("")) } else { None }; + + while remote_count > local_count { + let page = client + .get_history(last_sync, last_timestamp, host.clone()) + .await?; + + let history: Vec<_> = page + .history + .iter() + // TODO: handle deletion earlier in this chain + .map(|h| serde_json::from_str(h).expect("invalid base64")) + .map(|h| decrypt(h, key).expect("failed to decrypt history! check your key")) + .map(|mut h| { + if remote_deleted.contains(h.id.0.as_str()) { + h.deleted_at = Some(time::OffsetDateTime::now_utc()); + h.command = String::from(""); + } + + h + }) + .collect(); + + db.save_bulk(&history).await?; + + local_count = db.history_count(true).await?; + let remote_page_size = std::cmp::max(remote_status.page_size, 0) as usize; + + if history.len() < remote_page_size { + break; + } + + let page_last = history + .last() + .expect("could not get last element of page") + .timestamp; + + // in the case of a small sync frequency, it's possible for history to + // be "lost" between syncs. In this case we need to rewind the sync + // timestamps + if page_last == last_timestamp { + last_timestamp = OffsetDateTime::UNIX_EPOCH; + last_sync -= time::Duration::hours(1); + } else { + last_timestamp = page_last; + } + } + + for i in remote_status.deleted { + // we will update the stored history to have this data + // pretty much everything can be nullified + match db.load(i.as_str()).await? { + Some(h) => { + db.delete(h).await?; + } + _ => { + info!( + "could not delete history with id {}, not found locally", + i.as_str() + ); + } + } + } + + Ok((local_count - initial_local, local_count)) +} + +// Check if we have things remote doesn't, and if so, upload them +async fn sync_upload( + key: &Key, + _force: bool, + client: &api_client::Client<'_>, + db: &impl Database, +) -> Result<()> { + debug!("starting sync upload"); + + let remote_status = client.status().await?; + let remote_deleted: HashSet<String> = HashSet::from_iter(remote_status.deleted.clone()); + + let initial_remote_count = client.count().await?; + let mut remote_count = initial_remote_count; + + let local_count = db.history_count(true).await?; + + debug!("remote has {remote_count}, we have {local_count}"); + + // first just try the most recent set + let mut cursor = OffsetDateTime::now_utc(); + + while local_count > remote_count { + let last = db.before(cursor, remote_status.page_size).await?; + let mut buffer = Vec::new(); + + if last.is_empty() { + break; + } + + for i in last { + let data = encrypt(&i, key)?; + let data = serde_json::to_string(&data)?; + + let add_hist = AddHistoryRequest { + id: i.id.to_string(), + timestamp: i.timestamp, + data, + hostname: hash_str(&i.hostname), + }; + + buffer.push(add_hist); + } + + // anything left over outside of the 100 block size + client.post_history(&buffer).await?; + cursor = buffer.last().unwrap().timestamp; + remote_count = client.count().await?; + + debug!("upload cursor: {cursor:?}"); + } + + let deleted = db.deleted().await?; + + for i in deleted { + if remote_deleted.contains(&i.id.to_string()) { + continue; + } + + info!("deleting {} on remote", i.id); + client.delete_history(i).await?; + } + + Ok(()) +} + +pub async fn sync(settings: &Settings, force: bool, db: &impl Database) -> Result<()> { + let client = api_client::Client::new( + &settings.sync_address, + settings.sync_auth_token().await?, + settings.network_connect_timeout, + settings.network_timeout, + )?; + + Settings::save_sync_time().await?; + + let key = load_key(settings)?; // encryption key + + sync_upload(&key, force, &client, db).await?; + + let download = sync_download(&key, force, &client, db).await?; + + debug!("sync downloaded {}", download.0); + + Ok(()) +} diff --git a/crates/turtle/src/atuin_client/theme.rs b/crates/turtle/src/atuin_client/theme.rs new file mode 100644 index 00000000..1d9c0b9e --- /dev/null +++ b/crates/turtle/src/atuin_client/theme.rs @@ -0,0 +1,831 @@ +use config::{Config, File as ConfigFile, FileFormat}; +use log; +use palette::named; +use serde::{Deserialize, Serialize}; +use serde_json; +use std::collections::HashMap; +use std::error; +use std::io::{Error, ErrorKind}; +use std::path::PathBuf; +use std::sync::LazyLock; +use strum_macros; + +static DEFAULT_MAX_DEPTH: u8 = 10; + +// Collection of settable "meanings" that can have colors set. +// NOTE: You can add a new meaning here without breaking backwards compatibility but please: +// - update the atuin/docs repository, which has a list of available meanings +// - add a fallback in the MEANING_FALLBACKS below, so that themes which do not have it +// get a sensible fallback (see Title as an example) +#[derive( + Serialize, Deserialize, Copy, Clone, Hash, Debug, Eq, PartialEq, strum_macros::Display, +)] +#[strum(serialize_all = "camel_case")] +pub enum Meaning { + AlertInfo, + AlertWarn, + AlertError, + Annotation, + Base, + Guidance, + Important, + Title, + Muted, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ThemeConfig { + // Definition of the theme + pub theme: ThemeDefinitionConfigBlock, + + // Colors + pub colors: HashMap<Meaning, String>, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ThemeDefinitionConfigBlock { + /// Name of theme ("default" for base) + pub name: String, + + /// Whether any theme should be treated as a parent _if available_ + pub parent: Option<String>, +} + +use crossterm::style::{Attribute, Attributes, Color, ContentStyle}; + +// For now, a theme is loaded as a mapping of meanings to colors, but it may be desirable to +// expand that in the future to general styles, so we populate a Meaning->ContentStyle hashmap. +pub struct Theme { + pub name: String, + pub parent: Option<String>, + pub styles: HashMap<Meaning, ContentStyle>, +} + +// Themes have a number of convenience functions for the most commonly used meanings. +// The general purpose `as_style` routine gives back a style, but for ease-of-use and to keep +// theme-related boilerplate minimal, the convenience functions give a color. +impl Theme { + // This is the base "default" color, for general text + pub fn get_base(&self) -> ContentStyle { + self.styles[&Meaning::Base] + } + + pub fn get_info(&self) -> ContentStyle { + self.get_alert(log::Level::Info) + } + + pub fn get_warning(&self) -> ContentStyle { + self.get_alert(log::Level::Warn) + } + + pub fn get_error(&self) -> ContentStyle { + self.get_alert(log::Level::Error) + } + + // The alert meanings may be chosen by the Level enum, rather than the methods above + // or the full Meaning enum, to simplify programmatic selection of a log-level. + pub fn get_alert(&self, severity: log::Level) -> ContentStyle { + self.styles[ALERT_TYPES.get(&severity).unwrap()] + } + + pub fn new( + name: String, + parent: Option<String>, + styles: HashMap<Meaning, ContentStyle>, + ) -> Theme { + Theme { + name, + parent, + styles, + } + } + + pub fn closest_meaning<'a>(&self, meaning: &'a Meaning) -> &'a Meaning { + if self.styles.contains_key(meaning) { + meaning + } else if MEANING_FALLBACKS.contains_key(meaning) { + self.closest_meaning(&MEANING_FALLBACKS[meaning]) + } else { + &Meaning::Base + } + } + + // General access - if you have a meaning, this will give you a (crossterm) style + pub fn as_style(&self, meaning: Meaning) -> ContentStyle { + self.styles[self.closest_meaning(&meaning)] + } + + // Turns a map of meanings to colornames into a theme + // If theme-debug is on, then we will print any colornames that we cannot load, + // but we do not have this on in general, as it could print unfiltered text to the terminal + // from a theme TOML file. However, it will always return a theme, falling back to + // defaults on error, so that a TOML file does not break loading + pub fn from_foreground_colors( + name: String, + parent: Option<&Theme>, + foreground_colors: HashMap<Meaning, String>, + debug: bool, + ) -> Theme { + let styles: HashMap<Meaning, ContentStyle> = foreground_colors + .iter() + .map(|(name, color)| { + ( + *name, + StyleFactory::from_fg_string(color).unwrap_or_else(|err| { + if debug { + log::warn!("Tried to load string as a color unsuccessfully: ({name}={color}) {err}"); + } + ContentStyle::default() + }), + ) + }) + .collect(); + Theme::from_map(name, parent, &styles) + } + + // Boil down a meaning-color hashmap into a theme, by taking the defaults + // for any unknown colors + fn from_map( + name: String, + parent: Option<&Theme>, + overrides: &HashMap<Meaning, ContentStyle>, + ) -> Theme { + let styles = match parent { + Some(theme) => Box::new(theme.styles.clone()), + None => Box::new(DEFAULT_THEME.styles.clone()), + } + .iter() + .map(|(name, color)| match overrides.get(name) { + Some(value) => (*name, *value), + None => (*name, *color), + }) + .collect(); + Theme::new(name, parent.map(|p| p.name.clone()), styles) + } +} + +// Use palette to get a color from a string name, if possible +fn from_string(name: &str) -> Result<Color, String> { + if name.is_empty() { + return Err("Empty string".into()); + } + let first_char = name.chars().next().unwrap(); + match first_char { + '#' => { + let hexcode = &name[1..]; + let vec: Vec<u8> = hexcode + .chars() + .collect::<Vec<char>>() + .chunks(2) + .map(|pair| u8::from_str_radix(pair.iter().collect::<String>().as_str(), 16)) + .filter_map(|n| n.ok()) + .collect(); + if vec.len() != 3 { + return Err("Could not parse 3 hex values from string".into()); + } + Ok(Color::Rgb { + r: vec[0], + g: vec[1], + b: vec[2], + }) + } + '@' => { + // For full flexibility, we need to use serde_json, given + // crossterm's approach. + serde_json::from_str::<Color>(format!("\"{}\"", &name[1..]).as_str()) + .map_err(|_| format!("Could not convert color name {name} to Crossterm color")) + } + _ => { + let srgb = named::from_str(name).ok_or("No such color in palette")?; + Ok(Color::Rgb { + r: srgb.red, + g: srgb.green, + b: srgb.blue, + }) + } + } +} + +pub struct StyleFactory {} + +impl StyleFactory { + fn from_fg_string(name: &str) -> Result<ContentStyle, String> { + match from_string(name) { + Ok(color) => Ok(Self::from_fg_color(color)), + Err(err) => Err(err), + } + } + + // For succinctness, if we are confident that the name will be known, + // this routine is available to keep the code readable + fn known_fg_string(name: &str) -> ContentStyle { + Self::from_fg_string(name).unwrap() + } + + fn from_fg_color(color: Color) -> ContentStyle { + ContentStyle { + foreground_color: Some(color), + ..ContentStyle::default() + } + } + + fn from_fg_color_and_attributes(color: Color, attributes: Attributes) -> ContentStyle { + ContentStyle { + foreground_color: Some(color), + attributes, + ..ContentStyle::default() + } + } +} + +// Built-in themes. Rather than having extra files added before any theming +// is available, this gives a couple of basic options, demonstrating the use +// of themes: autumn and marine +static ALERT_TYPES: LazyLock<HashMap<log::Level, Meaning>> = LazyLock::new(|| { + HashMap::from([ + (log::Level::Info, Meaning::AlertInfo), + (log::Level::Warn, Meaning::AlertWarn), + (log::Level::Error, Meaning::AlertError), + ]) +}); + +static MEANING_FALLBACKS: LazyLock<HashMap<Meaning, Meaning>> = LazyLock::new(|| { + HashMap::from([ + (Meaning::Guidance, Meaning::AlertInfo), + (Meaning::Annotation, Meaning::AlertInfo), + (Meaning::Title, Meaning::Important), + ]) +}); + +static DEFAULT_THEME: LazyLock<Theme> = LazyLock::new(|| { + Theme::new( + "default".to_string(), + None, + HashMap::from([ + ( + Meaning::AlertError, + StyleFactory::from_fg_color(Color::DarkRed), + ), + ( + Meaning::AlertWarn, + StyleFactory::from_fg_color(Color::DarkYellow), + ), + ( + Meaning::AlertInfo, + StyleFactory::from_fg_color(Color::DarkGreen), + ), + ( + Meaning::Annotation, + StyleFactory::from_fg_color(Color::DarkGrey), + ), + ( + Meaning::Guidance, + StyleFactory::from_fg_color(Color::DarkBlue), + ), + ( + Meaning::Important, + StyleFactory::from_fg_color_and_attributes( + Color::White, + Attributes::from(Attribute::Bold), + ), + ), + (Meaning::Muted, StyleFactory::from_fg_color(Color::Grey)), + (Meaning::Base, ContentStyle::default()), + ]), + ) +}); + +static BUILTIN_THEMES: LazyLock<HashMap<&'static str, Theme>> = LazyLock::new(|| { + HashMap::from([ + ("default", HashMap::new()), + ( + "(none)", + HashMap::from([ + (Meaning::AlertError, ContentStyle::default()), + (Meaning::AlertWarn, ContentStyle::default()), + (Meaning::AlertInfo, ContentStyle::default()), + (Meaning::Annotation, ContentStyle::default()), + (Meaning::Guidance, ContentStyle::default()), + (Meaning::Important, ContentStyle::default()), + (Meaning::Muted, ContentStyle::default()), + (Meaning::Base, ContentStyle::default()), + ]), + ), + ( + "autumn", + HashMap::from([ + ( + Meaning::AlertError, + StyleFactory::known_fg_string("saddlebrown"), + ), + ( + Meaning::AlertWarn, + StyleFactory::known_fg_string("darkorange"), + ), + (Meaning::AlertInfo, StyleFactory::known_fg_string("gold")), + ( + Meaning::Annotation, + StyleFactory::from_fg_color(Color::DarkGrey), + ), + (Meaning::Guidance, StyleFactory::known_fg_string("brown")), + ]), + ), + ( + "marine", + HashMap::from([ + ( + Meaning::AlertError, + StyleFactory::known_fg_string("yellowgreen"), + ), + (Meaning::AlertWarn, StyleFactory::known_fg_string("cyan")), + ( + Meaning::AlertInfo, + StyleFactory::known_fg_string("turquoise"), + ), + ( + Meaning::Annotation, + StyleFactory::known_fg_string("steelblue"), + ), + ( + Meaning::Base, + StyleFactory::known_fg_string("lightsteelblue"), + ), + (Meaning::Guidance, StyleFactory::known_fg_string("teal")), + ]), + ), + ]) + .iter() + .map(|(name, theme)| (*name, Theme::from_map(name.to_string(), None, theme))) + .collect() +}); + +// To avoid themes being repeatedly loaded, we store them in a theme manager +pub struct ThemeManager { + loaded_themes: HashMap<String, Theme>, + debug: bool, + override_theme_dir: Option<String>, +} + +// Theme-loading logic +impl ThemeManager { + pub fn new(debug: Option<bool>, theme_dir: Option<String>) -> Self { + Self { + loaded_themes: HashMap::new(), + debug: debug.unwrap_or(false), + override_theme_dir: match theme_dir { + Some(theme_dir) => Some(theme_dir), + None => std::env::var("ATUIN_THEME_DIR").ok(), + }, + } + } + + // Try to load a theme from a `{name}.toml` file in the theme directory. If an override is set + // for the theme dir (via ATUIN_THEME_DIR env) we should load the theme from there + pub fn load_theme_from_file( + &mut self, + name: &str, + max_depth: u8, + ) -> Result<&Theme, Box<dyn error::Error>> { + let mut theme_file = if let Some(p) = &self.override_theme_dir { + if p.is_empty() { + return Err(Box::new(Error::new( + ErrorKind::NotFound, + "Empty theme directory override and could not find theme elsewhere", + ))); + } + PathBuf::from(p) + } else { + let config_dir = crate::atuin_common::utils::config_dir(); + let mut theme_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { + PathBuf::from(p) + } else { + let mut theme_file = PathBuf::new(); + theme_file.push(config_dir); + theme_file + }; + theme_file.push("themes"); + theme_file + }; + + let theme_toml = format!("{name}.toml"); + theme_file.push(theme_toml); + + let mut config_builder = Config::builder(); + + config_builder = config_builder.add_source(ConfigFile::new( + theme_file.to_str().unwrap(), + FileFormat::Toml, + )); + + let config = config_builder.build()?; + self.load_theme_from_config(name, config, max_depth) + } + + pub fn load_theme_from_config( + &mut self, + name: &str, + config: Config, + max_depth: u8, + ) -> Result<&Theme, Box<dyn error::Error>> { + let debug = self.debug; + let theme_config: ThemeConfig = match config.try_deserialize() { + Ok(tc) => tc, + Err(e) => { + return Err(Box::new(Error::new( + ErrorKind::InvalidInput, + format!( + "Failed to deserialize theme: {}", + if debug { + e.to_string() + } else { + "set theme debug on for more info".to_string() + } + ), + ))); + } + }; + let colors: HashMap<Meaning, String> = theme_config.colors; + let parent: Option<&Theme> = match theme_config.theme.parent { + Some(parent_name) => { + if max_depth == 0 { + return Err(Box::new(Error::new( + ErrorKind::InvalidInput, + "Parent requested but we hit the recursion limit", + ))); + } + Some(self.load_theme(parent_name.as_str(), Some(max_depth - 1))) + } + None => Some(self.load_theme("default", Some(max_depth - 1))), + }; + + if debug && name != theme_config.theme.name { + log::warn!( + "Your theme config name is not the name of your loaded theme {} != {}", + name, + theme_config.theme.name + ); + } + + let theme = Theme::from_foreground_colors(theme_config.theme.name, parent, colors, debug); + let name = name.to_string(); + self.loaded_themes.insert(name.clone(), theme); + let theme = self.loaded_themes.get(&name).unwrap(); + Ok(theme) + } + + // Check if the requested theme is loaded and, if not, then attempt to get it + // from the builtins or, if not there, from file + pub fn load_theme(&mut self, name: &str, max_depth: Option<u8>) -> &Theme { + if self.loaded_themes.contains_key(name) { + return self.loaded_themes.get(name).unwrap(); + } + let built_ins = &BUILTIN_THEMES; + match built_ins.get(name) { + Some(theme) => theme, + None => match self.load_theme_from_file(name, max_depth.unwrap_or(DEFAULT_MAX_DEPTH)) { + Ok(theme) => theme, + Err(err) => { + log::warn!("Could not load theme {name}: {err}"); + built_ins.get("(none)").unwrap() + } + }, + } + } +} + +#[cfg(test)] +mod theme_tests { + use super::*; + + #[test] + fn test_can_load_builtin_theme() { + let mut manager = ThemeManager::new(Some(false), Some("".to_string())); + let theme = manager.load_theme("autumn", None); + assert_eq!( + theme.as_style(Meaning::Guidance).foreground_color, + from_string("brown").ok() + ); + } + + #[test] + fn test_can_create_theme() { + let mut manager = ThemeManager::new(Some(false), Some("".to_string())); + let mytheme = Theme::new( + "mytheme".to_string(), + None, + HashMap::from([( + Meaning::AlertError, + StyleFactory::known_fg_string("yellowgreen"), + )]), + ); + manager.loaded_themes.insert("mytheme".to_string(), mytheme); + let theme = manager.load_theme("mytheme", None); + assert_eq!( + theme.as_style(Meaning::AlertError).foreground_color, + from_string("yellowgreen").ok() + ); + } + + #[test] + fn test_can_fallback_when_meaning_missing() { + let mut manager = ThemeManager::new(Some(false), Some("".to_string())); + + // We use title as an example of a meaning that is not defined + // even in the base theme. + assert!(!DEFAULT_THEME.styles.contains_key(&Meaning::Title)); + + let config = Config::builder() + .add_source(ConfigFile::from_str( + " + [theme] + name = \"title_theme\" + + [colors] + Guidance = \"white\" + AlertInfo = \"zomp\" + ", + FileFormat::Toml, + )) + .build() + .unwrap(); + let theme = manager + .load_theme_from_config("config_theme", config, 1) + .unwrap(); + + // Correctly picks overridden color. + assert_eq!( + theme.as_style(Meaning::Guidance).foreground_color, + from_string("white").ok() + ); + + // Does not fall back to any color. + assert_eq!(theme.as_style(Meaning::AlertInfo).foreground_color, None); + + // Even for the base. + assert_eq!(theme.as_style(Meaning::Base).foreground_color, None); + + // Falls back to red as meaning missing from theme, so picks base default. + assert_eq!( + theme.as_style(Meaning::AlertError).foreground_color, + Some(Color::DarkRed) + ); + + // Falls back to Important as Title not available. + assert_eq!( + theme.as_style(Meaning::Title).foreground_color, + theme.as_style(Meaning::Important).foreground_color, + ); + + let title_config = Config::builder() + .add_source(ConfigFile::from_str( + " + [theme] + name = \"title_theme\" + + [colors] + Title = \"white\" + AlertInfo = \"zomp\" + ", + FileFormat::Toml, + )) + .build() + .unwrap(); + let title_theme = manager + .load_theme_from_config("title_theme", title_config, 1) + .unwrap(); + + assert_eq!( + title_theme.as_style(Meaning::Title).foreground_color, + Some(Color::White) + ); + } + + #[test] + fn test_no_fallbacks_are_circular() { + let mytheme = Theme::new("mytheme".to_string(), None, HashMap::from([])); + MEANING_FALLBACKS + .iter() + .for_each(|pair| assert_eq!(mytheme.closest_meaning(pair.0), &Meaning::Base)) + } + + #[test] + fn test_can_get_colors_via_convenience_functions() { + let mut manager = ThemeManager::new(Some(true), Some("".to_string())); + let theme = manager.load_theme("default", None); + assert_eq!(theme.get_error().foreground_color.unwrap(), Color::DarkRed); + assert_eq!( + theme.get_warning().foreground_color.unwrap(), + Color::DarkYellow + ); + assert_eq!(theme.get_info().foreground_color.unwrap(), Color::DarkGreen); + assert_eq!(theme.get_base().foreground_color, None); + assert_eq!( + theme.get_alert(log::Level::Error).foreground_color.unwrap(), + Color::DarkRed + ) + } + + #[test] + fn test_can_use_parent_theme_for_fallbacks() { + testing_logger::setup(); + + let mut manager = ThemeManager::new(Some(false), Some("".to_string())); + + // First, we introduce a base theme + let solarized = Config::builder() + .add_source(ConfigFile::from_str( + " + [theme] + name = \"solarized\" + + [colors] + Guidance = \"white\" + AlertInfo = \"pink\" + ", + FileFormat::Toml, + )) + .build() + .unwrap(); + let solarized_theme = manager + .load_theme_from_config("solarized", solarized, 1) + .unwrap(); + + assert_eq!( + solarized_theme + .as_style(Meaning::AlertInfo) + .foreground_color, + from_string("pink").ok() + ); + + // Then we introduce a derived theme + let unsolarized = Config::builder() + .add_source(ConfigFile::from_str( + " + [theme] + name = \"unsolarized\" + parent = \"solarized\" + + [colors] + AlertInfo = \"red\" + ", + FileFormat::Toml, + )) + .build() + .unwrap(); + let unsolarized_theme = manager + .load_theme_from_config("unsolarized", unsolarized, 1) + .unwrap(); + + // It will take its own values + assert_eq!( + unsolarized_theme + .as_style(Meaning::AlertInfo) + .foreground_color, + from_string("red").ok() + ); + + // ...or fall back to the parent + assert_eq!( + unsolarized_theme + .as_style(Meaning::Guidance) + .foreground_color, + from_string("white").ok() + ); + + testing_logger::validate(|captured_logs| assert_eq!(captured_logs.len(), 0)); + + // If the parent is not found, we end up with the no theme colors or styling + // as this is considered a (soft) error state. + let nunsolarized = Config::builder() + .add_source(ConfigFile::from_str( + " + [theme] + name = \"nunsolarized\" + parent = \"nonsolarized\" + + [colors] + AlertInfo = \"red\" + ", + FileFormat::Toml, + )) + .build() + .unwrap(); + let nunsolarized_theme = manager + .load_theme_from_config("nunsolarized", nunsolarized, 1) + .unwrap(); + + assert_eq!( + nunsolarized_theme + .as_style(Meaning::Guidance) + .foreground_color, + None + ); + + testing_logger::validate(|captured_logs| { + assert_eq!(captured_logs.len(), 1); + assert_eq!( + captured_logs[0].body, + "Could not load theme nonsolarized: Empty theme directory override and could not find theme elsewhere" + ); + assert_eq!(captured_logs[0].level, log::Level::Warn) + }); + } + + #[test] + fn test_can_debug_theme() { + testing_logger::setup(); + [true, false].iter().for_each(|debug| { + let mut manager = ThemeManager::new(Some(*debug), Some("".to_string())); + let config = Config::builder() + .add_source(ConfigFile::from_str( + " + [theme] + name = \"mytheme\" + + [colors] + Guidance = \"white\" + AlertInfo = \"xinetic\" + ", + FileFormat::Toml, + )) + .build() + .unwrap(); + manager + .load_theme_from_config("config_theme", config, 1) + .unwrap(); + testing_logger::validate(|captured_logs| { + if *debug { + assert_eq!(captured_logs.len(), 2); + assert_eq!( + captured_logs[0].body, + "Your theme config name is not the name of your loaded theme config_theme != mytheme" + ); + assert_eq!(captured_logs[0].level, log::Level::Warn); + assert_eq!( + captured_logs[1].body, + "Tried to load string as a color unsuccessfully: (AlertInfo=xinetic) No such color in palette" + ); + assert_eq!(captured_logs[1].level, log::Level::Warn) + } else { + assert_eq!(captured_logs.len(), 0) + } + }) + }) + } + + #[test] + fn test_can_parse_color_strings_correctly() { + assert_eq!( + from_string("brown").unwrap(), + Color::Rgb { + r: 165, + g: 42, + b: 42 + } + ); + + assert_eq!(from_string(""), Err("Empty string".into())); + + ["manatee", "caput mortuum", "123456"] + .iter() + .for_each(|inp| { + assert_eq!(from_string(inp), Err("No such color in palette".into())); + }); + + assert_eq!( + from_string("#ff1122").unwrap(), + Color::Rgb { + r: 255, + g: 17, + b: 34 + } + ); + ["#1122", "#ffaa112", "#brown"].iter().for_each(|inp| { + assert_eq!( + from_string(inp), + Err("Could not parse 3 hex values from string".into()) + ); + }); + + assert_eq!(from_string("@dark_grey").unwrap(), Color::DarkGrey); + assert_eq!( + from_string("@rgb_(255,255,255)").unwrap(), + Color::Rgb { + r: 255, + g: 255, + b: 255 + } + ); + assert_eq!(from_string("@ansi_(255)").unwrap(), Color::AnsiValue(255)); + ["@", "@DarkGray", "@Dark 4ay", "@ansi(256)"] + .iter() + .for_each(|inp| { + assert_eq!( + from_string(inp), + Err(format!( + "Could not convert color name {inp} to Crossterm color" + )) + ); + }); + } +} diff --git a/crates/turtle/src/atuin_client/utils.rs b/crates/turtle/src/atuin_client/utils.rs new file mode 100644 index 00000000..35d7db26 --- /dev/null +++ b/crates/turtle/src/atuin_client/utils.rs @@ -0,0 +1,14 @@ +pub(crate) fn get_hostname() -> String { + std::env::var("ATUIN_HOST_NAME") + .unwrap_or_else(|_| whoami::hostname().unwrap_or_else(|_| "unknown-host".to_string())) +} + +pub(crate) fn get_username() -> String { + std::env::var("ATUIN_HOST_USER") + .unwrap_or_else(|_| whoami::username().unwrap_or_else(|_| "unknown-user".to_string())) +} + +/// Returns a pair of the hostname and username, separated by a colon. +pub(crate) fn get_host_user() -> String { + format!("{}:{}", get_hostname(), get_username()) +} diff --git a/crates/turtle/src/atuin_common/api.rs b/crates/turtle/src/atuin_common/api.rs new file mode 100644 index 00000000..1a9f348c --- /dev/null +++ b/crates/turtle/src/atuin_common/api.rs @@ -0,0 +1,144 @@ +use semver::Version; +use serde::{Deserialize, Serialize}; +use std::borrow::Cow; +use std::sync::LazyLock; +use time::OffsetDateTime; + +// the usage of X- has been deprecated for quite along time, it turns out +pub static ATUIN_HEADER_VERSION: &str = "Atuin-Version"; +pub static ATUIN_CARGO_VERSION: &str = env!("CARGO_PKG_VERSION"); + +pub static ATUIN_VERSION: LazyLock<Version> = + LazyLock::new(|| Version::parse(ATUIN_CARGO_VERSION).expect("failed to parse self semver")); + +#[derive(Debug, Serialize, Deserialize)] +pub struct UserResponse { + pub username: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct RegisterRequest { + pub email: String, + pub username: String, + pub password: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct RegisterResponse { + pub session: String, + /// Auth type: "hub" for Hub API tokens, "cli" for legacy CLI session tokens. + /// Old servers that don't return this field will deserialize as None. + #[serde(default)] + pub auth: Option<String>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct DeleteUserResponse {} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ChangePasswordRequest { + pub current_password: String, + pub new_password: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ChangePasswordResponse {} + +#[derive(Debug, Serialize, Deserialize)] +pub struct LoginRequest { + pub username: String, + pub password: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct LoginResponse { + pub session: String, + /// Auth type: "hub" for Hub API tokens, "cli" for legacy CLI session tokens. + /// Old servers that don't return this field will deserialize as None. + #[serde(default)] + pub auth: Option<String>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct AddHistoryRequest { + pub id: String, + #[serde(with = "time::serde::rfc3339")] + pub timestamp: OffsetDateTime, + pub data: String, + pub hostname: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CountResponse { + pub count: i64, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SyncHistoryRequest { + #[serde(with = "time::serde::rfc3339")] + pub sync_ts: OffsetDateTime, + #[serde(with = "time::serde::rfc3339")] + pub history_ts: OffsetDateTime, + pub host: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SyncHistoryResponse { + pub history: Vec<String>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ErrorResponse<'a> { + pub reason: Cow<'a, str>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct IndexResponse { + pub homage: String, + pub version: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct StatusResponse { + pub count: i64, + pub username: String, + pub deleted: Vec<String>, + + // These could/should also go on the index of the server + // However, we do not request the server index as a part of normal sync + // I'd rather slightly increase the size of this response, than add an extra HTTP request + pub page_size: i64, // max page size supported by the server + pub version: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct DeleteHistoryRequest { + pub client_id: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct MessageResponse { + pub message: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct MeResponse { + pub username: String, +} + +// Hub CLI authentication types + +/// Response from POST /auth/cli/code - generates a code for CLI auth +#[derive(Debug, Serialize, Deserialize)] +pub struct CliCodeResponse { + pub code: String, +} + +/// Response from GET /auth/cli/verify?code=<code> - polls for authorization +#[derive(Debug, Serialize, Deserialize)] +pub struct CliVerifyResponse { + /// Session token, present only when authorization is complete + pub token: Option<String>, + pub success: Option<bool>, + pub error: Option<String>, +} diff --git a/crates/turtle/src/atuin_common/calendar.rs b/crates/turtle/src/atuin_common/calendar.rs new file mode 100644 index 00000000..d3b1d921 --- /dev/null +++ b/crates/turtle/src/atuin_common/calendar.rs @@ -0,0 +1,16 @@ +// Calendar data +use serde::{Serialize, Deserialize}; + +pub enum TimePeriod { + YEAR, + MONTH, + DAY, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TimePeriodInfo { + pub count: u64, + + // TODO: Use this for merkle tree magic + pub hash: String, +} diff --git a/crates/turtle/src/atuin_common/mod.rs b/crates/turtle/src/atuin_common/mod.rs new file mode 100644 index 00000000..d886520d --- /dev/null +++ b/crates/turtle/src/atuin_common/mod.rs @@ -0,0 +1,58 @@ +/// Defines a new UUID type wrapper +macro_rules! new_uuid { + ($name:ident) => { + #[derive( + Debug, + Copy, + Clone, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + serde::Serialize, + serde::Deserialize, + )] + #[serde(transparent)] + pub struct $name(pub Uuid); + + impl<DB: sqlx::Database> sqlx::Type<DB> for $name + where + Uuid: sqlx::Type<DB>, + { + fn type_info() -> <DB as sqlx::Database>::TypeInfo { + Uuid::type_info() + } + } + + impl<'r, DB: sqlx::Database> sqlx::Decode<'r, DB> for $name + where + Uuid: sqlx::Decode<'r, DB>, + { + fn decode( + value: DB::ValueRef<'r>, + ) -> std::result::Result<Self, sqlx::error::BoxDynError> { + Uuid::decode(value).map(Self) + } + } + + impl<'q, DB: sqlx::Database> sqlx::Encode<'q, DB> for $name + where + Uuid: sqlx::Encode<'q, DB>, + { + fn encode_by_ref( + &self, + buf: &mut DB::ArgumentBuffer<'q>, + ) -> Result<sqlx::encode::IsNull, Box<dyn std::error::Error + Send + Sync + 'static>> + { + self.0.encode_by_ref(buf) + } + } + }; +} + +pub mod api; +pub mod record; +pub mod shell; +pub mod tls; +pub mod utils; diff --git a/crates/turtle/src/atuin_common/record.rs b/crates/turtle/src/atuin_common/record.rs new file mode 100644 index 00000000..05c29338 --- /dev/null +++ b/crates/turtle/src/atuin_common/record.rs @@ -0,0 +1,426 @@ +use std::collections::HashMap; + +use eyre::Result; +use serde::{Deserialize, Serialize}; +use typed_builder::TypedBuilder; +use uuid::Uuid; + +#[derive(Clone, Debug, PartialEq)] +pub struct DecryptedData(pub Vec<u8>); + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct EncryptedData { + pub data: String, + pub content_encryption_key: String, +} + +#[derive(Debug, PartialEq, PartialOrd, Ord, Eq)] +pub struct Diff { + pub host: HostId, + pub tag: String, + pub local: Option<RecordIdx>, + pub remote: Option<RecordIdx>, +} + +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] +pub struct Host { + pub id: HostId, + pub name: String, +} + +impl Host { + pub fn new(id: HostId) -> Self { + Host { + id, + name: String::new(), + } + } +} + +new_uuid!(RecordId); +new_uuid!(HostId); + +pub type RecordIdx = u64; + +/// A single record stored inside of our local database +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TypedBuilder)] +pub struct Record<Data> { + /// a unique ID + #[builder(default = RecordId(crate::atuin_common::utils::uuid_v7()))] + pub id: RecordId, + + /// The integer record ID. This is only unique per (host, tag). + pub idx: RecordIdx, + + /// The unique ID of the host. + // TODO(ellie): Optimize the storage here. We use a bunch of IDs, and currently store + // as strings. I would rather avoid normalization, so store as UUID binary instead of + // encoding to a string and wasting much more storage. + pub host: Host, + + /// The creation time in nanoseconds since unix epoch + #[builder(default = time::OffsetDateTime::now_utc().unix_timestamp_nanos() as u64)] + pub timestamp: u64, + + /// The version the data in the entry conforms to + // However we want to track versions for this tag, eg v2 + pub version: String, + + /// The type of data we are storing here. Eg, "history" + pub tag: String, + + /// Some data. This can be anything you wish to store. Use the tag field to know how to handle it. + pub data: Data, +} + +/// Extra data from the record that should be encoded in the data +#[derive(Debug, Copy, Clone)] +pub struct AdditionalData<'a> { + pub id: &'a RecordId, + pub idx: &'a u64, + pub version: &'a str, + pub tag: &'a str, + pub host: &'a HostId, +} + +impl<Data> Record<Data> { + pub fn append(&self, data: Vec<u8>) -> Record<DecryptedData> { + Record::builder() + .host(self.host.clone()) + .version(self.version.clone()) + .idx(self.idx + 1) + .tag(self.tag.clone()) + .data(DecryptedData(data)) + .build() + } +} + +/// An index representing the current state of the record stores +/// This can be both remote, or local, and compared in either direction +#[derive(Debug, Serialize, Deserialize)] +pub struct RecordStatus { + // A map of host -> tag -> max(idx) + pub hosts: HashMap<HostId, HashMap<String, RecordIdx>>, +} + +impl Default for RecordStatus { + fn default() -> Self { + Self::new() + } +} + +impl Extend<(HostId, String, RecordIdx)> for RecordStatus { + fn extend<T: IntoIterator<Item = (HostId, String, RecordIdx)>>(&mut self, iter: T) { + for (host, tag, tail_idx) in iter { + self.set_raw(host, tag, tail_idx); + } + } +} + +impl RecordStatus { + pub fn new() -> RecordStatus { + RecordStatus { + hosts: HashMap::new(), + } + } + + /// Insert a new tail record into the store + pub fn set(&mut self, tail: Record<DecryptedData>) { + self.set_raw(tail.host.id, tail.tag, tail.idx) + } + + pub fn set_raw(&mut self, host: HostId, tag: String, tail_id: RecordIdx) { + self.hosts.entry(host).or_default().insert(tag, tail_id); + } + + pub fn get(&self, host: HostId, tag: String) -> Option<RecordIdx> { + self.hosts.get(&host).and_then(|v| v.get(&tag)).cloned() + } + + /// Diff this index with another, likely remote index. + /// The two diffs can then be reconciled, and the optimal change set calculated + /// Returns a tuple, with (host, tag, Option(OTHER)) + /// OTHER is set to the value of the idx on the other machine. If it is greater than our index, + /// then we need to do some downloading. If it is smaller, then we need to do some uploading + /// Note that we cannot upload if we are not the owner of the record store - hosts can only + /// write to their own store. + pub fn diff(&self, other: &Self) -> Vec<Diff> { + let mut ret = Vec::new(); + + // First, we check if other has everything that self has + for (host, tag_map) in self.hosts.iter() { + for (tag, idx) in tag_map.iter() { + match other.get(*host, tag.clone()) { + // The other store is all up to date! No diff. + Some(t) if t.eq(idx) => continue, + + // The other store does exist, and it is either ahead or behind us. A diff regardless + Some(t) => ret.push(Diff { + host: *host, + tag: tag.clone(), + local: Some(*idx), + remote: Some(t), + }), + + // The other store does not exist :O + None => ret.push(Diff { + host: *host, + tag: tag.clone(), + local: Some(*idx), + remote: None, + }), + }; + } + } + + // At this point, there is a single case we have not yet considered. + // If the other store knows of a tag that we are not yet aware of, then the diff will be missed + + // account for that! + for (host, tag_map) in other.hosts.iter() { + for (tag, idx) in tag_map.iter() { + match self.get(*host, tag.clone()) { + // If we have this host/tag combo, the comparison and diff will have already happened above + Some(_) => continue, + + None => ret.push(Diff { + host: *host, + tag: tag.clone(), + remote: Some(*idx), + local: None, + }), + }; + } + } + + // Stability is a nice property to have + ret.sort(); + ret + } +} + +pub trait Encryption { + fn re_encrypt( + data: EncryptedData, + ad: AdditionalData, + old_key: &[u8; 32], + new_key: &[u8; 32], + ) -> Result<EncryptedData> { + let data = Self::decrypt(data, ad, old_key)?; + Ok(Self::encrypt(data, ad, new_key)) + } + fn encrypt(data: DecryptedData, ad: AdditionalData, key: &[u8; 32]) -> EncryptedData; + fn decrypt(data: EncryptedData, ad: AdditionalData, key: &[u8; 32]) -> Result<DecryptedData>; +} + +impl Record<DecryptedData> { + pub fn encrypt<E: Encryption>(self, key: &[u8; 32]) -> Record<EncryptedData> { + let ad = AdditionalData { + id: &self.id, + version: &self.version, + tag: &self.tag, + host: &self.host.id, + idx: &self.idx, + }; + Record { + data: E::encrypt(self.data, ad, key), + id: self.id, + host: self.host, + idx: self.idx, + timestamp: self.timestamp, + version: self.version, + tag: self.tag, + } + } +} + +impl Record<EncryptedData> { + pub fn decrypt<E: Encryption>(self, key: &[u8; 32]) -> Result<Record<DecryptedData>> { + let ad = AdditionalData { + id: &self.id, + version: &self.version, + tag: &self.tag, + host: &self.host.id, + idx: &self.idx, + }; + Ok(Record { + data: E::decrypt(self.data, ad, key)?, + id: self.id, + host: self.host, + idx: self.idx, + timestamp: self.timestamp, + version: self.version, + tag: self.tag, + }) + } + + pub fn re_encrypt<E: Encryption>( + self, + old_key: &[u8; 32], + new_key: &[u8; 32], + ) -> Result<Record<EncryptedData>> { + let ad = AdditionalData { + id: &self.id, + version: &self.version, + tag: &self.tag, + host: &self.host.id, + idx: &self.idx, + }; + Ok(Record { + data: E::re_encrypt(self.data, ad, old_key, new_key)?, + id: self.id, + host: self.host, + idx: self.idx, + timestamp: self.timestamp, + version: self.version, + tag: self.tag, + }) + } +} + +#[cfg(test)] +mod tests { + use crate::atuin_common::record::{Host, HostId}; + + use super::{DecryptedData, Diff, Record, RecordStatus}; + use pretty_assertions::assert_eq; + + fn test_record() -> Record<DecryptedData> { + Record::builder() + .host(Host::new(HostId(crate::atuin_common::utils::uuid_v7()))) + .version("v1".into()) + .tag(crate::atuin_common::utils::uuid_v7().simple().to_string()) + .data(DecryptedData(vec![0, 1, 2, 3])) + .idx(0) + .build() + } + + #[test] + fn record_index() { + let mut index = RecordStatus::new(); + let record = test_record(); + + index.set(record.clone()); + + let tail = index.get(record.host.id, record.tag); + + assert_eq!( + record.idx, + tail.expect("tail not in store"), + "tail in store did not match" + ); + } + + #[test] + fn record_index_overwrite() { + let mut index = RecordStatus::new(); + let record = test_record(); + let child = record.append(vec![1, 2, 3]); + + index.set(record.clone()); + index.set(child.clone()); + + let tail = index.get(record.host.id, record.tag); + + assert_eq!( + child.idx, + tail.expect("tail not in store"), + "tail in store did not match" + ); + } + + #[test] + fn record_index_no_diff() { + // Here, they both have the same version and should have no diff + + let mut index1 = RecordStatus::new(); + let mut index2 = RecordStatus::new(); + + let record1 = test_record(); + + index1.set(record1.clone()); + index2.set(record1); + + let diff = index1.diff(&index2); + + assert_eq!(0, diff.len(), "expected empty diff"); + } + + #[test] + fn record_index_single_diff() { + // Here, they both have the same stores, but one is ahead by a single record + + let mut index1 = RecordStatus::new(); + let mut index2 = RecordStatus::new(); + + let record1 = test_record(); + let record2 = record1.append(vec![1, 2, 3]); + + index1.set(record1); + index2.set(record2.clone()); + + let diff = index1.diff(&index2); + + assert_eq!(1, diff.len(), "expected single diff"); + assert_eq!( + diff[0], + Diff { + host: record2.host.id, + tag: record2.tag, + remote: Some(1), + local: Some(0) + } + ); + } + + #[test] + fn record_index_multi_diff() { + // A much more complex case, with a bunch more checks + let mut index1 = RecordStatus::new(); + let mut index2 = RecordStatus::new(); + + let store1record1 = test_record(); + let store1record2 = store1record1.append(vec![1, 2, 3]); + + let store2record1 = test_record(); + let store2record2 = store2record1.append(vec![1, 2, 3]); + + let store3record1 = test_record(); + + let store4record1 = test_record(); + + // index1 only knows about the first two entries of the first two stores + index1.set(store1record1); + index1.set(store2record1); + + // index2 is fully up to date with the first two stores, and knows of a third + index2.set(store1record2); + index2.set(store2record2); + index2.set(store3record1); + + // index1 knows of a 4th store + index1.set(store4record1); + + let diff1 = index1.diff(&index2); + let diff2 = index2.diff(&index1); + + // both diffs the same length + assert_eq!(4, diff1.len()); + assert_eq!(4, diff2.len()); + + dbg!(&diff1, &diff2); + + // both diffs should be ALMOST the same. They will agree on which hosts and tags + // require updating, but the "other" value will not be the same. + let smol_diff_1: Vec<(HostId, String)> = + diff1.iter().map(|v| (v.host, v.tag.clone())).collect(); + let smol_diff_2: Vec<(HostId, String)> = + diff1.iter().map(|v| (v.host, v.tag.clone())).collect(); + + assert_eq!(smol_diff_1, smol_diff_2); + + // diffing with yourself = no diff + assert_eq!(index1.diff(&index1).len(), 0); + assert_eq!(index2.diff(&index2).len(), 0); + } +} diff --git a/crates/turtle/src/atuin_common/shell.rs b/crates/turtle/src/atuin_common/shell.rs new file mode 100644 index 00000000..7f9a7b8f --- /dev/null +++ b/crates/turtle/src/atuin_common/shell.rs @@ -0,0 +1,183 @@ +use std::{ffi::OsStr, path::Path, process::Command}; + +use serde::Serialize; +use sysinfo::{Process, System, get_current_pid}; +use thiserror::Error; + +#[derive(PartialEq)] +pub enum Shell { + Sh, + Bash, + Fish, + Zsh, + Xonsh, + Nu, + Powershell, + + Unknown, +} + +impl std::fmt::Display for Shell { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let shell = match self { + Shell::Bash => "bash", + Shell::Fish => "fish", + Shell::Zsh => "zsh", + Shell::Nu => "nu", + Shell::Xonsh => "xonsh", + Shell::Sh => "sh", + Shell::Powershell => "powershell", + + Shell::Unknown => "unknown", + }; + + write!(f, "{shell}") + } +} + +#[derive(Debug, Error, Serialize)] +pub enum ShellError { + #[error("shell not supported")] + NotSupported, + + #[error("failed to execute shell command: {0}")] + ExecError(String), +} + +impl Shell { + pub fn current() -> Shell { + let sys = System::new_all(); + + let process = sys + .process(get_current_pid().expect("Failed to get current PID")) + .expect("Process with current pid does not exist"); + + let parent = sys + .process(process.parent().expect("Atuin running with no parent!")) + .expect("Process with parent pid does not exist"); + + let shell = parent.name().trim().to_lowercase(); + let shell = shell.strip_prefix('-').unwrap_or(&shell); + + Shell::from_string(shell.to_string()) + } + + pub fn from_env() -> Shell { + std::env::var("ATUIN_SHELL").map_or(Shell::Unknown, |shell| { + Shell::from_string(shell.trim().to_lowercase()) + }) + } + + pub fn config_file(&self) -> Option<std::path::PathBuf> { + let mut path = if let Some(base) = directories::BaseDirs::new() { + base.home_dir().to_owned() + } else { + return None; + }; + + // TODO: handle all shells + match self { + Shell::Bash => path.push(".bashrc"), + Shell::Zsh => path.push(".zshrc"), + Shell::Fish => path.push(".config/fish/config.fish"), + + _ => return None, + }; + + Some(path) + } + + /// Best-effort attempt to determine the default shell + /// This implementation will be different across different platforms + /// Caller should ensure to handle Shell::Unknown correctly + pub fn default_shell() -> Result<Shell, ShellError> { + let sys = System::name().unwrap_or("".to_string()).to_lowercase(); + + // TODO: Support Linux + // I'm pretty sure we can use /etc/passwd there, though there will probably be some issues + let path = if sys.contains("darwin") { + // This works in my testing so far + Shell::Sh.run_interactive([ + "dscl localhost -read \"/Local/Default/Users/$USER\" shell | awk '{print $2}'", + ])? + } else if cfg!(windows) { + return Ok(Shell::Powershell); + } else { + Shell::Sh.run_interactive(["getent passwd $LOGNAME | cut -d: -f7"])? + }; + + let path = Path::new(path.trim()); + let shell = path.file_name(); + + if shell.is_none() { + return Err(ShellError::NotSupported); + } + + Ok(Shell::from_string( + shell.unwrap().to_string_lossy().to_string(), + )) + } + + pub fn from_string(name: String) -> Shell { + match name.as_str() { + "bash" => Shell::Bash, + "fish" => Shell::Fish, + "zsh" => Shell::Zsh, + "xonsh" => Shell::Xonsh, + "nu" => Shell::Nu, + "sh" => Shell::Sh, + "powershell" => Shell::Powershell, + + _ => Shell::Unknown, + } + } + + /// Returns true if the shell is posix-like + /// Note that while fish is not posix compliant, it behaves well enough for our current + /// featureset that this does not matter. + pub fn is_posixish(&self) -> bool { + matches!(self, Shell::Bash | Shell::Fish | Shell::Zsh) + } + + pub fn run_interactive<I, S>(&self, args: I) -> Result<String, ShellError> + where + I: IntoIterator<Item = S>, + S: AsRef<OsStr>, + { + let shell = self.to_string(); + let output = if self == &Self::Powershell { + Command::new(shell) + .args(args) + .output() + .map_err(|e| ShellError::ExecError(e.to_string()))? + } else { + Command::new(shell) + .arg("-ic") + .args(args) + .output() + .map_err(|e| ShellError::ExecError(e.to_string()))? + }; + + Ok(String::from_utf8(output.stdout).unwrap()) + } +} + +pub fn shell_name(parent: Option<&Process>) -> String { + let sys = System::new_all(); + + let parent = if let Some(parent) = parent { + parent + } else { + let process = sys + .process(get_current_pid().expect("Failed to get current PID")) + .expect("Process with current pid does not exist"); + + sys.process(process.parent().expect("Atuin running with no parent!")) + .expect("Process with parent pid does not exist") + }; + + let shell = parent.name().trim().to_lowercase(); + let shell = shell.strip_prefix('-').unwrap_or(&shell); + + shell.to_string() +} diff --git a/crates/turtle/src/atuin_common/tls.rs b/crates/turtle/src/atuin_common/tls.rs new file mode 100644 index 00000000..e8c840e0 --- /dev/null +++ b/crates/turtle/src/atuin_common/tls.rs @@ -0,0 +1,15 @@ +use std::sync::Once; + +static INIT: Once = Once::new(); + +/// Ensure the rustls crypto provider (ring) is installed. +/// +/// Must be called before creating any reqwest clients. Safe to call +/// multiple times — only the first call installs the provider. +pub fn ensure_crypto_provider() { + INIT.call_once(|| { + rustls::crypto::ring::default_provider() + .install_default() + .expect("Failed to install rustls crypto provider"); + }); +} diff --git a/crates/turtle/src/atuin_common/utils.rs b/crates/turtle/src/atuin_common/utils.rs new file mode 100644 index 00000000..d7382fb2 --- /dev/null +++ b/crates/turtle/src/atuin_common/utils.rs @@ -0,0 +1,383 @@ +use std::borrow::Cow; +use std::env; +use std::path::{Path, PathBuf}; + +use eyre::{Result, eyre}; + +use base64::prelude::{BASE64_URL_SAFE_NO_PAD, Engine}; +use getrandom::getrandom; +use uuid::Uuid; + +/// Generate N random bytes, using a cryptographically secure source +pub fn crypto_random_bytes<const N: usize>() -> [u8; N] { + // rand say they are in principle safe for crypto purposes, but that it is perhaps a better + // idea to use getrandom for things such as passwords. + let mut ret = [0u8; N]; + + getrandom(&mut ret).expect("Failed to generate random bytes!"); + + ret +} + +/// Generate N random bytes using a cryptographically secure source, return encoded as a string +pub fn crypto_random_string<const N: usize>() -> String { + let bytes = crypto_random_bytes::<N>(); + + // We only use this to create a random string, and won't be reversing it to find the original + // data - no padding is OK there. It may be in URLs. + BASE64_URL_SAFE_NO_PAD.encode(bytes) +} + +pub fn uuid_v7() -> Uuid { + Uuid::now_v7() +} + +pub fn uuid_v4() -> String { + Uuid::new_v4().as_simple().to_string() +} + +pub fn has_git_dir(path: &str) -> bool { + let mut gitdir = PathBuf::from(path); + gitdir.push(".git"); + + gitdir.exists() +} + +// in a git worktree, .git is a file containing "gitdir: <path>" pointing +// to the main repo's .git/worktrees/<name> directory. follow the pointer +// back to the main repo root so all worktrees share a workspace. +fn resolve_git_worktree(path: &Path) -> Option<PathBuf> { + let git_path = path.join(".git"); + + if !git_path.is_file() { + return None; + } + + let contents = std::fs::read_to_string(&git_path).ok()?; + let gitdir_str = contents.strip_prefix("gitdir: ")?.trim(); + + let gitdir = PathBuf::from(gitdir_str); + let gitdir = if gitdir.is_absolute() { + gitdir + } else { + path.join(gitdir_str) + }; + + // walk up from e.g. /repo/.git/worktrees/feature to find /repo + let mut candidate = gitdir.as_path(); + while let Some(parent) = candidate.parent() { + if parent.join(".git").is_dir() { + return Some(parent.to_path_buf()); + } + candidate = parent; + } + + None +} + +// detect if any parent dir has a git repo in it +// I really don't want to bring in libgit for something simple like this +// If we start to do anything more advanced, then perhaps +pub fn in_git_repo(path: &str) -> Option<PathBuf> { + let mut gitdir = PathBuf::from(path); + + while gitdir.parent().is_some() && !has_git_dir(gitdir.to_str().unwrap()) { + gitdir.pop(); + } + + // No parent? then we hit root, finding no git + if gitdir.parent().is_some() { + // if .git is a file (worktree), resolve to the main repo root + if let Some(main_repo) = resolve_git_worktree(&gitdir) { + return Some(main_repo); + } + return Some(gitdir); + } + + None +} + +// TODO: more reliable, more tested +// I don't want to use ProjectDirs, it puts config in awkward places on +// mac. Data too. Seems to be more intended for GUI apps. + +pub fn home_dir() -> PathBuf { + directories::BaseDirs::new() + .map(|d| d.home_dir().to_path_buf()) + .expect("could not determine home directory") +} + +pub fn config_dir() -> PathBuf { + let config_dir = + std::env::var("XDG_CONFIG_HOME").map_or_else(|_| home_dir().join(".config"), PathBuf::from); + config_dir.join("atuin") +} + +pub fn data_dir() -> PathBuf { + let data_dir = std::env::var("XDG_DATA_HOME") + .map_or_else(|_| home_dir().join(".local").join("share"), PathBuf::from); + + data_dir.join("atuin") +} + +pub fn runtime_dir() -> PathBuf { + std::env::var("XDG_RUNTIME_DIR").map_or_else(|_| data_dir(), PathBuf::from) +} + +pub fn logs_dir() -> PathBuf { + home_dir().join(".atuin").join("logs") +} + +pub fn dotfiles_cache_dir() -> PathBuf { + // In most cases, this will be ~/.local/share/atuin/dotfiles/cache + let data_dir = std::env::var("XDG_DATA_HOME") + .map_or_else(|_| home_dir().join(".local").join("share"), PathBuf::from); + + data_dir.join("atuin").join("dotfiles").join("cache") +} + +pub fn get_current_dir() -> String { + // Prefer PWD environment variable over cwd if available to better support symbolic links + match env::var("PWD") { + Ok(v) => v, + Err(_) => match env::current_dir() { + Ok(dir) => dir.display().to_string(), + Err(_) => String::from(""), + }, + } +} + +pub fn broken_symlink<P: Into<PathBuf>>(path: P) -> bool { + let path = path.into(); + path.is_symlink() && !path.exists() +} + +/// Extension trait for anything that can behave like a string to make it easy to escape control +/// characters. +/// +/// Intended to help prevent control characters being printed and interpreted by the terminal when +/// printing history as well as to ensure the commands that appear in the interactive search +/// reflect the actual command run rather than just the printable characters. +pub trait Escapable: AsRef<str> { + fn escape_control(&self) -> Cow<'_, str> { + if !self.as_ref().contains(|c: char| c.is_ascii_control()) { + self.as_ref().into() + } else { + let mut remaining = self.as_ref(); + // Not a perfect way to reserve space but should reduce the allocations + let mut buf = String::with_capacity(remaining.len()); + while let Some(i) = remaining.find(|c: char| c.is_ascii_control()) { + // safe to index with `..i`, `i` and `i+1..` as part[i] is a single byte ascii char + buf.push_str(&remaining[..i]); + buf.push('^'); + buf.push(match remaining.as_bytes()[i] { + 0x7F => '?', + code => char::from_u32(u32::from(code) + 64).unwrap(), + }); + remaining = &remaining[i + 1..]; + } + buf.push_str(remaining); + buf.into() + } + } +} + +pub fn unquote(s: &str) -> Result<String> { + if s.chars().count() < 2 { + return Err(eyre!("not enough chars")); + } + + let quote = s.chars().next().unwrap(); + + // not quoted, do nothing + if quote != '"' && quote != '\'' && quote != '`' { + return Ok(s.to_string()); + } + + if s.chars().last().unwrap() != quote { + return Err(eyre!("unexpected eof, quotes do not match")); + } + + // removes quote characters + // the sanity checks performed above ensure that the quotes will be ASCII and this will not + // panic + let s = &s[1..s.len() - 1]; + + Ok(s.to_string()) +} + +impl<T: AsRef<str>> Escapable for T {} + +#[expect(unsafe_code)] +#[cfg(test)] +mod tests { + use pretty_assertions::assert_ne; + + use super::*; + + use std::collections::HashSet; + + #[cfg(not(windows))] + #[test] + fn test_dirs() { + // these tests need to be run sequentially to prevent race condition + test_config_dir_xdg(); + test_config_dir(); + test_data_dir_xdg(); + test_data_dir(); + } + + #[cfg(not(windows))] + fn test_config_dir_xdg() { + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::remove_var("HOME") }; + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::set_var("XDG_CONFIG_HOME", "/home/user/custom_config") }; + assert_eq!( + config_dir(), + PathBuf::from("/home/user/custom_config/atuin") + ); + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::remove_var("XDG_CONFIG_HOME") }; + } + + #[cfg(not(windows))] + fn test_config_dir() { + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::set_var("HOME", "/home/user") }; + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::remove_var("XDG_CONFIG_HOME") }; + + assert_eq!(config_dir(), PathBuf::from("/home/user/.config/atuin")); + + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::remove_var("HOME") }; + } + + #[cfg(not(windows))] + fn test_data_dir_xdg() { + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::remove_var("HOME") }; + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::set_var("XDG_DATA_HOME", "/home/user/custom_data") }; + assert_eq!(data_dir(), PathBuf::from("/home/user/custom_data/atuin")); + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::remove_var("XDG_DATA_HOME") }; + } + + #[cfg(not(windows))] + fn test_data_dir() { + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::set_var("HOME", "/home/user") }; + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::remove_var("XDG_DATA_HOME") }; + assert_eq!(data_dir(), PathBuf::from("/home/user/.local/share/atuin")); + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::remove_var("HOME") }; + } + + #[test] + fn uuid_is_unique() { + let how_many: usize = 1000000; + + // for peace of mind + let mut uuids: HashSet<Uuid> = HashSet::with_capacity(how_many); + + // there will be many in the same millisecond + for _ in 0..how_many { + let uuid = uuid_v7(); + uuids.insert(uuid); + } + + assert_eq!(uuids.len(), how_many); + } + + #[test] + fn escape_control_characters() { + use super::Escapable; + // CSI colour sequence + assert_eq!("\x1b[31mfoo".escape_control(), "^[[31mfoo"); + + // Tabs count as control chars + assert_eq!("foo\tbar".escape_control(), "foo^Ibar"); + + // space is in control char range but should be excluded + assert_eq!("two words".escape_control(), "two words"); + + // unicode multi-byte characters + let s = "🐢\x1b[32m🦀"; + assert_eq!(s.escape_control(), s.replace("\x1b", "^[")); + } + + #[test] + fn escape_no_control_characters() { + use super::Escapable as _; + assert!(matches!( + "no control characters".escape_control(), + Cow::Borrowed(_) + )); + assert!(matches!( + "with \x1b[31mcontrol\x1b[0m characters".escape_control(), + Cow::Owned(_) + )); + } + + #[cfg(not(windows))] + #[test] + fn in_git_repo_regular() { + // regular git repo should resolve to the directory containing .git + let tmp = std::env::temp_dir().join("atuin-test-regular-git"); + let _ = std::fs::remove_dir_all(&tmp); + let subdir = tmp.join("src").join("deep"); + std::fs::create_dir_all(&subdir).unwrap(); + std::fs::create_dir_all(tmp.join(".git")).unwrap(); + + let result = in_git_repo(subdir.to_str().unwrap()); + assert_eq!(result, Some(tmp.clone())); + + std::fs::remove_dir_all(&tmp).unwrap(); + } + + #[cfg(not(windows))] + #[test] + fn in_git_repo_worktree_resolves_to_main_repo() { + // worktree .git is a file pointing back to the main repo — + // in_git_repo should follow it so all worktrees share a workspace + let tmp = std::env::temp_dir().join("atuin-test-worktree-git"); + let _ = std::fs::remove_dir_all(&tmp); + + // main repo at tmp/main with a real .git directory + let main_repo = tmp.join("main"); + let worktree_git_dir = main_repo.join(".git").join("worktrees").join("feature"); + std::fs::create_dir_all(&worktree_git_dir).unwrap(); + + // worktree at tmp/worktree with a .git file + let worktree = tmp.join("worktree"); + let worktree_subdir = worktree.join("src"); + std::fs::create_dir_all(&worktree_subdir).unwrap(); + std::fs::write( + worktree.join(".git"), + format!("gitdir: {}", worktree_git_dir.to_str().unwrap()), + ) + .unwrap(); + + // should resolve to the main repo root, not the worktree root + let result = in_git_repo(worktree_subdir.to_str().unwrap()); + assert_eq!(result, Some(main_repo.clone())); + + std::fs::remove_dir_all(&tmp).unwrap(); + } + + #[test] + fn dumb_random_test() { + // Obviously not a test of randomness, but make sure we haven't made some + // catastrophic error + + assert_ne!(crypto_random_string::<1>(), crypto_random_string::<1>()); + assert_ne!(crypto_random_string::<2>(), crypto_random_string::<2>()); + assert_ne!(crypto_random_string::<4>(), crypto_random_string::<4>()); + assert_ne!(crypto_random_string::<8>(), crypto_random_string::<8>()); + assert_ne!(crypto_random_string::<16>(), crypto_random_string::<16>()); + assert_ne!(crypto_random_string::<32>(), crypto_random_string::<32>()); + } +} diff --git a/crates/turtle/src/atuin_daemon/client.rs b/crates/turtle/src/atuin_daemon/client.rs new file mode 100644 index 00000000..45ef19e9 --- /dev/null +++ b/crates/turtle/src/atuin_daemon/client.rs @@ -0,0 +1,418 @@ +use crate::atuin_client::database::Context; +use crate::atuin_client::settings::{FilterMode, Settings}; +use eyre::{Context as EyreContext, Result}; +use tonic::Code; +use tonic::transport::{Channel, Endpoint, Uri}; +use tower::service_fn; + +use hyper_util::rt::TokioIo; + +#[cfg(unix)] +use tokio::net::UnixStream; + +use crate::atuin_client::history::History; +use tracing::{Level, instrument, span}; + +use crate::atuin_daemon::control::HistoryRebuiltEvent; +use crate::atuin_daemon::control::{ + ForceSyncEvent, HistoryDeletedEvent, HistoryPrunedEvent, SendEventRequest, + SettingsReloadedEvent, ShutdownEvent, control_client::ControlClient as ControlServiceClient, +}; +use crate::atuin_daemon::events::DaemonEvent; +use crate::atuin_daemon::history::{ + EndHistoryReply, EndHistoryRequest, ShutdownRequest, StartHistoryReply, StartHistoryRequest, + StatusReply, StatusRequest, TailHistoryReply, TailHistoryRequest, + history_client::HistoryClient as HistoryServiceClient, +}; +use crate::atuin_daemon::search::{ + FilterMode as RpcFilterMode, SearchContext as RpcSearchContext, SearchRequest, SearchResponse, + search_client::SearchClient as SearchServiceClient, +}; +use crate::atuin_daemon::semantic::{ + CommandCapture, CommandOutputReply, CommandOutputRequest, OutputRange, RecordCommandsReply, + semantic_client::SemanticClient as SemanticServiceClient, +}; + +pub struct HistoryClient { + client: HistoryServiceClient<Channel>, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum DaemonClientErrorKind { + Connect, + Unavailable, + Unimplemented, + Other, +} + +#[must_use] +pub fn classify_error(error: &eyre::Report) -> DaemonClientErrorKind { + for cause in error.chain() { + if cause.downcast_ref::<tonic::transport::Error>().is_some() { + return DaemonClientErrorKind::Connect; + } + + if let Some(status) = cause.downcast_ref::<tonic::Status>() { + return match status.code() { + Code::Unavailable => DaemonClientErrorKind::Unavailable, + Code::Unimplemented => DaemonClientErrorKind::Unimplemented, + _ => DaemonClientErrorKind::Other, + }; + } + } + + DaemonClientErrorKind::Other +} + +// Wrap the grpc client +impl HistoryClient { + #[cfg(unix)] + pub async fn new(path: String) -> Result<Self> { + use eyre::Context; + + let log_path = path.clone(); + let channel = Endpoint::try_from("http://atuin_local_daemon:0")? + .connect_with_connector(service_fn(move |_: Uri| { + let path = path.clone(); + + async move { + Ok::<_, std::io::Error>(TokioIo::new(UnixStream::connect(path.clone()).await?)) + } + })) + .await + .wrap_err_with(|| { + format!( + "failed to connect to local atuin daemon at {}. Is it running?", + &log_path + ) + })?; + + let client = HistoryServiceClient::new(channel); + + Ok(HistoryClient { client }) + } + + pub async fn start_history(&mut self, h: History) -> Result<StartHistoryReply> { + let req = StartHistoryRequest { + command: h.command, + cwd: h.cwd, + hostname: h.hostname, + session: h.session, + timestamp: h.timestamp.unix_timestamp_nanos() as u64, + author: h.author, + intent: h.intent.unwrap_or_default(), + }; + + Ok(self.client.start_history(req).await?.into_inner()) + } + + pub async fn end_history( + &mut self, + id: String, + duration: u64, + exit: i64, + ) -> Result<EndHistoryReply> { + let req = EndHistoryRequest { id, duration, exit }; + + Ok(self.client.end_history(req).await?.into_inner()) + } + + pub async fn status(&mut self) -> Result<StatusReply> { + Ok(self.client.status(StatusRequest {}).await?.into_inner()) + } + + pub async fn tail_history(&mut self) -> Result<tonic::Streaming<TailHistoryReply>> { + Ok(self + .client + .tail_history(TailHistoryRequest {}) + .await? + .into_inner()) + } + + pub async fn shutdown(&mut self) -> Result<bool> { + let resp = self.client.shutdown(ShutdownRequest {}).await?.into_inner(); + Ok(resp.accepted) + } +} + +pub struct SearchClient { + client: SearchServiceClient<Channel>, +} + +impl SearchClient { + #[cfg(unix)] + pub async fn new(path: String) -> Result<Self> { + let log_path = path.clone(); + let channel = Endpoint::try_from("http://atuin_local_daemon:0")? + .connect_with_connector(service_fn(move |_: Uri| { + let path = path.clone(); + + async move { + Ok::<_, std::io::Error>(TokioIo::new(UnixStream::connect(path.clone()).await?)) + } + })) + .await + .wrap_err_with(|| { + format!( + "failed to connect to local atuin daemon at {}. Is it running?", + &log_path + ) + })?; + + let client = SearchServiceClient::new(channel); + + Ok(SearchClient { client }) + } + + #[instrument(skip_all, level = Level::TRACE, name = "daemon_client_search", fields(query = %query, query_id = query_id))] + pub async fn search( + &mut self, + query: String, + query_id: u64, + filter_mode: FilterMode, + context: Option<Context>, + ) -> Result<tonic::Streaming<SearchResponse>> { + let request = SearchRequest { + query, + query_id, + filter_mode: RpcFilterMode::from(filter_mode).into(), + context: context.map(RpcSearchContext::from), + }; + let request_stream = tokio_stream::once(request); + let response = span!(Level::TRACE, "daemon_client_search.request") + .in_scope(async || self.client.search(request_stream).await) + .await?; + + Ok(response.into_inner()) + } +} + +impl From<FilterMode> for RpcFilterMode { + fn from(filter_mode: FilterMode) -> Self { + match filter_mode { + FilterMode::Global => RpcFilterMode::Global, + FilterMode::Host => RpcFilterMode::Host, + FilterMode::Session => RpcFilterMode::Session, + FilterMode::Directory => RpcFilterMode::Directory, + FilterMode::Workspace => RpcFilterMode::Workspace, + FilterMode::SessionPreload => RpcFilterMode::SessionPreload, + } + } +} + +impl From<Context> for RpcSearchContext { + fn from(context: Context) -> Self { + RpcSearchContext { + session_id: context.session, + cwd: context.cwd, + hostname: context.hostname, + host_id: context.host_id, + git_root: context + .git_root + .map(|path| path.to_string_lossy().to_string()), + } + } +} + +pub struct SemanticClient { + client: SemanticServiceClient<Channel>, +} + +impl SemanticClient { + #[cfg(unix)] + pub async fn new(path: String) -> Result<Self> { + let log_path = path.clone(); + let channel = Endpoint::try_from("http://atuin_local_daemon:0")? + .connect_with_connector(service_fn(move |_: Uri| { + let path = path.clone(); + + async move { + Ok::<_, std::io::Error>(TokioIo::new(UnixStream::connect(path.clone()).await?)) + } + })) + .await + .wrap_err_with(|| { + format!( + "failed to connect to local atuin daemon at {}. Is it running?", + &log_path + ) + })?; + + let client = SemanticServiceClient::new(channel); + + Ok(SemanticClient { client }) + } + + #[cfg(unix)] + pub async fn from_settings(settings: &Settings) -> Result<Self> { + Self::new(settings.daemon.socket_path.clone()).await + } + + pub async fn record_commands( + &mut self, + captures: Vec<CommandCapture>, + ) -> Result<RecordCommandsReply> { + let stream = tokio_stream::iter(captures); + Ok(self.client.record_commands(stream).await?.into_inner()) + } + + pub async fn command_output( + &mut self, + history_id: String, + ranges: Vec<(i64, i64)>, + ) -> Result<CommandOutputReply> { + let request = CommandOutputRequest { + history_id, + ranges: ranges + .into_iter() + .map(|(start, end)| OutputRange { start, end }) + .collect(), + }; + + Ok(self.client.command_output(request).await?.into_inner()) + } +} + +// ============================================================================ +// Control Client +// ============================================================================ + +/// Client for the Control gRPC service. +/// +/// Used to inject events into a running daemon from external processes. +pub struct ControlClient { + client: ControlServiceClient<Channel>, +} + +impl ControlClient { + /// Connect to the daemon's control service. + #[cfg(unix)] + pub async fn new(path: String) -> Result<Self> { + let log_path = path.clone(); + let channel = Endpoint::try_from("http://atuin_local_daemon:0")? + .connect_with_connector(service_fn(move |_: Uri| { + let path = path.clone(); + + async move { + Ok::<_, std::io::Error>(TokioIo::new(UnixStream::connect(path.clone()).await?)) + } + })) + .await + .wrap_err_with(|| { + format!( + "failed to connect to local atuin daemon at {}. Is it running?", + &log_path + ) + })?; + + let client = ControlServiceClient::new(channel); + + Ok(ControlClient { client }) + } + + /// Connect using settings. + #[cfg(unix)] + pub async fn from_settings(settings: &Settings) -> Result<Self> { + Self::new(settings.daemon.socket_path.clone()).await + } + + /// Send an event to the daemon. + pub async fn send_event(&mut self, event: DaemonEvent) -> Result<()> { + let proto_event = daemon_event_to_proto(event); + let request = SendEventRequest { + event: Some(proto_event), + }; + self.client.send_event(request).await?; + Ok(()) + } +} + +/// Convert a daemon event to its proto representation. +fn daemon_event_to_proto( + event: DaemonEvent, +) -> crate::atuin_daemon::control::send_event_request::Event { + use crate::atuin_daemon::control::send_event_request::Event; + + match event { + DaemonEvent::HistoryPruned => Event::HistoryPruned(HistoryPrunedEvent {}), + DaemonEvent::HistoryRebuilt => Event::HistoryRebuilt(HistoryRebuiltEvent {}), + DaemonEvent::HistoryDeleted { ids } => Event::HistoryDeleted(HistoryDeletedEvent { + ids: ids.into_iter().map(|id| id.0).collect(), + }), + DaemonEvent::ForceSync => Event::ForceSync(ForceSyncEvent {}), + DaemonEvent::SettingsReloaded => Event::SettingsReloaded(SettingsReloadedEvent {}), + DaemonEvent::ShutdownRequested => Event::Shutdown(ShutdownEvent {}), + // These events are internal and not sent via the control service + DaemonEvent::HistoryStarted(_) + | DaemonEvent::HistoryEnded(_) + | DaemonEvent::RecordsAdded(_) + | DaemonEvent::SyncCompleted { .. } + | DaemonEvent::SyncFailed { .. } => { + // Use shutdown as a fallback, though this shouldn't happen + tracing::warn!("attempted to send internal event via control service"); + Event::Shutdown(ShutdownEvent {}) + } + } +} + +// ============================================================================ +// Convenience Functions +// ============================================================================ + +/// Emit an event to the daemon. +/// +/// This is a fire-and-forget helper for sending events to the daemon from +/// external processes like CLI commands. If the daemon isn't running, this +/// will silently succeed (returns Ok). +/// +/// # Example +/// +/// ```ignore +/// // After pruning history +/// emit_event(DaemonEvent::HistoryPruned).await?; +/// +/// // After deleting specific history items +/// emit_event(DaemonEvent::HistoryDeleted { ids: vec![...] }).await?; +/// +/// // Request immediate sync +/// emit_event(DaemonEvent::ForceSync).await?; +/// ``` +pub async fn emit_event(event: DaemonEvent) -> Result<()> { + emit_event_with_settings(event, None).await +} + +/// Emit an event to the daemon with explicit settings. +/// +/// If settings are not provided, they will be loaded from the default location. +/// If the daemon isn't running, this will silently succeed. +pub async fn emit_event_with_settings( + event: DaemonEvent, + settings: Option<&Settings>, +) -> Result<()> { + // Load settings if not provided + let owned_settings; + let settings = match settings { + Some(s) => s, + None => { + owned_settings = Settings::new()?; + &owned_settings + } + }; + + // Try to connect - if daemon isn't running, that's fine + let mut client = match ControlClient::from_settings(settings).await { + Ok(c) => c, + Err(e) => { + tracing::debug!(?e, "daemon not running, skipping event emission"); + return Ok(()); + } + }; + + // Send the event + if let Err(e) = client.send_event(event).await { + tracing::debug!(?e, "failed to send event to daemon"); + // Don't fail - this is fire-and-forget + } + + Ok(()) +} diff --git a/crates/turtle/src/atuin_daemon/components/history.rs b/crates/turtle/src/atuin_daemon/components/history.rs new file mode 100644 index 00000000..95d34b69 --- /dev/null +++ b/crates/turtle/src/atuin_daemon/components/history.rs @@ -0,0 +1,327 @@ +//! History component. +//! +//! Handles command history lifecycle (start/end) and provides the History gRPC service. + +use std::{pin::Pin, sync::Arc}; + +use crate::atuin_client::{ + database::Database, + history::{History, HistoryId, store::HistoryStore}, + settings::Settings, +}; +use dashmap::DashMap; +use eyre::Result; +use time::OffsetDateTime; +use tokio_stream::Stream; +use tonic::{Request, Response, Status}; +use tracing::{Level, instrument}; + +use crate::atuin_daemon::{ + daemon::{Component, DaemonHandle}, + events::DaemonEvent, + history::{ + EndHistoryReply, EndHistoryRequest, HistoryEntry, HistoryEventKind, ShutdownReply, + ShutdownRequest, StartHistoryReply, StartHistoryRequest, StatusReply, StatusRequest, + TailHistoryReply, TailHistoryRequest, + history_server::{History as HistorySvc, HistoryServer}, + }, +}; + +const DAEMON_PROTOCOL_VERSION: u32 = 1; + +/// History component - manages command history lifecycle. +/// +/// This component: +/// - Tracks currently running commands (stored in memory) +/// - Saves completed commands to the database and record store +/// - Emits history events for other components (e.g., search indexing) +/// - Provides the History gRPC service +pub struct HistoryComponent { + inner: Arc<HistoryComponentInner>, +} + +struct HistoryComponentInner { + /// Commands currently running (not yet completed). + running: DashMap<HistoryId, History>, + + /// Handle to the daemon (set during start). + handle: tokio::sync::RwLock<Option<DaemonHandle>>, + + /// History store for pushing records (set during start). + history_store: tokio::sync::RwLock<Option<HistoryStore>>, +} + +impl HistoryComponent { + /// Create a new history component. + pub fn new() -> Self { + Self { + inner: Arc::new(HistoryComponentInner { + running: DashMap::new(), + handle: tokio::sync::RwLock::new(None), + history_store: tokio::sync::RwLock::new(None), + }), + } + } + + /// Get the gRPC service for this component. + /// + /// This returns a tonic service that can be added to a gRPC server. + pub fn grpc_service(&self) -> HistoryServer<HistoryGrpcService> { + HistoryServer::new(HistoryGrpcService { + inner: self.inner.clone(), + }) + } +} + +impl Default for HistoryComponent { + fn default() -> Self { + Self::new() + } +} + +#[tonic::async_trait] +impl Component for HistoryComponent { + fn name(&self) -> &'static str { + "history" + } + + async fn start(&mut self, handle: DaemonHandle) -> Result<()> { + // Create the history store + let host_id = Settings::host_id().await?; + let history_store = + HistoryStore::new(handle.store().clone(), host_id, *handle.encryption_key()); + + *self.inner.history_store.write().await = Some(history_store); + *self.inner.handle.write().await = Some(handle); + + tracing::info!("history component started"); + Ok(()) + } + + async fn handle_event(&mut self, _event: &DaemonEvent) -> Result<()> { + // History component produces events but doesn't need to react to them + Ok(()) + } + + async fn stop(&mut self) -> Result<()> { + tracing::info!("history component stopped"); + Ok(()) + } +} + +/// The gRPC service implementation. +/// +/// This is a thin wrapper that delegates to the component's shared state. +pub struct HistoryGrpcService { + inner: Arc<HistoryComponentInner>, +} + +fn history_to_tail_reply(kind: HistoryEventKind, history: History) -> TailHistoryReply { + TailHistoryReply { + kind: kind as i32, + history: Some(HistoryEntry { + timestamp: history.timestamp.unix_timestamp_nanos() as u64, + id: history.id.0, + command: history.command, + cwd: history.cwd, + session: history.session, + hostname: history.hostname, + author: history.author, + intent: history.intent.unwrap_or_default(), + exit: history.exit, + duration: history.duration, + }), + } +} + +#[tonic::async_trait] +impl HistorySvc for HistoryGrpcService { + type TailHistoryStream = Pin<Box<dyn Stream<Item = Result<TailHistoryReply, Status>> + Send>>; + + #[instrument(skip_all, level = Level::INFO)] + async fn start_history( + &self, + request: Request<StartHistoryRequest>, + ) -> Result<Response<StartHistoryReply>, Status> { + let req = request.into_inner(); + + let timestamp = + OffsetDateTime::from_unix_timestamp_nanos(req.timestamp as i128).map_err(|_| { + Status::invalid_argument( + "failed to parse timestamp as unix time (expected nanos since epoch)", + ) + })?; + + let h: History = History::daemon() + .timestamp(timestamp) + .command(req.command) + .cwd(req.cwd) + .session(req.session) + .hostname(req.hostname) + .author(req.author) + .intent(req.intent) + .build() + .into(); + + // Emit the event + if let Some(handle) = self.inner.handle.read().await.as_ref() { + handle.emit(DaemonEvent::HistoryStarted(h.clone())); + } + + let id = h.id.clone(); + tracing::info!(id = id.to_string(), "start history"); + self.inner.running.insert(id.clone(), h); + + let reply = StartHistoryReply { + id: id.to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + protocol: DAEMON_PROTOCOL_VERSION, + }; + + Ok(Response::new(reply)) + } + + #[instrument(skip_all, level = Level::INFO)] + async fn end_history( + &self, + request: Request<EndHistoryRequest>, + ) -> Result<Response<EndHistoryReply>, Status> { + let req = request.into_inner(); + let id = HistoryId(req.id); + + if let Some((_, mut history)) = self.inner.running.remove(&id) { + history.exit = req.exit; + history.duration = match req.duration { + 0 => i64::try_from( + (OffsetDateTime::now_utc() - history.timestamp).whole_nanoseconds(), + ) + .expect("failed to convert calculated duration to i64"), + value => i64::try_from(value).expect("failed to get i64 duration"), + }; + + // Get the handle and store to save the history + let handle_guard = self.inner.handle.read().await; + let handle = handle_guard + .as_ref() + .ok_or_else(|| Status::internal("component not initialized"))?; + + let store_guard = self.inner.history_store.read().await; + let history_store = store_guard + .as_ref() + .ok_or_else(|| Status::internal("component not initialized"))?; + + // Save to database + handle + .history_db() + .save(&history) + .await + .map_err(|e| Status::internal(format!("failed to write to db: {e:?}")))?; + + tracing::info!( + id = id.0.to_string(), + duration = history.duration, + "end history" + ); + + // Push to record store + let (record_id, idx) = history_store + .push(history.clone()) + .await + .map_err(|e| Status::internal(format!("failed to push record to store: {e:?}")))?; + + // Emit the event + handle.emit(DaemonEvent::HistoryEnded(history)); + + let reply = EndHistoryReply { + id: record_id.0.to_string(), + idx, + version: env!("CARGO_PKG_VERSION").to_string(), + protocol: DAEMON_PROTOCOL_VERSION, + }; + + return Ok(Response::new(reply)); + } + + Err(Status::not_found(format!( + "could not find history with id: {id}" + ))) + } + + #[instrument(skip_all, level = Level::INFO)] + async fn tail_history( + &self, + _request: Request<TailHistoryRequest>, + ) -> Result<Response<Self::TailHistoryStream>, Status> { + let handle_guard = self.inner.handle.read().await; + let handle = handle_guard + .as_ref() + .cloned() + .ok_or_else(|| Status::internal("component not initialized"))?; + + let mut rx = handle.subscribe(); + let (tx, out_rx) = tokio::sync::mpsc::channel::<Result<TailHistoryReply, Status>>(128); + + tokio::spawn(async move { + loop { + let event = match rx.recv().await { + Ok(event) => event, + Err(tokio::sync::broadcast::error::RecvError::Lagged(skipped)) => { + let _ = tx + .send(Err(Status::resource_exhausted(format!( + "tail stream lagged behind and dropped {skipped} events" + )))) + .await; + break; + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => break, + }; + + let reply = match event { + DaemonEvent::HistoryStarted(history) => { + Some(history_to_tail_reply(HistoryEventKind::Started, history)) + } + DaemonEvent::HistoryEnded(history) => { + Some(history_to_tail_reply(HistoryEventKind::Ended, history)) + } + _ => None, + }; + + if let Some(reply) = reply + && tx.send(Ok(reply)).await.is_err() + { + break; + } + } + }); + + let stream = tokio_stream::wrappers::ReceiverStream::new(out_rx); + Ok(Response::new(Box::pin(stream))) + } + + #[instrument(skip_all, level = Level::INFO)] + async fn status( + &self, + _request: Request<StatusRequest>, + ) -> Result<Response<StatusReply>, Status> { + let reply = StatusReply { + healthy: true, + version: env!("CARGO_PKG_VERSION").to_string(), + pid: std::process::id(), + protocol: DAEMON_PROTOCOL_VERSION, + }; + + Ok(Response::new(reply)) + } + + #[instrument(skip_all, level = Level::INFO)] + async fn shutdown( + &self, + _request: Request<ShutdownRequest>, + ) -> Result<Response<ShutdownReply>, Status> { + // Use the daemon handle to request shutdown + if let Some(handle) = self.inner.handle.read().await.as_ref() { + handle.shutdown(); + } + Ok(Response::new(ShutdownReply { accepted: true })) + } +} diff --git a/crates/turtle/src/atuin_daemon/components/mod.rs b/crates/turtle/src/atuin_daemon/components/mod.rs new file mode 100644 index 00000000..447e31df --- /dev/null +++ b/crates/turtle/src/atuin_daemon/components/mod.rs @@ -0,0 +1,25 @@ +//! Daemon components. +//! +//! Components are the building blocks of the daemon. Each component handles +//! a specific domain and can: +//! +//! - Expose gRPC services +//! - React to events +//! - Spawn background tasks +//! +//! Available components: +//! +//! - [`history::HistoryComponent`]: Command history lifecycle management +//! - [`search::SearchComponent`]: Fuzzy search over history +//! - [`semantic::SemanticComponent`]: In-memory semantic command captures +//! - [`sync::SyncComponent`]: Cloud sync + +pub mod history; +pub mod search; +pub mod semantic; +pub mod sync; + +pub use history::HistoryComponent; +pub use search::SearchComponent; +pub use semantic::SemanticComponent; +pub use sync::SyncComponent; diff --git a/crates/turtle/src/atuin_daemon/components/search.rs b/crates/turtle/src/atuin_daemon/components/search.rs new file mode 100644 index 00000000..85191cff --- /dev/null +++ b/crates/turtle/src/atuin_daemon/components/search.rs @@ -0,0 +1,413 @@ +//! Search component. +//! +//! Provides fuzzy search over command history using the Nucleo search library +//! with frecency-based ranking and dynamic filtering. + +use std::{pin::Pin, sync::Arc}; + +use crate::atuin_client::database::Database; +use eyre::Result; +use tokio::sync::RwLock; +use tokio_stream::Stream; +use tonic::{Request, Response, Status, Streaming}; +use tracing::{Level, debug, info, instrument, span, trace}; +use uuid::Uuid; + +use crate::atuin_daemon::{ + daemon::{Component, DaemonHandle}, + events::DaemonEvent, + search::{ + FilterMode, IndexFilterMode, QueryContext, SearchIndex, SearchRequest, SearchResponse, + search_server::{Search as SearchSvc, SearchServer}, + }, +}; + +const PAGE_SIZE: usize = 5000; +const RESULTS_LIMIT: u32 = 200; +/// How often to rebuild the frecency map (in seconds). +const FRECENCY_REFRESH_INTERVAL_SECS: u64 = 60; + +/// Search component - provides fuzzy search over command history. +/// +/// This component: +/// - Maintains a deduplicated search index with frecency ranking +/// - Loads history from the database on startup +/// - Updates the index when history events occur +/// - Provides the Search gRPC service +pub struct SearchComponent { + index: Arc<RwLock<SearchIndex>>, + handle: tokio::sync::RwLock<Option<DaemonHandle>>, + loader_handle: Option<tokio::task::JoinHandle<()>>, + frecency_handle: Option<tokio::task::JoinHandle<()>>, +} + +impl SearchComponent { + /// Create a new search component. + pub fn new() -> Self { + Self { + index: Arc::new(RwLock::new(SearchIndex::new())), + handle: tokio::sync::RwLock::new(None), + loader_handle: None, + frecency_handle: None, + } + } + + /// Get the gRPC service for this component. + pub fn grpc_service(&self) -> SearchServer<SearchGrpcService> { + SearchServer::new(SearchGrpcService { + index: self.index.clone(), + }) + } + + /// Rebuild the entire search index from the database. + async fn rebuild_index(&self) -> Result<()> { + let handle_guard = self.handle.read().await; + let handle = handle_guard + .as_ref() + .ok_or_else(|| eyre::eyre!("component not initialized"))?; + + info!("Rebuilding search index from database"); + + // Create a new index + let new_index = SearchIndex::new(); + + // Load all history into the new index + let db = handle.history_db().clone(); + let mut pager = db.all_paged(PAGE_SIZE, false, true); + loop { + match pager.next().await { + Ok(Some(histories)) => { + info!( + "Loading {} history entries into search index", + histories.len() + ); + new_index.add_histories(&histories); + } + Ok(None) => break, + Err(e) => { + tracing::error!("Failed to load history during rebuild: {}", e); + break; + } + } + } + + info!( + "Search index rebuild complete; {} unique commands", + new_index.command_count() + ); + + // Replace the old index with the new one + *self.index.write().await = new_index; + Ok(()) + } +} + +impl Default for SearchComponent { + fn default() -> Self { + Self::new() + } +} + +#[tonic::async_trait] +impl Component for SearchComponent { + fn name(&self) -> &'static str { + "search" + } + + async fn start(&mut self, handle: DaemonHandle) -> Result<()> { + *self.handle.write().await = Some(handle.clone()); + + // Spawn background task to load history into index + let index = self.index.clone(); + let db = handle.history_db().clone(); + let handle_for_loader = handle.clone(); + + self.loader_handle = Some(tokio::spawn(async move { + info!( + "Loading history into search index; page size = {}", + PAGE_SIZE + ); + let mut pager = db.all_paged(PAGE_SIZE, false, true); + loop { + match pager.next().await { + Ok(Some(histories)) => { + info!( + "Loading {} history entries into search index", + histories.len() + ); + index.read().await.add_histories(&histories); + } + Ok(None) => { + info!( + "Initial history load complete; {} unique commands indexed", + index.read().await.command_count() + ); + // Build initial frecency map with current settings + let settings = handle_for_loader.settings().await; + index.read().await.rebuild_frecency(&settings.search).await; + info!("Initial frecency map built"); + break; + } + Err(e) => { + tracing::error!("Failed to load history: {}", e); + break; + } + } + } + })); + + // Spawn background task to periodically refresh frecency + let index_for_frecency = self.index.clone(); + let handle_for_frecency = handle.clone(); + self.frecency_handle = Some(tokio::spawn(async move { + let mut interval = tokio::time::interval(std::time::Duration::from_secs( + FRECENCY_REFRESH_INTERVAL_SECS, + )); + loop { + interval.tick().await; + trace!("Refreshing frecency map"); + let settings = handle_for_frecency.settings().await; + index_for_frecency + .read() + .await + .rebuild_frecency(&settings.search) + .await; + } + })); + + tracing::info!("search component started"); + Ok(()) + } + + async fn handle_event(&mut self, event: &DaemonEvent) -> Result<()> { + match event { + DaemonEvent::RecordsAdded(records) => { + debug!( + count = records.len(), + "Processing added records for search index" + ); + + let handle_guard = self.handle.read().await; + if let Some(handle) = handle_guard.as_ref() { + let histories: Vec<_> = handle + .history_db() + .query_history( + format!( + "select * from history where id in ({})", + records + .iter() + .map(|record| record.0.to_string()) + .collect::<Vec<_>>() + .join(",") + ) + .as_str(), + ) + .await + .unwrap_or_default(); + + span!(Level::TRACE, "inject_records", count = histories.len()) + .in_scope(async || { + self.index.read().await.add_histories(&histories); + }) + .await; + } + } + DaemonEvent::HistoryStarted(history) => { + debug!(id = %history.id, command = %history.command, "History started (no index action)"); + } + DaemonEvent::HistoryEnded(history) => { + span!(Level::TRACE, "inject_history_ended") + .in_scope(async || { + self.index.read().await.add_history(history); + }) + .await; + } + DaemonEvent::HistoryPruned | DaemonEvent::HistoryRebuilt => { + info!("History store pruned or rebuilt, rebuilding search index"); + if let Err(e) = self.rebuild_index().await { + tracing::error!("Failed to rebuild search index: {}", e); + } + } + DaemonEvent::HistoryDeleted { ids } => { + info!( + count = ids.len(), + "History deleted, rebuilding search index" + ); + // For now, just rebuild the entire index. A more efficient implementation + // would remove specific items from the index. + if let Err(e) = self.rebuild_index().await { + tracing::error!("Failed to rebuild search index: {}", e); + } + } + DaemonEvent::SettingsReloaded => { + info!("Settings reloaded, rebuilding frecency map with new multipliers"); + let handle_guard = self.handle.read().await; + if let Some(handle) = handle_guard.as_ref() { + let settings = handle.settings().await; + self.index + .read() + .await + .rebuild_frecency(&settings.search) + .await; + } + } + // Events we don't care about + DaemonEvent::SyncCompleted { .. } + | DaemonEvent::SyncFailed { .. } + | DaemonEvent::ForceSync + | DaemonEvent::ShutdownRequested => {} + } + Ok(()) + } + + async fn stop(&mut self) -> Result<()> { + if let Some(handle) = self.loader_handle.take() { + handle.abort(); + } + if let Some(handle) = self.frecency_handle.take() { + handle.abort(); + } + tracing::info!("search component stopped"); + Ok(()) + } +} + +/// The gRPC service implementation. +pub struct SearchGrpcService { + index: Arc<RwLock<SearchIndex>>, +} + +#[tonic::async_trait] +impl SearchSvc for SearchGrpcService { + type SearchStream = Pin<Box<dyn Stream<Item = Result<SearchResponse, Status>> + Send>>; + + #[instrument(skip_all, level = Level::TRACE, name = "search_rpc")] + async fn search( + &self, + request: Request<Streaming<SearchRequest>>, + ) -> Result<Response<Self::SearchStream>, Status> { + let mut in_stream = request.into_inner(); + let index = self.index.clone(); + + // Create output channel + let (tx, rx) = tokio::sync::mpsc::channel::<Result<SearchResponse, Status>>(128); + + // Spawn task to handle incoming requests and send responses + tokio::spawn(async move { + while let Some(req) = in_stream.message().await.transpose() { + match req { + Ok(search_req) => { + let query = search_req.query; + let query_id = search_req.query_id; + let filter_mode: FilterMode = search_req + .filter_mode + .try_into() + .unwrap_or(FilterMode::Global); + let proto_context = search_req.context; + + debug!( + "search request: query = {}, query_id = {}, filter_mode = {}, context = {:?}", + query, + query_id, + filter_mode.as_str_name(), + proto_context + ); + + // Convert proto FilterMode + context to IndexFilterMode + let index_filter = convert_filter_mode(filter_mode, &proto_context); + + // Build QueryContext from proto context + let query_context = proto_context + .map(|ctx| QueryContext { + cwd: Some(with_trailing_slash(&ctx.cwd)), + git_root: ctx.git_root.map(|s| with_trailing_slash(&s)), + hostname: Some(ctx.hostname), + session_id: Some(ctx.session_id), + }) + .unwrap_or_default(); + + // Perform the search + let history_ids = + span!(Level::TRACE, "daemon_search_query", %query, query_id) + .in_scope(|| async { + let index = index.read().await; + index + .search(&query, index_filter, &query_context, RESULTS_LIMIT) + .await + }) + .await; + + // Convert history IDs to bytes + let ids: Vec<Vec<u8>> = history_ids + .iter() + .filter_map(|id| { + Uuid::parse_str(id) + .ok() + .map(|uuid| uuid.as_bytes().to_vec()) + }) + .collect(); + + if tx.send(Ok(SearchResponse { query_id, ids })).await.is_err() { + break; // Client disconnected + } + } + Err(e) => { + let _ = tx.send(Err(e)).await; + break; + } + } + } + }); + + // Convert receiver to stream + let out_stream = tokio_stream::wrappers::ReceiverStream::new(rx); + Ok(Response::new(Box::pin(out_stream))) + } +} + +/// Convert proto FilterMode and context to IndexFilterMode. +fn convert_filter_mode( + mode: FilterMode, + context: &Option<crate::atuin_daemon::search::SearchContext>, +) -> IndexFilterMode { + match (mode, context) { + (FilterMode::Global, _) => IndexFilterMode::Global, + (FilterMode::Directory, Some(ctx)) => { + IndexFilterMode::Directory(with_trailing_slash(&ctx.cwd)) + } + (FilterMode::Workspace, Some(ctx)) => { + if let Some(ref git_root) = ctx.git_root { + IndexFilterMode::Workspace(with_trailing_slash(git_root)) + } else { + // Fall back to directory if no git root + IndexFilterMode::Directory(with_trailing_slash(&ctx.cwd)) + } + } + (FilterMode::Host, Some(ctx)) => IndexFilterMode::Host(ctx.hostname.clone()), + (FilterMode::Session, Some(ctx)) => IndexFilterMode::Session(ctx.session_id.clone()), + (FilterMode::SessionPreload, Some(ctx)) => { + // SessionPreload is similar to Session - filter by session + IndexFilterMode::Session(ctx.session_id.clone()) + } + // If no context provided, fall back to global + _ => IndexFilterMode::Global, + } +} + +#[cfg(windows)] +pub fn with_trailing_slash(s: &str) -> String { + if s.ends_with('\\') { + s.to_string() + } else { + format!("{}\\", s) + } +} + +#[cfg(not(windows))] +pub fn with_trailing_slash(s: &str) -> String { + if s.ends_with('/') { + s.to_string() + } else { + format!("{}/", s) + } +} diff --git a/crates/turtle/src/atuin_daemon/components/semantic.rs b/crates/turtle/src/atuin_daemon/components/semantic.rs new file mode 100644 index 00000000..a42fd5cb --- /dev/null +++ b/crates/turtle/src/atuin_daemon/components/semantic.rs @@ -0,0 +1,903 @@ +//! Semantic command capture component. +//! +//! This is a prototype in-memory store for completed command captures emitted +//! by atuin-pty-proxy. It keeps recent captures per Atuin session and indexes +//! them by history ID for AI tool lookup. + +use std::collections::{HashMap, VecDeque}; +use std::fmt::{Display, Formatter}; +use std::sync::Arc; + +use crate::atuin_client::history::{History, HistoryId}; +use eyre::Result; +use tokio::sync::Mutex; +use tonic::{Request, Response, Status, Streaming}; +use tracing::{Level, instrument}; + +use crate::atuin_daemon::{ + daemon::{Component, DaemonHandle}, + events::DaemonEvent, + semantic::{ + CommandCapture, CommandOutputReply, CommandOutputRequest, OutputLine, RecordCommandsReply, + semantic_server::{Semantic as SemanticSvc, SemanticServer}, + }, +}; + +const MAX_SESSIONS: usize = 20; +const MAX_COMMANDS_PER_SESSION: usize = 128; +const MAX_BYTES_PER_SESSION: usize = 32 * 1024 * 1024; +const MAX_PENDING_HISTORIES: usize = 128; + +/// Stores completed command captures and associates them with history events. +pub struct SemanticComponent { + inner: Arc<SemanticComponentInner>, +} + +struct SemanticComponentInner { + state: Mutex<SemanticState>, +} + +#[derive(Default)] +struct SemanticState { + sessions: HashMap<SessionId, SessionCaptures>, + session_lru: VecDeque<SessionId>, + history_index: HashMap<HistoryId, CaptureRef>, + pending_histories: VecDeque<History>, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct SessionId(String); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +struct CaptureId(u64); + +#[derive(Debug, Clone, PartialEq, Eq)] +struct CaptureRef { + session_id: SessionId, + capture_id: CaptureId, +} + +#[derive(Default)] +struct SessionCaptures { + next_id: u64, + records: VecDeque<StoredCapture>, + output_bytes: usize, +} + +struct StoredCapture { + id: CaptureId, + history_id: HistoryId, + output_bytes: usize, + record: SemanticCommandRecord, +} + +struct EvictedCapture { + history_id: HistoryId, + capture_id: CaptureId, +} + +#[derive(Debug, Clone)] +struct SemanticCommandRecord { + capture: CommandCapture, + history: Option<History>, +} + +impl SemanticComponent { + pub fn new() -> Self { + Self { + inner: Arc::new(SemanticComponentInner { + state: Mutex::new(SemanticState::default()), + }), + } + } + + pub fn grpc_service(&self) -> SemanticServer<SemanticGrpcService> { + SemanticServer::new(SemanticGrpcService { + inner: self.inner.clone(), + }) + } +} + +impl Default for SemanticComponent { + fn default() -> Self { + Self::new() + } +} + +#[tonic::async_trait] +impl Component for SemanticComponent { + fn name(&self) -> &'static str { + "semantic" + } + + async fn start(&mut self, _handle: DaemonHandle) -> Result<()> { + tracing::info!("semantic component started"); + Ok(()) + } + + async fn handle_event(&mut self, event: &DaemonEvent) -> Result<()> { + if let DaemonEvent::HistoryEnded(history) = event { + self.inner.record_history(history.clone()).await; + } + + Ok(()) + } + + async fn stop(&mut self) -> Result<()> { + let state = self.inner.state.lock().await; + tracing::info!( + sessions = state.sessions.len(), + records = state.record_count(), + indexed_histories = state.history_index.len(), + pending_histories = state.pending_histories.len(), + "semantic component stopped" + ); + Ok(()) + } +} + +impl SemanticComponentInner { + async fn record_capture(&self, capture: CommandCapture) -> bool { + let mut state = self.state.lock().await; + state.record_capture(capture) + } + + async fn record_history(&self, history: History) { + let mut state = self.state.lock().await; + state.record_history(history); + } + + async fn command_output(&self, request: &CommandOutputRequest) -> CommandOutputReply { + let mut state = self.state.lock().await; + state.command_output(request) + } +} + +impl SemanticState { + fn record_capture(&mut self, mut capture: CommandCapture) -> bool { + let Some(history_id) = history_id_from_str(capture.history_id.as_deref()) else { + tracing::debug!( + command_bytes = capture.command.len(), + prompt_bytes = capture.prompt.len(), + output_bytes = capture.output.len(), + output_truncated = capture.output_truncated, + "dropping semantic command capture without history id" + ); + return false; + }; + + let history = take_pending_history(&mut self.pending_histories, &history_id); + let Some(session_id) = capture + .session_id + .as_deref() + .and_then(|session_id| SessionId::try_from(session_id).ok()) + .or_else(|| { + history + .as_ref() + .and_then(|history| SessionId::try_from(history.session.as_str()).ok()) + }) + else { + tracing::debug!( + history_id = %history_id, + command_bytes = capture.command.len(), + prompt_bytes = capture.prompt.len(), + output_bytes = capture.output.len(), + output_truncated = capture.output_truncated, + "dropping semantic command capture without session id" + ); + return false; + }; + + capture.history_id = Some(history_id.to_string()); + capture.session_id = Some(session_id.to_string()); + if capture.output_observed_bytes == 0 { + capture.output_observed_bytes = capture.output.len() as u64; + } + + let record = SemanticCommandRecord { capture, history }; + log_record(&record, "recorded semantic command capture"); + self.push_record(session_id, history_id, record); + true + } + + fn record_history(&mut self, history: History) { + let history_id = history.id.clone(); + + if let Some(capture_ref) = self.history_index.get(&history_id).cloned() { + if let Some(stored) = self.stored_capture_mut(&capture_ref) { + stored.record.history = Some(history); + log_record( + &stored.record, + "associated semantic command capture with history", + ); + return; + } + + self.history_index.remove(&history_id); + } + + tracing::debug!( + id = %history.id, + command_bytes = history.command.len(), + "history ended before semantic capture arrived" + ); + push_pending_history(&mut self.pending_histories, history); + } + + fn command_output(&mut self, request: &CommandOutputRequest) -> CommandOutputReply { + let Some(history_id) = history_id_from_str(Some(&request.history_id)) else { + return command_output_not_found(); + }; + let Some(capture_ref) = self.history_index.get(&history_id).cloned() else { + return command_output_not_found(); + }; + + let Some(reply) = self.command_output_for_ref(&capture_ref, &request.ranges) else { + self.history_index.remove(&history_id); + return command_output_not_found(); + }; + + self.touch_session(&capture_ref.session_id); + reply + } + + fn command_output_for_ref( + &self, + capture_ref: &CaptureRef, + ranges: &[crate::atuin_daemon::semantic::OutputRange], + ) -> Option<CommandOutputReply> { + let stored = self + .sessions + .get(&capture_ref.session_id)? + .stored_capture(capture_ref.capture_id)?; + let output = &stored.record.capture.output; + let output_observed_bytes = stored + .record + .capture + .output_observed_bytes + .max(output.len() as u64); + + Some(CommandOutputReply { + found: true, + output: String::new(), + total_bytes: output.len() as u64, + total_lines: output.lines().count() as u64, + lines: select_output_ranges(output, ranges), + output_truncated: stored.record.capture.output_truncated, + output_observed_bytes, + }) + } + + fn push_record( + &mut self, + session_id: SessionId, + history_id: HistoryId, + record: SemanticCommandRecord, + ) { + self.touch_session(&session_id); + + let (capture_id, evicted) = { + let session = self.sessions.entry(session_id.clone()).or_default(); + session.push(history_id.clone(), record) + }; + + let capture_ref = CaptureRef { + session_id: session_id.clone(), + capture_id, + }; + self.history_index.insert(history_id, capture_ref); + + for evicted in evicted { + self.remove_history_index_if_matches( + &session_id, + &evicted.history_id, + evicted.capture_id, + ); + } + + self.expire_lru_sessions(); + } + + fn touch_session(&mut self, session_id: &SessionId) { + if let Some(index) = self.session_lru.iter().position(|id| id == session_id) { + self.session_lru.remove(index); + } + self.session_lru.push_back(session_id.clone()); + } + + fn expire_lru_sessions(&mut self) { + while self.session_lru.len() > MAX_SESSIONS { + let Some(session_id) = self.session_lru.pop_front() else { + break; + }; + let Some(session) = self.sessions.remove(&session_id) else { + continue; + }; + + for stored in session.records { + self.remove_history_index_if_matches(&session_id, &stored.history_id, stored.id); + } + } + } + + fn remove_history_index_if_matches( + &mut self, + session_id: &SessionId, + history_id: &HistoryId, + capture_id: CaptureId, + ) { + if self + .history_index + .get(history_id) + .is_some_and(|capture_ref| { + &capture_ref.session_id == session_id && capture_ref.capture_id == capture_id + }) + { + self.history_index.remove(history_id); + } + } + + fn stored_capture_mut(&mut self, capture_ref: &CaptureRef) -> Option<&mut StoredCapture> { + self.sessions + .get_mut(&capture_ref.session_id)? + .stored_capture_mut(capture_ref.capture_id) + } + + fn record_count(&self) -> usize { + self.sessions + .values() + .map(|session| session.records.len()) + .sum() + } +} + +impl SessionCaptures { + fn push( + &mut self, + history_id: HistoryId, + record: SemanticCommandRecord, + ) -> (CaptureId, Vec<EvictedCapture>) { + self.push_with_limits( + history_id, + record, + MAX_COMMANDS_PER_SESSION, + MAX_BYTES_PER_SESSION, + ) + } + + fn push_with_limits( + &mut self, + history_id: HistoryId, + record: SemanticCommandRecord, + max_commands: usize, + max_output_bytes: usize, + ) -> (CaptureId, Vec<EvictedCapture>) { + let capture_id = CaptureId(self.next_id); + self.next_id = self.next_id.saturating_add(1); + let output_bytes = record.capture.output.len(); + self.output_bytes = self.output_bytes.saturating_add(output_bytes); + self.records.push_back(StoredCapture { + id: capture_id, + history_id, + output_bytes, + record, + }); + + ( + capture_id, + self.evict_to_limits(max_commands, max_output_bytes), + ) + } + + fn evict_to_limits( + &mut self, + max_commands: usize, + max_output_bytes: usize, + ) -> Vec<EvictedCapture> { + let mut evicted = Vec::new(); + while self.records.len() > max_commands || self.output_bytes > max_output_bytes { + let Some(record) = self.records.pop_front() else { + break; + }; + self.output_bytes = self.output_bytes.saturating_sub(record.output_bytes); + evicted.push(EvictedCapture { + history_id: record.history_id, + capture_id: record.id, + }); + } + evicted + } + + fn stored_capture(&self, capture_id: CaptureId) -> Option<&StoredCapture> { + self.records.iter().find(|record| record.id == capture_id) + } + + fn stored_capture_mut(&mut self, capture_id: CaptureId) -> Option<&mut StoredCapture> { + self.records + .iter_mut() + .find(|record| record.id == capture_id) + } +} + +impl TryFrom<&str> for SessionId { + type Error = (); + + fn try_from(value: &str) -> std::result::Result<Self, Self::Error> { + let value = value.trim(); + if value.is_empty() { + return Err(()); + } + + Ok(Self(value.to_string())) + } +} + +impl TryFrom<String> for SessionId { + type Error = (); + + fn try_from(value: String) -> std::result::Result<Self, Self::Error> { + Self::try_from(value.as_str()) + } +} + +impl AsRef<str> for SessionId { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl Display for SessionId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.0) + } +} + +pub struct SemanticGrpcService { + inner: Arc<SemanticComponentInner>, +} + +#[tonic::async_trait] +impl SemanticSvc for SemanticGrpcService { + #[instrument(skip_all, level = Level::INFO)] + async fn record_commands( + &self, + request: Request<Streaming<CommandCapture>>, + ) -> Result<Response<RecordCommandsReply>, Status> { + let mut stream = request.into_inner(); + let mut accepted = 0_u64; + + while let Some(capture) = stream.message().await? { + if self.inner.record_capture(capture).await { + accepted += 1; + } + } + + Ok(Response::new(RecordCommandsReply { accepted })) + } + + #[instrument(skip_all, level = Level::INFO)] + async fn command_output( + &self, + request: Request<CommandOutputRequest>, + ) -> Result<Response<CommandOutputReply>, Status> { + let request = request.into_inner(); + if request.history_id.trim().is_empty() { + return Err(Status::invalid_argument("history_id is required")); + } + + Ok(Response::new(self.inner.command_output(&request).await)) + } +} + +fn history_id_from_str(value: Option<&str>) -> Option<HistoryId> { + let value = value?.trim(); + (!value.is_empty()).then(|| HistoryId(value.to_string())) +} + +fn take_pending_history( + histories: &mut VecDeque<History>, + history_id: &HistoryId, +) -> Option<History> { + let index = histories + .iter() + .position(|history| &history.id == history_id)?; + histories.remove(index) +} + +fn push_pending_history(histories: &mut VecDeque<History>, history: History) { + if let Some(index) = histories + .iter() + .position(|pending| pending.id == history.id) + { + histories.remove(index); + } + + histories.push_back(history); + trim_front(histories, MAX_PENDING_HISTORIES); +} + +fn trim_front<T>(records: &mut VecDeque<T>, max_len: usize) { + while records.len() > max_len { + records.pop_front(); + } +} + +fn command_output_not_found() -> CommandOutputReply { + CommandOutputReply { + found: false, + output: String::new(), + total_bytes: 0, + total_lines: 0, + lines: Vec::new(), + output_truncated: false, + output_observed_bytes: 0, + } +} + +fn select_output_ranges( + output: &str, + ranges: &[crate::atuin_daemon::semantic::OutputRange], +) -> Vec<OutputLine> { + let lines: Vec<&str> = output.lines().collect(); + if lines.is_empty() { + return Vec::new(); + } + + let ranges = if ranges.is_empty() { + vec![crate::atuin_daemon::semantic::OutputRange { start: 0, end: 999 }] + } else { + ranges.to_vec() + }; + + let mut ranges = ranges + .into_iter() + .filter_map(|range| normalize_line_range(range.start, range.end, lines.len())) + .collect::<Vec<_>>(); + ranges.sort_unstable_by_key(|(start, _)| *start); + + let mut merged: Vec<(usize, usize)> = Vec::new(); + for (start, end) in ranges { + match merged.last_mut() { + Some((_, merged_end)) if start <= merged_end.saturating_add(1) => { + *merged_end = (*merged_end).max(end); + } + _ => merged.push((start, end)), + } + } + + merged + .into_iter() + .flat_map(|(start, end)| { + lines[start..=end] + .iter() + .enumerate() + .map(move |(offset, line)| OutputLine { + line_number: (start + offset + 1) as u64, + content: (*line).to_string(), + }) + }) + .collect() +} + +fn normalize_line_range(start: i64, end: i64, line_count: usize) -> Option<(usize, usize)> { + let line_count = i64::try_from(line_count).ok()?; + let start = if start < 0 { line_count + start } else { start }; + let end = if end < 0 { line_count + end } else { end }; + + if end < 0 || start >= line_count { + return None; + } + + let start = start.max(0); + let end = end.min(line_count - 1); + + (start <= end).then_some((start as usize, end as usize)) +} + +fn log_record(record: &SemanticCommandRecord, message: &'static str) { + let history_id = record.capture.history_id.as_deref().unwrap_or("<missing>"); + let associated_history_id = record + .history + .as_ref() + .map(|history| history.id.to_string()); + let exit = record.history.as_ref().map(|history| history.exit); + let duration = record.history.as_ref().map(|history| history.duration); + let author = record + .history + .as_ref() + .map(|history| history.author.as_str()); + let session_id = record.capture.session_id.as_deref(); + + tracing::debug!( + history_id = %history_id, + associated_history_id = ?associated_history_id, + session_id = ?session_id, + command_bytes = record.capture.command.len(), + prompt_bytes = record.capture.prompt.len(), + output_bytes = record.capture.output.len(), + output_truncated = record.capture.output_truncated, + output_observed_bytes = record.capture.output_observed_bytes, + capture_exit_code = ?record.capture.exit_code, + history_exit = ?exit, + duration = ?duration, + author = ?author, + "{message}" + ); +} + +#[cfg(test)] +mod tests { + use super::*; + use time::OffsetDateTime; + + fn history(id: &str, session: &str, command: &str) -> History { + History { + id: HistoryId(id.to_string()), + timestamp: OffsetDateTime::UNIX_EPOCH, + duration: 0, + exit: 0, + command: command.to_string(), + cwd: String::new(), + session: session.to_string(), + hostname: String::new(), + author: String::new(), + intent: None, + deleted_at: None, + } + } + + fn capture(history_id: Option<&str>, session_id: Option<&str>, output: &str) -> CommandCapture { + CommandCapture { + prompt: String::new(), + command: String::new(), + output: output.to_string(), + exit_code: None, + history_id: history_id.map(str::to_string), + session_id: session_id.map(str::to_string), + output_truncated: false, + output_observed_bytes: output.len() as u64, + } + } + + fn command_output(state: &mut SemanticState, history_id: &str) -> CommandOutputReply { + state.command_output(&CommandOutputRequest { + history_id: history_id.to_string(), + ranges: Vec::new(), + }) + } + + fn output_line(line_number: u64, content: &str) -> OutputLine { + OutputLine { + line_number, + content: content.to_string(), + } + } + + #[test] + fn drops_capture_without_history_id() { + let mut state = SemanticState::default(); + + assert!(!state.record_capture(capture(None, Some("session-1"), "output"))); + assert!(!command_output(&mut state, "id-1").found); + assert_eq!(state.record_count(), 0); + } + + #[test] + fn stores_capture_by_session_and_history_id() { + let mut state = SemanticState::default(); + + assert!(state.record_capture(capture(Some("id-1"), Some("session-1"), "output"))); + + let reply = command_output(&mut state, "id-1"); + assert!(reply.found); + assert_eq!(reply.total_bytes, 6); + assert_eq!(reply.output_observed_bytes, 6); + assert_eq!(reply.lines, vec![output_line(1, "output")]); + } + + #[test] + fn uses_pending_history_session_when_capture_session_is_missing() { + let mut state = SemanticState::default(); + + state.record_history(history("id-1", "session-from-history", "cargo test")); + assert!(state.record_capture(capture(Some("id-1"), None, "output"))); + + assert!( + state + .sessions + .contains_key(&SessionId("session-from-history".to_string())) + ); + assert!(command_output(&mut state, "id-1").found); + } + + #[test] + fn associates_history_by_id_after_capture_arrives() { + let mut state = SemanticState::default(); + + assert!(state.record_capture(capture(Some("id-1"), Some("session-1"), "output"))); + state.record_history(history("id-1", "session-1", "different command")); + + let capture_ref = state + .history_index + .get(&HistoryId("id-1".to_string())) + .unwrap(); + let stored = state + .sessions + .get(&capture_ref.session_id) + .unwrap() + .stored_capture(capture_ref.capture_id) + .unwrap(); + assert!(stored.record.history.is_some()); + } + + #[test] + fn evicts_oldest_command_when_session_ring_is_full() { + let mut state = SemanticState::default(); + + for index in 0..=MAX_COMMANDS_PER_SESSION { + assert!(state.record_capture(capture( + Some(&format!("id-{index}")), + Some("session-1"), + "output", + ))); + } + + assert!(!command_output(&mut state, "id-0").found); + assert!(command_output(&mut state, &format!("id-{MAX_COMMANDS_PER_SESSION}")).found); + assert_eq!(state.record_count(), MAX_COMMANDS_PER_SESSION); + } + + #[test] + fn evicts_oldest_session_after_lru_limit() { + let mut state = SemanticState::default(); + + for index in 0..MAX_SESSIONS { + assert!(state.record_capture(capture( + Some(&format!("id-{index}")), + Some(&format!("session-{index}")), + "output", + ))); + } + assert!(command_output(&mut state, "id-0").found); + + assert!(state.record_capture(capture(Some("new-id"), Some("new-session"), "output",))); + + assert!(command_output(&mut state, "id-0").found); + assert!(!command_output(&mut state, "id-1").found); + assert!(command_output(&mut state, "new-id").found); + assert_eq!(state.sessions.len(), MAX_SESSIONS); + } + + #[test] + fn evicts_by_session_byte_limit() { + let mut session = SessionCaptures::default(); + let first_output = "x".repeat(10); + let second_output = "y"; + let (_, evicted_first) = session.push_with_limits( + HistoryId("first".to_string()), + SemanticCommandRecord { + capture: capture(Some("first"), Some("session-1"), &first_output), + history: None, + }, + MAX_COMMANDS_PER_SESSION, + 10, + ); + assert!(evicted_first.is_empty()); + + let (_, evicted_second) = session.push_with_limits( + HistoryId("second".to_string()), + SemanticCommandRecord { + capture: capture(Some("second"), Some("session-1"), second_output), + history: None, + }, + MAX_COMMANDS_PER_SESSION, + 10, + ); + + assert_eq!(evicted_second.len(), 1); + assert_eq!(evicted_second[0].history_id, HistoryId("first".to_string())); + assert_eq!(session.records.len(), 1); + assert_eq!(session.output_bytes, 1); + } + + #[test] + fn command_output_reports_truncation_metadata() { + let mut state = SemanticState::default(); + let mut capture = capture(Some("id-1"), Some("session-1"), "partial"); + capture.output_truncated = true; + capture.output_observed_bytes = 1024; + + assert!(state.record_capture(capture)); + + let reply = command_output(&mut state, "id-1"); + assert!(reply.output_truncated); + assert_eq!(reply.total_bytes, 7); + assert_eq!(reply.output_observed_bytes, 1024); + } + + #[test] + fn output_ranges_are_line_based_inclusive_and_support_negative_offsets() { + let output = "zero\none\ntwo\nthree\nfour"; + let ranges = vec![ + crate::atuin_daemon::semantic::OutputRange { start: 1, end: 2 }, + crate::atuin_daemon::semantic::OutputRange { start: -2, end: -1 }, + ]; + + assert_eq!( + select_output_ranges(output, &ranges), + vec![ + output_line(2, "one"), + output_line(3, "two"), + output_line(4, "three"), + output_line(5, "four"), + ] + ); + } + + #[test] + fn output_ranges_merge_overlaps_and_adjacent_ranges() { + let output = (0..100) + .map(|n| format!("line {n}")) + .collect::<Vec<_>>() + .join("\n"); + let ranges = vec![ + crate::atuin_daemon::semantic::OutputRange { start: 0, end: 100 }, + crate::atuin_daemon::semantic::OutputRange { + start: -100, + end: -1, + }, + ]; + + let selected = select_output_ranges(&output, &ranges); + + assert_eq!(selected.len(), 100); + assert_eq!(selected.first(), Some(&output_line(1, "line 0"))); + assert_eq!(selected.last(), Some(&output_line(100, "line 99"))); + } + + #[test] + fn output_ranges_can_leave_gaps_for_client_formatting() { + let output = "zero\none\ntwo\nthree\nfour"; + let ranges = vec![ + crate::atuin_daemon::semantic::OutputRange { start: 0, end: 1 }, + crate::atuin_daemon::semantic::OutputRange { start: 4, end: 4 }, + ]; + + assert_eq!( + select_output_ranges(output, &ranges), + vec![ + output_line(1, "zero"), + output_line(2, "one"), + output_line(5, "four"), + ] + ); + } + + #[test] + fn empty_output_ranges_default_to_first_thousand_lines() { + let output = (0..1001) + .map(|n| format!("line {n}")) + .collect::<Vec<_>>() + .join("\n"); + + let selected = select_output_ranges(&output, &[]); + + assert_eq!(selected.len(), 1000); + assert_eq!(selected.first(), Some(&output_line(1, "line 0"))); + assert_eq!(selected.last(), Some(&output_line(1000, "line 999"))); + } + + #[test] + fn output_ranges_skip_ranges_fully_outside_output() { + let output = "zero\none\ntwo"; + let ranges = vec![ + crate::atuin_daemon::semantic::OutputRange { start: 10, end: 20 }, + crate::atuin_daemon::semantic::OutputRange { + start: -20, + end: -10, + }, + ]; + + assert_eq!(select_output_ranges(output, &ranges), Vec::new()); + } +} diff --git a/crates/turtle/src/atuin_daemon/components/sync.rs b/crates/turtle/src/atuin_daemon/components/sync.rs new file mode 100644 index 00000000..c76fb71b --- /dev/null +++ b/crates/turtle/src/atuin_daemon/components/sync.rs @@ -0,0 +1,279 @@ +//! Sync component. +//! +//! Handles periodic synchronization with the Atuin cloud server. + +use std::time::Duration; + +use eyre::Result; +use rand::Rng; +use tokio::sync::mpsc; +use tokio::time::{self, MissedTickBehavior}; + +use crate::atuin_client::{history::store::HistoryStore, record::sync, settings::Settings}; + +use crate::atuin_daemon::{ + daemon::{Component, DaemonHandle}, + events::DaemonEvent, +}; + +/// Commands that can be sent to the sync task. +enum SyncCommand { + /// Trigger an immediate sync. + ForceSync, + /// Stop the sync loop. + Stop, +} + +/// Sync state - tracks whether we're in normal operation or retrying after failure. +#[derive(Clone, Copy, PartialEq, Eq)] +enum SyncState { + /// Normal operation. Periodic syncs only run if auto_sync is enabled. + Idle, + /// Retrying after a sync failure. Retries continue regardless of auto_sync + /// until the sync succeeds. + Retrying, +} + +/// Sync component - handles periodic cloud synchronization. +/// +/// This component: +/// - Runs a background sync loop on a configurable interval +/// - Implements exponential backoff on sync failures +/// - Responds to ForceSync events for immediate sync +/// - Emits SyncCompleted/SyncFailed events +pub struct SyncComponent { + task_handle: Option<tokio::task::JoinHandle<()>>, + command_tx: Option<mpsc::Sender<SyncCommand>>, +} + +impl SyncComponent { + /// Create a new sync component. + pub fn new() -> Self { + Self { + task_handle: None, + command_tx: None, + } + } +} + +impl Default for SyncComponent { + fn default() -> Self { + Self::new() + } +} + +#[tonic::async_trait] +impl Component for SyncComponent { + fn name(&self) -> &'static str { + "sync" + } + + async fn start(&mut self, handle: DaemonHandle) -> Result<()> { + let (cmd_tx, cmd_rx) = mpsc::channel(16); + self.command_tx = Some(cmd_tx); + + // Spawn the sync loop with its own copy of the handle + self.task_handle = Some(tokio::spawn(sync_loop(handle, cmd_rx))); + + tracing::info!("sync component started"); + Ok(()) + } + + async fn handle_event(&mut self, event: &DaemonEvent) -> Result<()> { + if let DaemonEvent::ForceSync = event { + tracing::info!("force sync requested"); + if let Some(tx) = &self.command_tx { + let _ = tx.send(SyncCommand::ForceSync).await; + } + } + Ok(()) + } + + async fn stop(&mut self) -> Result<()> { + if let Some(tx) = &self.command_tx { + let _ = tx.send(SyncCommand::Stop).await; + } + if let Some(handle) = self.task_handle.take() { + // Give the task a moment to shut down gracefully + let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await; + } + tracing::info!("sync component stopped"); + Ok(()) + } +} + +/// The main sync loop. +/// +/// This runs in a spawned task and handles periodic sync as well as +/// force sync requests. +async fn sync_loop(handle: DaemonHandle, mut cmd_rx: mpsc::Receiver<SyncCommand>) { + tracing::info!("sync loop starting"); + + // Clone settings since we need them across await points + let settings = handle.settings().await.clone(); + let host_id = match Settings::host_id().await { + Ok(id) => id, + Err(e) => { + tracing::error!("failed to get host id, sync disabled: {e}"); + return; + } + }; + + // Create the stores we need + let encryption_key = *handle.encryption_key(); + let history_store = HistoryStore::new(handle.store().clone(), host_id, encryption_key); + + // Don't backoff by more than 30 mins (with a random jitter of up to 1 min) + let max_interval: f64 = 60.0 * 30.0 + rand::thread_rng().gen_range(0.0..60.0); + + let mut ticker = time::interval(time::Duration::from_secs(settings.daemon.sync_frequency)); + + // IMPORTANT: without this, if we miss ticks because a sync takes ages or is otherwise delayed, + // we may end up running a lot of syncs in a hot loop. + ticker.set_missed_tick_behavior(MissedTickBehavior::Skip); + + let mut sync_state = SyncState::Idle; + + loop { + tokio::select! { + _ = ticker.tick() => { + let settings = handle.settings().await; + + // Skip periodic ticks if auto_sync is disabled AND we're not retrying + // a previous failure. Retries must continue regardless of auto_sync. + if !settings.auto_sync && sync_state == SyncState::Idle { + tracing::debug!("auto_sync disabled, skipping periodic sync tick"); + continue; + } + + sync_state = do_sync_tick( + &handle, + &history_store, + &mut ticker, + max_interval, + &settings, + ).await; + } + cmd = cmd_rx.recv() => { + match cmd { + Some(SyncCommand::ForceSync) => { + tracing::info!("executing force sync"); + let settings = handle.settings().await; + sync_state = do_sync_tick( + &handle, + &history_store, + &mut ticker, + max_interval, + &settings, + ).await; + } + Some(SyncCommand::Stop) | None => { + tracing::info!("sync loop stopping"); + break; + } + } + } + } + } +} + +/// Execute a single sync tick. +/// +/// Returns the new sync state: `Idle` on success, `Retrying` on failure. +async fn do_sync_tick( + handle: &DaemonHandle, + history_store: &HistoryStore, + ticker: &mut time::Interval, + max_interval: f64, + settings: &Settings, +) -> SyncState { + tracing::info!("sync tick"); + + // Check if logged in + let logged_in = match settings.logged_in().await { + Ok(v) => v, + Err(e) => { + tracing::warn!("failed to check login status, skipping sync tick: {e}"); + return SyncState::Idle; + } + }; + + if !logged_in { + tracing::debug!("not logged in, skipping sync tick"); + return SyncState::Idle; + } + + // Perform the sync + let res = sync::sync(settings, handle.store(), handle.encryption_key()).await; + + match res { + Err(e) => { + tracing::error!("sync tick failed with {e}"); + + // Emit failure event + handle.emit(DaemonEvent::SyncFailed { + error: e.to_string(), + }); + + // Exponential backoff + let mut rng = rand::thread_rng(); + let mut new_interval = ticker.period().as_secs_f64() * rng.gen_range(2.0..2.2); + + if new_interval > max_interval { + new_interval = max_interval; + } + + *ticker = time::interval_at( + tokio::time::Instant::now() + Duration::from_secs(new_interval as u64), + time::Duration::from_secs(new_interval as u64), + ); + ticker.reset_after(time::Duration::from_secs(new_interval as u64)); + ticker.set_missed_tick_behavior(MissedTickBehavior::Skip); + + tracing::error!("backing off, next sync tick in {new_interval}"); + + SyncState::Retrying + } + Ok((uploaded_count, downloaded_records)) => { + tracing::info!( + uploaded = uploaded_count, + downloaded = downloaded_records.len(), + "sync complete" + ); + + // Build history from downloaded records + if let Err(e) = history_store + .incremental_build(handle.history_db(), &downloaded_records) + .await + { + tracing::error!("failed to build history from downloaded records: {e}"); + } + + // Emit the records added event (for search indexing) + handle.emit(DaemonEvent::RecordsAdded(downloaded_records.clone())); + + // Emit sync completed event + handle.emit(DaemonEvent::SyncCompleted { + uploaded: uploaded_count as usize, + downloaded: downloaded_records.len(), + }); + + // Reset backoff on success + if ticker.period().as_secs() != settings.daemon.sync_frequency { + *ticker = time::interval_at( + tokio::time::Instant::now() + + Duration::from_secs(settings.daemon.sync_frequency), + time::Duration::from_secs(settings.daemon.sync_frequency), + ); + ticker.set_missed_tick_behavior(MissedTickBehavior::Skip); + } + + // Store sync time + if let Err(e) = Settings::save_sync_time().await { + tracing::error!("failed to save sync time: {e}"); + } + + SyncState::Idle + } + } +} diff --git a/crates/turtle/src/atuin_daemon/control/mod.rs b/crates/turtle/src/atuin_daemon/control/mod.rs new file mode 100644 index 00000000..afb29c57 --- /dev/null +++ b/crates/turtle/src/atuin_daemon/control/mod.rs @@ -0,0 +1,12 @@ +//! Control module for external event injection. +//! +//! This module provides the gRPC service that allows external processes +//! (like CLI commands) to inject events into the daemon's event bus. + +mod service; + +// Include the generated proto code +tonic::include_proto!("control"); + +// Re-export the service +pub use service::ControlService; diff --git a/crates/turtle/src/atuin_daemon/control/service.rs b/crates/turtle/src/atuin_daemon/control/service.rs new file mode 100644 index 00000000..cb2ff74e --- /dev/null +++ b/crates/turtle/src/atuin_daemon/control/service.rs @@ -0,0 +1,71 @@ +//! Control service implementation. +//! +//! This gRPC service allows external processes (like CLI commands) to inject +//! events into the daemon's event bus. + +use crate::atuin_client::history::HistoryId; +use tonic::{Request, Response, Status}; +use tracing::{Level, info, instrument}; + +use super::{ + SendEventRequest, SendEventResponse, + control_server::{Control, ControlServer}, + send_event_request::Event, +}; +use crate::atuin_daemon::{daemon::DaemonHandle, events::DaemonEvent}; + +/// The Control gRPC service. +/// +/// This service is used by external processes to inject events into the daemon. +/// It's not a component - it's part of the daemon's core infrastructure. +pub struct ControlService { + handle: DaemonHandle, +} + +impl ControlService { + /// Create a new control service with the given daemon handle. + pub fn new(handle: DaemonHandle) -> Self { + Self { handle } + } + + /// Get a tonic server for this service. + pub fn into_server(self) -> ControlServer<Self> { + ControlServer::new(self) + } +} + +#[tonic::async_trait] +impl Control for ControlService { + #[instrument(skip_all, level = Level::INFO, name = "control_send_event")] + async fn send_event( + &self, + request: Request<SendEventRequest>, + ) -> Result<Response<SendEventResponse>, Status> { + let req = request.into_inner(); + + let event = req + .event + .ok_or_else(|| Status::invalid_argument("event is required"))?; + + let daemon_event = proto_event_to_daemon_event(event)?; + + info!(?daemon_event, "received control event"); + self.handle.emit(daemon_event); + + Ok(Response::new(SendEventResponse {})) + } +} + +/// Convert a proto event to a daemon event. +fn proto_event_to_daemon_event(event: Event) -> Result<DaemonEvent, Status> { + match event { + Event::HistoryPruned(_) => Ok(DaemonEvent::HistoryPruned), + Event::HistoryRebuilt(_) => Ok(DaemonEvent::HistoryRebuilt), + Event::HistoryDeleted(e) => Ok(DaemonEvent::HistoryDeleted { + ids: e.ids.into_iter().map(HistoryId).collect(), + }), + Event::ForceSync(_) => Ok(DaemonEvent::ForceSync), + Event::SettingsReloaded(_) => Ok(DaemonEvent::SettingsReloaded), + Event::Shutdown(_) => Ok(DaemonEvent::ShutdownRequested), + } +} diff --git a/crates/turtle/src/atuin_daemon/daemon.rs b/crates/turtle/src/atuin_daemon/daemon.rs new file mode 100644 index 00000000..77c0d8a5 --- /dev/null +++ b/crates/turtle/src/atuin_daemon/daemon.rs @@ -0,0 +1,458 @@ +//! Core daemon infrastructure. +//! +//! This module provides the foundational types for building the atuin daemon: +//! +//! - [`DaemonState`]: Shared state owned by the daemon +//! - [`DaemonHandle`]: A lightweight, cloneable handle for accessing daemon state +//! - [`Component`]: A trait for implementing daemon components +//! - [`Daemon`]: The main daemon orchestrator +//! - [`DaemonBuilder`]: Builder for constructing and configuring the daemon + +use std::sync::Arc; + +use crate::atuin_client::{ + database::Sqlite as HistoryDatabase, encryption, record::sqlite_store::SqliteStore, + settings::Settings, +}; +use eyre::{Context, Result}; +use tokio::sync::{RwLock, broadcast}; + +use crate::atuin_daemon::events::DaemonEvent; + +// ============================================================================ +// DaemonState +// ============================================================================ + +/// Shared state owned by the daemon. +/// +/// This contains all the resources that components and services need access to. +/// The state is wrapped in an `Arc` and accessed via [`DaemonHandle`]. +pub struct DaemonState { + // Event bus + event_tx: broadcast::Sender<DaemonEvent>, + + // Configuration (mutable - can be reloaded) + settings: RwLock<Settings>, + + // Encryption key (immutable - derived at startup) + encryption_key: [u8; 32], + + // Database handles + history_db: HistoryDatabase, + store: SqliteStore, +} + +// ============================================================================ +// DaemonHandle +// ============================================================================ + +/// A lightweight handle to the daemon's shared state. +/// +/// This is the primary way for components, gRPC services, and spawned tasks to +/// interact with the daemon. It provides access to: +/// +/// - Event emission and subscription +/// - Configuration (settings, encryption key) +/// - Database handles +/// +/// The handle is cheaply cloneable (wraps an `Arc`) and can be freely passed +/// around to any code that needs daemon access. +/// +/// # Example +/// +/// ```ignore +/// // Emit an event +/// handle.emit(DaemonEvent::HistoryPruned); +/// +/// // Access settings +/// let settings = handle.settings().await; +/// let sync_freq = settings.daemon.sync_frequency; +/// +/// // Access database +/// let history = handle.history_db().load(id).await?; +/// ``` +#[derive(Clone)] +pub struct DaemonHandle { + state: Arc<DaemonState>, +} + +impl DaemonHandle { + // ---- Events ---- + + /// Emit an event to the daemon's event bus. + /// + /// This is fire-and-forget - if no receivers are listening (which shouldn't + /// happen in normal operation), the event is dropped silently. + pub fn emit(&self, event: DaemonEvent) { + if let Err(e) = self.state.event_tx.send(event) { + tracing::warn!("failed to emit event (no receivers?): {e}"); + } + } + + /// Subscribe to the event bus. + /// + /// Returns a receiver that will receive all events emitted after this call. + /// Useful for components that need to listen for events outside of the + /// normal `handle_event` callback flow. + pub fn subscribe(&self) -> broadcast::Receiver<DaemonEvent> { + self.state.event_tx.subscribe() + } + + /// Request graceful shutdown of the daemon. + pub fn shutdown(&self) { + self.emit(DaemonEvent::ShutdownRequested); + } + + // ---- Configuration ---- + + /// Get the current settings. + /// + /// This acquires a read lock on the settings. For most use cases, clone + /// the settings if you need to hold onto them. + pub async fn settings(&self) -> tokio::sync::RwLockReadGuard<'_, Settings> { + self.state.settings.read().await + } + + /// Reload settings from disk and emit a SettingsReloaded event. + /// + /// Components listening for `SettingsReloaded` can then re-read settings + /// via `handle.settings()` to pick up the changes. + pub async fn reload_settings(&self) -> Result<()> { + let new_settings = Settings::new()?; + self.apply_settings(new_settings).await; + Ok(()) + } + + /// Apply already-loaded settings and emit a SettingsReloaded event. + /// + /// Use this when settings have already been loaded (e.g., from a file watcher) + /// to avoid parsing the config file twice. + pub async fn apply_settings(&self, settings: Settings) { + *self.state.settings.write().await = settings; + self.emit(DaemonEvent::SettingsReloaded); + tracing::info!("settings applied"); + } + + /// Get the encryption key. + pub fn encryption_key(&self) -> &[u8; 32] { + &self.state.encryption_key + } + + // ---- Database ---- + + /// Get a reference to the history database. + pub fn history_db(&self) -> &HistoryDatabase { + &self.state.history_db + } + + /// Get a reference to the record store. + pub fn store(&self) -> &SqliteStore { + &self.state.store + } +} + +impl std::fmt::Debug for DaemonHandle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DaemonHandle").finish_non_exhaustive() + } +} + +// ============================================================================ +// Component Trait +// ============================================================================ + +/// A daemon component that handles a specific domain. +/// +/// Components are the building blocks of the daemon. Each component: +/// +/// - Has a unique name for logging and debugging +/// - Can optionally expose gRPC services +/// - Receives a [`DaemonHandle`] on startup for accessing daemon resources +/// - Handles events from the event bus +/// - Performs cleanup on shutdown +/// +/// # Lifecycle +/// +/// 1. **Construction**: Component is created (usually via `new()`) +/// 2. **Start**: `start()` is called with a [`DaemonHandle`] +/// 3. **Running**: `handle_event()` is called for each event on the bus +/// 4. **Shutdown**: `stop()` is called for cleanup +/// +/// # Example +/// +/// ```ignore +/// pub struct MyComponent { +/// handle: Option<DaemonHandle>, +/// } +/// +/// #[async_trait] +/// impl Component for MyComponent { +/// fn name(&self) -> &'static str { "my-component" } +/// +/// async fn start(&mut self, handle: DaemonHandle) -> Result<()> { +/// self.handle = Some(handle); +/// Ok(()) +/// } +/// +/// async fn handle_event(&mut self, event: &DaemonEvent) -> Result<()> { +/// match event { +/// DaemonEvent::SomeEvent => { +/// // Handle the event +/// if let Some(handle) = &self.handle { +/// handle.emit(DaemonEvent::ResponseEvent); +/// } +/// } +/// _ => {} +/// } +/// Ok(()) +/// } +/// +/// async fn stop(&mut self) -> Result<()> { +/// Ok(()) +/// } +/// } +/// ``` +#[tonic::async_trait] +pub trait Component: Send + Sync { + /// Human-readable name for logging and debugging. + fn name(&self) -> &'static str; + + /// Called once at startup. + /// + /// Store the handle if you need to emit events or access daemon resources + /// later. The handle is cheaply cloneable, so feel free to clone it for + /// spawned tasks. + async fn start(&mut self, handle: DaemonHandle) -> Result<()>; + + /// Handle an incoming event. + /// + /// Called for every event on the bus. To emit new events in response, + /// use the handle stored during `start()`. Events emitted here will be + /// processed in subsequent event loop iterations. + async fn handle_event(&mut self, event: &DaemonEvent) -> Result<()>; + + /// Called on graceful shutdown. + /// + /// Use this to clean up resources, abort spawned tasks, etc. + async fn stop(&mut self) -> Result<()>; +} + +// ============================================================================ +// Daemon +// ============================================================================ + +/// The main daemon orchestrator. +/// +/// The daemon manages components, runs the event loop, and coordinates startup +/// and shutdown. It is constructed via [`DaemonBuilder`]. +/// +/// # Event Loop +/// +/// The daemon runs a simple event loop: +/// +/// 1. Wait for an event on the bus +/// 2. Dispatch the event to all components (in registration order) +/// 3. Components may emit new events in response +/// 4. Repeat until `ShutdownRequested` is received +/// +/// Events emitted during handling are queued and processed in subsequent +/// iterations, ensuring the loop eventually drains. +pub struct Daemon { + components: Vec<Box<dyn Component>>, + handle: DaemonHandle, +} + +impl Daemon { + /// Create a new daemon builder. + pub fn builder(settings: Settings) -> DaemonBuilder { + DaemonBuilder::new(settings) + } + + /// Get a clone of the daemon handle. + /// + /// The handle can be used to emit events, access settings, etc. + pub fn handle(&self) -> DaemonHandle { + self.handle.clone() + } + + /// Start all components. + /// + /// This must be called before `run_event_loop()`. It initializes all + /// registered components with the daemon handle. + pub async fn start_components(&mut self) -> Result<()> { + for component in &mut self.components { + tracing::info!(component = component.name(), "starting component"); + component + .start(self.handle.clone()) + .await + .with_context(|| format!("failed to start component: {}", component.name()))?; + } + Ok(()) + } + + /// Run the daemon event loop. + /// + /// This processes events until a ShutdownRequested event is received. + /// Components must be started first via `start_components()`. + pub async fn run_event_loop(&mut self) -> Result<()> { + let mut event_rx = self.handle.subscribe(); + loop { + match event_rx.recv().await { + Ok(DaemonEvent::ShutdownRequested) => { + tracing::info!("shutdown requested, stopping daemon"); + break; + } + Ok(event) => { + tracing::debug!(?event, "processing event"); + self.dispatch_event(&event).await; + } + Err(broadcast::error::RecvError::Lagged(n)) => { + tracing::warn!( + skipped = n, + "event receiver lagged, some events were dropped" + ); + } + Err(broadcast::error::RecvError::Closed) => { + tracing::info!("event bus closed, stopping daemon"); + break; + } + } + } + Ok(()) + } + + /// Stop all components. + /// + /// This performs graceful shutdown of all components. + pub async fn stop_components(&mut self) { + for component in &mut self.components { + tracing::info!(component = component.name(), "stopping component"); + if let Err(e) = component.stop().await { + tracing::error!( + component = component.name(), + error = ?e, + "error stopping component" + ); + } + } + tracing::info!("all components stopped"); + } + + /// Run the daemon. + /// + /// This is a convenience method that starts components, runs the event loop, + /// and handles shutdown. It does not return until the daemon is shut down. + pub async fn run(mut self) -> Result<()> { + self.start_components().await?; + self.run_event_loop().await?; + self.stop_components().await; + tracing::info!("daemon stopped"); + Ok(()) + } + + async fn dispatch_event(&mut self, event: &DaemonEvent) { + for component in &mut self.components { + if let Err(e) = component.handle_event(event).await { + tracing::error!( + component = component.name(), + error = ?e, + "error handling event" + ); + } + } + } +} + +// ============================================================================ +// DaemonBuilder +// ============================================================================ + +/// Builder for constructing a [`Daemon`]. +/// +/// # Example +/// +/// ```ignore +/// let daemon = Daemon::builder(settings) +/// .store(store) +/// .history_db(history_db) +/// .component(HistoryComponent::new()) +/// .component(SearchComponent::new()) +/// .component(SyncComponent::new()) +/// .build() +/// .await?; +/// +/// daemon.run().await?; +/// ``` +pub struct DaemonBuilder { + settings: Settings, + store: Option<SqliteStore>, + history_db: Option<HistoryDatabase>, + components: Vec<Box<dyn Component>>, +} + +impl DaemonBuilder { + /// Create a new daemon builder with the given settings. + pub fn new(settings: Settings) -> Self { + Self { + settings, + store: None, + history_db: None, + components: Vec::new(), + } + } + + /// Set the record store. + pub fn store(mut self, store: SqliteStore) -> Self { + self.store = Some(store); + self + } + + /// Set the history database. + pub fn history_db(mut self, db: HistoryDatabase) -> Self { + self.history_db = Some(db); + self + } + + /// Register a component. + /// + /// Components are started in registration order and stopped in reverse order. + pub fn component(mut self, component: impl Component + 'static) -> Self { + self.components.push(Box::new(component)); + self + } + + /// Build the daemon. + /// + /// This loads the encryption key and creates the daemon state. + pub async fn build(self) -> Result<Daemon> { + let store = self.store.ok_or_else(|| eyre::eyre!("store is required"))?; + let history_db = self + .history_db + .ok_or_else(|| eyre::eyre!("history_db is required"))?; + + // Load encryption key + let encryption_key: [u8; 32] = encryption::load_key(&self.settings) + .context("could not load encryption key")? + .into(); + + // Create the event bus + let (event_tx, _) = broadcast::channel(64); + + // Create the shared state + let state = Arc::new(DaemonState { + event_tx, + settings: RwLock::new(self.settings), + encryption_key, + history_db, + store, + }); + + // Create the handle (just a reference to the state) + let handle = DaemonHandle { state }; + + Ok(Daemon { + components: self.components, + handle, + }) + } +} diff --git a/crates/turtle/src/atuin_daemon/events.rs b/crates/turtle/src/atuin_daemon/events.rs new file mode 100644 index 00000000..9a398925 --- /dev/null +++ b/crates/turtle/src/atuin_daemon/events.rs @@ -0,0 +1,74 @@ +//! Daemon events. +//! +//! Events are the primary communication mechanism within the daemon. +//! Components emit events to notify others of state changes, and handle +//! events to react to changes elsewhere in the system. +//! +//! External processes (like CLI commands) can also inject events via the +//! Control gRPC service. + +use crate::atuin_client::history::{History, HistoryId}; +use crate::atuin_common::record::RecordId; + +/// Events that flow through the daemon's event bus. +/// +/// Events are broadcast to all components. Each component decides which +/// events it cares about in its `handle_event` implementation. +#[derive(Debug, Clone)] +pub enum DaemonEvent { + // ---- History lifecycle ---- + /// A command has started running. + HistoryStarted(History), + + /// A command has finished running. + HistoryEnded(History), + + // ---- Sync ---- + /// Records were synced from the server. + /// + /// The search component uses this to update its index with new history. + RecordsAdded(Vec<RecordId>), + + /// Sync completed successfully. + SyncCompleted { + /// Number of records uploaded. + uploaded: usize, + /// Number of records downloaded. + downloaded: usize, + }, + + /// Sync failed. + SyncFailed { + /// Error message describing what went wrong. + error: String, + }, + + /// Request an immediate sync (external trigger). + ForceSync, + + // ---- External commands ---- + /// History was pruned - search index needs a full rebuild. + /// + /// Emitted when the user runs `atuin history prune` or similar. + HistoryPruned, + + /// History was rebuilt - search index needs a full rebuild. + /// + /// Emitted when the user runs `atuin store rebuild history` or similar. + HistoryRebuilt, + + /// Specific history items were deleted. + /// + /// The search component should remove these from its index. + HistoryDeleted { + /// IDs of the deleted history entries. + ids: Vec<HistoryId>, + }, + + /// Settings have changed, components should reload if needed. + SettingsReloaded, + + // ---- Lifecycle ---- + /// Request graceful shutdown of the daemon. + ShutdownRequested, +} diff --git a/crates/turtle/src/atuin_daemon/history/mod.rs b/crates/turtle/src/atuin_daemon/history/mod.rs new file mode 100644 index 00000000..b71853df --- /dev/null +++ b/crates/turtle/src/atuin_daemon/history/mod.rs @@ -0,0 +1,6 @@ +//! History module for the daemon gRPC history service. +//! +//! This module contains the proto-generated types for the history gRPC service. + +// Include the generated proto code +tonic::include_proto!("history"); diff --git a/crates/turtle/src/atuin_daemon/mod.rs b/crates/turtle/src/atuin_daemon/mod.rs new file mode 100644 index 00000000..b05eb95c --- /dev/null +++ b/crates/turtle/src/atuin_daemon/mod.rs @@ -0,0 +1,128 @@ +use crate::atuin_client::database::Sqlite as HistoryDatabase; +use crate::atuin_client::record::sqlite_store::SqliteStore; +use crate::atuin_client::settings::{Settings, watcher::global_settings_watcher}; +use eyre::Result; + +pub mod client; +pub mod components; +pub mod control; +pub mod daemon; +pub mod events; +pub mod history; +pub mod search; +pub mod semantic; +pub mod server; + +// Re-export core daemon types for convenience +pub use daemon::{Component, Daemon, DaemonBuilder, DaemonHandle}; +pub use events::DaemonEvent; + +// Re-export components +pub use components::{HistoryComponent, SearchComponent, SemanticComponent, SyncComponent}; + +// Re-export client helpers +pub use client::{ControlClient, SemanticClient, emit_event, emit_event_with_settings}; + +/// Boot the daemon using the new component-based architecture. +/// +/// This creates a daemon with the standard components (history, search, sync), +/// starts the gRPC server with their services, and runs the event loop. +pub async fn boot( + settings: Settings, + store: SqliteStore, + history_db: HistoryDatabase, +) -> Result<()> { + // Create the components + let history_component = HistoryComponent::new(); + let search_component = SearchComponent::new(); + let semantic_component = SemanticComponent::new(); + let sync_component = SyncComponent::new(); + + // Get the gRPC services before moving components into the daemon + // (The services share state with the components via Arc) + let history_service = history_component.grpc_service(); + let search_service = search_component.grpc_service(); + let semantic_service = semantic_component.grpc_service(); + + // Build the daemon + let mut daemon = Daemon::builder(settings.clone()) + .store(store) + .history_db(history_db) + .component(history_component) + .component(search_component) + .component(semantic_component) + .component(sync_component) + .build() + .await?; + + // Get a handle for the control service and gRPC server shutdown + let handle = daemon.handle(); + + // Create the control service + let control_service = control::ControlService::new(handle.clone()); + + // Start all components first (so gRPC services can work) + daemon.start_components().await?; + + // Spawn config file watcher to reload settings on changes + if let Ok(watcher) = global_settings_watcher() { + let mut settings_rx = watcher.subscribe(); + let watcher_handle = handle.clone(); + tokio::spawn(async move { + tracing::info!("config file watcher started"); + while settings_rx.changed().await.is_ok() { + // Use the already-loaded settings from the watcher + // (avoids parsing the config file twice) + let new_settings = (*settings_rx.borrow()).clone(); + watcher_handle.apply_settings((*new_settings).clone()).await; + } + tracing::debug!("config file watcher stopped"); + }); + } else { + tracing::warn!( + "failed to start config file watcher; settings changes will require daemon restart" + ); + } + + // Spawn signal handler to emit ShutdownRequested on Ctrl+C/SIGTERM + let signal_handle = handle.clone(); + tokio::spawn(async move { + shutdown_signal().await; + tracing::info!("received shutdown signal"); + signal_handle.shutdown(); + }); + + // Start the gRPC server in the background + server::run_grpc_server( + settings, + history_service, + search_service, + semantic_service, + control_service.into_server(), + handle, + ) + .await?; + + // Run the daemon event loop + daemon.run_event_loop().await?; + + // Stop all components on shutdown + daemon.stop_components().await; + + tracing::info!("daemon shut down complete"); + Ok(()) +} + +/// Wait for a shutdown signal (Ctrl+C or SIGTERM). +#[cfg(unix)] +async fn shutdown_signal() { + let mut term = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) + .expect("failed to register sigterm handler"); + let mut int = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt()) + .expect("failed to register sigint handler"); + + tokio::select! { + _ = term.recv() => {}, + _ = int.recv() => {}, + } +} diff --git a/crates/turtle/src/atuin_daemon/search/index.rs b/crates/turtle/src/atuin_daemon/search/index.rs new file mode 100644 index 00000000..df627e1b --- /dev/null +++ b/crates/turtle/src/atuin_daemon/search/index.rs @@ -0,0 +1,684 @@ +//! Search index with frecency-based ranking. +//! +//! This module provides a deduplicated search index where each unique command +//! is stored once, with metadata about all its invocations. This enables: +//! +//! - Efficient fuzzy matching (fewer items to match) +//! - Frecency-based ranking (frequency + recency) +//! - Dynamic filtering by directory, host, session, etc. + +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; + +use crate::atuin_client::settings::Search; +use crate::{ + atuin_client::history::{History, is_known_agent}, + atuin_daemon::components::search::with_trailing_slash, +}; +use atuin_nucleo::{Injector, Nucleo, pattern}; +use dashmap::DashMap; +use lasso::{Spur, ThreadedRodeo}; +use time::OffsetDateTime; +use tokio::sync::RwLock; +use tracing::{Level, instrument}; +use uuid::Uuid; + +/// Parse a UUID string into a 16-byte array. +/// Returns None if the string is not a valid UUID. +fn parse_uuid_bytes(s: &str) -> Option<[u8; 16]> { + Uuid::parse_str(s).ok().map(|u| *u.as_bytes()) +} + +/// Format a 16-byte array as a UUID string. +fn format_uuid_bytes(bytes: &[u8; 16]) -> String { + Uuid::from_bytes(*bytes).to_string() +} + +/// Pre-computed frecency data for O(1) lookup. +#[derive(Debug, Clone, Default)] +pub struct FrecencyData { + /// Total number of times this command was used. + pub count: u32, + /// Most recent usage timestamp (unix seconds). + pub last_used: i64, +} + +impl FrecencyData { + /// Record a new usage of this command. + pub fn record_use(&mut self, timestamp: i64) { + self.count += 1; + if timestamp > self.last_used { + self.last_used = timestamp; + } + } + + /// Compute frecency score based on count and recency. + /// + /// Uses a decay function where more recent commands score higher. + /// The formula balances frequency (how often) with recency (how recent). + /// + /// Multipliers allow tuning the relative weights: + /// - `recency_mul`: Multiplier for recency score (default: 1.0) + /// - `frequency_mul`: Multiplier for frequency score (default: 1.0) + /// + /// A multiplier of 0.0 disables that component, 1.0 is unchanged, 2.0 doubles weight. + /// Values like 0.5 reduce weight by half, 1.5 increases by 50%, etc. + #[instrument(level = tracing::Level::TRACE, name = "index_frecency_compute")] + pub fn compute(&self, now: i64, recency_mul: f64, frequency_mul: f64) -> u32 { + if self.count == 0 { + return 0; + } + + // Time-based decay: score decreases as time passes + let age_seconds = (now - self.last_used).max(0) as u64; + let age_hours = age_seconds / 3600; + + // Decay factor: recent commands get higher scores + // - Last hour: multiplier ~1.0 + // - Last day: multiplier ~0.5 + // - Last week: multiplier ~0.1 + // - Older: multiplier approaches 0 + let recency_score: f64 = match age_hours { + 0 => 100.0, + 1..=6 => 90.0, + 7..=24 => 70.0, + 25..=72 => 50.0, + 73..=168 => 30.0, + 169..=720 => 15.0, + _ => 5.0, + }; + + // Frequency boost: more uses = higher score (with diminishing returns) + let frequency_score = ((self.count as f64).ln() * 20.0).min(100.0); + + // Apply multipliers and combine scores, then round to u32 + ((recency_score * recency_mul) + (frequency_score * frequency_mul)).round() as u32 + } +} + +/// Data for a unique command. +pub struct CommandData { + /// History ID of the most recent invocation (16-byte UUID). + most_recent_id: [u8; 16], + /// Timestamp of the most recent invocation. + most_recent_timestamp: i64, + /// Pre-computed global frecency. + pub global_frecency: FrecencyData, + + // Pre-computed indexes for O(1) filter lookups + // Using HashSet instead of DashSet since CommandData lives inside DashMap (already synchronized) + /// All directories where this command has been run (interned keys). + directories: HashSet<Spur>, + /// All hostnames where this command has been run (interned keys). + hosts: HashSet<Spur>, + /// All sessions where this command has been run (as 16-byte UUIDs). + sessions: HashSet<[u8; 16]>, +} + +impl CommandData { + /// Create a new CommandData from a history entry. + /// Returns None if the history entry has invalid UUIDs. + pub fn new(history: &History, interner: &ThreadedRodeo) -> Option<Self> { + let history_id = parse_uuid_bytes(&history.id.0)?; + let session = parse_uuid_bytes(&history.session)?; + let timestamp = history.timestamp.unix_timestamp(); + + let dir_key = interner.get_or_intern(with_trailing_slash(&history.cwd)); + let host_key = interner.get_or_intern(&history.hostname); + + let mut directories = HashSet::new(); + directories.insert(dir_key); + + let mut hosts = HashSet::new(); + hosts.insert(host_key); + + let mut sessions = HashSet::new(); + sessions.insert(session); + + let mut global_frecency = FrecencyData::default(); + global_frecency.record_use(timestamp); + + Some(Self { + most_recent_id: history_id, + most_recent_timestamp: timestamp, + global_frecency, + directories, + hosts, + sessions, + }) + } + + /// Add an invocation from a history entry. + /// Returns false if the history entry has invalid UUIDs. + pub fn add_invocation(&mut self, history: &History, interner: &ThreadedRodeo) -> bool { + let Some(history_id) = parse_uuid_bytes(&history.id.0) else { + return false; + }; + let Some(session) = parse_uuid_bytes(&history.session) else { + return false; + }; + + let timestamp = history.timestamp.unix_timestamp(); + + // Update global frecency + self.global_frecency.record_use(timestamp); + + // Update pre-computed indexes for O(1) filter lookups + let dir_key = interner.get_or_intern(with_trailing_slash(&history.cwd)); + self.directories.insert(dir_key); + self.hosts.insert(interner.get_or_intern(&history.hostname)); + self.sessions.insert(session); + + // Update most recent if this invocation is newer + if timestamp > self.most_recent_timestamp { + self.most_recent_id = history_id; + self.most_recent_timestamp = timestamp; + } + + true + } + + /// Get the most recent history ID for this command. + pub fn most_recent_id(&self) -> String { + format_uuid_bytes(&self.most_recent_id) + } + + /// Check if any invocation matches a directory filter (exact match). + /// O(1) lookup using pre-computed index. + pub fn has_invocation_in_dir(&self, dir: &str, interner: &ThreadedRodeo) -> bool { + interner + .get(dir) + .is_some_and(|spur| self.directories.contains(&spur)) + } + + /// Check if any invocation matches a directory prefix (workspace/git root). + /// O(n) where n = number of unique directories for this command. + pub fn has_invocation_in_workspace(&self, prefix: &str, interner: &ThreadedRodeo) -> bool { + self.directories + .iter() + .any(|&spur| interner.resolve(&spur).starts_with(prefix)) + } + + /// Check if any invocation matches a hostname. + /// O(1) lookup using pre-computed index. + pub fn has_invocation_on_host(&self, hostname: &str, interner: &ThreadedRodeo) -> bool { + interner + .get(hostname) + .is_some_and(|spur| self.hosts.contains(&spur)) + } + + /// Check if any invocation matches a session. + /// O(1) lookup using pre-computed index. + pub fn has_invocation_in_session(&self, session: &str) -> bool { + parse_uuid_bytes(session).is_some_and(|bytes| self.sessions.contains(&bytes)) + } +} + +/// Filter mode for search queries. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum IndexFilterMode { + /// No filtering - search all commands. + Global, + /// Filter to commands run in a specific directory. + Directory(String), + /// Filter to commands run in a workspace (directory prefix). + Workspace(String), + /// Filter to commands run on a specific host. + Host(String), + /// Filter to commands run in a specific session. + Session(String), +} + +/// Context for search queries. +#[derive(Debug, Clone, Default)] +pub struct QueryContext { + pub cwd: Option<String>, + pub git_root: Option<String>, + pub hostname: Option<String>, + pub session_id: Option<String>, +} + +/// Shareable frecency map: command -> frecency score. +/// Wrapped in Arc for zero-copy sharing with scorer callbacks. +type FrecencyMap = Arc<HashMap<Arc<str>, u32>>; + +/// A deduplicated search index with frecency-based ranking. +/// +/// Commands are stored by their text, with metadata about all invocations. +/// Nucleo handles fuzzy matching, while frecency is computed via scorer callback. +/// +/// Global frecency is precomputed by a background task and used for scoring. +/// If frecency data is not available, search still works but without frecency ranking; +/// although this should never happen due to precomputing the frecency map. +pub struct SearchIndex { + /// Map from command text to command data. + /// Using DashMap for concurrent read/write access, wrapped in Arc for sharing with scorer. + /// Keys are Arc<str> to enable zero-copy sharing with frecency_map. + commands: Arc<DashMap<Arc<str>, CommandData>>, + /// Nucleo fuzzy matcher - items are command strings. + nucleo: RwLock<Nucleo<String>>, + /// Injector for adding new commands to Nucleo. + injector: Injector<String>, + /// Precomputed global frecency map. Updated by background task. + frecency_map: RwLock<Option<FrecencyMap>>, + /// String interner for deduplicating cwd, hostname, and directory paths. + interner: Arc<ThreadedRodeo>, +} + +impl SearchIndex { + /// Create a new empty search index. + pub fn new() -> Self { + let nucleo_config = atuin_nucleo::Config::DEFAULT; + // Single column for command text + let nucleo = Nucleo::<String>::new(nucleo_config, Arc::new(|| {}), None, 1); + let injector = nucleo.injector(); + + Self { + commands: Arc::new(DashMap::new()), + nucleo: RwLock::new(nucleo), + injector, + frecency_map: RwLock::new(None), + interner: Arc::new(ThreadedRodeo::new()), + } + } + + /// Add a history entry to the index. + /// + /// If the command already exists, updates its invocation data. + /// If it's a new command, adds it to both the map and Nucleo. + pub fn add_history(&self, history: &History) { + if is_known_agent(&history.author) { + return; + } + + let command = history.command.as_str(); + + // DashMap with Arc<str> keys can be looked up with &str via Borrow trait + if let Some(mut entry) = self.commands.get_mut(command) { + // Existing command - just update invocations + entry.add_invocation(history, &self.interner); + } else { + // New command - create Arc<str> once and share it + let Some(data) = CommandData::new(history, &self.interner) else { + return; // Invalid UUIDs, skip this entry + }; + let command_arc: Arc<str> = command.into(); + self.commands.insert(Arc::clone(&command_arc), data); + // Nucleo still needs String (unavoidable copy for fuzzy matching) + self.injector.push(command_arc.to_string(), |cmd, cols| { + cols[0] = cmd.clone().into(); + }); + } + // Note: frecency_map is rebuilt by background task, not invalidated here + } + + /// Add multiple history entries to the index. + pub fn add_histories(&self, histories: &[History]) { + for history in histories { + self.add_history(history); + } + } + + /// Get the number of unique commands in the index. + pub fn command_count(&self) -> usize { + self.commands.len() + } + + /// Get the number of items in Nucleo (should match command_count). + pub async fn nucleo_item_count(&self) -> u32 { + self.nucleo.read().await.snapshot().item_count() + } + + /// Search for commands matching a query. + /// + /// Returns a list of history IDs (most recent invocation per command). + /// Uses precomputed global frecency for scoring if available. + #[instrument(skip_all, level = tracing::Level::TRACE, name = "index_search", fields(query = %query))] + pub async fn search( + &self, + query: &str, + filter_mode: IndexFilterMode, + _context: &QueryContext, + limit: u32, + ) -> Vec<String> { + let mut nucleo = self.nucleo.write().await; + + // Get precomputed frecency map (may be None if not yet computed) + let frecency_map = self.frecency_map.read().await.clone(); + + // Build filter based on mode + let filter = self.build_filter(&filter_mode); + nucleo.set_filter(filter); + + // Build scorer from precomputed frecency (or None if not available) + let scorer = Self::build_scorer(frecency_map); + nucleo.set_scorer(scorer); + + // Update pattern + nucleo.pattern.reparse( + 0, + query, + pattern::CaseMatching::Smart, + pattern::Normalization::Smart, + false, + ); + + tracing::span!(Level::TRACE, "index_search_tick").in_scope(|| { + // Tick until complete + while nucleo.tick(10).running {} + }); + + // Collect results + let snapshot = nucleo.snapshot(); + let matched_count = snapshot.matched_item_count().min(limit); + + tracing::span!(Level::TRACE, "index_search_results").in_scope(|| { + snapshot + .matched_items(..matched_count) + .filter_map(|item| { + let cmd = item.data; + // DashMap<Arc<str>, _>::get accepts &str via Borrow trait + self.commands + .get(cmd.as_str()) + .map(|data| data.most_recent_id()) + }) + .collect() + }) + } + + /// Rebuild the global frecency map. + /// + /// This should be called by a background task periodically. + /// The map is used for scoring search results. + /// + /// Uses multipliers from search settings: + /// - `recency_score_multiplier`: Weight for recency component + /// - `frequency_score_multiplier`: Weight for frequency component + /// - `frecency_score_multiplier`: Overall multiplier for final score + #[instrument(skip_all, level = tracing::Level::DEBUG, name = "rebuild_frecency")] + pub async fn rebuild_frecency(&self, search_settings: &Search) { + let now = OffsetDateTime::now_utc().unix_timestamp(); + let mut frecency_map: HashMap<Arc<str>, u32> = HashMap::new(); + + // Clamp multipliers to non-negative values to prevent broken frecency ranking + // (negative values would produce unexpected results when cast to u32) + let recency_mul = search_settings.recency_score_multiplier.max(0.0); + let frequency_mul = search_settings.frequency_score_multiplier.max(0.0); + let frecency_mul = search_settings.frecency_score_multiplier.max(0.0); + + for entry in self.commands.iter() { + let frecency = entry + .global_frecency + .compute(now, recency_mul, frequency_mul); + // Apply overall frecency multiplier and round to u32 + let frecency = (frecency as f64 * frecency_mul).round() as u32; + // Arc::clone is cheap - just increments reference count + frecency_map.insert(Arc::clone(entry.key()), frecency); + } + + *self.frecency_map.write().await = Some(Arc::new(frecency_map)); + } + + /// Build filter predicate for the given mode. + fn build_filter(&self, mode: &IndexFilterMode) -> Option<atuin_nucleo::Filter<String>> { + // For Global mode, no filter needed + if matches!(mode, IndexFilterMode::Global) { + return None; + } + + // Pre-compute which commands pass the filter + // Use HashSet<String> for the short-lived filter (simpler than Arc lookup) + let passing_commands: Arc<HashSet<String>> = { + let mut set = HashSet::new(); + for entry in self.commands.iter() { + let passes = match mode { + IndexFilterMode::Global => unreachable!(), + IndexFilterMode::Directory(dir) => { + entry.has_invocation_in_dir(dir, &self.interner) + } + IndexFilterMode::Workspace(prefix) => { + entry.has_invocation_in_workspace(prefix, &self.interner) + } + IndexFilterMode::Host(hostname) => { + entry.has_invocation_on_host(hostname, &self.interner) + } + IndexFilterMode::Session(session) => entry.has_invocation_in_session(session), + }; + if passes { + // Convert Arc<str> to String for filter lookup + set.insert(entry.key().to_string()); + } + } + Arc::new(set) + }; + + Some(Arc::new(move |cmd: &String| passing_commands.contains(cmd))) + } + + /// Build scorer from precomputed frecency map. + /// + /// Returns None if frecency map is not available (search still works, just without frecency ranking). + fn build_scorer(frecency_map: Option<FrecencyMap>) -> Option<atuin_nucleo::Scorer<String>> { + let map = frecency_map?; + Some(Arc::new(move |cmd: &String, fuzzy_score: u32| { + // HashMap<Arc<str>, _>::get accepts &str via Borrow trait + let frecency = map.get(cmd.as_str()).copied().unwrap_or(0); + fuzzy_score + frecency + })) + } +} + +impl Default for SearchIndex { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use time::macros::datetime; + + fn make_history(command: &str, cwd: &str, timestamp: OffsetDateTime) -> History { + History::import() + .timestamp(timestamp) + .command(command) + .cwd(cwd) + .build() + .into() + } + + #[test] + fn frecency_data_compute() { + let now = 1_000_000i64; + + // Recent command (with default multipliers of 1.0) + let recent = FrecencyData { + count: 5, + last_used: now - 60, // 1 minute ago + }; + assert!(recent.compute(now, 1.0, 1.0) > 100); // High score + + // Old command + let old = FrecencyData { + count: 5, + last_used: now - 86400 * 30, // 30 days ago + }; + assert!(old.compute(now, 1.0, 1.0) < recent.compute(now, 1.0, 1.0)); + + // Frequently used old command + let frequent_old = FrecencyData { + count: 100, + last_used: now - 86400 * 7, // 1 week ago + }; + // Should still have decent score due to frequency + assert!(frequent_old.compute(now, 1.0, 1.0) > 50); + } + + #[test] + fn frecency_data_compute_with_multipliers() { + let now = 1_000_000_i64; + + let data = FrecencyData { + count: 5, + last_used: now - 60, // 1 minute ago (recency_score = 100) + }; + + // Default multipliers (1.0, 1.0) + let default_score = data.compute(now, 1.0, 1.0); + + // Double recency weight + let double_recency = data.compute(now, 2.0, 1.0); + assert!(double_recency > default_score); + + // Double frequency weight + let double_frequency = data.compute(now, 1.0, 2.0); + assert!(double_frequency > default_score); + + // Zero out recency (only frequency counts) + let no_recency = data.compute(now, 0.0, 1.0); + assert!(no_recency < default_score); + + // Zero out frequency (only recency counts) + let no_frequency = data.compute(now, 1.0, 0.0); + assert!(no_frequency < default_score); + + // Zero both (should be zero) + let no_score = data.compute(now, 0.0, 0.0); + assert_eq!(no_score, 0); + + // Fractional multipliers + let half_recency = data.compute(now, 0.5, 1.0); + assert!(half_recency < default_score); + assert!(half_recency > no_recency); + + // 1.5x multiplier + let boost_recency = data.compute(now, 1.5, 1.0); + assert!(boost_recency > default_score); + assert!(boost_recency < double_recency); + } + + #[test] + fn command_data_add_invocation() { + let interner = ThreadedRodeo::new(); + + let (dir1, dir2) = if cfg!(windows) { + ("C:\\Users\\User\\project", "C:\\Users\\User\\other") + } else { + ("/home/user/project", "/home/user/other") + }; + + let history1 = make_history("git status", dir1, datetime!(2024-01-01 10:00 UTC)); + let history2 = make_history("git status", dir2, datetime!(2024-01-01 12:00 UTC)); + + let mut data = CommandData::new(&history1, &interner).unwrap(); + assert_eq!(data.global_frecency.count, 1); + let id1 = data.most_recent_id(); + + data.add_invocation(&history2, &interner); + assert_eq!(data.global_frecency.count, 2); + + // Most recent ID should update to history2 (newer timestamp) + let id2 = data.most_recent_id(); + assert_ne!(id1, id2); + } + + #[test] + fn command_data_filters() { + let interner = ThreadedRodeo::new(); + + let (dir1, dir2) = if cfg!(windows) { + ("C:\\Users\\User\\project", "C:\\Users\\User\\other") + } else { + ("/home/user/project", "/home/user/other") + }; + + let h1 = make_history("git status", dir1, datetime!(2024-01-01 10:00 UTC)); + let h2 = make_history("git status", dir2, datetime!(2024-01-01 12:00 UTC)); + + let mut data = CommandData::new(&h1, &interner).unwrap(); + data.add_invocation(&h2, &interner); + + let (check1, check2, check3) = if cfg!(windows) { + ( + with_trailing_slash("C:\\Users\\User\\project"), + with_trailing_slash("C:\\Users\\User\\other"), + with_trailing_slash("C:\\Users\\User\\missing"), + ) + } else { + ( + with_trailing_slash("/home/user/project"), + with_trailing_slash("/home/user/other"), + with_trailing_slash("/home/user/missing"), + ) + }; + + assert!(data.has_invocation_in_dir(&check1, &interner)); + assert!(data.has_invocation_in_dir(&check2, &interner)); + assert!(!data.has_invocation_in_dir(&check3, &interner)); + + let (check1, check2, check3) = if cfg!(windows) { + ( + with_trailing_slash("C:\\Users\\User"), + with_trailing_slash("C:\\Users"), + with_trailing_slash("C:\\Users\\User\\var"), + ) + } else { + ( + with_trailing_slash("/home/user"), + with_trailing_slash("/home"), + with_trailing_slash("/var"), + ) + }; + + assert!(data.has_invocation_in_workspace(&check1, &interner)); + assert!(data.has_invocation_in_workspace(&check2, &interner)); + assert!(!data.has_invocation_in_workspace(&check3, &interner)); + } + + #[tokio::test] + async fn search_index_add_and_search() { + let index = SearchIndex::new(); + + let h1 = make_history( + "git status", + "/home/user/project", + datetime!(2024-01-01 10:00 UTC), + ); + let h2 = make_history( + "git commit -m 'test'", + "/home/user/project", + datetime!(2024-01-01 10:05 UTC), + ); + let h3 = make_history( + "ls -la", + "/home/user/other", + datetime!(2024-01-01 10:10 UTC), + ); + + index.add_history(&h1); + index.add_history(&h2); + index.add_history(&h3); + + assert_eq!(index.command_count(), 3); + + // Search for "git" - should match 2 commands + let results = index + .search("git", IndexFilterMode::Global, &QueryContext::default(), 10) + .await; + assert_eq!(results.len(), 2); + + // Search with directory filter + let results = index + .search( + "", + IndexFilterMode::Directory(with_trailing_slash("/home/user/project")), + &QueryContext::default(), + 10, + ) + .await; + assert_eq!(results.len(), 2); // git status and git commit + } +} diff --git a/crates/turtle/src/atuin_daemon/search/mod.rs b/crates/turtle/src/atuin_daemon/search/mod.rs new file mode 100644 index 00000000..4d261956 --- /dev/null +++ b/crates/turtle/src/atuin_daemon/search/mod.rs @@ -0,0 +1,11 @@ +//! Search module for the daemon gRPC search service. +//! +//! This module provides fuzzy search over command history using Nucleo. + +mod index; + +// Include the generated proto code +tonic::include_proto!("search"); + +// Re-export the service and index +pub use index::{IndexFilterMode, QueryContext, SearchIndex}; diff --git a/crates/turtle/src/atuin_daemon/semantic/mod.rs b/crates/turtle/src/atuin_daemon/semantic/mod.rs new file mode 100644 index 00000000..c3511676 --- /dev/null +++ b/crates/turtle/src/atuin_daemon/semantic/mod.rs @@ -0,0 +1,3 @@ +//! Semantic command capture gRPC service types. + +tonic::include_proto!("semantic"); diff --git a/crates/turtle/src/atuin_daemon/server.rs b/crates/turtle/src/atuin_daemon/server.rs new file mode 100644 index 00000000..23b04342 --- /dev/null +++ b/crates/turtle/src/atuin_daemon/server.rs @@ -0,0 +1,115 @@ +use eyre::Result; + +use crate::atuin_daemon::components::history::HistoryGrpcService; +use crate::atuin_daemon::components::search::SearchGrpcService; +use crate::atuin_daemon::components::semantic::SemanticGrpcService; +use crate::atuin_daemon::control::{ControlService, control_server::ControlServer}; +use crate::atuin_daemon::daemon::DaemonHandle; +use crate::atuin_daemon::history::history_server::HistoryServer; +use crate::atuin_daemon::search::search_server::SearchServer; +use crate::atuin_daemon::semantic::semantic_server::SemanticServer; + +use crate::atuin_client::settings::Settings; + +/// Run the gRPC server with the given services. +/// +/// This starts the gRPC server in the background and returns immediately. +/// The server will shut down when a ShutdownRequested event is received. +#[cfg(unix)] +pub async fn run_grpc_server( + settings: Settings, + history_service: HistoryServer<HistoryGrpcService>, + search_service: SearchServer<SearchGrpcService>, + semantic_service: SemanticServer<SemanticGrpcService>, + control_service: ControlServer<ControlService>, + handle: DaemonHandle, +) -> Result<()> { + use tokio::net::UnixListener; + use tokio_stream::wrappers::UnixListenerStream; + + let socket_path = settings.daemon.socket_path.clone(); + + let (uds, cleanup) = if cfg!(target_os = "linux") && settings.daemon.systemd_socket { + #[cfg(target_os = "linux")] + { + use eyre::{OptionExt, WrapErr}; + use std::os::unix::net::SocketAddr; + use std::path::PathBuf; + tracing::info!("getting systemd socket"); + let listener = listenfd::ListenFd::from_env() + .take_unix_listener(0)? + .ok_or_eyre("missing systemd socket")?; + listener.set_nonblocking(true)?; + let actual_path: Result<PathBuf, eyre::Report> = listener + .local_addr() + .context("getting systemd socket's path") + .and_then(|addr: SocketAddr| { + addr.as_pathname() + .ok_or_eyre("systemd socket missing path") + .map(|path: &std::path::Path| path.to_owned()) + }); + match actual_path { + Ok(actual_path) => { + tracing::info!("listening on systemd socket: {actual_path:?}"); + if actual_path != std::path::Path::new(&socket_path) { + tracing::warn!( + "systemd socket is not at configured client path: {socket_path:?}" + ); + } + } + Err(err) => { + tracing::warn!( + "could not detect systemd socket path, ensure that it's at the configured path: {socket_path:?}, error: {err:?}" + ); + } + } + (UnixListener::from_std(listener)?, false) + } + } else { + tracing::info!("listening on unix socket {socket_path:?}"); + (UnixListener::bind(socket_path.clone())?, true) + }; + + let uds_stream = UnixListenerStream::new(uds); + + // Create shutdown signal from daemon handle + let shutdown_signal = async move { + let mut rx = handle.subscribe(); + loop { + use crate::atuin_daemon::DaemonEvent; + + match rx.recv().await { + Ok(DaemonEvent::ShutdownRequested) => break, + Ok(_) => continue, + Err(_) => break, // Channel closed + } + } + if cleanup { + eprintln!("Removing socket..."); + if let Err(e) = std::fs::remove_file(&socket_path) + && e.kind() != std::io::ErrorKind::NotFound + { + eprintln!("failed to remove socket: {e}"); + } + } + eprintln!("Shutting down gRPC server..."); + }; + + // Spawn the server in the background + tokio::spawn(async move { + use tonic::transport::Server; + + if let Err(e) = Server::builder() + .add_service(history_service) + .add_service(search_service) + .add_service(semantic_service) + .add_service(control_service) + .serve_with_incoming_shutdown(uds_stream, shutdown_signal) + .await + { + tracing::error!("gRPC server error: {e}"); + } + }); + + Ok(()) +} diff --git a/crates/turtle/src/atuin_history/mod.rs b/crates/turtle/src/atuin_history/mod.rs new file mode 100644 index 00000000..e7b33916 --- /dev/null +++ b/crates/turtle/src/atuin_history/mod.rs @@ -0,0 +1,2 @@ +pub mod sort; +pub mod stats; diff --git a/crates/turtle/src/atuin_history/sort.rs b/crates/turtle/src/atuin_history/sort.rs new file mode 100644 index 00000000..b162c810 --- /dev/null +++ b/crates/turtle/src/atuin_history/sort.rs @@ -0,0 +1,46 @@ +use crate::atuin_client::history::History; + +type ScoredHistory = (f64, History); + +// Fuzzy search already comes sorted by minspan +// This sorting should be applicable to all search modes, and solve the more "obvious" issues +// first. +// Later on, we can pass in context and do some boosts there too. +pub fn sort(query: &str, input: Vec<History>) -> Vec<History> { + // This can totally be extended. We need to be _careful_ that it's not slow. + // We also need to balance sorting db-side with sorting here. SQLite can do a lot, + // but some things are just much easier/more doable in Rust. + + let mut scored = input + .into_iter() + .map(|h| { + // If history is _prefixed_ with the query, score it more highly + let score = if h.command.starts_with(query) { + 2.0 + } else if h.command.contains(query) { + 1.75 + } else { + 1.0 + }; + + // calculate how long ago the history was, in seconds + let now = time::OffsetDateTime::now_utc().unix_timestamp(); + let time = h.timestamp.unix_timestamp(); + let diff = std::cmp::max(1, now - time); // no /0 please + + // prefer newer history, but not hugely so as to offset the other scoring + // the numbers will get super small over time, but I don't want time to overpower other + // scoring + #[expect(clippy::cast_precision_loss)] + let time_score = 1.0 + (1.0 / diff as f64); + let score = score * time_score; + + (score, h) + }) + .collect::<Vec<ScoredHistory>>(); + + scored.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap().reverse()); + + // Remove the scores and return the history + scored.into_iter().map(|(_, h)| h).collect::<Vec<History>>() +} diff --git a/crates/turtle/src/atuin_history/stats.rs b/crates/turtle/src/atuin_history/stats.rs new file mode 100644 index 00000000..e47d6c8e --- /dev/null +++ b/crates/turtle/src/atuin_history/stats.rs @@ -0,0 +1,548 @@ +use std::collections::{HashMap, HashSet}; + +use crossterm::style::{Color, ResetColor, SetAttribute, SetForegroundColor}; +use serde::{Deserialize, Serialize}; +use unicode_segmentation::UnicodeSegmentation; + +use crate::atuin_client::{history::History, settings::Settings, theme::Meaning, theme::Theme}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Stats { + pub total_commands: usize, + pub unique_commands: usize, + pub top: Vec<(Vec<String>, usize)>, +} + +fn first_non_whitespace(s: &str) -> Option<usize> { + s.char_indices() + // find the first non whitespace char + .find(|(_, c)| !c.is_ascii_whitespace()) + // return the index of that char + .map(|(i, _)| i) +} + +fn first_whitespace(s: &str) -> usize { + s.char_indices() + // find the first whitespace char + .find(|(_, c)| c.is_ascii_whitespace()) + // return the index of that char, (or the max length of the string) + .map_or(s.len(), |(i, _)| i) +} + +fn interesting_command<'a>(settings: &Settings, mut command: &'a str) -> &'a str { + // Sort by length so that we match the longest prefix first + let mut common_prefix = settings.stats.common_prefix.clone(); + common_prefix.sort_by_key(|b| std::cmp::Reverse(b.len())); + + // Trim off the common prefix, if it exists + for p in &common_prefix { + if command.starts_with(p) { + let i = p.len(); + let prefix = &command[..i]; + command = command[i..].trim_start(); + if command.is_empty() { + // no commands following, just use the prefix + return prefix; + } + break; + } + } + + // Sort the common_subcommands by length so that we match the longest subcommand first + let mut common_subcommands = settings.stats.common_subcommands.clone(); + common_subcommands.sort_by_key(|b| std::cmp::Reverse(b.len())); + + // Check for a common subcommand + for p in &common_subcommands { + if command.starts_with(p) { + // if the subcommand is the same length as the command, then we just use the subcommand + if p.len() == command.len() { + return command; + } + // otherwise we need to use the subcommand + the next word + let non_whitespace = first_non_whitespace(&command[p.len()..]).unwrap_or(0); + let j = + p.len() + non_whitespace + first_whitespace(&command[p.len() + non_whitespace..]); + return &command[..j]; + } + } + // Return the first word if there is no subcommand + &command[..first_whitespace(command)] +} + +fn split_at_pipe(command: &str) -> Vec<&str> { + let mut result = vec![]; + let mut quoted = false; + let mut start = 0; + let mut graphemes = UnicodeSegmentation::grapheme_indices(command, true); + + while let Some((i, c)) = graphemes.next() { + let current = i; + match c { + "\"" if command[start..current] != *"\"" => { + quoted = !quoted; + } + "'" if command[start..current] != *"'" => { + quoted = !quoted; + } + "\\" if graphemes.next().is_some() => {} + "|" if !quoted => { + if current > start && command[start..].starts_with('|') { + start += 1; + } + result.push(&command[start..current]); + start = current; + } + _ => {} + } + } + if command[start..].starts_with('|') { + start += 1; + } + result.push(&command[start..]); + result +} + +fn strip_leading_env_vars(command: &str) -> &str { + // fast path: no equals sign, no environment variable + if !command.contains('=') { + return command; + } + + let mut in_token = false; + let mut token_start_pos = 0; + let mut in_single_quotes = false; + let mut in_double_quotes = false; + let mut escape_next = false; + let mut has_equals_outside_quotes = false; + + for (i, g) in UnicodeSegmentation::grapheme_indices(command, true) { + if escape_next { + escape_next = false; + continue; + } + + if !in_token { + token_start_pos = i; + } + + match g { + "\\" => { + escape_next = true; + in_token = true; + } + "'" if !in_double_quotes => { + in_single_quotes = !in_single_quotes; + in_token = true; + } + "\"" if !in_single_quotes => { + in_double_quotes = !in_double_quotes; + in_token = true; + } + "=" if !in_single_quotes && !in_double_quotes => { + has_equals_outside_quotes = true; + in_token = true; + } + " " | "\t" if !in_single_quotes && !in_double_quotes => { + if in_token { + if !has_equals_outside_quotes { + // if we're not in an env var, we can break early + break; + } + in_token = false; + has_equals_outside_quotes = false; + } + } + _ => { + in_token = true; + } + } + } + + command[token_start_pos..].trim() +} + +pub fn pretty_print(stats: Stats, ngram_size: usize, theme: &Theme) { + let max = stats.top.iter().map(|x| x.1).max().unwrap(); + let num_pad = max.ilog10() as usize + 1; + + // Find the length of the longest command name for each column + let column_widths = stats + .top + .iter() + .map(|(commands, _)| commands.iter().map(|c| c.len()).collect::<Vec<usize>>()) + .fold(vec![0; ngram_size], |acc, item| { + acc.iter() + .zip(item.iter()) + .map(|(a, i)| *std::cmp::max(a, i)) + .collect() + }); + + for (command, count) in stats.top { + let gray = SetForegroundColor(match theme.as_style(Meaning::Muted).foreground_color { + Some(color) => color, + None => Color::Grey, + }); + let bold = SetAttribute(crossterm::style::Attribute::Bold); + + let in_ten = 10 * count / max; + + print!("["); + print!( + "{}", + SetForegroundColor(match theme.get_error().foreground_color { + Some(color) => color, + None => Color::Red, + }) + ); + + for i in 0..in_ten { + if i == 2 { + print!( + "{}", + SetForegroundColor(match theme.get_warning().foreground_color { + Some(color) => color, + None => Color::Yellow, + }) + ); + } + + if i == 5 { + print!( + "{}", + SetForegroundColor(match theme.get_info().foreground_color { + Some(color) => color, + None => Color::Green, + }) + ); + } + + print!("▮"); + } + + for _ in in_ten..10 { + print!(" "); + } + + let formatted_command = command + .iter() + .zip(column_widths.iter()) + .map(|(cmd, width)| format!("{cmd:width$}")) + .collect::<Vec<_>>() + .join(" | "); + + println!( + "{ResetColor}] {gray}{count:num_pad$}{ResetColor} {bold}{formatted_command}{ResetColor}" + ); + } + println!("Total commands: {}", stats.total_commands); + println!("Unique commands: {}", stats.unique_commands); +} + +pub fn compute( + settings: &Settings, + history: &[History], + count: usize, + ngram_size: usize, +) -> Option<Stats> { + let mut commands = HashSet::<&str>::with_capacity(history.len()); + let mut total_unignored = 0; + let mut prefixes = HashMap::<Vec<&str>, usize>::with_capacity(history.len()); + + for i in history { + // just in case it somehow has a leading tab or space or something (legacy atuin didn't ignore space prefixes) + let command = strip_leading_env_vars(i.command.trim()); + let prefix = interesting_command(settings, command); + + if settings.stats.ignored_commands.iter().any(|c| c == prefix) { + continue; + } + + total_unignored += 1; + commands.insert(command); + + split_at_pipe(command) + .iter() + .map(|l| { + let command = l.trim(); + commands.insert(command); + command + }) + .collect::<Vec<_>>() + .windows(ngram_size) + .for_each(|w| { + *prefixes + .entry(w.iter().map(|c| interesting_command(settings, c)).collect()) + .or_default() += 1; + }); + } + + let unique = commands.len(); + let mut top = prefixes.into_iter().collect::<Vec<_>>(); + + top.sort_unstable_by_key(|x| std::cmp::Reverse(x.1)); + top.truncate(count); + + if top.is_empty() { + return None; + } + + Some(Stats { + unique_commands: unique, + total_commands: total_unignored, + top: top + .into_iter() + .map(|t| (t.0.into_iter().map(|s| s.to_string()).collect(), t.1)) + .collect(), + }) +} + +#[cfg(test)] +mod tests { + use crate::atuin_client::history::History; + use crate::atuin_client::settings::Settings; + use time::OffsetDateTime; + + use super::compute; + use super::{interesting_command, split_at_pipe, strip_leading_env_vars}; + + #[test] + fn ignored_env_vars() { + let settings = Settings::utc(); + + let history: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("FOO='BAR=🚀' echo foo") + .cwd("/") + .build() + .into(); + + let stats = compute(&settings, &[history], 10, 1).expect("failed to compute stats"); + assert_eq!(stats.top.first().unwrap().0, vec!["echo"]); + } + + #[test] + fn ignored_commands() { + let mut settings = Settings::utc(); + settings.stats.ignored_commands.push("cd".to_string()); + + let history = [ + History::import() + .timestamp(OffsetDateTime::now_utc()) + .command("cd foo") + .build() + .into(), + History::import() + .timestamp(OffsetDateTime::now_utc()) + .command("cargo build stuff") + .build() + .into(), + ]; + + let stats = compute(&settings, &history, 10, 1).expect("failed to compute stats"); + assert_eq!(stats.total_commands, 1); + assert_eq!(stats.unique_commands, 1); + } + + #[test] + fn interesting_commands() { + let settings = Settings::utc(); + + assert_eq!(interesting_command(&settings, "cargo"), "cargo"); + assert_eq!( + interesting_command(&settings, "cargo build foo bar"), + "cargo build" + ); + assert_eq!( + interesting_command(&settings, "sudo cargo build foo bar"), + "cargo build" + ); + assert_eq!(interesting_command(&settings, "sudo"), "sudo"); + } + + // Test with spaces in the common_prefix + #[test] + fn interesting_commands_spaces() { + let mut settings = Settings::utc(); + settings.stats.common_prefix.push("sudo test".to_string()); + + assert_eq!(interesting_command(&settings, "sudo test"), "sudo test"); + assert_eq!(interesting_command(&settings, "sudo test "), "sudo test"); + assert_eq!(interesting_command(&settings, "sudo test foo bar"), "foo"); + assert_eq!( + interesting_command(&settings, "sudo test foo bar"), + "foo" + ); + + // Works with a common_subcommand as well + assert_eq!( + interesting_command(&settings, "sudo test cargo build foo bar"), + "cargo build" + ); + + // We still match on just the sudo prefix + assert_eq!(interesting_command(&settings, "sudo"), "sudo"); + assert_eq!(interesting_command(&settings, "sudo foo"), "foo"); + } + + // Test with spaces in the common_subcommand + #[test] + fn interesting_commands_spaces_subcommand() { + let mut settings = Settings::utc(); + settings + .stats + .common_subcommands + .push("cargo build".to_string()); + + assert_eq!(interesting_command(&settings, "cargo build"), "cargo build"); + assert_eq!( + interesting_command(&settings, "cargo build "), + "cargo build" + ); + assert_eq!( + interesting_command(&settings, "cargo build foo bar"), + "cargo build foo" + ); + + // Works with a common_prefix as well + assert_eq!( + interesting_command(&settings, "sudo cargo build foo bar"), + "cargo build foo" + ); + + // We still match on just cargo as a subcommand + assert_eq!(interesting_command(&settings, "cargo"), "cargo"); + assert_eq!(interesting_command(&settings, "cargo foo"), "cargo foo"); + } + + // Test with spaces in the common_prefix and common_subcommand + #[test] + fn interesting_commands_spaces_both() { + let mut settings = Settings::utc(); + settings.stats.common_prefix.push("sudo test".to_string()); + settings + .stats + .common_subcommands + .push("cargo build".to_string()); + + assert_eq!( + interesting_command(&settings, "sudo test cargo build"), + "cargo build" + ); + assert_eq!( + interesting_command(&settings, "sudo test cargo build"), + "cargo build" + ); + assert_eq!( + interesting_command(&settings, "sudo test cargo build "), + "cargo build" + ); + assert_eq!( + interesting_command(&settings, "sudo test cargo build foo bar"), + "cargo build foo" + ); + } + + #[test] + fn split_simple() { + assert_eq!(split_at_pipe("fd | rg"), ["fd ", " rg"]); + } + + #[test] + fn split_multi() { + assert_eq!( + split_at_pipe("kubectl | jq | rg"), + ["kubectl ", " jq ", " rg"] + ); + } + + #[test] + fn split_simple_quoted() { + assert_eq!( + split_at_pipe("foo | bar 'baz {} | quux' | xyzzy"), + ["foo ", " bar 'baz {} | quux' ", " xyzzy"] + ); + } + + #[test] + fn split_multi_quoted() { + assert_eq!( + split_at_pipe("foo | bar 'baz \"{}\" | quux' | xyzzy"), + ["foo ", " bar 'baz \"{}\" | quux' ", " xyzzy"] + ); + } + + #[test] + fn escaped_pipes() { + assert_eq!( + split_at_pipe("foo | bar baz \\| quux"), + ["foo ", " bar baz \\| quux"] + ); + } + + #[test] + fn emoji() { + assert_eq!( + split_at_pipe("git commit -m \"🚀\""), + ["git commit -m \"🚀\""] + ); + } + + #[test] + fn starts_with_pipe() { + assert_eq!( + split_at_pipe("| sed 's/[0-9a-f]//g'"), + ["", " sed 's/[0-9a-f]//g'"] + ); + } + + #[test] + fn starts_with_spaces_and_pipe() { + assert_eq!( + split_at_pipe(" | sed 's/[0-9a-f]//g'"), + [" ", " sed 's/[0-9a-f]//g'"] + ); + } + + #[test] + fn strip_leading_env_vars_simple() { + assert_eq!( + strip_leading_env_vars("FOO=bar BAZ=quux echo foo"), + "echo foo" + ); + } + + #[test] + fn strip_leading_env_vars_quoted_single() { + assert_eq!(strip_leading_env_vars("FOO='BAR=baz' echo foo"), "echo foo"); + } + + #[test] + fn strip_leading_env_vars_quoted_double() { + assert_eq!( + strip_leading_env_vars("FOO=\"BAR=baz\" echo foo"), + "echo foo" + ); + } + + #[test] + fn strip_leading_env_vars_quoted_single_and_double() { + assert_eq!( + strip_leading_env_vars("FOO='BAR=\"baz\"' echo foo \"BAR=quux\""), + "echo foo \"BAR=quux\"" + ); + } + + #[test] + fn strip_leading_env_vars_emojis() { + assert_eq!( + strip_leading_env_vars("FOO='BAR=🚀' echo foo \"BAR=quux\" foo"), + "echo foo \"BAR=quux\" foo" + ); + } + + #[test] + fn strip_leading_env_vars_name_same_as_command() { + assert_eq!(strip_leading_env_vars("FOO='bar' bar baz"), "bar baz"); + } +} diff --git a/crates/turtle/src/atuin_pty_proxy/capture.rs b/crates/turtle/src/atuin_pty_proxy/capture.rs new file mode 100644 index 00000000..97ac9b8f --- /dev/null +++ b/crates/turtle/src/atuin_pty_proxy/capture.rs @@ -0,0 +1,467 @@ +use std::sync::Arc; +use std::sync::atomic::{AtomicU16, Ordering}; + +use crate::atuin_pty_proxy::osc133::{Event, Params, Parser, Zone}; + +const HISTORY_ID_PARAM: &str = "history_id"; +const SESSION_ID_PARAM: &str = "session_id"; +const MAX_OUTPUT_CAPTURE_BYTES: usize = 1024 * 1024; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CommandCapture { + pub prompt: String, + pub command: String, + pub output: String, + pub exit_code: Option<i32>, + pub history_id: Option<String>, + pub session_id: Option<String>, + pub output_truncated: bool, + pub output_observed_bytes: u64, +} + +pub type CommandCaptureSink = Box<dyn Fn(CommandCapture) + Send + 'static>; + +#[derive(Default)] +struct CaptureBuffers { + prompt: Vec<u8>, + command: Vec<u8>, + output: Vec<u8>, + output_observed_bytes: u64, + output_truncated: bool, + exit_code: Option<i32>, + history_id: Option<String>, + session_id: Option<String>, +} + +pub(crate) struct CommandCaptureTracker { + parser: Parser, + zone: Zone, + buffers: CaptureBuffers, + cols: Arc<AtomicU16>, +} + +impl CommandCaptureTracker { + pub(crate) fn new(cols: Arc<AtomicU16>) -> Self { + Self { + parser: Parser::new(), + zone: Zone::Unknown, + buffers: CaptureBuffers::default(), + cols, + } + } + + pub(crate) fn push(&mut self, data: &[u8], mut on_capture: impl FnMut(CommandCapture)) { + let mut events = Vec::new(); + self.parser + .push_located(data, |located| events.push(located)); + + let mut start = 0; + for located in events { + let marker_start = located.start_offset.min(data.len()).max(start); + let offset = located.offset.min(data.len()); + self.append(&data[start..marker_start]); + self.handle_event(located.event, &located.params, &mut on_capture); + self.zone = located.zone; + start = offset; + } + + let append_end = self + .parser + .incomplete_osc_sequence_start() + .map_or(data.len(), |sequence_start| { + sequence_start.min(data.len()).max(start) + }); + if start < append_end { + self.append(&data[start..append_end]); + } + } + + fn append(&mut self, data: &[u8]) { + match self.zone { + Zone::Prompt => self.buffers.prompt.extend_from_slice(data), + Zone::Input => self.buffers.command.extend_from_slice(data), + Zone::Output => self.append_output(data), + Zone::Unknown => {} + } + } + + fn append_output(&mut self, data: &[u8]) { + self.buffers.output_observed_bytes = self + .buffers + .output_observed_bytes + .saturating_add(data.len() as u64); + + if self.buffers.output_truncated { + return; + } + + let remaining = MAX_OUTPUT_CAPTURE_BYTES.saturating_sub(self.buffers.output.len()); + let retained = data.len().min(remaining); + self.buffers.output_truncated = retained < data.len(); + + if retained > 0 { + self.buffers.output.extend_from_slice(&data[..retained]); + } + } + + fn handle_event( + &mut self, + event: Event, + params: &Params, + on_capture: &mut impl FnMut(CommandCapture), + ) { + match event { + Event::PromptStart => { + if self.zone != Zone::Prompt { + self.buffers = CaptureBuffers::default(); + } + } + Event::CommandStart | Event::CommandExecuted => {} + Event::CommandFinished { exit_code } => { + let Some(history_id) = params.get(HISTORY_ID_PARAM).map(str::to_owned) else { + return; + }; + + if exit_code.is_some() || self.buffers.exit_code.is_none() { + self.buffers.exit_code = exit_code; + } + self.buffers.history_id = Some(history_id); + self.buffers.session_id = params.get(SESSION_ID_PARAM).map(str::to_owned); + + if let Some(capture) = self.finish_capture() { + on_capture(capture); + } + } + } + } + + fn finish_capture(&mut self) -> Option<CommandCapture> { + let buffers = std::mem::take(&mut self.buffers); + let cols = self.cols.load(Ordering::Relaxed).max(1); + let prompt = render_plain_text(&buffers.prompt, cols); + let command = render_plain_text(&buffers.command, cols) + .trim_matches(|c| c == '\r' || c == '\n') + .to_string(); + let output = render_plain_text(&buffers.output, cols); + let output_truncated = buffers.output_truncated; + let output_observed_bytes = buffers.output_observed_bytes; + let exit_code = buffers.exit_code; + let history_id = buffers.history_id; + let session_id = buffers.session_id; + + if command.is_empty() && output.is_empty() { + return None; + } + + Some(CommandCapture { + prompt, + command, + output, + exit_code, + history_id, + session_id, + output_truncated, + output_observed_bytes, + }) + } +} + +const CLEAN_TEXT_MAX_ROWS: usize = 10_000; + +fn render_plain_text(bytes: &[u8], cols: u16) -> String { + if bytes.is_empty() { + return String::new(); + } + + let cols = cols.max(1); + let mut parser = vt100::Parser::new(estimated_rows(bytes, cols), cols, 0); + parser.process(bytes); + normalize_screen_contents(&parser.screen().contents()) +} + +fn normalize_screen_contents(contents: &str) -> String { + let mut lines = contents.lines().map(str::trim_end).collect::<Vec<_>>(); + while lines.last().is_some_and(|line| line.is_empty()) { + lines.pop(); + } + lines.join("\n") +} + +fn estimated_rows(bytes: &[u8], cols: u16) -> u16 { + let newline_rows = bytes.iter().filter(|byte| **byte == b'\n').count() + 1; + let wrapped_rows = bytes.len() / cols as usize; + newline_rows + .saturating_add(wrapped_rows) + .saturating_add(1) + .clamp(1, CLEAN_TEXT_MAX_ROWS) as u16 +} + +#[cfg(test)] +mod tests { + use super::*; + + fn tracker(cols: u16) -> CommandCaptureTracker { + CommandCaptureTracker::new(Arc::new(AtomicU16::new(cols))) + } + + fn assert_no_terminal_controls(text: &str) { + assert!( + !text + .chars() + .any(|ch| ch.is_control() && ch != '\n' && ch != '\t'), + "text still contains terminal controls: {text:?}" + ); + } + + #[test] + fn command_text_collapses_terminal_echo_edits() { + assert_eq!(render_plain_text(b"e\x08echo hi", 80), "echo hi"); + assert_eq!( + render_plain_text( + b"e\x08echo\x08 \x08\x08 \x08\x08\x08e \x08\x08 \x08e\x08echo hi", + 80 + ), + "echo hi" + ); + assert_eq!(render_plain_text(b"echo hi", 80), "echo hi"); + } + + #[test] + fn text_cleaning_strips_ansi_and_terminal_controls() { + let text = render_plain_text( + b"\x1b[32mhi\x1b[0m\r\n% \r \r", + 80, + ); + + assert_eq!(text, "hi"); + assert_no_terminal_controls(&text); + } + + #[test] + fn text_cleaning_preserves_valid_utf8_after_backspace() { + let text = render_plain_text("🦀x\x08 \x08 crab".as_bytes(), 80); + + assert_eq!(text, "🦀 crab"); + assert_no_terminal_controls(&text); + } + + #[test] + fn command_text_replays_backspaces() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + let input = + b"\x1b]133;A\x07$ \x1b]133;B\x07e\x08echo hi\r\n\x1b]133;C\x07hi\r\n\x1b]133;D;0;history_id=hist;session_id=sess\x07\x1b]133;A\x07$ "; + tracker.push(input, |capture| captures.push(capture)); + + assert_eq!(captures.len(), 1); + assert_eq!(captures[0].command, "echo hi"); + assert_eq!(captures[0].output, "hi"); + assert_no_terminal_controls(&captures[0].command); + assert_no_terminal_controls(&captures[0].output); + } + + #[test] + fn captures_complete_command() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;A\x07$ \x1b]133;B\x07echo hi\r\n\x1b]133;C\x07hi\r\n\x1b]133;D;0;history_id=hist;session_id=sess\x07\x1b]133;A\x07$ ", + |capture| captures.push(capture), + ); + + assert_eq!( + captures, + vec![CommandCapture { + prompt: "$".to_string(), + command: "echo hi".to_string(), + output: "hi".to_string(), + exit_code: Some(0), + history_id: Some("hist".to_string()), + session_id: Some("sess".to_string()), + output_truncated: false, + output_observed_bytes: 4, + }] + ); + } + + #[test] + fn strips_ansi_and_split_markers() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push(b"\x1b]133;A\x07\x1b[32m%\x1b[0m ", |_| {}); + tracker.push(b"\x1b]133;B\x07ls\x1b]133;C", |_| {}); + tracker.push( + b"\x07\x1b[31mfile\x1b[0m\r\n\x1b]133;D;1;history_id=hist;session_id=sess\x07\x1b]133;A\x07% ", + |capture| { + captures.push(capture); + }, + ); + + assert_eq!( + captures, + vec![CommandCapture { + prompt: "%".to_string(), + command: "ls".to_string(), + output: "file".to_string(), + exit_code: Some(1), + history_id: Some("hist".to_string()), + session_id: Some("sess".to_string()), + output_truncated: false, + output_observed_bytes: 15, + }] + ); + } + + #[test] + fn duplicate_prompt_start_does_not_reset_prompt_capture() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;A\x07$ \x1b]133;A\x07continued \x1b]133;B\x07echo hi\r\n\x1b]133;C\x07hi\r\n\x1b]133;D;0;history_id=hist;session_id=sess\x07\x1b]133;A\x07$ ", + |capture| captures.push(capture), + ); + + assert_eq!(captures.len(), 1); + assert_eq!(captures[0].prompt, "$ continued"); + assert_eq!(captures[0].command, "echo hi"); + assert_eq!(captures[0].output, "hi"); + } + + #[test] + fn bare_finish_without_metadata_is_ignored() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push(b"\x1b]133;C\x07line one\r\n\x1b]133;D;0\x07", |capture| { + captures.push(capture); + }); + + tracker.push(b"\x1b]133;A\x07$ ", |capture| captures.push(capture)); + + assert!(captures.is_empty()); + } + + #[test] + fn bare_finish_before_metadata_in_same_push_ignored() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;C\x07line one\r\n\x1b]133;D;1\x07\x1b]133;D;0;history_id=018f;session_id=abcd\x07", + |capture| captures.push(capture), + ); + + assert_eq!(captures.len(), 1); + assert_eq!(captures[0].output, "line one"); + assert_eq!(captures[0].exit_code, Some(0)); + assert_eq!(captures[0].history_id.as_deref(), Some("018f")); + assert_eq!(captures[0].session_id.as_deref(), Some("abcd")); + } + + #[test] + fn metadata_arriving_after_bare_finish_across_pushes() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push(b"\x1b]133;C\x07line one\r\n\x1b]133;D;0\x07", |capture| { + captures.push(capture); + }); + tracker.push(b"\x1b]133;D;0;history_id=018f", |capture| { + captures.push(capture) + }); + + assert!(captures.is_empty()); + + tracker.push(b";session_id=abcd\x07", |capture| captures.push(capture)); + + assert_eq!(captures.len(), 1); + assert_eq!(captures[0].output, "line one"); + assert_eq!(captures[0].exit_code, Some(0)); + assert_eq!(captures[0].history_id.as_deref(), Some("018f")); + assert_eq!(captures[0].session_id.as_deref(), Some("abcd")); + } + + #[test] + fn split_finish_marker_is_not_counted_as_output() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;C\x07line one\r\n\x1b]133;D;0;history_id=018f", + |capture| { + captures.push(capture); + }, + ); + assert!(captures.is_empty()); + + tracker.push(b";session_id=abcd\x07", |capture| captures.push(capture)); + + assert_eq!(captures.len(), 1); + assert_eq!(captures[0].output, "line one"); + assert_eq!(captures[0].output_observed_bytes, 10); + } + + #[test] + fn captures_output_with_history_metadata_from_d_marker() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;C\x07line one\r\n\x1b]133;D;0;history_id=018f;session_id=abcd\x07", + |capture| captures.push(capture), + ); + + assert_eq!( + captures, + vec![CommandCapture { + prompt: String::new(), + command: String::new(), + output: "line one".to_string(), + exit_code: Some(0), + history_id: Some("018f".to_string()), + session_id: Some("abcd".to_string()), + output_truncated: false, + output_observed_bytes: 10, + }] + ); + } + + #[test] + fn output_capture_is_capped_and_reports_observed_bytes() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + let mut input = b"\x1b]133;C\x07".to_vec(); + input.extend(std::iter::repeat_n(b'x', MAX_OUTPUT_CAPTURE_BYTES + 10)); + input.extend_from_slice(b"\x1b]133;D;0;history_id=big;session_id=session-1\x07"); + + tracker.push(&input, |capture| captures.push(capture)); + + assert_eq!(captures.len(), 1); + assert!(captures[0].output_truncated); + assert_eq!( + captures[0].output_observed_bytes, + (MAX_OUTPUT_CAPTURE_BYTES + 10) as u64 + ); + } + + #[test] + fn resets_buffers_between_c_d_only_captures() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;C\x07first\r\n\x1b]133;D;0;history_id=one\x07\x1b]133;C\x07second\r\n\x1b]133;D;1;history_id=two\x07", + |capture| captures.push(capture), + ); + + assert_eq!(captures.len(), 2); + assert_eq!(captures[0].output, "first"); + assert_eq!(captures[0].history_id.as_deref(), Some("one")); + assert_eq!(captures[1].output, "second"); + assert_eq!(captures[1].history_id.as_deref(), Some("two")); + } +} diff --git a/crates/turtle/src/atuin_pty_proxy/debug.rs b/crates/turtle/src/atuin_pty_proxy/debug.rs new file mode 100644 index 00000000..bf311281 --- /dev/null +++ b/crates/turtle/src/atuin_pty_proxy/debug.rs @@ -0,0 +1,53 @@ +use crate::atuin_pty_proxy::osc133::{Event, Parser}; + +pub(crate) const RESET: &[u8] = b"\x1b[0m"; + +pub(crate) struct Osc133DebugHighlighter { + parser: Parser, +} + +impl Osc133DebugHighlighter { + pub(crate) fn new() -> Self { + Self { + parser: Parser::new(), + } + } + + pub(crate) fn render(&mut self, data: &[u8]) -> Vec<u8> { + let mut events = Vec::new(); + self.parser + .push_located(data, |located| events.push(located)); + + if events.is_empty() { + return data.to_vec(); + } + + let mut rendered = Vec::with_capacity(data.len() + (events.len() * 64)); + let mut start = 0; + + for located in events { + let offset = located.offset.min(data.len()); + if offset > start { + rendered.extend_from_slice(&data[start..offset]); + } + + rendered.extend_from_slice(event_label(&located.event)); + rendered.extend_from_slice(RESET); + start = offset; + } + + rendered.extend_from_slice(&data[start..]); + rendered + } +} + +fn event_label(event: &Event) -> &'static [u8] { + match event { + Event::PromptStart => b"\x1b[1;37;45m[OSC133:A prompt]\x1b[0m", + Event::CommandStart => b"\x1b[1;30;43m[OSC133:B input]\x1b[0m", + Event::CommandExecuted => b"\x1b[1;30;46m[OSC133:C output]\x1b[0m", + Event::CommandFinished { exit_code: Some(0) } => b"\x1b[1;37;42m[OSC133:D exit=0]\x1b[0m", + Event::CommandFinished { exit_code: Some(_) } => b"\x1b[1;37;41m[OSC133:D exit!=0]\x1b[0m", + Event::CommandFinished { exit_code: None } => b"\x1b[1;37;44m[OSC133:D exit=?]\x1b[0m", + } +} diff --git a/crates/turtle/src/atuin_pty_proxy/mod.rs b/crates/turtle/src/atuin_pty_proxy/mod.rs new file mode 100644 index 00000000..612943fa --- /dev/null +++ b/crates/turtle/src/atuin_pty_proxy/mod.rs @@ -0,0 +1,17 @@ +#[cfg(unix)] +mod capture; +#[cfg(unix)] +mod debug; +#[cfg(unix)] +mod osc133; +#[cfg(unix)] +mod pty_proxy; +#[cfg(unix)] +mod runtime; +#[cfg(unix)] +mod screen; + +#[cfg(unix)] +pub use capture::{CommandCapture, CommandCaptureSink}; +#[cfg(unix)] +pub use pty_proxy::PtyProxy; diff --git a/crates/turtle/src/atuin_pty_proxy/osc133.rs b/crates/turtle/src/atuin_pty_proxy/osc133.rs new file mode 100644 index 00000000..5b70f0aa --- /dev/null +++ b/crates/turtle/src/atuin_pty_proxy/osc133.rs @@ -0,0 +1,900 @@ +//! Streaming parser for OSC 133 (FinalTerm semantic prompt) escape sequences. +//! +//! OSC 133 marks four regions of a shell interaction: +//! +//! | Marker | Meaning | +//! |--------|--------------------------------------| +//! | A | Prompt is about to be printed | +//! | B | Prompt ended — command input begins | +//! | C | Command submitted — output begins | +//! | D[;n] | Command finished with exit code *n* | +//! +//! The wire format is `ESC ] 133 ; <cmd> [; <params>] ST` where ST is BEL +//! (0x07), ESC \ (0x1B 0x5C), or C1 ST (0x9C). +//! +//! # Design goals +//! +//! * **Transparent** — the parser observes the byte stream without modifying it; +//! the caller remains responsible for forwarding bytes to their destination. +//! * **Bounded** — OSC parameter buffering is capped so malformed output cannot +//! grow memory without limit. +//! * **Non-blocking** — [`Parser::push`] processes whatever bytes are available +//! and returns immediately. +//! * **Extensible** — marker parameters are preserved so Atuin-specific metadata +//! can ride alongside standard OSC 133 markers. + +/// Events emitted when an OSC 133 marker is detected. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Event { + /// `ESC ] 133 ; A ST` — the shell is about to display its prompt. + PromptStart, + /// `ESC ] 133 ; B ST` — the prompt has ended; the user may type a command. + CommandStart, + /// `ESC ] 133 ; C ST` — the command has been submitted for execution. + CommandExecuted, + /// `ESC ] 133 ; D [; <exit_code>] ST` — command output is complete. + CommandFinished { + /// The exit code reported after the `;`, if present and valid. + exit_code: Option<i32>, + }, +} + +/// Parameters attached to an OSC 133 marker. +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub struct Params { + items: Vec<Param>, +} + +impl Params { + /// Iterate over all marker parameters in order. + #[cfg(test)] + #[inline] + pub fn iter(&self) -> impl Iterator<Item = &Param> { + self.items.iter() + } + + /// Return the value for the first `key=value` parameter with this key. + #[inline] + pub fn get(&self, key: &str) -> Option<&str> { + self.items.iter().find_map(|item| match item { + Param::KeyValue { + key: item_key, + value, + } if item_key == key => Some(value.as_str()), + Param::Value(_) | Param::KeyValue { .. } => None, + }) + } +} + +/// A single OSC 133 marker parameter. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Param { + /// A positional parameter without an equals sign. + Value(String), + /// A `key=value` parameter. + KeyValue { key: String, value: String }, +} + +/// An OSC 133 event with its position in the most recent input chunk. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LocatedEvent { + /// The OSC 133 event that was parsed. + pub event: Event, + /// Offset where this marker starts in the current chunk. + /// + /// If a marker started in an earlier [`Parser::push_located`] call, this is + /// `0` in the chunk that completed the marker. + pub start_offset: usize, + /// Offset immediately after this marker's terminator in the current chunk. + /// + /// If a marker spans multiple [`Parser::push_located`] calls, this is still + /// the offset in the chunk that completed the marker. + pub offset: usize, + /// The semantic zone after applying this event. + pub zone: Zone, + /// Metadata parameters attached to this marker. + pub params: Params, +} + +/// The current semantic zone as determined by the most recent OSC 133 marker. +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] +#[expect(dead_code)] +pub enum Zone { + /// No marker seen yet, or after a `D` marker (between commands). + #[default] + Unknown, + /// Between `A` and `B` — the shell is rendering its prompt. + Prompt, + /// Between `B` and `C` — the user is editing a command line. + Input, + /// Between `C` and `D` — command output is being produced. + Output, +} + +// --------------------------------------------------------------------------- +// Internal constants +// --------------------------------------------------------------------------- + +const ESC: u8 = 0x1B; +const BEL: u8 = 0x07; +const C1_ST: u8 = 0x9C; +const BACKSLASH: u8 = b'\\'; +const RIGHT_BRACKET: u8 = b']'; + +/// Maximum bytes we'll buffer for the OSC parameter string. This is large enough +/// for Atuin metadata such as history/session IDs while still bounding malformed +/// OSC sequences. +const PARAM_BUF_CAP: usize = 512; + +// --------------------------------------------------------------------------- +// State machine +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum State { + /// Normal pass-through. + Ground, + /// Saw ESC (0x1B). + Esc, + /// Inside an OSC sequence (`ESC ]`), accumulating parameter bytes. + OscParam, + /// Inside an OSC sequence, saw ESC — next byte decides if this is `ESC \` + /// (string terminator) or something else. + OscEsc, +} + +/// A streaming, zero-allocation parser for OSC 133 escape sequences. +/// +/// Feed arbitrary byte slices into [`Parser::push`]. The parser detects +/// OSC 133 markers and reports [`Event`]s through a caller-supplied callback +/// without modifying the data. It can sit transparently between a PTY reader +/// and stdout. +pub struct Parser { + state: State, + zone: Zone, + sequence_start: Option<usize>, + param_buf: [u8; PARAM_BUF_CAP], + param_len: usize, +} + +impl Default for Parser { + fn default() -> Self { + Self::new() + } +} + +impl Parser { + /// Create a new parser in the initial (ground / unknown-zone) state. + #[inline] + pub fn new() -> Self { + Self { + state: State::Ground, + zone: Zone::Unknown, + sequence_start: None, + param_buf: [0u8; PARAM_BUF_CAP], + param_len: 0, + } + } + + /// The current semantic zone based on markers seen so far. + #[inline] + #[expect(dead_code)] + pub fn zone(&self) -> Zone { + self.zone + } + + /// Start offset of an incomplete OSC sequence in the most recent chunk. + #[inline] + pub(crate) fn incomplete_osc_sequence_start(&self) -> Option<usize> { + matches!(self.state, State::OscParam | State::OscEsc) + .then(|| self.sequence_start.unwrap_or(0)) + } + + /// Process a chunk of bytes, calling `on_event` for every OSC 133 marker + /// found. + /// + /// All bytes in `data` should still be forwarded to the terminal by the + /// caller — this method only *observes* the stream. + #[cfg(test)] + #[inline] + pub fn push(&mut self, data: &[u8], mut on_event: impl FnMut(Event)) { + self.push_located(data, |located| on_event(located.event)); + } + + /// Process a chunk of bytes, calling `on_event` for every OSC 133 marker + /// found with its byte offset in this chunk. + /// + /// The offset points to the first byte after the marker terminator, making + /// it suitable for callers that need to split the original chunk at marker + /// boundaries. + #[inline] + pub fn push_located(&mut self, data: &[u8], mut on_event: impl FnMut(LocatedEvent)) { + self.sequence_start = (self.state != State::Ground).then_some(0); + + for (offset, &byte) in data.iter().enumerate() { + match self.state { + State::Ground => { + if byte == ESC { + self.state = State::Esc; + self.sequence_start = Some(offset); + } + } + State::Esc => { + if byte == RIGHT_BRACKET { + self.state = State::OscParam; + self.param_len = 0; + } else { + self.state = State::Ground; + self.sequence_start = None; + } + } + State::OscParam => { + if byte == BEL || byte == C1_ST { + self.dispatch(offset + 1, &mut on_event); + self.state = State::Ground; + self.sequence_start = None; + } else if byte == ESC { + self.state = State::OscEsc; + } else if self.param_len < PARAM_BUF_CAP { + self.param_buf[self.param_len] = byte; + self.param_len += 1; + } + // If param_len == PARAM_BUF_CAP we silently stop + // accumulating — dispatch will ignore non-133 sequences. + } + State::OscEsc => { + if byte == BACKSLASH { + self.dispatch(offset + 1, &mut on_event); + } + // Whether we got a valid ST or not, return to ground. + // (A new ESC ] would restart accumulation via the Ground + // -> Esc -> OscParam path on the *next* byte.) + self.state = State::Ground; + self.sequence_start = None; + } + } + } + } + + /// Inspect the accumulated parameter buffer. If it holds an OSC 133 + /// payload, emit the corresponding [`Event`] and update the zone. + #[inline] + fn dispatch(&mut self, offset: usize, on_event: &mut impl FnMut(LocatedEvent)) { + let payload = &self.param_buf[..self.param_len]; + + if payload.len() < 5 || &payload[..4] != b"133;" { + return; + } + + if payload.len() > 5 && payload[5] != b';' { + return; + } + + let metadata = payload.get(6..).unwrap_or_default(); + let cmd = payload[4]; + let (event, params) = match cmd { + b'A' => { + self.zone = Zone::Prompt; + (Event::PromptStart, parse_params(metadata)) + } + b'B' => { + self.zone = Zone::Input; + (Event::CommandStart, parse_params(metadata)) + } + b'C' => { + self.zone = Zone::Output; + (Event::CommandExecuted, parse_params(metadata)) + } + b'D' => { + let (exit_code, params) = parse_command_finished_params(metadata); + self.zone = Zone::Unknown; + (Event::CommandFinished { exit_code }, params) + } + _ => return, + }; + + on_event(LocatedEvent { + event, + start_offset: self.sequence_start.unwrap_or(0), + offset, + zone: self.zone, + params, + }); + } +} + +fn parse_command_finished_params(metadata: &[u8]) -> (Option<i32>, Params) { + if metadata.is_empty() { + return (None, Params::default()); + } + + let Some(separator) = metadata.iter().position(|byte| *byte == b';') else { + return parse_exit_code(metadata).map_or_else( + || (None, parse_params(metadata)), + |exit_code| (Some(exit_code), Params::default()), + ); + }; + + let (first, rest) = metadata.split_at(separator); + let rest = &rest[1..]; + + parse_exit_code(first).map_or_else( + || (None, parse_params(metadata)), + |exit_code| (Some(exit_code), parse_params(rest)), + ) +} + +fn parse_exit_code(code: &[u8]) -> Option<i32> { + if code.is_empty() { + return None; + } + + std::str::from_utf8(code) + .ok() + .and_then(|code| code.parse::<i32>().ok()) +} + +fn parse_params(metadata: &[u8]) -> Params { + let items = metadata + .split(|byte| *byte == b';') + .filter(|part| !part.is_empty()) + .map(parse_param) + .collect(); + + Params { items } +} + +fn parse_param(param: &[u8]) -> Param { + let param = String::from_utf8_lossy(param); + + if let Some((key, value)) = param.split_once('=') { + return Param::KeyValue { + key: key.to_string(), + value: value.to_string(), + }; + } + + Param::Value(param.into_owned()) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + /// Collect all events from a single `push` call. + fn parse_events(data: &[u8]) -> Vec<Event> { + let mut parser = Parser::new(); + let mut events = Vec::new(); + parser.push(data, |e| events.push(e)); + events + } + + // -- Basic event detection ------------------------------------------------ + + #[test] + fn detect_prompt_start_bel() { + let data = b"\x1b]133;A\x07"; + assert_eq!(parse_events(data), vec![Event::PromptStart]); + } + + #[test] + fn detect_prompt_start_st() { + let data = b"\x1b]133;A\x1b\\"; + assert_eq!(parse_events(data), vec![Event::PromptStart]); + } + + #[test] + fn detect_command_start_bel() { + let data = b"\x1b]133;B\x07"; + assert_eq!(parse_events(data), vec![Event::CommandStart]); + } + + #[test] + fn detect_command_start_st() { + let data = b"\x1b]133;B\x1b\\"; + assert_eq!(parse_events(data), vec![Event::CommandStart]); + } + + #[test] + fn detect_command_executed_bel() { + let data = b"\x1b]133;C\x07"; + assert_eq!(parse_events(data), vec![Event::CommandExecuted]); + } + + #[test] + fn detect_command_executed_st() { + let data = b"\x1b]133;C\x1b\\"; + assert_eq!(parse_events(data), vec![Event::CommandExecuted]); + } + + #[test] + fn detect_command_finished_no_exit_code() { + let data = b"\x1b]133;D\x07"; + assert_eq!( + parse_events(data), + vec![Event::CommandFinished { exit_code: None }] + ); + } + + #[test] + fn detect_command_finished_exit_zero() { + let data = b"\x1b]133;D;0\x07"; + assert_eq!( + parse_events(data), + vec![Event::CommandFinished { exit_code: Some(0) }] + ); + } + + #[test] + fn detect_command_finished_exit_nonzero() { + let data = b"\x1b]133;D;127\x07"; + assert_eq!( + parse_events(data), + vec![Event::CommandFinished { + exit_code: Some(127) + }] + ); + } + + #[test] + fn detect_command_finished_negative_exit_code() { + let data = b"\x1b]133;D;-1\x07"; + assert_eq!( + parse_events(data), + vec![Event::CommandFinished { + exit_code: Some(-1) + }] + ); + } + + #[test] + fn detect_command_finished_exit_code_st() { + let data = b"\x1b]133;D;42\x1b\\"; + assert_eq!( + parse_events(data), + vec![Event::CommandFinished { + exit_code: Some(42) + }] + ); + } + + #[test] + fn invalid_exit_code_yields_none() { + let data = b"\x1b]133;D;abc\x07"; + assert_eq!( + parse_events(data), + vec![Event::CommandFinished { exit_code: None }] + ); + } + + // -- Zone tracking -------------------------------------------------------- + + #[test] + fn zone_starts_unknown() { + let parser = Parser::new(); + assert_eq!(parser.zone(), Zone::Unknown); + } + + #[test] + fn full_zone_cycle() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push(b"\x1b]133;A\x07", |e| events.push(e)); + assert_eq!(parser.zone(), Zone::Prompt); + + parser.push(b"\x1b]133;B\x07", |e| events.push(e)); + assert_eq!(parser.zone(), Zone::Input); + + parser.push(b"\x1b]133;C\x07", |e| events.push(e)); + assert_eq!(parser.zone(), Zone::Output); + + parser.push(b"\x1b]133;D;0\x07", |e| events.push(e)); + assert_eq!(parser.zone(), Zone::Unknown); + + assert_eq!( + events, + vec![ + Event::PromptStart, + Event::CommandStart, + Event::CommandExecuted, + Event::CommandFinished { exit_code: Some(0) }, + ] + ); + } + + // -- Multiple events in one push ------------------------------------------ + + #[test] + fn multiple_events_single_push() { + let data = b"\x1b]133;A\x07$ \x1b]133;B\x07ls\n\x1b]133;C\x07file.txt\n\x1b]133;D;0\x07"; + let events = parse_events(data); + assert_eq!( + events, + vec![ + Event::PromptStart, + Event::CommandStart, + Event::CommandExecuted, + Event::CommandFinished { exit_code: Some(0) }, + ] + ); + } + + // -- Split across push boundaries ----------------------------------------- + + #[test] + fn split_esc_and_bracket() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push(b"\x1b", |e| events.push(e)); + assert!(events.is_empty()); + + parser.push(b"]133;A\x07", |e| events.push(e)); + assert_eq!(events, vec![Event::PromptStart]); + } + + #[test] + fn split_mid_param() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push(b"\x1b]13", |e| events.push(e)); + assert!(events.is_empty()); + + parser.push(b"3;D;42\x07", |e| events.push(e)); + assert_eq!( + events, + vec![Event::CommandFinished { + exit_code: Some(42) + }] + ); + } + + #[test] + fn split_before_terminator() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push(b"\x1b]133;B", |e| events.push(e)); + assert!(events.is_empty()); + + parser.push(b"\x07", |e| events.push(e)); + assert_eq!(events, vec![Event::CommandStart]); + } + + #[test] + fn split_esc_backslash_terminator() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push(b"\x1b]133;C\x1b", |e| events.push(e)); + assert!(events.is_empty()); + + parser.push(b"\\", |e| events.push(e)); + assert_eq!(events, vec![Event::CommandExecuted]); + } + + // -- Interleaved normal text ---------------------------------------------- + + #[test] + fn normal_text_before_and_after() { + let data = b"hello world\x1b]133;A\x07prompt text\x1b]133;B\x07command"; + let events = parse_events(data); + assert_eq!(events, vec![Event::PromptStart, Event::CommandStart]); + } + + // -- Non-133 OSC sequences (should be ignored) ---------------------------- + + #[test] + fn non_133_osc_ignored() { + let data = b"\x1b]0;window title\x07\x1b]133;A\x07"; + let events = parse_events(data); + assert_eq!(events, vec![Event::PromptStart]); + } + + #[test] + fn osc_7_ignored() { + let data = b"\x1b]7;file:///home/user\x07"; + assert!(parse_events(data).is_empty()); + } + + // -- Unknown command letter ----------------------------------------------- + + #[test] + fn unknown_command_ignored() { + let data = b"\x1b]133;Z\x07"; + assert!(parse_events(data).is_empty()); + } + + #[test] + fn marker_with_unexpected_trailing_bytes_ignored() { + let data = b"\x1b]133;ABC\x07"; + assert!(parse_events(data).is_empty()); + } + + // -- Malformed sequences -------------------------------------------------- + + #[test] + fn esc_followed_by_non_bracket() { + let data = b"\x1b[31m\x1b]133;A\x07"; + let events = parse_events(data); + assert_eq!(events, vec![Event::PromptStart]); + } + + #[test] + fn lone_esc_at_end_of_chunk() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push(b"\x1b", |e| events.push(e)); + assert!(events.is_empty()); + + // Feed non-bracket to abort the escape, then a real sequence. + parser.push(b"x\x1b]133;A\x07", |e| events.push(e)); + assert_eq!(events, vec![Event::PromptStart]); + } + + #[test] + fn truncated_133_prefix() { + // "13" followed by terminator — not "133;" so no event. + let data = b"\x1b]13\x07"; + assert!(parse_events(data).is_empty()); + } + + #[test] + fn empty_osc() { + let data = b"\x1b]\x07"; + assert!(parse_events(data).is_empty()); + } + + // -- Buffer overflow (very long non-133 OSC) ------------------------------ + + #[test] + fn very_long_osc_does_not_panic() { + let mut data = Vec::new(); + data.extend_from_slice(b"\x1b]"); + data.extend(std::iter::repeat_n(b'x', 1000)); + data.push(BEL); + // Should not panic and should produce no event. + assert!(parse_events(&data).is_empty()); + } + + // -- Empty input ---------------------------------------------------------- + + #[test] + fn empty_input() { + assert!(parse_events(b"").is_empty()); + } + + #[test] + fn only_normal_text() { + let data = b"just some regular terminal output\r\n"; + assert!(parse_events(data).is_empty()); + } + + // -- Repeated prompts (empty command) ------------------------------------ + + #[test] + fn repeated_prompt_cycle() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + // User hits enter on an empty prompt twice. + let data = b"\x1b]133;A\x07$ \x1b]133;B\x07\x1b]133;D\x07\x1b]133;A\x07$ \x1b]133;B\x07"; + parser.push(data, |e| events.push(e)); + + assert_eq!( + events, + vec![ + Event::PromptStart, + Event::CommandStart, + Event::CommandFinished { exit_code: None }, + Event::PromptStart, + Event::CommandStart, + ] + ); + assert_eq!(parser.zone(), Zone::Input); + } + + // -- Byte-at-a-time feeding ----------------------------------------------- + + #[test] + fn byte_at_a_time() { + let data = b"\x1b]133;D;99\x07"; + let mut parser = Parser::new(); + let mut events = Vec::new(); + + for &byte in data { + parser.push(&[byte], |e| events.push(e)); + } + + assert_eq!( + events, + vec![Event::CommandFinished { + exit_code: Some(99) + }] + ); + } + + // -- Mixed terminators ---------------------------------------------------- + + #[test] + fn mixed_bel_and_st_terminators() { + let data = b"\x1b]133;A\x07\x1b]133;B\x1b\\\x1b]133;C\x07\x1b]133;D;1\x1b\\"; + let events = parse_events(data); + assert_eq!( + events, + vec![ + Event::PromptStart, + Event::CommandStart, + Event::CommandExecuted, + Event::CommandFinished { exit_code: Some(1) }, + ] + ); + } + + #[test] + fn detects_c1_st_terminator() { + let data = b"\x1b]133;A\x9c"; + assert_eq!(parse_events(data), vec![Event::PromptStart]); + } + + // -- Located event offsets ------------------------------------------------ + + #[test] + fn located_event_reports_offset_after_marker() { + let data = b"before\x1b]133;A\x07prompt"; + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push_located(data, |e| events.push(e)); + + assert_eq!( + events, + vec![LocatedEvent { + event: Event::PromptStart, + start_offset: b"before".len(), + offset: b"before\x1b]133;A\x07".len(), + zone: Zone::Prompt, + params: Params::default(), + }] + ); + } + + #[test] + fn located_event_offset_is_relative_to_completing_chunk() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push_located(b"\x1b]133;", |e| events.push(e)); + parser.push_located(b"D;42\x07after", |e| events.push(e)); + + assert_eq!( + events, + vec![LocatedEvent { + event: Event::CommandFinished { + exit_code: Some(42) + }, + start_offset: 0, + offset: b"D;42\x07".len(), + zone: Zone::Unknown, + params: Params::default(), + }] + ); + } + + #[test] + fn located_event_preserves_metadata_params() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push_located( + b"\x1b]133;D;127;history_id=018f;session_id=abcd;flag\x07", + |event| events.push(event), + ); + + assert_eq!(events.len(), 1); + let event = &events[0]; + assert_eq!( + event.event, + Event::CommandFinished { + exit_code: Some(127) + } + ); + assert_eq!(event.params.get("history_id"), Some("018f")); + assert_eq!(event.params.get("session_id"), Some("abcd")); + assert!( + event + .params + .iter() + .any(|param| param == &Param::Value("flag".to_string())) + ); + } + + #[test] + fn command_finished_metadata_without_exit_code_is_preserved() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push_located(b"\x1b]133;D;history_id=018f;session_id=abcd\x07", |event| { + events.push(event); + }); + + assert_eq!(events.len(), 1); + let event = &events[0]; + assert_eq!(event.event, Event::CommandFinished { exit_code: None }); + assert_eq!(event.params.get("history_id"), Some("018f")); + assert_eq!(event.params.get("session_id"), Some("abcd")); + } + + // -- Default trait -------------------------------------------------------- + + #[test] + fn parser_default() { + let parser = Parser::default(); + assert_eq!(parser.zone(), Zone::Unknown); + } + + #[test] + fn zone_default() { + assert_eq!(Zone::default(), Zone::Unknown); + } + + // -- D with empty exit code field ----------------------------------------- + + #[test] + fn d_with_semicolon_but_empty_code() { + // "133;D;" — semicolon present but no digits. + let data = b"\x1b]133;D;\x07"; + assert_eq!( + parse_events(data), + vec![Event::CommandFinished { exit_code: None }] + ); + } + + // -- Consecutive OSC sequences without gap -------------------------------- + + #[test] + fn back_to_back_osc_no_gap() { + let data = b"\x1b]133;A\x07\x1b]133;B\x07"; + let events = parse_events(data); + assert_eq!(events, vec![Event::PromptStart, Event::CommandStart]); + } + + // -- CSI sequences interleaved (should not confuse parser) ---------------- + + #[test] + fn csi_sequences_ignored() { + // CSI (ESC [) color codes mixed with OSC 133. + let data = b"\x1b[32m\x1b]133;A\x07\x1b[0m$ \x1b]133;B\x07"; + let events = parse_events(data); + assert_eq!(events, vec![Event::PromptStart, Event::CommandStart]); + } + + // -- Large exit codes ----------------------------------------------------- + + #[test] + fn large_exit_code() { + let data = b"\x1b]133;D;2147483647\x07"; + assert_eq!( + parse_events(data), + vec![Event::CommandFinished { + exit_code: Some(i32::MAX) + }] + ); + } + + #[test] + fn overflow_exit_code_yields_none() { + let data = b"\x1b]133;D;9999999999999\x07"; + assert_eq!( + parse_events(data), + vec![Event::CommandFinished { exit_code: None }] + ); + } +} diff --git a/crates/turtle/src/atuin_pty_proxy/pty_proxy.rs b/crates/turtle/src/atuin_pty_proxy/pty_proxy.rs new file mode 100644 index 00000000..8dde6f53 --- /dev/null +++ b/crates/turtle/src/atuin_pty_proxy/pty_proxy.rs @@ -0,0 +1,231 @@ +use clap::{Args, Subcommand, ValueEnum}; + +use crate::atuin_pty_proxy::{CommandCaptureSink, runtime}; + +#[derive(Args, Debug)] +pub struct PtyProxy { + /// Highlight OSC 133 prompt, input, output, and exit-code regions + #[arg(long)] + debug_osc133: bool, + + #[command(subcommand)] + cmd: Option<Cmd>, +} + +#[derive(Subcommand, Debug)] +pub enum Cmd { + /// Print shell code to initialize atuin pty-proxy on shell startup + Init(Init), +} + +#[derive(Args, Debug)] +pub struct Init { + /// Shell to generate init for. If omitted, attempt auto-detection + #[arg(value_enum)] + shell: Option<Shell>, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq, ValueEnum)] +#[value(rename_all = "lower")] +#[expect(clippy::enum_variant_names, clippy::doc_markdown)] +enum Shell { + /// Zsh setup + Zsh, + /// Bash setup + Bash, + /// Fish setup + Fish, + /// Nu setup + Nu, +} + +pub(crate) struct RuntimeOptions { + pub(crate) debug_osc133: bool, + pub(crate) command_capture_sink: Option<CommandCaptureSink>, +} + +impl RuntimeOptions { + fn new(debug_osc133: bool, command_capture_sink: Option<CommandCaptureSink>) -> Self { + Self { + debug_osc133: debug_osc133 || env_flag("ATUIN_PTY_PROXY_DEBUG"), + command_capture_sink, + } + } +} + +impl PtyProxy { + pub fn run(self, command_capture_sink: Option<CommandCaptureSink>) { + match self.cmd { + Some(Cmd::Init(init)) => { + if let Err(err) = init.run() { + eprintln!("atuin pty-proxy: {err}"); + std::process::exit(1); + } + } + None => runtime::main(RuntimeOptions::new(self.debug_osc133, command_capture_sink)), + } + } +} + +impl Init { + fn run(self) -> Result<(), String> { + let shell = detect_shell(self.shell)?; + let script = render_init(shell); + print!("{script}"); + Ok(()) + } +} + +fn detect_shell(cli_shell: Option<Shell>) -> Result<Shell, String> { + if let Some(shell) = cli_shell { + return Ok(shell); + } + + if let Ok(shell) = std::env::var("ATUIN_SHELL") + && let Some(shell) = shell_from_name(&shell) + { + return Ok(shell); + } + + if let Ok(shell) = std::env::var("SHELL") + && let Some(shell) = shell_from_name(&shell) + { + return Ok(shell); + } + + Err( + "could not detect a supported shell. Please specify one explicitly: bash, zsh, fish, or nu" + .to_string(), + ) +} + +fn shell_from_name(name: &str) -> Option<Shell> { + let shell = name + .trim() + .rsplit('/') + .next() + .unwrap_or(name) + .trim_start_matches('-') + .to_ascii_lowercase(); + + match shell.as_str() { + "bash" => Some(Shell::Bash), + "zsh" => Some(Shell::Zsh), + "fish" => Some(Shell::Fish), + "nu" => Some(Shell::Nu), + _ => None, + } +} + +fn env_flag(name: &str) -> bool { + std::env::var(name).is_ok_and(|value| { + matches!( + value.trim().to_ascii_lowercase().as_str(), + "1" | "true" | "yes" | "on" + ) + }) +} + +fn render_init(shell: Shell) -> &'static str { + match shell { + Shell::Bash | Shell::Zsh => { + r#"if [[ "$-" == *i* ]] && [[ -t 0 ]] && [[ -t 1 ]]; then + _atuin_pty_proxy_tmux_current="${TMUX:-}" + _atuin_pty_proxy_tmux_previous="${ATUIN_PTY_PROXY_TMUX:-}" + + if [[ -z "${ATUIN_PTY_PROXY_ACTIVE:-}" ]] || [[ "$_atuin_pty_proxy_tmux_current" != "$_atuin_pty_proxy_tmux_previous" ]]; then + export ATUIN_PTY_PROXY_ACTIVE=1 + export ATUIN_PTY_PROXY_TMUX="$_atuin_pty_proxy_tmux_current" + exec atuin pty-proxy + fi + + unset _atuin_pty_proxy_tmux_current _atuin_pty_proxy_tmux_previous +fi +"# + } + Shell::Fish => { + r#"if status is-interactive; and test -t 0; and test -t 1 + set -l _atuin_pty_proxy_tmux_current "" + if set -q TMUX + set _atuin_pty_proxy_tmux_current "$TMUX" + end + + set -l _atuin_pty_proxy_tmux_previous "" + if set -q ATUIN_PTY_PROXY_TMUX + set _atuin_pty_proxy_tmux_previous "$ATUIN_PTY_PROXY_TMUX" + end + + if not set -q ATUIN_PTY_PROXY_ACTIVE + set -gx ATUIN_PTY_PROXY_ACTIVE 1 + set -gx ATUIN_PTY_PROXY_TMUX "$_atuin_pty_proxy_tmux_current" + exec atuin pty-proxy + else if test "$_atuin_pty_proxy_tmux_current" != "$_atuin_pty_proxy_tmux_previous" + set -gx ATUIN_PTY_PROXY_ACTIVE 1 + set -gx ATUIN_PTY_PROXY_TMUX "$_atuin_pty_proxy_tmux_current" + exec atuin pty-proxy + end +end +"# + } + // Nushell cannot dynamically source the output of `atuin init nu`, + // so we only output the pty-proxy preamble here. Users must also set up + // `atuin init nu` separately. + Shell::Nu => { + r#"if (is-terminal --stdin) and (is-terminal --stdout) { + let tmux_current = ($env.TMUX? | default "") + let tmux_previous = ($env.ATUIN_PTY_PROXY_TMUX? | default "") + + if (($env.ATUIN_PTY_PROXY_ACTIVE? | default "") | is-empty) or ($tmux_current != $tmux_previous) { + $env.ATUIN_PTY_PROXY_ACTIVE = "1" + $env.ATUIN_PTY_PROXY_TMUX = $tmux_current + exec atuin pty-proxy + } +} +"# + } + } +} + +#[cfg(test)] +mod tests { + use super::{Shell, render_init, shell_from_name}; + + #[test] + fn shell_from_name_handles_paths() { + assert_eq!(shell_from_name("/bin/zsh"), Some(Shell::Zsh)); + assert_eq!(shell_from_name("/usr/local/bin/bash"), Some(Shell::Bash)); + assert_eq!(shell_from_name("fish"), Some(Shell::Fish)); + assert_eq!(shell_from_name("nu"), Some(Shell::Nu)); + } + + #[test] + fn posix_init_uses_exec_and_tmux_guard() { + let script = render_init(Shell::Bash); + assert!(script.contains("exec atuin pty-proxy")); + assert!(script.contains("ATUIN_PTY_PROXY_TMUX")); + assert!(!script.contains("eval \"$(atuin init bash)\"")); + } + + #[test] + fn posix_init_has_no_double_braces() { + let script = render_init(Shell::Bash); + assert!(!script.contains("${{"), "double braces in bash init script"); + } + + #[test] + fn fish_init_uses_source() { + let script = render_init(Shell::Fish); + assert!(script.contains("exec atuin pty-proxy")); + assert!(!script.contains("atuin init fish | source")); + } + + #[test] + fn nu_init_uses_exec_and_tty_guard() { + let script = render_init(Shell::Nu); + assert!(script.contains("exec atuin pty-proxy")); + assert!(script.contains("ATUIN_PTY_PROXY_TMUX")); + assert!(script.contains("is-terminal --stdin")); + assert!(script.contains("is-terminal --stdout")); + assert!(script.contains("ATUIN_PTY_PROXY_ACTIVE")); + } +} diff --git a/crates/turtle/src/atuin_pty_proxy/runtime.rs b/crates/turtle/src/atuin_pty_proxy/runtime.rs new file mode 100644 index 00000000..37c77eef --- /dev/null +++ b/crates/turtle/src/atuin_pty_proxy/runtime.rs @@ -0,0 +1,184 @@ +use std::io::{Read, Write}; +use std::sync::Arc; +use std::sync::atomic::{AtomicU16, Ordering}; +use std::sync::mpsc; + +use crossterm::terminal; +use portable_pty::{CommandBuilder, PtySize, native_pty_system}; + +use crate::atuin_pty_proxy::capture::CommandCaptureTracker; +use crate::atuin_pty_proxy::debug::{Osc133DebugHighlighter, RESET}; +use crate::atuin_pty_proxy::pty_proxy::RuntimeOptions; +use crate::atuin_pty_proxy::screen::{self, Msg}; + +pub(crate) fn main(options: RuntimeOptions) { + if let Err(e) = run(options) { + let _ = terminal::disable_raw_mode(); + eprintln!("atuin pty-proxy: {e:#}"); + std::process::exit(1); + } +} + +fn run(options: RuntimeOptions) -> eyre::Result<()> { + let (cols, rows) = terminal::size()?; + + let pty_system = native_pty_system(); + let pair = pty_system + .openpty(PtySize { + rows, + cols, + pixel_width: 0, + pixel_height: 0, + }) + .map_err(|e| eyre::eyre!("{e:#}"))?; + + let sock_path = screen::socket_path(); + let _ = std::fs::remove_file(&sock_path); + + let mut cmd = CommandBuilder::new_default_prog(); + cmd.cwd(std::env::current_dir()?); + cmd.env("ATUIN_PTY_PROXY_SOCKET", sock_path.as_os_str()); + cmd.env("ATUIN_PTY_PROXY_ACTIVE", "1"); + + let mut child = pair + .slave + .spawn_command(cmd) + .map_err(|e| eyre::eyre!("{e:#}"))?; + + drop(pair.slave); + + let mut pty_reader = pair + .master + .try_clone_reader() + .map_err(|e| eyre::eyre!("{e:#}"))?; + let mut pty_writer = pair + .master + .take_writer() + .map_err(|e| eyre::eyre!("{e:#}"))?; + + let (msg_tx, msg_rx) = mpsc::sync_channel::<Msg>(64); + let current_cols = Arc::new(AtomicU16::new(cols.max(1))); + + screen::spawn_parser_thread(rows, cols, msg_rx); + screen::spawn_socket_server(sock_path.clone(), msg_tx.clone()); + spawn_resize_handler(pair.master, msg_tx.clone(), current_cols.clone())?; + + terminal::enable_raw_mode()?; + + let stdout_thread = std::thread::spawn(move || { + let mut stdout = std::io::stdout(); + let mut highlighter = options.debug_osc133.then(Osc133DebugHighlighter::new); + let mut capture_tracker = options + .command_capture_sink + .as_ref() + .map(|_| CommandCaptureTracker::new(current_cols)); + let mut buf = [0u8; 8192]; + + loop { + match pty_reader.read(&mut buf) { + Ok(0) | Err(_) => break, + Ok(n) => { + if let (Some(tracker), Some(sink)) = ( + capture_tracker.as_mut(), + options.command_capture_sink.as_ref(), + ) { + tracker.push(&buf[..n], sink); + } + + if let Some(highlighter) = highlighter.as_mut() { + let rendered = highlighter.render(&buf[..n]); + let _ = msg_tx.try_send(Msg::Data(rendered.clone())); + + if stdout.write_all(&rendered).is_err() { + break; + } + } else { + let _ = msg_tx.try_send(Msg::Data(buf[..n].to_vec())); + + if stdout.write_all(&buf[..n]).is_err() { + break; + } + } + let _ = stdout.flush(); + } + } + } + + if highlighter.is_some() { + let _ = stdout.write_all(RESET); + let _ = stdout.flush(); + } + }); + + std::thread::spawn(move || { + let mut stdin = std::io::stdin(); + let mut buf = [0u8; 8192]; + loop { + match stdin.read(&mut buf) { + Ok(0) | Err(_) => break, + Ok(n) => { + if pty_writer.write_all(&buf[..n]).is_err() { + break; + } + } + } + } + }); + + let status = child.wait()?; + let _ = stdout_thread.join(); + + let _ = terminal::disable_raw_mode(); + let _ = std::fs::remove_file(&sock_path); + + std::process::exit(process_exit_code(status.exit_code())); +} + +fn spawn_resize_handler( + master: Box<dyn portable_pty::MasterPty + Send>, + resize_tx: mpsc::SyncSender<Msg>, + current_cols: Arc<AtomicU16>, +) -> eyre::Result<()> { + use signal_hook::consts::SIGWINCH; + use signal_hook::iterator::Signals; + + let mut signals = Signals::new([SIGWINCH])?; + + std::thread::spawn(move || { + for _ in signals.forever() { + if let Ok((cols, rows)) = terminal::size() { + current_cols.store(cols.max(1), Ordering::Relaxed); + let _ = master.resize(PtySize { + rows, + cols, + pixel_width: 0, + pixel_height: 0, + }); + let _ = resize_tx.try_send(Msg::Resize { rows, cols }); + } + } + }); + + Ok(()) +} + +fn process_exit_code(code: u32) -> i32 { + i32::try_from(code).unwrap_or(1) +} + +#[cfg(test)] +mod tests { + use super::process_exit_code; + + #[test] + fn process_exit_code_preserves_valid_values() { + assert_eq!(process_exit_code(0), 0); + assert_eq!(process_exit_code(127), 127); + assert_eq!(process_exit_code(i32::MAX as u32), i32::MAX); + } + + #[test] + fn process_exit_code_defaults_when_out_of_range() { + assert_eq!(process_exit_code(i32::MAX as u32 + 1), 1); + } +} diff --git a/crates/turtle/src/atuin_pty_proxy/screen.rs b/crates/turtle/src/atuin_pty_proxy/screen.rs new file mode 100644 index 00000000..5b892e21 --- /dev/null +++ b/crates/turtle/src/atuin_pty_proxy/screen.rs @@ -0,0 +1,104 @@ +use std::io::Write; +use std::os::unix::net::UnixListener; +use std::path::PathBuf; +use std::sync::mpsc::{self, Receiver, SyncSender}; + +pub(crate) enum Msg { + Data(Vec<u8>), + Resize { rows: u16, cols: u16 }, + ScreenRequest(mpsc::Sender<Vec<u8>>), +} + +pub(crate) fn socket_path() -> PathBuf { + let dir = std::env::temp_dir(); + dir.join(format!("atuin-pty-proxy-{}.sock", std::process::id())) +} + +pub(crate) fn spawn_parser_thread(rows: u16, cols: u16, msg_rx: Receiver<Msg>) { + std::thread::spawn(move || { + let mut parser = vt100::Parser::new(rows, cols, 0); + + loop { + let first = match msg_rx.recv() { + Ok(msg) => msg, + Err(_) => break, + }; + + handle_parser_msg(&mut parser, first); + + while let Ok(msg) = msg_rx.try_recv() { + handle_parser_msg(&mut parser, msg); + } + } + }); +} + +pub(crate) fn spawn_socket_server(sock_path: PathBuf, screen_tx: SyncSender<Msg>) { + std::thread::spawn(move || { + let listener = match UnixListener::bind(&sock_path) { + Ok(l) => l, + Err(e) => { + eprintln!("atuin pty-proxy: failed to bind socket: {e}"); + return; + } + }; + + for stream in listener.incoming() { + let mut stream = match stream { + Ok(s) => s, + Err(_) => break, + }; + + let (reply_tx, reply_rx) = mpsc::channel(); + if screen_tx.send(Msg::ScreenRequest(reply_tx)).is_err() { + break; + } + if let Ok(data) = reply_rx.recv() { + let _ = stream.write_all(&data); + let _ = stream.flush(); + } + } + }); +} + +/// Wire format written to the Unix socket: +/// +/// ```text +/// [rows: u16 BE][cols: u16 BE][cursor_row: u16 BE][cursor_col: u16 BE] +/// [row_0_len: u32 BE][row_0_bytes...] +/// [row_1_len: u32 BE][row_1_bytes...] +/// ... +/// ``` +/// +/// Each row's bytes come from `screen.rows_formatted(0, cols)` and contain +/// pre-built ANSI escape sequences. The client can write them directly to +/// stdout without needing its own vt100 parser. +fn encode_screen(parser: &vt100::Parser) -> Vec<u8> { + let screen = parser.screen(); + let (rows, cols) = screen.size(); + let (cursor_row, cursor_col) = screen.cursor_position(); + + let mut buf: Vec<u8> = Vec::with_capacity(256 + (rows as usize * cols as usize)); + buf.extend_from_slice(&rows.to_be_bytes()); + buf.extend_from_slice(&cols.to_be_bytes()); + buf.extend_from_slice(&cursor_row.to_be_bytes()); + buf.extend_from_slice(&cursor_col.to_be_bytes()); + + for row_bytes in screen.rows_formatted(0, cols) { + let len = row_bytes.len() as u32; + buf.extend_from_slice(&len.to_be_bytes()); + buf.extend_from_slice(&row_bytes); + } + + buf +} + +fn handle_parser_msg(parser: &mut vt100::Parser, msg: Msg) { + match msg { + Msg::Data(data) => parser.process(&data), + Msg::Resize { rows, cols } => parser.screen_mut().set_size(rows, cols), + Msg::ScreenRequest(reply_tx) => { + let _ = reply_tx.send(encode_screen(parser)); + } + } +} diff --git a/crates/turtle/src/atuin_server/handlers/health.rs b/crates/turtle/src/atuin_server/handlers/health.rs new file mode 100644 index 00000000..aebd1e8f --- /dev/null +++ b/crates/turtle/src/atuin_server/handlers/health.rs @@ -0,0 +1,15 @@ +use axum::{Json, http, response::IntoResponse}; + +use serde::Serialize; + +#[derive(Serialize)] +pub struct HealthResponse { + pub status: &'static str, +} + +pub async fn health_check() -> impl IntoResponse { + ( + http::StatusCode::OK, + Json(HealthResponse { status: "healthy" }), + ) +} diff --git a/crates/turtle/src/atuin_server/handlers/history.rs b/crates/turtle/src/atuin_server/handlers/history.rs new file mode 100644 index 00000000..7f09161b --- /dev/null +++ b/crates/turtle/src/atuin_server/handlers/history.rs @@ -0,0 +1,237 @@ +use std::{collections::HashMap, convert::TryFrom}; + +use axum::{ + Json, + extract::{Path, Query, State}, + http::{HeaderMap, StatusCode}, +}; +use metrics::counter; +use time::{Month, UtcOffset}; +use tracing::{debug, error, instrument}; + +use super::{ErrorResponse, ErrorResponseStatus, RespExt}; +use crate::atuin_server::{ + router::{AppState, UserAuth}, + utils::client_version_min, +}; +use crate::atuin_server_database::{ + Database, + calendar::{TimePeriod, TimePeriodInfo}, + models::NewHistory, +}; + +use crate::atuin_common::api::*; + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn count<DB: Database>( + UserAuth(user): UserAuth, + state: State<AppState<DB>>, +) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> { + let db = &state.0.database; + match db.count_history_cached(&user).await { + // By default read out the cached value + Ok(count) => Ok(Json(CountResponse { count })), + + // If that fails, fallback on a full COUNT. Cache is built on a POST + // only + Err(_) => match db.count_history(&user).await { + Ok(count) => Ok(Json(CountResponse { count })), + Err(_) => Err(ErrorResponse::reply("failed to query history count") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)), + }, + } +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn list<DB: Database>( + req: Query<SyncHistoryRequest>, + UserAuth(user): UserAuth, + headers: HeaderMap, + state: State<AppState<DB>>, +) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> { + let db = &state.0.database; + + let agent = headers + .get("user-agent") + .map_or("", |v| v.to_str().unwrap_or("")); + + let variable_page_size = client_version_min(agent, ">=15.0.0").unwrap_or(false); + + let page_size = if variable_page_size { + state.settings.page_size + } else { + 100 + }; + + if req.sync_ts.unix_timestamp_nanos() < 0 || req.history_ts.unix_timestamp_nanos() < 0 { + error!("client asked for history from < epoch 0"); + counter!("atuin_history_epoch_before_zero").increment(1); + + return Err( + ErrorResponse::reply("asked for history from before epoch 0") + .with_status(StatusCode::BAD_REQUEST), + ); + } + + let history = db + .list_history(&user, req.sync_ts, req.history_ts, &req.host, page_size) + .await; + + if let Err(e) = history { + error!("failed to load history: {}", e); + return Err(ErrorResponse::reply("failed to load history") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + + let history: Vec<String> = history + .unwrap() + .iter() + .map(|i| i.data.to_string()) + .collect(); + + debug!( + "loaded {} items of history for user {}", + history.len(), + user.id + ); + + counter!("atuin_history_returned").increment(history.len() as u64); + + Ok(Json(SyncHistoryResponse { history })) +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn delete<DB: Database>( + UserAuth(user): UserAuth, + state: State<AppState<DB>>, + Json(req): Json<DeleteHistoryRequest>, +) -> Result<Json<MessageResponse>, ErrorResponseStatus<'static>> { + let db = &state.0.database; + + // user_id is the ID of the history, as set by the user (the server has its own ID) + let deleted = db.delete_history(&user, req.client_id).await; + + if let Err(e) = deleted { + error!("failed to delete history: {}", e); + return Err(ErrorResponse::reply("failed to delete history") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + + Ok(Json(MessageResponse { + message: String::from("deleted OK"), + })) +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn add<DB: Database>( + UserAuth(user): UserAuth, + state: State<AppState<DB>>, + Json(req): Json<Vec<AddHistoryRequest>>, +) -> Result<(), ErrorResponseStatus<'static>> { + let State(AppState { database, settings }) = state; + + debug!("request to add {} history items", req.len()); + counter!("atuin_history_uploaded").increment(req.len() as u64); + + let mut history: Vec<NewHistory> = req + .into_iter() + .map(|h| NewHistory { + client_id: h.id, + user_id: user.id, + hostname: h.hostname, + timestamp: h.timestamp, + data: h.data, + }) + .collect(); + + history.retain(|h| { + // keep if within limit, or limit is 0 (unlimited) + let keep = h.data.len() <= settings.max_history_length || settings.max_history_length == 0; + + // Don't return an error here. We want to insert as much of the + // history list as we can, so log the error and continue going. + if !keep { + counter!("atuin_history_too_long").increment(1); + + tracing::warn!( + "history too long, got length {}, max {}", + h.data.len(), + settings.max_history_length + ); + } + + keep + }); + + if let Err(e) = database.add_history(&history).await { + error!("failed to add history: {}", e); + + return Err(ErrorResponse::reply("failed to add history") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + }; + + Ok(()) +} + +#[derive(serde::Deserialize, Debug)] +pub struct CalendarQuery { + #[serde(default = "serde_calendar::zero")] + year: i32, + #[serde(default = "serde_calendar::one")] + month: u8, + + #[serde(default = "serde_calendar::utc")] + tz: UtcOffset, +} + +mod serde_calendar { + use time::UtcOffset; + + pub fn zero() -> i32 { + 0 + } + + pub fn one() -> u8 { + 1 + } + + pub fn utc() -> UtcOffset { + UtcOffset::UTC + } +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn calendar<DB: Database>( + Path(focus): Path<String>, + Query(params): Query<CalendarQuery>, + UserAuth(user): UserAuth, + state: State<AppState<DB>>, +) -> Result<Json<HashMap<u64, TimePeriodInfo>>, ErrorResponseStatus<'static>> { + let focus = focus.as_str(); + + let year = params.year; + let month = Month::try_from(params.month).map_err(|e| ErrorResponseStatus { + error: ErrorResponse { + reason: e.to_string().into(), + }, + status: StatusCode::BAD_REQUEST, + })?; + + let period = match focus { + "year" => TimePeriod::Year, + "month" => TimePeriod::Month { year }, + "day" => TimePeriod::Day { year, month }, + _ => { + return Err(ErrorResponse::reply("invalid focus: use year/month/day") + .with_status(StatusCode::BAD_REQUEST)); + } + }; + + let db = &state.0.database; + let focus = db.calendar(&user, period, params.tz).await.map_err(|_| { + ErrorResponse::reply("failed to query calendar") + .with_status(StatusCode::INTERNAL_SERVER_ERROR) + })?; + + Ok(Json(focus)) +} diff --git a/crates/turtle/src/atuin_server/handlers/mod.rs b/crates/turtle/src/atuin_server/handlers/mod.rs new file mode 100644 index 00000000..7722d03e --- /dev/null +++ b/crates/turtle/src/atuin_server/handlers/mod.rs @@ -0,0 +1,60 @@ +use crate::atuin_common::api::{ErrorResponse, IndexResponse}; +use crate::atuin_server_database::Database; +use axum::{Json, extract::State, http, response::IntoResponse}; + +use crate::atuin_server::router::AppState; + +pub mod health; +pub mod history; +pub mod record; +pub mod status; +pub mod user; +pub mod v0; + +const VERSION: &str = env!("CARGO_PKG_VERSION"); + +pub async fn index<DB: Database>(state: State<AppState<DB>>) -> Json<IndexResponse> { + let homage = r#""Through the fathomless deeps of space swims the star turtle Great A'Tuin, bearing on its back the four giant elephants who carry on their shoulders the mass of the Discworld." -- Sir Terry Pratchett"#; + + let version = state + .settings + .fake_version + .clone() + .unwrap_or(VERSION.to_string()); + + Json(IndexResponse { + homage: homage.to_string(), + version, + }) +} + +impl IntoResponse for ErrorResponseStatus<'_> { + fn into_response(self) -> axum::response::Response { + (self.status, Json(self.error)).into_response() + } +} + +pub struct ErrorResponseStatus<'a> { + pub error: ErrorResponse<'a>, + pub status: http::StatusCode, +} + +pub trait RespExt<'a> { + fn with_status(self, status: http::StatusCode) -> ErrorResponseStatus<'a>; + fn reply(reason: &'a str) -> Self; +} + +impl<'a> RespExt<'a> for ErrorResponse<'a> { + fn with_status(self, status: http::StatusCode) -> ErrorResponseStatus<'a> { + ErrorResponseStatus { + error: self, + status, + } + } + + fn reply(reason: &'a str) -> ErrorResponse<'a> { + Self { + reason: reason.into(), + } + } +} diff --git a/crates/turtle/src/atuin_server/handlers/record.rs b/crates/turtle/src/atuin_server/handlers/record.rs new file mode 100644 index 00000000..63325606 --- /dev/null +++ b/crates/turtle/src/atuin_server/handlers/record.rs @@ -0,0 +1,42 @@ +use axum::{Json, http::StatusCode, response::IntoResponse}; +use serde_json::json; +use tracing::instrument; + +use super::{ErrorResponse, ErrorResponseStatus, RespExt}; +use crate::atuin_server::router::UserAuth; + +use crate::atuin_common::record::{EncryptedData, Record}; + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn post(UserAuth(user): UserAuth) -> Result<(), ErrorResponseStatus<'static>> { + // anyone who has actually used the old record store (a very small number) will see this error + // upon trying to sync. + // 1. The status endpoint will say that the server has nothing + // 2. The client will try to upload local records + // 3. Sync will fail with this error + + // If the client has no local records, they will see the empty index and do nothing. For the + // vast majority of users, this is the case. + return Err( + ErrorResponse::reply("record store deprecated; please upgrade") + .with_status(StatusCode::BAD_REQUEST), + ); +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn index(UserAuth(user): UserAuth) -> axum::response::Response { + let ret = json!({ + "hosts": {} + }); + + ret.to_string().into_response() +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn next( + UserAuth(user): UserAuth, +) -> Result<Json<Vec<Record<EncryptedData>>>, ErrorResponseStatus<'static>> { + let records = Vec::new(); + + Ok(Json(records)) +} diff --git a/crates/turtle/src/atuin_server/handlers/status.rs b/crates/turtle/src/atuin_server/handlers/status.rs new file mode 100644 index 00000000..0cf2ca1e --- /dev/null +++ b/crates/turtle/src/atuin_server/handlers/status.rs @@ -0,0 +1,45 @@ +use axum::{Json, extract::State, http::StatusCode}; +use tracing::instrument; + +use super::{ErrorResponse, ErrorResponseStatus, RespExt}; +use crate::atuin_server::router::{AppState, UserAuth}; +use crate::atuin_server_database::Database; + +use crate::atuin_common::api::*; + +const VERSION: &str = env!("CARGO_PKG_VERSION"); + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn status<DB: Database>( + UserAuth(user): UserAuth, + state: State<AppState<DB>>, +) -> Result<Json<StatusResponse>, ErrorResponseStatus<'static>> { + let db = &state.0.database; + + let deleted = db.deleted_history(&user).await.unwrap_or(vec![]); + + let count = match db.count_history_cached(&user).await { + // By default read out the cached value + Ok(count) => count, + + // If that fails, fallback on a full COUNT. Cache is built on a POST + // only + Err(_) => match db.count_history(&user).await { + Ok(count) => count, + Err(_) => { + return Err(ErrorResponse::reply("failed to query history count") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + }, + }; + + tracing::debug!(user = user.username, "requested sync status"); + + Ok(Json(StatusResponse { + count, + deleted, + username: user.username, + version: VERSION.to_string(), + page_size: state.settings.page_size, + })) +} diff --git a/crates/turtle/src/atuin_server/handlers/user.rs b/crates/turtle/src/atuin_server/handlers/user.rs new file mode 100644 index 00000000..01b72202 --- /dev/null +++ b/crates/turtle/src/atuin_server/handlers/user.rs @@ -0,0 +1,269 @@ +use std::borrow::Borrow; +use std::collections::HashMap; +use std::time::Duration; + +use argon2::{ + Algorithm, Argon2, Params, PasswordHash, PasswordHasher, PasswordVerifier, Version, + password_hash::SaltString, +}; +use axum::{ + Json, + extract::{Path, State}, + http::StatusCode, +}; +use metrics::counter; + +use rand::rngs::OsRng; +use tracing::{debug, error, info, instrument}; + +use crate::atuin_common::tls::ensure_crypto_provider; + +use super::{ErrorResponse, ErrorResponseStatus, RespExt}; +use crate::atuin_server::router::{AppState, UserAuth}; +use crate::atuin_server_database::{ + Database, DbError, + models::{NewSession, NewUser}, +}; + +use reqwest::header::CONTENT_TYPE; + +use crate::atuin_common::{api::*, utils::crypto_random_string}; + +pub fn verify_str(hash: &str, password: &str) -> bool { + let arg2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, Params::default()); + let Ok(hash) = PasswordHash::new(hash) else { + return false; + }; + arg2.verify_password(password.as_bytes(), &hash).is_ok() +} + +// Try to send a Discord webhook once - if it fails, we don't retry. "At most once", and best effort. +// Don't return the status because if this fails, we don't really care. +async fn send_register_hook(url: &str, username: String, registered: String) { + ensure_crypto_provider(); + let hook = HashMap::from([ + ("username", username), + ("content", format!("{registered} has just signed up!")), + ]); + + let client = reqwest::Client::new(); + + let resp = client + .post(url) + .timeout(Duration::new(5, 0)) + .header(CONTENT_TYPE, "application/json") + .json(&hook) + .send() + .await; + + match resp { + Ok(_) => info!("register webhook sent ok!"), + Err(e) => error!("failed to send register webhook: {}", e), + } +} + +#[instrument(skip_all, fields(user.username = username.as_str()))] +pub async fn get<DB: Database>( + Path(username): Path<String>, + state: State<AppState<DB>>, +) -> Result<Json<UserResponse>, ErrorResponseStatus<'static>> { + let db = &state.0.database; + let user = match db.get_user(username.as_ref()).await { + Ok(user) => user, + Err(DbError::NotFound) => { + debug!("user not found: {}", username); + return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); + } + Err(DbError::Other(err)) => { + error!("database error: {}", err); + return Err(ErrorResponse::reply("database error") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + }; + + Ok(Json(UserResponse { + username: user.username, + })) +} + +#[instrument(skip_all)] +pub async fn register<DB: Database>( + state: State<AppState<DB>>, + Json(register): Json<RegisterRequest>, +) -> Result<Json<RegisterResponse>, ErrorResponseStatus<'static>> { + if !state.settings.open_registration { + return Err( + ErrorResponse::reply("this server is not open for registrations") + .with_status(StatusCode::BAD_REQUEST), + ); + } + + for c in register.username.chars() { + match c { + 'a'..='z' | 'A'..='Z' | '0'..='9' | '-' => {} + _ => { + return Err(ErrorResponse::reply( + "Only alphanumeric and hyphens (-) are allowed in usernames", + ) + .with_status(StatusCode::BAD_REQUEST)); + } + } + } + + let hashed = hash_secret(®ister.password); + + let new_user = NewUser { + email: register.email.clone(), + username: register.username.clone(), + password: hashed, + }; + + let db = &state.0.database; + let user_id = match db.add_user(&new_user).await { + Ok(id) => id, + Err(e) => { + error!("failed to add user: {}", e); + return Err( + ErrorResponse::reply("failed to add user").with_status(StatusCode::BAD_REQUEST) + ); + } + }; + + // 24 bytes encoded as base64 + let token = crypto_random_string::<24>(); + + let new_session = NewSession { + user_id, + token: (&token).into(), + }; + + if let Some(url) = &state.settings.register_webhook_url { + // Could probs be run on another thread, but it's ok atm + send_register_hook( + url, + state.settings.register_webhook_username.clone(), + register.username, + ) + .await; + } + + counter!("atuin_users_registered").increment(1); + + match db.add_session(&new_session).await { + Ok(_) => Ok(Json(RegisterResponse { + session: token, + auth: Some("cli".into()), + })), + Err(e) => { + error!("failed to add session: {}", e); + Err(ErrorResponse::reply("failed to register user") + .with_status(StatusCode::BAD_REQUEST)) + } + } +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn delete<DB: Database>( + UserAuth(user): UserAuth, + state: State<AppState<DB>>, +) -> Result<Json<DeleteUserResponse>, ErrorResponseStatus<'static>> { + debug!("request to delete user {}", user.id); + + let db = &state.0.database; + if let Err(e) = db.delete_user(&user).await { + error!("failed to delete user: {}", e); + + return Err(ErrorResponse::reply("failed to delete user") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + }; + + counter!("atuin_users_deleted").increment(1); + + Ok(Json(DeleteUserResponse {})) +} + +#[instrument(skip_all, fields(user.id = user.id, change_password))] +pub async fn change_password<DB: Database>( + UserAuth(mut user): UserAuth, + state: State<AppState<DB>>, + Json(change_password): Json<ChangePasswordRequest>, +) -> Result<Json<ChangePasswordResponse>, ErrorResponseStatus<'static>> { + let db = &state.0.database; + + let verified = verify_str( + user.password.as_str(), + change_password.current_password.borrow(), + ); + if !verified { + return Err( + ErrorResponse::reply("password is not correct").with_status(StatusCode::UNAUTHORIZED) + ); + } + + let hashed = hash_secret(&change_password.new_password); + user.password = hashed; + + if let Err(e) = db.update_user_password(&user).await { + error!("failed to change user password: {}", e); + + return Err(ErrorResponse::reply("failed to change user password") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + }; + Ok(Json(ChangePasswordResponse {})) +} + +#[instrument(skip_all, fields(user.username = login.username.as_str()))] +pub async fn login<DB: Database>( + state: State<AppState<DB>>, + login: Json<LoginRequest>, +) -> Result<Json<LoginResponse>, ErrorResponseStatus<'static>> { + let db = &state.0.database; + let user = match db.get_user(login.username.borrow()).await { + Ok(u) => u, + Err(DbError::NotFound) => { + return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); + } + Err(DbError::Other(e)) => { + error!("failed to get user {}: {}", login.username.clone(), e); + + return Err(ErrorResponse::reply("database error") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + }; + + let session = match db.get_user_session(&user).await { + Ok(u) => u, + Err(DbError::NotFound) => { + debug!("user session not found for user id={}", user.id); + return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); + } + Err(DbError::Other(err)) => { + error!("database error for user {}: {}", login.username, err); + return Err(ErrorResponse::reply("database error") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + }; + + let verified = verify_str(user.password.as_str(), login.password.borrow()); + + if !verified { + debug!(user = user.username, "login failed"); + return Err( + ErrorResponse::reply("password is not correct").with_status(StatusCode::UNAUTHORIZED) + ); + } + + debug!(user = user.username, "login success"); + + Ok(Json(LoginResponse { + session: session.token, + auth: Some("cli".into()), + })) +} + +fn hash_secret(password: &str) -> String { + let arg2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, Params::default()); + let salt = SaltString::generate(&mut OsRng); + let hash = arg2.hash_password(password.as_bytes(), &salt).unwrap(); + hash.to_string() +} diff --git a/crates/turtle/src/atuin_server/handlers/v0/me.rs b/crates/turtle/src/atuin_server/handlers/v0/me.rs new file mode 100644 index 00000000..a1e2db46 --- /dev/null +++ b/crates/turtle/src/atuin_server/handlers/v0/me.rs @@ -0,0 +1,16 @@ +use axum::Json; +use tracing::instrument; + +use crate::atuin_server::handlers::ErrorResponseStatus; +use crate::atuin_server::router::UserAuth; + +use crate::atuin_common::api::*; + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn get( + UserAuth(user): UserAuth, +) -> Result<Json<MeResponse>, ErrorResponseStatus<'static>> { + Ok(Json(MeResponse { + username: user.username, + })) +} diff --git a/crates/turtle/src/atuin_server/handlers/v0/mod.rs b/crates/turtle/src/atuin_server/handlers/v0/mod.rs new file mode 100644 index 00000000..d6f880f2 --- /dev/null +++ b/crates/turtle/src/atuin_server/handlers/v0/mod.rs @@ -0,0 +1,3 @@ +pub(crate) mod me; +pub(crate) mod record; +pub(crate) mod store; diff --git a/crates/turtle/src/atuin_server/handlers/v0/record.rs b/crates/turtle/src/atuin_server/handlers/v0/record.rs new file mode 100644 index 00000000..9b147a52 --- /dev/null +++ b/crates/turtle/src/atuin_server/handlers/v0/record.rs @@ -0,0 +1,114 @@ +use axum::{Json, extract::Query, extract::State, http::StatusCode}; +use metrics::counter; +use serde::Deserialize; +use tracing::{error, instrument}; + +use crate::atuin_server::{ + handlers::{ErrorResponse, ErrorResponseStatus, RespExt}, + router::{AppState, UserAuth}, +}; +use crate::atuin_server_database::Database; + +use crate::atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}; + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn post<DB: Database>( + UserAuth(user): UserAuth, + state: State<AppState<DB>>, + Json(records): Json<Vec<Record<EncryptedData>>>, +) -> Result<(), ErrorResponseStatus<'static>> { + let State(AppState { database, settings }) = state; + + tracing::debug!( + count = records.len(), + user = user.username, + "request to add records" + ); + + counter!("atuin_record_uploaded").increment(records.len() as u64); + + let keep = records + .iter() + .all(|r| r.data.data.len() <= settings.max_record_size || settings.max_record_size == 0); + + if !keep { + counter!("atuin_record_too_large").increment(1); + + return Err( + ErrorResponse::reply("could not add records; record too large") + .with_status(StatusCode::BAD_REQUEST), + ); + } + + if let Err(e) = database.add_records(&user, &records).await { + error!("failed to add record: {}", e); + + return Err(ErrorResponse::reply("failed to add record") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + }; + + Ok(()) +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn index<DB: Database>( + UserAuth(user): UserAuth, + state: State<AppState<DB>>, +) -> Result<Json<RecordStatus>, ErrorResponseStatus<'static>> { + let State(AppState { + database, + settings: _, + }) = state; + + let record_index = match database.status(&user).await { + Ok(index) => index, + Err(e) => { + error!("failed to get record index: {}", e); + + return Err(ErrorResponse::reply("failed to calculate record index") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + }; + + tracing::debug!(user = user.username, "record index request"); + + Ok(Json(record_index)) +} + +#[derive(Deserialize)] +pub struct NextParams { + host: HostId, + tag: String, + start: Option<RecordIdx>, + count: u64, +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn next<DB: Database>( + params: Query<NextParams>, + UserAuth(user): UserAuth, + state: State<AppState<DB>>, +) -> Result<Json<Vec<Record<EncryptedData>>>, ErrorResponseStatus<'static>> { + let State(AppState { + database, + settings: _, + }) = state; + let params = params.0; + + let records = match database + .next_records(&user, params.host, params.tag, params.start, params.count) + .await + { + Ok(records) => records, + Err(e) => { + error!("failed to get record index: {}", e); + + return Err(ErrorResponse::reply("failed to calculate record index") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + }; + + counter!("atuin_record_downloaded").increment(records.len() as u64); + + Ok(Json(records)) +} diff --git a/crates/turtle/src/atuin_server/handlers/v0/store.rs b/crates/turtle/src/atuin_server/handlers/v0/store.rs new file mode 100644 index 00000000..cd184546 --- /dev/null +++ b/crates/turtle/src/atuin_server/handlers/v0/store.rs @@ -0,0 +1,37 @@ +use axum::{extract::Query, extract::State, http::StatusCode}; +use metrics::counter; +use serde::Deserialize; +use tracing::{error, instrument}; + +use crate::atuin_server::{ + handlers::{ErrorResponse, ErrorResponseStatus, RespExt}, + router::{AppState, UserAuth}, +}; +use crate::atuin_server_database::Database; + +#[derive(Deserialize)] +pub struct DeleteParams {} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn delete<DB: Database>( + _params: Query<DeleteParams>, + UserAuth(user): UserAuth, + state: State<AppState<DB>>, +) -> Result<(), ErrorResponseStatus<'static>> { + let State(AppState { + database, + settings: _, + }) = state; + + if let Err(e) = database.delete_store(&user).await { + counter!("atuin_store_delete_failed").increment(1); + error!("failed to delete store {e:?}"); + + return Err(ErrorResponse::reply("failed to delete store") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + + counter!("atuin_store_deleted").increment(1); + + Ok(()) +} diff --git a/crates/turtle/src/atuin_server/metrics.rs b/crates/turtle/src/atuin_server/metrics.rs new file mode 100644 index 00000000..ebd0dd2d --- /dev/null +++ b/crates/turtle/src/atuin_server/metrics.rs @@ -0,0 +1,55 @@ +use std::time::Instant; + +use axum::{ + extract::{MatchedPath, Request}, + middleware::Next, + response::IntoResponse, +}; +use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; + +pub fn setup_metrics_recorder() -> PrometheusHandle { + const EXPONENTIAL_SECONDS: &[f64] = &[ + 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, + ]; + + PrometheusBuilder::new() + .set_buckets_for_metric( + Matcher::Full("http_requests_duration_seconds".to_string()), + EXPONENTIAL_SECONDS, + ) + .unwrap() + .install_recorder() + .unwrap() +} + +/// Middleware to record some common HTTP metrics +/// Generic over B to allow for arbitrary body types (eg Vec<u8>, Streams, a deserialized thing, etc) +/// Someday tower-http might provide a metrics middleware: https://github.com/tower-rs/tower-http/issues/57 +pub async fn track_metrics(req: Request, next: Next) -> impl IntoResponse { + let start = Instant::now(); + + let path = match req.extensions().get::<MatchedPath>() { + Some(matched_path) => matched_path.as_str().to_owned(), + _ => req.uri().path().to_owned(), + }; + + let method = req.method().clone(); + + // Run the rest of the request handling first, so we can measure it and get response + // codes. + let response = next.run(req).await; + + let latency = start.elapsed().as_secs_f64(); + let status = response.status().as_u16().to_string(); + + let labels = [ + ("method", method.to_string()), + ("path", path), + ("status", status), + ]; + + metrics::counter!("http_requests_total", &labels).increment(1); + metrics::histogram!("http_requests_duration_seconds", &labels).record(latency); + + response +} diff --git a/crates/turtle/src/atuin_server/mod.rs b/crates/turtle/src/atuin_server/mod.rs new file mode 100644 index 00000000..bd0f2168 --- /dev/null +++ b/crates/turtle/src/atuin_server/mod.rs @@ -0,0 +1,86 @@ +use std::future::Future; +use std::net::SocketAddr; + +use crate::atuin_server_database::Database; +use axum::{Router, serve}; +use eyre::{Context, Result}; + +mod handlers; +mod metrics; +mod router; +mod utils; + +pub use settings::Settings; + +pub mod settings; + +use tokio::net::TcpListener; +use tokio::signal; + +#[cfg(target_family = "unix")] +async fn shutdown_signal() { + let mut term = signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to register signal handler"); + let mut interrupt = signal::unix::signal(signal::unix::SignalKind::interrupt()) + .expect("failed to register signal handler"); + + tokio::select! { + _ = term.recv() => {}, + _ = interrupt.recv() => {}, + }; + eprintln!("Shutting down gracefully..."); +} + +pub async fn launch<Db: Database>(settings: Settings, addr: SocketAddr) -> Result<()> { + launch_with_tcp_listener::<Db>( + settings, + TcpListener::bind(addr) + .await + .context("could not connect to socket")?, + shutdown_signal(), + ) + .await +} + +pub async fn launch_with_tcp_listener<Db: Database>( + settings: Settings, + listener: TcpListener, + shutdown: impl Future<Output = ()> + Send + 'static, +) -> Result<()> { + let r = make_router::<Db>(settings).await?; + + serve(listener, r.into_make_service()) + .with_graceful_shutdown(shutdown) + .await?; + + Ok(()) +} + +// The separate listener means it's much easier to ensure metrics are not accidentally exposed to +// the public. +pub async fn launch_metrics_server(host: String, port: u16) -> Result<()> { + let listener = TcpListener::bind((host, port)) + .await + .context("failed to bind metrics tcp")?; + + let recorder_handle = metrics::setup_metrics_recorder(); + + let router = Router::new().route( + "/metrics", + axum::routing::get(move || std::future::ready(recorder_handle.render())), + ); + + serve(listener, router.into_make_service()) + .with_graceful_shutdown(shutdown_signal()) + .await?; + + Ok(()) +} + +async fn make_router<Db: Database>(settings: Settings) -> Result<Router, eyre::Error> { + let db = Db::new(&settings.db_settings) + .await + .wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?; + let r = router::router(db, settings); + Ok(r) +} diff --git a/crates/turtle/src/atuin_server/router.rs b/crates/turtle/src/atuin_server/router.rs new file mode 100644 index 00000000..11a16148 --- /dev/null +++ b/crates/turtle/src/atuin_server/router.rs @@ -0,0 +1,155 @@ +use crate::atuin_common::api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ErrorResponse}; +use axum::{ + Router, + extract::{FromRequestParts, Request}, + http::{self, request::Parts}, + middleware::Next, + response::{IntoResponse, Response}, + routing::{delete, get, patch, post}, +}; +use eyre::Result; +use tower::ServiceBuilder; +use tower_http::trace::TraceLayer; + +use super::handlers; +use crate::atuin_server::{ + handlers::{ErrorResponseStatus, RespExt}, + metrics, + settings::Settings, +}; +use crate::atuin_server_database::{Database, DbError, models::User}; + +pub struct UserAuth(pub User); + +impl<DB: Send + Sync> FromRequestParts<AppState<DB>> for UserAuth +where + DB: Database, +{ + type Rejection = ErrorResponseStatus<'static>; + + async fn from_request_parts( + req: &mut Parts, + state: &AppState<DB>, + ) -> Result<Self, Self::Rejection> { + let auth_header = req + .headers + .get(http::header::AUTHORIZATION) + .ok_or_else(|| { + ErrorResponse::reply("missing authorization header") + .with_status(http::StatusCode::BAD_REQUEST) + })?; + let auth_header = auth_header.to_str().map_err(|_| { + ErrorResponse::reply("invalid authorization header encoding") + .with_status(http::StatusCode::BAD_REQUEST) + })?; + let (typ, token) = auth_header.split_once(' ').ok_or_else(|| { + ErrorResponse::reply("invalid authorization header encoding") + .with_status(http::StatusCode::BAD_REQUEST) + })?; + + if typ != "Token" { + return Err( + ErrorResponse::reply("invalid authorization header encoding") + .with_status(http::StatusCode::BAD_REQUEST), + ); + } + + let user = state + .database + .get_session_user(token) + .await + .map_err(|e| match e { + DbError::NotFound => ErrorResponse::reply("session not found") + .with_status(http::StatusCode::FORBIDDEN), + DbError::Other(e) => { + tracing::error!(error = ?e, "could not query user session"); + ErrorResponse::reply("could not query user session") + .with_status(http::StatusCode::INTERNAL_SERVER_ERROR) + } + })?; + + Ok(UserAuth(user)) + } +} + +async fn teapot() -> impl IntoResponse { + // This used to return 418: 🫖 + // Much as it was fun, it wasn't as useful or informative as it should be + (http::StatusCode::NOT_FOUND, "404 not found") +} + +async fn clacks_overhead(request: Request, next: Next) -> Response { + let mut response = next.run(request).await; + + let gnu_terry_value = "GNU Terry Pratchett, Kris Nova"; + let gnu_terry_header = "X-Clacks-Overhead"; + + response + .headers_mut() + .insert(gnu_terry_header, gnu_terry_value.parse().unwrap()); + response +} + +/// Ensure that we only try and sync with clients on the same major version +async fn semver(request: Request, next: Next) -> Response { + let mut response = next.run(request).await; + response + .headers_mut() + .insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse().unwrap()); + + response +} + +#[derive(Clone)] +pub struct AppState<DB: Database> { + pub database: DB, + pub settings: Settings, +} + +pub fn router<DB: Database>(database: DB, settings: Settings) -> Router { + let mut routes = Router::new() + .route("/", get(handlers::index)) + .route("/healthz", get(handlers::health::health_check)); + + // Sync v1 routes - can be disabled in favor of record-based sync + if settings.sync_v1_enabled { + routes = routes + .route("/sync/count", get(handlers::history::count)) + .route("/sync/history", get(handlers::history::list)) + .route("/sync/calendar/{focus}", get(handlers::history::calendar)) + .route("/sync/status", get(handlers::status::status)) + .route("/history", post(handlers::history::add)) + .route("/history", delete(handlers::history::delete)); + } + + let routes = routes + .route("/user/{username}", get(handlers::user::get)) + .route("/account", delete(handlers::user::delete)) + .route("/account/password", patch(handlers::user::change_password)) + .route("/register", post(handlers::user::register)) + .route("/login", post(handlers::user::login)) + .route("/record", post(handlers::record::post)) + .route("/record", get(handlers::record::index)) + .route("/record/next", get(handlers::record::next)) + .route("/api/v0/me", get(handlers::v0::me::get)) + .route("/api/v0/record", post(handlers::v0::record::post)) + .route("/api/v0/record", get(handlers::v0::record::index)) + .route("/api/v0/record/next", get(handlers::v0::record::next)) + .route("/api/v0/store", delete(handlers::v0::store::delete)); + + let path = settings.path.as_str(); + if path.is_empty() { + routes + } else { + Router::new().nest(path, routes) + } + .fallback(teapot) + .with_state(AppState { database, settings }) + .layer( + ServiceBuilder::new() + .layer(axum::middleware::from_fn(clacks_overhead)) + .layer(TraceLayer::new_for_http()) + .layer(axum::middleware::from_fn(metrics::track_metrics)) + .layer(axum::middleware::from_fn(semver)), + ) +} diff --git a/crates/turtle/src/atuin_server/settings.rs b/crates/turtle/src/atuin_server/settings.rs new file mode 100644 index 00000000..f6650af0 --- /dev/null +++ b/crates/turtle/src/atuin_server/settings.rs @@ -0,0 +1,110 @@ +use std::{io::prelude::*, path::PathBuf}; + +use crate::atuin_server_database::DbSettings; +use config::{Config, Environment, File as ConfigFile, FileFormat}; +use eyre::{Result, eyre}; +use fs_err::{File, create_dir_all}; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Metrics { + #[serde(alias = "enabled")] + pub enable: bool, + pub host: String, + pub port: u16, +} + +impl Default for Metrics { + fn default() -> Self { + Self { + enable: false, + host: String::from("127.0.0.1"), + port: 9001, + } + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Settings { + pub host: String, + pub port: u16, + pub path: String, + pub open_registration: bool, + pub max_history_length: usize, + pub max_record_size: usize, + pub page_size: i64, + pub register_webhook_url: Option<String>, + pub register_webhook_username: String, + pub metrics: Metrics, + + /// Enable legacy sync v1 routes (history-based sync) + /// Set to false to use only the newer record-based sync + pub sync_v1_enabled: bool, + + /// Advertise a version that is not what we are _actually_ running + /// Many clients compare their version with api.atuin.sh, and if they differ, notify the user + /// that an update is available. + /// Now that we take beta releases, we should be able to advertise a different version to avoid + /// notifying users when the server runs something that is not a stable release. + pub fake_version: Option<String>, + + #[serde(flatten)] + pub db_settings: DbSettings, +} + +impl Settings { + pub fn new() -> Result<Self> { + let mut config_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { + PathBuf::from(p) + } else { + let mut config_file = PathBuf::new(); + let config_dir = crate::atuin_common::utils::config_dir(); + config_file.push(config_dir); + config_file + }; + + config_file.push("server.toml"); + + // create the config file if it does not exist + let mut config_builder = Config::builder() + .set_default("host", "127.0.0.1")? + .set_default("port", 8888)? + .set_default("open_registration", false)? + .set_default("max_history_length", 8192)? + .set_default("max_record_size", 1024 * 1024 * 1024)? // pretty chonky + .set_default("path", "")? + .set_default("register_webhook_username", "")? + .set_default("page_size", 1100)? + .set_default("metrics.enable", false)? + .set_default("metrics.host", "127.0.0.1")? + .set_default("metrics.port", 9001)? + .set_default("sync_v1_enabled", true)? + .add_source( + Environment::with_prefix("atuin") + .prefix_separator("_") + .separator("__"), + ); + + let config = if config_file.exists() { + config_builder + .add_source(ConfigFile::new( + config_file.to_str().unwrap(), + FileFormat::Toml, + )) + .build()? + } else { + create_dir_all(config_file.parent().unwrap())?; + let mut file = File::create(config_file)?; + + let config = config_builder.build()?; + // TODO(@bpeetz): I'm quiet unsure, if this will work <2026-06-10> + file.write_all(config.cache.to_string().as_bytes())?; + + config + }; + + config + .try_deserialize() + .map_err(|e| eyre!("failed to deserialize: {}", e)) + } +} diff --git a/crates/turtle/src/atuin_server/utils.rs b/crates/turtle/src/atuin_server/utils.rs new file mode 100644 index 00000000..12e9ac1b --- /dev/null +++ b/crates/turtle/src/atuin_server/utils.rs @@ -0,0 +1,15 @@ +use eyre::Result; +use semver::{Version, VersionReq}; + +pub fn client_version_min(user_agent: &str, req: &str) -> Result<bool> { + if user_agent.is_empty() { + return Ok(false); + } + + let version = user_agent.replace("atuin/", ""); + + let req = VersionReq::parse(req)?; + let version = Version::parse(version.as_str())?; + + Ok(req.matches(&version)) +} diff --git a/crates/turtle/src/atuin_server_database/calendar.rs b/crates/turtle/src/atuin_server_database/calendar.rs new file mode 100644 index 00000000..2229667b --- /dev/null +++ b/crates/turtle/src/atuin_server_database/calendar.rs @@ -0,0 +1,18 @@ +// Calendar data + +use serde::{Deserialize, Serialize}; +use time::Month; + +pub enum TimePeriod { + Year, + Month { year: i32 }, + Day { year: i32, month: Month }, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TimePeriodInfo { + pub count: u64, + + // TODO: Use this for merkle tree magic + pub hash: String, +} diff --git a/crates/turtle/src/atuin_server_database/mod.rs b/crates/turtle/src/atuin_server_database/mod.rs new file mode 100644 index 00000000..91077b84 --- /dev/null +++ b/crates/turtle/src/atuin_server_database/mod.rs @@ -0,0 +1,266 @@ +pub mod calendar; +pub mod models; + +use std::{ + collections::HashMap, + fmt::{Debug, Display}, + ops::Range, +}; + +use self::{ + calendar::{TimePeriod, TimePeriodInfo}, + models::{History, NewHistory, NewSession, NewUser, Session, User}, +}; +use async_trait::async_trait; +use crate::atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}; +use serde::{Deserialize, Serialize}; +use time::{Date, Duration, Month, OffsetDateTime, PrimitiveDateTime, Time, UtcOffset}; +use tracing::instrument; + +#[derive(Debug)] +pub enum DbError { + NotFound, + Other(eyre::Report), +} + +impl Display for DbError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{self:?}") + } +} + +impl From<time::error::ComponentRange> for DbError { + fn from(error: time::error::ComponentRange) -> Self { + DbError::Other(error.into()) + } +} + +impl From<time::error::Error> for DbError { + fn from(error: time::error::Error) -> Self { + DbError::Other(error.into()) + } +} + +impl From<sqlx::Error> for DbError { + fn from(error: sqlx::Error) -> Self { + match error { + sqlx::Error::RowNotFound => DbError::NotFound, + error => DbError::Other(error.into()), + } + } +} + +impl std::error::Error for DbError {} + +pub type DbResult<T> = Result<T, DbError>; + +#[derive(Debug, PartialEq)] +pub enum DbType { + Postgres, + Sqlite, + Unknown, +} + +#[derive(Clone, Deserialize, Serialize)] +pub struct DbSettings { + pub db_uri: String, + /// Optional URI for read replicas. If set, read-only queries will use this connection. + pub read_db_uri: Option<String>, +} + +impl DbSettings { + pub fn db_type(&self) -> DbType { + if self.db_uri.starts_with("postgres://") || self.db_uri.starts_with("postgresql://") { + DbType::Postgres + } else if self.db_uri.starts_with("sqlite://") { + DbType::Sqlite + } else { + DbType::Unknown + } + } +} + +fn redact_db_uri(uri: &str) -> String { + url::Url::parse(uri) + .map(|mut url| { + let _ = url.set_password(Some("****")); + url.to_string() + }) + .unwrap_or_else(|_| uri.to_string()) +} + +// Do our best to redact passwords so they're not logged in the event of an error. +impl Debug for DbSettings { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.db_type() == DbType::Postgres { + let redacted_uri = redact_db_uri(&self.db_uri); + let redacted_read_uri = self.read_db_uri.as_ref().map(|uri| redact_db_uri(uri)); + f.debug_struct("DbSettings") + .field("db_uri", &redacted_uri) + .field("read_db_uri", &redacted_read_uri) + .finish() + } else { + f.debug_struct("DbSettings") + .field("db_uri", &self.db_uri) + .field("read_db_uri", &self.read_db_uri) + .finish() + } + } +} + +#[async_trait] +pub trait Database: Sized + Clone + Send + Sync + 'static { + async fn new(settings: &DbSettings) -> DbResult<Self>; + + async fn get_session(&self, token: &str) -> DbResult<Session>; + async fn get_session_user(&self, token: &str) -> DbResult<User>; + async fn add_session(&self, session: &NewSession) -> DbResult<()>; + + async fn get_user(&self, username: &str) -> DbResult<User>; + async fn get_user_session(&self, u: &User) -> DbResult<Session>; + async fn add_user(&self, user: &NewUser) -> DbResult<i64>; + + async fn update_user_password(&self, u: &User) -> DbResult<()>; + + async fn count_history(&self, user: &User) -> DbResult<i64>; + async fn count_history_cached(&self, user: &User) -> DbResult<i64>; + + async fn delete_user(&self, u: &User) -> DbResult<()>; + async fn delete_history(&self, user: &User, id: String) -> DbResult<()>; + async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>>; + async fn delete_store(&self, user: &User) -> DbResult<()>; + + async fn add_records(&self, user: &User, record: &[Record<EncryptedData>]) -> DbResult<()>; + async fn next_records( + &self, + user: &User, + host: HostId, + tag: String, + start: Option<RecordIdx>, + count: u64, + ) -> DbResult<Vec<Record<EncryptedData>>>; + + // Return the tail record ID for each store, so (HostID, Tag, TailRecordID) + async fn status(&self, user: &User) -> DbResult<RecordStatus>; + + async fn count_history_range(&self, user: &User, range: Range<OffsetDateTime>) + -> DbResult<i64>; + + async fn list_history( + &self, + user: &User, + created_after: OffsetDateTime, + since: OffsetDateTime, + host: &str, + page_size: i64, + ) -> DbResult<Vec<History>>; + + async fn add_history(&self, history: &[NewHistory]) -> DbResult<()>; + + async fn oldest_history(&self, user: &User) -> DbResult<History>; + + #[instrument(skip_all)] + async fn calendar( + &self, + user: &User, + period: TimePeriod, + tz: UtcOffset, + ) -> DbResult<HashMap<u64, TimePeriodInfo>> { + let mut ret = HashMap::new(); + let iter: Box<dyn Iterator<Item = DbResult<(u64, Range<Date>)>> + Send> = match period { + TimePeriod::Year => { + // First we need to work out how far back to calculate. Get the + // oldest history item + let oldest = self + .oldest_history(user) + .await? + .timestamp + .to_offset(tz) + .year(); + let current_year = OffsetDateTime::now_utc().to_offset(tz).year(); + + // All the years we need to get data for + // The upper bound is exclusive, so include current +1 + let years = oldest..current_year + 1; + + Box::new(years.map(|year| { + let start = Date::from_calendar_date(year, time::Month::January, 1)?; + let end = Date::from_calendar_date(year + 1, time::Month::January, 1)?; + + Ok((year as u64, start..end)) + })) + } + + TimePeriod::Month { year } => { + let months = + std::iter::successors(Some(Month::January), |m| Some(m.next())).take(12); + + Box::new(months.map(move |month| { + let start = Date::from_calendar_date(year, month, 1)?; + let days = start.month().length(year); + let end = start + Duration::days(days as i64); + + Ok((month as u64, start..end)) + })) + } + + TimePeriod::Day { year, month } => { + let days = 1..month.length(year); + Box::new(days.map(move |day| { + let start = Date::from_calendar_date(year, month, day)?; + let end = start + .next_day() + .ok_or_else(|| DbError::Other(eyre::eyre!("no next day?")))?; + + Ok((day as u64, start..end)) + })) + } + }; + + for x in iter { + let (index, range) = x?; + + let start = range.start.with_time(Time::MIDNIGHT).assume_offset(tz); + let end = range.end.with_time(Time::MIDNIGHT).assume_offset(tz); + + let count = self.count_history_range(user, start..end).await?; + + ret.insert( + index, + TimePeriodInfo { + count: count as u64, + hash: "".to_string(), + }, + ); + } + + Ok(ret) + } +} + +pub fn into_utc(x: OffsetDateTime) -> PrimitiveDateTime { + let x = x.to_offset(UtcOffset::UTC); + PrimitiveDateTime::new(x.date(), x.time()) +} + +#[cfg(test)] +mod tests { + use time::macros::datetime; + + use crate::into_utc; + + #[test] + fn utc() { + let dt = datetime!(2023-09-26 15:11:02 +05:30); + assert_eq!(into_utc(dt), datetime!(2023-09-26 09:41:02)); + assert_eq!(into_utc(dt).assume_utc(), dt); + + let dt = datetime!(2023-09-26 15:11:02 -07:00); + assert_eq!(into_utc(dt), datetime!(2023-09-26 22:11:02)); + assert_eq!(into_utc(dt).assume_utc(), dt); + + let dt = datetime!(2023-09-26 15:11:02 +00:00); + assert_eq!(into_utc(dt), datetime!(2023-09-26 15:11:02)); + assert_eq!(into_utc(dt).assume_utc(), dt); + } +} diff --git a/crates/turtle/src/atuin_server_database/models.rs b/crates/turtle/src/atuin_server_database/models.rs new file mode 100644 index 00000000..b71a9bc9 --- /dev/null +++ b/crates/turtle/src/atuin_server_database/models.rs @@ -0,0 +1,52 @@ +use time::OffsetDateTime; + +pub struct History { + pub id: i64, + pub client_id: String, // a client generated ID + pub user_id: i64, + pub hostname: String, + pub timestamp: OffsetDateTime, + + /// All the data we have about this command, encrypted. + /// + /// Currently this is an encrypted msgpack object, but this may change in the future. + pub data: String, + + pub created_at: OffsetDateTime, +} + +pub struct NewHistory { + pub client_id: String, + pub user_id: i64, + pub hostname: String, + pub timestamp: OffsetDateTime, + + /// All the data we have about this command, encrypted. + /// + /// Currently this is an encrypted msgpack object, but this may change in the future. + pub data: String, +} + +pub struct User { + pub id: i64, + pub username: String, + pub email: String, + pub password: String, +} + +pub struct Session { + pub id: i64, + pub user_id: i64, + pub token: String, +} + +pub struct NewUser { + pub username: String, + pub email: String, + pub password: String, +} + +pub struct NewSession { + pub user_id: i64, + pub token: String, +} diff --git a/crates/turtle/src/atuin_server_postgres/mod.rs b/crates/turtle/src/atuin_server_postgres/mod.rs new file mode 100644 index 00000000..f506cf25 --- /dev/null +++ b/crates/turtle/src/atuin_server_postgres/mod.rs @@ -0,0 +1,583 @@ +use std::collections::HashMap; +use std::ops::Range; + +use rand::Rng; + +use crate::atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}; +use crate::atuin_server_database::models::{ + History, NewHistory, NewSession, NewUser, Session, User, +}; +use crate::atuin_server_database::{Database, DbError, DbResult, DbSettings, into_utc}; +use async_trait::async_trait; +use futures_util::TryStreamExt; +use sqlx::Row; +use sqlx::postgres::PgPoolOptions; + +use time::OffsetDateTime; +use tracing::instrument; +use uuid::Uuid; +use wrappers::{DbHistory, DbRecord, DbSession, DbUser}; + +mod wrappers; + +const MIN_PG_VERSION: u32 = 14; + +#[derive(Clone)] +pub struct Postgres { + pool: sqlx::Pool<sqlx::postgres::Postgres>, + /// Optional read replica pool for read-only queries + read_pool: Option<sqlx::Pool<sqlx::postgres::Postgres>>, +} + +impl Postgres { + /// Returns the appropriate pool for read operations. + /// Uses read_pool if available, otherwise falls back to the primary pool. + fn read_pool(&self) -> &sqlx::Pool<sqlx::postgres::Postgres> { + self.read_pool.as_ref().unwrap_or(&self.pool) + } +} + +#[async_trait] +impl Database for Postgres { + async fn new(settings: &DbSettings) -> DbResult<Self> { + let pool = PgPoolOptions::new() + .max_connections(100) + .connect(settings.db_uri.as_str()) + .await?; + + // Call server_version_num to get the DB server's major version number + // The call returns None for servers older than 8.x. + let pg_major_version: u32 = + pool.acquire() + .await? + .server_version_num() + .ok_or(DbError::Other(eyre::Report::msg( + "could not get PostgreSQL version", + )))? + / 10000; + + if pg_major_version < MIN_PG_VERSION { + return Err(DbError::Other(eyre::Report::msg(format!( + "unsupported PostgreSQL version {pg_major_version}, minimum required is {MIN_PG_VERSION}" + )))); + } + + sqlx::migrate!("./migrations") + .run(&pool) + .await + .map_err(|error| DbError::Other(error.into()))?; + + // Create read replica pool if configured + let read_pool = if let Some(read_db_uri) = &settings.read_db_uri { + tracing::info!("Connecting to read replica database"); + let read_pool = PgPoolOptions::new() + .max_connections(100) + .connect(read_db_uri.as_str()) + .await?; + + // Verify the read replica is also a supported PostgreSQL version + let read_pg_major_version: u32 = read_pool + .acquire() + .await? + .server_version_num() + .ok_or(DbError::Other(eyre::Report::msg( + "could not get PostgreSQL version from read replica", + )))? + / 10000; + + if read_pg_major_version < MIN_PG_VERSION { + return Err(DbError::Other(eyre::Report::msg(format!( + "unsupported PostgreSQL version {read_pg_major_version} on read replica, minimum required is {MIN_PG_VERSION}" + )))); + } + + Some(read_pool) + } else { + None + }; + + Ok(Self { pool, read_pool }) + } + + #[instrument(skip_all)] + async fn get_session(&self, token: &str) -> DbResult<Session> { + sqlx::query_as("select id, user_id, token from sessions where token = $1") + .bind(token) + .fetch_one(self.read_pool()) + .await + .map_err(Into::into) + .map(|DbSession(session)| session) + } + + #[instrument(skip_all)] + async fn get_user(&self, username: &str) -> DbResult<User> { + sqlx::query_as("select id, username, email, password from users where username = $1") + .bind(username) + .fetch_one(self.read_pool()) + .await + .map_err(Into::into) + .map(|DbUser(user)| user) + } + + #[instrument(skip_all)] + async fn get_session_user(&self, token: &str) -> DbResult<User> { + sqlx::query_as( + "select users.id, users.username, users.email, users.password from users + inner join sessions + on users.id = sessions.user_id + and sessions.token = $1", + ) + .bind(token) + .fetch_one(self.read_pool()) + .await + .map_err(Into::into) + .map(|DbUser(user)| user) + } + + #[instrument(skip_all)] + async fn count_history(&self, user: &User) -> DbResult<i64> { + // The cache is new, and the user might not yet have a cache value. + // They will have one as soon as they post up some new history, but handle that + // edge case. + + let res: (i64,) = sqlx::query_as( + "select count(1) from history + where user_id = $1", + ) + .bind(user.id) + .fetch_one(self.read_pool()) + .await?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn count_history_cached(&self, user: &User) -> DbResult<i64> { + let res: (i32,) = sqlx::query_as( + "select total from total_history_count_user + where user_id = $1", + ) + .bind(user.id) + .fetch_one(self.read_pool()) + .await?; + + Ok(res.0 as i64) + } + + async fn delete_store(&self, user: &User) -> DbResult<()> { + let mut tx = self.pool.begin().await?; + + sqlx::query( + "delete from store + where user_id = $1", + ) + .bind(user.id) + .execute(&mut *tx) + .await?; + + sqlx::query( + "delete from store_idx_cache + where user_id = $1", + ) + .bind(user.id) + .execute(&mut *tx) + .await?; + + tx.commit().await?; + + Ok(()) + } + + async fn delete_history(&self, user: &User, id: String) -> DbResult<()> { + sqlx::query( + "update history + set deleted_at = $3 + where user_id = $1 + and client_id = $2 + and deleted_at is null", // don't just keep setting it + ) + .bind(user.id) + .bind(id) + .bind(OffsetDateTime::now_utc()) + .fetch_all(&self.pool) + .await?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> { + // The cache is new, and the user might not yet have a cache value. + // They will have one as soon as they post up some new history, but handle that + // edge case. + + let res = sqlx::query( + "select client_id from history + where user_id = $1 + and deleted_at is not null", + ) + .bind(user.id) + .fetch_all(self.read_pool()) + .await?; + + let res = res + .iter() + .map(|row| row.get::<String, _>("client_id")) + .collect(); + + Ok(res) + } + + #[instrument(skip_all)] + async fn count_history_range( + &self, + user: &User, + range: Range<OffsetDateTime>, + ) -> DbResult<i64> { + let res: (i64,) = sqlx::query_as( + "select count(1) from history + where user_id = $1 + and timestamp >= $2::date + and timestamp < $3::date", + ) + .bind(user.id) + .bind(into_utc(range.start)) + .bind(into_utc(range.end)) + .fetch_one(self.read_pool()) + .await?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn list_history( + &self, + user: &User, + created_after: OffsetDateTime, + since: OffsetDateTime, + host: &str, + page_size: i64, + ) -> DbResult<Vec<History>> { + let res = sqlx::query_as( + "select id, client_id, user_id, hostname, timestamp, data, created_at from history + where user_id = $1 + and hostname != $2 + and created_at >= $3 + and timestamp >= $4 + order by timestamp asc + limit $5", + ) + .bind(user.id) + .bind(host) + .bind(into_utc(created_after)) + .bind(into_utc(since)) + .bind(page_size) + .fetch(self.read_pool()) + .map_ok(|DbHistory(h)| h) + .try_collect() + .await?; + + Ok(res) + } + + #[instrument(skip_all)] + async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> { + let mut tx = self.pool.begin().await?; + + for i in history { + let client_id: &str = &i.client_id; + let hostname: &str = &i.hostname; + let data: &str = &i.data; + + sqlx::query( + "insert into history + (client_id, user_id, hostname, timestamp, data) + values ($1, $2, $3, $4, $5) + on conflict do nothing + ", + ) + .bind(client_id) + .bind(i.user_id) + .bind(hostname) + .bind(i.timestamp) + .bind(data) + .execute(&mut *tx) + .await?; + } + + tx.commit().await?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn delete_user(&self, u: &User) -> DbResult<()> { + sqlx::query("delete from sessions where user_id = $1") + .bind(u.id) + .execute(&self.pool) + .await?; + + sqlx::query("delete from history where user_id = $1") + .bind(u.id) + .execute(&self.pool) + .await?; + + sqlx::query("delete from store where user_id = $1") + .bind(u.id) + .execute(&self.pool) + .await?; + + sqlx::query("delete from total_history_count_user where user_id = $1") + .bind(u.id) + .execute(&self.pool) + .await?; + + sqlx::query("delete from users where id = $1") + .bind(u.id) + .execute(&self.pool) + .await?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn update_user_password(&self, user: &User) -> DbResult<()> { + sqlx::query( + "update users + set password = $1 + where id = $2", + ) + .bind(&user.password) + .bind(user.id) + .execute(&self.pool) + .await?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn add_user(&self, user: &NewUser) -> DbResult<i64> { + let email: &str = &user.email; + let username: &str = &user.username; + let password: &str = &user.password; + + let res: (i64,) = sqlx::query_as( + "insert into users + (username, email, password) + values($1, $2, $3) + returning id", + ) + .bind(username) + .bind(email) + .bind(password) + .fetch_one(&self.pool) + .await?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn add_session(&self, session: &NewSession) -> DbResult<()> { + let token: &str = &session.token; + + sqlx::query( + "insert into sessions + (user_id, token) + values($1, $2)", + ) + .bind(session.user_id) + .bind(token) + .execute(&self.pool) + .await?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn get_user_session(&self, u: &User) -> DbResult<Session> { + sqlx::query_as("select id, user_id, token from sessions where user_id = $1") + .bind(u.id) + .fetch_one(self.read_pool()) + .await + .map_err(Into::into) + .map(|DbSession(session)| session) + } + + #[instrument(skip_all)] + async fn oldest_history(&self, user: &User) -> DbResult<History> { + sqlx::query_as( + "select id, client_id, user_id, hostname, timestamp, data, created_at from history + where user_id = $1 + order by timestamp asc + limit 1", + ) + .bind(user.id) + .fetch_one(self.read_pool()) + .await + .map_err(Into::into) + .map(|DbHistory(h)| h) + } + + #[instrument(skip_all)] + async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> { + let mut tx = self.pool.begin().await?; + + // We won't have uploaded this data if it wasn't the max. Therefore, we can deduce the max + // idx without having to make further database queries. Doing the query on this small + // amount of data should be much, much faster. + // + // Worst case, say we get this wrong. We end up caching data that isn't actually the max + // idx, so clients upload again. The cache logic can be verified with a sql query anyway :) + + let mut heads = HashMap::<(HostId, &str), u64>::new(); + + for i in records { + let id = crate::atuin_common::utils::uuid_v7(); + + let result = sqlx::query( + "insert into store + (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id) + values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + on conflict do nothing + ", + ) + .bind(id) + .bind(i.id) + .bind(i.host.id) + .bind(i.idx as i64) + .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time + .bind(&i.version) + .bind(&i.tag) + .bind(&i.data.data) + .bind(&i.data.content_encryption_key) + .bind(user.id) + .execute(&mut *tx) + .await?; + + // Only update heads if we actually inserted the record + if result.rows_affected() > 0 { + heads + .entry((i.host.id, &i.tag)) + .and_modify(|e| { + if i.idx > *e { + *e = i.idx + } + }) + .or_insert(i.idx); + } + } + + // we've built the map of heads for this push, so commit it to the database + for ((host, tag), idx) in heads { + sqlx::query( + "insert into store_idx_cache + (user_id, host, tag, idx) + values ($1, $2, $3, $4) + on conflict(user_id, host, tag) do update set idx = greatest(store_idx_cache.idx, $4) + ", + ) + .bind(user.id) + .bind(host) + .bind(tag) + .bind(idx as i64) + .execute(&mut *tx) + .await + ?; + } + + tx.commit().await?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn next_records( + &self, + user: &User, + host: HostId, + tag: String, + start: Option<RecordIdx>, + count: u64, + ) -> DbResult<Vec<Record<EncryptedData>>> { + tracing::debug!("{:?} - {:?} - {:?}", host, tag, start); + let start = start.unwrap_or(0); + + let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as( + "select client_id, host, idx, timestamp, version, tag, data, cek from store + where user_id = $1 + and tag = $2 + and host = $3 + and idx >= $4 + order by idx asc + limit $5", + ) + .bind(user.id) + .bind(tag.clone()) + .bind(host) + .bind(start as i64) + .bind(count as i64) + .fetch_all(self.read_pool()) + .await + .map_err(Into::into); + + let ret = match records { + Ok(records) => { + let records: Vec<Record<EncryptedData>> = records + .into_iter() + .map(|f| { + let record: Record<EncryptedData> = f.into(); + record + }) + .collect(); + + records + } + Err(DbError::NotFound) => { + tracing::debug!("no records found in store: {:?}/{}", host, tag); + return Ok(vec![]); + } + Err(e) => return Err(e), + }; + + Ok(ret) + } + + async fn status(&self, user: &User) -> DbResult<RecordStatus> { + const STATUS_SQL: &str = + "select host, tag, max(idx) from store where user_id = $1 group by host, tag"; + + // If IDX_CACHE_ROLLOUT is set, then we + // 1. Read the value of the var, use it as a % chance of using the cache + // 2. If we use the cache, just read from the cache table + // 3. If we don't use the cache, read from the store table + // IDX_CACHE_ROLLOUT should be between 0 and 100. + + let idx_cache_rollout = std::env::var("IDX_CACHE_ROLLOUT").unwrap_or("0".to_string()); + let idx_cache_rollout = idx_cache_rollout.parse::<f64>().unwrap_or(0.0); + let use_idx_cache = rand::thread_rng().gen_bool(idx_cache_rollout / 100.0); + + let mut res: Vec<(Uuid, String, i64)> = if use_idx_cache { + tracing::debug!("using idx cache for user {}", user.id); + sqlx::query_as("select host, tag, idx from store_idx_cache where user_id = $1") + .bind(user.id) + .fetch_all(self.read_pool()) + .await? + } else { + tracing::debug!("using aggregate query for user {}", user.id); + sqlx::query_as(STATUS_SQL) + .bind(user.id) + .fetch_all(self.read_pool()) + .await? + }; + + res.sort(); + + let mut status = RecordStatus::new(); + + for i in res.iter() { + status.set_raw(HostId(i.0), i.1.clone(), i.2 as u64); + } + + Ok(status) + } +} diff --git a/crates/turtle/src/atuin_server_postgres/wrappers.rs b/crates/turtle/src/atuin_server_postgres/wrappers.rs new file mode 100644 index 00000000..214b255d --- /dev/null +++ b/crates/turtle/src/atuin_server_postgres/wrappers.rs @@ -0,0 +1,77 @@ +use ::sqlx::{FromRow, Result}; +use crate::atuin_common::record::{EncryptedData, Host, Record}; +use crate::atuin_server_database::models::{History, Session, User}; +use sqlx::{Row, postgres::PgRow}; +use time::PrimitiveDateTime; + +pub struct DbUser(pub User); +pub struct DbSession(pub Session); +pub struct DbHistory(pub History); +pub struct DbRecord(pub Record<EncryptedData>); + +impl<'a> FromRow<'a, PgRow> for DbUser { + fn from_row(row: &'a PgRow) -> Result<Self> { + Ok(Self(User { + id: row.try_get("id")?, + username: row.try_get("username")?, + email: row.try_get("email")?, + password: row.try_get("password")?, + })) + } +} + +impl<'a> ::sqlx::FromRow<'a, PgRow> for DbSession { + fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> { + Ok(Self(Session { + id: row.try_get("id")?, + user_id: row.try_get("user_id")?, + token: row.try_get("token")?, + })) + } +} + +impl<'a> ::sqlx::FromRow<'a, PgRow> for DbHistory { + fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> { + Ok(Self(History { + id: row.try_get("id")?, + client_id: row.try_get("client_id")?, + user_id: row.try_get("user_id")?, + hostname: row.try_get("hostname")?, + timestamp: row + .try_get::<PrimitiveDateTime, _>("timestamp")? + .assume_utc(), + data: row.try_get("data")?, + created_at: row + .try_get::<PrimitiveDateTime, _>("created_at")? + .assume_utc(), + })) + } +} + +impl<'a> ::sqlx::FromRow<'a, PgRow> for DbRecord { + fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> { + let timestamp: i64 = row.try_get("timestamp")?; + let idx: i64 = row.try_get("idx")?; + + let data = EncryptedData { + data: row.try_get("data")?, + content_encryption_key: row.try_get("cek")?, + }; + + Ok(Self(Record { + id: row.try_get("client_id")?, + host: Host::new(row.try_get("host")?), + idx: idx as u64, + timestamp: timestamp as u64, + version: row.try_get("version")?, + tag: row.try_get("tag")?, + data, + })) + } +} + +impl From<DbRecord> for Record<EncryptedData> { + fn from(other: DbRecord) -> Record<EncryptedData> { + Record { ..other.0 } + } +} diff --git a/crates/turtle/src/atuin_server_sqlite/mod.rs b/crates/turtle/src/atuin_server_sqlite/mod.rs new file mode 100644 index 00000000..3470a2f1 --- /dev/null +++ b/crates/turtle/src/atuin_server_sqlite/mod.rs @@ -0,0 +1,430 @@ +use std::str::FromStr; + +use crate::atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}; +use crate::atuin_server_database::{ + Database, DbError, DbResult, DbSettings, into_utc, + models::{History, NewHistory, NewSession, NewUser, Session, User}, +}; +use async_trait::async_trait; +use futures_util::TryStreamExt; +use sqlx::{ + Row, + sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions}, + types::Uuid, +}; +use tracing::instrument; +use wrappers::{DbHistory, DbRecord, DbSession, DbUser}; + +mod wrappers; + +#[derive(Clone)] +pub struct Sqlite { + pool: sqlx::Pool<sqlx::sqlite::Sqlite>, +} + +#[async_trait] +impl Database for Sqlite { + async fn new(settings: &DbSettings) -> DbResult<Self> { + let opts = SqliteConnectOptions::from_str(&settings.db_uri)? + .journal_mode(SqliteJournalMode::Wal) + .create_if_missing(true); + + let pool = SqlitePoolOptions::new().connect_with(opts).await?; + + sqlx::migrate!("./migrations") + .run(&pool) + .await + .map_err(|error| DbError::Other(error.into()))?; + + Ok(Self { pool }) + } + + #[instrument(skip_all)] + async fn get_session(&self, token: &str) -> DbResult<Session> { + sqlx::query_as("select id, user_id, token from sessions where token = $1") + .bind(token) + .fetch_one(&self.pool) + .await + .map_err(Into::into) + .map(|DbSession(session)| session) + } + + #[instrument(skip_all)] + async fn get_session_user(&self, token: &str) -> DbResult<User> { + sqlx::query_as( + "select users.id, users.username, users.email, users.password from users + inner join sessions + on users.id = sessions.user_id + and sessions.token = $1", + ) + .bind(token) + .fetch_one(&self.pool) + .await + .map_err(Into::into) + .map(|DbUser(user)| user) + } + + #[instrument(skip_all)] + async fn add_session(&self, session: &NewSession) -> DbResult<()> { + let token: &str = &session.token; + + sqlx::query( + "insert into sessions + (user_id, token) + values($1, $2)", + ) + .bind(session.user_id) + .bind(token) + .execute(&self.pool) + .await?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn get_user(&self, username: &str) -> DbResult<User> { + sqlx::query_as("select id, username, email, password from users where username = $1") + .bind(username) + .fetch_one(&self.pool) + .await + .map_err(Into::into) + .map(|DbUser(user)| user) + } + + #[instrument(skip_all)] + async fn get_user_session(&self, u: &User) -> DbResult<Session> { + sqlx::query_as("select id, user_id, token from sessions where user_id = $1") + .bind(u.id) + .fetch_one(&self.pool) + .await + .map_err(Into::into) + .map(|DbSession(session)| session) + } + + #[instrument(skip_all)] + async fn add_user(&self, user: &NewUser) -> DbResult<i64> { + let email: &str = &user.email; + let username: &str = &user.username; + let password: &str = &user.password; + + let res: (i64,) = sqlx::query_as( + "insert into users + (username, email, password) + values($1, $2, $3) + returning id", + ) + .bind(username) + .bind(email) + .bind(password) + .fetch_one(&self.pool) + .await?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn update_user_password(&self, user: &User) -> DbResult<()> { + sqlx::query( + "update users + set password = $1 + where id = $2", + ) + .bind(&user.password) + .bind(user.id) + .execute(&self.pool) + .await?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn count_history(&self, user: &User) -> DbResult<i64> { + // The cache is new, and the user might not yet have a cache value. + // They will have one as soon as they post up some new history, but handle that + // edge case. + + let res: (i64,) = sqlx::query_as( + "select count(1) from history + where user_id = $1", + ) + .bind(user.id) + .fetch_one(&self.pool) + .await?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn count_history_cached(&self, _user: &User) -> DbResult<i64> { + Err(DbError::NotFound) + } + + #[instrument(skip_all)] + async fn delete_user(&self, u: &User) -> DbResult<()> { + sqlx::query("delete from sessions where user_id = $1") + .bind(u.id) + .execute(&self.pool) + .await?; + + sqlx::query("delete from users where id = $1") + .bind(u.id) + .execute(&self.pool) + .await?; + + sqlx::query("delete from history where user_id = $1") + .bind(u.id) + .execute(&self.pool) + .await?; + + Ok(()) + } + + async fn delete_history(&self, user: &User, id: String) -> DbResult<()> { + sqlx::query( + "update history + set deleted_at = $3 + where user_id = $1 + and client_id = $2 + and deleted_at is null", // don't just keep setting it + ) + .bind(user.id) + .bind(id) + .bind(time::OffsetDateTime::now_utc()) + .fetch_all(&self.pool) + .await?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> { + // The cache is new, and the user might not yet have a cache value. + // They will have one as soon as they post up some new history, but handle that + // edge case. + + let res = sqlx::query( + "select client_id from history + where user_id = $1 + and deleted_at is not null", + ) + .bind(user.id) + .fetch_all(&self.pool) + .await?; + + let res = res.iter().map(|row| row.get("client_id")).collect(); + + Ok(res) + } + + async fn delete_store(&self, user: &User) -> DbResult<()> { + sqlx::query( + "delete from store + where user_id = $1", + ) + .bind(user.id) + .execute(&self.pool) + .await?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> { + let mut tx = self.pool.begin().await?; + + for i in records { + let id = crate::atuin_common::utils::uuid_v7(); + + sqlx::query( + "insert into store + (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id) + values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + on conflict do nothing + ", + ) + .bind(id) + .bind(i.id) + .bind(i.host.id) + .bind(i.idx as i64) + .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time + .bind(&i.version) + .bind(&i.tag) + .bind(&i.data.data) + .bind(&i.data.content_encryption_key) + .bind(user.id) + .execute(&mut *tx) + .await?; + } + + tx.commit().await?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn next_records( + &self, + user: &User, + host: HostId, + tag: String, + start: Option<RecordIdx>, + count: u64, + ) -> DbResult<Vec<Record<EncryptedData>>> { + tracing::debug!("{:?} - {:?} - {:?}", host, tag, start); + let start = start.unwrap_or(0); + + let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as( + "select client_id, host, idx, timestamp, version, tag, data, cek from store + where user_id = $1 + and tag = $2 + and host = $3 + and idx >= $4 + order by idx asc + limit $5", + ) + .bind(user.id) + .bind(tag.clone()) + .bind(host) + .bind(start as i64) + .bind(count as i64) + .fetch_all(&self.pool) + .await + .map_err(Into::into); + + let ret = match records { + Ok(records) => { + let records: Vec<Record<EncryptedData>> = records + .into_iter() + .map(|f| { + let record: Record<EncryptedData> = f.into(); + record + }) + .collect(); + + records + } + Err(DbError::NotFound) => { + tracing::debug!("no records found in store: {:?}/{}", host, tag); + return Ok(vec![]); + } + Err(e) => return Err(e), + }; + + Ok(ret) + } + + async fn status(&self, user: &User) -> DbResult<RecordStatus> { + const STATUS_SQL: &str = + "select host, tag, max(idx) from store where user_id = $1 group by host, tag"; + + let res: Vec<(Uuid, String, i64)> = sqlx::query_as(STATUS_SQL) + .bind(user.id) + .fetch_all(&self.pool) + .await?; + + let mut status = RecordStatus::new(); + + for i in res { + status.set_raw(HostId(i.0), i.1, i.2 as u64); + } + + Ok(status) + } + + #[instrument(skip_all)] + async fn count_history_range( + &self, + user: &User, + range: std::ops::Range<time::OffsetDateTime>, + ) -> DbResult<i64> { + let res: (i64,) = sqlx::query_as( + "select count(1) from history + where user_id = $1 + and timestamp >= $2::date + and timestamp < $3::date", + ) + .bind(user.id) + .bind(into_utc(range.start)) + .bind(into_utc(range.end)) + .fetch_one(&self.pool) + .await?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn list_history( + &self, + user: &User, + created_after: time::OffsetDateTime, + since: time::OffsetDateTime, + host: &str, + page_size: i64, + ) -> DbResult<Vec<History>> { + let res = sqlx::query_as( + "select id, client_id, user_id, hostname, timestamp, data, created_at from history + where user_id = $1 + and hostname != $2 + and created_at >= $3 + and timestamp >= $4 + order by timestamp asc + limit $5", + ) + .bind(user.id) + .bind(host) + .bind(into_utc(created_after)) + .bind(into_utc(since)) + .bind(page_size) + .fetch(&self.pool) + .map_ok(|DbHistory(h)| h) + .try_collect() + .await?; + + Ok(res) + } + + #[instrument(skip_all)] + async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> { + let mut tx = self.pool.begin().await?; + + for i in history { + let client_id: &str = &i.client_id; + let hostname: &str = &i.hostname; + let data: &str = &i.data; + + sqlx::query( + "insert into history + (client_id, user_id, hostname, timestamp, data) + values ($1, $2, $3, $4, $5) + on conflict do nothing + ", + ) + .bind(client_id) + .bind(i.user_id) + .bind(hostname) + .bind(i.timestamp) + .bind(data) + .execute(&mut *tx) + .await?; + } + + tx.commit().await?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn oldest_history(&self, user: &User) -> DbResult<History> { + sqlx::query_as( + "select id, client_id, user_id, hostname, timestamp, data, created_at from history + where user_id = $1 + order by timestamp asc + limit 1", + ) + .bind(user.id) + .fetch_one(&self.pool) + .await + .map_err(Into::into) + .map(|DbHistory(h)| h) + } +} diff --git a/crates/turtle/src/atuin_server_sqlite/wrappers.rs b/crates/turtle/src/atuin_server_sqlite/wrappers.rs new file mode 100644 index 00000000..5aa7a982 --- /dev/null +++ b/crates/turtle/src/atuin_server_sqlite/wrappers.rs @@ -0,0 +1,72 @@ +use ::sqlx::{FromRow, Result}; +use crate::atuin_common::record::{EncryptedData, Host, Record}; +use crate::atuin_server_database::models::{History, Session, User}; +use sqlx::{Row, sqlite::SqliteRow}; + +pub struct DbUser(pub User); +pub struct DbSession(pub Session); +pub struct DbHistory(pub History); +pub struct DbRecord(pub Record<EncryptedData>); + +impl<'a> FromRow<'a, SqliteRow> for DbUser { + fn from_row(row: &'a SqliteRow) -> Result<Self> { + Ok(Self(User { + id: row.try_get("id")?, + username: row.try_get("username")?, + email: row.try_get("email")?, + password: row.try_get("password")?, + })) + } +} + +impl<'a> ::sqlx::FromRow<'a, SqliteRow> for DbSession { + fn from_row(row: &'a SqliteRow) -> ::sqlx::Result<Self> { + Ok(Self(Session { + id: row.try_get("id")?, + user_id: row.try_get("user_id")?, + token: row.try_get("token")?, + })) + } +} + +impl<'a> ::sqlx::FromRow<'a, SqliteRow> for DbHistory { + fn from_row(row: &'a SqliteRow) -> ::sqlx::Result<Self> { + Ok(Self(History { + id: row.try_get("id")?, + client_id: row.try_get("client_id")?, + user_id: row.try_get("user_id")?, + hostname: row.try_get("hostname")?, + timestamp: row.try_get("timestamp")?, + data: row.try_get("data")?, + created_at: row.try_get("created_at")?, + })) + } +} + +impl<'a> ::sqlx::FromRow<'a, SqliteRow> for DbRecord { + fn from_row(row: &'a SqliteRow) -> ::sqlx::Result<Self> { + let idx: i64 = row.try_get("idx")?; + let timestamp: i64 = row.try_get("timestamp")?; + + let data = EncryptedData { + data: row.try_get("data")?, + content_encryption_key: row.try_get("cek")?, + }; + + Ok(Self(Record { + id: row.try_get("client_id")?, + host: Host::new(row.try_get("host")?), + idx: idx as u64, + timestamp: timestamp as u64, + version: row.try_get("version")?, + tag: row.try_get("tag")?, + data, + })) + } +} + +impl From<DbRecord> for Record<EncryptedData> { + fn from(other: DbRecord) -> Record<EncryptedData> { + Record { ..other.0 } + } +} diff --git a/crates/turtle/src/command/CONTRIBUTORS b/crates/turtle/src/command/CONTRIBUTORS new file mode 120000 index 00000000..1ca4115a --- /dev/null +++ b/crates/turtle/src/command/CONTRIBUTORS @@ -0,0 +1 @@ +../../../../CONTRIBUTORS
\ No newline at end of file diff --git a/crates/turtle/src/command/client.rs b/crates/turtle/src/command/client.rs new file mode 100644 index 00000000..20d85303 --- /dev/null +++ b/crates/turtle/src/command/client.rs @@ -0,0 +1,371 @@ +use std::fs::{self, OpenOptions}; +use std::path::{Path, PathBuf}; + +use clap::Subcommand; +use eyre::{Result, WrapErr}; + +use crate::atuin_client::{ + database::Sqlite, record::sqlite_store::SqliteStore, settings::Settings, theme, +}; +use tracing_appender::rolling::{RollingFileAppender, Rotation}; +use tracing_subscriber::{ + Layer, filter::EnvFilter, filter::LevelFilter, fmt, fmt::format::FmtSpan, prelude::*, +}; + +fn cleanup_old_logs(log_dir: &Path, prefix: &str, retention_days: u64) { + let cutoff = std::time::SystemTime::now() + - std::time::Duration::from_secs(retention_days * 24 * 60 * 60); + + let Ok(entries) = fs::read_dir(log_dir) else { + return; + }; + + for entry in entries.flatten() { + let path = entry.path(); + let Some(name) = path.file_name().and_then(|n| n.to_str()) else { + continue; + }; + + // Match files like "search.log.2024-02-23" or "daemon.log.2024-02-23" + if !name.starts_with(prefix) || name == prefix { + continue; + } + + if let Ok(metadata) = entry.metadata() + && let Ok(modified) = metadata.modified() + && modified < cutoff + { + let _ = fs::remove_file(&path); + } + } +} + +#[cfg(feature = "sync")] +mod sync; + +#[cfg(feature = "sync")] +mod account; + +#[cfg(feature = "daemon")] +mod daemon; + +mod config; +mod default_config; +mod doctor; +mod history; +mod import; +mod info; +mod init; +mod search; +mod server; +mod setup; +mod stats; +mod store; +mod wrapped; + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Setup Atuin features + #[command()] + Setup, + + /// Manipulate shell history + #[command(subcommand)] + History(history::Cmd), + + /// Import shell history from file + #[command(subcommand)] + Import(import::Cmd), + + /// Calculate statistics for your history + Stats(stats::Cmd), + + /// Interactive history search + Search(search::Cmd), + + #[cfg(feature = "sync")] + #[command(flatten)] + Sync(sync::Cmd), + + /// Manage the atuin server + #[command(subcommand)] + Server(server::Cmd), + + /// Manage your sync account + #[cfg(feature = "sync")] + Account(account::Cmd), + + /// Manage the atuin data store + #[command(subcommand)] + Store(store::Cmd), + + /// Print Atuin's shell init script + #[command()] + Init(init::Cmd), + + /// Information about dotfiles locations and ENV vars + #[command()] + Info, + + /// Run the doctor to check for common issues + #[command()] + Doctor, + + #[command()] + Wrapped { year: Option<i32> }, + + /// *Experimental* Manage the background daemon + #[cfg(feature = "daemon")] + #[command()] + Daemon(daemon::Cmd), + + /// Print the default atuin configuration (config.toml) + #[command()] + DefaultConfig, + + #[command(subcommand)] + Config(config::Cmd), +} + +impl Cmd { + pub fn run(self) -> Result<()> { + // Daemonize before creating the async runtime – fork() inside a live + // tokio runtime corrupts its internal state. + #[cfg(all(unix, feature = "daemon"))] + if let Self::Daemon(ref cmd) = self + && cmd.should_daemonize() + { + daemon::daemonize_current_process()?; + } + + let mut runtime = tokio::runtime::Builder::new_current_thread(); + + let runtime = runtime.enable_all().build().unwrap(); + + // For non-history commands, we want to initialize logging and the theme manager before + // doing anything else. History commands are performance-sensitive and run before and after + // every shell command, so we want to skip any unnecessary initialization for them. + let settings = Settings::new().wrap_err("could not load client settings")?; + let theme_manager = theme::ThemeManager::new(settings.theme.debug, None); + let res = runtime.block_on(self.run_inner(settings, theme_manager)); + + runtime.shutdown_timeout(std::time::Duration::from_millis(50)); + + res + } + + #[expect(clippy::too_many_lines, clippy::future_not_send)] + async fn run_inner( + self, + mut settings: Settings, + mut theme_manager: theme::ThemeManager, + ) -> Result<()> { + // ATUIN_LOG env var overrides config file level settings + let env_log_set = std::env::var("ATUIN_LOG").is_ok(); + + // Base filter from env var (or empty if not set) + let base_filter = + EnvFilter::from_env("ATUIN_LOG").add_directive("sqlx_sqlite::regexp=off".parse()?); + + let is_interactive_search = matches!(&self, Self::Search(cmd) if cmd.is_interactive()); + // Use file-based logging for interactive search (TUI mode) + let use_search_logging = is_interactive_search && settings.logs.search_enabled(); + + // Use file-based logging for daemon + #[cfg(feature = "daemon")] + let use_daemon_logging = matches!(&self, Self::Daemon(_)) && settings.logs.daemon_enabled(); + + #[cfg(not(feature = "daemon"))] + let use_daemon_logging = false; + + // Check if daemon should also log to console + #[cfg(feature = "daemon")] + let daemon_show_logs = matches!(&self, Self::Daemon(cmd) if cmd.show_logs()); + + #[cfg(not(feature = "daemon"))] + let daemon_show_logs = false; + + // Set up span timing JSON logs if ATUIN_SPAN is set + let span_path = std::env::var("ATUIN_SPAN").ok().map(|p| { + if p.is_empty() { + "atuin-spans.json".to_string() + } else { + p + } + }); + + // Helper to create span timing layer + macro_rules! make_span_layer { + ($path:expr) => {{ + let span_file = OpenOptions::new() + .create(true) + .truncate(true) + .write(true) + .open($path)?; + Some( + fmt::layer() + .json() + .with_writer(span_file) + .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) + .with_filter(LevelFilter::TRACE), + ) + }}; + } + + // Build the subscriber with all configured layers + if use_search_logging { + let search_filename = settings.logs.search.file.clone(); + let log_dir = PathBuf::from(&settings.logs.dir); + fs::create_dir_all(&log_dir)?; + + // Clean up old log files + cleanup_old_logs(&log_dir, &search_filename, settings.logs.search_retention()); + + let file_appender = + RollingFileAppender::new(Rotation::DAILY, &log_dir, &search_filename); + + // Use config level unless ATUIN_LOG is set + let filter = if env_log_set { + base_filter + } else { + EnvFilter::default() + .add_directive(settings.logs.search_level().as_directive().parse()?) + .add_directive("sqlx_sqlite::regexp=off".parse()?) + }; + + let base = tracing_subscriber::registry().with( + fmt::layer() + .with_writer(file_appender) + .with_ansi(false) + .with_filter(filter), + ); + + match &span_path { + Some(sp) => { + base.with(make_span_layer!(sp)).init(); + } + None => { + base.init(); + } + } + } else if use_daemon_logging { + let daemon_filename = settings.logs.daemon.file.clone(); + let log_dir = PathBuf::from(&settings.logs.dir); + fs::create_dir_all(&log_dir)?; + + // Clean up old log files + cleanup_old_logs(&log_dir, &daemon_filename, settings.logs.daemon_retention()); + + let file_appender = + RollingFileAppender::new(Rotation::DAILY, &log_dir, &daemon_filename); + + // Use config level unless ATUIN_LOG is set + let file_filter = if env_log_set { + base_filter + } else { + EnvFilter::default() + .add_directive(settings.logs.daemon_level().as_directive().parse()?) + .add_directive("sqlx_sqlite::regexp=off".parse()?) + }; + + let file_layer = fmt::layer() + .with_writer(file_appender) + .with_ansi(false) + .with_filter(file_filter); + + // Optionally add console layer for --show-logs + if daemon_show_logs { + let console_filter = EnvFilter::from_env("ATUIN_LOG") + .add_directive("sqlx_sqlite::regexp=off".parse()?); + + let console_layer = fmt::layer().with_filter(console_filter); + + let base = tracing_subscriber::registry() + .with(file_layer) + .with(console_layer); + + match &span_path { + Some(sp) => { + base.with(make_span_layer!(sp)).init(); + } + None => { + base.init(); + } + } + } else { + let base = tracing_subscriber::registry().with(file_layer); + + match &span_path { + Some(sp) => { + base.with(make_span_layer!(sp)).init(); + } + None => { + base.init(); + } + } + } + } + + tracing::trace!(command = ?self, "client command"); + + // Skip initializing any databases for history + // This is a pretty hot path, as it runs before and after every single command the user + // runs + match self { + Self::History(history) => return history.run(&settings).await, + Self::Init(init) => { + init.run(&settings); + return Ok(()); + } + Self::Doctor => return doctor::run(&settings).await, + Self::Config(config) => return config.run(&settings).await, + _ => {} + } + + let db_path = PathBuf::from(settings.db_path.as_str()); + let record_store_path = PathBuf::from(settings.record_store_path.as_str()); + + let db = Sqlite::new(db_path, settings.local_timeout).await?; + let sqlite_store = SqliteStore::new(record_store_path, settings.local_timeout).await?; + + let theme_name = settings.theme.name.clone(); + let theme = theme_manager.load_theme(theme_name.as_str(), settings.theme.max_depth); + + match self { + Self::Setup => setup::run(&settings).await, + Self::Import(import) => import.run(&db).await, + Self::Stats(stats) => stats.run(&db, &settings, theme).await, + Self::Search(search) => search.run(db, &mut settings, sqlite_store, theme).await, + + #[cfg(feature = "sync")] + Self::Sync(sync) => sync.run(settings, &db, sqlite_store).await, + + #[cfg(feature = "sync")] + Self::Account(account) => account.run(settings, sqlite_store).await, + + Self::Store(store) => store.run(&settings, &db, sqlite_store).await, + + Self::Server(server) => server.run().await, + + Self::Info => { + info::run(&settings); + Ok(()) + } + + Self::DefaultConfig => { + default_config::run(); + Ok(()) + } + + Self::Wrapped { year } => wrapped::run(year, &db, &settings, theme).await, + + #[cfg(feature = "daemon")] + Self::Daemon(cmd) => cmd.run(settings, sqlite_store, db).await, + + Self::History(_) | Self::Init(_) | Self::Doctor | Self::Config(_) => { + unreachable!() + } + } + } +} diff --git a/crates/turtle/src/command/client/account.rs b/crates/turtle/src/command/client/account.rs new file mode 100644 index 00000000..898f1ac4 --- /dev/null +++ b/crates/turtle/src/command/client/account.rs @@ -0,0 +1,47 @@ +use clap::{Args, Subcommand}; +use eyre::Result; + +use crate::atuin_client::record::sqlite_store::SqliteStore; +use crate::atuin_client::settings::Settings; + +pub mod change_password; +pub mod delete; +pub mod login; +pub mod logout; +pub mod register; + +#[derive(Args, Debug)] +pub struct Cmd { + #[command(subcommand)] + command: Commands, +} + +#[derive(Subcommand, Debug)] +pub enum Commands { + /// Login to the configured server + Login(login::Cmd), + + /// Register a new account + Register(register::Cmd), + + /// Log out + Logout, + + /// Delete your account, and all synced data + Delete(delete::Cmd), + + /// Change your password + ChangePassword(change_password::Cmd), +} + +impl Cmd { + pub async fn run(self, settings: Settings, store: SqliteStore) -> Result<()> { + match self.command { + Commands::Login(l) => l.run(&settings, &store).await, + Commands::Register(r) => r.run(&settings).await, + Commands::Logout => logout::run().await, + Commands::Delete(d) => d.run(&settings).await, + Commands::ChangePassword(c) => c.run(&settings).await, + } + } +} diff --git a/crates/turtle/src/command/client/account/change_password.rs b/crates/turtle/src/command/client/account/change_password.rs new file mode 100644 index 00000000..6112b0df --- /dev/null +++ b/crates/turtle/src/command/client/account/change_password.rs @@ -0,0 +1,67 @@ +use clap::Parser; +use eyre::{Result, bail}; + +use crate::atuin_client::{ + auth::{self, MutateResponse}, + settings::Settings, +}; +use rpassword::prompt_password; + +#[derive(Parser, Debug)] +pub struct Cmd { + #[clap(long, short)] + pub current_password: Option<String>, + + #[clap(long, short)] + pub new_password: Option<String>, + + /// The two-factor authentication code for your account, if any + #[clap(long, short)] + pub totp_code: Option<String>, +} + +impl Cmd { + pub async fn run(&self, settings: &Settings) -> Result<()> { + if !settings.logged_in().await? { + bail!("You are not logged in"); + } + + let client = auth::auth_client(settings).await; + + let current_password = self.current_password.clone().unwrap_or_else(|| { + prompt_password("Please enter the current password: ") + .expect("Failed to read from input") + }); + + if current_password.is_empty() { + bail!("please provide the current password"); + } + + let new_password = self.new_password.clone().unwrap_or_else(|| { + prompt_password("Please enter the new password: ").expect("Failed to read from input") + }); + + if new_password.is_empty() { + bail!("please provide a new password"); + } + + let mut totp_code = self.totp_code.clone(); + + loop { + let response = client + .change_password(¤t_password, &new_password, totp_code.as_deref()) + .await?; + + match response { + MutateResponse::Success => break, + MutateResponse::TwoFactorRequired => { + totp_code = Some(super::login::or_user_input(None, "two-factor code")); + } + } + } + + println!("Account password successfully changed!"); + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/account/delete.rs b/crates/turtle/src/command/client/account/delete.rs new file mode 100644 index 00000000..bcb40bc3 --- /dev/null +++ b/crates/turtle/src/command/client/account/delete.rs @@ -0,0 +1,57 @@ +use crate::atuin_client::{ + auth::{self, MutateResponse}, + settings::Settings, +}; +use clap::Parser; +use eyre::{Result, bail}; + +use super::login::{or_user_input, read_user_password}; + +#[derive(Parser, Debug)] +pub struct Cmd { + #[clap(long, short)] + pub password: Option<String>, + + /// The two-factor authentication code for your account, if any + #[clap(long, short)] + pub totp_code: Option<String>, +} + +impl Cmd { + pub async fn run(&self, settings: &Settings) -> Result<()> { + if !settings.logged_in().await? { + bail!("You are not logged in"); + } + + let client = auth::auth_client(settings).await; + + let password = self.password.clone().unwrap_or_else(read_user_password); + + if password.is_empty() { + bail!("please provide your password"); + } + + let mut totp_code = self.totp_code.clone(); + + loop { + let response = client + .delete_account(&password, totp_code.as_deref()) + .await?; + + match response { + MutateResponse::Success => break, + MutateResponse::TwoFactorRequired => { + totp_code = Some(or_user_input(None, "two-factor code")); + } + } + } + + // Clean up sessions from meta store + let meta = Settings::meta_store().await?; + meta.delete_session().await?; + + println!("Your account is deleted"); + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/account/login.rs b/crates/turtle/src/command/client/account/login.rs new file mode 100644 index 00000000..0c5b66f5 --- /dev/null +++ b/crates/turtle/src/command/client/account/login.rs @@ -0,0 +1,206 @@ +use std::{io, path::PathBuf}; + +use clap::Parser; +use eyre::{Context, Result, bail}; +use tokio::{fs::File, io::AsyncWriteExt}; + +use crate::atuin_client::{ + auth::{self, AuthResponse}, + encryption::{decode_key, load_key}, + record::sqlite_store::SqliteStore, + record::store::Store, + record::sync::{self, SyncError}, + settings::{Settings, SyncAuth}, +}; +use rpassword::prompt_password; + +#[derive(Parser, Debug)] +pub struct Cmd { + #[clap(long, short)] + pub username: Option<String>, + + #[clap(long, short)] + pub password: Option<String>, + + /// The encryption key for your account + #[clap(long, short)] + pub key: Option<String>, + + /// The two-factor authentication code for your account, if any + #[clap(long, short)] + pub totp_code: Option<String>, + + #[clap(long, hide = true)] + pub from_registration: bool, +} + +fn get_input() -> Result<String> { + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + Ok(input.trim_end_matches(&['\r', '\n'][..]).to_string()) +} + +impl Cmd { + pub async fn run(&self, settings: &Settings, store: &SqliteStore) -> Result<()> { + match settings.resolve_sync_auth().await { + SyncAuth::Legacy { .. } => { + println!("You are logged in to your sync server."); + println!("Run 'atuin logout' to log out."); + return Ok(()); + } + SyncAuth::NotLoggedIn { .. } => {} + } + + self.run_legacy_login(settings, store).await?; + + verify_key_against_remote(settings).await + } + + /// Legacy login: always prompt for username/password interactively + /// (or accept them via flags). + async fn run_legacy_login(&self, settings: &Settings, store: &SqliteStore) -> Result<()> { + let username = or_user_input(self.username.clone(), "username"); + let password = self.password.clone().unwrap_or_else(read_user_password); + + self.prompt_and_store_key(settings, store).await?; + + let client = auth::auth_client(settings).await; + let response = client.login(&username, &password).await?; + + match response { + AuthResponse::Success { session, .. } => { + Settings::meta_store().await?.save_session(&session).await?; + } + AuthResponse::TwoFactorRequired => { + // Legacy server doesn't support 2FA, so this shouldn't happen. + bail!("unexpected two-factor requirement from legacy server"); + } + } + + println!("Logged in!"); + Ok(()) + } + + async fn prompt_and_store_key(&self, settings: &Settings, store: &SqliteStore) -> Result<()> { + let key_path = settings.key_path.as_str(); + let key_path = PathBuf::from(key_path); + + println!("IMPORTANT"); + println!( + "If you are already logged in on another machine, you must ensure that the key you use here is the same as the key you used there." + ); + println!("You can find your key by running 'atuin key' on the other machine."); + println!("Do not share this key with anyone."); + println!("\nRead more here: https://docs.atuin.sh/guide/sync/#login \n"); + + let key = or_user_input( + self.key.clone(), + "encryption key [blank to use existing key file]", + ); + + if key.is_empty() { + if key_path.exists() { + let bytes = fs_err::read_to_string(&key_path).context(format!( + "Existing key file at '{}' could not be read", + key_path.to_string_lossy() + ))?; + if decode_key(bytes).is_err() { + bail!(format!( + "The key in existing key file at '{}' is invalid", + key_path.to_string_lossy() + )); + } + } else { + panic!( + "No key provided and no existing key file found. Please use 'atuin key' on your other machine, or recover your key from a backup" + ) + } + } else if !key_path.exists() { + if decode_key(key.clone()).is_err() { + bail!("The specified key is invalid"); + } + + let mut file = File::create(&key_path).await?; + file.write_all(key.as_bytes()).await?; + } else { + // we now know that the user has logged in specifying a key, AND that the key path + // exists + + // 1. check if the saved key and the provided key match. if so, nothing to do. + // 2. if not, re-encrypt the local history and overwrite the key + let current_key: [u8; 32] = load_key(settings)?.into(); + + let encoded = key.clone(); // gonna want to save it in a bit + let new_key: [u8; 32] = decode_key(key) + .context("Could not decode provided key; is not valid base64-encoded key")? + .into(); + + if new_key != current_key { + println!("\nRe-encrypting local store with new key"); + + store.re_encrypt(¤t_key, &new_key).await?; + + println!("Writing new key"); + let mut file = File::create(&key_path).await?; + file.write_all(encoded.as_bytes()).await?; + } + } + + Ok(()) + } +} + +async fn verify_key_against_remote(settings: &Settings) -> Result<()> { + let key: [u8; 32] = load_key(settings) + .context("could not load encryption key for verification")? + .into(); + + let client = sync::build_client(settings).await?; + let remote_index = match client.record_status().await { + Ok(idx) => idx, + Err(e) => { + tracing::warn!("could not fetch remote status to verify key: {e}"); + return Ok(()); + } + }; + + match sync::check_encryption_key(&client, &remote_index, &key).await { + Ok(()) => Ok(()), + Err(SyncError::WrongKey) => { + // Roll back the saved session so the user is not left in a + // half-authenticated state with a key that can't read the data. + if let Ok(meta) = Settings::meta_store().await { + let _ = meta.delete_session().await; + } + crate::print_error::print_error( + "Wrong encryption key", + "The encryption key on this machine does not match the data on the server. \ + You have been logged out.\n\n\ + To fix this, find your existing key by running `atuin key` on a machine that \ + already syncs successfully, then run `atuin login` again here with that key.", + ); + std::process::exit(1); + } + Err(e) => { + // Non-key error (e.g. transient network issue). Don't fail the + // login — the user is authenticated and can sync later when the + // network recovers. + tracing::warn!("could not verify encryption key against remote: {e}"); + Ok(()) + } + } +} + +pub(super) fn or_user_input(value: Option<String>, name: &'static str) -> String { + value.unwrap_or_else(|| read_user_input(name)) +} + +pub(super) fn read_user_password() -> String { + let password = prompt_password("Please enter password: "); + password.expect("Failed to read from input") +} + +fn read_user_input(name: &'static str) -> String { + eprint!("Please enter {name}: "); + get_input().expect("Failed to read from input") +} diff --git a/crates/turtle/src/command/client/account/logout.rs b/crates/turtle/src/command/client/account/logout.rs new file mode 100644 index 00000000..6150a52b --- /dev/null +++ b/crates/turtle/src/command/client/account/logout.rs @@ -0,0 +1,5 @@ +use eyre::Result; + +pub async fn run() -> Result<()> { + crate::atuin_client::logout::logout().await +} diff --git a/crates/turtle/src/command/client/account/register.rs b/crates/turtle/src/command/client/account/register.rs new file mode 100644 index 00000000..548c2739 --- /dev/null +++ b/crates/turtle/src/command/client/account/register.rs @@ -0,0 +1,67 @@ +use clap::Parser; +use eyre::{Result, bail}; + +use super::login::or_user_input; +use crate::atuin_client::settings::{Settings, SyncAuth}; + +#[derive(Parser, Debug)] +pub struct Cmd { + #[clap(long, short)] + pub username: Option<String>, + + #[clap(long, short)] + pub password: Option<String>, + + #[clap(long, short)] + pub email: Option<String>, +} + +impl Cmd { + pub async fn run(&self, settings: &Settings) -> Result<()> { + match settings.resolve_sync_auth().await { + SyncAuth::Legacy { .. } => { + println!("You are already logged in."); + println!("Run 'atuin logout' to log out."); + return Ok(()); + } + + SyncAuth::NotLoggedIn { .. } => {} + } + + // Legacy registration flow + println!("Registering for an Atuin Sync account"); + + let username = or_user_input(self.username.clone(), "username"); + let email = or_user_input(self.email.clone(), "email"); + let password = self + .password + .clone() + .unwrap_or_else(super::login::read_user_password); + + if password.is_empty() { + bail!("please provide a password"); + } + + let session = crate::atuin_client::api_client::register( + settings.sync_address.as_str(), + &username, + &email, + &password, + ) + .await?; + + let meta = Settings::meta_store().await?; + meta.save_session(&session.session).await?; + + let _key = crate::atuin_client::encryption::load_key(settings)?; + + println!( + "Registration successful! Please make a note of your key (run 'atuin key') and keep it safe." + ); + println!( + "You will need it to log in on other devices, and we cannot help recover it if you lose it." + ); + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/config.rs b/crates/turtle/src/command/client/config.rs new file mode 100644 index 00000000..1597a8d6 --- /dev/null +++ b/crates/turtle/src/command/client/config.rs @@ -0,0 +1,352 @@ +use crate::atuin_client::settings::Settings; +use clap::{Args, Subcommand, ValueEnum}; +use eyre::Result; +use toml_edit::{Document, DocumentMut, Item, Table, TableLike, Value}; + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Get a configuration value from your config.toml file + /// or after defaults and overrides are applied + #[command()] + Get(GetCmd), + + /// Set a configuration value in your config.toml file + #[command()] + Set(SetCmd), + + /// Print all configuration values from your config.toml file + /// in TOML format + /// + /// If a key is provided, only print the value of that key and all its children + #[command()] + Print(PrintCmd), +} + +impl Cmd { + pub async fn run(self, settings: &Settings) -> Result<()> { + match self { + Self::Get(get) => get.run(settings).await, + Self::Set(set) => set.run(settings).await, + Self::Print(print) => print.run(settings).await, + } + } +} + +/// Get a configuration value from your config.toml file, +/// or optionally the effective value after defaults and overrides are applied. +#[derive(Args, Debug)] +pub struct GetCmd { + /// The configuration key to get + pub key: String, + + /// Print the value after defaults and overrides are applied + #[arg(long, short)] + pub resolved: bool, + + /// Print both the config file value and the resolved value + #[arg(long, short)] + pub verbose: bool, +} + +impl GetCmd { + pub async fn run(&self, _settings: &Settings) -> Result<()> { + let key = self.key.trim(); + if key.is_empty() || key.contains(char::is_whitespace) { + eyre::bail!("Config key must be non-empty and must not contain whitespace"); + } + + if self.verbose { + println!("Config file:"); + self.print_current_value(key, " ").await?; + println!("\nResolved:"); + Self::print_effective_value(key, " "); + return Ok(()); + } + + if self.resolved { + Self::print_effective_value(key, ""); + } else { + self.print_current_value(key, "").await?; + } + + Ok(()) + } + + async fn print_current_value(&self, key: &str, prefix: &str) -> Result<()> { + let config_file = Settings::get_config_path()?; + let config_str = tokio::fs::read_to_string(&config_file).await?; + let doc = config_str.parse::<Document<_>>()?; + + let current = get_deep_key(&doc, key); + + match current { + Some(item) if item.is_table() || item.is_inline_table() => { + let table = item + .as_table_like() + .expect("is_table()/is_inline_table() but no table"); + println!("{prefix}[{key}]"); + dump_table(table, prefix, &mut vec![key.to_string()])?; + } + Some(item) => { + let val = item.to_string(); + let val = val.trim().trim_matches('"'); + println!("{prefix}{val}"); + } + None => { + println!("{prefix}(not set in config file)"); + } + } + + Ok(()) + } + + fn print_effective_value(key: &str, prefix: &str) { + match Settings::get_config_value(key) { + Ok(value) => { + for line in value.lines() { + println!("{prefix}{line}"); + } + } + Err(_) => { + println!("{prefix}(unknown key)"); + } + } + } +} + +#[derive(Args, Debug)] +pub struct SetCmd { + /// The configuration key to set + pub key: String, + + /// The value to set + pub value: String, + + /// Store value as an explicit type + #[arg(long = "type", short, value_enum, default_value_t = ValueType::Auto, value_name = "TYPE")] + pub the_type: ValueType, +} + +#[derive(ValueEnum, Debug, Clone, PartialEq, Eq)] +pub enum ValueType { + /// Automatically determine the type of the value + Auto, + /// Store value as a string + String, + /// Store value as a boolean + Boolean, + /// Store value as an integer + Integer, + /// Store the value as a float + Float, +} + +impl SetCmd { + pub async fn run(self, _settings: &Settings) -> Result<()> { + let key = self.key.trim(); + if key.is_empty() || key.contains(char::is_whitespace) { + eyre::bail!("Config key must be non-empty and must not contain whitespace"); + } + + let config_file = Settings::get_config_path()?; + let config_str = tokio::fs::read_to_string(&config_file).await?; + let mut doc: DocumentMut = config_str.parse()?; + + // When using auto type detection, try to match the existing value's type + // so we don't accidentally change e.g. "300" (string) to 300 (integer) + let existing_type = detect_existing_type(&doc, key); + let value = self.parse_value(existing_type.as_ref())?; + set_deep_key(&mut doc, key, value)?; + + tokio::fs::write(&config_file, doc.to_string()).await?; + + Ok(()) + } + + fn parse_value(&self, existing_type: Option<&ValueType>) -> Result<Value> { + let raw = &self.value; + + // Explicit --type takes priority, then existing value type, then auto-detect + let effective_type = if self.the_type != ValueType::Auto { + &self.the_type + } else if let Some(existing) = existing_type { + existing + } else { + &ValueType::Auto + }; + + match effective_type { + ValueType::String => Ok(Value::from(raw.as_str())), + ValueType::Boolean => { + let b: bool = raw + .parse() + .map_err(|_| eyre::eyre!("invalid boolean value: {raw}"))?; + Ok(Value::from(b)) + } + ValueType::Integer => { + let i: i64 = raw + .parse() + .map_err(|_| eyre::eyre!("invalid integer value: {raw}"))?; + Ok(Value::from(i)) + } + ValueType::Float => { + let f: f64 = raw + .parse() + .map_err(|_| eyre::eyre!("invalid float value: {raw}"))?; + Ok(Value::from(f)) + } + ValueType::Auto => { + if raw == "true" || raw == "false" { + return Ok(Value::from(raw == "true")); + } + if let Ok(i) = raw.parse::<i64>() { + return Ok(Value::from(i)); + } + if let Ok(f) = raw.parse::<f64>() { + return Ok(Value::from(f)); + } + Ok(Value::from(raw.as_str())) + } + } + } +} + +#[derive(Args, Debug)] +pub struct PrintCmd { + /// Print the value of a specific key and all its children + pub key: Option<String>, +} + +impl PrintCmd { + pub async fn run(&self, _settings: &Settings) -> Result<()> { + let config_file = Settings::get_config_path()?; + let config_str = tokio::fs::read_to_string(&config_file).await?; + let doc = config_str.parse::<Document<_>>()?; + + if let Some(key) = &self.key { + let current = get_deep_key(&doc, key); + + if let Some(current) = current { + if current.is_table() || current.is_inline_table() { + println!("[{key}]"); + dump_table( + current + .as_table_like() + .expect("is_table()/is_inline_table() but no table"), + "", + &mut vec![key.clone()], + )?; + } else { + println!("{}", current.to_string().trim().trim_matches('"')); + } + } else { + println!("key not found"); + } + } else { + dump_table(doc.as_table(), "", &mut Vec::new())?; + } + + Ok(()) + } +} + +fn dump_table(table: &dyn TableLike, prefix: &str, stack: &mut Vec<String>) -> Result<()> { + for (key, value) in table.iter() { + if value.is_table() || value.is_inline_table() { + stack.push(key.to_string()); + + let table = value + .as_table_like() + .expect("is_table()/is_inline_table() but no table"); + + println!("\n{}[{}]", prefix, stack.join(".")); + + dump_table(table, prefix, stack)?; + + stack.pop(); + } else { + println!("{prefix}{key} = {value}"); + } + } + + Ok(()) +} + +fn get_deep_key<'doc>(doc: &'doc Document<String>, key: &str) -> Option<&'doc Item> { + let parts = key.split('.'); + let mut current: Option<&Item> = Some(doc.as_item()); + + for part in parts { + current = current + .and_then(|item| item.as_table_like()) + .and_then(|table| table.get(part)); + } + + current +} + +/// Detect the TOML type of an existing key in the document, so `set` with auto +/// type detection preserves the original type rather than guessing from the value string. +fn detect_existing_type(doc: &DocumentMut, key: &str) -> Option<ValueType> { + let parts: Vec<&str> = key.split('.').collect(); + let mut current: &dyn TableLike = doc.as_table(); + + for &part in &parts[..parts.len().saturating_sub(1)] { + current = current.get(part)?.as_table_like()?; + } + + let last = parts.last()?; + let v = current.get(last)?.as_value()?; + + if v.is_str() { + Some(ValueType::String) + } else if v.is_bool() { + Some(ValueType::Boolean) + } else if v.is_integer() { + Some(ValueType::Integer) + } else if v.is_float() { + Some(ValueType::Float) + } else { + None + } +} + +fn set_deep_key(doc: &mut DocumentMut, key: &str, value: Value) -> Result<()> { + let parts: Vec<&str> = key.split('.').collect(); + + if parts.is_empty() { + eyre::bail!("empty config key"); + } + + let mut current: &mut dyn TableLike = doc.as_table_mut(); + + // Navigate/create intermediate tables + for &part in &parts[..parts.len() - 1] { + if !current.contains_key(part) { + current.insert(part, Item::Table(Table::new())); + } + current = current + .get_mut(part) + .expect("just inserted or already exists") + .as_table_like_mut() + .ok_or_else(|| eyre::eyre!("'{}' exists but is not a table", part))?; + } + + let last = *parts.last().unwrap(); + + // Don't silently overwrite a table with a scalar value + if let Some(existing) = current.get(last) + && (existing.is_table() || existing.is_inline_table()) + { + eyre::bail!( + "'{}' is a table; use a dotted key like '{}.key' to set a value within it", + key, + key + ); + } + + current.insert(last, Item::Value(value)); + + Ok(()) +} diff --git a/crates/turtle/src/command/client/daemon.rs b/crates/turtle/src/command/client/daemon.rs new file mode 100644 index 00000000..2ee9b759 --- /dev/null +++ b/crates/turtle/src/command/client/daemon.rs @@ -0,0 +1,769 @@ +use std::fs::{self, File, OpenOptions}; +use std::io::{ErrorKind, Write}; +#[cfg(unix)] +use std::os::unix::net::UnixStream as StdUnixStream; +use std::path::{Path, PathBuf}; +use std::process::{Command, Stdio}; +use std::time::{Duration, Instant}; + +use crate::atuin_client::{ + database::Sqlite, history::History, record::sqlite_store::SqliteStore, settings::Settings, +}; +use crate::atuin_daemon::DaemonEvent; +use crate::atuin_daemon::client::{ + ControlClient, DaemonClientErrorKind, HistoryClient, classify_error, +}; +use clap::Subcommand; +#[cfg(unix)] +use daemonize::Daemonize; +use eyre::{Result, WrapErr, bail, eyre}; +use fs4::fs_std::FileExt; +use tokio::time::sleep; + +#[derive(clap::Args, Debug)] +pub struct Cmd { + /// Internal flag for daemonization + #[arg(long, hide = true)] + daemonize: bool, + + /// Also write daemon logs to the console (useful for debugging) + #[arg(long)] + show_logs: bool, + + #[command(subcommand)] + subcmd: Option<SubCmd>, +} + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum SubCmd { + /// Start the daemon server + Start { + #[arg(long, hide = true)] + daemonize: bool, + + /// Also write daemon logs to the console (useful for debugging) + #[arg(long)] + show_logs: bool, + + /// Force start: kill existing daemon process and reset the socket + #[arg(long)] + force: bool, + }, + + /// Show the daemon's current status + Status, + + /// Stop the daemon gracefully + Stop, + + /// Restart the daemon (stop, then start in background) + Restart, +} + +impl Cmd { + /// Returns `true` when the process should daemonize before creating the + /// async runtime or opening any database connections. + #[cfg(unix)] + pub fn should_daemonize(&self) -> bool { + match &self.subcmd { + Some(SubCmd::Start { daemonize, .. }) => *daemonize, + None => self.daemonize, + _ => false, + } + } + + /// Returns `true` when logs should also be written to the console. + pub fn show_logs(&self) -> bool { + match &self.subcmd { + Some(SubCmd::Start { show_logs, .. }) => *show_logs, + None => self.show_logs, + _ => false, + } + } + + pub async fn run( + self, + settings: Settings, + store: SqliteStore, + history_db: Sqlite, + ) -> Result<()> { + match self.subcmd { + None => { + eprintln!("Warning: `atuin daemon` is deprecated, use `atuin daemon start`"); + run(settings, store, history_db, false).await + } + Some(SubCmd::Start { force, .. }) => run(settings, store, history_db, force).await, + Some(SubCmd::Status) => status_cmd(&settings).await, + Some(SubCmd::Stop) => stop_cmd(&settings).await, + Some(SubCmd::Restart) => restart_cmd(&settings).await, + } + } +} + +const DAEMON_VERSION: &str = env!("CARGO_PKG_VERSION"); +const DAEMON_PROTOCOL_VERSION: u32 = 1; +const STARTUP_POLL: Duration = Duration::from_millis(40); +const LOCK_POLL: Duration = Duration::from_millis(20); +const LEGACY_DAEMON_RESTART_MESSAGE: &str = "legacy daemon detected; restart daemon manually"; + +struct PidfileGuard { + file: File, +} + +impl PidfileGuard { + fn acquire(path: &Path) -> Result<Self> { + let mut file = open_lock_file(path)?; + + if !file.try_lock_exclusive()? { + bail!( + "daemon already running (pidfile lock busy at {})", + path.display() + ); + } + + file.set_len(0) + .wrap_err_with(|| format!("could not truncate daemon pidfile {}", path.display()))?; + writeln!(file, "{}", std::process::id()) + .and_then(|()| writeln!(file, "{DAEMON_VERSION}")) + .wrap_err_with(|| format!("could not write daemon pidfile {}", path.display()))?; + + Ok(Self { file }) + } +} + +impl Drop for PidfileGuard { + fn drop(&mut self) { + let _ = self.file.unlock(); + } +} + +enum Probe { + Ready(HistoryClient), + NeedsRestart(String), + Unreachable(eyre::Report), +} + +fn daemon_matches_expected(version: &str, protocol: u32) -> bool { + version == DAEMON_VERSION && protocol == DAEMON_PROTOCOL_VERSION +} + +fn daemon_mismatch_message(version: &str, protocol: u32) -> String { + if protocol == DAEMON_PROTOCOL_VERSION { + format!("daemon is out of date: expected {DAEMON_VERSION}, got {version}") + } else { + format!("daemon protocol mismatch: expected {DAEMON_PROTOCOL_VERSION}, got {protocol}") + } +} + +fn is_legacy_daemon_error(err: &eyre::Report) -> bool { + matches!(classify_error(err), DaemonClientErrorKind::Unimplemented) +} + +fn should_retry_after_error(err: &eyre::Report) -> bool { + matches!( + classify_error(err), + DaemonClientErrorKind::Connect + | DaemonClientErrorKind::Unavailable + | DaemonClientErrorKind::Unimplemented + ) +} + +fn daemon_startup_lock_path(pidfile_path: &Path) -> PathBuf { + let mut os = pidfile_path.as_os_str().to_os_string(); + os.push(".startup.lock"); + PathBuf::from(os) +} + +fn open_lock_file(path: &Path) -> Result<File> { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent) + .wrap_err_with(|| format!("could not create lock directory {}", parent.display()))?; + } + + OpenOptions::new() + .read(true) + .write(true) + .create(true) + .truncate(false) + .open(path) + .wrap_err_with(|| format!("could not open lock file {}", path.display())) +} + +async fn wait_for_lock(path: &Path, timeout: Duration) -> Result<File> { + let file = open_lock_file(path)?; + let start = Instant::now(); + + loop { + match file.try_lock_exclusive() { + Ok(true) => return Ok(file), + Ok(false) => { + if start.elapsed() >= timeout { + bail!("timed out waiting for lock at {}", path.display()); + } + + sleep(LOCK_POLL).await; + } + Err(err) => { + return Err(eyre!("could not lock {}: {err}", path.display())); + } + } + } +} + +async fn wait_for_pidfile_available(path: &Path, timeout: Duration) -> Result<()> { + let file = wait_for_lock(path, timeout).await?; + file.unlock() + .wrap_err_with(|| format!("failed to unlock {}", path.display()))?; + Ok(()) +} + +async fn connect_client(settings: &Settings) -> Result<HistoryClient> { + HistoryClient::new( + #[cfg(unix)] + settings.daemon.socket_path.clone(), + ) + .await +} + +async fn probe(settings: &Settings) -> Probe { + let mut client = match connect_client(settings).await { + Ok(client) => client, + Err(err) => return Probe::Unreachable(err), + }; + + match client.status().await { + Ok(status) => { + if daemon_matches_expected(&status.version, status.protocol) { + Probe::Ready(client) + } else { + Probe::NeedsRestart(daemon_mismatch_message(&status.version, status.protocol)) + } + } + Err(err) => Probe::Unreachable(err), + } +} + +async fn request_shutdown(settings: &Settings) { + if let Ok(mut client) = connect_client(settings).await { + let _ = client.shutdown().await; + } +} + +fn spawn_daemon_process() -> Result<()> { + let exe = std::env::current_exe().wrap_err("could not locate atuin executable")?; + + let mut cmd = Command::new(exe); + cmd.arg("daemon") + .arg("start") + .stdin(Stdio::null()) + .stdout(Stdio::null()) + .stderr(Stdio::null()); + + #[cfg(unix)] + cmd.arg("--daemonize"); + + cmd.spawn().wrap_err("failed to spawn daemon process")?; + + Ok(()) +} + +fn startup_timeout(settings: &Settings) -> Duration { + Duration::from_secs_f64(settings.local_timeout.max(0.5) + 2.0) +} + +#[cfg(unix)] +fn remove_stale_socket_if_present(settings: &Settings) -> Result<()> { + if settings.daemon.systemd_socket { + return Ok(()); + } + + let socket_path = Path::new(&settings.daemon.socket_path); + if !socket_path.exists() { + return Ok(()); + } + + match StdUnixStream::connect(socket_path) { + Ok(stream) => { + drop(stream); + Ok(()) + } + Err(err) if err.kind() == ErrorKind::ConnectionRefused => { + fs::remove_file(socket_path).wrap_err_with(|| { + format!( + "failed to remove stale daemon socket {}", + socket_path.display() + ) + })?; + Ok(()) + } + Err(err) if err.kind() == ErrorKind::NotFound => Ok(()), + Err(_) => Ok(()), + } +} + +async fn wait_until_ready(settings: &Settings, timeout: Duration) -> Result<HistoryClient> { + let start = Instant::now(); + let mut last_error = eyre!("daemon did not become ready"); + + loop { + match probe(settings).await { + Probe::Ready(client) => return Ok(client), + Probe::NeedsRestart(reason) => { + last_error = eyre!(reason); + } + Probe::Unreachable(err) => { + if is_legacy_daemon_error(&err) { + return Err(err.wrap_err(LEGACY_DAEMON_RESTART_MESSAGE)); + } + last_error = err; + } + } + + if start.elapsed() >= timeout { + return Err(last_error.wrap_err(format!( + "timed out waiting for daemon startup after {}ms", + timeout.as_millis() + ))); + } + + sleep(STARTUP_POLL).await; + } +} + +#[expect(clippy::unnecessary_wraps)] +fn ensure_autostart_supported(settings: &Settings) -> Result<()> { + #[cfg(unix)] + if settings.daemon.systemd_socket { + bail!( + "daemon autostart is incompatible with `daemon.systemd_socket = true`; use systemd to manage the daemon" + ); + } + + Ok(()) +} + +/// Ensure the daemon is running, starting it if necessary. +/// +/// If the daemon is already running and up-to-date, this is a no-op. +/// If it is not running or needs a restart, this will spawn a new daemon +/// process and wait for it to become ready. +/// +/// Returns an error if the daemon could not be started. +pub async fn ensure_daemon_running(settings: &Settings) -> Result<()> { + ensure_autostart_supported(settings)?; + + let timeout = startup_timeout(settings); + let pidfile_path = PathBuf::from(&settings.daemon.pidfile_path); + let startup_lock_path = daemon_startup_lock_path(&pidfile_path); + let startup_lock = wait_for_lock(&startup_lock_path, timeout).await?; + + match probe(settings).await { + Probe::Ready(_) => { + drop(startup_lock); + return Ok(()); + } + Probe::NeedsRestart(_) => { + request_shutdown(settings).await; + } + Probe::Unreachable(err) => { + if is_legacy_daemon_error(&err) { + return Err(err.wrap_err(LEGACY_DAEMON_RESTART_MESSAGE)); + } + } + } + + // This prevents rapid-fire hook invocations from racing daemon restart. + wait_for_pidfile_available(&pidfile_path, timeout).await?; + + #[cfg(unix)] + remove_stale_socket_if_present(settings)?; + + spawn_daemon_process()?; + let _ = wait_until_ready(settings, timeout).await?; + + drop(startup_lock); + Ok(()) +} + +async fn restart_daemon(settings: &Settings) -> Result<HistoryClient> { + ensure_daemon_running(settings).await?; + connect_client(settings).await +} + +fn ensure_reply_compatible(settings: &Settings, version: &str, protocol: u32) -> Result<()> { + if daemon_matches_expected(version, protocol) { + return Ok(()); + } + + let message = daemon_mismatch_message(version, protocol); + if settings.daemon.autostart { + bail!("{message}"); + } + + bail!("{message}. Enable `daemon.autostart = true` or restart the daemon manually"); +} + +pub async fn start_history(settings: &Settings, history: History) -> Result<String> { + match async { + connect_client(settings) + .await? + .start_history(history.clone()) + .await + } + .await + { + Ok(resp) => { + if daemon_matches_expected(&resp.version, resp.protocol) { + return Ok(resp.id); + } + + if !settings.daemon.autostart { + return Err(eyre!( + "{}. Enable `daemon.autostart = true` or restart the daemon manually", + daemon_mismatch_message(&resp.version, resp.protocol) + )); + } + } + Err(err) if !settings.daemon.autostart => return Err(err), + Err(err) if !should_retry_after_error(&err) => return Err(err), + Err(_) => {} + } + + let resp = restart_daemon(settings) + .await? + .start_history(history) + .await?; + ensure_reply_compatible(settings, &resp.version, resp.protocol)?; + Ok(resp.id) +} + +pub async fn end_history(settings: &Settings, id: String, duration: u64, exit: i64) -> Result<()> { + match async { + connect_client(settings) + .await? + .end_history(id.clone(), duration, exit) + .await + } + .await + { + Ok(resp) => { + if daemon_matches_expected(&resp.version, resp.protocol) { + return Ok(()); + } + + if !settings.daemon.autostart { + return Err(eyre!( + "{}. Enable `daemon.autostart = true` or restart the daemon manually", + daemon_mismatch_message(&resp.version, resp.protocol) + )); + } + + // End succeeded on the running daemon, so avoid replaying it. + // We only restart to make subsequent hook calls target the expected version. + let _ = restart_daemon(settings).await; + return Ok(()); + } + Err(err) if !settings.daemon.autostart => return Err(err), + Err(err) if !should_retry_after_error(&err) => return Err(err), + Err(_) => {} + } + + let resp = restart_daemon(settings) + .await? + .end_history(id, duration, exit) + .await?; + ensure_reply_compatible(settings, &resp.version, resp.protocol)?; + Ok(()) +} + +/// Emit a daemon event, auto-starting the daemon if it is not running. +/// +/// If the daemon is not reachable and `daemon.autostart` is enabled, this +/// will start the daemon and retry the event. If the daemon cannot be +/// started or the retry fails, a warning is printed to stderr. +pub async fn emit_event(settings: &Settings, event: DaemonEvent) { + // Try to connect and send + match ControlClient::from_settings(settings).await { + Ok(mut client) => { + if let Err(e) = client.send_event(event).await { + tracing::debug!(?e, "failed to send event to daemon"); + } + return; + } + Err(e) if !settings.daemon.autostart || !should_retry_after_error(&e) => { + tracing::debug!(?e, "daemon not available, skipping event emission"); + return; + } + Err(_) => {} + } + + // Auto-start the daemon and retry + if let Err(e) = ensure_daemon_running(settings).await { + eprintln!("Could not start daemon: {e}"); + return; + } + + match ControlClient::from_settings(settings).await { + Ok(mut client) => { + if let Err(e) = client.send_event(event).await { + eprintln!("Daemon started but failed to send event: {e}"); + } + } + Err(e) => { + eprintln!("Daemon started but failed to connect: {e}"); + } + } +} + +pub async fn tail_client(settings: &Settings) -> Result<HistoryClient> { + match probe(settings).await { + Probe::Ready(client) => return Ok(client), + Probe::NeedsRestart(reason) if !settings.daemon.autostart => { + bail!("{reason}. Enable `daemon.autostart = true` or restart the daemon manually"); + } + Probe::Unreachable(err) if is_legacy_daemon_error(&err) => { + return Err(err.wrap_err(LEGACY_DAEMON_RESTART_MESSAGE)); + } + Probe::Unreachable(err) if !settings.daemon.autostart => return Err(err), + Probe::Unreachable(err) if !should_retry_after_error(&err) => return Err(err), + Probe::NeedsRestart(_) | Probe::Unreachable(_) => {} + } + + restart_daemon(settings).await +} + +async fn status_cmd(settings: &Settings) -> Result<()> { + match probe(settings).await { + Probe::Ready(mut client) => { + let status = client.status().await?; + println!("Daemon running"); + println!(" PID: {}", status.pid); + println!(" Version: {}", status.version); + println!(" Protocol: {}", status.protocol); + println!(" Healthy: {}", status.healthy); + #[cfg(unix)] + println!(" Socket: {}", settings.daemon.socket_path); + } + Probe::NeedsRestart(reason) => { + println!("Daemon running (needs restart)"); + println!(" Reason: {reason}"); + } + Probe::Unreachable(_) => { + println!("Daemon is not running"); + } + } + + Ok(()) +} + +async fn stop_cmd(settings: &Settings) -> Result<()> { + let Ok(mut client) = connect_client(settings).await else { + println!("Daemon is not running"); + return Ok(()); + }; + + match client.shutdown().await { + Ok(true) => { + println!("Shutdown requested"); + + let pidfile_path = PathBuf::from(&settings.daemon.pidfile_path); + let timeout = Duration::from_secs(5); + match wait_for_pidfile_available(&pidfile_path, timeout).await { + Ok(()) => println!("Daemon stopped"), + Err(_) => println!("Daemon may still be shutting down"), + } + + Ok(()) + } + Ok(false) => bail!("Daemon rejected shutdown request"), + Err(err) => Err(err.wrap_err("Failed to send shutdown request")), + } +} + +async fn restart_cmd(settings: &Settings) -> Result<()> { + // Stop if running + match probe(settings).await { + Probe::Ready(_) | Probe::NeedsRestart(_) => { + request_shutdown(settings).await; + println!("Stopping daemon..."); + + let pidfile_path = PathBuf::from(&settings.daemon.pidfile_path); + let timeout = Duration::from_secs(5); + wait_for_pidfile_available(&pidfile_path, timeout) + .await + .wrap_err("Timed out waiting for old daemon to stop")?; + } + Probe::Unreachable(_) => { + println!("No daemon running"); + } + } + + #[cfg(unix)] + remove_stale_socket_if_present(settings)?; + + spawn_daemon_process()?; + println!("Starting daemon..."); + + let timeout = startup_timeout(settings); + let status = wait_until_ready(settings, timeout).await?.status().await?; + + println!("Daemon restarted"); + println!(" PID: {}", status.pid); + println!(" Version: {}", status.version); + + Ok(()) +} + +/// Daemonize the current process. Must be called before creating the tokio +/// runtime or opening database connections, since `fork()` inside an async +/// runtime corrupts its internal state. +#[cfg(unix)] +pub fn daemonize_current_process() -> Result<()> { + let cwd = + std::env::current_dir().wrap_err("could not determine current directory for daemon")?; + + Daemonize::new() + .working_directory(cwd) + .start() + .wrap_err("failed to daemonize process")?; + + Ok(()) +} + +async fn run( + settings: Settings, + store: SqliteStore, + history_db: Sqlite, + force: bool, +) -> Result<()> { + if force { + force_cleanup(&settings); + } + + let pidfile_path = PathBuf::from(&settings.daemon.pidfile_path); + let _pidfile_guard = PidfileGuard::acquire(&pidfile_path)?; + + crate::atuin_daemon::boot(settings, store, history_db).await?; + + Ok(()) +} + +/// Force cleanup: kill existing daemon process and remove socket. +fn force_cleanup(settings: &Settings) { + let pidfile_path = Path::new(&settings.daemon.pidfile_path); + + // Read and kill the existing process if pidfile exists + if pidfile_path.exists() { + if let Ok(contents) = fs::read_to_string(pidfile_path) + && let Some(pid_str) = contents.lines().next() + && let Ok(pid) = pid_str.parse::<u32>() + { + kill_process(pid); + // Give it a moment to release resources + std::thread::sleep(Duration::from_millis(100)); + } + + // Remove the pidfile + if let Err(e) = fs::remove_file(pidfile_path) + && e.kind() != ErrorKind::NotFound + { + tracing::warn!("failed to remove pidfile: {e}"); + } + } + + // Remove the socket file + #[cfg(unix)] + { + let socket_path = Path::new(&settings.daemon.socket_path); + if socket_path.exists() + && let Err(e) = fs::remove_file(socket_path) + && e.kind() != ErrorKind::NotFound + { + tracing::warn!("failed to remove socket: {e}"); + } + } +} + +/// Kill a process by PID. +#[cfg(unix)] +fn kill_process(pid: u32) { + // Use kill command to send SIGTERM for graceful shutdown + let _ = Command::new("kill") + .args(["-TERM", &pid.to_string()]) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .status(); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_version_matches() { + assert!(daemon_matches_expected( + DAEMON_VERSION, + DAEMON_PROTOCOL_VERSION + )); + } + + #[test] + fn test_version_mismatch() { + assert!(!daemon_matches_expected("0.0.0", DAEMON_PROTOCOL_VERSION)); + assert!(!daemon_matches_expected(DAEMON_VERSION, 999)); + assert!(!daemon_matches_expected("0.0.0", 999)); + } + + #[test] + fn test_mismatch_message_version() { + let msg = daemon_mismatch_message("0.0.0", DAEMON_PROTOCOL_VERSION); + assert!(msg.contains("out of date"), "got: {msg}"); + assert!(msg.contains("0.0.0")); + assert!(msg.contains(DAEMON_VERSION)); + } + + #[test] + fn test_mismatch_message_protocol() { + let msg = daemon_mismatch_message(DAEMON_VERSION, 999); + assert!(msg.contains("protocol mismatch"), "got: {msg}"); + } + + #[test] + fn test_startup_lock_path() { + let pidfile = Path::new("/tmp/atuin-daemon.pid"); + let lock = daemon_startup_lock_path(pidfile); + assert_eq!(lock, PathBuf::from("/tmp/atuin-daemon.pid.startup.lock")); + } + + #[test] + fn test_pidfile_guard_acquire_and_drop() { + let tmp = tempfile::tempdir().unwrap(); + let pidfile = tmp.path().join("daemon.pid"); + + { + let _guard = PidfileGuard::acquire(&pidfile).unwrap(); + // Guard holds an exclusive lock — on Windows other handles cannot + // read the file, so we verify contents after the guard is dropped. + } + + let contents = std::fs::read_to_string(&pidfile).unwrap(); + let lines: Vec<&str> = contents.lines().collect(); + assert_eq!(lines.len(), 2); + assert_eq!(lines[0], std::process::id().to_string()); + assert_eq!(lines[1], DAEMON_VERSION); + + // After guard is dropped, lock should be released — acquiring again must succeed. + let _guard2 = PidfileGuard::acquire(&pidfile).unwrap(); + } + + #[test] + fn test_pidfile_guard_prevents_double_acquire() { + let tmp = tempfile::tempdir().unwrap(); + let pidfile = tmp.path().join("daemon.pid"); + + let _guard = PidfileGuard::acquire(&pidfile).unwrap(); + let result = PidfileGuard::acquire(&pidfile); + assert!(result.is_err()); + } +} diff --git a/crates/turtle/src/command/client/default_config.rs b/crates/turtle/src/command/client/default_config.rs new file mode 100644 index 00000000..e8cc15f9 --- /dev/null +++ b/crates/turtle/src/command/client/default_config.rs @@ -0,0 +1,4 @@ +pub fn run() { + // TODO(@bpeetz): Re-add the default settings option back (Settings::example_config()) <2026-06-11> + println!("TODO"); +} diff --git a/crates/turtle/src/command/client/doctor.rs b/crates/turtle/src/command/client/doctor.rs new file mode 100644 index 00000000..09fa6e77 --- /dev/null +++ b/crates/turtle/src/command/client/doctor.rs @@ -0,0 +1,412 @@ +use std::process::Command; +use std::{env, str::FromStr}; + +use crate::atuin_client::database::Sqlite; +use crate::atuin_client::settings::Settings; +use crate::atuin_common::shell::{Shell, shell_name}; +use crate::atuin_common::utils; +use colored::Colorize; +use eyre::Result; +use serde::Serialize; + +use sysinfo::{Disks, System, get_current_pid}; + +#[derive(Debug, Serialize)] +struct ShellInfo { + pub name: String, + + // best-effort, not supported on all OSes + pub default: String, + + // Detect some shell plugins that the user has installed. + // I'm just going to start with preexec/blesh + pub plugins: Vec<String>, + + // The preexec framework used in the current session, if Atuin is loaded. + pub preexec: Option<String>, +} + +impl ShellInfo { + // HACK ALERT! + // Many of the shell vars we need to detect are not exported :( + // So, we're going to run a interactive session and directly check the + // variable. There's a chance this won't work, so it should not be fatal. + // + // Every shell we support handles `shell -ic 'command'` + fn shellvar_exists(shell: &str, var: &str) -> bool { + let cmd = Command::new(shell) + .args([ + "-ic", + format!("[ -z ${var} ] || echo ATUIN_DOCTOR_ENV_FOUND").as_str(), + ]) + .output() + .map_or(String::new(), |v| { + let out = v.stdout; + String::from_utf8(out).unwrap_or_default() + }); + + cmd.contains("ATUIN_DOCTOR_ENV_FOUND") + } + + fn detect_preexec_framework(shell: &str) -> Option<String> { + if env::var("ATUIN_SESSION").ok().is_none() { + None + } else if shell.starts_with("bash") || shell == "sh" { + env::var("ATUIN_PREEXEC_BACKEND") + .ok() + .filter(|value| !value.is_empty()) + .and_then(|atuin_preexec_backend| { + atuin_preexec_backend.rfind(':').and_then(|pos_colon| { + u32::from_str(&atuin_preexec_backend[..pos_colon]) + .ok() + .is_some_and(|preexec_shlvl| { + env::var("SHLVL") + .ok() + .and_then(|shlvl| u32::from_str(&shlvl).ok()) + .is_some_and(|shlvl| shlvl == preexec_shlvl) + }) + .then(|| atuin_preexec_backend[pos_colon + 1..].to_string()) + }) + }) + } else { + Some("built-in".to_string()) + } + } + + fn validate_plugin_blesh( + _shell: &str, + shell_process: &sysinfo::Process, + ble_session_id: &str, + ) -> Option<String> { + ble_session_id + .split('/') + .nth(1) + .and_then(|field| u32::from_str(field).ok()) + .filter(|&blesh_pid| blesh_pid == shell_process.pid().as_u32()) + .map(|_| "blesh".to_string()) + } + + pub fn plugins(shell: &str, shell_process: &sysinfo::Process) -> Vec<String> { + // consider a different detection approach if there are plugins + // that don't set shell vars + + enum PluginShellType { + Any, + Bash, + + // Note: these are currently unused + #[expect(dead_code)] + Zsh, + #[expect(dead_code)] + Fish, + #[expect(dead_code)] + Nushell, + #[expect(dead_code)] + Xonsh, + } + + enum PluginProbeType { + EnvironmentVariable(&'static str), + InteractiveShellVariable(&'static str), + } + + type PluginValidator = fn(&str, &sysinfo::Process, &str) -> Option<String>; + + let plugin_list: [( + &str, + PluginShellType, + PluginProbeType, + Option<PluginValidator>, + ); 3] = [ + ( + "atuin", + PluginShellType::Any, + PluginProbeType::EnvironmentVariable("ATUIN_SESSION"), + None, + ), + ( + "blesh", + PluginShellType::Bash, + PluginProbeType::EnvironmentVariable("BLE_SESSION_ID"), + Some(Self::validate_plugin_blesh), + ), + ( + "bash-preexec", + PluginShellType::Bash, + PluginProbeType::InteractiveShellVariable("bash_preexec_imported"), + None, + ), + ]; + + plugin_list + .into_iter() + .filter(|(_, shell_type, _, _)| match shell_type { + PluginShellType::Any => true, + PluginShellType::Bash => shell.starts_with("bash") || shell == "sh", + PluginShellType::Zsh => shell.starts_with("zsh"), + PluginShellType::Fish => shell.starts_with("fish"), + PluginShellType::Nushell => shell.starts_with("nu"), + PluginShellType::Xonsh => shell.starts_with("xonsh"), + }) + .filter_map(|(plugin, _, probe_type, validator)| -> Option<String> { + match probe_type { + PluginProbeType::EnvironmentVariable(env) => { + env::var(env).ok().filter(|value| !value.is_empty()) + } + PluginProbeType::InteractiveShellVariable(shellvar) => { + ShellInfo::shellvar_exists(shell, shellvar).then_some(String::default()) + } + } + .and_then(|value| { + validator.map_or_else( + || Some(plugin.to_string()), + |validator| validator(shell, shell_process, &value), + ) + }) + }) + .collect() + } + + pub fn new() -> Self { + // TODO: rework to use crate::atuin_common::Shell + + let sys = System::new_all(); + + let process = sys + .process(get_current_pid().expect("Failed to get current PID")) + .expect("Process with current pid does not exist"); + + let parent = sys + .process(process.parent().expect("Atuin running with no parent!")) + .expect("Process with parent pid does not exist"); + + let name = shell_name(Some(parent)); + + let plugins = ShellInfo::plugins(name.as_str(), parent); + + let default = Shell::default_shell().unwrap_or(Shell::Unknown).to_string(); + + let preexec = Self::detect_preexec_framework(name.as_str()); + + Self { + name, + default, + plugins, + preexec, + } + } +} + +#[derive(Debug, Serialize)] +struct DiskInfo { + pub name: String, + pub filesystem: String, +} + +#[derive(Debug, Serialize)] +struct SystemInfo { + pub os: String, + + pub arch: String, + + pub version: String, + pub disks: Vec<DiskInfo>, +} + +impl SystemInfo { + pub fn new() -> Self { + let disks = Disks::new_with_refreshed_list(); + let disks = disks + .list() + .iter() + .map(|d| DiskInfo { + name: d.name().to_os_string().into_string().unwrap(), + filesystem: d.file_system().to_os_string().into_string().unwrap(), + }) + .collect(); + + Self { + os: System::name().unwrap_or_else(|| "unknown".to_string()), + arch: System::cpu_arch().unwrap_or_else(|| "unknown".to_string()), + version: System::os_version().unwrap_or_else(|| "unknown".to_string()), + disks, + } + } +} + +#[derive(Debug, Serialize)] +struct SyncInfo { + pub auth_state: String, + pub auto_sync: bool, + + pub last_sync: String, +} + +impl SyncInfo { + pub async fn new(settings: &Settings) -> Self { + // Build auth state description from raw token state without calling + // resolve_sync_auth(), which has side effects (token migration cleanup) + // that a diagnostic command should not trigger. + let meta = Settings::meta_store().await.ok(); + let has_cli_token = match &meta { + Some(m) => m.session_token().await.ok().flatten().is_some(), + None => false, + }; + + let auth_state = if has_cli_token { + "Self-hosted (authenticated)".into() + } else { + "Not authenticated".into() + }; + + Self { + auth_state, + auto_sync: settings.auto_sync, + last_sync: Settings::last_sync() + .await + .map_or_else(|_| "no last sync".to_string(), |v| v.to_string()), + } + } +} + +#[derive(Debug)] +struct SettingPaths { + db: String, + record_store: String, + key: String, +} + +impl SettingPaths { + pub fn new(settings: &Settings) -> Self { + Self { + db: settings.db_path.clone(), + record_store: settings.record_store_path.clone(), + key: settings.key_path.clone(), + } + } + + pub fn verify(&self) { + let paths = vec![ + ("ATUIN_DB_PATH", &self.db), + ("ATUIN_RECORD_STORE", &self.record_store), + ("ATUIN_KEY", &self.key), + ]; + + for (path_env_var, path) in paths { + if utils::broken_symlink(path) { + eprintln!( + "{path} (${path_env_var}) is a broken symlink. This may cause issues with Atuin." + ); + } + } + } +} + +#[derive(Debug, Serialize)] +struct AtuinInfo { + pub version: String, + pub commit: String, + + /// Whether the main Atuin sync server is in use + /// I'm just calling it Atuin Cloud for lack of a better name atm + pub sync: Option<SyncInfo>, + + pub sqlite_version: String, + + #[serde(skip)] // probably unnecessary to expose this + pub setting_paths: SettingPaths, +} + +impl AtuinInfo { + pub async fn new(settings: &Settings) -> Self { + let logged_in = settings.logged_in().await.unwrap_or(false); + + let sync = if logged_in { + Some(SyncInfo::new(settings).await) + } else { + None + }; + + let sqlite_version = match Sqlite::new("sqlite::memory:", 0.1).await { + Ok(db) => db + .sqlite_version() + .await + .unwrap_or_else(|_| "unknown".to_string()), + Err(_) => "error".to_string(), + }; + + Self { + version: crate::VERSION.to_string(), + commit: crate::SHA.to_string(), + sync, + sqlite_version, + setting_paths: SettingPaths::new(settings), + } + } +} + +#[derive(Debug, Serialize)] +struct DoctorDump { + pub atuin: AtuinInfo, + pub shell: ShellInfo, + pub system: SystemInfo, +} + +impl DoctorDump { + pub async fn new(settings: &Settings) -> Self { + Self { + atuin: AtuinInfo::new(settings).await, + shell: ShellInfo::new(), + system: SystemInfo::new(), + } + } +} + +fn checks(info: &DoctorDump) { + println!(); // spacing + // + let zfs_error = "[Filesystem] ZFS is known to have some issues with SQLite. Atuin uses SQLite heavily. If you are having poor performance, there are some workarounds here: https://github.com/atuinsh/atuin/issues/952".bold().red(); + let bash_plugin_error = "[Shell] If you are using Bash, Atuin requires that either bash-preexec or ble.sh (>= 0.4) be installed. An older ble.sh may not be detected. so ignore this if you have ble.sh >= 0.4 set up! Read more here: https://docs.atuin.sh/guide/installation/#bash".bold().red(); + let blesh_integration_error = "[Shell] Atuin and ble.sh seem to be loaded in the session, but the integration does not seem to be working. Please check the setup in .bashrc.".bold().red(); + + // ZFS: https://github.com/atuinsh/atuin/issues/952 + if info.system.disks.iter().any(|d| d.filesystem == "zfs") { + println!("{zfs_error}"); + } + + info.atuin.setting_paths.verify(); + + // Shell + if info.shell.name == "bash" { + if !info + .shell + .plugins + .iter() + .any(|p| p == "blesh" || p == "bash-preexec") + { + println!("{bash_plugin_error}"); + } + + if info.shell.plugins.iter().any(|plugin| plugin == "atuin") + && info.shell.plugins.iter().any(|plugin| plugin == "blesh") + && info.shell.preexec.as_ref().is_some_and(|val| val == "none") + { + println!("{blesh_integration_error}"); + } + } +} + +pub async fn run(settings: &Settings) -> Result<()> { + println!("{}", "Atuin Doctor".bold()); + println!("Checking for diagnostics"); + let dump = DoctorDump::new(settings).await; + + checks(&dump); + + let dump = serde_json::to_string_pretty(&dump)?; + + println!("\nPlease include the output below with any bug reports or issues\n"); + println!("{dump}"); + + Ok(()) +} diff --git a/crates/turtle/src/command/client/history.rs b/crates/turtle/src/command/client/history.rs new file mode 100644 index 00000000..0c61392c --- /dev/null +++ b/crates/turtle/src/command/client/history.rs @@ -0,0 +1,1340 @@ +use std::{ + fmt::{self, Display}, + io::{self, IsTerminal, Write}, + path::PathBuf, + time::Duration, +}; + +use crate::atuin_common::utils::{self, Escapable as _}; +use clap::Subcommand; +use eyre::{Context, Result, bail}; +use runtime_format::{FormatKey, FormatKeyError, ParseSegment, ParsedFmt}; + +#[cfg(feature = "daemon")] +use super::daemon as daemon_cmd; +#[cfg(feature = "daemon")] +use colored::Colorize; +#[cfg(feature = "daemon")] +use serde::Serialize; + +#[cfg(feature = "daemon")] +use crate::atuin_daemon::history::{HistoryEventKind, TailHistoryReply}; + +use crate::atuin_client::{ + database::{Database, Sqlite, current_context}, + encryption, + history::{History, store::HistoryStore}, + record::sqlite_store::SqliteStore, + settings::{ + FilterMode::{Directory, Global, Session}, + Settings, Timezone, + }, +}; + +#[cfg(feature = "sync")] +use crate::atuin_client::record; + +use log::{debug, warn}; +use time::{OffsetDateTime, macros::format_description}; + +#[cfg(feature = "daemon")] +use super::daemon; +use super::search::format_duration_into; + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Begins a new command in the history + Start { + /// Collects the command from the `ATUIN_COMMAND_LINE` environment variable, + /// which does not need escaping and is more compatible between OS and shells + #[arg(long = "command-from-env", hide = true)] + cmd_env: bool, + + /// Author of this command, eg `ellie`, `claude`, or `copilot` + #[arg(long)] + author: Option<String>, + + /// Optional intent/rationale for running this command + #[arg(long)] + intent: Option<String>, + + command: Vec<String>, + }, + + /// Finishes a new command in the history (adds time, exit code) + End { + id: String, + #[arg(long, short)] + exit: i64, + #[arg(long, short)] + duration: Option<u64>, + }, + + /// Stream history events from the daemon as they are received + Tail, + + /// List all items in history + List { + #[arg(long, short)] + cwd: bool, + + #[arg(long, short)] + session: bool, + + #[arg(long)] + human: bool, + + /// Show only the text of the command + #[arg(long)] + cmd_only: bool, + + /// Terminate the output with a null, for better multiline support + #[arg(long)] + print0: bool, + + #[arg(long, short, default_value = "true")] + // accept no value + #[arg(num_args(0..=1), default_missing_value("true"))] + // accept a value + #[arg(action = clap::ArgAction::Set)] + reverse: bool, + + /// Display the command time in another timezone other than the configured default. + /// + /// This option takes one of the following kinds of values: + /// - the special value "local" (or "l") which refers to the system time zone + /// - an offset from UTC (e.g. "+9", "-2:30") + #[arg(long, visible_alias = "tz")] + timezone: Option<Timezone>, + + /// Available variables: {command}, {directory}, {duration}, {user}, {host}, {author}, {intent}, {exit}, {time}, {session}, and {uuid} + /// Example: --format "{time} - [{duration}] - {directory}$\t{command}" + #[arg(long, short)] + format: Option<String>, + }, + + /// Get the last command ran + Last { + #[arg(long)] + human: bool, + + /// Show only the text of the command + #[arg(long)] + cmd_only: bool, + + /// Display the command time in another timezone other than the configured default. + /// + /// This option takes one of the following kinds of values: + /// - the special value "local" (or "l") which refers to the system time zone + /// - an offset from UTC (e.g. "+9", "-2:30") + #[arg(long, visible_alias = "tz")] + timezone: Option<Timezone>, + + /// Available variables: {command}, {directory}, {duration}, {user}, {host}, {author}, {intent}, {time}, {session}, {uuid} and {relativetime}. + /// Example: --format "{time} - [{duration}] - {directory}$\t{command}" + #[arg(long, short)] + format: Option<String>, + }, + + InitStore, + + /// Delete history entries matching the configured exclusion filters + Prune { + /// List matching history lines without performing the actual deletion. + #[arg(short = 'n', long)] + dry_run: bool, + }, + + /// Delete duplicate history entries (that have the same command, cwd and hostname) + Dedup { + /// List matching history lines without performing the actual deletion. + #[arg(short = 'n', long)] + dry_run: bool, + + /// Only delete results added before this date + #[arg(long, short)] + before: String, + + /// How many recent duplicates to keep + #[arg(long)] + dupkeep: u32, + }, +} + +#[derive(Clone, Copy, Debug)] +pub enum ListMode { + Human, + CmdOnly, + Regular, +} + +impl ListMode { + pub const fn from_flags(human: bool, cmd_only: bool) -> Self { + if human { + ListMode::Human + } else if cmd_only { + ListMode::CmdOnly + } else { + ListMode::Regular + } + } +} + +#[expect(clippy::cast_sign_loss)] +pub fn print_list( + h: &[History], + list_mode: ListMode, + format: Option<&str>, + print0: bool, + reverse: bool, + tz: Timezone, +) { + let w = std::io::stdout(); + let mut w = w.lock(); + + let fmt_str = match list_mode { + ListMode::Human => format + .unwrap_or("{time} · {duration}\t{command}") + .replace("\\t", "\t"), + ListMode::Regular => format + .unwrap_or("{time}\t{command}\t{duration}") + .replace("\\t", "\t"), + // not used + ListMode::CmdOnly => String::new(), + }; + + let parsed_fmt = match list_mode { + ListMode::Human | ListMode::Regular => parse_fmt(&fmt_str), + ListMode::CmdOnly => std::iter::once(ParseSegment::Key("command")).collect(), + }; + + let iterator = if reverse { + Box::new(h.iter().rev()) as Box<dyn Iterator<Item = &History>> + } else { + Box::new(h.iter()) as Box<dyn Iterator<Item = &History>> + }; + + let entry_terminator = if print0 { "\0" } else { "\n" }; + let flush_each_line = print0; + + for history in iterator { + let fh = FmtHistory { + history, + cmd_format: CmdFormat::for_output(&w), + tz: &tz, + }; + let args = parsed_fmt.with_args(&fh); + + // Check for formatting errors before attempting to write + if let Err(err) = args.status() { + eprintln!("ERROR: history output failed with: {err}"); + std::process::exit(1); + } + + let write_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + write!(w, "{args}{entry_terminator}") + })); + + match write_result { + Ok(Ok(())) => { + // Write succeeded + } + Ok(Err(err)) => { + if err.kind() != io::ErrorKind::BrokenPipe { + eprintln!("ERROR: Failed to write history output: {err}"); + std::process::exit(1); + } + } + Err(_) => { + eprintln!("ERROR: Format string caused a formatting error."); + eprintln!( + "This may be due to an unsupported format string containing special characters." + ); + eprintln!( + "Please check your format string syntax and ensure literal braces are properly escaped." + ); + std::process::exit(1); + } + } + if flush_each_line { + check_for_write_errors(w.flush()); + } + } + + if !flush_each_line { + check_for_write_errors(w.flush()); + } +} + +fn check_for_write_errors(write: Result<(), io::Error>) { + if let Err(err) = write { + // Ignore broken pipe (issue #626) + if err.kind() != io::ErrorKind::BrokenPipe { + eprintln!("ERROR: History output failed with the following error: {err}"); + std::process::exit(1); + } + } +} + +/// Type wrapper around `History` with formatting settings. +#[derive(Clone, Copy, Debug)] +struct FmtHistory<'a> { + history: &'a History, + cmd_format: CmdFormat, + tz: &'a Timezone, +} + +#[derive(Clone, Copy, Debug)] +enum CmdFormat { + Literal, + Escaped, +} +impl CmdFormat { + fn for_output<O: IsTerminal>(out: &O) -> Self { + if out.is_terminal() { + Self::Escaped + } else { + Self::Literal + } + } +} + +static TIME_FMT: &[time::format_description::FormatItem<'static>] = + format_description!("[year]-[month]-[day] [hour repr:24]:[minute]:[second]"); + +/// defines how to format the history +impl FormatKey for FmtHistory<'_> { + #[expect(clippy::cast_sign_loss)] + fn fmt(&self, key: &str, f: &mut fmt::Formatter<'_>) -> Result<(), FormatKeyError> { + match key { + "command" => match self.cmd_format { + CmdFormat::Literal => f.write_str(self.history.command.trim()), + CmdFormat::Escaped => f.write_str(&self.history.command.trim().escape_control()), + }?, + "directory" => f.write_str(self.history.cwd.trim())?, + "exit" => f.write_str(&self.history.exit.to_string())?, + "duration" => { + let dur = Duration::from_nanos(std::cmp::max(self.history.duration, 0) as u64); + format_duration_into(dur, f)?; + } + "time" => { + self.history + .timestamp + .to_offset(self.tz.0) + .format(TIME_FMT) + .map_err(|_| fmt::Error)? + .fmt(f)?; + } + "relativetime" => { + let since = OffsetDateTime::now_utc() - self.history.timestamp; + let d = Duration::try_from(since).unwrap_or_default(); + format_duration_into(d, f)?; + } + "host" => f.write_str( + self.history + .hostname + .split_once(':') + .map_or(&self.history.hostname, |(host, _)| host), + )?, + "author" => f.write_str(&self.history.author)?, + "intent" => f.write_str(self.history.intent.as_deref().unwrap_or_default())?, + "user" => f.write_str( + self.history + .hostname + .split_once(':') + .map_or("", |(_, user)| user), + )?, + "session" => f.write_str(&self.history.session)?, + "uuid" => f.write_str(&self.history.id.0)?, + _ => return Err(FormatKeyError::UnknownKey), + } + Ok(()) + } +} + +fn parse_fmt(format: &str) -> ParsedFmt<'_> { + match ParsedFmt::new(format) { + Ok(fmt) => fmt, + Err(err) => { + eprintln!("ERROR: History formatting failed with the following error: {err}"); + + if format.contains('"') && (format.contains(":{") || format.contains(",{")) { + eprintln!("It looks like you're trying to create JSON output."); + eprintln!("For JSON, you need to escape literal braces by doubling them:"); + eprintln!("Example: '{{\"command\":\"{{command}}\",\"time\":\"{{time}}\"}}'"); + } else { + eprintln!( + "If your formatting string contains literal curly braces, you need to escape them by doubling:" + ); + eprintln!("Use {{{{ for literal {{ and }}}} for literal }}"); + } + std::process::exit(1) + } + } +} + +fn apply_start_metadata(history: &mut History, author: Option<&str>, intent: Option<&str>) { + if let Some(author) = author.map(str::trim).filter(|author| !author.is_empty()) { + author.clone_into(&mut history.author); + } + + if let Some(intent) = intent.map(str::trim).filter(|intent| !intent.is_empty()) { + history.intent = Some(intent.to_owned()); + } else if intent.is_some() { + history.intent = None; + } +} + +fn normalize_command_for_storage<'a>(command: &'a str, settings: &Settings) -> &'a str { + if !settings.strip_trailing_whitespace { + return command; + } + + let trimmed = command.trim_end_matches([' ', '\t']); + if trimmed.len() == command.len() { + return command; + } + + let trailing_backslashes = trimmed + .as_bytes() + .iter() + .rev() + .take_while(|&&byte| byte == b'\\') + .count(); + + if trailing_backslashes % 2 == 1 { + command + } else { + trimmed + } +} + +async fn handle_start( + db: &impl Database, + settings: &Settings, + command: &str, + author: Option<&str>, + intent: Option<&str>, +) -> Result<Option<String>> { + // It's better for atuin to silently fail here and attempt to + // store whatever is ran, than to throw an error to the terminal + let cwd = utils::get_current_dir(); + let command = normalize_command_for_storage(command, settings); + + let mut h: History = History::capture() + .timestamp(OffsetDateTime::now_utc()) + .command(command) + .cwd(cwd) + .build() + .into(); + apply_start_metadata(&mut h, author, intent); + + if !h.should_save(settings) { + return Ok(None); + } + + let id = h.id.0.clone(); + + // Silently ignore database errors to avoid breaking the shell + // This is important when disk is full or database is locked + if let Err(e) = db.save(&h).await { + debug!("failed to save history: {e}"); + } + + Ok(Some(id)) +} + +#[cfg(feature = "daemon")] +async fn handle_daemon_start( + settings: &Settings, + command: &str, + author: Option<&str>, + intent: Option<&str>, +) -> Result<Option<String>> { + // It's better for atuin to silently fail here and attempt to + // store whatever is ran, than to throw an error to the terminal + let cwd = utils::get_current_dir(); + let command = normalize_command_for_storage(command, settings); + + let mut h: History = History::capture() + .timestamp(OffsetDateTime::now_utc()) + .command(command) + .cwd(cwd) + .build() + .into(); + apply_start_metadata(&mut h, author, intent); + + if !h.should_save(settings) { + return Ok(None); + } + + // Attempt to start history via daemon, but silently ignore errors + // to avoid breaking the shell when the daemon is unavailable or disk is full + let resp = match daemon::start_history(settings, h.clone()).await { + Ok(id) => id, + Err(e) => { + debug!("failed to start history via daemon: {e}"); + h.id.0.clone() + } + }; + + Ok(Some(resp)) +} + +#[expect(unused_variables)] +async fn handle_end( + db: &impl Database, + store: SqliteStore, + history_store: HistoryStore, + settings: &Settings, + id: &str, + exit: i64, + duration: Option<u64>, +) -> Result<()> { + if id.trim() == "" { + return Ok(()); + } + + let Some(mut h) = db.load(id).await? else { + warn!("history entry is missing"); + return Ok(()); + }; + + if h.duration > 0 { + debug!("cannot end history - already has duration"); + + // returning OK as this can occur if someone Ctrl-c a prompt + return Ok(()); + } + + if !settings.store_failed && exit > 0 { + debug!("history has non-zero exit code, and store_failed is false"); + + // the history has already been inserted half complete. remove it + db.delete(h).await?; + + return Ok(()); + } + + h.exit = exit; + h.duration = match duration { + Some(value) => i64::try_from(value).context("command took over 292 years")?, + None => i64::try_from((OffsetDateTime::now_utc() - h.timestamp).whole_nanoseconds()) + .context("command took over 292 years")?, + }; + + db.update(&h).await?; + history_store.push(h).await?; + + if settings.should_sync().await? { + let (_, downloaded) = + record::sync::sync(settings, &store, &history_store.encryption_key).await?; + Settings::save_sync_time().await?; + + crate::sync::build(settings, &store, db, Some(&downloaded)).await?; + } else { + debug!("sync disabled! not syncing"); + } + + Ok(()) +} + +#[cfg(feature = "daemon")] +async fn handle_daemon_end( + settings: &Settings, + id: &str, + exit: i64, + duration: Option<u64>, +) -> Result<()> { + daemon::end_history(settings, id.to_string(), duration.unwrap_or(0), exit).await?; + + Ok(()) +} + +pub(super) async fn start_history_entry( + settings: &Settings, + command: &str, + author: Option<&str>, + intent: Option<&str>, +) -> Result<Option<String>> { + #[cfg(feature = "daemon")] + if settings.daemon.enabled { + return handle_daemon_start(settings, command, author, intent).await; + } + + let db_path = PathBuf::from(settings.db_path.as_str()); + let db = Sqlite::new(db_path, settings.local_timeout).await?; + handle_start(&db, settings, command, author, intent).await +} + +pub(super) async fn end_history_entry( + settings: &Settings, + id: &str, + exit: i64, + duration: Option<u64>, +) -> Result<()> { + #[cfg(feature = "daemon")] + if settings.daemon.enabled { + return handle_daemon_end(settings, id, exit, duration).await; + } + + let db_path = PathBuf::from(settings.db_path.as_str()); + let record_store_path = PathBuf::from(settings.record_store_path.as_str()); + + let db = Sqlite::new(db_path, settings.local_timeout).await?; + let store = SqliteStore::new(record_store_path, settings.local_timeout).await?; + + let encryption_key: [u8; 32] = encryption::load_key(settings) + .context("could not load encryption key")? + .into(); + let host_id = Settings::host_id().await?; + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + + handle_end(&db, store, history_store, settings, id, exit, duration).await +} + +#[cfg(feature = "daemon")] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum TailKind { + Started, + Ended, +} + +#[cfg(feature = "daemon")] +#[derive(Clone, Debug, Eq, PartialEq)] +struct TailEvent { + kind: TailKind, + history: History, +} + +#[cfg(feature = "daemon")] +#[derive(Serialize)] +struct TailJsonEvent<'a> { + event: &'static str, + history: TailJsonHistory<'a>, +} + +#[cfg(feature = "daemon")] +#[derive(Serialize)] +struct TailJsonHistory<'a> { + id: &'a str, + timestamp: String, + timestamp_unix_ns: u64, + command: &'a str, + cwd: &'a str, + session: &'a str, + hostname: &'a str, + host: &'a str, + user: &'a str, + author: &'a str, + #[serde(skip_serializing_if = "Option::is_none")] + intent: Option<&'a str>, + #[serde(skip_serializing_if = "Option::is_none")] + exit: Option<i64>, + #[serde(skip_serializing_if = "Option::is_none")] + duration_ns: Option<i64>, + #[serde(skip_serializing_if = "Option::is_none")] + duration: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + success: Option<bool>, + #[serde(skip_serializing_if = "Option::is_none")] + finished_at: Option<String>, +} + +#[cfg(feature = "daemon")] +impl TailEvent { + fn from_proto(reply: TailHistoryReply) -> Result<Self> { + let history = reply + .history + .ok_or_else(|| eyre::eyre!("daemon sent a history tail event without history"))?; + let timestamp = OffsetDateTime::from_unix_timestamp_nanos(i128::from(history.timestamp)) + .context("invalid daemon history timestamp")?; + let kind = match HistoryEventKind::try_from(reply.kind) + .unwrap_or(HistoryEventKind::Unspecified) + { + HistoryEventKind::Started => TailKind::Started, + HistoryEventKind::Ended => TailKind::Ended, + HistoryEventKind::Unspecified => bail!("daemon sent an unspecified history tail event"), + }; + + Ok(Self { + kind, + history: History { + id: history.id.into(), + timestamp, + duration: history.duration, + exit: history.exit, + command: history.command, + cwd: history.cwd, + session: history.session, + hostname: history.hostname, + author: history.author, + intent: normalize_optional_field(&history.intent), + deleted_at: None, + }, + }) + } + + fn render(&self, tty: bool, tz: Timezone) -> Result<String> { + if tty { + Ok(self.render_pretty(tz)) + } else { + let mut json = self.render_json(tz)?; + json.push('\n'); + Ok(json) + } + } + + fn render_json(&self, tz: Timezone) -> Result<String> { + let payload = TailJsonEvent { + event: self.kind.as_str(), + history: TailJsonHistory { + id: &self.history.id.0, + timestamp: format_history_time(self.history.timestamp, tz)?, + timestamp_unix_ns: u64::try_from(self.history.timestamp.unix_timestamp_nanos()) + .context("history timestamp predates unix epoch")?, + command: &self.history.command, + cwd: &self.history.cwd, + session: &self.history.session, + hostname: &self.history.hostname, + host: self.host(), + user: self.user(), + author: &self.history.author, + intent: self.history.intent.as_deref(), + exit: self.exit_value(), + duration_ns: self.duration_value(), + duration: self.duration_value().map(format_duration_ns), + success: self.success_value(), + finished_at: self + .finished_at() + .map(|time| format_history_time(time, tz)) + .transpose()?, + }, + }; + + Ok(serde_json::to_string(&payload)?) + } + + fn render_pretty(&self, tz: Timezone) -> String { + let mut out = String::new(); + let border = match self.kind { + TailKind::Started => "-".repeat(72).bright_blue().to_string(), + TailKind::Ended if self.history.exit == 0 => "-".repeat(72).bright_green().to_string(), + TailKind::Ended => "-".repeat(72).bright_red().to_string(), + }; + + out.push_str(&border); + out.push('\n'); + + let command = self.history.command.trim(); + let escaped_command = command.escape_control(); + let mut command_lines = escaped_command.lines(); + let header = format!( + "{} {}", + self.kind.badge(self.history.exit), + command_lines.next().unwrap_or_default().bold() + ); + out.push_str(&header); + out.push('\n'); + + for line in command_lines { + out.push_str(" "); + out.push_str(line); + out.push('\n'); + } + + push_pretty_field( + &mut out, + "start", + &format_history_time(self.history.timestamp, tz) + .unwrap_or_else(|_| "invalid".to_owned()), + ); + push_pretty_field(&mut out, "history", &self.history.id.0); + push_pretty_field(&mut out, "session", &self.history.session); + push_pretty_field(&mut out, "exit", &self.exit_display()); + push_pretty_field(&mut out, "duration", &self.duration_display()); + + out.push('\n'); + + push_pretty_field(&mut out, "cwd", &self.history.cwd); + push_pretty_field(&mut out, "hostname", &self.history.hostname); + push_pretty_field(&mut out, "host", self.host()); + push_pretty_field(&mut out, "user", self.user()); + push_pretty_field(&mut out, "author", &self.history.author); + + if let Some(intent) = self.history.intent.as_deref() { + push_pretty_field(&mut out, "intent", intent); + } + + if let Some(finished) = self.finished_at() { + let finished = + format_history_time(finished, tz).unwrap_or_else(|_| "invalid".to_owned()); + push_pretty_field(&mut out, "finished", &finished); + } + + out.push_str(&border); + out.push_str("\n\n"); + out + } + + fn host(&self) -> &str { + self.history + .hostname + .split_once(':') + .map_or(self.history.hostname.as_str(), |(host, _)| host) + } + + fn user(&self) -> &str { + self.history + .hostname + .split_once(':') + .map_or("", |(_, user)| user) + } + + fn exit_value(&self) -> Option<i64> { + matches!(self.kind, TailKind::Ended).then_some(self.history.exit) + } + + fn duration_value(&self) -> Option<i64> { + matches!(self.kind, TailKind::Ended).then_some(self.history.duration) + } + + fn success_value(&self) -> Option<bool> { + matches!(self.kind, TailKind::Ended).then_some(self.history.exit == 0) + } + + fn finished_at(&self) -> Option<OffsetDateTime> { + self.duration_value() + .filter(|duration| *duration >= 0) + .map(time::Duration::nanoseconds) + .and_then(|duration| self.history.timestamp.checked_add(duration)) + } + + fn exit_display(&self) -> String { + match self.exit_value() { + Some(0) => "0 (success)".bright_green().to_string(), + Some(code) => format!("{code} (failure)").bright_red().to_string(), + None => "pending".bright_yellow().to_string(), + } + } + + fn duration_display(&self) -> String { + match self.duration_value() { + Some(duration) if duration >= 0 => format_duration_ns(duration), + Some(_) => "unknown".bright_yellow().to_string(), + None => "running".bright_yellow().to_string(), + } + } +} + +#[cfg(feature = "daemon")] +impl TailKind { + const fn as_str(self) -> &'static str { + match self { + Self::Started => "started", + Self::Ended => "ended", + } + } + + fn badge(self, exit: i64) -> colored::ColoredString { + match self { + Self::Started => "STARTED".bold().bright_blue(), + Self::Ended if exit == 0 => "ENDED".bold().bright_green(), + Self::Ended => "ENDED".bold().bright_red(), + } + } +} + +#[cfg(feature = "daemon")] +fn format_history_time(timestamp: OffsetDateTime, tz: Timezone) -> Result<String> { + Ok(timestamp.to_offset(tz.0).format(TIME_FMT)?) +} + +#[cfg(feature = "daemon")] +fn format_duration_ns(duration_ns: i64) -> String { + struct F(Duration); + impl Display for F { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + format_duration_into(self.0, f) + } + } + + F(Duration::from_nanos(duration_ns.max(0).cast_unsigned())).to_string() +} + +#[cfg(feature = "daemon")] +fn push_pretty_field(out: &mut String, label: &str, value: &str) { + out.push_str(" "); + let label = format!("{label}:"); + out.push_str(&label.bright_cyan().bold().to_string()); + if label.len() < 10 { + out.push_str(&" ".repeat(10 - label.len())); + } + + let mut lines = value.lines(); + if let Some(first) = lines.next() { + out.push_str(first); + } + out.push('\n'); + + for line in lines { + out.push_str(" "); + out.push_str(line); + out.push('\n'); + } +} + +#[cfg(feature = "daemon")] +fn normalize_optional_field(value: &str) -> Option<String> { + let trimmed = value.trim(); + if trimmed.is_empty() { + None + } else { + Some(trimmed.to_owned()) + } +} + +impl Cmd { + #[cfg(feature = "daemon")] + async fn handle_tail(settings: &Settings) -> Result<()> { + let tty = std::io::stdout().is_terminal(); + let mut client = daemon::tail_client(settings).await?; + let mut stream = client.tail_history().await?; + let stdout = std::io::stdout(); + + while let Some(reply) = stream.message().await? { + let event = TailEvent::from_proto(reply)?; + let rendered = event.render(tty, settings.timezone)?; + let mut out = stdout.lock(); + + match out.write_all(rendered.as_bytes()) { + Ok(()) => out.flush()?, + Err(err) if err.kind() == io::ErrorKind::BrokenPipe => break, + Err(err) => return Err(err.into()), + } + } + + Ok(()) + } + + #[expect(clippy::too_many_lines, clippy::cast_possible_truncation)] + #[expect(clippy::too_many_arguments)] + #[expect(clippy::fn_params_excessive_bools)] + async fn handle_list( + db: &impl Database, + settings: &Settings, + context: crate::atuin_client::database::Context, + session: bool, + cwd: bool, + mode: ListMode, + format: Option<String>, + include_deleted: bool, + print0: bool, + reverse: bool, + tz: Timezone, + ) -> Result<()> { + let filters = match (session, cwd) { + (true, true) => [Session, Directory], + (true, false) => [Session, Global], + (false, true) => [Global, Directory], + (false, false) => [ + settings.default_filter_mode(context.git_root.is_some()), + Global, + ], + }; + + let history = db + .list(&filters, &context, None, false, include_deleted) + .await?; + + print_list( + &history, + mode, + match format { + None => Some(settings.history_format.as_str()), + _ => format.as_deref(), + }, + print0, + reverse, + tz, + ); + + Ok(()) + } + + async fn handle_prune( + db: &impl Database, + settings: &Settings, + store: SqliteStore, + context: crate::atuin_client::database::Context, + dry_run: bool, + ) -> Result<()> { + // Grab all executed commands and filter them using History::should_save. + // We could iterate or paginate here if memory usage becomes an issue. + let matches: Vec<History> = db + .list(&[Global], &context, None, false, false) + .await? + .into_iter() + .filter(|h| !h.should_save(settings)) + .collect(); + + match matches.len() { + 0 => { + println!("No entries to prune."); + return Ok(()); + } + 1 => println!("Found 1 entry to prune."), + n => println!("Found {n} entries to prune."), + } + + if dry_run { + print_list( + &matches, + ListMode::Human, + Some(settings.history_format.as_str()), + false, + false, + settings.timezone, + ); + } else { + let encryption_key: [u8; 32] = encryption::load_key(settings) + .context("could not load encryption key")? + .into(); + let host_id = Settings::host_id().await?; + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + + for entry in matches { + eprintln!("deleting {}", entry.id); + let (id, _) = history_store.delete(entry.id.clone()).await?; + history_store.incremental_build(db, &[id]).await?; + } + + #[cfg(feature = "daemon")] + daemon_cmd::emit_event(settings, crate::atuin_daemon::DaemonEvent::HistoryPruned).await; + } + Ok(()) + } + + async fn handle_dedup( + db: &impl Database, + settings: &Settings, + store: SqliteStore, + before: i64, + dupkeep: u32, + dry_run: bool, + ) -> Result<()> { + if dupkeep == 0 { + eprintln!( + "\"--dupkeep 0\" would keep 0 copies of duplicate commands and thus delete all of them! Use \"atuin search --delete ...\" if you really want that." + ); + std::process::exit(1); + } + + let matches: Vec<History> = db.get_dups(before, dupkeep).await?; + + match matches.len() { + 0 => { + println!("No duplicates to delete."); + return Ok(()); + } + 1 => println!("Found 1 duplicate to delete."), + n => println!("Found {n} duplicates to delete."), + } + + if dry_run { + print_list( + &matches, + ListMode::Human, + Some(settings.history_format.as_str()), + false, + false, + settings.timezone, + ); + } else { + let encryption_key: [u8; 32] = encryption::load_key(settings) + .context("could not load encryption key")? + .into(); + let host_id = Settings::host_id().await?; + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + + #[cfg(feature = "daemon")] + let ids = matches.iter().map(|h| h.id.clone()).collect::<Vec<_>>(); + + for entry in matches { + eprintln!("deleting {}", entry.id); + let (id, _) = history_store.delete(entry.id).await?; + history_store.incremental_build(db, &[id]).await?; + } + + #[cfg(feature = "daemon")] + daemon_cmd::emit_event( + settings, + crate::atuin_daemon::DaemonEvent::HistoryDeleted { ids }, + ) + .await; + } + Ok(()) + } + + #[expect(clippy::too_many_lines)] + pub async fn run(self, settings: &Settings) -> Result<()> { + match self { + Self::Start { + cmd_env, + author, + intent, + command, + } => { + let command = if cmd_env { + std::env::var("ATUIN_COMMAND_LINE").unwrap_or_default() + } else { + command.join(" ") + }; + + if let Some(id) = + start_history_entry(settings, &command, author.as_deref(), intent.as_deref()) + .await? + { + println!("{id}"); + } + + Ok(()) + } + Self::End { id, exit, duration } => { + end_history_entry(settings, &id, exit, duration).await + } + Self::Tail => { + #[cfg(feature = "daemon")] + { + return Self::handle_tail(settings).await; + } + + #[cfg(not(feature = "daemon"))] + bail!("`atuin history tail` requires Atuin to be built with the `daemon` feature"); + } + cmd => { + let context = current_context().await?; + + let db_path = PathBuf::from(settings.db_path.as_str()); + let record_store_path = PathBuf::from(settings.record_store_path.as_str()); + + let db = Sqlite::new(db_path, settings.local_timeout).await?; + let store = SqliteStore::new(record_store_path, settings.local_timeout).await?; + + let encryption_key: [u8; 32] = encryption::load_key(settings) + .context("could not load encryption key")? + .into(); + + let host_id = Settings::host_id().await?; + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + + match cmd { + Self::List { + session, + cwd, + human, + cmd_only, + print0, + reverse, + timezone, + format, + } => { + let mode = ListMode::from_flags(human, cmd_only); + let tz = timezone.unwrap_or(settings.timezone); + Self::handle_list( + &db, settings, context, session, cwd, mode, format, false, print0, + reverse, tz, + ) + .await + } + + Self::Last { + human, + cmd_only, + timezone, + format, + } => { + let last = db.last().await?; + let last = last.as_slice(); + let tz = timezone.unwrap_or(settings.timezone); + print_list( + last, + ListMode::from_flags(human, cmd_only), + match format { + None => Some(settings.history_format.as_str()), + _ => format.as_deref(), + }, + false, + true, + tz, + ); + + Ok(()) + } + + Self::InitStore => history_store.init_store(&db).await, + + Self::Prune { dry_run } => { + Self::handle_prune(&db, settings, store, context, dry_run).await + } + + Self::Dedup { + dry_run, + before, + dupkeep, + } => { + let before = i64::try_from( + interim::parse_date_string( + before.as_str(), + OffsetDateTime::now_utc(), + interim::Dialect::Uk, + )? + .unix_timestamp_nanos(), + )?; + Self::handle_dedup(&db, settings, store, before, dupkeep, dry_run).await + } + + Self::Start { .. } | Self::End { .. } | Self::Tail => unreachable!(), + } + } + } + } +} + +#[cfg(test)] +mod tests { + #[cfg(feature = "daemon")] + use time::macros::datetime; + + use super::*; + + #[test] + fn normalize_command_strips_trailing_spaces_and_tabs() { + let settings = Settings::utc(); + + assert!(settings.strip_trailing_whitespace); + assert_eq!(normalize_command_for_storage("ls \t", &settings), "ls"); + } + + #[test] + fn normalize_command_preserves_escaped_trailing_space() { + let settings = Settings::utc(); + + assert_eq!( + normalize_command_for_storage("printf foo\\ ", &settings), + "printf foo\\ " + ); + assert_eq!( + normalize_command_for_storage("printf foo\\\\ ", &settings), + "printf foo\\\\" + ); + } + + #[tokio::test] + async fn handle_start_saves_trimmed_command() { + let db = Sqlite::new("sqlite::memory:", 2.0).await.unwrap(); + let settings = Settings::utc(); + + handle_start(&db, &settings, "ls \t", None, None) + .await + .unwrap(); + + let history = db + .before(OffsetDateTime::now_utc() + time::Duration::SECOND, 1) + .await + .unwrap() + .pop() + .unwrap(); + assert_eq!(history.command, "ls"); + } + + #[tokio::test] + async fn handle_start_can_keep_trailing_whitespace() { + let db = Sqlite::new("sqlite::memory:", 2.0).await.unwrap(); + let settings = Settings { + strip_trailing_whitespace: false, + ..Settings::utc() + }; + + handle_start(&db, &settings, "ls \t", None, None) + .await + .unwrap(); + + let history = db + .before(OffsetDateTime::now_utc() + time::Duration::SECOND, 1) + .await + .unwrap() + .pop() + .unwrap(); + assert_eq!(history.command, "ls \t"); + } + + #[test] + fn test_format_string_no_panic() { + // Don't panic but provide helpful output (issue #2776) + let malformed_json = r#"{"command":"{command}","key":"value"}"#; + + let result = std::panic::catch_unwind(|| parse_fmt(malformed_json)); + + assert!(result.is_ok()); + } + + #[test] + fn test_valid_formats_still_work() { + assert!(std::panic::catch_unwind(|| parse_fmt("{command}")).is_ok()); + assert!(std::panic::catch_unwind(|| parse_fmt("{time} - {command}")).is_ok()); + } + + #[cfg(feature = "daemon")] + fn sample_tail_event(kind: TailKind) -> TailEvent { + TailEvent { + kind, + history: History { + id: "history-id".to_owned().into(), + timestamp: datetime!(2026-04-09 17:18:19 UTC), + duration: 12_345_678, + exit: 0, + command: "git status".to_owned(), + cwd: "/tmp/repo".to_owned(), + session: "session-id".to_owned(), + hostname: "host:ellie".to_owned(), + author: "claude".to_owned(), + intent: Some("inspect repository state".to_owned()), + deleted_at: None, + }, + } + } + + #[cfg(feature = "daemon")] + #[test] + fn test_tail_json_output_contains_history_fields() { + let json = sample_tail_event(TailKind::Ended) + .render(false, Timezone(time::UtcOffset::UTC)) + .unwrap(); + let value: serde_json::Value = serde_json::from_str(&json).unwrap(); + + assert_eq!(value["event"], "ended"); + assert_eq!(value["history"]["id"], "history-id"); + assert_eq!(value["history"]["duration_ns"], 12_345_678); + assert_eq!(value["history"]["success"], true); + assert!(value.get("record").is_none()); + } + + #[cfg(feature = "daemon")] + #[test] + fn test_tail_pretty_output_shows_pending_fields_for_started_events() { + let rendered = sample_tail_event(TailKind::Started) + .render(true, Timezone(time::UtcOffset::UTC)) + .unwrap(); + let plain = regex::Regex::new(r"\x1b\[[0-9;]*m") + .unwrap() + .replace_all(&rendered, ""); + + assert!(plain.contains("STARTED git status")); + assert!(plain.contains("exit:")); + assert!(plain.contains("pending")); + assert!(plain.contains("duration:")); + assert!(plain.contains("running")); + } +} diff --git a/crates/turtle/src/command/client/import.rs b/crates/turtle/src/command/client/import.rs new file mode 100644 index 00000000..363e6405 --- /dev/null +++ b/crates/turtle/src/command/client/import.rs @@ -0,0 +1,186 @@ +use std::env; + +use async_trait::async_trait; +use clap::Parser; +use eyre::Result; +use indicatif::ProgressBar; + +use crate::atuin_client::{ + database::Database, + history::History, + import::{ + Importer, Loader, bash::Bash, fish::Fish, nu::Nu, nu_histdb::NuHistDb, + powershell::PowerShell, replxx::Replxx, resh::Resh, xonsh::Xonsh, + xonsh_sqlite::XonshSqlite, zsh::Zsh, zsh_histdb::ZshHistDb, + }, +}; + +#[derive(Parser, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Import history for the current shell + Auto, + + /// Import history from the zsh history file + Zsh, + /// Import history from the zsh history file + ZshHistDb, + /// Import history from the bash history file + Bash, + /// Import history from the replxx history file + Replxx, + /// Import history from the resh history file + Resh, + /// Import history from the fish history file + Fish, + /// Import history from the nu history file + Nu, + /// Import history from the nu history file + NuHistDb, + /// Import history from xonsh json files + Xonsh, + /// Import history from xonsh sqlite db + XonshSqlite, + /// Import history from the powershell history file + Powershell, +} + +const BATCH_SIZE: usize = 100; + +impl Cmd { + #[expect(clippy::cognitive_complexity)] + pub async fn run<DB: Database>(&self, db: &DB) -> Result<()> { + println!(" Atuin "); + println!("======================"); + println!(" \u{1f30d} "); + println!(" \u{1f418}\u{1f418}\u{1f418}\u{1f418} "); + println!(" \u{1f422} "); + println!("======================"); + println!("Importing history..."); + + match self { + Self::Auto => { + if cfg!(windows) { + return if env::var("PSModulePath").is_ok() { + println!("Detected PowerShell"); + import::<PowerShell, DB>(db).await + } else { + println!("Could not detect the current shell."); + println!("Please run atuin import <SHELL>."); + println!("To view a list of shells, run atuin import."); + Ok(()) + }; + } + + // $XONSH_HISTORY_BACKEND isn't always set, but $XONSH_HISTORY_FILE is + let xonsh_histfile = + env::var("XONSH_HISTORY_FILE").unwrap_or_else(|_| String::new()); + let shell = env::var("SHELL").unwrap_or_else(|_| String::from("NO_SHELL")); + + if xonsh_histfile.to_lowercase().ends_with(".json") { + println!("Detected Xonsh"); + import::<Xonsh, DB>(db).await + } else if xonsh_histfile.to_lowercase().ends_with(".sqlite") { + println!("Detected Xonsh (SQLite backend)"); + import::<XonshSqlite, DB>(db).await + } else if shell.ends_with("/zsh") { + if ZshHistDb::histpath().is_ok() { + println!( + "Detected Zsh-HistDb, using :{}", + ZshHistDb::histpath().unwrap().to_str().unwrap() + ); + import::<ZshHistDb, DB>(db).await + } else { + println!("Detected ZSH"); + import::<Zsh, DB>(db).await + } + } else if shell.ends_with("/fish") { + println!("Detected Fish"); + import::<Fish, DB>(db).await + } else if shell.ends_with("/bash") { + println!("Detected Bash"); + import::<Bash, DB>(db).await + } else if shell.ends_with("/nu") { + if NuHistDb::histpath().is_ok() { + println!( + "Detected Nu-HistDb, using :{}", + NuHistDb::histpath().unwrap().to_str().unwrap() + ); + import::<NuHistDb, DB>(db).await + } else { + println!("Detected Nushell"); + import::<Nu, DB>(db).await + } + } else if shell.ends_with("/pwsh") { + println!("Detected PowerShell"); + import::<PowerShell, DB>(db).await + } else { + println!("cannot import {shell} history"); + Ok(()) + } + } + + Self::Zsh => import::<Zsh, DB>(db).await, + Self::ZshHistDb => import::<ZshHistDb, DB>(db).await, + Self::Bash => import::<Bash, DB>(db).await, + Self::Replxx => import::<Replxx, DB>(db).await, + Self::Resh => import::<Resh, DB>(db).await, + Self::Fish => import::<Fish, DB>(db).await, + Self::Nu => import::<Nu, DB>(db).await, + Self::NuHistDb => import::<NuHistDb, DB>(db).await, + Self::Xonsh => import::<Xonsh, DB>(db).await, + Self::XonshSqlite => import::<XonshSqlite, DB>(db).await, + Self::Powershell => import::<PowerShell, DB>(db).await, + } + } +} + +pub struct HistoryImporter<'db, DB: Database> { + pb: ProgressBar, + buf: Vec<History>, + db: &'db DB, +} + +impl<'db, DB: Database> HistoryImporter<'db, DB> { + fn new(db: &'db DB, len: usize) -> Self { + Self { + pb: ProgressBar::new(len as u64), + buf: Vec::with_capacity(BATCH_SIZE), + db, + } + } + + async fn flush(self) -> Result<()> { + if !self.buf.is_empty() { + self.db.save_bulk(&self.buf).await?; + } + self.pb.finish(); + Ok(()) + } +} + +#[async_trait] +impl<DB: Database> Loader for HistoryImporter<'_, DB> { + async fn push(&mut self, hist: History) -> Result<()> { + self.pb.inc(1); + self.buf.push(hist); + if self.buf.len() == self.buf.capacity() { + self.db.save_bulk(&self.buf).await?; + self.buf.clear(); + } + Ok(()) + } +} + +async fn import<I: Importer + Send, DB: Database>(db: &DB) -> Result<()> { + println!("Importing history from {}", I::NAME); + + let mut importer = I::new().await?; + let len = importer.entries().await.unwrap(); + let mut loader = HistoryImporter::new(db, len); + importer.load(&mut loader).await?; + loader.flush().await?; + + println!("Import complete!"); + Ok(()) +} diff --git a/crates/turtle/src/command/client/info.rs b/crates/turtle/src/command/client/info.rs new file mode 100644 index 00000000..ee24c419 --- /dev/null +++ b/crates/turtle/src/command/client/info.rs @@ -0,0 +1,31 @@ +use crate::atuin_client::settings::Settings;
+
+use crate::{SHA, VERSION};
+
+pub fn run(settings: &Settings) {
+ let config = crate::atuin_common::utils::config_dir();
+ let mut config_file = config.clone();
+ config_file.push("config.toml");
+ let mut sever_config = config;
+ sever_config.push("server.toml");
+
+ let config_paths = format!(
+ "Config files:\nclient config: {:?}\nserver config: {:?}\nclient db path: {:?}\nkey path: {:?}\nmeta db path: {:?}",
+ config_file.to_string_lossy(),
+ sever_config.to_string_lossy(),
+ settings.db_path,
+ settings.key_path,
+ settings.meta.db_path
+ );
+
+ let env_vars = format!(
+ "Env Vars:\nATUIN_CONFIG_DIR = {:?}",
+ std::env::var("ATUIN_CONFIG_DIR").unwrap_or_else(|_| "None".into())
+ );
+
+ let general_info = format!("Version info:\nversion: {VERSION}\ncommit: {SHA}");
+
+ let print_out = format!("{config_paths}\n\n{env_vars}\n\n{general_info}");
+
+ println!("{print_out}");
+}
diff --git a/crates/turtle/src/command/client/init.rs b/crates/turtle/src/command/client/init.rs new file mode 100644 index 00000000..bf9747bb --- /dev/null +++ b/crates/turtle/src/command/client/init.rs @@ -0,0 +1,127 @@ +use crate::atuin_client::settings::{Settings, Tmux}; +use clap::{Parser, ValueEnum}; + +mod bash; +mod fish; +mod powershell; +mod xonsh; +mod zsh; + +#[derive(Parser, Debug)] +pub struct Cmd { + shell: Shell, + + /// Disable the binding of CTRL-R to atuin + #[clap(long)] + disable_ctrl_r: bool, + + /// Disable the binding of the Up Arrow key to atuin + #[clap(long)] + disable_up_arrow: bool, + + /// Disable the binding of ? to Atuin AI + #[clap(long)] + disable_ai: bool, +} + +#[derive(Clone, Copy, ValueEnum, Debug)] +#[value(rename_all = "lower")] +#[expect(clippy::enum_variant_names, clippy::doc_markdown)] +pub enum Shell { + /// Zsh setup + Zsh, + /// Bash setup + Bash, + /// Fish setup + Fish, + /// Nu setup + Nu, + /// Xonsh setup + Xonsh, + /// PowerShell setup + PowerShell, +} + +impl Cmd { + fn init_nu(&self, _tmux: &Tmux) { + let full = include_str!("../../shell/atuin.nu"); + + // TODO: tmux popup for Nu + println!("{full}"); + + if std::env::var("ATUIN_NOBIND").is_err() { + const BIND_CTRL_R: &str = r"$env.config = ( + $env.config | upsert keybindings ( + $env.config.keybindings + | append { + name: atuin + modifier: control + keycode: char_r + mode: [emacs, vi_normal, vi_insert] + event: { send: executehostcommand cmd: (_atuin_search_cmd) } + } + ) +)"; + const BIND_UP_ARROW: &str = r" +$env.config = ( + $env.config | upsert keybindings ( + $env.config.keybindings + | append { + name: atuin + modifier: none + keycode: up + mode: [emacs, vi_normal, vi_insert] + event: { + until: [ + {send: menuup} + {send: executehostcommand cmd: (_atuin_search_cmd '--shell-up-key-binding') } + ] + } + } + ) +) +"; + if !self.disable_ctrl_r { + println!("{BIND_CTRL_R}"); + } + if !self.disable_up_arrow { + println!("{BIND_UP_ARROW}"); + } + } + } + + fn static_init(&self, settings: &Settings) { + let tmux = &settings.tmux; + + match self.shell { + Shell::Zsh => { + zsh::init_static(self.disable_up_arrow, self.disable_ctrl_r, tmux); + } + Shell::Bash => { + bash::init_static(self.disable_up_arrow, self.disable_ctrl_r, tmux); + } + Shell::Fish => { + fish::init_static(self.disable_up_arrow, self.disable_ctrl_r, tmux); + } + Shell::Nu => { + self.init_nu(tmux); + } + Shell::Xonsh => { + xonsh::init_static(self.disable_up_arrow, self.disable_ctrl_r, tmux); + } + Shell::PowerShell => { + powershell::init_static(self.disable_up_arrow, self.disable_ctrl_r, tmux); + } + } + } + + pub fn run(self, settings: &Settings) { + if !settings.paths_ok() { + eprintln!( + "Atuin settings paths are broken. Disabling atuin shell hooks. Run `atuin doctor` to diagnose." + ); + } + + self.static_init(settings); + } +} diff --git a/crates/turtle/src/command/client/init/bash.rs b/crates/turtle/src/command/client/init/bash.rs new file mode 100644 index 00000000..fd17e37e --- /dev/null +++ b/crates/turtle/src/command/client/init/bash.rs @@ -0,0 +1,25 @@ +use crate::atuin_client::settings::Tmux; + +fn print_tmux_config(tmux: &Tmux) { + if tmux.enabled { + println!("export ATUIN_TMUX_POPUP_WIDTH='{}'", tmux.width); + println!("export ATUIN_TMUX_POPUP_HEIGHT='{}'", tmux.height); + } else { + println!("export ATUIN_TMUX_POPUP=false"); + } +} + +pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, tmux: &Tmux) { + let base = include_str!("../../../shell/atuin.bash"); + + let (bind_ctrl_r, bind_up_arrow) = if std::env::var("ATUIN_NOBIND").is_ok() { + (false, false) + } else { + (!disable_ctrl_r, !disable_up_arrow) + }; + + print_tmux_config(tmux); + println!("__atuin_bind_ctrl_r={bind_ctrl_r}"); + println!("__atuin_bind_up_arrow={bind_up_arrow}"); + println!("{base}"); +} diff --git a/crates/turtle/src/command/client/init/fish.rs b/crates/turtle/src/command/client/init/fish.rs new file mode 100644 index 00000000..8a046bfa --- /dev/null +++ b/crates/turtle/src/command/client/init/fish.rs @@ -0,0 +1,86 @@ +use crate::atuin_client::settings::Tmux; + +fn print_tmux_config(tmux: &Tmux) { + if tmux.enabled { + println!("set -gx ATUIN_TMUX_POPUP_WIDTH '{}'", tmux.width); + println!("set -gx ATUIN_TMUX_POPUP_HEIGHT '{}'", tmux.height); + } else { + println!("set -gx ATUIN_TMUX_POPUP false"); + } +} + +fn print_bindings( + indent: &str, + disable_up_arrow: bool, + disable_ctrl_r: bool, + bind_ctrl_r: &str, + bind_up_arrow: &str, + bind_ctrl_r_ins: &str, + bind_up_arrow_ins: &str, +) { + if !disable_ctrl_r { + println!("{indent}{bind_ctrl_r}"); + } + if !disable_up_arrow { + println!("{indent}{bind_up_arrow}"); + } + + println!("{indent}if bind -M insert >/dev/null 2>&1"); + if !disable_ctrl_r { + println!("{indent}{indent}{bind_ctrl_r_ins}"); + } + if !disable_up_arrow { + println!("{indent}{indent}{bind_up_arrow_ins}"); + } + println!("{indent}end"); +} + +pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, tmux: &Tmux) { + let indent = " ".repeat(4); + + let base = include_str!("../../../shell/atuin.fish"); + + print_tmux_config(tmux); + println!("{base}"); + + if std::env::var("ATUIN_NOBIND").is_err() { + println!("if string match -q '4.*' $version"); + + // In fish 4.0 and above the option bind -k doesn't exist anymore, + // instead we can use key names and modifiers directly. + print_bindings( + &indent, + disable_up_arrow, + disable_ctrl_r, + "bind ctrl-r _atuin_search", + "bind up _atuin_bind_up", + "bind -M insert ctrl-r _atuin_search", + "bind -M insert up _atuin_bind_up", + ); + + println!("else"); + + // We keep these for compatibility with fish 3.x + print_bindings( + &indent, + disable_up_arrow, + disable_ctrl_r, + r"bind \cr _atuin_search", + &[ + r"bind -k up _atuin_bind_up", + r"bind \eOA _atuin_bind_up", + r"bind \e\[A _atuin_bind_up", + ] + .join("; "), + r"bind -M insert \cr _atuin_search", + &[ + r"bind -M insert -k up _atuin_bind_up", + r"bind -M insert \eOA _atuin_bind_up", + r"bind -M insert \e\[A _atuin_bind_up", + ] + .join("; "), + ); + + println!("end"); + } +} diff --git a/crates/turtle/src/command/client/init/powershell.rs b/crates/turtle/src/command/client/init/powershell.rs new file mode 100644 index 00000000..10c0c461 --- /dev/null +++ b/crates/turtle/src/command/client/init/powershell.rs @@ -0,0 +1,23 @@ +use crate::atuin_client::settings::Tmux; + +pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, _tmux: &Tmux) { + let base = include_str!("../../../shell/atuin.ps1"); + + let (bind_ctrl_r, bind_up_arrow) = if std::env::var("ATUIN_NOBIND").is_ok() { + (false, false) + } else { + (!disable_ctrl_r, !disable_up_arrow) + }; + + // TODO: tmux popup for Powershell + println!("{base}"); + println!( + "Enable-AtuinSearchKeys -CtrlR {} -UpArrow {}", + ps_bool(bind_ctrl_r), + ps_bool(bind_up_arrow) + ); +} + +fn ps_bool(value: bool) -> &'static str { + if value { "$true" } else { "$false" } +} diff --git a/crates/turtle/src/command/client/init/xonsh.rs b/crates/turtle/src/command/client/init/xonsh.rs new file mode 100644 index 00000000..a17d85d8 --- /dev/null +++ b/crates/turtle/src/command/client/init/xonsh.rs @@ -0,0 +1,22 @@ +use crate::atuin_client::settings::Tmux; + +pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, _tmux: &Tmux) { + let base = include_str!("../../../shell/atuin.xsh"); + + let (bind_ctrl_r, bind_up_arrow) = if std::env::var("ATUIN_NOBIND").is_ok() { + (false, false) + } else { + (!disable_ctrl_r, !disable_up_arrow) + }; + + // TODO: tmux popup for xonsh + println!( + "_ATUIN_BIND_CTRL_R={}", + if bind_ctrl_r { "True" } else { "False" } + ); + println!( + "_ATUIN_BIND_UP_ARROW={}", + if bind_up_arrow { "True" } else { "False" } + ); + println!("{base}"); +} diff --git a/crates/turtle/src/command/client/init/zsh.rs b/crates/turtle/src/command/client/init/zsh.rs new file mode 100644 index 00000000..38c3086b --- /dev/null +++ b/crates/turtle/src/command/client/init/zsh.rs @@ -0,0 +1,38 @@ +use crate::atuin_client::settings::Tmux; + +fn print_tmux_config(tmux: &Tmux) { + if tmux.enabled { + println!("export ATUIN_TMUX_POPUP_WIDTH='{}'", tmux.width); + println!("export ATUIN_TMUX_POPUP_HEIGHT='{}'", tmux.height); + } else { + println!("export ATUIN_TMUX_POPUP=false"); + } +} + +pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, tmux: &Tmux) { + let base = include_str!("../../../shell/atuin.zsh"); + + print_tmux_config(tmux); + println!("{base}"); + + if std::env::var("ATUIN_NOBIND").is_err() { + const BIND_CTRL_R: &str = r"bindkey -M emacs '^r' atuin-search +bindkey -M viins '^r' atuin-search-viins +bindkey -M vicmd '/' atuin-search"; + + const BIND_UP_ARROW: &str = r"bindkey -M emacs '^[[A' atuin-up-search +bindkey -M vicmd '^[[A' atuin-up-search-vicmd +bindkey -M viins '^[[A' atuin-up-search-viins +bindkey -M emacs '^[OA' atuin-up-search +bindkey -M vicmd '^[OA' atuin-up-search-vicmd +bindkey -M viins '^[OA' atuin-up-search-viins +bindkey -M vicmd 'k' atuin-up-search-vicmd"; + + if !disable_ctrl_r { + println!("{BIND_CTRL_R}"); + } + if !disable_up_arrow { + println!("{BIND_UP_ARROW}"); + } + } +} diff --git a/crates/turtle/src/command/client/search.rs b/crates/turtle/src/command/client/search.rs new file mode 100644 index 00000000..4a2114d5 --- /dev/null +++ b/crates/turtle/src/command/client/search.rs @@ -0,0 +1,375 @@ +use std::fs::File; +use std::io::{IsTerminal as _, Write, stderr, stdout}; + +use crate::atuin_common::utils::{self, Escapable as _}; +use clap::Parser; +use eyre::Result; + +use crate::atuin_client::{ + database::Database, + database::{OptFilters, current_context}, + encryption, + history::{History, store::HistoryStore}, + record::sqlite_store::SqliteStore, + settings::{FilterMode, KeymapMode, SearchMode, Settings, Timezone}, + theme::Theme, +}; + +use super::history::ListMode; + +mod cursor; +mod duration; +mod engines; +mod history_list; +mod inspector; +mod interactive; +pub mod keybindings; + +pub use duration::format_duration_into; + +#[expect(clippy::struct_excessive_bools, clippy::struct_field_names)] +#[derive(Parser, Debug)] +pub struct Cmd { + /// Filter search result by directory + #[arg(long, short)] + cwd: Option<String>, + + /// Exclude directory from results + #[arg(long = "exclude-cwd")] + exclude_cwd: Option<String>, + + /// Filter search result by exit code + #[arg(long, short)] + exit: Option<i64>, + + /// Exclude results with this exit code + #[arg(long = "exclude-exit")] + exclude_exit: Option<i64>, + + /// Only include results added before this date + #[arg(long, short)] + before: Option<String>, + + /// Only include results after this date + #[arg(long)] + after: Option<String>, + + /// How many entries to return at most + #[arg(long)] + limit: Option<i64>, + + /// Offset from the start of the results + #[arg(long)] + offset: Option<i64>, + + /// Open interactive search UI + #[arg(long, short)] + interactive: bool, + + /// Allow overriding filter mode over config + #[arg(long = "filter-mode")] + filter_mode: Option<FilterMode>, + + /// Allow overriding search mode over config + #[arg(long = "search-mode")] + search_mode: Option<SearchMode>, + + /// Marker argument used to inform atuin that it was invoked from a shell up-key binding (hidden from help to avoid confusion) + #[arg(long = "shell-up-key-binding", hide = true)] + shell_up_key_binding: bool, + + /// Notify the keymap at the shell's side + #[arg(long = "keymap-mode", default_value = "auto")] + keymap_mode: KeymapMode, + + /// Use human-readable formatting for time + #[arg(long)] + human: bool, + + #[arg(allow_hyphen_values = true)] + query: Option<Vec<String>>, + + /// Show only the text of the command + #[arg(long)] + cmd_only: bool, + + /// Terminate the output with a null, for better multiline handling + #[arg(long)] + print0: bool, + + /// Delete anything matching this query. Will not print out the match + #[arg(long)] + delete: bool, + + /// Delete EVERYTHING! + #[arg(long)] + delete_it_all: bool, + + /// Reverse the order of results, oldest first + #[arg(long, short)] + reverse: bool, + + /// Display the command time in another timezone other than the configured default. + /// + /// This option takes one of the following kinds of values: + /// - the special value "local" (or "l") which refers to the system time zone + /// - an offset from UTC (e.g. "+9", "-2:30") + #[arg(long, visible_alias = "tz")] + #[arg(allow_hyphen_values = true)] + // Clippy warns about `Option<Option<T>>`, but we suppress it because we need + // this distinction for proper argument handling. + #[expect(clippy::option_option)] + timezone: Option<Option<Timezone>>, + + /// Available variables: {command}, {directory}, {duration}, {user}, {host}, {time}, {exit} and + /// {relativetime}. + /// Example: --format "{time} - [{duration}] - {directory}$\t{command}" + #[arg(long, short)] + format: Option<String>, + + /// Set the maximum number of lines Atuin's interface should take up. + #[arg(long = "inline-height")] + inline_height: Option<u16>, + + /// Filter by author. Supports $all-user (non-agents), $all-agent, or literal names. + /// Can be specified multiple times. + #[arg(long)] + author: Option<Vec<String>>, + + /// Include duplicate commands in the output (non-interactive only) + #[arg(long)] + include_duplicates: bool, + + /// File name to write the result to (hidden from help as this is meant to be used from a script) + #[arg(long = "result-file", hide = true)] + result_file: Option<String>, +} + +impl Cmd { + /// Returns true if this search command will run in interactive (TUI) mode + pub fn is_interactive(&self) -> bool { + self.interactive + } + + // clippy: please write this instead + // clippy: now it has too many lines + // me: I'll do it later OKAY + #[expect(clippy::too_many_lines)] + pub async fn run( + self, + db: impl Database, + settings: &mut Settings, + store: SqliteStore, + theme: &Theme, + ) -> Result<()> { + let query = self.query.unwrap_or_else(|| { + std::env::var("ATUIN_QUERY").map_or_else( + |_| vec![], + |query| { + query + .split(' ') + .map(std::string::ToString::to_string) + .collect() + }, + ) + }); + + if (self.delete_it_all || self.delete) && self.limit.is_some() { + // Because of how deletion is implemented, it will always delete all matches + // and disregard the limit option. It is also not clear what deletion with a + // limit would even mean. Deleting the LIMIT most recent entries that match + // the search query would make sense, but that wouldn't match what's displayed + // when running the equivalent search, but deleting those entries that are + // displayed with the search would leave any duplicates of those lines which may + // or may not have been intended to be deleted. + eprintln!("\"--limit\" is not compatible with deletion."); + return Ok(()); + } + + if self.delete && query.is_empty() { + eprintln!( + "Please specify a query to match the items you wish to delete. If you wish to delete all history, pass --delete-it-all" + ); + return Ok(()); + } + + if self.delete_it_all && !query.is_empty() { + eprintln!( + "--delete-it-all will delete ALL of your history! It does not require a query." + ); + return Ok(()); + } + + if let Some(search_mode) = self.search_mode { + settings.search_mode = search_mode; + } + if let Some(filter_mode) = self.filter_mode { + settings.filter_mode = Some(filter_mode); + } + if let Some(inline_height) = self.inline_height { + settings.inline_height = inline_height; + } + + settings.shell_up_key_binding = self.shell_up_key_binding; + + // `keymap_mode` specified in config.toml overrides the `--keymap-mode` + // option specified in the keybindings. + settings.keymap_mode = match settings.keymap_mode { + KeymapMode::Auto => self.keymap_mode, + value => value, + }; + settings.keymap_mode_shell = self.keymap_mode; + + let encryption_key: [u8; 32] = encryption::load_key(settings)?.into(); + + let host_id = Settings::host_id().await?; + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + + if self.interactive { + let item = interactive::history(&query, settings, db, &history_store, theme).await?; + + if let Some(result_file) = self.result_file { + let mut file = File::create(result_file)?; + write!(file, "{item}")?; + } else if !stdout().is_terminal() { + // stdout is not a terminal - likely command substitution like VAR=$(atuin search -i) + // Write to stdout so it gets captured. This requires some care on Windows, as the current + // console code page or `[Console]::OutputEncoding` on PowerShell may be different from UTF-8. + println!("{item}"); + } else if stderr().is_terminal() { + eprintln!("{}", item.escape_control()); + } else { + eprintln!("{item}"); + } + } else { + let opt_filter = OptFilters { + exit: self.exit, + exclude_exit: self.exclude_exit, + cwd: self.cwd, + exclude_cwd: self.exclude_cwd, + before: self.before, + after: self.after, + limit: self.limit, + offset: self.offset, + reverse: self.reverse, + include_duplicates: self.include_duplicates, + authors: self.author.clone().unwrap_or_default(), + }; + + let mut entries = + run_non_interactive(settings, opt_filter.clone(), &query, &db).await?; + + if entries.is_empty() { + std::process::exit(1) + } + + // if we aren't deleting, print it all + if self.delete || self.delete_it_all { + // delete it + // it only took me _years_ to add this + // sorry + while !entries.is_empty() { + for entry in &entries { + eprintln!("deleting {}", entry.id); + } + + let ids = history_store.delete_entries(entries).await?; + history_store.incremental_build(&db, &ids).await?; + + entries = + run_non_interactive(settings, opt_filter.clone(), &query, &db).await?; + } + } else { + let format = match self.format { + None => Some(settings.history_format.as_str()), + _ => self.format.as_deref(), + }; + let tz = match self.timezone { + Some(Some(tz)) => tz, // User provided a value + Some(None) | None => settings.timezone, // No value was provided + }; + + super::history::print_list( + &entries, + ListMode::from_flags(self.human, self.cmd_only), + format, + self.print0, + true, + tz, + ); + } + } + Ok(()) + } +} + +// This is supposed to more-or-less mirror the command line version, so ofc +// it is going to have a lot of args +#[expect(clippy::too_many_arguments, clippy::cast_possible_truncation)] +async fn run_non_interactive( + settings: &Settings, + filter_options: OptFilters, + query: &[String], + db: &impl Database, +) -> Result<Vec<History>> { + let dir = if filter_options.cwd.as_deref() == Some(".") { + Some(utils::get_current_dir()) + } else { + filter_options.cwd + }; + + let context = current_context().await?; + + let opt_filter = OptFilters { + cwd: dir.clone(), + ..filter_options + }; + + let filter_mode = settings.default_filter_mode(context.git_root.is_some()); + + let results = db + .search( + settings.search_mode, + filter_mode, + &context, + query.join(" ").as_str(), + opt_filter, + ) + .await?; + + Ok(results) +} + +#[cfg(test)] +mod tests { + use super::Cmd; + use clap::Parser; + + #[test] + fn search_for_triple_dash() { + // Issue #3028: searching for `---` should not be treated as a CLI flag + let cmd = Cmd::try_parse_from(["search", "---"]); + assert!(cmd.is_ok(), "Failed to parse '---' as a query: {cmd:?}"); + let cmd = cmd.unwrap(); + assert_eq!(cmd.query, Some(vec!["---".to_string()])); + } + + #[test] + fn search_for_double_dash_value() { + // Searching for strings starting with -- should also work + let cmd = Cmd::try_parse_from(["search", "--", "--foo"]); + assert!(cmd.is_ok()); + let cmd = cmd.unwrap(); + assert_eq!(cmd.query, Some(vec!["--foo".to_string()])); + } + + #[test] + fn search_author_cli_flag() { + let cmd = + Cmd::try_parse_from(["search", "--author", "codex", "--author", "ellie"]).unwrap(); + assert_eq!( + cmd.author, + Some(vec!["codex".to_string(), "ellie".to_string()]) + ); + } +} diff --git a/crates/turtle/src/command/client/search/cursor.rs b/crates/turtle/src/command/client/search/cursor.rs new file mode 100644 index 00000000..84f94082 --- /dev/null +++ b/crates/turtle/src/command/client/search/cursor.rs @@ -0,0 +1,405 @@ +use crate::atuin_client::settings::WordJumpMode; + +pub struct Cursor { + source: String, + index: usize, +} + +impl From<String> for Cursor { + fn from(source: String) -> Self { + Self { source, index: 0 } + } +} + +pub struct WordJumper<'a> { + word_chars: &'a str, + word_jump_mode: WordJumpMode, +} + +impl WordJumper<'_> { + fn is_word_boundary(&self, c: char, next_c: char) -> bool { + (c.is_whitespace() && !next_c.is_whitespace()) + || (!c.is_whitespace() && next_c.is_whitespace()) + || (self.word_chars.contains(c) && !self.word_chars.contains(next_c)) + || (!self.word_chars.contains(c) && self.word_chars.contains(next_c)) + } + + fn emacs_get_next_word_pos(&self, source: &str, index: usize) -> usize { + let index = (index + 1..source.len().saturating_sub(1)) + .find(|&i| self.word_chars.contains(source.chars().nth(i).unwrap())) + .unwrap_or(source.len()); + (index + 1..source.len().saturating_sub(1)) + .find(|&i| !self.word_chars.contains(source.chars().nth(i).unwrap())) + .unwrap_or(source.len()) + } + + fn emacs_get_prev_word_pos(&self, source: &str, index: usize) -> usize { + let index = (1..index) + .rev() + .find(|&i| self.word_chars.contains(source.chars().nth(i).unwrap())) + .unwrap_or(0); + (1..index) + .rev() + .find(|&i| !self.word_chars.contains(source.chars().nth(i).unwrap())) + .map_or(0, |i| i + 1) + } + + fn subl_get_next_word_pos(&self, source: &str, index: usize) -> usize { + let index = (index..source.len().saturating_sub(1)).find(|&i| { + self.is_word_boundary( + source.chars().nth(i).unwrap(), + source.chars().nth(i + 1).unwrap(), + ) + }); + if index.is_none() { + return source.len(); + } + (index.unwrap() + 1..source.len()) + .find(|&i| !source.chars().nth(i).unwrap().is_whitespace()) + .unwrap_or(source.len()) + } + + fn subl_get_prev_word_pos(&self, source: &str, index: usize) -> usize { + let index = (1..index) + .rev() + .find(|&i| !source.chars().nth(i).unwrap().is_whitespace()); + if index.is_none() { + return 0; + } + (1..index.unwrap()) + .rev() + .find(|&i| { + self.is_word_boundary( + source.chars().nth(i - 1).unwrap(), + source.chars().nth(i).unwrap(), + ) + }) + .unwrap_or(0) + } + + fn get_next_word_pos(&self, source: &str, index: usize) -> usize { + match self.word_jump_mode { + WordJumpMode::Emacs => self.emacs_get_next_word_pos(source, index), + WordJumpMode::Subl => self.subl_get_next_word_pos(source, index), + } + } + + fn get_prev_word_pos(&self, source: &str, index: usize) -> usize { + match self.word_jump_mode { + WordJumpMode::Emacs => self.emacs_get_prev_word_pos(source, index), + WordJumpMode::Subl => self.subl_get_prev_word_pos(source, index), + } + } +} + +impl Cursor { + pub fn as_str(&self) -> &str { + self.source.as_str() + } + + pub fn into_inner(self) -> String { + self.source + } + + /// Returns the string before the cursor + pub fn substring(&self) -> &str { + &self.source[..self.index] + } + + /// Returns the currently selected [`char`] + pub fn char(&self) -> Option<char> { + self.source[self.index..].chars().next() + } + + pub fn right(&mut self) { + if self.index < self.source.len() { + loop { + self.index += 1; + if self.source.is_char_boundary(self.index) { + break; + } + } + } + } + + pub fn left(&mut self) -> bool { + if self.index > 0 { + loop { + self.index -= 1; + if self.source.is_char_boundary(self.index) { + break true; + } + } + } else { + false + } + } + + pub fn next_word(&mut self, word_chars: &str, word_jump_mode: WordJumpMode) { + let word_jumper = WordJumper { + word_chars, + word_jump_mode, + }; + self.index = word_jumper.get_next_word_pos(&self.source, self.index); + } + + pub fn prev_word(&mut self, word_chars: &str, word_jump_mode: WordJumpMode) { + let word_jumper = WordJumper { + word_chars, + word_jump_mode, + }; + self.index = word_jumper.get_prev_word_pos(&self.source, self.index); + } + + /// Move cursor to the end of the current/next word (vim `e` motion). + /// + /// If cursor is in the middle of a word, moves to the end of that word. + /// If cursor is at the end of a word (or on whitespace), moves to the + /// end of the next word. + pub fn word_end(&mut self, word_chars: &str) { + let len = self.source.len(); + if self.index >= len { + return; + } + + let chars: Vec<char> = self.source.chars().collect(); + let mut char_idx = self.source[..self.index].chars().count(); + + if char_idx >= chars.len() { + return; + } + + let current = chars[char_idx]; + + // Check if we're at a word boundary (end of current word or on whitespace) + let at_word_boundary = current.is_whitespace() || char_idx + 1 >= chars.len() || { + let next = chars[char_idx + 1]; + next.is_whitespace() || (word_chars.contains(current) != word_chars.contains(next)) + }; + + // If at word boundary, advance past it and skip whitespace to find next word + if at_word_boundary { + char_idx += 1; + while char_idx < chars.len() && chars[char_idx].is_whitespace() { + char_idx += 1; + } + } + + // If we've gone past end, go to end of string + if char_idx >= chars.len() { + self.index = len; + return; + } + + // Find end of word: advance until next char is whitespace or different word type + let in_word_chars = word_chars.contains(chars[char_idx]); + while char_idx < chars.len() { + let next_idx = char_idx + 1; + if next_idx >= chars.len() { + // At last char, move past it + char_idx = next_idx; + break; + } + let next_c = chars[next_idx]; + if next_c.is_whitespace() || (word_chars.contains(next_c) != in_word_chars) { + // Next char is start of new word/whitespace, so current char is end + char_idx = next_idx; + break; + } + char_idx += 1; + } + + // Convert char index back to byte index + self.index = chars.iter().take(char_idx).map(|c| c.len_utf8()).sum(); + } + + pub fn insert(&mut self, c: char) { + self.source.insert(self.index, c); + self.index += c.len_utf8(); + } + + pub fn remove(&mut self) -> Option<char> { + if self.index < self.source.len() { + Some(self.source.remove(self.index)) + } else { + None + } + } + + pub fn remove_next_word(&mut self, word_chars: &str, word_jump_mode: WordJumpMode) { + let word_jumper = WordJumper { + word_chars, + word_jump_mode, + }; + let next_index = word_jumper.get_next_word_pos(&self.source, self.index); + self.source.replace_range(self.index..next_index, ""); + } + + pub fn remove_prev_word(&mut self, word_chars: &str, word_jump_mode: WordJumpMode) { + let word_jumper = WordJumper { + word_chars, + word_jump_mode, + }; + let next_index = word_jumper.get_prev_word_pos(&self.source, self.index); + self.source.replace_range(next_index..self.index, ""); + self.index = next_index; + } + + pub fn back(&mut self) -> Option<char> { + if self.left() { self.remove() } else { None } + } + + pub fn clear(&mut self) { + self.source.clear(); + self.index = 0; + } + + pub fn clear_to_start(&mut self) { + self.source.replace_range(..self.index, ""); + self.index = 0; + } + + pub fn clear_to_end(&mut self) { + self.source.replace_range(self.index.., ""); + self.index = self.source.len(); + } + + pub fn end(&mut self) { + self.index = self.source.len(); + } + + pub fn start(&mut self) { + self.index = 0; + } + + pub fn position(&self) -> usize { + self.index + } +} + +#[cfg(test)] +mod cursor_tests { + use super::Cursor; + use super::*; + + static EMACS_WORD_JUMPER: WordJumper = WordJumper { + word_chars: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", + word_jump_mode: WordJumpMode::Emacs, + }; + + static SUBL_WORD_JUMPER: WordJumper = WordJumper { + word_chars: "./\\()\"'-:,.;<>~!@#$%^&*|+=[]{}`~?", + word_jump_mode: WordJumpMode::Subl, + }; + + #[test] + fn right() { + // ö is 2 bytes + let mut c = Cursor::from(String::from("öaöböcödöeöfö")); + let indices = [0, 2, 3, 5, 6, 8, 9, 11, 12, 14, 15, 17, 18, 20, 20, 20, 20]; + for i in indices { + assert_eq!(c.index, i); + c.right(); + } + } + + #[test] + fn left() { + // ö is 2 bytes + let mut c = Cursor::from(String::from("öaöböcödöeöfö")); + c.end(); + let indices = [20, 18, 17, 15, 14, 12, 11, 9, 8, 6, 5, 3, 2, 0, 0, 0, 0]; + for i in indices { + assert_eq!(c.index, i); + c.left(); + } + } + + #[test] + fn test_emacs_get_next_word_pos() { + let s = String::from(" aaa ((()))bbb ((())) "); + let indices = [(0, 6), (3, 6), (7, 18), (19, 30)]; + for (i_src, i_dest) in indices { + assert_eq!(EMACS_WORD_JUMPER.get_next_word_pos(&s, i_src), i_dest); + } + assert_eq!(EMACS_WORD_JUMPER.get_next_word_pos("", 0), 0); + } + + #[test] + fn test_emacs_get_prev_word_pos() { + let s = String::from(" aaa ((()))bbb ((())) "); + let indices = [(30, 15), (29, 15), (15, 3), (3, 0)]; + for (i_src, i_dest) in indices { + assert_eq!(EMACS_WORD_JUMPER.get_prev_word_pos(&s, i_src), i_dest); + } + assert_eq!(EMACS_WORD_JUMPER.get_prev_word_pos("", 0), 0); + } + + #[test] + fn test_subl_get_next_word_pos() { + let s = String::from(" aaa ((()))bbb ((())) "); + let indices = [(0, 3), (1, 3), (3, 9), (9, 15), (15, 21), (21, 30)]; + for (i_src, i_dest) in indices { + assert_eq!(SUBL_WORD_JUMPER.get_next_word_pos(&s, i_src), i_dest); + } + assert_eq!(SUBL_WORD_JUMPER.get_next_word_pos("", 0), 0); + } + + #[test] + fn test_subl_get_prev_word_pos() { + let s = String::from(" aaa ((()))bbb ((())) "); + let indices = [(30, 21), (21, 15), (15, 9), (9, 3), (3, 0)]; + for (i_src, i_dest) in indices { + assert_eq!(SUBL_WORD_JUMPER.get_prev_word_pos(&s, i_src), i_dest); + } + assert_eq!(SUBL_WORD_JUMPER.get_prev_word_pos("", 0), 0); + } + + #[test] + fn pop() { + let mut s = String::from("öaöböcödöeöfö"); + let mut c = Cursor::from(s.clone()); + c.end(); + while !s.is_empty() { + let c1 = s.pop(); + let c2 = c.back(); + assert_eq!(c1, c2); + assert_eq!(s.as_str(), c.substring()); + } + let c1 = s.pop(); + let c2 = c.back(); + assert_eq!(c1, c2); + } + + #[test] + fn back() { + let mut c = Cursor::from(String::from("öaöböcödöeöfö")); + // move to ^ + for _ in 0..4 { + c.right(); + } + assert_eq!(c.substring(), "öaöb"); + assert_eq!(c.back(), Some('b')); + assert_eq!(c.back(), Some('ö')); + assert_eq!(c.back(), Some('a')); + assert_eq!(c.back(), Some('ö')); + assert_eq!(c.back(), None); + assert_eq!(c.as_str(), "öcödöeöfö"); + } + + #[test] + fn insert() { + let mut c = Cursor::from(String::from("öaöböcödöeöfö")); + // move to ^ + for _ in 0..4 { + c.right(); + } + assert_eq!(c.substring(), "öaöb"); + c.insert('ö'); + c.insert('g'); + c.insert('ö'); + c.insert('h'); + assert_eq!(c.substring(), "öaöbögöh"); + assert_eq!(c.as_str(), "öaöbögöhöcödöeöfö"); + } +} diff --git a/crates/turtle/src/command/client/search/duration.rs b/crates/turtle/src/command/client/search/duration.rs new file mode 100644 index 00000000..54856c87 --- /dev/null +++ b/crates/turtle/src/command/client/search/duration.rs @@ -0,0 +1,65 @@ +use core::fmt; +use std::{ops::ControlFlow, time::Duration}; + +#[expect(clippy::module_name_repetitions)] +pub fn format_duration_into(dur: Duration, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn item(unit: &'static str, value: u64) -> ControlFlow<(&'static str, u64)> { + if value > 0 { + ControlFlow::Break((unit, value)) + } else { + ControlFlow::Continue(()) + } + } + + // impl taken and modified from + // https://github.com/tailhook/humantime/blob/master/src/duration.rs#L295-L331 + // Copyright (c) 2016 The humantime Developers + fn fmt(f: Duration) -> ControlFlow<(&'static str, u64), ()> { + let secs = f.as_secs(); + let nanos = f.subsec_nanos(); + + let years = secs / 31_557_600; // 365.25d + let year_days = secs % 31_557_600; + let months = year_days / 2_630_016; // 30.44d + let month_days = year_days % 2_630_016; + let days = month_days / 86400; + let day_secs = month_days % 86400; + let hours = day_secs / 3600; + let minutes = day_secs % 3600 / 60; + let seconds = day_secs % 60; + + let millis = nanos / 1_000_000; + let micros = nanos / 1_000; + + // a difference from our impl than the original is that + // we only care about the most-significant segment of the duration. + // If the item call returns `Break`, then the `?` will early-return. + // This allows for a very consise impl + item("y", years)?; + item("mo", months)?; + item("d", days)?; + item("h", hours)?; + item("m", minutes)?; + item("s", seconds)?; + item("ms", u64::from(millis))?; + item("us", u64::from(micros))?; + item("ns", u64::from(nanos))?; + ControlFlow::Continue(()) + } + + match fmt(dur) { + ControlFlow::Break((unit, value)) => write!(f, "{value}{unit}"), + ControlFlow::Continue(()) => write!(f, "0s"), + } +} + +#[expect(clippy::module_name_repetitions)] +pub fn format_duration(f: Duration) -> String { + struct F(Duration); + impl fmt::Display for F { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + format_duration_into(self.0, f) + } + } + F(f).to_string() +} diff --git a/crates/turtle/src/command/client/search/engines.rs b/crates/turtle/src/command/client/search/engines.rs new file mode 100644 index 00000000..0f92b4c7 --- /dev/null +++ b/crates/turtle/src/command/client/search/engines.rs @@ -0,0 +1,95 @@ +use async_trait::async_trait; +use crate::atuin_client::{ + database::{Context, Database, OptFilters}, + history::{AUTHOR_FILTER_ALL_USER, History, HistoryId}, + settings::{FilterMode, SearchMode, Settings}, +}; +use eyre::Result; + +use super::cursor::Cursor; + +#[cfg(feature = "daemon")] +pub mod daemon; +pub mod db; +pub mod skim; + +#[expect(unused)] // settings is only used if daemon feature is enabled +pub fn engine(search_mode: SearchMode, settings: &Settings) -> Box<dyn SearchEngine> { + match search_mode { + SearchMode::Skim => Box::new(skim::Search::new()) as Box<_>, + #[cfg(feature = "daemon")] + SearchMode::DaemonFuzzy => Box::new(daemon::Search::new(settings)) as Box<_>, + #[cfg(not(feature = "daemon"))] + SearchMode::DaemonFuzzy => { + // Fall back to fuzzy mode if daemon feature is not enabled + Box::new(db::Search(SearchMode::Fuzzy)) as Box<_> + } + mode => Box::new(db::Search(mode)) as Box<_>, + } +} + +pub struct SearchState { + pub input: Cursor, + pub filter_mode: FilterMode, + pub context: Context, + pub custom_context: Option<HistoryId>, +} + +impl SearchState { + pub(crate) fn rotate_filter_mode(&mut self, settings: &Settings, offset: isize) { + let mut i = settings + .search + .filters + .iter() + .position(|&m| m == self.filter_mode) + .unwrap_or_default(); + for _ in 0..settings.search.filters.len() { + i = (i.wrapping_add_signed(offset)) % settings.search.filters.len(); + let mode = settings.search.filters[i]; + if self.filter_mode_available(mode, settings) { + self.filter_mode = mode; + break; + } + } + } + + fn filter_mode_available(&self, mode: FilterMode, settings: &Settings) -> bool { + match mode { + FilterMode::Global | FilterMode::SessionPreload => self.custom_context.is_none(), + FilterMode::Workspace => settings.workspaces && self.context.git_root.is_some(), + _ => true, + } + } +} + +#[async_trait] +pub trait SearchEngine: Send + Sync + 'static { + async fn full_query( + &mut self, + state: &SearchState, + db: &mut dyn Database, + ) -> Result<Vec<History>>; + + async fn query(&mut self, state: &SearchState, db: &mut dyn Database) -> Result<Vec<History>> { + if state.input.as_str().is_empty() { + Ok(db + .search( + SearchMode::FullText, + state.filter_mode, + &state.context, + "", + OptFilters { + limit: Some(200), + authors: vec![AUTHOR_FILTER_ALL_USER.to_string()], + ..Default::default() + }, + ) + .await? + .into_iter() + .collect::<Vec<_>>()) + } else { + self.full_query(state, db).await + } + } + fn get_highlight_indices(&self, command: &str, search_input: &str) -> Vec<usize>; +} diff --git a/crates/turtle/src/command/client/search/engines/daemon.rs b/crates/turtle/src/command/client/search/engines/daemon.rs new file mode 100644 index 00000000..b1299c02 --- /dev/null +++ b/crates/turtle/src/command/client/search/engines/daemon.rs @@ -0,0 +1,242 @@ +use crate::atuin_client::{ + database::{Database, OptFilters}, + history::{AUTHOR_FILTER_ALL_USER, History}, + settings::{SearchMode, Settings}, +}; +use crate::atuin_daemon::client::{DaemonClientErrorKind, SearchClient, classify_error}; +use async_trait::async_trait; +use atuin_nucleo_matcher::{ + Config, Matcher, Utf32Str, + pattern::{CaseMatching, Normalization, Pattern}, +}; +use eyre::Result; +use tracing::{Level, debug, instrument, span}; +use uuid::Uuid; + +use super::{SearchEngine, SearchState}; +use crate::command::client::daemon; + +pub struct Search { + client: Option<SearchClient>, + query_id: u64, + settings: Settings, + #[cfg(unix)] + socket_path: String, +} + +impl Search { + pub fn new(settings: &Settings) -> Self { + Search { + client: None, + query_id: 0, + settings: settings.clone(), + #[cfg(unix)] + socket_path: settings.daemon.socket_path.clone(), + } + } + + #[instrument(skip_all, level = Level::TRACE, name = "get_daemon_client")] + async fn get_client(&mut self) -> Result<&mut SearchClient> { + if self.client.is_none() { + self.connect().await?; + } + Ok(self.client.as_mut().unwrap()) + } + + async fn connect(&mut self) -> Result<()> { + #[cfg(unix)] + let client = SearchClient::new(self.socket_path.clone()).await?; + + self.client = Some(client); + Ok(()) + } + + fn should_retry(err: &eyre::Report) -> bool { + matches!( + classify_error(err), + DaemonClientErrorKind::Connect + | DaemonClientErrorKind::Unavailable + | DaemonClientErrorKind::Unimplemented + ) + } + + fn next_query_id(&mut self) -> u64 { + self.query_id += 1; + self.query_id + } + + /// Check if query contains regex pattern (r/.../) + /// Nucleo doesn't support regex, so we fall back to database search + fn contains_regex_pattern(query: &str) -> bool { + query.starts_with("r/") || query.contains(" r/") + } + + #[instrument(skip_all, level = Level::TRACE, name = "daemon_db_fallback")] + async fn fallback_to_db_search( + &self, + state: &SearchState, + db: &dyn Database, + ) -> Result<Vec<History>> { + let results = db + .search( + SearchMode::FullText, + state.filter_mode, + &state.context, + state.input.as_str(), + OptFilters { + limit: Some(200), + authors: vec![AUTHOR_FILTER_ALL_USER.to_string()], + ..Default::default() + }, + ) + .await + .map_or(Vec::new(), |r| r.into_iter().collect()); + Ok(results) + } + + #[instrument(skip_all, level = Level::TRACE, name = "hydrate_from_db", fields(count = ids.len()))] + async fn hydrate_from_db(&self, db: &dyn Database, ids: &[String]) -> Result<Vec<History>> { + let placeholders: Vec<String> = ids.iter().map(|id| format!("'{id}'")).collect(); + let sql_query = format!( + "SELECT * FROM history WHERE id IN ({}) ORDER BY timestamp DESC", + placeholders.join(",") + ); + Ok(db.query_history(&sql_query).await?) + } +} + +#[async_trait] +impl SearchEngine for Search { + #[instrument(skip_all, level = Level::TRACE, name = "daemon_search", fields(query = %state.input.as_str()))] + async fn full_query( + &mut self, + state: &SearchState, + db: &mut dyn Database, + ) -> Result<Vec<History>> { + let query = state.input.as_str().to_string(); + + // Fall back to database for regex queries (Nucleo doesn't support regex) + if Self::contains_regex_pattern(&query) { + debug!(query = %query, "[daemon-client] regex detected, falling back to db"); + return self.fallback_to_db_search(state, db).await; + } + + let query_id = self.next_query_id(); + + let span = + span!(Level::TRACE, "daemon_search.req_resp", query = %query, query_id = query_id); + + // Try to connect and search; if it fails with a retriable error, + // auto-start the daemon and retry once. + let first_attempt = async { + let client = self.get_client().await?; + client + .search( + query.clone(), + query_id, + state.filter_mode, + Some(state.context.clone()), + ) + .await + } + .await; + + let mut stream = match first_attempt { + Ok(stream) => stream, + Err(err) if self.settings.daemon.autostart && Self::should_retry(&err) => { + debug!("daemon not available, attempting auto-start"); + self.client = None; + + daemon::ensure_daemon_running(&self.settings).await?; + + let client = self.get_client().await?; + client + .search( + query.clone(), + query_id, + state.filter_mode, + Some(state.context.clone()), + ) + .await? + } + Err(err) => return Err(err), + }; + + let mut ids = Vec::with_capacity(200); + span!(Level::TRACE, "daemon_search.resp") + .in_scope(async || { + while let Ok(Some(response)) = stream.message().await { + let span2 = span!( + Level::TRACE, + "daemon_search.resp.item", + query_id = response.query_id + ); + let _span2 = span2.enter(); + // Only process if the query_id matches (prevents stale responses) + if response.query_id == query_id { + let uuids = response + .ids + .iter() + .map(|id| { + let bytes: [u8; 16] = + id.as_slice().try_into().expect("id should be 16 bytes"); + Uuid::from_bytes(bytes).as_simple().to_string() + }) + .collect::<Vec<_>>(); + ids.extend(uuids); + } + drop(_span2); + drop(span2); + } + }) + .await; + drop(span); + + if ids.is_empty() { + debug!(query = %query, results = 0, "[daemon-client] empty results"); + return Ok(Vec::new()); + } + + // // Hydrate from local database + let results = self.hydrate_from_db(db, &ids).await?; + + // // Reorder results to match the order from the daemon (which is ranked by relevance) + let ordered_results = span!(Level::TRACE, "reorder_results").in_scope(|| { + let mut ordered_results = Vec::with_capacity(results.len()); + for id in &ids { + if let Some(history) = results.iter().find(|h| h.id.0 == *id) { + ordered_results.push(history.clone()); + } + } + ordered_results + }); + + debug!( + query = %query, + results = results.len(), + "[daemon-client]" + ); + + Ok(ordered_results) + } + + #[instrument(skip_all, level = Level::TRACE, name = "daemon_highlight")] + fn get_highlight_indices(&self, command: &str, search_input: &str) -> Vec<usize> { + // Use fulltext highlighting for regex queries + if Self::contains_regex_pattern(search_input) { + return super::db::get_highlight_indices_fulltext(command, search_input); + } + + let mut matcher = Matcher::new(Config::DEFAULT); + let pattern = Pattern::parse(search_input, CaseMatching::Smart, Normalization::Smart); + + let mut indices: Vec<u32> = Vec::new(); + let mut haystack_buf = Vec::new(); + + let haystack = Utf32Str::new(command, &mut haystack_buf); + pattern.indices(haystack, &mut matcher, &mut indices); + + // Convert u32 indices to usize + indices.into_iter().map(|i| i as usize).collect() + } +} diff --git a/crates/turtle/src/command/client/search/engines/db.rs b/crates/turtle/src/command/client/search/engines/db.rs new file mode 100644 index 00000000..2765faf5 --- /dev/null +++ b/crates/turtle/src/command/client/search/engines/db.rs @@ -0,0 +1,110 @@ +use super::{SearchEngine, SearchState}; +use async_trait::async_trait; +use crate::atuin_client::{ + database::Database, + database::OptFilters, + database::{QueryToken, QueryTokenizer}, + history::{AUTHOR_FILTER_ALL_USER, History}, + settings::SearchMode, +}; +use eyre::Result; +use norm::Metric; +use norm::fzf::{FzfParser, FzfV2}; +use std::ops::Range; +use tracing::{Level, instrument}; + +pub struct Search(pub SearchMode); + +#[async_trait] +impl SearchEngine for Search { + #[instrument(skip_all, level = Level::TRACE, name = "db_search", fields(mode = ?self.0, query = %state.input.as_str()))] + async fn full_query( + &mut self, + state: &SearchState, + db: &mut dyn Database, + ) -> Result<Vec<History>> { + let results = db + .search( + self.0, + state.filter_mode, + &state.context, + state.input.as_str(), + OptFilters { + limit: Some(200), + authors: vec![AUTHOR_FILTER_ALL_USER.to_string()], + ..Default::default() + }, + ) + .await + // ignore errors as it may be caused by incomplete regex + .map_or(Vec::new(), |r| r.into_iter().collect()); + Ok(results) + } + + #[instrument(skip_all, level = Level::TRACE, name = "db_highlight")] + fn get_highlight_indices(&self, command: &str, search_input: &str) -> Vec<usize> { + if self.0 == SearchMode::Prefix { + return vec![]; + } else if self.0 == SearchMode::FullText { + return get_highlight_indices_fulltext(command, search_input); + } + let mut fzf = FzfV2::new(); + let mut parser = FzfParser::new(); + let query = parser.parse(search_input); + let mut ranges: Vec<Range<usize>> = Vec::new(); + let _ = fzf.distance_and_ranges(query, command, &mut ranges); + + // convert ranges to all indices + ranges.into_iter().flatten().collect() + } +} + +#[instrument(skip_all, level = Level::TRACE, name = "db_highlight_fulltext")] +pub fn get_highlight_indices_fulltext(command: &str, search_input: &str) -> Vec<usize> { + let mut ranges = vec![]; + let lower_command = command.to_ascii_lowercase(); + + for token in QueryTokenizer::new(search_input) { + let matchee = if token.has_uppercase() { + command + } else { + &lower_command + }; + + if token.is_inverse() { + continue; + } + + match token { + QueryToken::Or => {} + QueryToken::Regex(r) => { + if let Ok(re) = regex::Regex::new(r) { + for m in re.find_iter(command) { + ranges.push(m.range()); + } + } + } + QueryToken::MatchStart(term, _) => { + if matchee.starts_with(term) { + ranges.push(0..term.len()); + } + } + QueryToken::MatchEnd(term, _) => { + if matchee.ends_with(term) { + let l = matchee.len(); + ranges.push((l - term.len())..l); + } + } + QueryToken::Match(term, _) | QueryToken::MatchFull(term, _) => { + for (idx, m) in matchee.match_indices(term) { + ranges.push(idx..(idx + m.len())); + } + } + } + } + + let mut ret: Vec<_> = ranges.into_iter().flatten().collect(); + ret.sort_unstable(); + ret.dedup(); + ret +} diff --git a/crates/turtle/src/command/client/search/engines/skim.rs b/crates/turtle/src/command/client/search/engines/skim.rs new file mode 100644 index 00000000..96a6574d --- /dev/null +++ b/crates/turtle/src/command/client/search/engines/skim.rs @@ -0,0 +1,229 @@ +use std::path::Path; + +use async_trait::async_trait; +use crate::atuin_client::{ + database::Database, + history::{History, is_known_agent}, + settings::FilterMode, +}; +use eyre::Result; +use fuzzy_matcher::{FuzzyMatcher, skim::SkimMatcherV2}; +use itertools::Itertools; +use time::OffsetDateTime; +use tokio::task::yield_now; +use tracing::{Level, instrument, warn}; +use uuid; + +use super::{SearchEngine, SearchState}; + +pub struct Search { + all_history: Vec<(History, i32)>, + engine: SkimMatcherV2, +} + +impl Search { + pub fn new() -> Self { + Search { + all_history: vec![], + engine: SkimMatcherV2::default(), + } + } +} + +#[async_trait] +impl SearchEngine for Search { + #[instrument(skip_all, level = Level::TRACE, name = "skim_search", fields(query = %state.input.as_str()))] + async fn full_query( + &mut self, + state: &SearchState, + db: &mut dyn Database, + ) -> Result<Vec<History>> { + if self.all_history.is_empty() { + self.all_history = load_all_history(db).await; + } + + Ok(fuzzy_search(&self.engine, state, &self.all_history).await) + } + + #[instrument(skip_all, level = Level::TRACE, name = "skim_highlight")] + fn get_highlight_indices(&self, command: &str, search_input: &str) -> Vec<usize> { + let (_, indices) = self + .engine + .fuzzy_indices(command, search_input) + .unwrap_or_default(); + indices + } +} + +#[instrument(skip_all, level = Level::TRACE, name = "load_all_history")] +async fn load_all_history(db: &dyn Database) -> Vec<(History, i32)> { + db.all_with_count().await.unwrap() +} + +#[expect(clippy::too_many_lines)] +#[instrument(skip_all, level = Level::TRACE, name = "fuzzy_match", fields(history_count = all_history.len()))] +async fn fuzzy_search( + engine: &SkimMatcherV2, + state: &SearchState, + all_history: &[(History, i32)], +) -> Vec<History> { + let mut set = Vec::with_capacity(200); + let mut ranks = Vec::with_capacity(200); + let query = state.input.as_str(); + let now = OffsetDateTime::now_utc(); + + for (i, (history, count)) in all_history.iter().enumerate() { + if i % 256 == 0 { + yield_now().await; + } + if is_known_agent(&history.author) { + continue; + } + let context = &state.context; + let git_root = context + .git_root + .as_ref() + .and_then(|git_root| git_root.to_str()) + .unwrap_or(&context.cwd); + match state.filter_mode { + FilterMode::Global => {} + // we aggregate host by ',' separating them + FilterMode::Host + if history + .hostname + .split(',') + .contains(&context.hostname.as_str()) => {} + // we aggregate session by concattenating them. + // sessions are 32 byte simple uuid formats + FilterMode::Session + if history + .session + .as_bytes() + .chunks(32) + .contains(&context.session.as_bytes()) => {} + // SessionPreload: include current session + global history from before session start + FilterMode::SessionPreload => { + let is_current_session = { + history + .session + .as_bytes() + .chunks(32) + .any(|chunk| chunk == context.session.as_bytes()) + }; + + if !is_current_session { + let Ok(uuid) = uuid::Uuid::parse_str(&context.session) else { + warn!("failed to parse session id '{}'", context.session); + continue; + }; + let Some(timestamp) = uuid.get_timestamp() else { + warn!( + "failed to get timestamp from uuid '{}'", + uuid.as_hyphenated() + ); + continue; + }; + let (seconds, nanos) = timestamp.to_unix(); + let Ok(session_start) = time::OffsetDateTime::from_unix_timestamp_nanos( + i128::from(seconds) * 1_000_000_000 + i128::from(nanos), + ) else { + warn!( + "failed to create OffsetDateTime from second: {seconds}, nanosecond: {nanos}" + ); + continue; + }; + + if history.timestamp >= session_start { + continue; + } + } + } + // we aggregate directory by ':' separating them + FilterMode::Directory if history.cwd.split(':').contains(&context.cwd.as_str()) => {} + FilterMode::Workspace if history.cwd.split(':').contains(&git_root) => {} + _ => continue, + } + #[expect(clippy::cast_lossless, clippy::cast_precision_loss)] + if let Some((score, indices)) = engine.fuzzy_indices(&history.command, query) { + let begin = indices.first().copied().unwrap_or_default(); + + let mut duration = (now - history.timestamp).as_seconds_f64().log2(); + if !duration.is_finite() || duration <= 1.0 { + duration = 1.0; + } + // these + X.0 just make the log result a bit smoother. + // log is very spiky towards 1-4, but I want a gradual decay. + // eg: + // log2(4) = 2, log2(5) = 2.3 (16% increase) + // log2(8) = 3, log2(9) = 3.16 (5% increase) + // log2(16) = 4, log2(17) = 4.08 (2% increase) + let count = (*count as f64 + 8.0).log2(); + let begin = (begin as f64 + 16.0).log2(); + let path = path_dist(history.cwd.as_ref(), state.context.cwd.as_ref()); + let path = (path as f64 + 8.0).log2(); + + // reduce longer durations, raise higher counts, raise matches close to the start + let score = (-score as f64) * count / path / duration / begin; + + 'insert: { + // algorithm: + // 1. find either the position that this command ranks + // 2. find the same command positioned better than our rank. + for i in 0..set.len() { + // do we out score the current position? + if ranks[i] > score { + ranks.insert(i, score); + set.insert(i, history.clone()); + let mut j = i + 1; + while j < set.len() { + // remove duplicates that have a worse score + if set[j].command == history.command { + ranks.remove(j); + set.remove(j); + + // break this while loop because there won't be any other + // duplicates. + break; + } + j += 1; + } + + // keep it limited + if ranks.len() > 200 { + ranks.pop(); + set.pop(); + } + + break 'insert; + } + // don't continue if this command has a better score already + if set[i].command == history.command { + break 'insert; + } + } + + if set.len() < 200 { + ranks.push(score); + set.push(history.clone()); + } + } + } + } + + set +} + +fn path_dist(a: &Path, b: &Path) -> usize { + let mut a: Vec<_> = a.components().collect(); + let b: Vec<_> = b.components().collect(); + + let mut dist = 0; + + // pop a until there's a common ancestor + while !b.starts_with(&a) { + dist += 1; + a.pop(); + } + + b.len() - a.len() + dist +} diff --git a/crates/turtle/src/command/client/search/history_list.rs b/crates/turtle/src/command/client/search/history_list.rs new file mode 100644 index 00000000..4c83d7eb --- /dev/null +++ b/crates/turtle/src/command/client/search/history_list.rs @@ -0,0 +1,429 @@ +use std::time::Duration; + +use super::duration::format_duration; +use super::engines::SearchEngine; +use crate::atuin_client::{ + history::History, + settings::{UiColumn, UiColumnType}, + theme::{Meaning, Theme}, +}; +use crate::atuin_common::utils::Escapable as _; +use itertools::Itertools; +use ratatui::{ + backend::FromCrossterm, + buffer::Buffer, + crossterm::style, + layout::Rect, + style::{Modifier, Style}, + widgets::{Block, StatefulWidget, Widget}, +}; +use time::OffsetDateTime; + +pub struct HistoryHighlighter<'a> { + pub engine: &'a dyn SearchEngine, + pub search_input: &'a str, +} + +impl HistoryHighlighter<'_> { + pub fn get_highlight_indices(&self, command: &str) -> Vec<usize> { + self.engine + .get_highlight_indices(command, self.search_input) + } +} + +pub struct HistoryList<'a> { + history: &'a [History], + block: Option<Block<'a>>, + inverted: bool, + /// Apply an alternative highlighting to the selected row + alternate_highlight: bool, + now: &'a dyn Fn() -> OffsetDateTime, + indicator: &'a str, + theme: &'a Theme, + history_highlighter: HistoryHighlighter<'a>, + show_numeric_shortcuts: bool, + /// Columns to display (in order, after the indicator) + columns: &'a [UiColumn], +} + +#[derive(Default)] +pub struct ListState { + offset: usize, + selected: usize, + max_entries: usize, +} + +impl ListState { + pub fn selected(&self) -> usize { + self.selected + } + + pub fn max_entries(&self) -> usize { + self.max_entries + } + + pub fn offset(&self) -> usize { + self.offset + } + + pub fn select(&mut self, index: usize) { + self.selected = index; + } +} + +impl StatefulWidget for HistoryList<'_> { + type State = ListState; + + fn render(mut self, area: Rect, buf: &mut Buffer, state: &mut Self::State) { + let list_area = self.block.take().map_or(area, |b| { + let inner_area = b.inner(area); + b.render(area, buf); + inner_area + }); + + if list_area.width < 1 || list_area.height < 1 || self.history.is_empty() { + return; + } + let list_height = list_area.height as usize; + + let (start, end) = self.get_items_bounds(state.selected, state.offset, list_height); + state.offset = start; + state.max_entries = end - start; + + let mut s = DrawState { + buf, + list_area, + x: 0, + y: 0, + state, + inverted: self.inverted, + alternate_highlight: self.alternate_highlight, + now: &self.now, + indicator: self.indicator, + theme: self.theme, + history_highlighter: self.history_highlighter, + show_numeric_shortcuts: self.show_numeric_shortcuts, + columns: self.columns, + }; + + for item in self.history.iter().skip(state.offset).take(end - start) { + s.render_row(item); + + // reset line + s.y += 1; + s.x = 0; + } + } +} + +impl<'a> HistoryList<'a> { + #[expect(clippy::too_many_arguments)] + pub fn new( + history: &'a [History], + inverted: bool, + alternate_highlight: bool, + now: &'a dyn Fn() -> OffsetDateTime, + indicator: &'a str, + theme: &'a Theme, + history_highlighter: HistoryHighlighter<'a>, + show_numeric_shortcuts: bool, + columns: &'a [UiColumn], + ) -> Self { + Self { + history, + block: None, + inverted, + alternate_highlight, + now, + indicator, + theme, + history_highlighter, + show_numeric_shortcuts, + columns, + } + } + + pub fn block(mut self, block: Block<'a>) -> Self { + self.block = Some(block); + self + } + + fn get_items_bounds(&self, selected: usize, offset: usize, height: usize) -> (usize, usize) { + let offset = offset.min(self.history.len().saturating_sub(1)); + + let max_scroll_space = height.min(10).min(self.history.len() - selected); + if offset + height < selected + max_scroll_space { + let end = selected + max_scroll_space; + (end - height, end) + } else if selected < offset { + (selected, selected + height) + } else { + (offset, offset + height) + } + } +} + +struct DrawState<'a> { + buf: &'a mut Buffer, + list_area: Rect, + x: u16, + y: u16, + state: &'a ListState, + inverted: bool, + alternate_highlight: bool, + now: &'a dyn Fn() -> OffsetDateTime, + indicator: &'a str, + theme: &'a Theme, + history_highlighter: HistoryHighlighter<'a>, + show_numeric_shortcuts: bool, + columns: &'a [UiColumn], +} + +// these encode the slices of `" > "`, `" {n} "`, or `" "` in a compact form. +// Yes, this is a hack, but it makes me feel happy +static SLICES: &str = " > 1 2 3 4 5 6 7 8 9 "; + +impl DrawState<'_> { + /// Render a complete row for a history item based on configured columns. + fn render_row(&mut self, h: &History) { + // Always render the indicator first (width 3) + self.index(); + + // Calculate the width for the expanding column + // Fixed columns use their configured width + 1 (trailing space) + let indicator_width: u16 = 3; + let fixed_width: u16 = self + .columns + .iter() + .filter(|c| !c.expand) + .map(|c| c.width + 1) + .sum(); + let expand_width = self + .list_area + .width + .saturating_sub(indicator_width + fixed_width); + + let style = self.theme.as_style(Meaning::Base); + // Render each configured column + for (idx, column) in self.columns.iter().enumerate() { + if idx != 0 { + self.draw(" ", Style::from_crossterm(style)); + } + let width = if column.expand { + expand_width + } else { + column.width + }; + match column.column_type { + UiColumnType::Duration => self.duration(h, width), + UiColumnType::Time => self.time(h, width), + UiColumnType::Datetime => self.datetime(h, width), + UiColumnType::Directory => self.directory(h, width), + UiColumnType::Host => self.host(h, width), + UiColumnType::User => self.user(h, width), + UiColumnType::Exit => self.exit_code(h, width), + UiColumnType::Command => self.command(h), + } + } + } + + fn index(&mut self) { + if !self.show_numeric_shortcuts { + let i = self.y as usize + self.state.offset; + let is_selected = i == self.state.selected(); + let prompt: &str = if is_selected { self.indicator } else { " " }; + self.draw(prompt, Style::default()); + return; + } + + // these encode the slices of `" > "`, `" {n} "`, or `" "` in a compact form. + // Yes, this is a hack, but it makes me feel happy + + let i = self.y as usize + self.state.offset; + let i = i.checked_sub(self.state.selected); + let i = i.unwrap_or(10).min(10) * 2; + let prompt: &str = if i == 0 { + self.indicator + } else { + &SLICES[i..i + 3] + }; + self.draw(prompt, Style::default()); + } + + fn duration(&mut self, h: &History, width: u16) { + let style = self.theme.as_style(if h.success() { + Meaning::AlertInfo + } else { + Meaning::AlertError + }); + let duration = Duration::from_nanos(u64::try_from(h.duration).unwrap_or(0)); + let formatted = format_duration(duration); + let w = width as usize; + // Right-align duration within its column width, plus trailing space + let display = format!("{formatted:>w$}"); + self.draw(&display, Style::from_crossterm(style)); + } + + fn time(&mut self, h: &History, width: u16) { + let style = self.theme.as_style(Meaning::Guidance); + + // Account for the chance that h.timestamp is "in the future" + // This would mean that "since" is negative, and the unwrap here + // would fail. + // If the timestamp would otherwise be in the future, display + // the time since as 0. + let since = (self.now)() - h.timestamp; + let time = format_duration(since.try_into().unwrap_or_default()); + + // Format as "Xs ago" right-aligned within column width + let w = width as usize; + let time_str = format!("{time} ago"); + + let display = format!("{time_str:>w$}"); + self.draw(&display, Style::from_crossterm(style)); + } + + fn command(&mut self, h: &History) { + let mut style = self.theme.as_style(Meaning::Base); + let mut row_highlighted = false; + if !self.alternate_highlight && (self.y as usize + self.state.offset == self.state.selected) + { + row_highlighted = true; + // if not applying alternative highlighting to the whole row, color the command + style = self.theme.as_style(Meaning::AlertError); + style.attributes.set(style::Attribute::Bold); + } + + let highlight_indices = self.history_highlighter.get_highlight_indices( + h.command + .escape_control() + .split_ascii_whitespace() + .join(" ") + .as_str(), + ); + + let mut pos = 0; + for section in h.command.escape_control().split_ascii_whitespace() { + if pos != 0 { + self.draw(" ", Style::from_crossterm(style)); + } + for ch in section.chars() { + if self.x > self.list_area.width { + // Avoid attempting to draw a command section beyond the width + // of the list + return; + } + let mut style = style; + if highlight_indices.contains(&pos) { + if row_highlighted { + // if the row is highlighted bold is not enough as the whole row is bold + // change the color too + style = self.theme.as_style(Meaning::AlertWarn); + } + style.attributes.set(style::Attribute::Bold); + } + let s = ch.to_string(); + self.draw(&s, Style::from_crossterm(style)); + pos += s.len(); + } + pos += 1; + } + } + + /// Render the absolute datetime column (e.g., "2025-01-22 14:35") + fn datetime(&mut self, h: &History, width: u16) { + let style = self.theme.as_style(Meaning::Annotation); + // Format: YYYY-MM-DD HH:MM + let formatted = h + .timestamp + .format( + &time::format_description::parse("[year]-[month]-[day] [hour]:[minute]") + .expect("valid format"), + ) + .unwrap_or_else(|_| "????-??-?? ??:??".to_string()); + let w = width as usize; + let display = format!("{formatted:w$}"); + self.draw(&display, Style::from_crossterm(style)); + } + + /// Render the directory column (working directory, truncated) + fn directory(&mut self, h: &History, width: u16) { + let style = self.theme.as_style(Meaning::Annotation); + let w = width as usize; + let cwd = &h.cwd; + let char_count = cwd.chars().count(); + // Truncate from the left with "..." if too long, plus trailing space + // Use character count for comparison and skip for UTF-8 safety + let display = if char_count > w && w >= 4 { + let truncated: String = cwd.chars().skip(char_count - (w - 3)).collect(); + format!("...{truncated}") + } else { + format!("{cwd:w$}") + }; + self.draw(&display, Style::from_crossterm(style)); + } + + /// Render the host column (just the hostname) + fn host(&mut self, h: &History, width: u16) { + let style = self.theme.as_style(Meaning::Annotation); + let w = width as usize; + // Database stores hostname as "hostname:username" + let host = h.hostname.split(':').next().unwrap_or(&h.hostname); + let char_count = host.chars().count(); + // Use character count for comparison and take for UTF-8 safety + let display = if char_count > w && w >= 4 { + let truncated: String = host.chars().take(w.saturating_sub(4)).collect(); + format!("{truncated}...") + } else { + format!("{host:w$}") + }; + self.draw(&display, Style::from_crossterm(style)); + } + + /// Render the user column + fn user(&mut self, h: &History, width: u16) { + let style = self.theme.as_style(Meaning::Annotation); + let w = width as usize; + // Database stores hostname as "hostname:username" + let user = h.hostname.split(':').nth(1).unwrap_or(""); + let char_count = user.chars().count(); + // Use character count for comparison and take for UTF-8 safety + let display = if char_count > w && w >= 4 { + let truncated: String = user.chars().take(w.saturating_sub(4)).collect(); + format!("{truncated}...") + } else { + format!("{user:w$}") + }; + self.draw(&display, Style::from_crossterm(style)); + } + + /// Render the exit code column + fn exit_code(&mut self, h: &History, width: u16) { + let style = if h.success() { + self.theme.as_style(Meaning::AlertInfo) + } else { + self.theme.as_style(Meaning::AlertError) + }; + let w = width as usize; + let display = format!("{:>w$}", h.exit); + self.draw(&display, Style::from_crossterm(style)); + } + + fn draw(&mut self, s: &str, mut style: Style) { + let cx = self.list_area.left() + self.x; + + let cy = if self.inverted { + self.list_area.top() + self.y + } else { + self.list_area.bottom() - self.y - 1 + }; + + if self.alternate_highlight && (self.y as usize + self.state.offset == self.state.selected) + { + style = style.add_modifier(Modifier::REVERSED); + } + + let w = (self.list_area.width - self.x) as usize; + self.x += self.buf.set_stringn(cx, cy, s, w, style).0 - cx; + } +} diff --git a/crates/turtle/src/command/client/search/inspector.rs b/crates/turtle/src/command/client/search/inspector.rs new file mode 100644 index 00000000..1ebc4383 --- /dev/null +++ b/crates/turtle/src/command/client/search/inspector.rs @@ -0,0 +1,421 @@ +use std::time::Duration; +use time::macros::format_description; + +use crate::atuin_client::{ + history::{History, HistoryStats}, + settings::{Settings, Timezone}, +}; +use ratatui::{ + Frame, + backend::FromCrossterm, + layout::Rect, + prelude::{Constraint, Direction, Layout}, + style::Style, + text::{Span, Text}, + widgets::{Bar, BarChart, BarGroup, Block, Borders, Padding, Paragraph, Row, Table}, +}; + +use super::duration::format_duration; + +use super::super::theme::{Meaning, Theme}; +use super::interactive::{Compactness, to_compactness}; + +#[expect(clippy::cast_sign_loss)] +fn u64_or_zero(num: i64) -> u64 { + if num < 0 { 0 } else { num as u64 } +} + +pub fn draw_commands( + f: &mut Frame<'_>, + parent: Rect, + history: &History, + stats: &HistoryStats, + compact: bool, + theme: &Theme, +) { + let commands = Layout::default() + .direction(if compact { + Direction::Vertical + } else { + Direction::Horizontal + }) + .constraints(if compact { + [ + Constraint::Length(1), + Constraint::Length(1), + Constraint::Min(0), + ] + } else { + [ + Constraint::Ratio(1, 4), + Constraint::Ratio(1, 2), + Constraint::Ratio(1, 4), + ] + }) + .split(parent); + + let command = Paragraph::new(Text::from(Span::styled( + history.command.clone(), + Style::from_crossterm(theme.as_style(Meaning::Important)), + ))) + .block(if compact { + Block::new() + .borders(Borders::NONE) + .style(Style::from_crossterm(theme.as_style(Meaning::Base))) + } else { + Block::new() + .borders(Borders::ALL) + .style(Style::from_crossterm(theme.as_style(Meaning::Base))) + .title("Command") + .padding(Padding::horizontal(1)) + }); + + let previous = Paragraph::new( + stats + .previous + .clone() + .map_or_else(|| "[No previous command]".to_string(), |prev| prev.command), + ) + .block(if compact { + Block::new() + .borders(Borders::NONE) + .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))) + } else { + Block::new() + .borders(Borders::ALL) + .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))) + .title("Previous command") + .padding(Padding::horizontal(1)) + }); + + // Add [] around blank text, as when this is shown in a list + // compacted, it makes it more obviously control text. + let next = Paragraph::new( + stats + .next + .clone() + .map_or_else(|| "[No next command]".to_string(), |next| next.command), + ) + .block(if compact { + Block::new() + .borders(Borders::NONE) + .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))) + } else { + Block::new() + .borders(Borders::ALL) + .title("Next command") + .padding(Padding::horizontal(1)) + .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))) + }); + + f.render_widget(previous, commands[0]); + f.render_widget(command, commands[1]); + f.render_widget(next, commands[2]); +} + +pub fn draw_stats_table( + f: &mut Frame<'_>, + parent: Rect, + history: &History, + tz: Timezone, + stats: &HistoryStats, + theme: &Theme, +) { + let duration = Duration::from_nanos(u64_or_zero(history.duration)); + let avg_duration = Duration::from_nanos(stats.average_duration); + let (host, user) = history.hostname.split_once(':').unwrap_or(("", "")); + + let rows = [ + Row::new(vec!["Host".to_string(), host.to_string()]), + Row::new(vec!["User".to_string(), user.to_string()]), + Row::new(vec![ + "Time".to_string(), + history.timestamp.to_offset(tz.0).to_string(), + ]), + Row::new(vec!["Duration".to_string(), format_duration(duration)]), + Row::new(vec![ + "Avg duration".to_string(), + format_duration(avg_duration), + ]), + Row::new(vec!["Exit".to_string(), history.exit.to_string()]), + Row::new(vec!["Directory".to_string(), history.cwd.clone()]), + Row::new(vec!["Session".to_string(), history.session.clone()]), + Row::new(vec!["Total runs".to_string(), stats.total.to_string()]), + ]; + + let widths = [Constraint::Ratio(1, 5), Constraint::Ratio(4, 5)]; + + let table = Table::new(rows, widths).column_spacing(1).block( + Block::default() + .title("Command stats") + .borders(Borders::ALL) + .style(Style::from_crossterm(theme.as_style(Meaning::Base))) + .padding(Padding::vertical(1)), + ); + + f.render_widget(table, parent); +} + +fn num_to_day(num: &str) -> String { + match num { + "0" => "Sunday".to_string(), + "1" => "Monday".to_string(), + "2" => "Tuesday".to_string(), + "3" => "Wednesday".to_string(), + "4" => "Thursday".to_string(), + "5" => "Friday".to_string(), + "6" => "Saturday".to_string(), + _ => "Invalid day".to_string(), + } +} + +fn sort_duration_over_time(durations: &[(String, i64)]) -> Vec<(String, i64)> { + let format = format_description!("[day]-[month]-[year]"); + let output = format_description!("[month]/[year repr:last_two]"); + + let mut durations: Vec<(time::Date, i64)> = durations + .iter() + .map(|d| { + ( + time::Date::parse(d.0.as_str(), &format).expect("invalid date string from sqlite"), + d.1, + ) + }) + .collect(); + + durations.sort_by_key(|a| a.0); + + durations + .iter() + .map(|(date, duration)| { + ( + date.format(output).expect("failed to format sqlite date"), + *duration, + ) + }) + .collect() +} + +fn draw_stats_charts(f: &mut Frame<'_>, parent: Rect, stats: &HistoryStats, theme: &Theme) { + let exits: Vec<Bar> = stats + .exits + .iter() + .map(|(exit, count)| { + Bar::default() + .label(exit.to_string()) + .value(u64_or_zero(*count)) + }) + .collect(); + + let exits = BarChart::default() + .block( + Block::default() + .title("Exit code distribution") + .style(Style::from_crossterm(theme.as_style(Meaning::Base))) + .borders(Borders::ALL), + ) + .bar_width(3) + .bar_gap(1) + .bar_style(Style::default()) + .value_style(Style::default()) + .label_style(Style::default()) + .data(BarGroup::default().bars(&exits)); + + let day_of_week: Vec<Bar> = stats + .day_of_week + .iter() + .map(|(day, count)| { + Bar::default() + .label(num_to_day(day.as_str())) + .value(u64_or_zero(*count)) + }) + .collect(); + + let day_of_week = BarChart::default() + .block( + Block::default() + .title("Runs per day") + .style(Style::from_crossterm(theme.as_style(Meaning::Base))) + .borders(Borders::ALL), + ) + .bar_width(3) + .bar_gap(1) + .bar_style(Style::default()) + .value_style(Style::default()) + .label_style(Style::default()) + .data(BarGroup::default().bars(&day_of_week)); + + let duration_over_time = sort_duration_over_time(&stats.duration_over_time); + let duration_over_time: Vec<Bar> = duration_over_time + .iter() + .map(|(date, duration)| { + let d = Duration::from_nanos(u64_or_zero(*duration)); + Bar::default() + .label(date.clone()) + .value(u64_or_zero(*duration)) + .text_value(format_duration(d)) + }) + .collect(); + + let duration_over_time = BarChart::default() + .block( + Block::default() + .title("Duration over time") + .style(Style::from_crossterm(theme.as_style(Meaning::Base))) + .borders(Borders::ALL), + ) + .bar_width(5) + .bar_gap(1) + .bar_style(Style::default()) + .value_style(Style::default()) + .label_style(Style::default()) + .data(BarGroup::default().bars(&duration_over_time)); + + let layout = Layout::default() + .direction(Direction::Vertical) + .constraints([ + Constraint::Ratio(1, 3), + Constraint::Ratio(1, 3), + Constraint::Ratio(1, 3), + ]) + .split(parent); + + f.render_widget(exits, layout[0]); + f.render_widget(day_of_week, layout[1]); + f.render_widget(duration_over_time, layout[2]); +} + +pub fn draw( + f: &mut Frame<'_>, + chunk: Rect, + history: &History, + stats: &HistoryStats, + settings: &Settings, + theme: &Theme, + tz: Timezone, +) { + let compactness = to_compactness(f, settings); + + match compactness { + Compactness::Ultracompact => draw_ultracompact(f, chunk, history, stats, theme), + _ => draw_full(f, chunk, history, stats, theme, tz), + } +} + +pub fn draw_ultracompact( + f: &mut Frame<'_>, + chunk: Rect, + history: &History, + stats: &HistoryStats, + theme: &Theme, +) { + draw_commands(f, chunk, history, stats, true, theme); +} + +pub fn draw_full( + f: &mut Frame<'_>, + chunk: Rect, + history: &History, + stats: &HistoryStats, + theme: &Theme, + tz: Timezone, +) { + let vert_layout = Layout::default() + .direction(Direction::Vertical) + .constraints([Constraint::Ratio(1, 5), Constraint::Ratio(4, 5)]) + .split(chunk); + + let stats_layout = Layout::default() + .direction(Direction::Horizontal) + .constraints([Constraint::Ratio(1, 3), Constraint::Ratio(2, 3)]) + .split(vert_layout[1]); + + draw_commands(f, vert_layout[0], history, stats, false, theme); + draw_stats_table(f, stats_layout[0], history, tz, stats, theme); + draw_stats_charts(f, stats_layout[1], stats, theme); +} + +#[cfg(test)] +mod tests { + use super::draw_ultracompact; + use crate::atuin_client::{ + history::{History, HistoryId, HistoryStats}, + theme::ThemeManager, + }; + use ratatui::{backend::TestBackend, prelude::*}; + use time::OffsetDateTime; + + fn mock_history_stats() -> (History, HistoryStats) { + let history = History { + id: HistoryId::from("test1".to_string()), + timestamp: OffsetDateTime::now_utc(), + duration: 3, + exit: 0, + command: "/bin/cmd".to_string(), + cwd: "/toot".to_string(), + session: "sesh1".to_string(), + hostname: "hostn".to_string(), + author: "hostn".to_string(), + intent: None, + deleted_at: None, + }; + let next = History { + id: HistoryId::from("test2".to_string()), + timestamp: OffsetDateTime::now_utc(), + duration: 2, + exit: 0, + command: "/bin/cmd -os".to_string(), + cwd: "/toot".to_string(), + session: "sesh1".to_string(), + hostname: "hostn".to_string(), + author: "hostn".to_string(), + intent: None, + deleted_at: None, + }; + let prev = History { + id: HistoryId::from("test3".to_string()), + timestamp: OffsetDateTime::now_utc(), + duration: 1, + exit: 0, + command: "/bin/cmd -a".to_string(), + cwd: "/toot".to_string(), + session: "sesh1".to_string(), + hostname: "hostn".to_string(), + author: "hostn".to_string(), + intent: None, + deleted_at: None, + }; + let stats = HistoryStats { + next: Some(next.clone()), + previous: Some(prev.clone()), + total: 2, + average_duration: 3, + exits: Vec::new(), + day_of_week: Vec::new(), + duration_over_time: Vec::new(), + }; + (history, stats) + } + + #[test] + fn test_output_looks_correct_for_ultracompact() { + let backend = TestBackend::new(22, 5); + let mut terminal = Terminal::new(backend).expect("Could not create terminal"); + let chunk = Rect::new(0, 0, 22, 5); + let (history, stats) = mock_history_stats(); + let prev = stats.previous.clone().unwrap(); + let next = stats.next.clone().unwrap(); + + let mut manager = ThemeManager::new(Some(true), Some("".to_string())); + let theme = manager.load_theme("(none)", None); + let _ = terminal.draw(|f| draw_ultracompact(f, chunk, &history, &stats, &theme)); + let mut lines = [" "; 5].map(|l| Line::from(l)); + for (n, entry) in [prev, history, next].iter().enumerate() { + let mut l = lines[n].to_string(); + l.replace_range(0..entry.command.len(), &entry.command); + lines[n] = Line::from(l); + } + + terminal.backend().assert_buffer_lines(lines); + } +} diff --git a/crates/turtle/src/command/client/search/interactive.rs b/crates/turtle/src/command/client/search/interactive.rs new file mode 100644 index 00000000..a3d2cb79 --- /dev/null +++ b/crates/turtle/src/command/client/search/interactive.rs @@ -0,0 +1,3041 @@ +use std::{ + io::{IsTerminal, Write, stdout}, + time::Duration, +}; + +#[cfg(unix)] +use std::io::Read as _; + +use crate::atuin_common::{shell::Shell, utils::Escapable as _}; +use eyre::Result; +use time::OffsetDateTime; +use unicode_width::{UnicodeWidthChar, UnicodeWidthStr}; + +use super::{ + cursor::Cursor, + engines::{SearchEngine, SearchState}, + history_list::{HistoryList, ListState}, +}; +use crate::atuin_client::{ + database::{Context, Database, current_context}, + history::{History, HistoryId, HistoryStats, store::HistoryStore}, + settings::{ + CursorStyle, ExitMode, FilterMode, KeymapMode, PreviewStrategy, SearchMode, Settings, + UiColumn, + }, +}; + +use crate::command::client::search::history_list::HistoryHighlighter; +use crate::command::client::search::keybindings::KeymapSet; +use crate::command::client::theme::{Meaning, Theme}; +use crate::{VERSION, command::client::search::engines}; + +use ratatui::{ + Frame, Terminal, TerminalOptions, Viewport, + backend::{CrosstermBackend, FromCrossterm}, + crossterm::{ + cursor::SetCursorStyle, + event::{self, Event, KeyEvent, MouseEvent}, + execute, queue, terminal, + }, + layout::{Alignment, Constraint, Direction, Layout}, + prelude::*, + style::{Modifier, Style}, + text::{Line, Span, Text}, + widgets::{Block, BorderType, Borders, Clear, Padding, Paragraph, Tabs}, +}; + +#[cfg(not(target_os = "windows"))] +use ratatui::crossterm::event::{ + KeyboardEnhancementFlags, PopKeyboardEnhancementFlags, PushKeyboardEnhancementFlags, +}; + +const TAB_TITLES: [&str; 2] = ["Search", "Inspect"]; + +pub enum InputAction { + Accept(usize), + AcceptInspecting, + Copy(usize), + Delete(usize), + DeleteAllMatching(usize), + ReturnOriginal, + ReturnQuery, + Continue, + Redraw, + SwitchContext(Option<usize>), +} + +#[derive(Clone)] +pub struct InspectingState { + current: Option<HistoryId>, + next: Option<HistoryId>, + previous: Option<HistoryId>, +} + +impl InspectingState { + pub fn move_to_previous(&mut self) { + let previous = self.previous.clone(); + self.reset(); + self.current = previous; + } + + pub fn move_to_next(&mut self) { + let next = self.next.clone(); + self.reset(); + self.current = next; + } + + pub fn reset(&mut self) { + self.current = None; + self.next = None; + self.previous = None; + } +} + +pub fn to_compactness(f: &Frame, settings: &Settings) -> Compactness { + if match settings.style { + crate::atuin_client::settings::Style::Auto => f.area().height < 14, + crate::atuin_client::settings::Style::Compact => true, + crate::atuin_client::settings::Style::Full => false, + } { + if settings.auto_hide_height != 0 && f.area().height <= settings.auto_hide_height { + Compactness::Ultracompact + } else { + Compactness::Compact + } + } else { + Compactness::Full + } +} + +#[expect(clippy::struct_field_names)] +#[expect(clippy::struct_excessive_bools)] +pub struct State { + history_count: i64, + results_state: ListState, + switched_search_mode: bool, + search_mode: SearchMode, + results_len: usize, + accept: bool, + keymap_mode: KeymapMode, + prefix: bool, + current_cursor: Option<CursorStyle>, + tab_index: usize, + pending_vim_key: Option<char>, + original_input_empty: bool, + + pub inspecting_state: InspectingState, + + keymaps: KeymapSet, + search: SearchState, + engine: Box<dyn SearchEngine>, + now: Box<dyn Fn() -> OffsetDateTime + Send>, +} + +#[derive(Clone, Copy)] +pub enum Compactness { + Ultracompact, + Compact, + Full, +} + +#[derive(Clone, Copy)] +struct StyleState { + compactness: Compactness, + invert: bool, + inner_width: usize, +} + +impl State { + async fn query_results( + &mut self, + db: &mut dyn Database, + smart_sort: bool, + ) -> Result<Vec<History>> { + let results = self.engine.query(&self.search, db).await?; + + self.inspecting_state = InspectingState { + current: None, + next: None, + previous: None, + }; + self.results_state.select(0); + self.results_len = results.len(); + + if smart_sort { + Ok(crate::atuin_history::sort::sort( + self.search.input.as_str(), + results, + )) + } else { + Ok(results) + } + } + + fn handle_input(&mut self, settings: &Settings, input: &Event) -> InputAction { + match input { + Event::Key(k) => self.handle_key_input(settings, k), + Event::Mouse(m) => self.handle_mouse_input(*m, settings.invert), + Event::Paste(d) => self.handle_paste_input(d), + _ => InputAction::Continue, + } + } + + fn handle_mouse_input(&mut self, input: MouseEvent, inverted: bool) -> InputAction { + match (input.kind, inverted) { + (event::MouseEventKind::ScrollDown, false) + | (event::MouseEventKind::ScrollUp, true) => { + self.scroll_down(1); + } + (event::MouseEventKind::ScrollDown, true) + | (event::MouseEventKind::ScrollUp, false) => { + self.scroll_up(1); + } + _ => {} + } + InputAction::Continue + } + + fn handle_paste_input(&mut self, input: &str) -> InputAction { + for i in input.chars() { + self.search.input.insert(i); + } + InputAction::Continue + } + + fn cast_cursor_style(style: CursorStyle) -> SetCursorStyle { + match style { + CursorStyle::DefaultUserShape => SetCursorStyle::DefaultUserShape, + CursorStyle::BlinkingBlock => SetCursorStyle::BlinkingBlock, + CursorStyle::SteadyBlock => SetCursorStyle::SteadyBlock, + CursorStyle::BlinkingUnderScore => SetCursorStyle::BlinkingUnderScore, + CursorStyle::SteadyUnderScore => SetCursorStyle::SteadyUnderScore, + CursorStyle::BlinkingBar => SetCursorStyle::BlinkingBar, + CursorStyle::SteadyBar => SetCursorStyle::SteadyBar, + } + } + + fn set_keymap_cursor(&mut self, settings: &Settings, keymap_name: &str) { + let cursor_style = if keymap_name == "__clear__" { + None + } else { + settings.keymap_cursor.get(keymap_name).copied() + } + .or_else(|| self.current_cursor.map(|_| CursorStyle::DefaultUserShape)); + + if cursor_style != self.current_cursor + && let Some(style) = cursor_style + { + self.current_cursor = cursor_style; + let _ = execute!(stdout(), Self::cast_cursor_style(style)); + } + } + + pub fn initialize_keymap_cursor(&mut self, settings: &Settings) { + match self.keymap_mode { + KeymapMode::Emacs => self.set_keymap_cursor(settings, "emacs"), + KeymapMode::VimNormal => self.set_keymap_cursor(settings, "vim_normal"), + KeymapMode::VimInsert => self.set_keymap_cursor(settings, "vim_insert"), + KeymapMode::Auto => {} + } + } + + pub fn finalize_keymap_cursor(&mut self, settings: &Settings) { + match settings.keymap_mode_shell { + KeymapMode::Emacs => self.set_keymap_cursor(settings, "emacs"), + KeymapMode::VimNormal => self.set_keymap_cursor(settings, "vim_normal"), + KeymapMode::VimInsert => self.set_keymap_cursor(settings, "vim_insert"), + KeymapMode::Auto => self.set_keymap_cursor(settings, "__clear__"), + } + } + + fn handle_key_exit(settings: &Settings) -> InputAction { + match settings.exit_mode { + ExitMode::ReturnOriginal => InputAction::ReturnOriginal, + ExitMode::ReturnQuery => InputAction::ReturnQuery, + } + } + + /// Select the keymap for the current mode (ignoring prefix). + fn mode_keymap(&self) -> &super::keybindings::Keymap { + if self.tab_index == 1 { + &self.keymaps.inspector + } else { + match self.keymap_mode { + KeymapMode::Emacs | KeymapMode::Auto => &self.keymaps.emacs, + KeymapMode::VimNormal => &self.keymaps.vim_normal, + KeymapMode::VimInsert => &self.keymaps.vim_insert, + } + } + } + + /// Whether the current mode supports character insertion on unmatched keys. + fn is_insert_mode(&self) -> bool { + matches!( + self.keymap_mode, + KeymapMode::Emacs | KeymapMode::Auto | KeymapMode::VimInsert + ) + } + + fn handle_key_input(&mut self, settings: &Settings, input: &KeyEvent) -> InputAction { + use super::keybindings::Action; + use super::keybindings::EvalContext; + use super::keybindings::key::{KeyCodeValue, KeyInput, SingleKey}; + + // Skip release events + if input.kind == event::KeyEventKind::Release { + return InputAction::Continue; + } + + // Reset switched_search_mode at start of each key event + self.switched_search_mode = false; + + // Build evaluation context from current state + let ctx = EvalContext { + cursor_position: self.search.input.position(), + input_width: UnicodeWidthStr::width(self.search.input.as_str()), + input_byte_len: self.search.input.as_str().len(), + selected_index: self.results_state.selected(), + results_len: self.results_len, + original_input_empty: self.original_input_empty, + has_context: self.search.custom_context.is_some(), + }; + + // Convert KeyEvent to SingleKey + let Some(single) = SingleKey::from_event(input) else { + return InputAction::Continue; + }; + + // --- Phase 1: Resolve (take pending key first, then immutable borrows) --- + + // Take pending key before any immutable borrows of self + let pending = self.pending_vim_key.take(); + + // If in prefix mode, try prefix keymap first (single keys only) + let prefix_action = if self.prefix { + let ki = KeyInput::Single(single.clone()); + self.keymaps.prefix.resolve(&ki, &ctx) + } else { + None + }; + + // The if-let/else-if chain here is clearer than map_or_else with nested closures. + #[expect(clippy::option_if_let_else)] + let (action, new_pending) = if prefix_action.is_some() { + (prefix_action, None) + } else { + // Use mode keymap (handles both single and multi-key sequences) + let keymap = self.mode_keymap(); + + if let Some(pending_char) = pending { + // We have a pending key from a previous press (e.g., first 'g' of 'gg') + let pending_single = SingleKey { + code: KeyCodeValue::Char(pending_char), + ctrl: false, + alt: false, + shift: false, + super_key: false, + }; + let seq = KeyInput::Sequence(vec![pending_single, single.clone()]); + let action = keymap + .resolve(&seq, &ctx) + .or_else(|| keymap.resolve(&KeyInput::Single(single.clone()), &ctx)); + (action, None) + } else if keymap.has_sequence_starting_with(&single) + && matches!(single.code, KeyCodeValue::Char(_)) + && !single.ctrl + && !single.alt + { + // This key starts a multi-key sequence; wait for next key + let KeyCodeValue::Char(c) = single.code else { + unreachable!() + }; + (Some(Action::Noop), Some(c)) + } else { + ( + keymap.resolve(&KeyInput::Single(single.clone()), &ctx), + None, + ) + } + }; + + // --- Phase 2: Apply mutations --- + self.pending_vim_key = new_pending; + + // Reset prefix (before execute, so EnterPrefixMode can re-set it) + self.prefix = false; + + if let Some(action) = action { + self.execute_action(&action, settings) + } else { + // No action matched. In insert-capable modes, insert the character. + if self.is_insert_mode() && !single.ctrl && !single.alt { + match single.code { + KeyCodeValue::Char(c) => { + self.search.input.insert(c); + } + KeyCodeValue::Space => { + self.search.input.insert(' '); + } + _ => {} + } + } + InputAction::Continue + } + } + + fn scroll_down(&mut self, scroll_len: usize) { + let i = self.results_state.selected().saturating_sub(scroll_len); + self.inspecting_state.reset(); + self.results_state.select(i); + } + + fn scroll_up(&mut self, scroll_len: usize) { + let i = self.results_state.selected() + scroll_len; + self.results_state + .select(i.min(self.results_len.saturating_sub(1))); + self.inspecting_state.reset(); + } + + /// Execute a resolved action, performing all side effects and returning the + /// appropriate `InputAction` for the event loop. + /// + /// This is the "do it" half of the resolve+execute pipeline. The resolver + /// decides *what* to do (which `Action`), and this function carries it out. + /// + /// Invert handling: scroll actions (`SelectNext`, `ScrollPageDown`, etc.) account + /// for `settings.invert` so that keybindings are always in "visual" terms — + /// users never need to think about invert in their keybinding config. + #[expect(clippy::too_many_lines)] + pub(crate) fn execute_action( + &mut self, + action: &super::keybindings::Action, + settings: &Settings, + ) -> InputAction { + use crate::command::client::search::keybindings::Action; + + match action { + // -- Cursor movement -- + Action::CursorLeft => { + self.search.input.left(); + InputAction::Continue + } + Action::CursorRight => { + self.search.input.right(); + InputAction::Continue + } + Action::CursorWordLeft => { + self.search + .input + .prev_word(&settings.word_chars, settings.word_jump_mode); + InputAction::Continue + } + Action::CursorWordRight => { + self.search + .input + .next_word(&settings.word_chars, settings.word_jump_mode); + InputAction::Continue + } + Action::CursorWordEnd => { + self.search.input.word_end(&settings.word_chars); + InputAction::Continue + } + Action::CursorStart => { + self.search.input.start(); + InputAction::Continue + } + Action::CursorEnd => { + self.search.input.end(); + InputAction::Continue + } + + // -- Editing -- + Action::DeleteCharBefore => { + self.search.input.back(); + InputAction::Continue + } + Action::DeleteCharAfter => { + self.search.input.remove(); + InputAction::Continue + } + Action::DeleteWordBefore => { + self.search + .input + .remove_prev_word(&settings.word_chars, settings.word_jump_mode); + InputAction::Continue + } + Action::DeleteWordAfter => { + self.search + .input + .remove_next_word(&settings.word_chars, settings.word_jump_mode); + InputAction::Continue + } + Action::DeleteToWordBoundary => { + // ctrl-w: remove trailing whitespace, then delete to word boundary + while matches!(self.search.input.back(), Some(c) if c.is_whitespace()) {} + while self.search.input.left() { + if self.search.input.char().unwrap().is_whitespace() { + self.search.input.right(); + break; + } + self.search.input.remove(); + } + InputAction::Continue + } + Action::ClearLine => { + self.search.input.clear(); + InputAction::Continue + } + Action::ClearToStart => { + self.search.input.clear_to_start(); + InputAction::Continue + } + Action::ClearToEnd => { + self.search.input.clear_to_end(); + InputAction::Continue + } + + // -- List navigation (invert-aware) -- + Action::SelectNext => { + if settings.invert { + self.scroll_up(1); + } else { + self.scroll_down(1); + } + InputAction::Continue + } + Action::SelectPrevious => { + if settings.invert { + self.scroll_down(1); + } else { + self.scroll_up(1); + } + InputAction::Continue + } + // -- Page/half-page scroll (invert-aware) -- + Action::ScrollHalfPageUp => { + let scroll_len = self + .results_state + .max_entries() + .saturating_sub(settings.scroll_context_lines) + / 2; + if settings.invert { + self.scroll_down(scroll_len); + } else { + self.scroll_up(scroll_len); + } + InputAction::Continue + } + Action::ScrollHalfPageDown => { + let scroll_len = self + .results_state + .max_entries() + .saturating_sub(settings.scroll_context_lines) + / 2; + if settings.invert { + self.scroll_up(scroll_len); + } else { + self.scroll_down(scroll_len); + } + InputAction::Continue + } + Action::ScrollPageUp => { + let scroll_len = self + .results_state + .max_entries() + .saturating_sub(settings.scroll_context_lines); + if settings.invert { + self.scroll_down(scroll_len); + } else { + self.scroll_up(scroll_len); + } + InputAction::Continue + } + Action::ScrollPageDown => { + let scroll_len = self + .results_state + .max_entries() + .saturating_sub(settings.scroll_context_lines); + if settings.invert { + self.scroll_up(scroll_len); + } else { + self.scroll_down(scroll_len); + } + InputAction::Continue + } + + // -- Absolute jumps (invert-aware) -- + Action::ScrollToTop => { + // Visual top of history + if settings.invert { + self.results_state.select(0); + } else { + let last_idx = self.results_len.saturating_sub(1); + self.results_state.select(last_idx); + } + self.inspecting_state.reset(); + InputAction::Continue + } + Action::ScrollToBottom => { + // Visual bottom of history + if settings.invert { + let last_idx = self.results_len.saturating_sub(1); + self.results_state.select(last_idx); + } else { + self.results_state.select(0); + } + self.inspecting_state.reset(); + InputAction::Continue + } + Action::ScrollToScreenTop => { + // H — jump to top of visible screen + let top = self.results_state.offset(); + let visible = self.results_state.max_entries().min(self.results_len); + let bottom = top + visible.saturating_sub(1); + self.results_state + .select(bottom.min(self.results_len.saturating_sub(1))); + self.inspecting_state.reset(); + InputAction::Continue + } + Action::ScrollToScreenMiddle => { + // M — jump to middle of visible screen + let top = self.results_state.offset(); + let visible = self.results_state.max_entries().min(self.results_len); + let middle = top + visible / 2; + self.results_state + .select(middle.min(self.results_len.saturating_sub(1))); + self.inspecting_state.reset(); + InputAction::Continue + } + Action::ScrollToScreenBottom => { + // L — jump to bottom of visible screen + let top_visible = self.results_state.offset(); + self.results_state.select(top_visible); + self.inspecting_state.reset(); + InputAction::Continue + } + + // -- Commands -- + Action::Accept => { + if self.tab_index == 1 { + return InputAction::AcceptInspecting; + } + self.accept = true; + InputAction::Accept(self.results_state.selected()) + } + Action::AcceptNth(n) => { + self.accept = true; + InputAction::Accept(self.results_state.selected() + *n as usize) + } + Action::ReturnSelection => { + if self.tab_index == 1 { + return InputAction::AcceptInspecting; + } + InputAction::Accept(self.results_state.selected()) + } + Action::ReturnSelectionNth(n) => { + InputAction::Accept(self.results_state.selected() + *n as usize) + } + Action::Copy => InputAction::Copy(self.results_state.selected()), + Action::Delete => InputAction::Delete(self.results_state.selected()), + Action::DeleteAll => InputAction::DeleteAllMatching(self.results_state.selected()), + Action::ReturnOriginal => InputAction::ReturnOriginal, + Action::ReturnQuery => InputAction::ReturnQuery, + Action::Exit => Self::handle_key_exit(settings), + Action::Redraw => InputAction::Redraw, + Action::CycleFilterMode => { + self.search.rotate_filter_mode(settings, 1); + InputAction::Continue + } + Action::CycleSearchMode => { + self.switched_search_mode = true; + self.search_mode = self.search_mode.next(settings); + self.engine = engines::engine(self.search_mode, settings); + InputAction::Continue + } + Action::SwitchContext => { + InputAction::SwitchContext(Some(self.results_state.selected())) + } + Action::ClearContext => InputAction::SwitchContext(None), + Action::ToggleTab => { + self.tab_index = (self.tab_index + 1) % TAB_TITLES.len(); + InputAction::Continue + } + + // -- Mode changes -- + Action::VimEnterNormal => { + self.set_keymap_cursor(settings, "vim_normal"); + self.keymap_mode = KeymapMode::VimNormal; + InputAction::Continue + } + Action::VimEnterInsert => { + self.set_keymap_cursor(settings, "vim_insert"); + self.keymap_mode = KeymapMode::VimInsert; + InputAction::Continue + } + Action::VimEnterInsertAfter => { + self.search.input.right(); + self.set_keymap_cursor(settings, "vim_insert"); + self.keymap_mode = KeymapMode::VimInsert; + InputAction::Continue + } + Action::VimEnterInsertAtStart => { + self.search.input.start(); + self.set_keymap_cursor(settings, "vim_insert"); + self.keymap_mode = KeymapMode::VimInsert; + InputAction::Continue + } + Action::VimEnterInsertAtEnd => { + self.search.input.end(); + self.set_keymap_cursor(settings, "vim_insert"); + self.keymap_mode = KeymapMode::VimInsert; + InputAction::Continue + } + Action::VimSearchInsert => { + self.search.input.clear(); + self.set_keymap_cursor(settings, "vim_insert"); + self.keymap_mode = KeymapMode::VimInsert; + InputAction::Continue + } + Action::VimChangeToEnd => { + self.search.input.clear_to_end(); + self.set_keymap_cursor(settings, "vim_insert"); + self.keymap_mode = KeymapMode::VimInsert; + InputAction::Continue + } + Action::EnterPrefixMode => { + self.prefix = true; + InputAction::Continue + } + + // -- Inspector -- + Action::InspectPrevious => { + self.inspecting_state.move_to_previous(); + InputAction::Redraw + } + Action::InspectNext => { + self.inspecting_state.move_to_next(); + InputAction::Redraw + } + + // -- Special -- + Action::Noop => InputAction::Continue, + } + } + + #[expect(clippy::cast_possible_truncation)] + #[expect(clippy::bool_to_int_with_if)] + fn calc_preview_height( + settings: &Settings, + results: &[History], + selected: usize, + tab_index: usize, + compactness: Compactness, + border_size: u16, + preview_width: u16, + ) -> u16 { + if settings.show_preview + && settings.preview.strategy == PreviewStrategy::Auto + && tab_index == 0 + && !results.is_empty() + { + let length_current_cmd = results[selected].command.len() as u16; + // calculate the number of newlines in the command + let num_newlines = results[selected] + .command + .chars() + .filter(|&c| c == '\n') + .count() as u16; + if num_newlines > 0 { + std::cmp::min( + settings.max_preview_height, + results[selected] + .command + .split('\n') + .map(|line| { + (line.len() as u16 + preview_width - 1 - border_size) + / (preview_width - border_size) + }) + .sum(), + ) + border_size * 2 + } + // The '- 19' takes the characters before the command (duration and time) into account + else if length_current_cmd > preview_width - 19 { + std::cmp::min( + settings.max_preview_height, + (length_current_cmd + preview_width - 1 - border_size) + / (preview_width - border_size), + ) + border_size * 2 + } else { + 1 + } + } else if settings.show_preview + && settings.preview.strategy == PreviewStrategy::Static + && tab_index == 0 + { + let longest_command = results + .iter() + .max_by(|h1, h2| h1.command.len().cmp(&h2.command.len())); + longest_command.map_or(0, |v| { + std::cmp::min( + settings.max_preview_height, + v.command + .split('\n') + .map(|line| { + (line.len() as u16 + preview_width - 1 - border_size) + / (preview_width - border_size) + }) + .sum(), + ) + }) + border_size * 2 + } else if settings.show_preview && settings.preview.strategy == PreviewStrategy::Fixed { + settings.max_preview_height + border_size * 2 + } else if !matches!(compactness, Compactness::Full) || tab_index == 1 { + 0 + } else { + 1 + } + } + + #[expect(clippy::bool_to_int_with_if)] + #[expect(clippy::too_many_lines)] + #[expect(clippy::too_many_arguments)] + fn draw( + &mut self, + f: &mut Frame, + results: &[History], + stats: Option<HistoryStats>, + inspecting: Option<&History>, + settings: &Settings, + theme: &Theme, + popup_mode: bool, + ) { + let area = f.area(); + if popup_mode { + f.render_widget(Clear, area); + } + self.draw_inner(f, area, results, stats, inspecting, settings, theme); + } + + #[expect(clippy::too_many_arguments)] + #[expect(clippy::too_many_lines)] + #[expect(clippy::bool_to_int_with_if)] + fn draw_inner( + &mut self, + f: &mut Frame, + area: Rect, + results: &[History], + stats: Option<HistoryStats>, + inspecting: Option<&History>, + settings: &Settings, + theme: &Theme, + ) { + let compactness = to_compactness(f, settings); + let invert = settings.invert; + let border_size = match compactness { + Compactness::Full => 1, + _ => 0, + }; + let preview_width = area.width.saturating_sub(2); + let preview_height = Self::calc_preview_height( + settings, + results, + self.results_state.selected(), + self.tab_index, + compactness, + border_size, + preview_width, + ); + let show_help = + settings.show_help && (matches!(compactness, Compactness::Full) || area.height > 1); + // This is an OR, as it seems more likely for someone to wish to override + // tabs unexpectedly being missed, than unexpectedly present. + let show_tabs = settings.show_tabs && !matches!(compactness, Compactness::Ultracompact); + let chunks = Layout::default() + .direction(Direction::Vertical) + .margin(0) + .horizontal_margin(1) + .constraints::<&[Constraint]>( + if invert { + [ + Constraint::Length(1 + border_size), // input + Constraint::Min(1), // results list + Constraint::Length(preview_height), // preview + Constraint::Length(if show_tabs { 1 } else { 0 }), // tabs + Constraint::Length(if show_help { 1 } else { 0 }), // header (sic) + ] + } else { + match compactness { + Compactness::Ultracompact => [ + Constraint::Length(if show_help { 1 } else { 0 }), // header + Constraint::Length(0), // tabs + Constraint::Min(1), // results list + Constraint::Length(0), + Constraint::Length(0), + ], + _ => [ + Constraint::Length(if show_help { 1 } else { 0 }), // header + Constraint::Length(if show_tabs { 1 } else { 0 }), // tabs + Constraint::Min(1), // results list + Constraint::Length(1 + border_size), // input + Constraint::Length(preview_height), // preview + ], + } + } + .as_ref(), + ) + .split(area); + + let input_chunk = if invert { chunks[0] } else { chunks[3] }; + let results_list_chunk = if invert { chunks[1] } else { chunks[2] }; + let preview_chunk = if invert { chunks[2] } else { chunks[4] }; + let tabs_chunk = if invert { chunks[3] } else { chunks[1] }; + let header_chunk = if invert { chunks[4] } else { chunks[0] }; + + // TODO: this should be split so that we have one interactive search container that is + // EITHER a search box or an inspector. But I'm not doing that now, way too much atm. + // also allocate less 🙈 + let titles: Vec<_> = TAB_TITLES.iter().copied().map(Line::from).collect(); + + if show_tabs { + let tabs = Tabs::new(titles) + .block(Block::default().borders(Borders::NONE)) + .select(self.tab_index) + .style(Style::default()) + .highlight_style(Style::from_crossterm(theme.as_style(Meaning::Important))); + + f.render_widget(tabs, tabs_chunk); + } + + let style = StyleState { + compactness, + invert, + inner_width: input_chunk.width.into(), + }; + + let header_chunks = Layout::default() + .direction(Direction::Horizontal) + .constraints::<&[Constraint]>( + [ + Constraint::Ratio(1, 5), + Constraint::Ratio(3, 5), + Constraint::Ratio(1, 5), + ] + .as_ref(), + ) + .split(header_chunk); + + let title = Self::build_title(theme); + f.render_widget(title, header_chunks[0]); + + let help = self.build_help(settings, theme); + f.render_widget(help, header_chunks[1]); + + let stats_tab = self.build_stats(theme); + f.render_widget(stats_tab, header_chunks[2]); + + let indicator: String = match compactness { + Compactness::Ultracompact => { + if self.switched_search_mode { + format!("S{}>", self.search_mode.as_str().chars().next().unwrap()) + } else if self.search.custom_context.is_some() { + format!( + "C{}>", + self.search.filter_mode.as_str().chars().next().unwrap() + ) + } else { + format!( + "{}> ", + self.search.filter_mode.as_str().chars().next().unwrap() + ) + } + } + _ => " > ".to_string(), + }; + + match self.tab_index { + 0 => { + let history_highlighter = HistoryHighlighter { + engine: self.engine.as_ref(), + search_input: self.search.input.as_str(), + }; + let results_list = Self::build_results_list( + style, + results, + self.keymap_mode, + &self.now, + indicator.as_str(), + theme, + history_highlighter, + settings.show_numeric_shortcuts, + &settings.ui.columns, + ); + f.render_stateful_widget(results_list, results_list_chunk, &mut self.results_state); + } + + 1 => { + if results.is_empty() { + let message = Paragraph::new("Nothing to inspect") + .block( + Block::new() + .title(Line::from(" Info ".to_string())) + .title_alignment(Alignment::Center) + .borders(Borders::ALL) + .padding(Padding::vertical(2)), + ) + .alignment(Alignment::Center); + f.render_widget(message, results_list_chunk); + } else { + let inspecting = match inspecting { + Some(inspecting) => inspecting, + None => &results[self.results_state.selected()], + }; + super::inspector::draw( + f, + results_list_chunk, + inspecting, + &stats.expect("Drawing inspector, but no stats"), + settings, + theme, + settings.timezone, + ); + } + + // HACK: I'm following up with abstracting this into the UI container, with a + // sub-widget for search + for inspector + let feedback = Paragraph::new( + "The inspector is new - please give feedback (good, or bad) at https://forum.atuin.sh", + ); + f.render_widget(feedback, input_chunk); + + return; + } + + _ => { + panic!("invalid tab index"); + } + } + + if !matches!(compactness, Compactness::Ultracompact) { + let preview_width = match compactness { + Compactness::Full => preview_width - 2, + _ => preview_width, + }; + let preview = self.build_preview( + results, + compactness, + preview_width, + preview_chunk.width.into(), + theme, + ); + #[expect(clippy::cast_possible_truncation)] + let prefix_width = settings + .ui + .columns + .iter() + .take_while(|col| !col.expand) + .map(|col| col.width + 1) + .sum::<u16>() + + " > ".len() as u16; + #[expect(clippy::cast_possible_truncation)] + let min_prefix_width = "[ SRCH: FULLTXT ] ".len() as u16; + self.draw_preview( + f, + style, + input_chunk, + compactness, + preview_chunk, + preview, + std::cmp::max(prefix_width, min_prefix_width), + ); + } + } + + #[expect(clippy::cast_possible_truncation, clippy::too_many_arguments)] + fn draw_preview( + &self, + f: &mut Frame, + style: StyleState, + input_chunk: Rect, + compactness: Compactness, + preview_chunk: Rect, + preview: Paragraph, + prefix_width: u16, + ) { + let input = self.build_input(style, prefix_width); + f.render_widget(input, input_chunk); + + f.render_widget(preview, preview_chunk); + + let extra_width = UnicodeWidthStr::width(self.search.input.substring()); + + let cursor_offset = match compactness { + Compactness::Full => 1, + _ => 0, + }; + f.set_cursor_position(( + // Put cursor past the end of the input text + input_chunk.x + extra_width as u16 + prefix_width + cursor_offset, + input_chunk.y + cursor_offset, + )); + } + + fn build_title(theme: &Theme) -> Paragraph<'_> { + let title = { + let style: Style = Style::from_crossterm(theme.as_style(Meaning::Base)); + Paragraph::new(Text::from(Span::styled( + format!("Atuin v{VERSION}"), + style.add_modifier(Modifier::BOLD), + ))) + }; + title.alignment(Alignment::Left) + } + + #[expect(clippy::unused_self)] + fn build_help(&self, settings: &Settings, theme: &Theme) -> Paragraph<'_> { + match self.tab_index { + // search + 0 => Paragraph::new(Text::from(Line::from(vec![ + Span::styled("<esc>", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(": exit"), + Span::raw(", "), + Span::styled("<tab>", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(": edit"), + Span::raw(", "), + Span::styled("<enter>", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(if settings.enter_accept { + ": run" + } else { + ": edit" + }), + Span::raw(", "), + Span::styled("<ctrl-o>", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(": inspect"), + ]))), + + 1 => Paragraph::new(Text::from(Line::from(vec![ + Span::styled("<esc>", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(": exit"), + Span::raw(", "), + Span::styled("<ctrl-o>", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(": search"), + Span::raw(", "), + Span::styled("<ctrl-d>", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(": delete"), + ]))), + + _ => unreachable!("invalid tab index"), + } + .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))) + .alignment(Alignment::Center) + } + + fn build_stats(&self, theme: &Theme) -> Paragraph<'_> { + Paragraph::new(Text::from(Span::raw(format!( + "history count: {}", + self.history_count, + )))) + .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))) + .alignment(Alignment::Right) + } + + #[expect(clippy::too_many_arguments)] + fn build_results_list<'a>( + style: StyleState, + results: &'a [History], + keymap_mode: KeymapMode, + now: &'a dyn Fn() -> OffsetDateTime, + indicator: &'a str, + theme: &'a Theme, + history_highlighter: HistoryHighlighter<'a>, + show_numeric_shortcuts: bool, + columns: &'a [UiColumn], + ) -> HistoryList<'a> { + let results_list = HistoryList::new( + results, + style.invert, + keymap_mode == KeymapMode::VimNormal, + now, + indicator, + theme, + history_highlighter, + show_numeric_shortcuts, + columns, + ); + + match style.compactness { + Compactness::Full => { + if style.invert { + results_list.block( + Block::default() + .borders(Borders::LEFT | Borders::RIGHT) + .border_type(BorderType::Rounded) + .title(format!("{:─>width$}", "", width = style.inner_width - 2)), + ) + } else { + results_list.block( + Block::default() + .borders(Borders::TOP | Borders::LEFT | Borders::RIGHT) + .border_type(BorderType::Rounded), + ) + } + } + _ => results_list, + } + } + + fn build_input(&self, style: StyleState, prefix_width: u16) -> Paragraph<'_> { + let (pref, mode) = if self.switched_search_mode { + (" SRCH:", self.search_mode.as_str()) + } else if self.search.custom_context.is_some() { + (" CTX:", self.search.filter_mode.as_str()) + } else { + ("", self.search.filter_mode.as_str()) + }; + // 3: surrounding "[" "] " + let mode_width = usize::from(prefix_width) - pref.len() - 3; + // sanity check to ensure we don't exceed the layout limits + debug_assert!(mode_width >= mode.len(), "mode name '{mode}' is too long!"); + let input = format!("[{pref}{mode:^mode_width$}] {}", self.search.input.as_str()); + let input = Paragraph::new(input); + match style.compactness { + Compactness::Full => { + if style.invert { + input.block( + Block::default() + .borders(Borders::LEFT | Borders::RIGHT | Borders::TOP) + .border_type(BorderType::Rounded), + ) + } else { + input.block( + Block::default() + .borders(Borders::LEFT | Borders::RIGHT) + .border_type(BorderType::Rounded) + .title(format!("{:─>width$}", "", width = style.inner_width - 2)), + ) + } + } + _ => input, + } + } + + fn build_preview( + &self, + results: &[History], + compactness: Compactness, + preview_width: u16, + chunk_width: usize, + theme: &Theme, + ) -> Paragraph<'_> { + let selected = self.results_state.selected(); + let command = if results.is_empty() { + String::new() + } else { + let s = &results[selected].command; + let mut lines = Vec::new(); + for line in s.split('\n') { + let line = line.escape_control(); + let mut width = 0; + let mut start = 0; + for (idx, ch) in line.char_indices() { + let w = ch.width().unwrap_or(0); // None for control chars which should not happen + if width + w > preview_width.into() { + lines.push(line[start..idx].to_owned()); + start = idx; + width = w; + } else { + width += w; + } + } + if width != 0 { + lines.push(line[start..].to_owned()); + } + } + lines.join("\n") + }; + + match compactness { + Compactness::Full => Paragraph::new(command).block( + Block::default() + .borders(Borders::BOTTOM | Borders::LEFT | Borders::RIGHT) + .border_type(BorderType::Rounded) + .title(format!("{:─>width$}", "", width = chunk_width - 2)), + ), + _ => Paragraph::new(command) + .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))), + } + } +} + +/// The writer used for terminal output - either stdout or /dev/tty +enum TerminalWriter { + Stdout(std::io::Stdout), + #[cfg(unix)] + Tty(std::fs::File), +} + +impl TerminalWriter { + fn new() -> std::io::Result<Self> { + let stdout = stdout(); + if stdout.is_terminal() { + return Ok(TerminalWriter::Stdout(stdout)); + } + + // If stdout is not a terminal (e.g., captured by command substitution), + // fall back to /dev/tty so the TUI can still render. + // This allows usage like: VAR=$(atuin search -i) + #[cfg(unix)] + { + Ok(TerminalWriter::Tty( + std::fs::File::options() + .read(true) + .write(true) + .open("/dev/tty")?, + )) + } + } +} + +impl Write for TerminalWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> { + match self { + TerminalWriter::Stdout(stdout) => stdout.write(buf), + #[cfg(unix)] + TerminalWriter::Tty(file) => file.write(buf), + } + } + + fn flush(&mut self) -> std::io::Result<()> { + match self { + TerminalWriter::Stdout(stdout) => stdout.flush(), + #[cfg(unix)] + TerminalWriter::Tty(file) => file.flush(), + } + } +} + +/// Screen state captured from atuin pty-proxy's screen server. +#[cfg(unix)] +struct SavedScreen { + #[expect(dead_code)] + rows: u16, + #[expect(dead_code)] + cols: u16, + cursor_row: u16, + cursor_col: u16, + /// Pre-formatted ANSI bytes for each screen row, ready to write to stdout. + rows_data: Vec<Vec<u8>>, +} + +/// Connect to atuin pty-proxy's Unix socket and fetch the current screen state. +/// +/// The wire format is: +/// ```text +/// [rows: u16 BE][cols: u16 BE][cursor_row: u16 BE][cursor_col: u16 BE] +/// [row_0_len: u32 BE][row_0_bytes...] +/// [row_1_len: u32 BE][row_1_bytes...] +/// ... +/// ``` +#[cfg(unix)] +fn fetch_screen_state(socket_path: &str) -> Option<SavedScreen> { + use std::os::unix::net::UnixStream; + + let mut stream = UnixStream::connect(socket_path).ok()?; + stream.set_read_timeout(Some(Duration::from_secs(2))).ok()?; + + let mut data = Vec::new(); + stream.read_to_end(&mut data).ok()?; + + if data.len() < 8 { + return None; + } + + let rows = u16::from_be_bytes([data[0], data[1]]); + let cols = u16::from_be_bytes([data[2], data[3]]); + let cursor_row = u16::from_be_bytes([data[4], data[5]]); + let cursor_col = u16::from_be_bytes([data[6], data[7]]); + + // Parse length-prefixed rows + let mut rows_data = Vec::with_capacity(rows as usize); + let mut offset = 8; + while offset + 4 <= data.len() { + let row_len = u32::from_be_bytes([ + data[offset], + data[offset + 1], + data[offset + 2], + data[offset + 3], + ]) as usize; + offset += 4; + if offset + row_len > data.len() { + break; + } + rows_data.push(data[offset..offset + row_len].to_vec()); + offset += row_len; + } + + Some(SavedScreen { + rows, + cols, + cursor_row, + cursor_col, + rows_data, + }) +} + +/// Restore the screen area that was covered by the popup. +/// +/// Writes the pre-formatted per-row ANSI bytes received from atuin pty-proxy +/// directly to stdout, which correctly handles wide characters, colors, and +/// all text attributes without needing a client-side vt100 parser. +#[cfg(unix)] +fn restore_popup_area(saved: &SavedScreen, popup_rect: Rect, scroll_offset: u16) { + use ratatui::crossterm::cursor::MoveTo; + + let mut stdout = stdout(); + + for dy in 0..popup_rect.height { + let target_row = popup_rect.y + dy; + let source_row = (target_row + scroll_offset) as usize; + + // Clear only the popup region. The server-side rows_formatted() skips + // default cells (spaces with default attributes) using cursor jumps, so + // any popup content at those positions would remain if not cleared + // beforehand. We write `popup_rect.width` spaces instead of + // ClearType::CurrentLine so that only the popup area is cleared, not + // the entire terminal line. + let _ = execute!( + stdout, + MoveTo(popup_rect.x, target_row), + ratatui::crossterm::style::SetAttribute(ratatui::crossterm::style::Attribute::Reset), + ); + let _ = write!(stdout, "{:width$}", "", width = popup_rect.width as usize); + let _ = execute!(stdout, MoveTo(popup_rect.x, target_row)); + + if let Some(row_bytes) = saved.rows_data.get(source_row) { + let _ = stdout.write_all(row_bytes); + } + } + + let _ = execute!( + stdout, + MoveTo( + saved.cursor_col, + saved.cursor_row.saturating_sub(scroll_offset) + ) + ); + let _ = stdout.flush(); +} + +struct Stdout { + writer: TerminalWriter, + inline_mode: bool, + no_mouse: bool, +} + +impl Stdout { + pub fn new(inline_mode: bool, no_mouse: bool) -> std::io::Result<Self> { + terminal::enable_raw_mode()?; + + let mut writer = TerminalWriter::new()?; + + if !inline_mode { + execute!(writer, terminal::EnterAlternateScreen)?; + } + + if !no_mouse { + execute!(writer, event::EnableMouseCapture)?; + } + + execute!(writer, event::EnableBracketedPaste)?; + + #[cfg(not(target_os = "windows"))] + execute!( + writer, + PushKeyboardEnhancementFlags( + KeyboardEnhancementFlags::DISAMBIGUATE_ESCAPE_CODES + | KeyboardEnhancementFlags::REPORT_ALL_KEYS_AS_ESCAPE_CODES + | KeyboardEnhancementFlags::REPORT_ALTERNATE_KEYS + ), + )?; + + Ok(Self { + writer, + inline_mode, + no_mouse, + }) + } +} + +impl Drop for Stdout { + fn drop(&mut self) { + #[cfg(not(target_os = "windows"))] + if let Err(e) = execute!(self.writer, PopKeyboardEnhancementFlags) { + tracing::error!(?e, "Failed to pop keyboard enhancement flags"); + } + + if !self.inline_mode + && let Err(e) = execute!(self.writer, terminal::LeaveAlternateScreen) + { + tracing::error!(?e, "Failed to leave alt screen mode"); + } + + if !self.no_mouse + && let Err(e) = execute!(self.writer, event::DisableMouseCapture) + { + tracing::error!(?e, "Failed to disable mouse capture"); + } + + if let Err(e) = execute!(self.writer, event::DisableBracketedPaste) { + tracing::error!(?e, "Failed to disable bracketed paste"); + } + + if let Err(e) = terminal::disable_raw_mode() { + tracing::error!(?e, "Failed to disable raw mode"); + } + } +} + +impl Write for Stdout { + fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> { + self.writer.write(buf) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.writer.flush() + } +} + +// this is a big blob of horrible! clean it up! +/// Compute the popup position and any scroll offset needed to make room. +/// +/// Given the cursor row, terminal dimensions, and desired popup height, +/// returns `(popup_rect, scroll_offset)` where `scroll_offset` is the number +/// of lines the caller should scroll the terminal up before rendering. +/// +/// This function performs no I/O — it is a pure computation. +#[cfg(unix)] +fn compute_popup_placement( + cursor_row: u16, + term_rows: u16, + term_cols: u16, + inline_height: u16, +) -> (Rect, u16) { + let popup_w = term_cols; + let popup_h = inline_height.min(term_rows); + let space_below = term_rows.saturating_sub(cursor_row); + + let (popup_y, scroll) = if popup_h <= space_below { + // Fits below cursor + (cursor_row, 0u16) + } else if cursor_row >= term_rows / 2 { + // Bottom half — render above cursor (overlay on existing text) + (cursor_row.saturating_sub(popup_h), 0u16) + } else { + // Top half, not enough space — scroll terminal to make room + let scroll = popup_h.saturating_sub(space_below); + let popup_y = cursor_row.saturating_sub(scroll); + (popup_y, scroll) + }; + + (Rect::new(0, popup_y, popup_w, popup_h), scroll) +} + +// for now, it works. But it'd be great if it were more easily readable, and +// modular. I'd like to add some more stats and stuff at some point +#[expect( + clippy::cast_possible_truncation, + clippy::too_many_lines, + clippy::cognitive_complexity +)] +pub async fn history( + query: &[String], + settings: &Settings, + mut db: impl Database, + history_store: &HistoryStore, + theme: &Theme, +) -> Result<String> { + let inline_height = if settings.shell_up_key_binding { + settings + .inline_height_shell_up_key_binding + .unwrap_or(settings.inline_height) + } else { + settings.inline_height + }; + + // Use fullscreen mode if the inline height doesn't fit in the terminal, + // this will preserve the scroll position upon exit. + // Also force fullscreen when stdout isn't a terminal (e.g., command substitution + // like VAR=$(atuin search -i)). In that case, we need to use /dev/tty for the TUI and force + // fullscreen mode (inline mode won't work as it requires cursor position queries + // that don't work when stdout is captured). + let inline_height = if !stdout().is_terminal() { + 0 + } else if let Ok(size) = terminal::size() + && inline_height >= size.1 + { + 0 + } else { + inline_height + }; + + // Popup mode: if running under atuin pty-proxy and inline mode is requested, + // fetch the screen state and render as a centered overlay. + #[cfg(unix)] + let (saved_screen, popup_rect, popup_scroll_offset) = { + let socket_path = std::env::var("ATUIN_PTY_PROXY_SOCKET") + .or_else(|_| std::env::var("ATUIN_HEX_SOCKET")) + .ok(); + if let Some(ref path) = socket_path + && inline_height > 0 + { + let saved = fetch_screen_state(path); + if let Some(ref s) = saved { + let (term_cols, term_rows) = terminal::size().unwrap_or((s.cols, s.rows)); + let (popup_rect, scroll) = + compute_popup_placement(s.cursor_row, term_rows, term_cols, inline_height); + + // Scroll terminal content up to make room if needed + if scroll > 0 { + use ratatui::crossterm::cursor::MoveTo; + let mut stdout = stdout(); + let _ = execute!(stdout, MoveTo(0, term_rows - 1)); + for _ in 0..scroll { + let _ = writeln!(stdout); + } + let _ = stdout.flush(); + } + + (saved, popup_rect, scroll) + } else { + (None, Rect::default(), 0u16) + } + } else { + (None, Rect::default(), 0u16) + } + }; + + let popup_mode = saved_screen.is_some(); + + let stdout = Stdout::new(inline_height > 0, settings.no_mouse)?; + + // In popup mode, clear the popup region on the physical terminal before + // ratatui takes over. Ratatui's diff-based rendering compares against an + // initially-empty buffer, so cells that remain "empty" (spaces with default + // style) won't be written — leaving underlying terminal text visible. + // By pre-clearing with spaces, those cells are already correct on screen. + if popup_mode { + use ratatui::crossterm::cursor::MoveTo; + let mut raw_stdout = std::io::stdout(); + // Queue all commands without flushing so the terminal receives them + // as a single write — no intermediate cursor positions are visible. + let _ = queue!( + raw_stdout, + ratatui::crossterm::style::SetAttribute(ratatui::crossterm::style::Attribute::Reset) + ); + for row in popup_rect.y..popup_rect.y.saturating_add(popup_rect.height) { + let _ = queue!(raw_stdout, MoveTo(popup_rect.x, row)); + let _ = write!( + raw_stdout, + "{:width$}", + "", + width = popup_rect.width as usize + ); + } + let _ = raw_stdout.flush(); + } + + let backend = CrosstermBackend::new(stdout); + let mut terminal = Terminal::with_options( + backend, + TerminalOptions { + viewport: if popup_mode { + Viewport::Fixed(popup_rect) + } else if inline_height > 0 { + Viewport::Inline(inline_height) + } else { + Viewport::Fullscreen + }, + }, + )?; + + let original_query = query.join(" "); + + // Check if this is a command chaining scenario + let is_command_chaining = if settings.command_chaining { + let trimmed = original_query.trim_end(); + trimmed.ends_with("&&") || trimmed.ends_with('|') + } else { + false + }; + + // For command chaining, start with empty input to allow searching for new commands + let search_input = if is_command_chaining { + String::new() + } else { + original_query.clone() + }; + + let mut input = Cursor::from(search_input); + // Put the cursor at the end of the query by default + input.end(); + + let initial_context = current_context().await?; + + let history_count = db.history_count(false).await?; + let search_mode = if settings.shell_up_key_binding { + settings + .search_mode_shell_up_key_binding + .unwrap_or(settings.search_mode) + } else { + settings.search_mode + }; + let default_filter_mode = settings + .filter_mode_shell_up_key_binding + .filter(|_| settings.shell_up_key_binding) + .unwrap_or_else(|| settings.default_filter_mode(initial_context.git_root.is_some())); + let mut app = State { + history_count, + results_state: ListState::default(), + switched_search_mode: false, + search_mode, + tab_index: 0, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::from_settings(settings), + search: SearchState { + input, + filter_mode: default_filter_mode, + context: initial_context.clone(), + custom_context: None, + }, + engine: engines::engine(search_mode, settings), + results_len: 0, + accept: false, + keymap_mode: match settings.keymap_mode { + KeymapMode::Auto => KeymapMode::Emacs, + value => value, + }, + current_cursor: None, + now: if settings.prefers_reduced_motion { + let now = OffsetDateTime::now_utc(); + Box::new(move || now) + } else { + Box::new(OffsetDateTime::now_utc) + }, + prefix: false, + pending_vim_key: None, + original_input_empty: original_query.is_empty(), + }; + + app.initialize_keymap_cursor(settings); + + let mut results = app.query_results(&mut db, settings.smart_sort).await?; + + if inline_height > 0 && !popup_mode { + terminal.clear()?; + } + + let mut stats: Option<HistoryStats> = None; + let mut inspecting: Option<History> = None; + let accept; + let result = 'render: loop { + terminal.draw(|f| { + app.draw( + f, + &results, + stats.clone(), + inspecting.as_ref(), + settings, + theme, + popup_mode, + ); + })?; + + let initial_input = app.search.input.as_str().to_owned(); + let initial_filter_mode = app.search.filter_mode; + let initial_search_mode = app.search_mode; + let initial_custom_context = app.search.custom_context.clone(); + + let event_ready = tokio::task::spawn_blocking(|| event::poll(Duration::from_millis(250))); + + tokio::select! { + event_ready = event_ready => { + if event_ready?? { + loop { + match app.handle_input(settings, &event::read()?) { + InputAction::Continue => {}, + InputAction::Delete(index) => { + if results.is_empty() { + break; + } + app.results_len -= 1; + let selected = app.results_state.selected(); + if selected == app.results_len { + app.inspecting_state.reset(); + app.results_state.select(selected - 1); + } + + let entry = results.remove(index); + + let ids = history_store.delete_entries([entry]).await?; + history_store.incremental_build(&db, &ids).await?; + + app.tab_index = 0; + }, + InputAction::DeleteAllMatching(index) => { + if results.is_empty() { + break; + } + + let command = results[index].command.clone(); + + // Remove matching entries from the visible results + results.retain(|e| e.command != command); + + // Query the DB for ALL entries with this command and delete them + let all_matching = db.query_history( + &format!( + "select * from history where command = '{}' and deleted_at is null", + command.replace('\'', "''") + ) + ).await?; + + let ids = history_store.delete_entries(all_matching).await?; + history_store.incremental_build(&db, &ids).await?; + + app.results_len = results.len(); + app.results_state = ListState::default(); + app.inspecting_state.reset(); + app.tab_index = 0; + }, + InputAction::SwitchContext(index) => { + if let Some(index) = index && let Some(entry) = results.get(index) { + app.search.custom_context = Some(entry.id.clone()); + app.search.context = Context::from_history(entry); + app.search.filter_mode = FilterMode::Session; + app.search.input = Cursor::from(String::new()); + app.results_state = ListState::default(); + } else { + app.search.custom_context = None; + app.search.context = initial_context.clone(); + app.search.filter_mode = default_filter_mode; + } + }, + InputAction::Redraw => { + if !popup_mode { + terminal.clear()?; + } + terminal.draw(|f| { + app.draw(f, &results, stats.clone(), inspecting.as_ref(), settings, theme, popup_mode); + })?; + }, + r => { + accept = app.accept; + break 'render r; + }, + } + if !event::poll(Duration::ZERO)? { + break; + } + } + } + } + } + + if initial_input != app.search.input.as_str() + || initial_filter_mode != app.search.filter_mode + || initial_search_mode != app.search_mode + || initial_custom_context != app.search.custom_context + { + results = app.query_results(&mut db, settings.smart_sort).await?; + } + + // In custom context mode, when no filter is applied, highlight the entry which was used + // to enter the context when changing modes. This helps to find your way around. + if app.search.custom_context.is_some() + && app.search.input.as_str().is_empty() + && (initial_custom_context != app.search.custom_context + || initial_filter_mode != app.search.filter_mode) + && let Some(history_id) = app.search.custom_context.clone() + && let Some(pos) = results.iter().position(|entry| entry.id == history_id) + { + app.results_state.select(pos); + } + + let inspecting_id = app.inspecting_state.clone().current; + // If inspecting ID is not the current inspecting History, update it. + match inspecting_id { + Some(inspecting_id) => { + if inspecting.is_none() || inspecting_id != inspecting.clone().unwrap().id { + inspecting = db.load(inspecting_id.0.as_str()).await?; + } + } + _ => { + inspecting = None; + } + } + + stats = if app.tab_index == 0 { + None + } else if !results.is_empty() { + // If we have stats, then we can indicate next available IDs. This avoids passing + // around a database object, or a full stats object. + let selected = match inspecting.clone() { + Some(insp) => insp, + None => results[app.results_state.selected()].clone(), + }; + let stats = db.stats(&selected).await?; + app.inspecting_state.current = Some(selected.id); + app.inspecting_state.previous = match stats.previous.clone() { + Some(p) => Some(p.id), + _ => None, + }; + app.inspecting_state.next = match stats.next.clone() { + Some(p) => Some(p.id), + _ => None, + }; + Some(stats) + } else { + None + }; + }; + + app.finalize_keymap_cursor(settings); + + if popup_mode { + // In popup mode, restore the screen area that was covered by the popup. + // This must happen before Stdout is dropped (which disables raw mode). + #[cfg(unix)] + if let Some(ref saved) = saved_screen { + restore_popup_area(saved, popup_rect, popup_scroll_offset); + } + } else if inline_height > 0 { + terminal.clear()?; + } + + let accept = accept + && matches!( + Shell::from_env(), + Shell::Zsh | Shell::Fish | Shell::Bash | Shell::Xonsh | Shell::Nu | Shell::Powershell + ); + + let accept_prefix = "__atuin_accept__:"; + + match result { + InputAction::AcceptInspecting => { + match inspecting { + Some(result) => { + let mut command = result.command; + + if accept { + command = String::from(accept_prefix) + &command; + } + + // index is in bounds so we return that entry + Ok(command) + } + None => Ok(String::new()), + } + } + InputAction::Accept(index) if index < results.len() => { + let mut command = results.swap_remove(index).command; + + if is_command_chaining { + command = format!("{} {}", original_query.trim_end(), command); + } else if accept { + command = String::from(accept_prefix) + &command; + } + + // index is in bounds so we return that entry + Ok(command) + } + InputAction::ReturnOriginal => Ok(String::new()), + InputAction::Copy(index) => { + let cmd = results.swap_remove(index).command; + set_clipboard(cmd); + Ok(String::new()) + } + InputAction::ReturnQuery | InputAction::Accept(_) => { + // Either: + // * index == RETURN_QUERY, in which case we should return the input + // * out of bounds -> usually implies no selected entry so we return the input + Ok(app.search.input.into_inner()) + } + InputAction::Continue + | InputAction::Redraw + | InputAction::Delete(_) + | InputAction::DeleteAllMatching(_) + | InputAction::SwitchContext(_) => { + unreachable!("should have been handled!") + } + } +} + +// cli-clipboard only works on Windows, Mac, and Linux. + +#[cfg(all( + feature = "clipboard", + any(target_os = "windows", target_os = "macos", target_os = "linux") +))] +fn set_clipboard(s: String) { + let mut ctx = arboard::Clipboard::new().unwrap(); + ctx.set_text(s).unwrap(); + // Use the clipboard context to make sure it is saved + ctx.get_text().unwrap(); +} + +#[cfg(not(all( + feature = "clipboard", + any(target_os = "windows", target_os = "macos", target_os = "linux") +)))] +fn set_clipboard(_s: String) {} + +#[cfg(test)] +mod tests { + use crate::atuin_client::database::Context; + use crate::atuin_client::history::History; + use crate::atuin_client::settings::{ + FilterMode, KeymapMode, Preview, PreviewStrategy, SearchMode, Settings, + }; + use time::OffsetDateTime; + + use crate::command::client::search::engines::{self, SearchState}; + use crate::command::client::search::history_list::ListState; + + use super::{Compactness, InspectingState, KeymapSet, State}; + + #[test] + #[expect(clippy::too_many_lines)] + fn calc_preview_height_test() { + let settings_preview_auto = Settings { + preview: Preview { + strategy: PreviewStrategy::Auto, + }, + show_preview: true, + ..Settings::utc() + }; + + let settings_preview_auto_h2 = Settings { + preview: Preview { + strategy: PreviewStrategy::Auto, + }, + show_preview: true, + max_preview_height: 2, + ..Settings::utc() + }; + + let settings_preview_h4 = Settings { + preview: Preview { + strategy: PreviewStrategy::Static, + }, + show_preview: true, + max_preview_height: 4, + ..Settings::utc() + }; + + let settings_preview_fixed = Settings { + preview: Preview { + strategy: PreviewStrategy::Fixed, + }, + show_preview: true, + max_preview_height: 15, + ..Settings::utc() + }; + + let cmd_60: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("for i in $(seq -w 10); do echo \"item number $i - abcd\"; done") + .cwd("/") + .build() + .into(); + + let cmd_124: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("echo 'Aurea prima sata est aetas, quae vindice nullo, sponte sua, sine lege fidem rectumque colebat. Poena metusque aberant'") + .cwd("/") + .build() + .into(); + + let cmd_200: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("CREATE USER atuin WITH ENCRYPTED PASSWORD 'supersecretpassword'; CREATE DATABASE atuin WITH OWNER = atuin; \\c atuin; REVOKE ALL PRIVILEGES ON SCHEMA public FROM PUBLIC; echo 'All done. 200 characters'") + .cwd("/") + .build() + .into(); + + let results: Vec<History> = vec![cmd_60, cmd_124, cmd_200]; + + // the selected command does not require a preview + let no_preview = State::calc_preview_height( + &settings_preview_auto, + &results, + 0_usize, + 0_usize, + Compactness::Full, + 1, + 80, + ); + // the selected command requires 2 lines + let preview_h2 = State::calc_preview_height( + &settings_preview_auto, + &results, + 1_usize, + 0_usize, + Compactness::Full, + 1, + 80, + ); + // the selected command requires 3 lines + let preview_h3 = State::calc_preview_height( + &settings_preview_auto, + &results, + 2_usize, + 0_usize, + Compactness::Full, + 1, + 80, + ); + // the selected command requires a preview of 1 line (happens when the command is between preview_width-19 and preview_width) + let preview_one_line = State::calc_preview_height( + &settings_preview_auto, + &results, + 0_usize, + 0_usize, + Compactness::Full, + 1, + 66, + ); + // the selected command requires 3 lines, but we have a max preview height limit of 2 + let preview_limit_at_2 = State::calc_preview_height( + &settings_preview_auto_h2, + &results, + 2_usize, + 0_usize, + Compactness::Full, + 1, + 80, + ); + // the longest command requires 3 lines + let preview_static_h3 = State::calc_preview_height( + &settings_preview_h4, + &results, + 1_usize, + 0_usize, + Compactness::Full, + 1, + 80, + ); + // the longest command requires 10 lines, but we have a max preview height limit of 4 + let preview_static_limit_at_4 = State::calc_preview_height( + &settings_preview_h4, + &results, + 1_usize, + 0_usize, + Compactness::Full, + 1, + 20, + ); + // the longest command requires 10 lines, but we have a max preview height of 15 and a fixed preview strategy + let settings_preview_fixed = State::calc_preview_height( + &settings_preview_fixed, + &results, + 1_usize, + 0_usize, + Compactness::Full, + 1, + 20, + ); + + assert_eq!(no_preview, 1); + // 1 * 2 is the space for the border + let border_space = 2; + assert_eq!(preview_h2, 2 + border_space); + assert_eq!(preview_h3, 3 + border_space); + assert_eq!(preview_one_line, 1 + border_space); + assert_eq!(preview_limit_at_2, 2 + border_space); + assert_eq!(preview_static_h3, 3 + border_space); + assert_eq!(preview_static_limit_at_4, 4 + border_space); + assert_eq!(settings_preview_fixed, 15 + border_space); + } + + // Test when there's no results, scrolling up or down doesn't underflow + #[test] + fn state_scroll_up_underflow() { + let settings = Settings::utc(); + let mut state = State { + history_count: 0, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len: 0, + accept: false, + keymap_mode: KeymapMode::Auto, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::defaults(&settings), + search: SearchState { + input: String::new().into(), + filter_mode: FilterMode::Directory, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + + state.scroll_up(1); + state.scroll_down(1); + } + + #[test] + fn test_accept_keybindings() { + use crate::atuin_client::settings::Keys; + use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + + let mut settings = Settings::utc(); + settings.keys = Keys { + scroll_exits: true, + exit_past_line_start: false, + accept_past_line_end: true, + accept_past_line_start: false, + accept_with_backspace: false, + prefix: "a".to_string(), + }; + + let mut state = State { + history_count: 1, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len: 1, + accept: false, + keymap_mode: KeymapMode::Emacs, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::defaults(&settings), + search: SearchState { + input: String::new().into(), + filter_mode: FilterMode::Global, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + + let tab_event = KeyEvent::new(KeyCode::Tab, KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &tab_event); + assert!( + matches!(result, super::InputAction::Accept(_)), + "Tab should always accept" + ); + + // Test left arrow with accept_past_line_start disabled (should continue) + let left_event = KeyEvent::new(KeyCode::Left, KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &left_event); + assert!( + matches!(result, super::InputAction::Continue), + "Left arrow should continue when disabled" + ); + + // Test left arrow with accept_past_line_start enabled (should accept at start of line) + settings.keys.accept_past_line_start = true; + state.keymaps = KeymapSet::defaults(&settings); + let result = state.handle_key_input(&settings, &left_event); + assert!( + matches!(result, super::InputAction::Accept(_)), + "Left arrow should accept at start of line when enabled" + ); + settings.keys.accept_past_line_start = false; + state.keymaps = KeymapSet::defaults(&settings); + + let backspace_event = KeyEvent::new(KeyCode::Backspace, KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &backspace_event); + assert!( + matches!(result, super::InputAction::Continue), + "Backspace should continue when disabled" + ); + + settings.keys.accept_with_backspace = true; + state.keymaps = KeymapSet::defaults(&settings); + let result = state.handle_key_input(&settings, &backspace_event); + assert!( + matches!(result, super::InputAction::Accept(_)), + "Backspace should accept at start of line when enabled" + ); + + state.search.input.insert('t'); + state.search.input.insert('e'); + state.search.input.insert('s'); + state.search.input.insert('t'); + state.search.input.end(); + + let right_event = KeyEvent::new(KeyCode::Right, KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &right_event); + assert!( + matches!(result, super::InputAction::Accept(_)), + "Right arrow should accept at end of line when enabled" + ); + + settings.keys.accept_past_line_start = true; + state.keymaps = KeymapSet::defaults(&settings); + let left_event = KeyEvent::new(KeyCode::Left, KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &left_event); + assert!( + matches!(result, super::InputAction::Continue), + "Left arrow should continue and end of line, even when enabled" + ); + settings.keys.accept_past_line_start = false; + state.keymaps = KeymapSet::defaults(&settings); + + settings.keys.accept_with_backspace = true; + state.keymaps = KeymapSet::defaults(&settings); + let backspace_event = KeyEvent::new(KeyCode::Backspace, KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &backspace_event); + assert!( + matches!(result, super::InputAction::Continue), + "Backspace should continue at end of line, even when enabled" + ); + settings.keys.accept_with_backspace = false; + state.keymaps = KeymapSet::defaults(&settings); + } + + #[test] + fn test_vim_gg_multikey_sequence() { + use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + + let settings = Settings::utc(); + + let mut state = State { + history_count: 100, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len: 100, + accept: false, + keymap_mode: KeymapMode::VimNormal, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::defaults(&settings), + search: SearchState { + input: String::new().into(), + filter_mode: FilterMode::Global, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + + // Start in the middle of the list + state.results_state.select(50); + + // First 'g' should set pending state + let g_event = KeyEvent::new(KeyCode::Char('g'), KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &g_event); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.pending_vim_key, Some('g')); + assert_eq!(state.results_state.selected(), 50); // Position unchanged + + // Second 'g' should jump to end (visual top in non-inverted mode) + let result = state.handle_key_input(&settings, &g_event); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.pending_vim_key, None); + assert_eq!(state.results_state.selected(), 99); // Jumped to last index (visual top) + } + + #[test] + fn test_vim_g_key_clears_on_other_input() { + use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + + let settings = Settings::utc(); + + let mut state = State { + history_count: 100, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len: 100, + accept: false, + keymap_mode: KeymapMode::VimNormal, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::defaults(&settings), + search: SearchState { + input: String::new().into(), + filter_mode: FilterMode::Global, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + + state.results_state.select(50); + + // Press 'g' to set pending state + let g_event = KeyEvent::new(KeyCode::Char('g'), KeyModifiers::NONE); + state.handle_key_input(&settings, &g_event); + assert_eq!(state.pending_vim_key, Some('g')); + + // Press 'j' - should clear pending state + let j_event = KeyEvent::new(KeyCode::Char('j'), KeyModifiers::NONE); + state.handle_key_input(&settings, &j_event); + assert_eq!(state.pending_vim_key, None); + } + + #[test] + fn test_vim_big_g_jump_to_bottom() { + use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + + let settings = Settings::utc(); + + let mut state = State { + history_count: 100, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len: 100, + accept: false, + keymap_mode: KeymapMode::VimNormal, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::defaults(&settings), + search: SearchState { + input: String::new().into(), + filter_mode: FilterMode::Global, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + + state.results_state.select(50); + + // 'G' should jump to visual bottom (index 0 in non-inverted mode) + let big_g_event = KeyEvent::new(KeyCode::Char('G'), KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &big_g_event); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.results_state.selected(), 0); + } + + #[test] + fn test_vim_ctrl_u_d_half_page_scroll() { + use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + + let settings = Settings::utc(); + + let mut state = State { + history_count: 100, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len: 100, + accept: false, + keymap_mode: KeymapMode::VimNormal, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::defaults(&settings), + search: SearchState { + input: String::new().into(), + filter_mode: FilterMode::Global, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + + state.results_state.select(50); + + // Ctrl+d should return Continue and clear pending key + // (scroll amount depends on max_entries which is 0 in tests) + state.pending_vim_key = Some('g'); + let ctrl_d_event = KeyEvent::new(KeyCode::Char('d'), KeyModifiers::CONTROL); + let result = state.handle_key_input(&settings, &ctrl_d_event); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.pending_vim_key, None); + + // Ctrl+u should return Continue and clear pending key + state.pending_vim_key = Some('g'); + let ctrl_u_event = KeyEvent::new(KeyCode::Char('u'), KeyModifiers::CONTROL); + let result = state.handle_key_input(&settings, &ctrl_u_event); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.pending_vim_key, None); + } + + #[test] + fn test_vim_ctrl_f_b_full_page_scroll() { + use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + + let settings = Settings::utc(); + + let mut state = State { + history_count: 100, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len: 100, + accept: false, + keymap_mode: KeymapMode::VimNormal, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::defaults(&settings), + search: SearchState { + input: String::new().into(), + filter_mode: FilterMode::Global, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + + state.results_state.select(50); + + // Ctrl+f should return Continue and clear pending key + // (scroll amount depends on max_entries which is 0 in tests) + state.pending_vim_key = Some('g'); + let ctrl_f_event = KeyEvent::new(KeyCode::Char('f'), KeyModifiers::CONTROL); + let result = state.handle_key_input(&settings, &ctrl_f_event); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.pending_vim_key, None); + + // Ctrl+b should return Continue and clear pending key + state.pending_vim_key = Some('g'); + let ctrl_b_event = KeyEvent::new(KeyCode::Char('b'), KeyModifiers::CONTROL); + let result = state.handle_key_input(&settings, &ctrl_b_event); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.pending_vim_key, None); + } + + // ----------------------------------------------------------------------- + // Executor tests (execute_action) + // ----------------------------------------------------------------------- + + /// Helper to build a State for executor tests. + fn make_executor_state(results_len: usize, selected: usize) -> State { + let settings = Settings::utc(); + let mut state = State { + history_count: results_len as i64, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len, + accept: false, + keymap_mode: KeymapMode::Emacs, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::defaults(&settings), + search: SearchState { + input: String::new().into(), + filter_mode: FilterMode::Global, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + state.results_state.select(selected); + state + } + + #[test] + fn execute_select_next_no_invert() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 50); + let settings = Settings::utc(); + let result = state.execute_action(&Action::SelectNext, &settings); + assert!(matches!(result, super::InputAction::Continue)); + // Non-inverted: SelectNext = scroll_down = selected - 1 + assert_eq!(state.results_state.selected(), 49); + } + + #[test] + fn execute_select_next_with_invert() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 50); + let mut settings = Settings::utc(); + settings.invert = true; + let result = state.execute_action(&Action::SelectNext, &settings); + assert!(matches!(result, super::InputAction::Continue)); + // Inverted: SelectNext = scroll_up = selected + 1 + assert_eq!(state.results_state.selected(), 51); + } + + #[test] + fn execute_select_previous_no_invert() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 50); + let settings = Settings::utc(); + let result = state.execute_action(&Action::SelectPrevious, &settings); + assert!(matches!(result, super::InputAction::Continue)); + // Non-inverted: SelectPrevious = scroll_up = selected + 1 + assert_eq!(state.results_state.selected(), 51); + } + + #[test] + fn execute_vim_enter_normal() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + let settings = Settings::utc(); + let result = state.execute_action(&Action::VimEnterNormal, &settings); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.keymap_mode, KeymapMode::VimNormal); + } + + #[test] + fn execute_vim_enter_insert() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + state.keymap_mode = KeymapMode::VimNormal; + let settings = Settings::utc(); + let result = state.execute_action(&Action::VimEnterInsert, &settings); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.keymap_mode, KeymapMode::VimInsert); + } + + #[test] + fn execute_accept_sets_accept_flag() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 5); + let mut settings = Settings::utc(); + settings.enter_accept = true; + let result = state.execute_action(&Action::Accept, &settings); + assert!(matches!(result, super::InputAction::Accept(5))); + assert!(state.accept); + } + + #[test] + fn execute_return_selection_does_not_set_accept() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 5); + let settings = Settings::utc(); + let result = state.execute_action(&Action::ReturnSelection, &settings); + assert!(matches!(result, super::InputAction::Accept(5))); + assert!(!state.accept); + } + + #[test] + fn execute_accept_nth() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 5); + let settings = Settings::utc(); + let result = state.execute_action(&Action::AcceptNth(3), &settings); + assert!(matches!(result, super::InputAction::Accept(8))); + } + + #[test] + fn execute_scroll_to_top_no_invert() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 50); + let settings = Settings::utc(); + let result = state.execute_action(&Action::ScrollToTop, &settings); + assert!(matches!(result, super::InputAction::Continue)); + // Non-inverted: visual top = highest index + assert_eq!(state.results_state.selected(), 99); + } + + #[test] + fn execute_scroll_to_top_with_invert() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 50); + let mut settings = Settings::utc(); + settings.invert = true; + let result = state.execute_action(&Action::ScrollToTop, &settings); + assert!(matches!(result, super::InputAction::Continue)); + // Inverted: visual top = index 0 + assert_eq!(state.results_state.selected(), 0); + } + + #[test] + fn execute_scroll_to_bottom_no_invert() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 50); + let settings = Settings::utc(); + let result = state.execute_action(&Action::ScrollToBottom, &settings); + assert!(matches!(result, super::InputAction::Continue)); + // Non-inverted: visual bottom = index 0 + assert_eq!(state.results_state.selected(), 0); + } + + #[test] + fn execute_toggle_tab() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + let settings = Settings::utc(); + assert_eq!(state.tab_index, 0); + state.execute_action(&Action::ToggleTab, &settings); + assert_eq!(state.tab_index, 1); + state.execute_action(&Action::ToggleTab, &settings); + assert_eq!(state.tab_index, 0); + } + + #[test] + fn execute_enter_prefix_mode() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + let settings = Settings::utc(); + assert!(!state.prefix); + state.execute_action(&Action::EnterPrefixMode, &settings); + assert!(state.prefix); + } + + #[test] + fn execute_exit_returns_based_on_exit_mode() { + use crate::atuin_client::settings::ExitMode; + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + let mut settings = Settings::utc(); + + settings.exit_mode = ExitMode::ReturnOriginal; + let result = state.execute_action(&Action::Exit, &settings); + assert!(matches!(result, super::InputAction::ReturnOriginal)); + + settings.exit_mode = ExitMode::ReturnQuery; + let result = state.execute_action(&Action::Exit, &settings); + assert!(matches!(result, super::InputAction::ReturnQuery)); + } + + #[test] + fn execute_return_original() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + let settings = Settings::utc(); + let result = state.execute_action(&Action::ReturnOriginal, &settings); + assert!(matches!(result, super::InputAction::ReturnOriginal)); + } + + #[test] + fn execute_copy() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 7); + let settings = Settings::utc(); + let result = state.execute_action(&Action::Copy, &settings); + assert!(matches!(result, super::InputAction::Copy(7))); + } + + #[test] + fn execute_delete() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 7); + let settings = Settings::utc(); + let result = state.execute_action(&Action::Delete, &settings); + assert!(matches!(result, super::InputAction::Delete(7))); + } + + #[test] + fn execute_switch_context() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 7); + let settings = Settings::utc(); + let result = state.execute_action(&Action::SwitchContext, &settings); + assert!(matches!(result, super::InputAction::SwitchContext(Some(7)))); + } + + #[test] + fn execute_clear_context() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 7); + let settings = Settings::utc(); + let result = state.execute_action(&Action::ClearContext, &settings); + assert!(matches!(result, super::InputAction::SwitchContext(None))); + } + + #[test] + fn execute_noop() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 50); + let settings = Settings::utc(); + let result = state.execute_action(&Action::Noop, &settings); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.results_state.selected(), 50); + } + + #[test] + fn execute_accept_in_inspector_tab() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 5); + state.tab_index = 1; + let settings = Settings::utc(); + let result = state.execute_action(&Action::Accept, &settings); + assert!(matches!(result, super::InputAction::AcceptInspecting)); + } + + #[test] + fn execute_cycle_search_mode() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + let settings = Settings::utc(); + let original_mode = state.search_mode; + let result = state.execute_action(&Action::CycleSearchMode, &settings); + assert!(matches!(result, super::InputAction::Continue)); + assert!(state.switched_search_mode); + assert_ne!(state.search_mode, original_mode); + } + + #[test] + fn execute_vim_search_insert() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + state.search.input.insert('h'); + state.search.input.insert('i'); + state.keymap_mode = KeymapMode::VimNormal; + let settings = Settings::utc(); + let result = state.execute_action(&Action::VimSearchInsert, &settings); + assert!(matches!(result, super::InputAction::Continue)); + // Should clear input and switch to insert mode + assert_eq!(state.search.input.as_str(), ""); + assert_eq!(state.keymap_mode, KeymapMode::VimInsert); + } + + #[test] + fn execute_cursor_movement() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + let settings = Settings::utc(); + + // Insert some text + state.search.input.insert('h'); + state.search.input.insert('e'); + state.search.input.insert('l'); + state.search.input.insert('l'); + state.search.input.insert('o'); + // cursor is at end (position 5) + + // CursorLeft + state.execute_action(&Action::CursorLeft, &settings); + assert_eq!(state.search.input.position(), 4); + + // CursorStart + state.execute_action(&Action::CursorStart, &settings); + assert_eq!(state.search.input.position(), 0); + + // CursorEnd + state.execute_action(&Action::CursorEnd, &settings); + assert_eq!(state.search.input.position(), 5); + + // CursorRight at end does nothing + state.execute_action(&Action::CursorRight, &settings); + assert_eq!(state.search.input.position(), 5); + } + + #[test] + fn execute_editing() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + let settings = Settings::utc(); + + // Insert "hello" + state.search.input.insert('h'); + state.search.input.insert('e'); + state.search.input.insert('l'); + state.search.input.insert('l'); + state.search.input.insert('o'); + + // DeleteCharBefore (backspace) + state.execute_action(&Action::DeleteCharBefore, &settings); + assert_eq!(state.search.input.as_str(), "hell"); + + // ClearLine + state.execute_action(&Action::ClearLine, &settings); + assert_eq!(state.search.input.as_str(), ""); + } + + #[test] + fn keymap_config_return_query() { + use crate::atuin_client::settings::KeyBindingConfig; + use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + use std::collections::HashMap; + + let mut settings = Settings::utc(); + // Configure tab to return-query + settings.keymap.emacs = HashMap::from([( + "tab".to_string(), + KeyBindingConfig::Simple("return-query".to_string()), + )]); + + let mut state = State { + history_count: 100, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len: 100, + accept: false, + keymap_mode: KeymapMode::Emacs, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::from_settings(&settings), + search: SearchState { + input: "test query".to_string().into(), + filter_mode: FilterMode::Global, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + + let tab_event = KeyEvent::new(KeyCode::Tab, KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &tab_event); + assert!( + matches!(result, super::InputAction::ReturnQuery), + "Tab configured as return-query should return InputAction::ReturnQuery" + ); + } +} diff --git a/crates/turtle/src/command/client/search/keybindings/actions.rs b/crates/turtle/src/command/client/search/keybindings/actions.rs new file mode 100644 index 00000000..ff2ef7de --- /dev/null +++ b/crates/turtle/src/command/client/search/keybindings/actions.rs @@ -0,0 +1,322 @@ +use std::fmt; + +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +/// All possible actions that can be triggered by a keybinding. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Action { + // Cursor movement + CursorLeft, + CursorRight, + CursorWordLeft, + CursorWordRight, + CursorWordEnd, + CursorStart, + CursorEnd, + + // Editing + DeleteCharBefore, + DeleteCharAfter, + DeleteWordBefore, + DeleteWordAfter, + DeleteToWordBoundary, + ClearLine, + ClearToStart, + ClearToEnd, + + // List navigation + SelectNext, + SelectPrevious, + ScrollHalfPageUp, + ScrollHalfPageDown, + ScrollPageUp, + ScrollPageDown, + ScrollToTop, + ScrollToBottom, + ScrollToScreenTop, + ScrollToScreenMiddle, + ScrollToScreenBottom, + + // Commands — accept selection and execute immediately + Accept, + AcceptNth(u8), + // Commands — return selection to command line without executing + ReturnSelection, + ReturnSelectionNth(u8), + // Commands — other + Copy, + Delete, + DeleteAll, + ReturnOriginal, + ReturnQuery, + Exit, + Redraw, + CycleFilterMode, + CycleSearchMode, + SwitchContext, + ClearContext, + ToggleTab, + + // Mode changes + VimEnterNormal, + VimEnterInsert, + VimEnterInsertAfter, + VimEnterInsertAtStart, + VimEnterInsertAtEnd, + VimSearchInsert, + VimChangeToEnd, + EnterPrefixMode, + + // Inspector + InspectPrevious, + InspectNext, + + // Special + Noop, +} + +impl Action { + /// Convert from a kebab-case string. + pub fn from_str(s: &str) -> Result<Self, String> { + // Handle accept-N and return-selection-N patterns + if let Some(rest) = s.strip_prefix("accept-") + && let Ok(n) = rest.parse::<u8>() + && (1..=9).contains(&n) + { + return Ok(Action::AcceptNth(n)); + } + if let Some(rest) = s.strip_prefix("return-selection-") + && let Ok(n) = rest.parse::<u8>() + && (1..=9).contains(&n) + { + return Ok(Action::ReturnSelectionNth(n)); + } + + match s { + "cursor-left" => Ok(Action::CursorLeft), + "cursor-right" => Ok(Action::CursorRight), + "cursor-word-left" => Ok(Action::CursorWordLeft), + "cursor-word-right" => Ok(Action::CursorWordRight), + "cursor-word-end" => Ok(Action::CursorWordEnd), + "cursor-start" => Ok(Action::CursorStart), + "cursor-end" => Ok(Action::CursorEnd), + + "delete-char-before" => Ok(Action::DeleteCharBefore), + "delete-char-after" => Ok(Action::DeleteCharAfter), + "delete-word-before" => Ok(Action::DeleteWordBefore), + "delete-word-after" => Ok(Action::DeleteWordAfter), + "delete-to-word-boundary" => Ok(Action::DeleteToWordBoundary), + "clear-line" => Ok(Action::ClearLine), + "clear-to-start" => Ok(Action::ClearToStart), + "clear-to-end" => Ok(Action::ClearToEnd), + + "select-next" => Ok(Action::SelectNext), + "select-previous" => Ok(Action::SelectPrevious), + "scroll-half-page-up" => Ok(Action::ScrollHalfPageUp), + "scroll-half-page-down" => Ok(Action::ScrollHalfPageDown), + "scroll-page-up" => Ok(Action::ScrollPageUp), + "scroll-page-down" => Ok(Action::ScrollPageDown), + "scroll-to-top" => Ok(Action::ScrollToTop), + "scroll-to-bottom" => Ok(Action::ScrollToBottom), + "scroll-to-screen-top" => Ok(Action::ScrollToScreenTop), + "scroll-to-screen-middle" => Ok(Action::ScrollToScreenMiddle), + "scroll-to-screen-bottom" => Ok(Action::ScrollToScreenBottom), + + "accept" => Ok(Action::Accept), + "return-selection" => Ok(Action::ReturnSelection), + "copy" => Ok(Action::Copy), + "delete" => Ok(Action::Delete), + "delete-all" => Ok(Action::DeleteAll), + "return-original" => Ok(Action::ReturnOriginal), + "return-query" => Ok(Action::ReturnQuery), + "exit" => Ok(Action::Exit), + "redraw" => Ok(Action::Redraw), + "cycle-filter-mode" => Ok(Action::CycleFilterMode), + "cycle-search-mode" => Ok(Action::CycleSearchMode), + "switch-context" => Ok(Action::SwitchContext), + "clear-context" => Ok(Action::ClearContext), + "toggle-tab" => Ok(Action::ToggleTab), + + "vim-enter-normal" => Ok(Action::VimEnterNormal), + "vim-enter-insert" => Ok(Action::VimEnterInsert), + "vim-enter-insert-after" => Ok(Action::VimEnterInsertAfter), + "vim-enter-insert-at-start" => Ok(Action::VimEnterInsertAtStart), + "vim-enter-insert-at-end" => Ok(Action::VimEnterInsertAtEnd), + "vim-search-insert" => Ok(Action::VimSearchInsert), + "vim-change-to-end" => Ok(Action::VimChangeToEnd), + "enter-prefix-mode" => Ok(Action::EnterPrefixMode), + + "inspect-previous" => Ok(Action::InspectPrevious), + "inspect-next" => Ok(Action::InspectNext), + + "noop" => Ok(Action::Noop), + + _ => Err(format!("unknown action: {s}")), + } + } + + /// Convert to a kebab-case string. + pub fn as_str(&self) -> String { + match self { + Action::CursorLeft => "cursor-left".to_string(), + Action::CursorRight => "cursor-right".to_string(), + Action::CursorWordLeft => "cursor-word-left".to_string(), + Action::CursorWordRight => "cursor-word-right".to_string(), + Action::CursorWordEnd => "cursor-word-end".to_string(), + Action::CursorStart => "cursor-start".to_string(), + Action::CursorEnd => "cursor-end".to_string(), + + Action::DeleteCharBefore => "delete-char-before".to_string(), + Action::DeleteCharAfter => "delete-char-after".to_string(), + Action::DeleteWordBefore => "delete-word-before".to_string(), + Action::DeleteWordAfter => "delete-word-after".to_string(), + Action::DeleteToWordBoundary => "delete-to-word-boundary".to_string(), + Action::ClearLine => "clear-line".to_string(), + Action::ClearToStart => "clear-to-start".to_string(), + Action::ClearToEnd => "clear-to-end".to_string(), + + Action::SelectNext => "select-next".to_string(), + Action::SelectPrevious => "select-previous".to_string(), + Action::ScrollHalfPageUp => "scroll-half-page-up".to_string(), + Action::ScrollHalfPageDown => "scroll-half-page-down".to_string(), + Action::ScrollPageUp => "scroll-page-up".to_string(), + Action::ScrollPageDown => "scroll-page-down".to_string(), + Action::ScrollToTop => "scroll-to-top".to_string(), + Action::ScrollToBottom => "scroll-to-bottom".to_string(), + Action::ScrollToScreenTop => "scroll-to-screen-top".to_string(), + Action::ScrollToScreenMiddle => "scroll-to-screen-middle".to_string(), + Action::ScrollToScreenBottom => "scroll-to-screen-bottom".to_string(), + + Action::Accept => "accept".to_string(), + Action::AcceptNth(n) => format!("accept-{n}"), + Action::ReturnSelection => "return-selection".to_string(), + Action::ReturnSelectionNth(n) => format!("return-selection-{n}"), + Action::Copy => "copy".to_string(), + Action::Delete => "delete".to_string(), + Action::DeleteAll => "delete-all".to_string(), + Action::ReturnOriginal => "return-original".to_string(), + Action::ReturnQuery => "return-query".to_string(), + Action::Exit => "exit".to_string(), + Action::Redraw => "redraw".to_string(), + Action::CycleFilterMode => "cycle-filter-mode".to_string(), + Action::CycleSearchMode => "cycle-search-mode".to_string(), + Action::SwitchContext => "switch-context".to_string(), + Action::ClearContext => "clear-context".to_string(), + Action::ToggleTab => "toggle-tab".to_string(), + + Action::VimEnterNormal => "vim-enter-normal".to_string(), + Action::VimEnterInsert => "vim-enter-insert".to_string(), + Action::VimEnterInsertAfter => "vim-enter-insert-after".to_string(), + Action::VimEnterInsertAtStart => "vim-enter-insert-at-start".to_string(), + Action::VimEnterInsertAtEnd => "vim-enter-insert-at-end".to_string(), + Action::VimSearchInsert => "vim-search-insert".to_string(), + Action::VimChangeToEnd => "vim-change-to-end".to_string(), + Action::EnterPrefixMode => "enter-prefix-mode".to_string(), + + Action::InspectPrevious => "inspect-previous".to_string(), + Action::InspectNext => "inspect-next".to_string(), + + Action::Noop => "noop".to_string(), + } + } +} + +impl fmt::Display for Action { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.as_str()) + } +} + +impl Serialize for Action { + fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + serializer.serialize_str(&self.as_str()) + } +} + +impl<'de> Deserialize<'de> for Action { + fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> { + let s = String::deserialize(deserializer)?; + Action::from_str(&s).map_err(serde::de::Error::custom) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_basic_actions() { + assert_eq!(Action::from_str("cursor-left").unwrap(), Action::CursorLeft); + assert_eq!(Action::from_str("accept").unwrap(), Action::Accept); + assert_eq!(Action::from_str("exit").unwrap(), Action::Exit); + assert_eq!(Action::from_str("noop").unwrap(), Action::Noop); + assert_eq!( + Action::from_str("vim-enter-normal").unwrap(), + Action::VimEnterNormal + ); + } + + #[test] + fn parse_accept_nth() { + assert_eq!(Action::from_str("accept-1").unwrap(), Action::AcceptNth(1)); + assert_eq!(Action::from_str("accept-9").unwrap(), Action::AcceptNth(9)); + } + + #[test] + fn parse_return_selection() { + assert_eq!( + Action::from_str("return-selection").unwrap(), + Action::ReturnSelection + ); + assert_eq!( + Action::from_str("return-selection-1").unwrap(), + Action::ReturnSelectionNth(1) + ); + assert_eq!( + Action::from_str("return-selection-9").unwrap(), + Action::ReturnSelectionNth(9) + ); + } + + #[test] + fn parse_unknown_action() { + assert!(Action::from_str("unknown-action").is_err()); + assert!(Action::from_str("accept-0").is_err()); + assert!(Action::from_str("accept-10").is_err()); + assert!(Action::from_str("return-selection-0").is_err()); + assert!(Action::from_str("return-selection-10").is_err()); + } + + #[test] + fn round_trip() { + let actions = vec![ + Action::CursorLeft, + Action::Accept, + Action::AcceptNth(5), + Action::ReturnSelection, + Action::ReturnSelectionNth(3), + Action::VimSearchInsert, + Action::ScrollToScreenMiddle, + ]; + for action in actions { + let s = action.as_str(); + let parsed = Action::from_str(&s).unwrap(); + assert_eq!(action, parsed); + } + } + + #[test] + fn serde_round_trip() { + let action = Action::CursorLeft; + let json = serde_json::to_string(&action).unwrap(); + assert_eq!(json, "\"cursor-left\""); + let parsed: Action = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed, Action::CursorLeft); + + let action = Action::AcceptNth(3); + let json = serde_json::to_string(&action).unwrap(); + assert_eq!(json, "\"accept-3\""); + let parsed: Action = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed, Action::AcceptNth(3)); + } +} diff --git a/crates/turtle/src/command/client/search/keybindings/conditions.rs b/crates/turtle/src/command/client/search/keybindings/conditions.rs new file mode 100644 index 00000000..055ae905 --- /dev/null +++ b/crates/turtle/src/command/client/search/keybindings/conditions.rs @@ -0,0 +1,801 @@ +use std::fmt; + +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +/// Atomic (leaf) conditions that can be evaluated against state. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ConditionAtom { + CursorAtStart, + CursorAtEnd, + InputEmpty, + OriginalInputEmpty, + ListAtEnd, + ListAtStart, + NoResults, + HasResults, + HasContext, +} + +/// Boolean expression tree over condition atoms. +/// +/// Supports negation, conjunction, and disjunction with standard precedence: +/// `!` binds tightest, then `&&`, then `||`. +/// +/// Examples of valid expression strings: +/// - `"cursor-at-start"` (bare atom) +/// - `"!no-results"` (negation) +/// - `"cursor-at-start && input-empty"` (conjunction) +/// - `"list-at-start || no-results"` (disjunction) +/// - `"(cursor-at-start && !input-empty) || no-results"` (grouping) +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ConditionExpr { + Atom(ConditionAtom), + Not(Box<ConditionExpr>), + And(Box<ConditionExpr>, Box<ConditionExpr>), + Or(Box<ConditionExpr>, Box<ConditionExpr>), +} + +/// Context needed to evaluate conditions. This is a pure snapshot of state — +/// no references to mutable data. +pub struct EvalContext { + /// Current cursor position (unicode width units). + pub cursor_position: usize, + /// Width of the input string in unicode width units. + pub input_width: usize, + /// Byte length of the input string. + pub input_byte_len: usize, + /// Currently selected index in the results list. + pub selected_index: usize, + /// Total number of results. + pub results_len: usize, + /// Whether the original input (query passed to the TUI) was empty. + pub original_input_empty: bool, + /// Whether we use a search context of a command from the history. + pub has_context: bool, +} + +// --------------------------------------------------------------------------- +// ConditionAtom +// --------------------------------------------------------------------------- + +impl ConditionAtom { + /// Evaluate this atom against the given context. + pub fn evaluate(&self, ctx: &EvalContext) -> bool { + match self { + ConditionAtom::CursorAtStart => ctx.cursor_position == 0, + ConditionAtom::CursorAtEnd => ctx.cursor_position == ctx.input_width, + ConditionAtom::InputEmpty => ctx.input_byte_len == 0, + ConditionAtom::OriginalInputEmpty => ctx.original_input_empty, + ConditionAtom::ListAtEnd => { + ctx.results_len == 0 || ctx.selected_index >= ctx.results_len.saturating_sub(1) + } + ConditionAtom::ListAtStart => ctx.results_len == 0 || ctx.selected_index == 0, + ConditionAtom::NoResults => ctx.results_len == 0, + ConditionAtom::HasResults => ctx.results_len > 0, + ConditionAtom::HasContext => ctx.has_context, + } + } + + /// Parse from a kebab-case string. + pub fn from_str(s: &str) -> Result<Self, String> { + match s { + "cursor-at-start" => Ok(ConditionAtom::CursorAtStart), + "cursor-at-end" => Ok(ConditionAtom::CursorAtEnd), + "input-empty" => Ok(ConditionAtom::InputEmpty), + "original-input-empty" => Ok(ConditionAtom::OriginalInputEmpty), + "list-at-end" => Ok(ConditionAtom::ListAtEnd), + "list-at-start" => Ok(ConditionAtom::ListAtStart), + "no-results" => Ok(ConditionAtom::NoResults), + "has-results" => Ok(ConditionAtom::HasResults), + "has-context" => Ok(ConditionAtom::HasContext), + _ => Err(format!("unknown condition: {s}")), + } + } + + /// Convert to a kebab-case string. + pub fn as_str(&self) -> &'static str { + match self { + ConditionAtom::CursorAtStart => "cursor-at-start", + ConditionAtom::CursorAtEnd => "cursor-at-end", + ConditionAtom::InputEmpty => "input-empty", + ConditionAtom::OriginalInputEmpty => "original-input-empty", + ConditionAtom::ListAtEnd => "list-at-end", + ConditionAtom::ListAtStart => "list-at-start", + ConditionAtom::NoResults => "no-results", + ConditionAtom::HasResults => "has-results", + ConditionAtom::HasContext => "has-context", + } + } +} + +impl fmt::Display for ConditionAtom { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.as_str()) + } +} + +// --------------------------------------------------------------------------- +// ConditionExpr — evaluation +// --------------------------------------------------------------------------- + +impl ConditionExpr { + /// Evaluate this expression against the given context. + pub fn evaluate(&self, ctx: &EvalContext) -> bool { + match self { + ConditionExpr::Atom(atom) => atom.evaluate(ctx), + ConditionExpr::Not(inner) => !inner.evaluate(ctx), + ConditionExpr::And(lhs, rhs) => lhs.evaluate(ctx) && rhs.evaluate(ctx), + ConditionExpr::Or(lhs, rhs) => lhs.evaluate(ctx) || rhs.evaluate(ctx), + } + } +} + +// --------------------------------------------------------------------------- +// ConditionExpr — ergonomic builders +// --------------------------------------------------------------------------- + +impl From<ConditionAtom> for ConditionExpr { + fn from(atom: ConditionAtom) -> Self { + ConditionExpr::Atom(atom) + } +} + +#[expect(dead_code)] +impl ConditionExpr { + /// Negate this expression: `!self`. + pub fn not(self) -> Self { + ConditionExpr::Not(Box::new(self)) + } + + /// Conjoin with another expression: `self && other`. + pub fn and(self, other: ConditionExpr) -> Self { + ConditionExpr::And(Box::new(self), Box::new(other)) + } + + /// Disjoin with another expression: `self || other`. + pub fn or(self, other: ConditionExpr) -> Self { + ConditionExpr::Or(Box::new(self), Box::new(other)) + } +} + +// --------------------------------------------------------------------------- +// ConditionExpr — parser +// --------------------------------------------------------------------------- + +/// Recursive descent parser for boolean condition expressions. +/// +/// Grammar (standard boolean precedence): +/// ```text +/// expr = or_expr +/// or_expr = and_expr ("||" and_expr)* +/// and_expr = unary ("&&" unary)* +/// unary = "!" unary | primary +/// primary = atom | "(" expr ")" +/// atom = [a-z][a-z0-9-]* +/// ``` +struct ExprParser<'a> { + input: &'a str, + pos: usize, +} + +impl<'a> ExprParser<'a> { + fn new(input: &'a str) -> Self { + Self { input, pos: 0 } + } + + fn skip_whitespace(&mut self) { + while self.pos < self.input.len() && self.input.as_bytes()[self.pos].is_ascii_whitespace() { + self.pos += 1; + } + } + + fn starts_with(&mut self, s: &str) -> bool { + self.skip_whitespace(); + self.input[self.pos..].starts_with(s) + } + + fn consume(&mut self, s: &str) -> bool { + self.skip_whitespace(); + if self.input[self.pos..].starts_with(s) { + self.pos += s.len(); + true + } else { + false + } + } + + /// Parse a full expression, expecting to consume all input. + fn parse(mut self) -> Result<ConditionExpr, String> { + let expr = self.parse_or()?; + self.skip_whitespace(); + if self.pos < self.input.len() { + return Err(format!( + "unexpected input at position {}: {:?}", + self.pos, + &self.input[self.pos..] + )); + } + Ok(expr) + } + + /// `or_expr` = `and_expr` ("||" `and_expr`)* + fn parse_or(&mut self) -> Result<ConditionExpr, String> { + let mut left = self.parse_and()?; + while self.starts_with("||") { + self.consume("||"); + let right = self.parse_and()?; + left = ConditionExpr::Or(Box::new(left), Box::new(right)); + } + Ok(left) + } + + /// `and_expr` = unary ("&&" unary)* + fn parse_and(&mut self) -> Result<ConditionExpr, String> { + let mut left = self.parse_unary()?; + while self.starts_with("&&") { + self.consume("&&"); + let right = self.parse_unary()?; + left = ConditionExpr::And(Box::new(left), Box::new(right)); + } + Ok(left) + } + + /// unary = "!" unary | primary + fn parse_unary(&mut self) -> Result<ConditionExpr, String> { + if self.consume("!") { + let inner = self.parse_unary()?; + Ok(ConditionExpr::Not(Box::new(inner))) + } else { + self.parse_primary() + } + } + + /// primary = "(" expr ")" | atom + fn parse_primary(&mut self) -> Result<ConditionExpr, String> { + if self.consume("(") { + let expr = self.parse_or()?; + if !self.consume(")") { + return Err(format!("expected ')' at position {}", self.pos)); + } + Ok(expr) + } else { + self.parse_atom() + } + } + + /// atom = [a-z][a-z0-9-]* + fn parse_atom(&mut self) -> Result<ConditionExpr, String> { + self.skip_whitespace(); + let start = self.pos; + while self.pos < self.input.len() { + let b = self.input.as_bytes()[self.pos]; + if b.is_ascii_lowercase() || b.is_ascii_digit() || b == b'-' { + self.pos += 1; + } else { + break; + } + } + if self.pos == start { + return Err(format!("expected condition name at position {}", self.pos)); + } + let name = &self.input[start..self.pos]; + let atom = ConditionAtom::from_str(name)?; + Ok(ConditionExpr::Atom(atom)) + } +} + +impl ConditionExpr { + /// Parse a condition expression from a string. + pub fn parse(s: &str) -> Result<Self, String> { + let parser = ExprParser::new(s); + parser.parse() + } +} + +// --------------------------------------------------------------------------- +// ConditionExpr — Display +// --------------------------------------------------------------------------- + +/// Precedence levels for minimal-parentheses display. +#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +enum Prec { + Or = 0, + And = 1, + Not = 2, + Atom = 3, +} + +impl ConditionExpr { + fn prec(&self) -> Prec { + match self { + ConditionExpr::Or(..) => Prec::Or, + ConditionExpr::And(..) => Prec::And, + ConditionExpr::Not(..) => Prec::Not, + ConditionExpr::Atom(..) => Prec::Atom, + } + } + + fn fmt_with_prec(&self, f: &mut fmt::Formatter<'_>, parent_prec: Prec) -> fmt::Result { + let needs_parens = self.prec() < parent_prec; + if needs_parens { + write!(f, "(")?; + } + match self { + ConditionExpr::Atom(atom) => write!(f, "{atom}")?, + ConditionExpr::Not(inner) => { + write!(f, "!")?; + inner.fmt_with_prec(f, Prec::Not)?; + } + ConditionExpr::And(lhs, rhs) => { + lhs.fmt_with_prec(f, Prec::And)?; + write!(f, " && ")?; + rhs.fmt_with_prec(f, Prec::And)?; + } + ConditionExpr::Or(lhs, rhs) => { + lhs.fmt_with_prec(f, Prec::Or)?; + write!(f, " || ")?; + rhs.fmt_with_prec(f, Prec::Or)?; + } + } + if needs_parens { + write!(f, ")")?; + } + Ok(()) + } +} + +impl fmt::Display for ConditionExpr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.fmt_with_prec(f, Prec::Or) + } +} + +// --------------------------------------------------------------------------- +// Serde +// --------------------------------------------------------------------------- + +impl Serialize for ConditionExpr { + fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + serializer.serialize_str(&self.to_string()) + } +} + +impl<'de> Deserialize<'de> for ConditionExpr { + fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> { + let s = String::deserialize(deserializer)?; + ConditionExpr::parse(&s).map_err(serde::de::Error::custom) + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn ctx( + cursor: usize, + width: usize, + byte_len: usize, + selected: usize, + len: usize, + ) -> EvalContext { + ctx_with_original(cursor, width, byte_len, selected, len, false) + } + + fn ctx_with_original( + cursor: usize, + width: usize, + byte_len: usize, + selected: usize, + len: usize, + original_input_empty: bool, + ) -> EvalContext { + EvalContext { + cursor_position: cursor, + input_width: width, + input_byte_len: byte_len, + selected_index: selected, + results_len: len, + original_input_empty, + has_context: false, + } + } + + // -- Atom evaluation (carried over from Phase 0) -- + + #[test] + fn atom_cursor_at_start() { + assert!(ConditionAtom::CursorAtStart.evaluate(&ctx(0, 5, 5, 0, 10))); + assert!(!ConditionAtom::CursorAtStart.evaluate(&ctx(3, 5, 5, 0, 10))); + } + + #[test] + fn atom_cursor_at_end() { + assert!(ConditionAtom::CursorAtEnd.evaluate(&ctx(5, 5, 5, 0, 10))); + assert!(!ConditionAtom::CursorAtEnd.evaluate(&ctx(3, 5, 5, 0, 10))); + assert!(ConditionAtom::CursorAtEnd.evaluate(&ctx(0, 0, 0, 0, 10))); + } + + #[test] + fn atom_input_empty() { + assert!(ConditionAtom::InputEmpty.evaluate(&ctx(0, 0, 0, 0, 10))); + assert!(!ConditionAtom::InputEmpty.evaluate(&ctx(0, 5, 5, 0, 10))); + } + + #[test] + fn atom_original_input_empty() { + // original_input_empty = true + assert!( + ConditionAtom::OriginalInputEmpty.evaluate(&ctx_with_original(0, 0, 0, 0, 10, true)) + ); + // original_input_empty = false + assert!( + !ConditionAtom::OriginalInputEmpty.evaluate(&ctx_with_original(0, 0, 0, 0, 10, false)) + ); + // original_input_empty is independent of current input state + assert!( + ConditionAtom::OriginalInputEmpty.evaluate(&ctx_with_original(0, 5, 5, 0, 10, true)) + ); + } + + #[test] + fn atom_list_at_end() { + assert!(ConditionAtom::ListAtEnd.evaluate(&ctx(0, 0, 0, 99, 100))); + assert!(!ConditionAtom::ListAtEnd.evaluate(&ctx(0, 0, 0, 50, 100))); + assert!(ConditionAtom::ListAtEnd.evaluate(&ctx(0, 0, 0, 0, 0))); + } + + #[test] + fn atom_list_at_start() { + assert!(ConditionAtom::ListAtStart.evaluate(&ctx(0, 0, 0, 0, 100))); + assert!(!ConditionAtom::ListAtStart.evaluate(&ctx(0, 0, 0, 50, 100))); + assert!(ConditionAtom::ListAtStart.evaluate(&ctx(0, 0, 0, 0, 0))); + } + + #[test] + fn atom_no_results_and_has_results() { + assert!(ConditionAtom::NoResults.evaluate(&ctx(0, 0, 0, 0, 0))); + assert!(!ConditionAtom::NoResults.evaluate(&ctx(0, 0, 0, 0, 5))); + assert!(ConditionAtom::HasResults.evaluate(&ctx(0, 0, 0, 0, 5))); + assert!(!ConditionAtom::HasResults.evaluate(&ctx(0, 0, 0, 0, 0))); + } + + #[test] + fn atom_has_context() { + let mut context = ctx(0, 0, 0, 0, 0); + assert!(!ConditionAtom::HasContext.evaluate(&context)); + context.has_context = true; + assert!(ConditionAtom::HasContext.evaluate(&context)); + } + + #[test] + fn atom_parse_round_trip() { + let conditions = [ + "cursor-at-start", + "cursor-at-end", + "input-empty", + "original-input-empty", + "list-at-end", + "list-at-start", + "no-results", + "has-results", + ]; + for s in conditions { + let c = ConditionAtom::from_str(s).unwrap(); + assert_eq!(c.as_str(), s); + } + } + + #[test] + fn atom_parse_unknown() { + assert!(ConditionAtom::from_str("unknown-condition").is_err()); + } + + // -- Parser tests -- + + #[test] + fn parse_bare_atom() { + let expr = ConditionExpr::parse("cursor-at-start").unwrap(); + assert_eq!(expr, ConditionExpr::Atom(ConditionAtom::CursorAtStart)); + } + + #[test] + fn parse_negation() { + let expr = ConditionExpr::parse("!no-results").unwrap(); + assert_eq!( + expr, + ConditionExpr::Not(Box::new(ConditionExpr::Atom(ConditionAtom::NoResults))) + ); + } + + #[test] + fn parse_double_negation() { + let expr = ConditionExpr::parse("!!no-results").unwrap(); + assert_eq!( + expr, + ConditionExpr::Not(Box::new(ConditionExpr::Not(Box::new(ConditionExpr::Atom( + ConditionAtom::NoResults + ))))) + ); + } + + #[test] + fn parse_and() { + let expr = ConditionExpr::parse("cursor-at-start && input-empty").unwrap(); + assert_eq!( + expr, + ConditionExpr::And( + Box::new(ConditionExpr::Atom(ConditionAtom::CursorAtStart)), + Box::new(ConditionExpr::Atom(ConditionAtom::InputEmpty)), + ) + ); + } + + #[test] + fn parse_or() { + let expr = ConditionExpr::parse("list-at-start || no-results").unwrap(); + assert_eq!( + expr, + ConditionExpr::Or( + Box::new(ConditionExpr::Atom(ConditionAtom::ListAtStart)), + Box::new(ConditionExpr::Atom(ConditionAtom::NoResults)), + ) + ); + } + + #[test] + fn parse_precedence_and_binds_tighter_than_or() { + // "a || b && c" should parse as "a || (b && c)" + let expr = ConditionExpr::parse("cursor-at-start || input-empty && no-results").unwrap(); + assert_eq!( + expr, + ConditionExpr::Or( + Box::new(ConditionExpr::Atom(ConditionAtom::CursorAtStart)), + Box::new(ConditionExpr::And( + Box::new(ConditionExpr::Atom(ConditionAtom::InputEmpty)), + Box::new(ConditionExpr::Atom(ConditionAtom::NoResults)), + )), + ) + ); + } + + #[test] + fn parse_parens_override_precedence() { + // "(a || b) && c" + let expr = ConditionExpr::parse("(cursor-at-start || input-empty) && no-results").unwrap(); + assert_eq!( + expr, + ConditionExpr::And( + Box::new(ConditionExpr::Or( + Box::new(ConditionExpr::Atom(ConditionAtom::CursorAtStart)), + Box::new(ConditionExpr::Atom(ConditionAtom::InputEmpty)), + )), + Box::new(ConditionExpr::Atom(ConditionAtom::NoResults)), + ) + ); + } + + #[test] + fn parse_complex_nested() { + // "(a && !b) || c" + let expr = ConditionExpr::parse("(cursor-at-start && !input-empty) || no-results").unwrap(); + assert_eq!( + expr, + ConditionExpr::Or( + Box::new(ConditionExpr::And( + Box::new(ConditionExpr::Atom(ConditionAtom::CursorAtStart)), + Box::new(ConditionExpr::Not(Box::new(ConditionExpr::Atom( + ConditionAtom::InputEmpty + )))), + )), + Box::new(ConditionExpr::Atom(ConditionAtom::NoResults)), + ) + ); + } + + #[test] + fn parse_whitespace_tolerance() { + let a = ConditionExpr::parse("cursor-at-start||input-empty").unwrap(); + let b = ConditionExpr::parse("cursor-at-start || input-empty").unwrap(); + let c = ConditionExpr::parse(" cursor-at-start || input-empty ").unwrap(); + assert_eq!(a, b); + assert_eq!(b, c); + } + + #[test] + fn parse_error_unknown_atom() { + assert!(ConditionExpr::parse("unknown-thing").is_err()); + } + + #[test] + fn parse_error_trailing_input() { + assert!(ConditionExpr::parse("cursor-at-start blah").is_err()); + } + + #[test] + fn parse_error_unmatched_paren() { + assert!(ConditionExpr::parse("(cursor-at-start").is_err()); + } + + #[test] + fn parse_error_empty() { + assert!(ConditionExpr::parse("").is_err()); + } + + // -- Expression evaluation -- + + #[test] + fn eval_not() { + let expr = ConditionExpr::parse("!no-results").unwrap(); + // Has results → !no-results is true + assert!(expr.evaluate(&ctx(0, 0, 0, 0, 5))); + // No results → !no-results is false + assert!(!expr.evaluate(&ctx(0, 0, 0, 0, 0))); + } + + #[test] + fn eval_and() { + let expr = ConditionExpr::parse("cursor-at-start && input-empty").unwrap(); + // Both true + assert!(expr.evaluate(&ctx(0, 0, 0, 0, 10))); + // First true, second false (non-empty input) + assert!(!expr.evaluate(&ctx(0, 5, 5, 0, 10))); + // First false (cursor not at start) + assert!(!expr.evaluate(&ctx(3, 5, 5, 0, 10))); + } + + #[test] + fn eval_or() { + let expr = ConditionExpr::parse("list-at-start || no-results").unwrap(); + // list at bottom (selected=0) + assert!(expr.evaluate(&ctx(0, 0, 0, 0, 10))); + // no results + assert!(expr.evaluate(&ctx(0, 0, 0, 0, 0))); + // neither + assert!(!expr.evaluate(&ctx(0, 0, 0, 5, 10))); + } + + #[test] + fn eval_complex_nested() { + // (cursor-at-start && !input-empty) || no-results + let expr = ConditionExpr::parse("(cursor-at-start && !input-empty) || no-results").unwrap(); + + // cursor at start, input not empty → true (left branch) + assert!(expr.evaluate(&ctx(0, 5, 5, 0, 10))); + // no results → true (right branch) + assert!(expr.evaluate(&ctx(3, 5, 5, 0, 0))); + // cursor not at start, has results → false + assert!(!expr.evaluate(&ctx(3, 5, 5, 0, 10))); + // cursor at start, input empty → false (left: && fails; right: has results) + assert!(!expr.evaluate(&ctx(0, 0, 0, 0, 10))); + } + + // -- Display -- + + #[test] + fn display_atom() { + let expr = ConditionExpr::Atom(ConditionAtom::CursorAtStart); + assert_eq!(expr.to_string(), "cursor-at-start"); + } + + #[test] + fn display_not() { + let expr = ConditionExpr::Atom(ConditionAtom::NoResults).not(); + assert_eq!(expr.to_string(), "!no-results"); + } + + #[test] + fn display_and() { + let expr = ConditionExpr::Atom(ConditionAtom::CursorAtStart) + .and(ConditionExpr::Atom(ConditionAtom::InputEmpty)); + assert_eq!(expr.to_string(), "cursor-at-start && input-empty"); + } + + #[test] + fn display_or() { + let expr = ConditionExpr::Atom(ConditionAtom::ListAtStart) + .or(ConditionExpr::Atom(ConditionAtom::NoResults)); + assert_eq!(expr.to_string(), "list-at-start || no-results"); + } + + #[test] + fn display_parens_when_needed() { + // (a || b) && c — the Or inside And needs parens + let expr = ConditionExpr::Atom(ConditionAtom::CursorAtStart) + .or(ConditionExpr::Atom(ConditionAtom::InputEmpty)) + .and(ConditionExpr::Atom(ConditionAtom::NoResults)); + assert_eq!( + expr.to_string(), + "(cursor-at-start || input-empty) && no-results" + ); + } + + #[test] + fn display_no_parens_when_not_needed() { + // a || b && c — no parens needed (and binds tighter) + let inner_and = ConditionExpr::Atom(ConditionAtom::InputEmpty) + .and(ConditionExpr::Atom(ConditionAtom::NoResults)); + let expr = ConditionExpr::Atom(ConditionAtom::CursorAtStart).or(inner_and); + assert_eq!( + expr.to_string(), + "cursor-at-start || input-empty && no-results" + ); + } + + // -- Display round-trip -- + + #[test] + fn display_round_trip() { + let cases = [ + "cursor-at-start", + "!no-results", + "cursor-at-start && input-empty", + "list-at-start || no-results", + "(cursor-at-start || input-empty) && no-results", + "(cursor-at-start && !input-empty) || no-results", + ]; + for s in cases { + let expr = ConditionExpr::parse(s).unwrap(); + let displayed = expr.to_string(); + let reparsed = ConditionExpr::parse(&displayed).unwrap(); + assert_eq!(expr, reparsed, "round-trip failed for: {s}"); + } + } + + // -- Serde -- + + #[test] + fn serde_simple_atom() { + let expr = ConditionExpr::Atom(ConditionAtom::CursorAtStart); + let json = serde_json::to_string(&expr).unwrap(); + assert_eq!(json, "\"cursor-at-start\""); + let parsed: ConditionExpr = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed, expr); + } + + #[test] + fn serde_compound_expression() { + let json = "\"cursor-at-start && !input-empty\""; + let parsed: ConditionExpr = serde_json::from_str(json).unwrap(); + let expected = ConditionExpr::And( + Box::new(ConditionExpr::Atom(ConditionAtom::CursorAtStart)), + Box::new(ConditionExpr::Not(Box::new(ConditionExpr::Atom( + ConditionAtom::InputEmpty, + )))), + ); + assert_eq!(parsed, expected); + } + + #[test] + fn serde_round_trip() { + let expr = ConditionExpr::parse("(cursor-at-start && !input-empty) || no-results").unwrap(); + let json = serde_json::to_string(&expr).unwrap(); + let parsed: ConditionExpr = serde_json::from_str(&json).unwrap(); + assert_eq!(expr, parsed); + } + + // -- From<ConditionAtom> -- + + #[test] + fn from_atom_into_expr() { + let expr: ConditionExpr = ConditionAtom::CursorAtStart.into(); + assert_eq!(expr, ConditionExpr::Atom(ConditionAtom::CursorAtStart)); + } + + // -- Builder helpers -- + + #[test] + fn builder_chain() { + let expr = ConditionExpr::from(ConditionAtom::CursorAtStart) + .and(ConditionExpr::from(ConditionAtom::InputEmpty).not()) + .or(ConditionExpr::from(ConditionAtom::NoResults)); + // And binds tighter than Or, so no parens needed around the And + assert_eq!( + expr.to_string(), + "cursor-at-start && !input-empty || no-results" + ); + } +} diff --git a/crates/turtle/src/command/client/search/keybindings/defaults.rs b/crates/turtle/src/command/client/search/keybindings/defaults.rs new file mode 100644 index 00000000..c8401e37 --- /dev/null +++ b/crates/turtle/src/command/client/search/keybindings/defaults.rs @@ -0,0 +1,1286 @@ +use std::collections::HashMap; + +use crate::atuin_client::settings::{KeyBindingConfig, Settings}; +use tracing::warn; + +use super::actions::Action; +use super::conditions::{ConditionAtom, ConditionExpr}; +use super::key::KeyInput; +use super::keymap::{KeyBinding, KeyRule, Keymap}; + +/// Helper to bind a scroll key with optional exit behavior. +/// +/// When `scroll_exits` is true AND the key scrolls toward index 0 (the newest +/// entry), we add a conditional rule: at `ListAtStart` → `Exit`, otherwise → +/// the scroll action. +/// +/// Whether a key scrolls toward index 0 depends on the `invert` setting: +/// - Non-inverted: "down" / "j" move toward index 0, "up" / "k" move away +/// - Inverted: "up" / "k" move toward index 0, "down" / "j" move away +/// +/// If `toward_index_zero` is false, or `scroll_exits` is false, we just bind +/// the key to the plain scroll action (no exit). +fn bind_scroll_key( + km: &mut Keymap, + key_str: &str, + action: Action, + toward_index_zero: bool, + scroll_exits: bool, +) { + let k = key(key_str); + if scroll_exits && toward_index_zero { + km.bind_conditional( + k, + vec![ + KeyRule::when(ConditionAtom::ListAtStart, Action::Exit), + KeyRule::always(action), + ], + ); + } else { + km.bind(k, action); + } +} + +/// Helper to parse a key string, panicking on invalid keys (these are all +/// compile-time-known strings). +fn key(s: &str) -> KeyInput { + KeyInput::parse(s).unwrap_or_else(|e| panic!("invalid default key {s:?}: {e}")) +} + +/// All five keymaps bundled together. +#[derive(Debug, Clone)] +pub struct KeymapSet { + pub emacs: Keymap, + pub vim_normal: Keymap, + pub vim_insert: Keymap, + pub inspector: Keymap, + pub prefix: Keymap, +} + +// --------------------------------------------------------------------------- +// Common bindings shared across search-tab keymaps +// --------------------------------------------------------------------------- + +/// Add the bindings that are common to all search-tab keymaps: +/// ctrl-c, ctrl-g, ctrl-o, and tab. +/// +/// Note: `esc`/`ctrl-[` are NOT included here because their behavior differs +/// between emacs (exit), vim-normal (exit), and vim-insert (enter normal mode). +fn add_common_bindings(km: &mut Keymap) { + km.bind(key("ctrl-c"), Action::ReturnOriginal); + km.bind(key("ctrl-g"), Action::ReturnOriginal); + km.bind(key("ctrl-o"), Action::ToggleTab); + + // Tab: always returns selection without executing (unlike Enter which respects enter_accept) + km.bind(key("tab"), Action::ReturnSelection); +} + +/// Returns `Accept` or `ReturnSelection` based on the `enter_accept` setting. +fn accept_action(settings: &Settings) -> Action { + if settings.enter_accept { + Action::Accept + } else { + Action::ReturnSelection + } +} + +// --------------------------------------------------------------------------- +// Emacs keymap (also base for vim-insert) +// --------------------------------------------------------------------------- + +/// Build the default emacs keymap. This encodes the behavior from +/// `handle_key_input` common section + `handle_search_input` shared section. +/// +/// The `settings` parameter is used for: +/// - `keys.prefix` — which ctrl-key enters prefix mode +/// - `keys.scroll_exits`, `invert` — scroll-at-boundary exit behavior +/// - `keys.accept_past_line_end` — right arrow at end of line accepts +/// - `keys.exit_past_line_start` — left arrow at start of line exits +/// - `keys.accept_past_line_start` — left arrow at start accepts (overrides exit) +/// - `keys.accept_with_backspace` — backspace at start of line accepts +/// - `ctrl_n_shortcuts` — whether alt or ctrl is used for numeric shortcuts +// Keymap builder that enumerates every default binding; not worth splitting. +#[expect(clippy::too_many_lines)] +pub fn default_emacs_keymap(settings: &Settings) -> Keymap { + let mut km = Keymap::new(); + add_common_bindings(&mut km); + + let accept = accept_action(settings); + + // esc / ctrl-[ → exit + km.bind(key("esc"), Action::Exit); + km.bind(key("ctrl-["), Action::Exit); + + // Prefix key: ctrl-<prefix_char> → enter prefix mode + let prefix_char = settings.keys.prefix.chars().next().unwrap_or('a'); + km.bind(key(&format!("ctrl-{prefix_char}")), Action::EnterPrefixMode); + + // --- Accept / navigation edge behaviors (from [keys] settings) --- + + // right: behavior at end of line + if settings.keys.accept_past_line_end { + km.bind_conditional( + key("right"), + vec![ + KeyRule::when(ConditionAtom::CursorAtEnd, Action::ReturnSelection), + KeyRule::always(Action::CursorRight), + ], + ); + } else { + km.bind(key("right"), Action::CursorRight); + } + + // left: behavior at start of line + // accept_past_line_start takes precedence over exit_past_line_start + if settings.keys.accept_past_line_start { + km.bind_conditional( + key("left"), + vec![ + KeyRule::when(ConditionAtom::CursorAtStart, Action::ReturnSelection), + KeyRule::always(Action::CursorLeft), + ], + ); + } else if settings.keys.exit_past_line_start { + km.bind_conditional( + key("left"), + vec![ + KeyRule::when(ConditionAtom::CursorAtStart, Action::Exit), + KeyRule::always(Action::CursorLeft), + ], + ); + } else { + km.bind(key("left"), Action::CursorLeft); + } + + // down/up: scroll with optional exit at boundary. + // Non-inverted: down moves toward index 0 (can exit); up moves away (no exit). + // Inverted: up moves toward index 0 (can exit); down moves away (no exit). + let scroll_exits = settings.keys.scroll_exits; + let invert = settings.invert; + bind_scroll_key(&mut km, "down", Action::SelectNext, !invert, scroll_exits); + bind_scroll_key(&mut km, "up", Action::SelectPrevious, invert, scroll_exits); + + // backspace: behavior at start of line + if settings.keys.accept_with_backspace { + km.bind_conditional( + key("backspace"), + vec![ + KeyRule::when(ConditionAtom::CursorAtStart, Action::ReturnSelection), + KeyRule::always(Action::DeleteCharBefore), + ], + ); + } else { + km.bind(key("backspace"), Action::DeleteCharBefore); + } + + // --- Accept --- + km.bind(key("enter"), accept.clone()); + km.bind(key("ctrl-m"), accept); + + // --- Copy --- + km.bind(key("ctrl-y"), Action::Copy); + + // --- Numeric shortcuts (alt-1..9 by default, ctrl-1..9 if ctrl_n_shortcuts) --- + // These return the selection without executing, regardless of enter_accept. + let num_mod = if settings.ctrl_n_shortcuts { + "ctrl" + } else { + "alt" + }; + for n in 1..=9u8 { + km.bind( + key(&format!("{num_mod}-{n}")), + Action::ReturnSelectionNth(n), + ); + } + + // --- Cursor movement --- + km.bind(key("ctrl-left"), Action::CursorWordLeft); + km.bind(key("alt-b"), Action::CursorWordLeft); + km.bind(key("ctrl-b"), Action::CursorLeft); + km.bind(key("ctrl-right"), Action::CursorWordRight); + km.bind(key("alt-f"), Action::CursorWordRight); + km.bind(key("ctrl-f"), Action::CursorRight); + km.bind(key("home"), Action::CursorStart); + // ctrl-a → CursorStart only if prefix char is NOT 'a' + // (otherwise ctrl-a is already bound to EnterPrefixMode above) + if prefix_char != 'a' { + km.bind(key("ctrl-a"), Action::CursorStart); + } + km.bind(key("ctrl-e"), Action::CursorEnd); + km.bind(key("end"), Action::CursorEnd); + + // --- Editing --- + km.bind(key("ctrl-backspace"), Action::DeleteWordBefore); + km.bind(key("ctrl-h"), Action::DeleteCharBefore); + km.bind(key("ctrl-?"), Action::DeleteCharBefore); + km.bind(key("ctrl-delete"), Action::DeleteWordAfter); + km.bind(key("delete"), Action::DeleteCharAfter); + // ctrl-d: if input empty → return original, otherwise delete char + km.bind_conditional( + key("ctrl-d"), + vec![ + KeyRule::when(ConditionAtom::InputEmpty, Action::ReturnOriginal), + KeyRule::always(Action::DeleteCharAfter), + ], + ); + km.bind(key("ctrl-w"), Action::DeleteToWordBoundary); + km.bind(key("ctrl-u"), Action::ClearLine); + + // --- Search mode --- + km.bind(key("ctrl-r"), Action::CycleFilterMode); + km.bind(key("ctrl-s"), Action::CycleSearchMode); + + // --- Scroll (no exit) --- + km.bind(key("ctrl-n"), Action::SelectNext); + km.bind(key("ctrl-j"), Action::SelectNext); + km.bind(key("ctrl-p"), Action::SelectPrevious); + km.bind(key("ctrl-k"), Action::SelectPrevious); + + // --- Redraw --- + km.bind(key("ctrl-l"), Action::Redraw); + + // --- Page scroll --- + km.bind(key("pagedown"), Action::ScrollPageDown); + km.bind(key("pageup"), Action::ScrollPageUp); + + km +} + +// --------------------------------------------------------------------------- +// Vim Normal keymap +// --------------------------------------------------------------------------- + +/// Build the default vim-normal keymap. +pub fn default_vim_normal_keymap(settings: &Settings) -> Keymap { + let mut km = Keymap::new(); + add_common_bindings(&mut km); + + // esc / ctrl-[ → exit (vim-normal exits, unlike vim-insert) + km.bind(key("esc"), Action::Exit); + km.bind(key("ctrl-["), Action::Exit); + + // Prefix key + let prefix_char = settings.keys.prefix.chars().next().unwrap_or('a'); + km.bind(key(&format!("ctrl-{prefix_char}")), Action::EnterPrefixMode); + + // --- Vim navigation --- + // j/k: scroll with optional exit at boundary. + let scroll_exits = settings.keys.scroll_exits; + let invert = settings.invert; + bind_scroll_key(&mut km, "j", Action::SelectNext, !invert, scroll_exits); + bind_scroll_key(&mut km, "k", Action::SelectPrevious, invert, scroll_exits); + km.bind(key("h"), Action::CursorLeft); + km.bind(key("l"), Action::CursorRight); + + // --- Vim cursor movement --- + km.bind(key("0"), Action::CursorStart); + km.bind(key("$"), Action::CursorEnd); + km.bind(key("w"), Action::CursorWordRight); + km.bind(key("b"), Action::CursorWordLeft); + km.bind(key("e"), Action::CursorWordEnd); + + // --- Vim editing --- + km.bind(key("x"), Action::DeleteCharAfter); + km.bind(key("d d"), Action::ClearLine); + km.bind(key("D"), Action::ClearToEnd); + km.bind(key("C"), Action::VimChangeToEnd); + + // --- Mode switching --- + km.bind(key("?"), Action::VimSearchInsert); + km.bind(key("/"), Action::VimSearchInsert); + km.bind(key("a"), Action::VimEnterInsertAfter); + km.bind(key("A"), Action::VimEnterInsertAtEnd); + km.bind(key("i"), Action::VimEnterInsert); + km.bind(key("I"), Action::VimEnterInsertAtStart); + + // --- Numeric shortcuts (return selection without executing) --- + for n in 1..=9u8 { + km.bind(key(&n.to_string()), Action::ReturnSelectionNth(n)); + } + + // --- Half/full page scroll --- + km.bind(key("ctrl-u"), Action::ScrollHalfPageUp); + km.bind(key("ctrl-d"), Action::ScrollHalfPageDown); + km.bind(key("ctrl-b"), Action::ScrollPageUp); + km.bind(key("ctrl-f"), Action::ScrollPageDown); + + // --- Jump --- + km.bind(key("G"), Action::ScrollToBottom); + km.bind(key("g g"), Action::ScrollToTop); + km.bind(key("H"), Action::ScrollToScreenTop); + km.bind(key("M"), Action::ScrollToScreenMiddle); + km.bind(key("L"), Action::ScrollToScreenBottom); + + // --- Arrow keys (same as emacs for convenience) --- + bind_scroll_key(&mut km, "down", Action::SelectNext, !invert, scroll_exits); + bind_scroll_key(&mut km, "up", Action::SelectPrevious, invert, scroll_exits); + + // --- Page scroll --- + km.bind(key("pagedown"), Action::ScrollPageDown); + km.bind(key("pageup"), Action::ScrollPageUp); + + // --- Accept --- + let accept = accept_action(settings); + km.bind(key("enter"), accept); + + km +} + +// --------------------------------------------------------------------------- +// Vim Insert keymap +// --------------------------------------------------------------------------- + +/// Build the default vim-insert keymap. This clones the emacs keymap and +/// overlays vim-insert-specific bindings (esc → enter normal mode). +pub fn default_vim_insert_keymap(settings: &Settings) -> Keymap { + let mut km = default_emacs_keymap(settings); + + // Override esc and ctrl-[ to enter normal mode instead of exiting + km.bind(key("esc"), Action::VimEnterNormal); + km.bind(key("ctrl-["), Action::VimEnterNormal); + + km +} + +// --------------------------------------------------------------------------- +// Inspector keymap +// --------------------------------------------------------------------------- + +/// Build the default inspector keymap (tab index 1). +/// +/// The inspector shows details about the selected history item and has no +/// text input, so we build a minimal keymap with only inspector-relevant +/// bindings. We respect the user's `keymap_mode` to provide vim-style j/k +/// navigation for vim users. +pub fn default_inspector_keymap(settings: &Settings) -> Keymap { + use crate::atuin_client::settings::KeymapMode; + + let mut km = Keymap::new(); + + // Common bindings (same as search tab) + km.bind(key("ctrl-c"), Action::ReturnOriginal); + km.bind(key("ctrl-g"), Action::ReturnOriginal); + km.bind(key("esc"), Action::Exit); + km.bind(key("ctrl-["), Action::Exit); + km.bind(key("tab"), Action::ReturnSelection); + km.bind(key("ctrl-o"), Action::ToggleTab); + + // Accept behavior respects enter_accept setting + let accept = if settings.enter_accept { + Action::Accept + } else { + Action::ReturnSelection + }; + km.bind(key("enter"), accept); + + // Inspector-specific: delete history entry + km.bind(key("ctrl-d"), Action::Delete); + + // Inspector navigation + km.bind(key("up"), Action::InspectPrevious); + km.bind(key("down"), Action::InspectNext); + km.bind(key("pageup"), Action::InspectPrevious); + km.bind(key("pagedown"), Action::InspectNext); + + // For vim users, add j/k navigation + if matches!( + settings.keymap_mode, + KeymapMode::VimNormal | KeymapMode::VimInsert + ) { + km.bind(key("j"), Action::InspectNext); + km.bind(key("k"), Action::InspectPrevious); + } + + km +} + +// --------------------------------------------------------------------------- +// Prefix keymap +// --------------------------------------------------------------------------- + +/// Build the default prefix keymap (active after ctrl-a prefix). +pub fn default_prefix_keymap() -> Keymap { + let mut km = Keymap::new(); + + km.bind(key("d"), Action::Delete); + km.bind(key("D"), Action::DeleteAll); + km.bind(key("a"), Action::CursorStart); + km.bind_conditional( + key("c"), + vec![ + KeyRule::when(ConditionAtom::HasContext, Action::ClearContext), + KeyRule::always(Action::SwitchContext), + ], + ); + + km +} + +// --------------------------------------------------------------------------- +// KeymapSet construction +// --------------------------------------------------------------------------- + +// --------------------------------------------------------------------------- +// Config → Keymap conversion +// --------------------------------------------------------------------------- + +/// Convert a `KeyBindingConfig` (from TOML) into a `KeyBinding`. +/// Returns `Err` if an action name or condition expression is invalid. +fn parse_binding_config(config: &KeyBindingConfig) -> Result<KeyBinding, String> { + match config { + KeyBindingConfig::Simple(action_str) => { + let action = Action::from_str(action_str)?; + Ok(KeyBinding::simple(action)) + } + KeyBindingConfig::Rules(rules) => { + let mut parsed_rules = Vec::with_capacity(rules.len()); + for rule_cfg in rules { + let action = Action::from_str(&rule_cfg.action)?; + let rule = match &rule_cfg.when { + None => KeyRule::always(action), + Some(cond_str) => { + let cond = ConditionExpr::parse(cond_str)?; + KeyRule::when(cond, action) + } + }; + parsed_rules.push(rule); + } + Ok(KeyBinding::conditional(parsed_rules)) + } + } +} + +/// Apply a map of key-string → binding-config overrides to a keymap. +/// Per-key override replaces the entire rule list for that key. +/// Invalid keys or action names are logged and skipped. +fn apply_config_to_keymap(keymap: &mut Keymap, overrides: &HashMap<String, KeyBindingConfig>) { + for (key_str, binding_cfg) in overrides { + let key = match KeyInput::parse(key_str) { + Ok(k) => k, + Err(e) => { + warn!("invalid key in keymap config: {key_str:?}: {e}"); + continue; + } + }; + match parse_binding_config(binding_cfg) { + Ok(binding) => { + keymap.bindings.insert(key, binding); + } + Err(e) => { + warn!("invalid binding for {key_str:?} in keymap config: {e}"); + } + } + } +} + +impl KeymapSet { + /// Build the complete set of default keymaps from settings. + pub fn defaults(settings: &Settings) -> Self { + KeymapSet { + emacs: default_emacs_keymap(settings), + vim_normal: default_vim_normal_keymap(settings), + vim_insert: default_vim_insert_keymap(settings), + inspector: default_inspector_keymap(settings), + prefix: default_prefix_keymap(), + } + } + + /// Build keymaps from settings, applying any user `[keymap]` overrides. + /// + /// Precedence rules: + /// - If `[keymap]` has any entries, `[keys]` is **ignored entirely**. + /// Defaults are built with standard `[keys]` values, then `[keymap]` + /// overrides are applied per-key. + /// - If `[keymap]` is empty/absent, `[keys]` customizes the defaults + /// (current behavior for backward compatibility). + pub fn from_settings(settings: &Settings) -> Self { + use crate::atuin_client::settings::Keys; + + if settings.keymap.is_empty() { + // No [keymap] section → use [keys] to customize defaults + Self::defaults(settings) + } else { + // [keymap] present → ignore [keys], use standard defaults as base + let mut base_settings = settings.clone(); + base_settings.keys = Keys::standard_defaults(); + let mut set = Self::defaults(&base_settings); + set.apply_config(settings); + set + } + } + + /// Apply user keymap config overrides to all modes. + fn apply_config(&mut self, settings: &Settings) { + let config = &settings.keymap; + apply_config_to_keymap(&mut self.emacs, &config.emacs); + apply_config_to_keymap(&mut self.vim_normal, &config.vim_normal); + apply_config_to_keymap(&mut self.vim_insert, &config.vim_insert); + apply_config_to_keymap(&mut self.inspector, &config.inspector); + apply_config_to_keymap(&mut self.prefix, &config.prefix); + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use crate::command::client::search::keybindings::conditions::EvalContext; + + fn make_ctx(cursor: usize, width: usize, selected: usize, len: usize) -> EvalContext { + EvalContext { + cursor_position: cursor, + input_width: width, + input_byte_len: width, + selected_index: selected, + results_len: len, + original_input_empty: false, + has_context: false, + } + } + + fn default_settings() -> Settings { + Settings::utc() + } + + // -- Emacs keymap tests -- + + #[test] + fn emacs_ctrl_c_returns_original() { + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + km.resolve(&key("ctrl-c"), &ctx), + Some(Action::ReturnOriginal) + ); + } + + #[test] + fn emacs_esc_exits() { + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("esc"), &ctx), Some(Action::Exit)); + } + + #[test] + fn emacs_tab_returns_selection() { + // enter_accept=false in test defaults → ReturnSelection + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("tab"), &ctx), Some(Action::ReturnSelection)); + } + + #[test] + fn emacs_enter_returns_selection() { + // enter_accept=false in test defaults → ReturnSelection + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + km.resolve(&key("enter"), &ctx), + Some(Action::ReturnSelection) + ); + } + + #[test] + fn emacs_enter_accept_true_uses_accept() { + let mut settings = default_settings(); + settings.enter_accept = true; + let km = default_emacs_keymap(&settings); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("enter"), &ctx), Some(Action::Accept)); + assert_eq!(km.resolve(&key("tab"), &ctx), Some(Action::ReturnSelection)); + } + + #[test] + fn emacs_right_at_end_returns_selection() { + let km = default_emacs_keymap(&default_settings()); + // cursor at end of "hello" (width 5) + let ctx = make_ctx(5, 5, 0, 10); + assert_eq!( + km.resolve(&key("right"), &ctx), + Some(Action::ReturnSelection) + ); + } + + #[test] + fn emacs_right_not_at_end_moves() { + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(2, 5, 0, 10); + assert_eq!(km.resolve(&key("right"), &ctx), Some(Action::CursorRight)); + } + + #[test] + fn emacs_left_at_start_exits() { + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(0, 5, 0, 10); + assert_eq!(km.resolve(&key("left"), &ctx), Some(Action::Exit)); + } + + #[test] + fn emacs_left_not_at_start_moves() { + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(3, 5, 0, 10); + assert_eq!(km.resolve(&key("left"), &ctx), Some(Action::CursorLeft)); + } + + #[test] + fn emacs_down_at_start_exits() { + let km = default_emacs_keymap(&default_settings()); + // selected=0 → ListAtStart → Exit + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("down"), &ctx), Some(Action::Exit)); + } + + #[test] + fn emacs_down_not_at_start_selects_next() { + let km = default_emacs_keymap(&default_settings()); + // selected=5 → not at start → SelectNext + let ctx = make_ctx(0, 0, 5, 10); + assert_eq!(km.resolve(&key("down"), &ctx), Some(Action::SelectNext)); + } + + #[test] + fn emacs_up_selects_previous() { + let km = default_emacs_keymap(&default_settings()); + // Non-inverted: up never exits (moves away from index 0) + let ctx = make_ctx(0, 0, 5, 10); + assert_eq!(km.resolve(&key("up"), &ctx), Some(Action::SelectPrevious)); + } + + #[test] + fn emacs_ctrl_d_empty_returns_original() { + let km = default_emacs_keymap(&default_settings()); + // input empty (byte_len = 0) + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + km.resolve(&key("ctrl-d"), &ctx), + Some(Action::ReturnOriginal) + ); + } + + #[test] + fn emacs_ctrl_d_nonempty_deletes() { + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(2, 5, 0, 10); + assert_eq!( + km.resolve(&key("ctrl-d"), &ctx), + Some(Action::DeleteCharAfter) + ); + } + + #[test] + fn emacs_ctrl_n_selects_next_no_exit_condition() { + let km = default_emacs_keymap(&default_settings()); + // at start, but ctrl-n should NOT exit (no exit condition bound) + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("ctrl-n"), &ctx), Some(Action::SelectNext)); + } + + #[test] + fn emacs_prefix_key_enters_prefix() { + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + km.resolve(&key("ctrl-a"), &ctx), + Some(Action::EnterPrefixMode) + ); + } + + #[test] + fn emacs_home_cursor_start() { + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(5, 10, 0, 10); + assert_eq!(km.resolve(&key("home"), &ctx), Some(Action::CursorStart)); + } + + // -- Vim Normal keymap tests -- + + #[test] + fn vim_normal_j_at_start_exits() { + let km = default_vim_normal_keymap(&default_settings()); + // selected=0 → ListAtStart → Exit (non-inverted: j moves toward index 0) + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("j"), &ctx), Some(Action::Exit)); + } + + #[test] + fn vim_normal_j_not_at_start_selects_next() { + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 5, 10); + assert_eq!(km.resolve(&key("j"), &ctx), Some(Action::SelectNext)); + } + + #[test] + fn vim_normal_k_selects_previous() { + let km = default_vim_normal_keymap(&default_settings()); + // Non-inverted: k never exits (moves away from index 0) + let ctx = make_ctx(0, 0, 5, 10); + assert_eq!(km.resolve(&key("k"), &ctx), Some(Action::SelectPrevious)); + } + + #[test] + fn vim_normal_i_enters_insert() { + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("i"), &ctx), Some(Action::VimEnterInsert)); + } + + #[test] + fn vim_normal_slash_search_insert() { + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("/"), &ctx), Some(Action::VimSearchInsert)); + } + + #[test] + fn vim_normal_gg_scroll_to_top() { + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 50, 100); + assert_eq!(km.resolve(&key("g g"), &ctx), Some(Action::ScrollToTop)); + } + + #[test] + fn vim_normal_big_g_scroll_to_bottom() { + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 50, 100); + assert_eq!(km.resolve(&key("G"), &ctx), Some(Action::ScrollToBottom)); + } + + #[test] + fn vim_normal_numeric_returns_selection() { + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + km.resolve(&key("3"), &ctx), + Some(Action::ReturnSelectionNth(3)) + ); + } + + #[test] + fn vim_normal_ctrl_u_half_page_up() { + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 50, 100); + assert_eq!( + km.resolve(&key("ctrl-u"), &ctx), + Some(Action::ScrollHalfPageUp) + ); + } + + #[test] + fn vim_normal_screen_jumps() { + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 50, 100); + assert_eq!(km.resolve(&key("H"), &ctx), Some(Action::ScrollToScreenTop)); + assert_eq!( + km.resolve(&key("M"), &ctx), + Some(Action::ScrollToScreenMiddle) + ); + assert_eq!( + km.resolve(&key("L"), &ctx), + Some(Action::ScrollToScreenBottom) + ); + } + + #[test] + fn vim_normal_enter_returns_selection() { + // enter_accept=false in test defaults → ReturnSelection + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + km.resolve(&key("enter"), &ctx), + Some(Action::ReturnSelection) + ); + } + + #[test] + fn vim_normal_enter_accept_true_uses_accept() { + let mut settings = default_settings(); + settings.enter_accept = true; + let km = default_vim_normal_keymap(&settings); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("enter"), &ctx), Some(Action::Accept)); + } + + // -- Vim Insert keymap tests -- + + #[test] + fn vim_insert_inherits_emacs_enter() { + let km = default_vim_insert_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + // enter_accept=false → ReturnSelection + assert_eq!( + km.resolve(&key("enter"), &ctx), + Some(Action::ReturnSelection) + ); + } + + #[test] + fn vim_insert_esc_enters_normal() { + let km = default_vim_insert_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("esc"), &ctx), Some(Action::VimEnterNormal)); + } + + #[test] + fn vim_insert_ctrl_bracket_enters_normal() { + let km = default_vim_insert_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + km.resolve(&key("ctrl-["), &ctx), + Some(Action::VimEnterNormal) + ); + } + + #[test] + fn vim_insert_inherits_emacs_ctrl_d() { + let km = default_vim_insert_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + // input empty → return original + assert_eq!( + km.resolve(&key("ctrl-d"), &ctx), + Some(Action::ReturnOriginal) + ); + } + + // -- Inspector keymap tests -- + + #[test] + fn inspector_ctrl_d_deletes() { + let km = default_inspector_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("ctrl-d"), &ctx), Some(Action::Delete)); + } + + #[test] + fn inspector_up_inspects_previous() { + let km = default_inspector_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("up"), &ctx), Some(Action::InspectPrevious)); + } + + #[test] + fn inspector_down_inspects_next() { + let km = default_inspector_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("down"), &ctx), Some(Action::InspectNext)); + } + + #[test] + fn inspector_esc_exits() { + let km = default_inspector_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("esc"), &ctx), Some(Action::Exit)); + } + + #[test] + fn inspector_tab_returns_selection() { + // enter_accept=false → ReturnSelection + let km = default_inspector_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("tab"), &ctx), Some(Action::ReturnSelection)); + } + + // -- Prefix keymap tests -- + + #[test] + fn prefix_d_deletes() { + let km = default_prefix_keymap(); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("d"), &ctx), Some(Action::Delete)); + } + + #[test] + fn prefix_a_cursor_start() { + let km = default_prefix_keymap(); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("a"), &ctx), Some(Action::CursorStart)); + } + + #[test] + fn prefix_unknown_key_returns_none() { + let km = default_prefix_keymap(); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("x"), &ctx), None); + } + + // -- KeymapSet tests -- + + #[test] + fn keymap_set_defaults_builds() { + let settings = default_settings(); + let set = KeymapSet::defaults(&settings); + let ctx = make_ctx(0, 0, 0, 10); + + // Sanity check each keymap has bindings + assert!(set.emacs.resolve(&key("ctrl-c"), &ctx).is_some()); + assert!(set.vim_normal.resolve(&key("ctrl-c"), &ctx).is_some()); + assert!(set.vim_insert.resolve(&key("ctrl-c"), &ctx).is_some()); + assert!(set.inspector.resolve(&key("ctrl-c"), &ctx).is_some()); + assert!(set.prefix.resolve(&key("d"), &ctx).is_some()); + } + + // -- Settings-dependent behavior -- + + #[test] + fn custom_prefix_char() { + let mut settings = default_settings(); + settings.keys.prefix = "x".to_string(); + let km = default_emacs_keymap(&settings); + let ctx = make_ctx(0, 0, 0, 10); + + // ctrl-x should be prefix mode + assert_eq!( + km.resolve(&key("ctrl-x"), &ctx), + Some(Action::EnterPrefixMode) + ); + // ctrl-a should now be CursorStart (not prefix) + assert_eq!(km.resolve(&key("ctrl-a"), &ctx), Some(Action::CursorStart)); + } + + #[test] + fn ctrl_n_shortcuts_changes_numeric_modifier() { + let mut settings = default_settings(); + settings.ctrl_n_shortcuts = true; + let km = default_emacs_keymap(&settings); + let ctx = make_ctx(0, 0, 0, 10); + + // ctrl-1 should work + assert_eq!( + km.resolve(&key("ctrl-1"), &ctx), + Some(Action::ReturnSelectionNth(1)) + ); + // alt-1 should NOT be bound + assert_eq!(km.resolve(&key("alt-1"), &ctx), None); + } + + #[test] + fn default_alt_numeric_shortcuts() { + let settings = default_settings(); + let km = default_emacs_keymap(&settings); + let ctx = make_ctx(0, 0, 0, 10); + + // alt-1 should work by default + assert_eq!( + km.resolve(&key("alt-1"), &ctx), + Some(Action::ReturnSelectionNth(1)) + ); + } + + // ----------------------------------------------------------------------- + // Config parsing and merging tests + // ----------------------------------------------------------------------- + + #[test] + fn parse_simple_binding_config() { + use crate::atuin_client::settings::KeyBindingConfig; + let cfg = KeyBindingConfig::Simple("accept".to_string()); + let binding = super::parse_binding_config(&cfg).unwrap(); + assert_eq!(binding.rules.len(), 1); + assert!(binding.rules[0].condition.is_none()); + assert_eq!(binding.rules[0].action, Action::Accept); + } + + #[test] + fn parse_conditional_binding_config() { + use crate::atuin_client::settings::{KeyBindingConfig, KeyRuleConfig}; + let cfg = KeyBindingConfig::Rules(vec![ + KeyRuleConfig { + when: Some("cursor-at-start".to_string()), + action: "exit".to_string(), + }, + KeyRuleConfig { + when: None, + action: "cursor-left".to_string(), + }, + ]); + let binding = super::parse_binding_config(&cfg).unwrap(); + assert_eq!(binding.rules.len(), 2); + assert!(binding.rules[0].condition.is_some()); + assert_eq!(binding.rules[0].action, Action::Exit); + assert!(binding.rules[1].condition.is_none()); + assert_eq!(binding.rules[1].action, Action::CursorLeft); + } + + #[test] + fn parse_binding_config_invalid_action() { + use crate::atuin_client::settings::KeyBindingConfig; + let cfg = KeyBindingConfig::Simple("not-a-real-action".to_string()); + assert!(super::parse_binding_config(&cfg).is_err()); + } + + #[test] + fn parse_binding_config_invalid_condition() { + use crate::atuin_client::settings::{KeyBindingConfig, KeyRuleConfig}; + let cfg = KeyBindingConfig::Rules(vec![KeyRuleConfig { + when: Some("not-a-real-condition".to_string()), + action: "exit".to_string(), + }]); + assert!(super::parse_binding_config(&cfg).is_err()); + } + + #[test] + fn config_override_replaces_key() { + use crate::atuin_client::settings::KeyBindingConfig; + use std::collections::HashMap; + + let mut settings = default_settings(); + let set = KeymapSet::defaults(&settings); + + // Default: ctrl-c → ReturnOriginal + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + set.emacs.resolve(&key("ctrl-c"), &ctx), + Some(Action::ReturnOriginal) + ); + + // Override ctrl-c → Exit via config + settings.keymap.emacs = HashMap::from([( + "ctrl-c".to_string(), + KeyBindingConfig::Simple("exit".to_string()), + )]); + + let set = KeymapSet::from_settings(&settings); + assert_eq!(set.emacs.resolve(&key("ctrl-c"), &ctx), Some(Action::Exit)); + } + + #[test] + fn config_override_preserves_unoverridden_keys() { + use crate::atuin_client::settings::KeyBindingConfig; + use std::collections::HashMap; + + let mut settings = default_settings(); + // Override only ctrl-c; enter should keep its default + settings.keymap.emacs = HashMap::from([( + "ctrl-c".to_string(), + KeyBindingConfig::Simple("exit".to_string()), + )]); + + let set = KeymapSet::from_settings(&settings); + let ctx = make_ctx(0, 0, 0, 10); + + // ctrl-c overridden + assert_eq!(set.emacs.resolve(&key("ctrl-c"), &ctx), Some(Action::Exit)); + // enter still has default (enter_accept=false → ReturnSelection) + assert_eq!( + set.emacs.resolve(&key("enter"), &ctx), + Some(Action::ReturnSelection) + ); + } + + #[test] + fn config_conditional_override() { + use crate::atuin_client::settings::{KeyBindingConfig, KeyRuleConfig}; + use std::collections::HashMap; + + let mut settings = default_settings(); + // Override "up" with a custom conditional + settings.keymap.emacs = HashMap::from([( + "up".to_string(), + KeyBindingConfig::Rules(vec![ + KeyRuleConfig { + when: Some("no-results".to_string()), + action: "exit".to_string(), + }, + KeyRuleConfig { + when: None, + action: "select-previous".to_string(), + }, + ]), + )]); + + let set = KeymapSet::from_settings(&settings); + + // With no results → exit + let ctx = make_ctx(0, 0, 0, 0); + assert_eq!(set.emacs.resolve(&key("up"), &ctx), Some(Action::Exit)); + + // With results → select-previous + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + set.emacs.resolve(&key("up"), &ctx), + Some(Action::SelectPrevious) + ); + } + + #[test] + fn from_settings_with_empty_config_equals_defaults() { + let settings = default_settings(); + let defaults = KeymapSet::defaults(&settings); + let from_settings = KeymapSet::from_settings(&settings); + + // Verify a sample of keys produce the same results + let ctx = make_ctx(0, 0, 0, 10); + let test_keys = [ + "ctrl-c", "enter", "esc", "tab", "up", "down", "left", "right", + ]; + for k in &test_keys { + assert_eq!( + defaults.emacs.resolve(&key(k), &ctx), + from_settings.emacs.resolve(&key(k), &ctx), + "mismatch for emacs key {k}" + ); + } + } + + // ----------------------------------------------------------------------- + // Phase 5: [keys] vs [keymap] backward compatibility + // ----------------------------------------------------------------------- + + #[test] + fn keymap_overrides_ignore_keys_section() { + use crate::atuin_client::settings::KeyBindingConfig; + + // Set up: [keys] disables scroll_exits, but [keymap] is present + let mut settings = default_settings(); + settings.keys.scroll_exits = false; + + // Without [keymap], scroll_exits=false means no exit condition on down + let set_legacy = KeymapSet::defaults(&settings); + // At list-at-start (selected=0), down should still be SelectNext (no exit) + let ctx_at_boundary = make_ctx(0, 0, 0, 10); + assert_eq!( + set_legacy.emacs.resolve(&key("down"), &ctx_at_boundary), + Some(Action::SelectNext), + "legacy: down at boundary should be SelectNext with scroll_exits=false" + ); + + // With [keymap] present (even just one override), [keys] is ignored + // so the standard defaults (scroll_exits=true) apply + settings.keymap.emacs = HashMap::from([( + "ctrl-c".to_string(), + KeyBindingConfig::Simple("exit".to_string()), + )]); + let set_keymap = KeymapSet::from_settings(&settings); + + // Not at boundary (selected=5): should SelectNext normally + let ctx_not_at_boundary = make_ctx(0, 0, 5, 10); + assert_eq!( + set_keymap.emacs.resolve(&key("down"), &ctx_not_at_boundary), + Some(Action::SelectNext), + "keymap: down not at boundary should SelectNext" + ); + // At list-at-start (selected=0): should Exit (standard scroll_exits=true) + assert_eq!( + set_keymap.emacs.resolve(&key("down"), &ctx_at_boundary), + Some(Action::Exit), + "keymap: down at boundary should Exit (standard defaults restored)" + ); + } + + #[test] + fn keymap_present_resets_to_standard_keys_defaults() { + use crate::atuin_client::settings::KeyBindingConfig; + + let mut settings = default_settings(); + // Disable all [keys] behaviors + settings.keys.exit_past_line_start = false; + settings.keys.accept_past_line_end = false; + + // Without [keymap], left should be plain CursorLeft + let set_legacy = KeymapSet::defaults(&settings); + let ctx_at_start = make_ctx(0, 5, 0, 10); + assert_eq!( + set_legacy.emacs.resolve(&key("left"), &ctx_at_start), + Some(Action::CursorLeft), + "legacy: left should be plain CursorLeft without exit_past_line_start" + ); + + // Add a [keymap] entry (for a different key) + settings.keymap.emacs = HashMap::from([( + "ctrl-c".to_string(), + KeyBindingConfig::Simple("exit".to_string()), + )]); + let set_keymap = KeymapSet::from_settings(&settings); + + // Now left should use standard defaults (exit_past_line_start=true) + // At cursor start → Exit + assert_eq!( + set_keymap.emacs.resolve(&key("left"), &ctx_at_start), + Some(Action::Exit), + "keymap: left at cursor start should exit (standard defaults)" + ); + + // Right at cursor end should return selection (standard defaults: accept_past_line_end=true, enter_accept=false) + let ctx_at_end = make_ctx(5, 5, 0, 10); + assert_eq!( + set_keymap.emacs.resolve(&key("right"), &ctx_at_end), + Some(Action::ReturnSelection), + "keymap: right at cursor end should return selection (standard defaults)" + ); + } + + #[test] + fn keys_has_non_default_values_detection() { + use crate::atuin_client::settings::Keys; + + let standard = Keys::standard_defaults(); + assert!(!standard.has_non_default_values()); + + let mut modified = Keys::standard_defaults(); + modified.scroll_exits = false; + assert!(modified.has_non_default_values()); + + let mut modified = Keys::standard_defaults(); + modified.prefix = "x".to_string(); + assert!(modified.has_non_default_values()); + } + + #[test] + fn original_input_empty_condition_in_config() { + use crate::atuin_client::settings::{KeyBindingConfig, KeyRuleConfig}; + use std::collections::HashMap; + + let mut settings = default_settings(); + // Configure esc to: if original-input-empty -> return-query, else return-original + settings.keymap.emacs = HashMap::from([( + "esc".to_string(), + KeyBindingConfig::Rules(vec![ + KeyRuleConfig { + when: Some("original-input-empty".to_string()), + action: "return-query".to_string(), + }, + KeyRuleConfig { + when: None, + action: "return-original".to_string(), + }, + ]), + )]); + + let set = KeymapSet::from_settings(&settings); + + // When original input was empty, should return-query + let ctx_original_empty = EvalContext { + cursor_position: 0, + input_width: 5, + input_byte_len: 5, + selected_index: 0, + results_len: 10, + original_input_empty: true, + has_context: false, + }; + assert_eq!( + set.emacs.resolve(&key("esc"), &ctx_original_empty), + Some(Action::ReturnQuery), + "esc with original_input_empty=true should return-query" + ); + + // When original input was not empty, should return-original + let ctx_original_not_empty = EvalContext { + cursor_position: 0, + input_width: 5, + input_byte_len: 5, + selected_index: 0, + results_len: 10, + original_input_empty: false, + has_context: false, + }; + assert_eq!( + set.emacs.resolve(&key("esc"), &ctx_original_not_empty), + Some(Action::ReturnOriginal), + "esc with original_input_empty=false should return-original" + ); + } +} diff --git a/crates/turtle/src/command/client/search/keybindings/key.rs b/crates/turtle/src/command/client/search/keybindings/key.rs new file mode 100644 index 00000000..c2eb31c6 --- /dev/null +++ b/crates/turtle/src/command/client/search/keybindings/key.rs @@ -0,0 +1,629 @@ +use std::fmt; + +use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers, MediaKeyCode}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +/// A single key press with modifiers (e.g. `ctrl-c`, `alt-f`, `enter`). +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[expect(clippy::struct_excessive_bools)] +pub struct SingleKey { + pub code: KeyCodeValue, + pub ctrl: bool, + pub alt: bool, + pub shift: bool, + pub super_key: bool, +} + +/// The key code portion of a key press. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum KeyCodeValue { + Char(char), + Enter, + Esc, + Tab, + Backspace, + Delete, + Insert, + Up, + Down, + Left, + Right, + Home, + End, + PageUp, + PageDown, + Space, + F(u8), + Media(MediaKeyCode), +} + +/// A key input that may be a single key or a multi-key sequence (e.g. `g g`). +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum KeyInput { + Single(SingleKey), + Sequence(Vec<SingleKey>), +} + +impl SingleKey { + /// Convert a crossterm `KeyEvent` into a `SingleKey`. + pub fn from_event(event: &KeyEvent) -> Option<Self> { + let ctrl = event.modifiers.contains(KeyModifiers::CONTROL); + let alt = event.modifiers.contains(KeyModifiers::ALT); + let shift = event.modifiers.contains(KeyModifiers::SHIFT); + let super_key = event.modifiers.contains(KeyModifiers::SUPER); + + let code = match event.code { + KeyCode::Char(' ') => KeyCodeValue::Space, + KeyCode::Char(c) => { + // If shift is the only modifier and it's an uppercase letter, + // we store the uppercase char directly and clear the shift flag + // since the case already encodes it. + if shift && !ctrl && !alt && !super_key && c.is_ascii_uppercase() { + return Some(SingleKey { + code: KeyCodeValue::Char(c), + ctrl: false, + alt: false, + shift: false, + super_key: false, + }); + } + KeyCodeValue::Char(c) + } + KeyCode::Enter => KeyCodeValue::Enter, + KeyCode::Esc => KeyCodeValue::Esc, + KeyCode::Tab => KeyCodeValue::Tab, + // BackTab is sent by many terminals for Shift+Tab + KeyCode::BackTab => { + return Some(SingleKey { + code: KeyCodeValue::Tab, + ctrl, + alt, + shift: true, + super_key, + }); + } + KeyCode::Backspace => KeyCodeValue::Backspace, + KeyCode::Delete => KeyCodeValue::Delete, + KeyCode::Insert => KeyCodeValue::Insert, + KeyCode::Up => KeyCodeValue::Up, + KeyCode::Down => KeyCodeValue::Down, + KeyCode::Left => KeyCodeValue::Left, + KeyCode::Right => KeyCodeValue::Right, + KeyCode::Home => KeyCodeValue::Home, + KeyCode::End => KeyCodeValue::End, + KeyCode::PageUp => KeyCodeValue::PageUp, + KeyCode::PageDown => KeyCodeValue::PageDown, + KeyCode::F(n) => KeyCodeValue::F(n), + KeyCode::Media(m) => KeyCodeValue::Media(m), + _ => return None, + }; + + Some(SingleKey { + code, + ctrl, + alt, + shift: if matches!(code, KeyCodeValue::Char(_)) { + false + } else { + shift + }, + super_key, + }) + } + + /// Parse a key string like `"ctrl-c"`, `"alt-f"`, `"enter"`, `"G"`. + pub fn parse(s: &str) -> Result<Self, String> { + let s = s.trim(); + let parts: Vec<&str> = s.split('-').collect(); + + let mut ctrl = false; + let mut alt = false; + let mut shift = false; + let mut super_key = false; + + // All parts except the last are modifiers + for &part in &parts[..parts.len() - 1] { + match part.to_lowercase().as_str() { + "ctrl" => ctrl = true, + "alt" => alt = true, + "shift" => shift = true, + "super" | "cmd" | "win" => super_key = true, + _ => return Err(format!("unknown modifier: {part}")), + } + } + + let key_part = parts[parts.len() - 1]; + let code = match key_part.to_lowercase().as_str() { + "enter" | "return" => KeyCodeValue::Enter, + "esc" | "escape" => KeyCodeValue::Esc, + "tab" => KeyCodeValue::Tab, + "backspace" => KeyCodeValue::Backspace, + "delete" | "del" => KeyCodeValue::Delete, + "insert" | "ins" => KeyCodeValue::Insert, + "up" => KeyCodeValue::Up, + "down" => KeyCodeValue::Down, + "left" => KeyCodeValue::Left, + "right" => KeyCodeValue::Right, + "home" => KeyCodeValue::Home, + "end" => KeyCodeValue::End, + "pageup" => KeyCodeValue::PageUp, + "pagedown" => KeyCodeValue::PageDown, + "space" => KeyCodeValue::Space, + s if s.starts_with('f') && s.len() > 1 => { + // Parse function keys like "f1", "f12" + if let Ok(n) = s[1..].parse::<u8>() { + if (1..=24).contains(&n) { + KeyCodeValue::F(n) + } else { + return Err(format!("function key out of range: {key_part}")); + } + } else { + return Err(format!("unknown key: {key_part}")); + } + } + "[" => KeyCodeValue::Char('['), + "]" => KeyCodeValue::Char(']'), + "?" => KeyCodeValue::Char('?'), + "/" => KeyCodeValue::Char('/'), + "$" => KeyCodeValue::Char('$'), + // Media keys (no dashes - the parser splits on dash for modifiers) + "play" => KeyCodeValue::Media(MediaKeyCode::Play), + "pause" => KeyCodeValue::Media(MediaKeyCode::Pause), + "playpause" => KeyCodeValue::Media(MediaKeyCode::PlayPause), + "stop" => KeyCodeValue::Media(MediaKeyCode::Stop), + "fastforward" => KeyCodeValue::Media(MediaKeyCode::FastForward), + "rewind" => KeyCodeValue::Media(MediaKeyCode::Rewind), + "tracknext" => KeyCodeValue::Media(MediaKeyCode::TrackNext), + "trackprevious" => KeyCodeValue::Media(MediaKeyCode::TrackPrevious), + "record" => KeyCodeValue::Media(MediaKeyCode::Record), + "lowervolume" => KeyCodeValue::Media(MediaKeyCode::LowerVolume), + "raisevolume" => KeyCodeValue::Media(MediaKeyCode::RaiseVolume), + "mutevolume" | "mute" => KeyCodeValue::Media(MediaKeyCode::MuteVolume), + _ => { + let chars: Vec<char> = key_part.chars().collect(); + if chars.len() == 1 { + let c = chars[0]; + // An uppercase letter implies shift (unless shift already specified) + if c.is_ascii_uppercase() && !ctrl && !alt && !super_key { + return Ok(SingleKey { + code: KeyCodeValue::Char(c), + ctrl: false, + alt: false, + shift: false, + super_key: false, + }); + } + KeyCodeValue::Char(c) + } else { + return Err(format!("unknown key: {key_part}")); + } + } + }; + + Ok(SingleKey { + code, + ctrl, + alt, + shift, + super_key, + }) + } +} + +impl fmt::Display for SingleKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.super_key { + write!(f, "super-")?; + } + if self.ctrl { + write!(f, "ctrl-")?; + } + if self.alt { + write!(f, "alt-")?; + } + if self.shift { + write!(f, "shift-")?; + } + match &self.code { + KeyCodeValue::Char(c) => write!(f, "{c}"), + KeyCodeValue::Enter => write!(f, "enter"), + KeyCodeValue::Esc => write!(f, "esc"), + KeyCodeValue::Tab => write!(f, "tab"), + KeyCodeValue::Backspace => write!(f, "backspace"), + KeyCodeValue::Delete => write!(f, "delete"), + KeyCodeValue::Insert => write!(f, "insert"), + KeyCodeValue::Up => write!(f, "up"), + KeyCodeValue::Down => write!(f, "down"), + KeyCodeValue::Left => write!(f, "left"), + KeyCodeValue::Right => write!(f, "right"), + KeyCodeValue::Home => write!(f, "home"), + KeyCodeValue::End => write!(f, "end"), + KeyCodeValue::PageUp => write!(f, "pageup"), + KeyCodeValue::PageDown => write!(f, "pagedown"), + KeyCodeValue::Space => write!(f, "space"), + KeyCodeValue::F(n) => write!(f, "f{n}"), + KeyCodeValue::Media(m) => match m { + MediaKeyCode::Play => write!(f, "play"), + MediaKeyCode::Pause => write!(f, "media-pause"), + MediaKeyCode::PlayPause => write!(f, "playpause"), + MediaKeyCode::Stop => write!(f, "stop"), + MediaKeyCode::FastForward => write!(f, "fastforward"), + MediaKeyCode::Rewind => write!(f, "rewind"), + MediaKeyCode::TrackNext => write!(f, "tracknext"), + MediaKeyCode::TrackPrevious => write!(f, "trackprevious"), + MediaKeyCode::Record => write!(f, "record"), + MediaKeyCode::LowerVolume => write!(f, "lowervolume"), + MediaKeyCode::RaiseVolume => write!(f, "raisevolume"), + MediaKeyCode::MuteVolume => write!(f, "mutevolume"), + MediaKeyCode::Reverse => write!(f, "reverse"), + }, + } + } +} + +impl KeyInput { + /// Parse a key input string. Supports multi-key sequences separated by spaces + /// (e.g. `"g g"`). + pub fn parse(s: &str) -> Result<Self, String> { + let s = s.trim(); + // Check for space-separated multi-key sequences + // But don't split "space" or modifier combos like "ctrl-a" + let parts: Vec<&str> = s.split_whitespace().collect(); + if parts.len() > 1 { + let keys: Result<Vec<SingleKey>, String> = + parts.iter().map(|p| SingleKey::parse(p)).collect(); + Ok(KeyInput::Sequence(keys?)) + } else { + Ok(KeyInput::Single(SingleKey::parse(s)?)) + } + } +} + +impl fmt::Display for KeyInput { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + KeyInput::Single(k) => write!(f, "{k}"), + KeyInput::Sequence(keys) => { + for (i, k) in keys.iter().enumerate() { + if i > 0 { + write!(f, " ")?; + } + write!(f, "{k}")?; + } + Ok(()) + } + } + } +} + +impl Serialize for KeyInput { + fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + serializer.serialize_str(&self.to_string()) + } +} + +impl<'de> Deserialize<'de> for KeyInput { + fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> { + let s = String::deserialize(deserializer)?; + KeyInput::parse(&s).map_err(serde::de::Error::custom) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + + #[test] + fn parse_simple_keys() { + let k = SingleKey::parse("a").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('a')); + assert!(!k.ctrl && !k.alt && !k.shift); + + let k = SingleKey::parse("enter").unwrap(); + assert_eq!(k.code, KeyCodeValue::Enter); + + let k = SingleKey::parse("esc").unwrap(); + assert_eq!(k.code, KeyCodeValue::Esc); + + let k = SingleKey::parse("tab").unwrap(); + assert_eq!(k.code, KeyCodeValue::Tab); + + let k = SingleKey::parse("space").unwrap(); + assert_eq!(k.code, KeyCodeValue::Space); + } + + #[test] + fn parse_modifiers() { + let k = SingleKey::parse("ctrl-c").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('c')); + assert!(k.ctrl); + assert!(!k.alt); + + let k = SingleKey::parse("alt-f").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('f')); + assert!(k.alt); + assert!(!k.ctrl); + + let k = SingleKey::parse("ctrl-alt-x").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('x')); + assert!(k.ctrl && k.alt); + } + + #[test] + fn parse_uppercase_implies_no_shift_flag() { + let k = SingleKey::parse("G").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('G')); + assert!(!k.shift); + assert!(!k.ctrl); + } + + #[test] + fn parse_special_chars() { + let k = SingleKey::parse("ctrl-[").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('[')); + assert!(k.ctrl); + + let k = SingleKey::parse("?").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('?')); + + let k = SingleKey::parse("/").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('/')); + } + + #[test] + fn parse_multi_key_sequence() { + let ki = KeyInput::parse("g g").unwrap(); + match ki { + KeyInput::Sequence(keys) => { + assert_eq!(keys.len(), 2); + assert_eq!(keys[0].code, KeyCodeValue::Char('g')); + assert_eq!(keys[1].code, KeyCodeValue::Char('g')); + } + _ => panic!("expected sequence"), + } + } + + #[test] + fn display_round_trip() { + let cases = ["ctrl-c", "alt-f", "enter", "G", "tab", "pageup"]; + for s in cases { + let k = KeyInput::parse(s).unwrap(); + let display = k.to_string(); + let k2 = KeyInput::parse(&display).unwrap(); + assert_eq!(k, k2, "round-trip failed for {s}"); + } + + let ki = KeyInput::parse("g g").unwrap(); + assert_eq!(ki.to_string(), "g g"); + } + + #[test] + fn from_event_basic() { + let event = KeyEvent::new(KeyCode::Char('c'), KeyModifiers::CONTROL); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('c')); + assert!(k.ctrl); + assert!(!k.alt); + + let event = KeyEvent::new(KeyCode::Enter, KeyModifiers::NONE); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::Enter); + } + + #[test] + fn from_event_uppercase() { + // Crossterm sends uppercase chars with SHIFT modifier + let event = KeyEvent::new(KeyCode::Char('G'), KeyModifiers::SHIFT); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('G')); + // shift flag should be cleared since the case encodes it + assert!(!k.shift); + } + + #[test] + fn from_event_matches_parsed() { + // Verify that from_event and parse produce the same SingleKey + let event = KeyEvent::new(KeyCode::Char('c'), KeyModifiers::CONTROL); + let from_event = SingleKey::from_event(&event).unwrap(); + let parsed = SingleKey::parse("ctrl-c").unwrap(); + assert_eq!(from_event, parsed); + + let event = KeyEvent::new(KeyCode::Char('G'), KeyModifiers::SHIFT); + let from_event = SingleKey::from_event(&event).unwrap(); + let parsed = SingleKey::parse("G").unwrap(); + assert_eq!(from_event, parsed); + } + + #[test] + fn parse_super_modifier() { + let k = SingleKey::parse("super-a").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('a')); + assert!(k.super_key); + assert!(!k.ctrl && !k.alt && !k.shift); + + // "cmd" is an alias for "super" + let k2 = SingleKey::parse("cmd-a").unwrap(); + assert_eq!(k, k2); + + // "win" is an alias for "super" + let k3 = SingleKey::parse("win-a").unwrap(); + assert_eq!(k, k3); + } + + #[test] + fn parse_super_with_other_modifiers() { + let k = SingleKey::parse("super-ctrl-c").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('c')); + assert!(k.super_key && k.ctrl); + assert!(!k.alt && !k.shift); + } + + #[test] + fn display_super_modifier() { + let k = SingleKey::parse("super-a").unwrap(); + assert_eq!(k.to_string(), "super-a"); + + let k = SingleKey::parse("super-ctrl-x").unwrap(); + assert_eq!(k.to_string(), "super-ctrl-x"); + } + + #[test] + fn display_round_trip_super() { + let k = KeyInput::parse("super-a").unwrap(); + let display = k.to_string(); + let k2 = KeyInput::parse(&display).unwrap(); + assert_eq!(k, k2, "round-trip failed for super-a"); + } + + #[test] + fn from_event_super() { + let event = KeyEvent::new(KeyCode::Char('a'), KeyModifiers::SUPER); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('a')); + assert!(k.super_key); + assert!(!k.ctrl && !k.alt && !k.shift); + } + + #[test] + fn from_event_super_matches_parsed() { + let event = KeyEvent::new(KeyCode::Char('a'), KeyModifiers::SUPER); + let from_event = SingleKey::from_event(&event).unwrap(); + let parsed = SingleKey::parse("super-a").unwrap(); + assert_eq!(from_event, parsed); + } + + #[test] + fn super_uppercase_preserves_super() { + // super-G should keep the super flag (unlike bare "G" which clears shift) + let k = SingleKey::parse("super-G").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('G')); + assert!(k.super_key); + } + + #[test] + fn parse_errors() { + assert!(SingleKey::parse("ctrl-alt-shift-xxx").is_err()); + assert!(SingleKey::parse("foobar-a").is_err()); + } + + #[test] + fn parse_function_keys() { + let k = SingleKey::parse("f1").unwrap(); + assert_eq!(k.code, KeyCodeValue::F(1)); + assert!(!k.ctrl && !k.alt && !k.shift); + + let k = SingleKey::parse("F12").unwrap(); + assert_eq!(k.code, KeyCodeValue::F(12)); + + let k = SingleKey::parse("ctrl-f5").unwrap(); + assert_eq!(k.code, KeyCodeValue::F(5)); + assert!(k.ctrl); + + // F24 is valid (some keyboards have extended function keys) + let k = SingleKey::parse("f24").unwrap(); + assert_eq!(k.code, KeyCodeValue::F(24)); + + // F0 and F25+ are invalid + assert!(SingleKey::parse("f0").is_err()); + assert!(SingleKey::parse("f25").is_err()); + } + + #[test] + fn from_event_function_keys() { + let event = KeyEvent::new(KeyCode::F(1), KeyModifiers::NONE); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::F(1)); + + let event = KeyEvent::new(KeyCode::F(12), KeyModifiers::CONTROL); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::F(12)); + assert!(k.ctrl); + } + + #[test] + fn display_function_keys() { + let k = SingleKey::parse("f1").unwrap(); + assert_eq!(k.to_string(), "f1"); + + let k = SingleKey::parse("ctrl-f12").unwrap(); + assert_eq!(k.to_string(), "ctrl-f12"); + } + + #[test] + fn function_key_round_trip() { + let cases = ["f1", "f12", "ctrl-f5", "alt-f10"]; + for s in cases { + let k = KeyInput::parse(s).unwrap(); + let display = k.to_string(); + let k2 = KeyInput::parse(&display).unwrap(); + assert_eq!(k, k2, "round-trip failed for {s}"); + } + } + + #[test] + fn from_event_function_key_matches_parsed() { + let event = KeyEvent::new(KeyCode::F(12), KeyModifiers::NONE); + let from_event = SingleKey::from_event(&event).unwrap(); + let parsed = SingleKey::parse("f12").unwrap(); + assert_eq!(from_event, parsed); + } + + #[test] + fn from_event_backtab_becomes_shift_tab() { + // Many terminals send BackTab for Shift+Tab + let event = KeyEvent::new(KeyCode::BackTab, KeyModifiers::NONE); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::Tab); + assert!(k.shift); + assert!(!k.ctrl && !k.alt); + } + + #[test] + fn from_event_backtab_matches_parsed_shift_tab() { + let event = KeyEvent::new(KeyCode::BackTab, KeyModifiers::NONE); + let from_event = SingleKey::from_event(&event).unwrap(); + let parsed = SingleKey::parse("shift-tab").unwrap(); + assert_eq!(from_event, parsed); + } + + #[test] + fn from_event_backtab_with_ctrl() { + // BackTab with ctrl modifier + let event = KeyEvent::new(KeyCode::BackTab, KeyModifiers::CONTROL); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::Tab); + assert!(k.shift); + assert!(k.ctrl); + } + + #[test] + fn parse_insert_key() { + let k = SingleKey::parse("insert").unwrap(); + assert_eq!(k.code, KeyCodeValue::Insert); + assert!(!k.ctrl && !k.alt && !k.shift); + + let k = SingleKey::parse("ins").unwrap(); + assert_eq!(k.code, KeyCodeValue::Insert); + + let k = SingleKey::parse("ctrl-insert").unwrap(); + assert_eq!(k.code, KeyCodeValue::Insert); + assert!(k.ctrl); + } + + #[test] + fn from_event_insert_key() { + let event = KeyEvent::new(KeyCode::Insert, KeyModifiers::NONE); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::Insert); + } + + #[test] + fn insert_key_round_trip() { + let k = KeyInput::parse("insert").unwrap(); + let display = k.to_string(); + assert_eq!(display, "insert"); + let k2 = KeyInput::parse(&display).unwrap(); + assert_eq!(k, k2); + } +} diff --git a/crates/turtle/src/command/client/search/keybindings/keymap.rs b/crates/turtle/src/command/client/search/keybindings/keymap.rs new file mode 100644 index 00000000..0d362863 --- /dev/null +++ b/crates/turtle/src/command/client/search/keybindings/keymap.rs @@ -0,0 +1,233 @@ +use std::collections::HashMap; + +use super::actions::Action; +use super::conditions::{ConditionExpr, EvalContext}; +use super::key::{KeyInput, SingleKey}; + +/// A single rule within a keybinding: an optional condition and an action. +/// If the condition is `None`, the rule always matches. +#[derive(Debug, Clone)] +pub struct KeyRule { + pub condition: Option<ConditionExpr>, + pub action: Action, +} + +/// A keybinding is an ordered list of rules. The first rule whose condition +/// matches (or has no condition) wins. +#[derive(Debug, Clone)] +pub struct KeyBinding { + pub rules: Vec<KeyRule>, +} + +/// A keymap is a collection of keybindings indexed by key input. +#[derive(Debug, Clone)] +pub struct Keymap { + pub bindings: HashMap<KeyInput, KeyBinding>, +} + +impl KeyRule { + /// Create an unconditional rule. + pub fn always(action: Action) -> Self { + KeyRule { + condition: None, + action, + } + } + + /// Create a conditional rule. Accepts any type convertible to `ConditionExpr`, + /// including bare `ConditionAtom` values. + pub fn when(condition: impl Into<ConditionExpr>, action: Action) -> Self { + KeyRule { + condition: Some(condition.into()), + action, + } + } +} + +impl KeyBinding { + /// Create a simple (unconditional) binding. + pub fn simple(action: Action) -> Self { + KeyBinding { + rules: vec![KeyRule::always(action)], + } + } + + /// Create a conditional binding from a list of rules. + pub fn conditional(rules: Vec<KeyRule>) -> Self { + KeyBinding { rules } + } +} + +impl Keymap { + /// Create an empty keymap. + pub fn new() -> Self { + Keymap { + bindings: HashMap::new(), + } + } + + /// Bind a key input to a simple (unconditional) action. + pub fn bind(&mut self, key: KeyInput, action: Action) { + self.bindings.insert(key, KeyBinding::simple(action)); + } + + /// Bind a key input to a conditional set of rules. + pub fn bind_conditional(&mut self, key: KeyInput, rules: Vec<KeyRule>) { + self.bindings.insert(key, KeyBinding::conditional(rules)); + } + + /// Resolve a key input to an action given the current evaluation context. + /// Returns `None` if the key has no binding or no rule's condition matches. + pub fn resolve(&self, key: &KeyInput, ctx: &EvalContext) -> Option<Action> { + let binding = self.bindings.get(key)?; + for rule in &binding.rules { + match &rule.condition { + None => return Some(rule.action.clone()), + Some(cond) if cond.evaluate(ctx) => return Some(rule.action.clone()), + Some(_) => {} + } + } + None + } + + /// Check if any binding starts with the given single key as the first key + /// of a multi-key sequence. Used to detect pending multi-key sequences. + pub fn has_sequence_starting_with(&self, prefix: &SingleKey) -> bool { + self.bindings.keys().any(|ki| match ki { + KeyInput::Sequence(keys) => keys.first() == Some(prefix), + KeyInput::Single(_) => false, + }) + } + + /// Merge another keymap into this one. Keys from `other` override keys in `self`. + #[expect(dead_code)] + pub fn merge(&mut self, other: &Keymap) { + for (key, binding) in &other.bindings { + self.bindings.insert(key.clone(), binding.clone()); + } + } +} + +impl Default for Keymap { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::super::conditions::ConditionAtom; + use super::*; + + fn make_ctx(cursor: usize, width: usize, selected: usize, len: usize) -> EvalContext { + EvalContext { + cursor_position: cursor, + input_width: width, + input_byte_len: width, + selected_index: selected, + results_len: len, + original_input_empty: false, + has_context: false, + } + } + + #[test] + fn simple_binding_resolves() { + let mut keymap = Keymap::new(); + let key = KeyInput::parse("ctrl-c").unwrap(); + keymap.bind(key.clone(), Action::ReturnOriginal); + + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(keymap.resolve(&key, &ctx), Some(Action::ReturnOriginal)); + } + + #[test] + fn conditional_first_match_wins() { + let mut keymap = Keymap::new(); + let key = KeyInput::parse("left").unwrap(); + keymap.bind_conditional( + key.clone(), + vec![ + KeyRule::when(ConditionAtom::CursorAtStart, Action::Exit), + KeyRule::always(Action::CursorLeft), + ], + ); + + // Cursor at start → Exit + let ctx = make_ctx(0, 5, 0, 10); + assert_eq!(keymap.resolve(&key, &ctx), Some(Action::Exit)); + + // Cursor not at start → CursorLeft + let ctx = make_ctx(3, 5, 0, 10); + assert_eq!(keymap.resolve(&key, &ctx), Some(Action::CursorLeft)); + } + + #[test] + fn no_match_returns_none() { + let keymap = Keymap::new(); + let key = KeyInput::parse("ctrl-c").unwrap(); + let ctx = make_ctx(0, 0, 0, 0); + assert_eq!(keymap.resolve(&key, &ctx), None); + } + + #[test] + fn conditional_no_condition_matches_returns_none() { + let mut keymap = Keymap::new(); + let key = KeyInput::parse("left").unwrap(); + // Only one rule with a condition that won't match + keymap.bind_conditional( + key.clone(), + vec![KeyRule::when(ConditionAtom::CursorAtStart, Action::Exit)], + ); + + // Cursor not at start → no match + let ctx = make_ctx(3, 5, 0, 10); + assert_eq!(keymap.resolve(&key, &ctx), None); + } + + #[test] + fn has_sequence_starting_with() { + let mut keymap = Keymap::new(); + let seq = KeyInput::parse("g g").unwrap(); + keymap.bind(seq, Action::ScrollToTop); + + let g = SingleKey::parse("g").unwrap(); + assert!(keymap.has_sequence_starting_with(&g)); + + let h = SingleKey::parse("h").unwrap(); + assert!(!keymap.has_sequence_starting_with(&h)); + } + + #[test] + fn merge_overrides() { + let mut base = Keymap::new(); + let key = KeyInput::parse("ctrl-c").unwrap(); + base.bind(key.clone(), Action::ReturnOriginal); + + let mut overlay = Keymap::new(); + overlay.bind(key.clone(), Action::Exit); + + base.merge(&overlay); + + let ctx = make_ctx(0, 0, 0, 0); + assert_eq!(base.resolve(&key, &ctx), Some(Action::Exit)); + } + + #[test] + fn merge_preserves_unoverridden() { + let mut base = Keymap::new(); + let key1 = KeyInput::parse("ctrl-c").unwrap(); + let key2 = KeyInput::parse("ctrl-d").unwrap(); + base.bind(key1.clone(), Action::ReturnOriginal); + base.bind(key2.clone(), Action::DeleteCharAfter); + + let mut overlay = Keymap::new(); + overlay.bind(key1.clone(), Action::Exit); + + base.merge(&overlay); + + let ctx = make_ctx(0, 0, 0, 0); + assert_eq!(base.resolve(&key1, &ctx), Some(Action::Exit)); + assert_eq!(base.resolve(&key2, &ctx), Some(Action::DeleteCharAfter)); + } +} diff --git a/crates/turtle/src/command/client/search/keybindings/mod.rs b/crates/turtle/src/command/client/search/keybindings/mod.rs new file mode 100644 index 00000000..3b6eb2b2 --- /dev/null +++ b/crates/turtle/src/command/client/search/keybindings/mod.rs @@ -0,0 +1,14 @@ +pub mod actions; +pub mod conditions; +pub mod defaults; +pub mod key; +pub mod keymap; + +pub use actions::Action; +#[expect(unused_imports)] +pub use conditions::{ConditionAtom, ConditionExpr, EvalContext}; +pub use defaults::KeymapSet; +#[expect(unused_imports)] +pub use key::{KeyCodeValue, KeyInput, SingleKey}; +#[expect(unused_imports)] +pub use keymap::{KeyBinding, KeyRule, Keymap}; diff --git a/crates/turtle/src/command/client/server.rs b/crates/turtle/src/command/client/server.rs new file mode 100644 index 00000000..7de27551 --- /dev/null +++ b/crates/turtle/src/command/client/server.rs @@ -0,0 +1,61 @@ +use std::net::SocketAddr; + +use crate::atuin_server::{Settings, launch, launch_metrics_server}; +use crate::atuin_server_database::DbType; +use crate::atuin_server_postgres::Postgres; +use crate::atuin_server_sqlite::Sqlite; + +use clap::Subcommand; +use eyre::{Context, Result, eyre}; + +#[derive(Subcommand, Clone, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Start the server + Start { + /// The host address to bind + #[clap(long)] + host: Option<String>, + + /// The port to bind + #[clap(long, short)] + port: Option<u16>, + }, + + /// Print server example configuration + DefaultConfig, +} + +impl Cmd { + #[expect(clippy::too_many_lines)] + pub async fn run(self) -> Result<()> { + match self { + Cmd::Start { host, port } => { + let settings = Settings::new().wrap_err("could not load server settings")?; + let host = host.as_ref().unwrap_or(&settings.host).clone(); + let port = port.unwrap_or(settings.port); + let addr = SocketAddr::new(host.parse()?, port); + + if settings.metrics.enable { + tokio::spawn(launch_metrics_server( + settings.metrics.host.clone(), + settings.metrics.port, + )); + } + + match settings.db_settings.db_type() { + DbType::Postgres => launch::<Postgres>(settings, addr).await, + DbType::Sqlite => launch::<Sqlite>(settings, addr).await, + DbType::Unknown => { + Err(eyre!("db_uri must start with postgres:// or sqlite://")) + } + } + } + Cmd::DefaultConfig => { + // TODO(@bpeetz): Add this back <2026-06-11> + println!("TODO"); + Ok(()) + } + } + } +} diff --git a/crates/turtle/src/command/client/setup.rs b/crates/turtle/src/command/client/setup.rs new file mode 100644 index 00000000..b32ceb97 --- /dev/null +++ b/crates/turtle/src/command/client/setup.rs @@ -0,0 +1,81 @@ +use crate::atuin_client::settings::Settings; + +use colored::Colorize; +use eyre::Result; +use std::io::{self, Write}; +use toml_edit::{DocumentMut, value}; + +pub async fn run(_settings: &Settings) -> Result<()> { + let enable_ai = prompt( + "Atuin AI", + "This will enable command generation and other AI features via the question mark key", + Some( + "By default, Atuin AI only has access to the name and version of your operating system and shell - your shell history is not sent to the AI.", + ), + )?; + + let enable_daemon = prompt( + "Atuin Daemon", + "This will enable improved search and history sync using a persistent background process", + None, + )?; + + let config_file = Settings::get_config_path()?; + let config_str = tokio::fs::read_to_string(&config_file).await?; + let mut doc = config_str.parse::<DocumentMut>()?; + + let mut changed = false; + if enable_ai { + changed = true; + if !doc.contains_key("ai") { + doc["ai"] = toml_edit::table(); + } + doc["ai"]["enabled"] = value(true); + } + + if enable_daemon { + changed = true; + if !doc.contains_key("daemon") { + doc["daemon"] = toml_edit::table(); + } + doc["daemon"]["enabled"] = value(true); + doc["daemon"]["autostart"] = value(true); + doc["search_mode"] = value("daemon-fuzzy"); + } + + if changed { + tokio::fs::write(config_file, doc.to_string()).await?; + + println!( + "{check} Settings updated successfully", + check = "✓".bold().bright_green() + ); + } else { + println!( + "{check} No settings changed", + check = "✓".bold().bright_green() + ); + } + + Ok(()) +} + +pub fn prompt(feature: &str, description: &str, note: Option<&str>) -> Result<bool> { + println!( + "> Enable {feature}?", + feature = feature.bold().bright_blue() + ); + if let Some(note) = note { + println!(" {description}"); + print!(" {note} {q} ", q = "[Y/n]".bold()); + } else { + print!(" {description} {q} ", q = "[Y/n]".bold()); + } + + io::stdout().flush().ok(); + + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + let answer = input.trim().to_lowercase(); + Ok(answer.is_empty() || answer == "y" || answer == "yes") +} diff --git a/crates/turtle/src/command/client/stats.rs b/crates/turtle/src/command/client/stats.rs new file mode 100644 index 00000000..fc10e949 --- /dev/null +++ b/crates/turtle/src/command/client/stats.rs @@ -0,0 +1,85 @@ +use clap::Parser; +use eyre::Result; +use interim::parse_date_string; +use time::{Duration, OffsetDateTime, Time}; + +use crate::atuin_client::{ + database::{Database, current_context}, + settings::Settings, + theme::Theme, +}; + +use crate::atuin_history::stats::{compute, pretty_print}; + +fn parse_ngram_size(s: &str) -> Result<usize, String> { + let value = s + .parse::<usize>() + .map_err(|_| format!("'{s}' is not a valid window size"))?; + + if value == 0 { + return Err("ngram window size must be at least 1".to_string()); + } + + Ok(value) +} + +#[derive(Parser, Debug)] +#[command(infer_subcommands = true)] +pub struct Cmd { + /// Compute statistics for the specified period, leave blank for statistics since the beginning. See [this](https://docs.atuin.sh/reference/stats/) for more details. + period: Vec<String>, + + /// How many top commands to list + #[arg(long, short, default_value = "10")] + count: usize, + + /// The number of consecutive commands to consider + #[arg(long, short, default_value = "1", value_parser = parse_ngram_size)] + ngram_size: usize, +} + +impl Cmd { + pub async fn run(&self, db: &impl Database, settings: &Settings, theme: &Theme) -> Result<()> { + let context = current_context().await?; + let words = if self.period.is_empty() { + String::from("all") + } else { + self.period.join(" ") + }; + + let now = OffsetDateTime::now_utc().to_offset(settings.timezone.0); + let last_night = now.replace_time(Time::MIDNIGHT); + + let history = if words.as_str() == "all" { + db.list(&[], &context, None, false, false).await? + } else if words.trim() == "today" { + let start = last_night; + let end = start + Duration::days(1); + db.range(start, end).await? + } else if words.trim() == "month" { + let end = last_night; + let start = end - Duration::days(31); + db.range(start, end).await? + } else if words.trim() == "week" { + let end = last_night; + let start = end - Duration::days(7); + db.range(start, end).await? + } else if words.trim() == "year" { + let end = last_night; + let start = end - Duration::days(365); + db.range(start, end).await? + } else { + let start = parse_date_string(&words, now, settings.dialect.into())?; + let end = start + Duration::days(1); + db.range(start, end).await? + }; + + let stats = compute(settings, &history, self.count, self.ngram_size); + + if let Some(stats) = stats { + pretty_print(stats, self.ngram_size, theme); + } + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/store.rs b/crates/turtle/src/command/client/store.rs new file mode 100644 index 00000000..dfa3b66c --- /dev/null +++ b/crates/turtle/src/command/client/store.rs @@ -0,0 +1,120 @@ +use clap::Subcommand; +use eyre::Result; + +use crate::atuin_client::{ + database::Database, + record::{sqlite_store::SqliteStore, store::Store}, + settings::Settings, +}; +use itertools::Itertools; +use time::{OffsetDateTime, UtcOffset}; + +#[cfg(feature = "sync")] +mod push; + +#[cfg(feature = "sync")] +mod pull; + +mod purge; +mod rebuild; +mod rekey; +mod verify; + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Print the current status of the record store + Status, + + /// Rebuild a store (eg atuin store rebuild history) + Rebuild(rebuild::Rebuild), + + /// Re-encrypt the store with a new key (potential for data loss!) + Rekey(rekey::Rekey), + + /// Delete all records in the store that cannot be decrypted with the current key + Purge(purge::Purge), + + /// Verify that all records in the store can be decrypted with the current key + Verify(verify::Verify), + + /// Push all records to the remote sync server (one way sync) + #[cfg(feature = "sync")] + Push(push::Push), + + /// Pull records from the remote sync server (one way sync) + #[cfg(feature = "sync")] + Pull(pull::Pull), +} + +impl Cmd { + pub async fn run( + &self, + settings: &Settings, + database: &dyn Database, + store: SqliteStore, + ) -> Result<()> { + match self { + Self::Status => self.status(store).await, + Self::Rebuild(rebuild) => rebuild.run(settings, store, database).await, + Self::Rekey(rekey) => rekey.run(settings, store).await, + Self::Verify(verify) => verify.run(settings, store).await, + Self::Purge(purge) => purge.run(settings, store).await, + + #[cfg(feature = "sync")] + Self::Push(push) => push.run(settings, store).await, + + #[cfg(feature = "sync")] + Self::Pull(pull) => pull.run(settings, store, database).await, + } + } + + pub async fn status(&self, store: SqliteStore) -> Result<()> { + let host_id = Settings::host_id().await?; + let offset = UtcOffset::current_local_offset().unwrap_or(UtcOffset::UTC); + + let status = store.status().await?; + + // TODO: should probs build some data structure and then pretty-print it or smth + for (host, st) in status.hosts.iter().sorted_by_key(|(h, _)| *h) { + let host_string = if host == &host_id { + format!("host: {} <- CURRENT HOST", host.0.as_hyphenated()) + } else { + format!("host: {}", host.0.as_hyphenated()) + }; + + println!("{host_string}"); + + for (tag, idx) in st.iter().sorted_by_key(|(tag, _)| *tag) { + println!("\tstore: {tag}"); + + let first = store.first(*host, tag).await?; + let last = store.last(*host, tag).await?; + + println!("\t\tidx: {idx}"); + + if let Some(first) = first { + println!("\t\tfirst: {}", first.id.0.as_hyphenated()); + + let time = + OffsetDateTime::from_unix_timestamp_nanos(i128::from(first.timestamp))? + .to_offset(offset); + println!("\t\t\tcreated: {time}"); + } + + if let Some(last) = last { + println!("\t\tlast: {}", last.id.0.as_hyphenated()); + + let time = + OffsetDateTime::from_unix_timestamp_nanos(i128::from(last.timestamp))? + .to_offset(offset); + println!("\t\t\tcreated: {time}"); + } + } + + println!(); + } + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/store/pull.rs b/crates/turtle/src/command/client/store/pull.rs new file mode 100644 index 00000000..c9c9c379 --- /dev/null +++ b/crates/turtle/src/command/client/store/pull.rs @@ -0,0 +1,94 @@ +use clap::Args; +use eyre::Result; + +use crate::atuin_client::{ + database::Database, + encryption::load_key, + record::store::Store, + record::sync::Operation, + record::{sqlite_store::SqliteStore, sync}, + settings::Settings, +}; + +#[derive(Args, Debug)] +pub struct Pull { + /// The tag to push (eg, 'history'). Defaults to all tags + #[arg(long, short)] + pub tag: Option<String>, + + /// Force push records + /// This will first wipe the local store, and then download all records from the remote + #[arg(long, default_value = "false")] + pub force: bool, + + /// Page Size + /// How many records to download at once. Defaults to 100 + #[arg(long, default_value = "100")] + pub page: u64, +} + +impl Pull { + pub async fn run( + &self, + settings: &Settings, + store: SqliteStore, + db: &dyn Database, + ) -> Result<()> { + if self.force { + println!("Forcing local overwrite!"); + println!("Clearing local store"); + + store.delete_all().await?; + } + + // We can actually just use the existing diff/etc to push + // 1. Diff + // 2. Get operations + // 3. Filter operations by + // a) are they a download op? + // b) are they for the host/tag we are pushing here? + let client = sync::build_client(settings).await?; + let (diff, remote_index) = sync::diff(&client, &store).await?; + + // Skip on --force: local was already wiped above, mismatch is the user's call. + if !self.force { + let key: [u8; 32] = load_key(settings)?.into(); + sync::check_encryption_key(&client, &remote_index, &key) + .await + .map_err(crate::print_error::format_sync_error)?; + } + + let operations = sync::operations(diff, &store).await?; + + let operations = operations + .into_iter() + .filter(|op| match op { + // No noops or downloads thx + Operation::Noop { .. } | Operation::Upload { .. } => false, + + // pull, so yes plz to downloads! + Operation::Download { tag, .. } => { + if self.force { + return true; + } + + if let Some(t) = self.tag.clone() + && t != *tag + { + return false; + } + + true + } + }) + .collect(); + + let (_, downloaded) = sync::sync_remote(&client, operations, &store, self.page).await?; + + println!("Downloaded {} records", downloaded.len()); + + crate::sync::build(settings, &store, db, Some(&downloaded)).await?; + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/store/purge.rs b/crates/turtle/src/command/client/store/purge.rs new file mode 100644 index 00000000..f7996c4b --- /dev/null +++ b/crates/turtle/src/command/client/store/purge.rs @@ -0,0 +1,26 @@ +use clap::Args; +use eyre::Result; + +use crate::atuin_client::{ + encryption::load_key, + record::{sqlite_store::SqliteStore, store::Store}, + settings::Settings, +}; + +#[derive(Args, Debug)] +pub struct Purge {} + +impl Purge { + pub async fn run(&self, settings: &Settings, store: SqliteStore) -> Result<()> { + println!("Purging local records that cannot be decrypted"); + + let key = load_key(settings)?; + + match store.purge(&key.into()).await { + Ok(()) => println!("Local store purge completed OK"), + Err(e) => println!("Failed to purge local store: {e:?}"), + } + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/store/push.rs b/crates/turtle/src/command/client/store/push.rs new file mode 100644 index 00000000..724dfbef --- /dev/null +++ b/crates/turtle/src/command/client/store/push.rs @@ -0,0 +1,112 @@ +use crate::atuin_common::record::HostId; +use clap::Args; +use eyre::Result; +use uuid::Uuid; + +use crate::atuin_client::{ + api_client::Client, + encryption::load_key, + record::sync::Operation, + record::{sqlite_store::SqliteStore, sync}, + settings::Settings, +}; + +#[derive(Args, Debug)] +pub struct Push { + /// The tag to push (eg, 'history'). Defaults to all tags + #[arg(long, short)] + pub tag: Option<String>, + + /// The host to push, in the form of a UUID host ID. Defaults to the current host. + #[arg(long)] + pub host: Option<Uuid>, + + /// Force push records + /// This will override both host and tag, to be all hosts and all tags. First clear the remote store, then upload all of the + /// local store + #[arg(long, default_value = "false")] + pub force: bool, + + /// Page Size + /// How many records to upload at once. Defaults to 100 + #[arg(long, default_value = "100")] + pub page: u64, +} + +impl Push { + pub async fn run(&self, settings: &Settings, store: SqliteStore) -> Result<()> { + let host_id = Settings::host_id().await?; + + if self.force { + println!("Forcing remote store overwrite!"); + println!("Clearing remote store"); + + let client = Client::new( + &settings.sync_address, + settings.sync_auth_token().await?, + settings.network_connect_timeout, + settings.network_timeout * 10, // we may be deleting a lot of data... so up the + // timeout + ) + .expect("failed to create client"); + + client.delete_store().await?; + } + + // We can actually just use the existing diff/etc to push + // 1. Diff + // 2. Get operations + // 3. Filter operations by + // a) are they an upload op? + // b) are they for the host/tag we are pushing here? + let client = sync::build_client(settings).await?; + let (diff, remote_index) = sync::diff(&client, &store).await?; + + // Skip on --force: that path intentionally replaces remote with local. + if !self.force { + let key: [u8; 32] = load_key(settings)?.into(); + sync::check_encryption_key(&client, &remote_index, &key) + .await + .map_err(crate::print_error::format_sync_error)?; + } + + let operations = sync::operations(diff, &store).await?; + + let operations = operations + .into_iter() + .filter(|op| match op { + // No noops or downloads thx + Operation::Noop { .. } | Operation::Download { .. } => false, + + // push, so yes plz to uploads! + Operation::Upload { host, tag, .. } => { + if self.force { + return true; + } + + if let Some(h) = self.host { + if HostId(h) != *host { + return false; + } + } else if *host != host_id { + return false; + } + + if let Some(t) = self.tag.clone() + && t != *tag + { + return false; + } + + true + } + }) + .collect(); + + let (uploaded, _) = sync::sync_remote(&client, operations, &store, self.page).await?; + + println!("Uploaded {uploaded} records"); + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/store/rebuild.rs b/crates/turtle/src/command/client/store/rebuild.rs new file mode 100644 index 00000000..80e201c2 --- /dev/null +++ b/crates/turtle/src/command/client/store/rebuild.rs @@ -0,0 +1,58 @@ +use clap::Args; +use eyre::{Result, bail}; + +#[cfg(feature = "daemon")] +use crate::command::client::daemon as daemon_cmd; + +use crate::atuin_client::{ + database::Database, encryption, history::store::HistoryStore, + record::sqlite_store::SqliteStore, settings::Settings, +}; + +#[derive(Args, Debug)] +pub struct Rebuild { + pub tag: String, +} + +impl Rebuild { + pub async fn run( + &self, + settings: &Settings, + store: SqliteStore, + database: &dyn Database, + ) -> Result<()> { + // keep it as a string and not an enum atm + // would be super cool to build this dynamically in the future + // eg register handles for rebuilding various tags without having to make this part of the + // binary big + match self.tag.as_str() { + "history" => { + self.rebuild_history(settings, store.clone(), database) + .await?; + } + + tag => bail!("unknown tag: {tag}"), + } + + Ok(()) + } + + async fn rebuild_history( + &self, + settings: &Settings, + store: SqliteStore, + database: &dyn Database, + ) -> Result<()> { + let encryption_key: [u8; 32] = encryption::load_key(settings)?.into(); + + let host_id = Settings::host_id().await?; + let history_store = HistoryStore::new(store, host_id, encryption_key); + + history_store.build(database).await?; + + #[cfg(feature = "daemon")] + daemon_cmd::emit_event(settings, crate::atuin_daemon::DaemonEvent::HistoryRebuilt).await; + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/store/rekey.rs b/crates/turtle/src/command/client/store/rekey.rs new file mode 100644 index 00000000..e63be447 --- /dev/null +++ b/crates/turtle/src/command/client/store/rekey.rs @@ -0,0 +1,41 @@ +use clap::Args; +use eyre::Result; +use tokio::{fs::File, io::AsyncWriteExt}; + +use crate::atuin_client::{ + encryption::{decode_key, generate_encoded_key, load_key}, + record::sqlite_store::SqliteStore, + record::store::Store, + settings::Settings, +}; + +#[derive(Args, Debug)] +pub struct Rekey { + /// The new key to use for encryption. Omit for a randomly-generated key + key: Option<String>, +} + +impl Rekey { + pub async fn run(&self, settings: &Settings, store: SqliteStore) -> Result<()> { + let key = if let Some(key) = self.key.clone() { + println!("Re-encrypting store with specified key"); + + key + } else { + println!("Re-encrypting store with freshly-generated key"); + let (_, encoded) = generate_encoded_key()?; + encoded + }; + + let current_key: [u8; 32] = load_key(settings)?.into(); + let new_key: [u8; 32] = decode_key(key.clone())?.into(); + + store.re_encrypt(¤t_key, &new_key).await?; + + println!("Store rewritten. Saving new key"); + let mut file = File::create(settings.key_path.clone()).await?; + file.write_all(key.as_bytes()).await?; + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/store/verify.rs b/crates/turtle/src/command/client/store/verify.rs new file mode 100644 index 00000000..5aa1dc70 --- /dev/null +++ b/crates/turtle/src/command/client/store/verify.rs @@ -0,0 +1,26 @@ +use clap::Args; +use eyre::Result; + +use crate::atuin_client::{ + encryption::load_key, + record::{sqlite_store::SqliteStore, store::Store}, + settings::Settings, +}; + +#[derive(Args, Debug)] +pub struct Verify {} + +impl Verify { + pub async fn run(&self, settings: &Settings, store: SqliteStore) -> Result<()> { + println!("Verifying local store can be decrypted with the current key"); + + let key = load_key(settings)?; + + match store.verify(&key.into()).await { + Ok(()) => println!("Local store encryption verified OK"), + Err(e) => println!("Failed to verify local store encryption: {e:?}"), + } + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/sync.rs b/crates/turtle/src/command/client/sync.rs new file mode 100644 index 00000000..a4839b5f --- /dev/null +++ b/crates/turtle/src/command/client/sync.rs @@ -0,0 +1,120 @@ +use clap::Subcommand; +use eyre::{Result, WrapErr}; + +use crate::atuin_client::{ + database::Database, + encryption, + history::store::HistoryStore, + record::{sqlite_store::SqliteStore, store::Store, sync}, + settings::Settings, +}; + +mod status; + +use crate::command::client::account; + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Sync with the configured server + Sync { + /// Force re-download everything + #[arg(long, short)] + force: bool, + }, + + /// Login to the configured server + Login(account::login::Cmd), + + /// Log out + Logout, + + /// Register with the configured server + Register(account::register::Cmd), + + /// Print the encryption key for transfer to another machine + Key {}, + + /// Display the sync status + Status, +} + +impl Cmd { + pub async fn run( + self, + settings: Settings, + db: &impl Database, + store: SqliteStore, + ) -> Result<()> { + match self { + Self::Sync { force } => run(&settings, force, db, store).await, + Self::Login(l) => l.run(&settings, &store).await, + Self::Logout => account::logout::run().await, + Self::Register(r) => r.run(&settings).await, + Self::Status => status::run(&settings).await, + Self::Key {} => { + use crate::atuin_client::encryption::{encode_key, load_key}; + let key = load_key(&settings).wrap_err("could not load encryption key")?; + + let encode = encode_key(&key).wrap_err("could not encode encryption key")?; + println!("{encode}"); + + Ok(()) + } + } + } +} + +async fn run( + settings: &Settings, + force: bool, + db: &impl Database, + store: SqliteStore, +) -> Result<()> { + let encryption_key: [u8; 32] = encryption::load_key(settings) + .context("could not load encryption key")? + .into(); + + let host_id = Settings::host_id().await?; + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + + let (uploaded, downloaded) = sync::sync(settings, &store, &encryption_key) + .await + .map_err(crate::print_error::format_sync_error)?; + + crate::sync::build(settings, &store, db, Some(&downloaded)).await?; + + println!("{uploaded}/{} up/down to record store", downloaded.len()); + + let history_length = db.history_count(true).await?; + let store_history_length = store.len_tag("history").await?; + + #[expect(clippy::cast_sign_loss)] + if history_length as u64 > store_history_length { + println!("{history_length} in history index, but {store_history_length} in history store"); + println!("Running automatic history store init..."); + + // Internally we use the global filter mode, so this context is ignored. + // don't recurse or loop here. + history_store.init_store(db).await?; + + println!("Re-running sync due to new records locally"); + + // we'll want to run sync once more, as there will now be stuff to upload + let (uploaded, downloaded) = sync::sync(settings, &store, &encryption_key) + .await + .map_err(crate::print_error::format_sync_error)?; + + crate::sync::build(settings, &store, db, Some(&downloaded)).await?; + + println!("{uploaded}/{} up/down to record store", downloaded.len()); + } + + println!( + "Sync complete! {} items in history database, force: {}", + db.history_count(true).await?, + force + ); + + Ok(()) +} diff --git a/crates/turtle/src/command/client/sync/status.rs b/crates/turtle/src/command/client/sync/status.rs new file mode 100644 index 00000000..00088b59 --- /dev/null +++ b/crates/turtle/src/command/client/sync/status.rs @@ -0,0 +1,37 @@ +use crate::{SHA, VERSION}; +use crate::atuin_client::{api_client, settings::Settings}; +use colored::Colorize; +use eyre::{Result, bail}; + +pub async fn run(settings: &Settings) -> Result<()> { + if !settings.logged_in().await? { + bail!("You are not logged in to a sync server - cannot show sync status"); + } + + let client = api_client::Client::new( + &settings.sync_address, + settings.sync_auth_token().await?, + settings.network_connect_timeout, + settings.network_timeout, + )?; + + let me = client.me().await?; + let last_sync = Settings::last_sync().await?; + + println!("Atuin v{VERSION} - Build rev {SHA}\n"); + + println!("{}", "[Local]".green()); + + if settings.auto_sync { + println!("Sync frequency: {}", settings.sync_frequency); + println!("Last sync: {}", last_sync.to_offset(settings.timezone.0)); + } + + if settings.auto_sync { + println!("{}", "[Remote]".green()); + println!("Address: {}", settings.sync_address); + println!("Username: {}", me.username); + } + + Ok(()) +} diff --git a/crates/turtle/src/command/client/wrapped.rs b/crates/turtle/src/command/client/wrapped.rs new file mode 100644 index 00000000..694157c2 --- /dev/null +++ b/crates/turtle/src/command/client/wrapped.rs @@ -0,0 +1,326 @@ +use crossterm::style::{ResetColor, SetAttribute}; +use eyre::Result; +use std::collections::{HashMap, HashSet}; +use time::{Date, Duration, Month, OffsetDateTime, Time}; + +use crate::atuin_client::{database::Database, settings::Settings, theme::Theme}; + +use crate::atuin_history::stats::{Stats, compute}; + +#[derive(Debug)] +struct WrappedStats { + nav_commands: usize, + pkg_commands: usize, + error_rate: f64, + first_half_commands: Vec<(String, usize)>, + second_half_commands: Vec<(String, usize)>, + git_percentage: f64, + busiest_hour: Option<(String, usize)>, +} + +impl WrappedStats { + #[expect(clippy::too_many_lines, clippy::cast_precision_loss)] + fn new( + settings: &Settings, + stats: &Stats, + history: &[crate::atuin_client::history::History], + ) -> Self { + let nav_commands = stats + .top + .iter() + .filter(|(cmd, _)| { + let cmd = &cmd[0]; + cmd == "cd" || cmd == "ls" || cmd == "pwd" || cmd == "pushd" || cmd == "popd" + }) + .map(|(_, count)| count) + .sum(); + + let pkg_managers = [ + "cargo", + "npm", + "pnpm", + "yarn", + "pip", + "pip3", + "pipenv", + "poetry", + "pipx", + "uv", + "brew", + "apt", + "apt-get", + "apk", + "pacman", + "yay", + "paru", + "yum", + "dnf", + "dnf5", + "rpm", + "rpm-ostree", + "zypper", + "pkg", + "chocolatey", + "choco", + "scoop", + "winget", + "gem", + "bundle", + "shards", + "composer", + "gradle", + "maven", + "mvn", + "go get", + "nuget", + "dotnet", + "mix", + "hex", + "rebar3", + "nix", + "nix-env", + "cabal", + "opam", + ]; + + let pkg_commands = history + .iter() + .filter(|h| { + let cmd = h.command.clone(); + pkg_managers.iter().any(|pm| cmd.starts_with(pm)) + }) + .count(); + + // Error analysis + let mut command_errors: HashMap<String, (usize, usize)> = HashMap::new(); // (total_uses, errors) + let midyear = history[0].timestamp + Duration::days(182); // Split year in half + + let mut first_half_commands: HashMap<String, usize> = HashMap::new(); + let mut second_half_commands: HashMap<String, usize> = HashMap::new(); + let mut hours: HashMap<String, usize> = HashMap::new(); + + for entry in history { + let cmd = entry + .command + .split_whitespace() + .next() + .unwrap_or("") + .to_string(); + let (total, errors) = command_errors.entry(cmd.clone()).or_insert((0, 0)); + *total += 1; + if entry.exit != 0 { + *errors += 1; + } + + // Track command evolution + if entry.timestamp < midyear { + *first_half_commands.entry(cmd.clone()).or_default() += 1; + } else { + *second_half_commands.entry(cmd).or_default() += 1; + } + + // Track hourly distribution + let local_time = entry + .timestamp + .to_offset(time::UtcOffset::current_local_offset().unwrap_or(settings.timezone.0)); + let hour = format!("{:02}:00", local_time.time().hour()); + *hours.entry(hour).or_default() += 1; + } + + let total_errors: usize = command_errors.values().map(|(_, errors)| errors).sum(); + let total_commands: usize = command_errors.values().map(|(total, _)| total).sum(); + let error_rate = total_errors as f64 / total_commands as f64; + + // Process command evolution data + let mut first_half: Vec<_> = first_half_commands.into_iter().collect(); + let mut second_half: Vec<_> = second_half_commands.into_iter().collect(); + first_half.sort_by_key(|(_, count)| std::cmp::Reverse(*count)); + second_half.sort_by_key(|(_, count)| std::cmp::Reverse(*count)); + first_half.truncate(5); + second_half.truncate(5); + + // Calculate git percentage + let git_commands: usize = stats + .top + .iter() + .filter(|(cmd, _)| cmd[0].starts_with("git")) + .map(|(_, count)| count) + .sum(); + let git_percentage = git_commands as f64 / stats.total_commands as f64; + + // Find busiest hour + let busiest_hour = hours.into_iter().max_by_key(|(_, count)| *count); + + Self { + nav_commands, + pkg_commands, + error_rate, + first_half_commands: first_half, + second_half_commands: second_half, + git_percentage, + busiest_hour, + } + } +} + +pub fn print_wrapped_header(year: i32) { + let reset = ResetColor; + let bold = SetAttribute(crossterm::style::Attribute::Bold); + + println!("{bold}╭────────────────────────────────────╮{reset}"); + println!("{bold}│ ATUIN WRAPPED {year} │{reset}"); + println!("{bold}│ Your Year in Shell History │{reset}"); + println!("{bold}╰────────────────────────────────────╯{reset}"); + println!(); +} + +#[expect(clippy::cast_precision_loss)] +fn print_fun_facts(wrapped_stats: &WrappedStats, stats: &Stats, year: i32) { + let reset = ResetColor; + let bold = SetAttribute(crossterm::style::Attribute::Bold); + + if wrapped_stats.git_percentage > 0.05 { + println!( + "{bold}🌟 You're a Git Power User!{reset} {bold}{:.1}%{reset} of your commands were Git operations\n", + wrapped_stats.git_percentage * 100.0 + ); + } + // Navigation patterns + let nav_percentage = wrapped_stats.nav_commands as f64 / stats.total_commands as f64 * 100.0; + if nav_percentage > 0.05 { + println!( + "{bold}🚀 You're a Navigator!{reset} {bold}{nav_percentage:.1}%{reset} of your time was spent navigating directories\n", + ); + } + + // Command vocabulary + println!( + "{bold}📚 Command Vocabulary{reset}: You know {bold}{}{reset} unique commands\n", + stats.unique_commands + ); + + // Package management + println!( + "{bold}📦 Package Management{reset}: You ran {bold}{}{reset} package-related commands\n", + wrapped_stats.pkg_commands + ); + + // Error patterns + let error_percentage = wrapped_stats.error_rate * 100.0; + println!( + "{bold}🚨 Error Analysis{reset}: Your commands failed {bold}{error_percentage:.1}%{reset} of the time\n", + ); + + // Command evolution + println!("🔍 Command Evolution:"); + + // print stats for each half and compare + println!(" {bold}Top Commands{reset} in the first half of {year}:"); + for (cmd, count) in wrapped_stats.first_half_commands.iter().take(3) { + println!(" {bold}{cmd}{reset} ({count} times)"); + } + + println!(" {bold}Top Commands{reset} in the second half of {year}:"); + for (cmd, count) in wrapped_stats.second_half_commands.iter().take(3) { + println!(" {bold}{cmd}{reset} ({count} times)"); + } + + // Find new favorite commands (in top 5 of second half but not in first half) + let first_half_set: HashSet<_> = wrapped_stats + .first_half_commands + .iter() + .map(|(cmd, _)| cmd) + .collect(); + let new_favorites: Vec<_> = wrapped_stats + .second_half_commands + .iter() + .filter(|(cmd, _)| !first_half_set.contains(cmd)) + .take(2) + .collect(); + + if !new_favorites.is_empty() { + println!(" {bold}New favorites{reset} in the second half:"); + for (cmd, count) in new_favorites { + println!(" {bold}{cmd}{reset} ({count} times)"); + } + } + + // Time patterns + if let Some((hour, count)) = &wrapped_stats.busiest_hour { + println!("\n🕘 Most Productive Hour: {bold}{hour}{reset} ({count} commands)"); + + // Night owl or early bird + let hour_num = hour + .split(':') + .next() + .unwrap_or("0") + .parse::<u32>() + .unwrap_or(0); + if hour_num >= 22 || hour_num <= 4 { + println!(" You're quite the night owl! 🦉"); + } else if (5..=7).contains(&hour_num) { + println!(" Early bird gets the worm! 🐦"); + } + } + + println!(); +} + +pub async fn run( + year: Option<i32>, + db: &impl Database, + settings: &Settings, + theme: &Theme, +) -> Result<()> { + let now = OffsetDateTime::now_utc().to_offset(settings.timezone.0); + let month = now.month(); + + // If we're in December, then wrapped is for the current year. If not, it's for the previous year + let year = year.unwrap_or_else(|| { + if month == Month::December { + now.year() + } else { + now.year() - 1 + } + }); + + let start = OffsetDateTime::new_in_offset( + Date::from_calendar_date(year, Month::January, 1).unwrap(), + Time::MIDNIGHT, + now.offset(), + ); + let end = OffsetDateTime::new_in_offset( + Date::from_calendar_date(year, Month::December, 31).unwrap(), + Time::MIDNIGHT + Duration::days(1) - Duration::nanoseconds(1), + now.offset(), + ); + + let history = db.range(start, end).await?; + if history.is_empty() { + println!( + "Your history for {year} is empty!\nMaybe 'atuin import' could help you import your previous history 🪄" + ); + return Ok(()); + } + + // Compute overall stats using existing functionality + let stats = compute(settings, &history, 10, 1).expect("Failed to compute stats"); + let wrapped_stats = WrappedStats::new(settings, &stats, &history); + + // Print wrapped format + print_wrapped_header(year); + + println!("🎉 In {year}, you typed {} commands!", stats.total_commands); + println!( + " That's ~{} commands every day\n", + stats.total_commands / 365 + ); + + println!("Your Top Commands:"); + crate::atuin_history::stats::pretty_print(stats.clone(), 1, theme); + println!(); + + print_fun_facts(&wrapped_stats, &stats, year); + + Ok(()) +} diff --git a/crates/turtle/src/command/contributors.rs b/crates/turtle/src/command/contributors.rs new file mode 100644 index 00000000..452fd335 --- /dev/null +++ b/crates/turtle/src/command/contributors.rs @@ -0,0 +1,5 @@ +static CONTRIBUTORS: &str = include_str!("CONTRIBUTORS"); + +pub fn run() { + println!("\n{CONTRIBUTORS}"); +} diff --git a/crates/turtle/src/command/external.rs b/crates/turtle/src/command/external.rs new file mode 100644 index 00000000..e1f0cddd --- /dev/null +++ b/crates/turtle/src/command/external.rs @@ -0,0 +1,102 @@ +use std::fmt::Write as _; +use std::process::Command; +use std::{io, process}; + +#[cfg(feature = "client")] +use crate::atuin_client::plugin::{OfficialPluginRegistry, PluginContext}; +use clap::CommandFactory; +use clap::builder::{StyledStr, Styles}; +use eyre::Result; + +use crate::Atuin; + +pub fn run(args: &[String]) -> Result<()> { + let subcommand = &args[0]; + let bin = format!("atuin-{subcommand}"); + let mut cmd = Command::new(&bin); + cmd.args(&args[1..]); + + #[cfg(feature = "client")] + let context = PluginContext::new(subcommand); + + let spawn_result = match cmd.spawn() { + Ok(child) => Ok(child), + Err(e) => match e.kind() { + io::ErrorKind::NotFound => { + let output = render_not_found(subcommand, &bin); + Err(output) + } + _ => Err(e.to_string().into()), + }, + }; + + match spawn_result { + Ok(mut child) => { + let status = child.wait()?; + if status.success() { + Ok(()) + } else { + #[cfg(feature = "client")] + drop(context); + + process::exit(status.code().unwrap_or(1)); + } + } + Err(e) => { + eprintln!("{}", e.ansi()); + + #[cfg(feature = "client")] + drop(context); + + process::exit(1); + } + } +} + +fn render_not_found(subcommand: &str, bin: &str) -> StyledStr { + let mut output = StyledStr::new(); + let styles = Styles::styled(); + + let error = styles.get_error(); + let invalid = styles.get_invalid(); + let literal = styles.get_literal(); + + #[cfg(feature = "client")] + { + let registry = OfficialPluginRegistry::new(); + + // Check if this is an official plugin + if let Some(install_message) = registry.get_install_message(subcommand) { + let _ = write!(output, "{error}error:{error:#} "); + let _ = write!( + output, + "'{invalid}{subcommand}{invalid:#}' is an official atuin plugin, but it's not installed" + ); + let _ = write!(output, "\n\n"); + let _ = write!(output, "{install_message}"); + return output; + } + } + + let mut atuin_cmd = Atuin::command(); + let usage = atuin_cmd.render_usage(); + + let _ = write!(output, "{error}error:{error:#} "); + let _ = write!( + output, + "unrecognized subcommand '{invalid}{subcommand}{invalid:#}' " + ); + let _ = write!( + output, + "and no executable named '{invalid}{bin}{invalid:#}' found in your PATH" + ); + let _ = write!(output, "\n\n"); + let _ = write!(output, "{usage}"); + let _ = write!(output, "\n\n"); + let _ = write!( + output, + "For more information, try '{literal}--help{literal:#}'." + ); + + output +} diff --git a/crates/turtle/src/command/gen_completions.rs b/crates/turtle/src/command/gen_completions.rs new file mode 100644 index 00000000..10d4f689 --- /dev/null +++ b/crates/turtle/src/command/gen_completions.rs @@ -0,0 +1,84 @@ +use clap::{CommandFactory, Parser, ValueEnum}; +use clap_complete::{Generator, Shell, generate, generate_to}; +use clap_complete_nushell::Nushell; +use eyre::Result; + +// clap put nushell completions into a separate package due to the maintainers +// being a little less committed to support them. +// This means we have to do a tiny bit of legwork to combine these completions +// into one command. +#[derive(Debug, Clone, ValueEnum)] +#[value(rename_all = "lower")] +pub enum GenShell { + Bash, + Elvish, + Fish, + Nushell, + PowerShell, + Zsh, +} + +impl Generator for GenShell { + fn file_name(&self, name: &str) -> String { + match self { + // clap_complete + Self::Bash => Shell::Bash.file_name(name), + Self::Elvish => Shell::Elvish.file_name(name), + Self::Fish => Shell::Fish.file_name(name), + Self::PowerShell => Shell::PowerShell.file_name(name), + Self::Zsh => Shell::Zsh.file_name(name), + + // clap_complete_nushell + Self::Nushell => Nushell.file_name(name), + } + } + + fn generate(&self, cmd: &clap::Command, buf: &mut dyn std::io::prelude::Write) { + match self { + // clap_complete + Self::Bash => Shell::Bash.generate(cmd, buf), + Self::Elvish => Shell::Elvish.generate(cmd, buf), + Self::Fish => Shell::Fish.generate(cmd, buf), + Self::PowerShell => Shell::PowerShell.generate(cmd, buf), + Self::Zsh => Shell::Zsh.generate(cmd, buf), + + // clap_complete_nushell + Self::Nushell => Nushell.generate(cmd, buf), + } + } +} + +#[derive(Debug, Parser)] +pub struct Cmd { + /// Set the shell for generating completions + #[arg(long, short)] + shell: GenShell, + + /// Set the output directory + #[arg(long, short)] + out_dir: Option<String>, +} + +impl Cmd { + pub fn run(self) -> Result<()> { + let Cmd { shell, out_dir } = self; + + let mut cli = crate::Atuin::command(); + + match out_dir { + Some(out_dir) => { + generate_to(shell, &mut cli, env!("CARGO_PKG_NAME"), &out_dir)?; + } + None => { + generate( + shell, + &mut cli, + env!("CARGO_PKG_NAME"), + &mut std::io::stdout(), + ); + } + } + + Ok(()) + } +} diff --git a/crates/turtle/src/command/mod.rs b/crates/turtle/src/command/mod.rs new file mode 100644 index 00000000..e58bfe72 --- /dev/null +++ b/crates/turtle/src/command/mod.rs @@ -0,0 +1,156 @@ +use clap::Subcommand; +use eyre::Result; + +#[cfg(not(windows))] +use rustix::{fs::Mode, process::umask}; + +#[cfg(feature = "client")] +mod client; + +mod contributors; + +mod gen_completions; + +mod external; + +#[derive(Subcommand)] +#[command(infer_subcommands = true)] +#[expect(clippy::large_enum_variant)] +pub enum AtuinCmd { + #[cfg(feature = "client")] + #[command(flatten)] + Client(client::Cmd), + + /// PTY proxy for atuin + #[cfg(feature = "pty-proxy")] + #[command(alias = "hex")] + PtyProxy(crate::atuin_pty_proxy::PtyProxy), + + /// Generate a UUID + Uuid, + + Contributors, + + /// Generate shell completions + GenCompletions(gen_completions::Cmd), + + #[command(external_subcommand)] + External(Vec<String>), +} + +impl AtuinCmd { + pub fn run(self) -> Result<()> { + #[cfg(not(windows))] + { + // set umask before we potentially open/create files + // or in other words, 077. Do not allow any access to any other user + let mode = Mode::RWXG | Mode::RWXO; + umask(mode); + } + + match self { + #[cfg(feature = "client")] + Self::Client(client) => client.run(), + + #[cfg(feature = "pty-proxy")] + Self::PtyProxy(proxy) => { + run_pty_proxy(proxy); + Ok(()) + } + + Self::Contributors => { + contributors::run(); + Ok(()) + } + Self::Uuid => { + println!("{}", crate::atuin_common::utils::uuid_v7().as_simple()); + Ok(()) + } + Self::GenCompletions(gen_completions) => gen_completions.run(), + Self::External(args) => external::run(&args), + } + } +} + +#[cfg(all(feature = "pty-proxy", unix))] +fn run_pty_proxy(proxy: crate::atuin_pty_proxy::PtyProxy) { + #[cfg(feature = "daemon")] + proxy.run(semantic_command_capture_sink()); + + #[cfg(not(feature = "daemon"))] + proxy.run(None); +} + +#[cfg(all(feature = "daemon", feature = "pty-proxy", unix))] +fn semantic_command_capture_sink() -> Option<crate::atuin_pty_proxy::CommandCaptureSink> { + use std::sync::mpsc; + use std::time::Duration; + + if is_truthy_env("ATUIN_TERMINAL") { + return None; + } + + let settings = crate::atuin_client::settings::Settings::new().ok()?; + let (tx, rx) = mpsc::sync_channel::<crate::atuin_pty_proxy::CommandCapture>(128); + + std::thread::spawn(move || { + let Ok(runtime) = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + else { + return; + }; + + while let Ok(first) = rx.recv() { + let mut batch = vec![first]; + + while batch.len() < 64 { + match rx.recv_timeout(Duration::from_millis(25)) { + Ok(capture) => batch.push(capture), + Err(mpsc::RecvTimeoutError::Timeout | mpsc::RecvTimeoutError::Disconnected) => { + break; + } + } + } + + runtime.block_on(send_semantic_command_captures(&settings, batch)); + } + }); + + Some(Box::new(move |capture| { + let _ = tx.try_send(capture); + })) +} + +#[cfg(all(feature = "daemon", feature = "pty-proxy", unix))] +#[inline] +fn is_truthy_env(name: &str) -> bool { + std::env::var(name) + .ok() + .as_ref() + .is_some_and(|value| !value.trim().is_empty() && value.trim() != "false") +} + +#[cfg(all(feature = "daemon", feature = "pty-proxy", unix))] +async fn send_semantic_command_captures( + settings: &crate::atuin_client::settings::Settings, + batch: Vec<crate::atuin_pty_proxy::CommandCapture>, +) { + let captures = batch + .into_iter() + .map(|capture| crate::atuin_daemon::semantic::CommandCapture { + prompt: capture.prompt, + command: capture.command, + output: capture.output, + exit_code: capture.exit_code, + history_id: capture.history_id, + session_id: capture.session_id, + output_truncated: capture.output_truncated, + output_observed_bytes: capture.output_observed_bytes, + }) + .collect(); + + if let Ok(mut client) = crate::atuin_daemon::SemanticClient::from_settings(settings).await { + let _ = client.record_commands(captures).await; + } +} diff --git a/crates/turtle/src/main.rs b/crates/turtle/src/main.rs new file mode 100644 index 00000000..e5b80ee8 --- /dev/null +++ b/crates/turtle/src/main.rs @@ -0,0 +1,73 @@ +#![warn(clippy::pedantic, clippy::nursery)] +#![allow(clippy::use_self, clippy::missing_const_for_fn)] // not 100% reliable +// #![deny(unsafe_code)] +#![forbid(unsafe_code)] + +use clap::Parser; +use clap::builder::Styles; +use clap::builder::styling::{AnsiColor, Effects}; +use eyre::Result; + +use command::AtuinCmd; + +mod command; + +mod atuin_client; +mod atuin_common; +mod atuin_daemon; +mod atuin_history; +mod atuin_pty_proxy; +mod atuin_server; +mod atuin_server_database; +mod atuin_server_postgres; +mod atuin_server_sqlite; + +#[cfg(feature = "sync")] +mod print_error; +#[cfg(feature = "sync")] +mod sync; + +const VERSION: &str = env!("CARGO_PKG_VERSION"); +const SHA: &str = env!("GIT_HASH"); + +const LONG_VERSION: &str = concat!(env!("CARGO_PKG_VERSION"), " (", env!("GIT_HASH"), ")"); + +static HELP_TEMPLATE: &str = "\ +{before-help}{name} {version} +{author} +{about} + +{usage-heading} + {usage} + +{all-args}{after-help}"; + +const STYLES: Styles = Styles::styled() + .header(AnsiColor::Yellow.on_default().effects(Effects::BOLD)) + .usage(AnsiColor::Green.on_default().effects(Effects::BOLD)) + .literal(AnsiColor::Green.on_default().effects(Effects::BOLD)) + .placeholder(AnsiColor::Green.on_default()); + +/// Magical shell history +#[derive(Parser)] +#[command( + author = "Ellie Huxtable <ellie@atuin.sh>", + version = VERSION, + long_version = LONG_VERSION, + help_template(HELP_TEMPLATE), + styles = STYLES, +)] +struct Atuin { + #[command(subcommand)] + atuin: AtuinCmd, +} + +impl Atuin { + fn run(self) -> Result<()> { + self.atuin.run() + } +} + +fn main() -> Result<()> { + Atuin::parse().run() +} diff --git a/crates/turtle/src/print_error.rs b/crates/turtle/src/print_error.rs new file mode 100644 index 00000000..4d4724bc --- /dev/null +++ b/crates/turtle/src/print_error.rs @@ -0,0 +1,123 @@ +use std::io::IsTerminal; + +use crate::atuin_client::record::sync::SyncError; +use colored::Colorize; +use crossterm::terminal; + +/// Print a prominent error to stderr. Colored and box-bordered when stderr is +/// a TTY, plain "Error: ..." header otherwise. The description is word-wrapped +/// to the terminal width (capped at 100 columns) so the message stays readable. +pub fn print_error(title: &str, description: &str) { + let is_tty = std::io::stderr().is_terminal(); + let width = if is_tty { + terminal::size().map_or(80, |(w, _)| w as usize) + } else { + 80 + } + .min(100); + + eprintln!(); + if is_tty { + let bar = "━".repeat(width).red().bold().to_string(); + eprintln!("{bar}"); + eprintln!(" {} {}", "✗".red().bold(), title.red().bold()); + eprintln!("{bar}"); + } else { + eprintln!("Error: {title}"); + eprintln!("{}", "-".repeat(width)); + } + eprintln!(); + + for line in wrap_text(description, width.saturating_sub(2)) { + eprintln!(" {line}"); + } + eprintln!(); +} + +/// Convert a `SyncError` into an `eyre::Report`, exiting on `WrongKey` after +/// painting the prominent banner. +pub fn format_sync_error(e: SyncError) -> eyre::Report { + if matches!(e, SyncError::WrongKey) { + print_error( + "Wrong encryption key", + "Your local encryption key cannot decrypt the data on the server. \ + This usually means another machine wrote records with a different key.\n\n\ + To fix this, find the correct key by running `atuin key` on a machine that \ + already syncs successfully, then run `atuin store rekey <key>` here.", + ); + std::process::exit(1); + } + e.into() +} + +fn wrap_text(text: &str, width: usize) -> Vec<String> { + let mut out = Vec::new(); + for paragraph in text.split('\n') { + let mut line = String::new(); + let mut line_len = 0; + for word in paragraph.split_whitespace() { + let word_len = word.chars().count(); + if !line.is_empty() && line_len + 1 + word_len > width { + out.push(std::mem::take(&mut line)); + line_len = 0; + } + if !line.is_empty() { + line.push(' '); + line_len += 1; + } + line.push_str(word); + line_len += word_len; + } + // Push every paragraph's final line (even empty) so `\n\n` in the + // input becomes a blank line in the output. + out.push(line); + } + while out.first().is_some_and(String::is_empty) { + out.remove(0); + } + while out.last().is_some_and(String::is_empty) { + out.pop(); + } + out +} + +#[cfg(test)] +mod tests { + use super::wrap_text; + + #[test] + fn wraps_long_text() { + let lines = wrap_text("the quick brown fox jumps over the lazy dog", 20); + for line in &lines { + assert!(line.chars().count() <= 20, "line too long: {line:?}"); + } + assert_eq!( + lines.join(" "), + "the quick brown fox jumps over the lazy dog" + ); + } + + #[test] + fn preserves_explicit_newlines() { + let lines = wrap_text("first line\nsecond line", 80); + assert_eq!(lines, vec!["first line", "second line"]); + } + + #[test] + fn handles_word_longer_than_width() { + let lines = wrap_text("short superlongword more", 5); + assert_eq!(lines, vec!["short", "superlongword", "more"]); + } + + #[test] + fn preserves_blank_lines_between_paragraphs() { + let lines = wrap_text("first paragraph\n\nsecond paragraph", 80); + assert_eq!(lines, vec!["first paragraph", "", "second paragraph"]); + } + + #[test] + fn trims_leading_and_trailing_blank_lines() { + let lines = wrap_text("\n\nbody\n\n", 80); + assert_eq!(lines, vec!["body"]); + } +} diff --git a/crates/turtle/src/shell/.gitattributes b/crates/turtle/src/shell/.gitattributes new file mode 100644 index 00000000..fae8897c --- /dev/null +++ b/crates/turtle/src/shell/.gitattributes @@ -0,0 +1 @@ +* eol=lf diff --git a/crates/turtle/src/shell/atuin.bash b/crates/turtle/src/shell/atuin.bash new file mode 100644 index 00000000..8b540bd7 --- /dev/null +++ b/crates/turtle/src/shell/atuin.bash @@ -0,0 +1,725 @@ +# Include guard +if [[ ${__atuin_initialized-} == true ]]; then + false +elif [[ $- != *i* ]]; then + # Enable only in interactive shells + false +elif ((BASH_VERSINFO[0] < 3 || BASH_VERSINFO[0] == 3 && BASH_VERSINFO[1] < 1)); then + # Require bash >= 3.1 + [[ -t 2 ]] && printf 'atuin: requires bash >= 3.1 for the integration.\n' >&2 + false +else # (include guard) beginning of main content +#------------------------------------------------------------------------------ +__atuin_initialized=true + +if [[ -z "${ATUIN_SESSION:-}" || "${ATUIN_SHLVL:-}" != "$SHLVL" ]]; then + ATUIN_SESSION=$(atuin uuid) + export ATUIN_SESSION + export ATUIN_SHLVL=$SHLVL +fi +ATUIN_STTY=$(stty -g) +ATUIN_HISTORY_ID="" + +__atuin_osc133_command_executed() { + [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]] || return + [[ -n "${ATUIN_HISTORY_ID:-}" && "$ATUIN_HISTORY_ID" != "__bash_preexec_failure__" ]] || return + + printf '\033]133;C\a' +} + +__atuin_osc133_command_finished() { + [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]] || return + [[ -n "${ATUIN_HISTORY_ID:-}" && "$ATUIN_HISTORY_ID" != "__bash_preexec_failure__" ]] || return + + printf '\033]133;D;%s;history_id=%s;session_id=%s\a' "$1" "$ATUIN_HISTORY_ID" "${ATUIN_SESSION:-}" +} + +__atuin_osc133_prompt_start=$'\001\033]133;A;cl=line\a\002' +__atuin_osc133_prompt_end=$'\001\033]133;B\a\002' + +__atuin_osc133_wrap_prompt() { + local __atuin_prompt="${PS1-}" + __atuin_prompt="${__atuin_prompt//$__atuin_osc133_prompt_start/}" + __atuin_prompt="${__atuin_prompt//$__atuin_osc133_prompt_end/}" + + if [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]]; then + PS1="${__atuin_osc133_prompt_start}${__atuin_prompt}${__atuin_osc133_prompt_end}" + else + PS1="$__atuin_prompt" + fi +} + +export ATUIN_PREEXEC_BACKEND=$SHLVL:none +__atuin_update_preexec_backend() { + if [[ ${BLE_ATTACHED-} ]]; then + ATUIN_PREEXEC_BACKEND=$SHLVL:blesh-${BLE_VERSION-} + elif [[ ${bash_preexec_imported-} ]]; then + ATUIN_PREEXEC_BACKEND=$SHLVL:bash-preexec + elif [[ ${__bp_imported-} ]]; then + ATUIN_PREEXEC_BACKEND="$SHLVL:bash-preexec (old)" + else + ATUIN_PREEXEC_BACKEND=$SHLVL:unknown + fi +} + +__atuin_preexec() { + # Workaround for old versions of bash-preexec + if [[ ! ${BLE_ATTACHED-} ]]; then + # In older versions of bash-preexec, the preexec hook may be called + # even for the commands run by keybindings. There is no general and + # robust way to detect the command for keybindings, but at least we + # want to exclude Atuin's keybindings. When the preexec hook is called + # for a keybinding, the preexec hook for the user command will not + # fire, so we instead set a fake ATUIN_HISTORY_ID here to notify + # __atuin_precmd of this failure. + if [[ $BASH_COMMAND != "$1" ]]; then + case $BASH_COMMAND in + '__atuin_history'* | '__atuin_widget_run'* | '__atuin_bash42_dispatch'*) + ATUIN_HISTORY_ID=__bash_preexec_failure__ + return 0 ;; + esac + fi + fi + + # Note: We update ATUIN_PREEXEC_BACKEND on every preexec because blesh's + # attaching state can dynamically change. + __atuin_update_preexec_backend + + local id + id=$(atuin history start -- "$1" 2>/dev/null) + export ATUIN_HISTORY_ID=$id + [[ -n ${__atuin_skip_osc133:-} ]] || __atuin_osc133_command_executed + __atuin_preexec_time=${EPOCHREALTIME-} +} + +__atuin_precmd() { + local EXIT=$? __atuin_precmd_time=${EPOCHREALTIME-} + + __atuin_osc133_wrap_prompt + + [[ ! $ATUIN_HISTORY_ID ]] && return + + # If the previous preexec hook failed, we manually call __atuin_preexec + local __atuin_skip_osc133="" + if [[ $ATUIN_HISTORY_ID == __bash_preexec_failure__ ]]; then + # This is the command extraction code taken from bash-preexec + local previous_command + previous_command=$( + export LC_ALL=C HISTTIMEFORMAT='' + builtin history 1 | sed '1 s/^ *[0-9][0-9]*[* ] //' + ) + __atuin_skip_osc133=1 + __atuin_preexec "$previous_command" + fi + + local duration="" + # shellcheck disable=SC2154,SC2309 + if [[ ${BLE_ATTACHED-} && ${_ble_exec_time_ata-} ]]; then + # With ble.sh, we utilize the shell variable `_ble_exec_time_ata` + # recorded by ble.sh. It is more accurate than the measurements by + # Atuin, which includes the spawn cost of Atuin. ble.sh uses the + # special shell variable `EPOCHREALTIME` in bash >= 5.0 with the + # microsecond resolution, or the builtin `time` in bash < 5.0 with the + # millisecond resolution. + duration=${_ble_exec_time_ata}000 + elif ((BASH_VERSINFO[0] >= 5)); then + # We calculate the high-resolution duration based on EPOCHREALTIME + # (bash >= 5.0) recorded by precmd/preexec, though it might not be as + # accurate as `_ble_exec_time_ata` provided by ble.sh because it + # includes the extra time of the precmd/preexec handling. Since Bash + # does not offer floating-point arithmetic, we remove the non-digit + # characters and perform the integral arithmetic. The fraction part of + # EPOCHREALTIME is fixed to have 6 digits in Bash. We remove all the + # non-digit characters because the decimal point is not necessarily a + # period depending on the locale. + duration=$((${__atuin_precmd_time//[!0-9]} - ${__atuin_preexec_time//[!0-9]})) + if ((duration >= 0)); then + duration=${duration}000 + else + duration="" # clear the result on overflow + fi + fi + + [[ -n ${__atuin_skip_osc133:-} ]] || __atuin_osc133_command_finished "$EXIT" + (ATUIN_LOG=error atuin history end --exit "$EXIT" ${duration:+"--duration=$duration"} -- "$ATUIN_HISTORY_ID" &) >/dev/null 2>&1 + export ATUIN_HISTORY_ID="" +} + +__atuin_set_ret_value() { + return ${1:+"$1"} +} + +#------------------------------------------------------------------------------ +# section: __atuin_accept_line +# +# The function "__atuin_accept_line" is kept for backward compatibility of the +# direct use of __atuin_history in keybindings by users. + +# The shell function `__atuin_evaluate_prompt` evaluates prompt sequences in +# $PS1. We switch the implementation of the shell function +# `__atuin_evaluate_prompt` based on the Bash version because the expansion +# ${PS1@P} is only available in bash >= 4.4. +if ((BASH_VERSINFO[0] >= 5 || BASH_VERSINFO[0] == 4 && BASH_VERSINFO[1] >= 4)); then + __atuin_evaluate_prompt() { + __atuin_set_ret_value "${__bp_last_ret_value-}" "${__bp_last_argument_prev_command-}" + __atuin_prompt=${PS1@P} + + # Note: Strip the control characters ^A (\001) and ^B (\002), which + # Bash internally uses to enclose the escape sequences. They are + # produced by '\[' and '\]', respectively, in $PS1 and used to tell + # Bash that the strings inbetween do not contribute to the prompt + # width. After the prompt width calculation, Bash strips those control + # characters before outputting it to the terminal. We here strip these + # characters following Bash's behavior. + __atuin_prompt=${__atuin_prompt//[$'\001\002']} + + # Count the number of newlines contained in $__atuin_prompt + __atuin_prompt_offset=${__atuin_prompt//[!$'\n']} + __atuin_prompt_offset=${#__atuin_prompt_offset} + } +else + __atuin_evaluate_prompt() { + __atuin_prompt='$ ' + __atuin_prompt_offset=0 + } +fi + +# The shell function `__atuin_clear_prompt N` outputs terminal control +# sequences to clear the contents of the current and N previous lines. After +# clearing, the cursor is placed at the beginning of the N-th previous line. +__atuin_clear_prompt_cache=() +__atuin_clear_prompt() { + local offset=$1 + if [[ ! ${__atuin_clear_prompt_cache[offset]+set} ]]; then + if [[ ! ${__atuin_clear_prompt_cache[0]+set} ]]; then + __atuin_clear_prompt_cache[0]=$'\r'$(tput el 2>/dev/null || tput ce 2>/dev/null) + fi + if ((offset > 0)); then + __atuin_clear_prompt_cache[offset]=${__atuin_clear_prompt_cache[0]}$( + tput cuu "$offset" 2>/dev/null || tput UP "$offset" 2>/dev/null + tput dl "$offset" 2>/dev/null || tput DL "$offset" 2>/dev/null + tput il "$offset" 2>/dev/null || tput AL "$offset" 2>/dev/null + ) + fi + fi + printf '%s' "${__atuin_clear_prompt_cache[offset]}" +} + +__atuin_accept_line() { + local __atuin_command=$1 + + # Reprint the prompt, accounting for multiple lines + local __atuin_prompt __atuin_prompt_offset + __atuin_evaluate_prompt + __atuin_clear_prompt "$__atuin_prompt_offset" + printf '%s\n' "$__atuin_prompt$__atuin_command" + + # Add it to the bash history + history -s "$__atuin_command" + + # Assuming bash-preexec + # Invoke every function in the preexec array + local __atuin_preexec_function + local __atuin_preexec_function_ret_value + local __atuin_preexec_ret_value=0 + for __atuin_preexec_function in "${preexec_functions[@]:-}"; do + if type -t "$__atuin_preexec_function" 1>/dev/null; then + __atuin_set_ret_value "${__bp_last_ret_value:-}" + "$__atuin_preexec_function" "$__atuin_command" + __atuin_preexec_function_ret_value=$? + if [[ $__atuin_preexec_function_ret_value != 0 ]]; then + __atuin_preexec_ret_value=$__atuin_preexec_function_ret_value + fi + fi + done + + # If extdebug is turned on and any preexec function returns non-zero + # exit status, we do not run the user command. + if ! { shopt -q extdebug && ((__atuin_preexec_ret_value)); }; then + # Note: When a child Bash session is started by enter_accept, if the + # environment variable READLINE_POINT is present, bash-preexec in the + # child session does not fire preexec at all because it considers we + # are inside Atuin's keybinding of the current session. To avoid + # propagating the environment variable to the child session, we remove + # the export attribute of READLINE_LINE and READLINE_POINT. + export -n READLINE_LINE READLINE_POINT + + # Juggle the terminal settings so that the command can be interacted + # with + local __atuin_stty_backup + __atuin_stty_backup=$(stty -g) + stty "$ATUIN_STTY" + + # Execute the command. Note: We need to record $? and $_ after the + # user command within the same call of "eval" because $_ is otherwise + # overwritten by the last argument of "eval". + __atuin_set_ret_value "${__bp_last_ret_value-}" "${__bp_last_argument_prev_command-}" + eval -- "$__atuin_command"$'\n__bp_last_ret_value=$? __bp_last_argument_prev_command=$_' + + stty "$__atuin_stty_backup" + fi + + # Execute preprompt commands + local __atuin_prompt_command + for __atuin_prompt_command in "${PROMPT_COMMAND[@]}"; do + __atuin_set_ret_value "${__bp_last_ret_value-}" "${__bp_last_argument_prev_command-}" + eval -- "$__atuin_prompt_command" + done + # Bash will redraw only the line with the prompt after we finish, + # so to work for a multiline prompt we need to print it ourselves, + # then go to the beginning of the last line. + __atuin_evaluate_prompt + printf '%s' "$__atuin_prompt" + __atuin_clear_prompt 0 +} + +#------------------------------------------------------------------------------ + +# Check if tmux popup is available (tmux >= 3.2) +__atuin_tmux_popup_check() { + [[ -n "${TMUX-}" ]] || return 1 + [[ "${ATUIN_TMUX_POPUP:-true}" != "false" ]] || return 1 + + # https://github.com/tmux/tmux/wiki/FAQ#how-often-is-tmux-released-what-is-the-version-number-scheme + local tmux_version + tmux_version=$(tmux -V 2>/dev/null | sed -n 's/^[^0-9]*\([0-9][0-9]*\.[0-9][0-9]*\).*/\1/p') # Could have used grep... + [[ -z "$tmux_version" ]] && return 1 + + local m1 m2 + m1=${tmux_version%%.*} + m2=${tmux_version#*.} + m2=${m2%%.*} + [[ "$m1" =~ ^[0-9]+$ ]] || return 1 + [[ "$m2" =~ ^[0-9]+$ ]] || m2=0 + (( m1 > 3 || (m1 == 3 && m2 >= 2) )) +} + +# Use global variable to fix scope issues with traps +__atuin_popup_tmpdir="" +__atuin_tmux_popup_cleanup() { + [[ -n "$__atuin_popup_tmpdir" && -d "$__atuin_popup_tmpdir" ]] && command rm -rf "$__atuin_popup_tmpdir" + __atuin_popup_tmpdir="" +} + +__atuin_search_cmd() { + local -a search_args=("$@") + + if __atuin_tmux_popup_check; then + __atuin_popup_tmpdir=$(mktemp -d) || return 1 + local result_file="$__atuin_popup_tmpdir/result" + + trap '__atuin_tmux_popup_cleanup' EXIT HUP INT TERM + + local escaped_query escaped_args + escaped_query=$(printf '%s' "$READLINE_LINE" | sed "s/'/'\\\\''/g") + escaped_args="" + for arg in "${search_args[@]}"; do + escaped_args+=" '$(printf '%s' "$arg" | sed "s/'/'\\\\''/g")'" + done + + # In the popup, atuin goes to terminal, stderr goes to file + local cdir popup_width popup_height + cdir=$(pwd) + popup_width="${ATUIN_TMUX_POPUP_WIDTH:-80%}" # Keep default value anyways + popup_height="${ATUIN_TMUX_POPUP_HEIGHT:-60%}" + tmux display-popup -d "$cdir" -w "$popup_width" -h "$popup_height" -E -E -- \ + sh -c "PATH='$PATH' ATUIN_SESSION='$ATUIN_SESSION' ATUIN_SHELL=bash ATUIN_LOG=error ATUIN_QUERY='$escaped_query' atuin search $escaped_args -i 2>'$result_file'" + + if [[ -f "$result_file" ]]; then + cat "$result_file" + fi + + __atuin_tmux_popup_cleanup + trap - EXIT HUP INT TERM + else + ATUIN_SHELL=bash ATUIN_LOG=error ATUIN_QUERY=$READLINE_LINE atuin search "${search_args[@]}" -i 3>&1 1>&2 2>&3 3>&- + fi +} + +__atuin_history() { + # Default action of the up key: When this function is called with the first + # argument `--shell-up-key-binding`, we perform Atuin's history search only + # when the up key is supposed to cause the history movement in the original + # binding. We do this only for ble.sh because the up key always invokes + # the history movement in the plain Bash. + if [[ ${BLE_ATTACHED-} && ${1-} == --shell-up-key-binding ]]; then + # When the current cursor position is not in the first line, the up key + # should move the cursor to the previous line. While the selection is + # performed, the up key should not start the history search. + # shellcheck disable=SC2154 # Note: these variables are set by ble.sh + if [[ ${_ble_edit_str::_ble_edit_ind} == *$'\n'* || $_ble_edit_mark_active ]]; then + ble/widget/@nomarked backward-line + local status=$? + READLINE_LINE=$_ble_edit_str + READLINE_POINT=$_ble_edit_ind + READLINE_MARK=$_ble_edit_mark + return "$status" + fi + fi + + # READLINE_LINE and READLINE_POINT are only supported by bash >= 4.0 or + # ble.sh. When it is not supported, we clear them to suppress strange + # behaviors. + [[ ${BLE_ATTACHED-} ]] || ((BASH_VERSINFO[0] >= 4)) || + READLINE_LINE="" READLINE_POINT=0 + + local __atuin_output + if ! __atuin_output=$(__atuin_search_cmd "$@"); then + [[ $__atuin_output ]] && printf '%s\n' "$__atuin_output" >&2 + return 1 + fi + + # We do nothing when the search is canceled. + [[ $__atuin_output ]] || return 0 + + if [[ $__atuin_output == __atuin_accept__:* ]]; then + __atuin_output=${__atuin_output#__atuin_accept__:} + + if [[ ${BLE_ATTACHED-} ]]; then + ble-edit/content/reset-and-check-dirty "$__atuin_output" + ble/widget/accept-line + READLINE_LINE="" + elif [[ ${__atuin_macro_chain_keymap-} ]]; then + READLINE_LINE=$__atuin_output + bind -m "$__atuin_macro_chain_keymap" '"'"$__atuin_macro_chain"'": '"$__atuin_macro_accept_line" + else + __atuin_accept_line "$__atuin_output" + READLINE_LINE="" + fi + + READLINE_POINT=${#READLINE_LINE} + else + READLINE_LINE=$__atuin_output + READLINE_POINT=${#READLINE_LINE} + if [[ ! ${BLE_ATTACHED-} ]] && ((BASH_VERSINFO[0] < 4)) && [[ ${__atuin_macro_chain_keymap-} ]]; then + bind -m "$__atuin_macro_chain_keymap" '"'"$__atuin_macro_chain"'": '"$__atuin_macro_insert_line" + fi + fi +} + +__atuin_initialize_blesh() { + # shellcheck disable=SC2154 + [[ ${BLE_VERSION-} ]] && ((_ble_version >= 400)) || return 0 + + ble-import contrib/integration/bash-preexec + + # Define and register an autosuggestion source for ble.sh's auto-complete. + # If you'd like to overwrite this, define the same name of shell function + # after the $(atuin init bash) line in your .bashrc. If you do not need + # the auto-complete source by Atuin, please add the following code to + # remove the entry after the $(atuin init bash) line in your .bashrc: + # + # ble/util/import/eval-after-load core-complete ' + # ble/array#remove _ble_complete_auto_source atuin-history' + # + function ble/complete/auto-complete/source:atuin-history { + local suggestion + suggestion=$(ATUIN_QUERY="$_ble_edit_str" atuin search --cmd-only --limit 1 --search-mode prefix 2>/dev/null) + [[ $suggestion == "$_ble_edit_str"?* ]] || return 1 + ble/complete/auto-complete/enter h 0 "${suggestion:${#_ble_edit_str}}" '' "$suggestion" + } + ble/util/import/eval-after-load core-complete ' + ble/array#unshift _ble_complete_auto_source atuin-history' + + # @env BLE_SESSION_ID: `atuin doctor` references the environment variable + # BLE_SESSION_ID. We explicitly export the variable because it was not + # exported in older versions of ble.sh. + [[ ${BLE_SESSION_ID-} ]] && export BLE_SESSION_ID +} +__atuin_initialize_blesh +BLE_ONLOAD+=(__atuin_initialize_blesh) +precmd_functions+=(__atuin_precmd) +preexec_functions+=(__atuin_preexec) + +#------------------------------------------------------------------------------ +# section: atuin-bind + +__atuin_widget=() + +__atuin_widget_save() { + local data=$1 + for REPLY in "${!__atuin_widget[@]}"; do + if [[ ${__atuin_widget[REPLY]} == "$data" ]]; then + return 0 + fi + done + # shellcheck disable=SC2154 + REPLY=${#__atuin_widget[*]} + __atuin_widget[REPLY]=$data +} + +__atuin_widget_run() { + local data=${__atuin_widget[$1]} + local keymap=${data%%:*} widget=${data#*:} + local __atuin_macro_chain_keymap=$keymap + bind -m "$keymap" '"'"$__atuin_macro_chain"'": ""' + builtin eval -- "$widget" +} + +# To realize the enter_accept feature in a robust way, we need to call the +# readline bindable function `accept-line'. However, there is no way to call +# `accept-line' from the shell script. To call the bindable function +# `accept-line', we may utilize string macros of readline. When we bind KEYSEQ +# to a WIDGET that wants to conditionally call `accept-line' at the end, we +# perform two-step dispatching: +# +# 1. [KEYSEQ -> IKEYSEQ1 IKEYSEQ2]---We first translate KEYSEQ to two +# intermediate key sequences IKEYSEQ1 and IKEYSEQ2 using string macros. For +# example, when we bind `__atuin_history` to \C-r, this step can be set up by +# `bind '"\C-r": "IKEYSEQ1IKEYSEQ2"'`. +# +# 2. [IKEYSEQ1 -> WIDGET]---Then, IKEYSEQ1 is bound to the WIDGET, and the +# binding of IKEYSEQ2 is dynamically determined by WIDGET. For example, when +# we bind `__atuin_history` to \C-r, this step can be set up by `bind -x +# '"IKEYSEQ1": WIDGET'`. +# +# 3. [IKEYSEQ2 -> accept-line] or [IKEYSEQ2 -> ""]---To request the execution +# of `accept-line', WIDGET can change the binding of IKEYSEQ2 by running +# `bind '"IKEYSEQ2": accept-line''. Otherwise, WIDGET can change the binding +# of IKEYSEQ2 to no-op by running `bind '"IKEYSEQ2": ""'`. +# +# For the choice of the intermediate key sequences, we want to choose key +# sequences that are unlikely to conflict with others. In addition, we want to +# avoid a key sequence containing \e because keymap "vi-insert" stops +# processing key sequences containing \e in older versions of Bash. We have +# used \e[0;<m>A (a variant of the [up] key with modifier <m>) in Atuin 3.10.0 +# for intermediate key sequences, but this contains \e and caused a problem. +# Instead, we use \C-x\C-_A<n>\a, which starts with \C-x\C-_ (an unlikely +# two-byte combination) and A (represents the initial letter of Atuin), +# followed by the payload <n> and the terminator \a (BEL, \C-g). + +__atuin_macro_chain='\C-x\C-_A0\a' +for __atuin_keymap in emacs vi-insert vi-command; do + bind -m "$__atuin_keymap" "\"$__atuin_macro_chain\": \"\"" +done +unset -v __atuin_keymap + +if ((BASH_VERSINFO[0] >= 5 || BASH_VERSINFO[0] == 4 && BASH_VERSINFO[1] >= 3)); then + # In Bash >= 4.3 + + __atuin_macro_accept_line=accept-line + + __atuin_bind_impl() { + local keymap=$1 keyseq=$2 command=$3 + + # Note: In Bash <= 5.0, the table for `bind -x` from the keyseq to the + # command is shared by all the keymaps (emacs, vi-insert, and + # vi-command), so one cannot safely bind different command strings to + # the same keyseq in different keymaps. Therefore, the command string + # and the keyseq need to be globally in one-to-one correspondence in + # all the keymaps. + local REPLY + __atuin_widget_save "$keymap:$command" + local widget=$REPLY + local ikeyseq1='\C-x\C-_A'$((1 + widget))'\a' + local ikeyseq2=$__atuin_macro_chain + + if ((BASH_VERSINFO[0] == 5 && BASH_VERSINFO[1] == 1)); then + # Workaround for Bash 5.1: Bash 5.1 has a bug that overwriting an + # existing "bind -x" keybinding breaks other existing "bind -x" + # keybindings [1,2]. To work around the problem, we explicitly + # unbind an existing keybinding before overwriting it. + # + # [1] https://lists.gnu.org/archive/html/bug-bash/2021-04/msg00135.html + # [2] https://github.com/atuinsh/atuin/issues/962#issuecomment-3451132291 + bind -m "$keymap" -r "$keyseq" + fi + + bind -m "$keymap" "\"$keyseq\": \"$ikeyseq1$ikeyseq2\"" + bind -m "$keymap" -x "\"$ikeyseq1\": __atuin_widget_run $widget" + } + + __atuin_bind_blesh_onload() { + # In ble.sh, we need to enable unrecognized CSI sequences like \e[0;0A, + # which are discarded by ble.sh by default. Note: In Bash <= 4.2, we + # do not need to unset "decode_error_cseq_discard" because \e[0;<m>A is + # used only for the macro chaining (which is unused by ble.sh) in Bash + # <= 4.2. + bleopt decode_error_cseq_discard= + } + if [[ ${BLE_VERSION-} ]]; then + __atuin_bind_blesh_onload + fi + BLE_ONLOAD+=(__atuin_bind_blesh_onload) +else + # In Bash <= 4.2, "bind -x" cannot bind a shell command to a keyseq having + # more than two bytes, so we need to work with only two-byte sequences. + # + # However, the number of available combinations of two-byte sequences is + # limited. To minimize the number of key sequences used by Atuin, instead + # of specifying a widget by its own intermediate sequence, we specify a + # widget by a fixed-length sequence of multiple two-byte sequences. More + # specifically, instead of IKEYSEQ1, we use IKS1 IKS2 IKS3 [IKS4 IKS5] + # IKSX, where IKS1..IKS5 just stores its information to a global variable, + # and IKSX collects all the information and determine and call the actual + # widget based on the stored information. Each of IKn (n=1..5) is one of + # the two reserved sequences, $__atuin_bash42_code0 and + # $__atuin_bash42_code1. IKSX is fixed to be $__atuin_bash42_code2. + # + # For the choices of the special key sequences, we consider \C-xQ, \C-xR, + # and \C-xS. In the emacs editing mode of Bash, \C-x is used as a prefix + # key, i.e., it is used for the beginning key of the keybindings with + # multiple keys, so \C-x is unlikely to be used for a single-key binding by + # the user. Also, \C-x is not used in the vi editing mode by default. The + # combinations \C-xQ..\C-xS are also unlikely be used because we need to + # switch the modifier keys from Control to Shift to input these sequences, + # and these are not easy to input. + __atuin_bash42_code0='\C-xQ' + __atuin_bash42_code1='\C-xR' + __atuin_bash42_code2='\C-xS' + + __atuin_bash42_encode() { + REPLY= + local n=$1 min_width=${2-} + while + if ((n % 2 == 0)); then + REPLY=$__atuin_bash42_code0$REPLY + else + REPLY=$__atuin_bash42_code1$REPLY + fi + (((n /= 2) || ${#REPLY} / ${#__atuin_bash42_code0} < min_width)) + do :; done + } + + __atuin_bash42_bind() { + local __atuin_keymap + for __atuin_keymap in emacs vi-insert vi-command; do + bind -m "$__atuin_keymap" -x '"'"$__atuin_bash42_code0"'": __atuin_bash42_dispatch_selector+=0' + bind -m "$__atuin_keymap" -x '"'"$__atuin_bash42_code1"'": __atuin_bash42_dispatch_selector+=1' + bind -m "$__atuin_keymap" -x '"'"$__atuin_bash42_code2"'": __atuin_bash42_dispatch' + done + } + __atuin_bash42_bind + # In Bash <= 4.2, there is no way to read users' "bind -x" settings, so we + # need to explicitly perform "bind -x" when ble.sh is loaded. + BLE_ONLOAD+=(__atuin_bash42_bind) + + if ((BASH_VERSINFO[0] >= 4)); then + __atuin_macro_accept_line=accept-line + else + # Note: We rewrite the command line and invoke `accept-line'. In + # bash <= 3.2, there is no way to rewrite the command line from the + # shell script, so we rewrite it using a macro and + # `shell-expand-line'. + # + # Note: Concerning the key sequences to invoke bindable functions + # such as "\C-x\C-_A1\a", another option is to use + # "\exbegginning-of-line\r", etc. to make it consistent with bash + # >= 5.3. However, an older Bash configuration can still conflict + # on [M-x]. The conflict is more likely than \C-x\C-_A1\a. + for __atuin_keymap in emacs vi-insert vi-command; do + bind -m "$__atuin_keymap" '"\C-x\C-_A1\a": beginning-of-line' + bind -m "$__atuin_keymap" '"\C-x\C-_A2\a": kill-line' + # shellcheck disable=SC2016 + bind -m "$__atuin_keymap" '"\C-x\C-_A3\a": "$READLINE_LINE"' + bind -m "$__atuin_keymap" '"\C-x\C-_A4\a": shell-expand-line' + bind -m "$__atuin_keymap" '"\C-x\C-_A5\a": accept-line' + bind -m "$__atuin_keymap" '"\C-x\C-_A6\a": end-of-line' + done + unset -v __atuin_keymap + + bind -m vi-command '"\C-x\C-_A7\a": vi-insertion-mode' + bind -m vi-insert '"\C-x\C-_A7\a": vi-movement-mode' + + # "\C-x\C-_A10\a": Replace the command line with READLINE_LINE. When we are + # in the vi-command keymap, we go to vi-insert, input + # "$READLINE_LINE", and come back to vi-command. + bind -m emacs '"\C-x\C-_A10\a": "\C-x\C-_A1\a\C-x\C-_A2\a\C-x\C-_A3\a\C-x\C-_A4\a"' + bind -m vi-insert '"\C-x\C-_A10\a": "\C-x\C-_A1\a\C-x\C-_A2\a\C-x\C-_A3\a\C-x\C-_A4\a"' + bind -m vi-command '"\C-x\C-_A10\a": "\C-x\C-_A1\a\C-x\C-_A2\a\C-x\C-_A7\a\C-x\C-_A3\a\C-x\C-_A7\a\C-x\C-_A4\a"' + + __atuin_macro_accept_line='"\C-x\C-_A10\a\C-x\C-_A5\a"' + __atuin_macro_insert_line='"\C-x\C-_A10\a\C-x\C-_A6\a"' + fi + + __atuin_bash42_dispatch_selector= + + __atuin_bash42_dispatch() { + local s=$__atuin_bash42_dispatch_selector + __atuin_bash42_dispatch_selector= + __atuin_widget_run "$((2#0$s))" + } + + __atuin_bind_impl() { + local keymap=$1 keyseq=$2 command=$3 + + __atuin_widget_save "$keymap:$command" + __atuin_bash42_encode "$REPLY" + local macro=$REPLY$__atuin_bash42_code2$__atuin_macro_chain + + bind -m "$keymap" "\"$keyseq\": \"$macro\"" + } +fi + +atuin-bind() { + local keymap= + local OPTIND=1 OPTARG="" OPTERR=0 flag + while getopts ':m:' flag "$@"; do + case $flag in + m) keymap=$OPTARG ;; + *) + printf '%s\n' "atuin-bind: unrecognized option '-$flag'" >&2 + return 2 + ;; + esac + done + shift "$((OPTIND - 1))" + + if (($# != 2)); then + printf '%s\n' 'usage: atuin-bind [-m keymap] keyseq widget' >&2 + return 2 + fi + + local keyseq=$1 + [[ $keymap ]] || keymap=$(bind -v | awk '$2 == "keymap" { print $3 }') + case $keymap in + emacs-meta) keymap=emacs keyseq='\e'$keyseq ;; + emacs-ctlx) keymap=emacs keyseq='\C-x'$keyseq ;; + emacs*) keymap=emacs ;; + vi-insert) ;; + vi*) keymap=vi-command ;; + *) + printf '%s\n' "atuin-bind: unknown keymap '$keymap'" >&2 + return 2 ;; + esac + + local command=$2 widget=${2%%[[:blank:]]*} + case $widget in + atuin-search) command=${2/#"$widget"/__atuin_history} ;; + atuin-search-emacs) command=${2/#"$widget"/__atuin_history --keymap-mode=emacs} ;; + atuin-search-viins) command=${2/#"$widget"/__atuin_history --keymap-mode=vim-insert} ;; + atuin-search-vicmd) command=${2/#"$widget"/__atuin_history --keymap-mode=vim-normal} ;; + atuin-up-search) command=${2/#"$widget"/__atuin_history --shell-up-key-binding} ;; + atuin-up-search-emacs) command=${2/#"$widget"/__atuin_history --shell-up-key-binding --keymap-mode=emacs} ;; + atuin-up-search-viins) command=${2/#"$widget"/__atuin_history --shell-up-key-binding --keymap-mode=vim-insert} ;; + atuin-up-search-vicmd) command=${2/#"$widget"/__atuin_history --shell-up-key-binding --keymap-mode=vim-normal} ;; + esac + + __atuin_bind_impl "$keymap" "$keyseq" "$command" +} + +#------------------------------------------------------------------------------ + +# shellcheck disable=SC2154 +if [[ $__atuin_bind_ctrl_r == true ]]; then + # Note: We do not overwrite [C-r] in the vi-command keymap because we do + # not want to overwrite "redo", which is already bound to [C-r] in the + # vi_nmap keymap in ble.sh. + atuin-bind -m emacs '\C-r' atuin-search-emacs + atuin-bind -m vi-insert '\C-r' atuin-search-viins + atuin-bind -m vi-command '/' atuin-search-emacs +fi + +# shellcheck disable=SC2154 +if [[ $__atuin_bind_up_arrow == true ]]; then + atuin-bind -m emacs '\e[A' atuin-up-search-emacs + atuin-bind -m emacs '\eOA' atuin-up-search-emacs + atuin-bind -m vi-insert '\e[A' atuin-up-search-viins + atuin-bind -m vi-insert '\eOA' atuin-up-search-viins + atuin-bind -m vi-command '\e[A' atuin-up-search-vicmd + atuin-bind -m vi-command '\eOA' atuin-up-search-vicmd + atuin-bind -m vi-command 'k' atuin-up-search-vicmd +fi + +#------------------------------------------------------------------------------ +fi # (include guard) end of main content diff --git a/crates/turtle/src/shell/atuin.fish b/crates/turtle/src/shell/atuin.fish new file mode 100644 index 00000000..15b33451 --- /dev/null +++ b/crates/turtle/src/shell/atuin.fish @@ -0,0 +1,178 @@ +if not set -q ATUIN_SESSION; or test "$ATUIN_SHLVL" != "$SHLVL" + set -gx ATUIN_SESSION (atuin uuid) + set -gx ATUIN_SHLVL $SHLVL +end +set --erase ATUIN_HISTORY_ID + +function _atuin_osc133_command_executed + set -q ATUIN_PTY_PROXY_ACTIVE; or return + test -n "$ATUIN_HISTORY_ID"; or return + + printf '\033]133;C\a' +end + +function _atuin_osc133_command_finished --argument-names exit_code + set -q ATUIN_PTY_PROXY_ACTIVE; or return + test -n "$ATUIN_HISTORY_ID"; or return + + printf '\033]133;D;%s;history_id=%s;session_id=%s\a' "$exit_code" "$ATUIN_HISTORY_ID" "$ATUIN_SESSION" +end + +function _atuin_preexec --on-event fish_preexec + if not test -n "$fish_private_mode" + set -g ATUIN_HISTORY_ID (atuin history start -- "$argv[1]" 2>/dev/null) + _atuin_osc133_command_executed + end +end + +function _atuin_postexec --on-event fish_postexec + set -l s $status + + if test -n "$ATUIN_HISTORY_ID" + _atuin_osc133_command_finished $s + ATUIN_LOG=error atuin history end --exit $s -- $ATUIN_HISTORY_ID &>/dev/null & + disown + end + + set --erase ATUIN_HISTORY_ID +end + +# Check if tmux popup is available (tmux >= 3.2) +function _atuin_tmux_popup_check + if not test -n "$TMUX" + echo 0 + return + end + + if test "$ATUIN_TMUX_POPUP" = false + echo 0 + return + end + + set -l tmux_version (tmux -V 2>/dev/null | string match -r '\d+\.\d+') + if not test -n "$tmux_version" + echo 0 + return + end + + set -l parts (string split '.' $tmux_version) + set -l m1 $parts[1] + set -l m2 0 + if test (count $parts) -ge 2 + set m2 $parts[2] + end + + if not string match -rq '^[0-9]+$' -- "$m1" + echo 0 + return + end + + if not string match -rq '^[0-9]+$' -- "$m2" + set m2 0 + end + + if test "$m1" -gt 3 2>/dev/null; or begin + test "$m1" -eq 3 2>/dev/null; and test "$m2" -ge 2 2>/dev/null + end + echo 1 + else + echo 0 + end +end + +function _atuin_search + set -l keymap_mode + switch $fish_key_bindings + case fish_vi_key_bindings fish_hybrid_key_bindings + switch $fish_bind_mode + case default + set keymap_mode vim-normal + case insert + set keymap_mode vim-insert + end + case '*' + set keymap_mode emacs + end + + set -l use_tmux_popup (_atuin_tmux_popup_check) + + set -l ATUIN_H + set -l ATUIN_STATUS 0 + if test "$use_tmux_popup" -eq 1 + set -l tmpdir (mktemp -d) + if not test -d "$tmpdir" + # if mktemp got errors + set ATUIN_H (ATUIN_SHELL=fish ATUIN_LOG=error ATUIN_QUERY=(commandline -b) atuin search --keymap-mode=$keymap_mode $argv -i 3>&1 1>&2 2>&3 3>&- | string collect) + set ATUIN_STATUS $pipestatus[1] + else + set -l result_file "$tmpdir/result" + + set -l query (commandline -b | string replace -a "'" "'\\''") + set -l escaped_args "" + for arg in $argv + set escaped_args "$escaped_args '"(string replace -a "'" "'\\''" -- $arg)"'" + end + + # In the popup, atuin goes to terminal, stderr goes to file + set -l cdir (pwd) + # Keep default value anyways + set -l popup_width (test -n "$ATUIN_TMUX_POPUP_WIDTH" && echo "$ATUIN_TMUX_POPUP_WIDTH" || echo "80%") + set -l popup_height (test -n "$ATUIN_TMUX_POPUP_HEIGHT" && echo "$ATUIN_TMUX_POPUP_HEIGHT" || echo "60%") + tmux display-popup -d "$cdir" -w "$popup_width" -h "$popup_height" -E -E -- \ + sh -c "PATH='$PATH' ATUIN_SESSION='$ATUIN_SESSION' ATUIN_SHELL=fish ATUIN_LOG=error ATUIN_QUERY='$query' atuin search --keymap-mode=$keymap_mode$escaped_args -i 2>'$result_file'" + set ATUIN_STATUS $status + + if test -f "$result_file" + set ATUIN_H (cat "$result_file" | string collect) + end + + command rm -rf "$tmpdir" + end + else + # In fish 3.4 and above we can use `"$(some command)"` to keep multiple lines separate; + # but to support fish 3.3 we need to use `(some command | string collect)`. + # https://fishshell.com/docs/current/relnotes.html#id24 (fish 3.4 "Notable improvements and fixes") + set ATUIN_H (ATUIN_SHELL=fish ATUIN_LOG=error ATUIN_QUERY=(commandline -b) atuin search --keymap-mode=$keymap_mode $argv -i 3>&1 1>&2 2>&3 3>&- | string collect) + set ATUIN_STATUS $pipestatus[1] + end + + if test "$ATUIN_STATUS" -ne 0 + test -n "$ATUIN_H"; and printf '%s\n' "$ATUIN_H" >&2 + commandline -f repaint + return "$ATUIN_STATUS" + end + + set ATUIN_H (string trim -- $ATUIN_H | string collect) # trim whitespace + + if test -n "$ATUIN_H" + if string match --quiet '__atuin_accept__:*' "$ATUIN_H" + set -l ATUIN_HIST (string replace "__atuin_accept__:" "" -- "$ATUIN_H" | string collect) + commandline -r "$ATUIN_HIST" + commandline -f repaint + commandline -f execute + return + else + commandline -r "$ATUIN_H" + end + end + + commandline -f repaint +end + +function _atuin_bind_up + # Fallback to fish's builtin up-or-search if we're in search or paging mode + if commandline --search-mode; or commandline --paging-mode + up-or-search + return + end + + # Only invoke atuin if we're on the top line of the command + set -l lineno (commandline --line) + + switch $lineno + case 1 + _atuin_search --shell-up-key-binding + case '*' + up-or-search + end +end diff --git a/crates/turtle/src/shell/atuin.nu b/crates/turtle/src/shell/atuin.nu new file mode 100644 index 00000000..d37457e4 --- /dev/null +++ b/crates/turtle/src/shell/atuin.nu @@ -0,0 +1,121 @@ +# Source this in your ~/.config/nushell/config.nu +# minimum supported version = 0.93.0 +module compat { + export def --wrapped "random uuid -v 7" [...rest] { atuin uuid } +} +use (if not ( + (version).major > 0 or + (version).minor >= 103 +) { "compat" }) * + +if 'ATUIN_SESSION' not-in $env or ('ATUIN_SHLVL' not-in $env) or ($env.ATUIN_SHLVL != ($env.SHLVL? | default "")) { + $env.ATUIN_SESSION = (random uuid -v 7 | str replace -a "-" "") + $env.ATUIN_SHLVL = ($env.SHLVL? | default "") +} +hide-env -i ATUIN_HISTORY_ID + +def _atuin_osc133_command_executed [] { + if 'ATUIN_PTY_PROXY_ACTIVE' not-in $env { + return + } + if 'ATUIN_HISTORY_ID' not-in $env or ($env.ATUIN_HISTORY_ID | is-empty) { + return + } + + print -n $"(char esc)]133;C(char bel)" +} + +def _atuin_osc133_command_finished [exit_code: int] { + if 'ATUIN_PTY_PROXY_ACTIVE' not-in $env { + return + } + if 'ATUIN_HISTORY_ID' not-in $env or ($env.ATUIN_HISTORY_ID | is-empty) { + return + } + + print -n $"(char esc)]133;D;($exit_code);history_id=($env.ATUIN_HISTORY_ID);session_id=($env.ATUIN_SESSION)(char bel)" +} + +# Magic token to make sure we don't record commands run by keybindings +let ATUIN_KEYBINDING_TOKEN = $"# (random uuid)" + +let _atuin_pre_execution = {|| + if ($nu | get history-enabled?) == false { + return + } + let cmd = (commandline) + if ($cmd | is-empty) { + return + } + if not ($cmd | str starts-with $ATUIN_KEYBINDING_TOKEN) { + $env.ATUIN_HISTORY_ID = (atuin history start -- $cmd | complete | get stdout | str trim) + _atuin_osc133_command_executed + } +} + +let _atuin_pre_prompt = {|| + let last_exit = $env.LAST_EXIT_CODE + if 'ATUIN_HISTORY_ID' not-in $env { + return + } + _atuin_osc133_command_finished $last_exit + with-env { ATUIN_LOG: error } { + if (version).minor >= 104 or (version).major > 0 { + job spawn { + ^atuin history end $'--exit=($env.LAST_EXIT_CODE)' -- $env.ATUIN_HISTORY_ID | complete + } | ignore + } else { + do { atuin history end $'--exit=($last_exit)' -- $env.ATUIN_HISTORY_ID } | complete + } + + } + hide-env ATUIN_HISTORY_ID +} + +def _atuin_search_cmd [...flags: string] { + if (version).minor >= 106 or (version).major > 0 { + [ + $ATUIN_KEYBINDING_TOKEN, + ([ + `with-env { ATUIN_LOG: error, ATUIN_QUERY: (commandline), ATUIN_SHELL: nu } {`, + ([ + 'let output = (run-external atuin search', + ($flags | append [--interactive] | each {|e| $'"($e)"'}), + 'e>| str trim)', + ] | flatten | str join ' '), + 'if ($output | str starts-with "__atuin_accept__:") {', + 'commandline edit --accept ($output | str replace "__atuin_accept__:" "")', + '} else {', + 'commandline edit $output', + '}', + `}`, + ] | flatten | str join "\n"), + ] + } else { + [ + $ATUIN_KEYBINDING_TOKEN, + ([ + `with-env { ATUIN_LOG: error, ATUIN_QUERY: (commandline) } {`, + 'commandline edit', + '(run-external atuin search', + ($flags | append [--interactive] | each {|e| $'"($e)"'}), + ' e>| str trim)', + `}`, + ] | flatten | str join ' '), + ] + } | str join "\n" +} + +$env.config = ($env | default {} config).config +$env.config = ($env.config | default {} hooks) +$env.config = ( + $env.config | upsert hooks ( + $env.config.hooks + | upsert pre_execution ( + $env.config.hooks | get pre_execution? | default [] | append $_atuin_pre_execution) + | upsert pre_prompt ( + $env.config.hooks | get pre_prompt? | default [] | append $_atuin_pre_prompt) + ) +) + +$env.config = ($env.config | default [] keybindings) diff --git a/crates/turtle/src/shell/atuin.ps1 b/crates/turtle/src/shell/atuin.ps1 new file mode 100644 index 00000000..431ee2c3 --- /dev/null +++ b/crates/turtle/src/shell/atuin.ps1 @@ -0,0 +1,240 @@ +# Atuin PowerShell module +# +# This should support PowerShell 5.1 (which is shipped with Windows) and later versions, on Windows and Linux. +# +# Usage: atuin init powershell | Out-String | Invoke-Expression +# +# Settings: +# - $env:ATUIN_POWERSHELL_PROMPT_OFFSET - Number of lines to offset the prompt position after exiting search. +# This is useful when using a multi-line prompt: e.g. set this to -1 when using a 2-line prompt. +# It is initialized from the current prompt line count if not set when the first Atuin search is performed. + +if (Get-Module Atuin -ErrorAction Ignore) { + if ($PSVersionTable.PSVersion.Major -ge 7) { + Write-Warning "The Atuin module is already loaded, replacing it." + Remove-Module Atuin + } else { + Write-Warning "The Atuin module is already loaded, skipping." + return + } +} + +if (!(Get-Command atuin -ErrorAction Ignore)) { + Write-Error "The 'atuin' executable needs to be available in the PATH." + return +} + +if (!(Get-Module PSReadLine -ErrorAction Ignore)) { + Write-Error "Atuin requires the PSReadLine module to be installed." + return +} + +New-Module -Name Atuin -ScriptBlock { + if (-not $env:ATUIN_SESSION -or $env:ATUIN_PID -ne $PID) { + $env:ATUIN_SESSION = atuin uuid + $env:ATUIN_PID = $PID + } + + $script:atuinHistoryId = $null + $script:previousPSConsoleHostReadLine = $Function:PSConsoleHostReadLine + + # The ReadLine overloads changed with breaking changes over time, make sure the one we expect is available. + $script:hasExpectedReadLineOverload = ([Microsoft.PowerShell.PSConsoleReadLine]::ReadLine).OverloadDefinitions.Contains("static string ReadLine(runspace runspace, System.Management.Automation.EngineIntrinsics engineIntrinsics, System.Threading.CancellationToken cancellationToken, System.Nullable[bool] lastRunStatus)") + + function Get-CommandLine { + $commandLine = "" + [Microsoft.PowerShell.PSConsoleReadLine]::GetBufferState([ref]$commandLine, [ref]$null) + return $commandLine + } + + function Set-CommandLine { + param([string]$Text) + + $commandLine = Get-CommandLine + [Microsoft.PowerShell.PSConsoleReadLine]::Replace(0, $commandLine.Length, $Text) + } + + # This function name is called by PSReadLine to read the next command line to execute. + # We replace it with a custom implementation which adds Atuin support. + function PSConsoleHostReadLine { + ## 1. Collect the exit code of the previous command. + + # This needs to be done as the first thing because any script run will flush $?. + $lastRunStatus = $? + + # Exit statuses are maintained separately for native and PowerShell commands, this needs to be taken into account. + $lastNativeExitCode = $global:LASTEXITCODE + $exitCode = if ($lastRunStatus) { 0 } elseif ($lastNativeExitCode) { $lastNativeExitCode } else { 1 } + + ## 2. Report the status of the previous command to Atuin (atuin history end). + + if ($script:atuinHistoryId) { + try { + # The duration is not recorded in old PowerShell versions, let Atuin handle it. $null arguments are ignored. + $duration = (Get-History -Count 1).Duration.Ticks * 100 + $durationArg = if ($duration) { "--duration=$duration" } else { $null } + + # Fire and forget the atuin history end command to avoid blocking the shell during a potential sync. + $process = New-Object System.Diagnostics.Process + $process.StartInfo.FileName = "atuin" + $process.StartInfo.Arguments = "history end --exit=$exitCode $durationArg -- $script:atuinHistoryId" + $process.StartInfo.UseShellExecute = $false + $process.StartInfo.CreateNoWindow = $true + $process.StartInfo.RedirectStandardInput = $true + $process.StartInfo.RedirectStandardOutput = $true + $process.StartInfo.RedirectStandardError = $true + $process.Start() | Out-Null + $process.StandardInput.Close() + $process.BeginOutputReadLine() + $process.BeginErrorReadLine() + } + catch { + # Ignore errors to avoid breaking the shell. + # An error would occur if the user removes atuin from the PATH, for instance. + } + finally { + $script:atuinHistoryId = $null + } + } + + ## 3. Read the next command line to execute. + + # PSConsoleHostReadLine implementation from PSReadLine, adjusted to support old versions. + Microsoft.PowerShell.Core\Set-StrictMode -Off + + $line = if ($script:hasExpectedReadLineOverload) { + # When the overload we expect is available, we can pass $lastRunStatus to it. + [Microsoft.PowerShell.PSConsoleReadLine]::ReadLine($Host.Runspace, $ExecutionContext, [System.Threading.CancellationToken]::None, $lastRunStatus) + } else { + # Either PSReadLine is older than v2.2.0-beta3, or maybe newer than we expect, so use the function from PSReadLine as-is. + & $script:previousPSConsoleHostReadLine + } + + ## 4. Report the next command line to Atuin (atuin history start). + + # PowerShell doesn't handle double quotes in native command line arguments the same way depending on its version, + # and the value of $PSNativeCommandArgumentPassing - see the about_Parsing help page which explains the breaking changes. + # This makes it unreliable, so we go through an environment variable, which should always be consistent across versions. + try { + $env:ATUIN_COMMAND_LINE = $line + $script:atuinHistoryId = atuin history start --command-from-env + } + catch { + # Ignore errors to avoid breaking the shell, see above. + } + finally { + $env:ATUIN_COMMAND_LINE = $null + } + + $global:LASTEXITCODE = $lastNativeExitCode + return $line + } + + function Invoke-AtuinSearch { + param([string]$ExtraArgs = "") + + $previousOutputEncoding = [System.Console]::OutputEncoding + $resultFile = New-TemporaryFile + $suggestion = "" + $errorOutput = "" + + try { + [System.Console]::OutputEncoding = [System.Text.Encoding]::UTF8 + + # Start-Process does some crazy stuff, just use the Process class directly to have more control. + $process = New-Object System.Diagnostics.Process + $process.StartInfo.FileName = "atuin" + $process.StartInfo.Arguments = "search -i --result-file ""$($resultFile.FullName)"" $ExtraArgs" + $process.StartInfo.UseShellExecute = $false + $process.StartInfo.RedirectStandardError = $true + $process.StartInfo.StandardErrorEncoding = [System.Text.Encoding]::UTF8 + $process.StartInfo.EnvironmentVariables["ATUIN_SHELL"] = "powershell" + $process.StartInfo.EnvironmentVariables["ATUIN_QUERY"] = Get-CommandLine + # PowerShell's Set-Location (cd) doesn't update the process-level working directory, set it explicitly + $process.StartInfo.WorkingDirectory = (Get-Location -PSProvider FileSystem).ProviderPath + + try { + $process.Start() | Out-Null + + # A single stream is redirected, so we can read it synchronously, but we have to start reading it + # before waiting for the process to exit, otherwise the buffer could fill up and cause a deadlock. + $errorOutput = $process.StandardError.ReadToEnd().Trim() + $process.WaitForExit() + + $suggestion = (Get-Content -LiteralPath $resultFile.FullName -Raw -Encoding UTF8 | Out-String).Trim() + } + catch { + $errorOutput = $_ + } + + if ($errorOutput) { + Write-Host -ForegroundColor Red "Atuin error:" + Write-Host -ForegroundColor DarkRed $errorOutput + } + + # If no shell prompt offset is set, initialize it from the current prompt line count. + if ($null -eq $env:ATUIN_POWERSHELL_PROMPT_OFFSET) { + try { + $promptLines = (& $Function:prompt | Out-String | Measure-Object -Line).Lines + $env:ATUIN_POWERSHELL_PROMPT_OFFSET = -1 * ($promptLines - 1) + } + catch { + $env:ATUIN_POWERSHELL_PROMPT_OFFSET = 0 + } + } + + # PSReadLine maintains its own cursor position, which will no longer be valid if Atuin scrolls the display in inline mode. + # Fortunately, InvokePrompt can receive a new Y position and reset the internal state. + $y = $Host.UI.RawUI.CursorPosition.Y + [int]$env:ATUIN_POWERSHELL_PROMPT_OFFSET + $y = [System.Math]::Max([System.Math]::Min($y, [System.Console]::BufferHeight - 1), 0) + [Microsoft.PowerShell.PSConsoleReadLine]::InvokePrompt($null, $y) + + if ($suggestion -eq "") { + # The previous input was already rendered by InvokePrompt + return + } + + $acceptPrefix = "__atuin_accept__:" + + if ( $suggestion.StartsWith($acceptPrefix)) { + Set-CommandLine $suggestion.Substring($acceptPrefix.Length) + [Microsoft.PowerShell.PSConsoleReadLine]::AcceptLine() + } else { + Set-CommandLine $suggestion + } + } + finally { + [System.Console]::OutputEncoding = $previousOutputEncoding + $resultFile.Delete() + } + } + + function Enable-AtuinSearchKeys { + param([bool]$CtrlR = $true, [bool]$UpArrow = $true) + + if ($CtrlR) { + Set-PSReadLineKeyHandler -Chord "Ctrl+r" -BriefDescription "Runs Atuin search" -ScriptBlock { + Invoke-AtuinSearch + } + } + + if ($UpArrow) { + Set-PSReadLineKeyHandler -Chord "UpArrow" -BriefDescription "Runs Atuin search" -ScriptBlock { + $line = Get-CommandLine + + if (!$line.Contains("`n")) { + Invoke-AtuinSearch -ExtraArgs "--shell-up-key-binding" + } else { + [Microsoft.PowerShell.PSConsoleReadLine]::PreviousLine() + } + } + } + } + + $ExecutionContext.SessionState.Module.OnRemove += { + $env:ATUIN_SESSION = $null + $Function:PSConsoleHostReadLine = $script:previousPSConsoleHostReadLine + } + + Export-ModuleMember -Function @("Enable-AtuinSearchKeys", "PSConsoleHostReadLine") +} | Import-Module -Global diff --git a/crates/turtle/src/shell/atuin.xsh b/crates/turtle/src/shell/atuin.xsh new file mode 100644 index 00000000..a0283402 --- /dev/null +++ b/crates/turtle/src/shell/atuin.xsh @@ -0,0 +1,86 @@ +import os +import subprocess + +from prompt_toolkit.application.current import get_app +from prompt_toolkit.filters import Condition +from prompt_toolkit.keys import Keys + + +if "ATUIN_SESSION" not in ${...} or ${...}.get("ATUIN_SHLVL", "") != ${...}.get("SHLVL", ""): + $ATUIN_SESSION=$(atuin uuid).rstrip('\n') + $ATUIN_SHLVL = ${...}.get("SHLVL", "") + +@events.on_precommand +def _atuin_precommand(cmd: str): + cmd = cmd.rstrip("\n") + try: + $ATUIN_HISTORY_ID = $(atuin history start -- @(cmd) 2>@(os.devnull)).rstrip("\n") + except: + $ATUIN_HISTORY_ID = "" + + +@events.on_postcommand +def _atuin_postcommand(cmd: str, rtn: int, out, ts): + if "ATUIN_HISTORY_ID" not in ${...}: + return + + duration = ts[1] - ts[0] + # Duration is float representing seconds, but atuin expects integer of nanoseconds + nanos = round(duration * 10 ** 9) + with ${...}.swap(ATUIN_LOG="error"): + # This causes the entire .xonshrc to be re-executed, which is incredibly slow + # This happens when using a subshell and using output redirection at the same time + # For more details, see https://github.com/xonsh/xonsh/issues/5224 + # (atuin history end --exit @(rtn) -- $ATUIN_HISTORY_ID &) > /dev/null 2>&1 + atuin history end --exit @(rtn) --duration @(nanos) -- $ATUIN_HISTORY_ID > @(os.devnull) 2>&1 + del $ATUIN_HISTORY_ID + + +def _search(event, extra_args: list[str]): + buffer = event.current_buffer + cmd = ["atuin", "search", "--interactive", *extra_args] + # We need to explicitly pass in xonsh env, in case user has set XDG_HOME or something else that matters + env = ${...}.detype() + env["ATUIN_SHELL"] = "xonsh" + env["ATUIN_QUERY"] = buffer.text + + p = subprocess.run(cmd, stderr=subprocess.PIPE, encoding="utf-8", env=env) + result = p.stderr.rstrip("\n") + # redraw prompt - necessary if atuin is configured to run inline, rather than fullscreen + event.cli.renderer.erase() + + if not result: + return + + buffer.reset() + if result.startswith("__atuin_accept__:"): + buffer.insert_text(result[17:]) + buffer.validate_and_handle() + else: + buffer.insert_text(result) + + +@events.on_ptk_create +def _custom_keybindings(bindings, **kw): + if _ATUIN_BIND_CTRL_R: + @bindings.add(Keys.ControlR) + def r_search(event): + _search(event, extra_args=[]) + + if _ATUIN_BIND_UP_ARROW: + @Condition + def should_search(): + buffer = get_app().current_buffer + # disable keybind when there is an active completion, so + # that up arrow can be used to navigate completion menu + if buffer.complete_state is not None: + return False + # similarly, disable when buffer text contains multiple lines + if '\n' in buffer.text: + return False + + return True + + @bindings.add(Keys.Up, filter=should_search) + def up_search(event): + _search(event, extra_args=["--shell-up-key-binding"]) diff --git a/crates/turtle/src/shell/atuin.zsh b/crates/turtle/src/shell/atuin.zsh new file mode 100644 index 00000000..7a7375aa --- /dev/null +++ b/crates/turtle/src/shell/atuin.zsh @@ -0,0 +1,221 @@ +# shellcheck disable=SC2034,SC2153,SC2086,SC2155 + +# Above line is because shellcheck doesn't support zsh, per +# https://github.com/koalaman/shellcheck/wiki/SC1071, and the ignore: param in +# ludeeus/action-shellcheck only supports _directories_, not _files_. So +# instead, we manually add any error the shellcheck step finds in the file to +# the above line ... + +# Source this in your ~/.zshrc +autoload -U add-zsh-hook + +zmodload zsh/datetime 2>/dev/null + +# If zsh-autosuggestions is installed, configure it to use Atuin's search. If +# you'd like to override this, then add your config after the $(atuin init zsh) +# in your .zshrc +_zsh_autosuggest_strategy_atuin() { + # silence errors, since we don't want to spam the terminal prompt while typing. + suggestion=$(ATUIN_QUERY="$1" atuin search --cmd-only --limit 1 --search-mode prefix 2>/dev/null) +} + +if [ -n "${ZSH_AUTOSUGGEST_STRATEGY:-}" ]; then + ZSH_AUTOSUGGEST_STRATEGY=("atuin" "${ZSH_AUTOSUGGEST_STRATEGY[@]}") +else + ZSH_AUTOSUGGEST_STRATEGY=("atuin") +fi + +if [[ -z "${ATUIN_SESSION:-}" || "${ATUIN_SHLVL:-}" != "$SHLVL" ]]; then + export ATUIN_SESSION=$(atuin uuid) + export ATUIN_SHLVL=$SHLVL +fi +ATUIN_HISTORY_ID="" + +__atuin_osc133_command_executed() { + [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]] || return + [[ -n "${ATUIN_HISTORY_ID:-}" ]] || return + + printf '\033]133;C\a' +} + +__atuin_osc133_command_finished() { + [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]] || return + [[ -n "${ATUIN_HISTORY_ID:-}" ]] || return + + printf '\033]133;D;%s;history_id=%s;session_id=%s\a' "$1" "$ATUIN_HISTORY_ID" "${ATUIN_SESSION:-}" +} + +__atuin_osc133_prompt_start=$'%{\033]133;A;cl=line\a%}' +__atuin_osc133_prompt_end=$'%{\033]133;B\a%}' + +__atuin_osc133_wrap_prompt() { + local __atuin_prompt="${PROMPT-}" + local __atuin_rprompt="${RPROMPT-}" + + __atuin_prompt="${__atuin_prompt//$__atuin_osc133_prompt_start/}" + __atuin_prompt="${__atuin_prompt//$__atuin_osc133_prompt_end/}" + __atuin_rprompt="${__atuin_rprompt//$__atuin_osc133_prompt_start/}" + __atuin_rprompt="${__atuin_rprompt//$__atuin_osc133_prompt_end/}" + + if [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]]; then + PROMPT="${__atuin_osc133_prompt_start}${__atuin_prompt}" + RPROMPT="${__atuin_rprompt}${__atuin_osc133_prompt_end}" + else + PROMPT="$__atuin_prompt" + RPROMPT="$__atuin_rprompt" + fi +} + +_atuin_preexec() { + local id + id=$(atuin history start -- "$1" 2>/dev/null) + export ATUIN_HISTORY_ID="$id" + __atuin_osc133_command_executed + __atuin_preexec_time=${EPOCHREALTIME-} +} + +_atuin_precmd() { + local EXIT="$?" __atuin_precmd_time=${EPOCHREALTIME-} + + __atuin_osc133_wrap_prompt + + [[ -z "${ATUIN_HISTORY_ID:-}" ]] && return + + local duration="" + if [[ -n $__atuin_preexec_time && -n $__atuin_precmd_time ]]; then + printf -v duration %.0f $(((__atuin_precmd_time - __atuin_preexec_time) * 1000000000)) + fi + + __atuin_osc133_command_finished "$EXIT" + (ATUIN_LOG=error atuin history end --exit $EXIT ${duration:+--duration=$duration} -- $ATUIN_HISTORY_ID &) >/dev/null 2>&1 + export ATUIN_HISTORY_ID="" +} + +# Check if tmux popup is available (tmux >= 3.2) +__atuin_tmux_popup_check() { + [[ -n "${TMUX-}" ]] || return 1 + [[ "${ATUIN_TMUX_POPUP:-true}" != "false" ]] || return 1 + + # https://github.com/tmux/tmux/wiki/FAQ#how-often-is-tmux-released-what-is-the-version-number-scheme + local tmux_version + tmux_version=$(tmux -V 2>/dev/null | sed -n 's/^[^0-9]*\([0-9][0-9]*\.[0-9][0-9]*\).*/\1/p') # Could have used grep... + [[ -z "$tmux_version" ]] && return 1 + + local m1 m2 + m1=${tmux_version%%.*} + m2=${tmux_version#*.} + m2=${m2%%.*} + [[ "$m1" =~ ^[0-9]+$ ]] || return 1 + [[ "$m2" =~ ^[0-9]+$ ]] || m2=0 + (( m1 > 3 || (m1 == 3 && m2 >= 2) )) +} + +# Use global variable to fix scope issues with traps +__atuin_popup_tmpdir="" +__atuin_tmux_popup_cleanup() { + [[ -n "$__atuin_popup_tmpdir" && -d "$__atuin_popup_tmpdir" ]] && command rm -rf "$__atuin_popup_tmpdir" + __atuin_popup_tmpdir="" +} + +__atuin_search_cmd() { + local -a search_args=("$@") + + if __atuin_tmux_popup_check; then + __atuin_popup_tmpdir=$(mktemp -d) || return 1 + local result_file="$__atuin_popup_tmpdir/result" + + trap '__atuin_tmux_popup_cleanup' EXIT HUP INT TERM + + local escaped_query escaped_args + escaped_query=$(printf '%s' "$BUFFER" | sed "s/'/'\\\\''/g") + escaped_args="" + for arg in "${search_args[@]}"; do + escaped_args+=" '$(printf '%s' "$arg" | sed "s/'/'\\\\''/g")'" + done + + # In the popup, atuin goes to terminal, stderr goes to file + local cdir popup_width popup_height + cdir=$(pwd) + popup_width="${ATUIN_TMUX_POPUP_WIDTH:-80%}" # Keep default value anyways + popup_height="${ATUIN_TMUX_POPUP_HEIGHT:-60%}" + tmux display-popup -d "$cdir" -w "$popup_width" -h "$popup_height" -E -E -- \ + sh -c "PATH='$PATH' ATUIN_SESSION='$ATUIN_SESSION' ATUIN_SHELL=zsh ATUIN_LOG=error ATUIN_QUERY='$escaped_query' atuin search $escaped_args -i 2>'$result_file'" + + if [[ -f "$result_file" ]]; then + cat "$result_file" + fi + + __atuin_tmux_popup_cleanup + trap - EXIT HUP INT TERM + else + ATUIN_SHELL=zsh ATUIN_LOG=error ATUIN_QUERY=$BUFFER atuin search "${search_args[@]}" -i 3>&1 1>&2 2>&3 3>&- + fi +} + +_atuin_search() { + emulate -L zsh + zle -I + + # swap stderr and stdout, so that the tui stuff works + # TODO: not this + local output __atuin_status + # shellcheck disable=SC2048 + output=$(__atuin_search_cmd $*) + __atuin_status=$? + + zle reset-prompt + # re-enable bracketed paste + # shellcheck disable=SC2154 + echo -n ${zle_bracketed_paste[1]} >/dev/tty + + if (( __atuin_status != 0 )); then + [[ -n $output ]] && print -r -- "$output" >/dev/tty + return $__atuin_status + fi + + if [[ -n $output ]]; then + RBUFFER="" + LBUFFER=$output + + if [[ $LBUFFER == __atuin_accept__:* ]] + then + LBUFFER=${LBUFFER#__atuin_accept__:} + zle accept-line + fi + fi +} +_atuin_search_vicmd() { + _atuin_search --keymap-mode=vim-normal +} +_atuin_search_viins() { + _atuin_search --keymap-mode=vim-insert +} + +_atuin_up_search() { + # Only trigger if the buffer is a single line + if [[ ! $BUFFER == *$'\n'* ]]; then + _atuin_search --shell-up-key-binding "$@" + else + zle up-line + fi +} +_atuin_up_search_vicmd() { + _atuin_up_search --keymap-mode=vim-normal +} +_atuin_up_search_viins() { + _atuin_up_search --keymap-mode=vim-insert +} + +add-zsh-hook preexec _atuin_preexec +add-zsh-hook precmd _atuin_precmd + +zle -N atuin-search _atuin_search +zle -N atuin-search-vicmd _atuin_search_vicmd +zle -N atuin-search-viins _atuin_search_viins +zle -N atuin-up-search _atuin_up_search +zle -N atuin-up-search-vicmd _atuin_up_search_vicmd +zle -N atuin-up-search-viins _atuin_up_search_viins + +# These are compatibility widget names for "atuin <= 17.2.1" users. +zle -N _atuin_search_widget _atuin_search +zle -N _atuin_up_search_widget _atuin_up_search diff --git a/crates/turtle/src/sync.rs b/crates/turtle/src/sync.rs new file mode 100644 index 00000000..56aef615 --- /dev/null +++ b/crates/turtle/src/sync.rs @@ -0,0 +1,34 @@ +use eyre::{Context, Result}; + +use crate::atuin_client::{ + database::Database, history::store::HistoryStore, record::sqlite_store::SqliteStore, + settings::Settings, +}; +use crate::atuin_common::record::RecordId; + +// This is the only crate that ties together all other crates. +// Therefore, it's the only crate where functions tying together all stores can live + +/// Rebuild all stores after a sync +/// Note: for history, this only does an _incremental_ sync. Hence the need to specify downloaded +/// records. +pub async fn build( + settings: &Settings, + store: &SqliteStore, + db: &dyn Database, + downloaded: Option<&[RecordId]>, +) -> Result<()> { + let encryption_key: [u8; 32] = crate::atuin_client::encryption::load_key(settings) + .context("could not load encryption key")? + .into(); + + let host_id = Settings::host_id().await?; + + let downloaded = downloaded.unwrap_or(&[]); + + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + + history_store.incremental_build(db, downloaded).await?; + + Ok(()) +} |
