diff options
Diffstat (limited to 'crates/atuin-nucleo')
| -rw-r--r-- | crates/atuin-nucleo/src/lib.rs | 62 | ||||
| -rw-r--r-- | crates/atuin-nucleo/src/tests.rs | 242 | ||||
| -rw-r--r-- | crates/atuin-nucleo/src/worker.rs | 128 |
3 files changed, 418 insertions, 14 deletions
diff --git a/crates/atuin-nucleo/src/lib.rs b/crates/atuin-nucleo/src/lib.rs index 7ddb7407..5a500481 100644 --- a/crates/atuin-nucleo/src/lib.rs +++ b/crates/atuin-nucleo/src/lib.rs @@ -31,6 +31,15 @@ use std::sync::atomic::{self, AtomicBool, Ordering}; use std::sync::Arc; use std::time::Duration; +/// A filter predicate that determines whether an item should be included in matching. +/// Return `true` to include the item, `false` to skip it. +pub type Filter<T> = Arc<dyn Fn(&T) -> bool + Send + Sync>; + +/// A scorer callback that computes the final ranking score for an item. +/// Receives a reference to the item and its fuzzy match score. +/// Returns the combined/external score used for sorting results. +pub type Scorer<T> = Arc<dyn Fn(&T, u32) -> u32 + Send + Sync>; + use parking_lot::Mutex; use rayon::ThreadPool; @@ -124,7 +133,13 @@ impl<T> Injector<T> { /// An [item](crate::Item) that was successfully matched by a [`Nucleo`] worker. #[derive(PartialEq, Eq, Debug, Clone, Copy)] pub struct Match { + /// The raw fuzzy match score from the matcher. pub score: u32, + /// The external/combined score used for sorting. + /// If no scorer callback is set, this equals `score`. + /// If a scorer callback is set, this is the value returned by the callback. + pub external_score: u32, + /// The index of the matched item in the item list. pub idx: u32, } @@ -290,6 +305,12 @@ pub struct Nucleo<T: Sync + Send + 'static> { /// Note that the matcher worker will only become aware of the new pattern /// after a call to [`tick`](Nucleo::tick). pub pattern: MultiPattern, + /// Optional filter predicate. Items where filter returns false are skipped. + filter: Option<Filter<T>>, + /// Optional scorer callback. Returns combined score used for sorting. + scorer: Option<Scorer<T>>, + /// Flag indicating filter or scorer has changed and rescore is needed. + filter_scorer_changed: bool, } impl<T: Sync + Send + 'static> Nucleo<T> { @@ -328,6 +349,9 @@ impl<T: Sync + Send + 'static> Nucleo<T> { worker: Arc::new(Mutex::new(worker)), state: State::Init, notify, + filter: None, + scorer: None, + filter_scorer_changed: false, } } @@ -388,13 +412,46 @@ impl<T: Sync + Send + 'static> Nucleo<T> { self.worker.lock().reverse_items(reverse_items) } + /// Set a filter predicate. Items where the filter returns `false` are + /// skipped during matching. This is applied before fuzzy matching, so + /// filtered items don't incur the cost of fuzzy matching. + /// + /// Setting a new filter triggers a rescore on the next [`tick`](Nucleo::tick). + /// + /// Pass `None` to remove the filter. + pub fn set_filter(&mut self, filter: Option<Filter<T>>) { + self.filter = filter; + self.filter_scorer_changed = true; + } + + /// Set a scorer callback. The callback receives a reference to the item + /// and its fuzzy match score, and returns the combined score used for + /// sorting results. + /// + /// If no scorer is set, results are sorted by fuzzy match score. + /// + /// Setting a new scorer triggers a rescore on the next [`tick`](Nucleo::tick). + /// + /// Pass `None` to remove the scorer and use default fuzzy score sorting. + pub fn set_scorer(&mut self, scorer: Option<Scorer<T>>) { + self.scorer = scorer; + self.filter_scorer_changed = true; + } + /// The main way to interact with the matcher, this should be called /// regularly (for example each time a frame is rendered). To avoid /// excessive redraws this method will wait `timeout` milliseconds for the /// worker thread to finish. It is recommend to set the timeout to 10ms. pub fn tick(&mut self, timeout: u64) -> Status { self.should_notify.store(false, atomic::Ordering::Relaxed); - let status = self.pattern.status(); + let mut status = self.pattern.status(); + // If filter or scorer changed, treat as rescore + if self.filter_scorer_changed { + if status == pattern::Status::Unchanged { + status = pattern::Status::Rescore; + } + self.filter_scorer_changed = false; + } let canceled = status != pattern::Status::Unchanged || self.state.canceled(); let mut res = self.tick_inner(timeout, canceled, status); if !canceled { @@ -434,6 +491,9 @@ impl<T: Sync + Send + 'static> Nucleo<T> { } if running { inner.pattern.clone_from(&self.pattern); + // Update filter and scorer in worker + inner.set_filter(self.filter.clone()); + inner.set_scorer(self.scorer.clone()); self.canceled.store(false, atomic::Ordering::Relaxed); if !canceled { self.should_notify.store(true, atomic::Ordering::Release); diff --git a/crates/atuin-nucleo/src/tests.rs b/crates/atuin-nucleo/src/tests.rs index 676c50df..96c4d99c 100644 --- a/crates/atuin-nucleo/src/tests.rs +++ b/crates/atuin-nucleo/src/tests.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use nucleo_matcher::Config; -use crate::Nucleo; +use crate::{pattern, Nucleo}; #[test] fn active_injector_count() { @@ -25,3 +25,243 @@ fn active_injector_count() { drop(injector3); assert_eq!(nucleo.active_injectors(), 0); } + +#[derive(Clone, Debug)] +struct TestItem { + text: String, + category: u32, + priority: u32, +} + +#[test] +fn filter_excludes_items() { + let mut nucleo: Nucleo<TestItem> = Nucleo::new(Config::DEFAULT, Arc::new(|| ()), Some(1), 1); + let injector = nucleo.injector(); + + // Add items with different categories + injector.push( + TestItem { + text: "apple".into(), + category: 1, + priority: 10, + }, + |item, cols| cols[0] = item.text.clone().into(), + ); + injector.push( + TestItem { + text: "apricot".into(), + category: 2, + priority: 20, + }, + |item, cols| cols[0] = item.text.clone().into(), + ); + injector.push( + TestItem { + text: "avocado".into(), + category: 1, + priority: 30, + }, + |item, cols| cols[0] = item.text.clone().into(), + ); + + // Search without filter - should get all 3 + nucleo.pattern.reparse( + 0, + "a", + pattern::CaseMatching::Ignore, + pattern::Normalization::Smart, + false, + ); + while nucleo.tick(10).running {} + assert_eq!(nucleo.snapshot().matched_item_count(), 3); + + // Set filter to only include category 1 + nucleo.set_filter(Some(Arc::new(|item: &TestItem| item.category == 1))); + + // Search again - should only get 2 items (apple, avocado) + while nucleo.tick(10).running {} + assert_eq!(nucleo.snapshot().matched_item_count(), 2); + + // Verify the items are correct + let items: Vec<_> = nucleo + .snapshot() + .matched_items(..) + .map(|i| i.data.text.clone()) + .collect(); + assert!(items.contains(&"apple".to_string())); + assert!(items.contains(&"avocado".to_string())); + assert!(!items.contains(&"apricot".to_string())); + + // Remove filter - should get all 3 again + nucleo.set_filter(None); + while nucleo.tick(10).running {} + assert_eq!(nucleo.snapshot().matched_item_count(), 3); +} + +#[test] +fn scorer_affects_sort_order() { + let mut nucleo: Nucleo<TestItem> = Nucleo::new(Config::DEFAULT, Arc::new(|| ()), Some(1), 1); + let injector = nucleo.injector(); + + // Add items with different priorities + injector.push( + TestItem { + text: "banana".into(), + category: 1, + priority: 10, + }, + |item, cols| cols[0] = item.text.clone().into(), + ); + injector.push( + TestItem { + text: "blueberry".into(), + category: 1, + priority: 100, + }, + |item, cols| cols[0] = item.text.clone().into(), + ); + injector.push( + TestItem { + text: "blackberry".into(), + category: 1, + priority: 50, + }, + |item, cols| cols[0] = item.text.clone().into(), + ); + + // Search without scorer - results sorted by fuzzy score + nucleo.pattern.reparse( + 0, + "b", + pattern::CaseMatching::Ignore, + pattern::Normalization::Smart, + false, + ); + while nucleo.tick(10).running {} + assert_eq!(nucleo.snapshot().matched_item_count(), 3); + + // Set scorer that uses priority as the score (ignoring fuzzy score) + nucleo.set_scorer(Some(Arc::new(|item: &TestItem, _fuzzy_score| { + item.priority + }))); + + // Search again - should be sorted by priority (high to low) + while nucleo.tick(10).running {} + let items: Vec<_> = nucleo + .snapshot() + .matched_items(..) + .map(|i| i.data.clone()) + .collect(); + assert_eq!(items.len(), 3); + assert_eq!(items[0].text, "blueberry"); // priority 100 + assert_eq!(items[1].text, "blackberry"); // priority 50 + assert_eq!(items[2].text, "banana"); // priority 10 + + // Verify external_score is set correctly + let matches = nucleo.snapshot().matches(); + assert_eq!(matches[0].external_score, 100); + assert_eq!(matches[1].external_score, 50); + assert_eq!(matches[2].external_score, 10); +} + +#[test] +fn filter_and_scorer_combined() { + let mut nucleo: Nucleo<TestItem> = Nucleo::new(Config::DEFAULT, Arc::new(|| ()), Some(1), 1); + let injector = nucleo.injector(); + + injector.push( + TestItem { + text: "cherry".into(), + category: 1, + priority: 10, + }, + |item, cols| cols[0] = item.text.clone().into(), + ); + injector.push( + TestItem { + text: "cranberry".into(), + category: 2, + priority: 100, + }, + |item, cols| cols[0] = item.text.clone().into(), + ); + injector.push( + TestItem { + text: "coconut".into(), + category: 1, + priority: 50, + }, + |item, cols| cols[0] = item.text.clone().into(), + ); + + // Set both filter (category 1) and scorer (priority) + nucleo.set_filter(Some(Arc::new(|item: &TestItem| item.category == 1))); + nucleo.set_scorer(Some(Arc::new(|item: &TestItem, _| item.priority))); + + nucleo.pattern.reparse( + 0, + "c", + pattern::CaseMatching::Ignore, + pattern::Normalization::Smart, + false, + ); + while nucleo.tick(10).running {} + + // Should have 2 items (cherry, coconut) sorted by priority + let items: Vec<_> = nucleo + .snapshot() + .matched_items(..) + .map(|i| i.data.clone()) + .collect(); + assert_eq!(items.len(), 2); + assert_eq!(items[0].text, "coconut"); // priority 50 + assert_eq!(items[1].text, "cherry"); // priority 10 +} + +#[test] +fn scorer_combines_with_fuzzy_score() { + let mut nucleo: Nucleo<TestItem> = Nucleo::new(Config::DEFAULT, Arc::new(|| ()), Some(1), 1); + let injector = nucleo.injector(); + + injector.push( + TestItem { + text: "date".into(), + category: 1, + priority: 100, + }, + |item, cols| cols[0] = item.text.clone().into(), + ); + injector.push( + TestItem { + text: "dragon fruit".into(), + category: 1, + priority: 10, + }, + |item, cols| cols[0] = item.text.clone().into(), + ); + + // Set scorer that combines fuzzy score with priority + nucleo.set_scorer(Some(Arc::new(|item: &TestItem, fuzzy_score| { + fuzzy_score + item.priority + }))); + + nucleo.pattern.reparse( + 0, + "d", + pattern::CaseMatching::Ignore, + pattern::Normalization::Smart, + false, + ); + while nucleo.tick(10).running {} + + // Both items match, verify that external_score includes priority boost + let matches = nucleo.snapshot().matches(); + assert_eq!(matches.len(), 2); + + // The raw fuzzy scores should be in Match.score + // The combined scores should be in Match.external_score + for m in matches { + let item = nucleo.snapshot().get_item(m.idx).unwrap(); + assert_eq!(m.external_score, m.score + item.data.priority); + } +} diff --git a/crates/atuin-nucleo/src/worker.rs b/crates/atuin-nucleo/src/worker.rs index f4077e6e..ddd546ad 100644 --- a/crates/atuin-nucleo/src/worker.rs +++ b/crates/atuin-nucleo/src/worker.rs @@ -9,7 +9,7 @@ use rayon::{prelude::*, ThreadPool}; use crate::par_sort::par_quicksort; use crate::pattern::{self, MultiPattern}; -use crate::{boxcar, Match}; +use crate::{boxcar, Filter, Match, Scorer}; struct Matchers(Box<[UnsafeCell<nucleo_matcher::Matcher>]>); @@ -38,6 +38,8 @@ pub(crate) struct Worker<T: Sync + Send + 'static> { notify: Arc<(dyn Fn() + Sync + Send)>, pub(crate) items: Arc<boxcar::Vec<T>>, in_flight: Vec<u32>, + pub(crate) filter: Option<Filter<T>>, + pub(crate) scorer: Option<Scorer<T>>, } impl<T: Sync + Send + 'static> Worker<T> { @@ -56,6 +58,14 @@ impl<T: Sync + Send + 'static> Worker<T> { self.reverse_items = reverse_items; } + pub(crate) fn set_filter(&mut self, filter: Option<Filter<T>>) { + self.filter = filter; + } + + pub(crate) fn set_scorer(&mut self, scorer: Option<Scorer<T>>) { + self.scorer = scorer; + } + pub(crate) fn new( worker_threads: Option<usize>, config: Config, @@ -87,6 +97,8 @@ impl<T: Sync + Send + 'static> Worker<T> { notify, items: Arc::new(boxcar::Vec::with_capacity(2 * 1024, cols)), in_flight: Vec::with_capacity(64), + filter: None, + scorer: None, }; (pool, worker) } @@ -99,8 +111,22 @@ impl<T: Sync + Send + 'static> Worker<T> { let Some(item) = self.items.get(idx) else { return true; }; + // Apply filter if set + if let Some(ref filter) = self.filter { + if !filter(item.data) { + return false; // Item is ready but filtered out + } + } if let Some(score) = pattern.score(item.matcher_columns, matchers.get()) { - self.matches.push(Match { score, idx }); + let external_score = match &self.scorer { + Some(scorer) => scorer(item.data, score), + None => score, + }; + self.matches.push(Match { + score, + external_score, + idx, + }); }; false }); @@ -114,20 +140,45 @@ impl<T: Sync + Send + 'static> Worker<T> { unmatched.fetch_add(1, atomic::Ordering::Relaxed); return Match { score: 0, + external_score: 0, idx: u32::MAX, }; }; if self.canceled.load(atomic::Ordering::Relaxed) { - return Match { score: 0, idx }; + return Match { + score: 0, + external_score: 0, + idx, + }; + } + // Apply filter if set + if let Some(ref filter) = self.filter { + if !filter(item.data) { + unmatched.fetch_add(1, atomic::Ordering::Relaxed); + return Match { + score: 0, + external_score: 0, + idx: u32::MAX, + }; + } } let Some(score) = pattern.score(item.matcher_columns, matchers.get()) else { unmatched.fetch_add(1, atomic::Ordering::Relaxed); return Match { score: 0, + external_score: 0, idx: u32::MAX, }; }; - Match { score, idx } + let external_score = match &self.scorer { + Some(scorer) => scorer(item.data, score), + None => score, + }; + Match { + score, + external_score, + idx, + } }); self.matches.par_extend(items); self.last_snapshot = end; @@ -151,11 +202,23 @@ impl<T: Sync + Send + 'static> Worker<T> { if new_snapshot.end() != self.last_snapshot { let end = new_snapshot.end(); let items = new_snapshot.filter_map(|(idx, item)| { - if item.is_none() { - self.in_flight.push(idx); - return None; + let item = item?; + // Apply filter if set + if let Some(ref filter) = self.filter { + if !filter(item.data) { + return None; + } + } + // For empty pattern, apply scorer with score=0 if set + let external_score = match &self.scorer { + Some(scorer) => scorer(item.data, 0), + None => 0, }; - Some(Match { score: 0, idx }) + Some(Match { + score: 0, + external_score, + idx, + }) }); self.matches.extend(items); self.last_snapshot = end; @@ -205,11 +268,26 @@ impl<T: Sync + Send + 'static> Worker<T> { } // safety: in-flight items are never added to the matches let item = self.items.get_unchecked(match_.idx); + // Apply filter if set + if let Some(ref filter) = self.filter { + if !filter(item.data) { + unmatched.fetch_add(1, atomic::Ordering::Relaxed); + match_.score = 0; + match_.external_score = 0; + match_.idx = u32::MAX; + return; + } + } if let Some(score) = pattern.score(item.matcher_columns, matchers.get()) { match_.score = score; + match_.external_score = match &self.scorer { + Some(scorer) => scorer(item.data, score), + None => score, + }; } else { unmatched.fetch_add(1, atomic::Ordering::Relaxed); match_.score = 0; + match_.external_score = 0; match_.idx = u32::MAX; } }); @@ -234,8 +312,9 @@ impl<T: Sync + Send + 'static> Worker<T> { par_quicksort( &mut self.matches, |match1, match2| { - if match1.score != match2.score { - return match1.score > match2.score; + // Primary sort: external_score (used for frecency/custom ranking) + if match1.external_score != match2.external_score { + return match1.external_score > match2.external_score; } if match1.idx == u32::MAX { return false; @@ -293,8 +372,33 @@ impl<T: Sync + Send + 'static> Worker<T> { fn reset_matches(&mut self) { self.matches.clear(); - self.matches - .extend((0..self.last_snapshot).map(|idx| Match { score: 0, idx })); + // When resetting, apply filter if set + if let Some(ref filter) = self.filter { + for idx in 0..self.last_snapshot { + // Items up to last_snapshot should be initialized + if let Some(item) = self.items.get(idx) { + if filter(item.data) { + let external_score = match &self.scorer { + Some(scorer) => scorer(item.data, 0), + None => 0, + }; + self.matches.push(Match { + score: 0, + external_score, + idx, + }); + } + } + } + } else { + // No filter - add all items + self.matches + .extend((0..self.last_snapshot).map(|idx| Match { + score: 0, + external_score: 0, + idx, + })); + } // there are usually only very few in flight items (one for each writer) self.remove_in_flight_matches(); } |
