@@ -785,6 +785,43 @@ where D: Dimension
785
785
}
786
786
}
787
787
788
+ /// Remove axes with length one, except never removing the last axis.
789
+ ///
790
+ /// This only has effect on dynamic dimensions.
791
+ pub ( crate ) fn squeeze < D > ( dim : & mut D , strides : & mut D )
792
+ where D : Dimension
793
+ {
794
+ if let Some ( _) = D :: NDIM {
795
+ return ;
796
+ }
797
+ debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
798
+
799
+ // Count axes with dim == 1; we keep axes with d == 0 or d > 1
800
+ let mut ndim_new = 0 ;
801
+ for & d in dim. slice ( ) {
802
+ if d != 1 {
803
+ ndim_new += 1 ;
804
+ }
805
+ }
806
+ ndim_new = Ord :: max ( 1 , ndim_new) ;
807
+ let mut new_dim = D :: zeros ( ndim_new) ;
808
+ let mut new_strides = D :: zeros ( ndim_new) ;
809
+ let mut i = 0 ;
810
+ for ( & d, & s) in izip ! ( dim. slice( ) , strides. slice( ) ) {
811
+ if d != 1 {
812
+ new_dim[ i] = d;
813
+ new_strides[ i] = s;
814
+ i += 1 ;
815
+ }
816
+ }
817
+ if i == 0 {
818
+ new_dim[ i] = 1 ;
819
+ new_strides[ i] = 1 ;
820
+ }
821
+ * dim = new_dim;
822
+ * strides = new_strides;
823
+ }
824
+
788
825
#[ cfg( test) ]
789
826
mod test
790
827
{
@@ -797,6 +834,7 @@ mod test
797
834
slice_min_max,
798
835
slices_intersect,
799
836
solve_linear_diophantine_eq,
837
+ squeeze,
800
838
IntoDimension ,
801
839
} ;
802
840
use crate :: error:: { from_kind, ErrorKind } ;
@@ -1146,4 +1184,35 @@ mod test
1146
1184
s![ .., 3 ..; 6 , NewAxis ]
1147
1185
) ) ;
1148
1186
}
1187
+
1188
+ #[ test]
1189
+ #[ cfg( feature = "std" ) ]
1190
+ fn test_squeeze ( )
1191
+ {
1192
+ let dyndim = Dim :: < & [ usize ] > ;
1193
+
1194
+ let mut d = dyndim ( & [ 1 , 2 , 1 , 1 , 3 , 1 ] ) ;
1195
+ let mut s = dyndim ( & [ !0 , !0 , !0 , 9 , 10 , !0 ] ) ;
1196
+ let dans = dyndim ( & [ 2 , 3 ] ) ;
1197
+ let sans = dyndim ( & [ !0 , 10 ] ) ;
1198
+ squeeze ( & mut d, & mut s) ;
1199
+ assert_eq ! ( d, dans) ;
1200
+ assert_eq ! ( s, sans) ;
1201
+
1202
+ let mut d = dyndim ( & [ 1 , 1 ] ) ;
1203
+ let mut s = dyndim ( & [ 3 , 4 ] ) ;
1204
+ let dans = dyndim ( & [ 1 ] ) ;
1205
+ let sans = dyndim ( & [ 1 ] ) ;
1206
+ squeeze ( & mut d, & mut s) ;
1207
+ assert_eq ! ( d, dans) ;
1208
+ assert_eq ! ( s, sans) ;
1209
+
1210
+ let mut d = dyndim ( & [ 0 , 1 , 3 , 4 ] ) ;
1211
+ let mut s = dyndim ( & [ 2 , 3 , 4 , 5 ] ) ;
1212
+ let dans = dyndim ( & [ 0 , 3 , 4 ] ) ;
1213
+ let sans = dyndim ( & [ 2 , 4 , 5 ] ) ;
1214
+ squeeze ( & mut d, & mut s) ;
1215
+ assert_eq ! ( d, dans) ;
1216
+ assert_eq ! ( s, sans) ;
1217
+ }
1149
1218
}
0 commit comments