Skip to content

Commit 77336f3

Browse files
committed
Add multislice! macro
1 parent 0b34a22 commit 77336f3

File tree

3 files changed

+316
-2
lines changed

3 files changed

+316
-2
lines changed

src/lib.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,9 +447,12 @@ pub type Ixs = isize;
447447
/// [`.slice_move()`]: #method.slice_move
448448
/// [`.slice_inplace()`]: #method.slice_inplace
449449
///
450+
/// It's possible to take multiple simultaneous *mutable* slices with the
451+
/// [`multislice!()`](macro.multislice!.html) macro.
452+
///
450453
/// ```
451-
/// // import the s![] macro
452-
/// #[macro_use(s)]
454+
/// // import the multislice!() and s![] macros
455+
/// #[macro_use(multislice, s)]
453456
/// extern crate ndarray;
454457
///
455458
/// use ndarray::{arr2, arr3};
@@ -499,6 +502,20 @@ pub type Ixs = isize;
499502
/// [12, 11, 10]]);
500503
/// assert_eq!(f, g);
501504
/// assert_eq!(f.shape(), &[2, 3]);
505+
///
506+
/// // Let's take two disjoint, mutable slices of a matrix with
507+
/// //
508+
/// // - One containing all the even-index columns in the matrix
509+
/// // - One containing all the odd-index columns in the matrix
510+
/// let mut h = arr2(&[[0, 1, 2, 3],
511+
/// [4, 5, 6, 7]]);
512+
/// let (s0, s1) = multislice!(h, (mut s![.., ..;2], mut s![.., 1..;2]));
513+
/// let i = arr2(&[[0, 2],
514+
/// [4, 6]]);
515+
/// let j = arr2(&[[1, 3],
516+
/// [5, 7]]);
517+
/// assert_eq!(s0, i);
518+
/// assert_eq!(s1, j);
502519
/// }
503520
/// ```
504521
///

