Skip to content

Commit 326097c

Browse files
committed
FIX: Fix situations where we need to recompute stride
When the axis has length 0, or 1, we need to carefully compute new strides.
1 parent ef14b21 commit 326097c

File tree

2 files changed

+85
-10
lines changed

2 files changed

+85
-10
lines changed

src/impl_owned_array.rs

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ impl<A, D> Array<A, D>
262262
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape));
263263
}
264264

265+
let current_axis_len = self.len_of(axis);
265266
let remaining_shape = self.raw_dim().remove_axis(axis);
266267
let array_rem_shape = array.raw_dim().remove_axis(axis);
267268

@@ -281,22 +282,46 @@ impl<A, D> Array<A, D>
281282

282283
let self_is_empty = self.is_empty();
283284

284-
// array must be empty or have `axis` as the outermost (longest stride)
285-
// axis
286-
if !(self_is_empty ||
287-
self.axes().max_by_key(|ax| ax.stride).map(|ax| ax.axis) == Some(axis))
288-
{
289-
return Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout));
285+
// array must be empty or have `axis` as the outermost (longest stride) axis
286+
if !self_is_empty && current_axis_len > 1 {
287+
// `axis` must be max stride axis or equal to its stride
288+
let max_stride_axis = self.axes().max_by_key(|ax| ax.stride).unwrap();
289+
if max_stride_axis.axis != axis && max_stride_axis.stride > self.stride_of(axis) {
290+
return Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout));
291+
}
290292
}
291293

292294
// array must be be "full" (have no exterior holes)
293295
if self.len() != self.data.len() {
294296
return Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout));
295297
}
298+
296299
let strides = if self_is_empty {
297-
// recompute strides - if the array was previously empty, it could have
298-
// zeros in strides.
299-
res_dim.default_strides()
300+
// recompute strides - if the array was previously empty, it could have zeros in
301+
// strides.
302+
// The new order is based on c/f-contig but must have `axis` as outermost axis.
303+
if axis == Axis(self.ndim() - 1) {
304+
// prefer f-contig when appending to the last axis
305+
// Axis n - 1 is outermost axis
306+
res_dim.fortran_strides()
307+
} else {
308+
// Default with modification
309+
res_dim.slice_mut().swap(0, axis.index());
310+
let mut strides = res_dim.default_strides();
311+
res_dim.slice_mut().swap(0, axis.index());
312+
strides.slice_mut().swap(0, axis.index());
313+
strides
314+
}
315+
} else if current_axis_len == 1 {
316+
// This is the outermost/longest stride axis; so we find the max across the other axes
317+
let new_stride = self.axes().fold(1, |acc, ax| {
318+
if ax.axis == axis { acc } else {
319+
Ord::max(acc, ax.len as isize * ax.stride)
320+
}
321+
});
322+
let mut strides = self.strides.clone();
323+
strides[axis.index()] = new_stride as usize;
324+
strides
300325
} else {
301326
self.strides.clone()
302327
};
@@ -385,7 +410,8 @@ where
385410
return;
386411
}
387412
sort_axes_impl(&mut a.dim, &mut a.strides, &mut b.dim, &mut b.strides);
388-
debug_assert!(a.is_standard_layout());
413+
debug_assert!(a.is_standard_layout(), "not std layout dim: {:?}, strides: {:?}",
414+
a.shape(), a.strides());
389415
}
390416

391417
fn sort_axes_impl<D>(adim: &mut D, astrides: &mut D, bdim: &mut D, bstrides: &mut D)

tests/append.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,52 @@ fn append_array_3d() {
146146
[83, 84],
147147
[87, 88]]]);
148148
}
149+
150+
#[test]
151+
fn test_append_2d() {
152+
// create an empty array and append
153+
let mut a = Array::zeros((0, 4));
154+
let ones = ArrayView::from(&[1.; 12]).into_shape((3, 4)).unwrap();
155+
let zeros = ArrayView::from(&[0.; 8]).into_shape((2, 4)).unwrap();
156+
a.try_append_array(Axis(0), ones).unwrap();
157+
a.try_append_array(Axis(0), zeros).unwrap();
158+
a.try_append_array(Axis(0), ones).unwrap();
159+
println!("{:?}", a);
160+
assert_eq!(a.shape(), &[8, 4]);
161+
for (i, row) in a.rows().into_iter().enumerate() {
162+
let ones = i < 3 || i >= 5;
163+
assert!(row.iter().all(|&x| x == ones as i32 as f64), "failed on lane {}", i);
164+
}
165+
166+
let mut a = Array::zeros((0, 4));
167+
a = a.reversed_axes();
168+
let ones = ones.reversed_axes();
169+
let zeros = zeros.reversed_axes();
170+
a.try_append_array(Axis(1), ones).unwrap();
171+
a.try_append_array(Axis(1), zeros).unwrap();
172+
a.try_append_array(Axis(1), ones).unwrap();
173+
println!("{:?}", a);
174+
assert_eq!(a.shape(), &[4, 8]);
175+
176+
for (i, row) in a.columns().into_iter().enumerate() {
177+
let ones = i < 3 || i >= 5;
178+
assert!(row.iter().all(|&x| x == ones as i32 as f64), "failed on lane {}", i);
179+
}
180+
}
181+
182+
#[test]
183+
fn test_append_middle_axis() {
184+
// ensure we can append to Axis(1) by letting it become outermost
185+
let mut a = Array::<i32, _>::zeros((3, 0, 2));
186+
a.try_append_array(Axis(1), Array::from_iter(0..12).into_shape((3, 2, 2)).unwrap().view()).unwrap();
187+
println!("{:?}", a);
188+
a.try_append_array(Axis(1), Array::from_iter(12..24).into_shape((3, 2, 2)).unwrap().view()).unwrap();
189+
println!("{:?}", a);
190+
191+
// ensure we can append to Axis(1) by letting it become outermost
192+
let mut a = Array::<i32, _>::zeros((3, 1, 2));
193+
a.try_append_array(Axis(1), Array::from_iter(0..12).into_shape((3, 2, 2)).unwrap().view()).unwrap();
194+
println!("{:?}", a);
195+
a.try_append_array(Axis(1), Array::from_iter(12..24).into_shape((3, 2, 2)).unwrap().view()).unwrap();
196+
println!("{:?}", a);
197+
}

0 commit comments

Comments
 (0)