@@ -731,12 +731,46 @@ where
731
731
}
732
732
}
733
733
734
+ /// Remove axes with length one, except never removing the last axis.
735
+ pub ( crate ) fn squeeze < D > ( dim : & mut D , strides : & mut D )
736
+ where
737
+ D : Dimension ,
738
+ {
739
+ if let Some ( _) = D :: NDIM {
740
+ return ;
741
+ }
742
+ debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
743
+
744
+ // Count axes with dim == 1; we keep axes with d == 0 or d > 1
745
+ let mut ndim_new = 0 ;
746
+ for & d in dim. slice ( ) {
747
+ if d != 1 { ndim_new += 1 ; }
748
+ }
749
+ ndim_new = Ord :: max ( 1 , ndim_new) ;
750
+ let mut new_dim = D :: zeros ( ndim_new) ;
751
+ let mut new_strides = D :: zeros ( ndim_new) ;
752
+ let mut i = 0 ;
753
+ for ( & d, & s) in izip ! ( dim. slice( ) , strides. slice( ) ) {
754
+ if d != 1 {
755
+ new_dim[ i] = d;
756
+ new_strides[ i] = s;
757
+ i += 1 ;
758
+ }
759
+ }
760
+ if i == 0 {
761
+ new_dim[ i] = 1 ;
762
+ new_strides[ i] = 1 ;
763
+ }
764
+ * dim = new_dim;
765
+ * strides = new_strides;
766
+ }
767
+
734
768
#[ cfg( test) ]
735
769
mod test {
736
770
use super :: {
737
771
arith_seq_intersect, can_index_slice, can_index_slice_not_custom, extended_gcd,
738
772
max_abs_offset_check_overflow, slice_min_max, slices_intersect,
739
- solve_linear_diophantine_eq, IntoDimension ,
773
+ solve_linear_diophantine_eq, IntoDimension , squeeze ,
740
774
} ;
741
775
use crate :: error:: { from_kind, ErrorKind } ;
742
776
use crate :: slice:: Slice ;
@@ -1055,4 +1089,34 @@ mod test {
1055
1089
s![ .., 3 ..; 6 , NewAxis ]
1056
1090
) ) ;
1057
1091
}
1092
+
1093
+ #[ test]
1094
+ #[ cfg( feature = "std" ) ]
1095
+ fn test_squeeze ( ) {
1096
+ let dyndim = Dim :: < & [ usize ] > ;
1097
+
1098
+ let mut d = dyndim ( & [ 1 , 2 , 1 , 1 , 3 , 1 ] ) ;
1099
+ let mut s = dyndim ( & [ !0 , !0 , !0 , 9 , 10 , !0 ] ) ;
1100
+ let dans = dyndim ( & [ 2 , 3 ] ) ;
1101
+ let sans = dyndim ( & [ !0 , 10 ] ) ;
1102
+ squeeze ( & mut d, & mut s) ;
1103
+ assert_eq ! ( d, dans) ;
1104
+ assert_eq ! ( s, sans) ;
1105
+
1106
+ let mut d = dyndim ( & [ 1 , 1 ] ) ;
1107
+ let mut s = dyndim ( & [ 3 , 4 ] ) ;
1108
+ let dans = dyndim ( & [ 1 ] ) ;
1109
+ let sans = dyndim ( & [ 1 ] ) ;
1110
+ squeeze ( & mut d, & mut s) ;
1111
+ assert_eq ! ( d, dans) ;
1112
+ assert_eq ! ( s, sans) ;
1113
+
1114
+ let mut d = dyndim ( & [ 0 , 1 , 3 , 4 ] ) ;
1115
+ let mut s = dyndim ( & [ 2 , 3 , 4 , 5 ] ) ;
1116
+ let dans = dyndim ( & [ 0 , 3 , 4 ] ) ;
1117
+ let sans = dyndim ( & [ 2 , 4 , 5 ] ) ;
1118
+ squeeze ( & mut d, & mut s) ;
1119
+ assert_eq ! ( d, dans) ;
1120
+ assert_eq ! ( s, sans) ;
1121
+ }
1058
1122
}
0 commit comments