src/slice.rs

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,3 +582,214 @@ macro_rules! s(
582582
s![@parse ::std::marker::PhantomData::<$crate::Ix0>, [] $($t)*]
583583
};
584584
);
585+
586+
/// Take multiple slices simultaneously.
587+
///
588+
/// This macro makes it possible to take multiple slices of the same array, as
589+
/// long as Rust's aliasing rules are followed for *elements* in the slices.
590+
/// For example, it's possible to take two disjoint, mutable slices of an
591+
/// array, with one referencing the even-index elements and the other
592+
/// referencing the odd-index elements. If you tried to achieve this by calling
593+
/// `.slice_mut()` twice, the borrow checker would complain about mutably
594+
/// borrowing the array twice (even though it's safe as long as the slices are
595+
/// disjoint).
596+
///
597+
/// The syntax is `multislice!(` *expression, (pattern [, pattern [, …]])* `)`,
598+
/// where *expression* evaluates to an `ArrayBase<S, D>` where `S: DataMut`,
599+
/// and `pattern` is one of the following:
600+
///
601+
/// * `mut expr`: creates an `ArrayViewMut`, where `expr` evaluates to a
602+
/// `&SliceInfo` instance used to slice the array.
603+
/// * `expr`: creates an `ArrayView`, where `expr` evaluates to a `&SliceInfo`
604+
/// instance used to slice the array.
605+
///
606+
/// **Note** that this macro always mutably borrows the array even if there are
607+
/// no `mut` patterns. If all you want to do is take read-only slices, you
608+
/// don't need `multislice!()`; just call
609+
/// [`.slice()`](struct.ArrayBase.html#method.slice) multiple times instead.
610+
///
611+
/// `multislice!()` follows Rust's aliasing rules:
612+
///
613+
/// * An `ArrayViewMut` and `ArrayView` cannot reference the same element.
614+
/// * Two `ArrayViewMut` cannot reference the same element.
615+
/// * Two `ArrayView` can reference the same element.
616+
///
617+
/// **Panics** at runtime if any of the aliasing rules is violated.
618+
///
619+
/// See also [*Slicing*](struct.ArrayBase.html#slicing).
620+
///
621+
/// # Examples
622+
///
623+
/// In this example, there are two overlapping read-only slices, and two
624+
/// disjoint mutable slices. Neither of the mutable slices intersects any of
625+
/// the other slices.
626+
///
627+
/// ```
628+
/// #[macro_use]
629+
/// extern crate ndarray;
630+
///
631+
/// use ndarray::prelude::*;
632+
///
633+
/// # fn main() {
634+
/// let mut arr = Array1::from_iter(0..12);
635+
/// let (a, b, c, d) = multislice!(arr, (s![0..5], mut s![6..;2], s![1..6], mut s![7..;2]));
636+
/// assert_eq!(a, array![0, 1, 2, 3, 4]);
637+
/// assert_eq!(b, array![6, 8, 10]);
638+
/// assert_eq!(c, array![1, 2, 3, 4, 5]);
639+
/// assert_eq!(d, array![7, 9, 11]);
640+
/// # }
641+
/// ```
642+
///
643+
/// These examples panic because they don't follow the aliasing rules:
644+
///
645+
/// * `ArrayViewMut` and `ArrayView` cannot reference the same element.
646+
///
647+
/// ```should_panic
648+
/// # #[macro_use] extern crate ndarray;
649+
/// # use ndarray::prelude::*;
650+
/// # fn main() {
651+
/// let mut arr = Array1::from_iter(0..12);
652+
/// multislice!(arr, (s![0..5], mut s![1..;2])); // panic!
653+
/// # }
654+
/// ```
655+
///
656+
/// * Two `ArrayViewMut` cannot reference the same element.
657+
///
658+
/// ```should_panic
659+
/// # #[macro_use] extern crate ndarray;
660+
/// # use ndarray::prelude::*;
661+
/// # fn main() {
662+
/// let mut arr = Array1::from_iter(0..12);
663+
/// multislice!(arr, (mut s![0..5], mut s![1..;2])); // panic!
664+
/// # }
665+
/// ```
666+
#[macro_export]
667+
macro_rules! multislice(
668+
(
669+
@check $view:expr,
670+
$info:expr,
671+
()
672+
) => {};
673+
// Check that $info doesn't intersect $other.
674+
(
675+
@check $view:expr,
676+
$info:expr,
677+
($other:expr,)
678+
) => {
679+
assert!(
680+
!$crate::slices_intersect(&$view.raw_dim(), $info, $other),
681+
"Slice {:?} must not intersect slice {:?}", $info, $other
682+
)
683+
};
684+
// Check that $info doesn't intersect any of the other info in the tuple.
685+
(
686+
@check $view:expr,
687+
$info:expr,
688+
($other:expr, $($more:tt)*)
689+
) => {
690+
{
691+
multislice!(@check $view, $info, ($other,));
692+
multislice!(@check $view, $info, ($($more)*));
693+
}
694+
};
695+
// Parse last slice (mutable), no trailing comma.
696+
(
697+
@parse $view:expr,
698+
($($sliced:tt)*),
699+
($($mut_info:tt)*),
700+
($($immut_info:tt)*),
701+
(mut $info:expr)
702+
) => {
703+
{
704+
multislice!(@check $view, $info, ($($mut_info)*));
705+
multislice!(@check $view, $info, ($($immut_info)*));
706+
($($sliced)* unsafe { $view.aliasing_view_mut() }.slice_move($info))
707+
}
708+
};
709+
// Parse last slice (read-only), no trailing comma.
710+
(
711+
@parse $view:expr,
712+
($($sliced:tt)*),
713+
($($mut_info:tt)*),
714+
($($immut_info:tt)*),
715+
($info:expr)
716+
) => {
717+
{
718+
multislice!(@check $view, $info, ($($mut_info)*));
719+
($($sliced)* unsafe { $view.aliasing_view() }.slice_move($info))
720+
}
721+
};
722+
// Parse last slice (mutable), with trailing comma.
723+
(
724+
@parse $view:expr,
725+
($($sliced:tt)*),
726+
($($mut_info:tt)*),
727+
($($immut_info:tt)*),
728+
(mut $info:expr,)
729+
) => {
730+
{
731+
multislice!(@check $view, $info, ($($mut_info)*));
732+
multislice!(@check $view, $info, ($($immut_info)*));
733+
($($sliced)* unsafe { $view.aliasing_view_mut() }.slice_move($info))
734+
}
735+
};
736+
// Parse last slice (read-only), with trailing comma.
737+
(
738+
@parse $view:expr,
739+
($($sliced:tt)*),
740+
($($mut_info:tt)*),
741+
($($immut_info:tt)*),
742+
($info:expr,)
743+
) => {
744+
{
745+
multislice!(@check $view, $info, ($($mut_info)*));
746+
($($sliced)* unsafe { $view.aliasing_view() }.slice_move($info))
747+
}
748+
};
749+
// Parse a mutable slice.
750+
(
751+
@parse $view:expr,
752+
($($sliced:tt)*),
753+
($($mut_info:tt)*),
754+
($($immut_info:tt)*),
755+
(mut $info:expr, $($t:tt)*)
756+
) => {
757+
{
758+
multislice!(@check $view, $info, ($($mut_info)*));
759+
multislice!(@check $view, $info, ($($immut_info)*));
760+
multislice!(
761+
@parse $view,
762+
($($sliced)* unsafe { $view.aliasing_view_mut() }.slice_move($info),),
763+
($($mut_info)* $info,),
764+
($($immut_info)*),
765+
($($t)*)
766+
)
767+
}
768+
};
769+
// Parse a read-only slice.
770+
(
771+
@parse $view:expr,
772+
($($sliced:tt)*),
773+
($($mut_info:tt)*),
774+
($($immut_info:tt)*),
775+
($info:expr, $($t:tt)*)
776+
) => {
777+
{
778+
multislice!(@check $view, $info, ($($mut_info)*));
779+
multislice!(
780+
@parse $view,
781+
($($sliced)* unsafe { $view.aliasing_view() }.slice_move($info),),
782+
($($mut_info)*),
783+
($($immut_info)* $info,),
784+
($($t)*)
785+
)
786+
}
787+
};
788+
// Entry point.
789+
($arr:expr, ($($t:tt)*)) => {
790+
{
791+
let view = $crate::ArrayBase::view_mut(&mut $arr);
792+
multislice!(@parse view, (), (), (), ($($t)*))
793+
}
794+
};
795+
);

