diff options
| author | Michelle Tilley <michelle@michelletilley.net> | 2026-02-25 19:47:41 -0800 |
|---|---|---|
| committer | Ellie Huxtable <ellie@elliehuxtable.com> | 2026-03-16 15:36:14 -0700 |
| commit | 7049c9b7878b2a3be013272469f94ee39d8a7e2c (patch) | |
| tree | 3736e7e540341576d67e45c491e2f269e0acd061 /crates/atuin-nucleo/src/worker.rs | |
| parent | Update readme (diff) | |
| download | atuin-7049c9b7878b2a3be013272469f94ee39d8a7e2c.zip | |
feat: Add custom filtering and scoring mechanisms
Diffstat (limited to 'crates/atuin-nucleo/src/worker.rs')
| -rw-r--r-- | crates/atuin-nucleo/src/worker.rs | 128 |
1 files changed, 116 insertions, 12 deletions
diff --git a/crates/atuin-nucleo/src/worker.rs b/crates/atuin-nucleo/src/worker.rs index f4077e6e..ddd546ad 100644 --- a/crates/atuin-nucleo/src/worker.rs +++ b/crates/atuin-nucleo/src/worker.rs @@ -9,7 +9,7 @@ use rayon::{prelude::*, ThreadPool}; use crate::par_sort::par_quicksort; use crate::pattern::{self, MultiPattern}; -use crate::{boxcar, Match}; +use crate::{boxcar, Filter, Match, Scorer}; struct Matchers(Box<[UnsafeCell<nucleo_matcher::Matcher>]>); @@ -38,6 +38,8 @@ pub(crate) struct Worker<T: Sync + Send + 'static> { notify: Arc<(dyn Fn() + Sync + Send)>, pub(crate) items: Arc<boxcar::Vec<T>>, in_flight: Vec<u32>, + pub(crate) filter: Option<Filter<T>>, + pub(crate) scorer: Option<Scorer<T>>, } impl<T: Sync + Send + 'static> Worker<T> { @@ -56,6 +58,14 @@ impl<T: Sync + Send + 'static> Worker<T> { self.reverse_items = reverse_items; } + pub(crate) fn set_filter(&mut self, filter: Option<Filter<T>>) { + self.filter = filter; + } + + pub(crate) fn set_scorer(&mut self, scorer: Option<Scorer<T>>) { + self.scorer = scorer; + } + pub(crate) fn new( worker_threads: Option<usize>, config: Config, @@ -87,6 +97,8 @@ impl<T: Sync + Send + 'static> Worker<T> { notify, items: Arc::new(boxcar::Vec::with_capacity(2 * 1024, cols)), in_flight: Vec::with_capacity(64), + filter: None, + scorer: None, }; (pool, worker) } @@ -99,8 +111,22 @@ impl<T: Sync + Send + 'static> Worker<T> { let Some(item) = self.items.get(idx) else { return true; }; + // Apply filter if set + if let Some(ref filter) = self.filter { + if !filter(item.data) { + return false; // Item is ready but filtered out + } + } if let Some(score) = pattern.score(item.matcher_columns, matchers.get()) { - self.matches.push(Match { score, idx }); + let external_score = match &self.scorer { + Some(scorer) => scorer(item.data, score), + None => score, + }; + self.matches.push(Match { + score, + external_score, + idx, + }); }; false }); @@ -114,20 +140,45 @@ impl<T: Sync + Send + 'static> Worker<T> { unmatched.fetch_add(1, atomic::Ordering::Relaxed); return Match { score: 0, + external_score: 0, idx: u32::MAX, }; }; if self.canceled.load(atomic::Ordering::Relaxed) { - return Match { score: 0, idx }; + return Match { + score: 0, + external_score: 0, + idx, + }; + } + // Apply filter if set + if let Some(ref filter) = self.filter { + if !filter(item.data) { + unmatched.fetch_add(1, atomic::Ordering::Relaxed); + return Match { + score: 0, + external_score: 0, + idx: u32::MAX, + }; + } } let Some(score) = pattern.score(item.matcher_columns, matchers.get()) else { unmatched.fetch_add(1, atomic::Ordering::Relaxed); return Match { score: 0, + external_score: 0, idx: u32::MAX, }; }; - Match { score, idx } + let external_score = match &self.scorer { + Some(scorer) => scorer(item.data, score), + None => score, + }; + Match { + score, + external_score, + idx, + } }); self.matches.par_extend(items); self.last_snapshot = end; @@ -151,11 +202,23 @@ impl<T: Sync + Send + 'static> Worker<T> { if new_snapshot.end() != self.last_snapshot { let end = new_snapshot.end(); let items = new_snapshot.filter_map(|(idx, item)| { - if item.is_none() { - self.in_flight.push(idx); - return None; + let item = item?; + // Apply filter if set + if let Some(ref filter) = self.filter { + if !filter(item.data) { + return None; + } + } + // For empty pattern, apply scorer with score=0 if set + let external_score = match &self.scorer { + Some(scorer) => scorer(item.data, 0), + None => 0, }; - Some(Match { score: 0, idx }) + Some(Match { + score: 0, + external_score, + idx, + }) }); self.matches.extend(items); self.last_snapshot = end; @@ -205,11 +268,26 @@ impl<T: Sync + Send + 'static> Worker<T> { } // safety: in-flight items are never added to the matches let item = self.items.get_unchecked(match_.idx); + // Apply filter if set + if let Some(ref filter) = self.filter { + if !filter(item.data) { + unmatched.fetch_add(1, atomic::Ordering::Relaxed); + match_.score = 0; + match_.external_score = 0; + match_.idx = u32::MAX; + return; + } + } if let Some(score) = pattern.score(item.matcher_columns, matchers.get()) { match_.score = score; + match_.external_score = match &self.scorer { + Some(scorer) => scorer(item.data, score), + None => score, + }; } else { unmatched.fetch_add(1, atomic::Ordering::Relaxed); match_.score = 0; + match_.external_score = 0; match_.idx = u32::MAX; } }); @@ -234,8 +312,9 @@ impl<T: Sync + Send + 'static> Worker<T> { par_quicksort( &mut self.matches, |match1, match2| { - if match1.score != match2.score { - return match1.score > match2.score; + // Primary sort: external_score (used for frecency/custom ranking) + if match1.external_score != match2.external_score { + return match1.external_score > match2.external_score; } if match1.idx == u32::MAX { return false; @@ -293,8 +372,33 @@ impl<T: Sync + Send + 'static> Worker<T> { fn reset_matches(&mut self) { self.matches.clear(); - self.matches - .extend((0..self.last_snapshot).map(|idx| Match { score: 0, idx })); + // When resetting, apply filter if set + if let Some(ref filter) = self.filter { + for idx in 0..self.last_snapshot { + // Items up to last_snapshot should be initialized + if let Some(item) = self.items.get(idx) { + if filter(item.data) { + let external_score = match &self.scorer { + Some(scorer) => scorer(item.data, 0), + None => 0, + }; + self.matches.push(Match { + score: 0, + external_score, + idx, + }); + } + } + } + } else { + // No filter - add all items + self.matches + .extend((0..self.last_snapshot).map(|idx| Match { + score: 0, + external_score: 0, + idx, + })); + } // there are usually only very few in flight items (one for each writer) self.remove_in_flight_matches(); } |
