Skip to content

Commit d223bc6

Browse files
committed
FEAT: Add dimension::squeeze to remove dimensions with len == 1
1 parent 572dea0 commit d223bc6

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed

src/dimension/mod.rs

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,43 @@ where D: Dimension
785785
}
786786
}
787787

788+
/// Remove axes with length one, except never removing the last axis.
789+
///
790+
/// This only has effect on dynamic dimensions.
791+
pub(crate) fn squeeze<D>(dim: &mut D, strides: &mut D)
792+
where D: Dimension
793+
{
794+
if let Some(_) = D::NDIM {
795+
return;
796+
}
797+
debug_assert_eq!(dim.ndim(), strides.ndim());
798+
799+
// Count axes with dim == 1; we keep axes with d == 0 or d > 1
800+
let mut ndim_new = 0;
801+
for &d in dim.slice() {
802+
if d != 1 {
803+
ndim_new += 1;
804+
}
805+
}
806+
ndim_new = Ord::max(1, ndim_new);
807+
let mut new_dim = D::zeros(ndim_new);
808+
let mut new_strides = D::zeros(ndim_new);
809+
let mut i = 0;
810+
for (&d, &s) in izip!(dim.slice(), strides.slice()) {
811+
if d != 1 {
812+
new_dim[i] = d;
813+
new_strides[i] = s;
814+
i += 1;
815+
}
816+
}
817+
if i == 0 {
818+
new_dim[i] = 1;
819+
new_strides[i] = 1;
820+
}
821+
*dim = new_dim;
822+
*strides = new_strides;
823+
}
824+
788825
#[cfg(test)]
789826
mod test
790827
{
@@ -797,6 +834,7 @@ mod test
797834
slice_min_max,
798835
slices_intersect,
799836
solve_linear_diophantine_eq,
837+
squeeze,
800838
IntoDimension,
801839
};
802840
use crate::error::{from_kind, ErrorKind};
@@ -1146,4 +1184,35 @@ mod test
11461184
s![.., 3..;6, NewAxis]
11471185
));
11481186
}
1187+
1188+
#[test]
1189+
#[cfg(feature = "std")]
1190+
fn test_squeeze()
1191+
{
1192+
let dyndim = Dim::<&[usize]>;
1193+
1194+
let mut d = dyndim(&[1, 2, 1, 1, 3, 1]);
1195+
let mut s = dyndim(&[!0, !0, !0, 9, 10, !0]);
1196+
let dans = dyndim(&[2, 3]);
1197+
let sans = dyndim(&[!0, 10]);
1198+
squeeze(&mut d, &mut s);
1199+
assert_eq!(d, dans);
1200+
assert_eq!(s, sans);
1201+
1202+
let mut d = dyndim(&[1, 1]);
1203+
let mut s = dyndim(&[3, 4]);
1204+
let dans = dyndim(&[1]);
1205+
let sans = dyndim(&[1]);
1206+
squeeze(&mut d, &mut s);
1207+
assert_eq!(d, dans);
1208+
assert_eq!(s, sans);
1209+
1210+
let mut d = dyndim(&[0, 1, 3, 4]);
1211+
let mut s = dyndim(&[2, 3, 4, 5]);
1212+
let dans = dyndim(&[0, 3, 4]);
1213+
let sans = dyndim(&[2, 4, 5]);
1214+
squeeze(&mut d, &mut s);
1215+
assert_eq!(d, dans);
1216+
assert_eq!(s, sans);
1217+
}
11491218
}

0 commit comments

Comments
 (0)