From 085a319ea9cd1e24f669f6e5758b109ddfbb7dec Mon Sep 17 00:00:00 2001 From: SparrowLii Date: Wed, 3 Feb 2021 21:04:16 +0800 Subject: [PATCH 1/3] implement creating SliceArg from arbitrary Dimension --- blas-tests/tests/oper.rs | 2 +- src/dimension/dimension_trait.rs | 26 +++++- src/impl_methods.rs | 9 +- src/slice.rs | 149 ++++++++++++++++++++++--------- tests/array.rs | 32 ++++++- tests/oper.rs | 2 +- 6 files changed, 167 insertions(+), 53 deletions(-) diff --git a/blas-tests/tests/oper.rs b/blas-tests/tests/oper.rs index b7c2d769d..c3a409888 100644 --- a/blas-tests/tests/oper.rs +++ b/blas-tests/tests/oper.rs @@ -433,7 +433,7 @@ fn scaled_add_3() { { let mut av = a.slice_mut(s![..;s1, ..;s2]); - let c = c.slice(SliceInfo::<_, IxDyn>::new(cslice).unwrap().as_ref()); + let c = c.slice(SliceInfo::<_, IxDyn, IxDyn>::new(cslice).unwrap().as_ref()); let mut answerv = answer.slice_mut(s![..;s1, ..;s2]); answerv += &(beta * &c); diff --git a/src/dimension/dimension_trait.rs b/src/dimension/dimension_trait.rs index aa1f7f95a..3554a8bb6 100644 --- a/src/dimension/dimension_trait.rs +++ b/src/dimension/dimension_trait.rs @@ -9,6 +9,7 @@ use std::fmt::Debug; use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; use std::ops::{Index, IndexMut}; +use std::convert::TryInto; use alloc::vec::Vec; use super::axes_of; @@ -63,7 +64,7 @@ pub trait Dimension: /// - and so on.. /// - For `IxDyn`: `[SliceOrIndex]` /// - /// The easiest way to create a `&SliceInfo` is using the + /// The easiest way to create a `&SliceInfo` is using the /// [`s![]`](macro.s!.html) macro. type SliceArg: ?Sized + AsRef<[SliceOrIndex]>; /// Pattern matching friendly form of the dimension value. @@ -77,7 +78,9 @@ pub trait Dimension: type Smaller: Dimension; /// Next larger dimension type Larger: Dimension + RemoveAxis; - + /// Convert index to &Self::SliceArg. Make sure that length of index + /// consists with Self::NDIM(if it exists). + fn slice_arg_from(index: &[SliceOrIndex]) -> &Self::SliceArg ; /// Returns the number of dimensions (number of axes). fn ndim(&self) -> usize; @@ -398,6 +401,10 @@ impl Dimension for Dim<[Ix; 0]> { type Pattern = (); type Smaller = Self; type Larger = Ix1; + #[inline] + fn slice_arg_from(index: &[SliceOrIndex]) -> &Self::SliceArg { + index.try_into().unwrap() + } // empty product is 1 -> size is 1 #[inline] fn ndim(&self) -> usize { @@ -442,6 +449,9 @@ impl Dimension for Dim<[Ix; 1]> { type Pattern = Ix; type Smaller = Ix0; type Larger = Ix2; + fn slice_arg_from(index: &[SliceOrIndex]) -> &Self::SliceArg { + index.try_into().unwrap() + } #[inline] fn ndim(&self) -> usize { 1 @@ -558,6 +568,9 @@ impl Dimension for Dim<[Ix; 2]> { type Pattern = (Ix, Ix); type Smaller = Ix1; type Larger = Ix3; + fn slice_arg_from(index: &[SliceOrIndex]) -> &Self::SliceArg { + index.try_into().unwrap() + } #[inline] fn ndim(&self) -> usize { 2 @@ -715,6 +728,9 @@ impl Dimension for Dim<[Ix; 3]> { type Pattern = (Ix, Ix, Ix); type Smaller = Ix2; type Larger = Ix4; + fn slice_arg_from(index: &[SliceOrIndex]) -> &Self::SliceArg { + index.try_into().unwrap() + } #[inline] fn ndim(&self) -> usize { 3 @@ -838,6 +854,9 @@ macro_rules! large_dim { type Pattern = $pattern; type Smaller = Dim<[Ix; $n - 1]>; type Larger = $larger; + fn slice_arg_from(index: &[SliceOrIndex]) -> &Self::SliceArg { + index.try_into().unwrap() + } #[inline] fn ndim(&self) -> usize { $n } #[inline] @@ -889,6 +908,9 @@ impl Dimension for IxDyn { type Pattern = Self; type Smaller = Self; type Larger = Self; + fn slice_arg_from(index: &[SliceOrIndex]) -> &Self::SliceArg { + index + } #[inline] fn ndim(&self) -> usize { self.ix().len() diff --git a/src/impl_methods.rs b/src/impl_methods.rs index f8c0ee919..26ef7f8a7 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -340,9 +340,10 @@ where /// /// **Panics** if an index is out of bounds or step size is zero.
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) - pub fn slice(&self, info: &SliceInfo) -> ArrayView<'_, A, Do> + pub fn slice(&self, info: &SliceInfo) -> ArrayView<'_, A, Do> where Do: Dimension, + D2: Dimension, S: Data, { self.view().slice_move(info) @@ -358,9 +359,10 @@ where /// /// **Panics** if an index is out of bounds or step size is zero.
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) - pub fn slice_mut(&mut self, info: &SliceInfo) -> ArrayViewMut<'_, A, Do> + pub fn slice_mut(&mut self, info: &SliceInfo) -> ArrayViewMut<'_, A, Do> where Do: Dimension, + D2: Dimension, S: DataMut, { self.view_mut().slice_move(info) @@ -409,9 +411,10 @@ where /// /// **Panics** if an index is out of bounds or step size is zero.
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) - pub fn slice_move(mut self, info: &SliceInfo) -> ArrayBase + pub fn slice_move(mut self, info: &SliceInfo) -> ArrayBase where Do: Dimension, + D2: Dimension, { // Slice and collapse in-place without changing the number of dimensions. self.slice_collapse(&*info); diff --git a/src/slice.rs b/src/slice.rs index 86a2b0b8f..ed5228e44 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -7,7 +7,7 @@ // except according to those terms. use crate::dimension::slices_intersect; use crate::error::{ErrorKind, ShapeError}; -use crate::{ArrayViewMut, Dimension}; +use crate::{ArrayViewMut, Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn}; use std::fmt; use std::marker::PhantomData; use std::ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive}; @@ -70,7 +70,7 @@ impl Slice { /// A slice (range with step) or an index. /// /// See also the [`s![]`](macro.s!.html) macro for a convenient way to create a -/// `&SliceInfo<[SliceOrIndex; n], D>`. +/// `&SliceInfo<[SliceOrIndex; n], Do, D>`. /// /// ## Examples /// @@ -285,13 +285,15 @@ impl_sliceorindex_from_index!(i32); /// [`.slice()`]: struct.ArrayBase.html#method.slice #[derive(Debug)] #[repr(C)] -pub struct SliceInfo { - out_dim: PhantomData, +pub struct SliceInfo { + out_dim: PhantomData, + in_dim: PhantomData, indices: T, } -impl Deref for SliceInfo +impl Deref for SliceInfo where + Do: Dimension, D: Dimension, { type Target = T; @@ -300,8 +302,9 @@ where } } -impl SliceInfo +impl SliceInfo where + Do: Dimension, D: Dimension, { /// Returns a new `SliceInfo` instance. @@ -309,46 +312,54 @@ where /// If you call this method, you are guaranteeing that `out_dim` is /// consistent with `indices`. #[doc(hidden)] - pub unsafe fn new_unchecked(indices: T, out_dim: PhantomData) -> SliceInfo { - SliceInfo { out_dim, indices } + pub unsafe fn new_unchecked(indices: T, out_dim: PhantomData, in_dim: PhantomData,) -> SliceInfo { + SliceInfo { out_dim, in_dim, indices } } } -impl SliceInfo +impl SliceInfo where T: AsRef<[SliceOrIndex]>, + Do: Dimension, D: Dimension, { /// Returns a new `SliceInfo` instance. /// - /// Errors if `D` is not consistent with `indices`. - pub fn new(indices: T) -> Result, ShapeError> { - if let Some(ndim) = D::NDIM { + /// Errors if `Do` or `D` is not consistent with `indices`. + pub fn new(indices: T) -> Result, ShapeError> { + if let Some(ndim) = Do::NDIM { if ndim != indices.as_ref().iter().filter(|s| s.is_slice()).count() { return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); } } + if let Some(ndim) = D::NDIM { + if ndim != indices.as_ref().iter().count() { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); + } + } Ok(SliceInfo { out_dim: PhantomData, + in_dim: PhantomData, indices, }) } } -impl SliceInfo +impl SliceInfo where T: AsRef<[SliceOrIndex]>, + Do: Dimension, D: Dimension, { /// Returns the number of dimensions after calling /// [`.slice()`](struct.ArrayBase.html#method.slice) (including taking /// subviews). /// - /// If `D` is a fixed-size dimension type, then this is equivalent to - /// `D::NDIM.unwrap()`. Otherwise, the value is calculated by iterating + /// If `Do` is a fixed-size dimension type, then this is equivalent to + /// `Do::NDIM.unwrap()`. Otherwise, the value is calculated by iterating /// over the ranges/indices. pub fn out_ndim(&self) -> usize { - D::NDIM.unwrap_or_else(|| { + Do::NDIM.unwrap_or_else(|| { self.indices .as_ref() .iter() @@ -358,9 +369,10 @@ where } } -impl AsRef<[SliceOrIndex]> for SliceInfo +impl AsRef<[SliceOrIndex]> for SliceInfo where T: AsRef<[SliceOrIndex]>, + Do: Dimension, D: Dimension, { fn as_ref(&self) -> &[SliceOrIndex] { @@ -368,54 +380,98 @@ where } } -impl AsRef> for SliceInfo -where - T: AsRef<[SliceOrIndex]>, - D: Dimension, +impl AsRef> for SliceInfo + where + T: AsRef<[SliceOrIndex]>, + Do: Dimension, + D: Dimension, { - fn as_ref(&self) -> &SliceInfo<[SliceOrIndex], D> { + fn as_ref(&self) -> &SliceInfo { + let index = self.indices.as_ref(); + if let Some(dim) = D::NDIM { + debug_assert!(index.len() == dim); + } + let arg_ref = D::slice_arg_from(index); unsafe { // This is okay because the only non-zero-sized member of // `SliceInfo` is `indices`, so `&SliceInfo<[SliceOrIndex], D>` // should have the same bitwise representation as // `&[SliceOrIndex]`. - &*(self.indices.as_ref() as *const [SliceOrIndex] - as *const SliceInfo<[SliceOrIndex], D>) + &*(arg_ref as *const D::SliceArg + as *const SliceInfo) + } + } +} + +macro_rules! asref_dyn { + ($dim:ty) => { + impl AsRef> for SliceInfo + where + T: AsRef<[SliceOrIndex]>, + Do: Dimension, + { + fn as_ref(&self) -> &SliceInfo<[SliceOrIndex], Do, IxDyn> { + let index = self.indices.as_ref(); + unsafe { + // This is okay because the only non-zero-sized member of + // `SliceInfo` is `indices`, so `&SliceInfo<[SliceOrIndex], D>` + // should have the same bitwise representation as + // `&[SliceOrIndex]`. + &*(index as *const [SliceOrIndex] + as *const SliceInfo<[SliceOrIndex], Do, IxDyn>) + } } } + }; } +asref_dyn!(Ix0); +asref_dyn!(Ix1); +asref_dyn!(Ix2); +asref_dyn!(Ix3); +asref_dyn!(Ix4); +asref_dyn!(Ix5); +asref_dyn!(Ix6); + -impl Copy for SliceInfo +impl Copy for SliceInfo where T: Copy, + Do: Dimension, D: Dimension, { } -impl Clone for SliceInfo +impl Clone for SliceInfo where T: Clone, + Do: Dimension, D: Dimension, { fn clone(&self) -> Self { SliceInfo { out_dim: PhantomData, + in_dim: PhantomData, indices: self.indices.clone(), } } } #[doc(hidden)] -pub trait SliceNextDim { +pub trait SliceNextDim { fn next_dim(&self, _: PhantomData) -> PhantomData; + + fn next_dim_inc(&self, _: PhantomData) -> PhantomData; } macro_rules! impl_slicenextdim_equal { ($self:ty) => { - impl SliceNextDim for $self { + impl SliceNextDim for $self { fn next_dim(&self, _: PhantomData) -> PhantomData { PhantomData } + fn next_dim_inc(&self, _: PhantomData) -> PhantomData { + PhantomData + } } }; } @@ -425,10 +481,13 @@ impl_slicenextdim_equal!(i32); macro_rules! impl_slicenextdim_larger { (($($generics:tt)*), $self:ty) => { - impl SliceNextDim for $self { + impl SliceNextDim for $self { fn next_dim(&self, _: PhantomData) -> PhantomData { PhantomData } + fn next_dim_inc(&self, _: PhantomData) -> PhantomData { + PhantomData + } } } } @@ -534,49 +593,54 @@ impl_slicenextdim_larger!((), Slice); #[macro_export] macro_rules! s( // convert a..b;c into @convert(a..b, c), final item - (@parse $dim:expr, [$($stack:tt)*] $r:expr;$s:expr) => { + (@parse $dim:expr, $dim2:expr, [$($stack:tt)*] $r:expr;$s:expr) => { match $r { r => { let out_dim = $crate::SliceNextDim::next_dim(&r, $dim); + let in_dim = $crate::SliceNextDim::next_dim_inc(&r, $dim2); #[allow(unsafe_code)] unsafe { $crate::SliceInfo::new_unchecked( [$($stack)* $crate::s!(@convert r, $s)], out_dim, + in_dim, ) } } } }; // convert a..b into @convert(a..b), final item - (@parse $dim:expr, [$($stack:tt)*] $r:expr) => { + (@parse $dim:expr, $dim2:expr, [$($stack:tt)*] $r:expr) => { match $r { r => { let out_dim = $crate::SliceNextDim::next_dim(&r, $dim); + let in_dim = $crate::SliceNextDim::next_dim_inc(&r, $dim2); #[allow(unsafe_code)] unsafe { $crate::SliceInfo::new_unchecked( [$($stack)* $crate::s!(@convert r)], out_dim, + in_dim, ) } } } }; // convert a..b;c into @convert(a..b, c), final item, trailing comma - (@parse $dim:expr, [$($stack:tt)*] $r:expr;$s:expr ,) => { - $crate::s![@parse $dim, [$($stack)*] $r;$s] + (@parse $dim:expr, $dim2:expr, [$($stack:tt)*] $r:expr;$s:expr ,) => { + $crate::s![@parse $dim, $dim2, [$($stack)*] $r;$s] }; // convert a..b into @convert(a..b), final item, trailing comma - (@parse $dim:expr, [$($stack:tt)*] $r:expr ,) => { - $crate::s![@parse $dim, [$($stack)*] $r] + (@parse $dim:expr, $dim2:expr, [$($stack:tt)*] $r:expr ,) => { + $crate::s![@parse $dim, $dim2, [$($stack)*] $r] }; // convert a..b;c into @convert(a..b, c) - (@parse $dim:expr, [$($stack:tt)*] $r:expr;$s:expr, $($t:tt)*) => { + (@parse $dim:expr, $dim2:expr, [$($stack:tt)*] $r:expr;$s:expr, $($t:tt)*) => { match $r { r => { $crate::s![@parse $crate::SliceNextDim::next_dim(&r, $dim), + $crate::SliceNextDim::next_dim_inc(&r, $dim2), [$($stack)* $crate::s!(@convert r, $s),] $($t)* ] @@ -584,11 +648,12 @@ macro_rules! s( } }; // convert a..b into @convert(a..b) - (@parse $dim:expr, [$($stack:tt)*] $r:expr, $($t:tt)*) => { + (@parse $dim:expr, $dim2:expr, [$($stack:tt)*] $r:expr, $($t:tt)*) => { match $r { r => { $crate::s![@parse $crate::SliceNextDim::next_dim(&r, $dim), + $crate::SliceNextDim::next_dim_inc(&r, $dim2), [$($stack)* $crate::s!(@convert r),] $($t)* ] @@ -596,11 +661,11 @@ macro_rules! s( } }; // empty call, i.e. `s![]` - (@parse ::std::marker::PhantomData::<$crate::Ix0>, []) => { + (@parse ::std::marker::PhantomData::<$crate::Ix0>, ::std::marker::PhantomData::<$crate::Ix0>, []) => { { #[allow(unsafe_code)] unsafe { - $crate::SliceInfo::new_unchecked([], ::std::marker::PhantomData::<$crate::Ix0>) + $crate::SliceInfo::new_unchecked([], ::std::marker::PhantomData::<$crate::Ix0>, ::std::marker::PhantomData::<$crate::Ix0>) } } }; @@ -617,7 +682,7 @@ macro_rules! s( ($($t:tt)*) => { // The extra `*&` is a workaround for this compiler bug: // https://github.com/rust-lang/rust/issues/23014 - &*&$crate::s![@parse ::std::marker::PhantomData::<$crate::Ix0>, [] $($t)*] + &*&$crate::s![@parse ::std::marker::PhantomData::<$crate::Ix0>, ::std::marker::PhantomData::<$crate::Ix0>, [] $($t)*] }; ); @@ -650,7 +715,7 @@ where fn multi_slice_move(&self, _view: ArrayViewMut<'a, A, D>) -> Self::Output {} } -impl<'a, A, D, Do0> MultiSlice<'a, A, D> for (&SliceInfo,) +impl<'a, A, D, Do0> MultiSlice<'a, A, D> for (&SliceInfo,) where A: 'a, D: Dimension, @@ -668,7 +733,7 @@ macro_rules! impl_multislice_tuple { impl_multislice_tuple!(@def_impl ($($but_last,)* $last,), [$($but_last)*] $last); }; (@def_impl ($($all:ident,)*), [$($but_last:ident)*] $last:ident) => { - impl<'a, A, D, $($all,)*> MultiSlice<'a, A, D> for ($(&SliceInfo,)*) + impl<'a, A, D, $($all,)*> MultiSlice<'a, A, D> for ($(&SliceInfo,)*) where A: 'a, D: Dimension, diff --git a/tests/array.rs b/tests/array.rs index 6581e572e..37c96c287 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -217,7 +217,7 @@ fn test_slice_dyninput_array_fixed() { #[test] fn test_slice_array_dyn() { let mut arr = Array3::::zeros((5, 2, 5)); - let info = &SliceInfo::<_, IxDyn>::new([ + let info = &SliceInfo::<_, IxDyn, IxDyn>::new([ SliceOrIndex::from(1..), SliceOrIndex::from(1), SliceOrIndex::from(..).step_by(2), @@ -232,7 +232,7 @@ fn test_slice_array_dyn() { #[test] fn test_slice_dyninput_array_dyn() { let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); - let info = &SliceInfo::<_, IxDyn>::new([ + let info = &SliceInfo::<_, IxDyn, IxDyn>::new([ SliceOrIndex::from(1..), SliceOrIndex::from(1), SliceOrIndex::from(..).step_by(2), @@ -247,7 +247,7 @@ fn test_slice_dyninput_array_dyn() { #[test] fn test_slice_dyninput_vec_fixed() { let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); - let info = &SliceInfo::<_, Ix2>::new(vec![ + let info = &SliceInfo::<_, Ix2, Ix3>::new(vec![ SliceOrIndex::from(1..), SliceOrIndex::from(1), SliceOrIndex::from(..).step_by(2), @@ -262,7 +262,7 @@ fn test_slice_dyninput_vec_fixed() { #[test] fn test_slice_dyninput_vec_dyn() { let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); - let info = &SliceInfo::<_, IxDyn>::new(vec![ + let info = &SliceInfo::<_, IxDyn, IxDyn>::new(vec![ SliceOrIndex::from(1..), SliceOrIndex::from(1), SliceOrIndex::from(..).step_by(2), @@ -274,6 +274,30 @@ fn test_slice_dyninput_vec_dyn() { arr.view().slice_collapse(info.as_ref()); } +#[test] +fn test_slice_arg() { + fn use_arg_map(shape: Sh, f: impl Fn(&usize) -> SliceOrIndex, shape2: D) + where + Sh: ShapeBuilder, + D: Dimension, + { + let shape = shape.into_shape(); + let mut x = Array::from_elem(shape, 0); + let indices = x.shape().iter().map(f).collect::>(); + let s = x.slice_mut( + SliceInfo::<_, Sh::Dim, Sh::Dim>::new(indices) + .unwrap() + .as_ref(), + ); + let s2 = shape2.slice(); + assert_eq!(s.shape(), s2) + } + use_arg_map(0,|x| SliceOrIndex::from(*x/2..*x),Dim([0])); + use_arg_map((2, 4, 8),|x| SliceOrIndex::from(*x/2..*x),Dim([1, 2, 4])); + use_arg_map(vec![3, 6, 9],|x| SliceOrIndex::from(*x/3..*x/2),Dim([0, 1, 1])); + use_arg_map(vec![1, 2, 3, 4, 5, 6, 7], |x| SliceOrIndex::from(x-1), Dim([])); +} + #[test] fn test_slice_with_subview() { let mut arr = ArcArray::::zeros((3, 5, 4)); diff --git a/tests/oper.rs b/tests/oper.rs index 66c24c7f8..adcb6e7c7 100644 --- a/tests/oper.rs +++ b/tests/oper.rs @@ -597,7 +597,7 @@ fn scaled_add_3() { { let mut av = a.slice_mut(s![..;s1, ..;s2]); - let c = c.slice(SliceInfo::<_, IxDyn>::new(cslice).unwrap().as_ref()); + let c = c.slice(SliceInfo::<_, IxDyn, IxDyn>::new(cslice).unwrap().as_ref()); let mut answerv = answer.slice_mut(s![..;s1, ..;s2]); answerv += &(beta * &c); From 9d76f13a11d80f469baabc27ac5a3d007fab0bd1 Mon Sep 17 00:00:00 2001 From: SparrowLii Date: Thu, 4 Feb 2021 20:13:35 +0800 Subject: [PATCH 2/3] use slice_info_from in dimension_trait.rs instead --- blas-tests/tests/oper.rs | 2 +- src/dimension/dimension_trait.rs | 30 ++++++- src/impl_methods.rs | 9 +- src/slice.rs | 149 +++++++++---------------------- tests/array.rs | 12 ++- tests/oper.rs | 2 +- 6 files changed, 80 insertions(+), 124 deletions(-) diff --git a/blas-tests/tests/oper.rs b/blas-tests/tests/oper.rs index c3a409888..b7c2d769d 100644 --- a/blas-tests/tests/oper.rs +++ b/blas-tests/tests/oper.rs @@ -433,7 +433,7 @@ fn scaled_add_3() { { let mut av = a.slice_mut(s![..;s1, ..;s2]); - let c = c.slice(SliceInfo::<_, IxDyn, IxDyn>::new(cslice).unwrap().as_ref()); + let c = c.slice(SliceInfo::<_, IxDyn>::new(cslice).unwrap().as_ref()); let mut answerv = answer.slice_mut(s![..;s1, ..;s2]); answerv += &(beta * &c); diff --git a/src/dimension/dimension_trait.rs b/src/dimension/dimension_trait.rs index 3554a8bb6..3c705311a 100644 --- a/src/dimension/dimension_trait.rs +++ b/src/dimension/dimension_trait.rs @@ -16,7 +16,7 @@ use super::axes_of; use super::conversion::Convert; use super::{stride_offset, stride_offset_checked}; use crate::itertools::{enumerate, zip}; -use crate::Axis; +use crate::{Axis, SliceInfo}; use crate::IntoDimension; use crate::RemoveAxis; use crate::{ArrayView1, ArrayViewMut1}; @@ -64,7 +64,7 @@ pub trait Dimension: /// - and so on.. /// - For `IxDyn`: `[SliceOrIndex]` /// - /// The easiest way to create a `&SliceInfo` is using the + /// The easiest way to create a `&SliceInfo` is using the /// [`s![]`](macro.s!.html) macro. type SliceArg: ?Sized + AsRef<[SliceOrIndex]>; /// Pattern matching friendly form of the dimension value. @@ -78,9 +78,35 @@ pub trait Dimension: type Smaller: Dimension; /// Next larger dimension type Larger: Dimension + RemoveAxis; + /// Convert index to &Self::SliceArg. Make sure that length of index /// consists with Self::NDIM(if it exists). + /// + /// Panics if conversion failed. + #[doc(hidden)] fn slice_arg_from(index: &[SliceOrIndex]) -> &Self::SliceArg ; + + /// Convert &SliceInfo, Do> to &SliceInfo. + /// Generate SliceArg of any dimension via this method. + /// + /// Panics if conversion failed. + #[doc(hidden)] + fn slice_info_from(indices: &T) -> &SliceInfo + where + T: AsRef<[SliceOrIndex]>, + Do: Dimension, + { + let arg_ref = Self::slice_arg_from(indices.as_ref()); + unsafe { + // This is okay because the only non-zero-sized member of + // `SliceInfo` is `indices`, so `&SliceInfo<[SliceOrIndex], D>` + // should have the same bitwise representation as + // `&[SliceOrIndex]`. + &*(arg_ref as *const Self::SliceArg + as *const SliceInfo) + } + } + /// Returns the number of dimensions (number of axes). fn ndim(&self) -> usize; diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 26ef7f8a7..f8c0ee919 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -340,10 +340,9 @@ where /// /// **Panics** if an index is out of bounds or step size is zero.
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) - pub fn slice(&self, info: &SliceInfo) -> ArrayView<'_, A, Do> + pub fn slice(&self, info: &SliceInfo) -> ArrayView<'_, A, Do> where Do: Dimension, - D2: Dimension, S: Data, { self.view().slice_move(info) @@ -359,10 +358,9 @@ where /// /// **Panics** if an index is out of bounds or step size is zero.
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) - pub fn slice_mut(&mut self, info: &SliceInfo) -> ArrayViewMut<'_, A, Do> + pub fn slice_mut(&mut self, info: &SliceInfo) -> ArrayViewMut<'_, A, Do> where Do: Dimension, - D2: Dimension, S: DataMut, { self.view_mut().slice_move(info) @@ -411,10 +409,9 @@ where /// /// **Panics** if an index is out of bounds or step size is zero.
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) - pub fn slice_move(mut self, info: &SliceInfo) -> ArrayBase + pub fn slice_move(mut self, info: &SliceInfo) -> ArrayBase where Do: Dimension, - D2: Dimension, { // Slice and collapse in-place without changing the number of dimensions. self.slice_collapse(&*info); diff --git a/src/slice.rs b/src/slice.rs index ed5228e44..86a2b0b8f 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -7,7 +7,7 @@ // except according to those terms. use crate::dimension::slices_intersect; use crate::error::{ErrorKind, ShapeError}; -use crate::{ArrayViewMut, Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn}; +use crate::{ArrayViewMut, Dimension}; use std::fmt; use std::marker::PhantomData; use std::ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive}; @@ -70,7 +70,7 @@ impl Slice { /// A slice (range with step) or an index. /// /// See also the [`s![]`](macro.s!.html) macro for a convenient way to create a -/// `&SliceInfo<[SliceOrIndex; n], Do, D>`. +/// `&SliceInfo<[SliceOrIndex; n], D>`. /// /// ## Examples /// @@ -285,15 +285,13 @@ impl_sliceorindex_from_index!(i32); /// [`.slice()`]: struct.ArrayBase.html#method.slice #[derive(Debug)] #[repr(C)] -pub struct SliceInfo { - out_dim: PhantomData, - in_dim: PhantomData, +pub struct SliceInfo { + out_dim: PhantomData, indices: T, } -impl Deref for SliceInfo +impl Deref for SliceInfo where - Do: Dimension, D: Dimension, { type Target = T; @@ -302,9 +300,8 @@ where } } -impl SliceInfo +impl SliceInfo where - Do: Dimension, D: Dimension, { /// Returns a new `SliceInfo` instance. @@ -312,54 +309,46 @@ where /// If you call this method, you are guaranteeing that `out_dim` is /// consistent with `indices`. #[doc(hidden)] - pub unsafe fn new_unchecked(indices: T, out_dim: PhantomData, in_dim: PhantomData,) -> SliceInfo { - SliceInfo { out_dim, in_dim, indices } + pub unsafe fn new_unchecked(indices: T, out_dim: PhantomData) -> SliceInfo { + SliceInfo { out_dim, indices } } } -impl SliceInfo +impl SliceInfo where T: AsRef<[SliceOrIndex]>, - Do: Dimension, D: Dimension, { /// Returns a new `SliceInfo` instance. /// - /// Errors if `Do` or `D` is not consistent with `indices`. - pub fn new(indices: T) -> Result, ShapeError> { - if let Some(ndim) = Do::NDIM { - if ndim != indices.as_ref().iter().filter(|s| s.is_slice()).count() { - return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); - } - } + /// Errors if `D` is not consistent with `indices`. + pub fn new(indices: T) -> Result, ShapeError> { if let Some(ndim) = D::NDIM { - if ndim != indices.as_ref().iter().count() { + if ndim != indices.as_ref().iter().filter(|s| s.is_slice()).count() { return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); } } Ok(SliceInfo { out_dim: PhantomData, - in_dim: PhantomData, indices, }) } } -impl SliceInfo +impl SliceInfo where T: AsRef<[SliceOrIndex]>, - Do: Dimension, D: Dimension, { /// Returns the number of dimensions after calling /// [`.slice()`](struct.ArrayBase.html#method.slice) (including taking /// subviews). /// - /// If `Do` is a fixed-size dimension type, then this is equivalent to - /// `Do::NDIM.unwrap()`. Otherwise, the value is calculated by iterating + /// If `D` is a fixed-size dimension type, then this is equivalent to + /// `D::NDIM.unwrap()`. Otherwise, the value is calculated by iterating /// over the ranges/indices. pub fn out_ndim(&self) -> usize { - Do::NDIM.unwrap_or_else(|| { + D::NDIM.unwrap_or_else(|| { self.indices .as_ref() .iter() @@ -369,10 +358,9 @@ where } } -impl AsRef<[SliceOrIndex]> for SliceInfo +impl AsRef<[SliceOrIndex]> for SliceInfo where T: AsRef<[SliceOrIndex]>, - Do: Dimension, D: Dimension, { fn as_ref(&self) -> &[SliceOrIndex] { @@ -380,98 +368,54 @@ where } } -impl AsRef> for SliceInfo - where - T: AsRef<[SliceOrIndex]>, - Do: Dimension, - D: Dimension, +impl AsRef> for SliceInfo +where + T: AsRef<[SliceOrIndex]>, + D: Dimension, { - fn as_ref(&self) -> &SliceInfo { - let index = self.indices.as_ref(); - if let Some(dim) = D::NDIM { - debug_assert!(index.len() == dim); - } - let arg_ref = D::slice_arg_from(index); + fn as_ref(&self) -> &SliceInfo<[SliceOrIndex], D> { unsafe { // This is okay because the only non-zero-sized member of // `SliceInfo` is `indices`, so `&SliceInfo<[SliceOrIndex], D>` // should have the same bitwise representation as // `&[SliceOrIndex]`. - &*(arg_ref as *const D::SliceArg - as *const SliceInfo) - } - } -} - -macro_rules! asref_dyn { - ($dim:ty) => { - impl AsRef> for SliceInfo - where - T: AsRef<[SliceOrIndex]>, - Do: Dimension, - { - fn as_ref(&self) -> &SliceInfo<[SliceOrIndex], Do, IxDyn> { - let index = self.indices.as_ref(); - unsafe { - // This is okay because the only non-zero-sized member of - // `SliceInfo` is `indices`, so `&SliceInfo<[SliceOrIndex], D>` - // should have the same bitwise representation as - // `&[SliceOrIndex]`. - &*(index as *const [SliceOrIndex] - as *const SliceInfo<[SliceOrIndex], Do, IxDyn>) - } + &*(self.indices.as_ref() as *const [SliceOrIndex] + as *const SliceInfo<[SliceOrIndex], D>) } } - }; } -asref_dyn!(Ix0); -asref_dyn!(Ix1); -asref_dyn!(Ix2); -asref_dyn!(Ix3); -asref_dyn!(Ix4); -asref_dyn!(Ix5); -asref_dyn!(Ix6); - -impl Copy for SliceInfo +impl Copy for SliceInfo where T: Copy, - Do: Dimension, D: Dimension, { } -impl Clone for SliceInfo +impl Clone for SliceInfo where T: Clone, - Do: Dimension, D: Dimension, { fn clone(&self) -> Self { SliceInfo { out_dim: PhantomData, - in_dim: PhantomData, indices: self.indices.clone(), } } } #[doc(hidden)] -pub trait SliceNextDim { +pub trait SliceNextDim { fn next_dim(&self, _: PhantomData) -> PhantomData; - - fn next_dim_inc(&self, _: PhantomData) -> PhantomData; } macro_rules! impl_slicenextdim_equal { ($self:ty) => { - impl SliceNextDim for $self { + impl SliceNextDim for $self { fn next_dim(&self, _: PhantomData) -> PhantomData { PhantomData } - fn next_dim_inc(&self, _: PhantomData) -> PhantomData { - PhantomData - } } }; } @@ -481,13 +425,10 @@ impl_slicenextdim_equal!(i32); macro_rules! impl_slicenextdim_larger { (($($generics:tt)*), $self:ty) => { - impl SliceNextDim for $self { + impl SliceNextDim for $self { fn next_dim(&self, _: PhantomData) -> PhantomData { PhantomData } - fn next_dim_inc(&self, _: PhantomData) -> PhantomData { - PhantomData - } } } } @@ -593,54 +534,49 @@ impl_slicenextdim_larger!((), Slice); #[macro_export] macro_rules! s( // convert a..b;c into @convert(a..b, c), final item - (@parse $dim:expr, $dim2:expr, [$($stack:tt)*] $r:expr;$s:expr) => { + (@parse $dim:expr, [$($stack:tt)*] $r:expr;$s:expr) => { match $r { r => { let out_dim = $crate::SliceNextDim::next_dim(&r, $dim); - let in_dim = $crate::SliceNextDim::next_dim_inc(&r, $dim2); #[allow(unsafe_code)] unsafe { $crate::SliceInfo::new_unchecked( [$($stack)* $crate::s!(@convert r, $s)], out_dim, - in_dim, ) } } } }; // convert a..b into @convert(a..b), final item - (@parse $dim:expr, $dim2:expr, [$($stack:tt)*] $r:expr) => { + (@parse $dim:expr, [$($stack:tt)*] $r:expr) => { match $r { r => { let out_dim = $crate::SliceNextDim::next_dim(&r, $dim); - let in_dim = $crate::SliceNextDim::next_dim_inc(&r, $dim2); #[allow(unsafe_code)] unsafe { $crate::SliceInfo::new_unchecked( [$($stack)* $crate::s!(@convert r)], out_dim, - in_dim, ) } } } }; // convert a..b;c into @convert(a..b, c), final item, trailing comma - (@parse $dim:expr, $dim2:expr, [$($stack:tt)*] $r:expr;$s:expr ,) => { - $crate::s![@parse $dim, $dim2, [$($stack)*] $r;$s] + (@parse $dim:expr, [$($stack:tt)*] $r:expr;$s:expr ,) => { + $crate::s![@parse $dim, [$($stack)*] $r;$s] }; // convert a..b into @convert(a..b), final item, trailing comma - (@parse $dim:expr, $dim2:expr, [$($stack:tt)*] $r:expr ,) => { - $crate::s![@parse $dim, $dim2, [$($stack)*] $r] + (@parse $dim:expr, [$($stack:tt)*] $r:expr ,) => { + $crate::s![@parse $dim, [$($stack)*] $r] }; // convert a..b;c into @convert(a..b, c) - (@parse $dim:expr, $dim2:expr, [$($stack:tt)*] $r:expr;$s:expr, $($t:tt)*) => { + (@parse $dim:expr, [$($stack:tt)*] $r:expr;$s:expr, $($t:tt)*) => { match $r { r => { $crate::s![@parse $crate::SliceNextDim::next_dim(&r, $dim), - $crate::SliceNextDim::next_dim_inc(&r, $dim2), [$($stack)* $crate::s!(@convert r, $s),] $($t)* ] @@ -648,12 +584,11 @@ macro_rules! s( } }; // convert a..b into @convert(a..b) - (@parse $dim:expr, $dim2:expr, [$($stack:tt)*] $r:expr, $($t:tt)*) => { + (@parse $dim:expr, [$($stack:tt)*] $r:expr, $($t:tt)*) => { match $r { r => { $crate::s![@parse $crate::SliceNextDim::next_dim(&r, $dim), - $crate::SliceNextDim::next_dim_inc(&r, $dim2), [$($stack)* $crate::s!(@convert r),] $($t)* ] @@ -661,11 +596,11 @@ macro_rules! s( } }; // empty call, i.e. `s![]` - (@parse ::std::marker::PhantomData::<$crate::Ix0>, ::std::marker::PhantomData::<$crate::Ix0>, []) => { + (@parse ::std::marker::PhantomData::<$crate::Ix0>, []) => { { #[allow(unsafe_code)] unsafe { - $crate::SliceInfo::new_unchecked([], ::std::marker::PhantomData::<$crate::Ix0>, ::std::marker::PhantomData::<$crate::Ix0>) + $crate::SliceInfo::new_unchecked([], ::std::marker::PhantomData::<$crate::Ix0>) } } }; @@ -682,7 +617,7 @@ macro_rules! s( ($($t:tt)*) => { // The extra `*&` is a workaround for this compiler bug: // https://github.com/rust-lang/rust/issues/23014 - &*&$crate::s![@parse ::std::marker::PhantomData::<$crate::Ix0>, ::std::marker::PhantomData::<$crate::Ix0>, [] $($t)*] + &*&$crate::s![@parse ::std::marker::PhantomData::<$crate::Ix0>, [] $($t)*] }; ); @@ -715,7 +650,7 @@ where fn multi_slice_move(&self, _view: ArrayViewMut<'a, A, D>) -> Self::Output {} } -impl<'a, A, D, Do0> MultiSlice<'a, A, D> for (&SliceInfo,) +impl<'a, A, D, Do0> MultiSlice<'a, A, D> for (&SliceInfo,) where A: 'a, D: Dimension, @@ -733,7 +668,7 @@ macro_rules! impl_multislice_tuple { impl_multislice_tuple!(@def_impl ($($but_last,)* $last,), [$($but_last)*] $last); }; (@def_impl ($($all:ident,)*), [$($but_last:ident)*] $last:ident) => { - impl<'a, A, D, $($all,)*> MultiSlice<'a, A, D> for ($(&SliceInfo,)*) + impl<'a, A, D, $($all,)*> MultiSlice<'a, A, D> for ($(&SliceInfo,)*) where A: 'a, D: Dimension, diff --git a/tests/array.rs b/tests/array.rs index 37c96c287..b398f9e37 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -217,7 +217,7 @@ fn test_slice_dyninput_array_fixed() { #[test] fn test_slice_array_dyn() { let mut arr = Array3::::zeros((5, 2, 5)); - let info = &SliceInfo::<_, IxDyn, IxDyn>::new([ + let info = &SliceInfo::<_, IxDyn>::new([ SliceOrIndex::from(1..), SliceOrIndex::from(1), SliceOrIndex::from(..).step_by(2), @@ -232,7 +232,7 @@ fn test_slice_array_dyn() { #[test] fn test_slice_dyninput_array_dyn() { let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); - let info = &SliceInfo::<_, IxDyn, IxDyn>::new([ + let info = &SliceInfo::<_, IxDyn>::new([ SliceOrIndex::from(1..), SliceOrIndex::from(1), SliceOrIndex::from(..).step_by(2), @@ -247,7 +247,7 @@ fn test_slice_dyninput_array_dyn() { #[test] fn test_slice_dyninput_vec_fixed() { let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); - let info = &SliceInfo::<_, Ix2, Ix3>::new(vec![ + let info = &SliceInfo::<_, Ix2>::new(vec![ SliceOrIndex::from(1..), SliceOrIndex::from(1), SliceOrIndex::from(..).step_by(2), @@ -262,7 +262,7 @@ fn test_slice_dyninput_vec_fixed() { #[test] fn test_slice_dyninput_vec_dyn() { let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); - let info = &SliceInfo::<_, IxDyn, IxDyn>::new(vec![ + let info = &SliceInfo::<_, IxDyn>::new(vec![ SliceOrIndex::from(1..), SliceOrIndex::from(1), SliceOrIndex::from(..).step_by(2), @@ -285,9 +285,7 @@ fn test_slice_arg() { let mut x = Array::from_elem(shape, 0); let indices = x.shape().iter().map(f).collect::>(); let s = x.slice_mut( - SliceInfo::<_, Sh::Dim, Sh::Dim>::new(indices) - .unwrap() - .as_ref(), + ::slice_info_from::<_, Sh::Dim>(&indices) ); let s2 = shape2.slice(); assert_eq!(s.shape(), s2) diff --git a/tests/oper.rs b/tests/oper.rs index adcb6e7c7..66c24c7f8 100644 --- a/tests/oper.rs +++ b/tests/oper.rs @@ -597,7 +597,7 @@ fn scaled_add_3() { { let mut av = a.slice_mut(s![..;s1, ..;s2]); - let c = c.slice(SliceInfo::<_, IxDyn, IxDyn>::new(cslice).unwrap().as_ref()); + let c = c.slice(SliceInfo::<_, IxDyn>::new(cslice).unwrap().as_ref()); let mut answerv = answer.slice_mut(s![..;s1, ..;s2]); answerv += &(beta * &c); From 03c379165477b99b08b38f221d96c03c6519793d Mon Sep 17 00:00:00 2001 From: SparrowLii Date: Fri, 5 Feb 2021 21:49:05 +0800 Subject: [PATCH 3/3] Transfer function into SliceInfo; Use Result as the return value --- src/dimension/dimension_trait.rs | 96 ++++++++++++++++---------------- src/slice.rs | 17 ++++++ tests/array.rs | 8 +-- 3 files changed, 68 insertions(+), 53 deletions(-) diff --git a/src/dimension/dimension_trait.rs b/src/dimension/dimension_trait.rs index 3c705311a..c3534a7a6 100644 --- a/src/dimension/dimension_trait.rs +++ b/src/dimension/dimension_trait.rs @@ -16,7 +16,7 @@ use super::axes_of; use super::conversion::Convert; use super::{stride_offset, stride_offset_checked}; use crate::itertools::{enumerate, zip}; -use crate::{Axis, SliceInfo}; +use crate::{Axis, ShapeError, ErrorKind}; use crate::IntoDimension; use crate::RemoveAxis; use crate::{ArrayView1, ArrayViewMut1}; @@ -78,35 +78,6 @@ pub trait Dimension: type Smaller: Dimension; /// Next larger dimension type Larger: Dimension + RemoveAxis; - - /// Convert index to &Self::SliceArg. Make sure that length of index - /// consists with Self::NDIM(if it exists). - /// - /// Panics if conversion failed. - #[doc(hidden)] - fn slice_arg_from(index: &[SliceOrIndex]) -> &Self::SliceArg ; - - /// Convert &SliceInfo, Do> to &SliceInfo. - /// Generate SliceArg of any dimension via this method. - /// - /// Panics if conversion failed. - #[doc(hidden)] - fn slice_info_from(indices: &T) -> &SliceInfo - where - T: AsRef<[SliceOrIndex]>, - Do: Dimension, - { - let arg_ref = Self::slice_arg_from(indices.as_ref()); - unsafe { - // This is okay because the only non-zero-sized member of - // `SliceInfo` is `indices`, so `&SliceInfo<[SliceOrIndex], D>` - // should have the same bitwise representation as - // `&[SliceOrIndex]`. - &*(arg_ref as *const Self::SliceArg - as *const SliceInfo) - } - } - /// Returns the number of dimensions (number of axes). fn ndim(&self) -> usize; @@ -404,6 +375,11 @@ pub trait Dimension: #[doc(hidden)] fn try_remove_axis(&self, axis: Axis) -> Self::Smaller; + /// Convert index to &Self::SliceArg. Return ShapeError if the length of index + /// doesn't consist with Self::NDIM(if it exists). + #[doc(hidden)] + fn slice_arg_from>(index: &T) -> Result<&Self::SliceArg, ShapeError>; + private_decl! {} } @@ -427,10 +403,6 @@ impl Dimension for Dim<[Ix; 0]> { type Pattern = (); type Smaller = Self; type Larger = Ix1; - #[inline] - fn slice_arg_from(index: &[SliceOrIndex]) -> &Self::SliceArg { - index.try_into().unwrap() - } // empty product is 1 -> size is 1 #[inline] fn ndim(&self) -> usize { @@ -465,6 +437,13 @@ impl Dimension for Dim<[Ix; 0]> { fn try_remove_axis(&self, _ignore: Axis) -> Self::Smaller { *self } + #[inline] + fn slice_arg_from>(index: &T) -> Result<&Self::SliceArg, ShapeError> { + match index.as_ref().try_into() { + Ok(arg) => Ok(arg), + Err(_) => Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)) + } + } private_impl! {} } @@ -475,9 +454,6 @@ impl Dimension for Dim<[Ix; 1]> { type Pattern = Ix; type Smaller = Ix0; type Larger = Ix2; - fn slice_arg_from(index: &[SliceOrIndex]) -> &Self::SliceArg { - index.try_into().unwrap() - } #[inline] fn ndim(&self) -> usize { 1 @@ -585,6 +561,15 @@ impl Dimension for Dim<[Ix; 1]> { None } } + + #[inline] + fn slice_arg_from>(index: &T) -> Result<&Self::SliceArg, ShapeError> { + match index.as_ref().try_into() { + Ok(arg) => Ok(arg), + Err(_) => Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)) + } + } + private_impl! {} } @@ -594,9 +579,6 @@ impl Dimension for Dim<[Ix; 2]> { type Pattern = (Ix, Ix); type Smaller = Ix1; type Larger = Ix3; - fn slice_arg_from(index: &[SliceOrIndex]) -> &Self::SliceArg { - index.try_into().unwrap() - } #[inline] fn ndim(&self) -> usize { 2 @@ -745,6 +727,13 @@ impl Dimension for Dim<[Ix; 2]> { fn try_remove_axis(&self, axis: Axis) -> Self::Smaller { self.remove_axis(axis) } + #[inline] + fn slice_arg_from>(index: &T) -> Result<&Self::SliceArg, ShapeError> { + match index.as_ref().try_into() { + Ok(arg) => Ok(arg), + Err(_) => Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)) + } + } private_impl! {} } @@ -754,9 +743,6 @@ impl Dimension for Dim<[Ix; 3]> { type Pattern = (Ix, Ix, Ix); type Smaller = Ix2; type Larger = Ix4; - fn slice_arg_from(index: &[SliceOrIndex]) -> &Self::SliceArg { - index.try_into().unwrap() - } #[inline] fn ndim(&self) -> usize { 3 @@ -869,6 +855,13 @@ impl Dimension for Dim<[Ix; 3]> { fn try_remove_axis(&self, axis: Axis) -> Self::Smaller { self.remove_axis(axis) } + #[inline] + fn slice_arg_from>(index: &T) -> Result<&Self::SliceArg, ShapeError> { + match index.as_ref().try_into() { + Ok(arg) => Ok(arg), + Err(_) => Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)) + } + } private_impl! {} } @@ -880,9 +873,6 @@ macro_rules! large_dim { type Pattern = $pattern; type Smaller = Dim<[Ix; $n - 1]>; type Larger = $larger; - fn slice_arg_from(index: &[SliceOrIndex]) -> &Self::SliceArg { - index.try_into().unwrap() - } #[inline] fn ndim(&self) -> usize { $n } #[inline] @@ -904,6 +894,13 @@ macro_rules! large_dim { fn try_remove_axis(&self, axis: Axis) -> Self::Smaller { self.remove_axis(axis) } + #[inline] + fn slice_arg_from>(index: &T) -> Result<&Self::SliceArg, ShapeError> { + match index.as_ref().try_into() { + Ok(arg) => Ok(arg), + Err(_) => Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)) + } + } private_impl!{} } ) @@ -934,9 +931,6 @@ impl Dimension for IxDyn { type Pattern = Self; type Smaller = Self; type Larger = Self; - fn slice_arg_from(index: &[SliceOrIndex]) -> &Self::SliceArg { - index - } #[inline] fn ndim(&self) -> usize { self.ix().len() @@ -977,6 +971,10 @@ impl Dimension for IxDyn { fn from_dimension(d: &D2) -> Option { Some(IxDyn(d.slice())) } + #[inline] + fn slice_arg_from>(index: &T) -> Result<&Self::SliceArg, ShapeError> { + Ok(index.as_ref()) + } private_impl! {} } diff --git a/src/slice.rs b/src/slice.rs index 86a2b0b8f..5203df47d 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -333,6 +333,23 @@ where indices, }) } + + /// Generate the corresponding SliceInfo from AsRef<[SliceOrIndex]> + /// for the specific dimension E. + /// + /// Return ShapeError if length does not match + pub fn for_dimensionality(indices: &T) -> Result<&SliceInfo, ShapeError> + { + let arg_ref = E::slice_arg_from(indices)?; + unsafe { + // This is okay because the only non-zero-sized member of + // `SliceInfo` is `indices`, so `&SliceInfo<[SliceOrIndex], D>` + // should have the same bitwise representation as + // `&[SliceOrIndex]`. + Ok(&*(arg_ref as *const E::SliceArg + as *const SliceInfo)) + } + } } impl SliceInfo diff --git a/tests/array.rs b/tests/array.rs index b398f9e37..46e296284 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -285,14 +285,14 @@ fn test_slice_arg() { let mut x = Array::from_elem(shape, 0); let indices = x.shape().iter().map(f).collect::>(); let s = x.slice_mut( - ::slice_info_from::<_, Sh::Dim>(&indices) + SliceInfo::<_, Sh::Dim>::for_dimensionality::(&indices).unwrap() ); let s2 = shape2.slice(); assert_eq!(s.shape(), s2) } - use_arg_map(0,|x| SliceOrIndex::from(*x/2..*x),Dim([0])); - use_arg_map((2, 4, 8),|x| SliceOrIndex::from(*x/2..*x),Dim([1, 2, 4])); - use_arg_map(vec![3, 6, 9],|x| SliceOrIndex::from(*x/3..*x/2),Dim([0, 1, 1])); + use_arg_map(0, |x| SliceOrIndex::from(*x/2..*x),Dim([0])); + use_arg_map((2, 4, 8), |x| SliceOrIndex::from(*x/2..*x),Dim([1, 2, 4])); + use_arg_map(vec![3, 6, 9], |x| SliceOrIndex::from(*x/3..*x/2),Dim([0, 1, 1])); use_arg_map(vec![1, 2, 3, 4, 5, 6, 7], |x| SliceOrIndex::from(x-1), Dim([])); }