Skip to content

Commit 9447328

Browse files
author
Nil Goyette
authored
Merge pull request #1297 from LazaroHurtado/fix/window_stride
Updated Windows `base` Computations to be Safer
2 parents 17a8d25 + 5bcc73e commit 9447328

File tree

2 files changed

+53
-31
lines changed

2 files changed

+53
-31
lines changed

src/iterators/windows.rs

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::imp_prelude::*;
33
use crate::IntoDimension;
44
use crate::Layout;
55
use crate::NdProducer;
6+
use crate::Slice;
67

78
/// Window producer and iterable
89
///
@@ -24,16 +25,19 @@ impl<'a, A, D: Dimension> Windows<'a, A, D> {
2425

2526
let mut unit_stride = D::zeros(ndim);
2627
unit_stride.slice_mut().fill(1);
27-
28+
2829
Windows::new_with_stride(a, window, unit_stride)
2930
}
3031

31-
pub(crate) fn new_with_stride<E>(a: ArrayView<'a, A, D>, window_size: E, strides: E) -> Self
32+
pub(crate) fn new_with_stride<E>(a: ArrayView<'a, A, D>, window_size: E, axis_strides: E) -> Self
3233
where
3334
E: IntoDimension<Dim = D>,
3435
{
3536
let window = window_size.into_dimension();
36-
let strides_d = strides.into_dimension();
37+
38+
let strides = axis_strides.into_dimension();
39+
let window_strides = a.strides.clone();
40+
3741
ndassert!(
3842
a.ndim() == window.ndim(),
3943
concat!(
@@ -44,45 +48,35 @@ impl<'a, A, D: Dimension> Windows<'a, A, D> {
4448
a.ndim(),
4549
a.shape()
4650
);
51+
4752
ndassert!(
48-
a.ndim() == strides_d.ndim(),
53+
a.ndim() == strides.ndim(),
4954
concat!(
5055
"Stride dimension {} does not match array dimension {} ",
5156
"(with array of shape {:?})"
5257
),
53-
strides_d.ndim(),
58+
strides.ndim(),
5459
a.ndim(),
5560
a.shape()
5661
);
57-
let mut size = a.dim;
58-
for ((sz, &ws), &stride) in size
59-
.slice_mut()
60-
.iter_mut()
61-
.zip(window.slice())
62-
.zip(strides_d.slice())
63-
{
64-
assert_ne!(ws, 0, "window-size must not be zero!");
65-
assert_ne!(stride, 0, "stride cannot have a dimension as zero!");
66-
// cannot use std::cmp::max(0, ..) since arithmetic underflow panics
67-
*sz = if *sz < ws {
68-
0
69-
} else {
70-
((*sz - (ws - 1) - 1) / stride) + 1
71-
};
72-
}
73-
let window_strides = a.strides.clone();
7462

75-
let mut array_strides = a.strides.clone();
76-
for (arr_stride, ix_stride) in array_strides.slice_mut().iter_mut().zip(strides_d.slice()) {
77-
*arr_stride *= ix_stride;
78-
}
63+
let mut base = a;
64+
base.slice_each_axis_inplace(|ax_desc| {
65+
let len = ax_desc.len;
66+
let wsz = window[ax_desc.axis.index()];
67+
let stride = strides[ax_desc.axis.index()];
7968

80-
unsafe {
81-
Windows {
82-
base: ArrayView::new(a.ptr, size, array_strides),
83-
window,
84-
strides: window_strides,
69+
if len < wsz {
70+
Slice::new(0, Some(0), 1)
71+
} else {
72+
Slice::new(0, Some((len - wsz + 1) as isize), stride as isize)
8573
}
74+
});
75+
76+
Windows {
77+
base,
78+
window,
79+
strides: window_strides,
8680
}
8781
}
8882
}

tests/windows.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,3 +302,31 @@ fn test_window_neg_stride() {
302302
answer.iter()
303303
);
304304
}
305+
306+
#[test]
307+
fn test_windows_with_stride_on_inverted_axis() {
308+
let mut array = Array::from_iter(1..17).into_shape((4, 4)).unwrap();
309+
310+
// inverting axis results in negative stride
311+
array.invert_axis(Axis(0));
312+
itertools::assert_equal(
313+
array.windows_with_stride((2, 2), (2,2)),
314+
vec![
315+
arr2(&[[13, 14], [9, 10]]),
316+
arr2(&[[15, 16], [11, 12]]),
317+
arr2(&[[5, 6], [1, 2]]),
318+
arr2(&[[7, 8], [3, 4]]),
319+
],
320+
);
321+
322+
array.invert_axis(Axis(1));
323+
itertools::assert_equal(
324+
array.windows_with_stride((2, 2), (2,2)),
325+
vec![
326+
arr2(&[[16, 15], [12, 11]]),
327+
arr2(&[[14, 13], [10, 9]]),
328+
arr2(&[[8, 7], [4, 3]]),
329+
arr2(&[[6, 5], [2, 1]]),
330+
],
331+
);
332+
}

0 commit comments

Comments
 (0)