From bed9c63337fa8bf7b6116e58d2c6892ff59a9c75 Mon Sep 17 00:00:00 2001 From: Stokhos Date: Fri, 9 Apr 2021 22:53:59 -0400 Subject: [PATCH] add getter for dim in shape minor fix changed method name --- src/shape_builder.rs | 42 +++++++++++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/src/shape_builder.rs b/src/shape_builder.rs index 6fc99d0b2..dcfddc1b9 100644 --- a/src/shape_builder.rs +++ b/src/shape_builder.rs @@ -13,7 +13,7 @@ pub struct Shape { } #[derive(Copy, Clone, Debug)] -pub(crate) enum Contiguous { } +pub(crate) enum Contiguous {} impl Shape { pub(crate) fn is_c(&self) -> bool { @@ -21,7 +21,6 @@ impl Shape { } } - /// An array shape of n dimensions in c-order, f-order or custom strides. #[derive(Copy, Clone, Debug)] pub struct StrideShape { @@ -29,6 +28,20 @@ pub struct StrideShape { pub(crate) strides: Strides, } +impl StrideShape +where + D: Dimension, +{ + /// Return a reference to the dimension + pub fn raw_dim(&self) -> &D { + &self.dim + } + /// Return the size of the shape in number of elements + pub fn size(&self) -> usize { + self.dim.size() + } +} + /// Stride description #[derive(Copy, Clone, Debug)] pub(crate) enum Strides { @@ -37,21 +50,26 @@ pub(crate) enum Strides { /// Column-major ("F"-order) F, /// Custom strides - Custom(D) + Custom(D), } impl Strides { /// Return strides for `dim` (computed from dimension if c/f, else return the custom stride) pub(crate) fn strides_for_dim(self, dim: &D) -> D - where D: Dimension + where + D: Dimension, { match self { Strides::C => dim.default_strides(), Strides::F => dim.fortran_strides(), Strides::Custom(c) => { - debug_assert_eq!(c.ndim(), dim.ndim(), + debug_assert_eq!( + c.ndim(), + dim.ndim(), "Custom strides given with {} dimensions, expected {}", - c.ndim(), dim.ndim()); + c.ndim(), + dim.ndim() + ); c } } @@ -94,11 +112,7 @@ where { fn from(value: T) -> Self { let shape = value.into_shape(); - let st = if shape.is_c() { - Strides::C - } else { - Strides::F - }; + let st = if shape.is_c() { Strides::C } else { Strides::F }; StrideShape { strides: st, dim: shape.dim, @@ -161,8 +175,10 @@ impl Shape where D: Dimension, { - // Return a reference to the dimension - //pub fn dimension(&self) -> &D { &self.dim } + /// Return a reference to the dimension + pub fn raw_dim(&self) -> &D { + &self.dim + } /// Return the size of the shape in number of elements pub fn size(&self) -> usize { self.dim.size()