Skip to content

Commit 97d5c20

Browse files
committed
Add multislice! macro
1 parent 34d738a commit 97d5c20

File tree

3 files changed

+316
-2
lines changed

3 files changed

+316
-2
lines changed

src/lib.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,9 +447,12 @@ pub type Ixs = isize;
447447
/// [`.slice_move()`]: #method.slice_move
448448
/// [`.slice_inplace()`]: #method.slice_inplace
449449
///
450+
/// It's possible to take multiple simultaneous *mutable* slices with the
451+
/// [`multislice!()`](macro.multislice!.html) macro.
452+
///
450453
/// ```
451-
/// // import the s![] macro
452-
/// #[macro_use(s)]
454+
/// // import the multislice!() and s![] macros
455+
/// #[macro_use(multislice, s)]
453456
/// extern crate ndarray;
454457
///
455458
/// use ndarray::{arr2, arr3};
@@ -499,6 +502,20 @@ pub type Ixs = isize;
499502
/// [12, 11, 10]]);
500503
/// assert_eq!(f, g);
501504
/// assert_eq!(f.shape(), &[2, 3]);
505+
///
506+
/// // Let's take two disjoint, mutable slices of a matrix with
507+
/// //
508+
/// // - One containing all the even-index columns in the matrix
509+
/// // - One containing all the odd-index columns in the matrix
510+
/// let mut h = arr2(&[[0, 1, 2, 3],
511+
/// [4, 5, 6, 7]]);
512+
/// let (s0, s1) = multislice!(h, (mut s![.., ..;2], mut s![.., 1..;2]));
513+
/// let i = arr2(&[[0, 2],
514+
/// [4, 6]]);
515+
/// let j = arr2(&[[1, 3],
516+
/// [5, 7]]);
517+
/// assert_eq!(s0, i);
518+
/// assert_eq!(s1, j);
502519
/// }
503520
/// ```
504521
///

