diff options
Diffstat (limited to 'crates/atuin-nucleo/src/lib.rs')
| -rw-r--r-- | crates/atuin-nucleo/src/lib.rs | 62 |
1 files changed, 61 insertions, 1 deletions
diff --git a/crates/atuin-nucleo/src/lib.rs b/crates/atuin-nucleo/src/lib.rs index 7ddb7407..5a500481 100644 --- a/crates/atuin-nucleo/src/lib.rs +++ b/crates/atuin-nucleo/src/lib.rs @@ -31,6 +31,15 @@ use std::sync::atomic::{self, AtomicBool, Ordering}; use std::sync::Arc; use std::time::Duration; +/// A filter predicate that determines whether an item should be included in matching. +/// Return `true` to include the item, `false` to skip it. +pub type Filter<T> = Arc<dyn Fn(&T) -> bool + Send + Sync>; + +/// A scorer callback that computes the final ranking score for an item. +/// Receives a reference to the item and its fuzzy match score. +/// Returns the combined/external score used for sorting results. +pub type Scorer<T> = Arc<dyn Fn(&T, u32) -> u32 + Send + Sync>; + use parking_lot::Mutex; use rayon::ThreadPool; @@ -124,7 +133,13 @@ impl<T> Injector<T> { /// An [item](crate::Item) that was successfully matched by a [`Nucleo`] worker. #[derive(PartialEq, Eq, Debug, Clone, Copy)] pub struct Match { + /// The raw fuzzy match score from the matcher. pub score: u32, + /// The external/combined score used for sorting. + /// If no scorer callback is set, this equals `score`. + /// If a scorer callback is set, this is the value returned by the callback. + pub external_score: u32, + /// The index of the matched item in the item list. pub idx: u32, } @@ -290,6 +305,12 @@ pub struct Nucleo<T: Sync + Send + 'static> { /// Note that the matcher worker will only become aware of the new pattern /// after a call to [`tick`](Nucleo::tick). pub pattern: MultiPattern, + /// Optional filter predicate. Items where filter returns false are skipped. + filter: Option<Filter<T>>, + /// Optional scorer callback. Returns combined score used for sorting. + scorer: Option<Scorer<T>>, + /// Flag indicating filter or scorer has changed and rescore is needed. + filter_scorer_changed: bool, } impl<T: Sync + Send + 'static> Nucleo<T> { @@ -328,6 +349,9 @@ impl<T: Sync + Send + 'static> Nucleo<T> { worker: Arc::new(Mutex::new(worker)), state: State::Init, notify, + filter: None, + scorer: None, + filter_scorer_changed: false, } } @@ -388,13 +412,46 @@ impl<T: Sync + Send + 'static> Nucleo<T> { self.worker.lock().reverse_items(reverse_items) } + /// Set a filter predicate. Items where the filter returns `false` are + /// skipped during matching. This is applied before fuzzy matching, so + /// filtered items don't incur the cost of fuzzy matching. + /// + /// Setting a new filter triggers a rescore on the next [`tick`](Nucleo::tick). + /// + /// Pass `None` to remove the filter. + pub fn set_filter(&mut self, filter: Option<Filter<T>>) { + self.filter = filter; + self.filter_scorer_changed = true; + } + + /// Set a scorer callback. The callback receives a reference to the item + /// and its fuzzy match score, and returns the combined score used for + /// sorting results. + /// + /// If no scorer is set, results are sorted by fuzzy match score. + /// + /// Setting a new scorer triggers a rescore on the next [`tick`](Nucleo::tick). + /// + /// Pass `None` to remove the scorer and use default fuzzy score sorting. + pub fn set_scorer(&mut self, scorer: Option<Scorer<T>>) { + self.scorer = scorer; + self.filter_scorer_changed = true; + } + /// The main way to interact with the matcher, this should be called /// regularly (for example each time a frame is rendered). To avoid /// excessive redraws this method will wait `timeout` milliseconds for the /// worker thread to finish. It is recommend to set the timeout to 10ms. pub fn tick(&mut self, timeout: u64) -> Status { self.should_notify.store(false, atomic::Ordering::Relaxed); - let status = self.pattern.status(); + let mut status = self.pattern.status(); + // If filter or scorer changed, treat as rescore + if self.filter_scorer_changed { + if status == pattern::Status::Unchanged { + status = pattern::Status::Rescore; + } + self.filter_scorer_changed = false; + } let canceled = status != pattern::Status::Unchanged || self.state.canceled(); let mut res = self.tick_inner(timeout, canceled, status); if !canceled { @@ -434,6 +491,9 @@ impl<T: Sync + Send + 'static> Nucleo<T> { } if running { inner.pattern.clone_from(&self.pattern); + // Update filter and scorer in worker + inner.set_filter(self.filter.clone()); + inner.set_scorer(self.scorer.clone()); self.canceled.store(false, atomic::Ordering::Relaxed); if !canceled { self.should_notify.store(true, atomic::Ordering::Release); |
