diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/boxcar.rs | 786 | ||||
| -rw-r--r-- | src/lib.rs | 462 | ||||
| -rw-r--r-- | src/par_sort.rs | 895 | ||||
| -rw-r--r-- | src/pattern.rs | 100 | ||||
| -rw-r--r-- | src/pattern/tests.rs | 14 | ||||
| -rw-r--r-- | src/tests.rs | 27 | ||||
| -rw-r--r-- | src/worker.rs | 301 |
7 files changed, 2585 insertions, 0 deletions
diff --git a/src/boxcar.rs b/src/boxcar.rs new file mode 100644 index 00000000..9b48809d --- /dev/null +++ b/src/boxcar.rs @@ -0,0 +1,786 @@ +//! Adapted from the `boxcar` crate at <https://github.com/ibraheemdev/boxcar/blob/master/src/raw.rs> +//! under MIT licenes: +//! +//! Copyright (c) 2022 Ibraheem Ahmed +//! +//! Permission is hereby granted, free of charge, to any person obtaining a copy +//! of this software and associated documentation files (the "Software"), to deal +//! in the Software without restriction, including without limitation the rights +//! to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +//! copies of the Software, and to permit persons to whom the Software is +//! furnished to do so, subject to the following conditions: +//! +//! The above copyright notice and this permission notice shall be included in all +//! copies or substantial portions of the Software. +//! +//! THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +//! IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +//! FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +//! AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +//! LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +//! OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +//! SOFTWARE. + +use std::alloc::Layout; +use std::cell::UnsafeCell; +use std::fmt::Debug; +use std::mem::MaybeUninit; +use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicU64, Ordering}; +use std::{ptr, slice}; + +use crate::{Item, Utf32String}; + +const BUCKETS: u32 = u32::BITS - SKIP_BUCKET; +const MAX_ENTRIES: u32 = u32::MAX - SKIP; + +/// A lock-free, append-only vector. +pub(crate) struct Vec<T> { + /// a counter used to retrieve a unique index to push to. + /// + /// this value may be more than the true length as it will + /// be incremented before values are actually stored. + inflight: AtomicU64, + /// buckets of length 32, 64 .. 2^31 + buckets: [Bucket<T>; BUCKETS as usize], + /// the number of matcher columns in this vector, its absolutely critical that + /// this remains constant and after initilaziaton (safety invariant) since + /// it is used to calculate the Entry layout + columns: u32, +} + +impl<T> Vec<T> { + /// Constructs a new, empty `Vec<T>` with the specified capacity and matcher columns. + pub fn with_capacity(capacity: u32, columns: u32) -> Vec<T> { + assert_ne!(columns, 0, "there must be atleast one matcher column"); + let init = match capacity { + 0 => 0, + // initialize enough buckets for `capacity` elements + n => Location::of(n).bucket, + }; + + let mut buckets = [ptr::null_mut(); BUCKETS as usize]; + + for (i, bucket) in buckets[..=init as usize].iter_mut().enumerate() { + let len = Location::bucket_len(i as u32); + *bucket = unsafe { Bucket::alloc(len, columns) }; + } + + Vec { + buckets: buckets.map(Bucket::new), + inflight: AtomicU64::new(0), + columns, + } + } + pub fn columns(&self) -> u32 { + self.columns + } + + /// Returns the number of elements in the vector. + #[inline] + pub fn count(&self) -> u32 { + self.inflight + .load(Ordering::Acquire) + .min(MAX_ENTRIES as u64) as u32 + } + + // Returns a reference to the element at the given index. + // + // # Safety + // + // Entry at `index` must be initialized. + #[inline] + pub unsafe fn get_unchecked(&self, index: u32) -> Item<'_, T> { + let location = Location::of(index); + + unsafe { + let entries = self + .buckets + .get_unchecked(location.bucket as usize) + .entries + .load(Ordering::Relaxed); + debug_assert!(!entries.is_null()); + let entry = Bucket::<T>::get(entries, location.entry, self.columns); + // this looks odd but is necessary to ensure cross + // thread synchronization (essentially acting as a memory barrier) + // since the caller must only guarantee that he has observed active on any thread + // but the current thread might still have an old value cached (although unlikely) + let _ = (*entry).active.load(Ordering::Acquire); + Entry::read(entry, self.columns) + } + } + + /// Returns a reference to the element at the given index. + pub fn get(&self, index: u32) -> Option<Item<'_, T>> { + let location = Location::of(index); + + unsafe { + // safety: `location.bucket` is always in bounds + let entries = self + .buckets + .get_unchecked(location.bucket as usize) + .entries + .load(Ordering::Relaxed); + + // bucket is uninitialized + if entries.is_null() { + return None; + } + + // safety: `location.entry` is always in bounds for it's bucket + let entry = Bucket::<T>::get(entries, location.entry, self.columns); + + // safety: the entry is active + (*entry) + .active + .load(Ordering::Acquire) + .then(|| Entry::read(entry, self.columns)) + } + } + + /// Appends an element to the back of the vector. + pub fn push(&self, value: T, fill_columns: impl FnOnce(&T, &mut [Utf32String])) -> u32 { + let index = self.inflight.fetch_add(1, Ordering::Release); + // the inflight counter is a `u64` to catch overflows of the vector'scapacity + let index: u32 = index.try_into().expect("overflowed maximum capacity"); + let location = Location::of(index); + + // eagerly allocate the next bucket if we are close to the end of this one + if index == (location.bucket_len - (location.bucket_len >> 3)) { + if let Some(next_bucket) = self.buckets.get(location.bucket as usize + 1) { + Vec::get_or_alloc(next_bucket, location.bucket_len << 1, self.columns); + } + } + + // safety: `location.bucket` is always in bounds + let bucket = unsafe { self.buckets.get_unchecked(location.bucket as usize) }; + let mut entries = bucket.entries.load(Ordering::Acquire); + + // the bucket has not been allocated yet + if entries.is_null() { + entries = Vec::get_or_alloc(bucket, location.bucket_len, self.columns); + } + + unsafe { + // safety: `location.entry` is always in bounds for it's bucket + let entry = Bucket::get(entries, location.entry, self.columns); + + // safety: we have unique access to this entry. + // + // 1. it is impossible for another thread to attempt a `push` + // to this location as we retrieved it from `inflight.fetch_add` + // + // 2. any thread trying to `get` this entry will see `active == false`, + // and will not try to access it + for col in Entry::matcher_cols_raw(entry, self.columns) { + col.get().write(MaybeUninit::new(Utf32String::default())) + } + fill_columns(&value, Entry::matcher_cols_mut(entry, self.columns)); + (*entry).slot.get().write(MaybeUninit::new(value)); + // let other threads know that this entry is active + (*entry).active.store(true, Ordering::Release); + } + + index + } + + /// Extends the vector by appending multiple elements at once. + pub fn extend<I>(&self, values: I, fill_columns: impl Fn(&T, &mut [Utf32String])) + where + I: IntoIterator<Item = T> + ExactSizeIterator, + { + let count: u32 = values + .len() + .try_into() + .expect("overflowed maximum capacity"); + if count == 0 { + assert!( + values.into_iter().next().is_none(), + "The `values` variable reported incorrect length." + ); + return; + } + + // Reserve all indices at once + let start_index: u32 = self + .inflight + .fetch_add(u64::from(count), Ordering::Release) + .try_into() + .expect("overflowed maximum capacity"); + + // Compute first and last locations + let start_location = Location::of(start_index); + let end_location = Location::of(start_index + count); + + // Eagerly allocate the next bucket if the last entry is close to the end of its next bucket + let alloc_entry = end_location.alloc_next_bucket_entry(); + if end_location.entry >= alloc_entry + && (start_location.bucket != end_location.bucket || start_location.entry <= alloc_entry) + { + // This might be the last bucket, hence the check + if let Some(next_bucket) = self.buckets.get(end_location.bucket as usize + 1) { + Vec::get_or_alloc(next_bucket, end_location.bucket_len << 1, self.columns); + } + } + + let mut bucket = unsafe { self.buckets.get_unchecked(start_location.bucket as usize) }; + let mut entries = bucket.entries.load(Ordering::Acquire); + if entries.is_null() { + entries = Vec::get_or_alloc( + bucket, + Location::bucket_len(start_location.bucket), + self.columns, + ); + } + // Route each value to its corresponding bucket + let mut location; + let count = count as usize; + for (i, v) in values.into_iter().enumerate() { + // ExactSizeIterator is a safe trait that can have bugs/lie about it's size. + // Unsafe code cannot rely on the reported length being correct. + assert!(i < count); + + location = + Location::of(start_index + u32::try_from(i).expect("overflowed maximum capacity")); + + // if we're starting to insert into a different bucket, allocate it beforehand + if location.entry == 0 && i != 0 { + // safety: `location.bucket` is always in bounds + bucket = unsafe { self.buckets.get_unchecked(location.bucket as usize) }; + entries = bucket.entries.load(Ordering::Acquire); + + if entries.is_null() { + entries = Vec::get_or_alloc( + bucket, + Location::bucket_len(location.bucket), + self.columns, + ); + } + } + + unsafe { + let entry = Bucket::get(entries, location.entry, self.columns); + + // Initialize matcher columns + for col in Entry::matcher_cols_raw(entry, self.columns) { + col.get().write(MaybeUninit::new(Utf32String::default())); + } + fill_columns(&v, Entry::matcher_cols_mut(entry, self.columns)); + (*entry).slot.get().write(MaybeUninit::new(v)); + (*entry).active.store(true, Ordering::Release); + } + } + } + + /// race to initialize a bucket + fn get_or_alloc(bucket: &Bucket<T>, len: u32, cols: u32) -> *mut Entry<T> { + let entries = unsafe { Bucket::alloc(len, cols) }; + match bucket.entries.compare_exchange( + ptr::null_mut(), + entries, + Ordering::Release, + Ordering::Acquire, + ) { + Ok(_) => entries, + Err(found) => unsafe { + Bucket::dealloc(entries, len, cols); + found + }, + } + } + + /// Returns an iterator over the vector starting at `start` + /// the iterator is deterministically sized and will not grow + /// as more elements are pushed + pub unsafe fn snapshot(&self, start: u32) -> Iter<'_, T> { + let end = self + .inflight + .load(Ordering::Acquire) + .min(MAX_ENTRIES as u64) as u32; + assert!(start <= end, "index {start} is out of bounds!"); + Iter { + location: Location::of(start), + vec: self, + idx: start, + end, + } + } + + /// Returns an iterator over the vector starting at `start` + /// the iterator is deterministically sized and will not grow + /// as more elements are pushed + pub unsafe fn par_snapshot(&self, start: u32) -> ParIter<'_, T> { + let end = self + .inflight + .load(Ordering::Acquire) + .min(MAX_ENTRIES as u64) as u32; + assert!(start <= end, "index {start} is out of bounds!"); + + ParIter { + start, + end, + vec: self, + } + } +} + +impl<T> Drop for Vec<T> { + fn drop(&mut self) { + for (i, bucket) in self.buckets.iter_mut().enumerate() { + let entries = *bucket.entries.get_mut(); + + if entries.is_null() { + break; + } + + let len = Location::bucket_len(i as u32); + // safety: in drop + unsafe { Bucket::dealloc(entries, len, self.columns) } + } + } +} +type SnapshotItem<'v, T> = (u32, Option<Item<'v, T>>); + +pub struct Iter<'v, T> { + location: Location, + idx: u32, + end: u32, + vec: &'v Vec<T>, +} +impl<T> Iter<'_, T> { + pub fn end(&self) -> u32 { + self.end + } +} + +impl<'v, T> Iterator for Iter<'v, T> { + type Item = SnapshotItem<'v, T>; + fn size_hint(&self) -> (usize, Option<usize>) { + ( + (self.end - self.idx) as usize, + Some((self.end - self.idx) as usize), + ) + } + + fn next(&mut self) -> Option<SnapshotItem<'v, T>> { + if self.end == self.idx { + return None; + } + debug_assert!(self.idx < self.end, "huh {} {}", self.idx, self.end); + debug_assert!(self.end as u64 <= self.vec.inflight.load(Ordering::Relaxed)); + + loop { + let entries = unsafe { + self.vec + .buckets + .get_unchecked(self.location.bucket as usize) + .entries + .load(Ordering::Relaxed) + }; + debug_assert!(self.location.bucket < BUCKETS); + + if self.location.entry < self.location.bucket_len { + if entries.is_null() { + // we still want to yield these + let index = self.idx; + self.location.entry += 1; + self.idx += 1; + return Some((index, None)); + } + // safety: bounds and null checked above + let entry = unsafe { Bucket::get(entries, self.location.entry, self.vec.columns) }; + let index = self.idx; + self.location.entry += 1; + self.idx += 1; + + let entry = unsafe { + (*entry) + .active + .load(Ordering::Acquire) + .then(|| Entry::read(entry, self.vec.columns)) + }; + return Some((index, entry)); + } + + self.location.entry = 0; + self.location.bucket += 1; + + if self.location.bucket < BUCKETS { + self.location.bucket_len = Location::bucket_len(self.location.bucket); + } + } + } +} +impl<T> ExactSizeIterator for Iter<'_, T> {} +impl<T> DoubleEndedIterator for Iter<'_, T> { + fn next_back(&mut self) -> Option<Self::Item> { + unimplemented!() + } +} + +pub struct ParIter<'v, T> { + end: u32, + start: u32, + vec: &'v Vec<T>, +} +impl<T> ParIter<'_, T> { + pub fn end(&self) -> u32 { + self.end + } +} + +impl<'v, T: Send + Sync> rayon::iter::ParallelIterator for ParIter<'v, T> { + type Item = SnapshotItem<'v, T>; + + fn drive_unindexed<C>(self, consumer: C) -> C::Result + where + C: rayon::iter::plumbing::UnindexedConsumer<Self::Item>, + { + rayon::iter::plumbing::bridge(self, consumer) + } + + fn opt_len(&self) -> Option<usize> { + Some((self.end - self.start) as usize) + } +} + +impl<T: Send + Sync> rayon::iter::IndexedParallelIterator for ParIter<'_, T> { + fn len(&self) -> usize { + (self.end - self.start) as usize + } + + fn drive<C: rayon::iter::plumbing::Consumer<Self::Item>>(self, consumer: C) -> C::Result { + rayon::iter::plumbing::bridge(self, consumer) + } + + fn with_producer<CB>(self, callback: CB) -> CB::Output + where + CB: rayon::iter::plumbing::ProducerCallback<Self::Item>, + { + callback.callback(ParIterProducer { + start: self.start, + end: self.end, + vec: self.vec, + }) + } +} + +struct ParIterProducer<'v, T: Send> { + start: u32, + end: u32, + vec: &'v Vec<T>, +} + +impl<'v, T: 'v + Send + Sync> rayon::iter::plumbing::Producer for ParIterProducer<'v, T> { + type Item = SnapshotItem<'v, T>; + type IntoIter = Iter<'v, T>; + + fn into_iter(self) -> Self::IntoIter { + debug_assert!(self.start <= self.end); + Iter { + location: Location::of(self.start), + idx: self.start, + end: self.end, + vec: self.vec, + } + } + + fn split_at(self, index: usize) -> (Self, Self) { + assert!(index <= (self.end - self.start) as usize); + let index = index as u32; + ( + ParIterProducer { + start: self.start, + end: self.start + index, + vec: self.vec, + }, + ParIterProducer { + start: self.start + index, + end: self.end, + vec: self.vec, + }, + ) + } +} + +struct Bucket<T> { + entries: AtomicPtr<Entry<T>>, +} + +impl<T> Bucket<T> { + fn layout(len: u32, layout: Layout) -> Layout { + Layout::from_size_align(layout.size() * len as usize, layout.align()) + .expect("exceeded maximum allocation size") + } + + unsafe fn alloc(len: u32, cols: u32) -> *mut Entry<T> { + let layout = Entry::<T>::layout(cols); + let arr_layout = Self::layout(len, layout); + let entries = std::alloc::alloc(arr_layout); + if entries.is_null() { + std::alloc::handle_alloc_error(arr_layout) + } + + for i in 0..len { + let active = entries.add(i as usize * layout.size()) as *mut AtomicBool; + active.write(AtomicBool::new(false)) + } + entries as *mut Entry<T> + } + + unsafe fn dealloc(entries: *mut Entry<T>, len: u32, cols: u32) { + let layout = Entry::<T>::layout(cols); + let arr_layout = Self::layout(len, layout); + for i in 0..len { + let entry = Bucket::get(entries, i, cols); + if *(*entry).active.get_mut() { + ptr::drop_in_place((*(*entry).slot.get()).as_mut_ptr()); + for matcher_col in Entry::matcher_cols_raw(entry, cols) { + ptr::drop_in_place((*matcher_col.get()).as_mut_ptr()); + } + } + } + std::alloc::dealloc(entries as *mut u8, arr_layout) + } + + unsafe fn get(entries: *mut Entry<T>, idx: u32, cols: u32) -> *mut Entry<T> { + let layout = Entry::<T>::layout(cols); + let ptr = entries as *mut u8; + ptr.add(layout.size() * idx as usize) as *mut Entry<T> + } + + fn new(entries: *mut Entry<T>) -> Bucket<T> { + Bucket { + entries: AtomicPtr::new(entries), + } + } +} + +#[repr(C)] +struct Entry<T> { + active: AtomicBool, + slot: UnsafeCell<MaybeUninit<T>>, + tail: [UnsafeCell<MaybeUninit<Utf32String>>; 0], +} + +impl<T> Entry<T> { + fn layout(cols: u32) -> Layout { + let head = Layout::new::<Self>(); + let tail = Layout::array::<Utf32String>(cols as usize).expect("invalid memory layout"); + head.extend(tail) + .expect("invalid memory layout") + .0 + .pad_to_align() + } + + unsafe fn matcher_cols_raw<'a>( + ptr: *mut Entry<T>, + cols: u32, + ) -> &'a [UnsafeCell<MaybeUninit<Utf32String>>] { + // this whole thing looks weird. The reason we do this is that + // we must make sure the pointer retains its provenance which may (or may not?) + // be lost if we used tail.as_ptr() + let tail = std::ptr::addr_of!((*ptr).tail) as *const u8; + let offset = tail.offset_from(ptr as *mut u8) as usize; + let ptr = (ptr as *mut u8).add(offset) as *mut _; + slice::from_raw_parts(ptr, cols as usize) + } + + unsafe fn matcher_cols_mut<'a>(ptr: *mut Entry<T>, cols: u32) -> &'a mut [Utf32String] { + // this whole thing looks weird. The reason we do this is that + // we must make sure the pointer retains its provenance which may (or may not?) + // be lost if we used tail.as_ptr() + let tail = std::ptr::addr_of!((*ptr).tail) as *const u8; + let offset = tail.offset_from(ptr as *mut u8) as usize; + let ptr = (ptr as *mut u8).add(offset) as *mut _; + slice::from_raw_parts_mut(ptr, cols as usize) + } + // # Safety + // + // Value must be initialized. + unsafe fn read<'a>(ptr: *mut Entry<T>, cols: u32) -> Item<'a, T> { + // this whole thing looks weird. The reason we do this is that + // we must make sure the pointer retains its provenance which may (or may not?) + // be lost if we used tail.as_ptr() + let data = (*(*ptr).slot.get()).assume_init_ref(); + let tail = std::ptr::addr_of!((*ptr).tail) as *const u8; + let offset = tail.offset_from(ptr as *mut u8) as usize; + let ptr = (ptr as *mut u8).add(offset) as *mut _; + let matcher_columns = slice::from_raw_parts(ptr, cols as usize); + Item { + data, + matcher_columns, + } + } +} + +#[derive(Debug)] +struct Location { + // the index of the bucket + bucket: u32, + // the length of `bucket` + bucket_len: u32, + // the index of the entry in `bucket` + entry: u32, +} + +// skip the shorter buckets to avoid unnecessary allocations. +// this also reduces the maximum capacity of a vector. +const SKIP: u32 = 32; +const SKIP_BUCKET: u32 = (u32::BITS - SKIP.leading_zeros()) - 1; + +impl Location { + fn of(index: u32) -> Location { + let skipped = index.checked_add(SKIP).expect("exceeded maximum length"); + let bucket = u32::BITS - skipped.leading_zeros(); + let bucket = bucket - (SKIP_BUCKET + 1); + let bucket_len = Location::bucket_len(bucket); + let entry = skipped ^ bucket_len; + + Location { + bucket, + bucket_len, + entry, + } + } + + fn bucket_len(bucket: u32) -> u32 { + 1 << (bucket + SKIP_BUCKET) + } + + /// The entry index at which the next bucket should be pre-allocated. + fn alloc_next_bucket_entry(&self) -> u32 { + self.bucket_len - (self.bucket_len >> 3) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn location() { + assert_eq!(Location::bucket_len(0), 32); + for i in 0..32 { + let loc = Location::of(i); + assert_eq!(loc.bucket_len, 32); + assert_eq!(loc.bucket, 0); + assert_eq!(loc.entry, i); + } + + assert_eq!(Location::bucket_len(1), 64); + for i in 33..96 { + let loc = Location::of(i); + assert_eq!(loc.bucket_len, 64); + assert_eq!(loc.bucket, 1); + assert_eq!(loc.entry, i - 32); + } + + assert_eq!(Location::bucket_len(2), 128); + for i in 96..224 { + let loc = Location::of(i); + assert_eq!(loc.bucket_len, 128); + assert_eq!(loc.bucket, 2); + assert_eq!(loc.entry, i - 96); + } + + let max = Location::of(MAX_ENTRIES); + assert_eq!(max.bucket, BUCKETS - 1); + assert_eq!(max.bucket_len, 1 << 31); + assert_eq!(max.entry, (1 << 31) - 1); + } + + #[test] + fn extend_unique_bucket() { + let vec = Vec::<u32>::with_capacity(1, 1); + vec.extend(0..10, |_, _| {}); + assert_eq!(vec.count(), 10); + for i in 0..10 { + assert_eq!(*vec.get(i).unwrap().data, i); + } + assert!(vec.get(10).is_none()); + } + + #[test] + fn extend_over_two_buckets() { + let vec = Vec::<u32>::with_capacity(1, 1); + vec.extend(0..100, |_, _| {}); + assert_eq!(vec.count(), 100); + for i in 0..100 { + assert_eq!(*vec.get(i).unwrap().data, i); + } + assert!(vec.get(100).is_none()); + } + + #[test] + fn extend_over_more_than_two_buckets() { + let vec = Vec::<u32>::with_capacity(1, 1); + vec.extend(0..1000, |_, _| {}); + assert_eq!(vec.count(), 1000); + for i in 0..1000 { + assert_eq!(*vec.get(i).unwrap().data, i); + } + assert!(vec.get(1000).is_none()); + } + + #[test] + /// test that ExactSizeIterator returning incorrect length is caught (0 AND more than reported) + fn extend_with_incorrect_reported_len_is_caught() { + struct IncorrectLenIter { + len: usize, + iter: std::ops::Range<u32>, + } + + impl Iterator for IncorrectLenIter { + type Item = u32; + + fn next(&mut self) -> Option<Self::Item> { + self.iter.next() + } + } + + impl ExactSizeIterator for IncorrectLenIter { + fn len(&self) -> usize { + self.len + } + } + + let vec = Vec::<u32>::with_capacity(1, 1); + let iter = IncorrectLenIter { + len: 10, + iter: (0..12), + }; + // this should panic + assert!(std::panic::catch_unwind(|| vec.extend(iter, |_, _| {})).is_err()); + + let vec = Vec::<u32>::with_capacity(1, 1); + let iter = IncorrectLenIter { + len: 12, + iter: (0..10), + }; + // this shouldn't panic and should just ignore the extra elements + assert!(std::panic::catch_unwind(|| vec.extend(iter, |_, _| {})).is_ok()); + // we should reserve 12 elements but only 10 should be present + assert_eq!(vec.count(), 12); + for i in 0..10 { + assert_eq!(*vec.get(i).unwrap().data, i); + } + assert!(vec.get(10).is_none()); + + let vec = Vec::<u32>::with_capacity(1, 1); + let iter = IncorrectLenIter { + len: 0, + iter: (0..2), + }; + // this should panic + assert!(std::panic::catch_unwind(|| vec.extend(iter, |_, _| {})).is_err()); + } + + // test |values| does not fit in the boxcar + #[test] + fn extend_over_max_capacity() { + let vec = Vec::<u32>::with_capacity(1, 1); + let count = MAX_ENTRIES as usize + 2; + let iter = std::iter::repeat(0).take(count); + assert!(std::panic::catch_unwind(|| vec.extend(iter, |_, _| {})).is_err()); + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 00000000..7ddb7407 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,462 @@ +/*! +`nucleo` is a high level crate that provides a high level matcher API that +provides a highly effective (parallel) matcher worker. It's designed to allow +quickly plugging a fully featured (and faster) fzf/skim like fuzzy matcher into +your TUI application. + +It's designed to run matching on a background threadpool while providing a +snapshot of the last complete match. That means the matcher can update the +results live while the user is typing while never blocking the main UI thread +(beyond a user provided timeout). Nucleo also supports fully concurrent lock-free +(and wait-free) streaming of input items. + +The [`Nucleo`] struct serves as the main API entrypoint for this crate. + +# Status + +Nucleo is used in the helix-editor and therefore has a large user base with lots +or real world testing. The core matcher implementation is considered complete +and is unlikely to see major changes. The `nucleo-matcher` crate is finished and +ready for widespread use, breaking changes should be very rare (a 1.0 release +should not be far away). + +While the high level `nucleo` crate also works well (and is also used in helix), +there are still additional features that will be added in the future. The high +level crate also need better documentation and will likely see a few minor API +changes in the future. + +*/ +use std::ops::{Bound, RangeBounds}; +use std::sync::atomic::{self, AtomicBool, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +use parking_lot::Mutex; +use rayon::ThreadPool; + +use crate::pattern::MultiPattern; +use crate::worker::Worker; +pub use nucleo_matcher::{chars, Config, Matcher, Utf32Str, Utf32String}; + +mod boxcar; +mod par_sort; +pub mod pattern; +mod worker; + +#[cfg(test)] +mod tests; + +/// A match candidate stored in a [`Nucleo`] worker. +pub struct Item<'a, T> { + pub data: &'a T, + pub matcher_columns: &'a [Utf32String], +} + +/// A handle that allows adding new items to a [`Nucleo`] worker. +/// +/// It's internally reference counted and can be cheaply cloned +/// and sent across threads. +pub struct Injector<T> { + items: Arc<boxcar::Vec<T>>, + notify: Arc<(dyn Fn() + Sync + Send)>, +} + +impl<T> Clone for Injector<T> { + fn clone(&self) -> Self { + Injector { + items: self.items.clone(), + notify: self.notify.clone(), + } + } +} + +impl<T> Injector<T> { + /// Appends an element to the list of matched items. + /// This function is lock-free and wait-free. + pub fn push(&self, value: T, fill_columns: impl FnOnce(&T, &mut [Utf32String])) -> u32 { + let idx = self.items.push(value, fill_columns); + (self.notify)(); + idx + } + + /// Appends multiple elements to the list of matched items. + /// This function is lock-free and wait-free. + /// + /// You should favor this function over `push` if at least one of the following is true: + /// - the number of items you're adding can be computed beforehand and is typically larger + /// than 1k + /// - you're able to batch incoming items + /// - you're adding items from multiple threads concurrently (this function results in less + /// contention) + pub fn extend<I>(&self, values: I, fill_columns: impl Fn(&T, &mut [Utf32String])) + where + I: IntoIterator<Item = T> + ExactSizeIterator, + { + self.items.extend(values, fill_columns); + (self.notify)(); + } + + /// Returns the total number of items injected in the matcher. This might + /// not match the number of items in the match snapshot (if the matcher + /// is still running) + pub fn injected_items(&self) -> u32 { + self.items.count() + } + + /// Returns a reference to the item at the given index. + /// + /// # Safety + /// + /// Item at `index` must be initialized. That means you must have observed + /// `push` returning this value or `get` returning `Some` for this value. + /// Just because a later index is initialized doesn't mean that this index + /// is initialized + pub unsafe fn get_unchecked(&self, index: u32) -> Item<'_, T> { + self.items.get_unchecked(index) + } + + /// Returns a reference to the element at the given index. + pub fn get(&self, index: u32) -> Option<Item<'_, T>> { + self.items.get(index) + } +} + +/// An [item](crate::Item) that was successfully matched by a [`Nucleo`] worker. +#[derive(PartialEq, Eq, Debug, Clone, Copy)] +pub struct Match { + pub score: u32, + pub idx: u32, +} + +/// That status of a [`Nucleo`] worker after a match. +#[derive(PartialEq, Eq, Debug, Clone, Copy)] +pub struct Status { + /// Whether the current snapshot has changed. + pub changed: bool, + /// Whether the matcher is still processing in the background. + pub running: bool, +} + +/// A snapshot represent the results of a [`Nucleo`] worker after +/// finishing a [`tick`](Nucleo::tick). +pub struct Snapshot<T: Sync + Send + 'static> { + item_count: u32, + matches: Vec<Match>, + pattern: MultiPattern, + items: Arc<boxcar::Vec<T>>, +} + +impl<T: Sync + Send + 'static> Snapshot<T> { + fn clear(&mut self, new_items: Arc<boxcar::Vec<T>>) { + self.item_count = 0; + self.matches.clear(); + self.items = new_items + } + + fn update(&mut self, worker: &Worker<T>) { + self.item_count = worker.item_count(); + self.pattern.clone_from(&worker.pattern); + self.matches.clone_from(&worker.matches); + if !Arc::ptr_eq(&worker.items, &self.items) { + self.items = worker.items.clone() + } + } + + /// Returns that total number of items + pub fn item_count(&self) -> u32 { + self.item_count + } + + /// Returns the pattern which items were matched against + pub fn pattern(&self) -> &MultiPattern { + &self.pattern + } + + /// Returns that number of items that matched the pattern + pub fn matched_item_count(&self) -> u32 { + self.matches.len() as u32 + } + + /// Returns an iterator over the items that correspond to a subrange of + /// all the matches in this snapshot. + /// + /// # Panics + /// Panics if `range` has a range bound that is larger than + /// the matched item count + pub fn matched_items( + &self, + range: impl RangeBounds<u32>, + ) -> impl ExactSizeIterator<Item = Item<'_, T>> + DoubleEndedIterator + '_ { + // TODO: use TAIT + let start = match range.start_bound() { + Bound::Included(&start) => start as usize, + Bound::Excluded(&start) => start as usize + 1, + Bound::Unbounded => 0, + }; + let end = match range.end_bound() { + Bound::Included(&end) => end as usize + 1, + Bound::Excluded(&end) => end as usize, + Bound::Unbounded => self.matches.len(), + }; + self.matches[start..end] + .iter() + .map(|&m| unsafe { self.items.get_unchecked(m.idx) }) + } + + /// Returns a reference to the item at the given index. + /// + /// # Safety + /// + /// Item at `index` must be initialized. That means you must have observed a + /// match with the corresponding index in this exact snapshot. Observing + /// a higher index is not enough as item indices can be non-contigously + /// initialized + #[inline] + pub unsafe fn get_item_unchecked(&self, index: u32) -> Item<'_, T> { + self.items.get_unchecked(index) + } + + /// Returns a reference to the item at the given index. + /// + /// Returns `None` if the given `index` is not initialized. This function + /// is only guarteed to return `Some` for item indices that can be found in + /// the `matches` of this struct. Both smaller and larger indices may return + /// `None`. + #[inline] + pub fn get_item(&self, index: u32) -> Option<Item<'_, T>> { + self.items.get(index) + } + + /// Return the matches corresponding to this snapshot. + #[inline] + pub fn matches(&self) -> &[Match] { + &self.matches + } + + /// A convenience function to return the [`Item`] corresponding to the + /// `n`th match. + /// + /// Returns `None` if `n` is greater than or equal to the match count. + #[inline] + pub fn get_matched_item(&self, n: u32) -> Option<Item<'_, T>> { + // SAFETY: A match index is guaranteed to corresponding to a valid global index in this + // snapshot. + unsafe { Some(self.get_item_unchecked(self.matches.get(n as usize)?.idx)) } + } +} + +#[repr(u8)] +#[derive(Clone, Copy, PartialEq, Eq)] +enum State { + Init, + /// items have been cleared but snapshot and items are still outdated + Cleared, + /// items are fresh + Fresh, +} + +impl State { + fn matcher_item_refs(self) -> usize { + match self { + State::Cleared => 1, + State::Init | State::Fresh => 2, + } + } + + fn canceled(self) -> bool { + self != State::Fresh + } + + fn cleared(self) -> bool { + self != State::Fresh + } +} + +/// A high level matcher worker that quickly computes matches in a background +/// threadpool. +pub struct Nucleo<T: Sync + Send + 'static> { + // the way the API is build we totally don't actually need these to be Arcs + // but this lets us avoid some unsafe + canceled: Arc<AtomicBool>, + should_notify: Arc<AtomicBool>, + worker: Arc<Mutex<Worker<T>>>, + pool: ThreadPool, + state: State, + items: Arc<boxcar::Vec<T>>, + notify: Arc<(dyn Fn() + Sync + Send)>, + snapshot: Snapshot<T>, + /// The pattern matched by this matcher. To update the match pattern + /// [`MultiPattern::reparse`](`pattern::MultiPattern::reparse`) should be used. + /// Note that the matcher worker will only become aware of the new pattern + /// after a call to [`tick`](Nucleo::tick). + pub pattern: MultiPattern, +} + +impl<T: Sync + Send + 'static> Nucleo<T> { + /// Constructs a new `nucleo` worker threadpool with the provided `config`. + /// + /// `notify` is called everytime new information is available and + /// [`tick`](Nucleo::tick) should be called. Note that `notify` is not + /// debounced, that should be handled by the downstream crate (for example + /// debouncing to only redraw at most every 1/60 seconds). + /// + /// If `None` is passed for the number of worker threads, nucleo will use + /// one thread per hardware thread. + /// + /// Nucleo can match items with multiple orthogonal properties. `columns` + /// indicates how many matching columns each item (and the pattern) has. The + /// number of columns cannot be changed after construction. + pub fn new( + config: Config, + notify: Arc<(dyn Fn() + Sync + Send)>, + num_threads: Option<usize>, + columns: u32, + ) -> Self { + let (pool, worker) = Worker::new(num_threads, config, notify.clone(), columns); + Self { + canceled: worker.canceled.clone(), + should_notify: worker.should_notify.clone(), + items: worker.items.clone(), + pool, + pattern: MultiPattern::new(columns as usize), + snapshot: Snapshot { + matches: Vec::with_capacity(2 * 1024), + pattern: MultiPattern::new(columns as usize), + item_count: 0, + items: worker.items.clone(), + }, + worker: Arc::new(Mutex::new(worker)), + state: State::Init, + notify, + } + } + + /// Returns the total number of active injectors + pub fn active_injectors(&self) -> usize { + Arc::strong_count(&self.items) + - self.state.matcher_item_refs() + - (Arc::ptr_eq(&self.snapshot.items, &self.items)) as usize + } + + /// Returns a snapshot of the current matcher state. + pub fn snapshot(&self) -> &Snapshot<T> { + &self.snapshot + } + + /// Returns an injector that can be used for adding candidates to the matcher. + pub fn injector(&self) -> Injector<T> { + Injector { + items: self.items.clone(), + notify: self.notify.clone(), + } + } + + /// Restart the the item stream. Removes all items and disconnects all + /// previously created injectors from this instance. If `clear_snapshot` + /// is `true` then all items and matched are removed from the [`Snapshot`] + /// immediately. Otherwise the snapshot will keep the current matches until + /// the matcher has run again. + /// + /// # Note + /// + /// The injectors will continue to function but they will not affect this + /// instance anymore. The old items will only be dropped when all injectors + /// were dropped. + pub fn restart(&mut self, clear_snapshot: bool) { + self.canceled.store(true, Ordering::Relaxed); + self.items = Arc::new(boxcar::Vec::with_capacity(1024, self.items.columns())); + self.state = State::Cleared; + if clear_snapshot { + self.snapshot.clear(self.items.clone()); + } + } + + /// Update the internal configuration. + pub fn update_config(&mut self, config: Config) { + self.worker.lock().update_config(config) + } + + // Set whether the matcher should sort search results by score after + // matching. Defaults to true. + pub fn sort_results(&mut self, sort_results: bool) { + self.worker.lock().sort_results(sort_results) + } + + // Set whether the matcher should reverse the order of the input. + // Defaults to false. + pub fn reverse_items(&mut self, reverse_items: bool) { + self.worker.lock().reverse_items(reverse_items) + } + + /// The main way to interact with the matcher, this should be called + /// regularly (for example each time a frame is rendered). To avoid + /// excessive redraws this method will wait `timeout` milliseconds for the + /// worker thread to finish. It is recommend to set the timeout to 10ms. + pub fn tick(&mut self, timeout: u64) -> Status { + self.should_notify.store(false, atomic::Ordering::Relaxed); + let status = self.pattern.status(); + let canceled = status != pattern::Status::Unchanged || self.state.canceled(); + let mut res = self.tick_inner(timeout, canceled, status); + if !canceled { + return res; + } + self.state = State::Fresh; + let status2 = self.tick_inner(timeout, false, pattern::Status::Unchanged); + res.changed |= status2.changed; + res.running = status2.running; + res + } + + fn tick_inner(&mut self, timeout: u64, canceled: bool, status: pattern::Status) -> Status { + let mut inner = if canceled { + self.pattern.reset_status(); + self.canceled.store(true, atomic::Ordering::Relaxed); + self.worker.lock_arc() + } else { + let Some(worker) = self.worker.try_lock_arc_for(Duration::from_millis(timeout)) else { + self.should_notify.store(true, Ordering::Release); + return Status { + changed: false, + running: true, + }; + }; + worker + }; + + let changed = inner.running; + + let running = canceled || self.items.count() > inner.item_count(); + if inner.running { + inner.running = false; + if !inner.was_canceled && !self.state.canceled() { + self.snapshot.update(&inner) + } + } + if running { + inner.pattern.clone_from(&self.pattern); + self.canceled.store(false, atomic::Ordering::Relaxed); + if !canceled { + self.should_notify.store(true, atomic::Ordering::Release); + } + let cleared = self.state.cleared(); + if cleared { + inner.items = self.items.clone(); + } + self.pool + .spawn(move || unsafe { inner.run(status, cleared) }) + } + Status { changed, running } + } +} + +impl<T: Sync + Send> Drop for Nucleo<T> { + fn drop(&mut self) { + // we ensure the worker quits before dropping items to ensure that + // the worker can always assume the items outlive it + self.canceled.store(true, atomic::Ordering::Relaxed); + let lock = self.worker.try_lock_for(Duration::from_secs(1)); + if lock.is_none() { + unreachable!("thread pool failed to shutdown properly") + } + } +} diff --git a/src/par_sort.rs b/src/par_sort.rs new file mode 100644 index 00000000..92f716cc --- /dev/null +++ b/src/par_sort.rs @@ -0,0 +1,895 @@ +//! Parallel quicksort. +//! +//! This implementation is copied verbatim from `std::slice::sort_unstable` and then parallelized. +//! The only difference from the original is that calls to `recurse` are executed in parallel using +//! `rayon_core::join`. +//! Further modified for nucleo to allow canceling the sort + +// Copyright (c) 2010 The Rust Project Developers +// +// Permission is hereby granted, free of charge, to any +// person obtaining a copy of this software and associated +// documentation files (the "Software"), to deal in the +// Software without restriction, including without +// limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software +// is furnished to do so, subject to the following +// conditions: +// +// The above copyright notice and this permission notice +// shall be included in all copies or substantial portions +// of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +use std::cmp; +use std::mem::{self, MaybeUninit}; +use std::ptr; +use std::sync::atomic::{self, AtomicBool}; + +/// When dropped, copies from `src` into `dest`. +struct CopyOnDrop<T> { + src: *const T, + dest: *mut T, +} + +impl<T> Drop for CopyOnDrop<T> { + fn drop(&mut self) { + // SAFETY: This is a helper class. + // Please refer to its usage for correctness. + // Namely, one must be sure that `src` and `dst` does not overlap as required by `ptr::copy_nonoverlapping`. + unsafe { + ptr::copy_nonoverlapping(self.src, self.dest, 1); + } + } +} + +/// Shifts the first element to the right until it encounters a greater or equal element. +fn shift_head<T, F>(v: &mut [T], is_less: &F) +where + F: Fn(&T, &T) -> bool, +{ + let len = v.len(); + // SAFETY: The unsafe operations below involves indexing without a bounds check (by offsetting a + // pointer) and copying memory (`ptr::copy_nonoverlapping`). + // + // a. Indexing: + // 1. We checked the size of the array to >=2. + // 2. All the indexing that we will do is always between {0 <= index < len} at most. + // + // b. Memory copying + // 1. We are obtaining pointers to references which are guaranteed to be valid. + // 2. They cannot overlap because we obtain pointers to difference indices of the slice. + // Namely, `i` and `i-1`. + // 3. If the slice is properly aligned, the elements are properly aligned. + // It is the caller's responsibility to make sure the slice is properly aligned. + // + // See comments below for further detail. + unsafe { + // If the first two elements are out-of-order... + if len >= 2 && is_less(v.get_unchecked(1), v.get_unchecked(0)) { + // Read the first element into a stack-allocated variable. If a following comparison + // operation panics, `hole` will get dropped and automatically write the element back + // into the slice. + let tmp = mem::ManuallyDrop::new(ptr::read(v.get_unchecked(0))); + let v = v.as_mut_ptr(); + let mut hole = CopyOnDrop { + src: &*tmp, + dest: v.add(1), + }; + ptr::copy_nonoverlapping(v.add(1), v.add(0), 1); + + for i in 2..len { + if !is_less(&*v.add(i), &*tmp) { + break; + } + + // Move `i`-th element one place to the left, thus shifting the hole to the right. + ptr::copy_nonoverlapping(v.add(i), v.add(i - 1), 1); + hole.dest = v.add(i); + } + // `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`. + } + } +} + +/// Shifts the last element to the left until it encounters a smaller or equal element. +fn shift_tail<T, F>(v: &mut [T], is_less: &F) +where + F: Fn(&T, &T) -> bool, +{ + let len = v.len(); + // SAFETY: The unsafe operations below involves indexing without a bound check (by offsetting a + // pointer) and copying memory (`ptr::copy_nonoverlapping`). + // + // a. Indexing: + // 1. We checked the size of the array to >= 2. + // 2. All the indexing that we will do is always between `0 <= index < len-1` at most. + // + // b. Memory copying + // 1. We are obtaining pointers to references which are guaranteed to be valid. + // 2. They cannot overlap because we obtain pointers to difference indices of the slice. + // Namely, `i` and `i+1`. + // 3. If the slice is properly aligned, the elements are properly aligned. + // It is the caller's responsibility to make sure the slice is properly aligned. + // + // See comments below for further detail. + unsafe { + // If the last two elements are out-of-order... + if len >= 2 && is_less(v.get_unchecked(len - 1), v.get_unchecked(len - 2)) { + // Read the last element into a stack-allocated variable. If a following comparison + // operation panics, `hole` will get dropped and automatically write the element back + // into the slice. + let tmp = mem::ManuallyDrop::new(ptr::read(v.get_unchecked(len - 1))); + let v = v.as_mut_ptr(); + let mut hole = CopyOnDrop { + src: &*tmp, + dest: v.add(len - 2), + }; + ptr::copy_nonoverlapping(v.add(len - 2), v.add(len - 1), 1); + + for i in (0..len - 2).rev() { + if !is_less(&*tmp, &*v.add(i)) { + break; + } + + // Move `i`-th element one place to the right, thus shifting the hole to the left. + ptr::copy_nonoverlapping(v.add(i), v.add(i + 1), 1); + hole.dest = v.add(i); + } + // `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`. + } + } +} + +/// Partially sorts a slice by shifting several out-of-order elements around. +/// +/// Returns `true` if the slice is sorted at the end. This function is *O*(*n*) worst-case. +#[cold] +fn partial_insertion_sort<T, F>(v: &mut [T], is_less: &F) -> bool +where + F: Fn(&T, &T) -> bool, +{ + // Maximum number of adjacent out-of-order pairs that will get shifted. + const MAX_STEPS: usize = 5; + // If the slice is shorter than this, don't shift any elements. + const SHORTEST_SHIFTING: usize = 50; + + let len = v.len(); + let mut i = 1; + + for _ in 0..MAX_STEPS { + // SAFETY: We already explicitly did the bound checking with `i < len`. + // All our subsequent indexing is only in the range `0 <= index < len` + unsafe { + // Find the next pair of adjacent out-of-order elements. + while i < len && !is_less(v.get_unchecked(i), v.get_unchecked(i - 1)) { + i += 1; + } + } + + // Are we done? + if i == len { + return true; + } + + // Don't shift elements on short arrays, that has a performance cost. + if len < SHORTEST_SHIFTING { + return false; + } + + // Swap the found pair of elements. This puts them in correct order. + v.swap(i - 1, i); + + // Shift the smaller element to the left. + shift_tail(&mut v[..i], is_less); + // Shift the greater element to the right. + shift_head(&mut v[i..], is_less); + } + + // Didn't manage to sort the slice in the limited number of steps. + false +} + +/// Sorts a slice using insertion sort, which is *O*(*n*^2) worst-case. +fn insertion_sort<T, F>(v: &mut [T], is_less: &F) +where + F: Fn(&T, &T) -> bool, +{ + for i in 1..v.len() { + shift_tail(&mut v[..i + 1], is_less); + } +} + +/// Sorts `v` using heapsort, which guarantees *O*(*n* \* log(*n*)) worst-case. +#[cold] +fn heapsort<T, F>(v: &mut [T], is_less: &F) +where + F: Fn(&T, &T) -> bool, +{ + // This binary heap respects the invariant `parent >= child`. + let sift_down = |v: &mut [T], mut node| { + loop { + // Children of `node`. + let mut child = 2 * node + 1; + if child >= v.len() { + break; + } + + // Choose the greater child. + if child + 1 < v.len() && is_less(&v[child], &v[child + 1]) { + child += 1; + } + + // Stop if the invariant holds at `node`. + if !is_less(&v[node], &v[child]) { + break; + } + + // Swap `node` with the greater child, move one step down, and continue sifting. + v.swap(node, child); + node = child; + } + }; + + // Build the heap in linear time. + for i in (0..v.len() / 2).rev() { + sift_down(v, i); + } + + // Pop maximal elements from the heap. + for i in (1..v.len()).rev() { + v.swap(0, i); + sift_down(&mut v[..i], 0); + } +} + +/// Partitions `v` into elements smaller than `pivot`, followed by elements greater than or equal +/// to `pivot`. +/// +/// Returns the number of elements smaller than `pivot`. +/// +/// Partitioning is performed block-by-block in order to minimize the cost of branching operations. +/// This idea is presented in the [BlockQuicksort][pdf] paper. +/// +/// [pdf]: https://drops.dagstuhl.de/opus/volltexte/2016/6389/pdf/LIPIcs-ESA-2016-38.pdf +fn partition_in_blocks<T, F>(v: &mut [T], pivot: &T, is_less: &F) -> usize +where + F: Fn(&T, &T) -> bool, +{ + // Number of elements in a typical block. + const BLOCK: usize = 128; + + // The partitioning algorithm repeats the following steps until completion: + // + // 1. Trace a block from the left side to identify elements greater than or equal to the pivot. + // 2. Trace a block from the right side to identify elements smaller than the pivot. + // 3. Exchange the identified elements between the left and right side. + // + // We keep the following variables for a block of elements: + // + // 1. `block` - Number of elements in the block. + // 2. `start` - Start pointer into the `offsets` array. + // 3. `end` - End pointer into the `offsets` array. + // 4. `offsets - Indices of out-of-order elements within the block. + + // The current block on the left side (from `l` to `l.add(block_l)`). + let mut l = v.as_mut_ptr(); + let mut block_l = BLOCK; + let mut start_l = ptr::null_mut(); + let mut end_l = ptr::null_mut(); + let mut offsets_l = [MaybeUninit::<u8>::uninit(); BLOCK]; + + // The current block on the right side (from `r.sub(block_r)` to `r`). + // SAFETY: The documentation for .add() specifically mention that `vec.as_ptr().add(vec.len())` is always safe` + let mut r = unsafe { l.add(v.len()) }; + let mut block_r = BLOCK; + let mut start_r = ptr::null_mut(); + let mut end_r = ptr::null_mut(); + let mut offsets_r = [MaybeUninit::<u8>::uninit(); BLOCK]; + + // FIXME: When we get VLAs, try creating one array of length `min(v.len(), 2 * BLOCK)` rather + // than two fixed-size arrays of length `BLOCK`. VLAs might be more cache-efficient. + + // Returns the number of elements between pointers `l` (inclusive) and `r` (exclusive). + fn width<T>(l: *mut T, r: *mut T) -> usize { + assert!(mem::size_of::<T>() > 0); + // FIXME: this should *likely* use `offset_from`, but more + // investigation is needed (including running tests in miri). + // TODO unstable: (r.addr() - l.addr()) / mem::size_of::<T>() + (r as usize - l as usize) / mem::size_of::<T>() + } + + loop { + // We are done with partitioning block-by-block when `l` and `r` get very close. Then we do + // some patch-up work in order to partition the remaining elements in between. + let is_done = width(l, r) <= 2 * BLOCK; + + if is_done { + // Number of remaining elements (still not compared to the pivot). + let mut rem = width(l, r); + if start_l < end_l || start_r < end_r { + rem -= BLOCK; + } + + // Adjust block sizes so that the left and right block don't overlap, but get perfectly + // aligned to cover the whole remaining gap. + if start_l < end_l { + block_r = rem; + } else if start_r < end_r { + block_l = rem; + } else { + // There were the same number of elements to switch on both blocks during the last + // iteration, so there are no remaining elements on either block. Cover the remaining + // items with roughly equally-sized blocks. + block_l = rem / 2; + block_r = rem - block_l; + } + debug_assert!(block_l <= BLOCK && block_r <= BLOCK); + debug_assert!(width(l, r) == block_l + block_r); + } + + if start_l == end_l { + // Trace `block_l` elements from the left side. + // TODO unstable: start_l = MaybeUninit::slice_as_mut_ptr(&mut offsets_l); + start_l = offsets_l.as_mut_ptr() as *mut u8; + end_l = start_l; + let mut elem = l; + + for i in 0..block_l { + // SAFETY: The unsafety operations below involve the usage of the `offset`. + // According to the conditions required by the function, we satisfy them because: + // 1. `offsets_l` is stack-allocated, and thus considered separate allocated object. + // 2. The function `is_less` returns a `bool`. + // Casting a `bool` will never overflow `isize`. + // 3. We have guaranteed that `block_l` will be `<= BLOCK`. + // Plus, `end_l` was initially set to the begin pointer of `offsets_` which was declared on the stack. + // Thus, we know that even in the worst case (all invocations of `is_less` returns false) we will only be at most 1 byte pass the end. + // Another unsafety operation here is dereferencing `elem`. + // However, `elem` was initially the begin pointer to the slice which is always valid. + unsafe { + // Branchless comparison. + *end_l = i as u8; + end_l = end_l.offset(!is_less(&*elem, pivot) as isize); + elem = elem.offset(1); + } + } + } + + if start_r == end_r { + // Trace `block_r` elements from the right side. + // TODO unstable: start_r = MaybeUninit::slice_as_mut_ptr(&mut offsets_r); + start_r = offsets_r.as_mut_ptr() as *mut u8; + end_r = start_r; + let mut elem = r; + + for i in 0..block_r { + // SAFETY: The unsafety operations below involve the usage of the `offset`. + // According to the conditions required by the function, we satisfy them because: + // 1. `offsets_r` is stack-allocated, and thus considered separate allocated object. + // 2. The function `is_less` returns a `bool`. + // Casting a `bool` will never overflow `isize`. + // 3. We have guaranteed that `block_r` will be `<= BLOCK`. + // Plus, `end_r` was initially set to the begin pointer of `offsets_` which was declared on the stack. + // Thus, we know that even in the worst case (all invocations of `is_less` returns true) we will only be at most 1 byte pass the end. + // Another unsafety operation here is dereferencing `elem`. + // However, `elem` was initially `1 * sizeof(T)` past the end and we decrement it by `1 * sizeof(T)` before accessing it. + // Plus, `block_r` was asserted to be less than `BLOCK` and `elem` will therefore at most be pointing to the beginning of the slice. + unsafe { + // Branchless comparison. + elem = elem.offset(-1); + *end_r = i as u8; + end_r = end_r.offset(is_less(&*elem, pivot) as isize); + } + } + } + + // Number of out-of-order elements to swap between the left and right side. + let count = cmp::min(width(start_l, end_l), width(start_r, end_r)); + + if count > 0 { + macro_rules! left { + () => { + l.offset(*start_l as isize) + }; + } + macro_rules! right { + () => { + r.offset(-(*start_r as isize) - 1) + }; + } + + // Instead of swapping one pair at the time, it is more efficient to perform a cyclic + // permutation. This is not strictly equivalent to swapping, but produces a similar + // result using fewer memory operations. + + // SAFETY: The use of `ptr::read` is valid because there is at least one element in + // both `offsets_l` and `offsets_r`, so `left!` is a valid pointer to read from. + // + // The uses of `left!` involve calls to `offset` on `l`, which points to the + // beginning of `v`. All the offsets pointed-to by `start_l` are at most `block_l`, so + // these `offset` calls are safe as all reads are within the block. The same argument + // applies for the uses of `right!`. + // + // The calls to `start_l.offset` are valid because there are at most `count-1` of them, + // plus the final one at the end of the unsafe block, where `count` is the minimum number + // of collected offsets in `offsets_l` and `offsets_r`, so there is no risk of there not + // being enough elements. The same reasoning applies to the calls to `start_r.offset`. + // + // The calls to `copy_nonoverlapping` are safe because `left!` and `right!` are guaranteed + // not to overlap, and are valid because of the reasoning above. + unsafe { + let tmp = ptr::read(left!()); + ptr::copy_nonoverlapping(right!(), left!(), 1); + + for _ in 1..count { + start_l = start_l.offset(1); + ptr::copy_nonoverlapping(left!(), right!(), 1); + start_r = start_r.offset(1); + ptr::copy_nonoverlapping(right!(), left!(), 1); + } + + ptr::copy_nonoverlapping(&tmp, right!(), 1); + mem::forget(tmp); + start_l = start_l.offset(1); + start_r = start_r.offset(1); + } + } + + if start_l == end_l { + // All out-of-order elements in the left block were moved. Move to the next block. + + // block-width-guarantee + // SAFETY: if `!is_done` then the slice width is guaranteed to be at least `2*BLOCK` wide. There + // are at most `BLOCK` elements in `offsets_l` because of its size, so the `offset` operation is + // safe. Otherwise, the debug assertions in the `is_done` case guarantee that + // `width(l, r) == block_l + block_r`, namely, that the block sizes have been adjusted to account + // for the smaller number of remaining elements. + l = unsafe { l.add(block_l) }; + } + + if start_r == end_r { + // All out-of-order elements in the right block were moved. Move to the previous block. + + // SAFETY: Same argument as [block-width-guarantee]. Either this is a full block `2*BLOCK`-wide, + // or `block_r` has been adjusted for the last handful of elements. + r = unsafe { r.offset(-(block_r as isize)) }; + } + + if is_done { + break; + } + } + + // All that remains now is at most one block (either the left or the right) with out-of-order + // elements that need to be moved. Such remaining elements can be simply shifted to the end + // within their block. + + if start_l < end_l { + // The left block remains. + // Move its remaining out-of-order elements to the far right. + debug_assert_eq!(width(l, r), block_l); + while start_l < end_l { + // remaining-elements-safety + // SAFETY: while the loop condition holds there are still elements in `offsets_l`, so it + // is safe to point `end_l` to the previous element. + // + // The `ptr::swap` is safe if both its arguments are valid for reads and writes: + // - Per the debug assert above, the distance between `l` and `r` is `block_l` + // elements, so there can be at most `block_l` remaining offsets between `start_l` + // and `end_l`. This means `r` will be moved at most `block_l` steps back, which + // makes the `r.offset` calls valid (at that point `l == r`). + // - `offsets_l` contains valid offsets into `v` collected during the partitioning of + // the last block, so the `l.offset` calls are valid. + unsafe { + end_l = end_l.offset(-1); + ptr::swap(l.offset(*end_l as isize), r.offset(-1)); + r = r.offset(-1); + } + } + width(v.as_mut_ptr(), r) + } else if start_r < end_r { + // The right block remains. + // Move its remaining out-of-order elements to the far left. + debug_assert_eq!(width(l, r), block_r); + while start_r < end_r { + // SAFETY: See the reasoning in [remaining-elements-safety]. + unsafe { + end_r = end_r.offset(-1); + ptr::swap(l, r.offset(-(*end_r as isize) - 1)); + l = l.offset(1); + } + } + width(v.as_mut_ptr(), l) + } else { + // Nothing else to do, we're done. + width(v.as_mut_ptr(), l) + } +} + +/// Partitions `v` into elements smaller than `v[pivot]`, followed by elements greater than or +/// equal to `v[pivot]`. +/// +/// Returns a tuple of: +/// +/// 1. Number of elements smaller than `v[pivot]`. +/// 2. True if `v` was already partitioned. +fn partition<T, F>(v: &mut [T], pivot: usize, is_less: &F) -> (usize, bool) +where + F: Fn(&T, &T) -> bool, +{ + let (mid, was_partitioned) = { + // Place the pivot at the beginning of slice. + v.swap(0, pivot); + let (pivot, v) = v.split_at_mut(1); + let pivot = &mut pivot[0]; + + // Read the pivot into a stack-allocated variable for efficiency. If a following comparison + // operation panics, the pivot will be automatically written back into the slice. + + // SAFETY: `pivot` is a reference to the first element of `v`, so `ptr::read` is safe. + let tmp = mem::ManuallyDrop::new(unsafe { ptr::read(pivot) }); + let _pivot_guard = CopyOnDrop { + src: &*tmp, + dest: pivot, + }; + let pivot = &*tmp; + + // Find the first pair of out-of-order elements. + let mut l = 0; + let mut r = v.len(); + + // SAFETY: The unsafety below involves indexing an array. + // For the first one: We already do the bounds checking here with `l < r`. + // For the second one: We initially have `l == 0` and `r == v.len()` and we checked that `l < r` at every indexing operation. + // From here we know that `r` must be at least `r == l` which was shown to be valid from the first one. + unsafe { + // Find the first element greater than or equal to the pivot. + while l < r && is_less(v.get_unchecked(l), pivot) { + l += 1; + } + + // Find the last element smaller that the pivot. + while l < r && !is_less(v.get_unchecked(r - 1), pivot) { + r -= 1; + } + } + + ( + l + partition_in_blocks(&mut v[l..r], pivot, is_less), + l >= r, + ) + + // `_pivot_guard` goes out of scope and writes the pivot (which is a stack-allocated + // variable) back into the slice where it originally was. This step is critical in ensuring + // safety! + }; + + // Place the pivot between the two partitions. + v.swap(0, mid); + + (mid, was_partitioned) +} + +/// Partitions `v` into elements equal to `v[pivot]` followed by elements greater than `v[pivot]`. +/// +/// Returns the number of elements equal to the pivot. It is assumed that `v` does not contain +/// elements smaller than the pivot. +fn partition_equal<T, F>(v: &mut [T], pivot: usize, is_less: &F) -> usize +where + F: Fn(&T, &T) -> bool, +{ + // Place the pivot at the beginning of slice. + v.swap(0, pivot); + let (pivot, v) = v.split_at_mut(1); + let pivot = &mut pivot[0]; + + // Read the pivot into a stack-allocated variable for efficiency. If a following comparison + // operation panics, the pivot will be automatically written back into the slice. + // SAFETY: The pointer here is valid because it is obtained from a reference to a slice. + let tmp = mem::ManuallyDrop::new(unsafe { ptr::read(pivot) }); + let _pivot_guard = CopyOnDrop { + src: &*tmp, + dest: pivot, + }; + let pivot = &*tmp; + + // Now partition the slice. + let mut l = 0; + let mut r = v.len(); + loop { + // SAFETY: The unsafety below involves indexing an array. + // For the first one: We already do the bounds checking here with `l < r`. + // For the second one: We initially have `l == 0` and `r == v.len()` and we checked that `l < r` at every indexing operation. + // From here we know that `r` must be at least `r == l` which was shown to be valid from the first one. + unsafe { + // Find the first element greater than the pivot. + while l < r && !is_less(pivot, v.get_unchecked(l)) { + l += 1; + } + + // Find the last element equal to the pivot. + while l < r && is_less(pivot, v.get_unchecked(r - 1)) { + r -= 1; + } + + // Are we done? + if l >= r { + break; + } + + // Swap the found pair of out-of-order elements. + r -= 1; + let ptr = v.as_mut_ptr(); + ptr::swap(ptr.add(l), ptr.add(r)); + l += 1; + } + } + + // We found `l` elements equal to the pivot. Add 1 to account for the pivot itself. + l + 1 + + // `_pivot_guard` goes out of scope and writes the pivot (which is a stack-allocated variable) + // back into the slice where it originally was. This step is critical in ensuring safety! +} + +/// Scatters some elements around in an attempt to break patterns that might cause imbalanced +/// partitions in quicksort. +#[cold] +fn break_patterns<T>(v: &mut [T]) { + let len = v.len(); + if len >= 8 { + // Pseudorandom number generator from the "Xorshift RNGs" paper by George Marsaglia. + let mut random = len as u32; + let mut gen_u32 = || { + random ^= random << 13; + random ^= random >> 17; + random ^= random << 5; + random + }; + let mut gen_usize = || { + if usize::BITS <= 32 { + gen_u32() as usize + } else { + (((gen_u32() as u64) << 32) | (gen_u32() as u64)) as usize + } + }; + + // Take random numbers modulo this number. + // The number fits into `usize` because `len` is not greater than `isize::MAX`. + let modulus = len.next_power_of_two(); + + // Some pivot candidates will be in the nearby of this index. Let's randomize them. + let pos = len / 4 * 2; + + for i in 0..3 { + // Generate a random number modulo `len`. However, in order to avoid costly operations + // we first take it modulo a power of two, and then decrease by `len` until it fits + // into the range `[0, len - 1]`. + let mut other = gen_usize() & (modulus - 1); + + // `other` is guaranteed to be less than `2 * len`. + if other >= len { + other -= len; + } + + v.swap(pos - 1 + i, other); + } + } +} + +/// Chooses a pivot in `v` and returns the index and `true` if the slice is likely already sorted. +/// +/// Elements in `v` might be reordered in the process. +fn choose_pivot<T, F>(v: &mut [T], is_less: &F) -> (usize, bool) +where + F: Fn(&T, &T) -> bool, +{ + // Minimum length to choose the median-of-medians method. + // Shorter slices use the simple median-of-three method. + const SHORTEST_MEDIAN_OF_MEDIANS: usize = 50; + // Maximum number of swaps that can be performed in this function. + const MAX_SWAPS: usize = 4 * 3; + + let len = v.len(); + + // Three indices near which we are going to choose a pivot. + #[allow(clippy::identity_op)] + let mut a = len / 4 * 1; + let mut b = len / 4 * 2; + let mut c = len / 4 * 3; + + // Counts the total number of swaps we are about to perform while sorting indices. + let mut swaps = 0; + + if len >= 8 { + // Swaps indices so that `v[a] <= v[b]`. + // SAFETY: `len >= 8` so there are at least two elements in the neighborhoods of + // `a`, `b` and `c`. This means the three calls to `sort_adjacent` result in + // corresponding calls to `sort3` with valid 3-item neighborhoods around each + // pointer, which in turn means the calls to `sort2` are done with valid + // references. Thus the `v.get_unchecked` calls are safe, as is the `ptr::swap` + // call. + let mut sort2 = |a: &mut usize, b: &mut usize| unsafe { + if is_less(v.get_unchecked(*b), v.get_unchecked(*a)) { + ptr::swap(a, b); + swaps += 1; + } + }; + + // Swaps indices so that `v[a] <= v[b] <= v[c]`. + let mut sort3 = |a: &mut usize, b: &mut usize, c: &mut usize| { + sort2(a, b); + sort2(b, c); + sort2(a, b); + }; + + if len >= SHORTEST_MEDIAN_OF_MEDIANS { + // Finds the median of `v[a - 1], v[a], v[a + 1]` and stores the index into `a`. + let mut sort_adjacent = |a: &mut usize| { + let tmp = *a; + sort3(&mut (tmp - 1), a, &mut (tmp + 1)); + }; + + // Find medians in the neighborhoods of `a`, `b`, and `c`. + sort_adjacent(&mut a); + sort_adjacent(&mut b); + sort_adjacent(&mut c); + } + + // Find the median among `a`, `b`, and `c`. + sort3(&mut a, &mut b, &mut c); + } + + if swaps < MAX_SWAPS { + (b, swaps == 0) + } else { + // The maximum number of swaps was performed. Chances are the slice is descending or mostly + // descending, so reversing will probably help sort it faster. + v.reverse(); + (len - 1 - b, true) + } +} + +/// Sorts `v` recursively. +/// +/// If the slice had a predecessor in the original array, it is specified as `pred`. +/// +/// `limit` is the number of allowed imbalanced partitions before switching to `heapsort`. If zero, +/// this function will immediately switch to heapsort. +fn recurse<'a, T, F>( + mut v: &'a mut [T], + is_less: &F, + mut pred: Option<&'a mut T>, + mut limit: u32, + canceled: &AtomicBool, +) -> bool +where + T: Send, + F: Fn(&T, &T) -> bool + Sync, +{ + // Slices of up to this length get sorted using insertion sort. + const MAX_INSERTION: usize = 20; + // If both partitions are up to this length, we continue sequentially. This number is as small + // as possible but so that the overhead of Rayon's task scheduling is still negligible. + const MAX_SEQUENTIAL: usize = 2000; + + // True if the last partitioning was reasonably balanced. + let mut was_balanced = true; + // True if the last partitioning didn't shuffle elements (the slice was already partitioned). + let mut was_partitioned = true; + + loop { + let len = v.len(); + + // Very short slices get sorted using insertion sort. + if len <= MAX_INSERTION { + insertion_sort(v, is_less); + return false; + } + + // If too many bad pivot choices were made, simply fall back to heapsort in order to + // guarantee `O(n * log(n))` worst-case. + if limit == 0 { + heapsort(v, is_less); + return false; + } + + // If the last partitioning was imbalanced, try breaking patterns in the slice by shuffling + // some elements around. Hopefully we'll choose a better pivot this time. + if !was_balanced { + break_patterns(v); + limit -= 1; + } + + // Choose a pivot and try guessing whether the slice is already sorted. + let (pivot, likely_sorted) = choose_pivot(v, is_less); + + // If the last partitioning was decently balanced and didn't shuffle elements, and if pivot + // selection predicts the slice is likely already sorted... + if was_balanced && was_partitioned && likely_sorted { + // Try identifying several out-of-order elements and shifting them to correct + // positions. If the slice ends up being completely sorted, we're done. + if partial_insertion_sort(v, is_less) { + return false; + } + } + + // If the chosen pivot is equal to the predecessor, then it's the smallest element in the + // slice. Partition the slice into elements equal to and elements greater than the pivot. + // This case is usually hit when the slice contains many duplicate elements. + if let Some(ref p) = pred { + if !is_less(p, &v[pivot]) { + let mid = partition_equal(v, pivot, is_less); + + // Continue sorting elements greater than the pivot. + v = &mut v[mid..]; + continue; + } + } + + // Partition the slice. + let (mid, was_p) = partition(v, pivot, is_less); + was_balanced = cmp::min(mid, len - mid) >= len / 8; + was_partitioned = was_p; + + // Split the slice into `left`, `pivot`, and `right`. + let (left, right) = v.split_at_mut(mid); + let (pivot, right) = right.split_at_mut(1); + let pivot = &mut pivot[0]; + + if cmp::max(left.len(), right.len()) <= MAX_SEQUENTIAL { + // Recurse into the shorter side only in order to minimize the total number of recursive + // calls and consume less stack space. Then just continue with the longer side (this is + // akin to tail recursion). + if left.len() < right.len() { + recurse(left, is_less, pred, limit, canceled); + v = right; + pred = Some(pivot); + } else { + recurse(right, is_less, Some(pivot), limit, canceled); + v = left; + } + } else if canceled.load(atomic::Ordering::Relaxed) { + break true; + } else { + // Sort the left and right half in parallel. + let (canceled1, canceled2) = rayon::join( + || recurse(left, is_less, pred, limit, canceled), + || recurse(right, is_less, Some(pivot), limit, canceled), + ); + break canceled1 | canceled2; + } + } +} + +/// Sorts `v` using pattern-defeating quicksort in parallel. +/// +/// The algorithm is unstable, in-place, and *O*(*n* \* log(*n*)) worst-case. +pub(crate) fn par_quicksort<T, F>(v: &mut [T], is_less: F, canceled: &AtomicBool) -> bool +where + T: Send, + F: Fn(&T, &T) -> bool + Sync, +{ + // Sorting has no meaningful behavior on zero-sized types. + if mem::size_of::<T>() == 0 { + return false; + } + if canceled.load(atomic::Ordering::Relaxed) { + return true; + } + + // Limit the number of imbalanced partitions to `floor(log2(len)) + 1`. + let limit = usize::BITS - v.len().leading_zeros(); + + recurse(v, &is_less, None, limit, canceled) +} diff --git a/src/pattern.rs b/src/pattern.rs new file mode 100644 index 00000000..816b0a31 --- /dev/null +++ b/src/pattern.rs @@ -0,0 +1,100 @@ +pub use nucleo_matcher::pattern::{Atom, AtomKind, CaseMatching, Normalization, Pattern}; +use nucleo_matcher::{Matcher, Utf32String}; + +#[cfg(test)] +mod tests; + +#[derive(Debug, PartialEq, Eq, Clone, Copy, PartialOrd, Ord, Default)] +pub(crate) enum Status { + #[default] + Unchanged, + Update, + Rescore, +} + +#[derive(Debug)] +pub struct MultiPattern { + cols: Vec<(Pattern, Status)>, +} + +impl Clone for MultiPattern { + fn clone(&self) -> Self { + Self { + cols: self.cols.clone(), + } + } + + fn clone_from(&mut self, source: &Self) { + self.cols.clone_from(&source.cols) + } +} + +impl MultiPattern { + /// Creates a multi pattern with `columns` empty column patterns. + pub fn new(columns: usize) -> Self { + Self { + cols: vec![Default::default(); columns], + } + } + + /// Reparses a column. By specifying `append` the caller promises that text passed + /// to the previous `reparse` invocation is a prefix of `new_text`. This enables + /// additional optimizations but can lead to missing matches if an incorrect value + /// is passed. + pub fn reparse( + &mut self, + column: usize, + new_text: &str, + case_matching: CaseMatching, + normalization: Normalization, + append: bool, + ) { + let old_status = self.cols[column].1; + if append + && old_status != Status::Rescore + && self.cols[column] + .0 + .atoms + .last() + .map_or(true, |last| !last.negative) + { + self.cols[column].1 = Status::Update; + } else { + self.cols[column].1 = Status::Rescore; + } + self.cols[column] + .0 + .reparse(new_text, case_matching, normalization); + } + + pub fn column_pattern(&self, column: usize) -> &Pattern { + &self.cols[column].0 + } + + pub(crate) fn status(&self) -> Status { + self.cols + .iter() + .map(|&(_, status)| status) + .max() + .unwrap_or(Status::Unchanged) + } + + pub(crate) fn reset_status(&mut self) { + for (_, status) in &mut self.cols { + *status = Status::Unchanged + } + } + + pub fn score(&self, haystack: &[Utf32String], matcher: &mut Matcher) -> Option<u32> { + // TODO: wheight columns? + let mut score = 0; + for ((pattern, _), haystack) in self.cols.iter().zip(haystack) { + score += pattern.score(haystack.slice(..), matcher)? + } + Some(score) + } + + pub fn is_empty(&self) -> bool { + self.cols.iter().all(|(pat, _)| pat.atoms.is_empty()) + } +} diff --git a/src/pattern/tests.rs b/src/pattern/tests.rs new file mode 100644 index 00000000..40e8e328 --- /dev/null +++ b/src/pattern/tests.rs @@ -0,0 +1,14 @@ +use nucleo_matcher::pattern::{CaseMatching, Normalization}; + +use crate::pattern::{MultiPattern, Status}; + +#[test] +fn append() { + let mut pat = MultiPattern::new(1); + pat.reparse(0, "!", CaseMatching::Smart, Normalization::Smart, true); + assert_eq!(pat.status(), Status::Update); + pat.reparse(0, "!f", CaseMatching::Smart, Normalization::Smart, true); + assert_eq!(pat.status(), Status::Update); + pat.reparse(0, "!fo", CaseMatching::Smart, Normalization::Smart, true); + assert_eq!(pat.status(), Status::Rescore); +} diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 00000000..676c50df --- /dev/null +++ b/src/tests.rs @@ -0,0 +1,27 @@ +use std::sync::Arc; + +use nucleo_matcher::Config; + +use crate::Nucleo; + +#[test] +fn active_injector_count() { + let mut nucleo: Nucleo<()> = Nucleo::new(Config::DEFAULT, Arc::new(|| ()), Some(1), 1); + assert_eq!(nucleo.active_injectors(), 0); + let injector = nucleo.injector(); + assert_eq!(nucleo.active_injectors(), 1); + let injector2 = nucleo.injector(); + assert_eq!(nucleo.active_injectors(), 2); + drop(injector2); + assert_eq!(nucleo.active_injectors(), 1); + nucleo.restart(false); + assert_eq!(nucleo.active_injectors(), 0); + let injector3 = nucleo.injector(); + assert_eq!(nucleo.active_injectors(), 1); + nucleo.tick(0); + assert_eq!(nucleo.active_injectors(), 1); + drop(injector); + assert_eq!(nucleo.active_injectors(), 1); + drop(injector3); + assert_eq!(nucleo.active_injectors(), 0); +} diff --git a/src/worker.rs b/src/worker.rs new file mode 100644 index 00000000..f4077e6e --- /dev/null +++ b/src/worker.rs @@ -0,0 +1,301 @@ +use std::cell::UnsafeCell; +use std::mem::take; +use std::sync::atomic::{self, AtomicBool, AtomicU32}; +use std::sync::Arc; + +use nucleo_matcher::Config; +use parking_lot::Mutex; +use rayon::{prelude::*, ThreadPool}; + +use crate::par_sort::par_quicksort; +use crate::pattern::{self, MultiPattern}; +use crate::{boxcar, Match}; + +struct Matchers(Box<[UnsafeCell<nucleo_matcher::Matcher>]>); + +impl Matchers { + // this is not a true mut from ref, we use a cell here + #[allow(clippy::mut_from_ref)] + unsafe fn get(&self) -> &mut nucleo_matcher::Matcher { + &mut *self.0[rayon::current_thread_index().unwrap()].get() + } +} + +unsafe impl Sync for Matchers {} +unsafe impl Send for Matchers {} + +pub(crate) struct Worker<T: Sync + Send + 'static> { + pub(crate) running: bool, + matchers: Matchers, + pub(crate) matches: Vec<Match>, + pub(crate) pattern: MultiPattern, + pub(crate) sort_results: bool, + pub(crate) reverse_items: bool, + pub(crate) canceled: Arc<AtomicBool>, + pub(crate) should_notify: Arc<AtomicBool>, + pub(crate) was_canceled: bool, + pub(crate) last_snapshot: u32, + notify: Arc<(dyn Fn() + Sync + Send)>, + pub(crate) items: Arc<boxcar::Vec<T>>, + in_flight: Vec<u32>, +} + +impl<T: Sync + Send + 'static> Worker<T> { + pub(crate) fn item_count(&self) -> u32 { + self.last_snapshot - self.in_flight.len() as u32 + } + pub(crate) fn update_config(&mut self, config: Config) { + for matcher in self.matchers.0.iter_mut() { + matcher.get_mut().config = config.clone(); + } + } + pub(crate) fn sort_results(&mut self, sort_results: bool) { + self.sort_results = sort_results; + } + pub(crate) fn reverse_items(&mut self, reverse_items: bool) { + self.reverse_items = reverse_items; + } + + pub(crate) fn new( + worker_threads: Option<usize>, + config: Config, + notify: Arc<(dyn Fn() + Sync + Send)>, + cols: u32, + ) -> (ThreadPool, Self) { + let worker_threads = worker_threads + .unwrap_or_else(|| std::thread::available_parallelism().map_or(4, |it| it.get())); + let pool = rayon::ThreadPoolBuilder::new() + .thread_name(|i| format!("nucleo worker {i}")) + .num_threads(worker_threads) + .build() + .expect("creating threadpool failed"); + let matchers = (0..worker_threads) + .map(|_| UnsafeCell::new(nucleo_matcher::Matcher::new(config.clone()))) + .collect(); + let worker = Worker { + running: false, + matchers: Matchers(matchers), + last_snapshot: 0, + matches: Vec::new(), + // just a placeholder + pattern: MultiPattern::new(cols as usize), + sort_results: true, + reverse_items: false, + canceled: Arc::new(AtomicBool::new(false)), + should_notify: Arc::new(AtomicBool::new(false)), + was_canceled: false, + notify, + items: Arc::new(boxcar::Vec::with_capacity(2 * 1024, cols)), + in_flight: Vec::with_capacity(64), + }; + (pool, worker) + } + + unsafe fn process_new_items(&mut self, unmatched: &AtomicU32) { + let matchers = &self.matchers; + let pattern = &self.pattern; + self.matches.reserve(self.in_flight.len()); + self.in_flight.retain(|&idx| { + let Some(item) = self.items.get(idx) else { + return true; + }; + if let Some(score) = pattern.score(item.matcher_columns, matchers.get()) { + self.matches.push(Match { score, idx }); + }; + false + }); + let new_snapshot = self.items.par_snapshot(self.last_snapshot); + if new_snapshot.end() != self.last_snapshot { + let end = new_snapshot.end(); + let in_flight = Mutex::new(&mut self.in_flight); + let items = new_snapshot.map(|(idx, item)| { + let Some(item) = item else { + in_flight.lock().push(idx); + unmatched.fetch_add(1, atomic::Ordering::Relaxed); + return Match { + score: 0, + idx: u32::MAX, + }; + }; + if self.canceled.load(atomic::Ordering::Relaxed) { + return Match { score: 0, idx }; + } + let Some(score) = pattern.score(item.matcher_columns, matchers.get()) else { + unmatched.fetch_add(1, atomic::Ordering::Relaxed); + return Match { + score: 0, + idx: u32::MAX, + }; + }; + Match { score, idx } + }); + self.matches.par_extend(items); + self.last_snapshot = end; + } + } + + fn remove_in_flight_matches(&mut self) { + let mut off = 0; + self.in_flight.retain(|&i| { + let is_in_flight = self.items.get(i).is_none(); + if is_in_flight { + self.matches.remove((i - off) as usize); + off += 1; + } + is_in_flight + }); + } + + unsafe fn process_new_items_trivial(&mut self) { + let new_snapshot = self.items.snapshot(self.last_snapshot); + if new_snapshot.end() != self.last_snapshot { + let end = new_snapshot.end(); + let items = new_snapshot.filter_map(|(idx, item)| { + if item.is_none() { + self.in_flight.push(idx); + return None; + }; + Some(Match { score: 0, idx }) + }); + self.matches.extend(items); + self.last_snapshot = end; + } + } + + pub(crate) unsafe fn run(&mut self, pattern_status: pattern::Status, cleared: bool) { + self.running = true; + self.was_canceled = false; + + if cleared { + self.last_snapshot = 0; + self.in_flight.clear(); + self.matches.clear(); + } + + // TODO: be smarter around reusing past results for rescoring + if self.pattern.is_empty() { + self.reset_matches(); + self.process_new_items_trivial(); + let canceled = self.sort_matches(); + if canceled { + self.was_canceled = true; + } else if self.should_notify.load(atomic::Ordering::Relaxed) { + (self.notify)(); + } + return; + } + + if pattern_status == pattern::Status::Rescore { + self.reset_matches(); + } + + let mut unmatched = AtomicU32::new(0); + if pattern_status != pattern::Status::Unchanged && !self.matches.is_empty() { + self.process_new_items_trivial(); + let matchers = &self.matchers; + let pattern = &self.pattern; + self.matches + .par_iter_mut() + .take_any_while(|_| !self.canceled.load(atomic::Ordering::Relaxed)) + .for_each(|match_| { + if match_.idx == u32::MAX { + debug_assert_eq!(match_.score, 0); + unmatched.fetch_add(1, atomic::Ordering::Relaxed); + return; + } + // safety: in-flight items are never added to the matches + let item = self.items.get_unchecked(match_.idx); + if let Some(score) = pattern.score(item.matcher_columns, matchers.get()) { + match_.score = score; + } else { + unmatched.fetch_add(1, atomic::Ordering::Relaxed); + match_.score = 0; + match_.idx = u32::MAX; + } + }); + } else { + self.process_new_items(&unmatched); + } + + let canceled = self.sort_matches(); + if canceled { + self.was_canceled = true; + } else { + self.matches + .truncate(self.matches.len() - take(unmatched.get_mut()) as usize); + if self.should_notify.load(atomic::Ordering::Relaxed) { + (self.notify)(); + } + } + } + + unsafe fn sort_matches(&mut self) -> bool { + if self.sort_results { + par_quicksort( + &mut self.matches, + |match1, match2| { + if match1.score != match2.score { + return match1.score > match2.score; + } + if match1.idx == u32::MAX { + return false; + } + if match2.idx == u32::MAX { + return true; + } + // the tie breaker is comparatively rarely needed so we keep it + // in a branch especially because we need to access the items + // array here which involves some pointer chasing + let item1 = self.items.get_unchecked(match1.idx); + let item2 = &self.items.get_unchecked(match2.idx); + let len1: u32 = item1 + .matcher_columns + .iter() + .map(|haystack| haystack.len() as u32) + .sum(); + let len2 = item2 + .matcher_columns + .iter() + .map(|haystack| haystack.len() as u32) + .sum(); + if len1 == len2 { + if self.reverse_items { + match2.idx < match1.idx + } else { + match1.idx < match2.idx + } + } else { + len1 < len2 + } + }, + &self.canceled, + ) + } else { + par_quicksort( + &mut self.matches, + |match1, match2| { + if match1.idx == u32::MAX { + return false; + } + if match2.idx == u32::MAX { + return true; + } + if self.reverse_items { + match2.idx < match1.idx + } else { + match1.idx < match2.idx + } + }, + &self.canceled, + ) + } + } + + fn reset_matches(&mut self) { + self.matches.clear(); + self.matches + .extend((0..self.last_snapshot).map(|idx| Match { score: 0, idx })); + // there are usually only very few in flight items (one for each writer) + self.remove_in_flight_matches(); + } +} |
