diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index 3b14ea221..1359b8f39 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -678,6 +678,36 @@ where } } +/// Move the axis which has the smallest absolute stride and a length +/// greater than one to be the last axis. +pub fn move_min_stride_axis_to_last(dim: &mut D, strides: &mut D) +where + D: Dimension, +{ + debug_assert_eq!(dim.ndim(), strides.ndim()); + match dim.ndim() { + 0 | 1 => {} + 2 => { + if dim[1] <= 1 + || dim[0] > 1 && (strides[0] as isize).abs() < (strides[1] as isize).abs() + { + dim.slice_mut().swap(0, 1); + strides.slice_mut().swap(0, 1); + } + } + n => { + if let Some(min_stride_axis) = (0..n) + .filter(|&ax| dim[ax] > 1) + .min_by_key(|&ax| (strides[ax] as isize).abs()) + { + let last = n - 1; + dim.slice_mut().swap(last, min_stride_axis); + strides.slice_mut().swap(last, min_stride_axis); + } + } + } +} + #[cfg(test)] mod test { use super::{ diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 7b6f3f6f5..81002291f 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -18,8 +18,8 @@ use crate::arraytraits; use crate::dimension; use crate::dimension::IntoDimension; use crate::dimension::{ - abs_index, axes_of, do_slice, merge_axes, offset_from_ptr_to_memory, size_of_shape_checked, - stride_offset, Axes, + abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last, + offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes, }; use crate::error::{self, ErrorKind, ShapeError}; use crate::math_cell::MathCell; @@ -1456,6 +1456,15 @@ where /// Return the array’s data as a slice if it is contiguous, /// return `None` otherwise. pub fn as_slice_memory_order_mut(&mut self) -> Option<&mut [A]> + where + S: DataMut, + { + self.try_as_slice_memory_order_mut().ok() + } + + /// Return the array’s data as a slice if it is contiguous, otherwise + /// return `self` in the `Err` variant. + pub(crate) fn try_as_slice_memory_order_mut(&mut self) -> Result<&mut [A], &mut Self> where S: DataMut, { @@ -1463,13 +1472,13 @@ where self.ensure_unique(); let offset = offset_from_ptr_to_memory(&self.dim, &self.strides); unsafe { - Some(slice::from_raw_parts_mut( + Ok(slice::from_raw_parts_mut( self.ptr.offset(offset).as_ptr(), self.len(), )) } } else { - None + Err(self) } } @@ -1976,7 +1985,7 @@ where S: DataMut, A: Clone, { - self.unordered_foreach_mut(move |elt| *elt = x.clone()); + self.map_inplace(move |elt| *elt = x.clone()); } fn zip_mut_with_same_shape(&mut self, rhs: &ArrayBase, mut f: F) @@ -2028,7 +2037,7 @@ where S: DataMut, F: FnMut(&mut A, &B), { - self.unordered_foreach_mut(move |elt| f(elt, rhs_elem)); + self.map_inplace(move |elt| f(elt, rhs_elem)); } /// Traverse two arrays in unspecified order, in lock step, @@ -2070,27 +2079,7 @@ where slc.iter().fold(init, f) } else { let mut v = self.view(); - // put the narrowest axis at the last position - match v.ndim() { - 0 | 1 => {} - 2 => { - if self.len_of(Axis(1)) <= 1 - || self.len_of(Axis(0)) > 1 - && self.stride_of(Axis(0)).abs() < self.stride_of(Axis(1)).abs() - { - v.swap_axes(0, 1); - } - } - n => { - let last = n - 1; - let narrow_axis = v - .axes() - .filter(|ax| ax.len() > 1) - .min_by_key(|ax| ax.stride().abs()) - .map_or(last, |ax| ax.axis().index()); - v.swap_axes(last, narrow_axis); - } - } + move_min_stride_axis_to_last(&mut v.dim, &mut v.strides); v.into_elements_base().fold(init, f) } } @@ -2200,12 +2189,20 @@ where /// Modify the array in place by calling `f` by mutable reference on each element. /// /// Elements are visited in arbitrary order. - pub fn map_inplace(&mut self, f: F) + pub fn map_inplace<'a, F>(&'a mut self, f: F) where S: DataMut, - F: FnMut(&mut A), - { - self.unordered_foreach_mut(f); + A: 'a, + F: FnMut(&'a mut A), + { + match self.try_as_slice_memory_order_mut() { + Ok(slc) => slc.iter_mut().for_each(f), + Err(arr) => { + let mut v = arr.view_mut(); + move_min_stride_axis_to_last(&mut v.dim, &mut v.strides); + v.into_elements_base().for_each(f); + } + } } /// Modify the array in place by calling `f` by **v**alue on each element. @@ -2235,7 +2232,7 @@ where F: FnMut(A) -> A, A: Clone, { - self.unordered_foreach_mut(move |x| *x = f(x.clone())); + self.map_inplace(move |x| *x = f(x.clone())); } /// Call `f` for each element in the array. diff --git a/src/impl_ops.rs b/src/impl_ops.rs index 51d432ee6..256bee3e5 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -141,7 +141,7 @@ impl $trt for ArrayBase { type Output = ArrayBase; fn $mth(mut self, x: B) -> ArrayBase { - self.unordered_foreach_mut(move |elt| { + self.map_inplace(move |elt| { *elt = elt.clone() $operator x.clone(); }); self @@ -194,7 +194,7 @@ impl $trt> for $scalar rhs.$mth(self) } or {{ let mut rhs = rhs; - rhs.unordered_foreach_mut(move |elt| { + rhs.map_inplace(move |elt| { *elt = self $operator *elt; }); rhs @@ -299,7 +299,7 @@ mod arithmetic_ops { type Output = Self; /// Perform an elementwise negation of `self` and return the result. fn neg(mut self) -> Self { - self.unordered_foreach_mut(|elt| { + self.map_inplace(|elt| { *elt = -elt.clone(); }); self @@ -329,7 +329,7 @@ mod arithmetic_ops { type Output = Self; /// Perform an elementwise unary not of `self` and return the result. fn not(mut self) -> Self { - self.unordered_foreach_mut(|elt| { + self.map_inplace(|elt| { *elt = !elt.clone(); }); self @@ -386,7 +386,7 @@ mod assign_ops { D: Dimension, { fn $method(&mut self, rhs: A) { - self.unordered_foreach_mut(move |elt| { + self.map_inplace(move |elt| { elt.$method(rhs.clone()); }); } diff --git a/src/lib.rs b/src/lib.rs index c71f09aaf..8faaf5ba7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -143,7 +143,7 @@ pub use crate::indexes::{indices, indices_of}; pub use crate::slice::{Slice, SliceInfo, SliceNextDim, SliceOrIndex}; use crate::iterators::Baseiter; -use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes, LanesMut}; +use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes}; pub use crate::arraytraits::AsArray; #[cfg(feature = "std")] @@ -1544,22 +1544,6 @@ where self.strides.clone() } - /// Apply closure `f` to each element in the array, in whatever - /// order is the fastest to visit. - fn unordered_foreach_mut(&mut self, mut f: F) - where - S: DataMut, - F: FnMut(&mut A), - { - if let Some(slc) = self.as_slice_memory_order_mut() { - slc.iter_mut().for_each(f); - } else { - for row in self.inner_rows_mut() { - row.into_iter_().fold((), |(), elt| f(elt)); - } - } - } - /// Remove array axis `axis` and return the result. fn try_remove_axis(self, axis: Axis) -> ArrayBase { let d = self.dim.try_remove_axis(axis); @@ -1577,15 +1561,6 @@ where let n = self.ndim(); Lanes::new(self.view(), Axis(n.saturating_sub(1))) } - - /// n-d generalization of rows, just like inner iter - fn inner_rows_mut(&mut self) -> iterators::LanesMut<'_, A, D::Smaller> - where - S: DataMut, - { - let n = self.ndim(); - LanesMut::new(self.view_mut(), Axis(n.saturating_sub(1))) - } } // parallel methods