aboutsummaryrefslogtreecommitdiffstats
path: root/crates/turtle/src/command/client/search/engines
diff options
context:
space:
mode:
authorBenedikt Peetz <benedikt.peetz@b-peetz.de>2026-06-11 00:54:30 +0200
committerBenedikt Peetz <benedikt.peetz@b-peetz.de>2026-06-11 00:54:30 +0200
commit5c39e7cf284a1f6e9a1657f2deb44e359fc47eb8 (patch)
treec64baa8d5866c8e339eaf660dd3f94f30a3f7d8a /crates/turtle/src/command/client/search/engines
parentchore: Somewhat simplify sync code (diff)
downloadatuin-5c39e7cf284a1f6e9a1657f2deb44e359fc47eb8.zip
chore: Move everything into one big crate
That helps remove duplicated code and rustc/cargo will now also show dead code correctly.
Diffstat (limited to 'crates/turtle/src/command/client/search/engines')
-rw-r--r--crates/turtle/src/command/client/search/engines/daemon.rs242
-rw-r--r--crates/turtle/src/command/client/search/engines/db.rs110
-rw-r--r--crates/turtle/src/command/client/search/engines/skim.rs229
3 files changed, 581 insertions, 0 deletions
diff --git a/crates/turtle/src/command/client/search/engines/daemon.rs b/crates/turtle/src/command/client/search/engines/daemon.rs
new file mode 100644
index 00000000..b1299c02
--- /dev/null
+++ b/crates/turtle/src/command/client/search/engines/daemon.rs
@@ -0,0 +1,242 @@
+use crate::atuin_client::{
+ database::{Database, OptFilters},
+ history::{AUTHOR_FILTER_ALL_USER, History},
+ settings::{SearchMode, Settings},
+};
+use crate::atuin_daemon::client::{DaemonClientErrorKind, SearchClient, classify_error};
+use async_trait::async_trait;
+use atuin_nucleo_matcher::{
+ Config, Matcher, Utf32Str,
+ pattern::{CaseMatching, Normalization, Pattern},
+};
+use eyre::Result;
+use tracing::{Level, debug, instrument, span};
+use uuid::Uuid;
+
+use super::{SearchEngine, SearchState};
+use crate::command::client::daemon;
+
+pub struct Search {
+ client: Option<SearchClient>,
+ query_id: u64,
+ settings: Settings,
+ #[cfg(unix)]
+ socket_path: String,
+}
+
+impl Search {
+ pub fn new(settings: &Settings) -> Self {
+ Search {
+ client: None,
+ query_id: 0,
+ settings: settings.clone(),
+ #[cfg(unix)]
+ socket_path: settings.daemon.socket_path.clone(),
+ }
+ }
+
+ #[instrument(skip_all, level = Level::TRACE, name = "get_daemon_client")]
+ async fn get_client(&mut self) -> Result<&mut SearchClient> {
+ if self.client.is_none() {
+ self.connect().await?;
+ }
+ Ok(self.client.as_mut().unwrap())
+ }
+
+ async fn connect(&mut self) -> Result<()> {
+ #[cfg(unix)]
+ let client = SearchClient::new(self.socket_path.clone()).await?;
+
+ self.client = Some(client);
+ Ok(())
+ }
+
+ fn should_retry(err: &eyre::Report) -> bool {
+ matches!(
+ classify_error(err),
+ DaemonClientErrorKind::Connect
+ | DaemonClientErrorKind::Unavailable
+ | DaemonClientErrorKind::Unimplemented
+ )
+ }
+
+ fn next_query_id(&mut self) -> u64 {
+ self.query_id += 1;
+ self.query_id
+ }
+
+ /// Check if query contains regex pattern (r/.../)
+ /// Nucleo doesn't support regex, so we fall back to database search
+ fn contains_regex_pattern(query: &str) -> bool {
+ query.starts_with("r/") || query.contains(" r/")
+ }
+
+ #[instrument(skip_all, level = Level::TRACE, name = "daemon_db_fallback")]
+ async fn fallback_to_db_search(
+ &self,
+ state: &SearchState,
+ db: &dyn Database,
+ ) -> Result<Vec<History>> {
+ let results = db
+ .search(
+ SearchMode::FullText,
+ state.filter_mode,
+ &state.context,
+ state.input.as_str(),
+ OptFilters {
+ limit: Some(200),
+ authors: vec![AUTHOR_FILTER_ALL_USER.to_string()],
+ ..Default::default()
+ },
+ )
+ .await
+ .map_or(Vec::new(), |r| r.into_iter().collect());
+ Ok(results)
+ }
+
+ #[instrument(skip_all, level = Level::TRACE, name = "hydrate_from_db", fields(count = ids.len()))]
+ async fn hydrate_from_db(&self, db: &dyn Database, ids: &[String]) -> Result<Vec<History>> {
+ let placeholders: Vec<String> = ids.iter().map(|id| format!("'{id}'")).collect();
+ let sql_query = format!(
+ "SELECT * FROM history WHERE id IN ({}) ORDER BY timestamp DESC",
+ placeholders.join(",")
+ );
+ Ok(db.query_history(&sql_query).await?)
+ }
+}
+
+#[async_trait]
+impl SearchEngine for Search {
+ #[instrument(skip_all, level = Level::TRACE, name = "daemon_search", fields(query = %state.input.as_str()))]
+ async fn full_query(
+ &mut self,
+ state: &SearchState,
+ db: &mut dyn Database,
+ ) -> Result<Vec<History>> {
+ let query = state.input.as_str().to_string();
+
+ // Fall back to database for regex queries (Nucleo doesn't support regex)
+ if Self::contains_regex_pattern(&query) {
+ debug!(query = %query, "[daemon-client] regex detected, falling back to db");
+ return self.fallback_to_db_search(state, db).await;
+ }
+
+ let query_id = self.next_query_id();
+
+ let span =
+ span!(Level::TRACE, "daemon_search.req_resp", query = %query, query_id = query_id);
+
+ // Try to connect and search; if it fails with a retriable error,
+ // auto-start the daemon and retry once.
+ let first_attempt = async {
+ let client = self.get_client().await?;
+ client
+ .search(
+ query.clone(),
+ query_id,
+ state.filter_mode,
+ Some(state.context.clone()),
+ )
+ .await
+ }
+ .await;
+
+ let mut stream = match first_attempt {
+ Ok(stream) => stream,
+ Err(err) if self.settings.daemon.autostart && Self::should_retry(&err) => {
+ debug!("daemon not available, attempting auto-start");
+ self.client = None;
+
+ daemon::ensure_daemon_running(&self.settings).await?;
+
+ let client = self.get_client().await?;
+ client
+ .search(
+ query.clone(),
+ query_id,
+ state.filter_mode,
+ Some(state.context.clone()),
+ )
+ .await?
+ }
+ Err(err) => return Err(err),
+ };
+
+ let mut ids = Vec::with_capacity(200);
+ span!(Level::TRACE, "daemon_search.resp")
+ .in_scope(async || {
+ while let Ok(Some(response)) = stream.message().await {
+ let span2 = span!(
+ Level::TRACE,
+ "daemon_search.resp.item",
+ query_id = response.query_id
+ );
+ let _span2 = span2.enter();
+ // Only process if the query_id matches (prevents stale responses)
+ if response.query_id == query_id {
+ let uuids = response
+ .ids
+ .iter()
+ .map(|id| {
+ let bytes: [u8; 16] =
+ id.as_slice().try_into().expect("id should be 16 bytes");
+ Uuid::from_bytes(bytes).as_simple().to_string()
+ })
+ .collect::<Vec<_>>();
+ ids.extend(uuids);
+ }
+ drop(_span2);
+ drop(span2);
+ }
+ })
+ .await;
+ drop(span);
+
+ if ids.is_empty() {
+ debug!(query = %query, results = 0, "[daemon-client] empty results");
+ return Ok(Vec::new());
+ }
+
+ // // Hydrate from local database
+ let results = self.hydrate_from_db(db, &ids).await?;
+
+ // // Reorder results to match the order from the daemon (which is ranked by relevance)
+ let ordered_results = span!(Level::TRACE, "reorder_results").in_scope(|| {
+ let mut ordered_results = Vec::with_capacity(results.len());
+ for id in &ids {
+ if let Some(history) = results.iter().find(|h| h.id.0 == *id) {
+ ordered_results.push(history.clone());
+ }
+ }
+ ordered_results
+ });
+
+ debug!(
+ query = %query,
+ results = results.len(),
+ "[daemon-client]"
+ );
+
+ Ok(ordered_results)
+ }
+
+ #[instrument(skip_all, level = Level::TRACE, name = "daemon_highlight")]
+ fn get_highlight_indices(&self, command: &str, search_input: &str) -> Vec<usize> {
+ // Use fulltext highlighting for regex queries
+ if Self::contains_regex_pattern(search_input) {
+ return super::db::get_highlight_indices_fulltext(command, search_input);
+ }
+
+ let mut matcher = Matcher::new(Config::DEFAULT);
+ let pattern = Pattern::parse(search_input, CaseMatching::Smart, Normalization::Smart);
+
+ let mut indices: Vec<u32> = Vec::new();
+ let mut haystack_buf = Vec::new();
+
+ let haystack = Utf32Str::new(command, &mut haystack_buf);
+ pattern.indices(haystack, &mut matcher, &mut indices);
+
+ // Convert u32 indices to usize
+ indices.into_iter().map(|i| i as usize).collect()
+ }
+}
diff --git a/crates/turtle/src/command/client/search/engines/db.rs b/crates/turtle/src/command/client/search/engines/db.rs
new file mode 100644
index 00000000..2765faf5
--- /dev/null
+++ b/crates/turtle/src/command/client/search/engines/db.rs
@@ -0,0 +1,110 @@
+use super::{SearchEngine, SearchState};
+use async_trait::async_trait;
+use crate::atuin_client::{
+ database::Database,
+ database::OptFilters,
+ database::{QueryToken, QueryTokenizer},
+ history::{AUTHOR_FILTER_ALL_USER, History},
+ settings::SearchMode,
+};
+use eyre::Result;
+use norm::Metric;
+use norm::fzf::{FzfParser, FzfV2};
+use std::ops::Range;
+use tracing::{Level, instrument};
+
+pub struct Search(pub SearchMode);
+
+#[async_trait]
+impl SearchEngine for Search {
+ #[instrument(skip_all, level = Level::TRACE, name = "db_search", fields(mode = ?self.0, query = %state.input.as_str()))]
+ async fn full_query(
+ &mut self,
+ state: &SearchState,
+ db: &mut dyn Database,
+ ) -> Result<Vec<History>> {
+ let results = db
+ .search(
+ self.0,
+ state.filter_mode,
+ &state.context,
+ state.input.as_str(),
+ OptFilters {
+ limit: Some(200),
+ authors: vec![AUTHOR_FILTER_ALL_USER.to_string()],
+ ..Default::default()
+ },
+ )
+ .await
+ // ignore errors as it may be caused by incomplete regex
+ .map_or(Vec::new(), |r| r.into_iter().collect());
+ Ok(results)
+ }
+
+ #[instrument(skip_all, level = Level::TRACE, name = "db_highlight")]
+ fn get_highlight_indices(&self, command: &str, search_input: &str) -> Vec<usize> {
+ if self.0 == SearchMode::Prefix {
+ return vec![];
+ } else if self.0 == SearchMode::FullText {
+ return get_highlight_indices_fulltext(command, search_input);
+ }
+ let mut fzf = FzfV2::new();
+ let mut parser = FzfParser::new();
+ let query = parser.parse(search_input);
+ let mut ranges: Vec<Range<usize>> = Vec::new();
+ let _ = fzf.distance_and_ranges(query, command, &mut ranges);
+
+ // convert ranges to all indices
+ ranges.into_iter().flatten().collect()
+ }
+}
+
+#[instrument(skip_all, level = Level::TRACE, name = "db_highlight_fulltext")]
+pub fn get_highlight_indices_fulltext(command: &str, search_input: &str) -> Vec<usize> {
+ let mut ranges = vec![];
+ let lower_command = command.to_ascii_lowercase();
+
+ for token in QueryTokenizer::new(search_input) {
+ let matchee = if token.has_uppercase() {
+ command
+ } else {
+ &lower_command
+ };
+
+ if token.is_inverse() {
+ continue;
+ }
+
+ match token {
+ QueryToken::Or => {}
+ QueryToken::Regex(r) => {
+ if let Ok(re) = regex::Regex::new(r) {
+ for m in re.find_iter(command) {
+ ranges.push(m.range());
+ }
+ }
+ }
+ QueryToken::MatchStart(term, _) => {
+ if matchee.starts_with(term) {
+ ranges.push(0..term.len());
+ }
+ }
+ QueryToken::MatchEnd(term, _) => {
+ if matchee.ends_with(term) {
+ let l = matchee.len();
+ ranges.push((l - term.len())..l);
+ }
+ }
+ QueryToken::Match(term, _) | QueryToken::MatchFull(term, _) => {
+ for (idx, m) in matchee.match_indices(term) {
+ ranges.push(idx..(idx + m.len()));
+ }
+ }
+ }
+ }
+
+ let mut ret: Vec<_> = ranges.into_iter().flatten().collect();
+ ret.sort_unstable();
+ ret.dedup();
+ ret
+}
diff --git a/crates/turtle/src/command/client/search/engines/skim.rs b/crates/turtle/src/command/client/search/engines/skim.rs
new file mode 100644
index 00000000..96a6574d
--- /dev/null
+++ b/crates/turtle/src/command/client/search/engines/skim.rs
@@ -0,0 +1,229 @@
+use std::path::Path;
+
+use async_trait::async_trait;
+use crate::atuin_client::{
+ database::Database,
+ history::{History, is_known_agent},
+ settings::FilterMode,
+};
+use eyre::Result;
+use fuzzy_matcher::{FuzzyMatcher, skim::SkimMatcherV2};
+use itertools::Itertools;
+use time::OffsetDateTime;
+use tokio::task::yield_now;
+use tracing::{Level, instrument, warn};
+use uuid;
+
+use super::{SearchEngine, SearchState};
+
+pub struct Search {
+ all_history: Vec<(History, i32)>,
+ engine: SkimMatcherV2,
+}
+
+impl Search {
+ pub fn new() -> Self {
+ Search {
+ all_history: vec![],
+ engine: SkimMatcherV2::default(),
+ }
+ }
+}
+
+#[async_trait]
+impl SearchEngine for Search {
+ #[instrument(skip_all, level = Level::TRACE, name = "skim_search", fields(query = %state.input.as_str()))]
+ async fn full_query(
+ &mut self,
+ state: &SearchState,
+ db: &mut dyn Database,
+ ) -> Result<Vec<History>> {
+ if self.all_history.is_empty() {
+ self.all_history = load_all_history(db).await;
+ }
+
+ Ok(fuzzy_search(&self.engine, state, &self.all_history).await)
+ }
+
+ #[instrument(skip_all, level = Level::TRACE, name = "skim_highlight")]
+ fn get_highlight_indices(&self, command: &str, search_input: &str) -> Vec<usize> {
+ let (_, indices) = self
+ .engine
+ .fuzzy_indices(command, search_input)
+ .unwrap_or_default();
+ indices
+ }
+}
+
+#[instrument(skip_all, level = Level::TRACE, name = "load_all_history")]
+async fn load_all_history(db: &dyn Database) -> Vec<(History, i32)> {
+ db.all_with_count().await.unwrap()
+}
+
+#[expect(clippy::too_many_lines)]
+#[instrument(skip_all, level = Level::TRACE, name = "fuzzy_match", fields(history_count = all_history.len()))]
+async fn fuzzy_search(
+ engine: &SkimMatcherV2,
+ state: &SearchState,
+ all_history: &[(History, i32)],
+) -> Vec<History> {
+ let mut set = Vec::with_capacity(200);
+ let mut ranks = Vec::with_capacity(200);
+ let query = state.input.as_str();
+ let now = OffsetDateTime::now_utc();
+
+ for (i, (history, count)) in all_history.iter().enumerate() {
+ if i % 256 == 0 {
+ yield_now().await;
+ }
+ if is_known_agent(&history.author) {
+ continue;
+ }
+ let context = &state.context;
+ let git_root = context
+ .git_root
+ .as_ref()
+ .and_then(|git_root| git_root.to_str())
+ .unwrap_or(&context.cwd);
+ match state.filter_mode {
+ FilterMode::Global => {}
+ // we aggregate host by ',' separating them
+ FilterMode::Host
+ if history
+ .hostname
+ .split(',')
+ .contains(&context.hostname.as_str()) => {}
+ // we aggregate session by concattenating them.
+ // sessions are 32 byte simple uuid formats
+ FilterMode::Session
+ if history
+ .session
+ .as_bytes()
+ .chunks(32)
+ .contains(&context.session.as_bytes()) => {}
+ // SessionPreload: include current session + global history from before session start
+ FilterMode::SessionPreload => {
+ let is_current_session = {
+ history
+ .session
+ .as_bytes()
+ .chunks(32)
+ .any(|chunk| chunk == context.session.as_bytes())
+ };
+
+ if !is_current_session {
+ let Ok(uuid) = uuid::Uuid::parse_str(&context.session) else {
+ warn!("failed to parse session id '{}'", context.session);
+ continue;
+ };
+ let Some(timestamp) = uuid.get_timestamp() else {
+ warn!(
+ "failed to get timestamp from uuid '{}'",
+ uuid.as_hyphenated()
+ );
+ continue;
+ };
+ let (seconds, nanos) = timestamp.to_unix();
+ let Ok(session_start) = time::OffsetDateTime::from_unix_timestamp_nanos(
+ i128::from(seconds) * 1_000_000_000 + i128::from(nanos),
+ ) else {
+ warn!(
+ "failed to create OffsetDateTime from second: {seconds}, nanosecond: {nanos}"
+ );
+ continue;
+ };
+
+ if history.timestamp >= session_start {
+ continue;
+ }
+ }
+ }
+ // we aggregate directory by ':' separating them
+ FilterMode::Directory if history.cwd.split(':').contains(&context.cwd.as_str()) => {}
+ FilterMode::Workspace if history.cwd.split(':').contains(&git_root) => {}
+ _ => continue,
+ }
+ #[expect(clippy::cast_lossless, clippy::cast_precision_loss)]
+ if let Some((score, indices)) = engine.fuzzy_indices(&history.command, query) {
+ let begin = indices.first().copied().unwrap_or_default();
+
+ let mut duration = (now - history.timestamp).as_seconds_f64().log2();
+ if !duration.is_finite() || duration <= 1.0 {
+ duration = 1.0;
+ }
+ // these + X.0 just make the log result a bit smoother.
+ // log is very spiky towards 1-4, but I want a gradual decay.
+ // eg:
+ // log2(4) = 2, log2(5) = 2.3 (16% increase)
+ // log2(8) = 3, log2(9) = 3.16 (5% increase)
+ // log2(16) = 4, log2(17) = 4.08 (2% increase)
+ let count = (*count as f64 + 8.0).log2();
+ let begin = (begin as f64 + 16.0).log2();
+ let path = path_dist(history.cwd.as_ref(), state.context.cwd.as_ref());
+ let path = (path as f64 + 8.0).log2();
+
+ // reduce longer durations, raise higher counts, raise matches close to the start
+ let score = (-score as f64) * count / path / duration / begin;
+
+ 'insert: {
+ // algorithm:
+ // 1. find either the position that this command ranks
+ // 2. find the same command positioned better than our rank.
+ for i in 0..set.len() {
+ // do we out score the current position?
+ if ranks[i] > score {
+ ranks.insert(i, score);
+ set.insert(i, history.clone());
+ let mut j = i + 1;
+ while j < set.len() {
+ // remove duplicates that have a worse score
+ if set[j].command == history.command {
+ ranks.remove(j);
+ set.remove(j);
+
+ // break this while loop because there won't be any other
+ // duplicates.
+ break;
+ }
+ j += 1;
+ }
+
+ // keep it limited
+ if ranks.len() > 200 {
+ ranks.pop();
+ set.pop();
+ }
+
+ break 'insert;
+ }
+ // don't continue if this command has a better score already
+ if set[i].command == history.command {
+ break 'insert;
+ }
+ }
+
+ if set.len() < 200 {
+ ranks.push(score);
+ set.push(history.clone());
+ }
+ }
+ }
+ }
+
+ set
+}
+
+fn path_dist(a: &Path, b: &Path) -> usize {
+ let mut a: Vec<_> = a.components().collect();
+ let b: Vec<_> = b.components().collect();
+
+ let mut dist = 0;
+
+ // pop a until there's a common ancestor
+ while !b.starts_with(&a) {
+ dist += 1;
+ a.pop();
+ }
+
+ b.len() - a.len() + dist
+}