diff --git a/src/dimension/dimension_trait.rs b/src/dimension/dimension_trait.rs index 18a10c3f0..b79a2a62e 100644 --- a/src/dimension/dimension_trait.rs +++ b/src/dimension/dimension_trait.rs @@ -9,13 +9,14 @@ use std::fmt::Debug; use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; use std::ops::{Index, IndexMut}; +use std::convert::TryInto; use alloc::vec::Vec; use super::axes_of; use super::conversion::Convert; use super::{stride_offset, stride_offset_checked}; use crate::itertools::{enumerate, zip}; -use crate::Axis; +use crate::{Axis, ShapeError, ErrorKind}; use crate::IntoDimension; use crate::RemoveAxis; use crate::{ArrayView1, ArrayViewMut1}; @@ -77,7 +78,6 @@ pub trait Dimension: type Smaller: Dimension; /// Next larger dimension type Larger: Dimension + RemoveAxis; - /// Returns the number of dimensions (number of axes). fn ndim(&self) -> usize; @@ -375,6 +375,11 @@ pub trait Dimension: #[doc(hidden)] fn try_remove_axis(&self, axis: Axis) -> Self::Smaller; + /// Convert index to &Self::SliceArg. Return ShapeError if the length of index + /// doesn't consist with Self::NDIM(if it exists). + #[doc(hidden)] + fn slice_arg_from>(index: &T) -> Result<&Self::SliceArg, ShapeError>; + private_decl! {} } @@ -432,6 +437,13 @@ impl Dimension for Dim<[Ix; 0]> { fn try_remove_axis(&self, _ignore: Axis) -> Self::Smaller { *self } + #[inline] + fn slice_arg_from>(index: &T) -> Result<&Self::SliceArg, ShapeError> { + match index.as_ref().try_into() { + Ok(arg) => Ok(arg), + Err(_) => Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)) + } + } private_impl! {} } @@ -549,6 +561,15 @@ impl Dimension for Dim<[Ix; 1]> { None } } + + #[inline] + fn slice_arg_from>(index: &T) -> Result<&Self::SliceArg, ShapeError> { + match index.as_ref().try_into() { + Ok(arg) => Ok(arg), + Err(_) => Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)) + } + } + private_impl! {} } @@ -706,6 +727,13 @@ impl Dimension for Dim<[Ix; 2]> { fn try_remove_axis(&self, axis: Axis) -> Self::Smaller { self.remove_axis(axis) } + #[inline] + fn slice_arg_from>(index: &T) -> Result<&Self::SliceArg, ShapeError> { + match index.as_ref().try_into() { + Ok(arg) => Ok(arg), + Err(_) => Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)) + } + } private_impl! {} } @@ -827,6 +855,13 @@ impl Dimension for Dim<[Ix; 3]> { fn try_remove_axis(&self, axis: Axis) -> Self::Smaller { self.remove_axis(axis) } + #[inline] + fn slice_arg_from>(index: &T) -> Result<&Self::SliceArg, ShapeError> { + match index.as_ref().try_into() { + Ok(arg) => Ok(arg), + Err(_) => Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)) + } + } private_impl! {} } @@ -859,6 +894,13 @@ macro_rules! large_dim { fn try_remove_axis(&self, axis: Axis) -> Self::Smaller { self.remove_axis(axis) } + #[inline] + fn slice_arg_from>(index: &T) -> Result<&Self::SliceArg, ShapeError> { + match index.as_ref().try_into() { + Ok(arg) => Ok(arg), + Err(_) => Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)) + } + } private_impl!{} } ) @@ -930,6 +972,11 @@ impl Dimension for IxDyn { Some(IxDyn(d.slice())) } + #[inline] + fn slice_arg_from>(index: &T) -> Result<&Self::SliceArg, ShapeError> { + Ok(index.as_ref()) + } + fn into_dyn(self) -> IxDyn { self } diff --git a/src/slice.rs b/src/slice.rs index 86a2b0b8f..5203df47d 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -333,6 +333,23 @@ where indices, }) } + + /// Generate the corresponding SliceInfo from AsRef<[SliceOrIndex]> + /// for the specific dimension E. + /// + /// Return ShapeError if length does not match + pub fn for_dimensionality(indices: &T) -> Result<&SliceInfo, ShapeError> + { + let arg_ref = E::slice_arg_from(indices)?; + unsafe { + // This is okay because the only non-zero-sized member of + // `SliceInfo` is `indices`, so `&SliceInfo<[SliceOrIndex], D>` + // should have the same bitwise representation as + // `&[SliceOrIndex]`. + Ok(&*(arg_ref as *const E::SliceArg + as *const SliceInfo)) + } + } } impl SliceInfo diff --git a/tests/array.rs b/tests/array.rs index dddcd5e1e..d863d4b7e 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -274,6 +274,28 @@ fn test_slice_dyninput_vec_dyn() { arr.view().slice_collapse(info.as_ref()); } +#[test] +fn test_slice_arg() { + fn use_arg_map(shape: Sh, f: impl Fn(&usize) -> SliceOrIndex, shape2: D) + where + Sh: ShapeBuilder, + D: Dimension, + { + let shape = shape.into_shape(); + let mut x = Array::from_elem(shape, 0); + let indices = x.shape().iter().map(f).collect::>(); + let s = x.slice_mut( + SliceInfo::<_, Sh::Dim>::for_dimensionality::(&indices).unwrap() + ); + let s2 = shape2.slice(); + assert_eq!(s.shape(), s2) + } + use_arg_map(0, |x| SliceOrIndex::from(*x/2..*x),Dim([0])); + use_arg_map((2, 4, 8), |x| SliceOrIndex::from(*x/2..*x),Dim([1, 2, 4])); + use_arg_map(vec![3, 6, 9], |x| SliceOrIndex::from(*x/3..*x/2),Dim([0, 1, 1])); + use_arg_map(vec![1, 2, 3, 4, 5, 6, 7], |x| SliceOrIndex::from(x-1), Dim([])); +} + #[test] fn test_slice_with_subview() { let mut arr = ArcArray::::zeros((3, 5, 4));