diff --git a/Cargo.toml b/Cargo.toml index 57795ce8..6f9b2755 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,8 +20,15 @@ noisy_float = "0.1.8" num-traits = "0.2" rand = "0.6" itertools = { version = "0.7.0", default-features = false } +indexmap = "1.0" [dev-dependencies] -quickcheck = "0.7" +criterion = "0.2" +quickcheck = { version = "0.8.1", default-features = false } ndarray-rand = "0.9" approx = "0.3" +quickcheck_macros = "0.8" + +[[bench]] +name = "sort" +harness = false diff --git a/benches/sort.rs b/benches/sort.rs new file mode 100644 index 00000000..cdcf8dd3 --- /dev/null +++ b/benches/sort.rs @@ -0,0 +1,67 @@ +extern crate criterion; +extern crate ndarray; +extern crate ndarray_stats; +extern crate rand; + +use criterion::{ + black_box, criterion_group, criterion_main, AxisScale, BatchSize, Criterion, + ParameterizedBenchmark, PlotConfiguration, +}; +use ndarray::prelude::*; +use ndarray_stats::Sort1dExt; +use rand::prelude::*; + +fn get_from_sorted_mut(c: &mut Criterion) { + let lens = vec![10, 100, 1000, 10000]; + let benchmark = ParameterizedBenchmark::new( + "get_from_sorted_mut", + |bencher, &len| { + let mut rng = StdRng::seed_from_u64(42); + let mut data: Vec<_> = (0..len).collect(); + data.shuffle(&mut rng); + let indices: Vec<_> = (0..len).step_by(len / 10).collect(); + bencher.iter_batched( + || Array1::from(data.clone()), + |mut arr| { + for &i in &indices { + black_box(arr.get_from_sorted_mut(i)); + } + }, + BatchSize::SmallInput, + ) + }, + lens, + ) + .plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); + c.bench("get_from_sorted_mut", benchmark); +} + +fn get_many_from_sorted_mut(c: &mut Criterion) { + let lens = vec![10, 100, 1000, 10000]; + let benchmark = ParameterizedBenchmark::new( + "get_many_from_sorted_mut", + |bencher, &len| { + let mut rng = StdRng::seed_from_u64(42); + let mut data: Vec<_> = (0..len).collect(); + data.shuffle(&mut rng); + let indices: Vec<_> = (0..len).step_by(len / 10).collect(); + bencher.iter_batched( + || Array1::from(data.clone()), + |mut arr| { + black_box(arr.get_many_from_sorted_mut(&indices)); + }, + BatchSize::SmallInput, + ) + }, + lens, + ) + .plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); + c.bench("get_many_from_sorted_mut", benchmark); +} + +criterion_group! { + name = benches; + config = Criterion::default(); + targets = get_from_sorted_mut, get_many_from_sorted_mut +} +criterion_main!(benches); diff --git a/src/errors.rs b/src/errors.rs index d89112a5..e2ee6965 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,4 +1,5 @@ //! Custom errors returned from our methods and functions. +use noisy_float::types::N64; use std::error::Error; use std::fmt; @@ -112,3 +113,31 @@ impl From for MultiInputError { MultiInputError::ShapeMismatch(err) } } + +/// An error computing a quantile. +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum QuantileError { + /// The input was empty. + EmptyInput, + /// The `q` was not between `0.` and `1.` (inclusive). + InvalidQuantile(N64), +} + +impl fmt::Display for QuantileError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + QuantileError::EmptyInput => write!(f, "Empty input."), + QuantileError::InvalidQuantile(q) => { + write!(f, "{:} is not between 0. and 1. (inclusive).", q) + } + } + } +} + +impl Error for QuantileError {} + +impl From for QuantileError { + fn from(_: EmptyInput) -> QuantileError { + QuantileError::EmptyInput + } +} diff --git a/src/histogram/strategies.rs b/src/histogram/strategies.rs index 93d75a9b..0892b311 100644 --- a/src/histogram/strategies.rs +++ b/src/histogram/strategies.rs @@ -24,6 +24,7 @@ use super::errors::BinsBuildError; use super::{Bins, Edges}; use ndarray::prelude::*; use ndarray::Data; +use noisy_float::types::n64; use num_traits::{FromPrimitive, NumOps, Zero}; /// A trait implemented by all strategies to build [`Bins`] @@ -334,8 +335,8 @@ where } let mut a_copy = a.to_owned(); - let first_quartile = a_copy.quantile_mut::(0.25).unwrap(); - let third_quartile = a_copy.quantile_mut::(0.75).unwrap(); + let first_quartile = a_copy.quantile_mut(n64(0.25), &Nearest).unwrap(); + let third_quartile = a_copy.quantile_mut(n64(0.75), &Nearest).unwrap(); let iqr = third_quartile - first_quartile; let bin_width = FreedmanDiaconis::compute_bin_width(n_points, iqr); diff --git a/src/lib.rs b/src/lib.rs index 9cf586f1..14bbb3a7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,6 +25,7 @@ //! [`NumPy`]: https://docs.scipy.org/doc/numpy-1.14.1/reference/routines.statistics.html //! [`StatsBase.jl`]: https://juliastats.github.io/StatsBase.jl/latest/ +extern crate indexmap; extern crate itertools; extern crate ndarray; extern crate noisy_float; diff --git a/src/quantile/interpolate.rs b/src/quantile/interpolate.rs new file mode 100644 index 00000000..a0fe64d4 --- /dev/null +++ b/src/quantile/interpolate.rs @@ -0,0 +1,138 @@ +//! Interpolation strategies. +use noisy_float::types::N64; +use num_traits::{Float, FromPrimitive, NumOps, ToPrimitive}; + +fn float_quantile_index(q: N64, len: usize) -> N64 { + q * ((len - 1) as f64) +} + +/// Returns the fraction that the quantile is between the lower and higher indices. +/// +/// This ranges from 0, where the quantile exactly corresponds the lower index, +/// to 1, where the quantile exactly corresponds to the higher index. +fn float_quantile_index_fraction(q: N64, len: usize) -> N64 { + float_quantile_index(q, len).fract() +} + +/// Returns the index of the value on the lower side of the quantile. +pub(crate) fn lower_index(q: N64, len: usize) -> usize { + float_quantile_index(q, len).floor().to_usize().unwrap() +} + +/// Returns the index of the value on the higher side of the quantile. +pub(crate) fn higher_index(q: N64, len: usize) -> usize { + float_quantile_index(q, len).ceil().to_usize().unwrap() +} + +/// Used to provide an interpolation strategy to [`quantile_axis_mut`]. +/// +/// [`quantile_axis_mut`]: ../trait.QuantileExt.html#tymethod.quantile_axis_mut +pub trait Interpolate { + /// Returns `true` iff the lower value is needed to compute the + /// interpolated value. + #[doc(hidden)] + fn needs_lower(q: N64, len: usize) -> bool; + + /// Returns `true` iff the higher value is needed to compute the + /// interpolated value. + #[doc(hidden)] + fn needs_higher(q: N64, len: usize) -> bool; + + /// Computes the interpolated value. + /// + /// **Panics** if `None` is provided for the lower value when it's needed + /// or if `None` is provided for the higher value when it's needed. + #[doc(hidden)] + fn interpolate(lower: Option, higher: Option, q: N64, len: usize) -> T; +} + +/// Select the higher value. +pub struct Higher; +/// Select the lower value. +pub struct Lower; +/// Select the nearest value. +pub struct Nearest; +/// Select the midpoint of the two values (`(lower + higher) / 2`). +pub struct Midpoint; +/// Linearly interpolate between the two values +/// (`lower + (higher - lower) * fraction`, where `fraction` is the +/// fractional part of the index surrounded by `lower` and `higher`). +pub struct Linear; + +impl Interpolate for Higher { + fn needs_lower(_q: N64, _len: usize) -> bool { + false + } + fn needs_higher(_q: N64, _len: usize) -> bool { + true + } + fn interpolate(_lower: Option, higher: Option, _q: N64, _len: usize) -> T { + higher.unwrap() + } +} + +impl Interpolate for Lower { + fn needs_lower(_q: N64, _len: usize) -> bool { + true + } + fn needs_higher(_q: N64, _len: usize) -> bool { + false + } + fn interpolate(lower: Option, _higher: Option, _q: N64, _len: usize) -> T { + lower.unwrap() + } +} + +impl Interpolate for Nearest { + fn needs_lower(q: N64, len: usize) -> bool { + float_quantile_index_fraction(q, len) < 0.5 + } + fn needs_higher(q: N64, len: usize) -> bool { + !>::needs_lower(q, len) + } + fn interpolate(lower: Option, higher: Option, q: N64, len: usize) -> T { + if >::needs_lower(q, len) { + lower.unwrap() + } else { + higher.unwrap() + } + } +} + +impl Interpolate for Midpoint +where + T: NumOps + Clone + FromPrimitive, +{ + fn needs_lower(_q: N64, _len: usize) -> bool { + true + } + fn needs_higher(_q: N64, _len: usize) -> bool { + true + } + fn interpolate(lower: Option, higher: Option, _q: N64, _len: usize) -> T { + let denom = T::from_u8(2).unwrap(); + let lower = lower.unwrap(); + let higher = higher.unwrap(); + lower.clone() + (higher.clone() - lower.clone()) / denom.clone() + } +} + +impl Interpolate for Linear +where + T: NumOps + Clone + FromPrimitive + ToPrimitive, +{ + fn needs_lower(_q: N64, _len: usize) -> bool { + true + } + fn needs_higher(_q: N64, _len: usize) -> bool { + true + } + fn interpolate(lower: Option, higher: Option, q: N64, len: usize) -> T { + let fraction = float_quantile_index_fraction(q, len).to_f64().unwrap(); + let lower = lower.unwrap(); + let higher = higher.unwrap(); + let lower_f64 = lower.to_f64().unwrap(); + let higher_f64 = higher.to_f64().unwrap(); + lower.clone() + T::from_f64(fraction * (higher_f64 - lower_f64)).unwrap() + } +} diff --git a/src/quantile.rs b/src/quantile/mod.rs similarity index 61% rename from src/quantile.rs rename to src/quantile/mod.rs index 626b27f9..3926e24f 100644 --- a/src/quantile.rs +++ b/src/quantile/mod.rs @@ -1,181 +1,12 @@ +use self::interpolate::{higher_index, lower_index, Interpolate}; +use super::sort::get_many_from_sorted_mut_unchecked; use crate::errors::{EmptyInput, MinMaxError, MinMaxError::UndefinedOrder}; -use interpolate::Interpolate; +use errors::QuantileError; use ndarray::prelude::*; -use ndarray::{s, Data, DataMut, RemoveAxis}; +use ndarray::{Data, DataMut, RemoveAxis, Zip}; +use noisy_float::types::N64; use std::cmp; -use {MaybeNan, MaybeNanExt, Sort1dExt}; - -/// Interpolation strategies. -pub mod interpolate { - use ndarray::azip; - use ndarray::prelude::*; - use num_traits::{FromPrimitive, NumOps, ToPrimitive}; - - /// Used to provide an interpolation strategy to [`quantile_axis_mut`]. - /// - /// [`quantile_axis_mut`]: ../trait.QuantileExt.html#tymethod.quantile_axis_mut - pub trait Interpolate { - #[doc(hidden)] - fn float_quantile_index(q: f64, len: usize) -> f64 { - ((len - 1) as f64) * q - } - #[doc(hidden)] - fn lower_index(q: f64, len: usize) -> usize { - Self::float_quantile_index(q, len).floor() as usize - } - #[doc(hidden)] - fn higher_index(q: f64, len: usize) -> usize { - Self::float_quantile_index(q, len).ceil() as usize - } - #[doc(hidden)] - fn float_quantile_index_fraction(q: f64, len: usize) -> f64 { - Self::float_quantile_index(q, len).fract() - } - #[doc(hidden)] - fn needs_lower(q: f64, len: usize) -> bool; - #[doc(hidden)] - fn needs_higher(q: f64, len: usize) -> bool; - #[doc(hidden)] - fn interpolate( - lower: Option>, - higher: Option>, - q: f64, - len: usize, - ) -> Array - where - D: Dimension; - } - - /// Select the higher value. - pub struct Higher; - /// Select the lower value. - pub struct Lower; - /// Select the nearest value. - pub struct Nearest; - /// Select the midpoint of the two values (`(lower + higher) / 2`). - pub struct Midpoint; - /// Linearly interpolate between the two values - /// (`lower + (higher - lower) * fraction`, where `fraction` is the - /// fractional part of the index surrounded by `lower` and `higher`). - pub struct Linear; - - impl Interpolate for Higher { - fn needs_lower(_q: f64, _len: usize) -> bool { - false - } - fn needs_higher(_q: f64, _len: usize) -> bool { - true - } - fn interpolate( - _lower: Option>, - higher: Option>, - _q: f64, - _len: usize, - ) -> Array { - higher.unwrap() - } - } - - impl Interpolate for Lower { - fn needs_lower(_q: f64, _len: usize) -> bool { - true - } - fn needs_higher(_q: f64, _len: usize) -> bool { - false - } - fn interpolate( - lower: Option>, - _higher: Option>, - _q: f64, - _len: usize, - ) -> Array { - lower.unwrap() - } - } - - impl Interpolate for Nearest { - fn needs_lower(q: f64, len: usize) -> bool { - >::float_quantile_index_fraction(q, len) < 0.5 - } - fn needs_higher(q: f64, len: usize) -> bool { - !>::needs_lower(q, len) - } - fn interpolate( - lower: Option>, - higher: Option>, - q: f64, - len: usize, - ) -> Array { - if >::needs_lower(q, len) { - lower.unwrap() - } else { - higher.unwrap() - } - } - } - - impl Interpolate for Midpoint - where - T: NumOps + Clone + FromPrimitive, - { - fn needs_lower(_q: f64, _len: usize) -> bool { - true - } - fn needs_higher(_q: f64, _len: usize) -> bool { - true - } - fn interpolate( - lower: Option>, - higher: Option>, - _q: f64, - _len: usize, - ) -> Array - where - D: Dimension, - { - let denom = T::from_u8(2).unwrap(); - let mut lower = lower.unwrap(); - let higher = higher.unwrap(); - azip!( - mut lower, ref higher in { - *lower = lower.clone() + (higher.clone() - lower.clone()) / denom.clone() - } - ); - lower - } - } - - impl Interpolate for Linear - where - T: NumOps + Clone + FromPrimitive + ToPrimitive, - { - fn needs_lower(_q: f64, _len: usize) -> bool { - true - } - fn needs_higher(_q: f64, _len: usize) -> bool { - true - } - fn interpolate( - lower: Option>, - higher: Option>, - q: f64, - len: usize, - ) -> Array - where - D: Dimension, - { - let fraction = >::float_quantile_index_fraction(q, len); - let mut a = lower.unwrap(); - let b = higher.unwrap(); - azip!(mut a, ref b in { - let a_f64 = a.to_f64().unwrap(); - let b_f64 = b.to_f64().unwrap(); - *a = a.clone() + T::from_f64((b_f64 - a_f64) * fraction).unwrap(); - }); - a - } - } -} +use {MaybeNan, MaybeNanExt}; /// Quantile methods for `ArrayBase`. pub trait QuantileExt @@ -214,7 +45,7 @@ where /// Finds the index of the minimum value of the array skipping NaN values. /// - /// Returns `None` if the array is empty or none of the values in the array + /// Returns `Err(MinMaxError::EmptyInput)` if the array is empty or none of the values in the array /// are non-NaN values. /// /// Even if there are multiple (equal) elements that are minima, only one @@ -232,9 +63,9 @@ where /// /// let a = array![[::std::f64::NAN, 3., 5.], /// [2., 0., 6.]]; - /// assert_eq!(a.argmin_skipnan(), Some((1, 1))); + /// assert_eq!(a.argmin_skipnan(), Ok((1, 1))); /// ``` - fn argmin_skipnan(&self) -> Option + fn argmin_skipnan(&self) -> Result where A: MaybeNan, A::NotNan: Ord; @@ -299,7 +130,7 @@ where /// Finds the index of the maximum value of the array skipping NaN values. /// - /// Returns `None` if the array is empty or none of the values in the array + /// Returns `Err(MinMaxError::EmptyInput)` if the array is empty or none of the values in the array /// are non-NaN values. /// /// Even if there are multiple (equal) elements that are maxima, only one @@ -317,9 +148,9 @@ where /// /// let a = array![[::std::f64::NAN, 3., 5.], /// [2., 0., 6.]]; - /// assert_eq!(a.argmax_skipnan(), Some((1, 2))); + /// assert_eq!(a.argmax_skipnan(), Ok((1, 2))); /// ``` - fn argmax_skipnan(&self) -> Option + fn argmax_skipnan(&self) -> Result where A: MaybeNan, A::NotNan: Ord; @@ -361,7 +192,7 @@ where /// in increasing order. /// If `(N-1)q` is not an integer the desired quantile lies between /// two data points: we return the lower, nearest, higher or interpolated - /// value depending on the type `Interpolate` bound `I`. + /// value depending on the `interpolate` strategy. /// /// Some examples: /// - `q=0.` returns the minimum along each 1-dimensional lane; @@ -381,19 +212,83 @@ where /// - worst case: O(`m`^2); /// where `m` is the number of elements in the array. /// - /// **Panics** if `axis` is out of bounds, if the axis has length 0, or if - /// `q` is not between `0.` and `1.` (inclusive). - fn quantile_axis_mut(&mut self, axis: Axis, q: f64) -> Array + /// Returns `Err(EmptyInput)` when the specified axis has length 0. + /// + /// Returns `Err(InvalidQuantile(q))` if `q` is not between `0.` and `1.` (inclusive). + /// + /// **Panics** if `axis` is out of bounds. + fn quantile_axis_mut( + &mut self, + axis: Axis, + q: N64, + interpolate: &I, + ) -> Result, QuantileError> + where + D: RemoveAxis, + A: Ord + Clone, + S: DataMut, + I: Interpolate; + + /// A bulk version of [`quantile_axis_mut`], optimized to retrieve multiple + /// quantiles at once. + /// + /// Returns an `Array`, where subviews along `axis` of the array correspond + /// to the elements of `qs`. + /// + /// See [`quantile_axis_mut`] for additional details on quantiles and the algorithm + /// used to retrieve them. + /// + /// Returns `Err(EmptyInput)` when the specified axis has length 0. + /// + /// Returns `Err(InvalidQuantile(q))` if any `q` in `qs` is not between `0.` and `1.` (inclusive). + /// + /// **Panics** if `axis` is out of bounds. + /// + /// [`quantile_axis_mut`]: #tymethod.quantile_axis_mut + /// + /// # Example + /// + /// ```rust + /// # extern crate ndarray; + /// # extern crate ndarray_stats; + /// # extern crate noisy_float; + /// # + /// use ndarray::{array, aview1, Axis}; + /// use ndarray_stats::{QuantileExt, interpolate::Nearest}; + /// use noisy_float::types::n64; + /// + /// # fn main() { + /// let mut data = array![[3, 4, 5], [6, 7, 8]]; + /// let axis = Axis(1); + /// let qs = &[n64(0.3), n64(0.7)]; + /// let quantiles = data.quantiles_axis_mut(axis, &aview1(qs), &Nearest).unwrap(); + /// for (&q, quantile) in qs.iter().zip(quantiles.axis_iter(axis)) { + /// assert_eq!(quantile, data.quantile_axis_mut(axis, q, &Nearest).unwrap()); + /// } + /// # } + /// ``` + fn quantiles_axis_mut( + &mut self, + axis: Axis, + qs: &ArrayBase, + interpolate: &I, + ) -> Result, QuantileError> where D: RemoveAxis, A: Ord + Clone, S: DataMut, + S2: Data, I: Interpolate; /// Return the `q`th quantile of the data along the specified axis, skipping NaN values. /// - /// See [`quantile_axis_mut`](##tymethod.quantile_axis_mut) for details. - fn quantile_axis_skipnan_mut(&mut self, axis: Axis, q: f64) -> Array + /// See [`quantile_axis_mut`](#tymethod.quantile_axis_mut) for details. + fn quantile_axis_skipnan_mut( + &mut self, + axis: Axis, + q: N64, + interpolate: &I, + ) -> Result, QuantileError> where D: RemoveAxis, A: MaybeNan, @@ -424,7 +319,7 @@ where Ok(current_pattern_min) } - fn argmin_skipnan(&self) -> Option + fn argmin_skipnan(&self) -> Result where A: MaybeNan, A::NotNan: Ord, @@ -440,9 +335,9 @@ where }) }); if min.is_some() { - Some(pattern_min) + Ok(pattern_min) } else { - None + Err(MinMaxError::EmptyInput) } } @@ -491,7 +386,7 @@ where Ok(current_pattern_max) } - fn argmax_skipnan(&self) -> Option + fn argmax_skipnan(&self) -> Result where A: MaybeNan, A::NotNan: Ord, @@ -507,9 +402,9 @@ where }) }); if max.is_some() { - Some(pattern_max) + Ok(pattern_max) } else { - None + Err(MinMaxError::EmptyInput) } } @@ -541,37 +436,108 @@ where })) } - fn quantile_axis_mut(&mut self, axis: Axis, q: f64) -> Array + fn quantiles_axis_mut( + &mut self, + axis: Axis, + qs: &ArrayBase, + interpolate: &I, + ) -> Result, QuantileError> where D: RemoveAxis, A: Ord + Clone, S: DataMut, + S2: Data, I: Interpolate, { - assert!((0. <= q) && (q <= 1.)); - let mut lower = None; - let mut higher = None; - let axis_len = self.len_of(axis); - if I::needs_lower(q, axis_len) { - let lower_index = I::lower_index(q, axis_len); - lower = Some(self.map_axis_mut(axis, |mut x| x.sorted_get_mut(lower_index))); - if I::needs_higher(q, axis_len) { - let higher_index = I::higher_index(q, axis_len); - let relative_higher_index = higher_index - lower_index; - higher = Some(self.map_axis_mut(axis, |mut x| { - x.slice_mut(s![lower_index..]) - .sorted_get_mut(relative_higher_index) - })); - }; - } else { - higher = Some( - self.map_axis_mut(axis, |mut x| x.sorted_get_mut(I::higher_index(q, axis_len))), - ); - }; - I::interpolate(lower, higher, q, axis_len) + // Minimize number of type parameters to avoid monomorphization bloat. + fn quantiles_axis_mut( + mut data: ArrayViewMut, + axis: Axis, + qs: ArrayView1, + _interpolate: &I, + ) -> Result, QuantileError> + where + D: RemoveAxis, + A: Ord + Clone, + I: Interpolate, + { + for &q in qs { + if !((q >= 0.) && (q <= 1.)) { + return Err(QuantileError::InvalidQuantile(q)); + } + } + + let axis_len = data.len_of(axis); + if axis_len == 0 { + return Err(QuantileError::EmptyInput); + } + + let mut results_shape = data.raw_dim(); + results_shape[axis.index()] = qs.len(); + if results_shape.size() == 0 { + return Ok(Array::from_shape_vec(results_shape, Vec::new()).unwrap()); + } + + let mut searched_indexes = Vec::with_capacity(2 * qs.len()); + for &q in &qs { + if I::needs_lower(q, axis_len) { + searched_indexes.push(lower_index(q, axis_len)); + } + if I::needs_higher(q, axis_len) { + searched_indexes.push(higher_index(q, axis_len)); + } + } + searched_indexes.sort(); + searched_indexes.dedup(); + + let mut results = Array::from_elem(results_shape, data.first().unwrap().clone()); + Zip::from(results.lanes_mut(axis)) + .and(data.lanes_mut(axis)) + .apply(|mut results, mut data| { + let index_map = + get_many_from_sorted_mut_unchecked(&mut data, &searched_indexes); + for (result, &q) in results.iter_mut().zip(qs) { + let lower = if I::needs_lower(q, axis_len) { + Some(index_map[&lower_index(q, axis_len)].clone()) + } else { + None + }; + let higher = if I::needs_higher(q, axis_len) { + Some(index_map[&higher_index(q, axis_len)].clone()) + } else { + None + }; + *result = I::interpolate(lower, higher, q, axis_len); + } + }); + Ok(results) + } + + quantiles_axis_mut(self.view_mut(), axis, qs.view(), interpolate) } - fn quantile_axis_skipnan_mut(&mut self, axis: Axis, q: f64) -> Array + fn quantile_axis_mut( + &mut self, + axis: Axis, + q: N64, + interpolate: &I, + ) -> Result, QuantileError> + where + D: RemoveAxis, + A: Ord + Clone, + S: DataMut, + I: Interpolate, + { + self.quantiles_axis_mut(axis, &aview1(&[q]), interpolate) + .map(|a| a.index_axis_move(axis, 0)) + } + + fn quantile_axis_skipnan_mut( + &mut self, + axis: Axis, + q: N64, + interpolate: &I, + ) -> Result, QuantileError> where D: RemoveAxis, A: MaybeNan, @@ -579,19 +545,28 @@ where S: DataMut, I: Interpolate, { - self.map_axis_mut(axis, |lane| { + if !((q >= 0.) && (q <= 1.)) { + return Err(QuantileError::InvalidQuantile(q)); + } + + if self.len_of(axis) == 0 { + return Err(QuantileError::EmptyInput); + } + + let quantile = self.map_axis_mut(axis, |lane| { let mut not_nan = A::remove_nan_mut(lane); A::from_not_nan_opt(if not_nan.is_empty() { None } else { Some( not_nan - .quantile_axis_mut::(Axis(0), q) - .into_raw_vec() - .remove(0), + .quantile_axis_mut::(Axis(0), q, interpolate) + .unwrap() + .into_scalar(), ) }) - }) + }); + Ok(quantile) } } @@ -608,7 +583,7 @@ where /// in increasing order. /// If `(N-1)q` is not an integer the desired quantile lies between /// two data points: we return the lower, nearest, higher or interpolated - /// value depending on the type `Interpolate` bound `I`. + /// value depending on the `interpolate` strategy. /// /// Some examples: /// - `q=0.` returns the minimum; @@ -628,11 +603,37 @@ where /// /// Returns `Err(EmptyInput)` if the array is empty. /// - /// **Panics** if `q` is not between `0.` and `1.` (inclusive). - fn quantile_mut(&mut self, q: f64) -> Result + /// Returns `Err(InvalidQuantile(q))` if `q` is not between `0.` and `1.` (inclusive). + fn quantile_mut(&mut self, q: N64, interpolate: &I) -> Result + where + A: Ord + Clone, + S: DataMut, + I: Interpolate; + + /// A bulk version of [`quantile_mut`], optimized to retrieve multiple + /// quantiles at once. + /// + /// Returns an `Array`, where the elements of the array correspond to the + /// elements of `qs`. + /// + /// Returns `Err(EmptyInput)` if the array is empty. + /// + /// Returns `Err(InvalidQuantile(q))` if any `q` in + /// `qs` is not between `0.` and `1.` (inclusive). + /// + /// See [`quantile_mut`] for additional details on quantiles and the algorithm + /// used to retrieve them. + /// + /// [`quantile_mut`]: #tymethod.quantile_mut + fn quantiles_mut( + &mut self, + qs: &ArrayBase, + interpolate: &I, + ) -> Result, QuantileError> where A: Ord + Clone, S: DataMut, + S2: Data, I: Interpolate; } @@ -640,16 +641,30 @@ impl Quantile1dExt for ArrayBase where S: Data, { - fn quantile_mut(&mut self, q: f64) -> Result + fn quantile_mut(&mut self, q: N64, interpolate: &I) -> Result where A: Ord + Clone, S: DataMut, I: Interpolate, { - if self.is_empty() { - Err(EmptyInput) - } else { - Ok(self.quantile_axis_mut::(Axis(0), q).into_scalar()) - } + Ok(self + .quantile_axis_mut(Axis(0), q, interpolate)? + .into_scalar()) + } + + fn quantiles_mut( + &mut self, + qs: &ArrayBase, + interpolate: &I, + ) -> Result, QuantileError> + where + A: Ord + Clone, + S: DataMut, + S2: Data, + I: Interpolate, + { + self.quantiles_axis_mut(Axis(0), qs, interpolate) } } + +pub mod interpolate; diff --git a/src/sort.rs b/src/sort.rs index eecda8f4..5a0b5d3b 100644 --- a/src/sort.rs +++ b/src/sort.rs @@ -1,5 +1,6 @@ +use indexmap::IndexMap; use ndarray::prelude::*; -use ndarray::{s, Data, DataMut}; +use ndarray::{Data, DataMut, Slice}; use rand::prelude::*; use rand::thread_rng; @@ -27,20 +28,38 @@ where /// where n is the number of elements in the array. /// /// **Panics** if `i` is greater than or equal to `n`. - fn sorted_get_mut(&mut self, i: usize) -> A + fn get_from_sorted_mut(&mut self, i: usize) -> A where A: Ord + Clone, S: DataMut; - /// Return the index of `self[partition_index]` if `self` were to be sorted - /// in increasing order. + /// A bulk version of [`get_from_sorted_mut`], optimized to retrieve multiple + /// indexes at once. + /// It returns an `IndexMap`, with indexes as keys and retrieved elements as + /// values. + /// The `IndexMap` is sorted with respect to indexes in increasing order: + /// this ordering is preserved when you iterate over it (using `iter`/`into_iter`). /// - /// `self` elements are rearranged in such a way that `self[partition_index]` - /// is in the position it would be in an array sorted in increasing order. - /// All elements smaller than `self[partition_index]` are moved to its - /// left and all elements equal or greater than `self[partition_index]` - /// are moved to its right. - /// The ordering of the elements in the two partitions is undefined. + /// **Panics** if any element in `indexes` is greater than or equal to `n`, + /// where `n` is the length of the array.. + /// + /// [`get_from_sorted_mut`]: #tymethod.get_from_sorted_mut + fn get_many_from_sorted_mut(&mut self, indexes: &ArrayBase) -> IndexMap + where + A: Ord + Clone, + S: DataMut, + S2: Data; + + /// Partitions the array in increasing order based on the value initially + /// located at `pivot_index` and returns the new index of the value. + /// + /// The elements are rearranged in such a way that the value initially + /// located at `pivot_index` is moved to the position it would be in an + /// array sorted in increasing order. The return value is the new index of + /// the value after rearrangement. All elements smaller than the value are + /// moved to its left and all elements equal or greater than the value are + /// moved to its right. The ordering of the elements in the two partitions + /// is undefined. /// /// `self` is shuffled **in place** to operate the desired partition: /// no copy of the array is allocated. @@ -50,7 +69,36 @@ where /// Average number of element swaps: n/6 - 1/3 (see /// [link](https://cs.stackexchange.com/questions/11458/quicksort-partitioning-hoare-vs-lomuto/11550)) /// - /// **Panics** if `partition_index` is greater than or equal to `n`. + /// **Panics** if `pivot_index` is greater than or equal to `n`. + /// + /// # Example + /// + /// ``` + /// extern crate ndarray; + /// extern crate ndarray_stats; + /// + /// use ndarray::array; + /// use ndarray_stats::Sort1dExt; + /// + /// # fn main() { + /// let mut data = array![3, 1, 4, 5, 2]; + /// let pivot_index = 2; + /// let pivot_value = data[pivot_index]; + /// + /// // Partition by the value located at `pivot_index`. + /// let new_index = data.partition_mut(pivot_index); + /// // The pivot value is now located at `new_index`. + /// assert_eq!(data[new_index], pivot_value); + /// // Elements less than that value are moved to the left. + /// for i in 0..new_index { + /// assert!(data[i] < pivot_value); + /// } + /// // Elements greater than or equal to that value are moved to the right. + /// for i in (new_index + 1)..data.len() { + /// assert!(data[i] >= pivot_value); + /// } + /// # } + /// ``` fn partition_mut(&mut self, pivot_index: usize) -> usize where A: Ord + Clone, @@ -61,7 +109,7 @@ impl Sort1dExt for ArrayBase where S: Data, { - fn sorted_get_mut(&mut self, i: usize) -> A + fn get_from_sorted_mut(&mut self, i: usize) -> A where A: Ord + Clone, S: DataMut, @@ -74,16 +122,30 @@ where let pivot_index = rng.gen_range(0, n); let partition_index = self.partition_mut(pivot_index); if i < partition_index { - self.slice_mut(s![..partition_index]).sorted_get_mut(i) + self.slice_axis_mut(Axis(0), Slice::from(..partition_index)) + .get_from_sorted_mut(i) } else if i == partition_index { self[i].clone() } else { - self.slice_mut(s![partition_index + 1..]) - .sorted_get_mut(i - (partition_index + 1)) + self.slice_axis_mut(Axis(0), Slice::from(partition_index + 1..)) + .get_from_sorted_mut(i - (partition_index + 1)) } } } + fn get_many_from_sorted_mut(&mut self, indexes: &ArrayBase) -> IndexMap + where + A: Ord + Clone, + S: DataMut, + S2: Data, + { + let mut deduped_indexes: Vec = indexes.to_vec(); + deduped_indexes.sort_unstable(); + deduped_indexes.dedup(); + + get_many_from_sorted_mut_unchecked(self, &deduped_indexes) + } + fn partition_mut(&mut self, pivot_index: usize) -> usize where A: Ord + Clone, @@ -122,3 +184,116 @@ where i - 1 } } + +/// To retrieve multiple indexes from the sorted array in an optimized fashion, +/// [get_many_from_sorted_mut] first of all sorts and deduplicates the +/// `indexes` vector. +/// +/// `get_many_from_sorted_mut_unchecked` does not perform this sorting and +/// deduplication, assuming that the user has already taken care of it. +/// +/// Useful when you have to call [get_many_from_sorted_mut] multiple times +/// using the same indexes. +/// +/// [get_many_from_sorted_mut]: ../trait.Sort1dExt.html#tymethod.get_many_from_sorted_mut +pub(crate) fn get_many_from_sorted_mut_unchecked( + array: &mut ArrayBase, + indexes: &[usize], +) -> IndexMap +where + A: Ord + Clone, + S: DataMut, +{ + if indexes.is_empty() { + return IndexMap::new(); + } + + // Since `!indexes.is_empty()` and indexes must be in-bounds, `array` must + // be non-empty. + let mut values = vec![array[0].clone(); indexes.len()]; + _get_many_from_sorted_mut_unchecked(array.view_mut(), &mut indexes.to_owned(), &mut values); + + // We convert the vector to a more search-friendly `IndexMap`. + indexes.iter().cloned().zip(values.into_iter()).collect() +} + +/// This is the recursive portion of `get_many_from_sorted_mut_unchecked`. +/// +/// `indexes` is the list of indexes to get. `indexes` is mutable so that it +/// can be used as scratch space for this routine; the value of `indexes` after +/// calling this routine should be ignored. +/// +/// `values` is a pre-allocated slice to use for writing the output. Its +/// initial element values are ignored. +fn _get_many_from_sorted_mut_unchecked( + mut array: ArrayViewMut1, + indexes: &mut [usize], + values: &mut [A], +) where + A: Ord + Clone, +{ + let n = array.len(); + debug_assert!(n >= indexes.len()); // because indexes must be unique and in-bounds + debug_assert_eq!(indexes.len(), values.len()); + + if indexes.is_empty() { + // Nothing to do in this case. + return; + } + + // At this point, `n >= 1` since `indexes.len() >= 1`. + if n == 1 { + // We can only reach this point if `indexes.len() == 1`, so we only + // need to assign the single value, and then we're done. + debug_assert_eq!(indexes.len(), 1); + values[0] = array[0].clone(); + return; + } + + // We pick a random pivot index: the corresponding element is the pivot value + let mut rng = thread_rng(); + let pivot_index = rng.gen_range(0, n); + + // We partition the array with respect to the pivot value. + // The pivot value moves to `array_partition_index`. + // Elements strictly smaller than the pivot value have indexes < `array_partition_index`. + // Elements greater or equal to the pivot value have indexes > `array_partition_index`. + let array_partition_index = array.partition_mut(pivot_index); + + // We use a divide-and-conquer strategy, splitting the indexes we are + // searching for (`indexes`) and the corresponding portions of the output + // slice (`values`) into pieces with respect to `array_partition_index`. + let (found_exact, index_split) = match indexes.binary_search(&array_partition_index) { + Ok(index) => (true, index), + Err(index) => (false, index), + }; + let (smaller_indexes, other_indexes) = indexes.split_at_mut(index_split); + let (smaller_values, other_values) = values.split_at_mut(index_split); + let (bigger_indexes, bigger_values) = if found_exact { + other_values[0] = array[array_partition_index].clone(); // Write exactly found value. + (&mut other_indexes[1..], &mut other_values[1..]) + } else { + (other_indexes, other_values) + }; + + // We search recursively for the values corresponding to strictly smaller + // indexes to the left of `partition_index`. + _get_many_from_sorted_mut_unchecked( + array.slice_axis_mut(Axis(0), Slice::from(..array_partition_index)), + smaller_indexes, + smaller_values, + ); + + // We search recursively for the values corresponding to strictly bigger + // indexes to the right of `partition_index`. Since only the right portion + // of the array is passed in, the indexes need to be shifted by length of + // the removed portion. + bigger_indexes + .iter_mut() + .for_each(|x| *x -= array_partition_index + 1); + _get_many_from_sorted_mut_unchecked( + array.slice_axis_mut(Axis(0), Slice::from(array_partition_index + 1..)), + bigger_indexes, + bigger_values, + ); +} diff --git a/tests/quantile.rs b/tests/quantile.rs index 05f1c7a3..36999089 100644 --- a/tests/quantile.rs +++ b/tests/quantile.rs @@ -1,15 +1,21 @@ -#[macro_use(array)] +extern crate itertools; extern crate ndarray; extern crate ndarray_stats; +extern crate noisy_float; +#[macro_use] extern crate quickcheck; +extern crate quickcheck_macros; +use itertools::izip; +use ndarray::array; use ndarray::prelude::*; use ndarray_stats::{ - errors::MinMaxError, - interpolate::{Higher, Linear, Lower, Midpoint, Nearest}, + errors::{MinMaxError, QuantileError}, + interpolate::{Higher, Interpolate, Linear, Lower, Midpoint, Nearest}, Quantile1dExt, QuantileExt, }; -use quickcheck::quickcheck; +use noisy_float::types::{n64, N64}; +use quickcheck_macros::quickcheck; #[test] fn test_argmin() { @@ -36,19 +42,19 @@ quickcheck! { #[test] fn test_argmin_skipnan() { let a = array![[1., 5., 3.], [2., 0., 6.]]; - assert_eq!(a.argmin_skipnan(), Some((1, 1))); + assert_eq!(a.argmin_skipnan(), Ok((1, 1))); let a = array![[1., 5., 3.], [2., ::std::f64::NAN, 6.]]; - assert_eq!(a.argmin_skipnan(), Some((0, 0))); + assert_eq!(a.argmin_skipnan(), Ok((0, 0))); let a = array![[::std::f64::NAN, 5., 3.], [2., ::std::f64::NAN, 6.]]; - assert_eq!(a.argmin_skipnan(), Some((1, 0))); + assert_eq!(a.argmin_skipnan(), Ok((1, 0))); let a: Array2 = array![[], []]; - assert_eq!(a.argmin_skipnan(), None); + assert_eq!(a.argmin_skipnan(), Err(MinMaxError::EmptyInput)); let a = arr2(&[[::std::f64::NAN; 2]; 2]); - assert_eq!(a.argmin_skipnan(), None); + assert_eq!(a.argmin_skipnan(), Err(MinMaxError::EmptyInput)); } quickcheck! { @@ -57,7 +63,7 @@ quickcheck! { let min = a.min_skipnan(); let argmin = a.argmin_skipnan(); if min.is_none() { - argmin == None + argmin == Err(MinMaxError::EmptyInput) } else { a[argmin.unwrap()] == *min } @@ -116,22 +122,22 @@ quickcheck! { #[test] fn test_argmax_skipnan() { let a = array![[1., 5., 3.], [2., 0., 6.]]; - assert_eq!(a.argmax_skipnan(), Some((1, 2))); + assert_eq!(a.argmax_skipnan(), Ok((1, 2))); let a = array![[1., 5., 3.], [2., ::std::f64::NAN, ::std::f64::NAN]]; - assert_eq!(a.argmax_skipnan(), Some((0, 1))); + assert_eq!(a.argmax_skipnan(), Ok((0, 1))); let a = array![ [::std::f64::NAN, ::std::f64::NAN, 3.], [2., ::std::f64::NAN, 6.] ]; - assert_eq!(a.argmax_skipnan(), Some((1, 2))); + assert_eq!(a.argmax_skipnan(), Ok((1, 2))); let a: Array2 = array![[], []]; - assert_eq!(a.argmax_skipnan(), None); + assert_eq!(a.argmax_skipnan(), Err(MinMaxError::EmptyInput)); let a = arr2(&[[::std::f64::NAN; 2]; 2]); - assert_eq!(a.argmax_skipnan(), None); + assert_eq!(a.argmax_skipnan(), Err(MinMaxError::EmptyInput)); } quickcheck! { @@ -140,7 +146,7 @@ quickcheck! { let max = a.max_skipnan(); let argmax = a.argmax_skipnan(); if max.is_none() { - argmax == None + argmax == Err(MinMaxError::EmptyInput) } else { a[argmax.unwrap()] == *max } @@ -177,49 +183,53 @@ fn test_max_skipnan_all_nan() { #[test] fn test_quantile_axis_mut_with_odd_axis_length() { let mut a = arr2(&[[1, 3, 2, 10], [2, 4, 3, 11], [3, 5, 6, 12]]); - let p = a.quantile_axis_mut::(Axis(0), 0.5); + let p = a.quantile_axis_mut(Axis(0), n64(0.5), &Lower).unwrap(); assert!(p == a.index_axis(Axis(0), 1)); } #[test] -#[should_panic] fn test_quantile_axis_mut_with_zero_axis_length() { let mut a = Array2::::zeros((5, 0)); - a.quantile_axis_mut::(Axis(1), 0.5); + assert_eq!( + a.quantile_axis_mut(Axis(1), n64(0.5), &Lower), + Err(QuantileError::EmptyInput) + ); } #[test] fn test_quantile_axis_mut_with_empty_array() { let mut a = Array2::::zeros((5, 0)); - let p = a.quantile_axis_mut::(Axis(0), 0.5); + let p = a.quantile_axis_mut(Axis(0), n64(0.5), &Lower).unwrap(); assert_eq!(p.shape(), &[0]); } #[test] fn test_quantile_axis_mut_with_even_axis_length() { let mut b = arr2(&[[1, 3, 2, 10], [2, 4, 3, 11], [3, 5, 6, 12], [4, 6, 7, 13]]); - let q = b.quantile_axis_mut::(Axis(0), 0.5); + let q = b.quantile_axis_mut(Axis(0), n64(0.5), &Lower).unwrap(); assert!(q == b.index_axis(Axis(0), 1)); } #[test] fn test_quantile_axis_mut_to_get_minimum() { let mut b = arr2(&[[1, 3, 22, 10]]); - let q = b.quantile_axis_mut::(Axis(1), 0.); + let q = b.quantile_axis_mut(Axis(1), n64(0.), &Lower).unwrap(); assert!(q == arr1(&[1])); } #[test] fn test_quantile_axis_mut_to_get_maximum() { let mut b = arr1(&[1, 3, 22, 10]); - let q = b.quantile_axis_mut::(Axis(0), 1.); + let q = b.quantile_axis_mut(Axis(0), n64(1.), &Lower).unwrap(); assert!(q == arr0(22)); } #[test] fn test_quantile_axis_skipnan_mut_higher_opt_i32() { let mut a = arr2(&[[Some(4), Some(2), None, Some(1), Some(5)], [None; 5]]); - let q = a.quantile_axis_skipnan_mut::(Axis(1), 0.6); + let q = a + .quantile_axis_skipnan_mut(Axis(1), n64(0.6), &Higher) + .unwrap(); assert_eq!(q.shape(), &[2]); assert_eq!(q[0], Some(4)); assert!(q[1].is_none()); @@ -228,7 +238,9 @@ fn test_quantile_axis_skipnan_mut_higher_opt_i32() { #[test] fn test_quantile_axis_skipnan_mut_nearest_opt_i32() { let mut a = arr2(&[[Some(4), Some(2), None, Some(1), Some(5)], [None; 5]]); - let q = a.quantile_axis_skipnan_mut::(Axis(1), 0.6); + let q = a + .quantile_axis_skipnan_mut(Axis(1), n64(0.6), &Nearest) + .unwrap(); assert_eq!(q.shape(), &[2]); assert_eq!(q[0], Some(4)); assert!(q[1].is_none()); @@ -237,7 +249,9 @@ fn test_quantile_axis_skipnan_mut_nearest_opt_i32() { #[test] fn test_quantile_axis_skipnan_mut_midpoint_opt_i32() { let mut a = arr2(&[[Some(4), Some(2), None, Some(1), Some(5)], [None; 5]]); - let q = a.quantile_axis_skipnan_mut::(Axis(1), 0.6); + let q = a + .quantile_axis_skipnan_mut(Axis(1), n64(0.6), &Midpoint) + .unwrap(); assert_eq!(q.shape(), &[2]); assert_eq!(q[0], Some(3)); assert!(q[1].is_none()); @@ -246,7 +260,9 @@ fn test_quantile_axis_skipnan_mut_midpoint_opt_i32() { #[test] fn test_quantile_axis_skipnan_mut_linear_f64() { let mut a = arr2(&[[1., 2., ::std::f64::NAN, 3.], [::std::f64::NAN; 4]]); - let q = a.quantile_axis_skipnan_mut::(Axis(1), 0.75); + let q = a + .quantile_axis_skipnan_mut(Axis(1), n64(0.75), &Linear) + .unwrap(); assert_eq!(q.shape(), &[2]); assert!((q[0] - 2.5).abs() < 1e-12); assert!(q[1].is_nan()); @@ -255,7 +271,9 @@ fn test_quantile_axis_skipnan_mut_linear_f64() { #[test] fn test_quantile_axis_skipnan_mut_linear_opt_i32() { let mut a = arr2(&[[Some(2), Some(4), None, Some(1)], [None; 4]]); - let q = a.quantile_axis_skipnan_mut::(Axis(1), 0.75); + let q = a + .quantile_axis_skipnan_mut(Axis(1), n64(0.75), &Linear) + .unwrap(); assert_eq!(q.shape(), &[2]); assert_eq!(q[0], Some(3)); assert!(q[1].is_none()); @@ -266,7 +284,148 @@ fn test_midpoint_overflow() { // Regression test // This triggered an overflow panic with a naive Midpoint implementation: (a+b)/2 let mut a: Array1 = array![129, 130, 130, 131]; - let median = a.quantile_mut::(0.5).unwrap(); + let median = a.quantile_mut(n64(0.5), &Midpoint).unwrap(); let expected_median = 130; assert_eq!(median, expected_median); } + +#[quickcheck] +fn test_quantiles_mut(xs: Vec) -> bool { + let v = Array::from_vec(xs.clone()); + + // Unordered list of quantile indexes to look up, with a duplicate + let quantile_indexes = Array::from(vec![ + n64(0.75), + n64(0.90), + n64(0.95), + n64(0.99), + n64(1.), + n64(0.), + n64(0.25), + n64(0.5), + n64(0.5), + ]); + let mut correct = true; + correct &= check_one_interpolation_method_for_quantiles_mut( + v.clone(), + quantile_indexes.view(), + &Linear, + ); + correct &= check_one_interpolation_method_for_quantiles_mut( + v.clone(), + quantile_indexes.view(), + &Higher, + ); + correct &= check_one_interpolation_method_for_quantiles_mut( + v.clone(), + quantile_indexes.view(), + &Lower, + ); + correct &= check_one_interpolation_method_for_quantiles_mut( + v.clone(), + quantile_indexes.view(), + &Midpoint, + ); + correct &= check_one_interpolation_method_for_quantiles_mut( + v.clone(), + quantile_indexes.view(), + &Nearest, + ); + correct +} + +fn check_one_interpolation_method_for_quantiles_mut( + mut v: Array1, + quantile_indexes: ArrayView1, + interpolate: &impl Interpolate, +) -> bool { + let bulk_quantiles = v.clone().quantiles_mut(&quantile_indexes, interpolate); + + if v.len() == 0 { + bulk_quantiles.is_err() + } else { + let bulk_quantiles = bulk_quantiles.unwrap(); + izip!(quantile_indexes, &bulk_quantiles).all(|(&quantile_index, &quantile)| { + quantile == v.quantile_mut(quantile_index, interpolate).unwrap() + }) + } +} + +#[quickcheck] +fn test_quantiles_axis_mut(mut xs: Vec) -> bool { + // We want a square matrix + let axis_length = (xs.len() as f64).sqrt().floor() as usize; + xs.truncate(axis_length * axis_length); + let m = Array::from_shape_vec((axis_length, axis_length), xs).unwrap(); + + // Unordered list of quantile indexes to look up, with a duplicate + let quantile_indexes = Array::from(vec![ + n64(0.75), + n64(0.90), + n64(0.95), + n64(0.99), + n64(1.), + n64(0.), + n64(0.25), + n64(0.5), + n64(0.5), + ]); + + // Test out all interpolation methods + let mut correct = true; + correct &= check_one_interpolation_method_for_quantiles_axis_mut( + m.clone(), + quantile_indexes.view(), + Axis(0), + &Linear, + ); + correct &= check_one_interpolation_method_for_quantiles_axis_mut( + m.clone(), + quantile_indexes.view(), + Axis(0), + &Higher, + ); + correct &= check_one_interpolation_method_for_quantiles_axis_mut( + m.clone(), + quantile_indexes.view(), + Axis(0), + &Lower, + ); + correct &= check_one_interpolation_method_for_quantiles_axis_mut( + m.clone(), + quantile_indexes.view(), + Axis(0), + &Midpoint, + ); + correct &= check_one_interpolation_method_for_quantiles_axis_mut( + m.clone(), + quantile_indexes.view(), + Axis(0), + &Nearest, + ); + correct +} + +fn check_one_interpolation_method_for_quantiles_axis_mut( + mut v: Array2, + quantile_indexes: ArrayView1, + axis: Axis, + interpolate: &impl Interpolate, +) -> bool { + let bulk_quantiles = v + .clone() + .quantiles_axis_mut(axis, &quantile_indexes, interpolate); + + if v.len() == 0 { + bulk_quantiles.is_err() + } else { + let bulk_quantiles = bulk_quantiles.unwrap(); + izip!(quantile_indexes, bulk_quantiles.axis_iter(axis)).all( + |(&quantile_index, quantile)| { + quantile + == v.quantile_axis_mut(axis, quantile_index, interpolate) + .unwrap() + }, + ) + } +} diff --git a/tests/sort.rs b/tests/sort.rs index 3c2cab58..2d1df06c 100644 --- a/tests/sort.rs +++ b/tests/sort.rs @@ -1,8 +1,11 @@ extern crate ndarray; extern crate ndarray_stats; +extern crate quickcheck; +extern crate quickcheck_macros; use ndarray::prelude::*; use ndarray_stats::Sort1dExt; +use quickcheck_macros::quickcheck; #[test] fn test_partition_mut() { @@ -27,7 +30,7 @@ fn test_partition_mut() { for i in 0..partition_index { assert!(a[i] < pivot_value); } - assert!(a[partition_index] == pivot_value); + assert_eq!(a[partition_index], pivot_value); for j in (partition_index + 1)..n { assert!(pivot_value <= a[j]); } @@ -37,10 +40,52 @@ fn test_partition_mut() { #[test] fn test_sorted_get_mut() { let a = arr1(&[1, 3, 2, 10]); - let j = a.clone().view_mut().sorted_get_mut(2); + let j = a.clone().view_mut().get_from_sorted_mut(2); assert_eq!(j, 3); - let j = a.clone().view_mut().sorted_get_mut(1); + let j = a.clone().view_mut().get_from_sorted_mut(1); assert_eq!(j, 2); - let j = a.clone().view_mut().sorted_get_mut(3); + let j = a.clone().view_mut().get_from_sorted_mut(3); assert_eq!(j, 10); } + +#[quickcheck] +fn test_sorted_get_many_mut(mut xs: Vec) -> bool { + let n = xs.len(); + if n == 0 { + true + } else { + let mut v = Array::from_vec(xs.clone()); + + // Insert each index twice, to get a set of indexes with duplicates, not sorted + let mut indexes: Vec = (0..n).into_iter().collect(); + indexes.append(&mut (0..n).collect()); + + let mut sorted_v = Vec::with_capacity(n); + for (i, (key, value)) in v + .get_many_from_sorted_mut(&Array::from(indexes)) + .into_iter() + .enumerate() + { + if i != key { + return false; + } + sorted_v.push(value); + } + xs.sort(); + println!("Sorted: {:?}. Truth: {:?}", sorted_v, xs); + xs == sorted_v + } +} + +#[quickcheck] +fn test_sorted_get_mut_as_sorting_algorithm(mut xs: Vec) -> bool { + let n = xs.len(); + if n == 0 { + true + } else { + let mut v = Array::from_vec(xs.clone()); + let sorted_v: Vec<_> = (0..n).map(|i| v.get_from_sorted_mut(i)).collect(); + xs.sort(); + xs == sorted_v + } +}