From 24a3299007ae569c9c2aa6c4038ccd7e2302ee8d Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sun, 9 Dec 2018 15:42:53 -0500 Subject: [PATCH 01/28] Rename SliceOrIndex to AxisSliceInfo --- blas-tests/tests/oper.rs | 8 +- src/dimension/dimension_trait.rs | 22 +++--- src/dimension/mod.rs | 14 ++-- src/impl_methods.rs | 14 ++-- src/lib.rs | 2 +- src/slice.rs | 130 +++++++++++++++---------------- tests/array.rs | 26 +++---- tests/oper.rs | 8 +- 8 files changed, 112 insertions(+), 112 deletions(-) diff --git a/blas-tests/tests/oper.rs b/blas-tests/tests/oper.rs index 51ac7824c..2475c4e2d 100644 --- a/blas-tests/tests/oper.rs +++ b/blas-tests/tests/oper.rs @@ -6,8 +6,8 @@ extern crate num_traits; use ndarray::linalg::general_mat_mul; use ndarray::linalg::general_mat_vec_mul; use ndarray::prelude::*; +use ndarray::{AxisSliceInfo, Ix, Ixs, SliceInfo}; use ndarray::{Data, LinalgScalar}; -use ndarray::{Ix, Ixs, SliceInfo, SliceOrIndex}; use approx::{assert_abs_diff_eq, assert_relative_eq}; use defmac::defmac; @@ -420,11 +420,11 @@ fn scaled_add_3() { let mut answer = a.clone(); let cdim = if n == 1 { vec![q] } else { vec![n, q] }; let cslice = if n == 1 { - vec![SliceOrIndex::from(..).step_by(s2)] + vec![AxisSliceInfo::from(..).step_by(s2)] } else { vec![ - SliceOrIndex::from(..).step_by(s1), - SliceOrIndex::from(..).step_by(s2), + AxisSliceInfo::from(..).step_by(s1), + AxisSliceInfo::from(..).step_by(s2), ] }; diff --git a/src/dimension/dimension_trait.rs b/src/dimension/dimension_trait.rs index 6007f93ab..c152ae3da 100644 --- a/src/dimension/dimension_trait.rs +++ b/src/dimension/dimension_trait.rs @@ -19,7 +19,7 @@ use crate::{Axis, DimMax}; use crate::IntoDimension; use crate::RemoveAxis; use crate::{ArrayView1, ArrayViewMut1}; -use crate::{Dim, Ix, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn, IxDynImpl, Ixs, SliceOrIndex}; +use crate::{AxisSliceInfo, Dim, Ix, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn, IxDynImpl, Ixs}; /// Array shape and index trait. /// @@ -63,14 +63,14 @@ pub trait Dimension: /// size, which you pass by reference. For the dynamic dimension it is /// a slice. /// - /// - For `Ix1`: `[SliceOrIndex; 1]` - /// - For `Ix2`: `[SliceOrIndex; 2]` + /// - For `Ix1`: `[AxisSliceInfo; 1]` + /// - For `Ix2`: `[AxisSliceInfo; 2]` /// - and so on.. - /// - For `IxDyn`: `[SliceOrIndex]` + /// - For `IxDyn`: `[AxisSliceInfo]` /// /// The easiest way to create a `&SliceInfo` is using the /// [`s![]`](macro.s!.html) macro. - type SliceArg: ?Sized + AsRef<[SliceOrIndex]>; + type SliceArg: ?Sized + AsRef<[AxisSliceInfo]>; /// Pattern matching friendly form of the dimension value. /// /// - For `Ix1`: `usize`, @@ -399,7 +399,7 @@ macro_rules! impl_insert_axis_array( impl Dimension for Dim<[Ix; 0]> { const NDIM: Option = Some(0); - type SliceArg = [SliceOrIndex; 0]; + type SliceArg = [AxisSliceInfo; 0]; type Pattern = (); type Smaller = Self; type Larger = Ix1; @@ -443,7 +443,7 @@ impl Dimension for Dim<[Ix; 0]> { impl Dimension for Dim<[Ix; 1]> { const NDIM: Option = Some(1); - type SliceArg = [SliceOrIndex; 1]; + type SliceArg = [AxisSliceInfo; 1]; type Pattern = Ix; type Smaller = Ix0; type Larger = Ix2; @@ -559,7 +559,7 @@ impl Dimension for Dim<[Ix; 1]> { impl Dimension for Dim<[Ix; 2]> { const NDIM: Option = Some(2); - type SliceArg = [SliceOrIndex; 2]; + type SliceArg = [AxisSliceInfo; 2]; type Pattern = (Ix, Ix); type Smaller = Ix1; type Larger = Ix3; @@ -716,7 +716,7 @@ impl Dimension for Dim<[Ix; 2]> { impl Dimension for Dim<[Ix; 3]> { const NDIM: Option = Some(3); - type SliceArg = [SliceOrIndex; 3]; + type SliceArg = [AxisSliceInfo; 3]; type Pattern = (Ix, Ix, Ix); type Smaller = Ix2; type Larger = Ix4; @@ -839,7 +839,7 @@ macro_rules! large_dim { ($n:expr, $name:ident, $pattern:ty, $larger:ty, { $($insert_axis:tt)* }) => ( impl Dimension for Dim<[Ix; $n]> { const NDIM: Option = Some($n); - type SliceArg = [SliceOrIndex; $n]; + type SliceArg = [AxisSliceInfo; $n]; type Pattern = $pattern; type Smaller = Dim<[Ix; $n - 1]>; type Larger = $larger; @@ -890,7 +890,7 @@ large_dim!(6, Ix6, (Ix, Ix, Ix, Ix, Ix, Ix), IxDyn, { /// and memory wasteful, but it allows an arbitrary and dynamic number of axes. impl Dimension for IxDyn { const NDIM: Option = None; - type SliceArg = [SliceOrIndex]; + type SliceArg = [AxisSliceInfo]; type Pattern = Self; type Smaller = Self; type Larger = Self; diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index 2505681b5..9bc603c53 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -7,7 +7,7 @@ // except according to those terms. use crate::error::{from_kind, ErrorKind, ShapeError}; -use crate::{Ix, Ixs, Slice, SliceOrIndex}; +use crate::{AxisSliceInfo, Ix, Ixs, Slice}; use num_integer::div_floor; pub use self::axes::{axes_of, Axes, AxisDescription}; @@ -601,15 +601,15 @@ pub fn slices_intersect( ) -> bool { debug_assert_eq!(indices1.as_ref().len(), indices2.as_ref().len()); for (&axis_len, &si1, &si2) in izip!(dim.slice(), indices1.as_ref(), indices2.as_ref()) { - // The slices do not intersect iff any pair of `SliceOrIndex` does not intersect. + // The slices do not intersect iff any pair of `AxisSliceInfo` does not intersect. match (si1, si2) { ( - SliceOrIndex::Slice { + AxisSliceInfo::Slice { start: start1, end: end1, step: step1, }, - SliceOrIndex::Slice { + AxisSliceInfo::Slice { start: start2, end: end2, step: step2, @@ -630,8 +630,8 @@ pub fn slices_intersect( return false; } } - (SliceOrIndex::Slice { start, end, step }, SliceOrIndex::Index(ind)) - | (SliceOrIndex::Index(ind), SliceOrIndex::Slice { start, end, step }) => { + (AxisSliceInfo::Slice { start, end, step }, AxisSliceInfo::Index(ind)) + | (AxisSliceInfo::Index(ind), AxisSliceInfo::Slice { start, end, step }) => { let ind = abs_index(axis_len, ind); let (min, max) = match slice_min_max(axis_len, Slice::new(start, end, step)) { Some(m) => m, @@ -641,7 +641,7 @@ pub fn slices_intersect( return false; } } - (SliceOrIndex::Index(ind1), SliceOrIndex::Index(ind2)) => { + (AxisSliceInfo::Index(ind1), AxisSliceInfo::Index(ind2)) => { let ind1 = abs_index(axis_len, ind1); let ind2 = abs_index(axis_len, ind2); if ind1 != ind2 { diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 958fc3f1c..02b182cb3 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -34,7 +34,7 @@ use crate::iter::{ }; use crate::slice::MultiSlice; use crate::stacking::concatenate; -use crate::{NdIndex, Slice, SliceInfo, SliceOrIndex}; +use crate::{AxisSliceInfo, NdIndex, Slice, SliceInfo}; /// # Methods For All Array Types impl ArrayBase @@ -417,7 +417,7 @@ where // Slice and collapse in-place without changing the number of dimensions. self.slice_collapse(&*info); - let indices: &[SliceOrIndex] = (**info).as_ref(); + let indices: &[AxisSliceInfo] = (**info).as_ref(); // Copy the dim and strides that remain after removing the subview axes. let out_ndim = info.out_ndim(); @@ -425,8 +425,8 @@ where let mut new_strides = Do::zeros(out_ndim); izip!(self.dim.slice(), self.strides.slice(), indices) .filter_map(|(d, s, slice_or_index)| match slice_or_index { - SliceOrIndex::Slice { .. } => Some((d, s)), - SliceOrIndex::Index(_) => None, + AxisSliceInfo::Slice { .. } => Some((d, s)), + AxisSliceInfo::Index(_) => None, }) .zip(izip!(new_dim.slice_mut(), new_strides.slice_mut())) .for_each(|((d, s), (new_d, new_s))| { @@ -455,16 +455,16 @@ where /// **Panics** if an index is out of bounds or step size is zero.
/// (**Panics** if `D` is `IxDyn` and `indices` does not match the number of array axes.) pub fn slice_collapse(&mut self, indices: &D::SliceArg) { - let indices: &[SliceOrIndex] = indices.as_ref(); + let indices: &[AxisSliceInfo] = indices.as_ref(); assert_eq!(indices.len(), self.ndim()); indices .iter() .enumerate() .for_each(|(axis, &slice_or_index)| match slice_or_index { - SliceOrIndex::Slice { start, end, step } => { + AxisSliceInfo::Slice { start, end, step } => { self.slice_axis_inplace(Axis(axis), Slice { start, end, step }) } - SliceOrIndex::Index(index) => { + AxisSliceInfo::Index(index) => { let i_usize = abs_index(self.len_of(Axis(axis)), index); self.collapse_axis(Axis(axis), i_usize) } diff --git a/src/lib.rs b/src/lib.rs index 4d35a1064..d1da388b6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -141,7 +141,7 @@ pub use crate::dimension::IxDynImpl; pub use crate::dimension::NdIndex; pub use crate::error::{ErrorKind, ShapeError}; pub use crate::indexes::{indices, indices_of}; -pub use crate::slice::{Slice, SliceInfo, SliceNextDim, SliceOrIndex}; +pub use crate::slice::{AxisSliceInfo, Slice, SliceInfo, SliceNextDim}; use crate::iterators::Baseiter; use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes}; diff --git a/src/slice.rs b/src/slice.rs index 86a2b0b8f..dbd8aea23 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -70,29 +70,29 @@ impl Slice { /// A slice (range with step) or an index. /// /// See also the [`s![]`](macro.s!.html) macro for a convenient way to create a -/// `&SliceInfo<[SliceOrIndex; n], D>`. +/// `&SliceInfo<[AxisSliceInfo; n], D>`. /// /// ## Examples /// -/// `SliceOrIndex::Index(a)` is the index `a`. It can also be created with -/// `SliceOrIndex::from(a)`. The Python equivalent is `[a]`. The macro +/// `AxisSliceInfo::Index(a)` is the index `a`. It can also be created with +/// `AxisSliceInfo::from(a)`. The Python equivalent is `[a]`. The macro /// equivalent is `s![a]`. /// -/// `SliceOrIndex::Slice { start: 0, end: None, step: 1 }` is the full range of -/// an axis. It can also be created with `SliceOrIndex::from(..)`. The Python -/// equivalent is `[:]`. The macro equivalent is `s![..]`. +/// `AxisSliceInfo::Slice { start: 0, end: None, step: 1 }` is the full range +/// of an axis. It can also be created with `AxisSliceInfo::from(..)`. The +/// Python equivalent is `[:]`. The macro equivalent is `s![..]`. /// -/// `SliceOrIndex::Slice { start: a, end: Some(b), step: 2 }` is every second +/// `AxisSliceInfo::Slice { start: a, end: Some(b), step: 2 }` is every second /// element from `a` until `b`. It can also be created with -/// `SliceOrIndex::from(a..b).step_by(2)`. The Python equivalent is `[a:b:2]`. +/// `AxisSliceInfo::from(a..b).step_by(2)`. The Python equivalent is `[a:b:2]`. /// The macro equivalent is `s![a..b;2]`. /// -/// `SliceOrIndex::Slice { start: a, end: None, step: -1 }` is every element, +/// `AxisSliceInfo::Slice { start: a, end: None, step: -1 }` is every element, /// from `a` until the end, in reverse order. It can also be created with -/// `SliceOrIndex::from(a..).step_by(-1)`. The Python equivalent is `[a::-1]`. +/// `AxisSliceInfo::from(a..).step_by(-1)`. The Python equivalent is `[a::-1]`. /// The macro equivalent is `s![a..;-1]`. #[derive(Debug, PartialEq, Eq, Hash)] -pub enum SliceOrIndex { +pub enum AxisSliceInfo { /// A range with step size. `end` is an exclusive index. Negative `begin` /// or `end` indexes are counted from the back of the axis. If `end` is /// `None`, the slice extends to the end of the axis. @@ -105,47 +105,47 @@ pub enum SliceOrIndex { Index(isize), } -copy_and_clone! {SliceOrIndex} +copy_and_clone! {AxisSliceInfo} -impl SliceOrIndex { +impl AxisSliceInfo { /// Returns `true` if `self` is a `Slice` value. pub fn is_slice(&self) -> bool { - matches!(self, SliceOrIndex::Slice { .. }) + matches!(self, AxisSliceInfo::Slice { .. }) } /// Returns `true` if `self` is an `Index` value. pub fn is_index(&self) -> bool { - matches!(self, SliceOrIndex::Index(_)) + matches!(self, AxisSliceInfo::Index(_)) } - /// Returns a new `SliceOrIndex` with the given step size (multiplied with + /// Returns a new `AxisSliceInfo` with the given step size (multiplied with /// the previous step size). /// /// `step` must be nonzero. /// (This method checks with a debug assertion that `step` is not zero.) #[inline] pub fn step_by(self, step: isize) -> Self { - debug_assert_ne!(step, 0, "SliceOrIndex::step_by: step must be nonzero"); + debug_assert_ne!(step, 0, "AxisSliceInfo::step_by: step must be nonzero"); match self { - SliceOrIndex::Slice { + AxisSliceInfo::Slice { start, end, step: orig_step, - } => SliceOrIndex::Slice { + } => AxisSliceInfo::Slice { start, end, step: orig_step * step, }, - SliceOrIndex::Index(s) => SliceOrIndex::Index(s), + AxisSliceInfo::Index(s) => AxisSliceInfo::Index(s), } } } -impl fmt::Display for SliceOrIndex { +impl fmt::Display for AxisSliceInfo { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { - SliceOrIndex::Index(index) => write!(f, "{}", index)?, - SliceOrIndex::Slice { start, end, step } => { + AxisSliceInfo::Index(index) => write!(f, "{}", index)?, + AxisSliceInfo::Slice { start, end, step } => { if start != 0 { write!(f, "{}", start)?; } @@ -225,9 +225,9 @@ macro_rules! impl_slice_variant_from_range { impl_slice_variant_from_range!(Slice, Slice, isize); impl_slice_variant_from_range!(Slice, Slice, usize); impl_slice_variant_from_range!(Slice, Slice, i32); -impl_slice_variant_from_range!(SliceOrIndex, SliceOrIndex::Slice, isize); -impl_slice_variant_from_range!(SliceOrIndex, SliceOrIndex::Slice, usize); -impl_slice_variant_from_range!(SliceOrIndex, SliceOrIndex::Slice, i32); +impl_slice_variant_from_range!(AxisSliceInfo, AxisSliceInfo::Slice, isize); +impl_slice_variant_from_range!(AxisSliceInfo, AxisSliceInfo::Slice, usize); +impl_slice_variant_from_range!(AxisSliceInfo, AxisSliceInfo::Slice, i32); impl From for Slice { #[inline] @@ -240,10 +240,10 @@ impl From for Slice { } } -impl From for SliceOrIndex { +impl From for AxisSliceInfo { #[inline] - fn from(_: RangeFull) -> SliceOrIndex { - SliceOrIndex::Slice { + fn from(_: RangeFull) -> AxisSliceInfo { + AxisSliceInfo::Slice { start: 0, end: None, step: 1, @@ -251,10 +251,10 @@ impl From for SliceOrIndex { } } -impl From for SliceOrIndex { +impl From for AxisSliceInfo { #[inline] - fn from(s: Slice) -> SliceOrIndex { - SliceOrIndex::Slice { + fn from(s: Slice) -> AxisSliceInfo { + AxisSliceInfo::Slice { start: s.start, end: s.end, step: s.step, @@ -262,24 +262,24 @@ impl From for SliceOrIndex { } } -macro_rules! impl_sliceorindex_from_index { +macro_rules! impl_axissliceinfo_from_index { ($index:ty) => { - impl From<$index> for SliceOrIndex { + impl From<$index> for AxisSliceInfo { #[inline] - fn from(r: $index) -> SliceOrIndex { - SliceOrIndex::Index(r as isize) + fn from(r: $index) -> AxisSliceInfo { + AxisSliceInfo::Index(r as isize) } } }; } -impl_sliceorindex_from_index!(isize); -impl_sliceorindex_from_index!(usize); -impl_sliceorindex_from_index!(i32); +impl_axissliceinfo_from_index!(isize); +impl_axissliceinfo_from_index!(usize); +impl_axissliceinfo_from_index!(i32); /// Represents all of the necessary information to perform a slice. /// -/// The type `T` is typically `[SliceOrIndex; n]`, `[SliceOrIndex]`, or -/// `Vec`. The type `D` is the output dimension after calling +/// The type `T` is typically `[AxisSliceInfo; n]`, `[AxisSliceInfo]`, or +/// `Vec`. The type `D` is the output dimension after calling /// [`.slice()`]. /// /// [`.slice()`]: struct.ArrayBase.html#method.slice @@ -316,7 +316,7 @@ where impl SliceInfo where - T: AsRef<[SliceOrIndex]>, + T: AsRef<[AxisSliceInfo]>, D: Dimension, { /// Returns a new `SliceInfo` instance. @@ -337,7 +337,7 @@ where impl SliceInfo where - T: AsRef<[SliceOrIndex]>, + T: AsRef<[AxisSliceInfo]>, D: Dimension, { /// Returns the number of dimensions after calling @@ -358,29 +358,29 @@ where } } -impl AsRef<[SliceOrIndex]> for SliceInfo +impl AsRef<[AxisSliceInfo]> for SliceInfo where - T: AsRef<[SliceOrIndex]>, + T: AsRef<[AxisSliceInfo]>, D: Dimension, { - fn as_ref(&self) -> &[SliceOrIndex] { + fn as_ref(&self) -> &[AxisSliceInfo] { self.indices.as_ref() } } -impl AsRef> for SliceInfo +impl AsRef> for SliceInfo where - T: AsRef<[SliceOrIndex]>, + T: AsRef<[AxisSliceInfo]>, D: Dimension, { - fn as_ref(&self) -> &SliceInfo<[SliceOrIndex], D> { + fn as_ref(&self) -> &SliceInfo<[AxisSliceInfo], D> { unsafe { // This is okay because the only non-zero-sized member of - // `SliceInfo` is `indices`, so `&SliceInfo<[SliceOrIndex], D>` + // `SliceInfo` is `indices`, so `&SliceInfo<[AxisSliceInfo], D>` // should have the same bitwise representation as - // `&[SliceOrIndex]`. - &*(self.indices.as_ref() as *const [SliceOrIndex] - as *const SliceInfo<[SliceOrIndex], D>) + // `&[AxisSliceInfo]`. + &*(self.indices.as_ref() as *const [AxisSliceInfo] + as *const SliceInfo<[AxisSliceInfo], D>) } } } @@ -452,8 +452,8 @@ impl_slicenextdim_larger!((), Slice); /// counted from the end of the axis. Step sizes are also signed and may be /// negative, but must not be zero. /// -/// The syntax is `s![` *[ axis-slice-or-index [, axis-slice-or-index [ , ... ] -/// ] ]* `]`, where *axis-slice-or-index* is any of the following: +/// The syntax is `s![` *[ axis-slice-info [, axis-slice-info [ , ... ] ] ]* +/// `]`, where *axis-slice-info* is any of the following: /// /// * *index*: an index to use for taking a subview with respect to that axis. /// (The index is selected. The axis is removed except with @@ -466,12 +466,12 @@ impl_slicenextdim_larger!((), Slice); /// /// [`Slice`]: struct.Slice.html /// -/// The number of *axis-slice-or-index* must match the number of axes in the -/// array. *index*, *range*, *slice*, and *step* can be expressions. *index* -/// must be of type `isize`, `usize`, or `i32`. *range* must be of type -/// `Range`, `RangeTo`, `RangeFrom`, or `RangeFull` where `I` is -/// `isize`, `usize`, or `i32`. *step* must be a type that can be converted to -/// `isize` with the `as` keyword. +/// The number of *axis-slice-info* must match the number of axes in the array. +/// *index*, *range*, *slice*, and *step* can be expressions. *index* must be +/// of type `isize`, `usize`, or `i32`. *range* must be of type `Range`, +/// `RangeTo`, `RangeFrom`, or `RangeFull` where `I` is `isize`, `usize`, +/// or `i32`. *step* must be a type that can be converted to `isize` with the +/// `as` keyword. /// /// For example `s![0..4;2, 6, 1..5]` is a slice of the first axis for 0..4 /// with step size 2, a subview of the second axis at index 6, and a slice of @@ -606,13 +606,13 @@ macro_rules! s( }; // Catch-all clause for syntax errors (@parse $($t:tt)*) => { compile_error!("Invalid syntax in s![] call.") }; - // convert range/index into SliceOrIndex + // convert range/index into AxisSliceInfo (@convert $r:expr) => { - <$crate::SliceOrIndex as ::std::convert::From<_>>::from($r) + <$crate::AxisSliceInfo as ::std::convert::From<_>>::from($r) }; - // convert range/index and step into SliceOrIndex + // convert range/index and step into AxisSliceInfo (@convert $r:expr, $s:expr) => { - <$crate::SliceOrIndex as ::std::convert::From<_>>::from($r).step_by($s as isize) + <$crate::AxisSliceInfo as ::std::convert::From<_>>::from($r).step_by($s as isize) }; ($($t:tt)*) => { // The extra `*&` is a workaround for this compiler bug: diff --git a/tests/array.rs b/tests/array.rs index b0a28ca41..4b30c15b4 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -12,7 +12,7 @@ use itertools::{enumerate, zip, Itertools}; use ndarray::prelude::*; use ndarray::{arr3, rcarr2}; use ndarray::indices; -use ndarray::{Slice, SliceInfo, SliceOrIndex}; +use ndarray::{AxisSliceInfo, Slice, SliceInfo}; macro_rules! assert_panics { ($body:expr) => { @@ -217,9 +217,9 @@ fn test_slice_dyninput_array_fixed() { fn test_slice_array_dyn() { let mut arr = Array3::::zeros((5, 2, 5)); let info = &SliceInfo::<_, IxDyn>::new([ - SliceOrIndex::from(1..), - SliceOrIndex::from(1), - SliceOrIndex::from(..).step_by(2), + AxisSliceInfo::from(1..), + AxisSliceInfo::from(1), + AxisSliceInfo::from(..).step_by(2), ]) .unwrap(); arr.slice(info); @@ -232,9 +232,9 @@ fn test_slice_array_dyn() { fn test_slice_dyninput_array_dyn() { let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); let info = &SliceInfo::<_, IxDyn>::new([ - SliceOrIndex::from(1..), - SliceOrIndex::from(1), - SliceOrIndex::from(..).step_by(2), + AxisSliceInfo::from(1..), + AxisSliceInfo::from(1), + AxisSliceInfo::from(..).step_by(2), ]) .unwrap(); arr.slice(info); @@ -247,9 +247,9 @@ fn test_slice_dyninput_array_dyn() { fn test_slice_dyninput_vec_fixed() { let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); let info = &SliceInfo::<_, Ix2>::new(vec![ - SliceOrIndex::from(1..), - SliceOrIndex::from(1), - SliceOrIndex::from(..).step_by(2), + AxisSliceInfo::from(1..), + AxisSliceInfo::from(1), + AxisSliceInfo::from(..).step_by(2), ]) .unwrap(); arr.slice(info.as_ref()); @@ -262,9 +262,9 @@ fn test_slice_dyninput_vec_fixed() { fn test_slice_dyninput_vec_dyn() { let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); let info = &SliceInfo::<_, IxDyn>::new(vec![ - SliceOrIndex::from(1..), - SliceOrIndex::from(1), - SliceOrIndex::from(..).step_by(2), + AxisSliceInfo::from(1..), + AxisSliceInfo::from(1), + AxisSliceInfo::from(..).step_by(2), ]) .unwrap(); arr.slice(info.as_ref()); diff --git a/tests/oper.rs b/tests/oper.rs index 0d659fa1e..22dc6603e 100644 --- a/tests/oper.rs +++ b/tests/oper.rs @@ -561,7 +561,7 @@ fn scaled_add_2() { #[test] fn scaled_add_3() { use approx::assert_relative_eq; - use ndarray::{SliceInfo, SliceOrIndex}; + use ndarray::{SliceInfo, AxisSliceInfo}; let beta = -2.3; let sizes = vec![ @@ -583,11 +583,11 @@ fn scaled_add_3() { let mut answer = a.clone(); let cdim = if n == 1 { vec![q] } else { vec![n, q] }; let cslice = if n == 1 { - vec![SliceOrIndex::from(..).step_by(s2)] + vec![AxisSliceInfo::from(..).step_by(s2)] } else { vec![ - SliceOrIndex::from(..).step_by(s1), - SliceOrIndex::from(..).step_by(s2), + AxisSliceInfo::from(..).step_by(s1), + AxisSliceInfo::from(..).step_by(s2), ] }; From 6a16b881647d405ed9123ba9b04f59f876f97487 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sun, 9 Dec 2018 16:39:00 -0500 Subject: [PATCH 02/28] Switch from Dimension::SliceArg to CanSlice trait --- blas-tests/tests/oper.rs | 2 +- src/dimension/dimension_trait.rs | 23 +-- src/dimension/mod.rs | 7 +- src/impl_methods.rs | 80 +++++---- src/lib.rs | 2 +- src/slice.rs | 286 ++++++++++++++++++++++--------- tests/array.rs | 28 +-- tests/oper.rs | 2 +- 8 files changed, 273 insertions(+), 157 deletions(-) diff --git a/blas-tests/tests/oper.rs b/blas-tests/tests/oper.rs index 2475c4e2d..25d26b7ba 100644 --- a/blas-tests/tests/oper.rs +++ b/blas-tests/tests/oper.rs @@ -432,7 +432,7 @@ fn scaled_add_3() { { let mut av = a.slice_mut(s![..;s1, ..;s2]); - let c = c.slice(SliceInfo::<_, IxDyn>::new(cslice).unwrap().as_ref()); + let c = c.slice(&SliceInfo::<_, IxDyn, IxDyn>::new(cslice).unwrap()); let mut answerv = answer.slice_mut(s![..;s1, ..;s2]); answerv += &(beta * &c); diff --git a/src/dimension/dimension_trait.rs b/src/dimension/dimension_trait.rs index c152ae3da..df38904f4 100644 --- a/src/dimension/dimension_trait.rs +++ b/src/dimension/dimension_trait.rs @@ -19,7 +19,7 @@ use crate::{Axis, DimMax}; use crate::IntoDimension; use crate::RemoveAxis; use crate::{ArrayView1, ArrayViewMut1}; -use crate::{AxisSliceInfo, Dim, Ix, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn, IxDynImpl, Ixs}; +use crate::{Dim, Ix, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn, IxDynImpl, Ixs}; /// Array shape and index trait. /// @@ -56,21 +56,6 @@ pub trait Dimension: /// `Some(ndim)`, and for variable-size dimension representations (e.g. /// `IxDyn`), this should be `None`. const NDIM: Option; - /// `SliceArg` is the type which is used to specify slicing for this - /// dimension. - /// - /// For the fixed size dimensions it is a fixed size array of the correct - /// size, which you pass by reference. For the dynamic dimension it is - /// a slice. - /// - /// - For `Ix1`: `[AxisSliceInfo; 1]` - /// - For `Ix2`: `[AxisSliceInfo; 2]` - /// - and so on.. - /// - For `IxDyn`: `[AxisSliceInfo]` - /// - /// The easiest way to create a `&SliceInfo` is using the - /// [`s![]`](macro.s!.html) macro. - type SliceArg: ?Sized + AsRef<[AxisSliceInfo]>; /// Pattern matching friendly form of the dimension value. /// /// - For `Ix1`: `usize`, @@ -399,7 +384,6 @@ macro_rules! impl_insert_axis_array( impl Dimension for Dim<[Ix; 0]> { const NDIM: Option = Some(0); - type SliceArg = [AxisSliceInfo; 0]; type Pattern = (); type Smaller = Self; type Larger = Ix1; @@ -443,7 +427,6 @@ impl Dimension for Dim<[Ix; 0]> { impl Dimension for Dim<[Ix; 1]> { const NDIM: Option = Some(1); - type SliceArg = [AxisSliceInfo; 1]; type Pattern = Ix; type Smaller = Ix0; type Larger = Ix2; @@ -559,7 +542,6 @@ impl Dimension for Dim<[Ix; 1]> { impl Dimension for Dim<[Ix; 2]> { const NDIM: Option = Some(2); - type SliceArg = [AxisSliceInfo; 2]; type Pattern = (Ix, Ix); type Smaller = Ix1; type Larger = Ix3; @@ -716,7 +698,6 @@ impl Dimension for Dim<[Ix; 2]> { impl Dimension for Dim<[Ix; 3]> { const NDIM: Option = Some(3); - type SliceArg = [AxisSliceInfo; 3]; type Pattern = (Ix, Ix, Ix); type Smaller = Ix2; type Larger = Ix4; @@ -839,7 +820,6 @@ macro_rules! large_dim { ($n:expr, $name:ident, $pattern:ty, $larger:ty, { $($insert_axis:tt)* }) => ( impl Dimension for Dim<[Ix; $n]> { const NDIM: Option = Some($n); - type SliceArg = [AxisSliceInfo; $n]; type Pattern = $pattern; type Smaller = Dim<[Ix; $n - 1]>; type Larger = $larger; @@ -890,7 +870,6 @@ large_dim!(6, Ix6, (Ix, Ix, Ix, Ix, Ix, Ix), IxDyn, { /// and memory wasteful, but it allows an arbitrary and dynamic number of axes. impl Dimension for IxDyn { const NDIM: Option = None; - type SliceArg = [AxisSliceInfo]; type Pattern = Self; type Smaller = Self; type Larger = Self; diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index 9bc603c53..a1e245884 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -7,6 +7,7 @@ // except according to those terms. use crate::error::{from_kind, ErrorKind, ShapeError}; +use crate::slice::CanSlice; use crate::{AxisSliceInfo, Ix, Ixs, Slice}; use num_integer::div_floor; @@ -596,10 +597,10 @@ fn slice_min_max(axis_len: usize, slice: Slice) -> Option<(usize, usize)> { /// Returns `true` iff the slices intersect. pub fn slices_intersect( dim: &D, - indices1: &D::SliceArg, - indices2: &D::SliceArg, + indices1: &impl CanSlice, + indices2: &impl CanSlice, ) -> bool { - debug_assert_eq!(indices1.as_ref().len(), indices2.as_ref().len()); + debug_assert_eq!(indices1.in_ndim(), indices2.in_ndim()); for (&axis_len, &si1, &si2) in izip!(dim.slice(), indices1.as_ref(), indices2.as_ref()) { // The slices do not intersect iff any pair of `AxisSliceInfo` does not intersect. match (si1, si2) { diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 02b182cb3..2cbfcacf2 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -32,9 +32,9 @@ use crate::iter::{ AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut, ExactChunks, ExactChunksMut, IndexedIter, IndexedIterMut, Iter, IterMut, Lanes, LanesMut, Windows, }; -use crate::slice::MultiSlice; +use crate::slice::{CanSlice, MultiSlice}; use crate::stacking::concatenate; -use crate::{AxisSliceInfo, NdIndex, Slice, SliceInfo}; +use crate::{AxisSliceInfo, NdIndex, Slice}; /// # Methods For All Array Types impl ArrayBase @@ -341,9 +341,9 @@ where /// /// **Panics** if an index is out of bounds or step size is zero.
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) - pub fn slice(&self, info: &SliceInfo) -> ArrayView<'_, A, Do> + pub fn slice(&self, info: &I) -> ArrayView<'_, A, I::OutDim> where - Do: Dimension, + I: CanSlice, S: Data, { self.view().slice_move(info) @@ -359,9 +359,9 @@ where /// /// **Panics** if an index is out of bounds or step size is zero.
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) - pub fn slice_mut(&mut self, info: &SliceInfo) -> ArrayViewMut<'_, A, Do> + pub fn slice_mut(&mut self, info: &I) -> ArrayViewMut<'_, A, I::OutDim> where - Do: Dimension, + I: CanSlice, S: DataMut, { self.view_mut().slice_move(info) @@ -410,29 +410,37 @@ where /// /// **Panics** if an index is out of bounds or step size is zero.
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) - pub fn slice_move(mut self, info: &SliceInfo) -> ArrayBase + pub fn slice_move(mut self, info: &I) -> ArrayBase where - Do: Dimension, + I: CanSlice, { // Slice and collapse in-place without changing the number of dimensions. - self.slice_collapse(&*info); + self.slice_collapse(info); - let indices: &[AxisSliceInfo] = (**info).as_ref(); - - // Copy the dim and strides that remain after removing the subview axes. let out_ndim = info.out_ndim(); - let mut new_dim = Do::zeros(out_ndim); - let mut new_strides = Do::zeros(out_ndim); - izip!(self.dim.slice(), self.strides.slice(), indices) - .filter_map(|(d, s, slice_or_index)| match slice_or_index { - AxisSliceInfo::Slice { .. } => Some((d, s)), - AxisSliceInfo::Index(_) => None, - }) - .zip(izip!(new_dim.slice_mut(), new_strides.slice_mut())) - .for_each(|((d, s), (new_d, new_s))| { - *new_d = *d; - *new_s = *s; + let mut new_dim = I::OutDim::zeros(out_ndim); + let mut new_strides = I::OutDim::zeros(out_ndim); + + // Write the dim and strides to the correct new axes. + { + let mut old_axis = 0; + let mut new_axis = 0; + info.as_ref().iter().for_each(|ax_info| match ax_info { + AxisSliceInfo::Slice { .. } => { + // Copy the old dim and stride to corresponding axis. + new_dim[new_axis] = self.dim[old_axis]; + new_strides[new_axis] = self.strides[old_axis]; + old_axis += 1; + new_axis += 1; + } + AxisSliceInfo::Index(_) => { + // Skip the old axis since it should be removed. + old_axis += 1; + } }); + debug_assert_eq!(old_axis, self.ndim()); + debug_assert_eq!(new_axis, out_ndim); + } // safe because new dimension, strides allow access to a subset of old data unsafe { @@ -442,25 +450,23 @@ where /// Slice the array in place without changing the number of dimensions. /// - /// Note that [`&SliceInfo`](struct.SliceInfo.html) (produced by the - /// [`s![]`](macro.s!.html) macro) will usually coerce into `&D::SliceArg` - /// automatically, but in some cases (e.g. if `D` is `IxDyn`), you may need - /// to call `.as_ref()`. - /// /// See [*Slicing*](#slicing) for full documentation. - /// See also [`D::SliceArg`]. - /// - /// [`D::SliceArg`]: trait.Dimension.html#associatedtype.SliceArg /// /// **Panics** if an index is out of bounds or step size is zero.
- /// (**Panics** if `D` is `IxDyn` and `indices` does not match the number of array axes.) - pub fn slice_collapse(&mut self, indices: &D::SliceArg) { - let indices: &[AxisSliceInfo] = indices.as_ref(); - assert_eq!(indices.len(), self.ndim()); - indices + /// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) + pub fn slice_collapse(&mut self, info: &I) + where + I: CanSlice, + { + assert_eq!( + info.in_ndim(), + self.ndim(), + "The input dimension of `info` must match the array to be sliced.", + ); + info.as_ref() .iter() .enumerate() - .for_each(|(axis, &slice_or_index)| match slice_or_index { + .for_each(|(axis, &ax_info)| match ax_info { AxisSliceInfo::Slice { start, end, step } => { self.slice_axis_inplace(Axis(axis), Slice { start, end, step }) } diff --git a/src/lib.rs b/src/lib.rs index d1da388b6..7c9439411 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -141,7 +141,7 @@ pub use crate::dimension::IxDynImpl; pub use crate::dimension::NdIndex; pub use crate::error::{ErrorKind, ShapeError}; pub use crate::indexes::{indices, indices_of}; -pub use crate::slice::{AxisSliceInfo, Slice, SliceInfo, SliceNextDim}; +pub use crate::slice::{AxisSliceInfo, Slice, SliceInfo, SliceNextInDim, SliceNextOutDim}; use crate::iterators::Baseiter; use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes}; diff --git a/src/slice.rs b/src/slice.rs index dbd8aea23..dc4ad9d7a 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -7,10 +7,10 @@ // except according to those terms. use crate::dimension::slices_intersect; use crate::error::{ErrorKind, ShapeError}; -use crate::{ArrayViewMut, Dimension}; use std::fmt; use std::marker::PhantomData; use std::ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive}; +use crate::{ArrayViewMut, Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn}; /// A slice (range with step size). /// @@ -70,7 +70,7 @@ impl Slice { /// A slice (range with step) or an index. /// /// See also the [`s![]`](macro.s!.html) macro for a convenient way to create a -/// `&SliceInfo<[AxisSliceInfo; n], D>`. +/// `&SliceInfo<[AxisSliceInfo; n], Di, Do>`. /// /// ## Examples /// @@ -276,23 +276,85 @@ impl_axissliceinfo_from_index!(isize); impl_axissliceinfo_from_index!(usize); impl_axissliceinfo_from_index!(i32); +/// A type that can slice an array of dimension `D`. +/// +/// This trait is unsafe to implement because the implementation must ensure +/// that `D`, `Self::OutDim`, `self.in_dim()`, and `self.out_ndim()` are +/// consistent with the `&[AxisSliceInfo]` returned by `self.as_ref()`. +pub unsafe trait CanSlice: AsRef<[AxisSliceInfo]> { + type OutDim: Dimension; + + fn in_ndim(&self) -> usize; + + fn out_ndim(&self) -> usize; +} + +macro_rules! impl_canslice_samedim { + ($in_dim:ty) => { + unsafe impl CanSlice<$in_dim> for SliceInfo + where + T: AsRef<[AxisSliceInfo]>, + Do: Dimension, + { + type OutDim = Do; + + fn in_ndim(&self) -> usize { + self.in_ndim() + } + + fn out_ndim(&self) -> usize { + self.out_ndim() + } + } + }; +} +impl_canslice_samedim!(Ix0); +impl_canslice_samedim!(Ix1); +impl_canslice_samedim!(Ix2); +impl_canslice_samedim!(Ix3); +impl_canslice_samedim!(Ix4); +impl_canslice_samedim!(Ix5); +impl_canslice_samedim!(Ix6); + +unsafe impl CanSlice for SliceInfo +where + T: AsRef<[AxisSliceInfo]>, + Di: Dimension, + Do: Dimension, +{ + type OutDim = Do; + + fn in_ndim(&self) -> usize { + self.in_ndim() + } + + fn out_ndim(&self) -> usize { + self.out_ndim() + } +} + /// Represents all of the necessary information to perform a slice. /// /// The type `T` is typically `[AxisSliceInfo; n]`, `[AxisSliceInfo]`, or -/// `Vec`. The type `D` is the output dimension after calling -/// [`.slice()`]. +/// `Vec`. The type `Di` is the dimension of the array to be +/// sliced, and `Do` is the output dimension after calling [`.slice()`]. Note +/// that if `Di` is a fixed dimension type (`Ix0`, `Ix1`, `Ix2`, etc.), the +/// `SliceInfo` instance can still be used to slice an array with dimension +/// `IxDyn` as long as the number of axes matches. /// /// [`.slice()`]: struct.ArrayBase.html#method.slice #[derive(Debug)] #[repr(C)] -pub struct SliceInfo { - out_dim: PhantomData, +pub struct SliceInfo { + in_dim: PhantomData, + out_dim: PhantomData, indices: T, } -impl Deref for SliceInfo +impl Deref for SliceInfo where - D: Dimension, + Di: Dimension, + Do: Dimension, { type Target = T; fn deref(&self) -> &Self::Target { @@ -300,55 +362,78 @@ where } } -impl SliceInfo +impl SliceInfo where - D: Dimension, + Di: Dimension, + Do: Dimension, { /// Returns a new `SliceInfo` instance. /// - /// If you call this method, you are guaranteeing that `out_dim` is - /// consistent with `indices`. + /// If you call this method, you are guaranteeing that `in_dim` and + /// `out_dim` are consistent with `indices`. #[doc(hidden)] - pub unsafe fn new_unchecked(indices: T, out_dim: PhantomData) -> SliceInfo { - SliceInfo { out_dim, indices } + pub unsafe fn new_unchecked( + indices: T, + in_dim: PhantomData, + out_dim: PhantomData, + ) -> SliceInfo { + SliceInfo { + in_dim: in_dim, + out_dim: out_dim, + indices: indices, + } } } -impl SliceInfo +impl SliceInfo where T: AsRef<[AxisSliceInfo]>, - D: Dimension, + Di: Dimension, + Do: Dimension, { /// Returns a new `SliceInfo` instance. /// - /// Errors if `D` is not consistent with `indices`. - pub fn new(indices: T) -> Result, ShapeError> { - if let Some(ndim) = D::NDIM { + /// Errors if `Di` or `Do` is not consistent with `indices`. + pub fn new(indices: T) -> Result, ShapeError> { + if let Some(ndim) = Di::NDIM { + if ndim != indices.as_ref().len() { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); + } + } + if let Some(ndim) = Do::NDIM { if ndim != indices.as_ref().iter().filter(|s| s.is_slice()).count() { return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); } } Ok(SliceInfo { + in_dim: PhantomData, out_dim: PhantomData, indices, }) } } -impl SliceInfo +impl SliceInfo where T: AsRef<[AxisSliceInfo]>, - D: Dimension, + Di: Dimension, + Do: Dimension, { + /// Returns the number of dimensions of the input array for + /// [`.slice()`](struct.ArrayBase.html#method.slice). + pub fn in_ndim(&self) -> usize { + Di::NDIM.unwrap_or_else(|| self.indices.as_ref().len()) + } + /// Returns the number of dimensions after calling /// [`.slice()`](struct.ArrayBase.html#method.slice) (including taking /// subviews). /// - /// If `D` is a fixed-size dimension type, then this is equivalent to - /// `D::NDIM.unwrap()`. Otherwise, the value is calculated by iterating - /// over the ranges/indices. + /// If `Do` is a fixed-size dimension type, then this is equivalent to + /// `Do::NDIM.unwrap()`. Otherwise, the value is calculated by iterating + /// over the `AxisSliceInfo` elements. pub fn out_ndim(&self) -> usize { - D::NDIM.unwrap_or_else(|| { + Do::NDIM.unwrap_or_else(|| { self.indices .as_ref() .iter() @@ -358,47 +443,52 @@ where } } -impl AsRef<[AxisSliceInfo]> for SliceInfo +impl AsRef<[AxisSliceInfo]> for SliceInfo where T: AsRef<[AxisSliceInfo]>, - D: Dimension, + Di: Dimension, + Do: Dimension, { fn as_ref(&self) -> &[AxisSliceInfo] { self.indices.as_ref() } } -impl AsRef> for SliceInfo +impl AsRef> for SliceInfo where T: AsRef<[AxisSliceInfo]>, - D: Dimension, + Di: Dimension, + Do: Dimension, { - fn as_ref(&self) -> &SliceInfo<[AxisSliceInfo], D> { + fn as_ref(&self) -> &SliceInfo<[AxisSliceInfo], Di, Do> { unsafe { // This is okay because the only non-zero-sized member of - // `SliceInfo` is `indices`, so `&SliceInfo<[AxisSliceInfo], D>` + // `SliceInfo` is `indices`, so `&SliceInfo<[AxisSliceInfo], Di, Do>` // should have the same bitwise representation as // `&[AxisSliceInfo]`. &*(self.indices.as_ref() as *const [AxisSliceInfo] - as *const SliceInfo<[AxisSliceInfo], D>) + as *const SliceInfo<[AxisSliceInfo], Di, Do>) } } } -impl Copy for SliceInfo +impl Copy for SliceInfo where T: Copy, - D: Dimension, + Di: Dimension, + Do: Dimension, { } -impl Clone for SliceInfo +impl Clone for SliceInfo where T: Clone, - D: Dimension, + Di: Dimension, + Do: Dimension, { fn clone(&self) -> Self { SliceInfo { + in_dim: PhantomData, out_dim: PhantomData, indices: self.indices.clone(), } @@ -406,39 +496,64 @@ where } #[doc(hidden)] -pub trait SliceNextDim { +pub trait SliceNextInDim { + fn next_dim(&self, _: PhantomData) -> PhantomData; +} + +macro_rules! impl_slicenextindim_larger { + (($($generics:tt)*), $self:ty) => { + impl SliceNextInDim for $self { + fn next_dim(&self, _: PhantomData) -> PhantomData { + PhantomData + } + } + } +} +impl_slicenextindim_larger!((), isize); +impl_slicenextindim_larger!((), usize); +impl_slicenextindim_larger!((), i32); +impl_slicenextindim_larger!((T), Range); +impl_slicenextindim_larger!((T), RangeInclusive); +impl_slicenextindim_larger!((T), RangeFrom); +impl_slicenextindim_larger!((T), RangeTo); +impl_slicenextindim_larger!((T), RangeToInclusive); +impl_slicenextindim_larger!((), RangeFull); +impl_slicenextindim_larger!((), Slice); + +#[doc(hidden)] +pub trait SliceNextOutDim { fn next_dim(&self, _: PhantomData) -> PhantomData; } -macro_rules! impl_slicenextdim_equal { +macro_rules! impl_slicenextoutdim_equal { ($self:ty) => { - impl SliceNextDim for $self { + impl SliceNextOutDim for $self { fn next_dim(&self, _: PhantomData) -> PhantomData { PhantomData } } }; } -impl_slicenextdim_equal!(isize); -impl_slicenextdim_equal!(usize); -impl_slicenextdim_equal!(i32); +impl_slicenextoutdim_equal!(isize); +impl_slicenextoutdim_equal!(usize); +impl_slicenextoutdim_equal!(i32); -macro_rules! impl_slicenextdim_larger { +macro_rules! impl_slicenextoutdim_larger { (($($generics:tt)*), $self:ty) => { - impl SliceNextDim for $self { + impl SliceNextOutDim for $self { fn next_dim(&self, _: PhantomData) -> PhantomData { PhantomData } } } } -impl_slicenextdim_larger!((T), Range); -impl_slicenextdim_larger!((T), RangeInclusive); -impl_slicenextdim_larger!((T), RangeFrom); -impl_slicenextdim_larger!((T), RangeTo); -impl_slicenextdim_larger!((T), RangeToInclusive); -impl_slicenextdim_larger!((), RangeFull); -impl_slicenextdim_larger!((), Slice); +impl_slicenextoutdim_larger!((T), Range); +impl_slicenextoutdim_larger!((T), RangeInclusive); +impl_slicenextoutdim_larger!((T), RangeFrom); +impl_slicenextoutdim_larger!((T), RangeTo); +impl_slicenextoutdim_larger!((T), RangeToInclusive); +impl_slicenextoutdim_larger!((), RangeFull); +impl_slicenextoutdim_larger!((), Slice); /// Slice argument constructor. /// @@ -534,14 +649,16 @@ impl_slicenextdim_larger!((), Slice); #[macro_export] macro_rules! s( // convert a..b;c into @convert(a..b, c), final item - (@parse $dim:expr, [$($stack:tt)*] $r:expr;$s:expr) => { + (@parse $in_dim:expr, $out_dim:expr, [$($stack:tt)*] $r:expr;$s:expr) => { match $r { r => { - let out_dim = $crate::SliceNextDim::next_dim(&r, $dim); + let in_dim = $crate::SliceNextInDim::next_dim(&r, $in_dim); + let out_dim = $crate::SliceNextOutDim::next_dim(&r, $out_dim); #[allow(unsafe_code)] unsafe { $crate::SliceInfo::new_unchecked( [$($stack)* $crate::s!(@convert r, $s)], + in_dim, out_dim, ) } @@ -549,14 +666,16 @@ macro_rules! s( } }; // convert a..b into @convert(a..b), final item - (@parse $dim:expr, [$($stack:tt)*] $r:expr) => { + (@parse $in_dim:expr, $out_dim:expr, [$($stack:tt)*] $r:expr) => { match $r { r => { - let out_dim = $crate::SliceNextDim::next_dim(&r, $dim); + let in_dim = $crate::SliceNextInDim::next_dim(&r, $in_dim); + let out_dim = $crate::SliceNextOutDim::next_dim(&r, $out_dim); #[allow(unsafe_code)] unsafe { $crate::SliceInfo::new_unchecked( [$($stack)* $crate::s!(@convert r)], + in_dim, out_dim, ) } @@ -564,19 +683,20 @@ macro_rules! s( } }; // convert a..b;c into @convert(a..b, c), final item, trailing comma - (@parse $dim:expr, [$($stack:tt)*] $r:expr;$s:expr ,) => { - $crate::s![@parse $dim, [$($stack)*] $r;$s] + (@parse $in_dim:expr, $out_dim:expr, [$($stack:tt)*] $r:expr;$s:expr ,) => { + $crate::s![@parse $in_dim, $out_dim, [$($stack)*] $r;$s] }; // convert a..b into @convert(a..b), final item, trailing comma - (@parse $dim:expr, [$($stack:tt)*] $r:expr ,) => { - $crate::s![@parse $dim, [$($stack)*] $r] + (@parse $in_dim:expr, $out_dim:expr, [$($stack:tt)*] $r:expr ,) => { + $crate::s![@parse $in_dim, $out_dim, [$($stack)*] $r] }; // convert a..b;c into @convert(a..b, c) - (@parse $dim:expr, [$($stack:tt)*] $r:expr;$s:expr, $($t:tt)*) => { + (@parse $in_dim:expr, $out_dim:expr, [$($stack:tt)*] $r:expr;$s:expr, $($t:tt)*) => { match $r { r => { $crate::s![@parse - $crate::SliceNextDim::next_dim(&r, $dim), + $crate::SliceNextInDim::next_dim(&r, $in_dim), + $crate::SliceNextOutDim::next_dim(&r, $out_dim), [$($stack)* $crate::s!(@convert r, $s),] $($t)* ] @@ -584,11 +704,12 @@ macro_rules! s( } }; // convert a..b into @convert(a..b) - (@parse $dim:expr, [$($stack:tt)*] $r:expr, $($t:tt)*) => { + (@parse $in_dim:expr, $out_dim:expr, [$($stack:tt)*] $r:expr, $($t:tt)*) => { match $r { r => { $crate::s![@parse - $crate::SliceNextDim::next_dim(&r, $dim), + $crate::SliceNextInDim::next_dim(&r, $in_dim), + $crate::SliceNextOutDim::next_dim(&r, $out_dim), [$($stack)* $crate::s!(@convert r),] $($t)* ] @@ -596,11 +717,15 @@ macro_rules! s( } }; // empty call, i.e. `s![]` - (@parse ::std::marker::PhantomData::<$crate::Ix0>, []) => { + (@parse ::std::marker::PhantomData::<$crate::Ix0>, ::std::marker::PhantomData::<$crate::Ix0>, []) => { { #[allow(unsafe_code)] unsafe { - $crate::SliceInfo::new_unchecked([], ::std::marker::PhantomData::<$crate::Ix0>) + $crate::SliceInfo::new_unchecked( + [], + ::std::marker::PhantomData::<$crate::Ix0>, + ::std::marker::PhantomData::<$crate::Ix0>, + ) } } }; @@ -617,7 +742,12 @@ macro_rules! s( ($($t:tt)*) => { // The extra `*&` is a workaround for this compiler bug: // https://github.com/rust-lang/rust/issues/23014 - &*&$crate::s![@parse ::std::marker::PhantomData::<$crate::Ix0>, [] $($t)*] + &*&$crate::s![@parse + ::std::marker::PhantomData::<$crate::Ix0>, + ::std::marker::PhantomData::<$crate::Ix0>, + [] + $($t)* + ] }; ); @@ -650,13 +780,13 @@ where fn multi_slice_move(&self, _view: ArrayViewMut<'a, A, D>) -> Self::Output {} } -impl<'a, A, D, Do0> MultiSlice<'a, A, D> for (&SliceInfo,) +impl<'a, A, D, I0> MultiSlice<'a, A, D> for (&I0,) where A: 'a, D: Dimension, - Do0: Dimension, + I0: CanSlice, { - type Output = (ArrayViewMut<'a, A, Do0>,); + type Output = (ArrayViewMut<'a, A, I0::OutDim>,); fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output { (view.slice_move(self.0),) @@ -668,17 +798,17 @@ macro_rules! impl_multislice_tuple { impl_multislice_tuple!(@def_impl ($($but_last,)* $last,), [$($but_last)*] $last); }; (@def_impl ($($all:ident,)*), [$($but_last:ident)*] $last:ident) => { - impl<'a, A, D, $($all,)*> MultiSlice<'a, A, D> for ($(&SliceInfo,)*) + impl<'a, A, D, $($all,)*> MultiSlice<'a, A, D> for ($(&$all,)*) where A: 'a, D: Dimension, - $($all: Dimension,)* + $($all: CanSlice,)* { - type Output = ($(ArrayViewMut<'a, A, $all>,)*); + type Output = ($(ArrayViewMut<'a, A, $all::OutDim>,)*); fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output { #[allow(non_snake_case)] - let ($($all,)*) = self; + let &($($all,)*) = self; let shape = view.raw_dim(); assert!(!impl_multislice_tuple!(@intersects_self &shape, ($($all,)*))); @@ -702,11 +832,11 @@ macro_rules! impl_multislice_tuple { }; } -impl_multislice_tuple!([Do0] Do1); -impl_multislice_tuple!([Do0 Do1] Do2); -impl_multislice_tuple!([Do0 Do1 Do2] Do3); -impl_multislice_tuple!([Do0 Do1 Do2 Do3] Do4); -impl_multislice_tuple!([Do0 Do1 Do2 Do3 Do4] Do5); +impl_multislice_tuple!([I0] I1); +impl_multislice_tuple!([I0 I1] I2); +impl_multislice_tuple!([I0 I1 I2] I3); +impl_multislice_tuple!([I0 I1 I2 I3] I4); +impl_multislice_tuple!([I0 I1 I2 I3 I4] I5); impl<'a, A, D, T> MultiSlice<'a, A, D> for &T where diff --git a/tests/array.rs b/tests/array.rs index 4b30c15b4..3da6ba1f7 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -210,13 +210,13 @@ fn test_slice_dyninput_array_fixed() { arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); - arr.view().slice_collapse(info.as_ref()); + arr.view().slice_collapse(info); } #[test] fn test_slice_array_dyn() { let mut arr = Array3::::zeros((5, 2, 5)); - let info = &SliceInfo::<_, IxDyn>::new([ + let info = &SliceInfo::<_, Ix3, IxDyn>::new([ AxisSliceInfo::from(1..), AxisSliceInfo::from(1), AxisSliceInfo::from(..).step_by(2), @@ -231,7 +231,7 @@ fn test_slice_array_dyn() { #[test] fn test_slice_dyninput_array_dyn() { let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); - let info = &SliceInfo::<_, IxDyn>::new([ + let info = &SliceInfo::<_, Ix3, IxDyn>::new([ AxisSliceInfo::from(1..), AxisSliceInfo::from(1), AxisSliceInfo::from(..).step_by(2), @@ -240,37 +240,37 @@ fn test_slice_dyninput_array_dyn() { arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); - arr.view().slice_collapse(info.as_ref()); + arr.view().slice_collapse(info); } #[test] fn test_slice_dyninput_vec_fixed() { let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); - let info = &SliceInfo::<_, Ix2>::new(vec![ + let info = &SliceInfo::<_, Ix3, Ix2>::new(vec![ AxisSliceInfo::from(1..), AxisSliceInfo::from(1), AxisSliceInfo::from(..).step_by(2), ]) .unwrap(); - arr.slice(info.as_ref()); - arr.slice_mut(info.as_ref()); - arr.view().slice_move(info.as_ref()); - arr.view().slice_collapse(info.as_ref()); + arr.slice(info); + arr.slice_mut(info); + arr.view().slice_move(info); + arr.view().slice_collapse(info); } #[test] fn test_slice_dyninput_vec_dyn() { let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); - let info = &SliceInfo::<_, IxDyn>::new(vec![ + let info = &SliceInfo::<_, Ix3, IxDyn>::new(vec![ AxisSliceInfo::from(1..), AxisSliceInfo::from(1), AxisSliceInfo::from(..).step_by(2), ]) .unwrap(); - arr.slice(info.as_ref()); - arr.slice_mut(info.as_ref()); - arr.view().slice_move(info.as_ref()); - arr.view().slice_collapse(info.as_ref()); + arr.slice(info); + arr.slice_mut(info); + arr.view().slice_move(info); + arr.view().slice_collapse(info); } #[test] diff --git a/tests/oper.rs b/tests/oper.rs index 22dc6603e..16f3edbc6 100644 --- a/tests/oper.rs +++ b/tests/oper.rs @@ -595,7 +595,7 @@ fn scaled_add_3() { { let mut av = a.slice_mut(s![..;s1, ..;s2]); - let c = c.slice(SliceInfo::<_, IxDyn>::new(cslice).unwrap().as_ref()); + let c = c.slice(&SliceInfo::<_, IxDyn, IxDyn>::new(cslice).unwrap()); let mut answerv = answer.slice_mut(s![..;s1, ..;s2]); answerv += &(beta * &c); From 546b69cd6e3cae7509716b1aa692daf6eb0a9ce4 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sun, 9 Dec 2018 16:53:27 -0500 Subject: [PATCH 03/28] Add support for inserting new axes while slicing --- src/dimension/mod.rs | 51 +++++++++++--- src/doc/ndarray_for_numpy_users/mod.rs | 2 +- src/impl_methods.rs | 22 ++++-- src/lib.rs | 23 ++++--- src/prelude.rs | 3 + src/slice.rs | 93 +++++++++++++++++++------- tests/array.rs | 48 +++++++------ 7 files changed, 173 insertions(+), 69 deletions(-) diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index a1e245884..28bdc7f91 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -601,7 +601,11 @@ pub fn slices_intersect( indices2: &impl CanSlice, ) -> bool { debug_assert_eq!(indices1.in_ndim(), indices2.in_ndim()); - for (&axis_len, &si1, &si2) in izip!(dim.slice(), indices1.as_ref(), indices2.as_ref()) { + for (&axis_len, &si1, &si2) in izip!( + dim.slice(), + indices1.as_ref().iter().filter(|si| !si.is_new_axis()), + indices2.as_ref().iter().filter(|si| !si.is_new_axis()), + ) { // The slices do not intersect iff any pair of `AxisSliceInfo` does not intersect. match (si1, si2) { ( @@ -649,6 +653,7 @@ pub fn slices_intersect( return false; } } + (AxisSliceInfo::NewAxis, _) | (_, AxisSliceInfo::NewAxis) => unreachable!(), } } true @@ -720,7 +725,7 @@ mod test { }; use crate::error::{from_kind, ErrorKind}; use crate::slice::Slice; - use crate::{Dim, Dimension, Ix0, Ix1, Ix2, Ix3, IxDyn}; + use crate::{Dim, Dimension, Ix0, Ix1, Ix2, Ix3, IxDyn, NewAxis}; use num_integer::gcd; use quickcheck::{quickcheck, TestResult}; @@ -994,17 +999,45 @@ mod test { #[test] fn slices_intersect_true() { - assert!(slices_intersect(&Dim([4, 5]), s![.., ..], s![.., ..])); - assert!(slices_intersect(&Dim([4, 5]), s![0, ..], s![0, ..])); - assert!(slices_intersect(&Dim([4, 5]), s![..;2, ..], s![..;3, ..])); - assert!(slices_intersect(&Dim([4, 5]), s![.., ..;2], s![.., 1..;3])); + assert!(slices_intersect( + &Dim([4, 5]), + s![NewAxis, .., NewAxis, ..], + s![.., NewAxis, .., NewAxis] + )); + assert!(slices_intersect( + &Dim([4, 5]), + s![NewAxis, 0, ..], + s![0, ..] + )); + assert!(slices_intersect( + &Dim([4, 5]), + s![..;2, ..], + s![..;3, NewAxis, ..] + )); + assert!(slices_intersect( + &Dim([4, 5]), + s![.., ..;2], + s![.., 1..;3, NewAxis] + )); assert!(slices_intersect(&Dim([4, 10]), s![.., ..;9], s![.., 3..;6])); } #[test] fn slices_intersect_false() { - assert!(!slices_intersect(&Dim([4, 5]), s![..;2, ..], s![1..;2, ..])); - assert!(!slices_intersect(&Dim([4, 5]), s![..;2, ..], s![1..;3, ..])); - assert!(!slices_intersect(&Dim([4, 5]), s![.., ..;9], s![.., 3..;6])); + assert!(!slices_intersect( + &Dim([4, 5]), + s![..;2, ..], + s![NewAxis, 1..;2, ..] + )); + assert!(!slices_intersect( + &Dim([4, 5]), + s![..;2, NewAxis, ..], + s![1..;3, ..] + )); + assert!(!slices_intersect( + &Dim([4, 5]), + s![.., ..;9], + s![.., 3..;6, NewAxis] + )); } } diff --git a/src/doc/ndarray_for_numpy_users/mod.rs b/src/doc/ndarray_for_numpy_users/mod.rs index aea97da3f..478cc2cec 100644 --- a/src/doc/ndarray_for_numpy_users/mod.rs +++ b/src/doc/ndarray_for_numpy_users/mod.rs @@ -532,7 +532,7 @@ //! `a[:] = b` | [`a.assign(&b)`][.assign()] | copy the data from array `b` into array `a` //! `np.concatenate((a,b), axis=1)` | [`concatenate![Axis(1), a, b]`][concatenate!] or [`concatenate(Axis(1), &[a.view(), b.view()])`][concatenate()] | concatenate arrays `a` and `b` along axis 1 //! `np.stack((a,b), axis=1)` | [`stack![Axis(1), a, b]`][stack!] or [`stack(Axis(1), vec![a.view(), b.view()])`][stack()] | stack arrays `a` and `b` along axis 1 -//! `a[:,np.newaxis]` or `np.expand_dims(a, axis=1)` | [`a.insert_axis(Axis(1))`][.insert_axis()] | create an array from `a`, inserting a new axis 1 +//! `a[:,np.newaxis]` or `np.expand_dims(a, axis=1)` | [`a.slice(s![.., NewAxis])`][.slice()] or [`a.insert_axis(Axis(1))`][.insert_axis()] | create an view of 1-D array `a`, inserting a new axis 1 //! `a.transpose()` or `a.T` | [`a.t()`][.t()] or [`a.reversed_axes()`][.reversed_axes()] | transpose of array `a` (view for `.t()` or by-move for `.reversed_axes()`) //! `np.diag(a)` | [`a.diag()`][.diag()] | view the diagonal of `a` //! `a.flatten()` | [`use std::iter::FromIterator; Array::from_iter(a.iter().cloned())`][::from_iter()] | create a 1-D array by flattening `a` diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 2cbfcacf2..143824850 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -437,6 +437,12 @@ where // Skip the old axis since it should be removed. old_axis += 1; } + AxisSliceInfo::NewAxis => { + // Set the dim and stride of the new axis. + new_dim[new_axis] = 1; + new_strides[new_axis] = 0; + new_axis += 1; + } }); debug_assert_eq!(old_axis, self.ndim()); debug_assert_eq!(new_axis, out_ndim); @@ -450,6 +456,8 @@ where /// Slice the array in place without changing the number of dimensions. /// + /// Note that `NewAxis` elements in `info` are ignored. + /// /// See [*Slicing*](#slicing) for full documentation. /// /// **Panics** if an index is out of bounds or step size is zero.
@@ -463,18 +471,20 @@ where self.ndim(), "The input dimension of `info` must match the array to be sliced.", ); - info.as_ref() - .iter() - .enumerate() - .for_each(|(axis, &ax_info)| match ax_info { + let mut axis = 0; + info.as_ref().iter().for_each(|&ax_info| match ax_info { AxisSliceInfo::Slice { start, end, step } => { - self.slice_axis_inplace(Axis(axis), Slice { start, end, step }) + self.slice_axis_inplace(Axis(axis), Slice { start, end, step }); + axis += 1; } AxisSliceInfo::Index(index) => { let i_usize = abs_index(self.len_of(Axis(axis)), index); - self.collapse_axis(Axis(axis), i_usize) + self.collapse_axis(Axis(axis), i_usize); + axis += 1; } + AxisSliceInfo::NewAxis => {} }); + debug_assert_eq!(axis, self.ndim()); } /// Return a view of the array, sliced along the specified axis. diff --git a/src/lib.rs b/src/lib.rs index 7c9439411..ae750f3d8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -141,7 +141,7 @@ pub use crate::dimension::IxDynImpl; pub use crate::dimension::NdIndex; pub use crate::error::{ErrorKind, ShapeError}; pub use crate::indexes::{indices, indices_of}; -pub use crate::slice::{AxisSliceInfo, Slice, SliceInfo, SliceNextInDim, SliceNextOutDim}; +pub use crate::slice::{AxisSliceInfo, NewAxis, Slice, SliceInfo, SliceNextInDim, SliceNextOutDim}; use crate::iterators::Baseiter; use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes}; @@ -496,14 +496,16 @@ pub type Ixs = isize; /// /// If a range is used, the axis is preserved. If an index is used, that index /// is selected and the axis is removed; this selects a subview. See -/// [*Subviews*](#subviews) for more information about subviews. Note that -/// [`.slice_collapse()`] behaves like [`.collapse_axis()`] by preserving -/// the number of dimensions. +/// [*Subviews*](#subviews) for more information about subviews. If a +/// [`NewAxis`] instance is used, a new axis is inserted. Note that +/// [`.slice_collapse()`] ignores `NewAxis` elements and behaves like +/// [`.collapse_axis()`] by preserving the number of dimensions. /// /// [`.slice()`]: #method.slice /// [`.slice_mut()`]: #method.slice_mut /// [`.slice_move()`]: #method.slice_move /// [`.slice_collapse()`]: #method.slice_collapse +/// [`NewAxis`]: struct.NewAxis.html /// /// When slicing arrays with generic dimensionality, creating an instance of /// [`&SliceInfo`] to pass to the multi-axis slicing methods like [`.slice()`] @@ -526,7 +528,7 @@ pub type Ixs = isize; /// [`.multi_slice_move()`]: type.ArrayViewMut.html#method.multi_slice_move /// /// ``` -/// use ndarray::{arr2, arr3, s, ArrayBase, DataMut, Dimension, Slice}; +/// use ndarray::{arr2, arr3, s, ArrayBase, DataMut, Dimension, NewAxis, Slice}; /// /// // 2 submatrices of 2 rows with 3 elements per row, means a shape of `[2, 2, 3]`. /// @@ -561,16 +563,17 @@ pub type Ixs = isize; /// assert_eq!(d, e); /// assert_eq!(d.shape(), &[2, 1, 3]); /// -/// // Let’s create a slice while selecting a subview with +/// // Let’s create a slice while selecting a subview and inserting a new axis with /// // /// // - Both submatrices of the greatest dimension: `..` /// // - The last row in each submatrix, removing that axis: `-1` /// // - Row elements in reverse order: `..;-1` -/// let f = a.slice(s![.., -1, ..;-1]); -/// let g = arr2(&[[ 6, 5, 4], -/// [12, 11, 10]]); +/// // - A new axis at the end. +/// let f = a.slice(s![.., -1, ..;-1, NewAxis]); +/// let g = arr3(&[[ [6], [5], [4]], +/// [[12], [11], [10]]]); /// assert_eq!(f, g); -/// assert_eq!(f.shape(), &[2, 3]); +/// assert_eq!(f.shape(), &[2, 3, 1]); /// /// // Let's take two disjoint, mutable slices of a matrix with /// // diff --git a/src/prelude.rs b/src/prelude.rs index def236841..ea6dfb08f 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -49,6 +49,9 @@ pub use crate::{array, azip, s}; #[doc(no_inline)] pub use crate::ShapeBuilder; +#[doc(no_inline)] +pub use crate::NewAxis; + #[doc(no_inline)] pub use crate::AsArray; diff --git a/src/slice.rs b/src/slice.rs index dc4ad9d7a..5e2a9ac7f 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -67,7 +67,12 @@ impl Slice { } } -/// A slice (range with step) or an index. +/// Token to represent a new axis in a slice description. +/// +/// See also the [`s![]`](macro.s!.html) macro. +pub struct NewAxis; + +/// A slice (range with step), an index, or a new axis token. /// /// See also the [`s![]`](macro.s!.html) macro for a convenient way to create a /// `&SliceInfo<[AxisSliceInfo; n], Di, Do>`. @@ -91,6 +96,10 @@ impl Slice { /// from `a` until the end, in reverse order. It can also be created with /// `AxisSliceInfo::from(a..).step_by(-1)`. The Python equivalent is `[a::-1]`. /// The macro equivalent is `s![a..;-1]`. +/// +/// `AxisSliceInfo::NewAxis` is a new axis of length 1. It can also be created +/// with `AxisSliceInfo::from(NewAxis)`. The Python equivalent is +/// `[np.newaxis]`. The macro equivalent is `s![NewAxis]`. #[derive(Debug, PartialEq, Eq, Hash)] pub enum AxisSliceInfo { /// A range with step size. `end` is an exclusive index. Negative `begin` @@ -103,6 +112,8 @@ pub enum AxisSliceInfo { }, /// A single index. Index(isize), + /// A new axis of length 1. + NewAxis, } copy_and_clone! {AxisSliceInfo} @@ -118,6 +129,11 @@ impl AxisSliceInfo { matches!(self, AxisSliceInfo::Index(_)) } + /// Returns `true` if `self` is a `NewAxis` value. + pub fn is_new_axis(&self) -> bool { + matches!(self, AxisSliceInfo::NewAxis) + } + /// Returns a new `AxisSliceInfo` with the given step size (multiplied with /// the previous step size). /// @@ -137,6 +153,7 @@ impl AxisSliceInfo { step: orig_step * step, }, AxisSliceInfo::Index(s) => AxisSliceInfo::Index(s), + AxisSliceInfo::NewAxis => AxisSliceInfo::NewAxis, } } } @@ -157,6 +174,7 @@ impl fmt::Display for AxisSliceInfo { write!(f, ";{}", step)?; } } + AxisSliceInfo::NewAxis => write!(f, "NewAxis")?, } Ok(()) } @@ -276,6 +294,13 @@ impl_axissliceinfo_from_index!(isize); impl_axissliceinfo_from_index!(usize); impl_axissliceinfo_from_index!(i32); +impl From for AxisSliceInfo { + #[inline] + fn from(_: NewAxis) -> AxisSliceInfo { + AxisSliceInfo::NewAxis + } +} + /// A type that can slice an array of dimension `D`. /// /// This trait is unsafe to implement because the implementation must ensure @@ -396,12 +421,12 @@ where /// Errors if `Di` or `Do` is not consistent with `indices`. pub fn new(indices: T) -> Result, ShapeError> { if let Some(ndim) = Di::NDIM { - if ndim != indices.as_ref().len() { + if ndim != indices.as_ref().iter().filter(|s| !s.is_new_axis()).count() { return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); } } if let Some(ndim) = Do::NDIM { - if ndim != indices.as_ref().iter().filter(|s| s.is_slice()).count() { + if ndim != indices.as_ref().iter().filter(|s| !s.is_index()).count() { return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); } } @@ -421,8 +446,18 @@ where { /// Returns the number of dimensions of the input array for /// [`.slice()`](struct.ArrayBase.html#method.slice). + /// + /// If `Di` is a fixed-size dimension type, then this is equivalent to + /// `Di::NDIM.unwrap()`. Otherwise, the value is calculated by iterating + /// over the `AxisSliceInfo` elements. pub fn in_ndim(&self) -> usize { - Di::NDIM.unwrap_or_else(|| self.indices.as_ref().len()) + Di::NDIM.unwrap_or_else(|| { + self.indices + .as_ref() + .iter() + .filter(|s| !s.is_new_axis()) + .count() + }) } /// Returns the number of dimensions after calling @@ -437,7 +472,7 @@ where self.indices .as_ref() .iter() - .filter(|s| s.is_slice()) + .filter(|s| !s.is_index()) .count() }) } @@ -500,6 +535,12 @@ pub trait SliceNextInDim { fn next_dim(&self, _: PhantomData) -> PhantomData; } +impl SliceNextInDim for NewAxis { + fn next_dim(&self, _: PhantomData) -> PhantomData { + PhantomData + } +} + macro_rules! impl_slicenextindim_larger { (($($generics:tt)*), $self:ty) => { impl SliceNextInDim for $self { @@ -554,12 +595,13 @@ impl_slicenextoutdim_larger!((T), RangeTo); impl_slicenextoutdim_larger!((T), RangeToInclusive); impl_slicenextoutdim_larger!((), RangeFull); impl_slicenextoutdim_larger!((), Slice); +impl_slicenextoutdim_larger!((), NewAxis); /// Slice argument constructor. /// -/// `s![]` takes a list of ranges/slices/indices, separated by comma, with -/// optional step sizes that are separated from the range by a semicolon. It is -/// converted into a [`&SliceInfo`] instance. +/// `s![]` takes a list of ranges/slices/indices/new-axes, separated by comma, +/// with optional step sizes that are separated from the range by a semicolon. +/// It is converted into a [`&SliceInfo`] instance. /// /// [`&SliceInfo`]: struct.SliceInfo.html /// @@ -578,22 +620,25 @@ impl_slicenextoutdim_larger!((), Slice); /// * *slice*: a [`Slice`] instance to use for slicing that axis. /// * *slice* `;` *step*: a range constructed from the start and end of a [`Slice`] /// instance, with new step size *step*, to use for slicing that axis. +/// * *new-axis*: a [`NewAxis`] instance that represents the creation of a new axis. /// /// [`Slice`]: struct.Slice.html -/// -/// The number of *axis-slice-info* must match the number of axes in the array. -/// *index*, *range*, *slice*, and *step* can be expressions. *index* must be -/// of type `isize`, `usize`, or `i32`. *range* must be of type `Range`, -/// `RangeTo`, `RangeFrom`, or `RangeFull` where `I` is `isize`, `usize`, -/// or `i32`. *step* must be a type that can be converted to `isize` with the -/// `as` keyword. -/// -/// For example `s![0..4;2, 6, 1..5]` is a slice of the first axis for 0..4 -/// with step size 2, a subview of the second axis at index 6, and a slice of -/// the third axis for 1..5 with default step size 1. The input array must have -/// 3 dimensions. The resulting slice would have shape `[2, 4]` for -/// [`.slice()`], [`.slice_mut()`], and [`.slice_move()`], and shape -/// `[2, 1, 4]` for [`.slice_collapse()`]. +/// [`NewAxis`]: struct.NewAxis.html +/// +/// The number of *axis-slice-info*, not including *new-axis*, must match the +/// number of axes in the array. *index*, *range*, *slice*, *step*, and +/// *new-axis* can be expressions. *index* must be of type `isize`, `usize`, or +/// `i32`. *range* must be of type `Range`, `RangeTo`, `RangeFrom`, or +/// `RangeFull` where `I` is `isize`, `usize`, or `i32`. *step* must be a type +/// that can be converted to `isize` with the `as` keyword. +/// +/// For example `s![0..4;2, 6, 1..5, NewAxis]` is a slice of the first axis for +/// 0..4 with step size 2, a subview of the second axis at index 6, a slice of +/// the third axis for 1..5 with default step size 1, and a new axis of length +/// 1 at the end of the shape. The input array must have 3 dimensions. The +/// resulting slice would have shape `[2, 4, 1]` for [`.slice()`], +/// [`.slice_mut()`], and [`.slice_move()`], and shape `[2, 1, 4]` for +/// [`.slice_collapse()`]. /// /// [`.slice()`]: struct.ArrayBase.html#method.slice /// [`.slice_mut()`]: struct.ArrayBase.html#method.slice_mut @@ -731,11 +776,11 @@ macro_rules! s( }; // Catch-all clause for syntax errors (@parse $($t:tt)*) => { compile_error!("Invalid syntax in s![] call.") }; - // convert range/index into AxisSliceInfo + // convert range/index/new-axis into AxisSliceInfo (@convert $r:expr) => { <$crate::AxisSliceInfo as ::std::convert::From<_>>::from($r) }; - // convert range/index and step into AxisSliceInfo + // convert range/index/new-axis and step into AxisSliceInfo (@convert $r:expr, $s:expr) => { <$crate::AxisSliceInfo as ::std::convert::From<_>>::from($r).step_by($s as isize) }; diff --git a/tests/array.rs b/tests/array.rs index 3da6ba1f7..51f59fcb1 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -85,8 +85,8 @@ fn test_slice() { *elt = i; } - let vi = A.slice(s![1.., ..;2, Slice::new(0, None, 2)]); - assert_eq!(vi.shape(), &[2, 2, 3]); + let vi = A.slice(s![1.., ..;2, NewAxis, Slice::new(0, None, 2)]); + assert_eq!(vi.shape(), &[2, 2, 1, 3]); let vi = A.slice(s![.., .., ..]); assert_eq!(vi.shape(), A.shape()); assert!(vi.iter().zip(A.iter()).all(|(a, b)| a == b)); @@ -138,8 +138,8 @@ fn test_slice_with_many_dim() { *elt = i; } - let vi = A.slice(s![..2, .., ..;2, ..1, ..1, 1.., ..]); - let new_shape = &[2, 1, 2, 1, 1, 1, 1][..]; + let vi = A.slice(s![..2, NewAxis, .., ..;2, NewAxis, ..1, ..1, 1.., ..]); + let new_shape = &[2, 1, 1, 2, 1, 1, 1, 1, 1][..]; assert_eq!(vi.shape(), new_shape); let correct = array![ [A[&[0, 0, 0, 0, 0, 1, 0][..]], A[&[0, 0, 2, 0, 0, 1, 0][..]]], @@ -196,7 +196,7 @@ fn test_slice_args_eval_step_once() { #[test] fn test_slice_array_fixed() { let mut arr = Array3::::zeros((5, 2, 5)); - let info = s![1.., 1, ..;2]; + let info = s![1.., 1, NewAxis, ..;2]; arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); @@ -206,7 +206,7 @@ fn test_slice_array_fixed() { #[test] fn test_slice_dyninput_array_fixed() { let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); - let info = s![1.., 1, ..;2]; + let info = s![1.., 1, NewAxis, ..;2]; arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); @@ -219,6 +219,7 @@ fn test_slice_array_dyn() { let info = &SliceInfo::<_, Ix3, IxDyn>::new([ AxisSliceInfo::from(1..), AxisSliceInfo::from(1), + AxisSliceInfo::from(NewAxis), AxisSliceInfo::from(..).step_by(2), ]) .unwrap(); @@ -234,6 +235,7 @@ fn test_slice_dyninput_array_dyn() { let info = &SliceInfo::<_, Ix3, IxDyn>::new([ AxisSliceInfo::from(1..), AxisSliceInfo::from(1), + AxisSliceInfo::from(NewAxis), AxisSliceInfo::from(..).step_by(2), ]) .unwrap(); @@ -246,9 +248,10 @@ fn test_slice_dyninput_array_dyn() { #[test] fn test_slice_dyninput_vec_fixed() { let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); - let info = &SliceInfo::<_, Ix3, Ix2>::new(vec![ + let info = &SliceInfo::<_, Ix3, Ix3>::new(vec![ AxisSliceInfo::from(1..), AxisSliceInfo::from(1), + AxisSliceInfo::from(NewAxis), AxisSliceInfo::from(..).step_by(2), ]) .unwrap(); @@ -264,6 +267,7 @@ fn test_slice_dyninput_vec_dyn() { let info = &SliceInfo::<_, Ix3, IxDyn>::new(vec![ AxisSliceInfo::from(1..), AxisSliceInfo::from(1), + AxisSliceInfo::from(NewAxis), AxisSliceInfo::from(..).step_by(2), ]) .unwrap(); @@ -274,27 +278,33 @@ fn test_slice_dyninput_vec_dyn() { } #[test] -fn test_slice_with_subview() { +fn test_slice_with_subview_and_new_axis() { let mut arr = ArcArray::::zeros((3, 5, 4)); for (i, elt) in arr.iter_mut().enumerate() { *elt = i; } - let vi = arr.slice(s![1.., 2, ..;2]); - assert_eq!(vi.shape(), &[2, 2]); + let vi = arr.slice(s![NewAxis, 1.., 2, ..;2]); + assert_eq!(vi.shape(), &[1, 2, 2]); assert!(vi .iter() - .zip(arr.index_axis(Axis(1), 2).slice(s![1.., ..;2]).iter()) + .zip( + arr.index_axis(Axis(1), 2) + .slice(s![1.., ..;2]) + .insert_axis(Axis(0)) + .iter() + ) .all(|(a, b)| a == b)); - let vi = arr.slice(s![1, 2, ..;2]); - assert_eq!(vi.shape(), &[2]); + let vi = arr.slice(s![1, NewAxis, 2, ..;2]); + assert_eq!(vi.shape(), &[1, 2]); assert!(vi .iter() .zip( arr.index_axis(Axis(0), 1) .index_axis(Axis(0), 2) .slice(s![..;2]) + .insert_axis(Axis(0)) .iter() ) .all(|(a, b)| a == b)); @@ -313,7 +323,7 @@ fn test_slice_collapse_with_indices() { { let mut vi = arr.view(); - vi.slice_collapse(s![1.., 2, ..;2]); + vi.slice_collapse(s![NewAxis, 1.., 2, ..;2]); assert_eq!(vi.shape(), &[2, 1, 2]); assert!(vi .iter() @@ -321,7 +331,7 @@ fn test_slice_collapse_with_indices() { .all(|(a, b)| a == b)); let mut vi = arr.view(); - vi.slice_collapse(s![1, 2, ..;2]); + vi.slice_collapse(s![1, NewAxis, 2, ..;2]); assert_eq!(vi.shape(), &[1, 1, 2]); assert!(vi .iter() @@ -329,7 +339,7 @@ fn test_slice_collapse_with_indices() { .all(|(a, b)| a == b)); let mut vi = arr.view(); - vi.slice_collapse(s![1, 2, 3]); + vi.slice_collapse(s![1, 2, NewAxis, 3]); assert_eq!(vi.shape(), &[1, 1, 1]); assert_eq!(vi, Array3::from_elem((1, 1, 1), arr[(1, 2, 3)])); } @@ -337,7 +347,7 @@ fn test_slice_collapse_with_indices() { // Do it to the ArcArray itself let elem = arr[(1, 2, 3)]; let mut vi = arr; - vi.slice_collapse(s![1, 2, 3]); + vi.slice_collapse(s![1, 2, 3, NewAxis]); assert_eq!(vi.shape(), &[1, 1, 1]); assert_eq!(vi, Array3::from_elem((1, 1, 1), elem)); } @@ -382,7 +392,7 @@ fn test_multislice() { fn test_multislice_intersecting() { assert_panics!({ let mut arr = Array2::::zeros((8, 6)); - arr.multi_slice_mut((s![3, ..], s![3, ..])); + arr.multi_slice_mut((s![3, .., NewAxis], s![3, ..])); }); assert_panics!({ let mut arr = Array2::::zeros((8, 6)); @@ -390,7 +400,7 @@ fn test_multislice_intersecting() { }); assert_panics!({ let mut arr = Array2::::zeros((8, 6)); - arr.multi_slice_mut((s![3, ..], s![..;3, ..])); + arr.multi_slice_mut((s![3, ..], s![..;3, NewAxis, ..])); }); assert_panics!({ let mut arr = Array2::::zeros((8, 6)); From 6e335ca21a1310ecfe084751ca70fc0f131296cc Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sun, 16 Dec 2018 23:31:07 -0500 Subject: [PATCH 04/28] Rename SliceInfo generic params to Din and Dout --- src/slice.rs | 108 +++++++++++++++++++++++++-------------------------- 1 file changed, 54 insertions(+), 54 deletions(-) diff --git a/src/slice.rs b/src/slice.rs index 5e2a9ac7f..e0ba0df78 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -75,7 +75,7 @@ pub struct NewAxis; /// A slice (range with step), an index, or a new axis token. /// /// See also the [`s![]`](macro.s!.html) macro for a convenient way to create a -/// `&SliceInfo<[AxisSliceInfo; n], Di, Do>`. +/// `&SliceInfo<[AxisSliceInfo; n], Din, Dout>`. /// /// ## Examples /// @@ -316,12 +316,12 @@ pub unsafe trait CanSlice: AsRef<[AxisSliceInfo]> { macro_rules! impl_canslice_samedim { ($in_dim:ty) => { - unsafe impl CanSlice<$in_dim> for SliceInfo + unsafe impl CanSlice<$in_dim> for SliceInfo where T: AsRef<[AxisSliceInfo]>, - Do: Dimension, + Dout: Dimension, { - type OutDim = Do; + type OutDim = Dout; fn in_ndim(&self) -> usize { self.in_ndim() @@ -341,13 +341,13 @@ impl_canslice_samedim!(Ix4); impl_canslice_samedim!(Ix5); impl_canslice_samedim!(Ix6); -unsafe impl CanSlice for SliceInfo +unsafe impl CanSlice for SliceInfo where T: AsRef<[AxisSliceInfo]>, - Di: Dimension, - Do: Dimension, + Din: Dimension, + Dout: Dimension, { - type OutDim = Do; + type OutDim = Dout; fn in_ndim(&self) -> usize { self.in_ndim() @@ -361,25 +361,25 @@ where /// Represents all of the necessary information to perform a slice. /// /// The type `T` is typically `[AxisSliceInfo; n]`, `[AxisSliceInfo]`, or -/// `Vec`. The type `Di` is the dimension of the array to be -/// sliced, and `Do` is the output dimension after calling [`.slice()`]. Note -/// that if `Di` is a fixed dimension type (`Ix0`, `Ix1`, `Ix2`, etc.), the +/// `Vec`. The type `Din` is the dimension of the array to be +/// sliced, and `Dout` is the output dimension after calling [`.slice()`]. Note +/// that if `Din` is a fixed dimension type (`Ix0`, `Ix1`, `Ix2`, etc.), the /// `SliceInfo` instance can still be used to slice an array with dimension /// `IxDyn` as long as the number of axes matches. /// /// [`.slice()`]: struct.ArrayBase.html#method.slice #[derive(Debug)] #[repr(C)] -pub struct SliceInfo { - in_dim: PhantomData, - out_dim: PhantomData, +pub struct SliceInfo { + in_dim: PhantomData, + out_dim: PhantomData, indices: T, } -impl Deref for SliceInfo +impl Deref for SliceInfo where - Di: Dimension, - Do: Dimension, + Din: Dimension, + Dout: Dimension, { type Target = T; fn deref(&self) -> &Self::Target { @@ -387,10 +387,10 @@ where } } -impl SliceInfo +impl SliceInfo where - Di: Dimension, - Do: Dimension, + Din: Dimension, + Dout: Dimension, { /// Returns a new `SliceInfo` instance. /// @@ -399,9 +399,9 @@ where #[doc(hidden)] pub unsafe fn new_unchecked( indices: T, - in_dim: PhantomData, - out_dim: PhantomData, - ) -> SliceInfo { + in_dim: PhantomData, + out_dim: PhantomData, + ) -> SliceInfo { SliceInfo { in_dim: in_dim, out_dim: out_dim, @@ -410,22 +410,22 @@ where } } -impl SliceInfo +impl SliceInfo where T: AsRef<[AxisSliceInfo]>, - Di: Dimension, - Do: Dimension, + Din: Dimension, + Dout: Dimension, { /// Returns a new `SliceInfo` instance. /// - /// Errors if `Di` or `Do` is not consistent with `indices`. - pub fn new(indices: T) -> Result, ShapeError> { - if let Some(ndim) = Di::NDIM { + /// Errors if `Din` or `Dout` is not consistent with `indices`. + pub fn new(indices: T) -> Result, ShapeError> { + if let Some(ndim) = Din::NDIM { if ndim != indices.as_ref().iter().filter(|s| !s.is_new_axis()).count() { return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); } } - if let Some(ndim) = Do::NDIM { + if let Some(ndim) = Dout::NDIM { if ndim != indices.as_ref().iter().filter(|s| !s.is_index()).count() { return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); } @@ -438,20 +438,20 @@ where } } -impl SliceInfo +impl SliceInfo where T: AsRef<[AxisSliceInfo]>, - Di: Dimension, - Do: Dimension, + Din: Dimension, + Dout: Dimension, { /// Returns the number of dimensions of the input array for /// [`.slice()`](struct.ArrayBase.html#method.slice). /// - /// If `Di` is a fixed-size dimension type, then this is equivalent to - /// `Di::NDIM.unwrap()`. Otherwise, the value is calculated by iterating + /// If `Din` is a fixed-size dimension type, then this is equivalent to + /// `Din::NDIM.unwrap()`. Otherwise, the value is calculated by iterating /// over the `AxisSliceInfo` elements. pub fn in_ndim(&self) -> usize { - Di::NDIM.unwrap_or_else(|| { + Din::NDIM.unwrap_or_else(|| { self.indices .as_ref() .iter() @@ -464,11 +464,11 @@ where /// [`.slice()`](struct.ArrayBase.html#method.slice) (including taking /// subviews). /// - /// If `Do` is a fixed-size dimension type, then this is equivalent to - /// `Do::NDIM.unwrap()`. Otherwise, the value is calculated by iterating + /// If `Dout` is a fixed-size dimension type, then this is equivalent to + /// `Dout::NDIM.unwrap()`. Otherwise, the value is calculated by iterating /// over the `AxisSliceInfo` elements. pub fn out_ndim(&self) -> usize { - Do::NDIM.unwrap_or_else(|| { + Dout::NDIM.unwrap_or_else(|| { self.indices .as_ref() .iter() @@ -478,48 +478,48 @@ where } } -impl AsRef<[AxisSliceInfo]> for SliceInfo +impl AsRef<[AxisSliceInfo]> for SliceInfo where T: AsRef<[AxisSliceInfo]>, - Di: Dimension, - Do: Dimension, + Din: Dimension, + Dout: Dimension, { fn as_ref(&self) -> &[AxisSliceInfo] { self.indices.as_ref() } } -impl AsRef> for SliceInfo +impl AsRef> for SliceInfo where T: AsRef<[AxisSliceInfo]>, - Di: Dimension, - Do: Dimension, + Din: Dimension, + Dout: Dimension, { - fn as_ref(&self) -> &SliceInfo<[AxisSliceInfo], Di, Do> { + fn as_ref(&self) -> &SliceInfo<[AxisSliceInfo], Din, Dout> { unsafe { // This is okay because the only non-zero-sized member of - // `SliceInfo` is `indices`, so `&SliceInfo<[AxisSliceInfo], Di, Do>` + // `SliceInfo` is `indices`, so `&SliceInfo<[AxisSliceInfo], Din, Dout>` // should have the same bitwise representation as // `&[AxisSliceInfo]`. &*(self.indices.as_ref() as *const [AxisSliceInfo] - as *const SliceInfo<[AxisSliceInfo], Di, Do>) + as *const SliceInfo<[AxisSliceInfo], Din, Dout>) } } } -impl Copy for SliceInfo +impl Copy for SliceInfo where T: Copy, - Di: Dimension, - Do: Dimension, + Din: Dimension, + Dout: Dimension, { } -impl Clone for SliceInfo +impl Clone for SliceInfo where T: Clone, - Di: Dimension, - Do: Dimension, + Din: Dimension, + Dout: Dimension, { fn clone(&self) -> Self { SliceInfo { From d6b9cb0d3655f7ee4c9b6cff72b504f294fbfa83 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sun, 16 Dec 2018 23:35:03 -0500 Subject: [PATCH 05/28] Improve code style --- src/slice.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/slice.rs b/src/slice.rs index e0ba0df78..e7a2fe172 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -451,13 +451,15 @@ where /// `Din::NDIM.unwrap()`. Otherwise, the value is calculated by iterating /// over the `AxisSliceInfo` elements. pub fn in_ndim(&self) -> usize { - Din::NDIM.unwrap_or_else(|| { + if let Some(ndim) = Din::NDIM { + ndim + } else { self.indices .as_ref() .iter() .filter(|s| !s.is_new_axis()) .count() - }) + } } /// Returns the number of dimensions after calling @@ -468,13 +470,15 @@ where /// `Dout::NDIM.unwrap()`. Otherwise, the value is calculated by iterating /// over the `AxisSliceInfo` elements. pub fn out_ndim(&self) -> usize { - Dout::NDIM.unwrap_or_else(|| { + if let Some(ndim) = Dout::NDIM { + ndim + } else { self.indices .as_ref() .iter() .filter(|s| !s.is_index()) .count() - }) + } } } From 438d69abcfabcbba149cbb70be055ec77cecdc85 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sun, 16 Dec 2018 23:35:23 -0500 Subject: [PATCH 06/28] Derive Clone, Copy, and Debug for NewAxis --- src/slice.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/slice.rs b/src/slice.rs index e7a2fe172..90d67ba16 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -70,6 +70,7 @@ impl Slice { /// Token to represent a new axis in a slice description. /// /// See also the [`s![]`](macro.s!.html) macro. +#[derive(Clone, Copy, Debug)] pub struct NewAxis; /// A slice (range with step), an index, or a new axis token. From 6050df3bbd0b54e62b1871c985df8126be418932 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Mon, 17 Dec 2018 19:02:45 -0500 Subject: [PATCH 07/28] Use stringify! for string literal of type name --- src/slice.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/slice.rs b/src/slice.rs index 90d67ba16..75083f193 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -175,7 +175,7 @@ impl fmt::Display for AxisSliceInfo { write!(f, ";{}", step)?; } } - AxisSliceInfo::NewAxis => write!(f, "NewAxis")?, + AxisSliceInfo::NewAxis => write!(f, stringify!(NewAxis))?, } Ok(()) } From 8d45268435bd8484162efcbd27ad0d3c88c70496 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Mon, 17 Dec 2018 19:06:29 -0500 Subject: [PATCH 08/28] Make step_by panic for variants other than Slice --- src/slice.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/slice.rs b/src/slice.rs index 75083f193..24cbafda2 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -140,6 +140,8 @@ impl AxisSliceInfo { /// /// `step` must be nonzero. /// (This method checks with a debug assertion that `step` is not zero.) + /// + /// **Panics** if `self` is not the `AxisSliceInfo::Slice` variant. #[inline] pub fn step_by(self, step: isize) -> Self { debug_assert_ne!(step, 0, "AxisSliceInfo::step_by: step must be nonzero"); @@ -153,8 +155,7 @@ impl AxisSliceInfo { end, step: orig_step * step, }, - AxisSliceInfo::Index(s) => AxisSliceInfo::Index(s), - AxisSliceInfo::NewAxis => AxisSliceInfo::NewAxis, + _ => panic!("AxisSliceInfo::step_by: `self` must be the `Slice` variant"), } } } From 1d15275eb5b38610bb04145b22e1fc745e90970f Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Mon, 17 Dec 2018 19:53:47 -0500 Subject: [PATCH 09/28] Add DimAdd trait --- src/dimension/mod.rs | 2 + src/dimension/ops.rs | 95 ++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 2 +- 3 files changed, 98 insertions(+), 1 deletion(-) create mode 100644 src/dimension/ops.rs diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index 28bdc7f91..a83c49015 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -19,6 +19,7 @@ pub use self::dim::*; pub use self::dimension_trait::Dimension; pub use self::dynindeximpl::IxDynImpl; pub use self::ndindex::NdIndex; +pub use self::ops::DimAdd; pub use self::remove_axis::RemoveAxis; use crate::shape_builder::Strides; @@ -36,6 +37,7 @@ pub mod dim; mod dimension_trait; mod dynindeximpl; mod ndindex; +mod ops; mod remove_axis; /// Calculate offset from `Ix` stride converting sign properly diff --git a/src/dimension/ops.rs b/src/dimension/ops.rs new file mode 100644 index 000000000..c31d67412 --- /dev/null +++ b/src/dimension/ops.rs @@ -0,0 +1,95 @@ +use crate::imp_prelude::*; + +/// Adds the two dimensions at compile time. +pub trait DimAdd: Dimension { + /// The sum of the two dimensions. + type Out: Dimension; +} + +macro_rules! impl_dimadd_const_out_const { + ($lhs:expr, $rhs:expr) => { + impl DimAdd> for Dim<[usize; $lhs]> { + type Out = Dim<[usize; $lhs + $rhs]>; + } + }; +} + +macro_rules! impl_dimadd_const_out_dyn { + ($lhs:expr, IxDyn) => { + impl DimAdd for Dim<[usize; $lhs]> { + type Out = IxDyn; + } + }; + ($lhs:expr, $rhs:expr) => { + impl DimAdd> for Dim<[usize; $lhs]> { + type Out = IxDyn; + } + }; +} + +impl_dimadd_const_out_const!(0, 0); +impl_dimadd_const_out_const!(0, 1); +impl_dimadd_const_out_const!(0, 2); +impl_dimadd_const_out_const!(0, 3); +impl_dimadd_const_out_const!(0, 4); +impl_dimadd_const_out_const!(0, 5); +impl_dimadd_const_out_const!(0, 6); +impl_dimadd_const_out_dyn!(0, IxDyn); + +impl_dimadd_const_out_const!(1, 0); +impl_dimadd_const_out_const!(1, 1); +impl_dimadd_const_out_const!(1, 2); +impl_dimadd_const_out_const!(1, 3); +impl_dimadd_const_out_const!(1, 4); +impl_dimadd_const_out_const!(1, 5); +impl_dimadd_const_out_dyn!(1, 6); +impl_dimadd_const_out_dyn!(1, IxDyn); + +impl_dimadd_const_out_const!(2, 0); +impl_dimadd_const_out_const!(2, 1); +impl_dimadd_const_out_const!(2, 2); +impl_dimadd_const_out_const!(2, 3); +impl_dimadd_const_out_const!(2, 4); +impl_dimadd_const_out_dyn!(2, 5); +impl_dimadd_const_out_dyn!(2, 6); +impl_dimadd_const_out_dyn!(2, IxDyn); + +impl_dimadd_const_out_const!(3, 0); +impl_dimadd_const_out_const!(3, 1); +impl_dimadd_const_out_const!(3, 2); +impl_dimadd_const_out_const!(3, 3); +impl_dimadd_const_out_dyn!(3, 4); +impl_dimadd_const_out_dyn!(3, 5); +impl_dimadd_const_out_dyn!(3, 6); +impl_dimadd_const_out_dyn!(3, IxDyn); + +impl_dimadd_const_out_const!(4, 0); +impl_dimadd_const_out_const!(4, 1); +impl_dimadd_const_out_const!(4, 2); +impl_dimadd_const_out_dyn!(4, 3); +impl_dimadd_const_out_dyn!(4, 4); +impl_dimadd_const_out_dyn!(4, 5); +impl_dimadd_const_out_dyn!(4, 6); +impl_dimadd_const_out_dyn!(4, IxDyn); + +impl_dimadd_const_out_const!(5, 0); +impl_dimadd_const_out_const!(5, 1); +impl_dimadd_const_out_dyn!(5, 2); +impl_dimadd_const_out_dyn!(5, 3); +impl_dimadd_const_out_dyn!(5, 4); +impl_dimadd_const_out_dyn!(5, 5); +impl_dimadd_const_out_dyn!(5, 6); +impl_dimadd_const_out_dyn!(5, IxDyn); + +impl_dimadd_const_out_const!(6, 0); +impl_dimadd_const_out_dyn!(6, 1); +impl_dimadd_const_out_dyn!(6, 2); +impl_dimadd_const_out_dyn!(6, 3); +impl_dimadd_const_out_dyn!(6, 4); +impl_dimadd_const_out_dyn!(6, 5); +impl_dimadd_const_out_dyn!(6, 6); +impl_dimadd_const_out_dyn!(6, IxDyn); + +impl DimAdd for IxDyn { + type Out = IxDyn; +} diff --git a/src/lib.rs b/src/lib.rs index ae750f3d8..5c1ef755e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -134,8 +134,8 @@ use std::marker::PhantomData; use alloc::sync::Arc; pub use crate::dimension::dim::*; -pub use crate::dimension::DimMax; pub use crate::dimension::{Axis, AxisDescription, Dimension, IntoDimension, RemoveAxis}; +pub use crate::dimension::{DimAdd, DimMax}; pub use crate::dimension::IxDynImpl; pub use crate::dimension::NdIndex; From 41cc4a1248f8b4698a5c98490d4847b5642ac0fd Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Mon, 17 Dec 2018 20:08:22 -0500 Subject: [PATCH 10/28] Replace SliceNextIn/OutDim with SliceArg trait --- src/lib.rs | 2 +- src/slice.rs | 107 +++++++++++++++++++++------------------------------ 2 files changed, 44 insertions(+), 65 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 5c1ef755e..d82a99649 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -141,7 +141,7 @@ pub use crate::dimension::IxDynImpl; pub use crate::dimension::NdIndex; pub use crate::error::{ErrorKind, ShapeError}; pub use crate::indexes::{indices, indices_of}; -pub use crate::slice::{AxisSliceInfo, NewAxis, Slice, SliceInfo, SliceNextInDim, SliceNextOutDim}; +pub use crate::slice::{AxisSliceInfo, NewAxis, Slice, SliceArg, SliceInfo}; use crate::iterators::Baseiter; use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes}; diff --git a/src/slice.rs b/src/slice.rs index 24cbafda2..d19a889e9 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -7,10 +7,10 @@ // except according to those terms. use crate::dimension::slices_intersect; use crate::error::{ErrorKind, ShapeError}; +use crate::{ArrayViewMut, DimAdd, Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn}; use std::fmt; use std::marker::PhantomData; use std::ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive}; -use crate::{ArrayViewMut, Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn}; /// A slice (range with step size). /// @@ -536,72 +536,51 @@ where } } +/// Trait for determining dimensionality of input and output for [`s!`] macro. #[doc(hidden)] -pub trait SliceNextInDim { - fn next_dim(&self, _: PhantomData) -> PhantomData; -} +pub trait SliceArg { + /// Number of dimensions that this slicing argument consumes in the input array. + type InDim: Dimension; + /// Number of dimensions that this slicing argument produces in the output array. + type OutDim: Dimension; -impl SliceNextInDim for NewAxis { - fn next_dim(&self, _: PhantomData) -> PhantomData { + fn next_in_dim(&self, _: PhantomData) -> PhantomData + where + D: Dimension + DimAdd, + { PhantomData } -} -macro_rules! impl_slicenextindim_larger { - (($($generics:tt)*), $self:ty) => { - impl SliceNextInDim for $self { - fn next_dim(&self, _: PhantomData) -> PhantomData { - PhantomData - } - } + fn next_out_dim(&self, _: PhantomData) -> PhantomData + where + D: Dimension + DimAdd, + { + PhantomData } } -impl_slicenextindim_larger!((), isize); -impl_slicenextindim_larger!((), usize); -impl_slicenextindim_larger!((), i32); -impl_slicenextindim_larger!((T), Range); -impl_slicenextindim_larger!((T), RangeInclusive); -impl_slicenextindim_larger!((T), RangeFrom); -impl_slicenextindim_larger!((T), RangeTo); -impl_slicenextindim_larger!((T), RangeToInclusive); -impl_slicenextindim_larger!((), RangeFull); -impl_slicenextindim_larger!((), Slice); - -#[doc(hidden)] -pub trait SliceNextOutDim { - fn next_dim(&self, _: PhantomData) -> PhantomData; -} -macro_rules! impl_slicenextoutdim_equal { - ($self:ty) => { - impl SliceNextOutDim for $self { - fn next_dim(&self, _: PhantomData) -> PhantomData { - PhantomData - } +macro_rules! impl_slicearg { + (($($generics:tt)*), $self:ty, $in:ty, $out:ty) => { + impl<$($generics)*> SliceArg for $self { + type InDim = $in; + type OutDim = $out; } }; } -impl_slicenextoutdim_equal!(isize); -impl_slicenextoutdim_equal!(usize); -impl_slicenextoutdim_equal!(i32); - -macro_rules! impl_slicenextoutdim_larger { - (($($generics:tt)*), $self:ty) => { - impl SliceNextOutDim for $self { - fn next_dim(&self, _: PhantomData) -> PhantomData { - PhantomData - } - } - } -} -impl_slicenextoutdim_larger!((T), Range); -impl_slicenextoutdim_larger!((T), RangeInclusive); -impl_slicenextoutdim_larger!((T), RangeFrom); -impl_slicenextoutdim_larger!((T), RangeTo); -impl_slicenextoutdim_larger!((T), RangeToInclusive); -impl_slicenextoutdim_larger!((), RangeFull); -impl_slicenextoutdim_larger!((), Slice); -impl_slicenextoutdim_larger!((), NewAxis); + +impl_slicearg!((), isize, Ix1, Ix0); +impl_slicearg!((), usize, Ix1, Ix0); +impl_slicearg!((), i32, Ix1, Ix0); + +impl_slicearg!((T), Range, Ix1, Ix1); +impl_slicearg!((T), RangeInclusive, Ix1, Ix1); +impl_slicearg!((T), RangeFrom, Ix1, Ix1); +impl_slicearg!((T), RangeTo, Ix1, Ix1); +impl_slicearg!((T), RangeToInclusive, Ix1, Ix1); +impl_slicearg!((), RangeFull, Ix1, Ix1); +impl_slicearg!((), Slice, Ix1, Ix1); + +impl_slicearg!((), NewAxis, Ix0, Ix1); /// Slice argument constructor. /// @@ -703,8 +682,8 @@ macro_rules! s( (@parse $in_dim:expr, $out_dim:expr, [$($stack:tt)*] $r:expr;$s:expr) => { match $r { r => { - let in_dim = $crate::SliceNextInDim::next_dim(&r, $in_dim); - let out_dim = $crate::SliceNextOutDim::next_dim(&r, $out_dim); + let in_dim = $crate::SliceArg::next_in_dim(&r, $in_dim); + let out_dim = $crate::SliceArg::next_out_dim(&r, $out_dim); #[allow(unsafe_code)] unsafe { $crate::SliceInfo::new_unchecked( @@ -720,8 +699,8 @@ macro_rules! s( (@parse $in_dim:expr, $out_dim:expr, [$($stack:tt)*] $r:expr) => { match $r { r => { - let in_dim = $crate::SliceNextInDim::next_dim(&r, $in_dim); - let out_dim = $crate::SliceNextOutDim::next_dim(&r, $out_dim); + let in_dim = $crate::SliceArg::next_in_dim(&r, $in_dim); + let out_dim = $crate::SliceArg::next_out_dim(&r, $out_dim); #[allow(unsafe_code)] unsafe { $crate::SliceInfo::new_unchecked( @@ -746,8 +725,8 @@ macro_rules! s( match $r { r => { $crate::s![@parse - $crate::SliceNextInDim::next_dim(&r, $in_dim), - $crate::SliceNextOutDim::next_dim(&r, $out_dim), + $crate::SliceArg::next_in_dim(&r, $in_dim), + $crate::SliceArg::next_out_dim(&r, $out_dim), [$($stack)* $crate::s!(@convert r, $s),] $($t)* ] @@ -759,8 +738,8 @@ macro_rules! s( match $r { r => { $crate::s![@parse - $crate::SliceNextInDim::next_dim(&r, $in_dim), - $crate::SliceNextOutDim::next_dim(&r, $out_dim), + $crate::SliceArg::next_in_dim(&r, $in_dim), + $crate::SliceArg::next_out_dim(&r, $out_dim), [$($stack)* $crate::s!(@convert r),] $($t)* ] From c66ad8ca932374c91b567b6e6442a4d0ea9c2e43 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sun, 7 Feb 2021 16:29:43 -0500 Subject: [PATCH 11/28] Combine DimAdd impls for Ix0 --- src/dimension/ops.rs | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/dimension/ops.rs b/src/dimension/ops.rs index c31d67412..855b3b4e1 100644 --- a/src/dimension/ops.rs +++ b/src/dimension/ops.rs @@ -27,14 +27,9 @@ macro_rules! impl_dimadd_const_out_dyn { }; } -impl_dimadd_const_out_const!(0, 0); -impl_dimadd_const_out_const!(0, 1); -impl_dimadd_const_out_const!(0, 2); -impl_dimadd_const_out_const!(0, 3); -impl_dimadd_const_out_const!(0, 4); -impl_dimadd_const_out_const!(0, 5); -impl_dimadd_const_out_const!(0, 6); -impl_dimadd_const_out_dyn!(0, IxDyn); +impl DimAdd for Ix0 { + type Out = D; +} impl_dimadd_const_out_const!(1, 0); impl_dimadd_const_out_const!(1, 1); From 7776bfcc7069aa4e159a7816b35933d9a7de0a0f Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sun, 14 Feb 2021 03:21:59 -0500 Subject: [PATCH 12/28] Implement CanSlice for [AxisSliceInfo] --- blas-tests/tests/oper.rs | 4 ++-- src/impl_methods.rs | 8 ++++---- src/slice.rs | 32 ++++++++++++++++++-------------- 3 files changed, 24 insertions(+), 20 deletions(-) diff --git a/blas-tests/tests/oper.rs b/blas-tests/tests/oper.rs index 25d26b7ba..6b6797f12 100644 --- a/blas-tests/tests/oper.rs +++ b/blas-tests/tests/oper.rs @@ -6,7 +6,7 @@ extern crate num_traits; use ndarray::linalg::general_mat_mul; use ndarray::linalg::general_mat_vec_mul; use ndarray::prelude::*; -use ndarray::{AxisSliceInfo, Ix, Ixs, SliceInfo}; +use ndarray::{AxisSliceInfo, Ix, Ixs}; use ndarray::{Data, LinalgScalar}; use approx::{assert_abs_diff_eq, assert_relative_eq}; @@ -432,7 +432,7 @@ fn scaled_add_3() { { let mut av = a.slice_mut(s![..;s1, ..;s2]); - let c = c.slice(&SliceInfo::<_, IxDyn, IxDyn>::new(cslice).unwrap()); + let c = c.slice(&*cslice); let mut answerv = answer.slice_mut(s![..;s1, ..;s2]); answerv += &(beta * &c); diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 143824850..fb7097d56 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -343,7 +343,7 @@ where /// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) pub fn slice(&self, info: &I) -> ArrayView<'_, A, I::OutDim> where - I: CanSlice, + I: CanSlice + ?Sized, S: Data, { self.view().slice_move(info) @@ -361,7 +361,7 @@ where /// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) pub fn slice_mut(&mut self, info: &I) -> ArrayViewMut<'_, A, I::OutDim> where - I: CanSlice, + I: CanSlice + ?Sized, S: DataMut, { self.view_mut().slice_move(info) @@ -412,7 +412,7 @@ where /// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) pub fn slice_move(mut self, info: &I) -> ArrayBase where - I: CanSlice, + I: CanSlice + ?Sized, { // Slice and collapse in-place without changing the number of dimensions. self.slice_collapse(info); @@ -464,7 +464,7 @@ where /// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) pub fn slice_collapse(&mut self, info: &I) where - I: CanSlice, + I: CanSlice + ?Sized, { assert_eq!( info.in_ndim(), diff --git a/src/slice.rs b/src/slice.rs index d19a889e9..ed567c8fc 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -360,6 +360,18 @@ where } } +unsafe impl CanSlice for [AxisSliceInfo] { + type OutDim = IxDyn; + + fn in_ndim(&self) -> usize { + self.iter().filter(|s| !s.is_new_axis()).count() + } + + fn out_ndim(&self) -> usize { + self.iter().filter(|s| !s.is_index()).count() + } +} + /// Represents all of the necessary information to perform a slice. /// /// The type `T` is typically `[AxisSliceInfo; n]`, `[AxisSliceInfo]`, or @@ -422,13 +434,13 @@ where /// /// Errors if `Din` or `Dout` is not consistent with `indices`. pub fn new(indices: T) -> Result, ShapeError> { - if let Some(ndim) = Din::NDIM { - if ndim != indices.as_ref().iter().filter(|s| !s.is_new_axis()).count() { + if let Some(in_ndim) = Din::NDIM { + if in_ndim != indices.as_ref().in_ndim() { return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); } } - if let Some(ndim) = Dout::NDIM { - if ndim != indices.as_ref().iter().filter(|s| !s.is_index()).count() { + if let Some(out_ndim) = Dout::NDIM { + if out_ndim != indices.as_ref().out_ndim() { return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); } } @@ -456,11 +468,7 @@ where if let Some(ndim) = Din::NDIM { ndim } else { - self.indices - .as_ref() - .iter() - .filter(|s| !s.is_new_axis()) - .count() + self.indices.as_ref().in_ndim() } } @@ -475,11 +483,7 @@ where if let Some(ndim) = Dout::NDIM { ndim } else { - self.indices - .as_ref() - .iter() - .filter(|s| !s.is_index()) - .count() + self.indices.as_ref().out_ndim() } } } From ab79d2857b8fcd8ab5d86558590f48aa87b46c61 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sun, 14 Feb 2021 20:28:47 -0500 Subject: [PATCH 13/28] Change SliceInfo to be repr(transparent) --- src/slice.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/slice.rs b/src/slice.rs index ed567c8fc..8d9d85fd6 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -383,7 +383,7 @@ unsafe impl CanSlice for [AxisSliceInfo] { /// /// [`.slice()`]: struct.ArrayBase.html#method.slice #[derive(Debug)] -#[repr(C)] +#[repr(transparent)] pub struct SliceInfo { in_dim: PhantomData, out_dim: PhantomData, From 615113e0eb3a853ff4a56576881e611634558095 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sun, 14 Feb 2021 21:04:42 -0500 Subject: [PATCH 14/28] Add debug assertions to SliceInfo::new_unchecked --- src/slice.rs | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/slice.rs b/src/slice.rs index 8d9d85fd6..204e15785 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -403,6 +403,7 @@ where impl SliceInfo where + T: AsRef<[AxisSliceInfo]>, Din: Dimension, Dout: Dimension, { @@ -410,16 +411,26 @@ where /// /// If you call this method, you are guaranteeing that `in_dim` and /// `out_dim` are consistent with `indices`. + /// + /// **Note:** only unchecked for non-debug builds of `ndarray`. #[doc(hidden)] pub unsafe fn new_unchecked( indices: T, in_dim: PhantomData, out_dim: PhantomData, ) -> SliceInfo { + if cfg!(debug_assertions) { + if let Some(in_ndim) = Din::NDIM { + assert_eq!(in_ndim, indices.as_ref().in_ndim()); + } + if let Some(out_ndim) = Dout::NDIM { + assert_eq!(out_ndim, indices.as_ref().out_ndim()); + } + } SliceInfo { - in_dim: in_dim, - out_dim: out_dim, - indices: indices, + in_dim, + out_dim, + indices, } } } From e66e3c89f1304f770679b45cfc554d83bc8a702f Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Mon, 15 Feb 2021 16:31:11 -0500 Subject: [PATCH 15/28] Fix safety of SliceInfo::new --- src/slice.rs | 19 +++++++++++---- tests/array.rs | 64 ++++++++++++++++++++++++++++---------------------- tests/oper.rs | 2 +- 3 files changed, 51 insertions(+), 34 deletions(-) diff --git a/src/slice.rs b/src/slice.rs index 204e15785..2c0b7d323 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -307,7 +307,8 @@ impl From for AxisSliceInfo { /// /// This trait is unsafe to implement because the implementation must ensure /// that `D`, `Self::OutDim`, `self.in_dim()`, and `self.out_ndim()` are -/// consistent with the `&[AxisSliceInfo]` returned by `self.as_ref()`. +/// consistent with the `&[AxisSliceInfo]` returned by `self.as_ref()` and that +/// `self.as_ref()` always returns the same value when called multiple times. pub unsafe trait CanSlice: AsRef<[AxisSliceInfo]> { type OutDim: Dimension; @@ -409,10 +410,13 @@ where { /// Returns a new `SliceInfo` instance. /// - /// If you call this method, you are guaranteeing that `in_dim` and - /// `out_dim` are consistent with `indices`. - /// /// **Note:** only unchecked for non-debug builds of `ndarray`. + /// + /// # Safety + /// + /// The caller must ensure that `in_dim` and `out_dim` are consistent with + /// `indices` and that `indices.as_ref()` always returns the same value + /// when called multiple times. #[doc(hidden)] pub unsafe fn new_unchecked( indices: T, @@ -444,7 +448,12 @@ where /// Returns a new `SliceInfo` instance. /// /// Errors if `Din` or `Dout` is not consistent with `indices`. - pub fn new(indices: T) -> Result, ShapeError> { + /// + /// # Safety + /// + /// The caller must ensure `indices.as_ref()` always returns the same value + /// when called multiple times. + pub unsafe fn new(indices: T) -> Result, ShapeError> { if let Some(in_ndim) = Din::NDIM { if in_ndim != indices.as_ref().in_ndim() { return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); diff --git a/tests/array.rs b/tests/array.rs index 51f59fcb1..f9bd0d67b 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -216,13 +216,15 @@ fn test_slice_dyninput_array_fixed() { #[test] fn test_slice_array_dyn() { let mut arr = Array3::::zeros((5, 2, 5)); - let info = &SliceInfo::<_, Ix3, IxDyn>::new([ - AxisSliceInfo::from(1..), - AxisSliceInfo::from(1), - AxisSliceInfo::from(NewAxis), - AxisSliceInfo::from(..).step_by(2), - ]) - .unwrap(); + let info = &unsafe { + SliceInfo::<_, Ix3, IxDyn>::new([ + AxisSliceInfo::from(1..), + AxisSliceInfo::from(1), + AxisSliceInfo::from(NewAxis), + AxisSliceInfo::from(..).step_by(2), + ]) + .unwrap() + }; arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); @@ -232,13 +234,15 @@ fn test_slice_array_dyn() { #[test] fn test_slice_dyninput_array_dyn() { let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); - let info = &SliceInfo::<_, Ix3, IxDyn>::new([ - AxisSliceInfo::from(1..), - AxisSliceInfo::from(1), - AxisSliceInfo::from(NewAxis), - AxisSliceInfo::from(..).step_by(2), - ]) - .unwrap(); + let info = &unsafe { + SliceInfo::<_, Ix3, IxDyn>::new([ + AxisSliceInfo::from(1..), + AxisSliceInfo::from(1), + AxisSliceInfo::from(NewAxis), + AxisSliceInfo::from(..).step_by(2), + ]) + .unwrap() + }; arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); @@ -248,13 +252,15 @@ fn test_slice_dyninput_array_dyn() { #[test] fn test_slice_dyninput_vec_fixed() { let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); - let info = &SliceInfo::<_, Ix3, Ix3>::new(vec![ - AxisSliceInfo::from(1..), - AxisSliceInfo::from(1), - AxisSliceInfo::from(NewAxis), - AxisSliceInfo::from(..).step_by(2), - ]) - .unwrap(); + let info = &unsafe { + SliceInfo::<_, Ix3, Ix3>::new(vec![ + AxisSliceInfo::from(1..), + AxisSliceInfo::from(1), + AxisSliceInfo::from(NewAxis), + AxisSliceInfo::from(..).step_by(2), + ]) + .unwrap() + }; arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); @@ -264,13 +270,15 @@ fn test_slice_dyninput_vec_fixed() { #[test] fn test_slice_dyninput_vec_dyn() { let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); - let info = &SliceInfo::<_, Ix3, IxDyn>::new(vec![ - AxisSliceInfo::from(1..), - AxisSliceInfo::from(1), - AxisSliceInfo::from(NewAxis), - AxisSliceInfo::from(..).step_by(2), - ]) - .unwrap(); + let info = &unsafe { + SliceInfo::<_, Ix3, IxDyn>::new(vec![ + AxisSliceInfo::from(1..), + AxisSliceInfo::from(1), + AxisSliceInfo::from(NewAxis), + AxisSliceInfo::from(..).step_by(2), + ]) + .unwrap() + }; arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); diff --git a/tests/oper.rs b/tests/oper.rs index 16f3edbc6..3a4a26ac8 100644 --- a/tests/oper.rs +++ b/tests/oper.rs @@ -595,7 +595,7 @@ fn scaled_add_3() { { let mut av = a.slice_mut(s![..;s1, ..;s2]); - let c = c.slice(&SliceInfo::<_, IxDyn, IxDyn>::new(cslice).unwrap()); + let c = c.slice(&unsafe { SliceInfo::<_, IxDyn, IxDyn>::new(cslice).unwrap() }); let mut answerv = answer.slice_mut(s![..;s1, ..;s2]); answerv += &(beta * &c); From 3ba6ceb8cd0463cb5164e19b295a26b376cd5261 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Mon, 15 Feb 2021 17:21:52 -0500 Subject: [PATCH 16/28] Add some impls of TryFrom for SliceInfo --- src/slice.rs | 114 ++++++++++++++++++++++++++++++++++++++++++------- tests/array.rs | 65 +++++++++++++--------------- tests/oper.rs | 3 +- 3 files changed, 129 insertions(+), 53 deletions(-) diff --git a/src/slice.rs b/src/slice.rs index 2c0b7d323..f2da78038 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -8,6 +8,8 @@ use crate::dimension::slices_intersect; use crate::error::{ErrorKind, ShapeError}; use crate::{ArrayViewMut, DimAdd, Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn}; +use alloc::vec::Vec; +use std::convert::TryFrom; use std::fmt; use std::marker::PhantomData; use std::ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive}; @@ -402,6 +404,24 @@ where } } +fn check_dims_for_sliceinfo(indices: &[AxisSliceInfo]) -> Result<(), ShapeError> +where + Din: Dimension, + Dout: Dimension, +{ + if let Some(in_ndim) = Din::NDIM { + if in_ndim != indices.in_ndim() { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); + } + } + if let Some(out_ndim) = Dout::NDIM { + if out_ndim != indices.out_ndim() { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); + } + } + Ok(()) +} + impl SliceInfo where T: AsRef<[AxisSliceInfo]>, @@ -424,12 +444,8 @@ where out_dim: PhantomData, ) -> SliceInfo { if cfg!(debug_assertions) { - if let Some(in_ndim) = Din::NDIM { - assert_eq!(in_ndim, indices.as_ref().in_ndim()); - } - if let Some(out_ndim) = Dout::NDIM { - assert_eq!(out_ndim, indices.as_ref().out_ndim()); - } + check_dims_for_sliceinfo::(indices.as_ref()) + .expect("`Din` and `Dout` must be consistent with `indices`."); } SliceInfo { in_dim, @@ -449,21 +465,14 @@ where /// /// Errors if `Din` or `Dout` is not consistent with `indices`. /// + /// For common types, a safe alternative is to use `TryFrom` instead. + /// /// # Safety /// /// The caller must ensure `indices.as_ref()` always returns the same value /// when called multiple times. pub unsafe fn new(indices: T) -> Result, ShapeError> { - if let Some(in_ndim) = Din::NDIM { - if in_ndim != indices.as_ref().in_ndim() { - return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); - } - } - if let Some(out_ndim) = Dout::NDIM { - if out_ndim != indices.as_ref().out_ndim() { - return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); - } - } + check_dims_for_sliceinfo::(indices.as_ref())?; Ok(SliceInfo { in_dim: PhantomData, out_dim: PhantomData, @@ -508,6 +517,79 @@ where } } +impl<'a, Din, Dout> TryFrom<&'a [AxisSliceInfo]> for &'a SliceInfo<[AxisSliceInfo], Din, Dout> +where + Din: Dimension, + Dout: Dimension, +{ + type Error = ShapeError; + + fn try_from( + indices: &'a [AxisSliceInfo], + ) -> Result<&'a SliceInfo<[AxisSliceInfo], Din, Dout>, ShapeError> { + check_dims_for_sliceinfo::(indices)?; + unsafe { + // This is okay because we've already checked the correctness of + // `Din` and `Dout`, and the only non-zero-sized member of + // `SliceInfo` is `indices`, so `&SliceInfo<[AxisSliceInfo], Din, + // Dout>` should have the same bitwise representation as + // `&[AxisSliceInfo]`. + Ok(&*(indices as *const [AxisSliceInfo] + as *const SliceInfo<[AxisSliceInfo], Din, Dout>)) + } + } +} + +impl TryFrom> for SliceInfo, Din, Dout> +where + Din: Dimension, + Dout: Dimension, +{ + type Error = ShapeError; + + fn try_from( + indices: Vec, + ) -> Result, Din, Dout>, ShapeError> { + unsafe { + // This is okay because `Vec` always returns the same value for + // `.as_ref()`. + Self::new(indices) + } + } +} + +macro_rules! impl_tryfrom_array_for_sliceinfo { + ($len:expr) => { + impl TryFrom<[AxisSliceInfo; $len]> + for SliceInfo<[AxisSliceInfo; $len], Din, Dout> + where + Din: Dimension, + Dout: Dimension, + { + type Error = ShapeError; + + fn try_from( + indices: [AxisSliceInfo; $len], + ) -> Result, ShapeError> { + unsafe { + // This is okay because `[AxisSliceInfo; N]` always returns + // the same value for `.as_ref()`. + Self::new(indices) + } + } + } + }; +} +impl_tryfrom_array_for_sliceinfo!(0); +impl_tryfrom_array_for_sliceinfo!(1); +impl_tryfrom_array_for_sliceinfo!(2); +impl_tryfrom_array_for_sliceinfo!(3); +impl_tryfrom_array_for_sliceinfo!(4); +impl_tryfrom_array_for_sliceinfo!(5); +impl_tryfrom_array_for_sliceinfo!(6); +impl_tryfrom_array_for_sliceinfo!(7); +impl_tryfrom_array_for_sliceinfo!(8); + impl AsRef<[AxisSliceInfo]> for SliceInfo where T: AsRef<[AxisSliceInfo]>, diff --git a/tests/array.rs b/tests/array.rs index f9bd0d67b..ac4791d6f 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -13,6 +13,7 @@ use ndarray::prelude::*; use ndarray::{arr3, rcarr2}; use ndarray::indices; use ndarray::{AxisSliceInfo, Slice, SliceInfo}; +use std::convert::TryFrom; macro_rules! assert_panics { ($body:expr) => { @@ -216,15 +217,13 @@ fn test_slice_dyninput_array_fixed() { #[test] fn test_slice_array_dyn() { let mut arr = Array3::::zeros((5, 2, 5)); - let info = &unsafe { - SliceInfo::<_, Ix3, IxDyn>::new([ - AxisSliceInfo::from(1..), - AxisSliceInfo::from(1), - AxisSliceInfo::from(NewAxis), - AxisSliceInfo::from(..).step_by(2), - ]) - .unwrap() - }; + let info = &SliceInfo::<_, Ix3, IxDyn>::try_from([ + AxisSliceInfo::from(1..), + AxisSliceInfo::from(1), + AxisSliceInfo::from(NewAxis), + AxisSliceInfo::from(..).step_by(2), + ]) + .unwrap(); arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); @@ -234,15 +233,13 @@ fn test_slice_array_dyn() { #[test] fn test_slice_dyninput_array_dyn() { let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); - let info = &unsafe { - SliceInfo::<_, Ix3, IxDyn>::new([ - AxisSliceInfo::from(1..), - AxisSliceInfo::from(1), - AxisSliceInfo::from(NewAxis), - AxisSliceInfo::from(..).step_by(2), - ]) - .unwrap() - }; + let info = &SliceInfo::<_, Ix3, IxDyn>::try_from([ + AxisSliceInfo::from(1..), + AxisSliceInfo::from(1), + AxisSliceInfo::from(NewAxis), + AxisSliceInfo::from(..).step_by(2), + ]) + .unwrap(); arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); @@ -252,15 +249,13 @@ fn test_slice_dyninput_array_dyn() { #[test] fn test_slice_dyninput_vec_fixed() { let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); - let info = &unsafe { - SliceInfo::<_, Ix3, Ix3>::new(vec![ - AxisSliceInfo::from(1..), - AxisSliceInfo::from(1), - AxisSliceInfo::from(NewAxis), - AxisSliceInfo::from(..).step_by(2), - ]) - .unwrap() - }; + let info = &SliceInfo::<_, Ix3, Ix3>::try_from(vec![ + AxisSliceInfo::from(1..), + AxisSliceInfo::from(1), + AxisSliceInfo::from(NewAxis), + AxisSliceInfo::from(..).step_by(2), + ]) + .unwrap(); arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); @@ -270,15 +265,13 @@ fn test_slice_dyninput_vec_fixed() { #[test] fn test_slice_dyninput_vec_dyn() { let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); - let info = &unsafe { - SliceInfo::<_, Ix3, IxDyn>::new(vec![ - AxisSliceInfo::from(1..), - AxisSliceInfo::from(1), - AxisSliceInfo::from(NewAxis), - AxisSliceInfo::from(..).step_by(2), - ]) - .unwrap() - }; + let info = &SliceInfo::<_, Ix3, IxDyn>::try_from(vec![ + AxisSliceInfo::from(1..), + AxisSliceInfo::from(1), + AxisSliceInfo::from(NewAxis), + AxisSliceInfo::from(..).step_by(2), + ]) + .unwrap(); arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); diff --git a/tests/oper.rs b/tests/oper.rs index 3a4a26ac8..b91a9bd85 100644 --- a/tests/oper.rs +++ b/tests/oper.rs @@ -562,6 +562,7 @@ fn scaled_add_2() { fn scaled_add_3() { use approx::assert_relative_eq; use ndarray::{SliceInfo, AxisSliceInfo}; + use std::convert::TryFrom; let beta = -2.3; let sizes = vec![ @@ -595,7 +596,7 @@ fn scaled_add_3() { { let mut av = a.slice_mut(s![..;s1, ..;s2]); - let c = c.slice(&unsafe { SliceInfo::<_, IxDyn, IxDyn>::new(cslice).unwrap() }); + let c = c.slice(&SliceInfo::<_, IxDyn, IxDyn>::try_from(cslice).unwrap()); let mut answerv = answer.slice_mut(s![..;s1, ..;s2]); answerv += &(beta * &c); From 815e708d5b05de1bf69099617ad5de1e6f6c1b35 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Mon, 15 Feb 2021 22:31:45 -0500 Subject: [PATCH 17/28] Make slice_move not call slice_collapse This isn't much more code and simplifies the logic somewhat. --- src/impl_methods.rs | 69 ++++++++++++++++++++++++--------------------- 1 file changed, 37 insertions(+), 32 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index fb7097d56..e9f87583f 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -414,44 +414,49 @@ where where I: CanSlice + ?Sized, { - // Slice and collapse in-place without changing the number of dimensions. - self.slice_collapse(info); - + assert_eq!( + info.in_ndim(), + self.ndim(), + "The input dimension of `info` must match the array to be sliced.", + ); let out_ndim = info.out_ndim(); let mut new_dim = I::OutDim::zeros(out_ndim); let mut new_strides = I::OutDim::zeros(out_ndim); - // Write the dim and strides to the correct new axes. - { - let mut old_axis = 0; - let mut new_axis = 0; - info.as_ref().iter().for_each(|ax_info| match ax_info { - AxisSliceInfo::Slice { .. } => { - // Copy the old dim and stride to corresponding axis. - new_dim[new_axis] = self.dim[old_axis]; - new_strides[new_axis] = self.strides[old_axis]; - old_axis += 1; - new_axis += 1; - } - AxisSliceInfo::Index(_) => { - // Skip the old axis since it should be removed. - old_axis += 1; - } - AxisSliceInfo::NewAxis => { - // Set the dim and stride of the new axis. - new_dim[new_axis] = 1; - new_strides[new_axis] = 0; - new_axis += 1; - } - }); - debug_assert_eq!(old_axis, self.ndim()); - debug_assert_eq!(new_axis, out_ndim); - } + let mut old_axis = 0; + let mut new_axis = 0; + info.as_ref().iter().for_each(|&ax_info| match ax_info { + AxisSliceInfo::Slice { start, end, step } => { + // Slice the axis in-place to update the `dim`, `strides`, and `ptr`. + self.slice_axis_inplace(Axis(old_axis), Slice { start, end, step }); + // Copy the sliced dim and stride to corresponding axis. + new_dim[new_axis] = self.dim[old_axis]; + new_strides[new_axis] = self.strides[old_axis]; + old_axis += 1; + new_axis += 1; + } + AxisSliceInfo::Index(index) => { + // Collapse the axis in-place to update the `ptr`. + let i_usize = abs_index(self.len_of(Axis(old_axis)), index); + self.collapse_axis(Axis(old_axis), i_usize); + // Skip copying the axis since it should be removed. Note that + // removing this axis is safe because `.collapse_axis()` panics + // if the index is out-of-bounds, so it will panic if the axis + // is zero length. + old_axis += 1; + } + AxisSliceInfo::NewAxis => { + // Set the dim and stride of the new axis. + new_dim[new_axis] = 1; + new_strides[new_axis] = 0; + new_axis += 1; + } + }); + debug_assert_eq!(old_axis, self.ndim()); + debug_assert_eq!(new_axis, out_ndim); // safe because new dimension, strides allow access to a subset of old data - unsafe { - self.with_strides_dim(new_strides, new_dim) - } + unsafe { self.with_strides_dim(new_strides, new_dim) } } /// Slice the array in place without changing the number of dimensions. From 25a7bb0b699ef302d0682c970535ce26bd72f23c Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Mon, 15 Feb 2021 22:35:13 -0500 Subject: [PATCH 18/28] Make slice_collapse return Err(_) for NewAxis --- examples/axis_ops.rs | 4 +-- serialization-tests/tests/serialize.rs | 6 ++-- src/impl_methods.rs | 36 +++++++++++++-------- tests/array.rs | 44 +++++++++++++------------- tests/iterators.rs | 4 +-- 5 files changed, 52 insertions(+), 42 deletions(-) diff --git a/examples/axis_ops.rs b/examples/axis_ops.rs index 624af32c3..1ff4a3105 100644 --- a/examples/axis_ops.rs +++ b/examples/axis_ops.rs @@ -51,7 +51,7 @@ fn main() { } a.swap_axes(0, 1); a.swap_axes(0, 2); - a.slice_collapse(s![.., ..;-1, ..]); + a.slice_collapse(s![.., ..;-1, ..]).unwrap(); regularize(&mut a).ok(); let mut b = Array::::zeros((2, 3, 4)); @@ -68,6 +68,6 @@ fn main() { for (i, elt) in (0..).zip(&mut a) { *elt = i; } - a.slice_collapse(s![..;-1, ..;2, ..]); + a.slice_collapse(s![..;-1, ..;2, ..]).unwrap(); regularize(&mut a).ok(); } diff --git a/serialization-tests/tests/serialize.rs b/serialization-tests/tests/serialize.rs index efb3bacd9..0afaffe1b 100644 --- a/serialization-tests/tests/serialize.rs +++ b/serialization-tests/tests/serialize.rs @@ -46,7 +46,7 @@ fn serial_many_dim_serde() { { // Test a sliced array. let mut a = ArcArray::linspace(0., 31., 32).reshape((2, 2, 2, 4)); - a.slice_collapse(s![..;-1, .., .., ..2]); + a.slice_collapse(s![..;-1, .., .., ..2]).unwrap(); let serial = serde_json::to_string(&a).unwrap(); println!("Encode {:?} => {:?}", a, serial); let res = serde_json::from_str::>(&serial); @@ -156,7 +156,7 @@ fn serial_many_dim_serde_msgpack() { { // Test a sliced array. let mut a = ArcArray::linspace(0., 31., 32).reshape((2, 2, 2, 4)); - a.slice_collapse(s![..;-1, .., .., ..2]); + a.slice_collapse(s![..;-1, .., .., ..2]).unwrap(); let mut buf = Vec::new(); serde::Serialize::serialize(&a, &mut rmp_serde::Serializer::new(&mut buf)) @@ -209,7 +209,7 @@ fn serial_many_dim_ron() { { // Test a sliced array. let mut a = ArcArray::linspace(0., 31., 32).reshape((2, 2, 2, 4)); - a.slice_collapse(s![..;-1, .., .., ..2]); + a.slice_collapse(s![..;-1, .., .., ..2]).unwrap(); let a_s = ron_serialize(&a).unwrap(); diff --git a/src/impl_methods.rs b/src/impl_methods.rs index e9f87583f..354eb0534 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -461,13 +461,15 @@ where /// Slice the array in place without changing the number of dimensions. /// - /// Note that `NewAxis` elements in `info` are ignored. + /// If there are any `NewAxis` elements in `info`, slicing is performed + /// using the other elements in `info` (i.e. ignoring the `NewAxis` + /// elements), and `Err(_)` is returned to notify the caller. /// /// See [*Slicing*](#slicing) for full documentation. /// /// **Panics** if an index is out of bounds or step size is zero.
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) - pub fn slice_collapse(&mut self, info: &I) + pub fn slice_collapse(&mut self, info: &I) -> Result<(), ShapeError> where I: CanSlice + ?Sized, { @@ -476,20 +478,28 @@ where self.ndim(), "The input dimension of `info` must match the array to be sliced.", ); + let mut new_axis_in_info = false; let mut axis = 0; info.as_ref().iter().for_each(|&ax_info| match ax_info { - AxisSliceInfo::Slice { start, end, step } => { - self.slice_axis_inplace(Axis(axis), Slice { start, end, step }); - axis += 1; - } - AxisSliceInfo::Index(index) => { - let i_usize = abs_index(self.len_of(Axis(axis)), index); - self.collapse_axis(Axis(axis), i_usize); - axis += 1; - } - AxisSliceInfo::NewAxis => {} - }); + AxisSliceInfo::Slice { start, end, step } => { + self.slice_axis_inplace(Axis(axis), Slice { start, end, step }); + axis += 1; + } + AxisSliceInfo::Index(index) => { + let i_usize = abs_index(self.len_of(Axis(axis)), index); + self.collapse_axis(Axis(axis), i_usize); + axis += 1; + } + AxisSliceInfo::NewAxis => { + new_axis_in_info = true; + } + }); debug_assert_eq!(axis, self.ndim()); + if new_axis_in_info { + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)) + } else { + Ok(()) + } } /// Return a view of the array, sliced along the specified axis. diff --git a/tests/array.rs b/tests/array.rs index ac4791d6f..8e5c01f53 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -103,10 +103,10 @@ fn test_slice_ix0() { #[test] fn test_slice_edge_cases() { let mut arr = Array3::::zeros((3, 4, 5)); - arr.slice_collapse(s![0..0;-1, .., ..]); + arr.slice_collapse(s![0..0;-1, .., ..]).unwrap(); assert_eq!(arr.shape(), &[0, 4, 5]); let mut arr = Array2::::from_shape_vec((1, 1).strides((10, 1)), vec![5]).unwrap(); - arr.slice_collapse(s![1..1, ..]); + arr.slice_collapse(s![1..1, ..]).unwrap(); assert_eq!(arr.shape(), &[0, 1]); } @@ -201,7 +201,7 @@ fn test_slice_array_fixed() { arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); - arr.view().slice_collapse(info); + arr.view().slice_collapse(info).unwrap_err(); } #[test] @@ -211,7 +211,7 @@ fn test_slice_dyninput_array_fixed() { arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); - arr.view().slice_collapse(info); + arr.view().slice_collapse(info).unwrap_err(); } #[test] @@ -227,7 +227,7 @@ fn test_slice_array_dyn() { arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); - arr.view().slice_collapse(info); + arr.view().slice_collapse(info).unwrap_err(); } #[test] @@ -243,7 +243,7 @@ fn test_slice_dyninput_array_dyn() { arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); - arr.view().slice_collapse(info); + arr.view().slice_collapse(info).unwrap_err(); } #[test] @@ -259,7 +259,7 @@ fn test_slice_dyninput_vec_fixed() { arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); - arr.view().slice_collapse(info); + arr.view().slice_collapse(info).unwrap_err(); } #[test] @@ -275,7 +275,7 @@ fn test_slice_dyninput_vec_dyn() { arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); - arr.view().slice_collapse(info); + arr.view().slice_collapse(info).unwrap_err(); } #[test] @@ -324,7 +324,7 @@ fn test_slice_collapse_with_indices() { { let mut vi = arr.view(); - vi.slice_collapse(s![NewAxis, 1.., 2, ..;2]); + vi.slice_collapse(s![NewAxis, 1.., 2, ..;2]).unwrap_err(); assert_eq!(vi.shape(), &[2, 1, 2]); assert!(vi .iter() @@ -332,7 +332,7 @@ fn test_slice_collapse_with_indices() { .all(|(a, b)| a == b)); let mut vi = arr.view(); - vi.slice_collapse(s![1, NewAxis, 2, ..;2]); + vi.slice_collapse(s![1, NewAxis, 2, ..;2]).unwrap_err(); assert_eq!(vi.shape(), &[1, 1, 2]); assert!(vi .iter() @@ -340,7 +340,7 @@ fn test_slice_collapse_with_indices() { .all(|(a, b)| a == b)); let mut vi = arr.view(); - vi.slice_collapse(s![1, 2, NewAxis, 3]); + vi.slice_collapse(s![1, 2, 3]).unwrap(); assert_eq!(vi.shape(), &[1, 1, 1]); assert_eq!(vi, Array3::from_elem((1, 1, 1), arr[(1, 2, 3)])); } @@ -348,7 +348,7 @@ fn test_slice_collapse_with_indices() { // Do it to the ArcArray itself let elem = arr[(1, 2, 3)]; let mut vi = arr; - vi.slice_collapse(s![1, 2, 3, NewAxis]); + vi.slice_collapse(s![1, 2, 3, NewAxis]).unwrap_err(); assert_eq!(vi.shape(), &[1, 1, 1]); assert_eq!(vi, Array3::from_elem((1, 1, 1), elem)); } @@ -567,7 +567,7 @@ fn test_cow() { assert_eq!(n[[0, 1]], 0); assert_eq!(n.get((0, 1)), Some(&0)); let mut rev = mat.reshape(4); - rev.slice_collapse(s![..;-1]); + rev.slice_collapse(s![..;-1]).unwrap(); assert_eq!(rev[0], 4); assert_eq!(rev[1], 3); assert_eq!(rev[2], 2); @@ -591,7 +591,7 @@ fn test_cow_shrink() { // mutation shrinks the array and gives it different strides // let mut mat = ArcArray::zeros((2, 3)); - //mat.slice_collapse(s![.., ..;2]); + //mat.slice_collapse(s![.., ..;2]).unwrap(); mat[[0, 0]] = 1; let n = mat.clone(); mat[[0, 1]] = 2; @@ -606,7 +606,7 @@ fn test_cow_shrink() { assert_eq!(n.get((0, 1)), Some(&0)); // small has non-C strides this way let mut small = mat.reshape(6); - small.slice_collapse(s![4..;-1]); + small.slice_collapse(s![4..;-1]).unwrap(); assert_eq!(small[0], 6); assert_eq!(small[1], 5); let before = small.clone(); @@ -886,7 +886,7 @@ fn assign() { let mut a = arr2(&[[1, 2], [3, 4]]); { let mut v = a.view_mut(); - v.slice_collapse(s![..1, ..]); + v.slice_collapse(s![..1, ..]).unwrap(); v.fill(0); } assert_eq!(a, arr2(&[[0, 0], [3, 4]])); @@ -1093,7 +1093,7 @@ fn owned_array_discontiguous_drop() { .collect(); let mut a = Array::from_shape_vec((2, 6), v).unwrap(); // discontiguous and non-zero offset - a.slice_collapse(s![.., 1..]); + a.slice_collapse(s![.., 1..]).unwrap(); } // each item was dropped exactly once itertools::assert_equal(set.borrow().iter().cloned(), 0..12); @@ -1792,7 +1792,7 @@ fn to_owned_memory_order() { #[test] fn to_owned_neg_stride() { let mut c = arr2(&[[1, 2, 3], [4, 5, 6]]); - c.slice_collapse(s![.., ..;-1]); + c.slice_collapse(s![.., ..;-1]).unwrap(); let co = c.to_owned(); assert_eq!(c, co); assert_eq!(c.strides(), co.strides()); @@ -1801,7 +1801,7 @@ fn to_owned_neg_stride() { #[test] fn discontiguous_owned_to_owned() { let mut c = arr2(&[[1, 2, 3], [4, 5, 6]]); - c.slice_collapse(s![.., ..;2]); + c.slice_collapse(s![.., ..;2]).unwrap(); let co = c.to_owned(); assert_eq!(c.strides(), &[3, 2]); @@ -2062,10 +2062,10 @@ fn test_accumulate_axis_inplace_nonstandard_layout() { fn test_to_vec() { let mut a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]); - a.slice_collapse(s![..;-1, ..]); + a.slice_collapse(s![..;-1, ..]).unwrap(); assert_eq!(a.row(3).to_vec(), vec![1, 2, 3]); assert_eq!(a.column(2).to_vec(), vec![12, 9, 6, 3]); - a.slice_collapse(s![.., ..;-1]); + a.slice_collapse(s![.., ..;-1]).unwrap(); assert_eq!(a.row(3).to_vec(), vec![3, 2, 1]); } @@ -2081,7 +2081,7 @@ fn test_array_clone_unalias() { #[test] fn test_array_clone_same_view() { let mut a = Array::from_iter(0..9).into_shape((3, 3)).unwrap(); - a.slice_collapse(s![..;-1, ..;-1]); + a.slice_collapse(s![..;-1, ..;-1]).unwrap(); let b = a.clone(); assert_eq!(a, b); } diff --git a/tests/iterators.rs b/tests/iterators.rs index 4e4bbc666..20693e228 100644 --- a/tests/iterators.rs +++ b/tests/iterators.rs @@ -332,7 +332,7 @@ fn axis_iter_zip_partially_consumed_discontiguous() { while iter.next().is_some() { consumed += 1; let mut b = Array::zeros((a.len() - consumed) * 2); - b.slice_collapse(s![..;2]); + b.slice_collapse(s![..;2]).unwrap(); Zip::from(&mut b).and(iter.clone()).for_each(|b, a| *b = a[()]); assert_eq!(a.slice(s![consumed..]), b); } @@ -519,7 +519,7 @@ fn axis_iter_mut_zip_partially_consumed_discontiguous() { iter.next(); } let mut b = Array::zeros(remaining * 2); - b.slice_collapse(s![..;2]); + b.slice_collapse(s![..;2]).unwrap(); Zip::from(&mut b).and(iter).for_each(|b, a| *b = a[()]); assert_eq!(a.slice(s![consumed..]), b); } From 5202a5021f0801d2ebd0771cc4c4ab27399d5a5d Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Mon, 15 Feb 2021 22:54:17 -0500 Subject: [PATCH 19/28] Expose CanSlice trait in public API This makes it visible in the docs, but the private marker trick prevents other crates from implementing it. --- src/lib.rs | 2 +- src/slice.rs | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index d82a99649..9a4f67394 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -141,7 +141,7 @@ pub use crate::dimension::IxDynImpl; pub use crate::dimension::NdIndex; pub use crate::error::{ErrorKind, ShapeError}; pub use crate::indexes::{indices, indices_of}; -pub use crate::slice::{AxisSliceInfo, NewAxis, Slice, SliceArg, SliceInfo}; +pub use crate::slice::{AxisSliceInfo, CanSlice, NewAxis, Slice, SliceArg, SliceInfo}; use crate::iterators::Baseiter; use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes}; diff --git a/src/slice.rs b/src/slice.rs index f2da78038..597507246 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -312,11 +312,16 @@ impl From for AxisSliceInfo { /// consistent with the `&[AxisSliceInfo]` returned by `self.as_ref()` and that /// `self.as_ref()` always returns the same value when called multiple times. pub unsafe trait CanSlice: AsRef<[AxisSliceInfo]> { + /// Dimensionality of the output array. type OutDim: Dimension; + /// Returns the number of axes in the input array. fn in_ndim(&self) -> usize; + /// Returns the number of axes in the output array. fn out_ndim(&self) -> usize; + + private_decl! {} } macro_rules! impl_canslice_samedim { @@ -335,6 +340,8 @@ macro_rules! impl_canslice_samedim { fn out_ndim(&self) -> usize { self.out_ndim() } + + private_impl! {} } }; } @@ -361,6 +368,8 @@ where fn out_ndim(&self) -> usize { self.out_ndim() } + + private_impl! {} } unsafe impl CanSlice for [AxisSliceInfo] { @@ -373,6 +382,8 @@ unsafe impl CanSlice for [AxisSliceInfo] { fn out_ndim(&self) -> usize { self.iter().filter(|s| !s.is_index()).count() } + + private_impl! {} } /// Represents all of the necessary information to perform a slice. From 319701de7281638b8fb7d7c93d82c54f4257131c Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Mon, 15 Feb 2021 23:02:50 -0500 Subject: [PATCH 20/28] Expose MultiSlice trait in public API This makes it visible in the docs, but the private marker trick prevents other crates from implementing it. --- src/lib.rs | 2 +- src/slice.rs | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 9a4f67394..a0f28e352 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -141,7 +141,7 @@ pub use crate::dimension::IxDynImpl; pub use crate::dimension::NdIndex; pub use crate::error::{ErrorKind, ShapeError}; pub use crate::indexes::{indices, indices_of}; -pub use crate::slice::{AxisSliceInfo, CanSlice, NewAxis, Slice, SliceArg, SliceInfo}; +pub use crate::slice::{AxisSliceInfo, CanSlice, MultiSlice, NewAxis, Slice, SliceArg, SliceInfo}; use crate::iterators::Baseiter; use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes}; diff --git a/src/slice.rs b/src/slice.rs index 597507246..8021d9a83 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -915,6 +915,8 @@ where /// **Panics** if performing any individual slice panics or if the slices /// are not disjoint (i.e. if they intersect). fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output; + + private_decl! {} } impl<'a, A, D> MultiSlice<'a, A, D> for () @@ -925,6 +927,8 @@ where type Output = (); fn multi_slice_move(&self, _view: ArrayViewMut<'a, A, D>) -> Self::Output {} + + private_impl! {} } impl<'a, A, D, I0> MultiSlice<'a, A, D> for (&I0,) @@ -938,6 +942,8 @@ where fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output { (view.slice_move(self.0),) } + + private_impl! {} } macro_rules! impl_multislice_tuple { @@ -968,6 +974,8 @@ macro_rules! impl_multislice_tuple { ) } } + + private_impl! {} } }; (@intersects_self $shape:expr, ($head:expr,)) => { @@ -996,4 +1004,6 @@ where fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output { T::multi_slice_move(self, view) } + + private_impl! {} } From d5d6482d3b3a524257bca2e06b2cc30036ead440 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Tue, 16 Feb 2021 01:12:24 -0500 Subject: [PATCH 21/28] Add DimAdd bounds to Dimension trait This reduces how often an explicit `DimAdd` bound is necessary. --- src/dimension/dimension_trait.rs | 7 +++++++ src/dimension/ops.rs | 2 +- src/slice.rs | 4 ++-- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/dimension/dimension_trait.rs b/src/dimension/dimension_trait.rs index df38904f4..a49cc815e 100644 --- a/src/dimension/dimension_trait.rs +++ b/src/dimension/dimension_trait.rs @@ -13,6 +13,7 @@ use alloc::vec::Vec; use super::axes_of; use super::conversion::Convert; +use super::ops::DimAdd; use super::{stride_offset, stride_offset_checked}; use crate::itertools::{enumerate, zip}; use crate::{Axis, DimMax}; @@ -51,6 +52,12 @@ pub trait Dimension: + DimMax + DimMax<::Smaller, Output=Self> + DimMax<::Larger, Output=::Larger> + + DimAdd + + DimAdd<::Smaller> + + DimAdd<::Larger> + + DimAdd + + DimAdd::Larger> + + DimAdd { /// For fixed-size dimension representations (e.g. `Ix2`), this should be /// `Some(ndim)`, and for variable-size dimension representations (e.g. diff --git a/src/dimension/ops.rs b/src/dimension/ops.rs index 855b3b4e1..10855f6c7 100644 --- a/src/dimension/ops.rs +++ b/src/dimension/ops.rs @@ -1,7 +1,7 @@ use crate::imp_prelude::*; /// Adds the two dimensions at compile time. -pub trait DimAdd: Dimension { +pub trait DimAdd { /// The sum of the two dimensions. type Out: Dimension; } diff --git a/src/slice.rs b/src/slice.rs index 8021d9a83..e11543b72 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -661,14 +661,14 @@ pub trait SliceArg { /// Number of dimensions that this slicing argument produces in the output array. type OutDim: Dimension; - fn next_in_dim(&self, _: PhantomData) -> PhantomData + fn next_in_dim(&self, _: PhantomData) -> PhantomData<>::Out> where D: Dimension + DimAdd, { PhantomData } - fn next_out_dim(&self, _: PhantomData) -> PhantomData + fn next_out_dim(&self, _: PhantomData) -> PhantomData<>::Out> where D: Dimension + DimAdd, { From 9614b13701b5389e10ae5d2a483cf06265450b15 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Tue, 16 Feb 2021 22:34:35 -0500 Subject: [PATCH 22/28] Revert "Make slice_collapse return Err(_) for NewAxis" This reverts commit 381546267a102d0dfa258b8e037a811a6770445e. --- examples/axis_ops.rs | 4 +-- serialization-tests/tests/serialize.rs | 6 ++-- src/impl_methods.rs | 36 ++++++++------------- tests/array.rs | 44 +++++++++++++------------- tests/iterators.rs | 4 +-- 5 files changed, 42 insertions(+), 52 deletions(-) diff --git a/examples/axis_ops.rs b/examples/axis_ops.rs index 1ff4a3105..624af32c3 100644 --- a/examples/axis_ops.rs +++ b/examples/axis_ops.rs @@ -51,7 +51,7 @@ fn main() { } a.swap_axes(0, 1); a.swap_axes(0, 2); - a.slice_collapse(s![.., ..;-1, ..]).unwrap(); + a.slice_collapse(s![.., ..;-1, ..]); regularize(&mut a).ok(); let mut b = Array::::zeros((2, 3, 4)); @@ -68,6 +68,6 @@ fn main() { for (i, elt) in (0..).zip(&mut a) { *elt = i; } - a.slice_collapse(s![..;-1, ..;2, ..]).unwrap(); + a.slice_collapse(s![..;-1, ..;2, ..]); regularize(&mut a).ok(); } diff --git a/serialization-tests/tests/serialize.rs b/serialization-tests/tests/serialize.rs index 0afaffe1b..efb3bacd9 100644 --- a/serialization-tests/tests/serialize.rs +++ b/serialization-tests/tests/serialize.rs @@ -46,7 +46,7 @@ fn serial_many_dim_serde() { { // Test a sliced array. let mut a = ArcArray::linspace(0., 31., 32).reshape((2, 2, 2, 4)); - a.slice_collapse(s![..;-1, .., .., ..2]).unwrap(); + a.slice_collapse(s![..;-1, .., .., ..2]); let serial = serde_json::to_string(&a).unwrap(); println!("Encode {:?} => {:?}", a, serial); let res = serde_json::from_str::>(&serial); @@ -156,7 +156,7 @@ fn serial_many_dim_serde_msgpack() { { // Test a sliced array. let mut a = ArcArray::linspace(0., 31., 32).reshape((2, 2, 2, 4)); - a.slice_collapse(s![..;-1, .., .., ..2]).unwrap(); + a.slice_collapse(s![..;-1, .., .., ..2]); let mut buf = Vec::new(); serde::Serialize::serialize(&a, &mut rmp_serde::Serializer::new(&mut buf)) @@ -209,7 +209,7 @@ fn serial_many_dim_ron() { { // Test a sliced array. let mut a = ArcArray::linspace(0., 31., 32).reshape((2, 2, 2, 4)); - a.slice_collapse(s![..;-1, .., .., ..2]).unwrap(); + a.slice_collapse(s![..;-1, .., .., ..2]); let a_s = ron_serialize(&a).unwrap(); diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 354eb0534..e9f87583f 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -461,15 +461,13 @@ where /// Slice the array in place without changing the number of dimensions. /// - /// If there are any `NewAxis` elements in `info`, slicing is performed - /// using the other elements in `info` (i.e. ignoring the `NewAxis` - /// elements), and `Err(_)` is returned to notify the caller. + /// Note that `NewAxis` elements in `info` are ignored. /// /// See [*Slicing*](#slicing) for full documentation. /// /// **Panics** if an index is out of bounds or step size is zero.
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) - pub fn slice_collapse(&mut self, info: &I) -> Result<(), ShapeError> + pub fn slice_collapse(&mut self, info: &I) where I: CanSlice + ?Sized, { @@ -478,28 +476,20 @@ where self.ndim(), "The input dimension of `info` must match the array to be sliced.", ); - let mut new_axis_in_info = false; let mut axis = 0; info.as_ref().iter().for_each(|&ax_info| match ax_info { - AxisSliceInfo::Slice { start, end, step } => { - self.slice_axis_inplace(Axis(axis), Slice { start, end, step }); - axis += 1; - } - AxisSliceInfo::Index(index) => { - let i_usize = abs_index(self.len_of(Axis(axis)), index); - self.collapse_axis(Axis(axis), i_usize); - axis += 1; - } - AxisSliceInfo::NewAxis => { - new_axis_in_info = true; - } - }); + AxisSliceInfo::Slice { start, end, step } => { + self.slice_axis_inplace(Axis(axis), Slice { start, end, step }); + axis += 1; + } + AxisSliceInfo::Index(index) => { + let i_usize = abs_index(self.len_of(Axis(axis)), index); + self.collapse_axis(Axis(axis), i_usize); + axis += 1; + } + AxisSliceInfo::NewAxis => {} + }); debug_assert_eq!(axis, self.ndim()); - if new_axis_in_info { - Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)) - } else { - Ok(()) - } } /// Return a view of the array, sliced along the specified axis. diff --git a/tests/array.rs b/tests/array.rs index 8e5c01f53..ac4791d6f 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -103,10 +103,10 @@ fn test_slice_ix0() { #[test] fn test_slice_edge_cases() { let mut arr = Array3::::zeros((3, 4, 5)); - arr.slice_collapse(s![0..0;-1, .., ..]).unwrap(); + arr.slice_collapse(s![0..0;-1, .., ..]); assert_eq!(arr.shape(), &[0, 4, 5]); let mut arr = Array2::::from_shape_vec((1, 1).strides((10, 1)), vec![5]).unwrap(); - arr.slice_collapse(s![1..1, ..]).unwrap(); + arr.slice_collapse(s![1..1, ..]); assert_eq!(arr.shape(), &[0, 1]); } @@ -201,7 +201,7 @@ fn test_slice_array_fixed() { arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); - arr.view().slice_collapse(info).unwrap_err(); + arr.view().slice_collapse(info); } #[test] @@ -211,7 +211,7 @@ fn test_slice_dyninput_array_fixed() { arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); - arr.view().slice_collapse(info).unwrap_err(); + arr.view().slice_collapse(info); } #[test] @@ -227,7 +227,7 @@ fn test_slice_array_dyn() { arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); - arr.view().slice_collapse(info).unwrap_err(); + arr.view().slice_collapse(info); } #[test] @@ -243,7 +243,7 @@ fn test_slice_dyninput_array_dyn() { arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); - arr.view().slice_collapse(info).unwrap_err(); + arr.view().slice_collapse(info); } #[test] @@ -259,7 +259,7 @@ fn test_slice_dyninput_vec_fixed() { arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); - arr.view().slice_collapse(info).unwrap_err(); + arr.view().slice_collapse(info); } #[test] @@ -275,7 +275,7 @@ fn test_slice_dyninput_vec_dyn() { arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); - arr.view().slice_collapse(info).unwrap_err(); + arr.view().slice_collapse(info); } #[test] @@ -324,7 +324,7 @@ fn test_slice_collapse_with_indices() { { let mut vi = arr.view(); - vi.slice_collapse(s![NewAxis, 1.., 2, ..;2]).unwrap_err(); + vi.slice_collapse(s![NewAxis, 1.., 2, ..;2]); assert_eq!(vi.shape(), &[2, 1, 2]); assert!(vi .iter() @@ -332,7 +332,7 @@ fn test_slice_collapse_with_indices() { .all(|(a, b)| a == b)); let mut vi = arr.view(); - vi.slice_collapse(s![1, NewAxis, 2, ..;2]).unwrap_err(); + vi.slice_collapse(s![1, NewAxis, 2, ..;2]); assert_eq!(vi.shape(), &[1, 1, 2]); assert!(vi .iter() @@ -340,7 +340,7 @@ fn test_slice_collapse_with_indices() { .all(|(a, b)| a == b)); let mut vi = arr.view(); - vi.slice_collapse(s![1, 2, 3]).unwrap(); + vi.slice_collapse(s![1, 2, NewAxis, 3]); assert_eq!(vi.shape(), &[1, 1, 1]); assert_eq!(vi, Array3::from_elem((1, 1, 1), arr[(1, 2, 3)])); } @@ -348,7 +348,7 @@ fn test_slice_collapse_with_indices() { // Do it to the ArcArray itself let elem = arr[(1, 2, 3)]; let mut vi = arr; - vi.slice_collapse(s![1, 2, 3, NewAxis]).unwrap_err(); + vi.slice_collapse(s![1, 2, 3, NewAxis]); assert_eq!(vi.shape(), &[1, 1, 1]); assert_eq!(vi, Array3::from_elem((1, 1, 1), elem)); } @@ -567,7 +567,7 @@ fn test_cow() { assert_eq!(n[[0, 1]], 0); assert_eq!(n.get((0, 1)), Some(&0)); let mut rev = mat.reshape(4); - rev.slice_collapse(s![..;-1]).unwrap(); + rev.slice_collapse(s![..;-1]); assert_eq!(rev[0], 4); assert_eq!(rev[1], 3); assert_eq!(rev[2], 2); @@ -591,7 +591,7 @@ fn test_cow_shrink() { // mutation shrinks the array and gives it different strides // let mut mat = ArcArray::zeros((2, 3)); - //mat.slice_collapse(s![.., ..;2]).unwrap(); + //mat.slice_collapse(s![.., ..;2]); mat[[0, 0]] = 1; let n = mat.clone(); mat[[0, 1]] = 2; @@ -606,7 +606,7 @@ fn test_cow_shrink() { assert_eq!(n.get((0, 1)), Some(&0)); // small has non-C strides this way let mut small = mat.reshape(6); - small.slice_collapse(s![4..;-1]).unwrap(); + small.slice_collapse(s![4..;-1]); assert_eq!(small[0], 6); assert_eq!(small[1], 5); let before = small.clone(); @@ -886,7 +886,7 @@ fn assign() { let mut a = arr2(&[[1, 2], [3, 4]]); { let mut v = a.view_mut(); - v.slice_collapse(s![..1, ..]).unwrap(); + v.slice_collapse(s![..1, ..]); v.fill(0); } assert_eq!(a, arr2(&[[0, 0], [3, 4]])); @@ -1093,7 +1093,7 @@ fn owned_array_discontiguous_drop() { .collect(); let mut a = Array::from_shape_vec((2, 6), v).unwrap(); // discontiguous and non-zero offset - a.slice_collapse(s![.., 1..]).unwrap(); + a.slice_collapse(s![.., 1..]); } // each item was dropped exactly once itertools::assert_equal(set.borrow().iter().cloned(), 0..12); @@ -1792,7 +1792,7 @@ fn to_owned_memory_order() { #[test] fn to_owned_neg_stride() { let mut c = arr2(&[[1, 2, 3], [4, 5, 6]]); - c.slice_collapse(s![.., ..;-1]).unwrap(); + c.slice_collapse(s![.., ..;-1]); let co = c.to_owned(); assert_eq!(c, co); assert_eq!(c.strides(), co.strides()); @@ -1801,7 +1801,7 @@ fn to_owned_neg_stride() { #[test] fn discontiguous_owned_to_owned() { let mut c = arr2(&[[1, 2, 3], [4, 5, 6]]); - c.slice_collapse(s![.., ..;2]).unwrap(); + c.slice_collapse(s![.., ..;2]); let co = c.to_owned(); assert_eq!(c.strides(), &[3, 2]); @@ -2062,10 +2062,10 @@ fn test_accumulate_axis_inplace_nonstandard_layout() { fn test_to_vec() { let mut a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]); - a.slice_collapse(s![..;-1, ..]).unwrap(); + a.slice_collapse(s![..;-1, ..]); assert_eq!(a.row(3).to_vec(), vec![1, 2, 3]); assert_eq!(a.column(2).to_vec(), vec![12, 9, 6, 3]); - a.slice_collapse(s![.., ..;-1]).unwrap(); + a.slice_collapse(s![.., ..;-1]); assert_eq!(a.row(3).to_vec(), vec![3, 2, 1]); } @@ -2081,7 +2081,7 @@ fn test_array_clone_unalias() { #[test] fn test_array_clone_same_view() { let mut a = Array::from_iter(0..9).into_shape((3, 3)).unwrap(); - a.slice_collapse(s![..;-1, ..;-1]).unwrap(); + a.slice_collapse(s![..;-1, ..;-1]); let b = a.clone(); assert_eq!(a, b); } diff --git a/tests/iterators.rs b/tests/iterators.rs index 20693e228..4e4bbc666 100644 --- a/tests/iterators.rs +++ b/tests/iterators.rs @@ -332,7 +332,7 @@ fn axis_iter_zip_partially_consumed_discontiguous() { while iter.next().is_some() { consumed += 1; let mut b = Array::zeros((a.len() - consumed) * 2); - b.slice_collapse(s![..;2]).unwrap(); + b.slice_collapse(s![..;2]); Zip::from(&mut b).and(iter.clone()).for_each(|b, a| *b = a[()]); assert_eq!(a.slice(s![consumed..]), b); } @@ -519,7 +519,7 @@ fn axis_iter_mut_zip_partially_consumed_discontiguous() { iter.next(); } let mut b = Array::zeros(remaining * 2); - b.slice_collapse(s![..;2]).unwrap(); + b.slice_collapse(s![..;2]); Zip::from(&mut b).and(iter).for_each(|b, a| *b = a[()]); assert_eq!(a.slice(s![consumed..]), b); } From 61cf7c049f6195f80dc51b0064d42ecbe7aa7d03 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Tue, 16 Feb 2021 22:59:28 -0500 Subject: [PATCH 23/28] Make slice_collapse panic on NewAxis --- src/impl_methods.rs | 13 ++++++----- src/lib.rs | 2 +- src/slice.rs | 16 ++++++++------ tests/array.rs | 53 ++++++++++++++++++++++++++++++++++++--------- 4 files changed, 61 insertions(+), 23 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index e9f87583f..0c40f4c0e 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -461,12 +461,15 @@ where /// Slice the array in place without changing the number of dimensions. /// - /// Note that `NewAxis` elements in `info` are ignored. - /// /// See [*Slicing*](#slicing) for full documentation. /// - /// **Panics** if an index is out of bounds or step size is zero.
- /// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) + /// **Panics** in the following cases: + /// + /// - if an index is out of bounds + /// - if a step size is zero + /// - if [`AxisSliceInfo::NewAxis`] is in `info`, e.g. if [`NewAxis`] was + /// used in the [`s!`] macro + /// - if `D` is `IxDyn` and `info` does not match the number of array axes pub fn slice_collapse(&mut self, info: &I) where I: CanSlice + ?Sized, @@ -487,7 +490,7 @@ where self.collapse_axis(Axis(axis), i_usize); axis += 1; } - AxisSliceInfo::NewAxis => {} + AxisSliceInfo::NewAxis => panic!("`slice_collapse` does not support `NewAxis`."), }); debug_assert_eq!(axis, self.ndim()); } diff --git a/src/lib.rs b/src/lib.rs index a0f28e352..9bc5aa08d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -498,7 +498,7 @@ pub type Ixs = isize; /// is selected and the axis is removed; this selects a subview. See /// [*Subviews*](#subviews) for more information about subviews. If a /// [`NewAxis`] instance is used, a new axis is inserted. Note that -/// [`.slice_collapse()`] ignores `NewAxis` elements and behaves like +/// [`.slice_collapse()`] panics on `NewAxis` elements and behaves like /// [`.collapse_axis()`] by preserving the number of dimensions. /// /// [`.slice()`]: #method.slice diff --git a/src/slice.rs b/src/slice.rs index e11543b72..8582ee8bb 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -723,6 +723,7 @@ impl_slicearg!((), NewAxis, Ix0, Ix1); /// * *slice* `;` *step*: a range constructed from the start and end of a [`Slice`] /// instance, with new step size *step*, to use for slicing that axis. /// * *new-axis*: a [`NewAxis`] instance that represents the creation of a new axis. +/// (Except for [`.slice_collapse()`], which panics on [`NewAxis`] elements.) /// /// [`Slice`]: struct.Slice.html /// [`NewAxis`]: struct.NewAxis.html @@ -734,13 +735,14 @@ impl_slicearg!((), NewAxis, Ix0, Ix1); /// `RangeFull` where `I` is `isize`, `usize`, or `i32`. *step* must be a type /// that can be converted to `isize` with the `as` keyword. /// -/// For example `s![0..4;2, 6, 1..5, NewAxis]` is a slice of the first axis for -/// 0..4 with step size 2, a subview of the second axis at index 6, a slice of -/// the third axis for 1..5 with default step size 1, and a new axis of length -/// 1 at the end of the shape. The input array must have 3 dimensions. The -/// resulting slice would have shape `[2, 4, 1]` for [`.slice()`], -/// [`.slice_mut()`], and [`.slice_move()`], and shape `[2, 1, 4]` for -/// [`.slice_collapse()`]. +/// For example, `s![0..4;2, 6, 1..5, NewAxis]` is a slice of the first axis +/// for 0..4 with step size 2, a subview of the second axis at index 6, a slice +/// of the third axis for 1..5 with default step size 1, and a new axis of +/// length 1 at the end of the shape. The input array must have 3 dimensions. +/// The resulting slice would have shape `[2, 4, 1]` for [`.slice()`], +/// [`.slice_mut()`], and [`.slice_move()`], while [`.slice_collapse()`] would +/// panic. Without the `NewAxis`, i.e. `s![0..4;2, 6, 1..5]`, +/// [`.slice_collapse()`] would result in an array of shape `[2, 1, 4]`. /// /// [`.slice()`]: struct.ArrayBase.html#method.slice /// [`.slice_mut()`]: struct.ArrayBase.html#method.slice_mut diff --git a/tests/array.rs b/tests/array.rs index ac4791d6f..fa5da4419 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -201,7 +201,8 @@ fn test_slice_array_fixed() { arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); - arr.view().slice_collapse(info); + let info2 = s![1.., 1, ..;2]; + arr.view().slice_collapse(info2); } #[test] @@ -211,7 +212,8 @@ fn test_slice_dyninput_array_fixed() { arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); - arr.view().slice_collapse(info); + let info2 = s![1.., 1, ..;2]; + arr.view().slice_collapse(info2); } #[test] @@ -227,7 +229,13 @@ fn test_slice_array_dyn() { arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); - arr.view().slice_collapse(info); + let info2 = &SliceInfo::<_, Ix3, IxDyn>::try_from([ + AxisSliceInfo::from(1..), + AxisSliceInfo::from(1), + AxisSliceInfo::from(..).step_by(2), + ]) + .unwrap(); + arr.view().slice_collapse(info2); } #[test] @@ -243,7 +251,13 @@ fn test_slice_dyninput_array_dyn() { arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); - arr.view().slice_collapse(info); + let info2 = &SliceInfo::<_, Ix3, IxDyn>::try_from([ + AxisSliceInfo::from(1..), + AxisSliceInfo::from(1), + AxisSliceInfo::from(..).step_by(2), + ]) + .unwrap(); + arr.view().slice_collapse(info2); } #[test] @@ -259,7 +273,13 @@ fn test_slice_dyninput_vec_fixed() { arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); - arr.view().slice_collapse(info); + let info2 = &SliceInfo::<_, Ix3, Ix2>::try_from(vec![ + AxisSliceInfo::from(1..), + AxisSliceInfo::from(1), + AxisSliceInfo::from(..).step_by(2), + ]) + .unwrap(); + arr.view().slice_collapse(info2); } #[test] @@ -275,7 +295,13 @@ fn test_slice_dyninput_vec_dyn() { arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); - arr.view().slice_collapse(info); + let info2 = &SliceInfo::<_, Ix3, IxDyn>::try_from(vec![ + AxisSliceInfo::from(1..), + AxisSliceInfo::from(1), + AxisSliceInfo::from(..).step_by(2), + ]) + .unwrap(); + arr.view().slice_collapse(info2); } #[test] @@ -324,7 +350,7 @@ fn test_slice_collapse_with_indices() { { let mut vi = arr.view(); - vi.slice_collapse(s![NewAxis, 1.., 2, ..;2]); + vi.slice_collapse(s![1.., 2, ..;2]); assert_eq!(vi.shape(), &[2, 1, 2]); assert!(vi .iter() @@ -332,7 +358,7 @@ fn test_slice_collapse_with_indices() { .all(|(a, b)| a == b)); let mut vi = arr.view(); - vi.slice_collapse(s![1, NewAxis, 2, ..;2]); + vi.slice_collapse(s![1, 2, ..;2]); assert_eq!(vi.shape(), &[1, 1, 2]); assert!(vi .iter() @@ -340,7 +366,7 @@ fn test_slice_collapse_with_indices() { .all(|(a, b)| a == b)); let mut vi = arr.view(); - vi.slice_collapse(s![1, 2, NewAxis, 3]); + vi.slice_collapse(s![1, 2, 3]); assert_eq!(vi.shape(), &[1, 1, 1]); assert_eq!(vi, Array3::from_elem((1, 1, 1), arr[(1, 2, 3)])); } @@ -348,11 +374,18 @@ fn test_slice_collapse_with_indices() { // Do it to the ArcArray itself let elem = arr[(1, 2, 3)]; let mut vi = arr; - vi.slice_collapse(s![1, 2, 3, NewAxis]); + vi.slice_collapse(s![1, 2, 3]); assert_eq!(vi.shape(), &[1, 1, 1]); assert_eq!(vi, Array3::from_elem((1, 1, 1), elem)); } +#[test] +#[should_panic] +fn test_slice_collapse_with_newaxis() { + let mut arr = Array2::::zeros((2, 3)); + arr.slice_collapse(s![0, 0, NewAxis]); +} + #[test] fn test_multislice() { macro_rules! do_test { From 91dbf3f987aa044f775ee3e5480aa642ff857d83 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Tue, 16 Feb 2021 23:01:27 -0500 Subject: [PATCH 24/28] Rename DimAdd::Out to DimAdd::Output --- src/dimension/dimension_trait.rs | 6 +++--- src/dimension/ops.rs | 12 ++++++------ src/slice.rs | 4 ++-- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/dimension/dimension_trait.rs b/src/dimension/dimension_trait.rs index a49cc815e..1eace34a7 100644 --- a/src/dimension/dimension_trait.rs +++ b/src/dimension/dimension_trait.rs @@ -55,9 +55,9 @@ pub trait Dimension: + DimAdd + DimAdd<::Smaller> + DimAdd<::Larger> - + DimAdd - + DimAdd::Larger> - + DimAdd + + DimAdd + + DimAdd::Larger> + + DimAdd { /// For fixed-size dimension representations (e.g. `Ix2`), this should be /// `Some(ndim)`, and for variable-size dimension representations (e.g. diff --git a/src/dimension/ops.rs b/src/dimension/ops.rs index 10855f6c7..dd23216f6 100644 --- a/src/dimension/ops.rs +++ b/src/dimension/ops.rs @@ -3,13 +3,13 @@ use crate::imp_prelude::*; /// Adds the two dimensions at compile time. pub trait DimAdd { /// The sum of the two dimensions. - type Out: Dimension; + type Output: Dimension; } macro_rules! impl_dimadd_const_out_const { ($lhs:expr, $rhs:expr) => { impl DimAdd> for Dim<[usize; $lhs]> { - type Out = Dim<[usize; $lhs + $rhs]>; + type Output = Dim<[usize; $lhs + $rhs]>; } }; } @@ -17,18 +17,18 @@ macro_rules! impl_dimadd_const_out_const { macro_rules! impl_dimadd_const_out_dyn { ($lhs:expr, IxDyn) => { impl DimAdd for Dim<[usize; $lhs]> { - type Out = IxDyn; + type Output = IxDyn; } }; ($lhs:expr, $rhs:expr) => { impl DimAdd> for Dim<[usize; $lhs]> { - type Out = IxDyn; + type Output = IxDyn; } }; } impl DimAdd for Ix0 { - type Out = D; + type Output = D; } impl_dimadd_const_out_const!(1, 0); @@ -86,5 +86,5 @@ impl_dimadd_const_out_dyn!(6, 6); impl_dimadd_const_out_dyn!(6, IxDyn); impl DimAdd for IxDyn { - type Out = IxDyn; + type Output = IxDyn; } diff --git a/src/slice.rs b/src/slice.rs index 8582ee8bb..a59fa13ff 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -661,14 +661,14 @@ pub trait SliceArg { /// Number of dimensions that this slicing argument produces in the output array. type OutDim: Dimension; - fn next_in_dim(&self, _: PhantomData) -> PhantomData<>::Out> + fn next_in_dim(&self, _: PhantomData) -> PhantomData<>::Output> where D: Dimension + DimAdd, { PhantomData } - fn next_out_dim(&self, _: PhantomData) -> PhantomData<>::Out> + fn next_out_dim(&self, _: PhantomData) -> PhantomData<>::Output> where D: Dimension + DimAdd, { From 5dc77bd9f18643677388966a7d45555a9b839cbf Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Tue, 16 Feb 2021 23:32:37 -0500 Subject: [PATCH 25/28] Rename SliceArg to SliceNextDim --- src/lib.rs | 4 +++- src/slice.rs | 44 ++++++++++++++++++++++---------------------- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 9bc5aa08d..0da511b01 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -141,7 +141,9 @@ pub use crate::dimension::IxDynImpl; pub use crate::dimension::NdIndex; pub use crate::error::{ErrorKind, ShapeError}; pub use crate::indexes::{indices, indices_of}; -pub use crate::slice::{AxisSliceInfo, CanSlice, MultiSlice, NewAxis, Slice, SliceArg, SliceInfo}; +pub use crate::slice::{ + AxisSliceInfo, CanSlice, MultiSlice, NewAxis, Slice, SliceInfo, SliceNextDim, +}; use crate::iterators::Baseiter; use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes}; diff --git a/src/slice.rs b/src/slice.rs index a59fa13ff..eea10d3e4 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -655,7 +655,7 @@ where /// Trait for determining dimensionality of input and output for [`s!`] macro. #[doc(hidden)] -pub trait SliceArg { +pub trait SliceNextDim { /// Number of dimensions that this slicing argument consumes in the input array. type InDim: Dimension; /// Number of dimensions that this slicing argument produces in the output array. @@ -676,28 +676,28 @@ pub trait SliceArg { } } -macro_rules! impl_slicearg { +macro_rules! impl_slicenextdim { (($($generics:tt)*), $self:ty, $in:ty, $out:ty) => { - impl<$($generics)*> SliceArg for $self { + impl<$($generics)*> SliceNextDim for $self { type InDim = $in; type OutDim = $out; } }; } -impl_slicearg!((), isize, Ix1, Ix0); -impl_slicearg!((), usize, Ix1, Ix0); -impl_slicearg!((), i32, Ix1, Ix0); +impl_slicenextdim!((), isize, Ix1, Ix0); +impl_slicenextdim!((), usize, Ix1, Ix0); +impl_slicenextdim!((), i32, Ix1, Ix0); -impl_slicearg!((T), Range, Ix1, Ix1); -impl_slicearg!((T), RangeInclusive, Ix1, Ix1); -impl_slicearg!((T), RangeFrom, Ix1, Ix1); -impl_slicearg!((T), RangeTo, Ix1, Ix1); -impl_slicearg!((T), RangeToInclusive, Ix1, Ix1); -impl_slicearg!((), RangeFull, Ix1, Ix1); -impl_slicearg!((), Slice, Ix1, Ix1); +impl_slicenextdim!((T), Range, Ix1, Ix1); +impl_slicenextdim!((T), RangeInclusive, Ix1, Ix1); +impl_slicenextdim!((T), RangeFrom, Ix1, Ix1); +impl_slicenextdim!((T), RangeTo, Ix1, Ix1); +impl_slicenextdim!((T), RangeToInclusive, Ix1, Ix1); +impl_slicenextdim!((), RangeFull, Ix1, Ix1); +impl_slicenextdim!((), Slice, Ix1, Ix1); -impl_slicearg!((), NewAxis, Ix0, Ix1); +impl_slicenextdim!((), NewAxis, Ix0, Ix1); /// Slice argument constructor. /// @@ -801,8 +801,8 @@ macro_rules! s( (@parse $in_dim:expr, $out_dim:expr, [$($stack:tt)*] $r:expr;$s:expr) => { match $r { r => { - let in_dim = $crate::SliceArg::next_in_dim(&r, $in_dim); - let out_dim = $crate::SliceArg::next_out_dim(&r, $out_dim); + let in_dim = $crate::SliceNextDim::next_in_dim(&r, $in_dim); + let out_dim = $crate::SliceNextDim::next_out_dim(&r, $out_dim); #[allow(unsafe_code)] unsafe { $crate::SliceInfo::new_unchecked( @@ -818,8 +818,8 @@ macro_rules! s( (@parse $in_dim:expr, $out_dim:expr, [$($stack:tt)*] $r:expr) => { match $r { r => { - let in_dim = $crate::SliceArg::next_in_dim(&r, $in_dim); - let out_dim = $crate::SliceArg::next_out_dim(&r, $out_dim); + let in_dim = $crate::SliceNextDim::next_in_dim(&r, $in_dim); + let out_dim = $crate::SliceNextDim::next_out_dim(&r, $out_dim); #[allow(unsafe_code)] unsafe { $crate::SliceInfo::new_unchecked( @@ -844,8 +844,8 @@ macro_rules! s( match $r { r => { $crate::s![@parse - $crate::SliceArg::next_in_dim(&r, $in_dim), - $crate::SliceArg::next_out_dim(&r, $out_dim), + $crate::SliceNextDim::next_in_dim(&r, $in_dim), + $crate::SliceNextDim::next_out_dim(&r, $out_dim), [$($stack)* $crate::s!(@convert r, $s),] $($t)* ] @@ -857,8 +857,8 @@ macro_rules! s( match $r { r => { $crate::s![@parse - $crate::SliceArg::next_in_dim(&r, $in_dim), - $crate::SliceArg::next_out_dim(&r, $out_dim), + $crate::SliceNextDim::next_in_dim(&r, $in_dim), + $crate::SliceNextDim::next_out_dim(&r, $out_dim), [$($stack)* $crate::s!(@convert r),] $($t)* ] From 87515c6672000fe1020b0388bda13f5494eb5f1d Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Wed, 17 Feb 2021 00:22:27 -0500 Subject: [PATCH 26/28] Rename CanSlice to SliceArg --- src/dimension/mod.rs | 6 +++--- src/impl_methods.rs | 31 ++++++++++--------------------- src/impl_views/splitting.rs | 6 ++---- src/lib.rs | 2 +- src/slice.rs | 28 ++++++++++++++-------------- 5 files changed, 30 insertions(+), 43 deletions(-) diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index a83c49015..243d5eeea 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -7,7 +7,7 @@ // except according to those terms. use crate::error::{from_kind, ErrorKind, ShapeError}; -use crate::slice::CanSlice; +use crate::slice::SliceArg; use crate::{AxisSliceInfo, Ix, Ixs, Slice}; use num_integer::div_floor; @@ -599,8 +599,8 @@ fn slice_min_max(axis_len: usize, slice: Slice) -> Option<(usize, usize)> { /// Returns `true` iff the slices intersect. pub fn slices_intersect( dim: &D, - indices1: &impl CanSlice, - indices2: &impl CanSlice, + indices1: &impl SliceArg, + indices2: &impl SliceArg, ) -> bool { debug_assert_eq!(indices1.in_ndim(), indices2.in_ndim()); for (&axis_len, &si1, &si2) in izip!( diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 0c40f4c0e..8eaae18a7 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -32,7 +32,7 @@ use crate::iter::{ AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut, ExactChunks, ExactChunksMut, IndexedIter, IndexedIterMut, Iter, IterMut, Lanes, LanesMut, Windows, }; -use crate::slice::{CanSlice, MultiSlice}; +use crate::slice::{MultiSlice, SliceArg}; use crate::stacking::concatenate; use crate::{AxisSliceInfo, NdIndex, Slice}; @@ -334,16 +334,13 @@ where /// Return a sliced view of the array. /// /// See [*Slicing*](#slicing) for full documentation. - /// See also [`SliceInfo`] and [`D::SliceArg`]. - /// - /// [`SliceInfo`]: struct.SliceInfo.html - /// [`D::SliceArg`]: trait.Dimension.html#associatedtype.SliceArg + /// See also [`s!`], [`SliceArg`], and [`SliceInfo`](crate::SliceInfo). /// /// **Panics** if an index is out of bounds or step size is zero.
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) pub fn slice(&self, info: &I) -> ArrayView<'_, A, I::OutDim> where - I: CanSlice + ?Sized, + I: SliceArg + ?Sized, S: Data, { self.view().slice_move(info) @@ -352,16 +349,13 @@ where /// Return a sliced read-write view of the array. /// /// See [*Slicing*](#slicing) for full documentation. - /// See also [`SliceInfo`] and [`D::SliceArg`]. - /// - /// [`SliceInfo`]: struct.SliceInfo.html - /// [`D::SliceArg`]: trait.Dimension.html#associatedtype.SliceArg + /// See also [`s!`], [`SliceArg`], and [`SliceInfo`](crate::SliceInfo). /// /// **Panics** if an index is out of bounds or step size is zero.
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) pub fn slice_mut(&mut self, info: &I) -> ArrayViewMut<'_, A, I::OutDim> where - I: CanSlice + ?Sized, + I: SliceArg + ?Sized, S: DataMut, { self.view_mut().slice_move(info) @@ -370,10 +364,7 @@ where /// Return multiple disjoint, sliced, mutable views of the array. /// /// See [*Slicing*](#slicing) for full documentation. - /// See also [`SliceInfo`] and [`D::SliceArg`]. - /// - /// [`SliceInfo`]: struct.SliceInfo.html - /// [`D::SliceArg`]: trait.Dimension.html#associatedtype.SliceArg + /// See also [`s!`], [`SliceArg`], and [`SliceInfo`](crate::SliceInfo). /// /// **Panics** if any of the following occur: /// @@ -403,16 +394,13 @@ where /// Slice the array, possibly changing the number of dimensions. /// /// See [*Slicing*](#slicing) for full documentation. - /// See also [`SliceInfo`] and [`D::SliceArg`]. - /// - /// [`SliceInfo`]: struct.SliceInfo.html - /// [`D::SliceArg`]: trait.Dimension.html#associatedtype.SliceArg + /// See also [`s!`], [`SliceArg`], and [`SliceInfo`](crate::SliceInfo). /// /// **Panics** if an index is out of bounds or step size is zero.
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) pub fn slice_move(mut self, info: &I) -> ArrayBase where - I: CanSlice + ?Sized, + I: SliceArg + ?Sized, { assert_eq!( info.in_ndim(), @@ -462,6 +450,7 @@ where /// Slice the array in place without changing the number of dimensions. /// /// See [*Slicing*](#slicing) for full documentation. + /// See also [`s!`], [`SliceArg`], and [`SliceInfo`](crate::SliceInfo). /// /// **Panics** in the following cases: /// @@ -472,7 +461,7 @@ where /// - if `D` is `IxDyn` and `info` does not match the number of array axes pub fn slice_collapse(&mut self, info: &I) where - I: CanSlice + ?Sized, + I: SliceArg + ?Sized, { assert_eq!( info.in_ndim(), diff --git a/src/impl_views/splitting.rs b/src/impl_views/splitting.rs index 38d07594a..dd39c7e22 100644 --- a/src/impl_views/splitting.rs +++ b/src/impl_views/splitting.rs @@ -117,12 +117,10 @@ where /// consumes `self` and produces views with lifetimes matching that of /// `self`. /// - /// See [*Slicing*](#slicing) for full documentation. - /// See also [`SliceInfo`] and [`D::SliceArg`]. + /// See [*Slicing*](#slicing) for full documentation. See also [`s!`], + /// [`SliceArg`](crate::SliceArg), and [`SliceInfo`](crate::SliceInfo). /// /// [`.multi_slice_mut()`]: struct.ArrayBase.html#method.multi_slice_mut - /// [`SliceInfo`]: struct.SliceInfo.html - /// [`D::SliceArg`]: trait.Dimension.html#associatedtype.SliceArg /// /// **Panics** if any of the following occur: /// diff --git a/src/lib.rs b/src/lib.rs index 0da511b01..3a6d169cb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -142,7 +142,7 @@ pub use crate::dimension::NdIndex; pub use crate::error::{ErrorKind, ShapeError}; pub use crate::indexes::{indices, indices_of}; pub use crate::slice::{ - AxisSliceInfo, CanSlice, MultiSlice, NewAxis, Slice, SliceInfo, SliceNextDim, + AxisSliceInfo, MultiSlice, NewAxis, Slice, SliceArg, SliceInfo, SliceNextDim, }; use crate::iterators::Baseiter; diff --git a/src/slice.rs b/src/slice.rs index eea10d3e4..c7fceefa7 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -311,7 +311,7 @@ impl From for AxisSliceInfo { /// that `D`, `Self::OutDim`, `self.in_dim()`, and `self.out_ndim()` are /// consistent with the `&[AxisSliceInfo]` returned by `self.as_ref()` and that /// `self.as_ref()` always returns the same value when called multiple times. -pub unsafe trait CanSlice: AsRef<[AxisSliceInfo]> { +pub unsafe trait SliceArg: AsRef<[AxisSliceInfo]> { /// Dimensionality of the output array. type OutDim: Dimension; @@ -324,9 +324,9 @@ pub unsafe trait CanSlice: AsRef<[AxisSliceInfo]> { private_decl! {} } -macro_rules! impl_canslice_samedim { +macro_rules! impl_slicearg_samedim { ($in_dim:ty) => { - unsafe impl CanSlice<$in_dim> for SliceInfo + unsafe impl SliceArg<$in_dim> for SliceInfo where T: AsRef<[AxisSliceInfo]>, Dout: Dimension, @@ -345,15 +345,15 @@ macro_rules! impl_canslice_samedim { } }; } -impl_canslice_samedim!(Ix0); -impl_canslice_samedim!(Ix1); -impl_canslice_samedim!(Ix2); -impl_canslice_samedim!(Ix3); -impl_canslice_samedim!(Ix4); -impl_canslice_samedim!(Ix5); -impl_canslice_samedim!(Ix6); +impl_slicearg_samedim!(Ix0); +impl_slicearg_samedim!(Ix1); +impl_slicearg_samedim!(Ix2); +impl_slicearg_samedim!(Ix3); +impl_slicearg_samedim!(Ix4); +impl_slicearg_samedim!(Ix5); +impl_slicearg_samedim!(Ix6); -unsafe impl CanSlice for SliceInfo +unsafe impl SliceArg for SliceInfo where T: AsRef<[AxisSliceInfo]>, Din: Dimension, @@ -372,7 +372,7 @@ where private_impl! {} } -unsafe impl CanSlice for [AxisSliceInfo] { +unsafe impl SliceArg for [AxisSliceInfo] { type OutDim = IxDyn; fn in_ndim(&self) -> usize { @@ -937,7 +937,7 @@ impl<'a, A, D, I0> MultiSlice<'a, A, D> for (&I0,) where A: 'a, D: Dimension, - I0: CanSlice, + I0: SliceArg, { type Output = (ArrayViewMut<'a, A, I0::OutDim>,); @@ -957,7 +957,7 @@ macro_rules! impl_multislice_tuple { where A: 'a, D: Dimension, - $($all: CanSlice,)* + $($all: SliceArg,)* { type Output = ($(ArrayViewMut<'a, A, $all::OutDim>,)*); From c4efbbfca33e1d993f0216b0bb32a6652a89f74a Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Wed, 17 Feb 2021 00:27:13 -0500 Subject: [PATCH 27/28] Rename MultiSlice to MultiSliceArg --- src/impl_methods.rs | 9 +++++---- src/impl_views/splitting.rs | 9 +++++---- src/lib.rs | 2 +- src/slice.rs | 12 ++++++------ 4 files changed, 17 insertions(+), 15 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 8eaae18a7..afd915786 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -32,7 +32,7 @@ use crate::iter::{ AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut, ExactChunks, ExactChunksMut, IndexedIter, IndexedIterMut, Iter, IterMut, Lanes, LanesMut, Windows, }; -use crate::slice::{MultiSlice, SliceArg}; +use crate::slice::{MultiSliceArg, SliceArg}; use crate::stacking::concatenate; use crate::{AxisSliceInfo, NdIndex, Slice}; @@ -363,8 +363,9 @@ where /// Return multiple disjoint, sliced, mutable views of the array. /// - /// See [*Slicing*](#slicing) for full documentation. - /// See also [`s!`], [`SliceArg`], and [`SliceInfo`](crate::SliceInfo). + /// See [*Slicing*](#slicing) for full documentation. See also + /// [`MultiSliceArg`], [`s!`], [`SliceArg`], and + /// [`SliceInfo`](crate::SliceInfo). /// /// **Panics** if any of the following occur: /// @@ -385,7 +386,7 @@ where /// ``` pub fn multi_slice_mut<'a, M>(&'a mut self, info: M) -> M::Output where - M: MultiSlice<'a, A, D>, + M: MultiSliceArg<'a, A, D>, S: DataMut, { info.multi_slice_move(self.view_mut()) diff --git a/src/impl_views/splitting.rs b/src/impl_views/splitting.rs index dd39c7e22..a36ae4ddb 100644 --- a/src/impl_views/splitting.rs +++ b/src/impl_views/splitting.rs @@ -7,7 +7,7 @@ // except according to those terms. use crate::imp_prelude::*; -use crate::slice::MultiSlice; +use crate::slice::MultiSliceArg; /// Methods for read-only array views. impl<'a, A, D> ArrayView<'a, A, D> @@ -117,8 +117,9 @@ where /// consumes `self` and produces views with lifetimes matching that of /// `self`. /// - /// See [*Slicing*](#slicing) for full documentation. See also [`s!`], - /// [`SliceArg`](crate::SliceArg), and [`SliceInfo`](crate::SliceInfo). + /// See [*Slicing*](#slicing) for full documentation. See also + /// [`MultiSliceArg`], [`s!`], [`SliceArg`](crate::SliceArg), and + /// [`SliceInfo`](crate::SliceInfo). /// /// [`.multi_slice_mut()`]: struct.ArrayBase.html#method.multi_slice_mut /// @@ -129,7 +130,7 @@ where /// * if `D` is `IxDyn` and `info` does not match the number of array axes pub fn multi_slice_move(self, info: M) -> M::Output where - M: MultiSlice<'a, A, D>, + M: MultiSliceArg<'a, A, D>, { info.multi_slice_move(self) } diff --git a/src/lib.rs b/src/lib.rs index 3a6d169cb..f48da4b32 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -142,7 +142,7 @@ pub use crate::dimension::NdIndex; pub use crate::error::{ErrorKind, ShapeError}; pub use crate::indexes::{indices, indices_of}; pub use crate::slice::{ - AxisSliceInfo, MultiSlice, NewAxis, Slice, SliceArg, SliceInfo, SliceNextDim, + AxisSliceInfo, MultiSliceArg, NewAxis, Slice, SliceArg, SliceInfo, SliceNextDim, }; use crate::iterators::Baseiter; diff --git a/src/slice.rs b/src/slice.rs index c7fceefa7..43785555d 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -904,7 +904,7 @@ macro_rules! s( /// /// It's unfortunate that we need `'a` and `A` to be parameters of the trait, /// but they're necessary until Rust supports generic associated types. -pub trait MultiSlice<'a, A, D> +pub trait MultiSliceArg<'a, A, D> where A: 'a, D: Dimension, @@ -921,7 +921,7 @@ where private_decl! {} } -impl<'a, A, D> MultiSlice<'a, A, D> for () +impl<'a, A, D> MultiSliceArg<'a, A, D> for () where A: 'a, D: Dimension, @@ -933,7 +933,7 @@ where private_impl! {} } -impl<'a, A, D, I0> MultiSlice<'a, A, D> for (&I0,) +impl<'a, A, D, I0> MultiSliceArg<'a, A, D> for (&I0,) where A: 'a, D: Dimension, @@ -953,7 +953,7 @@ macro_rules! impl_multislice_tuple { impl_multislice_tuple!(@def_impl ($($but_last,)* $last,), [$($but_last)*] $last); }; (@def_impl ($($all:ident,)*), [$($but_last:ident)*] $last:ident) => { - impl<'a, A, D, $($all,)*> MultiSlice<'a, A, D> for ($(&$all,)*) + impl<'a, A, D, $($all,)*> MultiSliceArg<'a, A, D> for ($(&$all,)*) where A: 'a, D: Dimension, @@ -995,11 +995,11 @@ impl_multislice_tuple!([I0 I1 I2] I3); impl_multislice_tuple!([I0 I1 I2 I3] I4); impl_multislice_tuple!([I0 I1 I2 I3 I4] I5); -impl<'a, A, D, T> MultiSlice<'a, A, D> for &T +impl<'a, A, D, T> MultiSliceArg<'a, A, D> for &T where A: 'a, D: Dimension, - T: MultiSlice<'a, A, D>, + T: MultiSliceArg<'a, A, D>, { type Output = T::Output; From 7506f9051a04057297c799eeccdd94963c4e5633 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Wed, 17 Feb 2021 00:27:34 -0500 Subject: [PATCH 28/28] Clarify docs of .slice_collapse() --- src/impl_methods.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index afd915786..d039d68a6 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -450,6 +450,14 @@ where /// Slice the array in place without changing the number of dimensions. /// + /// In particular, if an axis is sliced with an index, the axis is + /// collapsed, as in [`.collapse_axis()`], rather than removed, as in + /// [`.slice_move()`] or [`.index_axis_move()`]. + /// + /// [`.collapse_axis()`]: #method.collapse_axis + /// [`.slice_move()`]: #method.slice_move + /// [`.index_axis_move()`]: #method.index_axis_move + /// /// See [*Slicing*](#slicing) for full documentation. /// See also [`s!`], [`SliceArg`], and [`SliceInfo`](crate::SliceInfo). ///