diff --git a/src/data_traits.rs b/src/data_traits.rs index bc6c08090..e09b809a9 100644 --- a/src/data_traits.rs +++ b/src/data_traits.rs @@ -83,8 +83,8 @@ unsafe impl DataMut for Rc> // Create a new vec if the current view is less than half of // backing data. unsafe { - *self_ = ArrayBase::from_vec_dim_unchecked(self_.dim.clone(), - self_.iter() + *self_ = ArrayBase::from_shape_vec_unchecked(self_.dim.clone(), + self_.iter() .cloned() .collect()); } diff --git a/src/free_functions.rs b/src/free_functions.rs index f78edb002..34037c5c9 100644 --- a/src/free_functions.rs +++ b/src/free_functions.rs @@ -13,7 +13,7 @@ use imp_prelude::*; /// Create a zero-dimensional array with the element `x`. pub fn arr0(x: A) -> OwnedArray { - unsafe { ArrayBase::from_vec_dim_unchecked((), vec![x]) } + unsafe { ArrayBase::from_shape_vec_unchecked((), vec![x]) } } /// Create a one-dimensional array with elements from `xs`. @@ -129,7 +129,7 @@ pub fn arr2>(xs: &[V]) -> OwnedArray, U: FixedInitializer>( } } unsafe { - ArrayBase::from_vec_dim_unchecked(dim, result) + ArrayBase::from_shape_vec_unchecked(dim, result) } } diff --git a/src/impl_constructors.rs b/src/impl_constructors.rs index 0fa617970..3bccb5786 100644 --- a/src/impl_constructors.rs +++ b/src/impl_constructors.rs @@ -8,9 +8,14 @@ //! Constructor methods for ndarray //! -use libnum; +//! + +#![allow(deprecated)] // from_shape_vec + +use libnum::{Zero, One, Float}; use imp_prelude::*; +use {Shape, StrideShape}; use dimension; use linspace; use error::{self, ShapeError, ErrorKind}; @@ -30,7 +35,7 @@ impl ArrayBase /// let array = OwnedArray::from_vec(vec![1., 2., 3., 4.]); /// ``` pub fn from_vec(v: Vec) -> ArrayBase { - unsafe { Self::from_vec_dim_unchecked(v.len() as Ix, v) } + unsafe { Self::from_shape_vec_unchecked(v.len() as Ix, v) } } /// Create a one-dimensional array from an iterable. @@ -58,7 +63,7 @@ impl ArrayBase /// ``` pub fn linspace(start: F, end: F, n: usize) -> ArrayBase where S: Data, - F: libnum::Float, + F: Float, { Self::from_vec(::iterators::to_vec(linspace::linspace(start, end, n))) } @@ -74,7 +79,7 @@ impl ArrayBase /// ``` pub fn range(start: F, end: F, step: F) -> ArrayBase where S: Data, - F: libnum::Float, + F: Float, { Self::from_vec(::iterators::to_vec(linspace::range(start, end, step))) } @@ -89,7 +94,7 @@ impl ArrayBase /// **Panics** if `n * n` would overflow usize. pub fn eye(n: Ix) -> ArrayBase where S: DataMut, - A: Clone + libnum::Zero + libnum::One, + A: Clone + Zero + One, { let mut eye = Self::zeros((n, n)); for a_ii in eye.diag_mut() { @@ -113,13 +118,12 @@ impl ArrayBase where S: DataOwned, D: Dimension, { - /// Create an array with copies of `elem`, dimension `dim`. + /// Create an array with copies of `elem`, shape `shape`. /// - /// **Panics** if the number of elements in `dim` would overflow usize. + /// **Panics** if the number of elements in `shape` would overflow usize. /// /// ``` - /// use ndarray::OwnedArray; - /// use ndarray::arr3; + /// use ndarray::{OwnedArray, arr3, ShapeBuilder}; /// /// let a = OwnedArray::from_elem((2, 2, 2), 1.); /// @@ -130,68 +134,110 @@ impl ArrayBase /// [1., 1.]]]) /// ); /// assert!(a.strides() == &[4, 2, 1]); + /// + /// let b = OwnedArray::from_elem((2, 2, 2).f(), 1.); + /// assert!(b.strides() == &[1, 2, 4]); /// ``` - pub fn from_elem(dim: D, elem: A) -> ArrayBase - where A: Clone + pub fn from_elem(shape: Sh, elem: A) -> ArrayBase + where A: Clone, + Sh: Into>, { // Note: We don't need to check the case of a size between // isize::MAX -> usize::MAX; in this case, the vec constructor itself // panics. - let size = size_checked_unwrap!(dim); + let shape = shape.into(); + let size = size_checked_unwrap!(shape.dim); let v = vec![elem; size]; - unsafe { Self::from_vec_dim_unchecked(dim, v) } + unsafe { Self::from_shape_vec_unchecked(shape, v) } } - /// Create an array with copies of `elem`, dimension `dim` and fortran - /// memory order. - /// - /// **Panics** if the number of elements would overflow usize. + /// Create an array with zeros, shape `shape`. /// - /// ``` - /// use ndarray::OwnedArray; - /// - /// let a = OwnedArray::from_elem_f((2, 2, 2), 1.); - /// assert!(a.strides() == &[1, 2, 4]); - /// ``` - pub fn from_elem_f(dim: D, elem: A) -> ArrayBase - where A: Clone + /// **Panics** if the number of elements in `shape` would overflow usize. + pub fn zeros(shape: Sh) -> ArrayBase + where A: Clone + Zero, + Sh: Into>, { - let size = size_checked_unwrap!(dim); - let v = vec![elem; size]; - unsafe { Self::from_vec_dim_unchecked_f(dim, v) } + Self::from_elem(shape, A::zero()) } - /// Create an array with zeros, dimension `dim`. + /// Create an array with default values, shape `shape` /// - /// **Panics** if the number of elements in `dim` would overflow usize. - pub fn zeros(dim: D) -> ArrayBase - where A: Clone + libnum::Zero + /// **Panics** if the number of elements in `shape` would overflow usize. + pub fn default(shape: Sh) -> ArrayBase + where A: Default, + Sh: Into>, { - Self::from_elem(dim, libnum::zero()) + let shape = shape.into(); + let v = (0..shape.dim.size()).map(|_| A::default()).collect(); + unsafe { Self::from_shape_vec_unchecked(shape, v) } } - /// Create an array with zeros, dimension `dim` and fortran memory order. + /// Create an array with the given shape from a vector. (No cloning of + /// elements needed.) /// - /// **Panics** if the number of elements in `dim` would overflow usize. - pub fn zeros_f(dim: D) -> ArrayBase - where A: Clone + libnum::Zero + /// ---- + /// + /// For a contiguous c- or f-order shape, the following applies: + /// + /// **Errors** if `shape` does not correspond to the number of elements in `v`. + /// + /// ---- + /// + /// For custom strides, the following applies: + /// + /// **Errors** if strides and dimensions can point out of bounds of `v`.
+ /// **Errors** if strides allow multiple indices to point to the same element. + pub fn from_shape_vec(shape: Sh, v: Vec
) -> Result, ShapeError> + where Sh: Into>, { - Self::from_elem_f(dim, libnum::zero()) + // eliminate the type parameter Sh as soon as possible + Self::from_shape_vec_impl(shape.into(), v) } - /// Create an array with default values, dimension `dim`. - /// - /// **Panics** if the number of elements in `dim` would overflow usize. - pub fn default(dim: D) -> ArrayBase - where A: Default + fn from_shape_vec_impl(shape: StrideShape, v: Vec) -> Result, ShapeError> { - let v = (0..dim.size()).map(|_| A::default()).collect(); - unsafe { Self::from_vec_dim_unchecked(dim, v) } + if shape.custom { + Self::from_vec_dim_stride(shape.dim, shape.strides, v) + } else { + let dim = shape.dim; + let strides = shape.strides; + if dim.size_checked() != Some(v.len()) { + return Err(error::incompatible_shapes(&v.len(), &dim)); + } + unsafe { Ok(Self::from_vec_dim_stride_unchecked(dim, strides, v)) } + } } - /// Create an array from a vector (no copying needed). + /// Create an array from a vector and interpret it according to the + /// provided dimensions and strides. (No cloning of elements needed.) /// - /// **Errors** if `dim` does not correspond to the number of elements in `v`. + /// Unsafe because dimension and strides are unchecked. + pub unsafe fn from_shape_vec_unchecked(shape: Sh, v: Vec) -> ArrayBase + where Sh: Into>, + { + let shape = shape.into(); + Self::from_vec_dim_stride_unchecked(shape.dim, shape.strides, v) + } + + #[cfg_attr(has_deprecated, deprecated(note="Use from_elem instead."))] + /// ***Deprecated: Use from_elem instead*** + pub fn from_elem_f(dim: D, elem: A) -> ArrayBase + where A: Clone + { + Self::from_elem(dim.f(), elem) + } + + #[cfg_attr(has_deprecated, deprecated(note="Use zeros instead."))] + /// ***Deprecated: Use zeros instead*** + pub fn zeros_f(dim: D) -> ArrayBase + where A: Clone + Zero + { + Self::from_elem_f(dim, A::zero()) + } + + #[cfg_attr(has_deprecated, deprecated(note="Use from_shape_vec instead."))] + /// ***Deprecated: Use from_shape_vec instead*** pub fn from_vec_dim(dim: D, v: Vec) -> Result, ShapeError> { if dim.size_checked() != Some(v.len()) { return Err(error::incompatible_shapes(&v.len(), &dim)); @@ -199,10 +245,8 @@ impl ArrayBase unsafe { Ok(Self::from_vec_dim_unchecked(dim, v)) } } - /// Create an array from a vector (no copying needed) using fortran - /// memory order to interpret the data. - /// - /// **Errors** if `dim` does not correspond to the number of elements in `v`. + #[cfg_attr(has_deprecated, deprecated(note="Use from_shape_vec instead."))] + /// ***Deprecated: Use from_shape_vec instead*** pub fn from_vec_dim_f(dim: D, v: Vec) -> Result, ShapeError> { if dim.size_checked() != Some(v.len()) { return Err(error::incompatible_shapes(&v.len(), &dim)); @@ -210,9 +254,8 @@ impl ArrayBase unsafe { Ok(Self::from_vec_dim_unchecked_f(dim, v)) } } - /// Create an array from a vector (no copying needed). - /// - /// Unsafe because dimension is unchecked, and must be correct. + #[cfg_attr(has_deprecated, deprecated(note="Use from_shape_vec_unchecked instead."))] + /// ***Deprecated: Use from_shape_vec_unchecked instead*** pub unsafe fn from_vec_dim_unchecked(dim: D, mut v: Vec) -> ArrayBase { debug_assert!(dim.size_checked() == Some(v.len())); ArrayBase { @@ -223,24 +266,16 @@ impl ArrayBase } } - /// Create an array from a vector (with no copying needed), - /// using fortran memory order to interpret the data. - /// - /// Unsafe because dimension is unchecked, and must be correct. + #[cfg_attr(has_deprecated, deprecated(note="Use from_shape_vec_unchecked instead."))] + /// ***Deprecated: Use from_shape_vec_unchecked instead*** pub unsafe fn from_vec_dim_unchecked_f(dim: D, v: Vec) -> ArrayBase { debug_assert!(dim.size_checked() == Some(v.len())); let strides = dim.fortran_strides(); Self::from_vec_dim_stride_unchecked(dim, strides, v) } - /// Create an array from a vector and interpret it according to the - /// provided dimensions and strides. No allocation needed. - /// - /// Checks whether `dim` and `strides` are compatible with the vector's - /// length, returning an `Err` if not compatible. - /// - /// **Errors** if strides and dimensions can point out of bounds of `v`.
- /// **Errors** if strides allow multiple indices to point to the same element. + #[cfg_attr(has_deprecated, deprecated(note="Use from_shape_vec instead."))] + /// ***Deprecated: Use from_shape_vec instead*** pub fn from_vec_dim_stride(dim: D, strides: D, v: Vec
) -> Result, ShapeError> { @@ -251,10 +286,9 @@ impl ArrayBase }) } - /// Create an array from a vector and interpret it according to the - /// provided dimensions and strides. No allocation needed. + #[cfg_attr(has_deprecated, deprecated(note="Use from_shape_vec_unchecked instead."))] + /// ***Deprecated: Use from_shape_vec_unchecked instead*** /// - /// Unsafe because dimension and strides are unchecked. pub unsafe fn from_vec_dim_stride_unchecked(dim: D, strides: D, mut v: Vec) -> ArrayBase { diff --git a/src/impl_methods.rs b/src/impl_methods.rs index b9fe5eb7a..3ba1fbba9 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -91,7 +91,7 @@ impl ArrayBase where S: Data, D: Dimension (self.iter().cloned().collect(), self.dim.default_strides()) }; unsafe { - ArrayBase::from_vec_dim_stride_unchecked(self.dim.clone(), strides, data) + ArrayBase::from_shape_vec_unchecked(self.dim.clone().strides(strides), data) } } @@ -427,7 +427,7 @@ impl ArrayBase where S: Data, D: Dimension let mut dim = self.dim(); dim.set_axis(axis, 0); unsafe { - OwnedArray::from_vec_dim_unchecked(dim, vec![]) + OwnedArray::from_shape_vec_unchecked(dim, vec![]) } } else { stack(axis, &subs).unwrap() @@ -800,7 +800,7 @@ impl ArrayBase where S: Data, D: Dimension } else { let v = self.iter().map(|x| x.clone()).collect::>(); unsafe { - ArrayBase::from_vec_dim_unchecked(shape, v) + ArrayBase::from_shape_vec_unchecked(shape, v) } } } @@ -1150,13 +1150,13 @@ impl ArrayBase where S: Data, D: Dimension if let Some(slc) = self.as_slice_memory_order() { let v = ::iterators::to_vec(slc.iter().map(f)); unsafe { - ArrayBase::from_vec_dim_stride_unchecked( - self.dim.clone(), self.strides.clone(), v) + ArrayBase::from_shape_vec_unchecked( + self.dim.clone().strides(self.strides.clone()), v) } } else { let v = ::iterators::to_vec(self.iter().map(f)); unsafe { - ArrayBase::from_vec_dim_unchecked(self.dim.clone(), v) + ArrayBase::from_shape_vec_unchecked(self.dim.clone(), v) } } } diff --git a/src/impl_views.rs b/src/impl_views.rs index 398a17251..209603255 100644 --- a/src/impl_views.rs +++ b/src/impl_views.rs @@ -10,6 +10,8 @@ use imp_prelude::*; use dimension::{self, stride_offset}; use error::ShapeError; +use StrideShape; + /// # Methods for Array Views /// /// Methods for read-only array views `ArrayView<'a, A, D>` @@ -21,17 +23,17 @@ impl<'a, A, D> ArrayBase, D> { /// Create a read-only array view borrowing its data from a slice. /// - /// Checks whether `dim` and `strides` are compatible with the slice's + /// Checks whether `shape` are compatible with the slice's /// length, returning an `Err` if not compatible. /// /// ``` /// use ndarray::ArrayView; /// use ndarray::arr3; + /// use ndarray::ShapeBuilder; /// /// let s = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; - /// let a = ArrayView::from_slice_dim_stride((2, 3, 2), - /// (1, 4, 2), - /// &s).unwrap(); + /// let a = ArrayView::from_shape((2, 3, 2).strides((1, 4, 2)), + /// &s).unwrap(); /// /// assert!( /// a == arr3(&[[[0, 2], @@ -43,6 +45,22 @@ impl<'a, A, D> ArrayBase, D> /// ); /// assert!(a.strides() == &[1, 4, 2]); /// ``` + pub fn from_shape(shape: Sh, xs: &'a [A]) + -> Result + where Sh: Into>, + { + let shape = shape.into(); + let dim = shape.dim; + let strides = shape.strides; + dimension::can_index_slice(xs, &dim, &strides).map(|_| { + unsafe { + Self::new_(xs.as_ptr(), dim, strides) + } + }) + } + + #[cfg_attr(has_deprecated, deprecated(note="Use from_shape instead."))] + /// ***Deprecated: Use from_shape instead*** pub fn from_slice_dim_stride(dim: D, strides: D, xs: &'a [A]) -> Result { @@ -111,11 +129,11 @@ impl<'a, A, D> ArrayBase, D> /// ``` /// use ndarray::ArrayViewMut; /// use ndarray::arr3; + /// use ndarray::ShapeBuilder; /// /// let mut s = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; - /// let mut a = ArrayViewMut::from_slice_dim_stride((2, 3, 2), - /// (1, 4, 2), - /// &mut s).unwrap(); + /// let mut a = ArrayViewMut::from_shape((2, 3, 2).strides((1, 4, 2)), + /// &mut s).unwrap(); /// /// a[[0, 0, 0]] = 1; /// assert!( @@ -128,6 +146,22 @@ impl<'a, A, D> ArrayBase, D> /// ); /// assert!(a.strides() == &[1, 4, 2]); /// ``` + pub fn from_shape(shape: Sh, xs: &'a mut [A]) + -> Result + where Sh: Into>, + { + let shape = shape.into(); + let dim = shape.dim; + let strides = shape.strides; + dimension::can_index_slice(xs, &dim, &strides).map(|_| { + unsafe { + Self::new_(xs.as_mut_ptr(), dim, strides) + } + }) + } + + #[cfg_attr(has_deprecated, deprecated(note="Use from_shape instead."))] + /// ***Deprecated: Use from_shape instead*** pub fn from_slice_dim_stride(dim: D, strides: D, xs: &'a mut [A]) -> Result { diff --git a/src/lib.rs b/src/lib.rs index efe716769..8b69b2141 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -108,6 +108,8 @@ pub use arraytraits::AsArray; pub use linalg_traits::{LinalgScalar, NdFloat}; pub use stacking::stack; +pub use shape_builder::{ ShapeBuilder }; + mod arraytraits; #[cfg(feature = "serde")] mod arrayserialize; @@ -134,6 +136,7 @@ mod linspace; mod numeric_util; mod si; mod error; +mod shape_builder; mod stacking; /// Implementation's prelude. Common types used everywhere. @@ -697,3 +700,21 @@ enum ElementsRepr { Slice(S), Counted(C), } + + +/// A contiguous array shape of n dimensions. +/// +/// Either c- or f- memory ordered (*c* a.k.a *row major* is the default). +#[derive(Copy, Clone, Debug)] +pub struct Shape { + dim: D, + is_c: bool, +} + +/// An array shape of n dimensions c-order, f-order or custom strides. +#[derive(Copy, Clone, Debug)] +pub struct StrideShape { + dim: D, + strides: D, + custom: bool, +} diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 44f35f976..824fdae0a 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -230,11 +230,7 @@ impl Dot> for ArrayBase let mut c; unsafe { v.set_len(m * n); - if !column_major { - c = OwnedArray::from_vec_dim_unchecked((m, n), v); - } else { - c = OwnedArray::from_vec_dim_unchecked_f((m, n), v); - } + c = OwnedArray::from_shape_vec_unchecked((m, n).set_f(column_major), v); } mat_mul_impl(A::one(), &a, &b, A::zero(), &mut c.view_mut()); c @@ -293,7 +289,7 @@ impl Dot> for ArrayBase } } unsafe { - ArrayBase::from_vec_dim_unchecked(m, res_elems) + ArrayBase::from_shape_vec_unchecked(m, res_elems) } } } @@ -353,10 +349,10 @@ fn mat_mul_impl(alpha: A, let mut rhs_trans = CblasNoTrans; if both_f { // A^t B^t = C^t => B A = C - lhs_ = lhs_.reversed_axes(); - rhs_ = rhs_.reversed_axes(); + let lhs_t = lhs_.reversed_axes(); + lhs_ = rhs_.reversed_axes(); + rhs_ = lhs_t; c_ = c_.reversed_axes(); - swap(&mut lhs_, &mut rhs_); swap(&mut m, &mut n); } else if lhs_s0 == 1 && m == a { lhs_ = lhs_.reversed_axes(); diff --git a/src/prelude.rs b/src/prelude.rs index 3d00390c7..2e422b877 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -42,3 +42,8 @@ pub use { arr1, arr2, aview0, aview1, aview2, }; + +#[doc(no_inline)] +pub use { + ShapeBuilder, +}; diff --git a/src/shape_builder.rs b/src/shape_builder.rs new file mode 100644 index 000000000..2b548c45d --- /dev/null +++ b/src/shape_builder.rs @@ -0,0 +1,86 @@ + +use Dimension; +use {Shape, StrideShape}; + +/// A trait for `Shape` and `D where D: Dimension` that allows +/// customizing the memory layout (strides) of an array shape. +/// +/// This trait is used together with array constructor methods like +/// `OwnedArray::from_shape_vec`. +pub trait ShapeBuilder { + type Dim: Dimension; + + fn f(self) -> Shape; + fn set_f(self, is_f: bool) -> Shape; + fn strides(self, strides: Self::Dim) -> StrideShape; +} + +impl From for Shape + where D: Dimension +{ + fn from(d: D) -> Self { + Shape { + dim: d, + is_c: true, + } + } +} + +impl From for StrideShape + where D: Dimension +{ + fn from(d: D) -> Self { + StrideShape { + strides: d.default_strides(), + dim: d, + custom: false, + } + } +} + +impl From> for StrideShape + where D: Dimension +{ + fn from(shape: Shape) -> Self { + let d = shape.dim; + let st = if shape.is_c { d.default_strides() } else { d.fortran_strides() }; + StrideShape { + strides: st, + dim: d, + custom: false, + } + } +} + +impl ShapeBuilder for D + where D: Dimension +{ + type Dim = D; + fn f(self) -> Shape { self.set_f(true) } + fn set_f(self, is_f: bool) -> Shape { + Shape::from(self).set_f(is_f) + } + fn strides(self, st: D) -> StrideShape { + Shape::from(self).strides(st) + } +} + +impl ShapeBuilder for Shape + where D: Dimension +{ + type Dim = D; + fn f(self) -> Self { self.set_f(true) } + fn set_f(mut self, is_f: bool) -> Self { + self.is_c = !is_f; + self + } + fn strides(self, st: D) -> StrideShape { + StrideShape { + dim: self.dim, + strides: st, + custom: true, + } + } +} + + diff --git a/src/stacking.rs b/src/stacking.rs index ca152274b..e4233eb60 100644 --- a/src/stacking.rs +++ b/src/stacking.rs @@ -57,7 +57,7 @@ pub fn stack<'a, A, D>(axis: Axis, arrays: &[ArrayView<'a, A, D>]) unsafe { v.set_len(size); } - let mut res = try!(OwnedArray::from_vec_dim(res_dim, v)); + let mut res = try!(OwnedArray::from_shape_vec(res_dim, v)); { let mut assign_view = res.view_mut(); diff --git a/tests/array.rs b/tests/array.rs index ff830c757..32ea4ac70 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -4,20 +4,14 @@ extern crate ndarray; extern crate itertools; -use ndarray::{RcArray, S, Si, - OwnedArray, -}; +use ndarray::{S, Si}; +use ndarray::prelude::*; use ndarray::{ rcarr2, - arr0, arr1, arr2, arr3, - aview0, - aview1, - aview2, + arr0, arr3, aview_mut1, - Dimension, }; use ndarray::Indexes; -use ndarray::Axis; use itertools::free::enumerate; #[test] @@ -38,7 +32,7 @@ fn test_matmul_rcarray() println!("B = \n{:?}", B); println!("A x B = \n{:?}", c); unsafe { - let result = RcArray::from_vec_dim_unchecked((2, 4), vec![20, 23, 26, 29, 56, 68, 80, 92]); + let result = RcArray::from_shape_vec_unchecked((2, 4), vec![20, 23, 26, 29, 56, 68, 80, 92]); assert_eq!(c.shape(), result.shape()); assert!(c.iter().zip(result.iter()).all(|(a,b)| a == b)); assert!(c == result); @@ -454,7 +448,7 @@ fn owned_array_with_stride() { let dim = (2, 3, 2); let strides = (1, 4, 2); - let a = OwnedArray::from_vec_dim_stride(dim, strides, v).unwrap(); + let a = OwnedArray::from_shape_vec(dim.strides(strides), v).unwrap(); assert_eq!(a.strides(), &[1, 4, 2]); } @@ -471,7 +465,7 @@ macro_rules! assert_matches { #[test] fn from_vec_dim_stride_empty_1d() { let empty: [f32; 0] = []; - assert_matches!(OwnedArray::from_vec_dim_stride(0, 1, empty.to_vec()), + assert_matches!(OwnedArray::from_shape_vec(0.strides(1), empty.to_vec()), Ok(_)); } @@ -481,11 +475,11 @@ fn from_vec_dim_stride_0d() { let one = [1.]; let two = [1., 2.]; // too few elements - assert_matches!(OwnedArray::from_vec_dim_stride((), (), empty.to_vec()), Err(_)); + assert_matches!(OwnedArray::from_shape_vec(().strides(()), empty.to_vec()), Err(_)); // exact number of elements - assert_matches!(OwnedArray::from_vec_dim_stride((), (), one.to_vec()), Ok(_)); + assert_matches!(OwnedArray::from_shape_vec(().strides(()), one.to_vec()), Ok(_)); // too many are ok - assert_matches!(OwnedArray::from_vec_dim_stride((), (), two.to_vec()), Ok(_)); + assert_matches!(OwnedArray::from_shape_vec(().strides(()), two.to_vec()), Ok(_)); } #[test] @@ -493,7 +487,7 @@ fn from_vec_dim_stride_2d_1() { let two = [1., 2.]; let d = (2, 1); let s = d.default_strides(); - assert_matches!(OwnedArray::from_vec_dim_stride(d, s, two.to_vec()), Ok(_)); + assert_matches!(OwnedArray::from_shape_vec(d.strides(s), two.to_vec()), Ok(_)); } #[test] @@ -501,7 +495,7 @@ fn from_vec_dim_stride_2d_2() { let two = [1., 2.]; let d = (1, 2); let s = d.default_strides(); - assert_matches!(OwnedArray::from_vec_dim_stride(d, s, two.to_vec()), Ok(_)); + assert_matches!(OwnedArray::from_shape_vec(d.strides(s), two.to_vec()), Ok(_)); } #[test] @@ -511,7 +505,7 @@ fn from_vec_dim_stride_2d_3() { [[3]]]); let d = a.dim(); let s = d.default_strides(); - assert_matches!(OwnedArray::from_vec_dim_stride(d, s, a.as_slice().unwrap().to_vec()), Ok(_)); + assert_matches!(OwnedArray::from_shape_vec(d.strides(s), a.as_slice().unwrap().to_vec()), Ok(_)); } #[test] @@ -521,7 +515,7 @@ fn from_vec_dim_stride_2d_4() { [[3]]]); let d = a.dim(); let s = d.fortran_strides(); - assert_matches!(OwnedArray::from_vec_dim_stride(d, s, a.as_slice().unwrap().to_vec()), Ok(_)); + assert_matches!(OwnedArray::from_shape_vec(d.strides(s), a.as_slice().unwrap().to_vec()), Ok(_)); } #[test] @@ -529,7 +523,7 @@ fn from_vec_dim_stride_2d_5() { let a = arr3(&[[[1, 2, 3]]]); let d = a.dim(); let s = d.fortran_strides(); - assert_matches!(OwnedArray::from_vec_dim_stride(d, s, a.as_slice().unwrap().to_vec()), Ok(_)); + assert_matches!(OwnedArray::from_shape_vec(d.strides(s), a.as_slice().unwrap().to_vec()), Ok(_)); } #[test] @@ -537,11 +531,11 @@ fn from_vec_dim_stride_2d_6() { let a = [1., 2., 3., 4., 5., 6.]; let d = (2, 1, 1); let s = (2, 2, 1); - assert_matches!(OwnedArray::from_vec_dim_stride(d, s, a.to_vec()), Ok(_)); + assert_matches!(OwnedArray::from_shape_vec(d.strides(s), a.to_vec()), Ok(_)); let d = (1, 2, 1); let s = (2, 2, 1); - assert_matches!(OwnedArray::from_vec_dim_stride(d, s, a.to_vec()), Ok(_)); + assert_matches!(OwnedArray::from_shape_vec(d.strides(s), a.to_vec()), Ok(_)); } #[test] @@ -551,7 +545,7 @@ fn from_vec_dim_stride_2d_7() { // [[]] shape=[4, 0], strides=[0, 1] let d = (4, 0); let s = (0, 1); - assert_matches!(OwnedArray::from_vec_dim_stride(d, s, a.to_vec()), Ok(_)); + assert_matches!(OwnedArray::from_shape_vec(d.strides(s), a.to_vec()), Ok(_)); } #[test] @@ -560,7 +554,7 @@ fn from_vec_dim_stride_2d_8() { let a = [1.]; let d = (1, 1); let s = (0, 1); - assert_matches!(OwnedArray::from_vec_dim_stride(d, s, a.to_vec()), Err(_)); + assert_matches!(OwnedArray::from_shape_vec(d.strides(s), a.to_vec()), Err(_)); } #[test] @@ -568,11 +562,11 @@ fn from_vec_dim_stride_2d_rejects() { let two = [1., 2.]; let d = (2, 2); let s = (1, 0); - assert_matches!(OwnedArray::from_vec_dim_stride(d, s, two.to_vec()), Err(_)); + assert_matches!(OwnedArray::from_shape_vec(d.strides(s), two.to_vec()), Err(_)); let d = (2, 2); let s = (0, 1); - assert_matches!(OwnedArray::from_vec_dim_stride(d, s, two.to_vec()), Err(_)); + assert_matches!(OwnedArray::from_shape_vec(d.strides(s), two.to_vec()), Err(_)); } #[test] @@ -741,7 +735,7 @@ fn reshape_error2() { #[test] fn reshape_f() { - let mut u = OwnedArray::zeros_f((3, 4)); + let mut u = OwnedArray::zeros((3, 4).f()); for (i, elt) in enumerate(u.as_slice_memory_order_mut().unwrap()) { *elt = i as i32; } @@ -825,9 +819,9 @@ fn scalar_ops() { #[test] fn deny_wraparound_from_vec() { let five = vec![0; 5]; - let _five_large = OwnedArray::from_vec_dim((3, 7, 29, 36760123, 823996703), five.clone()); - assert!(_five_large.is_err()); - let six = OwnedArray::from_vec_dim(6, five.clone()); + let five_large = OwnedArray::from_shape_vec((3, 7, 29, 36760123, 823996703), five.clone()); + assert!(five_large.is_err()); + let six = OwnedArray::from_shape_vec(6, five.clone()); assert!(six.is_err()); } @@ -924,7 +918,7 @@ fn test_f_order() { // even if the underlying memory order is different let c = arr2(&[[1, 2, 3], [4, 5, 6]]); - let mut f = OwnedArray::zeros_f(c.dim()); + let mut f = OwnedArray::zeros(c.dim().f()); f.assign(&c); assert_eq!(f, c); assert_eq!(f.shape(), c.shape()); @@ -991,7 +985,7 @@ fn test_contiguous() { assert!(v.as_slice_memory_order().is_some()); let a = OwnedArray::::zeros((20, 1)); - let b = OwnedArray::::zeros_f((20, 1)); + let b = OwnedArray::::zeros((20, 1).f()); assert!(a.as_slice().is_some()); assert!(b.as_slice().is_some()); assert!(a.as_slice_memory_order().is_some()); @@ -1028,3 +1022,17 @@ fn test_swap() { } assert_eq!(a, b.t()); } + +#[test] +fn test_shape() { + let data = [0, 1, 2, 3, 4, 5]; + let a = OwnedArray::from_shape_vec((1, 2, 3), data.to_vec()).unwrap(); + let b = OwnedArray::from_shape_vec((1, 2, 3).f(), data.to_vec()).unwrap(); + let c = OwnedArray::from_shape_vec((1, 2, 3).strides((1, 3, 1)), data.to_vec()).unwrap(); + println!("{:?}", a); + println!("{:?}", b); + println!("{:?}", c); + assert_eq!(a.strides(), &[6, 3, 1]); + assert_eq!(b.strides(), &[1, 1, 2]); + assert_eq!(c.strides(), &[1, 3, 1]); +} diff --git a/tests/oper.rs b/tests/oper.rs index 1727df5ab..29253674f 100644 --- a/tests/oper.rs +++ b/tests/oper.rs @@ -243,7 +243,9 @@ fn reference_mat_mul(lhs: &ArrayBase, rhs: &ArrayBase, S2: Data, { - let ((m, k), (_, n)) = (lhs.dim(), rhs.dim()); + let ((m, k), (k2, n)) = (lhs.dim(), rhs.dim()); + assert!(m.checked_mul(n).is_some()); + assert_eq!(k, k2); let mut res_elems = Vec::::with_capacity(m * n); unsafe { res_elems.set_len(m * n); @@ -263,7 +265,7 @@ fn reference_mat_mul(lhs: &ArrayBase, rhs: &ArrayBase