aboutsummaryrefslogtreecommitdiffstats
path: root/crates/atuin-nucleo/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/atuin-nucleo/src/lib.rs')
-rw-r--r--crates/atuin-nucleo/src/lib.rs62
1 files changed, 61 insertions, 1 deletions
diff --git a/crates/atuin-nucleo/src/lib.rs b/crates/atuin-nucleo/src/lib.rs
index 7ddb7407..5a500481 100644
--- a/crates/atuin-nucleo/src/lib.rs
+++ b/crates/atuin-nucleo/src/lib.rs
@@ -31,6 +31,15 @@ use std::sync::atomic::{self, AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
+/// A filter predicate that determines whether an item should be included in matching.
+/// Return `true` to include the item, `false` to skip it.
+pub type Filter<T> = Arc<dyn Fn(&T) -> bool + Send + Sync>;
+
+/// A scorer callback that computes the final ranking score for an item.
+/// Receives a reference to the item and its fuzzy match score.
+/// Returns the combined/external score used for sorting results.
+pub type Scorer<T> = Arc<dyn Fn(&T, u32) -> u32 + Send + Sync>;
+
use parking_lot::Mutex;
use rayon::ThreadPool;
@@ -124,7 +133,13 @@ impl<T> Injector<T> {
/// An [item](crate::Item) that was successfully matched by a [`Nucleo`] worker.
#[derive(PartialEq, Eq, Debug, Clone, Copy)]
pub struct Match {
+ /// The raw fuzzy match score from the matcher.
pub score: u32,
+ /// The external/combined score used for sorting.
+ /// If no scorer callback is set, this equals `score`.
+ /// If a scorer callback is set, this is the value returned by the callback.
+ pub external_score: u32,
+ /// The index of the matched item in the item list.
pub idx: u32,
}
@@ -290,6 +305,12 @@ pub struct Nucleo<T: Sync + Send + 'static> {
/// Note that the matcher worker will only become aware of the new pattern
/// after a call to [`tick`](Nucleo::tick).
pub pattern: MultiPattern,
+ /// Optional filter predicate. Items where filter returns false are skipped.
+ filter: Option<Filter<T>>,
+ /// Optional scorer callback. Returns combined score used for sorting.
+ scorer: Option<Scorer<T>>,
+ /// Flag indicating filter or scorer has changed and rescore is needed.
+ filter_scorer_changed: bool,
}
impl<T: Sync + Send + 'static> Nucleo<T> {
@@ -328,6 +349,9 @@ impl<T: Sync + Send + 'static> Nucleo<T> {
worker: Arc::new(Mutex::new(worker)),
state: State::Init,
notify,
+ filter: None,
+ scorer: None,
+ filter_scorer_changed: false,
}
}
@@ -388,13 +412,46 @@ impl<T: Sync + Send + 'static> Nucleo<T> {
self.worker.lock().reverse_items(reverse_items)
}
+ /// Set a filter predicate. Items where the filter returns `false` are
+ /// skipped during matching. This is applied before fuzzy matching, so
+ /// filtered items don't incur the cost of fuzzy matching.
+ ///
+ /// Setting a new filter triggers a rescore on the next [`tick`](Nucleo::tick).
+ ///
+ /// Pass `None` to remove the filter.
+ pub fn set_filter(&mut self, filter: Option<Filter<T>>) {
+ self.filter = filter;
+ self.filter_scorer_changed = true;
+ }
+
+ /// Set a scorer callback. The callback receives a reference to the item
+ /// and its fuzzy match score, and returns the combined score used for
+ /// sorting results.
+ ///
+ /// If no scorer is set, results are sorted by fuzzy match score.
+ ///
+ /// Setting a new scorer triggers a rescore on the next [`tick`](Nucleo::tick).
+ ///
+ /// Pass `None` to remove the scorer and use default fuzzy score sorting.
+ pub fn set_scorer(&mut self, scorer: Option<Scorer<T>>) {
+ self.scorer = scorer;
+ self.filter_scorer_changed = true;
+ }
+
/// The main way to interact with the matcher, this should be called
/// regularly (for example each time a frame is rendered). To avoid
/// excessive redraws this method will wait `timeout` milliseconds for the
/// worker thread to finish. It is recommend to set the timeout to 10ms.
pub fn tick(&mut self, timeout: u64) -> Status {
self.should_notify.store(false, atomic::Ordering::Relaxed);
- let status = self.pattern.status();
+ let mut status = self.pattern.status();
+ // If filter or scorer changed, treat as rescore
+ if self.filter_scorer_changed {
+ if status == pattern::Status::Unchanged {
+ status = pattern::Status::Rescore;
+ }
+ self.filter_scorer_changed = false;
+ }
let canceled = status != pattern::Status::Unchanged || self.state.canceled();
let mut res = self.tick_inner(timeout, canceled, status);
if !canceled {
@@ -434,6 +491,9 @@ impl<T: Sync + Send + 'static> Nucleo<T> {
}
if running {
inner.pattern.clone_from(&self.pattern);
+ // Update filter and scorer in worker
+ inner.set_filter(self.filter.clone());
+ inner.set_scorer(self.scorer.clone());
self.canceled.store(false, atomic::Ordering::Relaxed);
if !canceled {
self.should_notify.store(true, atomic::Ordering::Release);