Skip to content

Commit 3bb5e7d

Browse files
committed
Add multi_slice_* methods
1 parent 1b6a2ad commit 3bb5e7d

File tree

5 files changed

+334
-1
lines changed

5 files changed

+334
-1
lines changed

src/impl_methods.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use crate::iter::{
2828
AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut, ExactChunks, ExactChunksMut,
2929
IndexedIter, IndexedIterMut, Iter, IterMut, Lanes, LanesMut, Windows,
3030
};
31+
use crate::slice::MultiSlice;
3132
use crate::stacking::stack;
3233
use crate::{NdIndex, Slice, SliceInfo, SliceOrIndex};
3334

@@ -350,6 +351,39 @@ where
350351
self.view_mut().slice_move(info)
351352
}
352353

354+
/// Return multiple disjoint, sliced, mutable views of the array.
355+
///
356+
/// See [*Slicing*](#slicing) for full documentation.
357+
/// See also [`SliceInfo`] and [`D::SliceArg`].
358+
///
359+
/// [`SliceInfo`]: struct.SliceInfo.html
360+
/// [`D::SliceArg`]: trait.Dimension.html#associatedtype.SliceArg
361+
///
362+
/// **Panics** if any of the following occur:
363+
///
364+
/// * if any of the views would intersect (i.e. if any element would appear in multiple slices)
365+
/// * if an index is out of bounds or step size is zero
366+
/// * if `D` is `IxDyn` and `info` does not match the number of array axes
367+
///
368+
/// # Example
369+
///
370+
/// ```
371+
/// use ndarray::{arr2, s};
372+
///
373+
/// let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]);
374+
/// let (mut edges, mut middle) = a.multi_slice_mut((s![.., ..;2], s![.., 1]));
375+
/// edges.fill(1);
376+
/// middle.fill(0);
377+
/// assert_eq!(a, arr2(&[[1, 0, 1], [1, 0, 1]]));
378+
/// ```
379+
pub fn multi_slice_mut<'a, M>(&'a mut self, info: M) -> M::Output
380+
where
381+
M: MultiSlice<'a, A, D>,
382+
S: DataMut,
383+
{
384+
unsafe { info.slice_and_deref(self.raw_view_mut()) }
385+
}
386+
353387
/// Slice the array, possibly changing the number of dimensions.
354388
///
355389
/// See [*Slicing*](#slicing) for full documentation.

src/impl_views/splitting.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
// except according to those terms.
88

99
use crate::imp_prelude::*;
10+
use crate::slice::MultiSlice;
1011

1112
/// Methods for read-only array views.
1213
impl<'a, A, D> ArrayView<'a, A, D>
@@ -109,4 +110,29 @@ where
109110
(left.deref_into_view_mut(), right.deref_into_view_mut())
110111
}
111112
}
113+
114+
/// Split the view into multiple disjoint slices.
115+
///
116+
/// This is similar to [`.multi_slice_mut()`], but `.multi_slice_move()`
117+
/// consumes `self` and produces views with lifetimes matching that of
118+
/// `self`.
119+
///
120+
/// See [*Slicing*](#slicing) for full documentation.
121+
/// See also [`SliceInfo`] and [`D::SliceArg`].
122+
///
123+
/// [`.multi_slice_mut()`]: struct.ArrayBase.html#method.multi_slice_mut
124+
/// [`SliceInfo`]: struct.SliceInfo.html
125+
/// [`D::SliceArg`]: trait.Dimension.html#associatedtype.SliceArg
126+
///
127+
/// **Panics** if any of the following occur:
128+
///
129+
/// * if any of the views would intersect (i.e. if any element would appear in multiple slices)
130+
/// * if an index is out of bounds or step size is zero
131+
/// * if `D` is `IxDyn` and `info` does not match the number of array axes
132+
pub fn multi_slice_move<M>(mut self, info: M) -> M::Output
133+
where
134+
M: MultiSlice<'a, A, D>,
135+
{
136+
unsafe { info.slice_and_deref(self.raw_view_mut()) }
137+
}
112138
}

