Skip to content

Commit f7b9816

Browse files
authored
Merge pull request #911 from jturner314/improve-map_inplace
Improve `map_inplace`, and use it to replace `unordered_foreach_mut`
2 parents a66f364 + 0dce73a commit f7b9816

File tree

4 files changed

+65
-63
lines changed

4 files changed

+65
-63
lines changed

src/dimension/mod.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,36 @@ where
678678
}
679679
}
680680

681+
/// Move the axis which has the smallest absolute stride and a length
682+
/// greater than one to be the last axis.
683+
pub fn move_min_stride_axis_to_last<D>(dim: &mut D, strides: &mut D)
684+
where
685+
D: Dimension,
686+
{
687+
debug_assert_eq!(dim.ndim(), strides.ndim());
688+
match dim.ndim() {
689+
0 | 1 => {}
690+
2 => {
691+
if dim[1] <= 1
692+
|| dim[0] > 1 && (strides[0] as isize).abs() < (strides[1] as isize).abs()
693+
{
694+
dim.slice_mut().swap(0, 1);
695+
strides.slice_mut().swap(0, 1);
696+
}
697+
}
698+
n => {
699+
if let Some(min_stride_axis) = (0..n)
700+
.filter(|&ax| dim[ax] > 1)
701+
.min_by_key(|&ax| (strides[ax] as isize).abs())
702+
{
703+
let last = n - 1;
704+
dim.slice_mut().swap(last, min_stride_axis);
705+
strides.slice_mut().swap(last, min_stride_axis);
706+
}
707+
}
708+
}
709+
}
710+
681711
#[cfg(test)]
682712
mod test {
683713
use super::{

src/impl_methods.rs

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ use crate::arraytraits;
1818
use crate::dimension;
1919
use crate::dimension::IntoDimension;
2020
use crate::dimension::{
21-
abs_index, axes_of, do_slice, merge_axes, offset_from_ptr_to_memory, size_of_shape_checked,
22-
stride_offset, Axes,
21+
abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last,
22+
offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes,
2323
};
2424
use crate::error::{self, ErrorKind, ShapeError};
2525
use crate::math_cell::MathCell;
@@ -1448,20 +1448,29 @@ where
14481448
/// Return the array’s data as a slice if it is contiguous,
14491449
/// return `None` otherwise.
14501450
pub fn as_slice_memory_order_mut(&mut self) -> Option<&mut [A]>
1451+
where
1452+
S: DataMut,
1453+
{
1454+
self.try_as_slice_memory_order_mut().ok()
1455+
}
1456+
1457+
/// Return the array’s data as a slice if it is contiguous, otherwise
1458+
/// return `self` in the `Err` variant.
1459+
pub(crate) fn try_as_slice_memory_order_mut(&mut self) -> Result<&mut [A], &mut Self>
14511460
where
14521461
S: DataMut,
14531462
{
14541463
if self.is_contiguous() {
14551464
self.ensure_unique();
14561465
let offset = offset_from_ptr_to_memory(&self.dim, &self.strides);
14571466
unsafe {
1458-
Some(slice::from_raw_parts_mut(
1467+
Ok(slice::from_raw_parts_mut(
14591468
self.ptr.offset(offset).as_ptr(),
14601469
self.len(),
14611470
))
14621471
}
14631472
} else {
1464-
None
1473+
Err(self)
14651474
}
14661475
}
14671476

@@ -1943,7 +1952,7 @@ where
19431952
S: DataMut,
19441953
A: Clone,
19451954
{
1946-
self.unordered_foreach_mut(move |elt| *elt = x.clone());
1955+
self.map_inplace(move |elt| *elt = x.clone());
19471956
}
19481957

19491958
fn zip_mut_with_same_shape<B, S2, E, F>(&mut self, rhs: &ArrayBase<S2, E>, mut f: F)
@@ -1995,7 +2004,7 @@ where
19952004
S: DataMut,
19962005
F: FnMut(&mut A, &B),
19972006
{
1998-
self.unordered_foreach_mut(move |elt| f(elt, rhs_elem));
2007+
self.map_inplace(move |elt| f(elt, rhs_elem));
19992008
}
20002009

20012010
/// Traverse two arrays in unspecified order, in lock step,
@@ -2037,27 +2046,7 @@ where
20372046
slc.iter().fold(init, f)
20382047
} else {
20392048
let mut v = self.view();
2040-
// put the narrowest axis at the last position
2041-
match v.ndim() {
2042-
0 | 1 => {}
2043-
2 => {
2044-
if self.len_of(Axis(1)) <= 1
2045-
|| self.len_of(Axis(0)) > 1
2046-
&& self.stride_of(Axis(0)).abs() < self.stride_of(Axis(1)).abs()
2047-
{
2048-
v.swap_axes(0, 1);
2049-
}
2050-
}
2051-
n => {
2052-
let last = n - 1;
2053-
let narrow_axis = v
2054-
.axes()
2055-
.filter(|ax| ax.len() > 1)
2056-
.min_by_key(|ax| ax.stride().abs())
2057-
.map_or(last, |ax| ax.axis().index());
2058-
v.swap_axes(last, narrow_axis);
2059-
}
2060-
}
2049+
move_min_stride_axis_to_last(&mut v.dim, &mut v.strides);
20612050
v.into_elements_base().fold(init, f)
20622051
}
20632052
}
@@ -2167,12 +2156,20 @@ where
21672156
/// Modify the array in place by calling `f` by mutable reference on each element.
21682157
///
21692158
/// Elements are visited in arbitrary order.
2170-
pub fn map_inplace<F>(&mut self, f: F)
2159+
pub fn map_inplace<'a, F>(&'a mut self, f: F)
21712160
where
21722161
S: DataMut,
2173-
F: FnMut(&mut A),
2174-
{
2175-
self.unordered_foreach_mut(f);
2162+
A: 'a,
2163+
F: FnMut(&'a mut A),
2164+
{
2165+
match self.try_as_slice_memory_order_mut() {
2166+
Ok(slc) => slc.iter_mut().for_each(f),
2167+
Err(arr) => {
2168+
let mut v = arr.view_mut();
2169+
move_min_stride_axis_to_last(&mut v.dim, &mut v.strides);
2170+
v.into_elements_base().for_each(f);
2171+
}
2172+
}
21762173
}
21772174

21782175
/// Modify the array in place by calling `f` by **v**alue on each element.
@@ -2202,7 +2199,7 @@ where
22022199
F: FnMut(A) -> A,
22032200
A: Clone,
22042201
{
2205-
self.unordered_foreach_mut(move |x| *x = f(x.clone()));
2202+
self.map_inplace(move |x| *x = f(x.clone()));
22062203
}
22072204

22082205
/// Call `f` for each element in the array.

src/impl_ops.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ impl<A, S, D, B> $trt<B> for ArrayBase<S, D>
141141
{
142142
type Output = ArrayBase<S, D>;
143143
fn $mth(mut self, x: B) -> ArrayBase<S, D> {
144-
self.unordered_foreach_mut(move |elt| {
144+
self.map_inplace(move |elt| {
145145
*elt = elt.clone() $operator x.clone();
146146
});
147147
self
@@ -194,7 +194,7 @@ impl<S, D> $trt<ArrayBase<S, D>> for $scalar
194194
rhs.$mth(self)
195195
} or {{
196196
let mut rhs = rhs;
197-
rhs.unordered_foreach_mut(move |elt| {
197+
rhs.map_inplace(move |elt| {
198198
*elt = self $operator *elt;
199199
});
200200
rhs
@@ -299,7 +299,7 @@ mod arithmetic_ops {
299299
type Output = Self;
300300
/// Perform an elementwise negation of `self` and return the result.
301301
fn neg(mut self) -> Self {
302-
self.unordered_foreach_mut(|elt| {
302+
self.map_inplace(|elt| {
303303
*elt = -elt.clone();
304304
});
305305
self
@@ -329,7 +329,7 @@ mod arithmetic_ops {
329329
type Output = Self;
330330
/// Perform an elementwise unary not of `self` and return the result.
331331
fn not(mut self) -> Self {
332-
self.unordered_foreach_mut(|elt| {
332+
self.map_inplace(|elt| {
333333
*elt = !elt.clone();
334334
});
335335
self
@@ -386,7 +386,7 @@ mod assign_ops {
386386
D: Dimension,
387387
{
388388
fn $method(&mut self, rhs: A) {
389-
self.unordered_foreach_mut(move |elt| {
389+
self.map_inplace(move |elt| {
390390
elt.$method(rhs.clone());
391391
});
392392
}

src/lib.rs

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ pub use crate::indexes::{indices, indices_of};
143143
pub use crate::slice::{Slice, SliceInfo, SliceNextDim, SliceOrIndex};
144144

145145
use crate::iterators::Baseiter;
146-
use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes, LanesMut};
146+
use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes};
147147

148148
pub use crate::arraytraits::AsArray;
149149
#[cfg(feature = "std")]
@@ -1545,22 +1545,6 @@ where
15451545
self.strides.clone()
15461546
}
15471547

1548-
/// Apply closure `f` to each element in the array, in whatever
1549-
/// order is the fastest to visit.
1550-
fn unordered_foreach_mut<F>(&mut self, mut f: F)
1551-
where
1552-
S: DataMut,
1553-
F: FnMut(&mut A),
1554-
{
1555-
if let Some(slc) = self.as_slice_memory_order_mut() {
1556-
slc.iter_mut().for_each(f);
1557-
} else {
1558-
for row in self.inner_rows_mut() {
1559-
row.into_iter_().fold((), |(), elt| f(elt));
1560-
}
1561-
}
1562-
}
1563-
15641548
/// Remove array axis `axis` and return the result.
15651549
fn try_remove_axis(self, axis: Axis) -> ArrayBase<S, D::Smaller> {
15661550
let d = self.dim.try_remove_axis(axis);
@@ -1576,15 +1560,6 @@ where
15761560
let n = self.ndim();
15771561
Lanes::new(self.view(), Axis(n.saturating_sub(1)))
15781562
}
1579-
1580-
/// n-d generalization of rows, just like inner iter
1581-
fn inner_rows_mut(&mut self) -> iterators::LanesMut<'_, A, D::Smaller>
1582-
where
1583-
S: DataMut,
1584-
{
1585-
let n = self.ndim();
1586-
LanesMut::new(self.view_mut(), Axis(n.saturating_sub(1)))
1587-
}
15881563
}
15891564

15901565
// parallel methods

0 commit comments

Comments
 (0)