src/slice.rs

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,3 +595,214 @@ macro_rules! s(
595595
s![@parse ::std::marker::PhantomData::<$crate::Ix0>, [] $($t)*]
596596
};
597597
);
598+
599+
/// Take multiple slices simultaneously.
600+
///
601+
/// This macro makes it possible to take multiple slices of the same array, as
602+
/// long as Rust's aliasing rules are followed for *elements* in the slices.
603+
/// For example, it's possible to take two disjoint, mutable slices of an
604+
/// array, with one referencing the even-index elements and the other
605+
/// referencing the odd-index elements. If you tried to achieve this by calling
606+
/// `.slice_mut()` twice, the borrow checker would complain about mutably
607+
/// borrowing the array twice (even though it's safe as long as the slices are
608+
/// disjoint).
609+
///
610+
/// The syntax is `multislice!(` *expression, (pattern [, pattern [, …]])* `)`,
611+
/// where *expression* evaluates to an `ArrayBase<S, D>` where `S: DataMut`,
612+
/// and `pattern` is one of the following:
613+
///
614+
/// * `mut expr`: creates an `ArrayViewMut`, where `expr` evaluates to a
615+
/// `&SliceInfo` instance used to slice the array.
616+
/// * `expr`: creates an `ArrayView`, where `expr` evaluates to a `&SliceInfo`
617+
/// instance used to slice the array.
618+
///
619+
/// **Note** that this macro always mutably borrows the array even if there are
620+
/// no `mut` patterns. If all you want to do is take read-only slices, you
621+
/// don't need `multislice!()`; just call
622+
/// [`.slice()`](struct.ArrayBase.html#method.slice) multiple times instead.
623+
///
624+
/// `multislice!()` follows Rust's aliasing rules:
625+
///
626+
/// * An `ArrayViewMut` and `ArrayView` cannot reference the same element.
627+
/// * Two `ArrayViewMut` cannot reference the same element.
628+
/// * Two `ArrayView` can reference the same element.
629+
///
630+
/// **Panics** at runtime if any of the aliasing rules is violated.
631+
///
632+
/// See also [*Slicing*](struct.ArrayBase.html#slicing).
633+
///
634+
/// # Examples
635+
///
636+
/// In this example, there are two overlapping read-only slices, and two
637+
/// disjoint mutable slices. Neither of the mutable slices intersects any of
638+
/// the other slices.
639+
///
640+
/// ```
641+
/// #[macro_use]
642+
/// extern crate ndarray;
643+
///
644+
/// use ndarray::prelude::*;
645+
///
646+
/// # fn main() {
647+
/// let mut arr = Array1::from_iter(0..12);
648+
/// let (a, b, c, d) = multislice!(arr, (s![0..5], mut s![6..;2], s![1..6], mut s![7..;2]));
649+
/// assert_eq!(a, array![0, 1, 2, 3, 4]);
650+
/// assert_eq!(b, array![6, 8, 10]);
651+
/// assert_eq!(c, array![1, 2, 3, 4, 5]);
652+
/// assert_eq!(d, array![7, 9, 11]);
653+
/// # }
654+
/// ```
655+
///
656+
/// These examples panic because they don't follow the aliasing rules:
657+
///
658+
/// * `ArrayViewMut` and `ArrayView` cannot reference the same element.
659+
///
660+
/// ```should_panic
661+
/// # #[macro_use] extern crate ndarray;
662+
/// # use ndarray::prelude::*;
663+
/// # fn main() {
664+
/// let mut arr = Array1::from_iter(0..12);
665+
/// multislice!(arr, (s![0..5], mut s![1..;2])); // panic!
666+
/// # }
667+
/// ```
668+
///
669+
/// * Two `ArrayViewMut` cannot reference the same element.
670+
///
671+
/// ```should_panic
672+
/// # #[macro_use] extern crate ndarray;
673+
/// # use ndarray::prelude::*;
674+
/// # fn main() {
675+
/// let mut arr = Array1::from_iter(0..12);
676+
/// multislice!(arr, (mut s![0..5], mut s![1..;2])); // panic!
677+
/// # }
678+
/// ```
679+
#[macro_export]
680+
macro_rules! multislice(
681+
(
682+
@check $view:expr,
683+
$info:expr,
684+
()
685+
) => {};
686+
// Check that $info doesn't intersect $other.
687+
(
688+
@check $view:expr,
689+
$info:expr,
690+
($other:expr,)
691+
) => {
692+
assert!(
693+
!$crate::slices_intersect(&$view.raw_dim(), $info, $other),
694+
"Slice {:?} must not intersect slice {:?}", $info, $other
695+
)
696+
};
697+
// Check that $info doesn't intersect any of the other info in the tuple.
698+
(
699+
@check $view:expr,
700+
$info:expr,
701+
($other:expr, $($more:tt)*)
702+
) => {
703+
{
704+
multislice!(@check $view, $info, ($other,));
705+
multislice!(@check $view, $info, ($($more)*));
706+
}
707+
};
708+
// Parse last slice (mutable), no trailing comma.
709+
(
710+
@parse $view:expr,
711+
($($sliced:tt)*),
712+
($($mut_info:tt)*),
713+
($($immut_info:tt)*),
714+
(mut $info:expr)
715+
) => {
716+
{
717+
multislice!(@check $view, $info, ($($mut_info)*));
718+
multislice!(@check $view, $info, ($($immut_info)*));
719+
($($sliced)* unsafe { $view.aliasing_view_mut() }.slice_move($info))
720+
}
721+
};
722+
// Parse last slice (read-only), no trailing comma.
723+
(
724+
@parse $view:expr,
725+
($($sliced:tt)*),
726+
($($mut_info:tt)*),
727+
($($immut_info:tt)*),
728+
($info:expr)
729+
) => {
730+
{
731+
multislice!(@check $view, $info, ($($mut_info)*));
732+
($($sliced)* unsafe { $view.aliasing_view() }.slice_move($info))
733+
}
734+
};
735+
// Parse last slice (mutable), with trailing comma.
736+
(
737+
@parse $view:expr,
738+
($($sliced:tt)*),
739+
($($mut_info:tt)*),
740+
($($immut_info:tt)*),
741+
(mut $info:expr,)
742+
) => {
743+
{
744+
multislice!(@check $view, $info, ($($mut_info)*));
745+
multislice!(@check $view, $info, ($($immut_info)*));
746+
($($sliced)* unsafe { $view.aliasing_view_mut() }.slice_move($info))
747+
}
748+
};
749+
// Parse last slice (read-only), with trailing comma.
750+
(
751+
@parse $view:expr,
752+
($($sliced:tt)*),
753+
($($mut_info:tt)*),
754+
($($immut_info:tt)*),
755+
($info:expr,)
756+
) => {
757+
{
758+
multislice!(@check $view, $info, ($($mut_info)*));
759+
($($sliced)* unsafe { $view.aliasing_view() }.slice_move($info))
760+
}
761+
};
762+
// Parse a mutable slice.
763+
(
764+
@parse $view:expr,
765+
($($sliced:tt)*),
766+
($($mut_info:tt)*),
767+
($($immut_info:tt)*),
768+
(mut $info:expr, $($t:tt)*)
769+
) => {
770+
{
771+
multislice!(@check $view, $info, ($($mut_info)*));
772+
multislice!(@check $view, $info, ($($immut_info)*));
773+
multislice!(
774+
@parse $view,
775+
($($sliced)* unsafe { $view.aliasing_view_mut() }.slice_move($info),),
776+
($($mut_info)* $info,),
777+
($($immut_info)*),
778+
($($t)*)
779+
)
780+
}
781+
};
782+
// Parse a read-only slice.
783+
(
784+
@parse $view:expr,
785+
($($sliced:tt)*),
786+
($($mut_info:tt)*),
787+
($($immut_info:tt)*),
788+
($info:expr, $($t:tt)*)
789+
) => {
790+
{
791+
multislice!(@check $view, $info, ($($mut_info)*));
792+
multislice!(
793+
@parse $view,
794+
($($sliced)* unsafe { $view.aliasing_view() }.slice_move($info),),
795+
($($mut_info)*),
796+
($($immut_info)* $info,),
797+
($($t)*)
798+
)
799+
}
800+
};
801+
// Entry point.
802+
($arr:expr, ($($t:tt)*)) => {
803+
{
804+
let view = $crate::ArrayBase::view_mut(&mut $arr);
805+
multislice!(@parse view, (), (), (), ($($t)*))
806+
}
807+
};
808+
);

