Skip to content

Commit a064023

Browse files
committed
Lengthen life in map_inplace and unify impl with fold
1 parent cb544a0 commit a064023

File tree

2 files changed

+56
-29
lines changed

2 files changed

+56
-29
lines changed

src/dimension/mod.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,36 @@ where
678678
}
679679
}
680680

681+
/// Move the axis which has the smallest absolute stride and a length
682+
/// greater than one to be the last axis.
683+
pub fn move_min_stride_axis_to_last<D>(dim: &mut D, strides: &mut D)
684+
where
685+
D: Dimension,
686+
{
687+
debug_assert_eq!(dim.ndim(), strides.ndim());
688+
match dim.ndim() {
689+
0 | 1 => {}
690+
2 => {
691+
if dim[1] <= 1
692+
|| dim[0] > 1 && (strides[0] as isize).abs() < (strides[1] as isize).abs()
693+
{
694+
dim.slice_mut().swap(0, 1);
695+
strides.slice_mut().swap(0, 1);
696+
}
697+
}
698+
n => {
699+
if let Some(min_stride_axis) = (0..n)
700+
.filter(|&ax| dim[ax] > 1)
701+
.min_by_key(|&ax| (strides[ax] as isize).abs())
702+
{
703+
let last = n - 1;
704+
dim.slice_mut().swap(last, min_stride_axis);
705+
strides.slice_mut().swap(last, min_stride_axis);
706+
}
707+
}
708+
}
709+
}
710+
681711
#[cfg(test)]
682712
mod test {
683713
use super::{

src/impl_methods.rs

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ use crate::arraytraits;
1818
use crate::dimension;
1919
use crate::dimension::IntoDimension;
2020
use crate::dimension::{
21-
abs_index, axes_of, do_slice, merge_axes, offset_from_ptr_to_memory, size_of_shape_checked,
22-
stride_offset, Axes,
21+
abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last,
22+
offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes,
2323
};
2424
use crate::error::{self, ErrorKind, ShapeError};
2525
use crate::math_cell::MathCell;
@@ -1456,20 +1456,29 @@ where
14561456
/// Return the array’s data as a slice if it is contiguous,
14571457
/// return `None` otherwise.
14581458
pub fn as_slice_memory_order_mut(&mut self) -> Option<&mut [A]>
1459+
where
1460+
S: DataMut,
1461+
{
1462+
self.try_as_slice_memory_order_mut().ok()
1463+
}
1464+
1465+
/// Return the array’s data as a slice if it is contiguous, otherwise
1466+
/// return `self` in the `Err` variant.
1467+
pub(crate) fn try_as_slice_memory_order_mut(&mut self) -> Result<&mut [A], &mut Self>
14591468
where
14601469
S: DataMut,
14611470
{
14621471
if self.is_contiguous() {
14631472
self.ensure_unique();
14641473
let offset = offset_from_ptr_to_memory(&self.dim, &self.strides);
14651474
unsafe {
1466-
Some(slice::from_raw_parts_mut(
1475+
Ok(slice::from_raw_parts_mut(
14671476
self.ptr.offset(offset).as_ptr(),
14681477
self.len(),
14691478
))
14701479
}
14711480
} else {
1472-
None
1481+
Err(self)
14731482
}
14741483
}
14751484

@@ -2070,27 +2079,7 @@ where
20702079
slc.iter().fold(init, f)
20712080
} else {
20722081
let mut v = self.view();
2073-
// put the narrowest axis at the last position
2074-
match v.ndim() {
2075-
0 | 1 => {}
2076-
2 => {
2077-
if self.len_of(Axis(1)) <= 1
2078-
|| self.len_of(Axis(0)) > 1
2079-
&& self.stride_of(Axis(0)).abs() < self.stride_of(Axis(1)).abs()
2080-
{
2081-
v.swap_axes(0, 1);
2082-
}
2083-
}
2084-
n => {
2085-
let last = n - 1;
2086-
let narrow_axis = v
2087-
.axes()
2088-
.filter(|ax| ax.len() > 1)
2089-
.min_by_key(|ax| ax.stride().abs())
2090-
.map_or(last, |ax| ax.axis().index());
2091-
v.swap_axes(last, narrow_axis);
2092-
}
2093-
}
2082+
move_min_stride_axis_to_last(&mut v.dim, &mut v.strides);
20942083
v.into_elements_base().fold(init, f)
20952084
}
20962085
}
@@ -2200,12 +2189,20 @@ where
22002189
/// Modify the array in place by calling `f` by mutable reference on each element.
22012190
///
22022191
/// Elements are visited in arbitrary order.
2203-
pub fn map_inplace<F>(&mut self, f: F)
2192+
pub fn map_inplace<'a, F>(&'a mut self, f: F)
22042193
where
22052194
S: DataMut,
2206-
F: FnMut(&mut A),
2207-
{
2208-
self.unordered_foreach_mut(f);
2195+
A: 'a,
2196+
F: FnMut(&'a mut A),
2197+
{
2198+
match self.try_as_slice_memory_order_mut() {
2199+
Ok(slc) => slc.iter_mut().for_each(f),
2200+
Err(arr) => {
2201+
let mut v = arr.view_mut();
2202+
move_min_stride_axis_to_last(&mut v.dim, &mut v.strides);
2203+
v.into_elements_base().for_each(f);
2204+
}
2205+
}
22092206
}
22102207

22112208
/// Modify the array in place by calling `f` by **v**alue on each element.

0 commit comments

Comments
 (0)