src/lib.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,13 @@ pub type Ixs = isize;
473473
/// [`.slice_move()`]: #method.slice_move
474474
/// [`.slice_collapse()`]: #method.slice_collapse
475475
///
476+
/// It's possible to take multiple simultaneous *mutable* slices with the
477+
/// [`.multi_slice_mut()`] or (for [`ArrayViewMut`] only)
478+
/// [`.multi_slice_move()`].
479+
///
480+
/// [`.multi_slice_mut()`]: #method.multi_slice_mut
481+
/// [`.multi_slice_move()`]: type.ArrayViewMut.html#method.multi_slice_move
482+
///
476483
/// ```
477484
/// extern crate ndarray;
478485
///
@@ -523,6 +530,20 @@ pub type Ixs = isize;
523530
/// [12, 11, 10]]);
524531
/// assert_eq!(f, g);
525532
/// assert_eq!(f.shape(), &[2, 3]);
533+
///
534+
/// // Let's take two disjoint, mutable slices of a matrix with
535+
/// //
536+
/// // - One containing all the even-index columns in the matrix
537+
/// // - One containing all the odd-index columns in the matrix
538+
/// let mut h = arr2(&[[0, 1, 2, 3],
539+
/// [4, 5, 6, 7]]);
540+
/// let (s0, s1) = h.multi_slice_mut((s![.., ..;2], s![.., 1..;2]));
541+
/// let i = arr2(&[[0, 2],
542+
/// [4, 6]]);
543+
/// let j = arr2(&[[1, 3],
544+
/// [5, 7]]);
545+
/// assert_eq!(s0, i);
546+
/// assert_eq!(s1, j);
526547
/// }
527548
/// ```
528549
///

src/slice.rs

