diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index 6b86af653..2c2962d07 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -545,7 +545,6 @@ fn slice_min_max(axis_len: usize, slice: Slice) -> Option<(usize, usize)> { } /// Returns `true` iff the slices intersect. -#[allow(dead_code)] pub fn slices_intersect( dim: &D, indices1: &D::SliceArg, diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 4fbfcd98a..263c566d2 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -28,6 +28,7 @@ use crate::iter::{ AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut, ExactChunks, ExactChunksMut, IndexedIter, IndexedIterMut, Iter, IterMut, Lanes, LanesMut, Windows, }; +use crate::slice::MultiSlice; use crate::stacking::stack; use crate::{NdIndex, Slice, SliceInfo, SliceOrIndex}; @@ -350,6 +351,39 @@ where self.view_mut().slice_move(info) } + /// Return multiple disjoint, sliced, mutable views of the array. + /// + /// See [*Slicing*](#slicing) for full documentation. + /// See also [`SliceInfo`] and [`D::SliceArg`]. + /// + /// [`SliceInfo`]: struct.SliceInfo.html + /// [`D::SliceArg`]: trait.Dimension.html#associatedtype.SliceArg + /// + /// **Panics** if any of the following occur: + /// + /// * if any of the views would intersect (i.e. if any element would appear in multiple slices) + /// * if an index is out of bounds or step size is zero + /// * if `D` is `IxDyn` and `info` does not match the number of array axes + /// + /// # Example + /// + /// ``` + /// use ndarray::{arr2, s}; + /// + /// let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]); + /// let (mut edges, mut middle) = a.multi_slice_mut((s![.., ..;2], s![.., 1])); + /// edges.fill(1); + /// middle.fill(0); + /// assert_eq!(a, arr2(&[[1, 0, 1], [1, 0, 1]])); + /// ``` + pub fn multi_slice_mut<'a, M>(&'a mut self, info: M) -> M::Output + where + M: MultiSlice<'a, A, D>, + S: DataMut, + { + unsafe { info.slice_and_deref(self.raw_view_mut()) } + } + /// Slice the array, possibly changing the number of dimensions. /// /// See [*Slicing*](#slicing) for full documentation. diff --git a/src/impl_views/splitting.rs b/src/impl_views/splitting.rs index dd6b87a0c..529fac152 100644 --- a/src/impl_views/splitting.rs +++ b/src/impl_views/splitting.rs @@ -7,6 +7,7 @@ // except according to those terms. use crate::imp_prelude::*; +use crate::slice::MultiSlice; /// Methods for read-only array views. impl<'a, A, D> ArrayView<'a, A, D> @@ -109,4 +110,29 @@ where (left.deref_into_view_mut(), right.deref_into_view_mut()) } } + + /// Split the view into multiple disjoint slices. + /// + /// This is similar to [`.multi_slice_mut()`], but `.multi_slice_move()` + /// consumes `self` and produces views with lifetimes matching that of + /// `self`. + /// + /// See [*Slicing*](#slicing) for full documentation. + /// See also [`SliceInfo`] and [`D::SliceArg`]. + /// + /// [`.multi_slice_mut()`]: struct.ArrayBase.html#method.multi_slice_mut + /// [`SliceInfo`]: struct.SliceInfo.html + /// [`D::SliceArg`]: trait.Dimension.html#associatedtype.SliceArg + /// + /// **Panics** if any of the following occur: + /// + /// * if any of the views would intersect (i.e. if any element would appear in multiple slices) + /// * if an index is out of bounds or step size is zero + /// * if `D` is `IxDyn` and `info` does not match the number of array axes + pub fn multi_slice_move(mut self, info: M) -> M::Output + where + M: MultiSlice<'a, A, D>, + { + unsafe { info.slice_and_deref(self.raw_view_mut()) } + } } diff --git a/src/lib.rs b/src/lib.rs index 18d0938cc..dd78fa56d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -473,6 +473,13 @@ pub type Ixs = isize; /// [`.slice_move()`]: #method.slice_move /// [`.slice_collapse()`]: #method.slice_collapse /// +/// It's possible to take multiple simultaneous *mutable* slices with the +/// [`.multi_slice_mut()`] or (for [`ArrayViewMut`] only) +/// [`.multi_slice_move()`]. +/// +/// [`.multi_slice_mut()`]: #method.multi_slice_mut +/// [`.multi_slice_move()`]: type.ArrayViewMut.html#method.multi_slice_move +/// /// ``` /// extern crate ndarray; /// @@ -523,6 +530,20 @@ pub type Ixs = isize; /// [12, 11, 10]]); /// assert_eq!(f, g); /// assert_eq!(f.shape(), &[2, 3]); +/// +/// // Let's take two disjoint, mutable slices of a matrix with +/// // +/// // - One containing all the even-index columns in the matrix +/// // - One containing all the odd-index columns in the matrix +/// let mut h = arr2(&[[0, 1, 2, 3], +/// [4, 5, 6, 7]]); +/// let (s0, s1) = h.multi_slice_mut((s![.., ..;2], s![.., 1..;2])); +/// let i = arr2(&[[0, 2], +/// [4, 6]]); +/// let j = arr2(&[[1, 3], +/// [5, 7]]); +/// assert_eq!(s0, i); +/// assert_eq!(s1, j); /// } /// ``` /// diff --git a/src/slice.rs b/src/slice.rs index 1d0dfa2b0..b2731e1cf 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -5,8 +5,9 @@ // , at your // option. This file may not be copied, modified, or distributed // except according to those terms. +use crate::dimension::slices_intersect; use crate::error::{ErrorKind, ShapeError}; -use crate::Dimension; +use crate::{ArrayViewMut, Dimension, RawArrayViewMut}; use std::fmt; use std::marker::PhantomData; use std::ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive}; @@ -629,3 +630,189 @@ macro_rules! s( &*&$crate::s![@parse ::std::marker::PhantomData::<$crate::Ix0>, [] $($t)*] }; ); + +/// Slicing information describing multiple mutable, disjoint slices. +/// +/// It's unfortunate that we need `'out` and `A` to be parameters of the trait, +/// but they're necessary until Rust supports generic associated types. +/// +/// # Safety +/// +/// Implementers of this trait must ensure that: +/// +/// * `.slice_and_deref()` panics or aborts if the slices would intersect, and +/// +/// * the `.intersects_self()`, `.intersects_indices()`, and +/// `.intersects_other()` implementations are correct. +pub unsafe trait MultiSlice<'out, A, D> +where + A: 'out, + D: Dimension, +{ + /// The type of the slices created by `.slice_and_deref()`. + type Output; + + /// Slice the raw view into multiple raw views, and dereference them. + /// + /// **Panics** if performing any individual slice panics or if the slices + /// are not disjoint (i.e. if they intersect). + /// + /// # Safety + /// + /// The caller must ensure that it is safe to mutably dereference the view + /// using the lifetime `'out`. + unsafe fn slice_and_deref(&self, view: RawArrayViewMut) -> Self::Output; + + /// Returns `true` if slicing an array of the specified `shape` with `self` + /// would result in intersecting slices. + /// + /// If `self.intersects_self(&view.raw_dim())` is `true`, then + /// `self.slice_and_deref(view)` must panic. + fn intersects_self(&self, shape: &D) -> bool; + + /// Returns `true` if any slices created by slicing an array of the + /// specified `shape` with `self` would intersect with the specified + /// indices. + /// + /// Note that even if this returns `false`, `self.intersects_self(shape)` + /// may still return `true`. (`.intersects_indices()` doesn't check for + /// intersections within `self`; it only checks for intersections between + /// `self` and `indices`.) + fn intersects_indices(&self, shape: &D, indices: &D::SliceArg) -> bool; + + /// Returns `true` if any slices created by slicing an array of the + /// specified `shape` with `self` would intersect any slices created by + /// slicing the array with `other`. + /// + /// Note that even if this returns `false`, `self.intersects_self(shape)` + /// or `other.intersects_self(shape)` may still return `true`. + /// (`.intersects_other()` doesn't check for intersections within `self` or + /// within `other`; it only checks for intersections between `self` and + /// `other`.) + fn intersects_other(&self, shape: &D, other: impl MultiSlice<'out, A, D>) -> bool; +} + +unsafe impl<'out, A, D, Do> MultiSlice<'out, A, D> for SliceInfo +where + A: 'out, + D: Dimension, + Do: Dimension, +{ + type Output = ArrayViewMut<'out, A, Do>; + + unsafe fn slice_and_deref(&self, view: RawArrayViewMut) -> Self::Output { + view.slice_move(self).deref_into_view_mut() + } + + fn intersects_self(&self, _shape: &D) -> bool { + false + } + + fn intersects_indices(&self, shape: &D, indices: &D::SliceArg) -> bool { + slices_intersect(shape, &*self, indices) + } + + fn intersects_other(&self, shape: &D, other: impl MultiSlice<'out, A, D>) -> bool { + other.intersects_indices(shape, &*self) + } +} + +unsafe impl<'out, A, D> MultiSlice<'out, A, D> for () +where + A: 'out, + D: Dimension, +{ + type Output = (); + + unsafe fn slice_and_deref(&self, _view: RawArrayViewMut) -> Self::Output {} + + fn intersects_self(&self, _shape: &D) -> bool { + false + } + + fn intersects_indices(&self, _shape: &D, _indices: &D::SliceArg) -> bool { + false + } + + fn intersects_other(&self, _shape: &D, _other: impl MultiSlice<'out, A, D>) -> bool { + false + } +} + +macro_rules! impl_multislice_tuple { + ($($T:ident,)*) => { + unsafe impl<'out, A, D, $($T,)*> MultiSlice<'out, A, D> for ($($T,)*) + where + A: 'out, + D: Dimension, + $($T: MultiSlice<'out, A, D>,)* + { + type Output = ($($T::Output,)*); + + unsafe fn slice_and_deref(&self, view: RawArrayViewMut) -> Self::Output { + assert!(!self.intersects_self(&view.raw_dim())); + + #[allow(non_snake_case)] + let ($($T,)*) = self; + ($($T.slice_and_deref(view.clone()),)*) + } + + fn intersects_self(&self, shape: &D) -> bool { + #[allow(non_snake_case)] + let ($($T,)*) = self; + impl_multislice_tuple!(@intersects_self shape, ($($T,)*)) + } + + fn intersects_indices(&self, shape: &D, indices: &D::SliceArg) -> bool { + #[allow(non_snake_case)] + let ($($T,)*) = self; + $($T.intersects_indices(shape, indices)) ||* + } + + fn intersects_other(&self, shape: &D, other: impl MultiSlice<'out, A, D>) -> bool { + #[allow(non_snake_case)] + let ($($T,)*) = self; + $($T.intersects_other(shape, &other)) ||* + } + } + }; + (@intersects_self $shape:expr, ($head:expr,)) => { + $head.intersects_self($shape) + }; + (@intersects_self $shape:expr, ($head:expr, $($tail:expr,)*)) => { + $head.intersects_self($shape) || + $($head.intersects_other($shape, &$tail)) ||* || + impl_multislice_tuple!(@intersects_self $shape, ($($tail,)*)) + }; +} +impl_multislice_tuple!(T0,); +impl_multislice_tuple!(T0, T1,); +impl_multislice_tuple!(T0, T1, T2,); +impl_multislice_tuple!(T0, T1, T2, T3,); +impl_multislice_tuple!(T0, T1, T2, T3, T4,); +impl_multislice_tuple!(T0, T1, T2, T3, T4, T5,); + +unsafe impl<'out, A, D, T> MultiSlice<'out, A, D> for &'_ T +where + A: 'out, + D: Dimension, + T: MultiSlice<'out, A, D>, +{ + type Output = T::Output; + + unsafe fn slice_and_deref(&self, view: RawArrayViewMut) -> Self::Output { + T::slice_and_deref(self, view) + } + + fn intersects_self(&self, shape: &D) -> bool { + T::intersects_self(self, shape) + } + + fn intersects_indices(&self, shape: &D, indices: &D::SliceArg) -> bool { + T::intersects_indices(self, shape, indices) + } + + fn intersects_other(&self, shape: &D, other: impl MultiSlice<'out, A, D>) -> bool { + T::intersects_other(self, shape, other) + } +} diff --git a/tests/array.rs b/tests/array.rs index ea374bc7c..8aaa697c2 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -15,6 +15,20 @@ use ndarray::{arr3, rcarr2}; use ndarray::{Slice, SliceInfo, SliceOrIndex}; use std::iter::FromIterator; +macro_rules! assert_panics { + ($body:expr) => { + if let Ok(v) = ::std::panic::catch_unwind(|| $body) { + panic!("assertion failed: should_panic; \ + non-panicking result: {:?}", v); + } + }; + ($body:expr, $($arg:tt)*) => { + if let Ok(_) = ::std::panic::catch_unwind(|| $body) { + panic!($($arg)*); + } + }; +} + #[test] fn test_matmul_arcarray() { let mut A = ArcArray::::zeros((2, 3)); @@ -328,6 +342,57 @@ fn test_slice_collapse_with_indices() { assert_eq!(vi, Array3::from_elem((1, 1, 1), elem)); } +#[test] +#[allow(clippy::cognitive_complexity)] +fn test_multislice() { + defmac!(test_multislice arr, s1, s2 => { + let copy = arr.clone(); + assert_eq!( + arr.multi_slice_mut((s1, s2)), + (copy.clone().slice_mut(s1), copy.clone().slice_mut(s2)) + ); + }); + let mut arr = Array1::from_iter(0..48).into_shape((8, 6)).unwrap(); + + assert_eq!(arr.clone().view(), arr.multi_slice_mut(s![.., ..])); + test_multislice!(&mut arr, s![0, ..], s![1, ..]); + test_multislice!(&mut arr, s![0, ..], s![-1, ..]); + test_multislice!(&mut arr, s![0, ..], s![1.., ..]); + test_multislice!(&mut arr, s![1, ..], s![..;2, ..]); + test_multislice!(&mut arr, s![..2, ..], s![2.., ..]); + test_multislice!(&mut arr, s![1..;2, ..], s![..;2, ..]); + test_multislice!(&mut arr, s![..;-2, ..], s![..;2, ..]); + test_multislice!(&mut arr, s![..;12, ..], s![3..;3, ..]); +} + +#[test] +fn test_multislice_intersecting() { + assert_panics!({ + let mut arr = Array2::::zeros((8, 6)); + arr.multi_slice_mut((s![3, ..], s![3, ..])); + }); + assert_panics!({ + let mut arr = Array2::::zeros((8, 6)); + arr.multi_slice_mut((s![3, ..], s![3.., ..])); + }); + assert_panics!({ + let mut arr = Array2::::zeros((8, 6)); + arr.multi_slice_mut((s![3, ..], s![..;3, ..])); + }); + assert_panics!({ + let mut arr = Array2::::zeros((8, 6)); + arr.multi_slice_mut((s![..;6, ..], s![3..;3, ..])); + }); + assert_panics!({ + let mut arr = Array2::::zeros((8, 6)); + arr.multi_slice_mut((s![2, ..], s![..-1;-2, ..])); + }); + { + let mut arr = Array2::::zeros((8, 6)); + arr.multi_slice_mut((s![3, ..], s![-1..;-2, ..])); + } +} + #[should_panic] #[test] fn index_out_of_bounds() {