diff --git a/Cargo.toml b/Cargo.toml index 8de0064b4..cf8d5ec13 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ bench = false test = true [dependencies] +num-integer = "0.1.39" num-traits = "0.2" num-complex = "0.2" itertools = { version = "0.7.0", default-features = false } diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index ee612efea..7fb3f76fa 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -6,9 +6,10 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use {Ix, Ixs}; +use {Ix, Ixs, Slice, SliceOrIndex}; use error::{from_kind, ErrorKind, ShapeError}; use itertools::izip; +use num_integer::div_floor; pub use self::dim::*; pub use self::axis::Axis; @@ -329,25 +330,18 @@ pub fn abs_index(len: Ix, index: Ixs) -> Ix { } } -/// Modify dimension, stride and return data pointer offset +/// Determines nonnegative start and end indices, and performs sanity checks. +/// +/// The return value is (start, end, step). /// /// **Panics** if stride is 0 or if any index is out of bounds. -pub fn do_slice( - dim: &mut Ix, - stride: &mut Ix, - start: Ixs, - end: Option, - step: Ixs, -) -> isize { - let mut offset = 0; - - let axis_len = *dim; +fn to_abs_slice(axis_len: usize, slice: Slice) -> (usize, usize, isize) { + let Slice { start, end, step } = slice; let start = abs_index(axis_len, start); - let mut end = abs_index(axis_len, end.unwrap_or(axis_len as Ixs)); + let mut end = abs_index(axis_len, end.unwrap_or(axis_len as isize)); if end < start { end = start; } - ndassert!( start <= axis_len, "Slice begin {} is past end of axis of length {}", @@ -360,15 +354,23 @@ pub fn do_slice( end, axis_len, ); + ndassert!(step != 0, "Slice stride must not be zero"); + (start, end, step) +} + +/// Modify dimension, stride and return data pointer offset +/// +/// **Panics** if stride is 0 or if any index is out of bounds. +pub fn do_slice(dim: &mut usize, stride: &mut usize, slice: Slice) -> isize { + let (start, end, step) = to_abs_slice(*dim, slice); let m = end - start; - // stride - let s = (*stride) as Ixs; + let s = (*stride) as isize; // Data pointer offset - offset += stride_offset(start, *stride); + let mut offset = stride_offset(start, *stride); // Adjust for strides - ndassert!(step != 0, "Slice stride must not be zero"); + // // How to implement negative strides: // // Increase start pointer by @@ -380,17 +382,210 @@ pub fn do_slice( let s_prim = s * step; - let d = m / step.abs() as Ix; - let r = m % step.abs() as Ix; + let d = m / step.abs() as usize; + let r = m % step.abs() as usize; let m_prim = d + if r > 0 { 1 } else { 0 }; // Update dimension and stride coordinate *dim = m_prim; - *stride = s_prim as Ix; + *stride = s_prim as usize; offset } +/// Solves `a * x + b * y = gcd(a, b)` for `x`, `y`, and `gcd(a, b)`. +/// +/// Returns `(g, (x, y))`, where `g` is `gcd(a, b)`, and `g` is always +/// nonnegative. +/// +/// See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm +fn extended_gcd(a: isize, b: isize) -> (isize, (isize, isize)) { + if a == 0 { + (b.abs(), (0, b.signum())) + } else if b == 0 { + (a.abs(), (a.signum(), 0)) + } else { + let mut r = (a, b); + let mut s = (1, 0); + let mut t = (0, 1); + while r.1 != 0 { + let q = r.0 / r.1; + r = (r.1, r.0 - q * r.1); + s = (s.1, s.0 - q * s.1); + t = (t.1, t.0 - q * t.1); + } + if r.0 > 0 { + (r.0, (s.0, t.0)) + } else { + (-r.0, (-s.0, -t.0)) + } + } +} + +/// Solves `a * x + b * y = c` for `x` where `a`, `b`, `c`, `x`, and `y` are +/// integers. +/// +/// If the return value is `Some((x0, xd))`, there is a solution. `xd` is +/// always positive. Solutions `x` are given by `x0 + xd * t` where `t` is any +/// integer. The value of `y` for any `x` is then `y = (c - a * x) / b`. +/// +/// If the return value is `None`, no solutions exist. +/// +/// **Note** `a` and `b` must be nonzero. +/// +/// See https://en.wikipedia.org/wiki/Diophantine_equation#One_equation +/// and https://math.stackexchange.com/questions/1656120#1656138 +fn solve_linear_diophantine_eq(a: isize, b: isize, c: isize) -> Option<(isize, isize)> { + debug_assert_ne!(a, 0); + debug_assert_ne!(b, 0); + let (g, (u, _)) = extended_gcd(a, b); + if c % g == 0 { + Some((c / g * u, (b / g).abs())) + } else { + None + } +} + +/// Returns `true` if two (finite length) arithmetic sequences intersect. +/// +/// `min*` and `max*` are the (inclusive) bounds of the sequences, and they +/// must be elements in the sequences. `step*` are the steps between +/// consecutive elements (the sign is irrelevant). +/// +/// **Note** `step1` and `step2` must be nonzero. +fn arith_seq_intersect( + (min1, max1, step1): (isize, isize, isize), + (min2, max2, step2): (isize, isize, isize), +) -> bool { + debug_assert!(max1 >= min1); + debug_assert!(max2 >= min2); + debug_assert_eq!((max1 - min1) % step1, 0); + debug_assert_eq!((max2 - min2) % step2, 0); + + // Handle the easy case where we don't have to solve anything. + if min1 > max2 || min2 > max1 { + false + } else { + // The sign doesn't matter semantically, and it's mathematically convenient + // for `step1` and `step2` to be positive. + let step1 = step1.abs(); + let step2 = step2.abs(); + // Ignoring the min/max bounds, the sequences are + // a(x) = min1 + step1 * x + // b(y) = min2 + step2 * y + // + // For intersections a(x) = b(y), we have: + // min1 + step1 * x = min2 + step2 * y + // ⇒ -step1 * x + step2 * y = min1 - min2 + // which is a linear Diophantine equation. + if let Some((x0, xd)) = solve_linear_diophantine_eq(-step1, step2, min1 - min2) { + // Minimum of [min1, max1] ∩ [min2, max2] + let min = ::std::cmp::max(min1, min2); + // Maximum of [min1, max1] ∩ [min2, max2] + let max = ::std::cmp::min(max1, max2); + // The potential intersections are + // a(x) = min1 + step1 * (x0 + xd * t) + // where `t` is any integer. + // + // There is an intersection in `[min, max]` if there exists an + // integer `t` such that + // min ≤ a(x) ≤ max + // ⇒ min ≤ min1 + step1 * (x0 + xd * t) ≤ max + // ⇒ min ≤ min1 + step1 * x0 + step1 * xd * t ≤ max + // ⇒ min - min1 - step1 * x0 ≤ (step1 * xd) * t ≤ max - min1 - step1 * x0 + // + // Therefore, the least possible intersection `a(x)` that is ≥ `min` has + // t = ⌈(min - min1 - step1 * x0) / (step1 * xd)⌉ + // If this `a(x) is also ≤ `max`, then there is an intersection in `[min, max]`. + // + // The greatest possible intersection `a(x)` that is ≤ `max` has + // t = ⌊(max - min1 - step1 * x0) / (step1 * xd)⌋ + // If this `a(x) is also ≥ `min`, then there is an intersection in `[min, max]`. + min1 + step1 * (x0 - xd * div_floor(min - min1 - step1 * x0, -step1 * xd)) <= max + || min1 + step1 * (x0 + xd * div_floor(max - min1 - step1 * x0, step1 * xd)) >= min + } else { + false + } + } +} + +/// Returns the minimum and maximum values of the indices (inclusive). +/// +/// If the slice is empty, then returns `None`, otherwise returns `Some((min, max))`. +fn slice_min_max(axis_len: usize, slice: Slice) -> Option<(usize, usize)> { + let (start, end, step) = to_abs_slice(axis_len, slice); + if start == end { + None + } else { + if step > 0 { + Some((start, end - 1 - (end - start - 1) % (step as usize))) + } else { + Some((start + (end - start - 1) % (-step as usize), end - 1)) + } + } +} + +/// Returns `true` iff the slices intersect. +#[doc(hidden)] +pub fn slices_intersect( + dim: &D, + indices1: &D::SliceArg, + indices2: &D::SliceArg, +) -> bool { + debug_assert_eq!(indices1.as_ref().len(), indices2.as_ref().len()); + for (&axis_len, &si1, &si2) in izip!(dim.slice(), indices1.as_ref(), indices2.as_ref()) { + // The slices do not intersect iff any pair of `SliceOrIndex` does not intersect. + match (si1, si2) { + ( + SliceOrIndex::Slice { + start: start1, + end: end1, + step: step1, + }, + SliceOrIndex::Slice { + start: start2, + end: end2, + step: step2, + }, + ) => { + let (min1, max1) = match slice_min_max(axis_len, Slice::new(start1, end1, step1)) { + Some(m) => m, + None => return false, + }; + let (min2, max2) = match slice_min_max(axis_len, Slice::new(start2, end2, step2)) { + Some(m) => m, + None => return false, + }; + if !arith_seq_intersect( + (min1 as isize, max1 as isize, step1), + (min2 as isize, max2 as isize, step2), + ) { + return false; + } + } + (SliceOrIndex::Slice { start, end, step }, SliceOrIndex::Index(ind)) | + (SliceOrIndex::Index(ind), SliceOrIndex::Slice { start, end, step }) => { + let ind = abs_index(axis_len, ind); + let (min, max) = match slice_min_max(axis_len, Slice::new(start, end, step)) { + Some(m) => m, + None => return false, + }; + if ind < min || ind > max || (ind - min) % step.abs() as usize != 0 { + return false; + } + } + (SliceOrIndex::Index(ind1), SliceOrIndex::Index(ind2)) => { + let ind1 = abs_index(axis_len, ind1); + let ind2 = abs_index(axis_len, ind2); + if ind1 != ind2 { + return false; + } + } + } + } + true +} + pub fn merge_axes(dim: &mut D, strides: &mut D, take: Axis, into: Axis) -> bool where D: Dimension, { @@ -422,11 +617,15 @@ pub fn merge_axes(dim: &mut D, strides: &mut D, take: Axis, into: Axis) -> bo #[cfg(test)] mod test { use super::{ - can_index_slice, can_index_slice_not_custom, max_abs_offset_check_overflow, IntoDimension + arith_seq_intersect, can_index_slice, can_index_slice_not_custom, extended_gcd, + max_abs_offset_check_overflow, slice_min_max, slices_intersect, + solve_linear_diophantine_eq, IntoDimension }; use error::{from_kind, ErrorKind}; - use quickcheck::quickcheck; - use {Dimension, Ix0, Ix1, Ix2, Ix3, IxDyn}; + use num_integer::gcd; + use quickcheck::{quickcheck, TestResult}; + use slice::{Slice, SliceOrIndex}; + use {Dim, Dimension, Ix0, Ix1, Ix2, Ix3, IxDyn}; #[test] fn slice_indexing_uncommon_strides() { @@ -548,4 +747,151 @@ mod test { } } } + + quickcheck! { + fn extended_gcd_solves_eq(a: isize, b: isize) -> bool { + let (g, (x, y)) = extended_gcd(a, b); + a * x + b * y == g + } + + fn extended_gcd_correct_gcd(a: isize, b: isize) -> bool { + let (g, _) = extended_gcd(a, b); + g == gcd(a, b) + } + } + + #[test] + fn extended_gcd_zero() { + assert_eq!(extended_gcd(0, 0), (0, (0, 0))); + assert_eq!(extended_gcd(0, 5), (5, (0, 1))); + assert_eq!(extended_gcd(5, 0), (5, (1, 0))); + assert_eq!(extended_gcd(0, -5), (5, (0, -1))); + assert_eq!(extended_gcd(-5, 0), (5, (-1, 0))); + } + + quickcheck! { + fn solve_linear_diophantine_eq_solution_existence( + a: isize, b: isize, c: isize + ) -> TestResult { + if a == 0 || b == 0 { + TestResult::discard() + } else { + TestResult::from_bool( + (c % gcd(a, b) == 0) == solve_linear_diophantine_eq(a, b, c).is_some() + ) + } + } + + fn solve_linear_diophantine_eq_correct_solution( + a: isize, b: isize, c: isize, t: isize + ) -> TestResult { + if a == 0 || b == 0 { + TestResult::discard() + } else { + match solve_linear_diophantine_eq(a, b, c) { + Some((x0, xd)) => { + let x = x0 + xd * t; + let y = (c - a * x) / b; + TestResult::from_bool(a * x + b * y == c) + } + None => TestResult::discard(), + } + } + } + } + + quickcheck! { + fn arith_seq_intersect_correct( + first1: isize, len1: isize, step1: isize, + first2: isize, len2: isize, step2: isize + ) -> TestResult { + use std::cmp; + + if len1 == 0 || len2 == 0 { + // This case is impossible to reach in `arith_seq_intersect()` + // because the `min*` and `max*` arguments are inclusive. + return TestResult::discard(); + } + let len1 = len1.abs(); + let len2 = len2.abs(); + + // Convert to `min*` and `max*` arguments for `arith_seq_intersect()`. + let last1 = first1 + step1 * (len1 - 1); + let (min1, max1) = (cmp::min(first1, last1), cmp::max(first1, last1)); + let last2 = first2 + step2 * (len2 - 1); + let (min2, max2) = (cmp::min(first2, last2), cmp::max(first2, last2)); + + // Naively determine if the sequences intersect. + let seq1: Vec<_> = (0..len1) + .map(|n| first1 + step1 * n) + .collect(); + let intersects = (0..len2) + .map(|n| first2 + step2 * n) + .any(|elem2| seq1.contains(&elem2)); + + TestResult::from_bool( + arith_seq_intersect( + (min1, max1, if step1 == 0 { 1 } else { step1 }), + (min2, max2, if step2 == 0 { 1 } else { step2 }) + ) == intersects + ) + } + } + + #[test] + fn slice_min_max_empty() { + assert_eq!(slice_min_max(0, Slice::new(0, None, 3)), None); + assert_eq!(slice_min_max(10, Slice::new(1, Some(1), 3)), None); + assert_eq!(slice_min_max(10, Slice::new(-1, Some(-1), 3)), None); + assert_eq!(slice_min_max(10, Slice::new(1, Some(1), -3)), None); + assert_eq!(slice_min_max(10, Slice::new(-1, Some(-1), -3)), None); + } + + #[test] + fn slice_min_max_pos_step() { + assert_eq!(slice_min_max(10, Slice::new(1, Some(8), 3)), Some((1, 7))); + assert_eq!(slice_min_max(10, Slice::new(1, Some(9), 3)), Some((1, 7))); + assert_eq!(slice_min_max(10, Slice::new(-9, Some(8), 3)), Some((1, 7))); + assert_eq!(slice_min_max(10, Slice::new(-9, Some(9), 3)), Some((1, 7))); + assert_eq!(slice_min_max(10, Slice::new(1, Some(-2), 3)), Some((1, 7))); + assert_eq!(slice_min_max(10, Slice::new(1, Some(-1), 3)), Some((1, 7))); + assert_eq!(slice_min_max(10, Slice::new(-9, Some(-2), 3)), Some((1, 7))); + assert_eq!(slice_min_max(10, Slice::new(-9, Some(-1), 3)), Some((1, 7))); + assert_eq!(slice_min_max(10, Slice::new(1, None, 3)), Some((1, 7))); + assert_eq!(slice_min_max(10, Slice::new(-9, None, 3)), Some((1, 7))); + assert_eq!(slice_min_max(11, Slice::new(1, None, 3)), Some((1, 10))); + assert_eq!(slice_min_max(11, Slice::new(-10, None, 3)), Some((1, 10))); + } + + #[test] + fn slice_min_max_neg_step() { + assert_eq!(slice_min_max(10, Slice::new(1, Some(8), -3)), Some((1, 7))); + assert_eq!(slice_min_max(10, Slice::new(2, Some(8), -3)), Some((4, 7))); + assert_eq!(slice_min_max(10, Slice::new(-9, Some(8), -3)), Some((1, 7))); + assert_eq!(slice_min_max(10, Slice::new(-8, Some(8), -3)), Some((4, 7))); + assert_eq!(slice_min_max(10, Slice::new(1, Some(-2), -3)), Some((1, 7))); + assert_eq!(slice_min_max(10, Slice::new(2, Some(-2), -3)), Some((4, 7))); + assert_eq!(slice_min_max(10, Slice::new(-9, Some(-2), -3)), Some((1, 7))); + assert_eq!(slice_min_max(10, Slice::new(-8, Some(-2), -3)), Some((4, 7))); + assert_eq!(slice_min_max(9, Slice::new(2, None, -3)), Some((2, 8))); + assert_eq!(slice_min_max(9, Slice::new(-7, None, -3)), Some((2, 8))); + assert_eq!(slice_min_max(9, Slice::new(3, None, -3)), Some((5, 8))); + assert_eq!(slice_min_max(9, Slice::new(-6, None, -3)), Some((5, 8))); + } + + #[test] + fn slices_intersect_true() { + assert!(slices_intersect(&Dim([4, 5]), s![.., ..], s![.., ..])); + assert!(slices_intersect(&Dim([4, 5]), s![0, ..], s![0, ..])); + assert!(slices_intersect(&Dim([4, 5]), s![..;2, ..], s![..;3, ..])); + assert!(slices_intersect(&Dim([4, 5]), s![.., ..;2], s![.., 1..;3])); + assert!(slices_intersect(&Dim([4, 10]), s![.., ..;9], s![.., 3..;6])); + } + + #[test] + fn slices_intersect_false() { + assert!(!slices_intersect(&Dim([4, 5]), s![..;2, ..], s![1..;2, ..])); + assert!(!slices_intersect(&Dim([4, 5]), s![..;2, ..], s![1..;3, ..])); + assert!(!slices_intersect(&Dim([4, 5]), s![.., ..;9], s![.., 3..;6])); + } } diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 562517dae..8e2f42eb3 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -445,9 +445,7 @@ where let offset = do_slice( &mut self.dim.slice_mut()[axis.index()], &mut self.strides.slice_mut()[axis.index()], - indices.start, - indices.end, - indices.step, + indices, ); unsafe { self.ptr = self.ptr.offset(offset); diff --git a/src/lib.rs b/src/lib.rs index c56df3d02..7ec0e3a94 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -97,6 +97,7 @@ extern crate matrixmultiply; extern crate itertools; extern crate num_traits as libnum; extern crate num_complex; +extern crate num_integer; #[cfg(test)] extern crate quickcheck; @@ -113,6 +114,7 @@ pub use dimension::{ RemoveAxis, Axis, AxisDescription, + slices_intersect, }; pub use dimension::dim::*; @@ -120,7 +122,10 @@ pub use dimension::NdIndex; pub use dimension::IxDynImpl; pub use indexes::{indices, indices_of}; pub use error::{ShapeError, ErrorKind}; -pub use slice::{Slice, SliceInfo, SliceNextDim, SliceOrIndex}; +pub use slice::{ + deref_raw_view_mut_into_view_with_life, deref_raw_view_mut_into_view_mut_with_life, + life_of_view_mut, Slice, SliceInfo, SliceNextDim, SliceOrIndex +}; use iterators::Baseiter; use iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes, LanesMut}; @@ -155,13 +160,11 @@ pub use data_traits::{ DataClone, }; -mod dimension; - mod free_functions; pub use free_functions::*; pub use iterators::iter; -mod slice; +#[macro_use] mod slice; mod layout; mod indexes; mod iterators; @@ -174,6 +177,8 @@ mod stacking; #[macro_use] mod zip; +mod dimension; + pub use zip::{ Zip, NdProducer, @@ -468,10 +473,13 @@ pub type Ixs = isize; /// [`.slice_move()`]: #method.slice_move /// [`.slice_collapse()`]: #method.slice_collapse /// +/// It's possible to take multiple simultaneous *mutable* slices with the +/// [`multislice!()`](macro.multislice!.html) macro. +/// /// ``` /// extern crate ndarray; /// -/// use ndarray::{arr2, arr3, s}; +/// use ndarray::{arr2, arr3, multislice, s}; /// /// fn main() { /// @@ -518,6 +526,20 @@ pub type Ixs = isize; /// [12, 11, 10]]); /// assert_eq!(f, g); /// assert_eq!(f.shape(), &[2, 3]); +/// +/// // Let's take two disjoint, mutable slices of a matrix with +/// // +/// // - One containing all the even-index columns in the matrix +/// // - One containing all the odd-index columns in the matrix +/// let mut h = arr2(&[[0, 1, 2, 3], +/// [4, 5, 6, 7]]); +/// let (s0, s1) = multislice!(h, mut [.., ..;2], mut [.., 1..;2]); +/// let i = arr2(&[[0, 2], +/// [4, 6]]); +/// let j = arr2(&[[1, 3], +/// [5, 7]]); +/// assert_eq!(s0, i); +/// assert_eq!(s1, j); /// } /// ``` /// diff --git a/src/slice.rs b/src/slice.rs index 791364e6b..a02925ba3 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -9,7 +9,7 @@ use error::{ShapeError, ErrorKind}; use std::ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive}; use std::fmt; use std::marker::PhantomData; -use super::Dimension; +use crate::{ArrayView, ArrayViewMut, Dimension, RawArrayViewMut}; /// A slice (range with step size). /// @@ -623,3 +623,377 @@ macro_rules! s( &*&$crate::s![@parse ::std::marker::PhantomData::<$crate::Ix0>, [] $($t)*] }; ); + +/// Returns a ZST representing the lifetime of the mutable view. +#[doc(hidden)] +pub fn life_of_view_mut<'a, A, D: Dimension>( + _view: &ArrayViewMut<'a, A, D> +) -> PhantomData<&'a mut A> { + PhantomData +} + +/// Derefs the raw mutable view into a view, using the given lifetime. +#[doc(hidden)] +pub unsafe fn deref_raw_view_mut_into_view_with_life<'a, A, D: Dimension>( + raw: RawArrayViewMut, + _life: PhantomData<&'a mut A>, +) -> ArrayView<'a, A, D> { + raw.deref_into_view() +} + +/// Derefs the raw mutable view into a mutable view, using the given lifetime. +#[doc(hidden)] +pub unsafe fn deref_raw_view_mut_into_view_mut_with_life<'a, A, D: Dimension>( + raw: RawArrayViewMut, + _life: PhantomData<&'a mut A>, +) -> ArrayViewMut<'a, A, D> { + raw.deref_into_view_mut() +} + +/// Take multiple slices simultaneously. +/// +/// This macro makes it possible to take multiple slices of the same array, as +/// long as Rust's aliasing rules are followed for *elements* in the slices. +/// For example, it's possible to take two disjoint, mutable slices of an +/// array, with one referencing the even-index elements and the other +/// referencing the odd-index elements. If you tried to achieve this by calling +/// `.slice_mut()` twice, the borrow checker would complain about mutably +/// borrowing the array twice (even though it's safe as long as the slices are +/// disjoint). +/// +/// The syntax is `multislice!(` *expression, pattern [, pattern [, …]]* `)`, +/// where *expression* evaluates to a mutable array, and each *pattern* is +/// either +/// +/// * `mut` *s-args-or-expr*: creates an `ArrayViewMut` or +/// * *s-args-or-expr*: creates an `ArrayView` +/// +/// where *s-args-or-expr* is either (1) arguments enclosed in `[]` to pass to +/// the [`s!`] macro to create a `&SliceInfo` instance or (2) an expression +/// that evaluates to a `&SliceInfo` instance. +/// +/// **Note** that this macro always mutably borrows the array even if there are +/// no `mut` patterns. If all you want to do is take read-only slices, you +/// don't need `multislice!()`; just call +/// [`.slice()`](struct.ArrayBase.html#method.slice) multiple times instead. +/// +/// `multislice!()` evaluates to a tuple of `ArrayView` and/or `ArrayViewMut` +/// instances. It checks Rust's aliasing rules: +/// +/// * An `ArrayViewMut` and `ArrayView` cannot reference the same element. +/// * Two `ArrayViewMut` cannot reference the same element. +/// * Two `ArrayView` can reference the same element. +/// +/// **Panics** at runtime if any of the aliasing rules is violated. +/// +/// See also [*Slicing*](struct.ArrayBase.html#slicing). +/// +/// # Examples +/// +/// In this example, there are two overlapping read-only slices, and two +/// disjoint mutable slices. Neither of the mutable slices intersects any of +/// the other slices. +/// +/// ``` +/// extern crate ndarray; +/// +/// use ndarray::multislice; +/// use ndarray::prelude::*; +/// +/// # fn main() { +/// let mut arr = Array1::from_iter(0..12); +/// let (a, b, c, d) = multislice!(arr, [0..5], mut [6..;2], [1..6], mut [7..;2]); +/// assert_eq!(a, array![0, 1, 2, 3, 4]); +/// assert_eq!(b, array![6, 8, 10]); +/// assert_eq!(c, array![1, 2, 3, 4, 5]); +/// assert_eq!(d, array![7, 9, 11]); +/// # } +/// ``` +/// +/// These examples panic because they don't follow the aliasing rules: +/// +/// * `ArrayViewMut` and `ArrayView` cannot reference the same element. +/// +/// ```should_panic +/// # extern crate ndarray; +/// # use ndarray::multislice; +/// # use ndarray::prelude::*; +/// # fn main() { +/// let mut arr = Array1::from_iter(0..12); +/// multislice!(arr, [0..5], mut [1..;2]); // panic! +/// # } +/// ``` +/// +/// * Two `ArrayViewMut` cannot reference the same element. +/// +/// ```should_panic +/// # extern crate ndarray; +/// # use ndarray::multislice; +/// # use ndarray::prelude::*; +/// # fn main() { +/// let mut arr = Array1::from_iter(0..12); +/// multislice!(arr, mut [0..5], mut [1..;2]); // panic! +/// # } +/// ``` +#[macro_export] +macro_rules! multislice( + (@check $view:expr, $info:expr, ()) => {}; + // Check that $info doesn't intersect $other. + (@check $view:expr, $info:expr, ($other:expr,)) => { + assert!( + !$crate::slices_intersect(&$view.raw_dim(), $info, $other), + "Slice {:?} must not intersect slice {:?}", $info, $other + ) + }; + // Check that $info doesn't intersect any of the other info in the tuple. + (@check $view:expr, $info:expr, ($other:expr, $($more:tt)*)) => { + { + $crate::multislice!(@check $view, $info, ($other,)); + $crate::multislice!(@check $view, $info, ($($more)*)); + } + }; + // Create the (mutable) slice. + (@slice $view:expr, $life:expr, mut $info:expr) => { + #[allow(unsafe_code)] + unsafe { + $crate::deref_raw_view_mut_into_view_mut_with_life( + $view.clone().slice_move($info), + $life, + ) + } + }; + // Create the (read-only) slice. + (@slice $view:expr, $life:expr, $info:expr) => { + #[allow(unsafe_code)] + unsafe { + $crate::deref_raw_view_mut_into_view_with_life( + $view.clone().slice_move($info), + $life, + ) + } + }; + // Parse last slice (mutable), no trailing comma, applying `s![]` macro. + ( + @parse $view:expr, $life:expr, + ($($sliced:tt)*), + ($($mut_info:tt)*), + ($($immut_info:tt)*), + (mut [$($info:tt)*]) + ) => { + // Apply `s![]` macro to info. + $crate::multislice!( + @parse $view, $life, + ($($sliced)*), + ($($mut_info)*), + ($($immut_info)*), + (mut $crate::s![$($info)*],) + ) + }; + // Parse last slice (read-only), no trailing comma, applying `s![]` macro. + ( + @parse $view:expr, $life:expr, + ($($sliced:tt)*), + ($($mut_info:tt)*), + ($($immut_info:tt)*), + ([$($info:tt)*]) + ) => { + // Apply `s![]` macro to info. + $crate::multislice!( + @parse $view, $life, + ($($sliced)*), + ($($mut_info)*), + ($($immut_info)*), + ($crate::s![$($info)*],) + ) + }; + // Parse last slice (mutable), with trailing comma, applying `s![]` macro. + ( + @parse $view:expr, $life:expr, + ($($sliced:tt)*), + ($($mut_info:tt)*), + ($($immut_info:tt)*), + (mut [$($info:tt)*],) + ) => { + // Apply `s![]` macro to info. + $crate::multislice!( + @parse $view, $life, + ($($sliced)*), + ($($mut_info)*), + ($($immut_info)*), + (mut $crate::s![$($info)*],) + ) + }; + // Parse last slice (read-only), with trailing comma, applying `s![]` macro. + ( + @parse $view:expr, $life:expr, + ($($sliced:tt)*), + ($($mut_info:tt)*), + ($($immut_info:tt)*), + ([$($info:tt)*],) + ) => { + // Apply `s![]` macro to info. + $crate::multislice!( + @parse $view, $life, + ($($sliced)*), + ($($mut_info)*), + ($($immut_info)*), + ($crate::s![$($info)*],) + ) + }; + // Parse a mutable slice, applying `s![]` macro. + ( + @parse $view:expr, $life:expr, + ($($sliced:tt)*), + ($($mut_info:tt)*), + ($($immut_info:tt)*), + (mut [$($info:tt)*], $($t:tt)*) + ) => { + // Apply `s![]` macro to info. + $crate::multislice!( + @parse $view, $life, + ($($sliced)*), + ($($mut_info)*), + ($($immut_info)*), + (mut $crate::s![$($info)*], $($t)*) + ) + }; + // Parse a read-only slice, applying `s![]` macro. + ( + @parse $view:expr, $life:expr, + ($($sliced:tt)*), + ($($mut_info:tt)*), + ($($immut_info:tt)*), + ([$($info:tt)*], $($t:tt)*) + ) => { + // Apply `s![]` macro to info. + $crate::multislice!( + @parse $view, $life, + ($($sliced)*), + ($($mut_info)*), + ($($immut_info)*), + ($crate::s![$($info)*], $($t)*) + ) + }; + // Parse last slice (mutable), no trailing comma. + ( + @parse $view:expr, $life:expr, + ($($sliced:tt)*), + ($($mut_info:tt)*), + ($($immut_info:tt)*), + (mut $info:expr) + ) => { + // Add trailing comma. + $crate::multislice!( + @parse $view, $life, + ($($sliced)*), + ($($mut_info)*), + ($($immut_info)*), + (mut $info,) + ) + }; + // Parse last slice (read-only), no trailing comma. + ( + @parse $view:expr, $life:expr, + ($($sliced:tt)*), + ($($mut_info:tt)*), + ($($immut_info:tt)*), + ($info:expr) + ) => { + // Add trailing comma. + $crate::multislice!( + @parse $view, $life, + ($($sliced)*), + ($($mut_info)*), + ($($immut_info)*), + ($info,) + ) + }; + // Parse last slice (mutable), with trailing comma. + ( + @parse $view:expr, $life:expr, + ($($sliced:tt)*), + ($($mut_info:tt)*), + ($($immut_info:tt)*), + (mut $info:expr,) + ) => { + match $info { + info => { + // Check for overlap with all previous mutable and immutable slices. + $crate::multislice!(@check $view, info, ($($mut_info)*)); + $crate::multislice!(@check $view, info, ($($immut_info)*)); + ($($sliced)* $crate::multislice!(@slice $view, $life, mut info),) + } + } + }; + // Parse last slice (read-only), with trailing comma. + ( + @parse $view:expr, $life:expr, + ($($sliced:tt)*), + ($($mut_info:tt)*), + ($($immut_info:tt)*), + ($info:expr,) + ) => { + match $info { + info => { + // Check for overlap with all previous mutable slices. + $crate::multislice!(@check $view, info, ($($mut_info)*)); + ($($sliced)* $crate::multislice!(@slice $view, $life, info),) + } + } + }; + // Parse a mutable slice. + ( + @parse $view:expr, $life:expr, + ($($sliced:tt)*), + ($($mut_info:tt)*), + ($($immut_info:tt)*), + (mut $info:expr, $($t:tt)*) + ) => { + match $info { + info => { + // Check for overlap with all previous mutable and immutable slices. + $crate::multislice!(@check $view, info, ($($mut_info)*)); + $crate::multislice!(@check $view, info, ($($immut_info)*)); + $crate::multislice!( + @parse $view, $life, + ($($sliced)* $crate::multislice!(@slice $view, $life, mut info),), + ($($mut_info)* info,), + ($($immut_info)*), + ($($t)*) + ) + } + } + }; + // Parse a read-only slice. + ( + @parse $view:expr, $life:expr, + ($($sliced:tt)*), + ($($mut_info:tt)*), + ($($immut_info:tt)*), + ($info:expr, $($t:tt)*) + ) => { + match $info { + info => { + // Check for overlap with all previous mutable slices. + $crate::multislice!(@check $view, info, ($($mut_info)*)); + $crate::multislice!( + @parse $view, $life, + ($($sliced)* $crate::multislice!(@slice $view, $life, info),), + ($($mut_info)*), + ($($immut_info)* info,), + ($($t)*) + ) + } + } + }; + // Entry point. + ($arr:expr, $($t:tt)*) => { + { + let (life, raw_view) = { + let mut view = $crate::ArrayBase::view_mut(&mut $arr); + ($crate::life_of_view_mut(&view), view.raw_view_mut()) + }; + $crate::multislice!(@parse raw_view, life, (), (), (), ($($t)*)) + } + }; +); diff --git a/tests/array.rs b/tests/array.rs index e50442e4f..28e2e7fbc 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -9,11 +9,26 @@ use ndarray::prelude::*; use ndarray::{ rcarr2, arr3, + multislice, }; use ndarray::indices; use defmac::defmac; use itertools::{enumerate, zip}; +macro_rules! assert_panics { + ($body:expr) => { + if let Ok(v) = ::std::panic::catch_unwind(|| $body) { + panic!("assertion failed: should_panic; \ + non-panicking result: {:?}", v); + } + }; + ($body:expr, $($arg:tt)*) => { + if let Ok(_) = ::std::panic::catch_unwind(|| $body) { + panic!($($arg)*); + } + }; +} + #[test] fn test_matmul_arcarray() { @@ -314,6 +329,138 @@ fn test_slice_collapse_with_indices() { assert_eq!(vi, Array3::from_elem((1, 1, 1), elem)); } +#[test] +fn test_multislice() { + defmac!(test_multislice mut arr, s1, s2 => { + { + let copy = arr.clone(); + assert_eq!( + multislice!(arr, mut s1, mut s2,), + (copy.clone().slice_mut(s1), copy.clone().slice_mut(s2)) + ); + } + { + let copy = arr.clone(); + assert_eq!( + multislice!(arr, mut s1, s2,), + (copy.clone().slice_mut(s1), copy.clone().slice(s2)) + ); + } + { + let copy = arr.clone(); + assert_eq!( + multislice!(arr, s1, mut s2), + (copy.clone().slice(s1), copy.clone().slice_mut(s2)) + ); + } + { + let copy = arr.clone(); + assert_eq!( + multislice!(arr, s1, s2), + (copy.clone().slice(s1), copy.clone().slice(s2)) + ); + } + }); + let mut arr = Array1::from_iter(0..48).into_shape((8, 6)).unwrap(); + + assert_eq!((arr.clone().view(),), multislice!(arr, [.., ..])); + test_multislice!(&mut arr, s![0, ..], s![1, ..]); + test_multislice!(&mut arr, s![0, ..], s![-1, ..]); + test_multislice!(&mut arr, s![0, ..], s![1.., ..]); + test_multislice!(&mut arr, s![1, ..], s![..;2, ..]); + test_multislice!(&mut arr, s![..2, ..], s![2.., ..]); + test_multislice!(&mut arr, s![1..;2, ..], s![..;2, ..]); + test_multislice!(&mut arr, s![..;-2, ..], s![..;2, ..]); + test_multislice!(&mut arr, s![..;12, ..], s![3..;3, ..]); +} + +#[test] +fn test_multislice_intersecting() { + assert_panics!({ + let mut arr = Array2::::zeros((8, 6)); + multislice!(arr, mut [3, ..], [3, ..]); + }); + assert_panics!({ + let mut arr = Array2::::zeros((8, 6)); + multislice!(arr, mut [3, ..], [3.., ..]); + }); + assert_panics!({ + let mut arr = Array2::::zeros((8, 6)); + multislice!(arr, mut [3, ..], [..;3, ..]); + }); + assert_panics!({ + let mut arr = Array2::::zeros((8, 6)); + multislice!(arr, mut [..;6, ..], [3..;3, ..]); + }); + assert_panics!({ + let mut arr = Array2::::zeros((8, 6)); + multislice!(arr, mut [2, ..], mut [..-1;-2, ..]); + }); + { + let mut arr = Array2::::zeros((8, 6)); + multislice!(arr, [3, ..], [-1..;-2, ..]); + } +} + +#[test] +fn test_multislice_eval_args_only_once() { + let mut arr = Array1::::zeros(10); + let mut eval_count = 0; + { + let mut slice = || { + eval_count += 1; + s![1..2].clone() + }; + multislice!(arr, mut &slice(), [3..4], [5..6]); + } + assert_eq!(eval_count, 1); + let mut eval_count = 0; + { + let mut slice = || { + eval_count += 1; + s![1..2].clone() + }; + multislice!(arr, [3..4], mut &slice(), [5..6]); + } + assert_eq!(eval_count, 1); + let mut eval_count = 0; + { + let mut slice = || { + eval_count += 1; + s![1..2].clone() + }; + multislice!(arr, [3..4], [5..6], mut &slice()); + } + assert_eq!(eval_count, 1); + let mut eval_count = 0; + { + let mut slice = || { + eval_count += 1; + s![1..2].clone() + }; + multislice!(arr, &slice(), mut [3..4], [5..6]); + } + assert_eq!(eval_count, 1); + let mut eval_count = 0; + { + let mut slice = || { + eval_count += 1; + s![1..2].clone() + }; + multislice!(arr, mut [3..4], &slice(), [5..6]); + } + assert_eq!(eval_count, 1); + let mut eval_count = 0; + { + let mut slice = || { + eval_count += 1; + s![1..2].clone() + }; + multislice!(arr, mut [3..4], [5..6], &slice()); + } + assert_eq!(eval_count, 1); +} + #[should_panic] #[test] fn index_out_of_bounds() {