tests/array.rs

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,20 @@ use ndarray::{
1515
use ndarray::indices;
1616
use itertools::{enumerate, zip};
1717

18+
macro_rules! assert_panics {
19+
($body:expr) => {
20+
if let Ok(v) = ::std::panic::catch_unwind(|| $body) {
21+
panic!("assertion failed: should_panic; \
22+
non-panicking result: {:?}", v);
23+
}
24+
};
25+
($body:expr, $($arg:tt)*) => {
26+
if let Ok(_) = ::std::panic::catch_unwind(|| $body) {
27+
panic!($($arg)*);
28+
}
29+
};
30+
}
31+
1832
#[test]
1933
fn test_matmul_rcarray()
2034
{
@@ -233,6 +247,78 @@ fn test_slice_inplace_with_subview_inplace() {
233247
assert_eq!(vi, Array3::from_elem((1, 1, 1), arr[(1, 2, 3)]));
234248
}
235249

250+
#[test]
251+
fn test_multislice() {
252+
defmac!(test_multislice mut arr, s1, s2 => {
253+
{
254+
let copy = arr.clone();
255+
assert_eq!(
256+
multislice!(arr, (mut s1, mut s2)),
257+
(copy.clone().slice_mut(s1), copy.clone().slice_mut(s2))
258+
);
259+
}
260+
{
261+
let copy = arr.clone();
262+
assert_eq!(
263+
multislice!(arr, (mut s1, s2)),
264+
(copy.clone().slice_mut(s1), copy.clone().slice(s2))
265+
);
266+
}
267+
{
268+
let copy = arr.clone();
269+
assert_eq!(
270+
multislice!(arr, (s1, mut s2)),
271+
(copy.clone().slice(s1), copy.clone().slice_mut(s2))
272+
);
273+
}
274+
{
275+
let copy = arr.clone();
276+
assert_eq!(
277+
multislice!(arr, (s1, s2)),
278+
(copy.clone().slice(s1), copy.clone().slice(s2))
279+
);
280+
}
281+
});
282+
let mut arr = Array1::from_iter(0..48).into_shape((8, 6)).unwrap();
283+
284+
test_multislice!(&mut arr, s![0, ..], s![1, ..]);
285+
test_multislice!(&mut arr, s![0, ..], s![-1, ..]);
286+
test_multislice!(&mut arr, s![0, ..], s![1.., ..]);
287+
test_multislice!(&mut arr, s![1, ..], s![..;2, ..]);
288+
test_multislice!(&mut arr, s![..2, ..], s![2.., ..]);
289+
test_multislice!(&mut arr, s![1..;2, ..], s![..;2, ..]);
290+
test_multislice!(&mut arr, s![..;-2, ..], s![..;2, ..]);
291+
test_multislice!(&mut arr, s![..;12, ..], s![3..;3, ..]);
292+
}
293+
294+
#[test]
295+
fn test_multislice_intersecting() {
296+
assert_panics!({
297+
let mut arr = Array2::<u8>::zeros((8, 6));
298+
multislice!(arr, (mut s![3, ..], s![3, ..]));
299+
});
300+
assert_panics!({
301+
let mut arr = Array2::<u8>::zeros((8, 6));
302+
multislice!(arr, (mut s![3, ..], s![3.., ..]));
303+
});
304+
assert_panics!({
305+
let mut arr = Array2::<u8>::zeros((8, 6));
306+
multislice!(arr, (mut s![3, ..], s![..;3, ..]));
307+
});
308+
assert_panics!({
309+
let mut arr = Array2::<u8>::zeros((8, 6));
310+
multislice!(arr, (mut s![..;6, ..], s![3..;3, ..]));
311+
});
312+
assert_panics!({
313+
let mut arr = Array2::<u8>::zeros((8, 6));
314+
multislice!(arr, (mut s![2, ..], mut s![..-1;-2, ..]));
315+
});
316+
{
317+
let mut arr = Array2::<u8>::zeros((8, 6));
318+
multislice!(arr, (s![3, ..], s![-1..;-2, ..]));
319+
}
320+
}
321+
236322
#[should_panic]
237323
#[test]
238324
fn index_out_of_bounds() {

0 commit comments

Comments
 (0)