@@ -701,6 +701,33 @@ where
701
701
}
702
702
}
703
703
704
+ /// Attempt to merge axes if possible, starting from the back
705
+ ///
706
+ /// Given axes [Axis(0), Axis(1), Axis(2), Axis(3)] this attempts
707
+ /// to merge all axes one by one into Axis(3); when/if this fails,
708
+ /// it attempts to merge the rest of the axes together into the next
709
+ /// axis in line, for example a result could be:
710
+ ///
711
+ /// [1, Axis(0) + Axis(1), 1, Axis(2) + Axis(3)] where `+` would
712
+ /// mean axes were merged.
713
+ pub ( crate ) fn merge_axes_from_the_back < D > ( dim : & mut D , strides : & mut D )
714
+ where
715
+ D : Dimension ,
716
+ {
717
+ debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
718
+ match dim. ndim ( ) {
719
+ 0 | 1 => { }
720
+ n => {
721
+ let mut last = n - 1 ;
722
+ for i in ( 0 ..last) . rev ( ) {
723
+ if !merge_axes ( dim, strides, Axis ( i) , Axis ( last) ) {
724
+ last = i;
725
+ }
726
+ }
727
+ }
728
+ }
729
+ }
730
+
704
731
/// Move the axis which has the smallest absolute stride and a length
705
732
/// greater than one to be the last axis.
706
733
pub fn move_min_stride_axis_to_last < D > ( dim : & mut D , strides : & mut D )
@@ -765,12 +792,40 @@ where
765
792
* strides = new_strides;
766
793
}
767
794
795
+
796
+ /// Sort axes to standard/row major order, i.e Axis(0) has biggest stride and Axis(n - 1) least
797
+ /// stride
798
+ ///
799
+ /// The axes are sorted according to the .abs() of their stride.
800
+ pub ( crate ) fn sort_axes_to_standard < D > ( dim : & mut D , strides : & mut D )
801
+ where
802
+ D : Dimension ,
803
+ {
804
+ debug_assert ! ( dim. ndim( ) > 1 ) ;
805
+ debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
806
+ // bubble sort axes
807
+ let mut changed = true ;
808
+ while changed {
809
+ changed = false ;
810
+ for i in 0 ..dim. ndim ( ) - 1 {
811
+ // make sure higher stride axes sort before.
812
+ if strides. get_stride ( Axis ( i) ) . abs ( ) < strides. get_stride ( Axis ( i + 1 ) ) . abs ( ) {
813
+ changed = true ;
814
+ dim. slice_mut ( ) . swap ( i, i + 1 ) ;
815
+ strides. slice_mut ( ) . swap ( i, i + 1 ) ;
816
+ }
817
+ }
818
+ }
819
+ }
820
+
821
+
768
822
#[ cfg( test) ]
769
823
mod test {
770
824
use super :: {
771
825
arith_seq_intersect, can_index_slice, can_index_slice_not_custom, extended_gcd,
772
826
max_abs_offset_check_overflow, slice_min_max, slices_intersect,
773
827
solve_linear_diophantine_eq, IntoDimension , squeeze,
828
+ merge_axes_from_the_back,
774
829
} ;
775
830
use crate :: error:: { from_kind, ErrorKind } ;
776
831
use crate :: slice:: Slice ;
@@ -1119,4 +1174,26 @@ mod test {
1119
1174
assert_eq ! ( d, dans) ;
1120
1175
assert_eq ! ( s, sans) ;
1121
1176
}
1177
+
1178
+ #[ test]
1179
+ fn test_merge_axes_from_the_back ( ) {
1180
+ let dyndim = Dim :: < & [ usize ] > ;
1181
+
1182
+ let mut d = Dim ( [ 3 , 4 , 5 ] ) ;
1183
+ let mut s = Dim ( [ 20 , 5 , 1 ] ) ;
1184
+ merge_axes_from_the_back ( & mut d, & mut s) ;
1185
+ assert_eq ! ( d, Dim ( [ 1 , 1 , 60 ] ) ) ;
1186
+ assert_eq ! ( s, Dim ( [ 20 , 5 , 1 ] ) ) ;
1187
+
1188
+ let mut d = Dim ( [ 3 , 4 , 5 , 2 ] ) ;
1189
+ let mut s = Dim ( [ 80 , 20 , 2 , 1 ] ) ;
1190
+ merge_axes_from_the_back ( & mut d, & mut s) ;
1191
+ assert_eq ! ( d, Dim ( [ 1 , 12 , 1 , 10 ] ) ) ;
1192
+ assert_eq ! ( s, Dim ( [ 80 , 20 , 2 , 1 ] ) ) ;
1193
+ let mut d = d. into_dyn ( ) ;
1194
+ let mut s = s. into_dyn ( ) ;
1195
+ squeeze ( & mut d, & mut s) ;
1196
+ assert_eq ! ( d, dyndim( & [ 12 , 10 ] ) ) ;
1197
+ assert_eq ! ( s, dyndim( & [ 20 , 1 ] ) ) ;
1198
+ }
1122
1199
}
0 commit comments