aboutsummaryrefslogtreecommitdiffstats
path: root/src/worker.rs
diff options
context:
space:
mode:
authorEllie Huxtable <ellie@elliehuxtable.com>2026-03-16 15:22:49 -0700
committerEllie Huxtable <ellie@elliehuxtable.com>2026-03-16 15:22:49 -0700
commit8f9777ce7aecfe1a163a915e3245466b9dd9ac2e (patch)
treea878a0495e5f84266b7e36918a9e1a9432b0ddd8 /src/worker.rs
downloadatuin-8f9777ce7aecfe1a163a915e3245466b9dd9ac2e.zip
Squashed 'crates/atuin-nucleo/' content from commit 4253de9f
git-subtree-dir: crates/atuin-nucleo git-subtree-split: 4253de9faabb4e5c6d81d946a5e35a90f87347ee
Diffstat (limited to 'src/worker.rs')
-rw-r--r--src/worker.rs301
1 files changed, 301 insertions, 0 deletions
diff --git a/src/worker.rs b/src/worker.rs
new file mode 100644
index 00000000..f4077e6e
--- /dev/null
+++ b/src/worker.rs
@@ -0,0 +1,301 @@
+use std::cell::UnsafeCell;
+use std::mem::take;
+use std::sync::atomic::{self, AtomicBool, AtomicU32};
+use std::sync::Arc;
+
+use nucleo_matcher::Config;
+use parking_lot::Mutex;
+use rayon::{prelude::*, ThreadPool};
+
+use crate::par_sort::par_quicksort;
+use crate::pattern::{self, MultiPattern};
+use crate::{boxcar, Match};
+
+struct Matchers(Box<[UnsafeCell<nucleo_matcher::Matcher>]>);
+
+impl Matchers {
+ // this is not a true mut from ref, we use a cell here
+ #[allow(clippy::mut_from_ref)]
+ unsafe fn get(&self) -> &mut nucleo_matcher::Matcher {
+ &mut *self.0[rayon::current_thread_index().unwrap()].get()
+ }
+}
+
+unsafe impl Sync for Matchers {}
+unsafe impl Send for Matchers {}
+
+pub(crate) struct Worker<T: Sync + Send + 'static> {
+ pub(crate) running: bool,
+ matchers: Matchers,
+ pub(crate) matches: Vec<Match>,
+ pub(crate) pattern: MultiPattern,
+ pub(crate) sort_results: bool,
+ pub(crate) reverse_items: bool,
+ pub(crate) canceled: Arc<AtomicBool>,
+ pub(crate) should_notify: Arc<AtomicBool>,
+ pub(crate) was_canceled: bool,
+ pub(crate) last_snapshot: u32,
+ notify: Arc<(dyn Fn() + Sync + Send)>,
+ pub(crate) items: Arc<boxcar::Vec<T>>,
+ in_flight: Vec<u32>,
+}
+
+impl<T: Sync + Send + 'static> Worker<T> {
+ pub(crate) fn item_count(&self) -> u32 {
+ self.last_snapshot - self.in_flight.len() as u32
+ }
+ pub(crate) fn update_config(&mut self, config: Config) {
+ for matcher in self.matchers.0.iter_mut() {
+ matcher.get_mut().config = config.clone();
+ }
+ }
+ pub(crate) fn sort_results(&mut self, sort_results: bool) {
+ self.sort_results = sort_results;
+ }
+ pub(crate) fn reverse_items(&mut self, reverse_items: bool) {
+ self.reverse_items = reverse_items;
+ }
+
+ pub(crate) fn new(
+ worker_threads: Option<usize>,
+ config: Config,
+ notify: Arc<(dyn Fn() + Sync + Send)>,
+ cols: u32,
+ ) -> (ThreadPool, Self) {
+ let worker_threads = worker_threads
+ .unwrap_or_else(|| std::thread::available_parallelism().map_or(4, |it| it.get()));
+ let pool = rayon::ThreadPoolBuilder::new()
+ .thread_name(|i| format!("nucleo worker {i}"))
+ .num_threads(worker_threads)
+ .build()
+ .expect("creating threadpool failed");
+ let matchers = (0..worker_threads)
+ .map(|_| UnsafeCell::new(nucleo_matcher::Matcher::new(config.clone())))
+ .collect();
+ let worker = Worker {
+ running: false,
+ matchers: Matchers(matchers),
+ last_snapshot: 0,
+ matches: Vec::new(),
+ // just a placeholder
+ pattern: MultiPattern::new(cols as usize),
+ sort_results: true,
+ reverse_items: false,
+ canceled: Arc::new(AtomicBool::new(false)),
+ should_notify: Arc::new(AtomicBool::new(false)),
+ was_canceled: false,
+ notify,
+ items: Arc::new(boxcar::Vec::with_capacity(2 * 1024, cols)),
+ in_flight: Vec::with_capacity(64),
+ };
+ (pool, worker)
+ }
+
+ unsafe fn process_new_items(&mut self, unmatched: &AtomicU32) {
+ let matchers = &self.matchers;
+ let pattern = &self.pattern;
+ self.matches.reserve(self.in_flight.len());
+ self.in_flight.retain(|&idx| {
+ let Some(item) = self.items.get(idx) else {
+ return true;
+ };
+ if let Some(score) = pattern.score(item.matcher_columns, matchers.get()) {
+ self.matches.push(Match { score, idx });
+ };
+ false
+ });
+ let new_snapshot = self.items.par_snapshot(self.last_snapshot);
+ if new_snapshot.end() != self.last_snapshot {
+ let end = new_snapshot.end();
+ let in_flight = Mutex::new(&mut self.in_flight);
+ let items = new_snapshot.map(|(idx, item)| {
+ let Some(item) = item else {
+ in_flight.lock().push(idx);
+ unmatched.fetch_add(1, atomic::Ordering::Relaxed);
+ return Match {
+ score: 0,
+ idx: u32::MAX,
+ };
+ };
+ if self.canceled.load(atomic::Ordering::Relaxed) {
+ return Match { score: 0, idx };
+ }
+ let Some(score) = pattern.score(item.matcher_columns, matchers.get()) else {
+ unmatched.fetch_add(1, atomic::Ordering::Relaxed);
+ return Match {
+ score: 0,
+ idx: u32::MAX,
+ };
+ };
+ Match { score, idx }
+ });
+ self.matches.par_extend(items);
+ self.last_snapshot = end;
+ }
+ }
+
+ fn remove_in_flight_matches(&mut self) {
+ let mut off = 0;
+ self.in_flight.retain(|&i| {
+ let is_in_flight = self.items.get(i).is_none();
+ if is_in_flight {
+ self.matches.remove((i - off) as usize);
+ off += 1;
+ }
+ is_in_flight
+ });
+ }
+
+ unsafe fn process_new_items_trivial(&mut self) {
+ let new_snapshot = self.items.snapshot(self.last_snapshot);
+ if new_snapshot.end() != self.last_snapshot {
+ let end = new_snapshot.end();
+ let items = new_snapshot.filter_map(|(idx, item)| {
+ if item.is_none() {
+ self.in_flight.push(idx);
+ return None;
+ };
+ Some(Match { score: 0, idx })
+ });
+ self.matches.extend(items);
+ self.last_snapshot = end;
+ }
+ }
+
+ pub(crate) unsafe fn run(&mut self, pattern_status: pattern::Status, cleared: bool) {
+ self.running = true;
+ self.was_canceled = false;
+
+ if cleared {
+ self.last_snapshot = 0;
+ self.in_flight.clear();
+ self.matches.clear();
+ }
+
+ // TODO: be smarter around reusing past results for rescoring
+ if self.pattern.is_empty() {
+ self.reset_matches();
+ self.process_new_items_trivial();
+ let canceled = self.sort_matches();
+ if canceled {
+ self.was_canceled = true;
+ } else if self.should_notify.load(atomic::Ordering::Relaxed) {
+ (self.notify)();
+ }
+ return;
+ }
+
+ if pattern_status == pattern::Status::Rescore {
+ self.reset_matches();
+ }
+
+ let mut unmatched = AtomicU32::new(0);
+ if pattern_status != pattern::Status::Unchanged && !self.matches.is_empty() {
+ self.process_new_items_trivial();
+ let matchers = &self.matchers;
+ let pattern = &self.pattern;
+ self.matches
+ .par_iter_mut()
+ .take_any_while(|_| !self.canceled.load(atomic::Ordering::Relaxed))
+ .for_each(|match_| {
+ if match_.idx == u32::MAX {
+ debug_assert_eq!(match_.score, 0);
+ unmatched.fetch_add(1, atomic::Ordering::Relaxed);
+ return;
+ }
+ // safety: in-flight items are never added to the matches
+ let item = self.items.get_unchecked(match_.idx);
+ if let Some(score) = pattern.score(item.matcher_columns, matchers.get()) {
+ match_.score = score;
+ } else {
+ unmatched.fetch_add(1, atomic::Ordering::Relaxed);
+ match_.score = 0;
+ match_.idx = u32::MAX;
+ }
+ });
+ } else {
+ self.process_new_items(&unmatched);
+ }
+
+ let canceled = self.sort_matches();
+ if canceled {
+ self.was_canceled = true;
+ } else {
+ self.matches
+ .truncate(self.matches.len() - take(unmatched.get_mut()) as usize);
+ if self.should_notify.load(atomic::Ordering::Relaxed) {
+ (self.notify)();
+ }
+ }
+ }
+
+ unsafe fn sort_matches(&mut self) -> bool {
+ if self.sort_results {
+ par_quicksort(
+ &mut self.matches,
+ |match1, match2| {
+ if match1.score != match2.score {
+ return match1.score > match2.score;
+ }
+ if match1.idx == u32::MAX {
+ return false;
+ }
+ if match2.idx == u32::MAX {
+ return true;
+ }
+ // the tie breaker is comparatively rarely needed so we keep it
+ // in a branch especially because we need to access the items
+ // array here which involves some pointer chasing
+ let item1 = self.items.get_unchecked(match1.idx);
+ let item2 = &self.items.get_unchecked(match2.idx);
+ let len1: u32 = item1
+ .matcher_columns
+ .iter()
+ .map(|haystack| haystack.len() as u32)
+ .sum();
+ let len2 = item2
+ .matcher_columns
+ .iter()
+ .map(|haystack| haystack.len() as u32)
+ .sum();
+ if len1 == len2 {
+ if self.reverse_items {
+ match2.idx < match1.idx
+ } else {
+ match1.idx < match2.idx
+ }
+ } else {
+ len1 < len2
+ }
+ },
+ &self.canceled,
+ )
+ } else {
+ par_quicksort(
+ &mut self.matches,
+ |match1, match2| {
+ if match1.idx == u32::MAX {
+ return false;
+ }
+ if match2.idx == u32::MAX {
+ return true;
+ }
+ if self.reverse_items {
+ match2.idx < match1.idx
+ } else {
+ match1.idx < match2.idx
+ }
+ },
+ &self.canceled,
+ )
+ }
+ }
+
+ fn reset_matches(&mut self) {
+ self.matches.clear();
+ self.matches
+ .extend((0..self.last_snapshot).map(|idx| Match { score: 0, idx }));
+ // there are usually only very few in flight items (one for each writer)
+ self.remove_in_flight_matches();
+ }
+}