Lines changed: 188 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
66
// option. This file may not be copied, modified, or distributed
77
// except according to those terms.
8+
use crate::dimension::slices_intersect;
89
use crate::error::{ErrorKind, ShapeError};
9-
use crate::Dimension;
10+
use crate::{ArrayViewMut, Dimension, RawArrayViewMut};
1011
use std::fmt;
1112
use std::marker::PhantomData;
1213
use std::ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
@@ -629,3 +630,189 @@ macro_rules! s(
629630
&*&$crate::s![@parse ::std::marker::PhantomData::<$crate::Ix0>, [] $($t)*]
630631
};
631632
);
633+
634+
/// Slicing information describing multiple mutable, disjoint slices.
635+
///
636+
/// It's unfortunate that we need `'out` and `A` to be parameters of the trait,
637+
/// but they're necessary until Rust supports generic associated types.
638+
///
639+
/// # Safety
640+
///
641+
/// Implementers of this trait must ensure that:
642+
///
643+
/// * `.slice_and_deref()` panics or aborts if the slices would intersect, and
644+
///
645+
/// * the `.intersects_self()`, `.intersects_indices()`, and
646+
/// `.intersects_other()` implementations are correct.
647+
pub unsafe trait MultiSlice<'out, A, D>
648+
where
649+
A: 'out,
650+
D: Dimension,
651+
{
652+
/// The type of the slices created by `.slice_and_deref()`.
653+
type Output;
654+
655+
/// Slice the raw view into multiple raw views, and dereference them.
656+
///
657+
/// **Panics** if performing any individual slice panics or if the slices
658+
/// are not disjoint (i.e. if they intersect).
659+
///
660+
/// # Safety
661+
///
662+
/// The caller must ensure that it is safe to mutably dereference the view
663+
/// using the lifetime `'out`.
664+
unsafe fn slice_and_deref(&self, view: RawArrayViewMut<A, D>) -> Self::Output;
665+
666+
/// Returns `true` if slicing an array of the specified `shape` with `self`
667+
/// would result in intersecting slices.
668+
///
669+
/// If `self.intersects_self(&view.raw_dim())` is `true`, then
670+
/// `self.slice_and_deref(view)` must panic.
671+
fn intersects_self(&self, shape: &D) -> bool;
672+
673+
/// Returns `true` if any slices created by slicing an array of the
674+
/// specified `shape` with `self` would intersect with the specified
675+
/// indices.
676+
///
677+
/// Note that even if this returns `false`, `self.intersects_self(shape)`
678+
/// may still return `true`. (`.intersects_indices()` doesn't check for
679+
/// intersections within `self`; it only checks for intersections between
680+
/// `self` and `indices`.)
681+
fn intersects_indices(&self, shape: &D, indices: &D::SliceArg) -> bool;
682+
683+
/// Returns `true` if any slices created by slicing an array of the
684+
/// specified `shape` with `self` would intersect any slices created by
685+
/// slicing the array with `other`.
686+
///
687+
/// Note that even if this returns `false`, `self.intersects_self(shape)`
688+
/// or `other.intersects_self(shape)` may still return `true`.
689+
/// (`.intersects_other()` doesn't check for intersections within `self` or
690+
/// within `other`; it only checks for intersections between `self` and
691+
/// `other`.)
692+
fn intersects_other(&self, shape: &D, other: impl MultiSlice<'out, A, D>) -> bool;
693+
}
694+
695+
unsafe impl<'out, A, D, Do> MultiSlice<'out, A, D> for SliceInfo<D::SliceArg, Do>
696+
where
697+
A: 'out,
698+
D: Dimension,
699+
Do: Dimension,
700+
{
701+
type Output = ArrayViewMut<'out, A, Do>;
702+
703+
unsafe fn slice_and_deref(&self, view: RawArrayViewMut<A, D>) -> Self::Output {
704+
view.slice_move(self).deref_into_view_mut()
705+
}
706+
707+
fn intersects_self(&self, _shape: &D) -> bool {
708+
false
709+
}
710+
711+
fn intersects_indices(&self, shape: &D, indices: &D::SliceArg) -> bool {
712+
slices_intersect(shape, &*self, indices)
713+
}
714+
715+
fn intersects_other(&self, shape: &D, other: impl MultiSlice<'out, A, D>) -> bool {
716+
other.intersects_indices(shape, &*self)
717+
}
718+
}
719+
720+
unsafe impl<'out, A, D> MultiSlice<'out, A, D> for ()
721+
where
722+
A: 'out,
723+
D: Dimension,
724+
{
725+
type Output = ();
726+
727+
unsafe fn slice_and_deref(&self, _view: RawArrayViewMut<A, D>) -> Self::Output {}
728+
729+
fn intersects_self(&self, _shape: &D) -> bool {
730+
false
731+
}
732+
733+
fn intersects_indices(&self, _shape: &D, _indices: &D::SliceArg) -> bool {
734+
false
735+
}
736+
737+
fn intersects_other(&self, _shape: &D, _other: impl MultiSlice<'out, A, D>) -> bool {
738+
false
739+
}
740+
}
741+
742+
macro_rules! impl_multislice_tuple {
743+
($($T:ident,)*) => {
744+
unsafe impl<'out, A, D, $($T,)*> MultiSlice<'out, A, D> for ($($T,)*)
745+
where
746+
A: 'out,
747+
D: Dimension,
748+
$($T: MultiSlice<'out, A, D>,)*
749+
{
750+
type Output = ($($T::Output,)*);
751+
752+
unsafe fn slice_and_deref(&self, view: RawArrayViewMut<A, D>) -> Self::Output {
753+
assert!(!self.intersects_self(&view.raw_dim()));
754+
755+
#[allow(non_snake_case)]
756+
let ($($T,)*) = self;
757+
($($T.slice_and_deref(view.clone()),)*)
758+
}
759+
760+
fn intersects_self(&self, shape: &D) -> bool {
761+
#[allow(non_snake_case)]
762+
let ($($T,)*) = self;
763+
impl_multislice_tuple!(@intersects_self shape, ($($T,)*))
764+
}
765+
766+
fn intersects_indices(&self, shape: &D, indices: &D::SliceArg) -> bool {
767+
#[allow(non_snake_case)]
768+
let ($($T,)*) = self;
769+
$($T.intersects_indices(shape, indices)) ||*
770+
}
771+
772+
fn intersects_other(&self, shape: &D, other: impl MultiSlice<'out, A, D>) -> bool {
773+
#[allow(non_snake_case)]
774+
let ($($T,)*) = self;
775+
$($T.intersects_other(shape, &other)) ||*
776+
}
777+
}
778+
};
779+
(@intersects_self $shape:expr, ($head:expr,)) => {
780+
$head.intersects_self($shape)
781+
};
782+
(@intersects_self $shape:expr, ($head:expr, $($tail:expr,)*)) => {
783+
$head.intersects_self($shape) ||
784+
$($head.intersects_other($shape, &$tail)) ||* ||
785+
impl_multislice_tuple!(@intersects_self $shape, ($($tail,)*))
786+
};
787+
}
788+
impl_multislice_tuple!(T0,);
789+
impl_multislice_tuple!(T0, T1,);
790+
impl_multislice_tuple!(T0, T1, T2,);
791+
impl_multislice_tuple!(T0, T1, T2, T3,);
792+
impl_multislice_tuple!(T0, T1, T2, T3, T4,);
793+
impl_multislice_tuple!(T0, T1, T2, T3, T4, T5,);
794+
795+
unsafe impl<'out, A, D, T> MultiSlice<'out, A, D> for &'_ T
796+
where
797+
A: 'out,
798+
D: Dimension,
799+
T: MultiSlice<'out, A, D>,
800+
{
801+
type Output = T::Output;
802+
803+
unsafe fn slice_and_deref(&self, view: RawArrayViewMut<A, D>) -> Self::Output {
804+
T::slice_and_deref(self, view)
805+
}
806+
807+
fn intersects_self(&self, shape: &D) -> bool {
808+
T::intersects_self(self, shape)
809+
}
810+
811+
fn intersects_indices(&self, shape: &D, indices: &D::SliceArg) -> bool {
812+
T::intersects_indices(self, shape, indices)
813+
}
814+
815+
fn intersects_other(&self, shape: &D, other: impl MultiSlice<'out, A, D>) -> bool {
816+
T::intersects_other(self, shape, other)
817+
}
818+
}

0 commit comments

Comments
 (0)