Skip to content

Commit bcd7078

Browse files
jturner314bluss
authored andcommitted
Add accumulate_axis_inplace method
1 parent 71c8c8f commit bcd7078

File tree

2 files changed

+97
-0
lines changed

2 files changed

+97
-0
lines changed

src/impl_methods.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2231,4 +2231,60 @@ where
22312231
})
22322232
}
22332233
}
2234+
2235+
/// Iterates over pairs of consecutive elements along the axis.
2236+
///
2237+
/// The first argument to the closure is an element, and the second
2238+
/// argument is the next element along the axis. Iteration is guaranteed to
2239+
/// proceed in order along the specified axis, but in all other respects
2240+
/// the iteration order is unspecified.
2241+
///
2242+
/// # Example
2243+
///
2244+
/// For example, this can be used to compute the cumulative sum along an
2245+
/// axis:
2246+
///
2247+
/// ```
2248+
/// use ndarray::{array, Axis};
2249+
///
2250+
/// let mut arr = array![
2251+
/// [[1, 2], [3, 4], [5, 6]],
2252+
/// [[7, 8], [9, 10], [11, 12]],
2253+
/// ];
2254+
/// arr.accumulate_axis_inplace(Axis(1), |&prev, curr| *curr += prev);
2255+
/// assert_eq!(
2256+
/// arr,
2257+
/// array![
2258+
/// [[1, 2], [4, 6], [9, 12]],
2259+
/// [[7, 8], [16, 18], [27, 30]],
2260+
/// ],
2261+
/// );
2262+
/// ```
2263+
pub fn accumulate_axis_inplace<F>(&mut self, axis: Axis, mut f: F)
2264+
where
2265+
F: FnMut(&A, &mut A),
2266+
S: DataMut,
2267+
{
2268+
if self.len_of(axis) <= 1 {
2269+
return;
2270+
}
2271+
let mut prev = self.raw_view();
2272+
prev.slice_axis_inplace(axis, Slice::from(..-1));
2273+
let mut curr = self.raw_view_mut();
2274+
curr.slice_axis_inplace(axis, Slice::from(1..));
2275+
// This implementation relies on `Zip` iterating along `axis` in order.
2276+
Zip::from(prev).and(curr).apply(|prev, curr| unsafe {
2277+
// These pointer dereferences and borrows are safe because:
2278+
//
2279+
// 1. They're pointers to elements in the array.
2280+
//
2281+
// 2. `S: DataMut` guarantees that elements are safe to borrow
2282+
// mutably and that they don't alias.
2283+
//
2284+
// 3. The lifetimes of the borrows last only for the duration
2285+
// of the call to `f`, so aliasing across calls to `f`
2286+
// cannot occur.
2287+
f(&*prev, &mut *curr)
2288+
});
2289+
}
22342290
}

tests/array.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1952,6 +1952,47 @@ fn test_map_axis() {
19521952
itertools::assert_equal(result.iter().cloned().sorted(), 1..=3 * 4);
19531953
}
19541954

1955+
#[test]
1956+
fn test_accumulate_axis_inplace_noop() {
1957+
let mut a = Array2::<u8>::zeros((0, 3));
1958+
a.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev);
1959+
assert_eq!(a, Array2::zeros((0, 3)));
1960+
1961+
let mut a = Array2::<u8>::zeros((3, 1));
1962+
a.accumulate_axis_inplace(Axis(1), |&prev, curr| *curr += prev);
1963+
assert_eq!(a, Array2::zeros((3, 1)));
1964+
}
1965+
1966+
#[test]
1967+
fn test_accumulate_axis_inplace_nonstandard_layout() {
1968+
let a = arr2(&[[1, 2, 3],
1969+
[4, 5, 6],
1970+
[7, 8, 9],
1971+
[10,11,12]]);
1972+
1973+
let mut a_t = a.clone().reversed_axes();
1974+
a_t.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev);
1975+
assert_eq!(a_t, aview2(&[[1, 4, 7, 10],
1976+
[3, 9, 15, 21],
1977+
[6, 15, 24, 33]]));
1978+
1979+
let mut a0 = a.clone();
1980+
a0.invert_axis(Axis(0));
1981+
a0.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev);
1982+
assert_eq!(a0, aview2(&[[10, 11, 12],
1983+
[17, 19, 21],
1984+
[21, 24, 27],
1985+
[22, 26, 30]]));
1986+
1987+
let mut a1 = a.clone();
1988+
a1.invert_axis(Axis(1));
1989+
a1.accumulate_axis_inplace(Axis(1), |&prev, curr| *curr += prev);
1990+
assert_eq!(a1, aview2(&[[3, 5, 6],
1991+
[6, 11, 15],
1992+
[9, 17, 24],
1993+
[12, 23, 33]]));
1994+
}
1995+
19551996
#[test]
19561997
fn test_to_vec() {
19571998
let mut a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]);

0 commit comments

Comments
 (0)