Skip to content

Commit ef14b21

Browse files
committed
FIX: Solve axis iteration order problem by sorting axes
1 parent c300ec5 commit ef14b21

File tree

3 files changed

+110
-13
lines changed

3 files changed

+110
-13
lines changed

src/impl_owned_array.rs

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ impl<A, D> Array<A, D>
252252
/// [1., 1., 1., 1.],
253253
/// [1., 1., 1., 1.]]);
254254
/// ```
255-
pub fn try_append_array(&mut self, axis: Axis, array: ArrayView<A, D>)
255+
pub fn try_append_array(&mut self, axis: Axis, mut array: ArrayView<A, D>)
256256
-> Result<(), ShapeError>
257257
where
258258
A: Clone,
@@ -312,7 +312,7 @@ impl<A, D> Array<A, D>
312312
// make a raw view with the new row
313313
// safe because the data was "full"
314314
let tail_ptr = self.data.as_end_nonnull();
315-
let tail_view = RawArrayViewMut::new(tail_ptr, array_shape, strides.clone());
315+
let mut tail_view = RawArrayViewMut::new(tail_ptr, array_shape, strides.clone());
316316

317317
struct SetLenOnDrop<'a, A: 'a> {
318318
len: usize,
@@ -332,37 +332,86 @@ impl<A, D> Array<A, D>
332332
}
333333
}
334334

335-
// we have a problem here XXX
336-
//
337335
// To be robust for panics and drop the right elements, we want
338336
// to fill the tail in-order, so that we can drop the right elements on
339-
// panic. Don't know how to achieve that.
337+
// panic.
340338
//
341-
// It might be easier to retrace our steps in a scope guard to drop the right
342-
// elements.. (PartialArray style).
339+
// We have: Zip::from(tail_view).and(array)
340+
// Transform tail_view into standard order by inverting and moving its axes.
341+
// Keep the Zip traversal unchanged by applying the same axis transformations to
342+
// `array`. This ensures the Zip traverses the underlying memory in order.
343343
//
344-
// assign the new elements
344+
// XXX It would be possible to skip this transformation if the element
345+
// doesn't have drop. However, in the interest of code coverage, all elements
346+
// use this code initially.
347+
348+
if tail_view.ndim() > 1 {
349+
for i in 0..tail_view.ndim() {
350+
if tail_view.stride_of(Axis(i)) < 0 {
351+
tail_view.invert_axis(Axis(i));
352+
array.invert_axis(Axis(i));
353+
}
354+
}
355+
sort_axes_to_standard_order(&mut tail_view, &mut array);
356+
}
345357
Zip::from(tail_view).and(array)
358+
.debug_assert_c_order()
346359
.for_each(|to, from| {
347360
to.write(from.clone());
348361
length_guard.len += 1;
349362
});
350363

351-
//length_guard.len += len_to_append;
352-
dbg!(len_to_append);
353364
drop(length_guard);
354365

355366
// update array dimension
356367
self.strides = strides;
357368
self.dim = res_dim;
358-
dbg!(&self.dim);
359-
360369
}
361370
// multiple assertions after pointer & dimension update
362371
debug_assert_eq!(self.data.len(), self.len());
363372
debug_assert_eq!(self.len(), new_len);
364-
debug_assert!(self.is_standard_layout());
365373

366374
Ok(())
367375
}
368376
}
377+
378+
fn sort_axes_to_standard_order<S, S2, D>(a: &mut ArrayBase<S, D>, b: &mut ArrayBase<S2, D>)
379+
where
380+
S: RawData,
381+
S2: RawData,
382+
D: Dimension,
383+
{
384+
if a.ndim() <= 1 {
385+
return;
386+
}
387+
sort_axes_impl(&mut a.dim, &mut a.strides, &mut b.dim, &mut b.strides);
388+
debug_assert!(a.is_standard_layout());
389+
}
390+
391+
fn sort_axes_impl<D>(adim: &mut D, astrides: &mut D, bdim: &mut D, bstrides: &mut D)
392+
where
393+
D: Dimension,
394+
{
395+
debug_assert!(adim.ndim() > 1);
396+
debug_assert_eq!(adim.ndim(), bdim.ndim());
397+
// bubble sort axes
398+
let mut changed = true;
399+
while changed {
400+
changed = false;
401+
for i in 0..adim.ndim() - 1 {
402+
let axis_i = i;
403+
let next_axis = i + 1;
404+
405+
// make sure higher stride axes sort before.
406+
debug_assert!(astrides.slice()[axis_i] as isize >= 0);
407+
if (astrides.slice()[axis_i] as isize) < astrides.slice()[next_axis] as isize {
408+
changed = true;
409+
adim.slice_mut().swap(axis_i, next_axis);
410+
astrides.slice_mut().swap(axis_i, next_axis);
411+
bdim.slice_mut().swap(axis_i, next_axis);
412+
bstrides.slice_mut().swap(axis_i, next_axis);
413+
}
414+
}
415+
}
416+
}
417+

src/zip/mod.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,13 @@ macro_rules! map_impl {
673673
self.build_and(part)
674674
}
675675

676+
#[allow(unused)]
677+
#[inline]
678+
pub(crate) fn debug_assert_c_order(self) -> Self {
679+
debug_assert!(self.layout.is(CORDER) || self.layout_tendency >= 0);
680+
self
681+
}
682+
676683
fn build_and<P>(self, part: P) -> Zip<($($p,)* P, ), D>
677684
where P: NdProducer<Dim=D>,
678685
{

tests/append.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,44 @@ fn append_array1() {
105105
[5., 5., 4., 4.],
106106
[3., 3., 2., 2.]]);
107107
}
108+
109+
#[test]
110+
fn append_array_3d() {
111+
let mut a = Array::zeros((0, 2, 2));
112+
a.try_append_array(Axis(0), array![[[0, 1], [2, 3]]].view()).unwrap();
113+
println!("{:?}", a);
114+
115+
let aa = array![[[51, 52], [53, 54]], [[55, 56], [57, 58]]];
116+
let av = aa.view();
117+
println!("Send {:?} to append", av);
118+
a.try_append_array(Axis(0), av.clone()).unwrap();
119+
120+
a.swap_axes(0, 1);
121+
let aa = array![[[71, 72], [73, 74]], [[75, 76], [77, 78]]];
122+
let mut av = aa.view();
123+
av.swap_axes(0, 1);
124+
println!("Send {:?} to append", av);
125+
a.try_append_array(Axis(1), av.clone()).unwrap();
126+
println!("{:?}", a);
127+
let aa = array![[[81, 82], [83, 84]], [[85, 86], [87, 88]]];
128+
let mut av = aa.view();
129+
av.swap_axes(0, 1);
130+
println!("Send {:?} to append", av);
131+
a.try_append_array(Axis(1), av).unwrap();
132+
println!("{:?}", a);
133+
assert_eq!(a,
134+
array![[[0, 1],
135+
[51, 52],
136+
[55, 56],
137+
[71, 72],
138+
[75, 76],
139+
[81, 82],
140+
[85, 86]],
141+
[[2, 3],
142+
[53, 54],
143+
[57, 58],
144+
[73, 74],
145+
[77, 78],
146+
[83, 84],
147+
[87, 88]]]);
148+
}

0 commit comments

Comments
 (0)