diff options
| author | Benedikt Peetz <benedikt.peetz@b-peetz.de> | 2026-06-11 00:54:30 +0200 |
|---|---|---|
| committer | Benedikt Peetz <benedikt.peetz@b-peetz.de> | 2026-06-11 00:54:30 +0200 |
| commit | 5c39e7cf284a1f6e9a1657f2deb44e359fc47eb8 (patch) | |
| tree | c64baa8d5866c8e339eaf660dd3f94f30a3f7d8a /crates/turtle/src/command/client/search/engines | |
| parent | chore: Somewhat simplify sync code (diff) | |
| download | atuin-5c39e7cf284a1f6e9a1657f2deb44e359fc47eb8.zip | |
chore: Move everything into one big crate
That helps remove duplicated code and rustc/cargo will now also show
dead code correctly.
Diffstat (limited to 'crates/turtle/src/command/client/search/engines')
3 files changed, 581 insertions, 0 deletions
diff --git a/crates/turtle/src/command/client/search/engines/daemon.rs b/crates/turtle/src/command/client/search/engines/daemon.rs new file mode 100644 index 00000000..b1299c02 --- /dev/null +++ b/crates/turtle/src/command/client/search/engines/daemon.rs @@ -0,0 +1,242 @@ +use crate::atuin_client::{ + database::{Database, OptFilters}, + history::{AUTHOR_FILTER_ALL_USER, History}, + settings::{SearchMode, Settings}, +}; +use crate::atuin_daemon::client::{DaemonClientErrorKind, SearchClient, classify_error}; +use async_trait::async_trait; +use atuin_nucleo_matcher::{ + Config, Matcher, Utf32Str, + pattern::{CaseMatching, Normalization, Pattern}, +}; +use eyre::Result; +use tracing::{Level, debug, instrument, span}; +use uuid::Uuid; + +use super::{SearchEngine, SearchState}; +use crate::command::client::daemon; + +pub struct Search { + client: Option<SearchClient>, + query_id: u64, + settings: Settings, + #[cfg(unix)] + socket_path: String, +} + +impl Search { + pub fn new(settings: &Settings) -> Self { + Search { + client: None, + query_id: 0, + settings: settings.clone(), + #[cfg(unix)] + socket_path: settings.daemon.socket_path.clone(), + } + } + + #[instrument(skip_all, level = Level::TRACE, name = "get_daemon_client")] + async fn get_client(&mut self) -> Result<&mut SearchClient> { + if self.client.is_none() { + self.connect().await?; + } + Ok(self.client.as_mut().unwrap()) + } + + async fn connect(&mut self) -> Result<()> { + #[cfg(unix)] + let client = SearchClient::new(self.socket_path.clone()).await?; + + self.client = Some(client); + Ok(()) + } + + fn should_retry(err: &eyre::Report) -> bool { + matches!( + classify_error(err), + DaemonClientErrorKind::Connect + | DaemonClientErrorKind::Unavailable + | DaemonClientErrorKind::Unimplemented + ) + } + + fn next_query_id(&mut self) -> u64 { + self.query_id += 1; + self.query_id + } + + /// Check if query contains regex pattern (r/.../) + /// Nucleo doesn't support regex, so we fall back to database search + fn contains_regex_pattern(query: &str) -> bool { + query.starts_with("r/") || query.contains(" r/") + } + + #[instrument(skip_all, level = Level::TRACE, name = "daemon_db_fallback")] + async fn fallback_to_db_search( + &self, + state: &SearchState, + db: &dyn Database, + ) -> Result<Vec<History>> { + let results = db + .search( + SearchMode::FullText, + state.filter_mode, + &state.context, + state.input.as_str(), + OptFilters { + limit: Some(200), + authors: vec![AUTHOR_FILTER_ALL_USER.to_string()], + ..Default::default() + }, + ) + .await + .map_or(Vec::new(), |r| r.into_iter().collect()); + Ok(results) + } + + #[instrument(skip_all, level = Level::TRACE, name = "hydrate_from_db", fields(count = ids.len()))] + async fn hydrate_from_db(&self, db: &dyn Database, ids: &[String]) -> Result<Vec<History>> { + let placeholders: Vec<String> = ids.iter().map(|id| format!("'{id}'")).collect(); + let sql_query = format!( + "SELECT * FROM history WHERE id IN ({}) ORDER BY timestamp DESC", + placeholders.join(",") + ); + Ok(db.query_history(&sql_query).await?) + } +} + +#[async_trait] +impl SearchEngine for Search { + #[instrument(skip_all, level = Level::TRACE, name = "daemon_search", fields(query = %state.input.as_str()))] + async fn full_query( + &mut self, + state: &SearchState, + db: &mut dyn Database, + ) -> Result<Vec<History>> { + let query = state.input.as_str().to_string(); + + // Fall back to database for regex queries (Nucleo doesn't support regex) + if Self::contains_regex_pattern(&query) { + debug!(query = %query, "[daemon-client] regex detected, falling back to db"); + return self.fallback_to_db_search(state, db).await; + } + + let query_id = self.next_query_id(); + + let span = + span!(Level::TRACE, "daemon_search.req_resp", query = %query, query_id = query_id); + + // Try to connect and search; if it fails with a retriable error, + // auto-start the daemon and retry once. + let first_attempt = async { + let client = self.get_client().await?; + client + .search( + query.clone(), + query_id, + state.filter_mode, + Some(state.context.clone()), + ) + .await + } + .await; + + let mut stream = match first_attempt { + Ok(stream) => stream, + Err(err) if self.settings.daemon.autostart && Self::should_retry(&err) => { + debug!("daemon not available, attempting auto-start"); + self.client = None; + + daemon::ensure_daemon_running(&self.settings).await?; + + let client = self.get_client().await?; + client + .search( + query.clone(), + query_id, + state.filter_mode, + Some(state.context.clone()), + ) + .await? + } + Err(err) => return Err(err), + }; + + let mut ids = Vec::with_capacity(200); + span!(Level::TRACE, "daemon_search.resp") + .in_scope(async || { + while let Ok(Some(response)) = stream.message().await { + let span2 = span!( + Level::TRACE, + "daemon_search.resp.item", + query_id = response.query_id + ); + let _span2 = span2.enter(); + // Only process if the query_id matches (prevents stale responses) + if response.query_id == query_id { + let uuids = response + .ids + .iter() + .map(|id| { + let bytes: [u8; 16] = + id.as_slice().try_into().expect("id should be 16 bytes"); + Uuid::from_bytes(bytes).as_simple().to_string() + }) + .collect::<Vec<_>>(); + ids.extend(uuids); + } + drop(_span2); + drop(span2); + } + }) + .await; + drop(span); + + if ids.is_empty() { + debug!(query = %query, results = 0, "[daemon-client] empty results"); + return Ok(Vec::new()); + } + + // // Hydrate from local database + let results = self.hydrate_from_db(db, &ids).await?; + + // // Reorder results to match the order from the daemon (which is ranked by relevance) + let ordered_results = span!(Level::TRACE, "reorder_results").in_scope(|| { + let mut ordered_results = Vec::with_capacity(results.len()); + for id in &ids { + if let Some(history) = results.iter().find(|h| h.id.0 == *id) { + ordered_results.push(history.clone()); + } + } + ordered_results + }); + + debug!( + query = %query, + results = results.len(), + "[daemon-client]" + ); + + Ok(ordered_results) + } + + #[instrument(skip_all, level = Level::TRACE, name = "daemon_highlight")] + fn get_highlight_indices(&self, command: &str, search_input: &str) -> Vec<usize> { + // Use fulltext highlighting for regex queries + if Self::contains_regex_pattern(search_input) { + return super::db::get_highlight_indices_fulltext(command, search_input); + } + + let mut matcher = Matcher::new(Config::DEFAULT); + let pattern = Pattern::parse(search_input, CaseMatching::Smart, Normalization::Smart); + + let mut indices: Vec<u32> = Vec::new(); + let mut haystack_buf = Vec::new(); + + let haystack = Utf32Str::new(command, &mut haystack_buf); + pattern.indices(haystack, &mut matcher, &mut indices); + + // Convert u32 indices to usize + indices.into_iter().map(|i| i as usize).collect() + } +} diff --git a/crates/turtle/src/command/client/search/engines/db.rs b/crates/turtle/src/command/client/search/engines/db.rs new file mode 100644 index 00000000..2765faf5 --- /dev/null +++ b/crates/turtle/src/command/client/search/engines/db.rs @@ -0,0 +1,110 @@ +use super::{SearchEngine, SearchState}; +use async_trait::async_trait; +use crate::atuin_client::{ + database::Database, + database::OptFilters, + database::{QueryToken, QueryTokenizer}, + history::{AUTHOR_FILTER_ALL_USER, History}, + settings::SearchMode, +}; +use eyre::Result; +use norm::Metric; +use norm::fzf::{FzfParser, FzfV2}; +use std::ops::Range; +use tracing::{Level, instrument}; + +pub struct Search(pub SearchMode); + +#[async_trait] +impl SearchEngine for Search { + #[instrument(skip_all, level = Level::TRACE, name = "db_search", fields(mode = ?self.0, query = %state.input.as_str()))] + async fn full_query( + &mut self, + state: &SearchState, + db: &mut dyn Database, + ) -> Result<Vec<History>> { + let results = db + .search( + self.0, + state.filter_mode, + &state.context, + state.input.as_str(), + OptFilters { + limit: Some(200), + authors: vec![AUTHOR_FILTER_ALL_USER.to_string()], + ..Default::default() + }, + ) + .await + // ignore errors as it may be caused by incomplete regex + .map_or(Vec::new(), |r| r.into_iter().collect()); + Ok(results) + } + + #[instrument(skip_all, level = Level::TRACE, name = "db_highlight")] + fn get_highlight_indices(&self, command: &str, search_input: &str) -> Vec<usize> { + if self.0 == SearchMode::Prefix { + return vec![]; + } else if self.0 == SearchMode::FullText { + return get_highlight_indices_fulltext(command, search_input); + } + let mut fzf = FzfV2::new(); + let mut parser = FzfParser::new(); + let query = parser.parse(search_input); + let mut ranges: Vec<Range<usize>> = Vec::new(); + let _ = fzf.distance_and_ranges(query, command, &mut ranges); + + // convert ranges to all indices + ranges.into_iter().flatten().collect() + } +} + +#[instrument(skip_all, level = Level::TRACE, name = "db_highlight_fulltext")] +pub fn get_highlight_indices_fulltext(command: &str, search_input: &str) -> Vec<usize> { + let mut ranges = vec![]; + let lower_command = command.to_ascii_lowercase(); + + for token in QueryTokenizer::new(search_input) { + let matchee = if token.has_uppercase() { + command + } else { + &lower_command + }; + + if token.is_inverse() { + continue; + } + + match token { + QueryToken::Or => {} + QueryToken::Regex(r) => { + if let Ok(re) = regex::Regex::new(r) { + for m in re.find_iter(command) { + ranges.push(m.range()); + } + } + } + QueryToken::MatchStart(term, _) => { + if matchee.starts_with(term) { + ranges.push(0..term.len()); + } + } + QueryToken::MatchEnd(term, _) => { + if matchee.ends_with(term) { + let l = matchee.len(); + ranges.push((l - term.len())..l); + } + } + QueryToken::Match(term, _) | QueryToken::MatchFull(term, _) => { + for (idx, m) in matchee.match_indices(term) { + ranges.push(idx..(idx + m.len())); + } + } + } + } + + let mut ret: Vec<_> = ranges.into_iter().flatten().collect(); + ret.sort_unstable(); + ret.dedup(); + ret +} diff --git a/crates/turtle/src/command/client/search/engines/skim.rs b/crates/turtle/src/command/client/search/engines/skim.rs new file mode 100644 index 00000000..96a6574d --- /dev/null +++ b/crates/turtle/src/command/client/search/engines/skim.rs @@ -0,0 +1,229 @@ +use std::path::Path; + +use async_trait::async_trait; +use crate::atuin_client::{ + database::Database, + history::{History, is_known_agent}, + settings::FilterMode, +}; +use eyre::Result; +use fuzzy_matcher::{FuzzyMatcher, skim::SkimMatcherV2}; +use itertools::Itertools; +use time::OffsetDateTime; +use tokio::task::yield_now; +use tracing::{Level, instrument, warn}; +use uuid; + +use super::{SearchEngine, SearchState}; + +pub struct Search { + all_history: Vec<(History, i32)>, + engine: SkimMatcherV2, +} + +impl Search { + pub fn new() -> Self { + Search { + all_history: vec![], + engine: SkimMatcherV2::default(), + } + } +} + +#[async_trait] +impl SearchEngine for Search { + #[instrument(skip_all, level = Level::TRACE, name = "skim_search", fields(query = %state.input.as_str()))] + async fn full_query( + &mut self, + state: &SearchState, + db: &mut dyn Database, + ) -> Result<Vec<History>> { + if self.all_history.is_empty() { + self.all_history = load_all_history(db).await; + } + + Ok(fuzzy_search(&self.engine, state, &self.all_history).await) + } + + #[instrument(skip_all, level = Level::TRACE, name = "skim_highlight")] + fn get_highlight_indices(&self, command: &str, search_input: &str) -> Vec<usize> { + let (_, indices) = self + .engine + .fuzzy_indices(command, search_input) + .unwrap_or_default(); + indices + } +} + +#[instrument(skip_all, level = Level::TRACE, name = "load_all_history")] +async fn load_all_history(db: &dyn Database) -> Vec<(History, i32)> { + db.all_with_count().await.unwrap() +} + +#[expect(clippy::too_many_lines)] +#[instrument(skip_all, level = Level::TRACE, name = "fuzzy_match", fields(history_count = all_history.len()))] +async fn fuzzy_search( + engine: &SkimMatcherV2, + state: &SearchState, + all_history: &[(History, i32)], +) -> Vec<History> { + let mut set = Vec::with_capacity(200); + let mut ranks = Vec::with_capacity(200); + let query = state.input.as_str(); + let now = OffsetDateTime::now_utc(); + + for (i, (history, count)) in all_history.iter().enumerate() { + if i % 256 == 0 { + yield_now().await; + } + if is_known_agent(&history.author) { + continue; + } + let context = &state.context; + let git_root = context + .git_root + .as_ref() + .and_then(|git_root| git_root.to_str()) + .unwrap_or(&context.cwd); + match state.filter_mode { + FilterMode::Global => {} + // we aggregate host by ',' separating them + FilterMode::Host + if history + .hostname + .split(',') + .contains(&context.hostname.as_str()) => {} + // we aggregate session by concattenating them. + // sessions are 32 byte simple uuid formats + FilterMode::Session + if history + .session + .as_bytes() + .chunks(32) + .contains(&context.session.as_bytes()) => {} + // SessionPreload: include current session + global history from before session start + FilterMode::SessionPreload => { + let is_current_session = { + history + .session + .as_bytes() + .chunks(32) + .any(|chunk| chunk == context.session.as_bytes()) + }; + + if !is_current_session { + let Ok(uuid) = uuid::Uuid::parse_str(&context.session) else { + warn!("failed to parse session id '{}'", context.session); + continue; + }; + let Some(timestamp) = uuid.get_timestamp() else { + warn!( + "failed to get timestamp from uuid '{}'", + uuid.as_hyphenated() + ); + continue; + }; + let (seconds, nanos) = timestamp.to_unix(); + let Ok(session_start) = time::OffsetDateTime::from_unix_timestamp_nanos( + i128::from(seconds) * 1_000_000_000 + i128::from(nanos), + ) else { + warn!( + "failed to create OffsetDateTime from second: {seconds}, nanosecond: {nanos}" + ); + continue; + }; + + if history.timestamp >= session_start { + continue; + } + } + } + // we aggregate directory by ':' separating them + FilterMode::Directory if history.cwd.split(':').contains(&context.cwd.as_str()) => {} + FilterMode::Workspace if history.cwd.split(':').contains(&git_root) => {} + _ => continue, + } + #[expect(clippy::cast_lossless, clippy::cast_precision_loss)] + if let Some((score, indices)) = engine.fuzzy_indices(&history.command, query) { + let begin = indices.first().copied().unwrap_or_default(); + + let mut duration = (now - history.timestamp).as_seconds_f64().log2(); + if !duration.is_finite() || duration <= 1.0 { + duration = 1.0; + } + // these + X.0 just make the log result a bit smoother. + // log is very spiky towards 1-4, but I want a gradual decay. + // eg: + // log2(4) = 2, log2(5) = 2.3 (16% increase) + // log2(8) = 3, log2(9) = 3.16 (5% increase) + // log2(16) = 4, log2(17) = 4.08 (2% increase) + let count = (*count as f64 + 8.0).log2(); + let begin = (begin as f64 + 16.0).log2(); + let path = path_dist(history.cwd.as_ref(), state.context.cwd.as_ref()); + let path = (path as f64 + 8.0).log2(); + + // reduce longer durations, raise higher counts, raise matches close to the start + let score = (-score as f64) * count / path / duration / begin; + + 'insert: { + // algorithm: + // 1. find either the position that this command ranks + // 2. find the same command positioned better than our rank. + for i in 0..set.len() { + // do we out score the current position? + if ranks[i] > score { + ranks.insert(i, score); + set.insert(i, history.clone()); + let mut j = i + 1; + while j < set.len() { + // remove duplicates that have a worse score + if set[j].command == history.command { + ranks.remove(j); + set.remove(j); + + // break this while loop because there won't be any other + // duplicates. + break; + } + j += 1; + } + + // keep it limited + if ranks.len() > 200 { + ranks.pop(); + set.pop(); + } + + break 'insert; + } + // don't continue if this command has a better score already + if set[i].command == history.command { + break 'insert; + } + } + + if set.len() < 200 { + ranks.push(score); + set.push(history.clone()); + } + } + } + } + + set +} + +fn path_dist(a: &Path, b: &Path) -> usize { + let mut a: Vec<_> = a.components().collect(); + let b: Vec<_> = b.components().collect(); + + let mut dist = 0; + + // pop a until there's a common ancestor + while !b.starts_with(&a) { + dist += 1; + a.pop(); + } + + b.len() - a.len() + dist +} |
