diff options
| author | Benedikt Peetz <benedikt.peetz@b-peetz.de> | 2026-06-11 00:54:30 +0200 |
|---|---|---|
| committer | Benedikt Peetz <benedikt.peetz@b-peetz.de> | 2026-06-11 00:54:30 +0200 |
| commit | 5c39e7cf284a1f6e9a1657f2deb44e359fc47eb8 (patch) | |
| tree | c64baa8d5866c8e339eaf660dd3f94f30a3f7d8a /crates/turtle/src/command | |
| parent | chore: Somewhat simplify sync code (diff) | |
| download | atuin-5c39e7cf284a1f6e9a1657f2deb44e359fc47eb8.zip | |
chore: Move everything into one big crate
That helps remove duplicated code and rustc/cargo will now also show
dead code correctly.
Diffstat (limited to 'crates/turtle/src/command')
54 files changed, 14467 insertions, 0 deletions
diff --git a/crates/turtle/src/command/CONTRIBUTORS b/crates/turtle/src/command/CONTRIBUTORS new file mode 120000 index 00000000..1ca4115a --- /dev/null +++ b/crates/turtle/src/command/CONTRIBUTORS @@ -0,0 +1 @@ +../../../../CONTRIBUTORS
\ No newline at end of file diff --git a/crates/turtle/src/command/client.rs b/crates/turtle/src/command/client.rs new file mode 100644 index 00000000..20d85303 --- /dev/null +++ b/crates/turtle/src/command/client.rs @@ -0,0 +1,371 @@ +use std::fs::{self, OpenOptions}; +use std::path::{Path, PathBuf}; + +use clap::Subcommand; +use eyre::{Result, WrapErr}; + +use crate::atuin_client::{ + database::Sqlite, record::sqlite_store::SqliteStore, settings::Settings, theme, +}; +use tracing_appender::rolling::{RollingFileAppender, Rotation}; +use tracing_subscriber::{ + Layer, filter::EnvFilter, filter::LevelFilter, fmt, fmt::format::FmtSpan, prelude::*, +}; + +fn cleanup_old_logs(log_dir: &Path, prefix: &str, retention_days: u64) { + let cutoff = std::time::SystemTime::now() + - std::time::Duration::from_secs(retention_days * 24 * 60 * 60); + + let Ok(entries) = fs::read_dir(log_dir) else { + return; + }; + + for entry in entries.flatten() { + let path = entry.path(); + let Some(name) = path.file_name().and_then(|n| n.to_str()) else { + continue; + }; + + // Match files like "search.log.2024-02-23" or "daemon.log.2024-02-23" + if !name.starts_with(prefix) || name == prefix { + continue; + } + + if let Ok(metadata) = entry.metadata() + && let Ok(modified) = metadata.modified() + && modified < cutoff + { + let _ = fs::remove_file(&path); + } + } +} + +#[cfg(feature = "sync")] +mod sync; + +#[cfg(feature = "sync")] +mod account; + +#[cfg(feature = "daemon")] +mod daemon; + +mod config; +mod default_config; +mod doctor; +mod history; +mod import; +mod info; +mod init; +mod search; +mod server; +mod setup; +mod stats; +mod store; +mod wrapped; + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Setup Atuin features + #[command()] + Setup, + + /// Manipulate shell history + #[command(subcommand)] + History(history::Cmd), + + /// Import shell history from file + #[command(subcommand)] + Import(import::Cmd), + + /// Calculate statistics for your history + Stats(stats::Cmd), + + /// Interactive history search + Search(search::Cmd), + + #[cfg(feature = "sync")] + #[command(flatten)] + Sync(sync::Cmd), + + /// Manage the atuin server + #[command(subcommand)] + Server(server::Cmd), + + /// Manage your sync account + #[cfg(feature = "sync")] + Account(account::Cmd), + + /// Manage the atuin data store + #[command(subcommand)] + Store(store::Cmd), + + /// Print Atuin's shell init script + #[command()] + Init(init::Cmd), + + /// Information about dotfiles locations and ENV vars + #[command()] + Info, + + /// Run the doctor to check for common issues + #[command()] + Doctor, + + #[command()] + Wrapped { year: Option<i32> }, + + /// *Experimental* Manage the background daemon + #[cfg(feature = "daemon")] + #[command()] + Daemon(daemon::Cmd), + + /// Print the default atuin configuration (config.toml) + #[command()] + DefaultConfig, + + #[command(subcommand)] + Config(config::Cmd), +} + +impl Cmd { + pub fn run(self) -> Result<()> { + // Daemonize before creating the async runtime – fork() inside a live + // tokio runtime corrupts its internal state. + #[cfg(all(unix, feature = "daemon"))] + if let Self::Daemon(ref cmd) = self + && cmd.should_daemonize() + { + daemon::daemonize_current_process()?; + } + + let mut runtime = tokio::runtime::Builder::new_current_thread(); + + let runtime = runtime.enable_all().build().unwrap(); + + // For non-history commands, we want to initialize logging and the theme manager before + // doing anything else. History commands are performance-sensitive and run before and after + // every shell command, so we want to skip any unnecessary initialization for them. + let settings = Settings::new().wrap_err("could not load client settings")?; + let theme_manager = theme::ThemeManager::new(settings.theme.debug, None); + let res = runtime.block_on(self.run_inner(settings, theme_manager)); + + runtime.shutdown_timeout(std::time::Duration::from_millis(50)); + + res + } + + #[expect(clippy::too_many_lines, clippy::future_not_send)] + async fn run_inner( + self, + mut settings: Settings, + mut theme_manager: theme::ThemeManager, + ) -> Result<()> { + // ATUIN_LOG env var overrides config file level settings + let env_log_set = std::env::var("ATUIN_LOG").is_ok(); + + // Base filter from env var (or empty if not set) + let base_filter = + EnvFilter::from_env("ATUIN_LOG").add_directive("sqlx_sqlite::regexp=off".parse()?); + + let is_interactive_search = matches!(&self, Self::Search(cmd) if cmd.is_interactive()); + // Use file-based logging for interactive search (TUI mode) + let use_search_logging = is_interactive_search && settings.logs.search_enabled(); + + // Use file-based logging for daemon + #[cfg(feature = "daemon")] + let use_daemon_logging = matches!(&self, Self::Daemon(_)) && settings.logs.daemon_enabled(); + + #[cfg(not(feature = "daemon"))] + let use_daemon_logging = false; + + // Check if daemon should also log to console + #[cfg(feature = "daemon")] + let daemon_show_logs = matches!(&self, Self::Daemon(cmd) if cmd.show_logs()); + + #[cfg(not(feature = "daemon"))] + let daemon_show_logs = false; + + // Set up span timing JSON logs if ATUIN_SPAN is set + let span_path = std::env::var("ATUIN_SPAN").ok().map(|p| { + if p.is_empty() { + "atuin-spans.json".to_string() + } else { + p + } + }); + + // Helper to create span timing layer + macro_rules! make_span_layer { + ($path:expr) => {{ + let span_file = OpenOptions::new() + .create(true) + .truncate(true) + .write(true) + .open($path)?; + Some( + fmt::layer() + .json() + .with_writer(span_file) + .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) + .with_filter(LevelFilter::TRACE), + ) + }}; + } + + // Build the subscriber with all configured layers + if use_search_logging { + let search_filename = settings.logs.search.file.clone(); + let log_dir = PathBuf::from(&settings.logs.dir); + fs::create_dir_all(&log_dir)?; + + // Clean up old log files + cleanup_old_logs(&log_dir, &search_filename, settings.logs.search_retention()); + + let file_appender = + RollingFileAppender::new(Rotation::DAILY, &log_dir, &search_filename); + + // Use config level unless ATUIN_LOG is set + let filter = if env_log_set { + base_filter + } else { + EnvFilter::default() + .add_directive(settings.logs.search_level().as_directive().parse()?) + .add_directive("sqlx_sqlite::regexp=off".parse()?) + }; + + let base = tracing_subscriber::registry().with( + fmt::layer() + .with_writer(file_appender) + .with_ansi(false) + .with_filter(filter), + ); + + match &span_path { + Some(sp) => { + base.with(make_span_layer!(sp)).init(); + } + None => { + base.init(); + } + } + } else if use_daemon_logging { + let daemon_filename = settings.logs.daemon.file.clone(); + let log_dir = PathBuf::from(&settings.logs.dir); + fs::create_dir_all(&log_dir)?; + + // Clean up old log files + cleanup_old_logs(&log_dir, &daemon_filename, settings.logs.daemon_retention()); + + let file_appender = + RollingFileAppender::new(Rotation::DAILY, &log_dir, &daemon_filename); + + // Use config level unless ATUIN_LOG is set + let file_filter = if env_log_set { + base_filter + } else { + EnvFilter::default() + .add_directive(settings.logs.daemon_level().as_directive().parse()?) + .add_directive("sqlx_sqlite::regexp=off".parse()?) + }; + + let file_layer = fmt::layer() + .with_writer(file_appender) + .with_ansi(false) + .with_filter(file_filter); + + // Optionally add console layer for --show-logs + if daemon_show_logs { + let console_filter = EnvFilter::from_env("ATUIN_LOG") + .add_directive("sqlx_sqlite::regexp=off".parse()?); + + let console_layer = fmt::layer().with_filter(console_filter); + + let base = tracing_subscriber::registry() + .with(file_layer) + .with(console_layer); + + match &span_path { + Some(sp) => { + base.with(make_span_layer!(sp)).init(); + } + None => { + base.init(); + } + } + } else { + let base = tracing_subscriber::registry().with(file_layer); + + match &span_path { + Some(sp) => { + base.with(make_span_layer!(sp)).init(); + } + None => { + base.init(); + } + } + } + } + + tracing::trace!(command = ?self, "client command"); + + // Skip initializing any databases for history + // This is a pretty hot path, as it runs before and after every single command the user + // runs + match self { + Self::History(history) => return history.run(&settings).await, + Self::Init(init) => { + init.run(&settings); + return Ok(()); + } + Self::Doctor => return doctor::run(&settings).await, + Self::Config(config) => return config.run(&settings).await, + _ => {} + } + + let db_path = PathBuf::from(settings.db_path.as_str()); + let record_store_path = PathBuf::from(settings.record_store_path.as_str()); + + let db = Sqlite::new(db_path, settings.local_timeout).await?; + let sqlite_store = SqliteStore::new(record_store_path, settings.local_timeout).await?; + + let theme_name = settings.theme.name.clone(); + let theme = theme_manager.load_theme(theme_name.as_str(), settings.theme.max_depth); + + match self { + Self::Setup => setup::run(&settings).await, + Self::Import(import) => import.run(&db).await, + Self::Stats(stats) => stats.run(&db, &settings, theme).await, + Self::Search(search) => search.run(db, &mut settings, sqlite_store, theme).await, + + #[cfg(feature = "sync")] + Self::Sync(sync) => sync.run(settings, &db, sqlite_store).await, + + #[cfg(feature = "sync")] + Self::Account(account) => account.run(settings, sqlite_store).await, + + Self::Store(store) => store.run(&settings, &db, sqlite_store).await, + + Self::Server(server) => server.run().await, + + Self::Info => { + info::run(&settings); + Ok(()) + } + + Self::DefaultConfig => { + default_config::run(); + Ok(()) + } + + Self::Wrapped { year } => wrapped::run(year, &db, &settings, theme).await, + + #[cfg(feature = "daemon")] + Self::Daemon(cmd) => cmd.run(settings, sqlite_store, db).await, + + Self::History(_) | Self::Init(_) | Self::Doctor | Self::Config(_) => { + unreachable!() + } + } + } +} diff --git a/crates/turtle/src/command/client/account.rs b/crates/turtle/src/command/client/account.rs new file mode 100644 index 00000000..898f1ac4 --- /dev/null +++ b/crates/turtle/src/command/client/account.rs @@ -0,0 +1,47 @@ +use clap::{Args, Subcommand}; +use eyre::Result; + +use crate::atuin_client::record::sqlite_store::SqliteStore; +use crate::atuin_client::settings::Settings; + +pub mod change_password; +pub mod delete; +pub mod login; +pub mod logout; +pub mod register; + +#[derive(Args, Debug)] +pub struct Cmd { + #[command(subcommand)] + command: Commands, +} + +#[derive(Subcommand, Debug)] +pub enum Commands { + /// Login to the configured server + Login(login::Cmd), + + /// Register a new account + Register(register::Cmd), + + /// Log out + Logout, + + /// Delete your account, and all synced data + Delete(delete::Cmd), + + /// Change your password + ChangePassword(change_password::Cmd), +} + +impl Cmd { + pub async fn run(self, settings: Settings, store: SqliteStore) -> Result<()> { + match self.command { + Commands::Login(l) => l.run(&settings, &store).await, + Commands::Register(r) => r.run(&settings).await, + Commands::Logout => logout::run().await, + Commands::Delete(d) => d.run(&settings).await, + Commands::ChangePassword(c) => c.run(&settings).await, + } + } +} diff --git a/crates/turtle/src/command/client/account/change_password.rs b/crates/turtle/src/command/client/account/change_password.rs new file mode 100644 index 00000000..6112b0df --- /dev/null +++ b/crates/turtle/src/command/client/account/change_password.rs @@ -0,0 +1,67 @@ +use clap::Parser; +use eyre::{Result, bail}; + +use crate::atuin_client::{ + auth::{self, MutateResponse}, + settings::Settings, +}; +use rpassword::prompt_password; + +#[derive(Parser, Debug)] +pub struct Cmd { + #[clap(long, short)] + pub current_password: Option<String>, + + #[clap(long, short)] + pub new_password: Option<String>, + + /// The two-factor authentication code for your account, if any + #[clap(long, short)] + pub totp_code: Option<String>, +} + +impl Cmd { + pub async fn run(&self, settings: &Settings) -> Result<()> { + if !settings.logged_in().await? { + bail!("You are not logged in"); + } + + let client = auth::auth_client(settings).await; + + let current_password = self.current_password.clone().unwrap_or_else(|| { + prompt_password("Please enter the current password: ") + .expect("Failed to read from input") + }); + + if current_password.is_empty() { + bail!("please provide the current password"); + } + + let new_password = self.new_password.clone().unwrap_or_else(|| { + prompt_password("Please enter the new password: ").expect("Failed to read from input") + }); + + if new_password.is_empty() { + bail!("please provide a new password"); + } + + let mut totp_code = self.totp_code.clone(); + + loop { + let response = client + .change_password(¤t_password, &new_password, totp_code.as_deref()) + .await?; + + match response { + MutateResponse::Success => break, + MutateResponse::TwoFactorRequired => { + totp_code = Some(super::login::or_user_input(None, "two-factor code")); + } + } + } + + println!("Account password successfully changed!"); + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/account/delete.rs b/crates/turtle/src/command/client/account/delete.rs new file mode 100644 index 00000000..bcb40bc3 --- /dev/null +++ b/crates/turtle/src/command/client/account/delete.rs @@ -0,0 +1,57 @@ +use crate::atuin_client::{ + auth::{self, MutateResponse}, + settings::Settings, +}; +use clap::Parser; +use eyre::{Result, bail}; + +use super::login::{or_user_input, read_user_password}; + +#[derive(Parser, Debug)] +pub struct Cmd { + #[clap(long, short)] + pub password: Option<String>, + + /// The two-factor authentication code for your account, if any + #[clap(long, short)] + pub totp_code: Option<String>, +} + +impl Cmd { + pub async fn run(&self, settings: &Settings) -> Result<()> { + if !settings.logged_in().await? { + bail!("You are not logged in"); + } + + let client = auth::auth_client(settings).await; + + let password = self.password.clone().unwrap_or_else(read_user_password); + + if password.is_empty() { + bail!("please provide your password"); + } + + let mut totp_code = self.totp_code.clone(); + + loop { + let response = client + .delete_account(&password, totp_code.as_deref()) + .await?; + + match response { + MutateResponse::Success => break, + MutateResponse::TwoFactorRequired => { + totp_code = Some(or_user_input(None, "two-factor code")); + } + } + } + + // Clean up sessions from meta store + let meta = Settings::meta_store().await?; + meta.delete_session().await?; + + println!("Your account is deleted"); + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/account/login.rs b/crates/turtle/src/command/client/account/login.rs new file mode 100644 index 00000000..0c5b66f5 --- /dev/null +++ b/crates/turtle/src/command/client/account/login.rs @@ -0,0 +1,206 @@ +use std::{io, path::PathBuf}; + +use clap::Parser; +use eyre::{Context, Result, bail}; +use tokio::{fs::File, io::AsyncWriteExt}; + +use crate::atuin_client::{ + auth::{self, AuthResponse}, + encryption::{decode_key, load_key}, + record::sqlite_store::SqliteStore, + record::store::Store, + record::sync::{self, SyncError}, + settings::{Settings, SyncAuth}, +}; +use rpassword::prompt_password; + +#[derive(Parser, Debug)] +pub struct Cmd { + #[clap(long, short)] + pub username: Option<String>, + + #[clap(long, short)] + pub password: Option<String>, + + /// The encryption key for your account + #[clap(long, short)] + pub key: Option<String>, + + /// The two-factor authentication code for your account, if any + #[clap(long, short)] + pub totp_code: Option<String>, + + #[clap(long, hide = true)] + pub from_registration: bool, +} + +fn get_input() -> Result<String> { + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + Ok(input.trim_end_matches(&['\r', '\n'][..]).to_string()) +} + +impl Cmd { + pub async fn run(&self, settings: &Settings, store: &SqliteStore) -> Result<()> { + match settings.resolve_sync_auth().await { + SyncAuth::Legacy { .. } => { + println!("You are logged in to your sync server."); + println!("Run 'atuin logout' to log out."); + return Ok(()); + } + SyncAuth::NotLoggedIn { .. } => {} + } + + self.run_legacy_login(settings, store).await?; + + verify_key_against_remote(settings).await + } + + /// Legacy login: always prompt for username/password interactively + /// (or accept them via flags). + async fn run_legacy_login(&self, settings: &Settings, store: &SqliteStore) -> Result<()> { + let username = or_user_input(self.username.clone(), "username"); + let password = self.password.clone().unwrap_or_else(read_user_password); + + self.prompt_and_store_key(settings, store).await?; + + let client = auth::auth_client(settings).await; + let response = client.login(&username, &password).await?; + + match response { + AuthResponse::Success { session, .. } => { + Settings::meta_store().await?.save_session(&session).await?; + } + AuthResponse::TwoFactorRequired => { + // Legacy server doesn't support 2FA, so this shouldn't happen. + bail!("unexpected two-factor requirement from legacy server"); + } + } + + println!("Logged in!"); + Ok(()) + } + + async fn prompt_and_store_key(&self, settings: &Settings, store: &SqliteStore) -> Result<()> { + let key_path = settings.key_path.as_str(); + let key_path = PathBuf::from(key_path); + + println!("IMPORTANT"); + println!( + "If you are already logged in on another machine, you must ensure that the key you use here is the same as the key you used there." + ); + println!("You can find your key by running 'atuin key' on the other machine."); + println!("Do not share this key with anyone."); + println!("\nRead more here: https://docs.atuin.sh/guide/sync/#login \n"); + + let key = or_user_input( + self.key.clone(), + "encryption key [blank to use existing key file]", + ); + + if key.is_empty() { + if key_path.exists() { + let bytes = fs_err::read_to_string(&key_path).context(format!( + "Existing key file at '{}' could not be read", + key_path.to_string_lossy() + ))?; + if decode_key(bytes).is_err() { + bail!(format!( + "The key in existing key file at '{}' is invalid", + key_path.to_string_lossy() + )); + } + } else { + panic!( + "No key provided and no existing key file found. Please use 'atuin key' on your other machine, or recover your key from a backup" + ) + } + } else if !key_path.exists() { + if decode_key(key.clone()).is_err() { + bail!("The specified key is invalid"); + } + + let mut file = File::create(&key_path).await?; + file.write_all(key.as_bytes()).await?; + } else { + // we now know that the user has logged in specifying a key, AND that the key path + // exists + + // 1. check if the saved key and the provided key match. if so, nothing to do. + // 2. if not, re-encrypt the local history and overwrite the key + let current_key: [u8; 32] = load_key(settings)?.into(); + + let encoded = key.clone(); // gonna want to save it in a bit + let new_key: [u8; 32] = decode_key(key) + .context("Could not decode provided key; is not valid base64-encoded key")? + .into(); + + if new_key != current_key { + println!("\nRe-encrypting local store with new key"); + + store.re_encrypt(¤t_key, &new_key).await?; + + println!("Writing new key"); + let mut file = File::create(&key_path).await?; + file.write_all(encoded.as_bytes()).await?; + } + } + + Ok(()) + } +} + +async fn verify_key_against_remote(settings: &Settings) -> Result<()> { + let key: [u8; 32] = load_key(settings) + .context("could not load encryption key for verification")? + .into(); + + let client = sync::build_client(settings).await?; + let remote_index = match client.record_status().await { + Ok(idx) => idx, + Err(e) => { + tracing::warn!("could not fetch remote status to verify key: {e}"); + return Ok(()); + } + }; + + match sync::check_encryption_key(&client, &remote_index, &key).await { + Ok(()) => Ok(()), + Err(SyncError::WrongKey) => { + // Roll back the saved session so the user is not left in a + // half-authenticated state with a key that can't read the data. + if let Ok(meta) = Settings::meta_store().await { + let _ = meta.delete_session().await; + } + crate::print_error::print_error( + "Wrong encryption key", + "The encryption key on this machine does not match the data on the server. \ + You have been logged out.\n\n\ + To fix this, find your existing key by running `atuin key` on a machine that \ + already syncs successfully, then run `atuin login` again here with that key.", + ); + std::process::exit(1); + } + Err(e) => { + // Non-key error (e.g. transient network issue). Don't fail the + // login — the user is authenticated and can sync later when the + // network recovers. + tracing::warn!("could not verify encryption key against remote: {e}"); + Ok(()) + } + } +} + +pub(super) fn or_user_input(value: Option<String>, name: &'static str) -> String { + value.unwrap_or_else(|| read_user_input(name)) +} + +pub(super) fn read_user_password() -> String { + let password = prompt_password("Please enter password: "); + password.expect("Failed to read from input") +} + +fn read_user_input(name: &'static str) -> String { + eprint!("Please enter {name}: "); + get_input().expect("Failed to read from input") +} diff --git a/crates/turtle/src/command/client/account/logout.rs b/crates/turtle/src/command/client/account/logout.rs new file mode 100644 index 00000000..6150a52b --- /dev/null +++ b/crates/turtle/src/command/client/account/logout.rs @@ -0,0 +1,5 @@ +use eyre::Result; + +pub async fn run() -> Result<()> { + crate::atuin_client::logout::logout().await +} diff --git a/crates/turtle/src/command/client/account/register.rs b/crates/turtle/src/command/client/account/register.rs new file mode 100644 index 00000000..548c2739 --- /dev/null +++ b/crates/turtle/src/command/client/account/register.rs @@ -0,0 +1,67 @@ +use clap::Parser; +use eyre::{Result, bail}; + +use super::login::or_user_input; +use crate::atuin_client::settings::{Settings, SyncAuth}; + +#[derive(Parser, Debug)] +pub struct Cmd { + #[clap(long, short)] + pub username: Option<String>, + + #[clap(long, short)] + pub password: Option<String>, + + #[clap(long, short)] + pub email: Option<String>, +} + +impl Cmd { + pub async fn run(&self, settings: &Settings) -> Result<()> { + match settings.resolve_sync_auth().await { + SyncAuth::Legacy { .. } => { + println!("You are already logged in."); + println!("Run 'atuin logout' to log out."); + return Ok(()); + } + + SyncAuth::NotLoggedIn { .. } => {} + } + + // Legacy registration flow + println!("Registering for an Atuin Sync account"); + + let username = or_user_input(self.username.clone(), "username"); + let email = or_user_input(self.email.clone(), "email"); + let password = self + .password + .clone() + .unwrap_or_else(super::login::read_user_password); + + if password.is_empty() { + bail!("please provide a password"); + } + + let session = crate::atuin_client::api_client::register( + settings.sync_address.as_str(), + &username, + &email, + &password, + ) + .await?; + + let meta = Settings::meta_store().await?; + meta.save_session(&session.session).await?; + + let _key = crate::atuin_client::encryption::load_key(settings)?; + + println!( + "Registration successful! Please make a note of your key (run 'atuin key') and keep it safe." + ); + println!( + "You will need it to log in on other devices, and we cannot help recover it if you lose it." + ); + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/config.rs b/crates/turtle/src/command/client/config.rs new file mode 100644 index 00000000..1597a8d6 --- /dev/null +++ b/crates/turtle/src/command/client/config.rs @@ -0,0 +1,352 @@ +use crate::atuin_client::settings::Settings; +use clap::{Args, Subcommand, ValueEnum}; +use eyre::Result; +use toml_edit::{Document, DocumentMut, Item, Table, TableLike, Value}; + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Get a configuration value from your config.toml file + /// or after defaults and overrides are applied + #[command()] + Get(GetCmd), + + /// Set a configuration value in your config.toml file + #[command()] + Set(SetCmd), + + /// Print all configuration values from your config.toml file + /// in TOML format + /// + /// If a key is provided, only print the value of that key and all its children + #[command()] + Print(PrintCmd), +} + +impl Cmd { + pub async fn run(self, settings: &Settings) -> Result<()> { + match self { + Self::Get(get) => get.run(settings).await, + Self::Set(set) => set.run(settings).await, + Self::Print(print) => print.run(settings).await, + } + } +} + +/// Get a configuration value from your config.toml file, +/// or optionally the effective value after defaults and overrides are applied. +#[derive(Args, Debug)] +pub struct GetCmd { + /// The configuration key to get + pub key: String, + + /// Print the value after defaults and overrides are applied + #[arg(long, short)] + pub resolved: bool, + + /// Print both the config file value and the resolved value + #[arg(long, short)] + pub verbose: bool, +} + +impl GetCmd { + pub async fn run(&self, _settings: &Settings) -> Result<()> { + let key = self.key.trim(); + if key.is_empty() || key.contains(char::is_whitespace) { + eyre::bail!("Config key must be non-empty and must not contain whitespace"); + } + + if self.verbose { + println!("Config file:"); + self.print_current_value(key, " ").await?; + println!("\nResolved:"); + Self::print_effective_value(key, " "); + return Ok(()); + } + + if self.resolved { + Self::print_effective_value(key, ""); + } else { + self.print_current_value(key, "").await?; + } + + Ok(()) + } + + async fn print_current_value(&self, key: &str, prefix: &str) -> Result<()> { + let config_file = Settings::get_config_path()?; + let config_str = tokio::fs::read_to_string(&config_file).await?; + let doc = config_str.parse::<Document<_>>()?; + + let current = get_deep_key(&doc, key); + + match current { + Some(item) if item.is_table() || item.is_inline_table() => { + let table = item + .as_table_like() + .expect("is_table()/is_inline_table() but no table"); + println!("{prefix}[{key}]"); + dump_table(table, prefix, &mut vec![key.to_string()])?; + } + Some(item) => { + let val = item.to_string(); + let val = val.trim().trim_matches('"'); + println!("{prefix}{val}"); + } + None => { + println!("{prefix}(not set in config file)"); + } + } + + Ok(()) + } + + fn print_effective_value(key: &str, prefix: &str) { + match Settings::get_config_value(key) { + Ok(value) => { + for line in value.lines() { + println!("{prefix}{line}"); + } + } + Err(_) => { + println!("{prefix}(unknown key)"); + } + } + } +} + +#[derive(Args, Debug)] +pub struct SetCmd { + /// The configuration key to set + pub key: String, + + /// The value to set + pub value: String, + + /// Store value as an explicit type + #[arg(long = "type", short, value_enum, default_value_t = ValueType::Auto, value_name = "TYPE")] + pub the_type: ValueType, +} + +#[derive(ValueEnum, Debug, Clone, PartialEq, Eq)] +pub enum ValueType { + /// Automatically determine the type of the value + Auto, + /// Store value as a string + String, + /// Store value as a boolean + Boolean, + /// Store value as an integer + Integer, + /// Store the value as a float + Float, +} + +impl SetCmd { + pub async fn run(self, _settings: &Settings) -> Result<()> { + let key = self.key.trim(); + if key.is_empty() || key.contains(char::is_whitespace) { + eyre::bail!("Config key must be non-empty and must not contain whitespace"); + } + + let config_file = Settings::get_config_path()?; + let config_str = tokio::fs::read_to_string(&config_file).await?; + let mut doc: DocumentMut = config_str.parse()?; + + // When using auto type detection, try to match the existing value's type + // so we don't accidentally change e.g. "300" (string) to 300 (integer) + let existing_type = detect_existing_type(&doc, key); + let value = self.parse_value(existing_type.as_ref())?; + set_deep_key(&mut doc, key, value)?; + + tokio::fs::write(&config_file, doc.to_string()).await?; + + Ok(()) + } + + fn parse_value(&self, existing_type: Option<&ValueType>) -> Result<Value> { + let raw = &self.value; + + // Explicit --type takes priority, then existing value type, then auto-detect + let effective_type = if self.the_type != ValueType::Auto { + &self.the_type + } else if let Some(existing) = existing_type { + existing + } else { + &ValueType::Auto + }; + + match effective_type { + ValueType::String => Ok(Value::from(raw.as_str())), + ValueType::Boolean => { + let b: bool = raw + .parse() + .map_err(|_| eyre::eyre!("invalid boolean value: {raw}"))?; + Ok(Value::from(b)) + } + ValueType::Integer => { + let i: i64 = raw + .parse() + .map_err(|_| eyre::eyre!("invalid integer value: {raw}"))?; + Ok(Value::from(i)) + } + ValueType::Float => { + let f: f64 = raw + .parse() + .map_err(|_| eyre::eyre!("invalid float value: {raw}"))?; + Ok(Value::from(f)) + } + ValueType::Auto => { + if raw == "true" || raw == "false" { + return Ok(Value::from(raw == "true")); + } + if let Ok(i) = raw.parse::<i64>() { + return Ok(Value::from(i)); + } + if let Ok(f) = raw.parse::<f64>() { + return Ok(Value::from(f)); + } + Ok(Value::from(raw.as_str())) + } + } + } +} + +#[derive(Args, Debug)] +pub struct PrintCmd { + /// Print the value of a specific key and all its children + pub key: Option<String>, +} + +impl PrintCmd { + pub async fn run(&self, _settings: &Settings) -> Result<()> { + let config_file = Settings::get_config_path()?; + let config_str = tokio::fs::read_to_string(&config_file).await?; + let doc = config_str.parse::<Document<_>>()?; + + if let Some(key) = &self.key { + let current = get_deep_key(&doc, key); + + if let Some(current) = current { + if current.is_table() || current.is_inline_table() { + println!("[{key}]"); + dump_table( + current + .as_table_like() + .expect("is_table()/is_inline_table() but no table"), + "", + &mut vec![key.clone()], + )?; + } else { + println!("{}", current.to_string().trim().trim_matches('"')); + } + } else { + println!("key not found"); + } + } else { + dump_table(doc.as_table(), "", &mut Vec::new())?; + } + + Ok(()) + } +} + +fn dump_table(table: &dyn TableLike, prefix: &str, stack: &mut Vec<String>) -> Result<()> { + for (key, value) in table.iter() { + if value.is_table() || value.is_inline_table() { + stack.push(key.to_string()); + + let table = value + .as_table_like() + .expect("is_table()/is_inline_table() but no table"); + + println!("\n{}[{}]", prefix, stack.join(".")); + + dump_table(table, prefix, stack)?; + + stack.pop(); + } else { + println!("{prefix}{key} = {value}"); + } + } + + Ok(()) +} + +fn get_deep_key<'doc>(doc: &'doc Document<String>, key: &str) -> Option<&'doc Item> { + let parts = key.split('.'); + let mut current: Option<&Item> = Some(doc.as_item()); + + for part in parts { + current = current + .and_then(|item| item.as_table_like()) + .and_then(|table| table.get(part)); + } + + current +} + +/// Detect the TOML type of an existing key in the document, so `set` with auto +/// type detection preserves the original type rather than guessing from the value string. +fn detect_existing_type(doc: &DocumentMut, key: &str) -> Option<ValueType> { + let parts: Vec<&str> = key.split('.').collect(); + let mut current: &dyn TableLike = doc.as_table(); + + for &part in &parts[..parts.len().saturating_sub(1)] { + current = current.get(part)?.as_table_like()?; + } + + let last = parts.last()?; + let v = current.get(last)?.as_value()?; + + if v.is_str() { + Some(ValueType::String) + } else if v.is_bool() { + Some(ValueType::Boolean) + } else if v.is_integer() { + Some(ValueType::Integer) + } else if v.is_float() { + Some(ValueType::Float) + } else { + None + } +} + +fn set_deep_key(doc: &mut DocumentMut, key: &str, value: Value) -> Result<()> { + let parts: Vec<&str> = key.split('.').collect(); + + if parts.is_empty() { + eyre::bail!("empty config key"); + } + + let mut current: &mut dyn TableLike = doc.as_table_mut(); + + // Navigate/create intermediate tables + for &part in &parts[..parts.len() - 1] { + if !current.contains_key(part) { + current.insert(part, Item::Table(Table::new())); + } + current = current + .get_mut(part) + .expect("just inserted or already exists") + .as_table_like_mut() + .ok_or_else(|| eyre::eyre!("'{}' exists but is not a table", part))?; + } + + let last = *parts.last().unwrap(); + + // Don't silently overwrite a table with a scalar value + if let Some(existing) = current.get(last) + && (existing.is_table() || existing.is_inline_table()) + { + eyre::bail!( + "'{}' is a table; use a dotted key like '{}.key' to set a value within it", + key, + key + ); + } + + current.insert(last, Item::Value(value)); + + Ok(()) +} diff --git a/crates/turtle/src/command/client/daemon.rs b/crates/turtle/src/command/client/daemon.rs new file mode 100644 index 00000000..2ee9b759 --- /dev/null +++ b/crates/turtle/src/command/client/daemon.rs @@ -0,0 +1,769 @@ +use std::fs::{self, File, OpenOptions}; +use std::io::{ErrorKind, Write}; +#[cfg(unix)] +use std::os::unix::net::UnixStream as StdUnixStream; +use std::path::{Path, PathBuf}; +use std::process::{Command, Stdio}; +use std::time::{Duration, Instant}; + +use crate::atuin_client::{ + database::Sqlite, history::History, record::sqlite_store::SqliteStore, settings::Settings, +}; +use crate::atuin_daemon::DaemonEvent; +use crate::atuin_daemon::client::{ + ControlClient, DaemonClientErrorKind, HistoryClient, classify_error, +}; +use clap::Subcommand; +#[cfg(unix)] +use daemonize::Daemonize; +use eyre::{Result, WrapErr, bail, eyre}; +use fs4::fs_std::FileExt; +use tokio::time::sleep; + +#[derive(clap::Args, Debug)] +pub struct Cmd { + /// Internal flag for daemonization + #[arg(long, hide = true)] + daemonize: bool, + + /// Also write daemon logs to the console (useful for debugging) + #[arg(long)] + show_logs: bool, + + #[command(subcommand)] + subcmd: Option<SubCmd>, +} + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum SubCmd { + /// Start the daemon server + Start { + #[arg(long, hide = true)] + daemonize: bool, + + /// Also write daemon logs to the console (useful for debugging) + #[arg(long)] + show_logs: bool, + + /// Force start: kill existing daemon process and reset the socket + #[arg(long)] + force: bool, + }, + + /// Show the daemon's current status + Status, + + /// Stop the daemon gracefully + Stop, + + /// Restart the daemon (stop, then start in background) + Restart, +} + +impl Cmd { + /// Returns `true` when the process should daemonize before creating the + /// async runtime or opening any database connections. + #[cfg(unix)] + pub fn should_daemonize(&self) -> bool { + match &self.subcmd { + Some(SubCmd::Start { daemonize, .. }) => *daemonize, + None => self.daemonize, + _ => false, + } + } + + /// Returns `true` when logs should also be written to the console. + pub fn show_logs(&self) -> bool { + match &self.subcmd { + Some(SubCmd::Start { show_logs, .. }) => *show_logs, + None => self.show_logs, + _ => false, + } + } + + pub async fn run( + self, + settings: Settings, + store: SqliteStore, + history_db: Sqlite, + ) -> Result<()> { + match self.subcmd { + None => { + eprintln!("Warning: `atuin daemon` is deprecated, use `atuin daemon start`"); + run(settings, store, history_db, false).await + } + Some(SubCmd::Start { force, .. }) => run(settings, store, history_db, force).await, + Some(SubCmd::Status) => status_cmd(&settings).await, + Some(SubCmd::Stop) => stop_cmd(&settings).await, + Some(SubCmd::Restart) => restart_cmd(&settings).await, + } + } +} + +const DAEMON_VERSION: &str = env!("CARGO_PKG_VERSION"); +const DAEMON_PROTOCOL_VERSION: u32 = 1; +const STARTUP_POLL: Duration = Duration::from_millis(40); +const LOCK_POLL: Duration = Duration::from_millis(20); +const LEGACY_DAEMON_RESTART_MESSAGE: &str = "legacy daemon detected; restart daemon manually"; + +struct PidfileGuard { + file: File, +} + +impl PidfileGuard { + fn acquire(path: &Path) -> Result<Self> { + let mut file = open_lock_file(path)?; + + if !file.try_lock_exclusive()? { + bail!( + "daemon already running (pidfile lock busy at {})", + path.display() + ); + } + + file.set_len(0) + .wrap_err_with(|| format!("could not truncate daemon pidfile {}", path.display()))?; + writeln!(file, "{}", std::process::id()) + .and_then(|()| writeln!(file, "{DAEMON_VERSION}")) + .wrap_err_with(|| format!("could not write daemon pidfile {}", path.display()))?; + + Ok(Self { file }) + } +} + +impl Drop for PidfileGuard { + fn drop(&mut self) { + let _ = self.file.unlock(); + } +} + +enum Probe { + Ready(HistoryClient), + NeedsRestart(String), + Unreachable(eyre::Report), +} + +fn daemon_matches_expected(version: &str, protocol: u32) -> bool { + version == DAEMON_VERSION && protocol == DAEMON_PROTOCOL_VERSION +} + +fn daemon_mismatch_message(version: &str, protocol: u32) -> String { + if protocol == DAEMON_PROTOCOL_VERSION { + format!("daemon is out of date: expected {DAEMON_VERSION}, got {version}") + } else { + format!("daemon protocol mismatch: expected {DAEMON_PROTOCOL_VERSION}, got {protocol}") + } +} + +fn is_legacy_daemon_error(err: &eyre::Report) -> bool { + matches!(classify_error(err), DaemonClientErrorKind::Unimplemented) +} + +fn should_retry_after_error(err: &eyre::Report) -> bool { + matches!( + classify_error(err), + DaemonClientErrorKind::Connect + | DaemonClientErrorKind::Unavailable + | DaemonClientErrorKind::Unimplemented + ) +} + +fn daemon_startup_lock_path(pidfile_path: &Path) -> PathBuf { + let mut os = pidfile_path.as_os_str().to_os_string(); + os.push(".startup.lock"); + PathBuf::from(os) +} + +fn open_lock_file(path: &Path) -> Result<File> { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent) + .wrap_err_with(|| format!("could not create lock directory {}", parent.display()))?; + } + + OpenOptions::new() + .read(true) + .write(true) + .create(true) + .truncate(false) + .open(path) + .wrap_err_with(|| format!("could not open lock file {}", path.display())) +} + +async fn wait_for_lock(path: &Path, timeout: Duration) -> Result<File> { + let file = open_lock_file(path)?; + let start = Instant::now(); + + loop { + match file.try_lock_exclusive() { + Ok(true) => return Ok(file), + Ok(false) => { + if start.elapsed() >= timeout { + bail!("timed out waiting for lock at {}", path.display()); + } + + sleep(LOCK_POLL).await; + } + Err(err) => { + return Err(eyre!("could not lock {}: {err}", path.display())); + } + } + } +} + +async fn wait_for_pidfile_available(path: &Path, timeout: Duration) -> Result<()> { + let file = wait_for_lock(path, timeout).await?; + file.unlock() + .wrap_err_with(|| format!("failed to unlock {}", path.display()))?; + Ok(()) +} + +async fn connect_client(settings: &Settings) -> Result<HistoryClient> { + HistoryClient::new( + #[cfg(unix)] + settings.daemon.socket_path.clone(), + ) + .await +} + +async fn probe(settings: &Settings) -> Probe { + let mut client = match connect_client(settings).await { + Ok(client) => client, + Err(err) => return Probe::Unreachable(err), + }; + + match client.status().await { + Ok(status) => { + if daemon_matches_expected(&status.version, status.protocol) { + Probe::Ready(client) + } else { + Probe::NeedsRestart(daemon_mismatch_message(&status.version, status.protocol)) + } + } + Err(err) => Probe::Unreachable(err), + } +} + +async fn request_shutdown(settings: &Settings) { + if let Ok(mut client) = connect_client(settings).await { + let _ = client.shutdown().await; + } +} + +fn spawn_daemon_process() -> Result<()> { + let exe = std::env::current_exe().wrap_err("could not locate atuin executable")?; + + let mut cmd = Command::new(exe); + cmd.arg("daemon") + .arg("start") + .stdin(Stdio::null()) + .stdout(Stdio::null()) + .stderr(Stdio::null()); + + #[cfg(unix)] + cmd.arg("--daemonize"); + + cmd.spawn().wrap_err("failed to spawn daemon process")?; + + Ok(()) +} + +fn startup_timeout(settings: &Settings) -> Duration { + Duration::from_secs_f64(settings.local_timeout.max(0.5) + 2.0) +} + +#[cfg(unix)] +fn remove_stale_socket_if_present(settings: &Settings) -> Result<()> { + if settings.daemon.systemd_socket { + return Ok(()); + } + + let socket_path = Path::new(&settings.daemon.socket_path); + if !socket_path.exists() { + return Ok(()); + } + + match StdUnixStream::connect(socket_path) { + Ok(stream) => { + drop(stream); + Ok(()) + } + Err(err) if err.kind() == ErrorKind::ConnectionRefused => { + fs::remove_file(socket_path).wrap_err_with(|| { + format!( + "failed to remove stale daemon socket {}", + socket_path.display() + ) + })?; + Ok(()) + } + Err(err) if err.kind() == ErrorKind::NotFound => Ok(()), + Err(_) => Ok(()), + } +} + +async fn wait_until_ready(settings: &Settings, timeout: Duration) -> Result<HistoryClient> { + let start = Instant::now(); + let mut last_error = eyre!("daemon did not become ready"); + + loop { + match probe(settings).await { + Probe::Ready(client) => return Ok(client), + Probe::NeedsRestart(reason) => { + last_error = eyre!(reason); + } + Probe::Unreachable(err) => { + if is_legacy_daemon_error(&err) { + return Err(err.wrap_err(LEGACY_DAEMON_RESTART_MESSAGE)); + } + last_error = err; + } + } + + if start.elapsed() >= timeout { + return Err(last_error.wrap_err(format!( + "timed out waiting for daemon startup after {}ms", + timeout.as_millis() + ))); + } + + sleep(STARTUP_POLL).await; + } +} + +#[expect(clippy::unnecessary_wraps)] +fn ensure_autostart_supported(settings: &Settings) -> Result<()> { + #[cfg(unix)] + if settings.daemon.systemd_socket { + bail!( + "daemon autostart is incompatible with `daemon.systemd_socket = true`; use systemd to manage the daemon" + ); + } + + Ok(()) +} + +/// Ensure the daemon is running, starting it if necessary. +/// +/// If the daemon is already running and up-to-date, this is a no-op. +/// If it is not running or needs a restart, this will spawn a new daemon +/// process and wait for it to become ready. +/// +/// Returns an error if the daemon could not be started. +pub async fn ensure_daemon_running(settings: &Settings) -> Result<()> { + ensure_autostart_supported(settings)?; + + let timeout = startup_timeout(settings); + let pidfile_path = PathBuf::from(&settings.daemon.pidfile_path); + let startup_lock_path = daemon_startup_lock_path(&pidfile_path); + let startup_lock = wait_for_lock(&startup_lock_path, timeout).await?; + + match probe(settings).await { + Probe::Ready(_) => { + drop(startup_lock); + return Ok(()); + } + Probe::NeedsRestart(_) => { + request_shutdown(settings).await; + } + Probe::Unreachable(err) => { + if is_legacy_daemon_error(&err) { + return Err(err.wrap_err(LEGACY_DAEMON_RESTART_MESSAGE)); + } + } + } + + // This prevents rapid-fire hook invocations from racing daemon restart. + wait_for_pidfile_available(&pidfile_path, timeout).await?; + + #[cfg(unix)] + remove_stale_socket_if_present(settings)?; + + spawn_daemon_process()?; + let _ = wait_until_ready(settings, timeout).await?; + + drop(startup_lock); + Ok(()) +} + +async fn restart_daemon(settings: &Settings) -> Result<HistoryClient> { + ensure_daemon_running(settings).await?; + connect_client(settings).await +} + +fn ensure_reply_compatible(settings: &Settings, version: &str, protocol: u32) -> Result<()> { + if daemon_matches_expected(version, protocol) { + return Ok(()); + } + + let message = daemon_mismatch_message(version, protocol); + if settings.daemon.autostart { + bail!("{message}"); + } + + bail!("{message}. Enable `daemon.autostart = true` or restart the daemon manually"); +} + +pub async fn start_history(settings: &Settings, history: History) -> Result<String> { + match async { + connect_client(settings) + .await? + .start_history(history.clone()) + .await + } + .await + { + Ok(resp) => { + if daemon_matches_expected(&resp.version, resp.protocol) { + return Ok(resp.id); + } + + if !settings.daemon.autostart { + return Err(eyre!( + "{}. Enable `daemon.autostart = true` or restart the daemon manually", + daemon_mismatch_message(&resp.version, resp.protocol) + )); + } + } + Err(err) if !settings.daemon.autostart => return Err(err), + Err(err) if !should_retry_after_error(&err) => return Err(err), + Err(_) => {} + } + + let resp = restart_daemon(settings) + .await? + .start_history(history) + .await?; + ensure_reply_compatible(settings, &resp.version, resp.protocol)?; + Ok(resp.id) +} + +pub async fn end_history(settings: &Settings, id: String, duration: u64, exit: i64) -> Result<()> { + match async { + connect_client(settings) + .await? + .end_history(id.clone(), duration, exit) + .await + } + .await + { + Ok(resp) => { + if daemon_matches_expected(&resp.version, resp.protocol) { + return Ok(()); + } + + if !settings.daemon.autostart { + return Err(eyre!( + "{}. Enable `daemon.autostart = true` or restart the daemon manually", + daemon_mismatch_message(&resp.version, resp.protocol) + )); + } + + // End succeeded on the running daemon, so avoid replaying it. + // We only restart to make subsequent hook calls target the expected version. + let _ = restart_daemon(settings).await; + return Ok(()); + } + Err(err) if !settings.daemon.autostart => return Err(err), + Err(err) if !should_retry_after_error(&err) => return Err(err), + Err(_) => {} + } + + let resp = restart_daemon(settings) + .await? + .end_history(id, duration, exit) + .await?; + ensure_reply_compatible(settings, &resp.version, resp.protocol)?; + Ok(()) +} + +/// Emit a daemon event, auto-starting the daemon if it is not running. +/// +/// If the daemon is not reachable and `daemon.autostart` is enabled, this +/// will start the daemon and retry the event. If the daemon cannot be +/// started or the retry fails, a warning is printed to stderr. +pub async fn emit_event(settings: &Settings, event: DaemonEvent) { + // Try to connect and send + match ControlClient::from_settings(settings).await { + Ok(mut client) => { + if let Err(e) = client.send_event(event).await { + tracing::debug!(?e, "failed to send event to daemon"); + } + return; + } + Err(e) if !settings.daemon.autostart || !should_retry_after_error(&e) => { + tracing::debug!(?e, "daemon not available, skipping event emission"); + return; + } + Err(_) => {} + } + + // Auto-start the daemon and retry + if let Err(e) = ensure_daemon_running(settings).await { + eprintln!("Could not start daemon: {e}"); + return; + } + + match ControlClient::from_settings(settings).await { + Ok(mut client) => { + if let Err(e) = client.send_event(event).await { + eprintln!("Daemon started but failed to send event: {e}"); + } + } + Err(e) => { + eprintln!("Daemon started but failed to connect: {e}"); + } + } +} + +pub async fn tail_client(settings: &Settings) -> Result<HistoryClient> { + match probe(settings).await { + Probe::Ready(client) => return Ok(client), + Probe::NeedsRestart(reason) if !settings.daemon.autostart => { + bail!("{reason}. Enable `daemon.autostart = true` or restart the daemon manually"); + } + Probe::Unreachable(err) if is_legacy_daemon_error(&err) => { + return Err(err.wrap_err(LEGACY_DAEMON_RESTART_MESSAGE)); + } + Probe::Unreachable(err) if !settings.daemon.autostart => return Err(err), + Probe::Unreachable(err) if !should_retry_after_error(&err) => return Err(err), + Probe::NeedsRestart(_) | Probe::Unreachable(_) => {} + } + + restart_daemon(settings).await +} + +async fn status_cmd(settings: &Settings) -> Result<()> { + match probe(settings).await { + Probe::Ready(mut client) => { + let status = client.status().await?; + println!("Daemon running"); + println!(" PID: {}", status.pid); + println!(" Version: {}", status.version); + println!(" Protocol: {}", status.protocol); + println!(" Healthy: {}", status.healthy); + #[cfg(unix)] + println!(" Socket: {}", settings.daemon.socket_path); + } + Probe::NeedsRestart(reason) => { + println!("Daemon running (needs restart)"); + println!(" Reason: {reason}"); + } + Probe::Unreachable(_) => { + println!("Daemon is not running"); + } + } + + Ok(()) +} + +async fn stop_cmd(settings: &Settings) -> Result<()> { + let Ok(mut client) = connect_client(settings).await else { + println!("Daemon is not running"); + return Ok(()); + }; + + match client.shutdown().await { + Ok(true) => { + println!("Shutdown requested"); + + let pidfile_path = PathBuf::from(&settings.daemon.pidfile_path); + let timeout = Duration::from_secs(5); + match wait_for_pidfile_available(&pidfile_path, timeout).await { + Ok(()) => println!("Daemon stopped"), + Err(_) => println!("Daemon may still be shutting down"), + } + + Ok(()) + } + Ok(false) => bail!("Daemon rejected shutdown request"), + Err(err) => Err(err.wrap_err("Failed to send shutdown request")), + } +} + +async fn restart_cmd(settings: &Settings) -> Result<()> { + // Stop if running + match probe(settings).await { + Probe::Ready(_) | Probe::NeedsRestart(_) => { + request_shutdown(settings).await; + println!("Stopping daemon..."); + + let pidfile_path = PathBuf::from(&settings.daemon.pidfile_path); + let timeout = Duration::from_secs(5); + wait_for_pidfile_available(&pidfile_path, timeout) + .await + .wrap_err("Timed out waiting for old daemon to stop")?; + } + Probe::Unreachable(_) => { + println!("No daemon running"); + } + } + + #[cfg(unix)] + remove_stale_socket_if_present(settings)?; + + spawn_daemon_process()?; + println!("Starting daemon..."); + + let timeout = startup_timeout(settings); + let status = wait_until_ready(settings, timeout).await?.status().await?; + + println!("Daemon restarted"); + println!(" PID: {}", status.pid); + println!(" Version: {}", status.version); + + Ok(()) +} + +/// Daemonize the current process. Must be called before creating the tokio +/// runtime or opening database connections, since `fork()` inside an async +/// runtime corrupts its internal state. +#[cfg(unix)] +pub fn daemonize_current_process() -> Result<()> { + let cwd = + std::env::current_dir().wrap_err("could not determine current directory for daemon")?; + + Daemonize::new() + .working_directory(cwd) + .start() + .wrap_err("failed to daemonize process")?; + + Ok(()) +} + +async fn run( + settings: Settings, + store: SqliteStore, + history_db: Sqlite, + force: bool, +) -> Result<()> { + if force { + force_cleanup(&settings); + } + + let pidfile_path = PathBuf::from(&settings.daemon.pidfile_path); + let _pidfile_guard = PidfileGuard::acquire(&pidfile_path)?; + + crate::atuin_daemon::boot(settings, store, history_db).await?; + + Ok(()) +} + +/// Force cleanup: kill existing daemon process and remove socket. +fn force_cleanup(settings: &Settings) { + let pidfile_path = Path::new(&settings.daemon.pidfile_path); + + // Read and kill the existing process if pidfile exists + if pidfile_path.exists() { + if let Ok(contents) = fs::read_to_string(pidfile_path) + && let Some(pid_str) = contents.lines().next() + && let Ok(pid) = pid_str.parse::<u32>() + { + kill_process(pid); + // Give it a moment to release resources + std::thread::sleep(Duration::from_millis(100)); + } + + // Remove the pidfile + if let Err(e) = fs::remove_file(pidfile_path) + && e.kind() != ErrorKind::NotFound + { + tracing::warn!("failed to remove pidfile: {e}"); + } + } + + // Remove the socket file + #[cfg(unix)] + { + let socket_path = Path::new(&settings.daemon.socket_path); + if socket_path.exists() + && let Err(e) = fs::remove_file(socket_path) + && e.kind() != ErrorKind::NotFound + { + tracing::warn!("failed to remove socket: {e}"); + } + } +} + +/// Kill a process by PID. +#[cfg(unix)] +fn kill_process(pid: u32) { + // Use kill command to send SIGTERM for graceful shutdown + let _ = Command::new("kill") + .args(["-TERM", &pid.to_string()]) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .status(); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_version_matches() { + assert!(daemon_matches_expected( + DAEMON_VERSION, + DAEMON_PROTOCOL_VERSION + )); + } + + #[test] + fn test_version_mismatch() { + assert!(!daemon_matches_expected("0.0.0", DAEMON_PROTOCOL_VERSION)); + assert!(!daemon_matches_expected(DAEMON_VERSION, 999)); + assert!(!daemon_matches_expected("0.0.0", 999)); + } + + #[test] + fn test_mismatch_message_version() { + let msg = daemon_mismatch_message("0.0.0", DAEMON_PROTOCOL_VERSION); + assert!(msg.contains("out of date"), "got: {msg}"); + assert!(msg.contains("0.0.0")); + assert!(msg.contains(DAEMON_VERSION)); + } + + #[test] + fn test_mismatch_message_protocol() { + let msg = daemon_mismatch_message(DAEMON_VERSION, 999); + assert!(msg.contains("protocol mismatch"), "got: {msg}"); + } + + #[test] + fn test_startup_lock_path() { + let pidfile = Path::new("/tmp/atuin-daemon.pid"); + let lock = daemon_startup_lock_path(pidfile); + assert_eq!(lock, PathBuf::from("/tmp/atuin-daemon.pid.startup.lock")); + } + + #[test] + fn test_pidfile_guard_acquire_and_drop() { + let tmp = tempfile::tempdir().unwrap(); + let pidfile = tmp.path().join("daemon.pid"); + + { + let _guard = PidfileGuard::acquire(&pidfile).unwrap(); + // Guard holds an exclusive lock — on Windows other handles cannot + // read the file, so we verify contents after the guard is dropped. + } + + let contents = std::fs::read_to_string(&pidfile).unwrap(); + let lines: Vec<&str> = contents.lines().collect(); + assert_eq!(lines.len(), 2); + assert_eq!(lines[0], std::process::id().to_string()); + assert_eq!(lines[1], DAEMON_VERSION); + + // After guard is dropped, lock should be released — acquiring again must succeed. + let _guard2 = PidfileGuard::acquire(&pidfile).unwrap(); + } + + #[test] + fn test_pidfile_guard_prevents_double_acquire() { + let tmp = tempfile::tempdir().unwrap(); + let pidfile = tmp.path().join("daemon.pid"); + + let _guard = PidfileGuard::acquire(&pidfile).unwrap(); + let result = PidfileGuard::acquire(&pidfile); + assert!(result.is_err()); + } +} diff --git a/crates/turtle/src/command/client/default_config.rs b/crates/turtle/src/command/client/default_config.rs new file mode 100644 index 00000000..e8cc15f9 --- /dev/null +++ b/crates/turtle/src/command/client/default_config.rs @@ -0,0 +1,4 @@ +pub fn run() { + // TODO(@bpeetz): Re-add the default settings option back (Settings::example_config()) <2026-06-11> + println!("TODO"); +} diff --git a/crates/turtle/src/command/client/doctor.rs b/crates/turtle/src/command/client/doctor.rs new file mode 100644 index 00000000..09fa6e77 --- /dev/null +++ b/crates/turtle/src/command/client/doctor.rs @@ -0,0 +1,412 @@ +use std::process::Command; +use std::{env, str::FromStr}; + +use crate::atuin_client::database::Sqlite; +use crate::atuin_client::settings::Settings; +use crate::atuin_common::shell::{Shell, shell_name}; +use crate::atuin_common::utils; +use colored::Colorize; +use eyre::Result; +use serde::Serialize; + +use sysinfo::{Disks, System, get_current_pid}; + +#[derive(Debug, Serialize)] +struct ShellInfo { + pub name: String, + + // best-effort, not supported on all OSes + pub default: String, + + // Detect some shell plugins that the user has installed. + // I'm just going to start with preexec/blesh + pub plugins: Vec<String>, + + // The preexec framework used in the current session, if Atuin is loaded. + pub preexec: Option<String>, +} + +impl ShellInfo { + // HACK ALERT! + // Many of the shell vars we need to detect are not exported :( + // So, we're going to run a interactive session and directly check the + // variable. There's a chance this won't work, so it should not be fatal. + // + // Every shell we support handles `shell -ic 'command'` + fn shellvar_exists(shell: &str, var: &str) -> bool { + let cmd = Command::new(shell) + .args([ + "-ic", + format!("[ -z ${var} ] || echo ATUIN_DOCTOR_ENV_FOUND").as_str(), + ]) + .output() + .map_or(String::new(), |v| { + let out = v.stdout; + String::from_utf8(out).unwrap_or_default() + }); + + cmd.contains("ATUIN_DOCTOR_ENV_FOUND") + } + + fn detect_preexec_framework(shell: &str) -> Option<String> { + if env::var("ATUIN_SESSION").ok().is_none() { + None + } else if shell.starts_with("bash") || shell == "sh" { + env::var("ATUIN_PREEXEC_BACKEND") + .ok() + .filter(|value| !value.is_empty()) + .and_then(|atuin_preexec_backend| { + atuin_preexec_backend.rfind(':').and_then(|pos_colon| { + u32::from_str(&atuin_preexec_backend[..pos_colon]) + .ok() + .is_some_and(|preexec_shlvl| { + env::var("SHLVL") + .ok() + .and_then(|shlvl| u32::from_str(&shlvl).ok()) + .is_some_and(|shlvl| shlvl == preexec_shlvl) + }) + .then(|| atuin_preexec_backend[pos_colon + 1..].to_string()) + }) + }) + } else { + Some("built-in".to_string()) + } + } + + fn validate_plugin_blesh( + _shell: &str, + shell_process: &sysinfo::Process, + ble_session_id: &str, + ) -> Option<String> { + ble_session_id + .split('/') + .nth(1) + .and_then(|field| u32::from_str(field).ok()) + .filter(|&blesh_pid| blesh_pid == shell_process.pid().as_u32()) + .map(|_| "blesh".to_string()) + } + + pub fn plugins(shell: &str, shell_process: &sysinfo::Process) -> Vec<String> { + // consider a different detection approach if there are plugins + // that don't set shell vars + + enum PluginShellType { + Any, + Bash, + + // Note: these are currently unused + #[expect(dead_code)] + Zsh, + #[expect(dead_code)] + Fish, + #[expect(dead_code)] + Nushell, + #[expect(dead_code)] + Xonsh, + } + + enum PluginProbeType { + EnvironmentVariable(&'static str), + InteractiveShellVariable(&'static str), + } + + type PluginValidator = fn(&str, &sysinfo::Process, &str) -> Option<String>; + + let plugin_list: [( + &str, + PluginShellType, + PluginProbeType, + Option<PluginValidator>, + ); 3] = [ + ( + "atuin", + PluginShellType::Any, + PluginProbeType::EnvironmentVariable("ATUIN_SESSION"), + None, + ), + ( + "blesh", + PluginShellType::Bash, + PluginProbeType::EnvironmentVariable("BLE_SESSION_ID"), + Some(Self::validate_plugin_blesh), + ), + ( + "bash-preexec", + PluginShellType::Bash, + PluginProbeType::InteractiveShellVariable("bash_preexec_imported"), + None, + ), + ]; + + plugin_list + .into_iter() + .filter(|(_, shell_type, _, _)| match shell_type { + PluginShellType::Any => true, + PluginShellType::Bash => shell.starts_with("bash") || shell == "sh", + PluginShellType::Zsh => shell.starts_with("zsh"), + PluginShellType::Fish => shell.starts_with("fish"), + PluginShellType::Nushell => shell.starts_with("nu"), + PluginShellType::Xonsh => shell.starts_with("xonsh"), + }) + .filter_map(|(plugin, _, probe_type, validator)| -> Option<String> { + match probe_type { + PluginProbeType::EnvironmentVariable(env) => { + env::var(env).ok().filter(|value| !value.is_empty()) + } + PluginProbeType::InteractiveShellVariable(shellvar) => { + ShellInfo::shellvar_exists(shell, shellvar).then_some(String::default()) + } + } + .and_then(|value| { + validator.map_or_else( + || Some(plugin.to_string()), + |validator| validator(shell, shell_process, &value), + ) + }) + }) + .collect() + } + + pub fn new() -> Self { + // TODO: rework to use crate::atuin_common::Shell + + let sys = System::new_all(); + + let process = sys + .process(get_current_pid().expect("Failed to get current PID")) + .expect("Process with current pid does not exist"); + + let parent = sys + .process(process.parent().expect("Atuin running with no parent!")) + .expect("Process with parent pid does not exist"); + + let name = shell_name(Some(parent)); + + let plugins = ShellInfo::plugins(name.as_str(), parent); + + let default = Shell::default_shell().unwrap_or(Shell::Unknown).to_string(); + + let preexec = Self::detect_preexec_framework(name.as_str()); + + Self { + name, + default, + plugins, + preexec, + } + } +} + +#[derive(Debug, Serialize)] +struct DiskInfo { + pub name: String, + pub filesystem: String, +} + +#[derive(Debug, Serialize)] +struct SystemInfo { + pub os: String, + + pub arch: String, + + pub version: String, + pub disks: Vec<DiskInfo>, +} + +impl SystemInfo { + pub fn new() -> Self { + let disks = Disks::new_with_refreshed_list(); + let disks = disks + .list() + .iter() + .map(|d| DiskInfo { + name: d.name().to_os_string().into_string().unwrap(), + filesystem: d.file_system().to_os_string().into_string().unwrap(), + }) + .collect(); + + Self { + os: System::name().unwrap_or_else(|| "unknown".to_string()), + arch: System::cpu_arch().unwrap_or_else(|| "unknown".to_string()), + version: System::os_version().unwrap_or_else(|| "unknown".to_string()), + disks, + } + } +} + +#[derive(Debug, Serialize)] +struct SyncInfo { + pub auth_state: String, + pub auto_sync: bool, + + pub last_sync: String, +} + +impl SyncInfo { + pub async fn new(settings: &Settings) -> Self { + // Build auth state description from raw token state without calling + // resolve_sync_auth(), which has side effects (token migration cleanup) + // that a diagnostic command should not trigger. + let meta = Settings::meta_store().await.ok(); + let has_cli_token = match &meta { + Some(m) => m.session_token().await.ok().flatten().is_some(), + None => false, + }; + + let auth_state = if has_cli_token { + "Self-hosted (authenticated)".into() + } else { + "Not authenticated".into() + }; + + Self { + auth_state, + auto_sync: settings.auto_sync, + last_sync: Settings::last_sync() + .await + .map_or_else(|_| "no last sync".to_string(), |v| v.to_string()), + } + } +} + +#[derive(Debug)] +struct SettingPaths { + db: String, + record_store: String, + key: String, +} + +impl SettingPaths { + pub fn new(settings: &Settings) -> Self { + Self { + db: settings.db_path.clone(), + record_store: settings.record_store_path.clone(), + key: settings.key_path.clone(), + } + } + + pub fn verify(&self) { + let paths = vec![ + ("ATUIN_DB_PATH", &self.db), + ("ATUIN_RECORD_STORE", &self.record_store), + ("ATUIN_KEY", &self.key), + ]; + + for (path_env_var, path) in paths { + if utils::broken_symlink(path) { + eprintln!( + "{path} (${path_env_var}) is a broken symlink. This may cause issues with Atuin." + ); + } + } + } +} + +#[derive(Debug, Serialize)] +struct AtuinInfo { + pub version: String, + pub commit: String, + + /// Whether the main Atuin sync server is in use + /// I'm just calling it Atuin Cloud for lack of a better name atm + pub sync: Option<SyncInfo>, + + pub sqlite_version: String, + + #[serde(skip)] // probably unnecessary to expose this + pub setting_paths: SettingPaths, +} + +impl AtuinInfo { + pub async fn new(settings: &Settings) -> Self { + let logged_in = settings.logged_in().await.unwrap_or(false); + + let sync = if logged_in { + Some(SyncInfo::new(settings).await) + } else { + None + }; + + let sqlite_version = match Sqlite::new("sqlite::memory:", 0.1).await { + Ok(db) => db + .sqlite_version() + .await + .unwrap_or_else(|_| "unknown".to_string()), + Err(_) => "error".to_string(), + }; + + Self { + version: crate::VERSION.to_string(), + commit: crate::SHA.to_string(), + sync, + sqlite_version, + setting_paths: SettingPaths::new(settings), + } + } +} + +#[derive(Debug, Serialize)] +struct DoctorDump { + pub atuin: AtuinInfo, + pub shell: ShellInfo, + pub system: SystemInfo, +} + +impl DoctorDump { + pub async fn new(settings: &Settings) -> Self { + Self { + atuin: AtuinInfo::new(settings).await, + shell: ShellInfo::new(), + system: SystemInfo::new(), + } + } +} + +fn checks(info: &DoctorDump) { + println!(); // spacing + // + let zfs_error = "[Filesystem] ZFS is known to have some issues with SQLite. Atuin uses SQLite heavily. If you are having poor performance, there are some workarounds here: https://github.com/atuinsh/atuin/issues/952".bold().red(); + let bash_plugin_error = "[Shell] If you are using Bash, Atuin requires that either bash-preexec or ble.sh (>= 0.4) be installed. An older ble.sh may not be detected. so ignore this if you have ble.sh >= 0.4 set up! Read more here: https://docs.atuin.sh/guide/installation/#bash".bold().red(); + let blesh_integration_error = "[Shell] Atuin and ble.sh seem to be loaded in the session, but the integration does not seem to be working. Please check the setup in .bashrc.".bold().red(); + + // ZFS: https://github.com/atuinsh/atuin/issues/952 + if info.system.disks.iter().any(|d| d.filesystem == "zfs") { + println!("{zfs_error}"); + } + + info.atuin.setting_paths.verify(); + + // Shell + if info.shell.name == "bash" { + if !info + .shell + .plugins + .iter() + .any(|p| p == "blesh" || p == "bash-preexec") + { + println!("{bash_plugin_error}"); + } + + if info.shell.plugins.iter().any(|plugin| plugin == "atuin") + && info.shell.plugins.iter().any(|plugin| plugin == "blesh") + && info.shell.preexec.as_ref().is_some_and(|val| val == "none") + { + println!("{blesh_integration_error}"); + } + } +} + +pub async fn run(settings: &Settings) -> Result<()> { + println!("{}", "Atuin Doctor".bold()); + println!("Checking for diagnostics"); + let dump = DoctorDump::new(settings).await; + + checks(&dump); + + let dump = serde_json::to_string_pretty(&dump)?; + + println!("\nPlease include the output below with any bug reports or issues\n"); + println!("{dump}"); + + Ok(()) +} diff --git a/crates/turtle/src/command/client/history.rs b/crates/turtle/src/command/client/history.rs new file mode 100644 index 00000000..0c61392c --- /dev/null +++ b/crates/turtle/src/command/client/history.rs @@ -0,0 +1,1340 @@ +use std::{ + fmt::{self, Display}, + io::{self, IsTerminal, Write}, + path::PathBuf, + time::Duration, +}; + +use crate::atuin_common::utils::{self, Escapable as _}; +use clap::Subcommand; +use eyre::{Context, Result, bail}; +use runtime_format::{FormatKey, FormatKeyError, ParseSegment, ParsedFmt}; + +#[cfg(feature = "daemon")] +use super::daemon as daemon_cmd; +#[cfg(feature = "daemon")] +use colored::Colorize; +#[cfg(feature = "daemon")] +use serde::Serialize; + +#[cfg(feature = "daemon")] +use crate::atuin_daemon::history::{HistoryEventKind, TailHistoryReply}; + +use crate::atuin_client::{ + database::{Database, Sqlite, current_context}, + encryption, + history::{History, store::HistoryStore}, + record::sqlite_store::SqliteStore, + settings::{ + FilterMode::{Directory, Global, Session}, + Settings, Timezone, + }, +}; + +#[cfg(feature = "sync")] +use crate::atuin_client::record; + +use log::{debug, warn}; +use time::{OffsetDateTime, macros::format_description}; + +#[cfg(feature = "daemon")] +use super::daemon; +use super::search::format_duration_into; + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Begins a new command in the history + Start { + /// Collects the command from the `ATUIN_COMMAND_LINE` environment variable, + /// which does not need escaping and is more compatible between OS and shells + #[arg(long = "command-from-env", hide = true)] + cmd_env: bool, + + /// Author of this command, eg `ellie`, `claude`, or `copilot` + #[arg(long)] + author: Option<String>, + + /// Optional intent/rationale for running this command + #[arg(long)] + intent: Option<String>, + + command: Vec<String>, + }, + + /// Finishes a new command in the history (adds time, exit code) + End { + id: String, + #[arg(long, short)] + exit: i64, + #[arg(long, short)] + duration: Option<u64>, + }, + + /// Stream history events from the daemon as they are received + Tail, + + /// List all items in history + List { + #[arg(long, short)] + cwd: bool, + + #[arg(long, short)] + session: bool, + + #[arg(long)] + human: bool, + + /// Show only the text of the command + #[arg(long)] + cmd_only: bool, + + /// Terminate the output with a null, for better multiline support + #[arg(long)] + print0: bool, + + #[arg(long, short, default_value = "true")] + // accept no value + #[arg(num_args(0..=1), default_missing_value("true"))] + // accept a value + #[arg(action = clap::ArgAction::Set)] + reverse: bool, + + /// Display the command time in another timezone other than the configured default. + /// + /// This option takes one of the following kinds of values: + /// - the special value "local" (or "l") which refers to the system time zone + /// - an offset from UTC (e.g. "+9", "-2:30") + #[arg(long, visible_alias = "tz")] + timezone: Option<Timezone>, + + /// Available variables: {command}, {directory}, {duration}, {user}, {host}, {author}, {intent}, {exit}, {time}, {session}, and {uuid} + /// Example: --format "{time} - [{duration}] - {directory}$\t{command}" + #[arg(long, short)] + format: Option<String>, + }, + + /// Get the last command ran + Last { + #[arg(long)] + human: bool, + + /// Show only the text of the command + #[arg(long)] + cmd_only: bool, + + /// Display the command time in another timezone other than the configured default. + /// + /// This option takes one of the following kinds of values: + /// - the special value "local" (or "l") which refers to the system time zone + /// - an offset from UTC (e.g. "+9", "-2:30") + #[arg(long, visible_alias = "tz")] + timezone: Option<Timezone>, + + /// Available variables: {command}, {directory}, {duration}, {user}, {host}, {author}, {intent}, {time}, {session}, {uuid} and {relativetime}. + /// Example: --format "{time} - [{duration}] - {directory}$\t{command}" + #[arg(long, short)] + format: Option<String>, + }, + + InitStore, + + /// Delete history entries matching the configured exclusion filters + Prune { + /// List matching history lines without performing the actual deletion. + #[arg(short = 'n', long)] + dry_run: bool, + }, + + /// Delete duplicate history entries (that have the same command, cwd and hostname) + Dedup { + /// List matching history lines without performing the actual deletion. + #[arg(short = 'n', long)] + dry_run: bool, + + /// Only delete results added before this date + #[arg(long, short)] + before: String, + + /// How many recent duplicates to keep + #[arg(long)] + dupkeep: u32, + }, +} + +#[derive(Clone, Copy, Debug)] +pub enum ListMode { + Human, + CmdOnly, + Regular, +} + +impl ListMode { + pub const fn from_flags(human: bool, cmd_only: bool) -> Self { + if human { + ListMode::Human + } else if cmd_only { + ListMode::CmdOnly + } else { + ListMode::Regular + } + } +} + +#[expect(clippy::cast_sign_loss)] +pub fn print_list( + h: &[History], + list_mode: ListMode, + format: Option<&str>, + print0: bool, + reverse: bool, + tz: Timezone, +) { + let w = std::io::stdout(); + let mut w = w.lock(); + + let fmt_str = match list_mode { + ListMode::Human => format + .unwrap_or("{time} · {duration}\t{command}") + .replace("\\t", "\t"), + ListMode::Regular => format + .unwrap_or("{time}\t{command}\t{duration}") + .replace("\\t", "\t"), + // not used + ListMode::CmdOnly => String::new(), + }; + + let parsed_fmt = match list_mode { + ListMode::Human | ListMode::Regular => parse_fmt(&fmt_str), + ListMode::CmdOnly => std::iter::once(ParseSegment::Key("command")).collect(), + }; + + let iterator = if reverse { + Box::new(h.iter().rev()) as Box<dyn Iterator<Item = &History>> + } else { + Box::new(h.iter()) as Box<dyn Iterator<Item = &History>> + }; + + let entry_terminator = if print0 { "\0" } else { "\n" }; + let flush_each_line = print0; + + for history in iterator { + let fh = FmtHistory { + history, + cmd_format: CmdFormat::for_output(&w), + tz: &tz, + }; + let args = parsed_fmt.with_args(&fh); + + // Check for formatting errors before attempting to write + if let Err(err) = args.status() { + eprintln!("ERROR: history output failed with: {err}"); + std::process::exit(1); + } + + let write_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + write!(w, "{args}{entry_terminator}") + })); + + match write_result { + Ok(Ok(())) => { + // Write succeeded + } + Ok(Err(err)) => { + if err.kind() != io::ErrorKind::BrokenPipe { + eprintln!("ERROR: Failed to write history output: {err}"); + std::process::exit(1); + } + } + Err(_) => { + eprintln!("ERROR: Format string caused a formatting error."); + eprintln!( + "This may be due to an unsupported format string containing special characters." + ); + eprintln!( + "Please check your format string syntax and ensure literal braces are properly escaped." + ); + std::process::exit(1); + } + } + if flush_each_line { + check_for_write_errors(w.flush()); + } + } + + if !flush_each_line { + check_for_write_errors(w.flush()); + } +} + +fn check_for_write_errors(write: Result<(), io::Error>) { + if let Err(err) = write { + // Ignore broken pipe (issue #626) + if err.kind() != io::ErrorKind::BrokenPipe { + eprintln!("ERROR: History output failed with the following error: {err}"); + std::process::exit(1); + } + } +} + +/// Type wrapper around `History` with formatting settings. +#[derive(Clone, Copy, Debug)] +struct FmtHistory<'a> { + history: &'a History, + cmd_format: CmdFormat, + tz: &'a Timezone, +} + +#[derive(Clone, Copy, Debug)] +enum CmdFormat { + Literal, + Escaped, +} +impl CmdFormat { + fn for_output<O: IsTerminal>(out: &O) -> Self { + if out.is_terminal() { + Self::Escaped + } else { + Self::Literal + } + } +} + +static TIME_FMT: &[time::format_description::FormatItem<'static>] = + format_description!("[year]-[month]-[day] [hour repr:24]:[minute]:[second]"); + +/// defines how to format the history +impl FormatKey for FmtHistory<'_> { + #[expect(clippy::cast_sign_loss)] + fn fmt(&self, key: &str, f: &mut fmt::Formatter<'_>) -> Result<(), FormatKeyError> { + match key { + "command" => match self.cmd_format { + CmdFormat::Literal => f.write_str(self.history.command.trim()), + CmdFormat::Escaped => f.write_str(&self.history.command.trim().escape_control()), + }?, + "directory" => f.write_str(self.history.cwd.trim())?, + "exit" => f.write_str(&self.history.exit.to_string())?, + "duration" => { + let dur = Duration::from_nanos(std::cmp::max(self.history.duration, 0) as u64); + format_duration_into(dur, f)?; + } + "time" => { + self.history + .timestamp + .to_offset(self.tz.0) + .format(TIME_FMT) + .map_err(|_| fmt::Error)? + .fmt(f)?; + } + "relativetime" => { + let since = OffsetDateTime::now_utc() - self.history.timestamp; + let d = Duration::try_from(since).unwrap_or_default(); + format_duration_into(d, f)?; + } + "host" => f.write_str( + self.history + .hostname + .split_once(':') + .map_or(&self.history.hostname, |(host, _)| host), + )?, + "author" => f.write_str(&self.history.author)?, + "intent" => f.write_str(self.history.intent.as_deref().unwrap_or_default())?, + "user" => f.write_str( + self.history + .hostname + .split_once(':') + .map_or("", |(_, user)| user), + )?, + "session" => f.write_str(&self.history.session)?, + "uuid" => f.write_str(&self.history.id.0)?, + _ => return Err(FormatKeyError::UnknownKey), + } + Ok(()) + } +} + +fn parse_fmt(format: &str) -> ParsedFmt<'_> { + match ParsedFmt::new(format) { + Ok(fmt) => fmt, + Err(err) => { + eprintln!("ERROR: History formatting failed with the following error: {err}"); + + if format.contains('"') && (format.contains(":{") || format.contains(",{")) { + eprintln!("It looks like you're trying to create JSON output."); + eprintln!("For JSON, you need to escape literal braces by doubling them:"); + eprintln!("Example: '{{\"command\":\"{{command}}\",\"time\":\"{{time}}\"}}'"); + } else { + eprintln!( + "If your formatting string contains literal curly braces, you need to escape them by doubling:" + ); + eprintln!("Use {{{{ for literal {{ and }}}} for literal }}"); + } + std::process::exit(1) + } + } +} + +fn apply_start_metadata(history: &mut History, author: Option<&str>, intent: Option<&str>) { + if let Some(author) = author.map(str::trim).filter(|author| !author.is_empty()) { + author.clone_into(&mut history.author); + } + + if let Some(intent) = intent.map(str::trim).filter(|intent| !intent.is_empty()) { + history.intent = Some(intent.to_owned()); + } else if intent.is_some() { + history.intent = None; + } +} + +fn normalize_command_for_storage<'a>(command: &'a str, settings: &Settings) -> &'a str { + if !settings.strip_trailing_whitespace { + return command; + } + + let trimmed = command.trim_end_matches([' ', '\t']); + if trimmed.len() == command.len() { + return command; + } + + let trailing_backslashes = trimmed + .as_bytes() + .iter() + .rev() + .take_while(|&&byte| byte == b'\\') + .count(); + + if trailing_backslashes % 2 == 1 { + command + } else { + trimmed + } +} + +async fn handle_start( + db: &impl Database, + settings: &Settings, + command: &str, + author: Option<&str>, + intent: Option<&str>, +) -> Result<Option<String>> { + // It's better for atuin to silently fail here and attempt to + // store whatever is ran, than to throw an error to the terminal + let cwd = utils::get_current_dir(); + let command = normalize_command_for_storage(command, settings); + + let mut h: History = History::capture() + .timestamp(OffsetDateTime::now_utc()) + .command(command) + .cwd(cwd) + .build() + .into(); + apply_start_metadata(&mut h, author, intent); + + if !h.should_save(settings) { + return Ok(None); + } + + let id = h.id.0.clone(); + + // Silently ignore database errors to avoid breaking the shell + // This is important when disk is full or database is locked + if let Err(e) = db.save(&h).await { + debug!("failed to save history: {e}"); + } + + Ok(Some(id)) +} + +#[cfg(feature = "daemon")] +async fn handle_daemon_start( + settings: &Settings, + command: &str, + author: Option<&str>, + intent: Option<&str>, +) -> Result<Option<String>> { + // It's better for atuin to silently fail here and attempt to + // store whatever is ran, than to throw an error to the terminal + let cwd = utils::get_current_dir(); + let command = normalize_command_for_storage(command, settings); + + let mut h: History = History::capture() + .timestamp(OffsetDateTime::now_utc()) + .command(command) + .cwd(cwd) + .build() + .into(); + apply_start_metadata(&mut h, author, intent); + + if !h.should_save(settings) { + return Ok(None); + } + + // Attempt to start history via daemon, but silently ignore errors + // to avoid breaking the shell when the daemon is unavailable or disk is full + let resp = match daemon::start_history(settings, h.clone()).await { + Ok(id) => id, + Err(e) => { + debug!("failed to start history via daemon: {e}"); + h.id.0.clone() + } + }; + + Ok(Some(resp)) +} + +#[expect(unused_variables)] +async fn handle_end( + db: &impl Database, + store: SqliteStore, + history_store: HistoryStore, + settings: &Settings, + id: &str, + exit: i64, + duration: Option<u64>, +) -> Result<()> { + if id.trim() == "" { + return Ok(()); + } + + let Some(mut h) = db.load(id).await? else { + warn!("history entry is missing"); + return Ok(()); + }; + + if h.duration > 0 { + debug!("cannot end history - already has duration"); + + // returning OK as this can occur if someone Ctrl-c a prompt + return Ok(()); + } + + if !settings.store_failed && exit > 0 { + debug!("history has non-zero exit code, and store_failed is false"); + + // the history has already been inserted half complete. remove it + db.delete(h).await?; + + return Ok(()); + } + + h.exit = exit; + h.duration = match duration { + Some(value) => i64::try_from(value).context("command took over 292 years")?, + None => i64::try_from((OffsetDateTime::now_utc() - h.timestamp).whole_nanoseconds()) + .context("command took over 292 years")?, + }; + + db.update(&h).await?; + history_store.push(h).await?; + + if settings.should_sync().await? { + let (_, downloaded) = + record::sync::sync(settings, &store, &history_store.encryption_key).await?; + Settings::save_sync_time().await?; + + crate::sync::build(settings, &store, db, Some(&downloaded)).await?; + } else { + debug!("sync disabled! not syncing"); + } + + Ok(()) +} + +#[cfg(feature = "daemon")] +async fn handle_daemon_end( + settings: &Settings, + id: &str, + exit: i64, + duration: Option<u64>, +) -> Result<()> { + daemon::end_history(settings, id.to_string(), duration.unwrap_or(0), exit).await?; + + Ok(()) +} + +pub(super) async fn start_history_entry( + settings: &Settings, + command: &str, + author: Option<&str>, + intent: Option<&str>, +) -> Result<Option<String>> { + #[cfg(feature = "daemon")] + if settings.daemon.enabled { + return handle_daemon_start(settings, command, author, intent).await; + } + + let db_path = PathBuf::from(settings.db_path.as_str()); + let db = Sqlite::new(db_path, settings.local_timeout).await?; + handle_start(&db, settings, command, author, intent).await +} + +pub(super) async fn end_history_entry( + settings: &Settings, + id: &str, + exit: i64, + duration: Option<u64>, +) -> Result<()> { + #[cfg(feature = "daemon")] + if settings.daemon.enabled { + return handle_daemon_end(settings, id, exit, duration).await; + } + + let db_path = PathBuf::from(settings.db_path.as_str()); + let record_store_path = PathBuf::from(settings.record_store_path.as_str()); + + let db = Sqlite::new(db_path, settings.local_timeout).await?; + let store = SqliteStore::new(record_store_path, settings.local_timeout).await?; + + let encryption_key: [u8; 32] = encryption::load_key(settings) + .context("could not load encryption key")? + .into(); + let host_id = Settings::host_id().await?; + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + + handle_end(&db, store, history_store, settings, id, exit, duration).await +} + +#[cfg(feature = "daemon")] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum TailKind { + Started, + Ended, +} + +#[cfg(feature = "daemon")] +#[derive(Clone, Debug, Eq, PartialEq)] +struct TailEvent { + kind: TailKind, + history: History, +} + +#[cfg(feature = "daemon")] +#[derive(Serialize)] +struct TailJsonEvent<'a> { + event: &'static str, + history: TailJsonHistory<'a>, +} + +#[cfg(feature = "daemon")] +#[derive(Serialize)] +struct TailJsonHistory<'a> { + id: &'a str, + timestamp: String, + timestamp_unix_ns: u64, + command: &'a str, + cwd: &'a str, + session: &'a str, + hostname: &'a str, + host: &'a str, + user: &'a str, + author: &'a str, + #[serde(skip_serializing_if = "Option::is_none")] + intent: Option<&'a str>, + #[serde(skip_serializing_if = "Option::is_none")] + exit: Option<i64>, + #[serde(skip_serializing_if = "Option::is_none")] + duration_ns: Option<i64>, + #[serde(skip_serializing_if = "Option::is_none")] + duration: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + success: Option<bool>, + #[serde(skip_serializing_if = "Option::is_none")] + finished_at: Option<String>, +} + +#[cfg(feature = "daemon")] +impl TailEvent { + fn from_proto(reply: TailHistoryReply) -> Result<Self> { + let history = reply + .history + .ok_or_else(|| eyre::eyre!("daemon sent a history tail event without history"))?; + let timestamp = OffsetDateTime::from_unix_timestamp_nanos(i128::from(history.timestamp)) + .context("invalid daemon history timestamp")?; + let kind = match HistoryEventKind::try_from(reply.kind) + .unwrap_or(HistoryEventKind::Unspecified) + { + HistoryEventKind::Started => TailKind::Started, + HistoryEventKind::Ended => TailKind::Ended, + HistoryEventKind::Unspecified => bail!("daemon sent an unspecified history tail event"), + }; + + Ok(Self { + kind, + history: History { + id: history.id.into(), + timestamp, + duration: history.duration, + exit: history.exit, + command: history.command, + cwd: history.cwd, + session: history.session, + hostname: history.hostname, + author: history.author, + intent: normalize_optional_field(&history.intent), + deleted_at: None, + }, + }) + } + + fn render(&self, tty: bool, tz: Timezone) -> Result<String> { + if tty { + Ok(self.render_pretty(tz)) + } else { + let mut json = self.render_json(tz)?; + json.push('\n'); + Ok(json) + } + } + + fn render_json(&self, tz: Timezone) -> Result<String> { + let payload = TailJsonEvent { + event: self.kind.as_str(), + history: TailJsonHistory { + id: &self.history.id.0, + timestamp: format_history_time(self.history.timestamp, tz)?, + timestamp_unix_ns: u64::try_from(self.history.timestamp.unix_timestamp_nanos()) + .context("history timestamp predates unix epoch")?, + command: &self.history.command, + cwd: &self.history.cwd, + session: &self.history.session, + hostname: &self.history.hostname, + host: self.host(), + user: self.user(), + author: &self.history.author, + intent: self.history.intent.as_deref(), + exit: self.exit_value(), + duration_ns: self.duration_value(), + duration: self.duration_value().map(format_duration_ns), + success: self.success_value(), + finished_at: self + .finished_at() + .map(|time| format_history_time(time, tz)) + .transpose()?, + }, + }; + + Ok(serde_json::to_string(&payload)?) + } + + fn render_pretty(&self, tz: Timezone) -> String { + let mut out = String::new(); + let border = match self.kind { + TailKind::Started => "-".repeat(72).bright_blue().to_string(), + TailKind::Ended if self.history.exit == 0 => "-".repeat(72).bright_green().to_string(), + TailKind::Ended => "-".repeat(72).bright_red().to_string(), + }; + + out.push_str(&border); + out.push('\n'); + + let command = self.history.command.trim(); + let escaped_command = command.escape_control(); + let mut command_lines = escaped_command.lines(); + let header = format!( + "{} {}", + self.kind.badge(self.history.exit), + command_lines.next().unwrap_or_default().bold() + ); + out.push_str(&header); + out.push('\n'); + + for line in command_lines { + out.push_str(" "); + out.push_str(line); + out.push('\n'); + } + + push_pretty_field( + &mut out, + "start", + &format_history_time(self.history.timestamp, tz) + .unwrap_or_else(|_| "invalid".to_owned()), + ); + push_pretty_field(&mut out, "history", &self.history.id.0); + push_pretty_field(&mut out, "session", &self.history.session); + push_pretty_field(&mut out, "exit", &self.exit_display()); + push_pretty_field(&mut out, "duration", &self.duration_display()); + + out.push('\n'); + + push_pretty_field(&mut out, "cwd", &self.history.cwd); + push_pretty_field(&mut out, "hostname", &self.history.hostname); + push_pretty_field(&mut out, "host", self.host()); + push_pretty_field(&mut out, "user", self.user()); + push_pretty_field(&mut out, "author", &self.history.author); + + if let Some(intent) = self.history.intent.as_deref() { + push_pretty_field(&mut out, "intent", intent); + } + + if let Some(finished) = self.finished_at() { + let finished = + format_history_time(finished, tz).unwrap_or_else(|_| "invalid".to_owned()); + push_pretty_field(&mut out, "finished", &finished); + } + + out.push_str(&border); + out.push_str("\n\n"); + out + } + + fn host(&self) -> &str { + self.history + .hostname + .split_once(':') + .map_or(self.history.hostname.as_str(), |(host, _)| host) + } + + fn user(&self) -> &str { + self.history + .hostname + .split_once(':') + .map_or("", |(_, user)| user) + } + + fn exit_value(&self) -> Option<i64> { + matches!(self.kind, TailKind::Ended).then_some(self.history.exit) + } + + fn duration_value(&self) -> Option<i64> { + matches!(self.kind, TailKind::Ended).then_some(self.history.duration) + } + + fn success_value(&self) -> Option<bool> { + matches!(self.kind, TailKind::Ended).then_some(self.history.exit == 0) + } + + fn finished_at(&self) -> Option<OffsetDateTime> { + self.duration_value() + .filter(|duration| *duration >= 0) + .map(time::Duration::nanoseconds) + .and_then(|duration| self.history.timestamp.checked_add(duration)) + } + + fn exit_display(&self) -> String { + match self.exit_value() { + Some(0) => "0 (success)".bright_green().to_string(), + Some(code) => format!("{code} (failure)").bright_red().to_string(), + None => "pending".bright_yellow().to_string(), + } + } + + fn duration_display(&self) -> String { + match self.duration_value() { + Some(duration) if duration >= 0 => format_duration_ns(duration), + Some(_) => "unknown".bright_yellow().to_string(), + None => "running".bright_yellow().to_string(), + } + } +} + +#[cfg(feature = "daemon")] +impl TailKind { + const fn as_str(self) -> &'static str { + match self { + Self::Started => "started", + Self::Ended => "ended", + } + } + + fn badge(self, exit: i64) -> colored::ColoredString { + match self { + Self::Started => "STARTED".bold().bright_blue(), + Self::Ended if exit == 0 => "ENDED".bold().bright_green(), + Self::Ended => "ENDED".bold().bright_red(), + } + } +} + +#[cfg(feature = "daemon")] +fn format_history_time(timestamp: OffsetDateTime, tz: Timezone) -> Result<String> { + Ok(timestamp.to_offset(tz.0).format(TIME_FMT)?) +} + +#[cfg(feature = "daemon")] +fn format_duration_ns(duration_ns: i64) -> String { + struct F(Duration); + impl Display for F { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + format_duration_into(self.0, f) + } + } + + F(Duration::from_nanos(duration_ns.max(0).cast_unsigned())).to_string() +} + +#[cfg(feature = "daemon")] +fn push_pretty_field(out: &mut String, label: &str, value: &str) { + out.push_str(" "); + let label = format!("{label}:"); + out.push_str(&label.bright_cyan().bold().to_string()); + if label.len() < 10 { + out.push_str(&" ".repeat(10 - label.len())); + } + + let mut lines = value.lines(); + if let Some(first) = lines.next() { + out.push_str(first); + } + out.push('\n'); + + for line in lines { + out.push_str(" "); + out.push_str(line); + out.push('\n'); + } +} + +#[cfg(feature = "daemon")] +fn normalize_optional_field(value: &str) -> Option<String> { + let trimmed = value.trim(); + if trimmed.is_empty() { + None + } else { + Some(trimmed.to_owned()) + } +} + +impl Cmd { + #[cfg(feature = "daemon")] + async fn handle_tail(settings: &Settings) -> Result<()> { + let tty = std::io::stdout().is_terminal(); + let mut client = daemon::tail_client(settings).await?; + let mut stream = client.tail_history().await?; + let stdout = std::io::stdout(); + + while let Some(reply) = stream.message().await? { + let event = TailEvent::from_proto(reply)?; + let rendered = event.render(tty, settings.timezone)?; + let mut out = stdout.lock(); + + match out.write_all(rendered.as_bytes()) { + Ok(()) => out.flush()?, + Err(err) if err.kind() == io::ErrorKind::BrokenPipe => break, + Err(err) => return Err(err.into()), + } + } + + Ok(()) + } + + #[expect(clippy::too_many_lines, clippy::cast_possible_truncation)] + #[expect(clippy::too_many_arguments)] + #[expect(clippy::fn_params_excessive_bools)] + async fn handle_list( + db: &impl Database, + settings: &Settings, + context: crate::atuin_client::database::Context, + session: bool, + cwd: bool, + mode: ListMode, + format: Option<String>, + include_deleted: bool, + print0: bool, + reverse: bool, + tz: Timezone, + ) -> Result<()> { + let filters = match (session, cwd) { + (true, true) => [Session, Directory], + (true, false) => [Session, Global], + (false, true) => [Global, Directory], + (false, false) => [ + settings.default_filter_mode(context.git_root.is_some()), + Global, + ], + }; + + let history = db + .list(&filters, &context, None, false, include_deleted) + .await?; + + print_list( + &history, + mode, + match format { + None => Some(settings.history_format.as_str()), + _ => format.as_deref(), + }, + print0, + reverse, + tz, + ); + + Ok(()) + } + + async fn handle_prune( + db: &impl Database, + settings: &Settings, + store: SqliteStore, + context: crate::atuin_client::database::Context, + dry_run: bool, + ) -> Result<()> { + // Grab all executed commands and filter them using History::should_save. + // We could iterate or paginate here if memory usage becomes an issue. + let matches: Vec<History> = db + .list(&[Global], &context, None, false, false) + .await? + .into_iter() + .filter(|h| !h.should_save(settings)) + .collect(); + + match matches.len() { + 0 => { + println!("No entries to prune."); + return Ok(()); + } + 1 => println!("Found 1 entry to prune."), + n => println!("Found {n} entries to prune."), + } + + if dry_run { + print_list( + &matches, + ListMode::Human, + Some(settings.history_format.as_str()), + false, + false, + settings.timezone, + ); + } else { + let encryption_key: [u8; 32] = encryption::load_key(settings) + .context("could not load encryption key")? + .into(); + let host_id = Settings::host_id().await?; + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + + for entry in matches { + eprintln!("deleting {}", entry.id); + let (id, _) = history_store.delete(entry.id.clone()).await?; + history_store.incremental_build(db, &[id]).await?; + } + + #[cfg(feature = "daemon")] + daemon_cmd::emit_event(settings, crate::atuin_daemon::DaemonEvent::HistoryPruned).await; + } + Ok(()) + } + + async fn handle_dedup( + db: &impl Database, + settings: &Settings, + store: SqliteStore, + before: i64, + dupkeep: u32, + dry_run: bool, + ) -> Result<()> { + if dupkeep == 0 { + eprintln!( + "\"--dupkeep 0\" would keep 0 copies of duplicate commands and thus delete all of them! Use \"atuin search --delete ...\" if you really want that." + ); + std::process::exit(1); + } + + let matches: Vec<History> = db.get_dups(before, dupkeep).await?; + + match matches.len() { + 0 => { + println!("No duplicates to delete."); + return Ok(()); + } + 1 => println!("Found 1 duplicate to delete."), + n => println!("Found {n} duplicates to delete."), + } + + if dry_run { + print_list( + &matches, + ListMode::Human, + Some(settings.history_format.as_str()), + false, + false, + settings.timezone, + ); + } else { + let encryption_key: [u8; 32] = encryption::load_key(settings) + .context("could not load encryption key")? + .into(); + let host_id = Settings::host_id().await?; + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + + #[cfg(feature = "daemon")] + let ids = matches.iter().map(|h| h.id.clone()).collect::<Vec<_>>(); + + for entry in matches { + eprintln!("deleting {}", entry.id); + let (id, _) = history_store.delete(entry.id).await?; + history_store.incremental_build(db, &[id]).await?; + } + + #[cfg(feature = "daemon")] + daemon_cmd::emit_event( + settings, + crate::atuin_daemon::DaemonEvent::HistoryDeleted { ids }, + ) + .await; + } + Ok(()) + } + + #[expect(clippy::too_many_lines)] + pub async fn run(self, settings: &Settings) -> Result<()> { + match self { + Self::Start { + cmd_env, + author, + intent, + command, + } => { + let command = if cmd_env { + std::env::var("ATUIN_COMMAND_LINE").unwrap_or_default() + } else { + command.join(" ") + }; + + if let Some(id) = + start_history_entry(settings, &command, author.as_deref(), intent.as_deref()) + .await? + { + println!("{id}"); + } + + Ok(()) + } + Self::End { id, exit, duration } => { + end_history_entry(settings, &id, exit, duration).await + } + Self::Tail => { + #[cfg(feature = "daemon")] + { + return Self::handle_tail(settings).await; + } + + #[cfg(not(feature = "daemon"))] + bail!("`atuin history tail` requires Atuin to be built with the `daemon` feature"); + } + cmd => { + let context = current_context().await?; + + let db_path = PathBuf::from(settings.db_path.as_str()); + let record_store_path = PathBuf::from(settings.record_store_path.as_str()); + + let db = Sqlite::new(db_path, settings.local_timeout).await?; + let store = SqliteStore::new(record_store_path, settings.local_timeout).await?; + + let encryption_key: [u8; 32] = encryption::load_key(settings) + .context("could not load encryption key")? + .into(); + + let host_id = Settings::host_id().await?; + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + + match cmd { + Self::List { + session, + cwd, + human, + cmd_only, + print0, + reverse, + timezone, + format, + } => { + let mode = ListMode::from_flags(human, cmd_only); + let tz = timezone.unwrap_or(settings.timezone); + Self::handle_list( + &db, settings, context, session, cwd, mode, format, false, print0, + reverse, tz, + ) + .await + } + + Self::Last { + human, + cmd_only, + timezone, + format, + } => { + let last = db.last().await?; + let last = last.as_slice(); + let tz = timezone.unwrap_or(settings.timezone); + print_list( + last, + ListMode::from_flags(human, cmd_only), + match format { + None => Some(settings.history_format.as_str()), + _ => format.as_deref(), + }, + false, + true, + tz, + ); + + Ok(()) + } + + Self::InitStore => history_store.init_store(&db).await, + + Self::Prune { dry_run } => { + Self::handle_prune(&db, settings, store, context, dry_run).await + } + + Self::Dedup { + dry_run, + before, + dupkeep, + } => { + let before = i64::try_from( + interim::parse_date_string( + before.as_str(), + OffsetDateTime::now_utc(), + interim::Dialect::Uk, + )? + .unix_timestamp_nanos(), + )?; + Self::handle_dedup(&db, settings, store, before, dupkeep, dry_run).await + } + + Self::Start { .. } | Self::End { .. } | Self::Tail => unreachable!(), + } + } + } + } +} + +#[cfg(test)] +mod tests { + #[cfg(feature = "daemon")] + use time::macros::datetime; + + use super::*; + + #[test] + fn normalize_command_strips_trailing_spaces_and_tabs() { + let settings = Settings::utc(); + + assert!(settings.strip_trailing_whitespace); + assert_eq!(normalize_command_for_storage("ls \t", &settings), "ls"); + } + + #[test] + fn normalize_command_preserves_escaped_trailing_space() { + let settings = Settings::utc(); + + assert_eq!( + normalize_command_for_storage("printf foo\\ ", &settings), + "printf foo\\ " + ); + assert_eq!( + normalize_command_for_storage("printf foo\\\\ ", &settings), + "printf foo\\\\" + ); + } + + #[tokio::test] + async fn handle_start_saves_trimmed_command() { + let db = Sqlite::new("sqlite::memory:", 2.0).await.unwrap(); + let settings = Settings::utc(); + + handle_start(&db, &settings, "ls \t", None, None) + .await + .unwrap(); + + let history = db + .before(OffsetDateTime::now_utc() + time::Duration::SECOND, 1) + .await + .unwrap() + .pop() + .unwrap(); + assert_eq!(history.command, "ls"); + } + + #[tokio::test] + async fn handle_start_can_keep_trailing_whitespace() { + let db = Sqlite::new("sqlite::memory:", 2.0).await.unwrap(); + let settings = Settings { + strip_trailing_whitespace: false, + ..Settings::utc() + }; + + handle_start(&db, &settings, "ls \t", None, None) + .await + .unwrap(); + + let history = db + .before(OffsetDateTime::now_utc() + time::Duration::SECOND, 1) + .await + .unwrap() + .pop() + .unwrap(); + assert_eq!(history.command, "ls \t"); + } + + #[test] + fn test_format_string_no_panic() { + // Don't panic but provide helpful output (issue #2776) + let malformed_json = r#"{"command":"{command}","key":"value"}"#; + + let result = std::panic::catch_unwind(|| parse_fmt(malformed_json)); + + assert!(result.is_ok()); + } + + #[test] + fn test_valid_formats_still_work() { + assert!(std::panic::catch_unwind(|| parse_fmt("{command}")).is_ok()); + assert!(std::panic::catch_unwind(|| parse_fmt("{time} - {command}")).is_ok()); + } + + #[cfg(feature = "daemon")] + fn sample_tail_event(kind: TailKind) -> TailEvent { + TailEvent { + kind, + history: History { + id: "history-id".to_owned().into(), + timestamp: datetime!(2026-04-09 17:18:19 UTC), + duration: 12_345_678, + exit: 0, + command: "git status".to_owned(), + cwd: "/tmp/repo".to_owned(), + session: "session-id".to_owned(), + hostname: "host:ellie".to_owned(), + author: "claude".to_owned(), + intent: Some("inspect repository state".to_owned()), + deleted_at: None, + }, + } + } + + #[cfg(feature = "daemon")] + #[test] + fn test_tail_json_output_contains_history_fields() { + let json = sample_tail_event(TailKind::Ended) + .render(false, Timezone(time::UtcOffset::UTC)) + .unwrap(); + let value: serde_json::Value = serde_json::from_str(&json).unwrap(); + + assert_eq!(value["event"], "ended"); + assert_eq!(value["history"]["id"], "history-id"); + assert_eq!(value["history"]["duration_ns"], 12_345_678); + assert_eq!(value["history"]["success"], true); + assert!(value.get("record").is_none()); + } + + #[cfg(feature = "daemon")] + #[test] + fn test_tail_pretty_output_shows_pending_fields_for_started_events() { + let rendered = sample_tail_event(TailKind::Started) + .render(true, Timezone(time::UtcOffset::UTC)) + .unwrap(); + let plain = regex::Regex::new(r"\x1b\[[0-9;]*m") + .unwrap() + .replace_all(&rendered, ""); + + assert!(plain.contains("STARTED git status")); + assert!(plain.contains("exit:")); + assert!(plain.contains("pending")); + assert!(plain.contains("duration:")); + assert!(plain.contains("running")); + } +} diff --git a/crates/turtle/src/command/client/import.rs b/crates/turtle/src/command/client/import.rs new file mode 100644 index 00000000..363e6405 --- /dev/null +++ b/crates/turtle/src/command/client/import.rs @@ -0,0 +1,186 @@ +use std::env; + +use async_trait::async_trait; +use clap::Parser; +use eyre::Result; +use indicatif::ProgressBar; + +use crate::atuin_client::{ + database::Database, + history::History, + import::{ + Importer, Loader, bash::Bash, fish::Fish, nu::Nu, nu_histdb::NuHistDb, + powershell::PowerShell, replxx::Replxx, resh::Resh, xonsh::Xonsh, + xonsh_sqlite::XonshSqlite, zsh::Zsh, zsh_histdb::ZshHistDb, + }, +}; + +#[derive(Parser, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Import history for the current shell + Auto, + + /// Import history from the zsh history file + Zsh, + /// Import history from the zsh history file + ZshHistDb, + /// Import history from the bash history file + Bash, + /// Import history from the replxx history file + Replxx, + /// Import history from the resh history file + Resh, + /// Import history from the fish history file + Fish, + /// Import history from the nu history file + Nu, + /// Import history from the nu history file + NuHistDb, + /// Import history from xonsh json files + Xonsh, + /// Import history from xonsh sqlite db + XonshSqlite, + /// Import history from the powershell history file + Powershell, +} + +const BATCH_SIZE: usize = 100; + +impl Cmd { + #[expect(clippy::cognitive_complexity)] + pub async fn run<DB: Database>(&self, db: &DB) -> Result<()> { + println!(" Atuin "); + println!("======================"); + println!(" \u{1f30d} "); + println!(" \u{1f418}\u{1f418}\u{1f418}\u{1f418} "); + println!(" \u{1f422} "); + println!("======================"); + println!("Importing history..."); + + match self { + Self::Auto => { + if cfg!(windows) { + return if env::var("PSModulePath").is_ok() { + println!("Detected PowerShell"); + import::<PowerShell, DB>(db).await + } else { + println!("Could not detect the current shell."); + println!("Please run atuin import <SHELL>."); + println!("To view a list of shells, run atuin import."); + Ok(()) + }; + } + + // $XONSH_HISTORY_BACKEND isn't always set, but $XONSH_HISTORY_FILE is + let xonsh_histfile = + env::var("XONSH_HISTORY_FILE").unwrap_or_else(|_| String::new()); + let shell = env::var("SHELL").unwrap_or_else(|_| String::from("NO_SHELL")); + + if xonsh_histfile.to_lowercase().ends_with(".json") { + println!("Detected Xonsh"); + import::<Xonsh, DB>(db).await + } else if xonsh_histfile.to_lowercase().ends_with(".sqlite") { + println!("Detected Xonsh (SQLite backend)"); + import::<XonshSqlite, DB>(db).await + } else if shell.ends_with("/zsh") { + if ZshHistDb::histpath().is_ok() { + println!( + "Detected Zsh-HistDb, using :{}", + ZshHistDb::histpath().unwrap().to_str().unwrap() + ); + import::<ZshHistDb, DB>(db).await + } else { + println!("Detected ZSH"); + import::<Zsh, DB>(db).await + } + } else if shell.ends_with("/fish") { + println!("Detected Fish"); + import::<Fish, DB>(db).await + } else if shell.ends_with("/bash") { + println!("Detected Bash"); + import::<Bash, DB>(db).await + } else if shell.ends_with("/nu") { + if NuHistDb::histpath().is_ok() { + println!( + "Detected Nu-HistDb, using :{}", + NuHistDb::histpath().unwrap().to_str().unwrap() + ); + import::<NuHistDb, DB>(db).await + } else { + println!("Detected Nushell"); + import::<Nu, DB>(db).await + } + } else if shell.ends_with("/pwsh") { + println!("Detected PowerShell"); + import::<PowerShell, DB>(db).await + } else { + println!("cannot import {shell} history"); + Ok(()) + } + } + + Self::Zsh => import::<Zsh, DB>(db).await, + Self::ZshHistDb => import::<ZshHistDb, DB>(db).await, + Self::Bash => import::<Bash, DB>(db).await, + Self::Replxx => import::<Replxx, DB>(db).await, + Self::Resh => import::<Resh, DB>(db).await, + Self::Fish => import::<Fish, DB>(db).await, + Self::Nu => import::<Nu, DB>(db).await, + Self::NuHistDb => import::<NuHistDb, DB>(db).await, + Self::Xonsh => import::<Xonsh, DB>(db).await, + Self::XonshSqlite => import::<XonshSqlite, DB>(db).await, + Self::Powershell => import::<PowerShell, DB>(db).await, + } + } +} + +pub struct HistoryImporter<'db, DB: Database> { + pb: ProgressBar, + buf: Vec<History>, + db: &'db DB, +} + +impl<'db, DB: Database> HistoryImporter<'db, DB> { + fn new(db: &'db DB, len: usize) -> Self { + Self { + pb: ProgressBar::new(len as u64), + buf: Vec::with_capacity(BATCH_SIZE), + db, + } + } + + async fn flush(self) -> Result<()> { + if !self.buf.is_empty() { + self.db.save_bulk(&self.buf).await?; + } + self.pb.finish(); + Ok(()) + } +} + +#[async_trait] +impl<DB: Database> Loader for HistoryImporter<'_, DB> { + async fn push(&mut self, hist: History) -> Result<()> { + self.pb.inc(1); + self.buf.push(hist); + if self.buf.len() == self.buf.capacity() { + self.db.save_bulk(&self.buf).await?; + self.buf.clear(); + } + Ok(()) + } +} + +async fn import<I: Importer + Send, DB: Database>(db: &DB) -> Result<()> { + println!("Importing history from {}", I::NAME); + + let mut importer = I::new().await?; + let len = importer.entries().await.unwrap(); + let mut loader = HistoryImporter::new(db, len); + importer.load(&mut loader).await?; + loader.flush().await?; + + println!("Import complete!"); + Ok(()) +} diff --git a/crates/turtle/src/command/client/info.rs b/crates/turtle/src/command/client/info.rs new file mode 100644 index 00000000..ee24c419 --- /dev/null +++ b/crates/turtle/src/command/client/info.rs @@ -0,0 +1,31 @@ +use crate::atuin_client::settings::Settings;
+
+use crate::{SHA, VERSION};
+
+pub fn run(settings: &Settings) {
+ let config = crate::atuin_common::utils::config_dir();
+ let mut config_file = config.clone();
+ config_file.push("config.toml");
+ let mut sever_config = config;
+ sever_config.push("server.toml");
+
+ let config_paths = format!(
+ "Config files:\nclient config: {:?}\nserver config: {:?}\nclient db path: {:?}\nkey path: {:?}\nmeta db path: {:?}",
+ config_file.to_string_lossy(),
+ sever_config.to_string_lossy(),
+ settings.db_path,
+ settings.key_path,
+ settings.meta.db_path
+ );
+
+ let env_vars = format!(
+ "Env Vars:\nATUIN_CONFIG_DIR = {:?}",
+ std::env::var("ATUIN_CONFIG_DIR").unwrap_or_else(|_| "None".into())
+ );
+
+ let general_info = format!("Version info:\nversion: {VERSION}\ncommit: {SHA}");
+
+ let print_out = format!("{config_paths}\n\n{env_vars}\n\n{general_info}");
+
+ println!("{print_out}");
+}
diff --git a/crates/turtle/src/command/client/init.rs b/crates/turtle/src/command/client/init.rs new file mode 100644 index 00000000..bf9747bb --- /dev/null +++ b/crates/turtle/src/command/client/init.rs @@ -0,0 +1,127 @@ +use crate::atuin_client::settings::{Settings, Tmux}; +use clap::{Parser, ValueEnum}; + +mod bash; +mod fish; +mod powershell; +mod xonsh; +mod zsh; + +#[derive(Parser, Debug)] +pub struct Cmd { + shell: Shell, + + /// Disable the binding of CTRL-R to atuin + #[clap(long)] + disable_ctrl_r: bool, + + /// Disable the binding of the Up Arrow key to atuin + #[clap(long)] + disable_up_arrow: bool, + + /// Disable the binding of ? to Atuin AI + #[clap(long)] + disable_ai: bool, +} + +#[derive(Clone, Copy, ValueEnum, Debug)] +#[value(rename_all = "lower")] +#[expect(clippy::enum_variant_names, clippy::doc_markdown)] +pub enum Shell { + /// Zsh setup + Zsh, + /// Bash setup + Bash, + /// Fish setup + Fish, + /// Nu setup + Nu, + /// Xonsh setup + Xonsh, + /// PowerShell setup + PowerShell, +} + +impl Cmd { + fn init_nu(&self, _tmux: &Tmux) { + let full = include_str!("../../shell/atuin.nu"); + + // TODO: tmux popup for Nu + println!("{full}"); + + if std::env::var("ATUIN_NOBIND").is_err() { + const BIND_CTRL_R: &str = r"$env.config = ( + $env.config | upsert keybindings ( + $env.config.keybindings + | append { + name: atuin + modifier: control + keycode: char_r + mode: [emacs, vi_normal, vi_insert] + event: { send: executehostcommand cmd: (_atuin_search_cmd) } + } + ) +)"; + const BIND_UP_ARROW: &str = r" +$env.config = ( + $env.config | upsert keybindings ( + $env.config.keybindings + | append { + name: atuin + modifier: none + keycode: up + mode: [emacs, vi_normal, vi_insert] + event: { + until: [ + {send: menuup} + {send: executehostcommand cmd: (_atuin_search_cmd '--shell-up-key-binding') } + ] + } + } + ) +) +"; + if !self.disable_ctrl_r { + println!("{BIND_CTRL_R}"); + } + if !self.disable_up_arrow { + println!("{BIND_UP_ARROW}"); + } + } + } + + fn static_init(&self, settings: &Settings) { + let tmux = &settings.tmux; + + match self.shell { + Shell::Zsh => { + zsh::init_static(self.disable_up_arrow, self.disable_ctrl_r, tmux); + } + Shell::Bash => { + bash::init_static(self.disable_up_arrow, self.disable_ctrl_r, tmux); + } + Shell::Fish => { + fish::init_static(self.disable_up_arrow, self.disable_ctrl_r, tmux); + } + Shell::Nu => { + self.init_nu(tmux); + } + Shell::Xonsh => { + xonsh::init_static(self.disable_up_arrow, self.disable_ctrl_r, tmux); + } + Shell::PowerShell => { + powershell::init_static(self.disable_up_arrow, self.disable_ctrl_r, tmux); + } + } + } + + pub fn run(self, settings: &Settings) { + if !settings.paths_ok() { + eprintln!( + "Atuin settings paths are broken. Disabling atuin shell hooks. Run `atuin doctor` to diagnose." + ); + } + + self.static_init(settings); + } +} diff --git a/crates/turtle/src/command/client/init/bash.rs b/crates/turtle/src/command/client/init/bash.rs new file mode 100644 index 00000000..fd17e37e --- /dev/null +++ b/crates/turtle/src/command/client/init/bash.rs @@ -0,0 +1,25 @@ +use crate::atuin_client::settings::Tmux; + +fn print_tmux_config(tmux: &Tmux) { + if tmux.enabled { + println!("export ATUIN_TMUX_POPUP_WIDTH='{}'", tmux.width); + println!("export ATUIN_TMUX_POPUP_HEIGHT='{}'", tmux.height); + } else { + println!("export ATUIN_TMUX_POPUP=false"); + } +} + +pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, tmux: &Tmux) { + let base = include_str!("../../../shell/atuin.bash"); + + let (bind_ctrl_r, bind_up_arrow) = if std::env::var("ATUIN_NOBIND").is_ok() { + (false, false) + } else { + (!disable_ctrl_r, !disable_up_arrow) + }; + + print_tmux_config(tmux); + println!("__atuin_bind_ctrl_r={bind_ctrl_r}"); + println!("__atuin_bind_up_arrow={bind_up_arrow}"); + println!("{base}"); +} diff --git a/crates/turtle/src/command/client/init/fish.rs b/crates/turtle/src/command/client/init/fish.rs new file mode 100644 index 00000000..8a046bfa --- /dev/null +++ b/crates/turtle/src/command/client/init/fish.rs @@ -0,0 +1,86 @@ +use crate::atuin_client::settings::Tmux; + +fn print_tmux_config(tmux: &Tmux) { + if tmux.enabled { + println!("set -gx ATUIN_TMUX_POPUP_WIDTH '{}'", tmux.width); + println!("set -gx ATUIN_TMUX_POPUP_HEIGHT '{}'", tmux.height); + } else { + println!("set -gx ATUIN_TMUX_POPUP false"); + } +} + +fn print_bindings( + indent: &str, + disable_up_arrow: bool, + disable_ctrl_r: bool, + bind_ctrl_r: &str, + bind_up_arrow: &str, + bind_ctrl_r_ins: &str, + bind_up_arrow_ins: &str, +) { + if !disable_ctrl_r { + println!("{indent}{bind_ctrl_r}"); + } + if !disable_up_arrow { + println!("{indent}{bind_up_arrow}"); + } + + println!("{indent}if bind -M insert >/dev/null 2>&1"); + if !disable_ctrl_r { + println!("{indent}{indent}{bind_ctrl_r_ins}"); + } + if !disable_up_arrow { + println!("{indent}{indent}{bind_up_arrow_ins}"); + } + println!("{indent}end"); +} + +pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, tmux: &Tmux) { + let indent = " ".repeat(4); + + let base = include_str!("../../../shell/atuin.fish"); + + print_tmux_config(tmux); + println!("{base}"); + + if std::env::var("ATUIN_NOBIND").is_err() { + println!("if string match -q '4.*' $version"); + + // In fish 4.0 and above the option bind -k doesn't exist anymore, + // instead we can use key names and modifiers directly. + print_bindings( + &indent, + disable_up_arrow, + disable_ctrl_r, + "bind ctrl-r _atuin_search", + "bind up _atuin_bind_up", + "bind -M insert ctrl-r _atuin_search", + "bind -M insert up _atuin_bind_up", + ); + + println!("else"); + + // We keep these for compatibility with fish 3.x + print_bindings( + &indent, + disable_up_arrow, + disable_ctrl_r, + r"bind \cr _atuin_search", + &[ + r"bind -k up _atuin_bind_up", + r"bind \eOA _atuin_bind_up", + r"bind \e\[A _atuin_bind_up", + ] + .join("; "), + r"bind -M insert \cr _atuin_search", + &[ + r"bind -M insert -k up _atuin_bind_up", + r"bind -M insert \eOA _atuin_bind_up", + r"bind -M insert \e\[A _atuin_bind_up", + ] + .join("; "), + ); + + println!("end"); + } +} diff --git a/crates/turtle/src/command/client/init/powershell.rs b/crates/turtle/src/command/client/init/powershell.rs new file mode 100644 index 00000000..10c0c461 --- /dev/null +++ b/crates/turtle/src/command/client/init/powershell.rs @@ -0,0 +1,23 @@ +use crate::atuin_client::settings::Tmux; + +pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, _tmux: &Tmux) { + let base = include_str!("../../../shell/atuin.ps1"); + + let (bind_ctrl_r, bind_up_arrow) = if std::env::var("ATUIN_NOBIND").is_ok() { + (false, false) + } else { + (!disable_ctrl_r, !disable_up_arrow) + }; + + // TODO: tmux popup for Powershell + println!("{base}"); + println!( + "Enable-AtuinSearchKeys -CtrlR {} -UpArrow {}", + ps_bool(bind_ctrl_r), + ps_bool(bind_up_arrow) + ); +} + +fn ps_bool(value: bool) -> &'static str { + if value { "$true" } else { "$false" } +} diff --git a/crates/turtle/src/command/client/init/xonsh.rs b/crates/turtle/src/command/client/init/xonsh.rs new file mode 100644 index 00000000..a17d85d8 --- /dev/null +++ b/crates/turtle/src/command/client/init/xonsh.rs @@ -0,0 +1,22 @@ +use crate::atuin_client::settings::Tmux; + +pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, _tmux: &Tmux) { + let base = include_str!("../../../shell/atuin.xsh"); + + let (bind_ctrl_r, bind_up_arrow) = if std::env::var("ATUIN_NOBIND").is_ok() { + (false, false) + } else { + (!disable_ctrl_r, !disable_up_arrow) + }; + + // TODO: tmux popup for xonsh + println!( + "_ATUIN_BIND_CTRL_R={}", + if bind_ctrl_r { "True" } else { "False" } + ); + println!( + "_ATUIN_BIND_UP_ARROW={}", + if bind_up_arrow { "True" } else { "False" } + ); + println!("{base}"); +} diff --git a/crates/turtle/src/command/client/init/zsh.rs b/crates/turtle/src/command/client/init/zsh.rs new file mode 100644 index 00000000..38c3086b --- /dev/null +++ b/crates/turtle/src/command/client/init/zsh.rs @@ -0,0 +1,38 @@ +use crate::atuin_client::settings::Tmux; + +fn print_tmux_config(tmux: &Tmux) { + if tmux.enabled { + println!("export ATUIN_TMUX_POPUP_WIDTH='{}'", tmux.width); + println!("export ATUIN_TMUX_POPUP_HEIGHT='{}'", tmux.height); + } else { + println!("export ATUIN_TMUX_POPUP=false"); + } +} + +pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, tmux: &Tmux) { + let base = include_str!("../../../shell/atuin.zsh"); + + print_tmux_config(tmux); + println!("{base}"); + + if std::env::var("ATUIN_NOBIND").is_err() { + const BIND_CTRL_R: &str = r"bindkey -M emacs '^r' atuin-search +bindkey -M viins '^r' atuin-search-viins +bindkey -M vicmd '/' atuin-search"; + + const BIND_UP_ARROW: &str = r"bindkey -M emacs '^[[A' atuin-up-search +bindkey -M vicmd '^[[A' atuin-up-search-vicmd +bindkey -M viins '^[[A' atuin-up-search-viins +bindkey -M emacs '^[OA' atuin-up-search +bindkey -M vicmd '^[OA' atuin-up-search-vicmd +bindkey -M viins '^[OA' atuin-up-search-viins +bindkey -M vicmd 'k' atuin-up-search-vicmd"; + + if !disable_ctrl_r { + println!("{BIND_CTRL_R}"); + } + if !disable_up_arrow { + println!("{BIND_UP_ARROW}"); + } + } +} diff --git a/crates/turtle/src/command/client/search.rs b/crates/turtle/src/command/client/search.rs new file mode 100644 index 00000000..4a2114d5 --- /dev/null +++ b/crates/turtle/src/command/client/search.rs @@ -0,0 +1,375 @@ +use std::fs::File; +use std::io::{IsTerminal as _, Write, stderr, stdout}; + +use crate::atuin_common::utils::{self, Escapable as _}; +use clap::Parser; +use eyre::Result; + +use crate::atuin_client::{ + database::Database, + database::{OptFilters, current_context}, + encryption, + history::{History, store::HistoryStore}, + record::sqlite_store::SqliteStore, + settings::{FilterMode, KeymapMode, SearchMode, Settings, Timezone}, + theme::Theme, +}; + +use super::history::ListMode; + +mod cursor; +mod duration; +mod engines; +mod history_list; +mod inspector; +mod interactive; +pub mod keybindings; + +pub use duration::format_duration_into; + +#[expect(clippy::struct_excessive_bools, clippy::struct_field_names)] +#[derive(Parser, Debug)] +pub struct Cmd { + /// Filter search result by directory + #[arg(long, short)] + cwd: Option<String>, + + /// Exclude directory from results + #[arg(long = "exclude-cwd")] + exclude_cwd: Option<String>, + + /// Filter search result by exit code + #[arg(long, short)] + exit: Option<i64>, + + /// Exclude results with this exit code + #[arg(long = "exclude-exit")] + exclude_exit: Option<i64>, + + /// Only include results added before this date + #[arg(long, short)] + before: Option<String>, + + /// Only include results after this date + #[arg(long)] + after: Option<String>, + + /// How many entries to return at most + #[arg(long)] + limit: Option<i64>, + + /// Offset from the start of the results + #[arg(long)] + offset: Option<i64>, + + /// Open interactive search UI + #[arg(long, short)] + interactive: bool, + + /// Allow overriding filter mode over config + #[arg(long = "filter-mode")] + filter_mode: Option<FilterMode>, + + /// Allow overriding search mode over config + #[arg(long = "search-mode")] + search_mode: Option<SearchMode>, + + /// Marker argument used to inform atuin that it was invoked from a shell up-key binding (hidden from help to avoid confusion) + #[arg(long = "shell-up-key-binding", hide = true)] + shell_up_key_binding: bool, + + /// Notify the keymap at the shell's side + #[arg(long = "keymap-mode", default_value = "auto")] + keymap_mode: KeymapMode, + + /// Use human-readable formatting for time + #[arg(long)] + human: bool, + + #[arg(allow_hyphen_values = true)] + query: Option<Vec<String>>, + + /// Show only the text of the command + #[arg(long)] + cmd_only: bool, + + /// Terminate the output with a null, for better multiline handling + #[arg(long)] + print0: bool, + + /// Delete anything matching this query. Will not print out the match + #[arg(long)] + delete: bool, + + /// Delete EVERYTHING! + #[arg(long)] + delete_it_all: bool, + + /// Reverse the order of results, oldest first + #[arg(long, short)] + reverse: bool, + + /// Display the command time in another timezone other than the configured default. + /// + /// This option takes one of the following kinds of values: + /// - the special value "local" (or "l") which refers to the system time zone + /// - an offset from UTC (e.g. "+9", "-2:30") + #[arg(long, visible_alias = "tz")] + #[arg(allow_hyphen_values = true)] + // Clippy warns about `Option<Option<T>>`, but we suppress it because we need + // this distinction for proper argument handling. + #[expect(clippy::option_option)] + timezone: Option<Option<Timezone>>, + + /// Available variables: {command}, {directory}, {duration}, {user}, {host}, {time}, {exit} and + /// {relativetime}. + /// Example: --format "{time} - [{duration}] - {directory}$\t{command}" + #[arg(long, short)] + format: Option<String>, + + /// Set the maximum number of lines Atuin's interface should take up. + #[arg(long = "inline-height")] + inline_height: Option<u16>, + + /// Filter by author. Supports $all-user (non-agents), $all-agent, or literal names. + /// Can be specified multiple times. + #[arg(long)] + author: Option<Vec<String>>, + + /// Include duplicate commands in the output (non-interactive only) + #[arg(long)] + include_duplicates: bool, + + /// File name to write the result to (hidden from help as this is meant to be used from a script) + #[arg(long = "result-file", hide = true)] + result_file: Option<String>, +} + +impl Cmd { + /// Returns true if this search command will run in interactive (TUI) mode + pub fn is_interactive(&self) -> bool { + self.interactive + } + + // clippy: please write this instead + // clippy: now it has too many lines + // me: I'll do it later OKAY + #[expect(clippy::too_many_lines)] + pub async fn run( + self, + db: impl Database, + settings: &mut Settings, + store: SqliteStore, + theme: &Theme, + ) -> Result<()> { + let query = self.query.unwrap_or_else(|| { + std::env::var("ATUIN_QUERY").map_or_else( + |_| vec![], + |query| { + query + .split(' ') + .map(std::string::ToString::to_string) + .collect() + }, + ) + }); + + if (self.delete_it_all || self.delete) && self.limit.is_some() { + // Because of how deletion is implemented, it will always delete all matches + // and disregard the limit option. It is also not clear what deletion with a + // limit would even mean. Deleting the LIMIT most recent entries that match + // the search query would make sense, but that wouldn't match what's displayed + // when running the equivalent search, but deleting those entries that are + // displayed with the search would leave any duplicates of those lines which may + // or may not have been intended to be deleted. + eprintln!("\"--limit\" is not compatible with deletion."); + return Ok(()); + } + + if self.delete && query.is_empty() { + eprintln!( + "Please specify a query to match the items you wish to delete. If you wish to delete all history, pass --delete-it-all" + ); + return Ok(()); + } + + if self.delete_it_all && !query.is_empty() { + eprintln!( + "--delete-it-all will delete ALL of your history! It does not require a query." + ); + return Ok(()); + } + + if let Some(search_mode) = self.search_mode { + settings.search_mode = search_mode; + } + if let Some(filter_mode) = self.filter_mode { + settings.filter_mode = Some(filter_mode); + } + if let Some(inline_height) = self.inline_height { + settings.inline_height = inline_height; + } + + settings.shell_up_key_binding = self.shell_up_key_binding; + + // `keymap_mode` specified in config.toml overrides the `--keymap-mode` + // option specified in the keybindings. + settings.keymap_mode = match settings.keymap_mode { + KeymapMode::Auto => self.keymap_mode, + value => value, + }; + settings.keymap_mode_shell = self.keymap_mode; + + let encryption_key: [u8; 32] = encryption::load_key(settings)?.into(); + + let host_id = Settings::host_id().await?; + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + + if self.interactive { + let item = interactive::history(&query, settings, db, &history_store, theme).await?; + + if let Some(result_file) = self.result_file { + let mut file = File::create(result_file)?; + write!(file, "{item}")?; + } else if !stdout().is_terminal() { + // stdout is not a terminal - likely command substitution like VAR=$(atuin search -i) + // Write to stdout so it gets captured. This requires some care on Windows, as the current + // console code page or `[Console]::OutputEncoding` on PowerShell may be different from UTF-8. + println!("{item}"); + } else if stderr().is_terminal() { + eprintln!("{}", item.escape_control()); + } else { + eprintln!("{item}"); + } + } else { + let opt_filter = OptFilters { + exit: self.exit, + exclude_exit: self.exclude_exit, + cwd: self.cwd, + exclude_cwd: self.exclude_cwd, + before: self.before, + after: self.after, + limit: self.limit, + offset: self.offset, + reverse: self.reverse, + include_duplicates: self.include_duplicates, + authors: self.author.clone().unwrap_or_default(), + }; + + let mut entries = + run_non_interactive(settings, opt_filter.clone(), &query, &db).await?; + + if entries.is_empty() { + std::process::exit(1) + } + + // if we aren't deleting, print it all + if self.delete || self.delete_it_all { + // delete it + // it only took me _years_ to add this + // sorry + while !entries.is_empty() { + for entry in &entries { + eprintln!("deleting {}", entry.id); + } + + let ids = history_store.delete_entries(entries).await?; + history_store.incremental_build(&db, &ids).await?; + + entries = + run_non_interactive(settings, opt_filter.clone(), &query, &db).await?; + } + } else { + let format = match self.format { + None => Some(settings.history_format.as_str()), + _ => self.format.as_deref(), + }; + let tz = match self.timezone { + Some(Some(tz)) => tz, // User provided a value + Some(None) | None => settings.timezone, // No value was provided + }; + + super::history::print_list( + &entries, + ListMode::from_flags(self.human, self.cmd_only), + format, + self.print0, + true, + tz, + ); + } + } + Ok(()) + } +} + +// This is supposed to more-or-less mirror the command line version, so ofc +// it is going to have a lot of args +#[expect(clippy::too_many_arguments, clippy::cast_possible_truncation)] +async fn run_non_interactive( + settings: &Settings, + filter_options: OptFilters, + query: &[String], + db: &impl Database, +) -> Result<Vec<History>> { + let dir = if filter_options.cwd.as_deref() == Some(".") { + Some(utils::get_current_dir()) + } else { + filter_options.cwd + }; + + let context = current_context().await?; + + let opt_filter = OptFilters { + cwd: dir.clone(), + ..filter_options + }; + + let filter_mode = settings.default_filter_mode(context.git_root.is_some()); + + let results = db + .search( + settings.search_mode, + filter_mode, + &context, + query.join(" ").as_str(), + opt_filter, + ) + .await?; + + Ok(results) +} + +#[cfg(test)] +mod tests { + use super::Cmd; + use clap::Parser; + + #[test] + fn search_for_triple_dash() { + // Issue #3028: searching for `---` should not be treated as a CLI flag + let cmd = Cmd::try_parse_from(["search", "---"]); + assert!(cmd.is_ok(), "Failed to parse '---' as a query: {cmd:?}"); + let cmd = cmd.unwrap(); + assert_eq!(cmd.query, Some(vec!["---".to_string()])); + } + + #[test] + fn search_for_double_dash_value() { + // Searching for strings starting with -- should also work + let cmd = Cmd::try_parse_from(["search", "--", "--foo"]); + assert!(cmd.is_ok()); + let cmd = cmd.unwrap(); + assert_eq!(cmd.query, Some(vec!["--foo".to_string()])); + } + + #[test] + fn search_author_cli_flag() { + let cmd = + Cmd::try_parse_from(["search", "--author", "codex", "--author", "ellie"]).unwrap(); + assert_eq!( + cmd.author, + Some(vec!["codex".to_string(), "ellie".to_string()]) + ); + } +} diff --git a/crates/turtle/src/command/client/search/cursor.rs b/crates/turtle/src/command/client/search/cursor.rs new file mode 100644 index 00000000..84f94082 --- /dev/null +++ b/crates/turtle/src/command/client/search/cursor.rs @@ -0,0 +1,405 @@ +use crate::atuin_client::settings::WordJumpMode; + +pub struct Cursor { + source: String, + index: usize, +} + +impl From<String> for Cursor { + fn from(source: String) -> Self { + Self { source, index: 0 } + } +} + +pub struct WordJumper<'a> { + word_chars: &'a str, + word_jump_mode: WordJumpMode, +} + +impl WordJumper<'_> { + fn is_word_boundary(&self, c: char, next_c: char) -> bool { + (c.is_whitespace() && !next_c.is_whitespace()) + || (!c.is_whitespace() && next_c.is_whitespace()) + || (self.word_chars.contains(c) && !self.word_chars.contains(next_c)) + || (!self.word_chars.contains(c) && self.word_chars.contains(next_c)) + } + + fn emacs_get_next_word_pos(&self, source: &str, index: usize) -> usize { + let index = (index + 1..source.len().saturating_sub(1)) + .find(|&i| self.word_chars.contains(source.chars().nth(i).unwrap())) + .unwrap_or(source.len()); + (index + 1..source.len().saturating_sub(1)) + .find(|&i| !self.word_chars.contains(source.chars().nth(i).unwrap())) + .unwrap_or(source.len()) + } + + fn emacs_get_prev_word_pos(&self, source: &str, index: usize) -> usize { + let index = (1..index) + .rev() + .find(|&i| self.word_chars.contains(source.chars().nth(i).unwrap())) + .unwrap_or(0); + (1..index) + .rev() + .find(|&i| !self.word_chars.contains(source.chars().nth(i).unwrap())) + .map_or(0, |i| i + 1) + } + + fn subl_get_next_word_pos(&self, source: &str, index: usize) -> usize { + let index = (index..source.len().saturating_sub(1)).find(|&i| { + self.is_word_boundary( + source.chars().nth(i).unwrap(), + source.chars().nth(i + 1).unwrap(), + ) + }); + if index.is_none() { + return source.len(); + } + (index.unwrap() + 1..source.len()) + .find(|&i| !source.chars().nth(i).unwrap().is_whitespace()) + .unwrap_or(source.len()) + } + + fn subl_get_prev_word_pos(&self, source: &str, index: usize) -> usize { + let index = (1..index) + .rev() + .find(|&i| !source.chars().nth(i).unwrap().is_whitespace()); + if index.is_none() { + return 0; + } + (1..index.unwrap()) + .rev() + .find(|&i| { + self.is_word_boundary( + source.chars().nth(i - 1).unwrap(), + source.chars().nth(i).unwrap(), + ) + }) + .unwrap_or(0) + } + + fn get_next_word_pos(&self, source: &str, index: usize) -> usize { + match self.word_jump_mode { + WordJumpMode::Emacs => self.emacs_get_next_word_pos(source, index), + WordJumpMode::Subl => self.subl_get_next_word_pos(source, index), + } + } + + fn get_prev_word_pos(&self, source: &str, index: usize) -> usize { + match self.word_jump_mode { + WordJumpMode::Emacs => self.emacs_get_prev_word_pos(source, index), + WordJumpMode::Subl => self.subl_get_prev_word_pos(source, index), + } + } +} + +impl Cursor { + pub fn as_str(&self) -> &str { + self.source.as_str() + } + + pub fn into_inner(self) -> String { + self.source + } + + /// Returns the string before the cursor + pub fn substring(&self) -> &str { + &self.source[..self.index] + } + + /// Returns the currently selected [`char`] + pub fn char(&self) -> Option<char> { + self.source[self.index..].chars().next() + } + + pub fn right(&mut self) { + if self.index < self.source.len() { + loop { + self.index += 1; + if self.source.is_char_boundary(self.index) { + break; + } + } + } + } + + pub fn left(&mut self) -> bool { + if self.index > 0 { + loop { + self.index -= 1; + if self.source.is_char_boundary(self.index) { + break true; + } + } + } else { + false + } + } + + pub fn next_word(&mut self, word_chars: &str, word_jump_mode: WordJumpMode) { + let word_jumper = WordJumper { + word_chars, + word_jump_mode, + }; + self.index = word_jumper.get_next_word_pos(&self.source, self.index); + } + + pub fn prev_word(&mut self, word_chars: &str, word_jump_mode: WordJumpMode) { + let word_jumper = WordJumper { + word_chars, + word_jump_mode, + }; + self.index = word_jumper.get_prev_word_pos(&self.source, self.index); + } + + /// Move cursor to the end of the current/next word (vim `e` motion). + /// + /// If cursor is in the middle of a word, moves to the end of that word. + /// If cursor is at the end of a word (or on whitespace), moves to the + /// end of the next word. + pub fn word_end(&mut self, word_chars: &str) { + let len = self.source.len(); + if self.index >= len { + return; + } + + let chars: Vec<char> = self.source.chars().collect(); + let mut char_idx = self.source[..self.index].chars().count(); + + if char_idx >= chars.len() { + return; + } + + let current = chars[char_idx]; + + // Check if we're at a word boundary (end of current word or on whitespace) + let at_word_boundary = current.is_whitespace() || char_idx + 1 >= chars.len() || { + let next = chars[char_idx + 1]; + next.is_whitespace() || (word_chars.contains(current) != word_chars.contains(next)) + }; + + // If at word boundary, advance past it and skip whitespace to find next word + if at_word_boundary { + char_idx += 1; + while char_idx < chars.len() && chars[char_idx].is_whitespace() { + char_idx += 1; + } + } + + // If we've gone past end, go to end of string + if char_idx >= chars.len() { + self.index = len; + return; + } + + // Find end of word: advance until next char is whitespace or different word type + let in_word_chars = word_chars.contains(chars[char_idx]); + while char_idx < chars.len() { + let next_idx = char_idx + 1; + if next_idx >= chars.len() { + // At last char, move past it + char_idx = next_idx; + break; + } + let next_c = chars[next_idx]; + if next_c.is_whitespace() || (word_chars.contains(next_c) != in_word_chars) { + // Next char is start of new word/whitespace, so current char is end + char_idx = next_idx; + break; + } + char_idx += 1; + } + + // Convert char index back to byte index + self.index = chars.iter().take(char_idx).map(|c| c.len_utf8()).sum(); + } + + pub fn insert(&mut self, c: char) { + self.source.insert(self.index, c); + self.index += c.len_utf8(); + } + + pub fn remove(&mut self) -> Option<char> { + if self.index < self.source.len() { + Some(self.source.remove(self.index)) + } else { + None + } + } + + pub fn remove_next_word(&mut self, word_chars: &str, word_jump_mode: WordJumpMode) { + let word_jumper = WordJumper { + word_chars, + word_jump_mode, + }; + let next_index = word_jumper.get_next_word_pos(&self.source, self.index); + self.source.replace_range(self.index..next_index, ""); + } + + pub fn remove_prev_word(&mut self, word_chars: &str, word_jump_mode: WordJumpMode) { + let word_jumper = WordJumper { + word_chars, + word_jump_mode, + }; + let next_index = word_jumper.get_prev_word_pos(&self.source, self.index); + self.source.replace_range(next_index..self.index, ""); + self.index = next_index; + } + + pub fn back(&mut self) -> Option<char> { + if self.left() { self.remove() } else { None } + } + + pub fn clear(&mut self) { + self.source.clear(); + self.index = 0; + } + + pub fn clear_to_start(&mut self) { + self.source.replace_range(..self.index, ""); + self.index = 0; + } + + pub fn clear_to_end(&mut self) { + self.source.replace_range(self.index.., ""); + self.index = self.source.len(); + } + + pub fn end(&mut self) { + self.index = self.source.len(); + } + + pub fn start(&mut self) { + self.index = 0; + } + + pub fn position(&self) -> usize { + self.index + } +} + +#[cfg(test)] +mod cursor_tests { + use super::Cursor; + use super::*; + + static EMACS_WORD_JUMPER: WordJumper = WordJumper { + word_chars: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", + word_jump_mode: WordJumpMode::Emacs, + }; + + static SUBL_WORD_JUMPER: WordJumper = WordJumper { + word_chars: "./\\()\"'-:,.;<>~!@#$%^&*|+=[]{}`~?", + word_jump_mode: WordJumpMode::Subl, + }; + + #[test] + fn right() { + // ö is 2 bytes + let mut c = Cursor::from(String::from("öaöböcödöeöfö")); + let indices = [0, 2, 3, 5, 6, 8, 9, 11, 12, 14, 15, 17, 18, 20, 20, 20, 20]; + for i in indices { + assert_eq!(c.index, i); + c.right(); + } + } + + #[test] + fn left() { + // ö is 2 bytes + let mut c = Cursor::from(String::from("öaöböcödöeöfö")); + c.end(); + let indices = [20, 18, 17, 15, 14, 12, 11, 9, 8, 6, 5, 3, 2, 0, 0, 0, 0]; + for i in indices { + assert_eq!(c.index, i); + c.left(); + } + } + + #[test] + fn test_emacs_get_next_word_pos() { + let s = String::from(" aaa ((()))bbb ((())) "); + let indices = [(0, 6), (3, 6), (7, 18), (19, 30)]; + for (i_src, i_dest) in indices { + assert_eq!(EMACS_WORD_JUMPER.get_next_word_pos(&s, i_src), i_dest); + } + assert_eq!(EMACS_WORD_JUMPER.get_next_word_pos("", 0), 0); + } + + #[test] + fn test_emacs_get_prev_word_pos() { + let s = String::from(" aaa ((()))bbb ((())) "); + let indices = [(30, 15), (29, 15), (15, 3), (3, 0)]; + for (i_src, i_dest) in indices { + assert_eq!(EMACS_WORD_JUMPER.get_prev_word_pos(&s, i_src), i_dest); + } + assert_eq!(EMACS_WORD_JUMPER.get_prev_word_pos("", 0), 0); + } + + #[test] + fn test_subl_get_next_word_pos() { + let s = String::from(" aaa ((()))bbb ((())) "); + let indices = [(0, 3), (1, 3), (3, 9), (9, 15), (15, 21), (21, 30)]; + for (i_src, i_dest) in indices { + assert_eq!(SUBL_WORD_JUMPER.get_next_word_pos(&s, i_src), i_dest); + } + assert_eq!(SUBL_WORD_JUMPER.get_next_word_pos("", 0), 0); + } + + #[test] + fn test_subl_get_prev_word_pos() { + let s = String::from(" aaa ((()))bbb ((())) "); + let indices = [(30, 21), (21, 15), (15, 9), (9, 3), (3, 0)]; + for (i_src, i_dest) in indices { + assert_eq!(SUBL_WORD_JUMPER.get_prev_word_pos(&s, i_src), i_dest); + } + assert_eq!(SUBL_WORD_JUMPER.get_prev_word_pos("", 0), 0); + } + + #[test] + fn pop() { + let mut s = String::from("öaöböcödöeöfö"); + let mut c = Cursor::from(s.clone()); + c.end(); + while !s.is_empty() { + let c1 = s.pop(); + let c2 = c.back(); + assert_eq!(c1, c2); + assert_eq!(s.as_str(), c.substring()); + } + let c1 = s.pop(); + let c2 = c.back(); + assert_eq!(c1, c2); + } + + #[test] + fn back() { + let mut c = Cursor::from(String::from("öaöböcödöeöfö")); + // move to ^ + for _ in 0..4 { + c.right(); + } + assert_eq!(c.substring(), "öaöb"); + assert_eq!(c.back(), Some('b')); + assert_eq!(c.back(), Some('ö')); + assert_eq!(c.back(), Some('a')); + assert_eq!(c.back(), Some('ö')); + assert_eq!(c.back(), None); + assert_eq!(c.as_str(), "öcödöeöfö"); + } + + #[test] + fn insert() { + let mut c = Cursor::from(String::from("öaöböcödöeöfö")); + // move to ^ + for _ in 0..4 { + c.right(); + } + assert_eq!(c.substring(), "öaöb"); + c.insert('ö'); + c.insert('g'); + c.insert('ö'); + c.insert('h'); + assert_eq!(c.substring(), "öaöbögöh"); + assert_eq!(c.as_str(), "öaöbögöhöcödöeöfö"); + } +} diff --git a/crates/turtle/src/command/client/search/duration.rs b/crates/turtle/src/command/client/search/duration.rs new file mode 100644 index 00000000..54856c87 --- /dev/null +++ b/crates/turtle/src/command/client/search/duration.rs @@ -0,0 +1,65 @@ +use core::fmt; +use std::{ops::ControlFlow, time::Duration}; + +#[expect(clippy::module_name_repetitions)] +pub fn format_duration_into(dur: Duration, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn item(unit: &'static str, value: u64) -> ControlFlow<(&'static str, u64)> { + if value > 0 { + ControlFlow::Break((unit, value)) + } else { + ControlFlow::Continue(()) + } + } + + // impl taken and modified from + // https://github.com/tailhook/humantime/blob/master/src/duration.rs#L295-L331 + // Copyright (c) 2016 The humantime Developers + fn fmt(f: Duration) -> ControlFlow<(&'static str, u64), ()> { + let secs = f.as_secs(); + let nanos = f.subsec_nanos(); + + let years = secs / 31_557_600; // 365.25d + let year_days = secs % 31_557_600; + let months = year_days / 2_630_016; // 30.44d + let month_days = year_days % 2_630_016; + let days = month_days / 86400; + let day_secs = month_days % 86400; + let hours = day_secs / 3600; + let minutes = day_secs % 3600 / 60; + let seconds = day_secs % 60; + + let millis = nanos / 1_000_000; + let micros = nanos / 1_000; + + // a difference from our impl than the original is that + // we only care about the most-significant segment of the duration. + // If the item call returns `Break`, then the `?` will early-return. + // This allows for a very consise impl + item("y", years)?; + item("mo", months)?; + item("d", days)?; + item("h", hours)?; + item("m", minutes)?; + item("s", seconds)?; + item("ms", u64::from(millis))?; + item("us", u64::from(micros))?; + item("ns", u64::from(nanos))?; + ControlFlow::Continue(()) + } + + match fmt(dur) { + ControlFlow::Break((unit, value)) => write!(f, "{value}{unit}"), + ControlFlow::Continue(()) => write!(f, "0s"), + } +} + +#[expect(clippy::module_name_repetitions)] +pub fn format_duration(f: Duration) -> String { + struct F(Duration); + impl fmt::Display for F { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + format_duration_into(self.0, f) + } + } + F(f).to_string() +} diff --git a/crates/turtle/src/command/client/search/engines.rs b/crates/turtle/src/command/client/search/engines.rs new file mode 100644 index 00000000..0f92b4c7 --- /dev/null +++ b/crates/turtle/src/command/client/search/engines.rs @@ -0,0 +1,95 @@ +use async_trait::async_trait; +use crate::atuin_client::{ + database::{Context, Database, OptFilters}, + history::{AUTHOR_FILTER_ALL_USER, History, HistoryId}, + settings::{FilterMode, SearchMode, Settings}, +}; +use eyre::Result; + +use super::cursor::Cursor; + +#[cfg(feature = "daemon")] +pub mod daemon; +pub mod db; +pub mod skim; + +#[expect(unused)] // settings is only used if daemon feature is enabled +pub fn engine(search_mode: SearchMode, settings: &Settings) -> Box<dyn SearchEngine> { + match search_mode { + SearchMode::Skim => Box::new(skim::Search::new()) as Box<_>, + #[cfg(feature = "daemon")] + SearchMode::DaemonFuzzy => Box::new(daemon::Search::new(settings)) as Box<_>, + #[cfg(not(feature = "daemon"))] + SearchMode::DaemonFuzzy => { + // Fall back to fuzzy mode if daemon feature is not enabled + Box::new(db::Search(SearchMode::Fuzzy)) as Box<_> + } + mode => Box::new(db::Search(mode)) as Box<_>, + } +} + +pub struct SearchState { + pub input: Cursor, + pub filter_mode: FilterMode, + pub context: Context, + pub custom_context: Option<HistoryId>, +} + +impl SearchState { + pub(crate) fn rotate_filter_mode(&mut self, settings: &Settings, offset: isize) { + let mut i = settings + .search + .filters + .iter() + .position(|&m| m == self.filter_mode) + .unwrap_or_default(); + for _ in 0..settings.search.filters.len() { + i = (i.wrapping_add_signed(offset)) % settings.search.filters.len(); + let mode = settings.search.filters[i]; + if self.filter_mode_available(mode, settings) { + self.filter_mode = mode; + break; + } + } + } + + fn filter_mode_available(&self, mode: FilterMode, settings: &Settings) -> bool { + match mode { + FilterMode::Global | FilterMode::SessionPreload => self.custom_context.is_none(), + FilterMode::Workspace => settings.workspaces && self.context.git_root.is_some(), + _ => true, + } + } +} + +#[async_trait] +pub trait SearchEngine: Send + Sync + 'static { + async fn full_query( + &mut self, + state: &SearchState, + db: &mut dyn Database, + ) -> Result<Vec<History>>; + + async fn query(&mut self, state: &SearchState, db: &mut dyn Database) -> Result<Vec<History>> { + if state.input.as_str().is_empty() { + Ok(db + .search( + SearchMode::FullText, + state.filter_mode, + &state.context, + "", + OptFilters { + limit: Some(200), + authors: vec![AUTHOR_FILTER_ALL_USER.to_string()], + ..Default::default() + }, + ) + .await? + .into_iter() + .collect::<Vec<_>>()) + } else { + self.full_query(state, db).await + } + } + fn get_highlight_indices(&self, command: &str, search_input: &str) -> Vec<usize>; +} diff --git a/crates/turtle/src/command/client/search/engines/daemon.rs b/crates/turtle/src/command/client/search/engines/daemon.rs new file mode 100644 index 00000000..b1299c02 --- /dev/null +++ b/crates/turtle/src/command/client/search/engines/daemon.rs @@ -0,0 +1,242 @@ +use crate::atuin_client::{ + database::{Database, OptFilters}, + history::{AUTHOR_FILTER_ALL_USER, History}, + settings::{SearchMode, Settings}, +}; +use crate::atuin_daemon::client::{DaemonClientErrorKind, SearchClient, classify_error}; +use async_trait::async_trait; +use atuin_nucleo_matcher::{ + Config, Matcher, Utf32Str, + pattern::{CaseMatching, Normalization, Pattern}, +}; +use eyre::Result; +use tracing::{Level, debug, instrument, span}; +use uuid::Uuid; + +use super::{SearchEngine, SearchState}; +use crate::command::client::daemon; + +pub struct Search { + client: Option<SearchClient>, + query_id: u64, + settings: Settings, + #[cfg(unix)] + socket_path: String, +} + +impl Search { + pub fn new(settings: &Settings) -> Self { + Search { + client: None, + query_id: 0, + settings: settings.clone(), + #[cfg(unix)] + socket_path: settings.daemon.socket_path.clone(), + } + } + + #[instrument(skip_all, level = Level::TRACE, name = "get_daemon_client")] + async fn get_client(&mut self) -> Result<&mut SearchClient> { + if self.client.is_none() { + self.connect().await?; + } + Ok(self.client.as_mut().unwrap()) + } + + async fn connect(&mut self) -> Result<()> { + #[cfg(unix)] + let client = SearchClient::new(self.socket_path.clone()).await?; + + self.client = Some(client); + Ok(()) + } + + fn should_retry(err: &eyre::Report) -> bool { + matches!( + classify_error(err), + DaemonClientErrorKind::Connect + | DaemonClientErrorKind::Unavailable + | DaemonClientErrorKind::Unimplemented + ) + } + + fn next_query_id(&mut self) -> u64 { + self.query_id += 1; + self.query_id + } + + /// Check if query contains regex pattern (r/.../) + /// Nucleo doesn't support regex, so we fall back to database search + fn contains_regex_pattern(query: &str) -> bool { + query.starts_with("r/") || query.contains(" r/") + } + + #[instrument(skip_all, level = Level::TRACE, name = "daemon_db_fallback")] + async fn fallback_to_db_search( + &self, + state: &SearchState, + db: &dyn Database, + ) -> Result<Vec<History>> { + let results = db + .search( + SearchMode::FullText, + state.filter_mode, + &state.context, + state.input.as_str(), + OptFilters { + limit: Some(200), + authors: vec![AUTHOR_FILTER_ALL_USER.to_string()], + ..Default::default() + }, + ) + .await + .map_or(Vec::new(), |r| r.into_iter().collect()); + Ok(results) + } + + #[instrument(skip_all, level = Level::TRACE, name = "hydrate_from_db", fields(count = ids.len()))] + async fn hydrate_from_db(&self, db: &dyn Database, ids: &[String]) -> Result<Vec<History>> { + let placeholders: Vec<String> = ids.iter().map(|id| format!("'{id}'")).collect(); + let sql_query = format!( + "SELECT * FROM history WHERE id IN ({}) ORDER BY timestamp DESC", + placeholders.join(",") + ); + Ok(db.query_history(&sql_query).await?) + } +} + +#[async_trait] +impl SearchEngine for Search { + #[instrument(skip_all, level = Level::TRACE, name = "daemon_search", fields(query = %state.input.as_str()))] + async fn full_query( + &mut self, + state: &SearchState, + db: &mut dyn Database, + ) -> Result<Vec<History>> { + let query = state.input.as_str().to_string(); + + // Fall back to database for regex queries (Nucleo doesn't support regex) + if Self::contains_regex_pattern(&query) { + debug!(query = %query, "[daemon-client] regex detected, falling back to db"); + return self.fallback_to_db_search(state, db).await; + } + + let query_id = self.next_query_id(); + + let span = + span!(Level::TRACE, "daemon_search.req_resp", query = %query, query_id = query_id); + + // Try to connect and search; if it fails with a retriable error, + // auto-start the daemon and retry once. + let first_attempt = async { + let client = self.get_client().await?; + client + .search( + query.clone(), + query_id, + state.filter_mode, + Some(state.context.clone()), + ) + .await + } + .await; + + let mut stream = match first_attempt { + Ok(stream) => stream, + Err(err) if self.settings.daemon.autostart && Self::should_retry(&err) => { + debug!("daemon not available, attempting auto-start"); + self.client = None; + + daemon::ensure_daemon_running(&self.settings).await?; + + let client = self.get_client().await?; + client + .search( + query.clone(), + query_id, + state.filter_mode, + Some(state.context.clone()), + ) + .await? + } + Err(err) => return Err(err), + }; + + let mut ids = Vec::with_capacity(200); + span!(Level::TRACE, "daemon_search.resp") + .in_scope(async || { + while let Ok(Some(response)) = stream.message().await { + let span2 = span!( + Level::TRACE, + "daemon_search.resp.item", + query_id = response.query_id + ); + let _span2 = span2.enter(); + // Only process if the query_id matches (prevents stale responses) + if response.query_id == query_id { + let uuids = response + .ids + .iter() + .map(|id| { + let bytes: [u8; 16] = + id.as_slice().try_into().expect("id should be 16 bytes"); + Uuid::from_bytes(bytes).as_simple().to_string() + }) + .collect::<Vec<_>>(); + ids.extend(uuids); + } + drop(_span2); + drop(span2); + } + }) + .await; + drop(span); + + if ids.is_empty() { + debug!(query = %query, results = 0, "[daemon-client] empty results"); + return Ok(Vec::new()); + } + + // // Hydrate from local database + let results = self.hydrate_from_db(db, &ids).await?; + + // // Reorder results to match the order from the daemon (which is ranked by relevance) + let ordered_results = span!(Level::TRACE, "reorder_results").in_scope(|| { + let mut ordered_results = Vec::with_capacity(results.len()); + for id in &ids { + if let Some(history) = results.iter().find(|h| h.id.0 == *id) { + ordered_results.push(history.clone()); + } + } + ordered_results + }); + + debug!( + query = %query, + results = results.len(), + "[daemon-client]" + ); + + Ok(ordered_results) + } + + #[instrument(skip_all, level = Level::TRACE, name = "daemon_highlight")] + fn get_highlight_indices(&self, command: &str, search_input: &str) -> Vec<usize> { + // Use fulltext highlighting for regex queries + if Self::contains_regex_pattern(search_input) { + return super::db::get_highlight_indices_fulltext(command, search_input); + } + + let mut matcher = Matcher::new(Config::DEFAULT); + let pattern = Pattern::parse(search_input, CaseMatching::Smart, Normalization::Smart); + + let mut indices: Vec<u32> = Vec::new(); + let mut haystack_buf = Vec::new(); + + let haystack = Utf32Str::new(command, &mut haystack_buf); + pattern.indices(haystack, &mut matcher, &mut indices); + + // Convert u32 indices to usize + indices.into_iter().map(|i| i as usize).collect() + } +} diff --git a/crates/turtle/src/command/client/search/engines/db.rs b/crates/turtle/src/command/client/search/engines/db.rs new file mode 100644 index 00000000..2765faf5 --- /dev/null +++ b/crates/turtle/src/command/client/search/engines/db.rs @@ -0,0 +1,110 @@ +use super::{SearchEngine, SearchState}; +use async_trait::async_trait; +use crate::atuin_client::{ + database::Database, + database::OptFilters, + database::{QueryToken, QueryTokenizer}, + history::{AUTHOR_FILTER_ALL_USER, History}, + settings::SearchMode, +}; +use eyre::Result; +use norm::Metric; +use norm::fzf::{FzfParser, FzfV2}; +use std::ops::Range; +use tracing::{Level, instrument}; + +pub struct Search(pub SearchMode); + +#[async_trait] +impl SearchEngine for Search { + #[instrument(skip_all, level = Level::TRACE, name = "db_search", fields(mode = ?self.0, query = %state.input.as_str()))] + async fn full_query( + &mut self, + state: &SearchState, + db: &mut dyn Database, + ) -> Result<Vec<History>> { + let results = db + .search( + self.0, + state.filter_mode, + &state.context, + state.input.as_str(), + OptFilters { + limit: Some(200), + authors: vec![AUTHOR_FILTER_ALL_USER.to_string()], + ..Default::default() + }, + ) + .await + // ignore errors as it may be caused by incomplete regex + .map_or(Vec::new(), |r| r.into_iter().collect()); + Ok(results) + } + + #[instrument(skip_all, level = Level::TRACE, name = "db_highlight")] + fn get_highlight_indices(&self, command: &str, search_input: &str) -> Vec<usize> { + if self.0 == SearchMode::Prefix { + return vec![]; + } else if self.0 == SearchMode::FullText { + return get_highlight_indices_fulltext(command, search_input); + } + let mut fzf = FzfV2::new(); + let mut parser = FzfParser::new(); + let query = parser.parse(search_input); + let mut ranges: Vec<Range<usize>> = Vec::new(); + let _ = fzf.distance_and_ranges(query, command, &mut ranges); + + // convert ranges to all indices + ranges.into_iter().flatten().collect() + } +} + +#[instrument(skip_all, level = Level::TRACE, name = "db_highlight_fulltext")] +pub fn get_highlight_indices_fulltext(command: &str, search_input: &str) -> Vec<usize> { + let mut ranges = vec![]; + let lower_command = command.to_ascii_lowercase(); + + for token in QueryTokenizer::new(search_input) { + let matchee = if token.has_uppercase() { + command + } else { + &lower_command + }; + + if token.is_inverse() { + continue; + } + + match token { + QueryToken::Or => {} + QueryToken::Regex(r) => { + if let Ok(re) = regex::Regex::new(r) { + for m in re.find_iter(command) { + ranges.push(m.range()); + } + } + } + QueryToken::MatchStart(term, _) => { + if matchee.starts_with(term) { + ranges.push(0..term.len()); + } + } + QueryToken::MatchEnd(term, _) => { + if matchee.ends_with(term) { + let l = matchee.len(); + ranges.push((l - term.len())..l); + } + } + QueryToken::Match(term, _) | QueryToken::MatchFull(term, _) => { + for (idx, m) in matchee.match_indices(term) { + ranges.push(idx..(idx + m.len())); + } + } + } + } + + let mut ret: Vec<_> = ranges.into_iter().flatten().collect(); + ret.sort_unstable(); + ret.dedup(); + ret +} diff --git a/crates/turtle/src/command/client/search/engines/skim.rs b/crates/turtle/src/command/client/search/engines/skim.rs new file mode 100644 index 00000000..96a6574d --- /dev/null +++ b/crates/turtle/src/command/client/search/engines/skim.rs @@ -0,0 +1,229 @@ +use std::path::Path; + +use async_trait::async_trait; +use crate::atuin_client::{ + database::Database, + history::{History, is_known_agent}, + settings::FilterMode, +}; +use eyre::Result; +use fuzzy_matcher::{FuzzyMatcher, skim::SkimMatcherV2}; +use itertools::Itertools; +use time::OffsetDateTime; +use tokio::task::yield_now; +use tracing::{Level, instrument, warn}; +use uuid; + +use super::{SearchEngine, SearchState}; + +pub struct Search { + all_history: Vec<(History, i32)>, + engine: SkimMatcherV2, +} + +impl Search { + pub fn new() -> Self { + Search { + all_history: vec![], + engine: SkimMatcherV2::default(), + } + } +} + +#[async_trait] +impl SearchEngine for Search { + #[instrument(skip_all, level = Level::TRACE, name = "skim_search", fields(query = %state.input.as_str()))] + async fn full_query( + &mut self, + state: &SearchState, + db: &mut dyn Database, + ) -> Result<Vec<History>> { + if self.all_history.is_empty() { + self.all_history = load_all_history(db).await; + } + + Ok(fuzzy_search(&self.engine, state, &self.all_history).await) + } + + #[instrument(skip_all, level = Level::TRACE, name = "skim_highlight")] + fn get_highlight_indices(&self, command: &str, search_input: &str) -> Vec<usize> { + let (_, indices) = self + .engine + .fuzzy_indices(command, search_input) + .unwrap_or_default(); + indices + } +} + +#[instrument(skip_all, level = Level::TRACE, name = "load_all_history")] +async fn load_all_history(db: &dyn Database) -> Vec<(History, i32)> { + db.all_with_count().await.unwrap() +} + +#[expect(clippy::too_many_lines)] +#[instrument(skip_all, level = Level::TRACE, name = "fuzzy_match", fields(history_count = all_history.len()))] +async fn fuzzy_search( + engine: &SkimMatcherV2, + state: &SearchState, + all_history: &[(History, i32)], +) -> Vec<History> { + let mut set = Vec::with_capacity(200); + let mut ranks = Vec::with_capacity(200); + let query = state.input.as_str(); + let now = OffsetDateTime::now_utc(); + + for (i, (history, count)) in all_history.iter().enumerate() { + if i % 256 == 0 { + yield_now().await; + } + if is_known_agent(&history.author) { + continue; + } + let context = &state.context; + let git_root = context + .git_root + .as_ref() + .and_then(|git_root| git_root.to_str()) + .unwrap_or(&context.cwd); + match state.filter_mode { + FilterMode::Global => {} + // we aggregate host by ',' separating them + FilterMode::Host + if history + .hostname + .split(',') + .contains(&context.hostname.as_str()) => {} + // we aggregate session by concattenating them. + // sessions are 32 byte simple uuid formats + FilterMode::Session + if history + .session + .as_bytes() + .chunks(32) + .contains(&context.session.as_bytes()) => {} + // SessionPreload: include current session + global history from before session start + FilterMode::SessionPreload => { + let is_current_session = { + history + .session + .as_bytes() + .chunks(32) + .any(|chunk| chunk == context.session.as_bytes()) + }; + + if !is_current_session { + let Ok(uuid) = uuid::Uuid::parse_str(&context.session) else { + warn!("failed to parse session id '{}'", context.session); + continue; + }; + let Some(timestamp) = uuid.get_timestamp() else { + warn!( + "failed to get timestamp from uuid '{}'", + uuid.as_hyphenated() + ); + continue; + }; + let (seconds, nanos) = timestamp.to_unix(); + let Ok(session_start) = time::OffsetDateTime::from_unix_timestamp_nanos( + i128::from(seconds) * 1_000_000_000 + i128::from(nanos), + ) else { + warn!( + "failed to create OffsetDateTime from second: {seconds}, nanosecond: {nanos}" + ); + continue; + }; + + if history.timestamp >= session_start { + continue; + } + } + } + // we aggregate directory by ':' separating them + FilterMode::Directory if history.cwd.split(':').contains(&context.cwd.as_str()) => {} + FilterMode::Workspace if history.cwd.split(':').contains(&git_root) => {} + _ => continue, + } + #[expect(clippy::cast_lossless, clippy::cast_precision_loss)] + if let Some((score, indices)) = engine.fuzzy_indices(&history.command, query) { + let begin = indices.first().copied().unwrap_or_default(); + + let mut duration = (now - history.timestamp).as_seconds_f64().log2(); + if !duration.is_finite() || duration <= 1.0 { + duration = 1.0; + } + // these + X.0 just make the log result a bit smoother. + // log is very spiky towards 1-4, but I want a gradual decay. + // eg: + // log2(4) = 2, log2(5) = 2.3 (16% increase) + // log2(8) = 3, log2(9) = 3.16 (5% increase) + // log2(16) = 4, log2(17) = 4.08 (2% increase) + let count = (*count as f64 + 8.0).log2(); + let begin = (begin as f64 + 16.0).log2(); + let path = path_dist(history.cwd.as_ref(), state.context.cwd.as_ref()); + let path = (path as f64 + 8.0).log2(); + + // reduce longer durations, raise higher counts, raise matches close to the start + let score = (-score as f64) * count / path / duration / begin; + + 'insert: { + // algorithm: + // 1. find either the position that this command ranks + // 2. find the same command positioned better than our rank. + for i in 0..set.len() { + // do we out score the current position? + if ranks[i] > score { + ranks.insert(i, score); + set.insert(i, history.clone()); + let mut j = i + 1; + while j < set.len() { + // remove duplicates that have a worse score + if set[j].command == history.command { + ranks.remove(j); + set.remove(j); + + // break this while loop because there won't be any other + // duplicates. + break; + } + j += 1; + } + + // keep it limited + if ranks.len() > 200 { + ranks.pop(); + set.pop(); + } + + break 'insert; + } + // don't continue if this command has a better score already + if set[i].command == history.command { + break 'insert; + } + } + + if set.len() < 200 { + ranks.push(score); + set.push(history.clone()); + } + } + } + } + + set +} + +fn path_dist(a: &Path, b: &Path) -> usize { + let mut a: Vec<_> = a.components().collect(); + let b: Vec<_> = b.components().collect(); + + let mut dist = 0; + + // pop a until there's a common ancestor + while !b.starts_with(&a) { + dist += 1; + a.pop(); + } + + b.len() - a.len() + dist +} diff --git a/crates/turtle/src/command/client/search/history_list.rs b/crates/turtle/src/command/client/search/history_list.rs new file mode 100644 index 00000000..4c83d7eb --- /dev/null +++ b/crates/turtle/src/command/client/search/history_list.rs @@ -0,0 +1,429 @@ +use std::time::Duration; + +use super::duration::format_duration; +use super::engines::SearchEngine; +use crate::atuin_client::{ + history::History, + settings::{UiColumn, UiColumnType}, + theme::{Meaning, Theme}, +}; +use crate::atuin_common::utils::Escapable as _; +use itertools::Itertools; +use ratatui::{ + backend::FromCrossterm, + buffer::Buffer, + crossterm::style, + layout::Rect, + style::{Modifier, Style}, + widgets::{Block, StatefulWidget, Widget}, +}; +use time::OffsetDateTime; + +pub struct HistoryHighlighter<'a> { + pub engine: &'a dyn SearchEngine, + pub search_input: &'a str, +} + +impl HistoryHighlighter<'_> { + pub fn get_highlight_indices(&self, command: &str) -> Vec<usize> { + self.engine + .get_highlight_indices(command, self.search_input) + } +} + +pub struct HistoryList<'a> { + history: &'a [History], + block: Option<Block<'a>>, + inverted: bool, + /// Apply an alternative highlighting to the selected row + alternate_highlight: bool, + now: &'a dyn Fn() -> OffsetDateTime, + indicator: &'a str, + theme: &'a Theme, + history_highlighter: HistoryHighlighter<'a>, + show_numeric_shortcuts: bool, + /// Columns to display (in order, after the indicator) + columns: &'a [UiColumn], +} + +#[derive(Default)] +pub struct ListState { + offset: usize, + selected: usize, + max_entries: usize, +} + +impl ListState { + pub fn selected(&self) -> usize { + self.selected + } + + pub fn max_entries(&self) -> usize { + self.max_entries + } + + pub fn offset(&self) -> usize { + self.offset + } + + pub fn select(&mut self, index: usize) { + self.selected = index; + } +} + +impl StatefulWidget for HistoryList<'_> { + type State = ListState; + + fn render(mut self, area: Rect, buf: &mut Buffer, state: &mut Self::State) { + let list_area = self.block.take().map_or(area, |b| { + let inner_area = b.inner(area); + b.render(area, buf); + inner_area + }); + + if list_area.width < 1 || list_area.height < 1 || self.history.is_empty() { + return; + } + let list_height = list_area.height as usize; + + let (start, end) = self.get_items_bounds(state.selected, state.offset, list_height); + state.offset = start; + state.max_entries = end - start; + + let mut s = DrawState { + buf, + list_area, + x: 0, + y: 0, + state, + inverted: self.inverted, + alternate_highlight: self.alternate_highlight, + now: &self.now, + indicator: self.indicator, + theme: self.theme, + history_highlighter: self.history_highlighter, + show_numeric_shortcuts: self.show_numeric_shortcuts, + columns: self.columns, + }; + + for item in self.history.iter().skip(state.offset).take(end - start) { + s.render_row(item); + + // reset line + s.y += 1; + s.x = 0; + } + } +} + +impl<'a> HistoryList<'a> { + #[expect(clippy::too_many_arguments)] + pub fn new( + history: &'a [History], + inverted: bool, + alternate_highlight: bool, + now: &'a dyn Fn() -> OffsetDateTime, + indicator: &'a str, + theme: &'a Theme, + history_highlighter: HistoryHighlighter<'a>, + show_numeric_shortcuts: bool, + columns: &'a [UiColumn], + ) -> Self { + Self { + history, + block: None, + inverted, + alternate_highlight, + now, + indicator, + theme, + history_highlighter, + show_numeric_shortcuts, + columns, + } + } + + pub fn block(mut self, block: Block<'a>) -> Self { + self.block = Some(block); + self + } + + fn get_items_bounds(&self, selected: usize, offset: usize, height: usize) -> (usize, usize) { + let offset = offset.min(self.history.len().saturating_sub(1)); + + let max_scroll_space = height.min(10).min(self.history.len() - selected); + if offset + height < selected + max_scroll_space { + let end = selected + max_scroll_space; + (end - height, end) + } else if selected < offset { + (selected, selected + height) + } else { + (offset, offset + height) + } + } +} + +struct DrawState<'a> { + buf: &'a mut Buffer, + list_area: Rect, + x: u16, + y: u16, + state: &'a ListState, + inverted: bool, + alternate_highlight: bool, + now: &'a dyn Fn() -> OffsetDateTime, + indicator: &'a str, + theme: &'a Theme, + history_highlighter: HistoryHighlighter<'a>, + show_numeric_shortcuts: bool, + columns: &'a [UiColumn], +} + +// these encode the slices of `" > "`, `" {n} "`, or `" "` in a compact form. +// Yes, this is a hack, but it makes me feel happy +static SLICES: &str = " > 1 2 3 4 5 6 7 8 9 "; + +impl DrawState<'_> { + /// Render a complete row for a history item based on configured columns. + fn render_row(&mut self, h: &History) { + // Always render the indicator first (width 3) + self.index(); + + // Calculate the width for the expanding column + // Fixed columns use their configured width + 1 (trailing space) + let indicator_width: u16 = 3; + let fixed_width: u16 = self + .columns + .iter() + .filter(|c| !c.expand) + .map(|c| c.width + 1) + .sum(); + let expand_width = self + .list_area + .width + .saturating_sub(indicator_width + fixed_width); + + let style = self.theme.as_style(Meaning::Base); + // Render each configured column + for (idx, column) in self.columns.iter().enumerate() { + if idx != 0 { + self.draw(" ", Style::from_crossterm(style)); + } + let width = if column.expand { + expand_width + } else { + column.width + }; + match column.column_type { + UiColumnType::Duration => self.duration(h, width), + UiColumnType::Time => self.time(h, width), + UiColumnType::Datetime => self.datetime(h, width), + UiColumnType::Directory => self.directory(h, width), + UiColumnType::Host => self.host(h, width), + UiColumnType::User => self.user(h, width), + UiColumnType::Exit => self.exit_code(h, width), + UiColumnType::Command => self.command(h), + } + } + } + + fn index(&mut self) { + if !self.show_numeric_shortcuts { + let i = self.y as usize + self.state.offset; + let is_selected = i == self.state.selected(); + let prompt: &str = if is_selected { self.indicator } else { " " }; + self.draw(prompt, Style::default()); + return; + } + + // these encode the slices of `" > "`, `" {n} "`, or `" "` in a compact form. + // Yes, this is a hack, but it makes me feel happy + + let i = self.y as usize + self.state.offset; + let i = i.checked_sub(self.state.selected); + let i = i.unwrap_or(10).min(10) * 2; + let prompt: &str = if i == 0 { + self.indicator + } else { + &SLICES[i..i + 3] + }; + self.draw(prompt, Style::default()); + } + + fn duration(&mut self, h: &History, width: u16) { + let style = self.theme.as_style(if h.success() { + Meaning::AlertInfo + } else { + Meaning::AlertError + }); + let duration = Duration::from_nanos(u64::try_from(h.duration).unwrap_or(0)); + let formatted = format_duration(duration); + let w = width as usize; + // Right-align duration within its column width, plus trailing space + let display = format!("{formatted:>w$}"); + self.draw(&display, Style::from_crossterm(style)); + } + + fn time(&mut self, h: &History, width: u16) { + let style = self.theme.as_style(Meaning::Guidance); + + // Account for the chance that h.timestamp is "in the future" + // This would mean that "since" is negative, and the unwrap here + // would fail. + // If the timestamp would otherwise be in the future, display + // the time since as 0. + let since = (self.now)() - h.timestamp; + let time = format_duration(since.try_into().unwrap_or_default()); + + // Format as "Xs ago" right-aligned within column width + let w = width as usize; + let time_str = format!("{time} ago"); + + let display = format!("{time_str:>w$}"); + self.draw(&display, Style::from_crossterm(style)); + } + + fn command(&mut self, h: &History) { + let mut style = self.theme.as_style(Meaning::Base); + let mut row_highlighted = false; + if !self.alternate_highlight && (self.y as usize + self.state.offset == self.state.selected) + { + row_highlighted = true; + // if not applying alternative highlighting to the whole row, color the command + style = self.theme.as_style(Meaning::AlertError); + style.attributes.set(style::Attribute::Bold); + } + + let highlight_indices = self.history_highlighter.get_highlight_indices( + h.command + .escape_control() + .split_ascii_whitespace() + .join(" ") + .as_str(), + ); + + let mut pos = 0; + for section in h.command.escape_control().split_ascii_whitespace() { + if pos != 0 { + self.draw(" ", Style::from_crossterm(style)); + } + for ch in section.chars() { + if self.x > self.list_area.width { + // Avoid attempting to draw a command section beyond the width + // of the list + return; + } + let mut style = style; + if highlight_indices.contains(&pos) { + if row_highlighted { + // if the row is highlighted bold is not enough as the whole row is bold + // change the color too + style = self.theme.as_style(Meaning::AlertWarn); + } + style.attributes.set(style::Attribute::Bold); + } + let s = ch.to_string(); + self.draw(&s, Style::from_crossterm(style)); + pos += s.len(); + } + pos += 1; + } + } + + /// Render the absolute datetime column (e.g., "2025-01-22 14:35") + fn datetime(&mut self, h: &History, width: u16) { + let style = self.theme.as_style(Meaning::Annotation); + // Format: YYYY-MM-DD HH:MM + let formatted = h + .timestamp + .format( + &time::format_description::parse("[year]-[month]-[day] [hour]:[minute]") + .expect("valid format"), + ) + .unwrap_or_else(|_| "????-??-?? ??:??".to_string()); + let w = width as usize; + let display = format!("{formatted:w$}"); + self.draw(&display, Style::from_crossterm(style)); + } + + /// Render the directory column (working directory, truncated) + fn directory(&mut self, h: &History, width: u16) { + let style = self.theme.as_style(Meaning::Annotation); + let w = width as usize; + let cwd = &h.cwd; + let char_count = cwd.chars().count(); + // Truncate from the left with "..." if too long, plus trailing space + // Use character count for comparison and skip for UTF-8 safety + let display = if char_count > w && w >= 4 { + let truncated: String = cwd.chars().skip(char_count - (w - 3)).collect(); + format!("...{truncated}") + } else { + format!("{cwd:w$}") + }; + self.draw(&display, Style::from_crossterm(style)); + } + + /// Render the host column (just the hostname) + fn host(&mut self, h: &History, width: u16) { + let style = self.theme.as_style(Meaning::Annotation); + let w = width as usize; + // Database stores hostname as "hostname:username" + let host = h.hostname.split(':').next().unwrap_or(&h.hostname); + let char_count = host.chars().count(); + // Use character count for comparison and take for UTF-8 safety + let display = if char_count > w && w >= 4 { + let truncated: String = host.chars().take(w.saturating_sub(4)).collect(); + format!("{truncated}...") + } else { + format!("{host:w$}") + }; + self.draw(&display, Style::from_crossterm(style)); + } + + /// Render the user column + fn user(&mut self, h: &History, width: u16) { + let style = self.theme.as_style(Meaning::Annotation); + let w = width as usize; + // Database stores hostname as "hostname:username" + let user = h.hostname.split(':').nth(1).unwrap_or(""); + let char_count = user.chars().count(); + // Use character count for comparison and take for UTF-8 safety + let display = if char_count > w && w >= 4 { + let truncated: String = user.chars().take(w.saturating_sub(4)).collect(); + format!("{truncated}...") + } else { + format!("{user:w$}") + }; + self.draw(&display, Style::from_crossterm(style)); + } + + /// Render the exit code column + fn exit_code(&mut self, h: &History, width: u16) { + let style = if h.success() { + self.theme.as_style(Meaning::AlertInfo) + } else { + self.theme.as_style(Meaning::AlertError) + }; + let w = width as usize; + let display = format!("{:>w$}", h.exit); + self.draw(&display, Style::from_crossterm(style)); + } + + fn draw(&mut self, s: &str, mut style: Style) { + let cx = self.list_area.left() + self.x; + + let cy = if self.inverted { + self.list_area.top() + self.y + } else { + self.list_area.bottom() - self.y - 1 + }; + + if self.alternate_highlight && (self.y as usize + self.state.offset == self.state.selected) + { + style = style.add_modifier(Modifier::REVERSED); + } + + let w = (self.list_area.width - self.x) as usize; + self.x += self.buf.set_stringn(cx, cy, s, w, style).0 - cx; + } +} diff --git a/crates/turtle/src/command/client/search/inspector.rs b/crates/turtle/src/command/client/search/inspector.rs new file mode 100644 index 00000000..1ebc4383 --- /dev/null +++ b/crates/turtle/src/command/client/search/inspector.rs @@ -0,0 +1,421 @@ +use std::time::Duration; +use time::macros::format_description; + +use crate::atuin_client::{ + history::{History, HistoryStats}, + settings::{Settings, Timezone}, +}; +use ratatui::{ + Frame, + backend::FromCrossterm, + layout::Rect, + prelude::{Constraint, Direction, Layout}, + style::Style, + text::{Span, Text}, + widgets::{Bar, BarChart, BarGroup, Block, Borders, Padding, Paragraph, Row, Table}, +}; + +use super::duration::format_duration; + +use super::super::theme::{Meaning, Theme}; +use super::interactive::{Compactness, to_compactness}; + +#[expect(clippy::cast_sign_loss)] +fn u64_or_zero(num: i64) -> u64 { + if num < 0 { 0 } else { num as u64 } +} + +pub fn draw_commands( + f: &mut Frame<'_>, + parent: Rect, + history: &History, + stats: &HistoryStats, + compact: bool, + theme: &Theme, +) { + let commands = Layout::default() + .direction(if compact { + Direction::Vertical + } else { + Direction::Horizontal + }) + .constraints(if compact { + [ + Constraint::Length(1), + Constraint::Length(1), + Constraint::Min(0), + ] + } else { + [ + Constraint::Ratio(1, 4), + Constraint::Ratio(1, 2), + Constraint::Ratio(1, 4), + ] + }) + .split(parent); + + let command = Paragraph::new(Text::from(Span::styled( + history.command.clone(), + Style::from_crossterm(theme.as_style(Meaning::Important)), + ))) + .block(if compact { + Block::new() + .borders(Borders::NONE) + .style(Style::from_crossterm(theme.as_style(Meaning::Base))) + } else { + Block::new() + .borders(Borders::ALL) + .style(Style::from_crossterm(theme.as_style(Meaning::Base))) + .title("Command") + .padding(Padding::horizontal(1)) + }); + + let previous = Paragraph::new( + stats + .previous + .clone() + .map_or_else(|| "[No previous command]".to_string(), |prev| prev.command), + ) + .block(if compact { + Block::new() + .borders(Borders::NONE) + .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))) + } else { + Block::new() + .borders(Borders::ALL) + .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))) + .title("Previous command") + .padding(Padding::horizontal(1)) + }); + + // Add [] around blank text, as when this is shown in a list + // compacted, it makes it more obviously control text. + let next = Paragraph::new( + stats + .next + .clone() + .map_or_else(|| "[No next command]".to_string(), |next| next.command), + ) + .block(if compact { + Block::new() + .borders(Borders::NONE) + .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))) + } else { + Block::new() + .borders(Borders::ALL) + .title("Next command") + .padding(Padding::horizontal(1)) + .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))) + }); + + f.render_widget(previous, commands[0]); + f.render_widget(command, commands[1]); + f.render_widget(next, commands[2]); +} + +pub fn draw_stats_table( + f: &mut Frame<'_>, + parent: Rect, + history: &History, + tz: Timezone, + stats: &HistoryStats, + theme: &Theme, +) { + let duration = Duration::from_nanos(u64_or_zero(history.duration)); + let avg_duration = Duration::from_nanos(stats.average_duration); + let (host, user) = history.hostname.split_once(':').unwrap_or(("", "")); + + let rows = [ + Row::new(vec!["Host".to_string(), host.to_string()]), + Row::new(vec!["User".to_string(), user.to_string()]), + Row::new(vec![ + "Time".to_string(), + history.timestamp.to_offset(tz.0).to_string(), + ]), + Row::new(vec!["Duration".to_string(), format_duration(duration)]), + Row::new(vec![ + "Avg duration".to_string(), + format_duration(avg_duration), + ]), + Row::new(vec!["Exit".to_string(), history.exit.to_string()]), + Row::new(vec!["Directory".to_string(), history.cwd.clone()]), + Row::new(vec!["Session".to_string(), history.session.clone()]), + Row::new(vec!["Total runs".to_string(), stats.total.to_string()]), + ]; + + let widths = [Constraint::Ratio(1, 5), Constraint::Ratio(4, 5)]; + + let table = Table::new(rows, widths).column_spacing(1).block( + Block::default() + .title("Command stats") + .borders(Borders::ALL) + .style(Style::from_crossterm(theme.as_style(Meaning::Base))) + .padding(Padding::vertical(1)), + ); + + f.render_widget(table, parent); +} + +fn num_to_day(num: &str) -> String { + match num { + "0" => "Sunday".to_string(), + "1" => "Monday".to_string(), + "2" => "Tuesday".to_string(), + "3" => "Wednesday".to_string(), + "4" => "Thursday".to_string(), + "5" => "Friday".to_string(), + "6" => "Saturday".to_string(), + _ => "Invalid day".to_string(), + } +} + +fn sort_duration_over_time(durations: &[(String, i64)]) -> Vec<(String, i64)> { + let format = format_description!("[day]-[month]-[year]"); + let output = format_description!("[month]/[year repr:last_two]"); + + let mut durations: Vec<(time::Date, i64)> = durations + .iter() + .map(|d| { + ( + time::Date::parse(d.0.as_str(), &format).expect("invalid date string from sqlite"), + d.1, + ) + }) + .collect(); + + durations.sort_by_key(|a| a.0); + + durations + .iter() + .map(|(date, duration)| { + ( + date.format(output).expect("failed to format sqlite date"), + *duration, + ) + }) + .collect() +} + +fn draw_stats_charts(f: &mut Frame<'_>, parent: Rect, stats: &HistoryStats, theme: &Theme) { + let exits: Vec<Bar> = stats + .exits + .iter() + .map(|(exit, count)| { + Bar::default() + .label(exit.to_string()) + .value(u64_or_zero(*count)) + }) + .collect(); + + let exits = BarChart::default() + .block( + Block::default() + .title("Exit code distribution") + .style(Style::from_crossterm(theme.as_style(Meaning::Base))) + .borders(Borders::ALL), + ) + .bar_width(3) + .bar_gap(1) + .bar_style(Style::default()) + .value_style(Style::default()) + .label_style(Style::default()) + .data(BarGroup::default().bars(&exits)); + + let day_of_week: Vec<Bar> = stats + .day_of_week + .iter() + .map(|(day, count)| { + Bar::default() + .label(num_to_day(day.as_str())) + .value(u64_or_zero(*count)) + }) + .collect(); + + let day_of_week = BarChart::default() + .block( + Block::default() + .title("Runs per day") + .style(Style::from_crossterm(theme.as_style(Meaning::Base))) + .borders(Borders::ALL), + ) + .bar_width(3) + .bar_gap(1) + .bar_style(Style::default()) + .value_style(Style::default()) + .label_style(Style::default()) + .data(BarGroup::default().bars(&day_of_week)); + + let duration_over_time = sort_duration_over_time(&stats.duration_over_time); + let duration_over_time: Vec<Bar> = duration_over_time + .iter() + .map(|(date, duration)| { + let d = Duration::from_nanos(u64_or_zero(*duration)); + Bar::default() + .label(date.clone()) + .value(u64_or_zero(*duration)) + .text_value(format_duration(d)) + }) + .collect(); + + let duration_over_time = BarChart::default() + .block( + Block::default() + .title("Duration over time") + .style(Style::from_crossterm(theme.as_style(Meaning::Base))) + .borders(Borders::ALL), + ) + .bar_width(5) + .bar_gap(1) + .bar_style(Style::default()) + .value_style(Style::default()) + .label_style(Style::default()) + .data(BarGroup::default().bars(&duration_over_time)); + + let layout = Layout::default() + .direction(Direction::Vertical) + .constraints([ + Constraint::Ratio(1, 3), + Constraint::Ratio(1, 3), + Constraint::Ratio(1, 3), + ]) + .split(parent); + + f.render_widget(exits, layout[0]); + f.render_widget(day_of_week, layout[1]); + f.render_widget(duration_over_time, layout[2]); +} + +pub fn draw( + f: &mut Frame<'_>, + chunk: Rect, + history: &History, + stats: &HistoryStats, + settings: &Settings, + theme: &Theme, + tz: Timezone, +) { + let compactness = to_compactness(f, settings); + + match compactness { + Compactness::Ultracompact => draw_ultracompact(f, chunk, history, stats, theme), + _ => draw_full(f, chunk, history, stats, theme, tz), + } +} + +pub fn draw_ultracompact( + f: &mut Frame<'_>, + chunk: Rect, + history: &History, + stats: &HistoryStats, + theme: &Theme, +) { + draw_commands(f, chunk, history, stats, true, theme); +} + +pub fn draw_full( + f: &mut Frame<'_>, + chunk: Rect, + history: &History, + stats: &HistoryStats, + theme: &Theme, + tz: Timezone, +) { + let vert_layout = Layout::default() + .direction(Direction::Vertical) + .constraints([Constraint::Ratio(1, 5), Constraint::Ratio(4, 5)]) + .split(chunk); + + let stats_layout = Layout::default() + .direction(Direction::Horizontal) + .constraints([Constraint::Ratio(1, 3), Constraint::Ratio(2, 3)]) + .split(vert_layout[1]); + + draw_commands(f, vert_layout[0], history, stats, false, theme); + draw_stats_table(f, stats_layout[0], history, tz, stats, theme); + draw_stats_charts(f, stats_layout[1], stats, theme); +} + +#[cfg(test)] +mod tests { + use super::draw_ultracompact; + use crate::atuin_client::{ + history::{History, HistoryId, HistoryStats}, + theme::ThemeManager, + }; + use ratatui::{backend::TestBackend, prelude::*}; + use time::OffsetDateTime; + + fn mock_history_stats() -> (History, HistoryStats) { + let history = History { + id: HistoryId::from("test1".to_string()), + timestamp: OffsetDateTime::now_utc(), + duration: 3, + exit: 0, + command: "/bin/cmd".to_string(), + cwd: "/toot".to_string(), + session: "sesh1".to_string(), + hostname: "hostn".to_string(), + author: "hostn".to_string(), + intent: None, + deleted_at: None, + }; + let next = History { + id: HistoryId::from("test2".to_string()), + timestamp: OffsetDateTime::now_utc(), + duration: 2, + exit: 0, + command: "/bin/cmd -os".to_string(), + cwd: "/toot".to_string(), + session: "sesh1".to_string(), + hostname: "hostn".to_string(), + author: "hostn".to_string(), + intent: None, + deleted_at: None, + }; + let prev = History { + id: HistoryId::from("test3".to_string()), + timestamp: OffsetDateTime::now_utc(), + duration: 1, + exit: 0, + command: "/bin/cmd -a".to_string(), + cwd: "/toot".to_string(), + session: "sesh1".to_string(), + hostname: "hostn".to_string(), + author: "hostn".to_string(), + intent: None, + deleted_at: None, + }; + let stats = HistoryStats { + next: Some(next.clone()), + previous: Some(prev.clone()), + total: 2, + average_duration: 3, + exits: Vec::new(), + day_of_week: Vec::new(), + duration_over_time: Vec::new(), + }; + (history, stats) + } + + #[test] + fn test_output_looks_correct_for_ultracompact() { + let backend = TestBackend::new(22, 5); + let mut terminal = Terminal::new(backend).expect("Could not create terminal"); + let chunk = Rect::new(0, 0, 22, 5); + let (history, stats) = mock_history_stats(); + let prev = stats.previous.clone().unwrap(); + let next = stats.next.clone().unwrap(); + + let mut manager = ThemeManager::new(Some(true), Some("".to_string())); + let theme = manager.load_theme("(none)", None); + let _ = terminal.draw(|f| draw_ultracompact(f, chunk, &history, &stats, &theme)); + let mut lines = [" "; 5].map(|l| Line::from(l)); + for (n, entry) in [prev, history, next].iter().enumerate() { + let mut l = lines[n].to_string(); + l.replace_range(0..entry.command.len(), &entry.command); + lines[n] = Line::from(l); + } + + terminal.backend().assert_buffer_lines(lines); + } +} diff --git a/crates/turtle/src/command/client/search/interactive.rs b/crates/turtle/src/command/client/search/interactive.rs new file mode 100644 index 00000000..a3d2cb79 --- /dev/null +++ b/crates/turtle/src/command/client/search/interactive.rs @@ -0,0 +1,3041 @@ +use std::{ + io::{IsTerminal, Write, stdout}, + time::Duration, +}; + +#[cfg(unix)] +use std::io::Read as _; + +use crate::atuin_common::{shell::Shell, utils::Escapable as _}; +use eyre::Result; +use time::OffsetDateTime; +use unicode_width::{UnicodeWidthChar, UnicodeWidthStr}; + +use super::{ + cursor::Cursor, + engines::{SearchEngine, SearchState}, + history_list::{HistoryList, ListState}, +}; +use crate::atuin_client::{ + database::{Context, Database, current_context}, + history::{History, HistoryId, HistoryStats, store::HistoryStore}, + settings::{ + CursorStyle, ExitMode, FilterMode, KeymapMode, PreviewStrategy, SearchMode, Settings, + UiColumn, + }, +}; + +use crate::command::client::search::history_list::HistoryHighlighter; +use crate::command::client::search::keybindings::KeymapSet; +use crate::command::client::theme::{Meaning, Theme}; +use crate::{VERSION, command::client::search::engines}; + +use ratatui::{ + Frame, Terminal, TerminalOptions, Viewport, + backend::{CrosstermBackend, FromCrossterm}, + crossterm::{ + cursor::SetCursorStyle, + event::{self, Event, KeyEvent, MouseEvent}, + execute, queue, terminal, + }, + layout::{Alignment, Constraint, Direction, Layout}, + prelude::*, + style::{Modifier, Style}, + text::{Line, Span, Text}, + widgets::{Block, BorderType, Borders, Clear, Padding, Paragraph, Tabs}, +}; + +#[cfg(not(target_os = "windows"))] +use ratatui::crossterm::event::{ + KeyboardEnhancementFlags, PopKeyboardEnhancementFlags, PushKeyboardEnhancementFlags, +}; + +const TAB_TITLES: [&str; 2] = ["Search", "Inspect"]; + +pub enum InputAction { + Accept(usize), + AcceptInspecting, + Copy(usize), + Delete(usize), + DeleteAllMatching(usize), + ReturnOriginal, + ReturnQuery, + Continue, + Redraw, + SwitchContext(Option<usize>), +} + +#[derive(Clone)] +pub struct InspectingState { + current: Option<HistoryId>, + next: Option<HistoryId>, + previous: Option<HistoryId>, +} + +impl InspectingState { + pub fn move_to_previous(&mut self) { + let previous = self.previous.clone(); + self.reset(); + self.current = previous; + } + + pub fn move_to_next(&mut self) { + let next = self.next.clone(); + self.reset(); + self.current = next; + } + + pub fn reset(&mut self) { + self.current = None; + self.next = None; + self.previous = None; + } +} + +pub fn to_compactness(f: &Frame, settings: &Settings) -> Compactness { + if match settings.style { + crate::atuin_client::settings::Style::Auto => f.area().height < 14, + crate::atuin_client::settings::Style::Compact => true, + crate::atuin_client::settings::Style::Full => false, + } { + if settings.auto_hide_height != 0 && f.area().height <= settings.auto_hide_height { + Compactness::Ultracompact + } else { + Compactness::Compact + } + } else { + Compactness::Full + } +} + +#[expect(clippy::struct_field_names)] +#[expect(clippy::struct_excessive_bools)] +pub struct State { + history_count: i64, + results_state: ListState, + switched_search_mode: bool, + search_mode: SearchMode, + results_len: usize, + accept: bool, + keymap_mode: KeymapMode, + prefix: bool, + current_cursor: Option<CursorStyle>, + tab_index: usize, + pending_vim_key: Option<char>, + original_input_empty: bool, + + pub inspecting_state: InspectingState, + + keymaps: KeymapSet, + search: SearchState, + engine: Box<dyn SearchEngine>, + now: Box<dyn Fn() -> OffsetDateTime + Send>, +} + +#[derive(Clone, Copy)] +pub enum Compactness { + Ultracompact, + Compact, + Full, +} + +#[derive(Clone, Copy)] +struct StyleState { + compactness: Compactness, + invert: bool, + inner_width: usize, +} + +impl State { + async fn query_results( + &mut self, + db: &mut dyn Database, + smart_sort: bool, + ) -> Result<Vec<History>> { + let results = self.engine.query(&self.search, db).await?; + + self.inspecting_state = InspectingState { + current: None, + next: None, + previous: None, + }; + self.results_state.select(0); + self.results_len = results.len(); + + if smart_sort { + Ok(crate::atuin_history::sort::sort( + self.search.input.as_str(), + results, + )) + } else { + Ok(results) + } + } + + fn handle_input(&mut self, settings: &Settings, input: &Event) -> InputAction { + match input { + Event::Key(k) => self.handle_key_input(settings, k), + Event::Mouse(m) => self.handle_mouse_input(*m, settings.invert), + Event::Paste(d) => self.handle_paste_input(d), + _ => InputAction::Continue, + } + } + + fn handle_mouse_input(&mut self, input: MouseEvent, inverted: bool) -> InputAction { + match (input.kind, inverted) { + (event::MouseEventKind::ScrollDown, false) + | (event::MouseEventKind::ScrollUp, true) => { + self.scroll_down(1); + } + (event::MouseEventKind::ScrollDown, true) + | (event::MouseEventKind::ScrollUp, false) => { + self.scroll_up(1); + } + _ => {} + } + InputAction::Continue + } + + fn handle_paste_input(&mut self, input: &str) -> InputAction { + for i in input.chars() { + self.search.input.insert(i); + } + InputAction::Continue + } + + fn cast_cursor_style(style: CursorStyle) -> SetCursorStyle { + match style { + CursorStyle::DefaultUserShape => SetCursorStyle::DefaultUserShape, + CursorStyle::BlinkingBlock => SetCursorStyle::BlinkingBlock, + CursorStyle::SteadyBlock => SetCursorStyle::SteadyBlock, + CursorStyle::BlinkingUnderScore => SetCursorStyle::BlinkingUnderScore, + CursorStyle::SteadyUnderScore => SetCursorStyle::SteadyUnderScore, + CursorStyle::BlinkingBar => SetCursorStyle::BlinkingBar, + CursorStyle::SteadyBar => SetCursorStyle::SteadyBar, + } + } + + fn set_keymap_cursor(&mut self, settings: &Settings, keymap_name: &str) { + let cursor_style = if keymap_name == "__clear__" { + None + } else { + settings.keymap_cursor.get(keymap_name).copied() + } + .or_else(|| self.current_cursor.map(|_| CursorStyle::DefaultUserShape)); + + if cursor_style != self.current_cursor + && let Some(style) = cursor_style + { + self.current_cursor = cursor_style; + let _ = execute!(stdout(), Self::cast_cursor_style(style)); + } + } + + pub fn initialize_keymap_cursor(&mut self, settings: &Settings) { + match self.keymap_mode { + KeymapMode::Emacs => self.set_keymap_cursor(settings, "emacs"), + KeymapMode::VimNormal => self.set_keymap_cursor(settings, "vim_normal"), + KeymapMode::VimInsert => self.set_keymap_cursor(settings, "vim_insert"), + KeymapMode::Auto => {} + } + } + + pub fn finalize_keymap_cursor(&mut self, settings: &Settings) { + match settings.keymap_mode_shell { + KeymapMode::Emacs => self.set_keymap_cursor(settings, "emacs"), + KeymapMode::VimNormal => self.set_keymap_cursor(settings, "vim_normal"), + KeymapMode::VimInsert => self.set_keymap_cursor(settings, "vim_insert"), + KeymapMode::Auto => self.set_keymap_cursor(settings, "__clear__"), + } + } + + fn handle_key_exit(settings: &Settings) -> InputAction { + match settings.exit_mode { + ExitMode::ReturnOriginal => InputAction::ReturnOriginal, + ExitMode::ReturnQuery => InputAction::ReturnQuery, + } + } + + /// Select the keymap for the current mode (ignoring prefix). + fn mode_keymap(&self) -> &super::keybindings::Keymap { + if self.tab_index == 1 { + &self.keymaps.inspector + } else { + match self.keymap_mode { + KeymapMode::Emacs | KeymapMode::Auto => &self.keymaps.emacs, + KeymapMode::VimNormal => &self.keymaps.vim_normal, + KeymapMode::VimInsert => &self.keymaps.vim_insert, + } + } + } + + /// Whether the current mode supports character insertion on unmatched keys. + fn is_insert_mode(&self) -> bool { + matches!( + self.keymap_mode, + KeymapMode::Emacs | KeymapMode::Auto | KeymapMode::VimInsert + ) + } + + fn handle_key_input(&mut self, settings: &Settings, input: &KeyEvent) -> InputAction { + use super::keybindings::Action; + use super::keybindings::EvalContext; + use super::keybindings::key::{KeyCodeValue, KeyInput, SingleKey}; + + // Skip release events + if input.kind == event::KeyEventKind::Release { + return InputAction::Continue; + } + + // Reset switched_search_mode at start of each key event + self.switched_search_mode = false; + + // Build evaluation context from current state + let ctx = EvalContext { + cursor_position: self.search.input.position(), + input_width: UnicodeWidthStr::width(self.search.input.as_str()), + input_byte_len: self.search.input.as_str().len(), + selected_index: self.results_state.selected(), + results_len: self.results_len, + original_input_empty: self.original_input_empty, + has_context: self.search.custom_context.is_some(), + }; + + // Convert KeyEvent to SingleKey + let Some(single) = SingleKey::from_event(input) else { + return InputAction::Continue; + }; + + // --- Phase 1: Resolve (take pending key first, then immutable borrows) --- + + // Take pending key before any immutable borrows of self + let pending = self.pending_vim_key.take(); + + // If in prefix mode, try prefix keymap first (single keys only) + let prefix_action = if self.prefix { + let ki = KeyInput::Single(single.clone()); + self.keymaps.prefix.resolve(&ki, &ctx) + } else { + None + }; + + // The if-let/else-if chain here is clearer than map_or_else with nested closures. + #[expect(clippy::option_if_let_else)] + let (action, new_pending) = if prefix_action.is_some() { + (prefix_action, None) + } else { + // Use mode keymap (handles both single and multi-key sequences) + let keymap = self.mode_keymap(); + + if let Some(pending_char) = pending { + // We have a pending key from a previous press (e.g., first 'g' of 'gg') + let pending_single = SingleKey { + code: KeyCodeValue::Char(pending_char), + ctrl: false, + alt: false, + shift: false, + super_key: false, + }; + let seq = KeyInput::Sequence(vec![pending_single, single.clone()]); + let action = keymap + .resolve(&seq, &ctx) + .or_else(|| keymap.resolve(&KeyInput::Single(single.clone()), &ctx)); + (action, None) + } else if keymap.has_sequence_starting_with(&single) + && matches!(single.code, KeyCodeValue::Char(_)) + && !single.ctrl + && !single.alt + { + // This key starts a multi-key sequence; wait for next key + let KeyCodeValue::Char(c) = single.code else { + unreachable!() + }; + (Some(Action::Noop), Some(c)) + } else { + ( + keymap.resolve(&KeyInput::Single(single.clone()), &ctx), + None, + ) + } + }; + + // --- Phase 2: Apply mutations --- + self.pending_vim_key = new_pending; + + // Reset prefix (before execute, so EnterPrefixMode can re-set it) + self.prefix = false; + + if let Some(action) = action { + self.execute_action(&action, settings) + } else { + // No action matched. In insert-capable modes, insert the character. + if self.is_insert_mode() && !single.ctrl && !single.alt { + match single.code { + KeyCodeValue::Char(c) => { + self.search.input.insert(c); + } + KeyCodeValue::Space => { + self.search.input.insert(' '); + } + _ => {} + } + } + InputAction::Continue + } + } + + fn scroll_down(&mut self, scroll_len: usize) { + let i = self.results_state.selected().saturating_sub(scroll_len); + self.inspecting_state.reset(); + self.results_state.select(i); + } + + fn scroll_up(&mut self, scroll_len: usize) { + let i = self.results_state.selected() + scroll_len; + self.results_state + .select(i.min(self.results_len.saturating_sub(1))); + self.inspecting_state.reset(); + } + + /// Execute a resolved action, performing all side effects and returning the + /// appropriate `InputAction` for the event loop. + /// + /// This is the "do it" half of the resolve+execute pipeline. The resolver + /// decides *what* to do (which `Action`), and this function carries it out. + /// + /// Invert handling: scroll actions (`SelectNext`, `ScrollPageDown`, etc.) account + /// for `settings.invert` so that keybindings are always in "visual" terms — + /// users never need to think about invert in their keybinding config. + #[expect(clippy::too_many_lines)] + pub(crate) fn execute_action( + &mut self, + action: &super::keybindings::Action, + settings: &Settings, + ) -> InputAction { + use crate::command::client::search::keybindings::Action; + + match action { + // -- Cursor movement -- + Action::CursorLeft => { + self.search.input.left(); + InputAction::Continue + } + Action::CursorRight => { + self.search.input.right(); + InputAction::Continue + } + Action::CursorWordLeft => { + self.search + .input + .prev_word(&settings.word_chars, settings.word_jump_mode); + InputAction::Continue + } + Action::CursorWordRight => { + self.search + .input + .next_word(&settings.word_chars, settings.word_jump_mode); + InputAction::Continue + } + Action::CursorWordEnd => { + self.search.input.word_end(&settings.word_chars); + InputAction::Continue + } + Action::CursorStart => { + self.search.input.start(); + InputAction::Continue + } + Action::CursorEnd => { + self.search.input.end(); + InputAction::Continue + } + + // -- Editing -- + Action::DeleteCharBefore => { + self.search.input.back(); + InputAction::Continue + } + Action::DeleteCharAfter => { + self.search.input.remove(); + InputAction::Continue + } + Action::DeleteWordBefore => { + self.search + .input + .remove_prev_word(&settings.word_chars, settings.word_jump_mode); + InputAction::Continue + } + Action::DeleteWordAfter => { + self.search + .input + .remove_next_word(&settings.word_chars, settings.word_jump_mode); + InputAction::Continue + } + Action::DeleteToWordBoundary => { + // ctrl-w: remove trailing whitespace, then delete to word boundary + while matches!(self.search.input.back(), Some(c) if c.is_whitespace()) {} + while self.search.input.left() { + if self.search.input.char().unwrap().is_whitespace() { + self.search.input.right(); + break; + } + self.search.input.remove(); + } + InputAction::Continue + } + Action::ClearLine => { + self.search.input.clear(); + InputAction::Continue + } + Action::ClearToStart => { + self.search.input.clear_to_start(); + InputAction::Continue + } + Action::ClearToEnd => { + self.search.input.clear_to_end(); + InputAction::Continue + } + + // -- List navigation (invert-aware) -- + Action::SelectNext => { + if settings.invert { + self.scroll_up(1); + } else { + self.scroll_down(1); + } + InputAction::Continue + } + Action::SelectPrevious => { + if settings.invert { + self.scroll_down(1); + } else { + self.scroll_up(1); + } + InputAction::Continue + } + // -- Page/half-page scroll (invert-aware) -- + Action::ScrollHalfPageUp => { + let scroll_len = self + .results_state + .max_entries() + .saturating_sub(settings.scroll_context_lines) + / 2; + if settings.invert { + self.scroll_down(scroll_len); + } else { + self.scroll_up(scroll_len); + } + InputAction::Continue + } + Action::ScrollHalfPageDown => { + let scroll_len = self + .results_state + .max_entries() + .saturating_sub(settings.scroll_context_lines) + / 2; + if settings.invert { + self.scroll_up(scroll_len); + } else { + self.scroll_down(scroll_len); + } + InputAction::Continue + } + Action::ScrollPageUp => { + let scroll_len = self + .results_state + .max_entries() + .saturating_sub(settings.scroll_context_lines); + if settings.invert { + self.scroll_down(scroll_len); + } else { + self.scroll_up(scroll_len); + } + InputAction::Continue + } + Action::ScrollPageDown => { + let scroll_len = self + .results_state + .max_entries() + .saturating_sub(settings.scroll_context_lines); + if settings.invert { + self.scroll_up(scroll_len); + } else { + self.scroll_down(scroll_len); + } + InputAction::Continue + } + + // -- Absolute jumps (invert-aware) -- + Action::ScrollToTop => { + // Visual top of history + if settings.invert { + self.results_state.select(0); + } else { + let last_idx = self.results_len.saturating_sub(1); + self.results_state.select(last_idx); + } + self.inspecting_state.reset(); + InputAction::Continue + } + Action::ScrollToBottom => { + // Visual bottom of history + if settings.invert { + let last_idx = self.results_len.saturating_sub(1); + self.results_state.select(last_idx); + } else { + self.results_state.select(0); + } + self.inspecting_state.reset(); + InputAction::Continue + } + Action::ScrollToScreenTop => { + // H — jump to top of visible screen + let top = self.results_state.offset(); + let visible = self.results_state.max_entries().min(self.results_len); + let bottom = top + visible.saturating_sub(1); + self.results_state + .select(bottom.min(self.results_len.saturating_sub(1))); + self.inspecting_state.reset(); + InputAction::Continue + } + Action::ScrollToScreenMiddle => { + // M — jump to middle of visible screen + let top = self.results_state.offset(); + let visible = self.results_state.max_entries().min(self.results_len); + let middle = top + visible / 2; + self.results_state + .select(middle.min(self.results_len.saturating_sub(1))); + self.inspecting_state.reset(); + InputAction::Continue + } + Action::ScrollToScreenBottom => { + // L — jump to bottom of visible screen + let top_visible = self.results_state.offset(); + self.results_state.select(top_visible); + self.inspecting_state.reset(); + InputAction::Continue + } + + // -- Commands -- + Action::Accept => { + if self.tab_index == 1 { + return InputAction::AcceptInspecting; + } + self.accept = true; + InputAction::Accept(self.results_state.selected()) + } + Action::AcceptNth(n) => { + self.accept = true; + InputAction::Accept(self.results_state.selected() + *n as usize) + } + Action::ReturnSelection => { + if self.tab_index == 1 { + return InputAction::AcceptInspecting; + } + InputAction::Accept(self.results_state.selected()) + } + Action::ReturnSelectionNth(n) => { + InputAction::Accept(self.results_state.selected() + *n as usize) + } + Action::Copy => InputAction::Copy(self.results_state.selected()), + Action::Delete => InputAction::Delete(self.results_state.selected()), + Action::DeleteAll => InputAction::DeleteAllMatching(self.results_state.selected()), + Action::ReturnOriginal => InputAction::ReturnOriginal, + Action::ReturnQuery => InputAction::ReturnQuery, + Action::Exit => Self::handle_key_exit(settings), + Action::Redraw => InputAction::Redraw, + Action::CycleFilterMode => { + self.search.rotate_filter_mode(settings, 1); + InputAction::Continue + } + Action::CycleSearchMode => { + self.switched_search_mode = true; + self.search_mode = self.search_mode.next(settings); + self.engine = engines::engine(self.search_mode, settings); + InputAction::Continue + } + Action::SwitchContext => { + InputAction::SwitchContext(Some(self.results_state.selected())) + } + Action::ClearContext => InputAction::SwitchContext(None), + Action::ToggleTab => { + self.tab_index = (self.tab_index + 1) % TAB_TITLES.len(); + InputAction::Continue + } + + // -- Mode changes -- + Action::VimEnterNormal => { + self.set_keymap_cursor(settings, "vim_normal"); + self.keymap_mode = KeymapMode::VimNormal; + InputAction::Continue + } + Action::VimEnterInsert => { + self.set_keymap_cursor(settings, "vim_insert"); + self.keymap_mode = KeymapMode::VimInsert; + InputAction::Continue + } + Action::VimEnterInsertAfter => { + self.search.input.right(); + self.set_keymap_cursor(settings, "vim_insert"); + self.keymap_mode = KeymapMode::VimInsert; + InputAction::Continue + } + Action::VimEnterInsertAtStart => { + self.search.input.start(); + self.set_keymap_cursor(settings, "vim_insert"); + self.keymap_mode = KeymapMode::VimInsert; + InputAction::Continue + } + Action::VimEnterInsertAtEnd => { + self.search.input.end(); + self.set_keymap_cursor(settings, "vim_insert"); + self.keymap_mode = KeymapMode::VimInsert; + InputAction::Continue + } + Action::VimSearchInsert => { + self.search.input.clear(); + self.set_keymap_cursor(settings, "vim_insert"); + self.keymap_mode = KeymapMode::VimInsert; + InputAction::Continue + } + Action::VimChangeToEnd => { + self.search.input.clear_to_end(); + self.set_keymap_cursor(settings, "vim_insert"); + self.keymap_mode = KeymapMode::VimInsert; + InputAction::Continue + } + Action::EnterPrefixMode => { + self.prefix = true; + InputAction::Continue + } + + // -- Inspector -- + Action::InspectPrevious => { + self.inspecting_state.move_to_previous(); + InputAction::Redraw + } + Action::InspectNext => { + self.inspecting_state.move_to_next(); + InputAction::Redraw + } + + // -- Special -- + Action::Noop => InputAction::Continue, + } + } + + #[expect(clippy::cast_possible_truncation)] + #[expect(clippy::bool_to_int_with_if)] + fn calc_preview_height( + settings: &Settings, + results: &[History], + selected: usize, + tab_index: usize, + compactness: Compactness, + border_size: u16, + preview_width: u16, + ) -> u16 { + if settings.show_preview + && settings.preview.strategy == PreviewStrategy::Auto + && tab_index == 0 + && !results.is_empty() + { + let length_current_cmd = results[selected].command.len() as u16; + // calculate the number of newlines in the command + let num_newlines = results[selected] + .command + .chars() + .filter(|&c| c == '\n') + .count() as u16; + if num_newlines > 0 { + std::cmp::min( + settings.max_preview_height, + results[selected] + .command + .split('\n') + .map(|line| { + (line.len() as u16 + preview_width - 1 - border_size) + / (preview_width - border_size) + }) + .sum(), + ) + border_size * 2 + } + // The '- 19' takes the characters before the command (duration and time) into account + else if length_current_cmd > preview_width - 19 { + std::cmp::min( + settings.max_preview_height, + (length_current_cmd + preview_width - 1 - border_size) + / (preview_width - border_size), + ) + border_size * 2 + } else { + 1 + } + } else if settings.show_preview + && settings.preview.strategy == PreviewStrategy::Static + && tab_index == 0 + { + let longest_command = results + .iter() + .max_by(|h1, h2| h1.command.len().cmp(&h2.command.len())); + longest_command.map_or(0, |v| { + std::cmp::min( + settings.max_preview_height, + v.command + .split('\n') + .map(|line| { + (line.len() as u16 + preview_width - 1 - border_size) + / (preview_width - border_size) + }) + .sum(), + ) + }) + border_size * 2 + } else if settings.show_preview && settings.preview.strategy == PreviewStrategy::Fixed { + settings.max_preview_height + border_size * 2 + } else if !matches!(compactness, Compactness::Full) || tab_index == 1 { + 0 + } else { + 1 + } + } + + #[expect(clippy::bool_to_int_with_if)] + #[expect(clippy::too_many_lines)] + #[expect(clippy::too_many_arguments)] + fn draw( + &mut self, + f: &mut Frame, + results: &[History], + stats: Option<HistoryStats>, + inspecting: Option<&History>, + settings: &Settings, + theme: &Theme, + popup_mode: bool, + ) { + let area = f.area(); + if popup_mode { + f.render_widget(Clear, area); + } + self.draw_inner(f, area, results, stats, inspecting, settings, theme); + } + + #[expect(clippy::too_many_arguments)] + #[expect(clippy::too_many_lines)] + #[expect(clippy::bool_to_int_with_if)] + fn draw_inner( + &mut self, + f: &mut Frame, + area: Rect, + results: &[History], + stats: Option<HistoryStats>, + inspecting: Option<&History>, + settings: &Settings, + theme: &Theme, + ) { + let compactness = to_compactness(f, settings); + let invert = settings.invert; + let border_size = match compactness { + Compactness::Full => 1, + _ => 0, + }; + let preview_width = area.width.saturating_sub(2); + let preview_height = Self::calc_preview_height( + settings, + results, + self.results_state.selected(), + self.tab_index, + compactness, + border_size, + preview_width, + ); + let show_help = + settings.show_help && (matches!(compactness, Compactness::Full) || area.height > 1); + // This is an OR, as it seems more likely for someone to wish to override + // tabs unexpectedly being missed, than unexpectedly present. + let show_tabs = settings.show_tabs && !matches!(compactness, Compactness::Ultracompact); + let chunks = Layout::default() + .direction(Direction::Vertical) + .margin(0) + .horizontal_margin(1) + .constraints::<&[Constraint]>( + if invert { + [ + Constraint::Length(1 + border_size), // input + Constraint::Min(1), // results list + Constraint::Length(preview_height), // preview + Constraint::Length(if show_tabs { 1 } else { 0 }), // tabs + Constraint::Length(if show_help { 1 } else { 0 }), // header (sic) + ] + } else { + match compactness { + Compactness::Ultracompact => [ + Constraint::Length(if show_help { 1 } else { 0 }), // header + Constraint::Length(0), // tabs + Constraint::Min(1), // results list + Constraint::Length(0), + Constraint::Length(0), + ], + _ => [ + Constraint::Length(if show_help { 1 } else { 0 }), // header + Constraint::Length(if show_tabs { 1 } else { 0 }), // tabs + Constraint::Min(1), // results list + Constraint::Length(1 + border_size), // input + Constraint::Length(preview_height), // preview + ], + } + } + .as_ref(), + ) + .split(area); + + let input_chunk = if invert { chunks[0] } else { chunks[3] }; + let results_list_chunk = if invert { chunks[1] } else { chunks[2] }; + let preview_chunk = if invert { chunks[2] } else { chunks[4] }; + let tabs_chunk = if invert { chunks[3] } else { chunks[1] }; + let header_chunk = if invert { chunks[4] } else { chunks[0] }; + + // TODO: this should be split so that we have one interactive search container that is + // EITHER a search box or an inspector. But I'm not doing that now, way too much atm. + // also allocate less 🙈 + let titles: Vec<_> = TAB_TITLES.iter().copied().map(Line::from).collect(); + + if show_tabs { + let tabs = Tabs::new(titles) + .block(Block::default().borders(Borders::NONE)) + .select(self.tab_index) + .style(Style::default()) + .highlight_style(Style::from_crossterm(theme.as_style(Meaning::Important))); + + f.render_widget(tabs, tabs_chunk); + } + + let style = StyleState { + compactness, + invert, + inner_width: input_chunk.width.into(), + }; + + let header_chunks = Layout::default() + .direction(Direction::Horizontal) + .constraints::<&[Constraint]>( + [ + Constraint::Ratio(1, 5), + Constraint::Ratio(3, 5), + Constraint::Ratio(1, 5), + ] + .as_ref(), + ) + .split(header_chunk); + + let title = Self::build_title(theme); + f.render_widget(title, header_chunks[0]); + + let help = self.build_help(settings, theme); + f.render_widget(help, header_chunks[1]); + + let stats_tab = self.build_stats(theme); + f.render_widget(stats_tab, header_chunks[2]); + + let indicator: String = match compactness { + Compactness::Ultracompact => { + if self.switched_search_mode { + format!("S{}>", self.search_mode.as_str().chars().next().unwrap()) + } else if self.search.custom_context.is_some() { + format!( + "C{}>", + self.search.filter_mode.as_str().chars().next().unwrap() + ) + } else { + format!( + "{}> ", + self.search.filter_mode.as_str().chars().next().unwrap() + ) + } + } + _ => " > ".to_string(), + }; + + match self.tab_index { + 0 => { + let history_highlighter = HistoryHighlighter { + engine: self.engine.as_ref(), + search_input: self.search.input.as_str(), + }; + let results_list = Self::build_results_list( + style, + results, + self.keymap_mode, + &self.now, + indicator.as_str(), + theme, + history_highlighter, + settings.show_numeric_shortcuts, + &settings.ui.columns, + ); + f.render_stateful_widget(results_list, results_list_chunk, &mut self.results_state); + } + + 1 => { + if results.is_empty() { + let message = Paragraph::new("Nothing to inspect") + .block( + Block::new() + .title(Line::from(" Info ".to_string())) + .title_alignment(Alignment::Center) + .borders(Borders::ALL) + .padding(Padding::vertical(2)), + ) + .alignment(Alignment::Center); + f.render_widget(message, results_list_chunk); + } else { + let inspecting = match inspecting { + Some(inspecting) => inspecting, + None => &results[self.results_state.selected()], + }; + super::inspector::draw( + f, + results_list_chunk, + inspecting, + &stats.expect("Drawing inspector, but no stats"), + settings, + theme, + settings.timezone, + ); + } + + // HACK: I'm following up with abstracting this into the UI container, with a + // sub-widget for search + for inspector + let feedback = Paragraph::new( + "The inspector is new - please give feedback (good, or bad) at https://forum.atuin.sh", + ); + f.render_widget(feedback, input_chunk); + + return; + } + + _ => { + panic!("invalid tab index"); + } + } + + if !matches!(compactness, Compactness::Ultracompact) { + let preview_width = match compactness { + Compactness::Full => preview_width - 2, + _ => preview_width, + }; + let preview = self.build_preview( + results, + compactness, + preview_width, + preview_chunk.width.into(), + theme, + ); + #[expect(clippy::cast_possible_truncation)] + let prefix_width = settings + .ui + .columns + .iter() + .take_while(|col| !col.expand) + .map(|col| col.width + 1) + .sum::<u16>() + + " > ".len() as u16; + #[expect(clippy::cast_possible_truncation)] + let min_prefix_width = "[ SRCH: FULLTXT ] ".len() as u16; + self.draw_preview( + f, + style, + input_chunk, + compactness, + preview_chunk, + preview, + std::cmp::max(prefix_width, min_prefix_width), + ); + } + } + + #[expect(clippy::cast_possible_truncation, clippy::too_many_arguments)] + fn draw_preview( + &self, + f: &mut Frame, + style: StyleState, + input_chunk: Rect, + compactness: Compactness, + preview_chunk: Rect, + preview: Paragraph, + prefix_width: u16, + ) { + let input = self.build_input(style, prefix_width); + f.render_widget(input, input_chunk); + + f.render_widget(preview, preview_chunk); + + let extra_width = UnicodeWidthStr::width(self.search.input.substring()); + + let cursor_offset = match compactness { + Compactness::Full => 1, + _ => 0, + }; + f.set_cursor_position(( + // Put cursor past the end of the input text + input_chunk.x + extra_width as u16 + prefix_width + cursor_offset, + input_chunk.y + cursor_offset, + )); + } + + fn build_title(theme: &Theme) -> Paragraph<'_> { + let title = { + let style: Style = Style::from_crossterm(theme.as_style(Meaning::Base)); + Paragraph::new(Text::from(Span::styled( + format!("Atuin v{VERSION}"), + style.add_modifier(Modifier::BOLD), + ))) + }; + title.alignment(Alignment::Left) + } + + #[expect(clippy::unused_self)] + fn build_help(&self, settings: &Settings, theme: &Theme) -> Paragraph<'_> { + match self.tab_index { + // search + 0 => Paragraph::new(Text::from(Line::from(vec![ + Span::styled("<esc>", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(": exit"), + Span::raw(", "), + Span::styled("<tab>", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(": edit"), + Span::raw(", "), + Span::styled("<enter>", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(if settings.enter_accept { + ": run" + } else { + ": edit" + }), + Span::raw(", "), + Span::styled("<ctrl-o>", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(": inspect"), + ]))), + + 1 => Paragraph::new(Text::from(Line::from(vec![ + Span::styled("<esc>", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(": exit"), + Span::raw(", "), + Span::styled("<ctrl-o>", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(": search"), + Span::raw(", "), + Span::styled("<ctrl-d>", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(": delete"), + ]))), + + _ => unreachable!("invalid tab index"), + } + .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))) + .alignment(Alignment::Center) + } + + fn build_stats(&self, theme: &Theme) -> Paragraph<'_> { + Paragraph::new(Text::from(Span::raw(format!( + "history count: {}", + self.history_count, + )))) + .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))) + .alignment(Alignment::Right) + } + + #[expect(clippy::too_many_arguments)] + fn build_results_list<'a>( + style: StyleState, + results: &'a [History], + keymap_mode: KeymapMode, + now: &'a dyn Fn() -> OffsetDateTime, + indicator: &'a str, + theme: &'a Theme, + history_highlighter: HistoryHighlighter<'a>, + show_numeric_shortcuts: bool, + columns: &'a [UiColumn], + ) -> HistoryList<'a> { + let results_list = HistoryList::new( + results, + style.invert, + keymap_mode == KeymapMode::VimNormal, + now, + indicator, + theme, + history_highlighter, + show_numeric_shortcuts, + columns, + ); + + match style.compactness { + Compactness::Full => { + if style.invert { + results_list.block( + Block::default() + .borders(Borders::LEFT | Borders::RIGHT) + .border_type(BorderType::Rounded) + .title(format!("{:─>width$}", "", width = style.inner_width - 2)), + ) + } else { + results_list.block( + Block::default() + .borders(Borders::TOP | Borders::LEFT | Borders::RIGHT) + .border_type(BorderType::Rounded), + ) + } + } + _ => results_list, + } + } + + fn build_input(&self, style: StyleState, prefix_width: u16) -> Paragraph<'_> { + let (pref, mode) = if self.switched_search_mode { + (" SRCH:", self.search_mode.as_str()) + } else if self.search.custom_context.is_some() { + (" CTX:", self.search.filter_mode.as_str()) + } else { + ("", self.search.filter_mode.as_str()) + }; + // 3: surrounding "[" "] " + let mode_width = usize::from(prefix_width) - pref.len() - 3; + // sanity check to ensure we don't exceed the layout limits + debug_assert!(mode_width >= mode.len(), "mode name '{mode}' is too long!"); + let input = format!("[{pref}{mode:^mode_width$}] {}", self.search.input.as_str()); + let input = Paragraph::new(input); + match style.compactness { + Compactness::Full => { + if style.invert { + input.block( + Block::default() + .borders(Borders::LEFT | Borders::RIGHT | Borders::TOP) + .border_type(BorderType::Rounded), + ) + } else { + input.block( + Block::default() + .borders(Borders::LEFT | Borders::RIGHT) + .border_type(BorderType::Rounded) + .title(format!("{:─>width$}", "", width = style.inner_width - 2)), + ) + } + } + _ => input, + } + } + + fn build_preview( + &self, + results: &[History], + compactness: Compactness, + preview_width: u16, + chunk_width: usize, + theme: &Theme, + ) -> Paragraph<'_> { + let selected = self.results_state.selected(); + let command = if results.is_empty() { + String::new() + } else { + let s = &results[selected].command; + let mut lines = Vec::new(); + for line in s.split('\n') { + let line = line.escape_control(); + let mut width = 0; + let mut start = 0; + for (idx, ch) in line.char_indices() { + let w = ch.width().unwrap_or(0); // None for control chars which should not happen + if width + w > preview_width.into() { + lines.push(line[start..idx].to_owned()); + start = idx; + width = w; + } else { + width += w; + } + } + if width != 0 { + lines.push(line[start..].to_owned()); + } + } + lines.join("\n") + }; + + match compactness { + Compactness::Full => Paragraph::new(command).block( + Block::default() + .borders(Borders::BOTTOM | Borders::LEFT | Borders::RIGHT) + .border_type(BorderType::Rounded) + .title(format!("{:─>width$}", "", width = chunk_width - 2)), + ), + _ => Paragraph::new(command) + .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))), + } + } +} + +/// The writer used for terminal output - either stdout or /dev/tty +enum TerminalWriter { + Stdout(std::io::Stdout), + #[cfg(unix)] + Tty(std::fs::File), +} + +impl TerminalWriter { + fn new() -> std::io::Result<Self> { + let stdout = stdout(); + if stdout.is_terminal() { + return Ok(TerminalWriter::Stdout(stdout)); + } + + // If stdout is not a terminal (e.g., captured by command substitution), + // fall back to /dev/tty so the TUI can still render. + // This allows usage like: VAR=$(atuin search -i) + #[cfg(unix)] + { + Ok(TerminalWriter::Tty( + std::fs::File::options() + .read(true) + .write(true) + .open("/dev/tty")?, + )) + } + } +} + +impl Write for TerminalWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> { + match self { + TerminalWriter::Stdout(stdout) => stdout.write(buf), + #[cfg(unix)] + TerminalWriter::Tty(file) => file.write(buf), + } + } + + fn flush(&mut self) -> std::io::Result<()> { + match self { + TerminalWriter::Stdout(stdout) => stdout.flush(), + #[cfg(unix)] + TerminalWriter::Tty(file) => file.flush(), + } + } +} + +/// Screen state captured from atuin pty-proxy's screen server. +#[cfg(unix)] +struct SavedScreen { + #[expect(dead_code)] + rows: u16, + #[expect(dead_code)] + cols: u16, + cursor_row: u16, + cursor_col: u16, + /// Pre-formatted ANSI bytes for each screen row, ready to write to stdout. + rows_data: Vec<Vec<u8>>, +} + +/// Connect to atuin pty-proxy's Unix socket and fetch the current screen state. +/// +/// The wire format is: +/// ```text +/// [rows: u16 BE][cols: u16 BE][cursor_row: u16 BE][cursor_col: u16 BE] +/// [row_0_len: u32 BE][row_0_bytes...] +/// [row_1_len: u32 BE][row_1_bytes...] +/// ... +/// ``` +#[cfg(unix)] +fn fetch_screen_state(socket_path: &str) -> Option<SavedScreen> { + use std::os::unix::net::UnixStream; + + let mut stream = UnixStream::connect(socket_path).ok()?; + stream.set_read_timeout(Some(Duration::from_secs(2))).ok()?; + + let mut data = Vec::new(); + stream.read_to_end(&mut data).ok()?; + + if data.len() < 8 { + return None; + } + + let rows = u16::from_be_bytes([data[0], data[1]]); + let cols = u16::from_be_bytes([data[2], data[3]]); + let cursor_row = u16::from_be_bytes([data[4], data[5]]); + let cursor_col = u16::from_be_bytes([data[6], data[7]]); + + // Parse length-prefixed rows + let mut rows_data = Vec::with_capacity(rows as usize); + let mut offset = 8; + while offset + 4 <= data.len() { + let row_len = u32::from_be_bytes([ + data[offset], + data[offset + 1], + data[offset + 2], + data[offset + 3], + ]) as usize; + offset += 4; + if offset + row_len > data.len() { + break; + } + rows_data.push(data[offset..offset + row_len].to_vec()); + offset += row_len; + } + + Some(SavedScreen { + rows, + cols, + cursor_row, + cursor_col, + rows_data, + }) +} + +/// Restore the screen area that was covered by the popup. +/// +/// Writes the pre-formatted per-row ANSI bytes received from atuin pty-proxy +/// directly to stdout, which correctly handles wide characters, colors, and +/// all text attributes without needing a client-side vt100 parser. +#[cfg(unix)] +fn restore_popup_area(saved: &SavedScreen, popup_rect: Rect, scroll_offset: u16) { + use ratatui::crossterm::cursor::MoveTo; + + let mut stdout = stdout(); + + for dy in 0..popup_rect.height { + let target_row = popup_rect.y + dy; + let source_row = (target_row + scroll_offset) as usize; + + // Clear only the popup region. The server-side rows_formatted() skips + // default cells (spaces with default attributes) using cursor jumps, so + // any popup content at those positions would remain if not cleared + // beforehand. We write `popup_rect.width` spaces instead of + // ClearType::CurrentLine so that only the popup area is cleared, not + // the entire terminal line. + let _ = execute!( + stdout, + MoveTo(popup_rect.x, target_row), + ratatui::crossterm::style::SetAttribute(ratatui::crossterm::style::Attribute::Reset), + ); + let _ = write!(stdout, "{:width$}", "", width = popup_rect.width as usize); + let _ = execute!(stdout, MoveTo(popup_rect.x, target_row)); + + if let Some(row_bytes) = saved.rows_data.get(source_row) { + let _ = stdout.write_all(row_bytes); + } + } + + let _ = execute!( + stdout, + MoveTo( + saved.cursor_col, + saved.cursor_row.saturating_sub(scroll_offset) + ) + ); + let _ = stdout.flush(); +} + +struct Stdout { + writer: TerminalWriter, + inline_mode: bool, + no_mouse: bool, +} + +impl Stdout { + pub fn new(inline_mode: bool, no_mouse: bool) -> std::io::Result<Self> { + terminal::enable_raw_mode()?; + + let mut writer = TerminalWriter::new()?; + + if !inline_mode { + execute!(writer, terminal::EnterAlternateScreen)?; + } + + if !no_mouse { + execute!(writer, event::EnableMouseCapture)?; + } + + execute!(writer, event::EnableBracketedPaste)?; + + #[cfg(not(target_os = "windows"))] + execute!( + writer, + PushKeyboardEnhancementFlags( + KeyboardEnhancementFlags::DISAMBIGUATE_ESCAPE_CODES + | KeyboardEnhancementFlags::REPORT_ALL_KEYS_AS_ESCAPE_CODES + | KeyboardEnhancementFlags::REPORT_ALTERNATE_KEYS + ), + )?; + + Ok(Self { + writer, + inline_mode, + no_mouse, + }) + } +} + +impl Drop for Stdout { + fn drop(&mut self) { + #[cfg(not(target_os = "windows"))] + if let Err(e) = execute!(self.writer, PopKeyboardEnhancementFlags) { + tracing::error!(?e, "Failed to pop keyboard enhancement flags"); + } + + if !self.inline_mode + && let Err(e) = execute!(self.writer, terminal::LeaveAlternateScreen) + { + tracing::error!(?e, "Failed to leave alt screen mode"); + } + + if !self.no_mouse + && let Err(e) = execute!(self.writer, event::DisableMouseCapture) + { + tracing::error!(?e, "Failed to disable mouse capture"); + } + + if let Err(e) = execute!(self.writer, event::DisableBracketedPaste) { + tracing::error!(?e, "Failed to disable bracketed paste"); + } + + if let Err(e) = terminal::disable_raw_mode() { + tracing::error!(?e, "Failed to disable raw mode"); + } + } +} + +impl Write for Stdout { + fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> { + self.writer.write(buf) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.writer.flush() + } +} + +// this is a big blob of horrible! clean it up! +/// Compute the popup position and any scroll offset needed to make room. +/// +/// Given the cursor row, terminal dimensions, and desired popup height, +/// returns `(popup_rect, scroll_offset)` where `scroll_offset` is the number +/// of lines the caller should scroll the terminal up before rendering. +/// +/// This function performs no I/O — it is a pure computation. +#[cfg(unix)] +fn compute_popup_placement( + cursor_row: u16, + term_rows: u16, + term_cols: u16, + inline_height: u16, +) -> (Rect, u16) { + let popup_w = term_cols; + let popup_h = inline_height.min(term_rows); + let space_below = term_rows.saturating_sub(cursor_row); + + let (popup_y, scroll) = if popup_h <= space_below { + // Fits below cursor + (cursor_row, 0u16) + } else if cursor_row >= term_rows / 2 { + // Bottom half — render above cursor (overlay on existing text) + (cursor_row.saturating_sub(popup_h), 0u16) + } else { + // Top half, not enough space — scroll terminal to make room + let scroll = popup_h.saturating_sub(space_below); + let popup_y = cursor_row.saturating_sub(scroll); + (popup_y, scroll) + }; + + (Rect::new(0, popup_y, popup_w, popup_h), scroll) +} + +// for now, it works. But it'd be great if it were more easily readable, and +// modular. I'd like to add some more stats and stuff at some point +#[expect( + clippy::cast_possible_truncation, + clippy::too_many_lines, + clippy::cognitive_complexity +)] +pub async fn history( + query: &[String], + settings: &Settings, + mut db: impl Database, + history_store: &HistoryStore, + theme: &Theme, +) -> Result<String> { + let inline_height = if settings.shell_up_key_binding { + settings + .inline_height_shell_up_key_binding + .unwrap_or(settings.inline_height) + } else { + settings.inline_height + }; + + // Use fullscreen mode if the inline height doesn't fit in the terminal, + // this will preserve the scroll position upon exit. + // Also force fullscreen when stdout isn't a terminal (e.g., command substitution + // like VAR=$(atuin search -i)). In that case, we need to use /dev/tty for the TUI and force + // fullscreen mode (inline mode won't work as it requires cursor position queries + // that don't work when stdout is captured). + let inline_height = if !stdout().is_terminal() { + 0 + } else if let Ok(size) = terminal::size() + && inline_height >= size.1 + { + 0 + } else { + inline_height + }; + + // Popup mode: if running under atuin pty-proxy and inline mode is requested, + // fetch the screen state and render as a centered overlay. + #[cfg(unix)] + let (saved_screen, popup_rect, popup_scroll_offset) = { + let socket_path = std::env::var("ATUIN_PTY_PROXY_SOCKET") + .or_else(|_| std::env::var("ATUIN_HEX_SOCKET")) + .ok(); + if let Some(ref path) = socket_path + && inline_height > 0 + { + let saved = fetch_screen_state(path); + if let Some(ref s) = saved { + let (term_cols, term_rows) = terminal::size().unwrap_or((s.cols, s.rows)); + let (popup_rect, scroll) = + compute_popup_placement(s.cursor_row, term_rows, term_cols, inline_height); + + // Scroll terminal content up to make room if needed + if scroll > 0 { + use ratatui::crossterm::cursor::MoveTo; + let mut stdout = stdout(); + let _ = execute!(stdout, MoveTo(0, term_rows - 1)); + for _ in 0..scroll { + let _ = writeln!(stdout); + } + let _ = stdout.flush(); + } + + (saved, popup_rect, scroll) + } else { + (None, Rect::default(), 0u16) + } + } else { + (None, Rect::default(), 0u16) + } + }; + + let popup_mode = saved_screen.is_some(); + + let stdout = Stdout::new(inline_height > 0, settings.no_mouse)?; + + // In popup mode, clear the popup region on the physical terminal before + // ratatui takes over. Ratatui's diff-based rendering compares against an + // initially-empty buffer, so cells that remain "empty" (spaces with default + // style) won't be written — leaving underlying terminal text visible. + // By pre-clearing with spaces, those cells are already correct on screen. + if popup_mode { + use ratatui::crossterm::cursor::MoveTo; + let mut raw_stdout = std::io::stdout(); + // Queue all commands without flushing so the terminal receives them + // as a single write — no intermediate cursor positions are visible. + let _ = queue!( + raw_stdout, + ratatui::crossterm::style::SetAttribute(ratatui::crossterm::style::Attribute::Reset) + ); + for row in popup_rect.y..popup_rect.y.saturating_add(popup_rect.height) { + let _ = queue!(raw_stdout, MoveTo(popup_rect.x, row)); + let _ = write!( + raw_stdout, + "{:width$}", + "", + width = popup_rect.width as usize + ); + } + let _ = raw_stdout.flush(); + } + + let backend = CrosstermBackend::new(stdout); + let mut terminal = Terminal::with_options( + backend, + TerminalOptions { + viewport: if popup_mode { + Viewport::Fixed(popup_rect) + } else if inline_height > 0 { + Viewport::Inline(inline_height) + } else { + Viewport::Fullscreen + }, + }, + )?; + + let original_query = query.join(" "); + + // Check if this is a command chaining scenario + let is_command_chaining = if settings.command_chaining { + let trimmed = original_query.trim_end(); + trimmed.ends_with("&&") || trimmed.ends_with('|') + } else { + false + }; + + // For command chaining, start with empty input to allow searching for new commands + let search_input = if is_command_chaining { + String::new() + } else { + original_query.clone() + }; + + let mut input = Cursor::from(search_input); + // Put the cursor at the end of the query by default + input.end(); + + let initial_context = current_context().await?; + + let history_count = db.history_count(false).await?; + let search_mode = if settings.shell_up_key_binding { + settings + .search_mode_shell_up_key_binding + .unwrap_or(settings.search_mode) + } else { + settings.search_mode + }; + let default_filter_mode = settings + .filter_mode_shell_up_key_binding + .filter(|_| settings.shell_up_key_binding) + .unwrap_or_else(|| settings.default_filter_mode(initial_context.git_root.is_some())); + let mut app = State { + history_count, + results_state: ListState::default(), + switched_search_mode: false, + search_mode, + tab_index: 0, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::from_settings(settings), + search: SearchState { + input, + filter_mode: default_filter_mode, + context: initial_context.clone(), + custom_context: None, + }, + engine: engines::engine(search_mode, settings), + results_len: 0, + accept: false, + keymap_mode: match settings.keymap_mode { + KeymapMode::Auto => KeymapMode::Emacs, + value => value, + }, + current_cursor: None, + now: if settings.prefers_reduced_motion { + let now = OffsetDateTime::now_utc(); + Box::new(move || now) + } else { + Box::new(OffsetDateTime::now_utc) + }, + prefix: false, + pending_vim_key: None, + original_input_empty: original_query.is_empty(), + }; + + app.initialize_keymap_cursor(settings); + + let mut results = app.query_results(&mut db, settings.smart_sort).await?; + + if inline_height > 0 && !popup_mode { + terminal.clear()?; + } + + let mut stats: Option<HistoryStats> = None; + let mut inspecting: Option<History> = None; + let accept; + let result = 'render: loop { + terminal.draw(|f| { + app.draw( + f, + &results, + stats.clone(), + inspecting.as_ref(), + settings, + theme, + popup_mode, + ); + })?; + + let initial_input = app.search.input.as_str().to_owned(); + let initial_filter_mode = app.search.filter_mode; + let initial_search_mode = app.search_mode; + let initial_custom_context = app.search.custom_context.clone(); + + let event_ready = tokio::task::spawn_blocking(|| event::poll(Duration::from_millis(250))); + + tokio::select! { + event_ready = event_ready => { + if event_ready?? { + loop { + match app.handle_input(settings, &event::read()?) { + InputAction::Continue => {}, + InputAction::Delete(index) => { + if results.is_empty() { + break; + } + app.results_len -= 1; + let selected = app.results_state.selected(); + if selected == app.results_len { + app.inspecting_state.reset(); + app.results_state.select(selected - 1); + } + + let entry = results.remove(index); + + let ids = history_store.delete_entries([entry]).await?; + history_store.incremental_build(&db, &ids).await?; + + app.tab_index = 0; + }, + InputAction::DeleteAllMatching(index) => { + if results.is_empty() { + break; + } + + let command = results[index].command.clone(); + + // Remove matching entries from the visible results + results.retain(|e| e.command != command); + + // Query the DB for ALL entries with this command and delete them + let all_matching = db.query_history( + &format!( + "select * from history where command = '{}' and deleted_at is null", + command.replace('\'', "''") + ) + ).await?; + + let ids = history_store.delete_entries(all_matching).await?; + history_store.incremental_build(&db, &ids).await?; + + app.results_len = results.len(); + app.results_state = ListState::default(); + app.inspecting_state.reset(); + app.tab_index = 0; + }, + InputAction::SwitchContext(index) => { + if let Some(index) = index && let Some(entry) = results.get(index) { + app.search.custom_context = Some(entry.id.clone()); + app.search.context = Context::from_history(entry); + app.search.filter_mode = FilterMode::Session; + app.search.input = Cursor::from(String::new()); + app.results_state = ListState::default(); + } else { + app.search.custom_context = None; + app.search.context = initial_context.clone(); + app.search.filter_mode = default_filter_mode; + } + }, + InputAction::Redraw => { + if !popup_mode { + terminal.clear()?; + } + terminal.draw(|f| { + app.draw(f, &results, stats.clone(), inspecting.as_ref(), settings, theme, popup_mode); + })?; + }, + r => { + accept = app.accept; + break 'render r; + }, + } + if !event::poll(Duration::ZERO)? { + break; + } + } + } + } + } + + if initial_input != app.search.input.as_str() + || initial_filter_mode != app.search.filter_mode + || initial_search_mode != app.search_mode + || initial_custom_context != app.search.custom_context + { + results = app.query_results(&mut db, settings.smart_sort).await?; + } + + // In custom context mode, when no filter is applied, highlight the entry which was used + // to enter the context when changing modes. This helps to find your way around. + if app.search.custom_context.is_some() + && app.search.input.as_str().is_empty() + && (initial_custom_context != app.search.custom_context + || initial_filter_mode != app.search.filter_mode) + && let Some(history_id) = app.search.custom_context.clone() + && let Some(pos) = results.iter().position(|entry| entry.id == history_id) + { + app.results_state.select(pos); + } + + let inspecting_id = app.inspecting_state.clone().current; + // If inspecting ID is not the current inspecting History, update it. + match inspecting_id { + Some(inspecting_id) => { + if inspecting.is_none() || inspecting_id != inspecting.clone().unwrap().id { + inspecting = db.load(inspecting_id.0.as_str()).await?; + } + } + _ => { + inspecting = None; + } + } + + stats = if app.tab_index == 0 { + None + } else if !results.is_empty() { + // If we have stats, then we can indicate next available IDs. This avoids passing + // around a database object, or a full stats object. + let selected = match inspecting.clone() { + Some(insp) => insp, + None => results[app.results_state.selected()].clone(), + }; + let stats = db.stats(&selected).await?; + app.inspecting_state.current = Some(selected.id); + app.inspecting_state.previous = match stats.previous.clone() { + Some(p) => Some(p.id), + _ => None, + }; + app.inspecting_state.next = match stats.next.clone() { + Some(p) => Some(p.id), + _ => None, + }; + Some(stats) + } else { + None + }; + }; + + app.finalize_keymap_cursor(settings); + + if popup_mode { + // In popup mode, restore the screen area that was covered by the popup. + // This must happen before Stdout is dropped (which disables raw mode). + #[cfg(unix)] + if let Some(ref saved) = saved_screen { + restore_popup_area(saved, popup_rect, popup_scroll_offset); + } + } else if inline_height > 0 { + terminal.clear()?; + } + + let accept = accept + && matches!( + Shell::from_env(), + Shell::Zsh | Shell::Fish | Shell::Bash | Shell::Xonsh | Shell::Nu | Shell::Powershell + ); + + let accept_prefix = "__atuin_accept__:"; + + match result { + InputAction::AcceptInspecting => { + match inspecting { + Some(result) => { + let mut command = result.command; + + if accept { + command = String::from(accept_prefix) + &command; + } + + // index is in bounds so we return that entry + Ok(command) + } + None => Ok(String::new()), + } + } + InputAction::Accept(index) if index < results.len() => { + let mut command = results.swap_remove(index).command; + + if is_command_chaining { + command = format!("{} {}", original_query.trim_end(), command); + } else if accept { + command = String::from(accept_prefix) + &command; + } + + // index is in bounds so we return that entry + Ok(command) + } + InputAction::ReturnOriginal => Ok(String::new()), + InputAction::Copy(index) => { + let cmd = results.swap_remove(index).command; + set_clipboard(cmd); + Ok(String::new()) + } + InputAction::ReturnQuery | InputAction::Accept(_) => { + // Either: + // * index == RETURN_QUERY, in which case we should return the input + // * out of bounds -> usually implies no selected entry so we return the input + Ok(app.search.input.into_inner()) + } + InputAction::Continue + | InputAction::Redraw + | InputAction::Delete(_) + | InputAction::DeleteAllMatching(_) + | InputAction::SwitchContext(_) => { + unreachable!("should have been handled!") + } + } +} + +// cli-clipboard only works on Windows, Mac, and Linux. + +#[cfg(all( + feature = "clipboard", + any(target_os = "windows", target_os = "macos", target_os = "linux") +))] +fn set_clipboard(s: String) { + let mut ctx = arboard::Clipboard::new().unwrap(); + ctx.set_text(s).unwrap(); + // Use the clipboard context to make sure it is saved + ctx.get_text().unwrap(); +} + +#[cfg(not(all( + feature = "clipboard", + any(target_os = "windows", target_os = "macos", target_os = "linux") +)))] +fn set_clipboard(_s: String) {} + +#[cfg(test)] +mod tests { + use crate::atuin_client::database::Context; + use crate::atuin_client::history::History; + use crate::atuin_client::settings::{ + FilterMode, KeymapMode, Preview, PreviewStrategy, SearchMode, Settings, + }; + use time::OffsetDateTime; + + use crate::command::client::search::engines::{self, SearchState}; + use crate::command::client::search::history_list::ListState; + + use super::{Compactness, InspectingState, KeymapSet, State}; + + #[test] + #[expect(clippy::too_many_lines)] + fn calc_preview_height_test() { + let settings_preview_auto = Settings { + preview: Preview { + strategy: PreviewStrategy::Auto, + }, + show_preview: true, + ..Settings::utc() + }; + + let settings_preview_auto_h2 = Settings { + preview: Preview { + strategy: PreviewStrategy::Auto, + }, + show_preview: true, + max_preview_height: 2, + ..Settings::utc() + }; + + let settings_preview_h4 = Settings { + preview: Preview { + strategy: PreviewStrategy::Static, + }, + show_preview: true, + max_preview_height: 4, + ..Settings::utc() + }; + + let settings_preview_fixed = Settings { + preview: Preview { + strategy: PreviewStrategy::Fixed, + }, + show_preview: true, + max_preview_height: 15, + ..Settings::utc() + }; + + let cmd_60: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("for i in $(seq -w 10); do echo \"item number $i - abcd\"; done") + .cwd("/") + .build() + .into(); + + let cmd_124: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("echo 'Aurea prima sata est aetas, quae vindice nullo, sponte sua, sine lege fidem rectumque colebat. Poena metusque aberant'") + .cwd("/") + .build() + .into(); + + let cmd_200: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("CREATE USER atuin WITH ENCRYPTED PASSWORD 'supersecretpassword'; CREATE DATABASE atuin WITH OWNER = atuin; \\c atuin; REVOKE ALL PRIVILEGES ON SCHEMA public FROM PUBLIC; echo 'All done. 200 characters'") + .cwd("/") + .build() + .into(); + + let results: Vec<History> = vec![cmd_60, cmd_124, cmd_200]; + + // the selected command does not require a preview + let no_preview = State::calc_preview_height( + &settings_preview_auto, + &results, + 0_usize, + 0_usize, + Compactness::Full, + 1, + 80, + ); + // the selected command requires 2 lines + let preview_h2 = State::calc_preview_height( + &settings_preview_auto, + &results, + 1_usize, + 0_usize, + Compactness::Full, + 1, + 80, + ); + // the selected command requires 3 lines + let preview_h3 = State::calc_preview_height( + &settings_preview_auto, + &results, + 2_usize, + 0_usize, + Compactness::Full, + 1, + 80, + ); + // the selected command requires a preview of 1 line (happens when the command is between preview_width-19 and preview_width) + let preview_one_line = State::calc_preview_height( + &settings_preview_auto, + &results, + 0_usize, + 0_usize, + Compactness::Full, + 1, + 66, + ); + // the selected command requires 3 lines, but we have a max preview height limit of 2 + let preview_limit_at_2 = State::calc_preview_height( + &settings_preview_auto_h2, + &results, + 2_usize, + 0_usize, + Compactness::Full, + 1, + 80, + ); + // the longest command requires 3 lines + let preview_static_h3 = State::calc_preview_height( + &settings_preview_h4, + &results, + 1_usize, + 0_usize, + Compactness::Full, + 1, + 80, + ); + // the longest command requires 10 lines, but we have a max preview height limit of 4 + let preview_static_limit_at_4 = State::calc_preview_height( + &settings_preview_h4, + &results, + 1_usize, + 0_usize, + Compactness::Full, + 1, + 20, + ); + // the longest command requires 10 lines, but we have a max preview height of 15 and a fixed preview strategy + let settings_preview_fixed = State::calc_preview_height( + &settings_preview_fixed, + &results, + 1_usize, + 0_usize, + Compactness::Full, + 1, + 20, + ); + + assert_eq!(no_preview, 1); + // 1 * 2 is the space for the border + let border_space = 2; + assert_eq!(preview_h2, 2 + border_space); + assert_eq!(preview_h3, 3 + border_space); + assert_eq!(preview_one_line, 1 + border_space); + assert_eq!(preview_limit_at_2, 2 + border_space); + assert_eq!(preview_static_h3, 3 + border_space); + assert_eq!(preview_static_limit_at_4, 4 + border_space); + assert_eq!(settings_preview_fixed, 15 + border_space); + } + + // Test when there's no results, scrolling up or down doesn't underflow + #[test] + fn state_scroll_up_underflow() { + let settings = Settings::utc(); + let mut state = State { + history_count: 0, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len: 0, + accept: false, + keymap_mode: KeymapMode::Auto, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::defaults(&settings), + search: SearchState { + input: String::new().into(), + filter_mode: FilterMode::Directory, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + + state.scroll_up(1); + state.scroll_down(1); + } + + #[test] + fn test_accept_keybindings() { + use crate::atuin_client::settings::Keys; + use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + + let mut settings = Settings::utc(); + settings.keys = Keys { + scroll_exits: true, + exit_past_line_start: false, + accept_past_line_end: true, + accept_past_line_start: false, + accept_with_backspace: false, + prefix: "a".to_string(), + }; + + let mut state = State { + history_count: 1, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len: 1, + accept: false, + keymap_mode: KeymapMode::Emacs, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::defaults(&settings), + search: SearchState { + input: String::new().into(), + filter_mode: FilterMode::Global, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + + let tab_event = KeyEvent::new(KeyCode::Tab, KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &tab_event); + assert!( + matches!(result, super::InputAction::Accept(_)), + "Tab should always accept" + ); + + // Test left arrow with accept_past_line_start disabled (should continue) + let left_event = KeyEvent::new(KeyCode::Left, KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &left_event); + assert!( + matches!(result, super::InputAction::Continue), + "Left arrow should continue when disabled" + ); + + // Test left arrow with accept_past_line_start enabled (should accept at start of line) + settings.keys.accept_past_line_start = true; + state.keymaps = KeymapSet::defaults(&settings); + let result = state.handle_key_input(&settings, &left_event); + assert!( + matches!(result, super::InputAction::Accept(_)), + "Left arrow should accept at start of line when enabled" + ); + settings.keys.accept_past_line_start = false; + state.keymaps = KeymapSet::defaults(&settings); + + let backspace_event = KeyEvent::new(KeyCode::Backspace, KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &backspace_event); + assert!( + matches!(result, super::InputAction::Continue), + "Backspace should continue when disabled" + ); + + settings.keys.accept_with_backspace = true; + state.keymaps = KeymapSet::defaults(&settings); + let result = state.handle_key_input(&settings, &backspace_event); + assert!( + matches!(result, super::InputAction::Accept(_)), + "Backspace should accept at start of line when enabled" + ); + + state.search.input.insert('t'); + state.search.input.insert('e'); + state.search.input.insert('s'); + state.search.input.insert('t'); + state.search.input.end(); + + let right_event = KeyEvent::new(KeyCode::Right, KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &right_event); + assert!( + matches!(result, super::InputAction::Accept(_)), + "Right arrow should accept at end of line when enabled" + ); + + settings.keys.accept_past_line_start = true; + state.keymaps = KeymapSet::defaults(&settings); + let left_event = KeyEvent::new(KeyCode::Left, KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &left_event); + assert!( + matches!(result, super::InputAction::Continue), + "Left arrow should continue and end of line, even when enabled" + ); + settings.keys.accept_past_line_start = false; + state.keymaps = KeymapSet::defaults(&settings); + + settings.keys.accept_with_backspace = true; + state.keymaps = KeymapSet::defaults(&settings); + let backspace_event = KeyEvent::new(KeyCode::Backspace, KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &backspace_event); + assert!( + matches!(result, super::InputAction::Continue), + "Backspace should continue at end of line, even when enabled" + ); + settings.keys.accept_with_backspace = false; + state.keymaps = KeymapSet::defaults(&settings); + } + + #[test] + fn test_vim_gg_multikey_sequence() { + use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + + let settings = Settings::utc(); + + let mut state = State { + history_count: 100, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len: 100, + accept: false, + keymap_mode: KeymapMode::VimNormal, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::defaults(&settings), + search: SearchState { + input: String::new().into(), + filter_mode: FilterMode::Global, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + + // Start in the middle of the list + state.results_state.select(50); + + // First 'g' should set pending state + let g_event = KeyEvent::new(KeyCode::Char('g'), KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &g_event); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.pending_vim_key, Some('g')); + assert_eq!(state.results_state.selected(), 50); // Position unchanged + + // Second 'g' should jump to end (visual top in non-inverted mode) + let result = state.handle_key_input(&settings, &g_event); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.pending_vim_key, None); + assert_eq!(state.results_state.selected(), 99); // Jumped to last index (visual top) + } + + #[test] + fn test_vim_g_key_clears_on_other_input() { + use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + + let settings = Settings::utc(); + + let mut state = State { + history_count: 100, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len: 100, + accept: false, + keymap_mode: KeymapMode::VimNormal, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::defaults(&settings), + search: SearchState { + input: String::new().into(), + filter_mode: FilterMode::Global, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + + state.results_state.select(50); + + // Press 'g' to set pending state + let g_event = KeyEvent::new(KeyCode::Char('g'), KeyModifiers::NONE); + state.handle_key_input(&settings, &g_event); + assert_eq!(state.pending_vim_key, Some('g')); + + // Press 'j' - should clear pending state + let j_event = KeyEvent::new(KeyCode::Char('j'), KeyModifiers::NONE); + state.handle_key_input(&settings, &j_event); + assert_eq!(state.pending_vim_key, None); + } + + #[test] + fn test_vim_big_g_jump_to_bottom() { + use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + + let settings = Settings::utc(); + + let mut state = State { + history_count: 100, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len: 100, + accept: false, + keymap_mode: KeymapMode::VimNormal, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::defaults(&settings), + search: SearchState { + input: String::new().into(), + filter_mode: FilterMode::Global, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + + state.results_state.select(50); + + // 'G' should jump to visual bottom (index 0 in non-inverted mode) + let big_g_event = KeyEvent::new(KeyCode::Char('G'), KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &big_g_event); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.results_state.selected(), 0); + } + + #[test] + fn test_vim_ctrl_u_d_half_page_scroll() { + use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + + let settings = Settings::utc(); + + let mut state = State { + history_count: 100, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len: 100, + accept: false, + keymap_mode: KeymapMode::VimNormal, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::defaults(&settings), + search: SearchState { + input: String::new().into(), + filter_mode: FilterMode::Global, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + + state.results_state.select(50); + + // Ctrl+d should return Continue and clear pending key + // (scroll amount depends on max_entries which is 0 in tests) + state.pending_vim_key = Some('g'); + let ctrl_d_event = KeyEvent::new(KeyCode::Char('d'), KeyModifiers::CONTROL); + let result = state.handle_key_input(&settings, &ctrl_d_event); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.pending_vim_key, None); + + // Ctrl+u should return Continue and clear pending key + state.pending_vim_key = Some('g'); + let ctrl_u_event = KeyEvent::new(KeyCode::Char('u'), KeyModifiers::CONTROL); + let result = state.handle_key_input(&settings, &ctrl_u_event); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.pending_vim_key, None); + } + + #[test] + fn test_vim_ctrl_f_b_full_page_scroll() { + use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + + let settings = Settings::utc(); + + let mut state = State { + history_count: 100, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len: 100, + accept: false, + keymap_mode: KeymapMode::VimNormal, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::defaults(&settings), + search: SearchState { + input: String::new().into(), + filter_mode: FilterMode::Global, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + + state.results_state.select(50); + + // Ctrl+f should return Continue and clear pending key + // (scroll amount depends on max_entries which is 0 in tests) + state.pending_vim_key = Some('g'); + let ctrl_f_event = KeyEvent::new(KeyCode::Char('f'), KeyModifiers::CONTROL); + let result = state.handle_key_input(&settings, &ctrl_f_event); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.pending_vim_key, None); + + // Ctrl+b should return Continue and clear pending key + state.pending_vim_key = Some('g'); + let ctrl_b_event = KeyEvent::new(KeyCode::Char('b'), KeyModifiers::CONTROL); + let result = state.handle_key_input(&settings, &ctrl_b_event); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.pending_vim_key, None); + } + + // ----------------------------------------------------------------------- + // Executor tests (execute_action) + // ----------------------------------------------------------------------- + + /// Helper to build a State for executor tests. + fn make_executor_state(results_len: usize, selected: usize) -> State { + let settings = Settings::utc(); + let mut state = State { + history_count: results_len as i64, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len, + accept: false, + keymap_mode: KeymapMode::Emacs, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::defaults(&settings), + search: SearchState { + input: String::new().into(), + filter_mode: FilterMode::Global, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + state.results_state.select(selected); + state + } + + #[test] + fn execute_select_next_no_invert() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 50); + let settings = Settings::utc(); + let result = state.execute_action(&Action::SelectNext, &settings); + assert!(matches!(result, super::InputAction::Continue)); + // Non-inverted: SelectNext = scroll_down = selected - 1 + assert_eq!(state.results_state.selected(), 49); + } + + #[test] + fn execute_select_next_with_invert() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 50); + let mut settings = Settings::utc(); + settings.invert = true; + let result = state.execute_action(&Action::SelectNext, &settings); + assert!(matches!(result, super::InputAction::Continue)); + // Inverted: SelectNext = scroll_up = selected + 1 + assert_eq!(state.results_state.selected(), 51); + } + + #[test] + fn execute_select_previous_no_invert() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 50); + let settings = Settings::utc(); + let result = state.execute_action(&Action::SelectPrevious, &settings); + assert!(matches!(result, super::InputAction::Continue)); + // Non-inverted: SelectPrevious = scroll_up = selected + 1 + assert_eq!(state.results_state.selected(), 51); + } + + #[test] + fn execute_vim_enter_normal() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + let settings = Settings::utc(); + let result = state.execute_action(&Action::VimEnterNormal, &settings); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.keymap_mode, KeymapMode::VimNormal); + } + + #[test] + fn execute_vim_enter_insert() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + state.keymap_mode = KeymapMode::VimNormal; + let settings = Settings::utc(); + let result = state.execute_action(&Action::VimEnterInsert, &settings); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.keymap_mode, KeymapMode::VimInsert); + } + + #[test] + fn execute_accept_sets_accept_flag() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 5); + let mut settings = Settings::utc(); + settings.enter_accept = true; + let result = state.execute_action(&Action::Accept, &settings); + assert!(matches!(result, super::InputAction::Accept(5))); + assert!(state.accept); + } + + #[test] + fn execute_return_selection_does_not_set_accept() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 5); + let settings = Settings::utc(); + let result = state.execute_action(&Action::ReturnSelection, &settings); + assert!(matches!(result, super::InputAction::Accept(5))); + assert!(!state.accept); + } + + #[test] + fn execute_accept_nth() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 5); + let settings = Settings::utc(); + let result = state.execute_action(&Action::AcceptNth(3), &settings); + assert!(matches!(result, super::InputAction::Accept(8))); + } + + #[test] + fn execute_scroll_to_top_no_invert() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 50); + let settings = Settings::utc(); + let result = state.execute_action(&Action::ScrollToTop, &settings); + assert!(matches!(result, super::InputAction::Continue)); + // Non-inverted: visual top = highest index + assert_eq!(state.results_state.selected(), 99); + } + + #[test] + fn execute_scroll_to_top_with_invert() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 50); + let mut settings = Settings::utc(); + settings.invert = true; + let result = state.execute_action(&Action::ScrollToTop, &settings); + assert!(matches!(result, super::InputAction::Continue)); + // Inverted: visual top = index 0 + assert_eq!(state.results_state.selected(), 0); + } + + #[test] + fn execute_scroll_to_bottom_no_invert() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 50); + let settings = Settings::utc(); + let result = state.execute_action(&Action::ScrollToBottom, &settings); + assert!(matches!(result, super::InputAction::Continue)); + // Non-inverted: visual bottom = index 0 + assert_eq!(state.results_state.selected(), 0); + } + + #[test] + fn execute_toggle_tab() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + let settings = Settings::utc(); + assert_eq!(state.tab_index, 0); + state.execute_action(&Action::ToggleTab, &settings); + assert_eq!(state.tab_index, 1); + state.execute_action(&Action::ToggleTab, &settings); + assert_eq!(state.tab_index, 0); + } + + #[test] + fn execute_enter_prefix_mode() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + let settings = Settings::utc(); + assert!(!state.prefix); + state.execute_action(&Action::EnterPrefixMode, &settings); + assert!(state.prefix); + } + + #[test] + fn execute_exit_returns_based_on_exit_mode() { + use crate::atuin_client::settings::ExitMode; + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + let mut settings = Settings::utc(); + + settings.exit_mode = ExitMode::ReturnOriginal; + let result = state.execute_action(&Action::Exit, &settings); + assert!(matches!(result, super::InputAction::ReturnOriginal)); + + settings.exit_mode = ExitMode::ReturnQuery; + let result = state.execute_action(&Action::Exit, &settings); + assert!(matches!(result, super::InputAction::ReturnQuery)); + } + + #[test] + fn execute_return_original() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + let settings = Settings::utc(); + let result = state.execute_action(&Action::ReturnOriginal, &settings); + assert!(matches!(result, super::InputAction::ReturnOriginal)); + } + + #[test] + fn execute_copy() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 7); + let settings = Settings::utc(); + let result = state.execute_action(&Action::Copy, &settings); + assert!(matches!(result, super::InputAction::Copy(7))); + } + + #[test] + fn execute_delete() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 7); + let settings = Settings::utc(); + let result = state.execute_action(&Action::Delete, &settings); + assert!(matches!(result, super::InputAction::Delete(7))); + } + + #[test] + fn execute_switch_context() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 7); + let settings = Settings::utc(); + let result = state.execute_action(&Action::SwitchContext, &settings); + assert!(matches!(result, super::InputAction::SwitchContext(Some(7)))); + } + + #[test] + fn execute_clear_context() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 7); + let settings = Settings::utc(); + let result = state.execute_action(&Action::ClearContext, &settings); + assert!(matches!(result, super::InputAction::SwitchContext(None))); + } + + #[test] + fn execute_noop() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 50); + let settings = Settings::utc(); + let result = state.execute_action(&Action::Noop, &settings); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.results_state.selected(), 50); + } + + #[test] + fn execute_accept_in_inspector_tab() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 5); + state.tab_index = 1; + let settings = Settings::utc(); + let result = state.execute_action(&Action::Accept, &settings); + assert!(matches!(result, super::InputAction::AcceptInspecting)); + } + + #[test] + fn execute_cycle_search_mode() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + let settings = Settings::utc(); + let original_mode = state.search_mode; + let result = state.execute_action(&Action::CycleSearchMode, &settings); + assert!(matches!(result, super::InputAction::Continue)); + assert!(state.switched_search_mode); + assert_ne!(state.search_mode, original_mode); + } + + #[test] + fn execute_vim_search_insert() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + state.search.input.insert('h'); + state.search.input.insert('i'); + state.keymap_mode = KeymapMode::VimNormal; + let settings = Settings::utc(); + let result = state.execute_action(&Action::VimSearchInsert, &settings); + assert!(matches!(result, super::InputAction::Continue)); + // Should clear input and switch to insert mode + assert_eq!(state.search.input.as_str(), ""); + assert_eq!(state.keymap_mode, KeymapMode::VimInsert); + } + + #[test] + fn execute_cursor_movement() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + let settings = Settings::utc(); + + // Insert some text + state.search.input.insert('h'); + state.search.input.insert('e'); + state.search.input.insert('l'); + state.search.input.insert('l'); + state.search.input.insert('o'); + // cursor is at end (position 5) + + // CursorLeft + state.execute_action(&Action::CursorLeft, &settings); + assert_eq!(state.search.input.position(), 4); + + // CursorStart + state.execute_action(&Action::CursorStart, &settings); + assert_eq!(state.search.input.position(), 0); + + // CursorEnd + state.execute_action(&Action::CursorEnd, &settings); + assert_eq!(state.search.input.position(), 5); + + // CursorRight at end does nothing + state.execute_action(&Action::CursorRight, &settings); + assert_eq!(state.search.input.position(), 5); + } + + #[test] + fn execute_editing() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + let settings = Settings::utc(); + + // Insert "hello" + state.search.input.insert('h'); + state.search.input.insert('e'); + state.search.input.insert('l'); + state.search.input.insert('l'); + state.search.input.insert('o'); + + // DeleteCharBefore (backspace) + state.execute_action(&Action::DeleteCharBefore, &settings); + assert_eq!(state.search.input.as_str(), "hell"); + + // ClearLine + state.execute_action(&Action::ClearLine, &settings); + assert_eq!(state.search.input.as_str(), ""); + } + + #[test] + fn keymap_config_return_query() { + use crate::atuin_client::settings::KeyBindingConfig; + use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + use std::collections::HashMap; + + let mut settings = Settings::utc(); + // Configure tab to return-query + settings.keymap.emacs = HashMap::from([( + "tab".to_string(), + KeyBindingConfig::Simple("return-query".to_string()), + )]); + + let mut state = State { + history_count: 100, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len: 100, + accept: false, + keymap_mode: KeymapMode::Emacs, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::from_settings(&settings), + search: SearchState { + input: "test query".to_string().into(), + filter_mode: FilterMode::Global, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + + let tab_event = KeyEvent::new(KeyCode::Tab, KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &tab_event); + assert!( + matches!(result, super::InputAction::ReturnQuery), + "Tab configured as return-query should return InputAction::ReturnQuery" + ); + } +} diff --git a/crates/turtle/src/command/client/search/keybindings/actions.rs b/crates/turtle/src/command/client/search/keybindings/actions.rs new file mode 100644 index 00000000..ff2ef7de --- /dev/null +++ b/crates/turtle/src/command/client/search/keybindings/actions.rs @@ -0,0 +1,322 @@ +use std::fmt; + +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +/// All possible actions that can be triggered by a keybinding. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Action { + // Cursor movement + CursorLeft, + CursorRight, + CursorWordLeft, + CursorWordRight, + CursorWordEnd, + CursorStart, + CursorEnd, + + // Editing + DeleteCharBefore, + DeleteCharAfter, + DeleteWordBefore, + DeleteWordAfter, + DeleteToWordBoundary, + ClearLine, + ClearToStart, + ClearToEnd, + + // List navigation + SelectNext, + SelectPrevious, + ScrollHalfPageUp, + ScrollHalfPageDown, + ScrollPageUp, + ScrollPageDown, + ScrollToTop, + ScrollToBottom, + ScrollToScreenTop, + ScrollToScreenMiddle, + ScrollToScreenBottom, + + // Commands — accept selection and execute immediately + Accept, + AcceptNth(u8), + // Commands — return selection to command line without executing + ReturnSelection, + ReturnSelectionNth(u8), + // Commands — other + Copy, + Delete, + DeleteAll, + ReturnOriginal, + ReturnQuery, + Exit, + Redraw, + CycleFilterMode, + CycleSearchMode, + SwitchContext, + ClearContext, + ToggleTab, + + // Mode changes + VimEnterNormal, + VimEnterInsert, + VimEnterInsertAfter, + VimEnterInsertAtStart, + VimEnterInsertAtEnd, + VimSearchInsert, + VimChangeToEnd, + EnterPrefixMode, + + // Inspector + InspectPrevious, + InspectNext, + + // Special + Noop, +} + +impl Action { + /// Convert from a kebab-case string. + pub fn from_str(s: &str) -> Result<Self, String> { + // Handle accept-N and return-selection-N patterns + if let Some(rest) = s.strip_prefix("accept-") + && let Ok(n) = rest.parse::<u8>() + && (1..=9).contains(&n) + { + return Ok(Action::AcceptNth(n)); + } + if let Some(rest) = s.strip_prefix("return-selection-") + && let Ok(n) = rest.parse::<u8>() + && (1..=9).contains(&n) + { + return Ok(Action::ReturnSelectionNth(n)); + } + + match s { + "cursor-left" => Ok(Action::CursorLeft), + "cursor-right" => Ok(Action::CursorRight), + "cursor-word-left" => Ok(Action::CursorWordLeft), + "cursor-word-right" => Ok(Action::CursorWordRight), + "cursor-word-end" => Ok(Action::CursorWordEnd), + "cursor-start" => Ok(Action::CursorStart), + "cursor-end" => Ok(Action::CursorEnd), + + "delete-char-before" => Ok(Action::DeleteCharBefore), + "delete-char-after" => Ok(Action::DeleteCharAfter), + "delete-word-before" => Ok(Action::DeleteWordBefore), + "delete-word-after" => Ok(Action::DeleteWordAfter), + "delete-to-word-boundary" => Ok(Action::DeleteToWordBoundary), + "clear-line" => Ok(Action::ClearLine), + "clear-to-start" => Ok(Action::ClearToStart), + "clear-to-end" => Ok(Action::ClearToEnd), + + "select-next" => Ok(Action::SelectNext), + "select-previous" => Ok(Action::SelectPrevious), + "scroll-half-page-up" => Ok(Action::ScrollHalfPageUp), + "scroll-half-page-down" => Ok(Action::ScrollHalfPageDown), + "scroll-page-up" => Ok(Action::ScrollPageUp), + "scroll-page-down" => Ok(Action::ScrollPageDown), + "scroll-to-top" => Ok(Action::ScrollToTop), + "scroll-to-bottom" => Ok(Action::ScrollToBottom), + "scroll-to-screen-top" => Ok(Action::ScrollToScreenTop), + "scroll-to-screen-middle" => Ok(Action::ScrollToScreenMiddle), + "scroll-to-screen-bottom" => Ok(Action::ScrollToScreenBottom), + + "accept" => Ok(Action::Accept), + "return-selection" => Ok(Action::ReturnSelection), + "copy" => Ok(Action::Copy), + "delete" => Ok(Action::Delete), + "delete-all" => Ok(Action::DeleteAll), + "return-original" => Ok(Action::ReturnOriginal), + "return-query" => Ok(Action::ReturnQuery), + "exit" => Ok(Action::Exit), + "redraw" => Ok(Action::Redraw), + "cycle-filter-mode" => Ok(Action::CycleFilterMode), + "cycle-search-mode" => Ok(Action::CycleSearchMode), + "switch-context" => Ok(Action::SwitchContext), + "clear-context" => Ok(Action::ClearContext), + "toggle-tab" => Ok(Action::ToggleTab), + + "vim-enter-normal" => Ok(Action::VimEnterNormal), + "vim-enter-insert" => Ok(Action::VimEnterInsert), + "vim-enter-insert-after" => Ok(Action::VimEnterInsertAfter), + "vim-enter-insert-at-start" => Ok(Action::VimEnterInsertAtStart), + "vim-enter-insert-at-end" => Ok(Action::VimEnterInsertAtEnd), + "vim-search-insert" => Ok(Action::VimSearchInsert), + "vim-change-to-end" => Ok(Action::VimChangeToEnd), + "enter-prefix-mode" => Ok(Action::EnterPrefixMode), + + "inspect-previous" => Ok(Action::InspectPrevious), + "inspect-next" => Ok(Action::InspectNext), + + "noop" => Ok(Action::Noop), + + _ => Err(format!("unknown action: {s}")), + } + } + + /// Convert to a kebab-case string. + pub fn as_str(&self) -> String { + match self { + Action::CursorLeft => "cursor-left".to_string(), + Action::CursorRight => "cursor-right".to_string(), + Action::CursorWordLeft => "cursor-word-left".to_string(), + Action::CursorWordRight => "cursor-word-right".to_string(), + Action::CursorWordEnd => "cursor-word-end".to_string(), + Action::CursorStart => "cursor-start".to_string(), + Action::CursorEnd => "cursor-end".to_string(), + + Action::DeleteCharBefore => "delete-char-before".to_string(), + Action::DeleteCharAfter => "delete-char-after".to_string(), + Action::DeleteWordBefore => "delete-word-before".to_string(), + Action::DeleteWordAfter => "delete-word-after".to_string(), + Action::DeleteToWordBoundary => "delete-to-word-boundary".to_string(), + Action::ClearLine => "clear-line".to_string(), + Action::ClearToStart => "clear-to-start".to_string(), + Action::ClearToEnd => "clear-to-end".to_string(), + + Action::SelectNext => "select-next".to_string(), + Action::SelectPrevious => "select-previous".to_string(), + Action::ScrollHalfPageUp => "scroll-half-page-up".to_string(), + Action::ScrollHalfPageDown => "scroll-half-page-down".to_string(), + Action::ScrollPageUp => "scroll-page-up".to_string(), + Action::ScrollPageDown => "scroll-page-down".to_string(), + Action::ScrollToTop => "scroll-to-top".to_string(), + Action::ScrollToBottom => "scroll-to-bottom".to_string(), + Action::ScrollToScreenTop => "scroll-to-screen-top".to_string(), + Action::ScrollToScreenMiddle => "scroll-to-screen-middle".to_string(), + Action::ScrollToScreenBottom => "scroll-to-screen-bottom".to_string(), + + Action::Accept => "accept".to_string(), + Action::AcceptNth(n) => format!("accept-{n}"), + Action::ReturnSelection => "return-selection".to_string(), + Action::ReturnSelectionNth(n) => format!("return-selection-{n}"), + Action::Copy => "copy".to_string(), + Action::Delete => "delete".to_string(), + Action::DeleteAll => "delete-all".to_string(), + Action::ReturnOriginal => "return-original".to_string(), + Action::ReturnQuery => "return-query".to_string(), + Action::Exit => "exit".to_string(), + Action::Redraw => "redraw".to_string(), + Action::CycleFilterMode => "cycle-filter-mode".to_string(), + Action::CycleSearchMode => "cycle-search-mode".to_string(), + Action::SwitchContext => "switch-context".to_string(), + Action::ClearContext => "clear-context".to_string(), + Action::ToggleTab => "toggle-tab".to_string(), + + Action::VimEnterNormal => "vim-enter-normal".to_string(), + Action::VimEnterInsert => "vim-enter-insert".to_string(), + Action::VimEnterInsertAfter => "vim-enter-insert-after".to_string(), + Action::VimEnterInsertAtStart => "vim-enter-insert-at-start".to_string(), + Action::VimEnterInsertAtEnd => "vim-enter-insert-at-end".to_string(), + Action::VimSearchInsert => "vim-search-insert".to_string(), + Action::VimChangeToEnd => "vim-change-to-end".to_string(), + Action::EnterPrefixMode => "enter-prefix-mode".to_string(), + + Action::InspectPrevious => "inspect-previous".to_string(), + Action::InspectNext => "inspect-next".to_string(), + + Action::Noop => "noop".to_string(), + } + } +} + +impl fmt::Display for Action { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.as_str()) + } +} + +impl Serialize for Action { + fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + serializer.serialize_str(&self.as_str()) + } +} + +impl<'de> Deserialize<'de> for Action { + fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> { + let s = String::deserialize(deserializer)?; + Action::from_str(&s).map_err(serde::de::Error::custom) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_basic_actions() { + assert_eq!(Action::from_str("cursor-left").unwrap(), Action::CursorLeft); + assert_eq!(Action::from_str("accept").unwrap(), Action::Accept); + assert_eq!(Action::from_str("exit").unwrap(), Action::Exit); + assert_eq!(Action::from_str("noop").unwrap(), Action::Noop); + assert_eq!( + Action::from_str("vim-enter-normal").unwrap(), + Action::VimEnterNormal + ); + } + + #[test] + fn parse_accept_nth() { + assert_eq!(Action::from_str("accept-1").unwrap(), Action::AcceptNth(1)); + assert_eq!(Action::from_str("accept-9").unwrap(), Action::AcceptNth(9)); + } + + #[test] + fn parse_return_selection() { + assert_eq!( + Action::from_str("return-selection").unwrap(), + Action::ReturnSelection + ); + assert_eq!( + Action::from_str("return-selection-1").unwrap(), + Action::ReturnSelectionNth(1) + ); + assert_eq!( + Action::from_str("return-selection-9").unwrap(), + Action::ReturnSelectionNth(9) + ); + } + + #[test] + fn parse_unknown_action() { + assert!(Action::from_str("unknown-action").is_err()); + assert!(Action::from_str("accept-0").is_err()); + assert!(Action::from_str("accept-10").is_err()); + assert!(Action::from_str("return-selection-0").is_err()); + assert!(Action::from_str("return-selection-10").is_err()); + } + + #[test] + fn round_trip() { + let actions = vec![ + Action::CursorLeft, + Action::Accept, + Action::AcceptNth(5), + Action::ReturnSelection, + Action::ReturnSelectionNth(3), + Action::VimSearchInsert, + Action::ScrollToScreenMiddle, + ]; + for action in actions { + let s = action.as_str(); + let parsed = Action::from_str(&s).unwrap(); + assert_eq!(action, parsed); + } + } + + #[test] + fn serde_round_trip() { + let action = Action::CursorLeft; + let json = serde_json::to_string(&action).unwrap(); + assert_eq!(json, "\"cursor-left\""); + let parsed: Action = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed, Action::CursorLeft); + + let action = Action::AcceptNth(3); + let json = serde_json::to_string(&action).unwrap(); + assert_eq!(json, "\"accept-3\""); + let parsed: Action = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed, Action::AcceptNth(3)); + } +} diff --git a/crates/turtle/src/command/client/search/keybindings/conditions.rs b/crates/turtle/src/command/client/search/keybindings/conditions.rs new file mode 100644 index 00000000..055ae905 --- /dev/null +++ b/crates/turtle/src/command/client/search/keybindings/conditions.rs @@ -0,0 +1,801 @@ +use std::fmt; + +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +/// Atomic (leaf) conditions that can be evaluated against state. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ConditionAtom { + CursorAtStart, + CursorAtEnd, + InputEmpty, + OriginalInputEmpty, + ListAtEnd, + ListAtStart, + NoResults, + HasResults, + HasContext, +} + +/// Boolean expression tree over condition atoms. +/// +/// Supports negation, conjunction, and disjunction with standard precedence: +/// `!` binds tightest, then `&&`, then `||`. +/// +/// Examples of valid expression strings: +/// - `"cursor-at-start"` (bare atom) +/// - `"!no-results"` (negation) +/// - `"cursor-at-start && input-empty"` (conjunction) +/// - `"list-at-start || no-results"` (disjunction) +/// - `"(cursor-at-start && !input-empty) || no-results"` (grouping) +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ConditionExpr { + Atom(ConditionAtom), + Not(Box<ConditionExpr>), + And(Box<ConditionExpr>, Box<ConditionExpr>), + Or(Box<ConditionExpr>, Box<ConditionExpr>), +} + +/// Context needed to evaluate conditions. This is a pure snapshot of state — +/// no references to mutable data. +pub struct EvalContext { + /// Current cursor position (unicode width units). + pub cursor_position: usize, + /// Width of the input string in unicode width units. + pub input_width: usize, + /// Byte length of the input string. + pub input_byte_len: usize, + /// Currently selected index in the results list. + pub selected_index: usize, + /// Total number of results. + pub results_len: usize, + /// Whether the original input (query passed to the TUI) was empty. + pub original_input_empty: bool, + /// Whether we use a search context of a command from the history. + pub has_context: bool, +} + +// --------------------------------------------------------------------------- +// ConditionAtom +// --------------------------------------------------------------------------- + +impl ConditionAtom { + /// Evaluate this atom against the given context. + pub fn evaluate(&self, ctx: &EvalContext) -> bool { + match self { + ConditionAtom::CursorAtStart => ctx.cursor_position == 0, + ConditionAtom::CursorAtEnd => ctx.cursor_position == ctx.input_width, + ConditionAtom::InputEmpty => ctx.input_byte_len == 0, + ConditionAtom::OriginalInputEmpty => ctx.original_input_empty, + ConditionAtom::ListAtEnd => { + ctx.results_len == 0 || ctx.selected_index >= ctx.results_len.saturating_sub(1) + } + ConditionAtom::ListAtStart => ctx.results_len == 0 || ctx.selected_index == 0, + ConditionAtom::NoResults => ctx.results_len == 0, + ConditionAtom::HasResults => ctx.results_len > 0, + ConditionAtom::HasContext => ctx.has_context, + } + } + + /// Parse from a kebab-case string. + pub fn from_str(s: &str) -> Result<Self, String> { + match s { + "cursor-at-start" => Ok(ConditionAtom::CursorAtStart), + "cursor-at-end" => Ok(ConditionAtom::CursorAtEnd), + "input-empty" => Ok(ConditionAtom::InputEmpty), + "original-input-empty" => Ok(ConditionAtom::OriginalInputEmpty), + "list-at-end" => Ok(ConditionAtom::ListAtEnd), + "list-at-start" => Ok(ConditionAtom::ListAtStart), + "no-results" => Ok(ConditionAtom::NoResults), + "has-results" => Ok(ConditionAtom::HasResults), + "has-context" => Ok(ConditionAtom::HasContext), + _ => Err(format!("unknown condition: {s}")), + } + } + + /// Convert to a kebab-case string. + pub fn as_str(&self) -> &'static str { + match self { + ConditionAtom::CursorAtStart => "cursor-at-start", + ConditionAtom::CursorAtEnd => "cursor-at-end", + ConditionAtom::InputEmpty => "input-empty", + ConditionAtom::OriginalInputEmpty => "original-input-empty", + ConditionAtom::ListAtEnd => "list-at-end", + ConditionAtom::ListAtStart => "list-at-start", + ConditionAtom::NoResults => "no-results", + ConditionAtom::HasResults => "has-results", + ConditionAtom::HasContext => "has-context", + } + } +} + +impl fmt::Display for ConditionAtom { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.as_str()) + } +} + +// --------------------------------------------------------------------------- +// ConditionExpr — evaluation +// --------------------------------------------------------------------------- + +impl ConditionExpr { + /// Evaluate this expression against the given context. + pub fn evaluate(&self, ctx: &EvalContext) -> bool { + match self { + ConditionExpr::Atom(atom) => atom.evaluate(ctx), + ConditionExpr::Not(inner) => !inner.evaluate(ctx), + ConditionExpr::And(lhs, rhs) => lhs.evaluate(ctx) && rhs.evaluate(ctx), + ConditionExpr::Or(lhs, rhs) => lhs.evaluate(ctx) || rhs.evaluate(ctx), + } + } +} + +// --------------------------------------------------------------------------- +// ConditionExpr — ergonomic builders +// --------------------------------------------------------------------------- + +impl From<ConditionAtom> for ConditionExpr { + fn from(atom: ConditionAtom) -> Self { + ConditionExpr::Atom(atom) + } +} + +#[expect(dead_code)] +impl ConditionExpr { + /// Negate this expression: `!self`. + pub fn not(self) -> Self { + ConditionExpr::Not(Box::new(self)) + } + + /// Conjoin with another expression: `self && other`. + pub fn and(self, other: ConditionExpr) -> Self { + ConditionExpr::And(Box::new(self), Box::new(other)) + } + + /// Disjoin with another expression: `self || other`. + pub fn or(self, other: ConditionExpr) -> Self { + ConditionExpr::Or(Box::new(self), Box::new(other)) + } +} + +// --------------------------------------------------------------------------- +// ConditionExpr — parser +// --------------------------------------------------------------------------- + +/// Recursive descent parser for boolean condition expressions. +/// +/// Grammar (standard boolean precedence): +/// ```text +/// expr = or_expr +/// or_expr = and_expr ("||" and_expr)* +/// and_expr = unary ("&&" unary)* +/// unary = "!" unary | primary +/// primary = atom | "(" expr ")" +/// atom = [a-z][a-z0-9-]* +/// ``` +struct ExprParser<'a> { + input: &'a str, + pos: usize, +} + +impl<'a> ExprParser<'a> { + fn new(input: &'a str) -> Self { + Self { input, pos: 0 } + } + + fn skip_whitespace(&mut self) { + while self.pos < self.input.len() && self.input.as_bytes()[self.pos].is_ascii_whitespace() { + self.pos += 1; + } + } + + fn starts_with(&mut self, s: &str) -> bool { + self.skip_whitespace(); + self.input[self.pos..].starts_with(s) + } + + fn consume(&mut self, s: &str) -> bool { + self.skip_whitespace(); + if self.input[self.pos..].starts_with(s) { + self.pos += s.len(); + true + } else { + false + } + } + + /// Parse a full expression, expecting to consume all input. + fn parse(mut self) -> Result<ConditionExpr, String> { + let expr = self.parse_or()?; + self.skip_whitespace(); + if self.pos < self.input.len() { + return Err(format!( + "unexpected input at position {}: {:?}", + self.pos, + &self.input[self.pos..] + )); + } + Ok(expr) + } + + /// `or_expr` = `and_expr` ("||" `and_expr`)* + fn parse_or(&mut self) -> Result<ConditionExpr, String> { + let mut left = self.parse_and()?; + while self.starts_with("||") { + self.consume("||"); + let right = self.parse_and()?; + left = ConditionExpr::Or(Box::new(left), Box::new(right)); + } + Ok(left) + } + + /// `and_expr` = unary ("&&" unary)* + fn parse_and(&mut self) -> Result<ConditionExpr, String> { + let mut left = self.parse_unary()?; + while self.starts_with("&&") { + self.consume("&&"); + let right = self.parse_unary()?; + left = ConditionExpr::And(Box::new(left), Box::new(right)); + } + Ok(left) + } + + /// unary = "!" unary | primary + fn parse_unary(&mut self) -> Result<ConditionExpr, String> { + if self.consume("!") { + let inner = self.parse_unary()?; + Ok(ConditionExpr::Not(Box::new(inner))) + } else { + self.parse_primary() + } + } + + /// primary = "(" expr ")" | atom + fn parse_primary(&mut self) -> Result<ConditionExpr, String> { + if self.consume("(") { + let expr = self.parse_or()?; + if !self.consume(")") { + return Err(format!("expected ')' at position {}", self.pos)); + } + Ok(expr) + } else { + self.parse_atom() + } + } + + /// atom = [a-z][a-z0-9-]* + fn parse_atom(&mut self) -> Result<ConditionExpr, String> { + self.skip_whitespace(); + let start = self.pos; + while self.pos < self.input.len() { + let b = self.input.as_bytes()[self.pos]; + if b.is_ascii_lowercase() || b.is_ascii_digit() || b == b'-' { + self.pos += 1; + } else { + break; + } + } + if self.pos == start { + return Err(format!("expected condition name at position {}", self.pos)); + } + let name = &self.input[start..self.pos]; + let atom = ConditionAtom::from_str(name)?; + Ok(ConditionExpr::Atom(atom)) + } +} + +impl ConditionExpr { + /// Parse a condition expression from a string. + pub fn parse(s: &str) -> Result<Self, String> { + let parser = ExprParser::new(s); + parser.parse() + } +} + +// --------------------------------------------------------------------------- +// ConditionExpr — Display +// --------------------------------------------------------------------------- + +/// Precedence levels for minimal-parentheses display. +#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +enum Prec { + Or = 0, + And = 1, + Not = 2, + Atom = 3, +} + +impl ConditionExpr { + fn prec(&self) -> Prec { + match self { + ConditionExpr::Or(..) => Prec::Or, + ConditionExpr::And(..) => Prec::And, + ConditionExpr::Not(..) => Prec::Not, + ConditionExpr::Atom(..) => Prec::Atom, + } + } + + fn fmt_with_prec(&self, f: &mut fmt::Formatter<'_>, parent_prec: Prec) -> fmt::Result { + let needs_parens = self.prec() < parent_prec; + if needs_parens { + write!(f, "(")?; + } + match self { + ConditionExpr::Atom(atom) => write!(f, "{atom}")?, + ConditionExpr::Not(inner) => { + write!(f, "!")?; + inner.fmt_with_prec(f, Prec::Not)?; + } + ConditionExpr::And(lhs, rhs) => { + lhs.fmt_with_prec(f, Prec::And)?; + write!(f, " && ")?; + rhs.fmt_with_prec(f, Prec::And)?; + } + ConditionExpr::Or(lhs, rhs) => { + lhs.fmt_with_prec(f, Prec::Or)?; + write!(f, " || ")?; + rhs.fmt_with_prec(f, Prec::Or)?; + } + } + if needs_parens { + write!(f, ")")?; + } + Ok(()) + } +} + +impl fmt::Display for ConditionExpr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.fmt_with_prec(f, Prec::Or) + } +} + +// --------------------------------------------------------------------------- +// Serde +// --------------------------------------------------------------------------- + +impl Serialize for ConditionExpr { + fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + serializer.serialize_str(&self.to_string()) + } +} + +impl<'de> Deserialize<'de> for ConditionExpr { + fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> { + let s = String::deserialize(deserializer)?; + ConditionExpr::parse(&s).map_err(serde::de::Error::custom) + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn ctx( + cursor: usize, + width: usize, + byte_len: usize, + selected: usize, + len: usize, + ) -> EvalContext { + ctx_with_original(cursor, width, byte_len, selected, len, false) + } + + fn ctx_with_original( + cursor: usize, + width: usize, + byte_len: usize, + selected: usize, + len: usize, + original_input_empty: bool, + ) -> EvalContext { + EvalContext { + cursor_position: cursor, + input_width: width, + input_byte_len: byte_len, + selected_index: selected, + results_len: len, + original_input_empty, + has_context: false, + } + } + + // -- Atom evaluation (carried over from Phase 0) -- + + #[test] + fn atom_cursor_at_start() { + assert!(ConditionAtom::CursorAtStart.evaluate(&ctx(0, 5, 5, 0, 10))); + assert!(!ConditionAtom::CursorAtStart.evaluate(&ctx(3, 5, 5, 0, 10))); + } + + #[test] + fn atom_cursor_at_end() { + assert!(ConditionAtom::CursorAtEnd.evaluate(&ctx(5, 5, 5, 0, 10))); + assert!(!ConditionAtom::CursorAtEnd.evaluate(&ctx(3, 5, 5, 0, 10))); + assert!(ConditionAtom::CursorAtEnd.evaluate(&ctx(0, 0, 0, 0, 10))); + } + + #[test] + fn atom_input_empty() { + assert!(ConditionAtom::InputEmpty.evaluate(&ctx(0, 0, 0, 0, 10))); + assert!(!ConditionAtom::InputEmpty.evaluate(&ctx(0, 5, 5, 0, 10))); + } + + #[test] + fn atom_original_input_empty() { + // original_input_empty = true + assert!( + ConditionAtom::OriginalInputEmpty.evaluate(&ctx_with_original(0, 0, 0, 0, 10, true)) + ); + // original_input_empty = false + assert!( + !ConditionAtom::OriginalInputEmpty.evaluate(&ctx_with_original(0, 0, 0, 0, 10, false)) + ); + // original_input_empty is independent of current input state + assert!( + ConditionAtom::OriginalInputEmpty.evaluate(&ctx_with_original(0, 5, 5, 0, 10, true)) + ); + } + + #[test] + fn atom_list_at_end() { + assert!(ConditionAtom::ListAtEnd.evaluate(&ctx(0, 0, 0, 99, 100))); + assert!(!ConditionAtom::ListAtEnd.evaluate(&ctx(0, 0, 0, 50, 100))); + assert!(ConditionAtom::ListAtEnd.evaluate(&ctx(0, 0, 0, 0, 0))); + } + + #[test] + fn atom_list_at_start() { + assert!(ConditionAtom::ListAtStart.evaluate(&ctx(0, 0, 0, 0, 100))); + assert!(!ConditionAtom::ListAtStart.evaluate(&ctx(0, 0, 0, 50, 100))); + assert!(ConditionAtom::ListAtStart.evaluate(&ctx(0, 0, 0, 0, 0))); + } + + #[test] + fn atom_no_results_and_has_results() { + assert!(ConditionAtom::NoResults.evaluate(&ctx(0, 0, 0, 0, 0))); + assert!(!ConditionAtom::NoResults.evaluate(&ctx(0, 0, 0, 0, 5))); + assert!(ConditionAtom::HasResults.evaluate(&ctx(0, 0, 0, 0, 5))); + assert!(!ConditionAtom::HasResults.evaluate(&ctx(0, 0, 0, 0, 0))); + } + + #[test] + fn atom_has_context() { + let mut context = ctx(0, 0, 0, 0, 0); + assert!(!ConditionAtom::HasContext.evaluate(&context)); + context.has_context = true; + assert!(ConditionAtom::HasContext.evaluate(&context)); + } + + #[test] + fn atom_parse_round_trip() { + let conditions = [ + "cursor-at-start", + "cursor-at-end", + "input-empty", + "original-input-empty", + "list-at-end", + "list-at-start", + "no-results", + "has-results", + ]; + for s in conditions { + let c = ConditionAtom::from_str(s).unwrap(); + assert_eq!(c.as_str(), s); + } + } + + #[test] + fn atom_parse_unknown() { + assert!(ConditionAtom::from_str("unknown-condition").is_err()); + } + + // -- Parser tests -- + + #[test] + fn parse_bare_atom() { + let expr = ConditionExpr::parse("cursor-at-start").unwrap(); + assert_eq!(expr, ConditionExpr::Atom(ConditionAtom::CursorAtStart)); + } + + #[test] + fn parse_negation() { + let expr = ConditionExpr::parse("!no-results").unwrap(); + assert_eq!( + expr, + ConditionExpr::Not(Box::new(ConditionExpr::Atom(ConditionAtom::NoResults))) + ); + } + + #[test] + fn parse_double_negation() { + let expr = ConditionExpr::parse("!!no-results").unwrap(); + assert_eq!( + expr, + ConditionExpr::Not(Box::new(ConditionExpr::Not(Box::new(ConditionExpr::Atom( + ConditionAtom::NoResults + ))))) + ); + } + + #[test] + fn parse_and() { + let expr = ConditionExpr::parse("cursor-at-start && input-empty").unwrap(); + assert_eq!( + expr, + ConditionExpr::And( + Box::new(ConditionExpr::Atom(ConditionAtom::CursorAtStart)), + Box::new(ConditionExpr::Atom(ConditionAtom::InputEmpty)), + ) + ); + } + + #[test] + fn parse_or() { + let expr = ConditionExpr::parse("list-at-start || no-results").unwrap(); + assert_eq!( + expr, + ConditionExpr::Or( + Box::new(ConditionExpr::Atom(ConditionAtom::ListAtStart)), + Box::new(ConditionExpr::Atom(ConditionAtom::NoResults)), + ) + ); + } + + #[test] + fn parse_precedence_and_binds_tighter_than_or() { + // "a || b && c" should parse as "a || (b && c)" + let expr = ConditionExpr::parse("cursor-at-start || input-empty && no-results").unwrap(); + assert_eq!( + expr, + ConditionExpr::Or( + Box::new(ConditionExpr::Atom(ConditionAtom::CursorAtStart)), + Box::new(ConditionExpr::And( + Box::new(ConditionExpr::Atom(ConditionAtom::InputEmpty)), + Box::new(ConditionExpr::Atom(ConditionAtom::NoResults)), + )), + ) + ); + } + + #[test] + fn parse_parens_override_precedence() { + // "(a || b) && c" + let expr = ConditionExpr::parse("(cursor-at-start || input-empty) && no-results").unwrap(); + assert_eq!( + expr, + ConditionExpr::And( + Box::new(ConditionExpr::Or( + Box::new(ConditionExpr::Atom(ConditionAtom::CursorAtStart)), + Box::new(ConditionExpr::Atom(ConditionAtom::InputEmpty)), + )), + Box::new(ConditionExpr::Atom(ConditionAtom::NoResults)), + ) + ); + } + + #[test] + fn parse_complex_nested() { + // "(a && !b) || c" + let expr = ConditionExpr::parse("(cursor-at-start && !input-empty) || no-results").unwrap(); + assert_eq!( + expr, + ConditionExpr::Or( + Box::new(ConditionExpr::And( + Box::new(ConditionExpr::Atom(ConditionAtom::CursorAtStart)), + Box::new(ConditionExpr::Not(Box::new(ConditionExpr::Atom( + ConditionAtom::InputEmpty + )))), + )), + Box::new(ConditionExpr::Atom(ConditionAtom::NoResults)), + ) + ); + } + + #[test] + fn parse_whitespace_tolerance() { + let a = ConditionExpr::parse("cursor-at-start||input-empty").unwrap(); + let b = ConditionExpr::parse("cursor-at-start || input-empty").unwrap(); + let c = ConditionExpr::parse(" cursor-at-start || input-empty ").unwrap(); + assert_eq!(a, b); + assert_eq!(b, c); + } + + #[test] + fn parse_error_unknown_atom() { + assert!(ConditionExpr::parse("unknown-thing").is_err()); + } + + #[test] + fn parse_error_trailing_input() { + assert!(ConditionExpr::parse("cursor-at-start blah").is_err()); + } + + #[test] + fn parse_error_unmatched_paren() { + assert!(ConditionExpr::parse("(cursor-at-start").is_err()); + } + + #[test] + fn parse_error_empty() { + assert!(ConditionExpr::parse("").is_err()); + } + + // -- Expression evaluation -- + + #[test] + fn eval_not() { + let expr = ConditionExpr::parse("!no-results").unwrap(); + // Has results → !no-results is true + assert!(expr.evaluate(&ctx(0, 0, 0, 0, 5))); + // No results → !no-results is false + assert!(!expr.evaluate(&ctx(0, 0, 0, 0, 0))); + } + + #[test] + fn eval_and() { + let expr = ConditionExpr::parse("cursor-at-start && input-empty").unwrap(); + // Both true + assert!(expr.evaluate(&ctx(0, 0, 0, 0, 10))); + // First true, second false (non-empty input) + assert!(!expr.evaluate(&ctx(0, 5, 5, 0, 10))); + // First false (cursor not at start) + assert!(!expr.evaluate(&ctx(3, 5, 5, 0, 10))); + } + + #[test] + fn eval_or() { + let expr = ConditionExpr::parse("list-at-start || no-results").unwrap(); + // list at bottom (selected=0) + assert!(expr.evaluate(&ctx(0, 0, 0, 0, 10))); + // no results + assert!(expr.evaluate(&ctx(0, 0, 0, 0, 0))); + // neither + assert!(!expr.evaluate(&ctx(0, 0, 0, 5, 10))); + } + + #[test] + fn eval_complex_nested() { + // (cursor-at-start && !input-empty) || no-results + let expr = ConditionExpr::parse("(cursor-at-start && !input-empty) || no-results").unwrap(); + + // cursor at start, input not empty → true (left branch) + assert!(expr.evaluate(&ctx(0, 5, 5, 0, 10))); + // no results → true (right branch) + assert!(expr.evaluate(&ctx(3, 5, 5, 0, 0))); + // cursor not at start, has results → false + assert!(!expr.evaluate(&ctx(3, 5, 5, 0, 10))); + // cursor at start, input empty → false (left: && fails; right: has results) + assert!(!expr.evaluate(&ctx(0, 0, 0, 0, 10))); + } + + // -- Display -- + + #[test] + fn display_atom() { + let expr = ConditionExpr::Atom(ConditionAtom::CursorAtStart); + assert_eq!(expr.to_string(), "cursor-at-start"); + } + + #[test] + fn display_not() { + let expr = ConditionExpr::Atom(ConditionAtom::NoResults).not(); + assert_eq!(expr.to_string(), "!no-results"); + } + + #[test] + fn display_and() { + let expr = ConditionExpr::Atom(ConditionAtom::CursorAtStart) + .and(ConditionExpr::Atom(ConditionAtom::InputEmpty)); + assert_eq!(expr.to_string(), "cursor-at-start && input-empty"); + } + + #[test] + fn display_or() { + let expr = ConditionExpr::Atom(ConditionAtom::ListAtStart) + .or(ConditionExpr::Atom(ConditionAtom::NoResults)); + assert_eq!(expr.to_string(), "list-at-start || no-results"); + } + + #[test] + fn display_parens_when_needed() { + // (a || b) && c — the Or inside And needs parens + let expr = ConditionExpr::Atom(ConditionAtom::CursorAtStart) + .or(ConditionExpr::Atom(ConditionAtom::InputEmpty)) + .and(ConditionExpr::Atom(ConditionAtom::NoResults)); + assert_eq!( + expr.to_string(), + "(cursor-at-start || input-empty) && no-results" + ); + } + + #[test] + fn display_no_parens_when_not_needed() { + // a || b && c — no parens needed (and binds tighter) + let inner_and = ConditionExpr::Atom(ConditionAtom::InputEmpty) + .and(ConditionExpr::Atom(ConditionAtom::NoResults)); + let expr = ConditionExpr::Atom(ConditionAtom::CursorAtStart).or(inner_and); + assert_eq!( + expr.to_string(), + "cursor-at-start || input-empty && no-results" + ); + } + + // -- Display round-trip -- + + #[test] + fn display_round_trip() { + let cases = [ + "cursor-at-start", + "!no-results", + "cursor-at-start && input-empty", + "list-at-start || no-results", + "(cursor-at-start || input-empty) && no-results", + "(cursor-at-start && !input-empty) || no-results", + ]; + for s in cases { + let expr = ConditionExpr::parse(s).unwrap(); + let displayed = expr.to_string(); + let reparsed = ConditionExpr::parse(&displayed).unwrap(); + assert_eq!(expr, reparsed, "round-trip failed for: {s}"); + } + } + + // -- Serde -- + + #[test] + fn serde_simple_atom() { + let expr = ConditionExpr::Atom(ConditionAtom::CursorAtStart); + let json = serde_json::to_string(&expr).unwrap(); + assert_eq!(json, "\"cursor-at-start\""); + let parsed: ConditionExpr = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed, expr); + } + + #[test] + fn serde_compound_expression() { + let json = "\"cursor-at-start && !input-empty\""; + let parsed: ConditionExpr = serde_json::from_str(json).unwrap(); + let expected = ConditionExpr::And( + Box::new(ConditionExpr::Atom(ConditionAtom::CursorAtStart)), + Box::new(ConditionExpr::Not(Box::new(ConditionExpr::Atom( + ConditionAtom::InputEmpty, + )))), + ); + assert_eq!(parsed, expected); + } + + #[test] + fn serde_round_trip() { + let expr = ConditionExpr::parse("(cursor-at-start && !input-empty) || no-results").unwrap(); + let json = serde_json::to_string(&expr).unwrap(); + let parsed: ConditionExpr = serde_json::from_str(&json).unwrap(); + assert_eq!(expr, parsed); + } + + // -- From<ConditionAtom> -- + + #[test] + fn from_atom_into_expr() { + let expr: ConditionExpr = ConditionAtom::CursorAtStart.into(); + assert_eq!(expr, ConditionExpr::Atom(ConditionAtom::CursorAtStart)); + } + + // -- Builder helpers -- + + #[test] + fn builder_chain() { + let expr = ConditionExpr::from(ConditionAtom::CursorAtStart) + .and(ConditionExpr::from(ConditionAtom::InputEmpty).not()) + .or(ConditionExpr::from(ConditionAtom::NoResults)); + // And binds tighter than Or, so no parens needed around the And + assert_eq!( + expr.to_string(), + "cursor-at-start && !input-empty || no-results" + ); + } +} diff --git a/crates/turtle/src/command/client/search/keybindings/defaults.rs b/crates/turtle/src/command/client/search/keybindings/defaults.rs new file mode 100644 index 00000000..c8401e37 --- /dev/null +++ b/crates/turtle/src/command/client/search/keybindings/defaults.rs @@ -0,0 +1,1286 @@ +use std::collections::HashMap; + +use crate::atuin_client::settings::{KeyBindingConfig, Settings}; +use tracing::warn; + +use super::actions::Action; +use super::conditions::{ConditionAtom, ConditionExpr}; +use super::key::KeyInput; +use super::keymap::{KeyBinding, KeyRule, Keymap}; + +/// Helper to bind a scroll key with optional exit behavior. +/// +/// When `scroll_exits` is true AND the key scrolls toward index 0 (the newest +/// entry), we add a conditional rule: at `ListAtStart` → `Exit`, otherwise → +/// the scroll action. +/// +/// Whether a key scrolls toward index 0 depends on the `invert` setting: +/// - Non-inverted: "down" / "j" move toward index 0, "up" / "k" move away +/// - Inverted: "up" / "k" move toward index 0, "down" / "j" move away +/// +/// If `toward_index_zero` is false, or `scroll_exits` is false, we just bind +/// the key to the plain scroll action (no exit). +fn bind_scroll_key( + km: &mut Keymap, + key_str: &str, + action: Action, + toward_index_zero: bool, + scroll_exits: bool, +) { + let k = key(key_str); + if scroll_exits && toward_index_zero { + km.bind_conditional( + k, + vec![ + KeyRule::when(ConditionAtom::ListAtStart, Action::Exit), + KeyRule::always(action), + ], + ); + } else { + km.bind(k, action); + } +} + +/// Helper to parse a key string, panicking on invalid keys (these are all +/// compile-time-known strings). +fn key(s: &str) -> KeyInput { + KeyInput::parse(s).unwrap_or_else(|e| panic!("invalid default key {s:?}: {e}")) +} + +/// All five keymaps bundled together. +#[derive(Debug, Clone)] +pub struct KeymapSet { + pub emacs: Keymap, + pub vim_normal: Keymap, + pub vim_insert: Keymap, + pub inspector: Keymap, + pub prefix: Keymap, +} + +// --------------------------------------------------------------------------- +// Common bindings shared across search-tab keymaps +// --------------------------------------------------------------------------- + +/// Add the bindings that are common to all search-tab keymaps: +/// ctrl-c, ctrl-g, ctrl-o, and tab. +/// +/// Note: `esc`/`ctrl-[` are NOT included here because their behavior differs +/// between emacs (exit), vim-normal (exit), and vim-insert (enter normal mode). +fn add_common_bindings(km: &mut Keymap) { + km.bind(key("ctrl-c"), Action::ReturnOriginal); + km.bind(key("ctrl-g"), Action::ReturnOriginal); + km.bind(key("ctrl-o"), Action::ToggleTab); + + // Tab: always returns selection without executing (unlike Enter which respects enter_accept) + km.bind(key("tab"), Action::ReturnSelection); +} + +/// Returns `Accept` or `ReturnSelection` based on the `enter_accept` setting. +fn accept_action(settings: &Settings) -> Action { + if settings.enter_accept { + Action::Accept + } else { + Action::ReturnSelection + } +} + +// --------------------------------------------------------------------------- +// Emacs keymap (also base for vim-insert) +// --------------------------------------------------------------------------- + +/// Build the default emacs keymap. This encodes the behavior from +/// `handle_key_input` common section + `handle_search_input` shared section. +/// +/// The `settings` parameter is used for: +/// - `keys.prefix` — which ctrl-key enters prefix mode +/// - `keys.scroll_exits`, `invert` — scroll-at-boundary exit behavior +/// - `keys.accept_past_line_end` — right arrow at end of line accepts +/// - `keys.exit_past_line_start` — left arrow at start of line exits +/// - `keys.accept_past_line_start` — left arrow at start accepts (overrides exit) +/// - `keys.accept_with_backspace` — backspace at start of line accepts +/// - `ctrl_n_shortcuts` — whether alt or ctrl is used for numeric shortcuts +// Keymap builder that enumerates every default binding; not worth splitting. +#[expect(clippy::too_many_lines)] +pub fn default_emacs_keymap(settings: &Settings) -> Keymap { + let mut km = Keymap::new(); + add_common_bindings(&mut km); + + let accept = accept_action(settings); + + // esc / ctrl-[ → exit + km.bind(key("esc"), Action::Exit); + km.bind(key("ctrl-["), Action::Exit); + + // Prefix key: ctrl-<prefix_char> → enter prefix mode + let prefix_char = settings.keys.prefix.chars().next().unwrap_or('a'); + km.bind(key(&format!("ctrl-{prefix_char}")), Action::EnterPrefixMode); + + // --- Accept / navigation edge behaviors (from [keys] settings) --- + + // right: behavior at end of line + if settings.keys.accept_past_line_end { + km.bind_conditional( + key("right"), + vec![ + KeyRule::when(ConditionAtom::CursorAtEnd, Action::ReturnSelection), + KeyRule::always(Action::CursorRight), + ], + ); + } else { + km.bind(key("right"), Action::CursorRight); + } + + // left: behavior at start of line + // accept_past_line_start takes precedence over exit_past_line_start + if settings.keys.accept_past_line_start { + km.bind_conditional( + key("left"), + vec![ + KeyRule::when(ConditionAtom::CursorAtStart, Action::ReturnSelection), + KeyRule::always(Action::CursorLeft), + ], + ); + } else if settings.keys.exit_past_line_start { + km.bind_conditional( + key("left"), + vec![ + KeyRule::when(ConditionAtom::CursorAtStart, Action::Exit), + KeyRule::always(Action::CursorLeft), + ], + ); + } else { + km.bind(key("left"), Action::CursorLeft); + } + + // down/up: scroll with optional exit at boundary. + // Non-inverted: down moves toward index 0 (can exit); up moves away (no exit). + // Inverted: up moves toward index 0 (can exit); down moves away (no exit). + let scroll_exits = settings.keys.scroll_exits; + let invert = settings.invert; + bind_scroll_key(&mut km, "down", Action::SelectNext, !invert, scroll_exits); + bind_scroll_key(&mut km, "up", Action::SelectPrevious, invert, scroll_exits); + + // backspace: behavior at start of line + if settings.keys.accept_with_backspace { + km.bind_conditional( + key("backspace"), + vec![ + KeyRule::when(ConditionAtom::CursorAtStart, Action::ReturnSelection), + KeyRule::always(Action::DeleteCharBefore), + ], + ); + } else { + km.bind(key("backspace"), Action::DeleteCharBefore); + } + + // --- Accept --- + km.bind(key("enter"), accept.clone()); + km.bind(key("ctrl-m"), accept); + + // --- Copy --- + km.bind(key("ctrl-y"), Action::Copy); + + // --- Numeric shortcuts (alt-1..9 by default, ctrl-1..9 if ctrl_n_shortcuts) --- + // These return the selection without executing, regardless of enter_accept. + let num_mod = if settings.ctrl_n_shortcuts { + "ctrl" + } else { + "alt" + }; + for n in 1..=9u8 { + km.bind( + key(&format!("{num_mod}-{n}")), + Action::ReturnSelectionNth(n), + ); + } + + // --- Cursor movement --- + km.bind(key("ctrl-left"), Action::CursorWordLeft); + km.bind(key("alt-b"), Action::CursorWordLeft); + km.bind(key("ctrl-b"), Action::CursorLeft); + km.bind(key("ctrl-right"), Action::CursorWordRight); + km.bind(key("alt-f"), Action::CursorWordRight); + km.bind(key("ctrl-f"), Action::CursorRight); + km.bind(key("home"), Action::CursorStart); + // ctrl-a → CursorStart only if prefix char is NOT 'a' + // (otherwise ctrl-a is already bound to EnterPrefixMode above) + if prefix_char != 'a' { + km.bind(key("ctrl-a"), Action::CursorStart); + } + km.bind(key("ctrl-e"), Action::CursorEnd); + km.bind(key("end"), Action::CursorEnd); + + // --- Editing --- + km.bind(key("ctrl-backspace"), Action::DeleteWordBefore); + km.bind(key("ctrl-h"), Action::DeleteCharBefore); + km.bind(key("ctrl-?"), Action::DeleteCharBefore); + km.bind(key("ctrl-delete"), Action::DeleteWordAfter); + km.bind(key("delete"), Action::DeleteCharAfter); + // ctrl-d: if input empty → return original, otherwise delete char + km.bind_conditional( + key("ctrl-d"), + vec![ + KeyRule::when(ConditionAtom::InputEmpty, Action::ReturnOriginal), + KeyRule::always(Action::DeleteCharAfter), + ], + ); + km.bind(key("ctrl-w"), Action::DeleteToWordBoundary); + km.bind(key("ctrl-u"), Action::ClearLine); + + // --- Search mode --- + km.bind(key("ctrl-r"), Action::CycleFilterMode); + km.bind(key("ctrl-s"), Action::CycleSearchMode); + + // --- Scroll (no exit) --- + km.bind(key("ctrl-n"), Action::SelectNext); + km.bind(key("ctrl-j"), Action::SelectNext); + km.bind(key("ctrl-p"), Action::SelectPrevious); + km.bind(key("ctrl-k"), Action::SelectPrevious); + + // --- Redraw --- + km.bind(key("ctrl-l"), Action::Redraw); + + // --- Page scroll --- + km.bind(key("pagedown"), Action::ScrollPageDown); + km.bind(key("pageup"), Action::ScrollPageUp); + + km +} + +// --------------------------------------------------------------------------- +// Vim Normal keymap +// --------------------------------------------------------------------------- + +/// Build the default vim-normal keymap. +pub fn default_vim_normal_keymap(settings: &Settings) -> Keymap { + let mut km = Keymap::new(); + add_common_bindings(&mut km); + + // esc / ctrl-[ → exit (vim-normal exits, unlike vim-insert) + km.bind(key("esc"), Action::Exit); + km.bind(key("ctrl-["), Action::Exit); + + // Prefix key + let prefix_char = settings.keys.prefix.chars().next().unwrap_or('a'); + km.bind(key(&format!("ctrl-{prefix_char}")), Action::EnterPrefixMode); + + // --- Vim navigation --- + // j/k: scroll with optional exit at boundary. + let scroll_exits = settings.keys.scroll_exits; + let invert = settings.invert; + bind_scroll_key(&mut km, "j", Action::SelectNext, !invert, scroll_exits); + bind_scroll_key(&mut km, "k", Action::SelectPrevious, invert, scroll_exits); + km.bind(key("h"), Action::CursorLeft); + km.bind(key("l"), Action::CursorRight); + + // --- Vim cursor movement --- + km.bind(key("0"), Action::CursorStart); + km.bind(key("$"), Action::CursorEnd); + km.bind(key("w"), Action::CursorWordRight); + km.bind(key("b"), Action::CursorWordLeft); + km.bind(key("e"), Action::CursorWordEnd); + + // --- Vim editing --- + km.bind(key("x"), Action::DeleteCharAfter); + km.bind(key("d d"), Action::ClearLine); + km.bind(key("D"), Action::ClearToEnd); + km.bind(key("C"), Action::VimChangeToEnd); + + // --- Mode switching --- + km.bind(key("?"), Action::VimSearchInsert); + km.bind(key("/"), Action::VimSearchInsert); + km.bind(key("a"), Action::VimEnterInsertAfter); + km.bind(key("A"), Action::VimEnterInsertAtEnd); + km.bind(key("i"), Action::VimEnterInsert); + km.bind(key("I"), Action::VimEnterInsertAtStart); + + // --- Numeric shortcuts (return selection without executing) --- + for n in 1..=9u8 { + km.bind(key(&n.to_string()), Action::ReturnSelectionNth(n)); + } + + // --- Half/full page scroll --- + km.bind(key("ctrl-u"), Action::ScrollHalfPageUp); + km.bind(key("ctrl-d"), Action::ScrollHalfPageDown); + km.bind(key("ctrl-b"), Action::ScrollPageUp); + km.bind(key("ctrl-f"), Action::ScrollPageDown); + + // --- Jump --- + km.bind(key("G"), Action::ScrollToBottom); + km.bind(key("g g"), Action::ScrollToTop); + km.bind(key("H"), Action::ScrollToScreenTop); + km.bind(key("M"), Action::ScrollToScreenMiddle); + km.bind(key("L"), Action::ScrollToScreenBottom); + + // --- Arrow keys (same as emacs for convenience) --- + bind_scroll_key(&mut km, "down", Action::SelectNext, !invert, scroll_exits); + bind_scroll_key(&mut km, "up", Action::SelectPrevious, invert, scroll_exits); + + // --- Page scroll --- + km.bind(key("pagedown"), Action::ScrollPageDown); + km.bind(key("pageup"), Action::ScrollPageUp); + + // --- Accept --- + let accept = accept_action(settings); + km.bind(key("enter"), accept); + + km +} + +// --------------------------------------------------------------------------- +// Vim Insert keymap +// --------------------------------------------------------------------------- + +/// Build the default vim-insert keymap. This clones the emacs keymap and +/// overlays vim-insert-specific bindings (esc → enter normal mode). +pub fn default_vim_insert_keymap(settings: &Settings) -> Keymap { + let mut km = default_emacs_keymap(settings); + + // Override esc and ctrl-[ to enter normal mode instead of exiting + km.bind(key("esc"), Action::VimEnterNormal); + km.bind(key("ctrl-["), Action::VimEnterNormal); + + km +} + +// --------------------------------------------------------------------------- +// Inspector keymap +// --------------------------------------------------------------------------- + +/// Build the default inspector keymap (tab index 1). +/// +/// The inspector shows details about the selected history item and has no +/// text input, so we build a minimal keymap with only inspector-relevant +/// bindings. We respect the user's `keymap_mode` to provide vim-style j/k +/// navigation for vim users. +pub fn default_inspector_keymap(settings: &Settings) -> Keymap { + use crate::atuin_client::settings::KeymapMode; + + let mut km = Keymap::new(); + + // Common bindings (same as search tab) + km.bind(key("ctrl-c"), Action::ReturnOriginal); + km.bind(key("ctrl-g"), Action::ReturnOriginal); + km.bind(key("esc"), Action::Exit); + km.bind(key("ctrl-["), Action::Exit); + km.bind(key("tab"), Action::ReturnSelection); + km.bind(key("ctrl-o"), Action::ToggleTab); + + // Accept behavior respects enter_accept setting + let accept = if settings.enter_accept { + Action::Accept + } else { + Action::ReturnSelection + }; + km.bind(key("enter"), accept); + + // Inspector-specific: delete history entry + km.bind(key("ctrl-d"), Action::Delete); + + // Inspector navigation + km.bind(key("up"), Action::InspectPrevious); + km.bind(key("down"), Action::InspectNext); + km.bind(key("pageup"), Action::InspectPrevious); + km.bind(key("pagedown"), Action::InspectNext); + + // For vim users, add j/k navigation + if matches!( + settings.keymap_mode, + KeymapMode::VimNormal | KeymapMode::VimInsert + ) { + km.bind(key("j"), Action::InspectNext); + km.bind(key("k"), Action::InspectPrevious); + } + + km +} + +// --------------------------------------------------------------------------- +// Prefix keymap +// --------------------------------------------------------------------------- + +/// Build the default prefix keymap (active after ctrl-a prefix). +pub fn default_prefix_keymap() -> Keymap { + let mut km = Keymap::new(); + + km.bind(key("d"), Action::Delete); + km.bind(key("D"), Action::DeleteAll); + km.bind(key("a"), Action::CursorStart); + km.bind_conditional( + key("c"), + vec![ + KeyRule::when(ConditionAtom::HasContext, Action::ClearContext), + KeyRule::always(Action::SwitchContext), + ], + ); + + km +} + +// --------------------------------------------------------------------------- +// KeymapSet construction +// --------------------------------------------------------------------------- + +// --------------------------------------------------------------------------- +// Config → Keymap conversion +// --------------------------------------------------------------------------- + +/// Convert a `KeyBindingConfig` (from TOML) into a `KeyBinding`. +/// Returns `Err` if an action name or condition expression is invalid. +fn parse_binding_config(config: &KeyBindingConfig) -> Result<KeyBinding, String> { + match config { + KeyBindingConfig::Simple(action_str) => { + let action = Action::from_str(action_str)?; + Ok(KeyBinding::simple(action)) + } + KeyBindingConfig::Rules(rules) => { + let mut parsed_rules = Vec::with_capacity(rules.len()); + for rule_cfg in rules { + let action = Action::from_str(&rule_cfg.action)?; + let rule = match &rule_cfg.when { + None => KeyRule::always(action), + Some(cond_str) => { + let cond = ConditionExpr::parse(cond_str)?; + KeyRule::when(cond, action) + } + }; + parsed_rules.push(rule); + } + Ok(KeyBinding::conditional(parsed_rules)) + } + } +} + +/// Apply a map of key-string → binding-config overrides to a keymap. +/// Per-key override replaces the entire rule list for that key. +/// Invalid keys or action names are logged and skipped. +fn apply_config_to_keymap(keymap: &mut Keymap, overrides: &HashMap<String, KeyBindingConfig>) { + for (key_str, binding_cfg) in overrides { + let key = match KeyInput::parse(key_str) { + Ok(k) => k, + Err(e) => { + warn!("invalid key in keymap config: {key_str:?}: {e}"); + continue; + } + }; + match parse_binding_config(binding_cfg) { + Ok(binding) => { + keymap.bindings.insert(key, binding); + } + Err(e) => { + warn!("invalid binding for {key_str:?} in keymap config: {e}"); + } + } + } +} + +impl KeymapSet { + /// Build the complete set of default keymaps from settings. + pub fn defaults(settings: &Settings) -> Self { + KeymapSet { + emacs: default_emacs_keymap(settings), + vim_normal: default_vim_normal_keymap(settings), + vim_insert: default_vim_insert_keymap(settings), + inspector: default_inspector_keymap(settings), + prefix: default_prefix_keymap(), + } + } + + /// Build keymaps from settings, applying any user `[keymap]` overrides. + /// + /// Precedence rules: + /// - If `[keymap]` has any entries, `[keys]` is **ignored entirely**. + /// Defaults are built with standard `[keys]` values, then `[keymap]` + /// overrides are applied per-key. + /// - If `[keymap]` is empty/absent, `[keys]` customizes the defaults + /// (current behavior for backward compatibility). + pub fn from_settings(settings: &Settings) -> Self { + use crate::atuin_client::settings::Keys; + + if settings.keymap.is_empty() { + // No [keymap] section → use [keys] to customize defaults + Self::defaults(settings) + } else { + // [keymap] present → ignore [keys], use standard defaults as base + let mut base_settings = settings.clone(); + base_settings.keys = Keys::standard_defaults(); + let mut set = Self::defaults(&base_settings); + set.apply_config(settings); + set + } + } + + /// Apply user keymap config overrides to all modes. + fn apply_config(&mut self, settings: &Settings) { + let config = &settings.keymap; + apply_config_to_keymap(&mut self.emacs, &config.emacs); + apply_config_to_keymap(&mut self.vim_normal, &config.vim_normal); + apply_config_to_keymap(&mut self.vim_insert, &config.vim_insert); + apply_config_to_keymap(&mut self.inspector, &config.inspector); + apply_config_to_keymap(&mut self.prefix, &config.prefix); + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use crate::command::client::search::keybindings::conditions::EvalContext; + + fn make_ctx(cursor: usize, width: usize, selected: usize, len: usize) -> EvalContext { + EvalContext { + cursor_position: cursor, + input_width: width, + input_byte_len: width, + selected_index: selected, + results_len: len, + original_input_empty: false, + has_context: false, + } + } + + fn default_settings() -> Settings { + Settings::utc() + } + + // -- Emacs keymap tests -- + + #[test] + fn emacs_ctrl_c_returns_original() { + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + km.resolve(&key("ctrl-c"), &ctx), + Some(Action::ReturnOriginal) + ); + } + + #[test] + fn emacs_esc_exits() { + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("esc"), &ctx), Some(Action::Exit)); + } + + #[test] + fn emacs_tab_returns_selection() { + // enter_accept=false in test defaults → ReturnSelection + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("tab"), &ctx), Some(Action::ReturnSelection)); + } + + #[test] + fn emacs_enter_returns_selection() { + // enter_accept=false in test defaults → ReturnSelection + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + km.resolve(&key("enter"), &ctx), + Some(Action::ReturnSelection) + ); + } + + #[test] + fn emacs_enter_accept_true_uses_accept() { + let mut settings = default_settings(); + settings.enter_accept = true; + let km = default_emacs_keymap(&settings); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("enter"), &ctx), Some(Action::Accept)); + assert_eq!(km.resolve(&key("tab"), &ctx), Some(Action::ReturnSelection)); + } + + #[test] + fn emacs_right_at_end_returns_selection() { + let km = default_emacs_keymap(&default_settings()); + // cursor at end of "hello" (width 5) + let ctx = make_ctx(5, 5, 0, 10); + assert_eq!( + km.resolve(&key("right"), &ctx), + Some(Action::ReturnSelection) + ); + } + + #[test] + fn emacs_right_not_at_end_moves() { + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(2, 5, 0, 10); + assert_eq!(km.resolve(&key("right"), &ctx), Some(Action::CursorRight)); + } + + #[test] + fn emacs_left_at_start_exits() { + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(0, 5, 0, 10); + assert_eq!(km.resolve(&key("left"), &ctx), Some(Action::Exit)); + } + + #[test] + fn emacs_left_not_at_start_moves() { + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(3, 5, 0, 10); + assert_eq!(km.resolve(&key("left"), &ctx), Some(Action::CursorLeft)); + } + + #[test] + fn emacs_down_at_start_exits() { + let km = default_emacs_keymap(&default_settings()); + // selected=0 → ListAtStart → Exit + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("down"), &ctx), Some(Action::Exit)); + } + + #[test] + fn emacs_down_not_at_start_selects_next() { + let km = default_emacs_keymap(&default_settings()); + // selected=5 → not at start → SelectNext + let ctx = make_ctx(0, 0, 5, 10); + assert_eq!(km.resolve(&key("down"), &ctx), Some(Action::SelectNext)); + } + + #[test] + fn emacs_up_selects_previous() { + let km = default_emacs_keymap(&default_settings()); + // Non-inverted: up never exits (moves away from index 0) + let ctx = make_ctx(0, 0, 5, 10); + assert_eq!(km.resolve(&key("up"), &ctx), Some(Action::SelectPrevious)); + } + + #[test] + fn emacs_ctrl_d_empty_returns_original() { + let km = default_emacs_keymap(&default_settings()); + // input empty (byte_len = 0) + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + km.resolve(&key("ctrl-d"), &ctx), + Some(Action::ReturnOriginal) + ); + } + + #[test] + fn emacs_ctrl_d_nonempty_deletes() { + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(2, 5, 0, 10); + assert_eq!( + km.resolve(&key("ctrl-d"), &ctx), + Some(Action::DeleteCharAfter) + ); + } + + #[test] + fn emacs_ctrl_n_selects_next_no_exit_condition() { + let km = default_emacs_keymap(&default_settings()); + // at start, but ctrl-n should NOT exit (no exit condition bound) + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("ctrl-n"), &ctx), Some(Action::SelectNext)); + } + + #[test] + fn emacs_prefix_key_enters_prefix() { + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + km.resolve(&key("ctrl-a"), &ctx), + Some(Action::EnterPrefixMode) + ); + } + + #[test] + fn emacs_home_cursor_start() { + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(5, 10, 0, 10); + assert_eq!(km.resolve(&key("home"), &ctx), Some(Action::CursorStart)); + } + + // -- Vim Normal keymap tests -- + + #[test] + fn vim_normal_j_at_start_exits() { + let km = default_vim_normal_keymap(&default_settings()); + // selected=0 → ListAtStart → Exit (non-inverted: j moves toward index 0) + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("j"), &ctx), Some(Action::Exit)); + } + + #[test] + fn vim_normal_j_not_at_start_selects_next() { + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 5, 10); + assert_eq!(km.resolve(&key("j"), &ctx), Some(Action::SelectNext)); + } + + #[test] + fn vim_normal_k_selects_previous() { + let km = default_vim_normal_keymap(&default_settings()); + // Non-inverted: k never exits (moves away from index 0) + let ctx = make_ctx(0, 0, 5, 10); + assert_eq!(km.resolve(&key("k"), &ctx), Some(Action::SelectPrevious)); + } + + #[test] + fn vim_normal_i_enters_insert() { + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("i"), &ctx), Some(Action::VimEnterInsert)); + } + + #[test] + fn vim_normal_slash_search_insert() { + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("/"), &ctx), Some(Action::VimSearchInsert)); + } + + #[test] + fn vim_normal_gg_scroll_to_top() { + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 50, 100); + assert_eq!(km.resolve(&key("g g"), &ctx), Some(Action::ScrollToTop)); + } + + #[test] + fn vim_normal_big_g_scroll_to_bottom() { + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 50, 100); + assert_eq!(km.resolve(&key("G"), &ctx), Some(Action::ScrollToBottom)); + } + + #[test] + fn vim_normal_numeric_returns_selection() { + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + km.resolve(&key("3"), &ctx), + Some(Action::ReturnSelectionNth(3)) + ); + } + + #[test] + fn vim_normal_ctrl_u_half_page_up() { + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 50, 100); + assert_eq!( + km.resolve(&key("ctrl-u"), &ctx), + Some(Action::ScrollHalfPageUp) + ); + } + + #[test] + fn vim_normal_screen_jumps() { + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 50, 100); + assert_eq!(km.resolve(&key("H"), &ctx), Some(Action::ScrollToScreenTop)); + assert_eq!( + km.resolve(&key("M"), &ctx), + Some(Action::ScrollToScreenMiddle) + ); + assert_eq!( + km.resolve(&key("L"), &ctx), + Some(Action::ScrollToScreenBottom) + ); + } + + #[test] + fn vim_normal_enter_returns_selection() { + // enter_accept=false in test defaults → ReturnSelection + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + km.resolve(&key("enter"), &ctx), + Some(Action::ReturnSelection) + ); + } + + #[test] + fn vim_normal_enter_accept_true_uses_accept() { + let mut settings = default_settings(); + settings.enter_accept = true; + let km = default_vim_normal_keymap(&settings); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("enter"), &ctx), Some(Action::Accept)); + } + + // -- Vim Insert keymap tests -- + + #[test] + fn vim_insert_inherits_emacs_enter() { + let km = default_vim_insert_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + // enter_accept=false → ReturnSelection + assert_eq!( + km.resolve(&key("enter"), &ctx), + Some(Action::ReturnSelection) + ); + } + + #[test] + fn vim_insert_esc_enters_normal() { + let km = default_vim_insert_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("esc"), &ctx), Some(Action::VimEnterNormal)); + } + + #[test] + fn vim_insert_ctrl_bracket_enters_normal() { + let km = default_vim_insert_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + km.resolve(&key("ctrl-["), &ctx), + Some(Action::VimEnterNormal) + ); + } + + #[test] + fn vim_insert_inherits_emacs_ctrl_d() { + let km = default_vim_insert_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + // input empty → return original + assert_eq!( + km.resolve(&key("ctrl-d"), &ctx), + Some(Action::ReturnOriginal) + ); + } + + // -- Inspector keymap tests -- + + #[test] + fn inspector_ctrl_d_deletes() { + let km = default_inspector_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("ctrl-d"), &ctx), Some(Action::Delete)); + } + + #[test] + fn inspector_up_inspects_previous() { + let km = default_inspector_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("up"), &ctx), Some(Action::InspectPrevious)); + } + + #[test] + fn inspector_down_inspects_next() { + let km = default_inspector_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("down"), &ctx), Some(Action::InspectNext)); + } + + #[test] + fn inspector_esc_exits() { + let km = default_inspector_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("esc"), &ctx), Some(Action::Exit)); + } + + #[test] + fn inspector_tab_returns_selection() { + // enter_accept=false → ReturnSelection + let km = default_inspector_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("tab"), &ctx), Some(Action::ReturnSelection)); + } + + // -- Prefix keymap tests -- + + #[test] + fn prefix_d_deletes() { + let km = default_prefix_keymap(); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("d"), &ctx), Some(Action::Delete)); + } + + #[test] + fn prefix_a_cursor_start() { + let km = default_prefix_keymap(); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("a"), &ctx), Some(Action::CursorStart)); + } + + #[test] + fn prefix_unknown_key_returns_none() { + let km = default_prefix_keymap(); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("x"), &ctx), None); + } + + // -- KeymapSet tests -- + + #[test] + fn keymap_set_defaults_builds() { + let settings = default_settings(); + let set = KeymapSet::defaults(&settings); + let ctx = make_ctx(0, 0, 0, 10); + + // Sanity check each keymap has bindings + assert!(set.emacs.resolve(&key("ctrl-c"), &ctx).is_some()); + assert!(set.vim_normal.resolve(&key("ctrl-c"), &ctx).is_some()); + assert!(set.vim_insert.resolve(&key("ctrl-c"), &ctx).is_some()); + assert!(set.inspector.resolve(&key("ctrl-c"), &ctx).is_some()); + assert!(set.prefix.resolve(&key("d"), &ctx).is_some()); + } + + // -- Settings-dependent behavior -- + + #[test] + fn custom_prefix_char() { + let mut settings = default_settings(); + settings.keys.prefix = "x".to_string(); + let km = default_emacs_keymap(&settings); + let ctx = make_ctx(0, 0, 0, 10); + + // ctrl-x should be prefix mode + assert_eq!( + km.resolve(&key("ctrl-x"), &ctx), + Some(Action::EnterPrefixMode) + ); + // ctrl-a should now be CursorStart (not prefix) + assert_eq!(km.resolve(&key("ctrl-a"), &ctx), Some(Action::CursorStart)); + } + + #[test] + fn ctrl_n_shortcuts_changes_numeric_modifier() { + let mut settings = default_settings(); + settings.ctrl_n_shortcuts = true; + let km = default_emacs_keymap(&settings); + let ctx = make_ctx(0, 0, 0, 10); + + // ctrl-1 should work + assert_eq!( + km.resolve(&key("ctrl-1"), &ctx), + Some(Action::ReturnSelectionNth(1)) + ); + // alt-1 should NOT be bound + assert_eq!(km.resolve(&key("alt-1"), &ctx), None); + } + + #[test] + fn default_alt_numeric_shortcuts() { + let settings = default_settings(); + let km = default_emacs_keymap(&settings); + let ctx = make_ctx(0, 0, 0, 10); + + // alt-1 should work by default + assert_eq!( + km.resolve(&key("alt-1"), &ctx), + Some(Action::ReturnSelectionNth(1)) + ); + } + + // ----------------------------------------------------------------------- + // Config parsing and merging tests + // ----------------------------------------------------------------------- + + #[test] + fn parse_simple_binding_config() { + use crate::atuin_client::settings::KeyBindingConfig; + let cfg = KeyBindingConfig::Simple("accept".to_string()); + let binding = super::parse_binding_config(&cfg).unwrap(); + assert_eq!(binding.rules.len(), 1); + assert!(binding.rules[0].condition.is_none()); + assert_eq!(binding.rules[0].action, Action::Accept); + } + + #[test] + fn parse_conditional_binding_config() { + use crate::atuin_client::settings::{KeyBindingConfig, KeyRuleConfig}; + let cfg = KeyBindingConfig::Rules(vec![ + KeyRuleConfig { + when: Some("cursor-at-start".to_string()), + action: "exit".to_string(), + }, + KeyRuleConfig { + when: None, + action: "cursor-left".to_string(), + }, + ]); + let binding = super::parse_binding_config(&cfg).unwrap(); + assert_eq!(binding.rules.len(), 2); + assert!(binding.rules[0].condition.is_some()); + assert_eq!(binding.rules[0].action, Action::Exit); + assert!(binding.rules[1].condition.is_none()); + assert_eq!(binding.rules[1].action, Action::CursorLeft); + } + + #[test] + fn parse_binding_config_invalid_action() { + use crate::atuin_client::settings::KeyBindingConfig; + let cfg = KeyBindingConfig::Simple("not-a-real-action".to_string()); + assert!(super::parse_binding_config(&cfg).is_err()); + } + + #[test] + fn parse_binding_config_invalid_condition() { + use crate::atuin_client::settings::{KeyBindingConfig, KeyRuleConfig}; + let cfg = KeyBindingConfig::Rules(vec![KeyRuleConfig { + when: Some("not-a-real-condition".to_string()), + action: "exit".to_string(), + }]); + assert!(super::parse_binding_config(&cfg).is_err()); + } + + #[test] + fn config_override_replaces_key() { + use crate::atuin_client::settings::KeyBindingConfig; + use std::collections::HashMap; + + let mut settings = default_settings(); + let set = KeymapSet::defaults(&settings); + + // Default: ctrl-c → ReturnOriginal + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + set.emacs.resolve(&key("ctrl-c"), &ctx), + Some(Action::ReturnOriginal) + ); + + // Override ctrl-c → Exit via config + settings.keymap.emacs = HashMap::from([( + "ctrl-c".to_string(), + KeyBindingConfig::Simple("exit".to_string()), + )]); + + let set = KeymapSet::from_settings(&settings); + assert_eq!(set.emacs.resolve(&key("ctrl-c"), &ctx), Some(Action::Exit)); + } + + #[test] + fn config_override_preserves_unoverridden_keys() { + use crate::atuin_client::settings::KeyBindingConfig; + use std::collections::HashMap; + + let mut settings = default_settings(); + // Override only ctrl-c; enter should keep its default + settings.keymap.emacs = HashMap::from([( + "ctrl-c".to_string(), + KeyBindingConfig::Simple("exit".to_string()), + )]); + + let set = KeymapSet::from_settings(&settings); + let ctx = make_ctx(0, 0, 0, 10); + + // ctrl-c overridden + assert_eq!(set.emacs.resolve(&key("ctrl-c"), &ctx), Some(Action::Exit)); + // enter still has default (enter_accept=false → ReturnSelection) + assert_eq!( + set.emacs.resolve(&key("enter"), &ctx), + Some(Action::ReturnSelection) + ); + } + + #[test] + fn config_conditional_override() { + use crate::atuin_client::settings::{KeyBindingConfig, KeyRuleConfig}; + use std::collections::HashMap; + + let mut settings = default_settings(); + // Override "up" with a custom conditional + settings.keymap.emacs = HashMap::from([( + "up".to_string(), + KeyBindingConfig::Rules(vec![ + KeyRuleConfig { + when: Some("no-results".to_string()), + action: "exit".to_string(), + }, + KeyRuleConfig { + when: None, + action: "select-previous".to_string(), + }, + ]), + )]); + + let set = KeymapSet::from_settings(&settings); + + // With no results → exit + let ctx = make_ctx(0, 0, 0, 0); + assert_eq!(set.emacs.resolve(&key("up"), &ctx), Some(Action::Exit)); + + // With results → select-previous + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + set.emacs.resolve(&key("up"), &ctx), + Some(Action::SelectPrevious) + ); + } + + #[test] + fn from_settings_with_empty_config_equals_defaults() { + let settings = default_settings(); + let defaults = KeymapSet::defaults(&settings); + let from_settings = KeymapSet::from_settings(&settings); + + // Verify a sample of keys produce the same results + let ctx = make_ctx(0, 0, 0, 10); + let test_keys = [ + "ctrl-c", "enter", "esc", "tab", "up", "down", "left", "right", + ]; + for k in &test_keys { + assert_eq!( + defaults.emacs.resolve(&key(k), &ctx), + from_settings.emacs.resolve(&key(k), &ctx), + "mismatch for emacs key {k}" + ); + } + } + + // ----------------------------------------------------------------------- + // Phase 5: [keys] vs [keymap] backward compatibility + // ----------------------------------------------------------------------- + + #[test] + fn keymap_overrides_ignore_keys_section() { + use crate::atuin_client::settings::KeyBindingConfig; + + // Set up: [keys] disables scroll_exits, but [keymap] is present + let mut settings = default_settings(); + settings.keys.scroll_exits = false; + + // Without [keymap], scroll_exits=false means no exit condition on down + let set_legacy = KeymapSet::defaults(&settings); + // At list-at-start (selected=0), down should still be SelectNext (no exit) + let ctx_at_boundary = make_ctx(0, 0, 0, 10); + assert_eq!( + set_legacy.emacs.resolve(&key("down"), &ctx_at_boundary), + Some(Action::SelectNext), + "legacy: down at boundary should be SelectNext with scroll_exits=false" + ); + + // With [keymap] present (even just one override), [keys] is ignored + // so the standard defaults (scroll_exits=true) apply + settings.keymap.emacs = HashMap::from([( + "ctrl-c".to_string(), + KeyBindingConfig::Simple("exit".to_string()), + )]); + let set_keymap = KeymapSet::from_settings(&settings); + + // Not at boundary (selected=5): should SelectNext normally + let ctx_not_at_boundary = make_ctx(0, 0, 5, 10); + assert_eq!( + set_keymap.emacs.resolve(&key("down"), &ctx_not_at_boundary), + Some(Action::SelectNext), + "keymap: down not at boundary should SelectNext" + ); + // At list-at-start (selected=0): should Exit (standard scroll_exits=true) + assert_eq!( + set_keymap.emacs.resolve(&key("down"), &ctx_at_boundary), + Some(Action::Exit), + "keymap: down at boundary should Exit (standard defaults restored)" + ); + } + + #[test] + fn keymap_present_resets_to_standard_keys_defaults() { + use crate::atuin_client::settings::KeyBindingConfig; + + let mut settings = default_settings(); + // Disable all [keys] behaviors + settings.keys.exit_past_line_start = false; + settings.keys.accept_past_line_end = false; + + // Without [keymap], left should be plain CursorLeft + let set_legacy = KeymapSet::defaults(&settings); + let ctx_at_start = make_ctx(0, 5, 0, 10); + assert_eq!( + set_legacy.emacs.resolve(&key("left"), &ctx_at_start), + Some(Action::CursorLeft), + "legacy: left should be plain CursorLeft without exit_past_line_start" + ); + + // Add a [keymap] entry (for a different key) + settings.keymap.emacs = HashMap::from([( + "ctrl-c".to_string(), + KeyBindingConfig::Simple("exit".to_string()), + )]); + let set_keymap = KeymapSet::from_settings(&settings); + + // Now left should use standard defaults (exit_past_line_start=true) + // At cursor start → Exit + assert_eq!( + set_keymap.emacs.resolve(&key("left"), &ctx_at_start), + Some(Action::Exit), + "keymap: left at cursor start should exit (standard defaults)" + ); + + // Right at cursor end should return selection (standard defaults: accept_past_line_end=true, enter_accept=false) + let ctx_at_end = make_ctx(5, 5, 0, 10); + assert_eq!( + set_keymap.emacs.resolve(&key("right"), &ctx_at_end), + Some(Action::ReturnSelection), + "keymap: right at cursor end should return selection (standard defaults)" + ); + } + + #[test] + fn keys_has_non_default_values_detection() { + use crate::atuin_client::settings::Keys; + + let standard = Keys::standard_defaults(); + assert!(!standard.has_non_default_values()); + + let mut modified = Keys::standard_defaults(); + modified.scroll_exits = false; + assert!(modified.has_non_default_values()); + + let mut modified = Keys::standard_defaults(); + modified.prefix = "x".to_string(); + assert!(modified.has_non_default_values()); + } + + #[test] + fn original_input_empty_condition_in_config() { + use crate::atuin_client::settings::{KeyBindingConfig, KeyRuleConfig}; + use std::collections::HashMap; + + let mut settings = default_settings(); + // Configure esc to: if original-input-empty -> return-query, else return-original + settings.keymap.emacs = HashMap::from([( + "esc".to_string(), + KeyBindingConfig::Rules(vec![ + KeyRuleConfig { + when: Some("original-input-empty".to_string()), + action: "return-query".to_string(), + }, + KeyRuleConfig { + when: None, + action: "return-original".to_string(), + }, + ]), + )]); + + let set = KeymapSet::from_settings(&settings); + + // When original input was empty, should return-query + let ctx_original_empty = EvalContext { + cursor_position: 0, + input_width: 5, + input_byte_len: 5, + selected_index: 0, + results_len: 10, + original_input_empty: true, + has_context: false, + }; + assert_eq!( + set.emacs.resolve(&key("esc"), &ctx_original_empty), + Some(Action::ReturnQuery), + "esc with original_input_empty=true should return-query" + ); + + // When original input was not empty, should return-original + let ctx_original_not_empty = EvalContext { + cursor_position: 0, + input_width: 5, + input_byte_len: 5, + selected_index: 0, + results_len: 10, + original_input_empty: false, + has_context: false, + }; + assert_eq!( + set.emacs.resolve(&key("esc"), &ctx_original_not_empty), + Some(Action::ReturnOriginal), + "esc with original_input_empty=false should return-original" + ); + } +} diff --git a/crates/turtle/src/command/client/search/keybindings/key.rs b/crates/turtle/src/command/client/search/keybindings/key.rs new file mode 100644 index 00000000..c2eb31c6 --- /dev/null +++ b/crates/turtle/src/command/client/search/keybindings/key.rs @@ -0,0 +1,629 @@ +use std::fmt; + +use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers, MediaKeyCode}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +/// A single key press with modifiers (e.g. `ctrl-c`, `alt-f`, `enter`). +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[expect(clippy::struct_excessive_bools)] +pub struct SingleKey { + pub code: KeyCodeValue, + pub ctrl: bool, + pub alt: bool, + pub shift: bool, + pub super_key: bool, +} + +/// The key code portion of a key press. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum KeyCodeValue { + Char(char), + Enter, + Esc, + Tab, + Backspace, + Delete, + Insert, + Up, + Down, + Left, + Right, + Home, + End, + PageUp, + PageDown, + Space, + F(u8), + Media(MediaKeyCode), +} + +/// A key input that may be a single key or a multi-key sequence (e.g. `g g`). +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum KeyInput { + Single(SingleKey), + Sequence(Vec<SingleKey>), +} + +impl SingleKey { + /// Convert a crossterm `KeyEvent` into a `SingleKey`. + pub fn from_event(event: &KeyEvent) -> Option<Self> { + let ctrl = event.modifiers.contains(KeyModifiers::CONTROL); + let alt = event.modifiers.contains(KeyModifiers::ALT); + let shift = event.modifiers.contains(KeyModifiers::SHIFT); + let super_key = event.modifiers.contains(KeyModifiers::SUPER); + + let code = match event.code { + KeyCode::Char(' ') => KeyCodeValue::Space, + KeyCode::Char(c) => { + // If shift is the only modifier and it's an uppercase letter, + // we store the uppercase char directly and clear the shift flag + // since the case already encodes it. + if shift && !ctrl && !alt && !super_key && c.is_ascii_uppercase() { + return Some(SingleKey { + code: KeyCodeValue::Char(c), + ctrl: false, + alt: false, + shift: false, + super_key: false, + }); + } + KeyCodeValue::Char(c) + } + KeyCode::Enter => KeyCodeValue::Enter, + KeyCode::Esc => KeyCodeValue::Esc, + KeyCode::Tab => KeyCodeValue::Tab, + // BackTab is sent by many terminals for Shift+Tab + KeyCode::BackTab => { + return Some(SingleKey { + code: KeyCodeValue::Tab, + ctrl, + alt, + shift: true, + super_key, + }); + } + KeyCode::Backspace => KeyCodeValue::Backspace, + KeyCode::Delete => KeyCodeValue::Delete, + KeyCode::Insert => KeyCodeValue::Insert, + KeyCode::Up => KeyCodeValue::Up, + KeyCode::Down => KeyCodeValue::Down, + KeyCode::Left => KeyCodeValue::Left, + KeyCode::Right => KeyCodeValue::Right, + KeyCode::Home => KeyCodeValue::Home, + KeyCode::End => KeyCodeValue::End, + KeyCode::PageUp => KeyCodeValue::PageUp, + KeyCode::PageDown => KeyCodeValue::PageDown, + KeyCode::F(n) => KeyCodeValue::F(n), + KeyCode::Media(m) => KeyCodeValue::Media(m), + _ => return None, + }; + + Some(SingleKey { + code, + ctrl, + alt, + shift: if matches!(code, KeyCodeValue::Char(_)) { + false + } else { + shift + }, + super_key, + }) + } + + /// Parse a key string like `"ctrl-c"`, `"alt-f"`, `"enter"`, `"G"`. + pub fn parse(s: &str) -> Result<Self, String> { + let s = s.trim(); + let parts: Vec<&str> = s.split('-').collect(); + + let mut ctrl = false; + let mut alt = false; + let mut shift = false; + let mut super_key = false; + + // All parts except the last are modifiers + for &part in &parts[..parts.len() - 1] { + match part.to_lowercase().as_str() { + "ctrl" => ctrl = true, + "alt" => alt = true, + "shift" => shift = true, + "super" | "cmd" | "win" => super_key = true, + _ => return Err(format!("unknown modifier: {part}")), + } + } + + let key_part = parts[parts.len() - 1]; + let code = match key_part.to_lowercase().as_str() { + "enter" | "return" => KeyCodeValue::Enter, + "esc" | "escape" => KeyCodeValue::Esc, + "tab" => KeyCodeValue::Tab, + "backspace" => KeyCodeValue::Backspace, + "delete" | "del" => KeyCodeValue::Delete, + "insert" | "ins" => KeyCodeValue::Insert, + "up" => KeyCodeValue::Up, + "down" => KeyCodeValue::Down, + "left" => KeyCodeValue::Left, + "right" => KeyCodeValue::Right, + "home" => KeyCodeValue::Home, + "end" => KeyCodeValue::End, + "pageup" => KeyCodeValue::PageUp, + "pagedown" => KeyCodeValue::PageDown, + "space" => KeyCodeValue::Space, + s if s.starts_with('f') && s.len() > 1 => { + // Parse function keys like "f1", "f12" + if let Ok(n) = s[1..].parse::<u8>() { + if (1..=24).contains(&n) { + KeyCodeValue::F(n) + } else { + return Err(format!("function key out of range: {key_part}")); + } + } else { + return Err(format!("unknown key: {key_part}")); + } + } + "[" => KeyCodeValue::Char('['), + "]" => KeyCodeValue::Char(']'), + "?" => KeyCodeValue::Char('?'), + "/" => KeyCodeValue::Char('/'), + "$" => KeyCodeValue::Char('$'), + // Media keys (no dashes - the parser splits on dash for modifiers) + "play" => KeyCodeValue::Media(MediaKeyCode::Play), + "pause" => KeyCodeValue::Media(MediaKeyCode::Pause), + "playpause" => KeyCodeValue::Media(MediaKeyCode::PlayPause), + "stop" => KeyCodeValue::Media(MediaKeyCode::Stop), + "fastforward" => KeyCodeValue::Media(MediaKeyCode::FastForward), + "rewind" => KeyCodeValue::Media(MediaKeyCode::Rewind), + "tracknext" => KeyCodeValue::Media(MediaKeyCode::TrackNext), + "trackprevious" => KeyCodeValue::Media(MediaKeyCode::TrackPrevious), + "record" => KeyCodeValue::Media(MediaKeyCode::Record), + "lowervolume" => KeyCodeValue::Media(MediaKeyCode::LowerVolume), + "raisevolume" => KeyCodeValue::Media(MediaKeyCode::RaiseVolume), + "mutevolume" | "mute" => KeyCodeValue::Media(MediaKeyCode::MuteVolume), + _ => { + let chars: Vec<char> = key_part.chars().collect(); + if chars.len() == 1 { + let c = chars[0]; + // An uppercase letter implies shift (unless shift already specified) + if c.is_ascii_uppercase() && !ctrl && !alt && !super_key { + return Ok(SingleKey { + code: KeyCodeValue::Char(c), + ctrl: false, + alt: false, + shift: false, + super_key: false, + }); + } + KeyCodeValue::Char(c) + } else { + return Err(format!("unknown key: {key_part}")); + } + } + }; + + Ok(SingleKey { + code, + ctrl, + alt, + shift, + super_key, + }) + } +} + +impl fmt::Display for SingleKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.super_key { + write!(f, "super-")?; + } + if self.ctrl { + write!(f, "ctrl-")?; + } + if self.alt { + write!(f, "alt-")?; + } + if self.shift { + write!(f, "shift-")?; + } + match &self.code { + KeyCodeValue::Char(c) => write!(f, "{c}"), + KeyCodeValue::Enter => write!(f, "enter"), + KeyCodeValue::Esc => write!(f, "esc"), + KeyCodeValue::Tab => write!(f, "tab"), + KeyCodeValue::Backspace => write!(f, "backspace"), + KeyCodeValue::Delete => write!(f, "delete"), + KeyCodeValue::Insert => write!(f, "insert"), + KeyCodeValue::Up => write!(f, "up"), + KeyCodeValue::Down => write!(f, "down"), + KeyCodeValue::Left => write!(f, "left"), + KeyCodeValue::Right => write!(f, "right"), + KeyCodeValue::Home => write!(f, "home"), + KeyCodeValue::End => write!(f, "end"), + KeyCodeValue::PageUp => write!(f, "pageup"), + KeyCodeValue::PageDown => write!(f, "pagedown"), + KeyCodeValue::Space => write!(f, "space"), + KeyCodeValue::F(n) => write!(f, "f{n}"), + KeyCodeValue::Media(m) => match m { + MediaKeyCode::Play => write!(f, "play"), + MediaKeyCode::Pause => write!(f, "media-pause"), + MediaKeyCode::PlayPause => write!(f, "playpause"), + MediaKeyCode::Stop => write!(f, "stop"), + MediaKeyCode::FastForward => write!(f, "fastforward"), + MediaKeyCode::Rewind => write!(f, "rewind"), + MediaKeyCode::TrackNext => write!(f, "tracknext"), + MediaKeyCode::TrackPrevious => write!(f, "trackprevious"), + MediaKeyCode::Record => write!(f, "record"), + MediaKeyCode::LowerVolume => write!(f, "lowervolume"), + MediaKeyCode::RaiseVolume => write!(f, "raisevolume"), + MediaKeyCode::MuteVolume => write!(f, "mutevolume"), + MediaKeyCode::Reverse => write!(f, "reverse"), + }, + } + } +} + +impl KeyInput { + /// Parse a key input string. Supports multi-key sequences separated by spaces + /// (e.g. `"g g"`). + pub fn parse(s: &str) -> Result<Self, String> { + let s = s.trim(); + // Check for space-separated multi-key sequences + // But don't split "space" or modifier combos like "ctrl-a" + let parts: Vec<&str> = s.split_whitespace().collect(); + if parts.len() > 1 { + let keys: Result<Vec<SingleKey>, String> = + parts.iter().map(|p| SingleKey::parse(p)).collect(); + Ok(KeyInput::Sequence(keys?)) + } else { + Ok(KeyInput::Single(SingleKey::parse(s)?)) + } + } +} + +impl fmt::Display for KeyInput { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + KeyInput::Single(k) => write!(f, "{k}"), + KeyInput::Sequence(keys) => { + for (i, k) in keys.iter().enumerate() { + if i > 0 { + write!(f, " ")?; + } + write!(f, "{k}")?; + } + Ok(()) + } + } + } +} + +impl Serialize for KeyInput { + fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + serializer.serialize_str(&self.to_string()) + } +} + +impl<'de> Deserialize<'de> for KeyInput { + fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> { + let s = String::deserialize(deserializer)?; + KeyInput::parse(&s).map_err(serde::de::Error::custom) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + + #[test] + fn parse_simple_keys() { + let k = SingleKey::parse("a").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('a')); + assert!(!k.ctrl && !k.alt && !k.shift); + + let k = SingleKey::parse("enter").unwrap(); + assert_eq!(k.code, KeyCodeValue::Enter); + + let k = SingleKey::parse("esc").unwrap(); + assert_eq!(k.code, KeyCodeValue::Esc); + + let k = SingleKey::parse("tab").unwrap(); + assert_eq!(k.code, KeyCodeValue::Tab); + + let k = SingleKey::parse("space").unwrap(); + assert_eq!(k.code, KeyCodeValue::Space); + } + + #[test] + fn parse_modifiers() { + let k = SingleKey::parse("ctrl-c").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('c')); + assert!(k.ctrl); + assert!(!k.alt); + + let k = SingleKey::parse("alt-f").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('f')); + assert!(k.alt); + assert!(!k.ctrl); + + let k = SingleKey::parse("ctrl-alt-x").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('x')); + assert!(k.ctrl && k.alt); + } + + #[test] + fn parse_uppercase_implies_no_shift_flag() { + let k = SingleKey::parse("G").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('G')); + assert!(!k.shift); + assert!(!k.ctrl); + } + + #[test] + fn parse_special_chars() { + let k = SingleKey::parse("ctrl-[").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('[')); + assert!(k.ctrl); + + let k = SingleKey::parse("?").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('?')); + + let k = SingleKey::parse("/").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('/')); + } + + #[test] + fn parse_multi_key_sequence() { + let ki = KeyInput::parse("g g").unwrap(); + match ki { + KeyInput::Sequence(keys) => { + assert_eq!(keys.len(), 2); + assert_eq!(keys[0].code, KeyCodeValue::Char('g')); + assert_eq!(keys[1].code, KeyCodeValue::Char('g')); + } + _ => panic!("expected sequence"), + } + } + + #[test] + fn display_round_trip() { + let cases = ["ctrl-c", "alt-f", "enter", "G", "tab", "pageup"]; + for s in cases { + let k = KeyInput::parse(s).unwrap(); + let display = k.to_string(); + let k2 = KeyInput::parse(&display).unwrap(); + assert_eq!(k, k2, "round-trip failed for {s}"); + } + + let ki = KeyInput::parse("g g").unwrap(); + assert_eq!(ki.to_string(), "g g"); + } + + #[test] + fn from_event_basic() { + let event = KeyEvent::new(KeyCode::Char('c'), KeyModifiers::CONTROL); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('c')); + assert!(k.ctrl); + assert!(!k.alt); + + let event = KeyEvent::new(KeyCode::Enter, KeyModifiers::NONE); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::Enter); + } + + #[test] + fn from_event_uppercase() { + // Crossterm sends uppercase chars with SHIFT modifier + let event = KeyEvent::new(KeyCode::Char('G'), KeyModifiers::SHIFT); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('G')); + // shift flag should be cleared since the case encodes it + assert!(!k.shift); + } + + #[test] + fn from_event_matches_parsed() { + // Verify that from_event and parse produce the same SingleKey + let event = KeyEvent::new(KeyCode::Char('c'), KeyModifiers::CONTROL); + let from_event = SingleKey::from_event(&event).unwrap(); + let parsed = SingleKey::parse("ctrl-c").unwrap(); + assert_eq!(from_event, parsed); + + let event = KeyEvent::new(KeyCode::Char('G'), KeyModifiers::SHIFT); + let from_event = SingleKey::from_event(&event).unwrap(); + let parsed = SingleKey::parse("G").unwrap(); + assert_eq!(from_event, parsed); + } + + #[test] + fn parse_super_modifier() { + let k = SingleKey::parse("super-a").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('a')); + assert!(k.super_key); + assert!(!k.ctrl && !k.alt && !k.shift); + + // "cmd" is an alias for "super" + let k2 = SingleKey::parse("cmd-a").unwrap(); + assert_eq!(k, k2); + + // "win" is an alias for "super" + let k3 = SingleKey::parse("win-a").unwrap(); + assert_eq!(k, k3); + } + + #[test] + fn parse_super_with_other_modifiers() { + let k = SingleKey::parse("super-ctrl-c").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('c')); + assert!(k.super_key && k.ctrl); + assert!(!k.alt && !k.shift); + } + + #[test] + fn display_super_modifier() { + let k = SingleKey::parse("super-a").unwrap(); + assert_eq!(k.to_string(), "super-a"); + + let k = SingleKey::parse("super-ctrl-x").unwrap(); + assert_eq!(k.to_string(), "super-ctrl-x"); + } + + #[test] + fn display_round_trip_super() { + let k = KeyInput::parse("super-a").unwrap(); + let display = k.to_string(); + let k2 = KeyInput::parse(&display).unwrap(); + assert_eq!(k, k2, "round-trip failed for super-a"); + } + + #[test] + fn from_event_super() { + let event = KeyEvent::new(KeyCode::Char('a'), KeyModifiers::SUPER); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('a')); + assert!(k.super_key); + assert!(!k.ctrl && !k.alt && !k.shift); + } + + #[test] + fn from_event_super_matches_parsed() { + let event = KeyEvent::new(KeyCode::Char('a'), KeyModifiers::SUPER); + let from_event = SingleKey::from_event(&event).unwrap(); + let parsed = SingleKey::parse("super-a").unwrap(); + assert_eq!(from_event, parsed); + } + + #[test] + fn super_uppercase_preserves_super() { + // super-G should keep the super flag (unlike bare "G" which clears shift) + let k = SingleKey::parse("super-G").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('G')); + assert!(k.super_key); + } + + #[test] + fn parse_errors() { + assert!(SingleKey::parse("ctrl-alt-shift-xxx").is_err()); + assert!(SingleKey::parse("foobar-a").is_err()); + } + + #[test] + fn parse_function_keys() { + let k = SingleKey::parse("f1").unwrap(); + assert_eq!(k.code, KeyCodeValue::F(1)); + assert!(!k.ctrl && !k.alt && !k.shift); + + let k = SingleKey::parse("F12").unwrap(); + assert_eq!(k.code, KeyCodeValue::F(12)); + + let k = SingleKey::parse("ctrl-f5").unwrap(); + assert_eq!(k.code, KeyCodeValue::F(5)); + assert!(k.ctrl); + + // F24 is valid (some keyboards have extended function keys) + let k = SingleKey::parse("f24").unwrap(); + assert_eq!(k.code, KeyCodeValue::F(24)); + + // F0 and F25+ are invalid + assert!(SingleKey::parse("f0").is_err()); + assert!(SingleKey::parse("f25").is_err()); + } + + #[test] + fn from_event_function_keys() { + let event = KeyEvent::new(KeyCode::F(1), KeyModifiers::NONE); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::F(1)); + + let event = KeyEvent::new(KeyCode::F(12), KeyModifiers::CONTROL); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::F(12)); + assert!(k.ctrl); + } + + #[test] + fn display_function_keys() { + let k = SingleKey::parse("f1").unwrap(); + assert_eq!(k.to_string(), "f1"); + + let k = SingleKey::parse("ctrl-f12").unwrap(); + assert_eq!(k.to_string(), "ctrl-f12"); + } + + #[test] + fn function_key_round_trip() { + let cases = ["f1", "f12", "ctrl-f5", "alt-f10"]; + for s in cases { + let k = KeyInput::parse(s).unwrap(); + let display = k.to_string(); + let k2 = KeyInput::parse(&display).unwrap(); + assert_eq!(k, k2, "round-trip failed for {s}"); + } + } + + #[test] + fn from_event_function_key_matches_parsed() { + let event = KeyEvent::new(KeyCode::F(12), KeyModifiers::NONE); + let from_event = SingleKey::from_event(&event).unwrap(); + let parsed = SingleKey::parse("f12").unwrap(); + assert_eq!(from_event, parsed); + } + + #[test] + fn from_event_backtab_becomes_shift_tab() { + // Many terminals send BackTab for Shift+Tab + let event = KeyEvent::new(KeyCode::BackTab, KeyModifiers::NONE); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::Tab); + assert!(k.shift); + assert!(!k.ctrl && !k.alt); + } + + #[test] + fn from_event_backtab_matches_parsed_shift_tab() { + let event = KeyEvent::new(KeyCode::BackTab, KeyModifiers::NONE); + let from_event = SingleKey::from_event(&event).unwrap(); + let parsed = SingleKey::parse("shift-tab").unwrap(); + assert_eq!(from_event, parsed); + } + + #[test] + fn from_event_backtab_with_ctrl() { + // BackTab with ctrl modifier + let event = KeyEvent::new(KeyCode::BackTab, KeyModifiers::CONTROL); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::Tab); + assert!(k.shift); + assert!(k.ctrl); + } + + #[test] + fn parse_insert_key() { + let k = SingleKey::parse("insert").unwrap(); + assert_eq!(k.code, KeyCodeValue::Insert); + assert!(!k.ctrl && !k.alt && !k.shift); + + let k = SingleKey::parse("ins").unwrap(); + assert_eq!(k.code, KeyCodeValue::Insert); + + let k = SingleKey::parse("ctrl-insert").unwrap(); + assert_eq!(k.code, KeyCodeValue::Insert); + assert!(k.ctrl); + } + + #[test] + fn from_event_insert_key() { + let event = KeyEvent::new(KeyCode::Insert, KeyModifiers::NONE); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::Insert); + } + + #[test] + fn insert_key_round_trip() { + let k = KeyInput::parse("insert").unwrap(); + let display = k.to_string(); + assert_eq!(display, "insert"); + let k2 = KeyInput::parse(&display).unwrap(); + assert_eq!(k, k2); + } +} diff --git a/crates/turtle/src/command/client/search/keybindings/keymap.rs b/crates/turtle/src/command/client/search/keybindings/keymap.rs new file mode 100644 index 00000000..0d362863 --- /dev/null +++ b/crates/turtle/src/command/client/search/keybindings/keymap.rs @@ -0,0 +1,233 @@ +use std::collections::HashMap; + +use super::actions::Action; +use super::conditions::{ConditionExpr, EvalContext}; +use super::key::{KeyInput, SingleKey}; + +/// A single rule within a keybinding: an optional condition and an action. +/// If the condition is `None`, the rule always matches. +#[derive(Debug, Clone)] +pub struct KeyRule { + pub condition: Option<ConditionExpr>, + pub action: Action, +} + +/// A keybinding is an ordered list of rules. The first rule whose condition +/// matches (or has no condition) wins. +#[derive(Debug, Clone)] +pub struct KeyBinding { + pub rules: Vec<KeyRule>, +} + +/// A keymap is a collection of keybindings indexed by key input. +#[derive(Debug, Clone)] +pub struct Keymap { + pub bindings: HashMap<KeyInput, KeyBinding>, +} + +impl KeyRule { + /// Create an unconditional rule. + pub fn always(action: Action) -> Self { + KeyRule { + condition: None, + action, + } + } + + /// Create a conditional rule. Accepts any type convertible to `ConditionExpr`, + /// including bare `ConditionAtom` values. + pub fn when(condition: impl Into<ConditionExpr>, action: Action) -> Self { + KeyRule { + condition: Some(condition.into()), + action, + } + } +} + +impl KeyBinding { + /// Create a simple (unconditional) binding. + pub fn simple(action: Action) -> Self { + KeyBinding { + rules: vec![KeyRule::always(action)], + } + } + + /// Create a conditional binding from a list of rules. + pub fn conditional(rules: Vec<KeyRule>) -> Self { + KeyBinding { rules } + } +} + +impl Keymap { + /// Create an empty keymap. + pub fn new() -> Self { + Keymap { + bindings: HashMap::new(), + } + } + + /// Bind a key input to a simple (unconditional) action. + pub fn bind(&mut self, key: KeyInput, action: Action) { + self.bindings.insert(key, KeyBinding::simple(action)); + } + + /// Bind a key input to a conditional set of rules. + pub fn bind_conditional(&mut self, key: KeyInput, rules: Vec<KeyRule>) { + self.bindings.insert(key, KeyBinding::conditional(rules)); + } + + /// Resolve a key input to an action given the current evaluation context. + /// Returns `None` if the key has no binding or no rule's condition matches. + pub fn resolve(&self, key: &KeyInput, ctx: &EvalContext) -> Option<Action> { + let binding = self.bindings.get(key)?; + for rule in &binding.rules { + match &rule.condition { + None => return Some(rule.action.clone()), + Some(cond) if cond.evaluate(ctx) => return Some(rule.action.clone()), + Some(_) => {} + } + } + None + } + + /// Check if any binding starts with the given single key as the first key + /// of a multi-key sequence. Used to detect pending multi-key sequences. + pub fn has_sequence_starting_with(&self, prefix: &SingleKey) -> bool { + self.bindings.keys().any(|ki| match ki { + KeyInput::Sequence(keys) => keys.first() == Some(prefix), + KeyInput::Single(_) => false, + }) + } + + /// Merge another keymap into this one. Keys from `other` override keys in `self`. + #[expect(dead_code)] + pub fn merge(&mut self, other: &Keymap) { + for (key, binding) in &other.bindings { + self.bindings.insert(key.clone(), binding.clone()); + } + } +} + +impl Default for Keymap { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::super::conditions::ConditionAtom; + use super::*; + + fn make_ctx(cursor: usize, width: usize, selected: usize, len: usize) -> EvalContext { + EvalContext { + cursor_position: cursor, + input_width: width, + input_byte_len: width, + selected_index: selected, + results_len: len, + original_input_empty: false, + has_context: false, + } + } + + #[test] + fn simple_binding_resolves() { + let mut keymap = Keymap::new(); + let key = KeyInput::parse("ctrl-c").unwrap(); + keymap.bind(key.clone(), Action::ReturnOriginal); + + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(keymap.resolve(&key, &ctx), Some(Action::ReturnOriginal)); + } + + #[test] + fn conditional_first_match_wins() { + let mut keymap = Keymap::new(); + let key = KeyInput::parse("left").unwrap(); + keymap.bind_conditional( + key.clone(), + vec![ + KeyRule::when(ConditionAtom::CursorAtStart, Action::Exit), + KeyRule::always(Action::CursorLeft), + ], + ); + + // Cursor at start → Exit + let ctx = make_ctx(0, 5, 0, 10); + assert_eq!(keymap.resolve(&key, &ctx), Some(Action::Exit)); + + // Cursor not at start → CursorLeft + let ctx = make_ctx(3, 5, 0, 10); + assert_eq!(keymap.resolve(&key, &ctx), Some(Action::CursorLeft)); + } + + #[test] + fn no_match_returns_none() { + let keymap = Keymap::new(); + let key = KeyInput::parse("ctrl-c").unwrap(); + let ctx = make_ctx(0, 0, 0, 0); + assert_eq!(keymap.resolve(&key, &ctx), None); + } + + #[test] + fn conditional_no_condition_matches_returns_none() { + let mut keymap = Keymap::new(); + let key = KeyInput::parse("left").unwrap(); + // Only one rule with a condition that won't match + keymap.bind_conditional( + key.clone(), + vec![KeyRule::when(ConditionAtom::CursorAtStart, Action::Exit)], + ); + + // Cursor not at start → no match + let ctx = make_ctx(3, 5, 0, 10); + assert_eq!(keymap.resolve(&key, &ctx), None); + } + + #[test] + fn has_sequence_starting_with() { + let mut keymap = Keymap::new(); + let seq = KeyInput::parse("g g").unwrap(); + keymap.bind(seq, Action::ScrollToTop); + + let g = SingleKey::parse("g").unwrap(); + assert!(keymap.has_sequence_starting_with(&g)); + + let h = SingleKey::parse("h").unwrap(); + assert!(!keymap.has_sequence_starting_with(&h)); + } + + #[test] + fn merge_overrides() { + let mut base = Keymap::new(); + let key = KeyInput::parse("ctrl-c").unwrap(); + base.bind(key.clone(), Action::ReturnOriginal); + + let mut overlay = Keymap::new(); + overlay.bind(key.clone(), Action::Exit); + + base.merge(&overlay); + + let ctx = make_ctx(0, 0, 0, 0); + assert_eq!(base.resolve(&key, &ctx), Some(Action::Exit)); + } + + #[test] + fn merge_preserves_unoverridden() { + let mut base = Keymap::new(); + let key1 = KeyInput::parse("ctrl-c").unwrap(); + let key2 = KeyInput::parse("ctrl-d").unwrap(); + base.bind(key1.clone(), Action::ReturnOriginal); + base.bind(key2.clone(), Action::DeleteCharAfter); + + let mut overlay = Keymap::new(); + overlay.bind(key1.clone(), Action::Exit); + + base.merge(&overlay); + + let ctx = make_ctx(0, 0, 0, 0); + assert_eq!(base.resolve(&key1, &ctx), Some(Action::Exit)); + assert_eq!(base.resolve(&key2, &ctx), Some(Action::DeleteCharAfter)); + } +} diff --git a/crates/turtle/src/command/client/search/keybindings/mod.rs b/crates/turtle/src/command/client/search/keybindings/mod.rs new file mode 100644 index 00000000..3b6eb2b2 --- /dev/null +++ b/crates/turtle/src/command/client/search/keybindings/mod.rs @@ -0,0 +1,14 @@ +pub mod actions; +pub mod conditions; +pub mod defaults; +pub mod key; +pub mod keymap; + +pub use actions::Action; +#[expect(unused_imports)] +pub use conditions::{ConditionAtom, ConditionExpr, EvalContext}; +pub use defaults::KeymapSet; +#[expect(unused_imports)] +pub use key::{KeyCodeValue, KeyInput, SingleKey}; +#[expect(unused_imports)] +pub use keymap::{KeyBinding, KeyRule, Keymap}; diff --git a/crates/turtle/src/command/client/server.rs b/crates/turtle/src/command/client/server.rs new file mode 100644 index 00000000..7de27551 --- /dev/null +++ b/crates/turtle/src/command/client/server.rs @@ -0,0 +1,61 @@ +use std::net::SocketAddr; + +use crate::atuin_server::{Settings, launch, launch_metrics_server}; +use crate::atuin_server_database::DbType; +use crate::atuin_server_postgres::Postgres; +use crate::atuin_server_sqlite::Sqlite; + +use clap::Subcommand; +use eyre::{Context, Result, eyre}; + +#[derive(Subcommand, Clone, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Start the server + Start { + /// The host address to bind + #[clap(long)] + host: Option<String>, + + /// The port to bind + #[clap(long, short)] + port: Option<u16>, + }, + + /// Print server example configuration + DefaultConfig, +} + +impl Cmd { + #[expect(clippy::too_many_lines)] + pub async fn run(self) -> Result<()> { + match self { + Cmd::Start { host, port } => { + let settings = Settings::new().wrap_err("could not load server settings")?; + let host = host.as_ref().unwrap_or(&settings.host).clone(); + let port = port.unwrap_or(settings.port); + let addr = SocketAddr::new(host.parse()?, port); + + if settings.metrics.enable { + tokio::spawn(launch_metrics_server( + settings.metrics.host.clone(), + settings.metrics.port, + )); + } + + match settings.db_settings.db_type() { + DbType::Postgres => launch::<Postgres>(settings, addr).await, + DbType::Sqlite => launch::<Sqlite>(settings, addr).await, + DbType::Unknown => { + Err(eyre!("db_uri must start with postgres:// or sqlite://")) + } + } + } + Cmd::DefaultConfig => { + // TODO(@bpeetz): Add this back <2026-06-11> + println!("TODO"); + Ok(()) + } + } + } +} diff --git a/crates/turtle/src/command/client/setup.rs b/crates/turtle/src/command/client/setup.rs new file mode 100644 index 00000000..b32ceb97 --- /dev/null +++ b/crates/turtle/src/command/client/setup.rs @@ -0,0 +1,81 @@ +use crate::atuin_client::settings::Settings; + +use colored::Colorize; +use eyre::Result; +use std::io::{self, Write}; +use toml_edit::{DocumentMut, value}; + +pub async fn run(_settings: &Settings) -> Result<()> { + let enable_ai = prompt( + "Atuin AI", + "This will enable command generation and other AI features via the question mark key", + Some( + "By default, Atuin AI only has access to the name and version of your operating system and shell - your shell history is not sent to the AI.", + ), + )?; + + let enable_daemon = prompt( + "Atuin Daemon", + "This will enable improved search and history sync using a persistent background process", + None, + )?; + + let config_file = Settings::get_config_path()?; + let config_str = tokio::fs::read_to_string(&config_file).await?; + let mut doc = config_str.parse::<DocumentMut>()?; + + let mut changed = false; + if enable_ai { + changed = true; + if !doc.contains_key("ai") { + doc["ai"] = toml_edit::table(); + } + doc["ai"]["enabled"] = value(true); + } + + if enable_daemon { + changed = true; + if !doc.contains_key("daemon") { + doc["daemon"] = toml_edit::table(); + } + doc["daemon"]["enabled"] = value(true); + doc["daemon"]["autostart"] = value(true); + doc["search_mode"] = value("daemon-fuzzy"); + } + + if changed { + tokio::fs::write(config_file, doc.to_string()).await?; + + println!( + "{check} Settings updated successfully", + check = "✓".bold().bright_green() + ); + } else { + println!( + "{check} No settings changed", + check = "✓".bold().bright_green() + ); + } + + Ok(()) +} + +pub fn prompt(feature: &str, description: &str, note: Option<&str>) -> Result<bool> { + println!( + "> Enable {feature}?", + feature = feature.bold().bright_blue() + ); + if let Some(note) = note { + println!(" {description}"); + print!(" {note} {q} ", q = "[Y/n]".bold()); + } else { + print!(" {description} {q} ", q = "[Y/n]".bold()); + } + + io::stdout().flush().ok(); + + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + let answer = input.trim().to_lowercase(); + Ok(answer.is_empty() || answer == "y" || answer == "yes") +} diff --git a/crates/turtle/src/command/client/stats.rs b/crates/turtle/src/command/client/stats.rs new file mode 100644 index 00000000..fc10e949 --- /dev/null +++ b/crates/turtle/src/command/client/stats.rs @@ -0,0 +1,85 @@ +use clap::Parser; +use eyre::Result; +use interim::parse_date_string; +use time::{Duration, OffsetDateTime, Time}; + +use crate::atuin_client::{ + database::{Database, current_context}, + settings::Settings, + theme::Theme, +}; + +use crate::atuin_history::stats::{compute, pretty_print}; + +fn parse_ngram_size(s: &str) -> Result<usize, String> { + let value = s + .parse::<usize>() + .map_err(|_| format!("'{s}' is not a valid window size"))?; + + if value == 0 { + return Err("ngram window size must be at least 1".to_string()); + } + + Ok(value) +} + +#[derive(Parser, Debug)] +#[command(infer_subcommands = true)] +pub struct Cmd { + /// Compute statistics for the specified period, leave blank for statistics since the beginning. See [this](https://docs.atuin.sh/reference/stats/) for more details. + period: Vec<String>, + + /// How many top commands to list + #[arg(long, short, default_value = "10")] + count: usize, + + /// The number of consecutive commands to consider + #[arg(long, short, default_value = "1", value_parser = parse_ngram_size)] + ngram_size: usize, +} + +impl Cmd { + pub async fn run(&self, db: &impl Database, settings: &Settings, theme: &Theme) -> Result<()> { + let context = current_context().await?; + let words = if self.period.is_empty() { + String::from("all") + } else { + self.period.join(" ") + }; + + let now = OffsetDateTime::now_utc().to_offset(settings.timezone.0); + let last_night = now.replace_time(Time::MIDNIGHT); + + let history = if words.as_str() == "all" { + db.list(&[], &context, None, false, false).await? + } else if words.trim() == "today" { + let start = last_night; + let end = start + Duration::days(1); + db.range(start, end).await? + } else if words.trim() == "month" { + let end = last_night; + let start = end - Duration::days(31); + db.range(start, end).await? + } else if words.trim() == "week" { + let end = last_night; + let start = end - Duration::days(7); + db.range(start, end).await? + } else if words.trim() == "year" { + let end = last_night; + let start = end - Duration::days(365); + db.range(start, end).await? + } else { + let start = parse_date_string(&words, now, settings.dialect.into())?; + let end = start + Duration::days(1); + db.range(start, end).await? + }; + + let stats = compute(settings, &history, self.count, self.ngram_size); + + if let Some(stats) = stats { + pretty_print(stats, self.ngram_size, theme); + } + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/store.rs b/crates/turtle/src/command/client/store.rs new file mode 100644 index 00000000..dfa3b66c --- /dev/null +++ b/crates/turtle/src/command/client/store.rs @@ -0,0 +1,120 @@ +use clap::Subcommand; +use eyre::Result; + +use crate::atuin_client::{ + database::Database, + record::{sqlite_store::SqliteStore, store::Store}, + settings::Settings, +}; +use itertools::Itertools; +use time::{OffsetDateTime, UtcOffset}; + +#[cfg(feature = "sync")] +mod push; + +#[cfg(feature = "sync")] +mod pull; + +mod purge; +mod rebuild; +mod rekey; +mod verify; + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Print the current status of the record store + Status, + + /// Rebuild a store (eg atuin store rebuild history) + Rebuild(rebuild::Rebuild), + + /// Re-encrypt the store with a new key (potential for data loss!) + Rekey(rekey::Rekey), + + /// Delete all records in the store that cannot be decrypted with the current key + Purge(purge::Purge), + + /// Verify that all records in the store can be decrypted with the current key + Verify(verify::Verify), + + /// Push all records to the remote sync server (one way sync) + #[cfg(feature = "sync")] + Push(push::Push), + + /// Pull records from the remote sync server (one way sync) + #[cfg(feature = "sync")] + Pull(pull::Pull), +} + +impl Cmd { + pub async fn run( + &self, + settings: &Settings, + database: &dyn Database, + store: SqliteStore, + ) -> Result<()> { + match self { + Self::Status => self.status(store).await, + Self::Rebuild(rebuild) => rebuild.run(settings, store, database).await, + Self::Rekey(rekey) => rekey.run(settings, store).await, + Self::Verify(verify) => verify.run(settings, store).await, + Self::Purge(purge) => purge.run(settings, store).await, + + #[cfg(feature = "sync")] + Self::Push(push) => push.run(settings, store).await, + + #[cfg(feature = "sync")] + Self::Pull(pull) => pull.run(settings, store, database).await, + } + } + + pub async fn status(&self, store: SqliteStore) -> Result<()> { + let host_id = Settings::host_id().await?; + let offset = UtcOffset::current_local_offset().unwrap_or(UtcOffset::UTC); + + let status = store.status().await?; + + // TODO: should probs build some data structure and then pretty-print it or smth + for (host, st) in status.hosts.iter().sorted_by_key(|(h, _)| *h) { + let host_string = if host == &host_id { + format!("host: {} <- CURRENT HOST", host.0.as_hyphenated()) + } else { + format!("host: {}", host.0.as_hyphenated()) + }; + + println!("{host_string}"); + + for (tag, idx) in st.iter().sorted_by_key(|(tag, _)| *tag) { + println!("\tstore: {tag}"); + + let first = store.first(*host, tag).await?; + let last = store.last(*host, tag).await?; + + println!("\t\tidx: {idx}"); + + if let Some(first) = first { + println!("\t\tfirst: {}", first.id.0.as_hyphenated()); + + let time = + OffsetDateTime::from_unix_timestamp_nanos(i128::from(first.timestamp))? + .to_offset(offset); + println!("\t\t\tcreated: {time}"); + } + + if let Some(last) = last { + println!("\t\tlast: {}", last.id.0.as_hyphenated()); + + let time = + OffsetDateTime::from_unix_timestamp_nanos(i128::from(last.timestamp))? + .to_offset(offset); + println!("\t\t\tcreated: {time}"); + } + } + + println!(); + } + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/store/pull.rs b/crates/turtle/src/command/client/store/pull.rs new file mode 100644 index 00000000..c9c9c379 --- /dev/null +++ b/crates/turtle/src/command/client/store/pull.rs @@ -0,0 +1,94 @@ +use clap::Args; +use eyre::Result; + +use crate::atuin_client::{ + database::Database, + encryption::load_key, + record::store::Store, + record::sync::Operation, + record::{sqlite_store::SqliteStore, sync}, + settings::Settings, +}; + +#[derive(Args, Debug)] +pub struct Pull { + /// The tag to push (eg, 'history'). Defaults to all tags + #[arg(long, short)] + pub tag: Option<String>, + + /// Force push records + /// This will first wipe the local store, and then download all records from the remote + #[arg(long, default_value = "false")] + pub force: bool, + + /// Page Size + /// How many records to download at once. Defaults to 100 + #[arg(long, default_value = "100")] + pub page: u64, +} + +impl Pull { + pub async fn run( + &self, + settings: &Settings, + store: SqliteStore, + db: &dyn Database, + ) -> Result<()> { + if self.force { + println!("Forcing local overwrite!"); + println!("Clearing local store"); + + store.delete_all().await?; + } + + // We can actually just use the existing diff/etc to push + // 1. Diff + // 2. Get operations + // 3. Filter operations by + // a) are they a download op? + // b) are they for the host/tag we are pushing here? + let client = sync::build_client(settings).await?; + let (diff, remote_index) = sync::diff(&client, &store).await?; + + // Skip on --force: local was already wiped above, mismatch is the user's call. + if !self.force { + let key: [u8; 32] = load_key(settings)?.into(); + sync::check_encryption_key(&client, &remote_index, &key) + .await + .map_err(crate::print_error::format_sync_error)?; + } + + let operations = sync::operations(diff, &store).await?; + + let operations = operations + .into_iter() + .filter(|op| match op { + // No noops or downloads thx + Operation::Noop { .. } | Operation::Upload { .. } => false, + + // pull, so yes plz to downloads! + Operation::Download { tag, .. } => { + if self.force { + return true; + } + + if let Some(t) = self.tag.clone() + && t != *tag + { + return false; + } + + true + } + }) + .collect(); + + let (_, downloaded) = sync::sync_remote(&client, operations, &store, self.page).await?; + + println!("Downloaded {} records", downloaded.len()); + + crate::sync::build(settings, &store, db, Some(&downloaded)).await?; + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/store/purge.rs b/crates/turtle/src/command/client/store/purge.rs new file mode 100644 index 00000000..f7996c4b --- /dev/null +++ b/crates/turtle/src/command/client/store/purge.rs @@ -0,0 +1,26 @@ +use clap::Args; +use eyre::Result; + +use crate::atuin_client::{ + encryption::load_key, + record::{sqlite_store::SqliteStore, store::Store}, + settings::Settings, +}; + +#[derive(Args, Debug)] +pub struct Purge {} + +impl Purge { + pub async fn run(&self, settings: &Settings, store: SqliteStore) -> Result<()> { + println!("Purging local records that cannot be decrypted"); + + let key = load_key(settings)?; + + match store.purge(&key.into()).await { + Ok(()) => println!("Local store purge completed OK"), + Err(e) => println!("Failed to purge local store: {e:?}"), + } + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/store/push.rs b/crates/turtle/src/command/client/store/push.rs new file mode 100644 index 00000000..724dfbef --- /dev/null +++ b/crates/turtle/src/command/client/store/push.rs @@ -0,0 +1,112 @@ +use crate::atuin_common::record::HostId; +use clap::Args; +use eyre::Result; +use uuid::Uuid; + +use crate::atuin_client::{ + api_client::Client, + encryption::load_key, + record::sync::Operation, + record::{sqlite_store::SqliteStore, sync}, + settings::Settings, +}; + +#[derive(Args, Debug)] +pub struct Push { + /// The tag to push (eg, 'history'). Defaults to all tags + #[arg(long, short)] + pub tag: Option<String>, + + /// The host to push, in the form of a UUID host ID. Defaults to the current host. + #[arg(long)] + pub host: Option<Uuid>, + + /// Force push records + /// This will override both host and tag, to be all hosts and all tags. First clear the remote store, then upload all of the + /// local store + #[arg(long, default_value = "false")] + pub force: bool, + + /// Page Size + /// How many records to upload at once. Defaults to 100 + #[arg(long, default_value = "100")] + pub page: u64, +} + +impl Push { + pub async fn run(&self, settings: &Settings, store: SqliteStore) -> Result<()> { + let host_id = Settings::host_id().await?; + + if self.force { + println!("Forcing remote store overwrite!"); + println!("Clearing remote store"); + + let client = Client::new( + &settings.sync_address, + settings.sync_auth_token().await?, + settings.network_connect_timeout, + settings.network_timeout * 10, // we may be deleting a lot of data... so up the + // timeout + ) + .expect("failed to create client"); + + client.delete_store().await?; + } + + // We can actually just use the existing diff/etc to push + // 1. Diff + // 2. Get operations + // 3. Filter operations by + // a) are they an upload op? + // b) are they for the host/tag we are pushing here? + let client = sync::build_client(settings).await?; + let (diff, remote_index) = sync::diff(&client, &store).await?; + + // Skip on --force: that path intentionally replaces remote with local. + if !self.force { + let key: [u8; 32] = load_key(settings)?.into(); + sync::check_encryption_key(&client, &remote_index, &key) + .await + .map_err(crate::print_error::format_sync_error)?; + } + + let operations = sync::operations(diff, &store).await?; + + let operations = operations + .into_iter() + .filter(|op| match op { + // No noops or downloads thx + Operation::Noop { .. } | Operation::Download { .. } => false, + + // push, so yes plz to uploads! + Operation::Upload { host, tag, .. } => { + if self.force { + return true; + } + + if let Some(h) = self.host { + if HostId(h) != *host { + return false; + } + } else if *host != host_id { + return false; + } + + if let Some(t) = self.tag.clone() + && t != *tag + { + return false; + } + + true + } + }) + .collect(); + + let (uploaded, _) = sync::sync_remote(&client, operations, &store, self.page).await?; + + println!("Uploaded {uploaded} records"); + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/store/rebuild.rs b/crates/turtle/src/command/client/store/rebuild.rs new file mode 100644 index 00000000..80e201c2 --- /dev/null +++ b/crates/turtle/src/command/client/store/rebuild.rs @@ -0,0 +1,58 @@ +use clap::Args; +use eyre::{Result, bail}; + +#[cfg(feature = "daemon")] +use crate::command::client::daemon as daemon_cmd; + +use crate::atuin_client::{ + database::Database, encryption, history::store::HistoryStore, + record::sqlite_store::SqliteStore, settings::Settings, +}; + +#[derive(Args, Debug)] +pub struct Rebuild { + pub tag: String, +} + +impl Rebuild { + pub async fn run( + &self, + settings: &Settings, + store: SqliteStore, + database: &dyn Database, + ) -> Result<()> { + // keep it as a string and not an enum atm + // would be super cool to build this dynamically in the future + // eg register handles for rebuilding various tags without having to make this part of the + // binary big + match self.tag.as_str() { + "history" => { + self.rebuild_history(settings, store.clone(), database) + .await?; + } + + tag => bail!("unknown tag: {tag}"), + } + + Ok(()) + } + + async fn rebuild_history( + &self, + settings: &Settings, + store: SqliteStore, + database: &dyn Database, + ) -> Result<()> { + let encryption_key: [u8; 32] = encryption::load_key(settings)?.into(); + + let host_id = Settings::host_id().await?; + let history_store = HistoryStore::new(store, host_id, encryption_key); + + history_store.build(database).await?; + + #[cfg(feature = "daemon")] + daemon_cmd::emit_event(settings, crate::atuin_daemon::DaemonEvent::HistoryRebuilt).await; + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/store/rekey.rs b/crates/turtle/src/command/client/store/rekey.rs new file mode 100644 index 00000000..e63be447 --- /dev/null +++ b/crates/turtle/src/command/client/store/rekey.rs @@ -0,0 +1,41 @@ +use clap::Args; +use eyre::Result; +use tokio::{fs::File, io::AsyncWriteExt}; + +use crate::atuin_client::{ + encryption::{decode_key, generate_encoded_key, load_key}, + record::sqlite_store::SqliteStore, + record::store::Store, + settings::Settings, +}; + +#[derive(Args, Debug)] +pub struct Rekey { + /// The new key to use for encryption. Omit for a randomly-generated key + key: Option<String>, +} + +impl Rekey { + pub async fn run(&self, settings: &Settings, store: SqliteStore) -> Result<()> { + let key = if let Some(key) = self.key.clone() { + println!("Re-encrypting store with specified key"); + + key + } else { + println!("Re-encrypting store with freshly-generated key"); + let (_, encoded) = generate_encoded_key()?; + encoded + }; + + let current_key: [u8; 32] = load_key(settings)?.into(); + let new_key: [u8; 32] = decode_key(key.clone())?.into(); + + store.re_encrypt(¤t_key, &new_key).await?; + + println!("Store rewritten. Saving new key"); + let mut file = File::create(settings.key_path.clone()).await?; + file.write_all(key.as_bytes()).await?; + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/store/verify.rs b/crates/turtle/src/command/client/store/verify.rs new file mode 100644 index 00000000..5aa1dc70 --- /dev/null +++ b/crates/turtle/src/command/client/store/verify.rs @@ -0,0 +1,26 @@ +use clap::Args; +use eyre::Result; + +use crate::atuin_client::{ + encryption::load_key, + record::{sqlite_store::SqliteStore, store::Store}, + settings::Settings, +}; + +#[derive(Args, Debug)] +pub struct Verify {} + +impl Verify { + pub async fn run(&self, settings: &Settings, store: SqliteStore) -> Result<()> { + println!("Verifying local store can be decrypted with the current key"); + + let key = load_key(settings)?; + + match store.verify(&key.into()).await { + Ok(()) => println!("Local store encryption verified OK"), + Err(e) => println!("Failed to verify local store encryption: {e:?}"), + } + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/sync.rs b/crates/turtle/src/command/client/sync.rs new file mode 100644 index 00000000..a4839b5f --- /dev/null +++ b/crates/turtle/src/command/client/sync.rs @@ -0,0 +1,120 @@ +use clap::Subcommand; +use eyre::{Result, WrapErr}; + +use crate::atuin_client::{ + database::Database, + encryption, + history::store::HistoryStore, + record::{sqlite_store::SqliteStore, store::Store, sync}, + settings::Settings, +}; + +mod status; + +use crate::command::client::account; + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Sync with the configured server + Sync { + /// Force re-download everything + #[arg(long, short)] + force: bool, + }, + + /// Login to the configured server + Login(account::login::Cmd), + + /// Log out + Logout, + + /// Register with the configured server + Register(account::register::Cmd), + + /// Print the encryption key for transfer to another machine + Key {}, + + /// Display the sync status + Status, +} + +impl Cmd { + pub async fn run( + self, + settings: Settings, + db: &impl Database, + store: SqliteStore, + ) -> Result<()> { + match self { + Self::Sync { force } => run(&settings, force, db, store).await, + Self::Login(l) => l.run(&settings, &store).await, + Self::Logout => account::logout::run().await, + Self::Register(r) => r.run(&settings).await, + Self::Status => status::run(&settings).await, + Self::Key {} => { + use crate::atuin_client::encryption::{encode_key, load_key}; + let key = load_key(&settings).wrap_err("could not load encryption key")?; + + let encode = encode_key(&key).wrap_err("could not encode encryption key")?; + println!("{encode}"); + + Ok(()) + } + } + } +} + +async fn run( + settings: &Settings, + force: bool, + db: &impl Database, + store: SqliteStore, +) -> Result<()> { + let encryption_key: [u8; 32] = encryption::load_key(settings) + .context("could not load encryption key")? + .into(); + + let host_id = Settings::host_id().await?; + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + + let (uploaded, downloaded) = sync::sync(settings, &store, &encryption_key) + .await + .map_err(crate::print_error::format_sync_error)?; + + crate::sync::build(settings, &store, db, Some(&downloaded)).await?; + + println!("{uploaded}/{} up/down to record store", downloaded.len()); + + let history_length = db.history_count(true).await?; + let store_history_length = store.len_tag("history").await?; + + #[expect(clippy::cast_sign_loss)] + if history_length as u64 > store_history_length { + println!("{history_length} in history index, but {store_history_length} in history store"); + println!("Running automatic history store init..."); + + // Internally we use the global filter mode, so this context is ignored. + // don't recurse or loop here. + history_store.init_store(db).await?; + + println!("Re-running sync due to new records locally"); + + // we'll want to run sync once more, as there will now be stuff to upload + let (uploaded, downloaded) = sync::sync(settings, &store, &encryption_key) + .await + .map_err(crate::print_error::format_sync_error)?; + + crate::sync::build(settings, &store, db, Some(&downloaded)).await?; + + println!("{uploaded}/{} up/down to record store", downloaded.len()); + } + + println!( + "Sync complete! {} items in history database, force: {}", + db.history_count(true).await?, + force + ); + + Ok(()) +} diff --git a/crates/turtle/src/command/client/sync/status.rs b/crates/turtle/src/command/client/sync/status.rs new file mode 100644 index 00000000..00088b59 --- /dev/null +++ b/crates/turtle/src/command/client/sync/status.rs @@ -0,0 +1,37 @@ +use crate::{SHA, VERSION}; +use crate::atuin_client::{api_client, settings::Settings}; +use colored::Colorize; +use eyre::{Result, bail}; + +pub async fn run(settings: &Settings) -> Result<()> { + if !settings.logged_in().await? { + bail!("You are not logged in to a sync server - cannot show sync status"); + } + + let client = api_client::Client::new( + &settings.sync_address, + settings.sync_auth_token().await?, + settings.network_connect_timeout, + settings.network_timeout, + )?; + + let me = client.me().await?; + let last_sync = Settings::last_sync().await?; + + println!("Atuin v{VERSION} - Build rev {SHA}\n"); + + println!("{}", "[Local]".green()); + + if settings.auto_sync { + println!("Sync frequency: {}", settings.sync_frequency); + println!("Last sync: {}", last_sync.to_offset(settings.timezone.0)); + } + + if settings.auto_sync { + println!("{}", "[Remote]".green()); + println!("Address: {}", settings.sync_address); + println!("Username: {}", me.username); + } + + Ok(()) +} diff --git a/crates/turtle/src/command/client/wrapped.rs b/crates/turtle/src/command/client/wrapped.rs new file mode 100644 index 00000000..694157c2 --- /dev/null +++ b/crates/turtle/src/command/client/wrapped.rs @@ -0,0 +1,326 @@ +use crossterm::style::{ResetColor, SetAttribute}; +use eyre::Result; +use std::collections::{HashMap, HashSet}; +use time::{Date, Duration, Month, OffsetDateTime, Time}; + +use crate::atuin_client::{database::Database, settings::Settings, theme::Theme}; + +use crate::atuin_history::stats::{Stats, compute}; + +#[derive(Debug)] +struct WrappedStats { + nav_commands: usize, + pkg_commands: usize, + error_rate: f64, + first_half_commands: Vec<(String, usize)>, + second_half_commands: Vec<(String, usize)>, + git_percentage: f64, + busiest_hour: Option<(String, usize)>, +} + +impl WrappedStats { + #[expect(clippy::too_many_lines, clippy::cast_precision_loss)] + fn new( + settings: &Settings, + stats: &Stats, + history: &[crate::atuin_client::history::History], + ) -> Self { + let nav_commands = stats + .top + .iter() + .filter(|(cmd, _)| { + let cmd = &cmd[0]; + cmd == "cd" || cmd == "ls" || cmd == "pwd" || cmd == "pushd" || cmd == "popd" + }) + .map(|(_, count)| count) + .sum(); + + let pkg_managers = [ + "cargo", + "npm", + "pnpm", + "yarn", + "pip", + "pip3", + "pipenv", + "poetry", + "pipx", + "uv", + "brew", + "apt", + "apt-get", + "apk", + "pacman", + "yay", + "paru", + "yum", + "dnf", + "dnf5", + "rpm", + "rpm-ostree", + "zypper", + "pkg", + "chocolatey", + "choco", + "scoop", + "winget", + "gem", + "bundle", + "shards", + "composer", + "gradle", + "maven", + "mvn", + "go get", + "nuget", + "dotnet", + "mix", + "hex", + "rebar3", + "nix", + "nix-env", + "cabal", + "opam", + ]; + + let pkg_commands = history + .iter() + .filter(|h| { + let cmd = h.command.clone(); + pkg_managers.iter().any(|pm| cmd.starts_with(pm)) + }) + .count(); + + // Error analysis + let mut command_errors: HashMap<String, (usize, usize)> = HashMap::new(); // (total_uses, errors) + let midyear = history[0].timestamp + Duration::days(182); // Split year in half + + let mut first_half_commands: HashMap<String, usize> = HashMap::new(); + let mut second_half_commands: HashMap<String, usize> = HashMap::new(); + let mut hours: HashMap<String, usize> = HashMap::new(); + + for entry in history { + let cmd = entry + .command + .split_whitespace() + .next() + .unwrap_or("") + .to_string(); + let (total, errors) = command_errors.entry(cmd.clone()).or_insert((0, 0)); + *total += 1; + if entry.exit != 0 { + *errors += 1; + } + + // Track command evolution + if entry.timestamp < midyear { + *first_half_commands.entry(cmd.clone()).or_default() += 1; + } else { + *second_half_commands.entry(cmd).or_default() += 1; + } + + // Track hourly distribution + let local_time = entry + .timestamp + .to_offset(time::UtcOffset::current_local_offset().unwrap_or(settings.timezone.0)); + let hour = format!("{:02}:00", local_time.time().hour()); + *hours.entry(hour).or_default() += 1; + } + + let total_errors: usize = command_errors.values().map(|(_, errors)| errors).sum(); + let total_commands: usize = command_errors.values().map(|(total, _)| total).sum(); + let error_rate = total_errors as f64 / total_commands as f64; + + // Process command evolution data + let mut first_half: Vec<_> = first_half_commands.into_iter().collect(); + let mut second_half: Vec<_> = second_half_commands.into_iter().collect(); + first_half.sort_by_key(|(_, count)| std::cmp::Reverse(*count)); + second_half.sort_by_key(|(_, count)| std::cmp::Reverse(*count)); + first_half.truncate(5); + second_half.truncate(5); + + // Calculate git percentage + let git_commands: usize = stats + .top + .iter() + .filter(|(cmd, _)| cmd[0].starts_with("git")) + .map(|(_, count)| count) + .sum(); + let git_percentage = git_commands as f64 / stats.total_commands as f64; + + // Find busiest hour + let busiest_hour = hours.into_iter().max_by_key(|(_, count)| *count); + + Self { + nav_commands, + pkg_commands, + error_rate, + first_half_commands: first_half, + second_half_commands: second_half, + git_percentage, + busiest_hour, + } + } +} + +pub fn print_wrapped_header(year: i32) { + let reset = ResetColor; + let bold = SetAttribute(crossterm::style::Attribute::Bold); + + println!("{bold}╭────────────────────────────────────╮{reset}"); + println!("{bold}│ ATUIN WRAPPED {year} │{reset}"); + println!("{bold}│ Your Year in Shell History │{reset}"); + println!("{bold}╰────────────────────────────────────╯{reset}"); + println!(); +} + +#[expect(clippy::cast_precision_loss)] +fn print_fun_facts(wrapped_stats: &WrappedStats, stats: &Stats, year: i32) { + let reset = ResetColor; + let bold = SetAttribute(crossterm::style::Attribute::Bold); + + if wrapped_stats.git_percentage > 0.05 { + println!( + "{bold}🌟 You're a Git Power User!{reset} {bold}{:.1}%{reset} of your commands were Git operations\n", + wrapped_stats.git_percentage * 100.0 + ); + } + // Navigation patterns + let nav_percentage = wrapped_stats.nav_commands as f64 / stats.total_commands as f64 * 100.0; + if nav_percentage > 0.05 { + println!( + "{bold}🚀 You're a Navigator!{reset} {bold}{nav_percentage:.1}%{reset} of your time was spent navigating directories\n", + ); + } + + // Command vocabulary + println!( + "{bold}📚 Command Vocabulary{reset}: You know {bold}{}{reset} unique commands\n", + stats.unique_commands + ); + + // Package management + println!( + "{bold}📦 Package Management{reset}: You ran {bold}{}{reset} package-related commands\n", + wrapped_stats.pkg_commands + ); + + // Error patterns + let error_percentage = wrapped_stats.error_rate * 100.0; + println!( + "{bold}🚨 Error Analysis{reset}: Your commands failed {bold}{error_percentage:.1}%{reset} of the time\n", + ); + + // Command evolution + println!("🔍 Command Evolution:"); + + // print stats for each half and compare + println!(" {bold}Top Commands{reset} in the first half of {year}:"); + for (cmd, count) in wrapped_stats.first_half_commands.iter().take(3) { + println!(" {bold}{cmd}{reset} ({count} times)"); + } + + println!(" {bold}Top Commands{reset} in the second half of {year}:"); + for (cmd, count) in wrapped_stats.second_half_commands.iter().take(3) { + println!(" {bold}{cmd}{reset} ({count} times)"); + } + + // Find new favorite commands (in top 5 of second half but not in first half) + let first_half_set: HashSet<_> = wrapped_stats + .first_half_commands + .iter() + .map(|(cmd, _)| cmd) + .collect(); + let new_favorites: Vec<_> = wrapped_stats + .second_half_commands + .iter() + .filter(|(cmd, _)| !first_half_set.contains(cmd)) + .take(2) + .collect(); + + if !new_favorites.is_empty() { + println!(" {bold}New favorites{reset} in the second half:"); + for (cmd, count) in new_favorites { + println!(" {bold}{cmd}{reset} ({count} times)"); + } + } + + // Time patterns + if let Some((hour, count)) = &wrapped_stats.busiest_hour { + println!("\n🕘 Most Productive Hour: {bold}{hour}{reset} ({count} commands)"); + + // Night owl or early bird + let hour_num = hour + .split(':') + .next() + .unwrap_or("0") + .parse::<u32>() + .unwrap_or(0); + if hour_num >= 22 || hour_num <= 4 { + println!(" You're quite the night owl! 🦉"); + } else if (5..=7).contains(&hour_num) { + println!(" Early bird gets the worm! 🐦"); + } + } + + println!(); +} + +pub async fn run( + year: Option<i32>, + db: &impl Database, + settings: &Settings, + theme: &Theme, +) -> Result<()> { + let now = OffsetDateTime::now_utc().to_offset(settings.timezone.0); + let month = now.month(); + + // If we're in December, then wrapped is for the current year. If not, it's for the previous year + let year = year.unwrap_or_else(|| { + if month == Month::December { + now.year() + } else { + now.year() - 1 + } + }); + + let start = OffsetDateTime::new_in_offset( + Date::from_calendar_date(year, Month::January, 1).unwrap(), + Time::MIDNIGHT, + now.offset(), + ); + let end = OffsetDateTime::new_in_offset( + Date::from_calendar_date(year, Month::December, 31).unwrap(), + Time::MIDNIGHT + Duration::days(1) - Duration::nanoseconds(1), + now.offset(), + ); + + let history = db.range(start, end).await?; + if history.is_empty() { + println!( + "Your history for {year} is empty!\nMaybe 'atuin import' could help you import your previous history 🪄" + ); + return Ok(()); + } + + // Compute overall stats using existing functionality + let stats = compute(settings, &history, 10, 1).expect("Failed to compute stats"); + let wrapped_stats = WrappedStats::new(settings, &stats, &history); + + // Print wrapped format + print_wrapped_header(year); + + println!("🎉 In {year}, you typed {} commands!", stats.total_commands); + println!( + " That's ~{} commands every day\n", + stats.total_commands / 365 + ); + + println!("Your Top Commands:"); + crate::atuin_history::stats::pretty_print(stats.clone(), 1, theme); + println!(); + + print_fun_facts(&wrapped_stats, &stats, year); + + Ok(()) +} diff --git a/crates/turtle/src/command/contributors.rs b/crates/turtle/src/command/contributors.rs new file mode 100644 index 00000000..452fd335 --- /dev/null +++ b/crates/turtle/src/command/contributors.rs @@ -0,0 +1,5 @@ +static CONTRIBUTORS: &str = include_str!("CONTRIBUTORS"); + +pub fn run() { + println!("\n{CONTRIBUTORS}"); +} diff --git a/crates/turtle/src/command/external.rs b/crates/turtle/src/command/external.rs new file mode 100644 index 00000000..e1f0cddd --- /dev/null +++ b/crates/turtle/src/command/external.rs @@ -0,0 +1,102 @@ +use std::fmt::Write as _; +use std::process::Command; +use std::{io, process}; + +#[cfg(feature = "client")] +use crate::atuin_client::plugin::{OfficialPluginRegistry, PluginContext}; +use clap::CommandFactory; +use clap::builder::{StyledStr, Styles}; +use eyre::Result; + +use crate::Atuin; + +pub fn run(args: &[String]) -> Result<()> { + let subcommand = &args[0]; + let bin = format!("atuin-{subcommand}"); + let mut cmd = Command::new(&bin); + cmd.args(&args[1..]); + + #[cfg(feature = "client")] + let context = PluginContext::new(subcommand); + + let spawn_result = match cmd.spawn() { + Ok(child) => Ok(child), + Err(e) => match e.kind() { + io::ErrorKind::NotFound => { + let output = render_not_found(subcommand, &bin); + Err(output) + } + _ => Err(e.to_string().into()), + }, + }; + + match spawn_result { + Ok(mut child) => { + let status = child.wait()?; + if status.success() { + Ok(()) + } else { + #[cfg(feature = "client")] + drop(context); + + process::exit(status.code().unwrap_or(1)); + } + } + Err(e) => { + eprintln!("{}", e.ansi()); + + #[cfg(feature = "client")] + drop(context); + + process::exit(1); + } + } +} + +fn render_not_found(subcommand: &str, bin: &str) -> StyledStr { + let mut output = StyledStr::new(); + let styles = Styles::styled(); + + let error = styles.get_error(); + let invalid = styles.get_invalid(); + let literal = styles.get_literal(); + + #[cfg(feature = "client")] + { + let registry = OfficialPluginRegistry::new(); + + // Check if this is an official plugin + if let Some(install_message) = registry.get_install_message(subcommand) { + let _ = write!(output, "{error}error:{error:#} "); + let _ = write!( + output, + "'{invalid}{subcommand}{invalid:#}' is an official atuin plugin, but it's not installed" + ); + let _ = write!(output, "\n\n"); + let _ = write!(output, "{install_message}"); + return output; + } + } + + let mut atuin_cmd = Atuin::command(); + let usage = atuin_cmd.render_usage(); + + let _ = write!(output, "{error}error:{error:#} "); + let _ = write!( + output, + "unrecognized subcommand '{invalid}{subcommand}{invalid:#}' " + ); + let _ = write!( + output, + "and no executable named '{invalid}{bin}{invalid:#}' found in your PATH" + ); + let _ = write!(output, "\n\n"); + let _ = write!(output, "{usage}"); + let _ = write!(output, "\n\n"); + let _ = write!( + output, + "For more information, try '{literal}--help{literal:#}'." + ); + + output +} diff --git a/crates/turtle/src/command/gen_completions.rs b/crates/turtle/src/command/gen_completions.rs new file mode 100644 index 00000000..10d4f689 --- /dev/null +++ b/crates/turtle/src/command/gen_completions.rs @@ -0,0 +1,84 @@ +use clap::{CommandFactory, Parser, ValueEnum}; +use clap_complete::{Generator, Shell, generate, generate_to}; +use clap_complete_nushell::Nushell; +use eyre::Result; + +// clap put nushell completions into a separate package due to the maintainers +// being a little less committed to support them. +// This means we have to do a tiny bit of legwork to combine these completions +// into one command. +#[derive(Debug, Clone, ValueEnum)] +#[value(rename_all = "lower")] +pub enum GenShell { + Bash, + Elvish, + Fish, + Nushell, + PowerShell, + Zsh, +} + +impl Generator for GenShell { + fn file_name(&self, name: &str) -> String { + match self { + // clap_complete + Self::Bash => Shell::Bash.file_name(name), + Self::Elvish => Shell::Elvish.file_name(name), + Self::Fish => Shell::Fish.file_name(name), + Self::PowerShell => Shell::PowerShell.file_name(name), + Self::Zsh => Shell::Zsh.file_name(name), + + // clap_complete_nushell + Self::Nushell => Nushell.file_name(name), + } + } + + fn generate(&self, cmd: &clap::Command, buf: &mut dyn std::io::prelude::Write) { + match self { + // clap_complete + Self::Bash => Shell::Bash.generate(cmd, buf), + Self::Elvish => Shell::Elvish.generate(cmd, buf), + Self::Fish => Shell::Fish.generate(cmd, buf), + Self::PowerShell => Shell::PowerShell.generate(cmd, buf), + Self::Zsh => Shell::Zsh.generate(cmd, buf), + + // clap_complete_nushell + Self::Nushell => Nushell.generate(cmd, buf), + } + } +} + +#[derive(Debug, Parser)] +pub struct Cmd { + /// Set the shell for generating completions + #[arg(long, short)] + shell: GenShell, + + /// Set the output directory + #[arg(long, short)] + out_dir: Option<String>, +} + +impl Cmd { + pub fn run(self) -> Result<()> { + let Cmd { shell, out_dir } = self; + + let mut cli = crate::Atuin::command(); + + match out_dir { + Some(out_dir) => { + generate_to(shell, &mut cli, env!("CARGO_PKG_NAME"), &out_dir)?; + } + None => { + generate( + shell, + &mut cli, + env!("CARGO_PKG_NAME"), + &mut std::io::stdout(), + ); + } + } + + Ok(()) + } +} diff --git a/crates/turtle/src/command/mod.rs b/crates/turtle/src/command/mod.rs new file mode 100644 index 00000000..e58bfe72 --- /dev/null +++ b/crates/turtle/src/command/mod.rs @@ -0,0 +1,156 @@ +use clap::Subcommand; +use eyre::Result; + +#[cfg(not(windows))] +use rustix::{fs::Mode, process::umask}; + +#[cfg(feature = "client")] +mod client; + +mod contributors; + +mod gen_completions; + +mod external; + +#[derive(Subcommand)] +#[command(infer_subcommands = true)] +#[expect(clippy::large_enum_variant)] +pub enum AtuinCmd { + #[cfg(feature = "client")] + #[command(flatten)] + Client(client::Cmd), + + /// PTY proxy for atuin + #[cfg(feature = "pty-proxy")] + #[command(alias = "hex")] + PtyProxy(crate::atuin_pty_proxy::PtyProxy), + + /// Generate a UUID + Uuid, + + Contributors, + + /// Generate shell completions + GenCompletions(gen_completions::Cmd), + + #[command(external_subcommand)] + External(Vec<String>), +} + +impl AtuinCmd { + pub fn run(self) -> Result<()> { + #[cfg(not(windows))] + { + // set umask before we potentially open/create files + // or in other words, 077. Do not allow any access to any other user + let mode = Mode::RWXG | Mode::RWXO; + umask(mode); + } + + match self { + #[cfg(feature = "client")] + Self::Client(client) => client.run(), + + #[cfg(feature = "pty-proxy")] + Self::PtyProxy(proxy) => { + run_pty_proxy(proxy); + Ok(()) + } + + Self::Contributors => { + contributors::run(); + Ok(()) + } + Self::Uuid => { + println!("{}", crate::atuin_common::utils::uuid_v7().as_simple()); + Ok(()) + } + Self::GenCompletions(gen_completions) => gen_completions.run(), + Self::External(args) => external::run(&args), + } + } +} + +#[cfg(all(feature = "pty-proxy", unix))] +fn run_pty_proxy(proxy: crate::atuin_pty_proxy::PtyProxy) { + #[cfg(feature = "daemon")] + proxy.run(semantic_command_capture_sink()); + + #[cfg(not(feature = "daemon"))] + proxy.run(None); +} + +#[cfg(all(feature = "daemon", feature = "pty-proxy", unix))] +fn semantic_command_capture_sink() -> Option<crate::atuin_pty_proxy::CommandCaptureSink> { + use std::sync::mpsc; + use std::time::Duration; + + if is_truthy_env("ATUIN_TERMINAL") { + return None; + } + + let settings = crate::atuin_client::settings::Settings::new().ok()?; + let (tx, rx) = mpsc::sync_channel::<crate::atuin_pty_proxy::CommandCapture>(128); + + std::thread::spawn(move || { + let Ok(runtime) = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + else { + return; + }; + + while let Ok(first) = rx.recv() { + let mut batch = vec![first]; + + while batch.len() < 64 { + match rx.recv_timeout(Duration::from_millis(25)) { + Ok(capture) => batch.push(capture), + Err(mpsc::RecvTimeoutError::Timeout | mpsc::RecvTimeoutError::Disconnected) => { + break; + } + } + } + + runtime.block_on(send_semantic_command_captures(&settings, batch)); + } + }); + + Some(Box::new(move |capture| { + let _ = tx.try_send(capture); + })) +} + +#[cfg(all(feature = "daemon", feature = "pty-proxy", unix))] +#[inline] +fn is_truthy_env(name: &str) -> bool { + std::env::var(name) + .ok() + .as_ref() + .is_some_and(|value| !value.trim().is_empty() && value.trim() != "false") +} + +#[cfg(all(feature = "daemon", feature = "pty-proxy", unix))] +async fn send_semantic_command_captures( + settings: &crate::atuin_client::settings::Settings, + batch: Vec<crate::atuin_pty_proxy::CommandCapture>, +) { + let captures = batch + .into_iter() + .map(|capture| crate::atuin_daemon::semantic::CommandCapture { + prompt: capture.prompt, + command: capture.command, + output: capture.output, + exit_code: capture.exit_code, + history_id: capture.history_id, + session_id: capture.session_id, + output_truncated: capture.output_truncated, + output_observed_bytes: capture.output_observed_bytes, + }) + .collect(); + + if let Ok(mut client) = crate::atuin_daemon::SemanticClient::from_settings(settings).await { + let _ = client.record_commands(captures).await; + } +} |
