aboutsummaryrefslogtreecommitdiffstats
path: root/crates/atuin-nucleo/src/worker.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/atuin-nucleo/src/worker.rs')
-rw-r--r--crates/atuin-nucleo/src/worker.rs128
1 files changed, 116 insertions, 12 deletions
diff --git a/crates/atuin-nucleo/src/worker.rs b/crates/atuin-nucleo/src/worker.rs
index f4077e6e..ddd546ad 100644
--- a/crates/atuin-nucleo/src/worker.rs
+++ b/crates/atuin-nucleo/src/worker.rs
@@ -9,7 +9,7 @@ use rayon::{prelude::*, ThreadPool};
use crate::par_sort::par_quicksort;
use crate::pattern::{self, MultiPattern};
-use crate::{boxcar, Match};
+use crate::{boxcar, Filter, Match, Scorer};
struct Matchers(Box<[UnsafeCell<nucleo_matcher::Matcher>]>);
@@ -38,6 +38,8 @@ pub(crate) struct Worker<T: Sync + Send + 'static> {
notify: Arc<(dyn Fn() + Sync + Send)>,
pub(crate) items: Arc<boxcar::Vec<T>>,
in_flight: Vec<u32>,
+ pub(crate) filter: Option<Filter<T>>,
+ pub(crate) scorer: Option<Scorer<T>>,
}
impl<T: Sync + Send + 'static> Worker<T> {
@@ -56,6 +58,14 @@ impl<T: Sync + Send + 'static> Worker<T> {
self.reverse_items = reverse_items;
}
+ pub(crate) fn set_filter(&mut self, filter: Option<Filter<T>>) {
+ self.filter = filter;
+ }
+
+ pub(crate) fn set_scorer(&mut self, scorer: Option<Scorer<T>>) {
+ self.scorer = scorer;
+ }
+
pub(crate) fn new(
worker_threads: Option<usize>,
config: Config,
@@ -87,6 +97,8 @@ impl<T: Sync + Send + 'static> Worker<T> {
notify,
items: Arc::new(boxcar::Vec::with_capacity(2 * 1024, cols)),
in_flight: Vec::with_capacity(64),
+ filter: None,
+ scorer: None,
};
(pool, worker)
}
@@ -99,8 +111,22 @@ impl<T: Sync + Send + 'static> Worker<T> {
let Some(item) = self.items.get(idx) else {
return true;
};
+ // Apply filter if set
+ if let Some(ref filter) = self.filter {
+ if !filter(item.data) {
+ return false; // Item is ready but filtered out
+ }
+ }
if let Some(score) = pattern.score(item.matcher_columns, matchers.get()) {
- self.matches.push(Match { score, idx });
+ let external_score = match &self.scorer {
+ Some(scorer) => scorer(item.data, score),
+ None => score,
+ };
+ self.matches.push(Match {
+ score,
+ external_score,
+ idx,
+ });
};
false
});
@@ -114,20 +140,45 @@ impl<T: Sync + Send + 'static> Worker<T> {
unmatched.fetch_add(1, atomic::Ordering::Relaxed);
return Match {
score: 0,
+ external_score: 0,
idx: u32::MAX,
};
};
if self.canceled.load(atomic::Ordering::Relaxed) {
- return Match { score: 0, idx };
+ return Match {
+ score: 0,
+ external_score: 0,
+ idx,
+ };
+ }
+ // Apply filter if set
+ if let Some(ref filter) = self.filter {
+ if !filter(item.data) {
+ unmatched.fetch_add(1, atomic::Ordering::Relaxed);
+ return Match {
+ score: 0,
+ external_score: 0,
+ idx: u32::MAX,
+ };
+ }
}
let Some(score) = pattern.score(item.matcher_columns, matchers.get()) else {
unmatched.fetch_add(1, atomic::Ordering::Relaxed);
return Match {
score: 0,
+ external_score: 0,
idx: u32::MAX,
};
};
- Match { score, idx }
+ let external_score = match &self.scorer {
+ Some(scorer) => scorer(item.data, score),
+ None => score,
+ };
+ Match {
+ score,
+ external_score,
+ idx,
+ }
});
self.matches.par_extend(items);
self.last_snapshot = end;
@@ -151,11 +202,23 @@ impl<T: Sync + Send + 'static> Worker<T> {
if new_snapshot.end() != self.last_snapshot {
let end = new_snapshot.end();
let items = new_snapshot.filter_map(|(idx, item)| {
- if item.is_none() {
- self.in_flight.push(idx);
- return None;
+ let item = item?;
+ // Apply filter if set
+ if let Some(ref filter) = self.filter {
+ if !filter(item.data) {
+ return None;
+ }
+ }
+ // For empty pattern, apply scorer with score=0 if set
+ let external_score = match &self.scorer {
+ Some(scorer) => scorer(item.data, 0),
+ None => 0,
};
- Some(Match { score: 0, idx })
+ Some(Match {
+ score: 0,
+ external_score,
+ idx,
+ })
});
self.matches.extend(items);
self.last_snapshot = end;
@@ -205,11 +268,26 @@ impl<T: Sync + Send + 'static> Worker<T> {
}
// safety: in-flight items are never added to the matches
let item = self.items.get_unchecked(match_.idx);
+ // Apply filter if set
+ if let Some(ref filter) = self.filter {
+ if !filter(item.data) {
+ unmatched.fetch_add(1, atomic::Ordering::Relaxed);
+ match_.score = 0;
+ match_.external_score = 0;
+ match_.idx = u32::MAX;
+ return;
+ }
+ }
if let Some(score) = pattern.score(item.matcher_columns, matchers.get()) {
match_.score = score;
+ match_.external_score = match &self.scorer {
+ Some(scorer) => scorer(item.data, score),
+ None => score,
+ };
} else {
unmatched.fetch_add(1, atomic::Ordering::Relaxed);
match_.score = 0;
+ match_.external_score = 0;
match_.idx = u32::MAX;
}
});
@@ -234,8 +312,9 @@ impl<T: Sync + Send + 'static> Worker<T> {
par_quicksort(
&mut self.matches,
|match1, match2| {
- if match1.score != match2.score {
- return match1.score > match2.score;
+ // Primary sort: external_score (used for frecency/custom ranking)
+ if match1.external_score != match2.external_score {
+ return match1.external_score > match2.external_score;
}
if match1.idx == u32::MAX {
return false;
@@ -293,8 +372,33 @@ impl<T: Sync + Send + 'static> Worker<T> {
fn reset_matches(&mut self) {
self.matches.clear();
- self.matches
- .extend((0..self.last_snapshot).map(|idx| Match { score: 0, idx }));
+ // When resetting, apply filter if set
+ if let Some(ref filter) = self.filter {
+ for idx in 0..self.last_snapshot {
+ // Items up to last_snapshot should be initialized
+ if let Some(item) = self.items.get(idx) {
+ if filter(item.data) {
+ let external_score = match &self.scorer {
+ Some(scorer) => scorer(item.data, 0),
+ None => 0,
+ };
+ self.matches.push(Match {
+ score: 0,
+ external_score,
+ idx,
+ });
+ }
+ }
+ }
+ } else {
+ // No filter - add all items
+ self.matches
+ .extend((0..self.last_snapshot).map(|idx| Match {
+ score: 0,
+ external_score: 0,
+ idx,
+ }));
+ }
// there are usually only very few in flight items (one for each writer)
self.remove_in_flight_matches();
}