Skip to content

Add multi_slice_* methods (supports nested tuples) #716

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/dimension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,6 @@ fn slice_min_max(axis_len: usize, slice: Slice) -> Option<(usize, usize)> {
}

/// Returns `true` iff the slices intersect.
#[allow(dead_code)]
pub fn slices_intersect<D: Dimension>(
dim: &D,
indices1: &D::SliceArg,
Expand Down
34 changes: 34 additions & 0 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use crate::iter::{
AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut, ExactChunks, ExactChunksMut,
IndexedIter, IndexedIterMut, Iter, IterMut, Lanes, LanesMut, Windows,
};
use crate::slice::MultiSlice;
use crate::stacking::stack;
use crate::{NdIndex, Slice, SliceInfo, SliceOrIndex};

Expand Down Expand Up @@ -350,6 +351,39 @@ where
self.view_mut().slice_move(info)
}

/// Return multiple disjoint, sliced, mutable views of the array.
///
/// See [*Slicing*](#slicing) for full documentation.
/// See also [`SliceInfo`] and [`D::SliceArg`].
///
/// [`SliceInfo`]: struct.SliceInfo.html
/// [`D::SliceArg`]: trait.Dimension.html#associatedtype.SliceArg
///
/// **Panics** if any of the following occur:
///
/// * if any of the views would intersect (i.e. if any element would appear in multiple slices)
/// * if an index is out of bounds or step size is zero
/// * if `D` is `IxDyn` and `info` does not match the number of array axes
///
/// # Example
///
/// ```
/// use ndarray::{arr2, s};
///
/// let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]);
/// let (mut edges, mut middle) = a.multi_slice_mut((s![.., ..;2], s![.., 1]));
/// edges.fill(1);
/// middle.fill(0);
/// assert_eq!(a, arr2(&[[1, 0, 1], [1, 0, 1]]));
/// ```
pub fn multi_slice_mut<'a, M>(&'a mut self, info: M) -> M::Output
where
M: MultiSlice<'a, A, D>,
S: DataMut,
{
unsafe { info.slice_and_deref(self.raw_view_mut()) }
}

/// Slice the array, possibly changing the number of dimensions.
///
/// See [*Slicing*](#slicing) for full documentation.
Expand Down
26 changes: 26 additions & 0 deletions src/impl_views/splitting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// except according to those terms.

use crate::imp_prelude::*;
use crate::slice::MultiSlice;

/// Methods for read-only array views.
impl<'a, A, D> ArrayView<'a, A, D>
Expand Down Expand Up @@ -109,4 +110,29 @@ where
(left.deref_into_view_mut(), right.deref_into_view_mut())
}
}

