aboutsummaryrefslogtreecommitdiffstats
path: root/crates/atuin-nucleo
diff options
context:
space:
mode:
Diffstat (limited to 'crates/atuin-nucleo')
-rw-r--r--crates/atuin-nucleo/src/lib.rs62
-rw-r--r--crates/atuin-nucleo/src/tests.rs242
-rw-r--r--crates/atuin-nucleo/src/worker.rs128
3 files changed, 418 insertions, 14 deletions
diff --git a/crates/atuin-nucleo/src/lib.rs b/crates/atuin-nucleo/src/lib.rs
index 7ddb7407..5a500481 100644
--- a/crates/atuin-nucleo/src/lib.rs
+++ b/crates/atuin-nucleo/src/lib.rs
@@ -31,6 +31,15 @@ use std::sync::atomic::{self, AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
+/// A filter predicate that determines whether an item should be included in matching.
+/// Return `true` to include the item, `false` to skip it.
+pub type Filter<T> = Arc<dyn Fn(&T) -> bool + Send + Sync>;
+
+/// A scorer callback that computes the final ranking score for an item.
+/// Receives a reference to the item and its fuzzy match score.
+/// Returns the combined/external score used for sorting results.
+pub type Scorer<T> = Arc<dyn Fn(&T, u32) -> u32 + Send + Sync>;
+
use parking_lot::Mutex;
use rayon::ThreadPool;
@@ -124,7 +133,13 @@ impl<T> Injector<T> {
/// An [item](crate::Item) that was successfully matched by a [`Nucleo`] worker.
#[derive(PartialEq, Eq, Debug, Clone, Copy)]
pub struct Match {
+ /// The raw fuzzy match score from the matcher.
pub score: u32,
+ /// The external/combined score used for sorting.
+ /// If no scorer callback is set, this equals `score`.
+ /// If a scorer callback is set, this is the value returned by the callback.
+ pub external_score: u32,
+ /// The index of the matched item in the item list.
pub idx: u32,
}
@@ -290,6 +305,12 @@ pub struct Nucleo<T: Sync + Send + 'static> {
/// Note that the matcher worker will only become aware of the new pattern
/// after a call to [`tick`](Nucleo::tick).
pub pattern: MultiPattern,
+ /// Optional filter predicate. Items where filter returns false are skipped.
+ filter: Option<Filter<T>>,
+ /// Optional scorer callback. Returns combined score used for sorting.
+ scorer: Option<Scorer<T>>,
+ /// Flag indicating filter or scorer has changed and rescore is needed.
+ filter_scorer_changed: bool,
}
impl<T: Sync + Send + 'static> Nucleo<T> {
@@ -328,6 +349,9 @@ impl<T: Sync + Send + 'static> Nucleo<T> {
worker: Arc::new(Mutex::new(worker)),
state: State::Init,
notify,
+ filter: None,
+ scorer: None,
+ filter_scorer_changed: false,
}
}
@@ -388,13 +412,46 @@ impl<T: Sync + Send + 'static> Nucleo<T> {
self.worker.lock().reverse_items(reverse_items)
}
+ /// Set a filter predicate. Items where the filter returns `false` are
+ /// skipped during matching. This is applied before fuzzy matching, so
+ /// filtered items don't incur the cost of fuzzy matching.
+ ///
+ /// Setting a new filter triggers a rescore on the next [`tick`](Nucleo::tick).
+ ///
+ /// Pass `None` to remove the filter.
+ pub fn set_filter(&mut self, filter: Option<Filter<T>>) {
+ self.filter = filter;
+ self.filter_scorer_changed = true;
+ }
+
+ /// Set a scorer callback. The callback receives a reference to the item
+ /// and its fuzzy match score, and returns the combined score used for
+ /// sorting results.
+ ///
+ /// If no scorer is set, results are sorted by fuzzy match score.
+ ///
+ /// Setting a new scorer triggers a rescore on the next [`tick`](Nucleo::tick).
+ ///
+ /// Pass `None` to remove the scorer and use default fuzzy score sorting.
+ pub fn set_scorer(&mut self, scorer: Option<Scorer<T>>) {
+ self.scorer = scorer;
+ self.filter_scorer_changed = true;
+ }
+
/// The main way to interact with the matcher, this should be called
/// regularly (for example each time a frame is rendered). To avoid
/// excessive redraws this method will wait `timeout` milliseconds for the
/// worker thread to finish. It is recommend to set the timeout to 10ms.
pub fn tick(&mut self, timeout: u64) -> Status {
self.should_notify.store(false, atomic::Ordering::Relaxed);
- let status = self.pattern.status();
+ let mut status = self.pattern.status();
+ // If filter or scorer changed, treat as rescore
+ if self.filter_scorer_changed {
+ if status == pattern::Status::Unchanged {
+ status = pattern::Status::Rescore;
+ }
+ self.filter_scorer_changed = false;
+ }
let canceled = status != pattern::Status::Unchanged || self.state.canceled();
let mut res = self.tick_inner(timeout, canceled, status);
if !canceled {
@@ -434,6 +491,9 @@ impl<T: Sync + Send + 'static> Nucleo<T> {
}
if running {
inner.pattern.clone_from(&self.pattern);
+ // Update filter and scorer in worker
+ inner.set_filter(self.filter.clone());
+ inner.set_scorer(self.scorer.clone());
self.canceled.store(false, atomic::Ordering::Relaxed);
if !canceled {
self.should_notify.store(true, atomic::Ordering::Release);
diff --git a/crates/atuin-nucleo/src/tests.rs b/crates/atuin-nucleo/src/tests.rs
index 676c50df..96c4d99c 100644
--- a/crates/atuin-nucleo/src/tests.rs
+++ b/crates/atuin-nucleo/src/tests.rs
@@ -2,7 +2,7 @@ use std::sync::Arc;
use nucleo_matcher::Config;
-use crate::Nucleo;
+use crate::{pattern, Nucleo};
#[test]
fn active_injector_count() {
@@ -25,3 +25,243 @@ fn active_injector_count() {
drop(injector3);
assert_eq!(nucleo.active_injectors(), 0);
}
+
+#[derive(Clone, Debug)]
+struct TestItem {
+ text: String,
+ category: u32,
+ priority: u32,
+}
+
+#[test]
+fn filter_excludes_items() {
+ let mut nucleo: Nucleo<TestItem> = Nucleo::new(Config::DEFAULT, Arc::new(|| ()), Some(1), 1);
+ let injector = nucleo.injector();
+
+ // Add items with different categories
+ injector.push(
+ TestItem {
+ text: "apple".into(),
+ category: 1,
+ priority: 10,
+ },
+ |item, cols| cols[0] = item.text.clone().into(),
+ );
+ injector.push(
+ TestItem {
+ text: "apricot".into(),
+ category: 2,
+ priority: 20,
+ },
+ |item, cols| cols[0] = item.text.clone().into(),
+ );
+ injector.push(
+ TestItem {
+ text: "avocado".into(),
+ category: 1,
+ priority: 30,
+ },
+ |item, cols| cols[0] = item.text.clone().into(),
+ );
+
+ // Search without filter - should get all 3
+ nucleo.pattern.reparse(
+ 0,
+ "a",
+ pattern::CaseMatching::Ignore,
+ pattern::Normalization::Smart,
+ false,
+ );
+ while nucleo.tick(10).running {}
+ assert_eq!(nucleo.snapshot().matched_item_count(), 3);
+
+ // Set filter to only include category 1
+ nucleo.set_filter(Some(Arc::new(|item: &TestItem| item.category == 1)));
+
+ // Search again - should only get 2 items (apple, avocado)
+ while nucleo.tick(10).running {}
+ assert_eq!(nucleo.snapshot().matched_item_count(), 2);
+
+ // Verify the items are correct
+ let items: Vec<_> = nucleo
+ .snapshot()
+ .matched_items(..)
+ .map(|i| i.data.text.clone())
+ .collect();
+ assert!(items.contains(&"apple".to_string()));
+ assert!(items.contains(&"avocado".to_string()));
+ assert!(!items.contains(&"apricot".to_string()));
+
+ // Remove filter - should get all 3 again
+ nucleo.set_filter(None);
+ while nucleo.tick(10).running {}
+ assert_eq!(nucleo.snapshot().matched_item_count(), 3);
+}
+
+#[test]
+fn scorer_affects_sort_order() {
+ let mut nucleo: Nucleo<TestItem> = Nucleo::new(Config::DEFAULT, Arc::new(|| ()), Some(1), 1);
+ let injector = nucleo.injector();
+
+ // Add items with different priorities
+ injector.push(
+ TestItem {
+ text: "banana".into(),
+ category: 1,
+ priority: 10,
+ },
+ |item, cols| cols[0] = item.text.clone().into(),
+ );
+ injector.push(
+ TestItem {
+ text: "blueberry".into(),
+ category: 1,
+ priority: 100,
+ },
+ |item, cols| cols[0] = item.text.clone().into(),
+ );
+ injector.push(
+ TestItem {
+ text: "blackberry".into(),
+ category: 1,
+ priority: 50,
+ },
+ |item, cols| cols[0] = item.text.clone().into(),
+ );
+
+ // Search without scorer - results sorted by fuzzy score
+ nucleo.pattern.reparse(
+ 0,
+ "b",
+ pattern::CaseMatching::Ignore,
+ pattern::Normalization::Smart,
+ false,
+ );
+ while nucleo.tick(10).running {}
+ assert_eq!(nucleo.snapshot().matched_item_count(), 3);
+
+ // Set scorer that uses priority as the score (ignoring fuzzy score)
+ nucleo.set_scorer(Some(Arc::new(|item: &TestItem, _fuzzy_score| {
+ item.priority
+ })));
+
+ // Search again - should be sorted by priority (high to low)
+ while nucleo.tick(10).running {}
+ let items: Vec<_> = nucleo
+ .snapshot()
+ .matched_items(..)
+ .map(|i| i.data.clone())
+ .collect();
+ assert_eq!(items.len(), 3);
+ assert_eq!(items[0].text, "blueberry"); // priority 100
+ assert_eq!(items[1].text, "blackberry"); // priority 50
+ assert_eq!(items[2].text, "banana"); // priority 10
+
+ // Verify external_score is set correctly
+ let matches = nucleo.snapshot().matches();
+ assert_eq!(matches[0].external_score, 100);
+ assert_eq!(matches[1].external_score, 50);
+ assert_eq!(matches[2].external_score, 10);
+}
+
+#[test]
+fn filter_and_scorer_combined() {
+ let mut nucleo: Nucleo<TestItem> = Nucleo::new(Config::DEFAULT, Arc::new(|| ()), Some(1), 1);
+ let injector = nucleo.injector();
+
+ injector.push(
+ TestItem {
+ text: "cherry".into(),
+ category: 1,
+ priority: 10,
+ },
+ |item, cols| cols[0] = item.text.clone().into(),
+ );
+ injector.push(
+ TestItem {
+ text: "cranberry".into(),
+ category: 2,
+ priority: 100,
+ },
+ |item, cols| cols[0] = item.text.clone().into(),
+ );
+ injector.push(
+ TestItem {
+ text: "coconut".into(),
+ category: 1,
+ priority: 50,
+ },
+ |item, cols| cols[0] = item.text.clone().into(),
+ );
+
+ // Set both filter (category 1) and scorer (priority)
+ nucleo.set_filter(Some(Arc::new(|item: &TestItem| item.category == 1)));
+ nucleo.set_scorer(Some(Arc::new(|item: &TestItem, _| item.priority)));
+
+ nucleo.pattern.reparse(
+ 0,
+ "c",
+ pattern::CaseMatching::Ignore,
+ pattern::Normalization::Smart,
+ false,
+ );
+ while nucleo.tick(10).running {}
+
+ // Should have 2 items (cherry, coconut) sorted by priority
+ let items: Vec<_> = nucleo
+ .snapshot()
+ .matched_items(..)
+ .map(|i| i.data.clone())
+ .collect();
+ assert_eq!(items.len(), 2);
+ assert_eq!(items[0].text, "coconut"); // priority 50
+ assert_eq!(items[1].text, "cherry"); // priority 10
+}
+
+#[test]
+fn scorer_combines_with_fuzzy_score() {
+ let mut nucleo: Nucleo<TestItem> = Nucleo::new(Config::DEFAULT, Arc::new(|| ()), Some(1), 1);
+ let injector = nucleo.injector();
+
+ injector.push(
+ TestItem {
+ text: "date".into(),
+ category: 1,
+ priority: 100,
+ },
+ |item, cols| cols[0] = item.text.clone().into(),
+ );
+ injector.push(
+ TestItem {
+ text: "dragon fruit".into(),
+ category: 1,
+ priority: 10,
+ },
+ |item, cols| cols[0] = item.text.clone().into(),
+ );
+
+ // Set scorer that combines fuzzy score with priority
+ nucleo.set_scorer(Some(Arc::new(|item: &TestItem, fuzzy_score| {
+ fuzzy_score + item.priority
+ })));
+
+ nucleo.pattern.reparse(
+ 0,
+ "d",
+ pattern::CaseMatching::Ignore,
+ pattern::Normalization::Smart,
+ false,
+ );
+ while nucleo.tick(10).running {}
+
+ // Both items match, verify that external_score includes priority boost
+ let matches = nucleo.snapshot().matches();
+ assert_eq!(matches.len(), 2);
+
+ // The raw fuzzy scores should be in Match.score
+ // The combined scores should be in Match.external_score
+ for m in matches {
+ let item = nucleo.snapshot().get_item(m.idx).unwrap();
+ assert_eq!(m.external_score, m.score + item.data.priority);
+ }
+}
diff --git a/crates/atuin-nucleo/src/worker.rs b/crates/atuin-nucleo/src/worker.rs
index f4077e6e..ddd546ad 100644
--- a/crates/atuin-nucleo/src/worker.rs
+++ b/crates/atuin-nucleo/src/worker.rs
@@ -9,7 +9,7 @@ use rayon::{prelude::*, ThreadPool};
use crate::par_sort::par_quicksort;
use crate::pattern::{self, MultiPattern};
-use crate::{boxcar, Match};
+use crate::{boxcar, Filter, Match, Scorer};
struct Matchers(Box<[UnsafeCell<nucleo_matcher::Matcher>]>);
@@ -38,6 +38,8 @@ pub(crate) struct Worker<T: Sync + Send + 'static> {
notify: Arc<(dyn Fn() + Sync + Send)>,
pub(crate) items: Arc<boxcar::Vec<T>>,
in_flight: Vec<u32>,
+ pub(crate) filter: Option<Filter<T>>,
+ pub(crate) scorer: Option<Scorer<T>>,
}
impl<T: Sync + Send + 'static> Worker<T> {
@@ -56,6 +58,14 @@ impl<T: Sync + Send + 'static> Worker<T> {
self.reverse_items = reverse_items;
}
+ pub(crate) fn set_filter(&mut self, filter: Option<Filter<T>>) {
+ self.filter = filter;
+ }
+
+ pub(crate) fn set_scorer(&mut self, scorer: Option<Scorer<T>>) {
+ self.scorer = scorer;
+ }
+
pub(crate) fn new(
worker_threads: Option<usize>,
config: Config,
@@ -87,6 +97,8 @@ impl<T: Sync + Send + 'static> Worker<T> {
notify,
items: Arc::new(boxcar::Vec::with_capacity(2 * 1024, cols)),
in_flight: Vec::with_capacity(64),
+ filter: None,
+ scorer: None,
};
(pool, worker)
}
@@ -99,8 +111,22 @@ impl<T: Sync + Send + 'static> Worker<T> {
let Some(item) = self.items.get(idx) else {
return true;
};
+ // Apply filter if set
+ if let Some(ref filter) = self.filter {
+ if !filter(item.data) {
+ return false; // Item is ready but filtered out
+ }
+ }
if let Some(score) = pattern.score(item.matcher_columns, matchers.get()) {
- self.matches.push(Match { score, idx });
+ let external_score = match &self.scorer {
+ Some(scorer) => scorer(item.data, score),
+ None => score,
+ };
+ self.matches.push(Match {
+ score,
+ external_score,
+ idx,
+ });
};
false
});
@@ -114,20 +140,45 @@ impl<T: Sync + Send + 'static> Worker<T> {
unmatched.fetch_add(1, atomic::Ordering::Relaxed);
return Match {
score: 0,
+ external_score: 0,
idx: u32::MAX,
};
};
if self.canceled.load(atomic::Ordering::Relaxed) {
- return Match { score: 0, idx };
+ return Match {
+ score: 0,
+ external_score: 0,
+ idx,
+ };
+ }
+ // Apply filter if set
+ if let Some(ref filter) = self.filter {
+ if !filter(item.data) {
+ unmatched.fetch_add(1, atomic::Ordering::Relaxed);
+ return Match {
+ score: 0,
+ external_score: 0,
+ idx: u32::MAX,
+ };
+ }
}
let Some(score) = pattern.score(item.matcher_columns, matchers.get()) else {
unmatched.fetch_add(1, atomic::Ordering::Relaxed);
return Match {
score: 0,
+ external_score: 0,
idx: u32::MAX,
};
};
- Match { score, idx }
+ let external_score = match &self.scorer {
+ Some(scorer) => scorer(item.data, score),
+ None => score,
+ };
+ Match {
+ score,
+ external_score,
+ idx,
+ }
});
self.matches.par_extend(items);
self.last_snapshot = end;
@@ -151,11 +202,23 @@ impl<T: Sync + Send + 'static> Worker<T> {
if new_snapshot.end() != self.last_snapshot {
let end = new_snapshot.end();
let items = new_snapshot.filter_map(|(idx, item)| {
- if item.is_none() {
- self.in_flight.push(idx);
- return None;
+ let item = item?;
+ // Apply filter if set
+ if let Some(ref filter) = self.filter {
+ if !filter(item.data) {
+ return None;
+ }
+ }
+ // For empty pattern, apply scorer with score=0 if set
+ let external_score = match &self.scorer {
+ Some(scorer) => scorer(item.data, 0),
+ None => 0,
};
- Some(Match { score: 0, idx })
+ Some(Match {
+ score: 0,
+ external_score,
+ idx,
+ })
});
self.matches.extend(items);
self.last_snapshot = end;
@@ -205,11 +268,26 @@ impl<T: Sync + Send + 'static> Worker<T> {
}
// safety: in-flight items are never added to the matches
let item = self.items.get_unchecked(match_.idx);
+ // Apply filter if set
+ if let Some(ref filter) = self.filter {
+ if !filter(item.data) {
+ unmatched.fetch_add(1, atomic::Ordering::Relaxed);
+ match_.score = 0;
+ match_.external_score = 0;
+ match_.idx = u32::MAX;
+ return;
+ }
+ }
if let Some(score) = pattern.score(item.matcher_columns, matchers.get()) {
match_.score = score;
+ match_.external_score = match &self.scorer {
+ Some(scorer) => scorer(item.data, score),
+ None => score,
+ };
} else {
unmatched.fetch_add(1, atomic::Ordering::Relaxed);
match_.score = 0;
+ match_.external_score = 0;
match_.idx = u32::MAX;
}
});
@@ -234,8 +312,9 @@ impl<T: Sync + Send + 'static> Worker<T> {
par_quicksort(
&mut self.matches,
|match1, match2| {
- if match1.score != match2.score {
- return match1.score > match2.score;
+ // Primary sort: external_score (used for frecency/custom ranking)
+ if match1.external_score != match2.external_score {
+ return match1.external_score > match2.external_score;
}
if match1.idx == u32::MAX {
return false;
@@ -293,8 +372,33 @@ impl<T: Sync + Send + 'static> Worker<T> {
fn reset_matches(&mut self) {
self.matches.clear();
- self.matches
- .extend((0..self.last_snapshot).map(|idx| Match { score: 0, idx }));
+ // When resetting, apply filter if set
+ if let Some(ref filter) = self.filter {
+ for idx in 0..self.last_snapshot {
+ // Items up to last_snapshot should be initialized
+ if let Some(item) = self.items.get(idx) {
+ if filter(item.data) {
+ let external_score = match &self.scorer {
+ Some(scorer) => scorer(item.data, 0),
+ None => 0,
+ };
+ self.matches.push(Match {
+ score: 0,
+ external_score,
+ idx,
+ });
+ }
+ }
+ }
+ } else {
+ // No filter - add all items
+ self.matches
+ .extend((0..self.last_snapshot).map(|idx| Match {
+ score: 0,
+ external_score: 0,
+ idx,
+ }));
+ }
// there are usually only very few in flight items (one for each writer)
self.remove_in_flight_matches();
}