tests/array.rs

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,20 @@ use ndarray::{
1515
use ndarray::indices;
1616
use itertools::{enumerate, zip};
1717

18+
macro_rules! assert_panics {
19+
($body:expr) => {
20+
if let Ok(v) = ::std::panic::catch_unwind(|| $body) {
21+
panic!("assertion failed: should_panic; \
22+
non-panicking result: {:?}", v);
23+
}
24+
};
25+
($body:expr, $($arg:tt)*) => {
26+
if let Ok(_) = ::std::panic::catch_unwind(|| $body) {
27+
panic!($($arg)*);
28+
}
29+
};
30+
}
31+
1832
#[test]
1933
fn test_matmul_rcarray()
2034
{
@@ -233,6 +247,78 @@ fn test_slice_inplace_with_subview_inplace() {
233247
assert_eq!(vi, Array3::from_elem((1, 1, 1), arr[(1, 2, 3)]));
234248
}
235249

250+
#[test]
251+
fn test_multislice() {
252+
defmac!(test_multislice mut arr, s1, s2 => {
253+
{
254+
let copy = arr.clone();
255+
assert_eq!(
256+
multislice!(arr, (mut s1, mut s2)),
257+
(copy.clone().slice_mut(s1), copy.clone().slice_mut(s2))
258+
);
259+
}
260+
{
261+
let copy = arr.clone();
262+
assert_eq!(
263+
multislice!(arr, (mut s1, s2)),
264+
(copy.clone().slice_mut(s1), copy.clone().slice(s2))
265+
);
266+
}
267+
{
268+
let copy = arr.clone();
269+
assert_eq!(
270+
multislice!(arr, (s1, mut s2)),
271+
(copy.clone().slice(s1), copy.clone().slice_mut(s2))
272+
);
273+
}
274+
{
275+
let copy = arr.clone();
276+
assert_eq!(
277+
multislice!(arr, (s1, s2)),
278+
(copy.clone().slice(s1), copy.clone().slice(s2))
279+
);
280+
}
281+
});
282+
let mut arr = Array1::from_iter(0..48).into_shape((8, 6)).unwrap();
283+
284+
test_multislice!(&mut arr, s![0, ..], s![1, ..]);
285+
test_multislice!(&mut arr, s![0, ..], s![-1, ..]);
286+
test_multislice!(&mut arr, s![0, ..], s![1.., ..]);
287+
test_multislice!(&mut arr, s![1, ..], s![..;2, ..]);
288+
test_multislice!(&mut arr, s![..2, ..], s![2.., ..]);
289+
test_multislice!(&mut arr, s![1..;2, ..], s![..;2, ..]);
290+
test_multislice!(&mut arr, s![..;-2, ..], s![..;2, ..]);
291+
test_multislice!(&mut arr, s![..;12, ..], s![3..;3, ..]);
292+
}
293+
294+
#[test]
295+
fn test_multislice_intersecting() {
296+
assert_panics!({
297+
let mut arr = Array2::<u8>::zeros((8, 6));
298+
multislice!(arr, (mut s![3, ..], s![3, ..]));
299+
});
300+
assert_panics!({
301+
let mut arr = Array2::<u8>::zeros((8, 6));
302+
multislice!(arr, (mut s![3, ..], s![3.., ..]));
303+
});
304+
assert_panics!({
305+
let mut arr = Array2::<u8>::zeros((8, 6));
306+
multislice!(arr, (mut s![3, ..], s![..;3, ..]));
307+
});
308+
assert_panics!({
309+
let mut arr = Array2::<u8>::zeros((8, 6));
310+
multislice!(arr, (mut s![..;6, ..], s![3..;3, ..]));
311+
});
312+
assert_panics!({
313+
let mut arr = Array2::<u8>::zeros((8, 6));
314+
multislice!(arr, (mut s![2, ..], mut s![..-1;-2, ..]));
315+
});
316+
{
317+
let mut arr = Array2::<u8>::zeros((8, 6));
318+
multislice!(arr, (s![3, ..], s![-1..;-2, ..]));
319+
}
320+
}
321+
236322
#[should_panic]
237323
#[test]
238324
fn index_out_of_bounds() {

0 commit comments

Comments
 (0)