/// Split the view into multiple disjoint slices.
///
/// This is similar to [`.multi_slice_mut()`], but `.multi_slice_move()`
/// consumes `self` and produces views with lifetimes matching that of
/// `self`.
///
/// See [*Slicing*](#slicing) for full documentation.
/// See also [`SliceInfo`] and [`D::SliceArg`].
///
/// [`.multi_slice_mut()`]: struct.ArrayBase.html#method.multi_slice_mut
/// [`SliceInfo`]: struct.SliceInfo.html
/// [`D::SliceArg`]: trait.Dimension.html#associatedtype.SliceArg
///
/// **Panics** if any of the following occur:
///
/// * if any of the views would intersect (i.e. if any element would appear in multiple slices)
/// * if an index is out of bounds or step size is zero
/// * if `D` is `IxDyn` and `info` does not match the number of array axes
pub fn multi_slice_move<M>(mut self, info: M) -> M::Output
where
M: MultiSlice<'a, A, D>,
{
unsafe { info.slice_and_deref(self.raw_view_mut()) }
}
}
21 changes: 21 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,13 @@ pub type Ixs = isize;
/// [`.slice_move()`]: #method.slice_move
/// [`.slice_collapse()`]: #method.slice_collapse
///
/// It's possible to take multiple simultaneous *mutable* slices with the
/// [`.multi_slice_mut()`] or (for [`ArrayViewMut`] only)
/// [`.multi_slice_move()`].
///
/// [`.multi_slice_mut()`]: #method.multi_slice_mut
/// [`.multi_slice_move()`]: type.ArrayViewMut.html#method.multi_slice_move
///
/// ```
/// extern crate ndarray;
///
Expand Down Expand Up @@ -523,6 +530,20 @@ pub type Ixs = isize;
/// [12, 11, 10]]);
/// assert_eq!(f, g);
/// assert_eq!(f.shape(), &[2, 3]);
///
/// // Let's take two disjoint, mutable slices of a matrix with
/// //
/// // - One containing all the even-index columns in the matrix
/// // - One containing all the odd-index columns in the matrix
/// let mut h = arr2(&[[0, 1, 2, 3],
/// [4, 5, 6, 7]]);
/// let (s0, s1) = h.multi_slice_mut((s![.., ..;2], s![.., 1..;2]));
/// let i = arr2(&[[0, 2],
/// [4, 6]]);
/// let j = arr2(&[[1, 3],
/// [5, 7]]);
/// assert_eq!(s0, i);
/// assert_eq!(s1, j);
/// }
/// ```
///
Expand Down
189 changes: 188 additions & 1 deletion src/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
use crate::dimension::slices_intersect;
use crate::error::{ErrorKind, ShapeError};
use crate::Dimension;
use crate::{ArrayViewMut, Dimension, RawArrayViewMut};
use std::fmt;
use std::marker::PhantomData;
use std::ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
Expand Down Expand Up @@ -629,3 +630,189 @@ macro_rules! s(
&*&$crate::s![@parse ::std::marker::PhantomData::<$crate::Ix0>, [] $($t)*]
};
);

