aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/boxcar.rs786
-rw-r--r--src/lib.rs462
-rw-r--r--src/par_sort.rs895
-rw-r--r--src/pattern.rs100
-rw-r--r--src/pattern/tests.rs14
-rw-r--r--src/tests.rs27
-rw-r--r--src/worker.rs301
7 files changed, 2585 insertions, 0 deletions
diff --git a/src/boxcar.rs b/src/boxcar.rs
new file mode 100644
index 00000000..9b48809d
--- /dev/null
+++ b/src/boxcar.rs
@@ -0,0 +1,786 @@
+//! Adapted from the `boxcar` crate at <https://github.com/ibraheemdev/boxcar/blob/master/src/raw.rs>
+//! under MIT licenes:
+//!
+//! Copyright (c) 2022 Ibraheem Ahmed
+//!
+//! Permission is hereby granted, free of charge, to any person obtaining a copy
+//! of this software and associated documentation files (the "Software"), to deal
+//! in the Software without restriction, including without limitation the rights
+//! to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+//! copies of the Software, and to permit persons to whom the Software is
+//! furnished to do so, subject to the following conditions:
+//!
+//! The above copyright notice and this permission notice shall be included in all
+//! copies or substantial portions of the Software.
+//!
+//! THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+//! IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+//! FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+//! AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+//! LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+//! OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+//! SOFTWARE.
+
+use std::alloc::Layout;
+use std::cell::UnsafeCell;
+use std::fmt::Debug;
+use std::mem::MaybeUninit;
+use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicU64, Ordering};
+use std::{ptr, slice};
+
+use crate::{Item, Utf32String};
+
+const BUCKETS: u32 = u32::BITS - SKIP_BUCKET;
+const MAX_ENTRIES: u32 = u32::MAX - SKIP;
+
+/// A lock-free, append-only vector.
+pub(crate) struct Vec<T> {
+ /// a counter used to retrieve a unique index to push to.
+ ///
+ /// this value may be more than the true length as it will
+ /// be incremented before values are actually stored.
+ inflight: AtomicU64,
+ /// buckets of length 32, 64 .. 2^31
+ buckets: [Bucket<T>; BUCKETS as usize],
+ /// the number of matcher columns in this vector, its absolutely critical that
+ /// this remains constant and after initilaziaton (safety invariant) since
+ /// it is used to calculate the Entry layout
+ columns: u32,
+}
+
+impl<T> Vec<T> {
+ /// Constructs a new, empty `Vec<T>` with the specified capacity and matcher columns.
+ pub fn with_capacity(capacity: u32, columns: u32) -> Vec<T> {
+ assert_ne!(columns, 0, "there must be atleast one matcher column");
+ let init = match capacity {
+ 0 => 0,
+ // initialize enough buckets for `capacity` elements
+ n => Location::of(n).bucket,
+ };
+
+ let mut buckets = [ptr::null_mut(); BUCKETS as usize];
+
+ for (i, bucket) in buckets[..=init as usize].iter_mut().enumerate() {
+ let len = Location::bucket_len(i as u32);
+ *bucket = unsafe { Bucket::alloc(len, columns) };
+ }
+
+ Vec {
+ buckets: buckets.map(Bucket::new),
+ inflight: AtomicU64::new(0),
+ columns,
+ }
+ }
+ pub fn columns(&self) -> u32 {
+ self.columns
+ }
+
+ /// Returns the number of elements in the vector.
+ #[inline]
+ pub fn count(&self) -> u32 {
+ self.inflight
+ .load(Ordering::Acquire)
+ .min(MAX_ENTRIES as u64) as u32
+ }
+
+ // Returns a reference to the element at the given index.
+ //
+ // # Safety
+ //
+ // Entry at `index` must be initialized.
+ #[inline]
+ pub unsafe fn get_unchecked(&self, index: u32) -> Item<'_, T> {
+ let location = Location::of(index);
+
+ unsafe {
+ let entries = self
+ .buckets
+ .get_unchecked(location.bucket as usize)
+ .entries
+ .load(Ordering::Relaxed);
+ debug_assert!(!entries.is_null());
+ let entry = Bucket::<T>::get(entries, location.entry, self.columns);
+ // this looks odd but is necessary to ensure cross
+ // thread synchronization (essentially acting as a memory barrier)
+ // since the caller must only guarantee that he has observed active on any thread
+ // but the current thread might still have an old value cached (although unlikely)
+ let _ = (*entry).active.load(Ordering::Acquire);
+ Entry::read(entry, self.columns)
+ }
+ }
+
+ /// Returns a reference to the element at the given index.
+ pub fn get(&self, index: u32) -> Option<Item<'_, T>> {
+ let location = Location::of(index);
+
+ unsafe {
+ // safety: `location.bucket` is always in bounds
+ let entries = self
+ .buckets
+ .get_unchecked(location.bucket as usize)
+ .entries
+ .load(Ordering::Relaxed);
+
+ // bucket is uninitialized
+ if entries.is_null() {
+ return None;
+ }
+
+ // safety: `location.entry` is always in bounds for it's bucket
+ let entry = Bucket::<T>::get(entries, location.entry, self.columns);
+
+ // safety: the entry is active
+ (*entry)
+ .active
+ .load(Ordering::Acquire)
+ .then(|| Entry::read(entry, self.columns))
+ }
+ }
+
+ /// Appends an element to the back of the vector.
+ pub fn push(&self, value: T, fill_columns: impl FnOnce(&T, &mut [Utf32String])) -> u32 {
+ let index = self.inflight.fetch_add(1, Ordering::Release);
+ // the inflight counter is a `u64` to catch overflows of the vector'scapacity
+ let index: u32 = index.try_into().expect("overflowed maximum capacity");
+ let location = Location::of(index);
+
+ // eagerly allocate the next bucket if we are close to the end of this one
+ if index == (location.bucket_len - (location.bucket_len >> 3)) {
+ if let Some(next_bucket) = self.buckets.get(location.bucket as usize + 1) {
+ Vec::get_or_alloc(next_bucket, location.bucket_len << 1, self.columns);
+ }
+ }
+
+ // safety: `location.bucket` is always in bounds
+ let bucket = unsafe { self.buckets.get_unchecked(location.bucket as usize) };
+ let mut entries = bucket.entries.load(Ordering::Acquire);
+
+ // the bucket has not been allocated yet
+ if entries.is_null() {
+ entries = Vec::get_or_alloc(bucket, location.bucket_len, self.columns);
+ }
+
+ unsafe {
+ // safety: `location.entry` is always in bounds for it's bucket
+ let entry = Bucket::get(entries, location.entry, self.columns);
+
+ // safety: we have unique access to this entry.
+ //
+ // 1. it is impossible for another thread to attempt a `push`
+ // to this location as we retrieved it from `inflight.fetch_add`
+ //
+ // 2. any thread trying to `get` this entry will see `active == false`,
+ // and will not try to access it
+ for col in Entry::matcher_cols_raw(entry, self.columns) {
+ col.get().write(MaybeUninit::new(Utf32String::default()))
+ }
+ fill_columns(&value, Entry::matcher_cols_mut(entry, self.columns));
+ (*entry).slot.get().write(MaybeUninit::new(value));
+ // let other threads know that this entry is active
+ (*entry).active.store(true, Ordering::Release);
+ }
+
+ index
+ }
+
+ /// Extends the vector by appending multiple elements at once.
+ pub fn extend<I>(&self, values: I, fill_columns: impl Fn(&T, &mut [Utf32String]))
+ where
+ I: IntoIterator<Item = T> + ExactSizeIterator,
+ {
+ let count: u32 = values
+ .len()
+ .try_into()
+ .expect("overflowed maximum capacity");
+ if count == 0 {
+ assert!(
+ values.into_iter().next().is_none(),
+ "The `values` variable reported incorrect length."
+ );
+ return;
+ }
+
+ // Reserve all indices at once
+ let start_index: u32 = self
+ .inflight
+ .fetch_add(u64::from(count), Ordering::Release)
+ .try_into()
+ .expect("overflowed maximum capacity");
+
+ // Compute first and last locations
+ let start_location = Location::of(start_index);
+ let end_location = Location::of(start_index + count);
+
+ // Eagerly allocate the next bucket if the last entry is close to the end of its next bucket
+ let alloc_entry = end_location.alloc_next_bucket_entry();
+ if end_location.entry >= alloc_entry
+ && (start_location.bucket != end_location.bucket || start_location.entry <= alloc_entry)
+ {
+ // This might be the last bucket, hence the check
+ if let Some(next_bucket) = self.buckets.get(end_location.bucket as usize + 1) {
+ Vec::get_or_alloc(next_bucket, end_location.bucket_len << 1, self.columns);
+ }
+ }
+
+ let mut bucket = unsafe { self.buckets.get_unchecked(start_location.bucket as usize) };
+ let mut entries = bucket.entries.load(Ordering::Acquire);
+ if entries.is_null() {
+ entries = Vec::get_or_alloc(
+ bucket,
+ Location::bucket_len(start_location.bucket),
+ self.columns,
+ );
+ }
+ // Route each value to its corresponding bucket
+ let mut location;
+ let count = count as usize;
+ for (i, v) in values.into_iter().enumerate() {
+ // ExactSizeIterator is a safe trait that can have bugs/lie about it's size.
+ // Unsafe code cannot rely on the reported length being correct.
+ assert!(i < count);
+
+ location =
+ Location::of(start_index + u32::try_from(i).expect("overflowed maximum capacity"));
+
+ // if we're starting to insert into a different bucket, allocate it beforehand
+ if location.entry == 0 && i != 0 {
+ // safety: `location.bucket` is always in bounds
+ bucket = unsafe { self.buckets.get_unchecked(location.bucket as usize) };
+ entries = bucket.entries.load(Ordering::Acquire);
+
+ if entries.is_null() {
+ entries = Vec::get_or_alloc(
+ bucket,
+ Location::bucket_len(location.bucket),
+ self.columns,
+ );
+ }
+ }
+
+ unsafe {
+ let entry = Bucket::get(entries, location.entry, self.columns);
+
+ // Initialize matcher columns
+ for col in Entry::matcher_cols_raw(entry, self.columns) {
+ col.get().write(MaybeUninit::new(Utf32String::default()));
+ }
+ fill_columns(&v, Entry::matcher_cols_mut(entry, self.columns));
+ (*entry).slot.get().write(MaybeUninit::new(v));
+ (*entry).active.store(true, Ordering::Release);
+ }
+ }
+ }
+
+ /// race to initialize a bucket
+ fn get_or_alloc(bucket: &Bucket<T>, len: u32, cols: u32) -> *mut Entry<T> {
+ let entries = unsafe { Bucket::alloc(len, cols) };
+ match bucket.entries.compare_exchange(
+ ptr::null_mut(),
+ entries,
+ Ordering::Release,
+ Ordering::Acquire,
+ ) {
+ Ok(_) => entries,
+ Err(found) => unsafe {
+ Bucket::dealloc(entries, len, cols);
+ found
+ },
+ }
+ }
+
+ /// Returns an iterator over the vector starting at `start`
+ /// the iterator is deterministically sized and will not grow
+ /// as more elements are pushed
+ pub unsafe fn snapshot(&self, start: u32) -> Iter<'_, T> {
+ let end = self
+ .inflight
+ .load(Ordering::Acquire)
+ .min(MAX_ENTRIES as u64) as u32;
+ assert!(start <= end, "index {start} is out of bounds!");
+ Iter {
+ location: Location::of(start),
+ vec: self,
+ idx: start,
+ end,
+ }
+ }
+
+ /// Returns an iterator over the vector starting at `start`
+ /// the iterator is deterministically sized and will not grow
+ /// as more elements are pushed
+ pub unsafe fn par_snapshot(&self, start: u32) -> ParIter<'_, T> {
+ let end = self
+ .inflight
+ .load(Ordering::Acquire)
+ .min(MAX_ENTRIES as u64) as u32;
+ assert!(start <= end, "index {start} is out of bounds!");
+
+ ParIter {
+ start,
+ end,
+ vec: self,
+ }
+ }
+}
+
+impl<T> Drop for Vec<T> {
+ fn drop(&mut self) {
+ for (i, bucket) in self.buckets.iter_mut().enumerate() {
+ let entries = *bucket.entries.get_mut();
+
+ if entries.is_null() {
+ break;
+ }
+
+ let len = Location::bucket_len(i as u32);
+ // safety: in drop
+ unsafe { Bucket::dealloc(entries, len, self.columns) }
+ }
+ }
+}
+type SnapshotItem<'v, T> = (u32, Option<Item<'v, T>>);
+
+pub struct Iter<'v, T> {
+ location: Location,
+ idx: u32,
+ end: u32,
+ vec: &'v Vec<T>,
+}
+impl<T> Iter<'_, T> {
+ pub fn end(&self) -> u32 {
+ self.end
+ }
+}
+
+impl<'v, T> Iterator for Iter<'v, T> {
+ type Item = SnapshotItem<'v, T>;
+ fn size_hint(&self) -> (usize, Option<usize>) {
+ (
+ (self.end - self.idx) as usize,
+ Some((self.end - self.idx) as usize),
+ )
+ }
+
+ fn next(&mut self) -> Option<SnapshotItem<'v, T>> {
+ if self.end == self.idx {
+ return None;
+ }
+ debug_assert!(self.idx < self.end, "huh {} {}", self.idx, self.end);
+ debug_assert!(self.end as u64 <= self.vec.inflight.load(Ordering::Relaxed));
+
+ loop {
+ let entries = unsafe {
+ self.vec
+ .buckets
+ .get_unchecked(self.location.bucket as usize)
+ .entries
+ .load(Ordering::Relaxed)
+ };
+ debug_assert!(self.location.bucket < BUCKETS);
+
+ if self.location.entry < self.location.bucket_len {
+ if entries.is_null() {
+ // we still want to yield these
+ let index = self.idx;
+ self.location.entry += 1;
+ self.idx += 1;
+ return Some((index, None));
+ }
+ // safety: bounds and null checked above
+ let entry = unsafe { Bucket::get(entries, self.location.entry, self.vec.columns) };
+ let index = self.idx;
+ self.location.entry += 1;
+ self.idx += 1;
+
+ let entry = unsafe {
+ (*entry)
+ .active
+ .load(Ordering::Acquire)
+ .then(|| Entry::read(entry, self.vec.columns))
+ };
+ return Some((index, entry));
+ }
+
+ self.location.entry = 0;
+ self.location.bucket += 1;
+
+ if self.location.bucket < BUCKETS {
+ self.location.bucket_len = Location::bucket_len(self.location.bucket);
+ }
+ }
+ }
+}
+impl<T> ExactSizeIterator for Iter<'_, T> {}
+impl<T> DoubleEndedIterator for Iter<'_, T> {
+ fn next_back(&mut self) -> Option<Self::Item> {
+ unimplemented!()
+ }
+}
+
+pub struct ParIter<'v, T> {
+ end: u32,
+ start: u32,
+ vec: &'v Vec<T>,
+}
+impl<T> ParIter<'_, T> {
+ pub fn end(&self) -> u32 {
+ self.end
+ }
+}
+
+impl<'v, T: Send + Sync> rayon::iter::ParallelIterator for ParIter<'v, T> {
+ type Item = SnapshotItem<'v, T>;
+
+ fn drive_unindexed<C>(self, consumer: C) -> C::Result
+ where
+ C: rayon::iter::plumbing::UnindexedConsumer<Self::Item>,
+ {
+ rayon::iter::plumbing::bridge(self, consumer)
+ }
+
+ fn opt_len(&self) -> Option<usize> {
+ Some((self.end - self.start) as usize)
+ }
+}
+
+impl<T: Send + Sync> rayon::iter::IndexedParallelIterator for ParIter<'_, T> {
+ fn len(&self) -> usize {
+ (self.end - self.start) as usize
+ }
+
+ fn drive<C: rayon::iter::plumbing::Consumer<Self::Item>>(self, consumer: C) -> C::Result {
+ rayon::iter::plumbing::bridge(self, consumer)
+ }
+
+ fn with_producer<CB>(self, callback: CB) -> CB::Output
+ where
+ CB: rayon::iter::plumbing::ProducerCallback<Self::Item>,
+ {
+ callback.callback(ParIterProducer {
+ start: self.start,
+ end: self.end,
+ vec: self.vec,
+ })
+ }
+}
+
+struct ParIterProducer<'v, T: Send> {
+ start: u32,
+ end: u32,
+ vec: &'v Vec<T>,
+}
+
+impl<'v, T: 'v + Send + Sync> rayon::iter::plumbing::Producer for ParIterProducer<'v, T> {
+ type Item = SnapshotItem<'v, T>;
+ type IntoIter = Iter<'v, T>;
+
+ fn into_iter(self) -> Self::IntoIter {
+ debug_assert!(self.start <= self.end);
+ Iter {
+ location: Location::of(self.start),
+ idx: self.start,
+ end: self.end,
+ vec: self.vec,
+ }
+ }
+
+ fn split_at(self, index: usize) -> (Self, Self) {
+ assert!(index <= (self.end - self.start) as usize);
+ let index = index as u32;
+ (
+ ParIterProducer {
+ start: self.start,
+ end: self.start + index,
+ vec: self.vec,
+ },
+ ParIterProducer {
+ start: self.start + index,
+ end: self.end,
+ vec: self.vec,
+ },
+ )
+ }
+}
+
+struct Bucket<T> {
+ entries: AtomicPtr<Entry<T>>,
+}
+
+impl<T> Bucket<T> {
+ fn layout(len: u32, layout: Layout) -> Layout {
+ Layout::from_size_align(layout.size() * len as usize, layout.align())
+ .expect("exceeded maximum allocation size")
+ }
+
+ unsafe fn alloc(len: u32, cols: u32) -> *mut Entry<T> {
+ let layout = Entry::<T>::layout(cols);
+ let arr_layout = Self::layout(len, layout);
+ let entries = std::alloc::alloc(arr_layout);
+ if entries.is_null() {
+ std::alloc::handle_alloc_error(arr_layout)
+ }
+
+ for i in 0..len {
+ let active = entries.add(i as usize * layout.size()) as *mut AtomicBool;
+ active.write(AtomicBool::new(false))
+ }
+ entries as *mut Entry<T>
+ }
+
+ unsafe fn dealloc(entries: *mut Entry<T>, len: u32, cols: u32) {
+ let layout = Entry::<T>::layout(cols);
+ let arr_layout = Self::layout(len, layout);
+ for i in 0..len {
+ let entry = Bucket::get(entries, i, cols);
+ if *(*entry).active.get_mut() {
+ ptr::drop_in_place((*(*entry).slot.get()).as_mut_ptr());
+ for matcher_col in Entry::matcher_cols_raw(entry, cols) {
+ ptr::drop_in_place((*matcher_col.get()).as_mut_ptr());
+ }
+ }
+ }
+ std::alloc::dealloc(entries as *mut u8, arr_layout)
+ }
+
+ unsafe fn get(entries: *mut Entry<T>, idx: u32, cols: u32) -> *mut Entry<T> {
+ let layout = Entry::<T>::layout(cols);
+ let ptr = entries as *mut u8;
+ ptr.add(layout.size() * idx as usize) as *mut Entry<T>
+ }
+
+ fn new(entries: *mut Entry<T>) -> Bucket<T> {
+ Bucket {
+ entries: AtomicPtr::new(entries),
+ }
+ }
+}
+
+#[repr(C)]
+struct Entry<T> {
+ active: AtomicBool,
+ slot: UnsafeCell<MaybeUninit<T>>,
+ tail: [UnsafeCell<MaybeUninit<Utf32String>>; 0],
+}
+
+impl<T> Entry<T> {
+ fn layout(cols: u32) -> Layout {
+ let head = Layout::new::<Self>();
+ let tail = Layout::array::<Utf32String>(cols as usize).expect("invalid memory layout");
+ head.extend(tail)
+ .expect("invalid memory layout")
+ .0
+ .pad_to_align()
+ }
+
+ unsafe fn matcher_cols_raw<'a>(
+ ptr: *mut Entry<T>,
+ cols: u32,
+ ) -> &'a [UnsafeCell<MaybeUninit<Utf32String>>] {
+ // this whole thing looks weird. The reason we do this is that
+ // we must make sure the pointer retains its provenance which may (or may not?)
+ // be lost if we used tail.as_ptr()
+ let tail = std::ptr::addr_of!((*ptr).tail) as *const u8;
+ let offset = tail.offset_from(ptr as *mut u8) as usize;
+ let ptr = (ptr as *mut u8).add(offset) as *mut _;
+ slice::from_raw_parts(ptr, cols as usize)
+ }
+
+ unsafe fn matcher_cols_mut<'a>(ptr: *mut Entry<T>, cols: u32) -> &'a mut [Utf32String] {
+ // this whole thing looks weird. The reason we do this is that
+ // we must make sure the pointer retains its provenance which may (or may not?)
+ // be lost if we used tail.as_ptr()
+ let tail = std::ptr::addr_of!((*ptr).tail) as *const u8;
+ let offset = tail.offset_from(ptr as *mut u8) as usize;
+ let ptr = (ptr as *mut u8).add(offset) as *mut _;
+ slice::from_raw_parts_mut(ptr, cols as usize)
+ }
+ // # Safety
+ //
+ // Value must be initialized.
+ unsafe fn read<'a>(ptr: *mut Entry<T>, cols: u32) -> Item<'a, T> {
+ // this whole thing looks weird. The reason we do this is that
+ // we must make sure the pointer retains its provenance which may (or may not?)
+ // be lost if we used tail.as_ptr()
+ let data = (*(*ptr).slot.get()).assume_init_ref();
+ let tail = std::ptr::addr_of!((*ptr).tail) as *const u8;
+ let offset = tail.offset_from(ptr as *mut u8) as usize;
+ let ptr = (ptr as *mut u8).add(offset) as *mut _;
+ let matcher_columns = slice::from_raw_parts(ptr, cols as usize);
+ Item {
+ data,
+ matcher_columns,
+ }
+ }
+}
+
+#[derive(Debug)]
+struct Location {
+ // the index of the bucket
+ bucket: u32,
+ // the length of `bucket`
+ bucket_len: u32,
+ // the index of the entry in `bucket`
+ entry: u32,
+}
+
+// skip the shorter buckets to avoid unnecessary allocations.
+// this also reduces the maximum capacity of a vector.
+const SKIP: u32 = 32;
+const SKIP_BUCKET: u32 = (u32::BITS - SKIP.leading_zeros()) - 1;
+
+impl Location {
+ fn of(index: u32) -> Location {
+ let skipped = index.checked_add(SKIP).expect("exceeded maximum length");
+ let bucket = u32::BITS - skipped.leading_zeros();
+ let bucket = bucket - (SKIP_BUCKET + 1);
+ let bucket_len = Location::bucket_len(bucket);
+ let entry = skipped ^ bucket_len;
+
+ Location {
+ bucket,
+ bucket_len,
+ entry,
+ }
+ }
+
+ fn bucket_len(bucket: u32) -> u32 {
+ 1 << (bucket + SKIP_BUCKET)
+ }
+
+ /// The entry index at which the next bucket should be pre-allocated.
+ fn alloc_next_bucket_entry(&self) -> u32 {
+ self.bucket_len - (self.bucket_len >> 3)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn location() {
+ assert_eq!(Location::bucket_len(0), 32);
+ for i in 0..32 {
+ let loc = Location::of(i);
+ assert_eq!(loc.bucket_len, 32);
+ assert_eq!(loc.bucket, 0);
+ assert_eq!(loc.entry, i);
+ }
+
+ assert_eq!(Location::bucket_len(1), 64);
+ for i in 33..96 {
+ let loc = Location::of(i);
+ assert_eq!(loc.bucket_len, 64);
+ assert_eq!(loc.bucket, 1);
+ assert_eq!(loc.entry, i - 32);
+ }
+
+ assert_eq!(Location::bucket_len(2), 128);
+ for i in 96..224 {
+ let loc = Location::of(i);
+ assert_eq!(loc.bucket_len, 128);
+ assert_eq!(loc.bucket, 2);
+ assert_eq!(loc.entry, i - 96);
+ }
+
+ let max = Location::of(MAX_ENTRIES);
+ assert_eq!(max.bucket, BUCKETS - 1);
+ assert_eq!(max.bucket_len, 1 << 31);
+ assert_eq!(max.entry, (1 << 31) - 1);
+ }
+
+ #[test]
+ fn extend_unique_bucket() {
+ let vec = Vec::<u32>::with_capacity(1, 1);
+ vec.extend(0..10, |_, _| {});
+ assert_eq!(vec.count(), 10);
+ for i in 0..10 {
+ assert_eq!(*vec.get(i).unwrap().data, i);
+ }
+ assert!(vec.get(10).is_none());
+ }
+
+ #[test]
+ fn extend_over_two_buckets() {
+ let vec = Vec::<u32>::with_capacity(1, 1);
+ vec.extend(0..100, |_, _| {});
+ assert_eq!(vec.count(), 100);
+ for i in 0..100 {
+ assert_eq!(*vec.get(i).unwrap().data, i);
+ }
+ assert!(vec.get(100).is_none());
+ }
+
+ #[test]
+ fn extend_over_more_than_two_buckets() {
+ let vec = Vec::<u32>::with_capacity(1, 1);
+ vec.extend(0..1000, |_, _| {});
+ assert_eq!(vec.count(), 1000);
+ for i in 0..1000 {
+ assert_eq!(*vec.get(i).unwrap().data, i);
+ }
+ assert!(vec.get(1000).is_none());
+ }
+
+ #[test]
+ /// test that ExactSizeIterator returning incorrect length is caught (0 AND more than reported)
+ fn extend_with_incorrect_reported_len_is_caught() {
+ struct IncorrectLenIter {
+ len: usize,
+ iter: std::ops::Range<u32>,
+ }
+
+ impl Iterator for IncorrectLenIter {
+ type Item = u32;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ self.iter.next()
+ }
+ }
+
+ impl ExactSizeIterator for IncorrectLenIter {
+ fn len(&self) -> usize {
+ self.len
+ }
+ }
+
+ let vec = Vec::<u32>::with_capacity(1, 1);
+ let iter = IncorrectLenIter {
+ len: 10,
+ iter: (0..12),
+ };
+ // this should panic
+ assert!(std::panic::catch_unwind(|| vec.extend(iter, |_, _| {})).is_err());
+
+ let vec = Vec::<u32>::with_capacity(1, 1);
+ let iter = IncorrectLenIter {
+ len: 12,
+ iter: (0..10),
+ };
+ // this shouldn't panic and should just ignore the extra elements
+ assert!(std::panic::catch_unwind(|| vec.extend(iter, |_, _| {})).is_ok());
+ // we should reserve 12 elements but only 10 should be present
+ assert_eq!(vec.count(), 12);
+ for i in 0..10 {
+ assert_eq!(*vec.get(i).unwrap().data, i);
+ }
+ assert!(vec.get(10).is_none());
+
+ let vec = Vec::<u32>::with_capacity(1, 1);
+ let iter = IncorrectLenIter {
+ len: 0,
+ iter: (0..2),
+ };
+ // this should panic
+ assert!(std::panic::catch_unwind(|| vec.extend(iter, |_, _| {})).is_err());
+ }
+
+ // test |values| does not fit in the boxcar
+ #[test]
+ fn extend_over_max_capacity() {
+ let vec = Vec::<u32>::with_capacity(1, 1);
+ let count = MAX_ENTRIES as usize + 2;
+ let iter = std::iter::repeat(0).take(count);
+ assert!(std::panic::catch_unwind(|| vec.extend(iter, |_, _| {})).is_err());
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644
index 00000000..7ddb7407
--- /dev/null
+++ b/src/lib.rs
@@ -0,0 +1,462 @@
+/*!
+`nucleo` is a high level crate that provides a high level matcher API that
+provides a highly effective (parallel) matcher worker. It's designed to allow
+quickly plugging a fully featured (and faster) fzf/skim like fuzzy matcher into
+your TUI application.
+
+It's designed to run matching on a background threadpool while providing a
+snapshot of the last complete match. That means the matcher can update the
+results live while the user is typing while never blocking the main UI thread
+(beyond a user provided timeout). Nucleo also supports fully concurrent lock-free
+(and wait-free) streaming of input items.
+
+The [`Nucleo`] struct serves as the main API entrypoint for this crate.
+
+# Status
+
+Nucleo is used in the helix-editor and therefore has a large user base with lots
+or real world testing. The core matcher implementation is considered complete
+and is unlikely to see major changes. The `nucleo-matcher` crate is finished and
+ready for widespread use, breaking changes should be very rare (a 1.0 release
+should not be far away).
+
+While the high level `nucleo` crate also works well (and is also used in helix),
+there are still additional features that will be added in the future. The high
+level crate also need better documentation and will likely see a few minor API
+changes in the future.
+
+*/
+use std::ops::{Bound, RangeBounds};
+use std::sync::atomic::{self, AtomicBool, Ordering};
+use std::sync::Arc;
+use std::time::Duration;
+
+use parking_lot::Mutex;
+use rayon::ThreadPool;
+
+use crate::pattern::MultiPattern;
+use crate::worker::Worker;
+pub use nucleo_matcher::{chars, Config, Matcher, Utf32Str, Utf32String};
+
+mod boxcar;
+mod par_sort;
+pub mod pattern;
+mod worker;
+
+#[cfg(test)]
+mod tests;
+
+/// A match candidate stored in a [`Nucleo`] worker.
+pub struct Item<'a, T> {
+ pub data: &'a T,
+ pub matcher_columns: &'a [Utf32String],
+}
+
+/// A handle that allows adding new items to a [`Nucleo`] worker.
+///
+/// It's internally reference counted and can be cheaply cloned
+/// and sent across threads.
+pub struct Injector<T> {
+ items: Arc<boxcar::Vec<T>>,
+ notify: Arc<(dyn Fn() + Sync + Send)>,
+}
+
+impl<T> Clone for Injector<T> {
+ fn clone(&self) -> Self {
+ Injector {
+ items: self.items.clone(),
+ notify: self.notify.clone(),
+ }
+ }
+}
+
+impl<T> Injector<T> {
+ /// Appends an element to the list of matched items.
+ /// This function is lock-free and wait-free.
+ pub fn push(&self, value: T, fill_columns: impl FnOnce(&T, &mut [Utf32String])) -> u32 {
+ let idx = self.items.push(value, fill_columns);
+ (self.notify)();
+ idx
+ }
+
+ /// Appends multiple elements to the list of matched items.
+ /// This function is lock-free and wait-free.
+ ///
+ /// You should favor this function over `push` if at least one of the following is true:
+ /// - the number of items you're adding can be computed beforehand and is typically larger
+ /// than 1k
+ /// - you're able to batch incoming items
+ /// - you're adding items from multiple threads concurrently (this function results in less
+ /// contention)
+ pub fn extend<I>(&self, values: I, fill_columns: impl Fn(&T, &mut [Utf32String]))
+ where
+ I: IntoIterator<Item = T> + ExactSizeIterator,
+ {
+ self.items.extend(values, fill_columns);
+ (self.notify)();
+ }
+
+ /// Returns the total number of items injected in the matcher. This might
+ /// not match the number of items in the match snapshot (if the matcher
+ /// is still running)
+ pub fn injected_items(&self) -> u32 {
+ self.items.count()
+ }
+
+ /// Returns a reference to the item at the given index.
+ ///
+ /// # Safety
+ ///
+ /// Item at `index` must be initialized. That means you must have observed
+ /// `push` returning this value or `get` returning `Some` for this value.
+ /// Just because a later index is initialized doesn't mean that this index
+ /// is initialized
+ pub unsafe fn get_unchecked(&self, index: u32) -> Item<'_, T> {
+ self.items.get_unchecked(index)
+ }
+
+ /// Returns a reference to the element at the given index.
+ pub fn get(&self, index: u32) -> Option<Item<'_, T>> {
+ self.items.get(index)
+ }
+}
+
+/// An [item](crate::Item) that was successfully matched by a [`Nucleo`] worker.
+#[derive(PartialEq, Eq, Debug, Clone, Copy)]
+pub struct Match {
+ pub score: u32,
+ pub idx: u32,
+}
+
+/// That status of a [`Nucleo`] worker after a match.
+#[derive(PartialEq, Eq, Debug, Clone, Copy)]
+pub struct Status {
+ /// Whether the current snapshot has changed.
+ pub changed: bool,
+ /// Whether the matcher is still processing in the background.
+ pub running: bool,
+}
+
+/// A snapshot represent the results of a [`Nucleo`] worker after
+/// finishing a [`tick`](Nucleo::tick).
+pub struct Snapshot<T: Sync + Send + 'static> {
+ item_count: u32,
+ matches: Vec<Match>,
+ pattern: MultiPattern,
+ items: Arc<boxcar::Vec<T>>,
+}
+
+impl<T: Sync + Send + 'static> Snapshot<T> {
+ fn clear(&mut self, new_items: Arc<boxcar::Vec<T>>) {
+ self.item_count = 0;
+ self.matches.clear();
+ self.items = new_items
+ }
+
+ fn update(&mut self, worker: &Worker<T>) {
+ self.item_count = worker.item_count();
+ self.pattern.clone_from(&worker.pattern);
+ self.matches.clone_from(&worker.matches);
+ if !Arc::ptr_eq(&worker.items, &self.items) {
+ self.items = worker.items.clone()
+ }
+ }
+
+ /// Returns that total number of items
+ pub fn item_count(&self) -> u32 {
+ self.item_count
+ }
+
+ /// Returns the pattern which items were matched against
+ pub fn pattern(&self) -> &MultiPattern {
+ &self.pattern
+ }
+
+ /// Returns that number of items that matched the pattern
+ pub fn matched_item_count(&self) -> u32 {
+ self.matches.len() as u32
+ }
+
+ /// Returns an iterator over the items that correspond to a subrange of
+ /// all the matches in this snapshot.
+ ///
+ /// # Panics
+ /// Panics if `range` has a range bound that is larger than
+ /// the matched item count
+ pub fn matched_items(
+ &self,
+ range: impl RangeBounds<u32>,
+ ) -> impl ExactSizeIterator<Item = Item<'_, T>> + DoubleEndedIterator + '_ {
+ // TODO: use TAIT
+ let start = match range.start_bound() {
+ Bound::Included(&start) => start as usize,
+ Bound::Excluded(&start) => start as usize + 1,
+ Bound::Unbounded => 0,
+ };
+ let end = match range.end_bound() {
+ Bound::Included(&end) => end as usize + 1,
+ Bound::Excluded(&end) => end as usize,
+ Bound::Unbounded => self.matches.len(),
+ };
+ self.matches[start..end]
+ .iter()
+ .map(|&m| unsafe { self.items.get_unchecked(m.idx) })
+ }
+
+ /// Returns a reference to the item at the given index.
+ ///
+ /// # Safety
+ ///
+ /// Item at `index` must be initialized. That means you must have observed a
+ /// match with the corresponding index in this exact snapshot. Observing
+ /// a higher index is not enough as item indices can be non-contigously
+ /// initialized
+ #[inline]
+ pub unsafe fn get_item_unchecked(&self, index: u32) -> Item<'_, T> {
+ self.items.get_unchecked(index)
+ }
+
+ /// Returns a reference to the item at the given index.
+ ///
+ /// Returns `None` if the given `index` is not initialized. This function
+ /// is only guarteed to return `Some` for item indices that can be found in
+ /// the `matches` of this struct. Both smaller and larger indices may return
+ /// `None`.
+ #[inline]
+ pub fn get_item(&self, index: u32) -> Option<Item<'_, T>> {
+ self.items.get(index)
+ }
+
+ /// Return the matches corresponding to this snapshot.
+ #[inline]
+ pub fn matches(&self) -> &[Match] {
+ &self.matches
+ }
+
+ /// A convenience function to return the [`Item`] corresponding to the
+ /// `n`th match.
+ ///
+ /// Returns `None` if `n` is greater than or equal to the match count.
+ #[inline]
+ pub fn get_matched_item(&self, n: u32) -> Option<Item<'_, T>> {
+ // SAFETY: A match index is guaranteed to corresponding to a valid global index in this
+ // snapshot.
+ unsafe { Some(self.get_item_unchecked(self.matches.get(n as usize)?.idx)) }
+ }
+}
+
+#[repr(u8)]
+#[derive(Clone, Copy, PartialEq, Eq)]
+enum State {
+ Init,
+ /// items have been cleared but snapshot and items are still outdated
+ Cleared,
+ /// items are fresh
+ Fresh,
+}
+
+impl State {
+ fn matcher_item_refs(self) -> usize {
+ match self {
+ State::Cleared => 1,
+ State::Init | State::Fresh => 2,
+ }
+ }
+
+ fn canceled(self) -> bool {
+ self != State::Fresh
+ }
+
+ fn cleared(self) -> bool {
+ self != State::Fresh
+ }
+}
+
+/// A high level matcher worker that quickly computes matches in a background
+/// threadpool.
+pub struct Nucleo<T: Sync + Send + 'static> {
+ // the way the API is build we totally don't actually need these to be Arcs
+ // but this lets us avoid some unsafe
+ canceled: Arc<AtomicBool>,
+ should_notify: Arc<AtomicBool>,
+ worker: Arc<Mutex<Worker<T>>>,
+ pool: ThreadPool,
+ state: State,
+ items: Arc<boxcar::Vec<T>>,
+ notify: Arc<(dyn Fn() + Sync + Send)>,
+ snapshot: Snapshot<T>,
+ /// The pattern matched by this matcher. To update the match pattern
+ /// [`MultiPattern::reparse`](`pattern::MultiPattern::reparse`) should be used.
+ /// Note that the matcher worker will only become aware of the new pattern
+ /// after a call to [`tick`](Nucleo::tick).
+ pub pattern: MultiPattern,
+}
+
+impl<T: Sync + Send + 'static> Nucleo<T> {
+ /// Constructs a new `nucleo` worker threadpool with the provided `config`.
+ ///
+ /// `notify` is called everytime new information is available and
+ /// [`tick`](Nucleo::tick) should be called. Note that `notify` is not
+ /// debounced, that should be handled by the downstream crate (for example
+ /// debouncing to only redraw at most every 1/60 seconds).
+ ///
+ /// If `None` is passed for the number of worker threads, nucleo will use
+ /// one thread per hardware thread.
+ ///
+ /// Nucleo can match items with multiple orthogonal properties. `columns`
+ /// indicates how many matching columns each item (and the pattern) has. The
+ /// number of columns cannot be changed after construction.
+ pub fn new(
+ config: Config,
+ notify: Arc<(dyn Fn() + Sync + Send)>,
+ num_threads: Option<usize>,
+ columns: u32,
+ ) -> Self {
+ let (pool, worker) = Worker::new(num_threads, config, notify.clone(), columns);
+ Self {
+ canceled: worker.canceled.clone(),
+ should_notify: worker.should_notify.clone(),
+ items: worker.items.clone(),
+ pool,
+ pattern: MultiPattern::new(columns as usize),
+ snapshot: Snapshot {
+ matches: Vec::with_capacity(2 * 1024),
+ pattern: MultiPattern::new(columns as usize),
+ item_count: 0,
+ items: worker.items.clone(),
+ },
+ worker: Arc::new(Mutex::new(worker)),
+ state: State::Init,
+ notify,
+ }
+ }
+
+ /// Returns the total number of active injectors
+ pub fn active_injectors(&self) -> usize {
+ Arc::strong_count(&self.items)
+ - self.state.matcher_item_refs()
+ - (Arc::ptr_eq(&self.snapshot.items, &self.items)) as usize
+ }
+
+ /// Returns a snapshot of the current matcher state.
+ pub fn snapshot(&self) -> &Snapshot<T> {
+ &self.snapshot
+ }
+
+ /// Returns an injector that can be used for adding candidates to the matcher.
+ pub fn injector(&self) -> Injector<T> {
+ Injector {
+ items: self.items.clone(),
+ notify: self.notify.clone(),
+ }
+ }
+
+ /// Restart the the item stream. Removes all items and disconnects all
+ /// previously created injectors from this instance. If `clear_snapshot`
+ /// is `true` then all items and matched are removed from the [`Snapshot`]
+ /// immediately. Otherwise the snapshot will keep the current matches until
+ /// the matcher has run again.
+ ///
+ /// # Note
+ ///
+ /// The injectors will continue to function but they will not affect this
+ /// instance anymore. The old items will only be dropped when all injectors
+ /// were dropped.
+ pub fn restart(&mut self, clear_snapshot: bool) {
+ self.canceled.store(true, Ordering::Relaxed);
+ self.items = Arc::new(boxcar::Vec::with_capacity(1024, self.items.columns()));
+ self.state = State::Cleared;
+ if clear_snapshot {
+ self.snapshot.clear(self.items.clone());
+ }
+ }
+
+ /// Update the internal configuration.
+ pub fn update_config(&mut self, config: Config) {
+ self.worker.lock().update_config(config)
+ }
+
+ // Set whether the matcher should sort search results by score after
+ // matching. Defaults to true.
+ pub fn sort_results(&mut self, sort_results: bool) {
+ self.worker.lock().sort_results(sort_results)
+ }
+
+ // Set whether the matcher should reverse the order of the input.
+ // Defaults to false.
+ pub fn reverse_items(&mut self, reverse_items: bool) {
+ self.worker.lock().reverse_items(reverse_items)
+ }
+
+ /// The main way to interact with the matcher, this should be called
+ /// regularly (for example each time a frame is rendered). To avoid
+ /// excessive redraws this method will wait `timeout` milliseconds for the
+ /// worker thread to finish. It is recommend to set the timeout to 10ms.
+ pub fn tick(&mut self, timeout: u64) -> Status {
+ self.should_notify.store(false, atomic::Ordering::Relaxed);
+ let status = self.pattern.status();
+ let canceled = status != pattern::Status::Unchanged || self.state.canceled();
+ let mut res = self.tick_inner(timeout, canceled, status);
+ if !canceled {
+ return res;
+ }
+ self.state = State::Fresh;
+ let status2 = self.tick_inner(timeout, false, pattern::Status::Unchanged);
+ res.changed |= status2.changed;
+ res.running = status2.running;
+ res
+ }
+
+ fn tick_inner(&mut self, timeout: u64, canceled: bool, status: pattern::Status) -> Status {
+ let mut inner = if canceled {
+ self.pattern.reset_status();
+ self.canceled.store(true, atomic::Ordering::Relaxed);
+ self.worker.lock_arc()
+ } else {
+ let Some(worker) = self.worker.try_lock_arc_for(Duration::from_millis(timeout)) else {
+ self.should_notify.store(true, Ordering::Release);
+ return Status {
+ changed: false,
+ running: true,
+ };
+ };
+ worker
+ };
+
+ let changed = inner.running;
+
+ let running = canceled || self.items.count() > inner.item_count();
+ if inner.running {
+ inner.running = false;
+ if !inner.was_canceled && !self.state.canceled() {
+ self.snapshot.update(&inner)
+ }
+ }
+ if running {
+ inner.pattern.clone_from(&self.pattern);
+ self.canceled.store(false, atomic::Ordering::Relaxed);
+ if !canceled {
+ self.should_notify.store(true, atomic::Ordering::Release);
+ }
+ let cleared = self.state.cleared();
+ if cleared {
+ inner.items = self.items.clone();
+ }
+ self.pool
+ .spawn(move || unsafe { inner.run(status, cleared) })
+ }
+ Status { changed, running }
+ }
+}
+
+impl<T: Sync + Send> Drop for Nucleo<T> {
+ fn drop(&mut self) {
+ // we ensure the worker quits before dropping items to ensure that
+ // the worker can always assume the items outlive it
+ self.canceled.store(true, atomic::Ordering::Relaxed);
+ let lock = self.worker.try_lock_for(Duration::from_secs(1));
+ if lock.is_none() {
+ unreachable!("thread pool failed to shutdown properly")
+ }
+ }
+}
diff --git a/src/par_sort.rs b/src/par_sort.rs
new file mode 100644
index 00000000..92f716cc
--- /dev/null
+++ b/src/par_sort.rs
@@ -0,0 +1,895 @@
+//! Parallel quicksort.
+//!
+//! This implementation is copied verbatim from `std::slice::sort_unstable` and then parallelized.
+//! The only difference from the original is that calls to `recurse` are executed in parallel using
+//! `rayon_core::join`.
+//! Further modified for nucleo to allow canceling the sort
+
+// Copyright (c) 2010 The Rust Project Developers
+//
+// Permission is hereby granted, free of charge, to any
+// person obtaining a copy of this software and associated
+// documentation files (the "Software"), to deal in the
+// Software without restriction, including without
+// limitation the rights to use, copy, modify, merge,
+// publish, distribute, sublicense, and/or sell copies of
+// the Software, and to permit persons to whom the Software
+// is furnished to do so, subject to the following
+// conditions:
+//
+// The above copyright notice and this permission notice
+// shall be included in all copies or substantial portions
+// of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
+// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
+// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
+// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
+// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
+// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
+// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+// DEALINGS IN THE SOFTWARE.
+
+use std::cmp;
+use std::mem::{self, MaybeUninit};
+use std::ptr;
+use std::sync::atomic::{self, AtomicBool};
+
+/// When dropped, copies from `src` into `dest`.
+struct CopyOnDrop<T> {
+ src: *const T,
+ dest: *mut T,
+}
+
+impl<T> Drop for CopyOnDrop<T> {
+ fn drop(&mut self) {
+ // SAFETY: This is a helper class.
+ // Please refer to its usage for correctness.
+ // Namely, one must be sure that `src` and `dst` does not overlap as required by `ptr::copy_nonoverlapping`.
+ unsafe {
+ ptr::copy_nonoverlapping(self.src, self.dest, 1);
+ }
+ }
+}
+
+/// Shifts the first element to the right until it encounters a greater or equal element.
+fn shift_head<T, F>(v: &mut [T], is_less: &F)
+where
+ F: Fn(&T, &T) -> bool,
+{
+ let len = v.len();
+ // SAFETY: The unsafe operations below involves indexing without a bounds check (by offsetting a
+ // pointer) and copying memory (`ptr::copy_nonoverlapping`).
+ //
+ // a. Indexing:
+ // 1. We checked the size of the array to >=2.
+ // 2. All the indexing that we will do is always between {0 <= index < len} at most.
+ //
+ // b. Memory copying
+ // 1. We are obtaining pointers to references which are guaranteed to be valid.
+ // 2. They cannot overlap because we obtain pointers to difference indices of the slice.
+ // Namely, `i` and `i-1`.
+ // 3. If the slice is properly aligned, the elements are properly aligned.
+ // It is the caller's responsibility to make sure the slice is properly aligned.
+ //
+ // See comments below for further detail.
+ unsafe {
+ // If the first two elements are out-of-order...
+ if len >= 2 && is_less(v.get_unchecked(1), v.get_unchecked(0)) {
+ // Read the first element into a stack-allocated variable. If a following comparison
+ // operation panics, `hole` will get dropped and automatically write the element back
+ // into the slice.
+ let tmp = mem::ManuallyDrop::new(ptr::read(v.get_unchecked(0)));
+ let v = v.as_mut_ptr();
+ let mut hole = CopyOnDrop {
+ src: &*tmp,
+ dest: v.add(1),
+ };
+ ptr::copy_nonoverlapping(v.add(1), v.add(0), 1);
+
+ for i in 2..len {
+ if !is_less(&*v.add(i), &*tmp) {
+ break;
+ }
+
+ // Move `i`-th element one place to the left, thus shifting the hole to the right.
+ ptr::copy_nonoverlapping(v.add(i), v.add(i - 1), 1);
+ hole.dest = v.add(i);
+ }
+ // `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`.
+ }
+ }
+}
+
+/// Shifts the last element to the left until it encounters a smaller or equal element.
+fn shift_tail<T, F>(v: &mut [T], is_less: &F)
+where
+ F: Fn(&T, &T) -> bool,
+{
+ let len = v.len();
+ // SAFETY: The unsafe operations below involves indexing without a bound check (by offsetting a
+ // pointer) and copying memory (`ptr::copy_nonoverlapping`).
+ //
+ // a. Indexing:
+ // 1. We checked the size of the array to >= 2.
+ // 2. All the indexing that we will do is always between `0 <= index < len-1` at most.
+ //
+ // b. Memory copying
+ // 1. We are obtaining pointers to references which are guaranteed to be valid.
+ // 2. They cannot overlap because we obtain pointers to difference indices of the slice.
+ // Namely, `i` and `i+1`.
+ // 3. If the slice is properly aligned, the elements are properly aligned.
+ // It is the caller's responsibility to make sure the slice is properly aligned.
+ //
+ // See comments below for further detail.
+ unsafe {
+ // If the last two elements are out-of-order...
+ if len >= 2 && is_less(v.get_unchecked(len - 1), v.get_unchecked(len - 2)) {
+ // Read the last element into a stack-allocated variable. If a following comparison
+ // operation panics, `hole` will get dropped and automatically write the element back
+ // into the slice.
+ let tmp = mem::ManuallyDrop::new(ptr::read(v.get_unchecked(len - 1)));
+ let v = v.as_mut_ptr();
+ let mut hole = CopyOnDrop {
+ src: &*tmp,
+ dest: v.add(len - 2),
+ };
+ ptr::copy_nonoverlapping(v.add(len - 2), v.add(len - 1), 1);
+
+ for i in (0..len - 2).rev() {
+ if !is_less(&*tmp, &*v.add(i)) {
+ break;
+ }
+
+ // Move `i`-th element one place to the right, thus shifting the hole to the left.
+ ptr::copy_nonoverlapping(v.add(i), v.add(i + 1), 1);
+ hole.dest = v.add(i);
+ }
+ // `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`.
+ }
+ }
+}
+
+/// Partially sorts a slice by shifting several out-of-order elements around.
+///
+/// Returns `true` if the slice is sorted at the end. This function is *O*(*n*) worst-case.
+#[cold]
+fn partial_insertion_sort<T, F>(v: &mut [T], is_less: &F) -> bool
+where
+ F: Fn(&T, &T) -> bool,
+{
+ // Maximum number of adjacent out-of-order pairs that will get shifted.
+ const MAX_STEPS: usize = 5;
+ // If the slice is shorter than this, don't shift any elements.
+ const SHORTEST_SHIFTING: usize = 50;
+
+ let len = v.len();
+ let mut i = 1;
+
+ for _ in 0..MAX_STEPS {
+ // SAFETY: We already explicitly did the bound checking with `i < len`.
+ // All our subsequent indexing is only in the range `0 <= index < len`
+ unsafe {
+ // Find the next pair of adjacent out-of-order elements.
+ while i < len && !is_less(v.get_unchecked(i), v.get_unchecked(i - 1)) {
+ i += 1;
+ }
+ }
+
+ // Are we done?
+ if i == len {
+ return true;
+ }
+
+ // Don't shift elements on short arrays, that has a performance cost.
+ if len < SHORTEST_SHIFTING {
+ return false;
+ }
+
+ // Swap the found pair of elements. This puts them in correct order.
+ v.swap(i - 1, i);
+
+ // Shift the smaller element to the left.
+ shift_tail(&mut v[..i], is_less);
+ // Shift the greater element to the right.
+ shift_head(&mut v[i..], is_less);
+ }
+
+ // Didn't manage to sort the slice in the limited number of steps.
+ false
+}
+
+/// Sorts a slice using insertion sort, which is *O*(*n*^2) worst-case.
+fn insertion_sort<T, F>(v: &mut [T], is_less: &F)
+where
+ F: Fn(&T, &T) -> bool,
+{
+ for i in 1..v.len() {
+ shift_tail(&mut v[..i + 1], is_less);
+ }
+}
+
+/// Sorts `v` using heapsort, which guarantees *O*(*n* \* log(*n*)) worst-case.
+#[cold]
+fn heapsort<T, F>(v: &mut [T], is_less: &F)
+where
+ F: Fn(&T, &T) -> bool,
+{
+ // This binary heap respects the invariant `parent >= child`.
+ let sift_down = |v: &mut [T], mut node| {
+ loop {
+ // Children of `node`.
+ let mut child = 2 * node + 1;
+ if child >= v.len() {
+ break;
+ }
+
+ // Choose the greater child.
+ if child + 1 < v.len() && is_less(&v[child], &v[child + 1]) {
+ child += 1;
+ }
+
+ // Stop if the invariant holds at `node`.
+ if !is_less(&v[node], &v[child]) {
+ break;
+ }
+
+ // Swap `node` with the greater child, move one step down, and continue sifting.
+ v.swap(node, child);
+ node = child;
+ }
+ };
+
+ // Build the heap in linear time.
+ for i in (0..v.len() / 2).rev() {
+ sift_down(v, i);
+ }
+
+ // Pop maximal elements from the heap.
+ for i in (1..v.len()).rev() {
+ v.swap(0, i);
+ sift_down(&mut v[..i], 0);
+ }
+}
+
+/// Partitions `v` into elements smaller than `pivot`, followed by elements greater than or equal
+/// to `pivot`.
+///
+/// Returns the number of elements smaller than `pivot`.
+///
+/// Partitioning is performed block-by-block in order to minimize the cost of branching operations.
+/// This idea is presented in the [BlockQuicksort][pdf] paper.
+///
+/// [pdf]: https://drops.dagstuhl.de/opus/volltexte/2016/6389/pdf/LIPIcs-ESA-2016-38.pdf
+fn partition_in_blocks<T, F>(v: &mut [T], pivot: &T, is_less: &F) -> usize
+where
+ F: Fn(&T, &T) -> bool,
+{
+ // Number of elements in a typical block.
+ const BLOCK: usize = 128;
+
+ // The partitioning algorithm repeats the following steps until completion:
+ //
+ // 1. Trace a block from the left side to identify elements greater than or equal to the pivot.
+ // 2. Trace a block from the right side to identify elements smaller than the pivot.
+ // 3. Exchange the identified elements between the left and right side.
+ //
+ // We keep the following variables for a block of elements:
+ //
+ // 1. `block` - Number of elements in the block.
+ // 2. `start` - Start pointer into the `offsets` array.
+ // 3. `end` - End pointer into the `offsets` array.
+ // 4. `offsets - Indices of out-of-order elements within the block.
+
+ // The current block on the left side (from `l` to `l.add(block_l)`).
+ let mut l = v.as_mut_ptr();
+ let mut block_l = BLOCK;
+ let mut start_l = ptr::null_mut();
+ let mut end_l = ptr::null_mut();
+ let mut offsets_l = [MaybeUninit::<u8>::uninit(); BLOCK];
+
+ // The current block on the right side (from `r.sub(block_r)` to `r`).
+ // SAFETY: The documentation for .add() specifically mention that `vec.as_ptr().add(vec.len())` is always safe`
+ let mut r = unsafe { l.add(v.len()) };
+ let mut block_r = BLOCK;
+ let mut start_r = ptr::null_mut();
+ let mut end_r = ptr::null_mut();
+ let mut offsets_r = [MaybeUninit::<u8>::uninit(); BLOCK];
+
+ // FIXME: When we get VLAs, try creating one array of length `min(v.len(), 2 * BLOCK)` rather
+ // than two fixed-size arrays of length `BLOCK`. VLAs might be more cache-efficient.
+
+ // Returns the number of elements between pointers `l` (inclusive) and `r` (exclusive).
+ fn width<T>(l: *mut T, r: *mut T) -> usize {
+ assert!(mem::size_of::<T>() > 0);
+ // FIXME: this should *likely* use `offset_from`, but more
+ // investigation is needed (including running tests in miri).
+ // TODO unstable: (r.addr() - l.addr()) / mem::size_of::<T>()
+ (r as usize - l as usize) / mem::size_of::<T>()
+ }
+
+ loop {
+ // We are done with partitioning block-by-block when `l` and `r` get very close. Then we do
+ // some patch-up work in order to partition the remaining elements in between.
+ let is_done = width(l, r) <= 2 * BLOCK;
+
+ if is_done {
+ // Number of remaining elements (still not compared to the pivot).
+ let mut rem = width(l, r);
+ if start_l < end_l || start_r < end_r {
+ rem -= BLOCK;
+ }
+
+ // Adjust block sizes so that the left and right block don't overlap, but get perfectly
+ // aligned to cover the whole remaining gap.
+ if start_l < end_l {
+ block_r = rem;
+ } else if start_r < end_r {
+ block_l = rem;
+ } else {
+ // There were the same number of elements to switch on both blocks during the last
+ // iteration, so there are no remaining elements on either block. Cover the remaining
+ // items with roughly equally-sized blocks.
+ block_l = rem / 2;
+ block_r = rem - block_l;
+ }
+ debug_assert!(block_l <= BLOCK && block_r <= BLOCK);
+ debug_assert!(width(l, r) == block_l + block_r);
+ }
+
+ if start_l == end_l {
+ // Trace `block_l` elements from the left side.
+ // TODO unstable: start_l = MaybeUninit::slice_as_mut_ptr(&mut offsets_l);
+ start_l = offsets_l.as_mut_ptr() as *mut u8;
+ end_l = start_l;
+ let mut elem = l;
+
+ for i in 0..block_l {
+ // SAFETY: The unsafety operations below involve the usage of the `offset`.
+ // According to the conditions required by the function, we satisfy them because:
+ // 1. `offsets_l` is stack-allocated, and thus considered separate allocated object.
+ // 2. The function `is_less` returns a `bool`.
+ // Casting a `bool` will never overflow `isize`.
+ // 3. We have guaranteed that `block_l` will be `<= BLOCK`.
+ // Plus, `end_l` was initially set to the begin pointer of `offsets_` which was declared on the stack.
+ // Thus, we know that even in the worst case (all invocations of `is_less` returns false) we will only be at most 1 byte pass the end.
+ // Another unsafety operation here is dereferencing `elem`.
+ // However, `elem` was initially the begin pointer to the slice which is always valid.
+ unsafe {
+ // Branchless comparison.
+ *end_l = i as u8;
+ end_l = end_l.offset(!is_less(&*elem, pivot) as isize);
+ elem = elem.offset(1);
+ }
+ }
+ }
+
+ if start_r == end_r {
+ // Trace `block_r` elements from the right side.
+ // TODO unstable: start_r = MaybeUninit::slice_as_mut_ptr(&mut offsets_r);
+ start_r = offsets_r.as_mut_ptr() as *mut u8;
+ end_r = start_r;
+ let mut elem = r;
+
+ for i in 0..block_r {
+ // SAFETY: The unsafety operations below involve the usage of the `offset`.
+ // According to the conditions required by the function, we satisfy them because:
+ // 1. `offsets_r` is stack-allocated, and thus considered separate allocated object.
+ // 2. The function `is_less` returns a `bool`.
+ // Casting a `bool` will never overflow `isize`.
+ // 3. We have guaranteed that `block_r` will be `<= BLOCK`.
+ // Plus, `end_r` was initially set to the begin pointer of `offsets_` which was declared on the stack.
+ // Thus, we know that even in the worst case (all invocations of `is_less` returns true) we will only be at most 1 byte pass the end.
+ // Another unsafety operation here is dereferencing `elem`.
+ // However, `elem` was initially `1 * sizeof(T)` past the end and we decrement it by `1 * sizeof(T)` before accessing it.
+ // Plus, `block_r` was asserted to be less than `BLOCK` and `elem` will therefore at most be pointing to the beginning of the slice.
+ unsafe {
+ // Branchless comparison.
+ elem = elem.offset(-1);
+ *end_r = i as u8;
+ end_r = end_r.offset(is_less(&*elem, pivot) as isize);
+ }
+ }
+ }
+
+ // Number of out-of-order elements to swap between the left and right side.
+ let count = cmp::min(width(start_l, end_l), width(start_r, end_r));
+
+ if count > 0 {
+ macro_rules! left {
+ () => {
+ l.offset(*start_l as isize)
+ };
+ }
+ macro_rules! right {
+ () => {
+ r.offset(-(*start_r as isize) - 1)
+ };
+ }
+
+ // Instead of swapping one pair at the time, it is more efficient to perform a cyclic
+ // permutation. This is not strictly equivalent to swapping, but produces a similar
+ // result using fewer memory operations.
+
+ // SAFETY: The use of `ptr::read` is valid because there is at least one element in
+ // both `offsets_l` and `offsets_r`, so `left!` is a valid pointer to read from.
+ //
+ // The uses of `left!` involve calls to `offset` on `l`, which points to the
+ // beginning of `v`. All the offsets pointed-to by `start_l` are at most `block_l`, so
+ // these `offset` calls are safe as all reads are within the block. The same argument
+ // applies for the uses of `right!`.
+ //
+ // The calls to `start_l.offset` are valid because there are at most `count-1` of them,
+ // plus the final one at the end of the unsafe block, where `count` is the minimum number
+ // of collected offsets in `offsets_l` and `offsets_r`, so there is no risk of there not
+ // being enough elements. The same reasoning applies to the calls to `start_r.offset`.
+ //
+ // The calls to `copy_nonoverlapping` are safe because `left!` and `right!` are guaranteed
+ // not to overlap, and are valid because of the reasoning above.
+ unsafe {
+ let tmp = ptr::read(left!());
+ ptr::copy_nonoverlapping(right!(), left!(), 1);
+
+ for _ in 1..count {
+ start_l = start_l.offset(1);
+ ptr::copy_nonoverlapping(left!(), right!(), 1);
+ start_r = start_r.offset(1);
+ ptr::copy_nonoverlapping(right!(), left!(), 1);
+ }
+
+ ptr::copy_nonoverlapping(&tmp, right!(), 1);
+ mem::forget(tmp);
+ start_l = start_l.offset(1);
+ start_r = start_r.offset(1);
+ }
+ }
+
+ if start_l == end_l {
+ // All out-of-order elements in the left block were moved. Move to the next block.
+
+ // block-width-guarantee
+ // SAFETY: if `!is_done` then the slice width is guaranteed to be at least `2*BLOCK` wide. There
+ // are at most `BLOCK` elements in `offsets_l` because of its size, so the `offset` operation is
+ // safe. Otherwise, the debug assertions in the `is_done` case guarantee that
+ // `width(l, r) == block_l + block_r`, namely, that the block sizes have been adjusted to account
+ // for the smaller number of remaining elements.
+ l = unsafe { l.add(block_l) };
+ }
+
+ if start_r == end_r {
+ // All out-of-order elements in the right block were moved. Move to the previous block.
+
+ // SAFETY: Same argument as [block-width-guarantee]. Either this is a full block `2*BLOCK`-wide,
+ // or `block_r` has been adjusted for the last handful of elements.
+ r = unsafe { r.offset(-(block_r as isize)) };
+ }
+
+ if is_done {
+ break;
+ }
+ }
+
+ // All that remains now is at most one block (either the left or the right) with out-of-order
+ // elements that need to be moved. Such remaining elements can be simply shifted to the end
+ // within their block.
+
+ if start_l < end_l {
+ // The left block remains.
+ // Move its remaining out-of-order elements to the far right.
+ debug_assert_eq!(width(l, r), block_l);
+ while start_l < end_l {
+ // remaining-elements-safety
+ // SAFETY: while the loop condition holds there are still elements in `offsets_l`, so it
+ // is safe to point `end_l` to the previous element.
+ //
+ // The `ptr::swap` is safe if both its arguments are valid for reads and writes:
+ // - Per the debug assert above, the distance between `l` and `r` is `block_l`
+ // elements, so there can be at most `block_l` remaining offsets between `start_l`
+ // and `end_l`. This means `r` will be moved at most `block_l` steps back, which
+ // makes the `r.offset` calls valid (at that point `l == r`).
+ // - `offsets_l` contains valid offsets into `v` collected during the partitioning of
+ // the last block, so the `l.offset` calls are valid.
+ unsafe {
+ end_l = end_l.offset(-1);
+ ptr::swap(l.offset(*end_l as isize), r.offset(-1));
+ r = r.offset(-1);
+ }
+ }
+ width(v.as_mut_ptr(), r)
+ } else if start_r < end_r {
+ // The right block remains.
+ // Move its remaining out-of-order elements to the far left.
+ debug_assert_eq!(width(l, r), block_r);
+ while start_r < end_r {
+ // SAFETY: See the reasoning in [remaining-elements-safety].
+ unsafe {
+ end_r = end_r.offset(-1);
+ ptr::swap(l, r.offset(-(*end_r as isize) - 1));
+ l = l.offset(1);
+ }
+ }
+ width(v.as_mut_ptr(), l)
+ } else {
+ // Nothing else to do, we're done.
+ width(v.as_mut_ptr(), l)
+ }
+}
+
+/// Partitions `v` into elements smaller than `v[pivot]`, followed by elements greater than or
+/// equal to `v[pivot]`.
+///
+/// Returns a tuple of:
+///
+/// 1. Number of elements smaller than `v[pivot]`.
+/// 2. True if `v` was already partitioned.
+fn partition<T, F>(v: &mut [T], pivot: usize, is_less: &F) -> (usize, bool)
+where
+ F: Fn(&T, &T) -> bool,
+{
+ let (mid, was_partitioned) = {
+ // Place the pivot at the beginning of slice.
+ v.swap(0, pivot);
+ let (pivot, v) = v.split_at_mut(1);
+ let pivot = &mut pivot[0];
+
+ // Read the pivot into a stack-allocated variable for efficiency. If a following comparison
+ // operation panics, the pivot will be automatically written back into the slice.
+
+ // SAFETY: `pivot` is a reference to the first element of `v`, so `ptr::read` is safe.
+ let tmp = mem::ManuallyDrop::new(unsafe { ptr::read(pivot) });
+ let _pivot_guard = CopyOnDrop {
+ src: &*tmp,
+ dest: pivot,
+ };
+ let pivot = &*tmp;
+
+ // Find the first pair of out-of-order elements.
+ let mut l = 0;
+ let mut r = v.len();
+
+ // SAFETY: The unsafety below involves indexing an array.
+ // For the first one: We already do the bounds checking here with `l < r`.
+ // For the second one: We initially have `l == 0` and `r == v.len()` and we checked that `l < r` at every indexing operation.
+ // From here we know that `r` must be at least `r == l` which was shown to be valid from the first one.
+ unsafe {
+ // Find the first element greater than or equal to the pivot.
+ while l < r && is_less(v.get_unchecked(l), pivot) {
+ l += 1;
+ }
+
+ // Find the last element smaller that the pivot.
+ while l < r && !is_less(v.get_unchecked(r - 1), pivot) {
+ r -= 1;
+ }
+ }
+
+ (
+ l + partition_in_blocks(&mut v[l..r], pivot, is_less),
+ l >= r,
+ )
+
+ // `_pivot_guard` goes out of scope and writes the pivot (which is a stack-allocated
+ // variable) back into the slice where it originally was. This step is critical in ensuring
+ // safety!
+ };
+
+ // Place the pivot between the two partitions.
+ v.swap(0, mid);
+
+ (mid, was_partitioned)
+}
+
+/// Partitions `v` into elements equal to `v[pivot]` followed by elements greater than `v[pivot]`.
+///
+/// Returns the number of elements equal to the pivot. It is assumed that `v` does not contain
+/// elements smaller than the pivot.
+fn partition_equal<T, F>(v: &mut [T], pivot: usize, is_less: &F) -> usize
+where
+ F: Fn(&T, &T) -> bool,
+{
+ // Place the pivot at the beginning of slice.
+ v.swap(0, pivot);
+ let (pivot, v) = v.split_at_mut(1);
+ let pivot = &mut pivot[0];
+
+ // Read the pivot into a stack-allocated variable for efficiency. If a following comparison
+ // operation panics, the pivot will be automatically written back into the slice.
+ // SAFETY: The pointer here is valid because it is obtained from a reference to a slice.
+ let tmp = mem::ManuallyDrop::new(unsafe { ptr::read(pivot) });
+ let _pivot_guard = CopyOnDrop {
+ src: &*tmp,
+ dest: pivot,
+ };
+ let pivot = &*tmp;
+
+ // Now partition the slice.
+ let mut l = 0;
+ let mut r = v.len();
+ loop {
+ // SAFETY: The unsafety below involves indexing an array.
+ // For the first one: We already do the bounds checking here with `l < r`.
+ // For the second one: We initially have `l == 0` and `r == v.len()` and we checked that `l < r` at every indexing operation.
+ // From here we know that `r` must be at least `r == l` which was shown to be valid from the first one.
+ unsafe {
+ // Find the first element greater than the pivot.
+ while l < r && !is_less(pivot, v.get_unchecked(l)) {
+ l += 1;
+ }
+
+ // Find the last element equal to the pivot.
+ while l < r && is_less(pivot, v.get_unchecked(r - 1)) {
+ r -= 1;
+ }
+
+ // Are we done?
+ if l >= r {
+ break;
+ }
+
+ // Swap the found pair of out-of-order elements.
+ r -= 1;
+ let ptr = v.as_mut_ptr();
+ ptr::swap(ptr.add(l), ptr.add(r));
+ l += 1;
+ }
+ }
+
+ // We found `l` elements equal to the pivot. Add 1 to account for the pivot itself.
+ l + 1
+
+ // `_pivot_guard` goes out of scope and writes the pivot (which is a stack-allocated variable)
+ // back into the slice where it originally was. This step is critical in ensuring safety!
+}
+
+/// Scatters some elements around in an attempt to break patterns that might cause imbalanced
+/// partitions in quicksort.
+#[cold]
+fn break_patterns<T>(v: &mut [T]) {
+ let len = v.len();
+ if len >= 8 {
+ // Pseudorandom number generator from the "Xorshift RNGs" paper by George Marsaglia.
+ let mut random = len as u32;
+ let mut gen_u32 = || {
+ random ^= random << 13;
+ random ^= random >> 17;
+ random ^= random << 5;
+ random
+ };
+ let mut gen_usize = || {
+ if usize::BITS <= 32 {
+ gen_u32() as usize
+ } else {
+ (((gen_u32() as u64) << 32) | (gen_u32() as u64)) as usize
+ }
+ };
+
+ // Take random numbers modulo this number.
+ // The number fits into `usize` because `len` is not greater than `isize::MAX`.
+ let modulus = len.next_power_of_two();
+
+ // Some pivot candidates will be in the nearby of this index. Let's randomize them.
+ let pos = len / 4 * 2;
+
+ for i in 0..3 {
+ // Generate a random number modulo `len`. However, in order to avoid costly operations
+ // we first take it modulo a power of two, and then decrease by `len` until it fits
+ // into the range `[0, len - 1]`.
+ let mut other = gen_usize() & (modulus - 1);
+
+ // `other` is guaranteed to be less than `2 * len`.
+ if other >= len {
+ other -= len;
+ }
+
+ v.swap(pos - 1 + i, other);
+ }
+ }
+}
+
+/// Chooses a pivot in `v` and returns the index and `true` if the slice is likely already sorted.
+///
+/// Elements in `v` might be reordered in the process.
+fn choose_pivot<T, F>(v: &mut [T], is_less: &F) -> (usize, bool)
+where
+ F: Fn(&T, &T) -> bool,
+{
+ // Minimum length to choose the median-of-medians method.
+ // Shorter slices use the simple median-of-three method.
+ const SHORTEST_MEDIAN_OF_MEDIANS: usize = 50;
+ // Maximum number of swaps that can be performed in this function.
+ const MAX_SWAPS: usize = 4 * 3;
+
+ let len = v.len();
+
+ // Three indices near which we are going to choose a pivot.
+ #[allow(clippy::identity_op)]
+ let mut a = len / 4 * 1;
+ let mut b = len / 4 * 2;
+ let mut c = len / 4 * 3;
+
+ // Counts the total number of swaps we are about to perform while sorting indices.
+ let mut swaps = 0;
+
+ if len >= 8 {
+ // Swaps indices so that `v[a] <= v[b]`.
+ // SAFETY: `len >= 8` so there are at least two elements in the neighborhoods of
+ // `a`, `b` and `c`. This means the three calls to `sort_adjacent` result in
+ // corresponding calls to `sort3` with valid 3-item neighborhoods around each
+ // pointer, which in turn means the calls to `sort2` are done with valid
+ // references. Thus the `v.get_unchecked` calls are safe, as is the `ptr::swap`
+ // call.
+ let mut sort2 = |a: &mut usize, b: &mut usize| unsafe {
+ if is_less(v.get_unchecked(*b), v.get_unchecked(*a)) {
+ ptr::swap(a, b);
+ swaps += 1;
+ }
+ };
+
+ // Swaps indices so that `v[a] <= v[b] <= v[c]`.
+ let mut sort3 = |a: &mut usize, b: &mut usize, c: &mut usize| {
+ sort2(a, b);
+ sort2(b, c);
+ sort2(a, b);
+ };
+
+ if len >= SHORTEST_MEDIAN_OF_MEDIANS {
+ // Finds the median of `v[a - 1], v[a], v[a + 1]` and stores the index into `a`.
+ let mut sort_adjacent = |a: &mut usize| {
+ let tmp = *a;
+ sort3(&mut (tmp - 1), a, &mut (tmp + 1));
+ };
+
+ // Find medians in the neighborhoods of `a`, `b`, and `c`.
+ sort_adjacent(&mut a);
+ sort_adjacent(&mut b);
+ sort_adjacent(&mut c);
+ }
+
+ // Find the median among `a`, `b`, and `c`.
+ sort3(&mut a, &mut b, &mut c);
+ }
+
+ if swaps < MAX_SWAPS {
+ (b, swaps == 0)
+ } else {
+ // The maximum number of swaps was performed. Chances are the slice is descending or mostly
+ // descending, so reversing will probably help sort it faster.
+ v.reverse();
+ (len - 1 - b, true)
+ }
+}
+
+/// Sorts `v` recursively.
+///
+/// If the slice had a predecessor in the original array, it is specified as `pred`.
+///
+/// `limit` is the number of allowed imbalanced partitions before switching to `heapsort`. If zero,
+/// this function will immediately switch to heapsort.
+fn recurse<'a, T, F>(
+ mut v: &'a mut [T],
+ is_less: &F,
+ mut pred: Option<&'a mut T>,
+ mut limit: u32,
+ canceled: &AtomicBool,
+) -> bool
+where
+ T: Send,
+ F: Fn(&T, &T) -> bool + Sync,
+{
+ // Slices of up to this length get sorted using insertion sort.
+ const MAX_INSERTION: usize = 20;
+ // If both partitions are up to this length, we continue sequentially. This number is as small
+ // as possible but so that the overhead of Rayon's task scheduling is still negligible.
+ const MAX_SEQUENTIAL: usize = 2000;
+
+ // True if the last partitioning was reasonably balanced.
+ let mut was_balanced = true;
+ // True if the last partitioning didn't shuffle elements (the slice was already partitioned).
+ let mut was_partitioned = true;
+
+ loop {
+ let len = v.len();
+
+ // Very short slices get sorted using insertion sort.
+ if len <= MAX_INSERTION {
+ insertion_sort(v, is_less);
+ return false;
+ }
+
+ // If too many bad pivot choices were made, simply fall back to heapsort in order to
+ // guarantee `O(n * log(n))` worst-case.
+ if limit == 0 {
+ heapsort(v, is_less);
+ return false;
+ }
+
+ // If the last partitioning was imbalanced, try breaking patterns in the slice by shuffling
+ // some elements around. Hopefully we'll choose a better pivot this time.
+ if !was_balanced {
+ break_patterns(v);
+ limit -= 1;
+ }
+
+ // Choose a pivot and try guessing whether the slice is already sorted.
+ let (pivot, likely_sorted) = choose_pivot(v, is_less);
+
+ // If the last partitioning was decently balanced and didn't shuffle elements, and if pivot
+ // selection predicts the slice is likely already sorted...
+ if was_balanced && was_partitioned && likely_sorted {
+ // Try identifying several out-of-order elements and shifting them to correct
+ // positions. If the slice ends up being completely sorted, we're done.
+ if partial_insertion_sort(v, is_less) {
+ return false;
+ }
+ }
+
+ // If the chosen pivot is equal to the predecessor, then it's the smallest element in the
+ // slice. Partition the slice into elements equal to and elements greater than the pivot.
+ // This case is usually hit when the slice contains many duplicate elements.
+ if let Some(ref p) = pred {
+ if !is_less(p, &v[pivot]) {
+ let mid = partition_equal(v, pivot, is_less);
+
+ // Continue sorting elements greater than the pivot.
+ v = &mut v[mid..];
+ continue;
+ }
+ }
+
+ // Partition the slice.
+ let (mid, was_p) = partition(v, pivot, is_less);
+ was_balanced = cmp::min(mid, len - mid) >= len / 8;
+ was_partitioned = was_p;
+
+ // Split the slice into `left`, `pivot`, and `right`.
+ let (left, right) = v.split_at_mut(mid);
+ let (pivot, right) = right.split_at_mut(1);
+ let pivot = &mut pivot[0];
+
+ if cmp::max(left.len(), right.len()) <= MAX_SEQUENTIAL {
+ // Recurse into the shorter side only in order to minimize the total number of recursive
+ // calls and consume less stack space. Then just continue with the longer side (this is
+ // akin to tail recursion).
+ if left.len() < right.len() {
+ recurse(left, is_less, pred, limit, canceled);
+ v = right;
+ pred = Some(pivot);
+ } else {
+ recurse(right, is_less, Some(pivot), limit, canceled);
+ v = left;
+ }
+ } else if canceled.load(atomic::Ordering::Relaxed) {
+ break true;
+ } else {
+ // Sort the left and right half in parallel.
+ let (canceled1, canceled2) = rayon::join(
+ || recurse(left, is_less, pred, limit, canceled),
+ || recurse(right, is_less, Some(pivot), limit, canceled),
+ );
+ break canceled1 | canceled2;
+ }
+ }
+}
+
+/// Sorts `v` using pattern-defeating quicksort in parallel.
+///
+/// The algorithm is unstable, in-place, and *O*(*n* \* log(*n*)) worst-case.
+pub(crate) fn par_quicksort<T, F>(v: &mut [T], is_less: F, canceled: &AtomicBool) -> bool
+where
+ T: Send,
+ F: Fn(&T, &T) -> bool + Sync,
+{
+ // Sorting has no meaningful behavior on zero-sized types.
+ if mem::size_of::<T>() == 0 {
+ return false;
+ }
+ if canceled.load(atomic::Ordering::Relaxed) {
+ return true;
+ }
+
+ // Limit the number of imbalanced partitions to `floor(log2(len)) + 1`.
+ let limit = usize::BITS - v.len().leading_zeros();
+
+ recurse(v, &is_less, None, limit, canceled)
+}
diff --git a/src/pattern.rs b/src/pattern.rs
new file mode 100644
index 00000000..816b0a31
--- /dev/null
+++ b/src/pattern.rs
@@ -0,0 +1,100 @@
+pub use nucleo_matcher::pattern::{Atom, AtomKind, CaseMatching, Normalization, Pattern};
+use nucleo_matcher::{Matcher, Utf32String};
+
+#[cfg(test)]
+mod tests;
+
+#[derive(Debug, PartialEq, Eq, Clone, Copy, PartialOrd, Ord, Default)]
+pub(crate) enum Status {
+ #[default]
+ Unchanged,
+ Update,
+ Rescore,
+}
+
+#[derive(Debug)]
+pub struct MultiPattern {
+ cols: Vec<(Pattern, Status)>,
+}
+
+impl Clone for MultiPattern {
+ fn clone(&self) -> Self {
+ Self {
+ cols: self.cols.clone(),
+ }
+ }
+
+ fn clone_from(&mut self, source: &Self) {
+ self.cols.clone_from(&source.cols)
+ }
+}
+
+impl MultiPattern {
+ /// Creates a multi pattern with `columns` empty column patterns.
+ pub fn new(columns: usize) -> Self {
+ Self {
+ cols: vec![Default::default(); columns],
+ }
+ }
+
+ /// Reparses a column. By specifying `append` the caller promises that text passed
+ /// to the previous `reparse` invocation is a prefix of `new_text`. This enables
+ /// additional optimizations but can lead to missing matches if an incorrect value
+ /// is passed.
+ pub fn reparse(
+ &mut self,
+ column: usize,
+ new_text: &str,
+ case_matching: CaseMatching,
+ normalization: Normalization,
+ append: bool,
+ ) {
+ let old_status = self.cols[column].1;
+ if append
+ && old_status != Status::Rescore
+ && self.cols[column]
+ .0
+ .atoms
+ .last()
+ .map_or(true, |last| !last.negative)
+ {
+ self.cols[column].1 = Status::Update;
+ } else {
+ self.cols[column].1 = Status::Rescore;
+ }
+ self.cols[column]
+ .0
+ .reparse(new_text, case_matching, normalization);
+ }
+
+ pub fn column_pattern(&self, column: usize) -> &Pattern {
+ &self.cols[column].0
+ }
+
+ pub(crate) fn status(&self) -> Status {
+ self.cols
+ .iter()
+ .map(|&(_, status)| status)
+ .max()
+ .unwrap_or(Status::Unchanged)
+ }
+
+ pub(crate) fn reset_status(&mut self) {
+ for (_, status) in &mut self.cols {
+ *status = Status::Unchanged
+ }
+ }
+
+ pub fn score(&self, haystack: &[Utf32String], matcher: &mut Matcher) -> Option<u32> {
+ // TODO: wheight columns?
+ let mut score = 0;
+ for ((pattern, _), haystack) in self.cols.iter().zip(haystack) {
+ score += pattern.score(haystack.slice(..), matcher)?
+ }
+ Some(score)
+ }
+
+ pub fn is_empty(&self) -> bool {
+ self.cols.iter().all(|(pat, _)| pat.atoms.is_empty())
+ }
+}
diff --git a/src/pattern/tests.rs b/src/pattern/tests.rs
new file mode 100644
index 00000000..40e8e328
--- /dev/null
+++ b/src/pattern/tests.rs
@@ -0,0 +1,14 @@
+use nucleo_matcher::pattern::{CaseMatching, Normalization};
+
+use crate::pattern::{MultiPattern, Status};
+
+#[test]
+fn append() {
+ let mut pat = MultiPattern::new(1);
+ pat.reparse(0, "!", CaseMatching::Smart, Normalization::Smart, true);
+ assert_eq!(pat.status(), Status::Update);
+ pat.reparse(0, "!f", CaseMatching::Smart, Normalization::Smart, true);
+ assert_eq!(pat.status(), Status::Update);
+ pat.reparse(0, "!fo", CaseMatching::Smart, Normalization::Smart, true);
+ assert_eq!(pat.status(), Status::Rescore);
+}
diff --git a/src/tests.rs b/src/tests.rs
new file mode 100644
index 00000000..676c50df
--- /dev/null
+++ b/src/tests.rs
@@ -0,0 +1,27 @@
+use std::sync::Arc;
+
+use nucleo_matcher::Config;
+
+use crate::Nucleo;
+
+#[test]
+fn active_injector_count() {
+ let mut nucleo: Nucleo<()> = Nucleo::new(Config::DEFAULT, Arc::new(|| ()), Some(1), 1);
+ assert_eq!(nucleo.active_injectors(), 0);
+ let injector = nucleo.injector();
+ assert_eq!(nucleo.active_injectors(), 1);
+ let injector2 = nucleo.injector();
+ assert_eq!(nucleo.active_injectors(), 2);
+ drop(injector2);
+ assert_eq!(nucleo.active_injectors(), 1);
+ nucleo.restart(false);
+ assert_eq!(nucleo.active_injectors(), 0);
+ let injector3 = nucleo.injector();
+ assert_eq!(nucleo.active_injectors(), 1);
+ nucleo.tick(0);
+ assert_eq!(nucleo.active_injectors(), 1);
+ drop(injector);
+ assert_eq!(nucleo.active_injectors(), 1);
+ drop(injector3);
+ assert_eq!(nucleo.active_injectors(), 0);
+}
diff --git a/src/worker.rs b/src/worker.rs
new file mode 100644
index 00000000..f4077e6e
--- /dev/null
+++ b/src/worker.rs
@@ -0,0 +1,301 @@
+use std::cell::UnsafeCell;
+use std::mem::take;
+use std::sync::atomic::{self, AtomicBool, AtomicU32};
+use std::sync::Arc;
+
+use nucleo_matcher::Config;
+use parking_lot::Mutex;
+use rayon::{prelude::*, ThreadPool};
+
+use crate::par_sort::par_quicksort;
+use crate::pattern::{self, MultiPattern};
+use crate::{boxcar, Match};
+
+struct Matchers(Box<[UnsafeCell<nucleo_matcher::Matcher>]>);
+
+impl Matchers {
+ // this is not a true mut from ref, we use a cell here
+ #[allow(clippy::mut_from_ref)]
+ unsafe fn get(&self) -> &mut nucleo_matcher::Matcher {
+ &mut *self.0[rayon::current_thread_index().unwrap()].get()
+ }
+}
+
+unsafe impl Sync for Matchers {}
+unsafe impl Send for Matchers {}
+
+pub(crate) struct Worker<T: Sync + Send + 'static> {
+ pub(crate) running: bool,
+ matchers: Matchers,
+ pub(crate) matches: Vec<Match>,
+ pub(crate) pattern: MultiPattern,
+ pub(crate) sort_results: bool,
+ pub(crate) reverse_items: bool,
+ pub(crate) canceled: Arc<AtomicBool>,
+ pub(crate) should_notify: Arc<AtomicBool>,
+ pub(crate) was_canceled: bool,
+ pub(crate) last_snapshot: u32,
+ notify: Arc<(dyn Fn() + Sync + Send)>,
+ pub(crate) items: Arc<boxcar::Vec<T>>,
+ in_flight: Vec<u32>,
+}
+
+impl<T: Sync + Send + 'static> Worker<T> {
+ pub(crate) fn item_count(&self) -> u32 {
+ self.last_snapshot - self.in_flight.len() as u32
+ }
+ pub(crate) fn update_config(&mut self, config: Config) {
+ for matcher in self.matchers.0.iter_mut() {
+ matcher.get_mut().config = config.clone();
+ }
+ }
+ pub(crate) fn sort_results(&mut self, sort_results: bool) {
+ self.sort_results = sort_results;
+ }
+ pub(crate) fn reverse_items(&mut self, reverse_items: bool) {
+ self.reverse_items = reverse_items;
+ }
+
+ pub(crate) fn new(
+ worker_threads: Option<usize>,
+ config: Config,
+ notify: Arc<(dyn Fn() + Sync + Send)>,
+ cols: u32,
+ ) -> (ThreadPool, Self) {
+ let worker_threads = worker_threads
+ .unwrap_or_else(|| std::thread::available_parallelism().map_or(4, |it| it.get()));
+ let pool = rayon::ThreadPoolBuilder::new()
+ .thread_name(|i| format!("nucleo worker {i}"))
+ .num_threads(worker_threads)
+ .build()
+ .expect("creating threadpool failed");
+ let matchers = (0..worker_threads)
+ .map(|_| UnsafeCell::new(nucleo_matcher::Matcher::new(config.clone())))
+ .collect();
+ let worker = Worker {
+ running: false,
+ matchers: Matchers(matchers),
+ last_snapshot: 0,
+ matches: Vec::new(),
+ // just a placeholder
+ pattern: MultiPattern::new(cols as usize),
+ sort_results: true,
+ reverse_items: false,
+ canceled: Arc::new(AtomicBool::new(false)),
+ should_notify: Arc::new(AtomicBool::new(false)),
+ was_canceled: false,
+ notify,
+ items: Arc::new(boxcar::Vec::with_capacity(2 * 1024, cols)),
+ in_flight: Vec::with_capacity(64),
+ };
+ (pool, worker)
+ }
+
+ unsafe fn process_new_items(&mut self, unmatched: &AtomicU32) {
+ let matchers = &self.matchers;
+ let pattern = &self.pattern;
+ self.matches.reserve(self.in_flight.len());
+ self.in_flight.retain(|&idx| {
+ let Some(item) = self.items.get(idx) else {
+ return true;
+ };
+ if let Some(score) = pattern.score(item.matcher_columns, matchers.get()) {
+ self.matches.push(Match { score, idx });
+ };
+ false
+ });
+ let new_snapshot = self.items.par_snapshot(self.last_snapshot);
+ if new_snapshot.end() != self.last_snapshot {
+ let end = new_snapshot.end();
+ let in_flight = Mutex::new(&mut self.in_flight);
+ let items = new_snapshot.map(|(idx, item)| {
+ let Some(item) = item else {
+ in_flight.lock().push(idx);
+ unmatched.fetch_add(1, atomic::Ordering::Relaxed);
+ return Match {
+ score: 0,
+ idx: u32::MAX,
+ };
+ };
+ if self.canceled.load(atomic::Ordering::Relaxed) {
+ return Match { score: 0, idx };
+ }
+ let Some(score) = pattern.score(item.matcher_columns, matchers.get()) else {
+ unmatched.fetch_add(1, atomic::Ordering::Relaxed);
+ return Match {
+ score: 0,
+ idx: u32::MAX,
+ };
+ };
+ Match { score, idx }
+ });
+ self.matches.par_extend(items);
+ self.last_snapshot = end;
+ }
+ }
+
+ fn remove_in_flight_matches(&mut self) {
+ let mut off = 0;
+ self.in_flight.retain(|&i| {
+ let is_in_flight = self.items.get(i).is_none();
+ if is_in_flight {
+ self.matches.remove((i - off) as usize);
+ off += 1;
+ }
+ is_in_flight
+ });
+ }
+
+ unsafe fn process_new_items_trivial(&mut self) {
+ let new_snapshot = self.items.snapshot(self.last_snapshot);
+ if new_snapshot.end() != self.last_snapshot {
+ let end = new_snapshot.end();
+ let items = new_snapshot.filter_map(|(idx, item)| {
+ if item.is_none() {
+ self.in_flight.push(idx);
+ return None;
+ };
+ Some(Match { score: 0, idx })
+ });
+ self.matches.extend(items);
+ self.last_snapshot = end;
+ }
+ }
+
+ pub(crate) unsafe fn run(&mut self, pattern_status: pattern::Status, cleared: bool) {
+ self.running = true;
+ self.was_canceled = false;
+
+ if cleared {
+ self.last_snapshot = 0;
+ self.in_flight.clear();
+ self.matches.clear();
+ }
+
+ // TODO: be smarter around reusing past results for rescoring
+ if self.pattern.is_empty() {
+ self.reset_matches();
+ self.process_new_items_trivial();
+ let canceled = self.sort_matches();
+ if canceled {
+ self.was_canceled = true;
+ } else if self.should_notify.load(atomic::Ordering::Relaxed) {
+ (self.notify)();
+ }
+ return;
+ }
+
+ if pattern_status == pattern::Status::Rescore {
+ self.reset_matches();
+ }
+
+ let mut unmatched = AtomicU32::new(0);
+ if pattern_status != pattern::Status::Unchanged && !self.matches.is_empty() {
+ self.process_new_items_trivial();
+ let matchers = &self.matchers;
+ let pattern = &self.pattern;
+ self.matches
+ .par_iter_mut()
+ .take_any_while(|_| !self.canceled.load(atomic::Ordering::Relaxed))
+ .for_each(|match_| {
+ if match_.idx == u32::MAX {
+ debug_assert_eq!(match_.score, 0);
+ unmatched.fetch_add(1, atomic::Ordering::Relaxed);
+ return;
+ }
+ // safety: in-flight items are never added to the matches
+ let item = self.items.get_unchecked(match_.idx);
+ if let Some(score) = pattern.score(item.matcher_columns, matchers.get()) {
+ match_.score = score;
+ } else {
+ unmatched.fetch_add(1, atomic::Ordering::Relaxed);
+ match_.score = 0;
+ match_.idx = u32::MAX;
+ }
+ });
+ } else {
+ self.process_new_items(&unmatched);
+ }
+
+ let canceled = self.sort_matches();
+ if canceled {
+ self.was_canceled = true;
+ } else {
+ self.matches
+ .truncate(self.matches.len() - take(unmatched.get_mut()) as usize);
+ if self.should_notify.load(atomic::Ordering::Relaxed) {
+ (self.notify)();
+ }
+ }
+ }
+
+ unsafe fn sort_matches(&mut self) -> bool {
+ if self.sort_results {
+ par_quicksort(
+ &mut self.matches,
+ |match1, match2| {
+ if match1.score != match2.score {
+ return match1.score > match2.score;
+ }
+ if match1.idx == u32::MAX {
+ return false;
+ }
+ if match2.idx == u32::MAX {
+ return true;
+ }
+ // the tie breaker is comparatively rarely needed so we keep it
+ // in a branch especially because we need to access the items
+ // array here which involves some pointer chasing
+ let item1 = self.items.get_unchecked(match1.idx);
+ let item2 = &self.items.get_unchecked(match2.idx);
+ let len1: u32 = item1
+ .matcher_columns
+ .iter()
+ .map(|haystack| haystack.len() as u32)
+ .sum();
+ let len2 = item2
+ .matcher_columns
+ .iter()
+ .map(|haystack| haystack.len() as u32)
+ .sum();
+ if len1 == len2 {
+ if self.reverse_items {
+ match2.idx < match1.idx
+ } else {
+ match1.idx < match2.idx
+ }
+ } else {
+ len1 < len2
+ }
+ },
+ &self.canceled,
+ )
+ } else {
+ par_quicksort(
+ &mut self.matches,
+ |match1, match2| {
+ if match1.idx == u32::MAX {
+ return false;
+ }
+ if match2.idx == u32::MAX {
+ return true;
+ }
+ if self.reverse_items {
+ match2.idx < match1.idx
+ } else {
+ match1.idx < match2.idx
+ }
+ },
+ &self.canceled,
+ )
+ }
+ }
+
+ fn reset_matches(&mut self) {
+ self.matches.clear();
+ self.matches
+ .extend((0..self.last_snapshot).map(|idx| Match { score: 0, idx }));
+ // there are usually only very few in flight items (one for each writer)
+ self.remove_in_flight_matches();
+ }
+}