Skip to content

Commit a307ac2

Browse files
committed
FEAT: Add dimension merge function to merge contiguous axes
1 parent b181f32 commit a307ac2

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed

src/dimension/mod.rs

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,33 @@ where
701701
}
702702
}
703703

704+
/// Attempt to merge axes if possible, starting from the back
705+
///
706+
/// Given axes [Axis(0), Axis(1), Axis(2), Axis(3)] this attempts
707+
/// to merge all axes one by one into Axis(3); when/if this fails,
708+
/// it attempts to merge the rest of the axes together into the next
709+
/// axis in line, for example a result could be:
710+
///
711+
/// [1, Axis(0) + Axis(1), 1, Axis(2) + Axis(3)] where `+` would
712+
/// mean axes were merged.
713+
pub(crate) fn merge_axes_from_the_back<D>(dim: &mut D, strides: &mut D)
714+
where
715+
D: Dimension,
716+
{
717+
debug_assert_eq!(dim.ndim(), strides.ndim());
718+
match dim.ndim() {
719+
0 | 1 => {}
720+
n => {
721+
let mut last = n - 1;
722+
for i in (0..last).rev() {
723+
if !merge_axes(dim, strides, Axis(i), Axis(last)) {
724+
last = i;
725+
}
726+
}
727+
}
728+
}
729+
}
730+
704731
/// Move the axis which has the smallest absolute stride and a length
705732
/// greater than one to be the last axis.
706733
pub fn move_min_stride_axis_to_last<D>(dim: &mut D, strides: &mut D)
@@ -765,12 +792,40 @@ where
765792
*strides = new_strides;
766793
}
767794

795+
796+
/// Sort axes to standard/row major order, i.e Axis(0) has biggest stride and Axis(n - 1) least
797+
/// stride
798+
///
799+
/// The axes are sorted according to the .abs() of their stride.
800+
pub(crate) fn sort_axes_to_standard<D>(dim: &mut D, strides: &mut D)
801+
where
802+
D: Dimension,
803+
{
804+
debug_assert!(dim.ndim() > 1);
805+
debug_assert_eq!(dim.ndim(), strides.ndim());
806+
// bubble sort axes
807+
let mut changed = true;
808+
while changed {
809+
changed = false;
810+
for i in 0..dim.ndim() - 1 {
811+
// make sure higher stride axes sort before.
812+
if strides.get_stride(Axis(i)).abs() < strides.get_stride(Axis(i + 1)).abs() {
813+
changed = true;
814+
dim.slice_mut().swap(i, i + 1);
815+
strides.slice_mut().swap(i, i + 1);
816+
}
817+
}
818+
}
819+
}
820+
821+
768822
#[cfg(test)]
769823
mod test {
770824
use super::{
771825
arith_seq_intersect, can_index_slice, can_index_slice_not_custom, extended_gcd,
772826
max_abs_offset_check_overflow, slice_min_max, slices_intersect,
773827
solve_linear_diophantine_eq, IntoDimension, squeeze,
828+
merge_axes_from_the_back,
774829
};
775830
use crate::error::{from_kind, ErrorKind};
776831
use crate::slice::Slice;
@@ -1119,4 +1174,26 @@ mod test {
11191174
assert_eq!(d, dans);
11201175
assert_eq!(s, sans);
11211176
}
1177+
1178+
#[test]
1179+
fn test_merge_axes_from_the_back() {
1180+
let dyndim = Dim::<&[usize]>;
1181+
1182+
let mut d = Dim([3, 4, 5]);
1183+
let mut s = Dim([20, 5, 1]);
1184+
merge_axes_from_the_back(&mut d, &mut s);
1185+
assert_eq!(d, Dim([1, 1, 60]));
1186+
assert_eq!(s, Dim([20, 5, 1]));
1187+
1188+
let mut d = Dim([3, 4, 5, 2]);
1189+
let mut s = Dim([80, 20, 2, 1]);
1190+
merge_axes_from_the_back(&mut d, &mut s);
1191+
assert_eq!(d, Dim([1, 12, 1, 10]));
1192+
assert_eq!(s, Dim([80, 20, 2, 1]));
1193+
let mut d = d.into_dyn();
1194+
let mut s = s.into_dyn();
1195+
squeeze(&mut d, &mut s);
1196+
assert_eq!(d, dyndim(&[12, 10]));
1197+
assert_eq!(s, dyndim(&[20, 1]));
1198+
}
11221199
}

0 commit comments

Comments
 (0)