/// Slicing information describing multiple mutable, disjoint slices.
///
/// It's unfortunate that we need `'out` and `A` to be parameters of the trait,
/// but they're necessary until Rust supports generic associated types.
///
/// # Safety
///
/// Implementers of this trait must ensure that:
///
/// * `.slice_and_deref()` panics or aborts if the slices would intersect, and
///
/// * the `.intersects_self()`, `.intersects_indices()`, and
/// `.intersects_other()` implementations are correct.
pub unsafe trait MultiSlice<'out, A, D>
where
A: 'out,
D: Dimension,
{
/// The type of the slices created by `.slice_and_deref()`.
type Output;

/// Slice the raw view into multiple raw views, and dereference them.
///
/// **Panics** if performing any individual slice panics or if the slices
/// are not disjoint (i.e. if they intersect).
///
/// # Safety
///
/// The caller must ensure that it is safe to mutably dereference the view
/// using the lifetime `'out`.
unsafe fn slice_and_deref(&self, view: RawArrayViewMut<A, D>) -> Self::Output;

/// Returns `true` if slicing an array of the specified `shape` with `self`
/// would result in intersecting slices.
///
/// If `self.intersects_self(&view.raw_dim())` is `true`, then
/// `self.slice_and_deref(view)` must panic.
fn intersects_self(&self, shape: &D) -> bool;

/// Returns `true` if any slices created by slicing an array of the
/// specified `shape` with `self` would intersect with the specified
/// indices.
///
/// Note that even if this returns `false`, `self.intersects_self(shape)`
/// may still return `true`. (`.intersects_indices()` doesn't check for
/// intersections within `self`; it only checks for intersections between
/// `self` and `indices`.)
fn intersects_indices(&self, shape: &D, indices: &D::SliceArg) -> bool;

/// Returns `true` if any slices created by slicing an array of the
/// specified `shape` with `self` would intersect any slices created by
/// slicing the array with `other`.
///
/// Note that even if this returns `false`, `self.intersects_self(shape)`
/// or `other.intersects_self(shape)` may still return `true`.
/// (`.intersects_other()` doesn't check for intersections within `self` or
/// within `other`; it only checks for intersections between `self` and
/// `other`.)
fn intersects_other(&self, shape: &D, other: impl MultiSlice<'out, A, D>) -> bool;
}

unsafe impl<'out, A, D, Do> MultiSlice<'out, A, D> for SliceInfo<D::SliceArg, Do>
where
A: 'out,
D: Dimension,
Do: Dimension,
{
type Output = ArrayViewMut<'out, A, Do>;

unsafe fn slice_and_deref(&self, view: RawArrayViewMut<A, D>) -> Self::Output {
view.slice_move(self).deref_into_view_mut()
}

fn intersects_self(&self, _shape: &D) -> bool {
false
}

fn intersects_indices(&self, shape: &D, indices: &D::SliceArg) -> bool {
slices_intersect(shape, &*self, indices)
}

fn intersects_other(&self, shape: &D, other: impl MultiSlice<'out, A, D>) -> bool {
other.intersects_indices(shape, &*self)
}
}

unsafe impl<'out, A, D> MultiSlice<'out, A, D> for ()
where
A: 'out,
D: Dimension,
{
type Output = ();

unsafe fn slice_and_deref(&self, _view: RawArrayViewMut<A, D>) -> Self::Output {}

fn intersects_self(&self, _shape: &D) -> bool {
false
}

fn intersects_indices(&self, _shape: &D, _indices: &D::SliceArg) -> bool {
false
}

fn intersects_other(&self, _shape: &D, _other: impl MultiSlice<'out, A, D>) -> bool {
false
}
}

macro_rules! impl_multislice_tuple {
($($T:ident,)*) => {
unsafe impl<'out, A, D, $($T,)*> MultiSlice<'out, A, D> for ($($T,)*)
where
A: 'out,
D: Dimension,
$($T: MultiSlice<'out, A, D>,)*
{
type Output = ($($T::Output,)*);

unsafe fn slice_and_deref(&self, view: RawArrayViewMut<A, D>) -> Self::Output {
assert!(!self.intersects_self(&view.raw_dim()));

#[allow(non_snake_case)]
let ($($T,)*) = self;
($($T.slice_and_deref(view.clone()),)*)
}

fn intersects_self(&self, shape: &D) -> bool {
#[allow(non_snake_case)]
let ($($T,)*) = self;
impl_multislice_tuple!(@intersects_self shape, ($($T,)*))
}

fn intersects_indices(&self, shape: &D, indices: &D::SliceArg) -> bool {
#[allow(non_snake_case)]
let ($($T,)*) = self;
$($T.intersects_indices(shape, indices)) ||*
}

fn intersects_other(&self, shape: &D, other: impl MultiSlice<'out, A, D>) -> bool {
#[allow(non_snake_case)]
let ($($T,)*) = self;
$($T.intersects_other(shape, &other)) ||*
}
}
};
(@intersects_self $shape:expr, ($head:expr,)) => {
$head.intersects_self($shape)
};
(@intersects_self $shape:expr, ($head:expr, $($tail:expr,)*)) => {
$head.intersects_self($shape) ||
$($head.intersects_other($shape, &$tail)) ||* ||
impl_multislice_tuple!(@intersects_self $shape, ($($tail,)*))
};
}
impl_multislice_tuple!(T0,);
impl_multislice_tuple!(T0, T1,);
impl_multislice_tuple!(T0, T1, T2,);
impl_multislice_tuple!(T0, T1, T2, T3,);
impl_multislice_tuple!(T0, T1, T2, T3, T4,);
impl_multislice_tuple!(T0, T1, T2, T3, T4, T5,);

unsafe impl<'out, A, D, T> MultiSlice<'out, A, D> for &'_ T
where
A: 'out,
D: Dimension,
T: MultiSlice<'out, A, D>,
{
type Output = T::Output;

unsafe fn slice_and_deref(&self, view: RawArrayViewMut<A, D>) -> Self::Output {
T::slice_and_deref(self, view)
}

fn intersects_self(&self, shape: &D) -> bool {
T::intersects_self(self, shape)
}

fn intersects_indices(&self, shape: &D, indices: &D::SliceArg) -> bool {
T::intersects_indices(self, shape, indices)
}

fn intersects_other(&self, shape: &D, other: impl MultiSlice<'out, A, D>) -> bool {
T::intersects_other(self, shape, other)
}
}
Loading