Skip to content

Commit b0f9669

Browse files
committed
FEAT: Add dimension::squeeze to remove dimensions with len == 1
1 parent be5e8c8 commit b0f9669

File tree

1 file changed

+65
-1
lines changed

1 file changed

+65
-1
lines changed

src/dimension/mod.rs

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -731,12 +731,46 @@ where
731731
}
732732
}
733733

734+
/// Remove axes with length one, except never removing the last axis.
735+
pub(crate) fn squeeze<D>(dim: &mut D, strides: &mut D)
736+
where
737+
D: Dimension,
738+
{
739+
if let Some(_) = D::NDIM {
740+
return;
741+
}
742+
debug_assert_eq!(dim.ndim(), strides.ndim());
743+
744+
// Count axes with dim == 1; we keep axes with d == 0 or d > 1
745+
let mut ndim_new = 0;
746+
for &d in dim.slice() {
747+
if d != 1 { ndim_new += 1; }
748+
}
749+
ndim_new = Ord::max(1, ndim_new);
750+
let mut new_dim = D::zeros(ndim_new);
751+
let mut new_strides = D::zeros(ndim_new);
752+
let mut i = 0;
753+
for (&d, &s) in izip!(dim.slice(), strides.slice()) {
754+
if d != 1 {
755+
new_dim[i] = d;
756+
new_strides[i] = s;
757+
i += 1;
758+
}
759+
}
760+
if i == 0 {
761+
new_dim[i] = 1;
762+
new_strides[i] = 1;
763+
}
764+
*dim = new_dim;
765+
*strides = new_strides;
766+
}
767+
734768
#[cfg(test)]
735769
mod test {
736770
use super::{
737771
arith_seq_intersect, can_index_slice, can_index_slice_not_custom, extended_gcd,
738772
max_abs_offset_check_overflow, slice_min_max, slices_intersect,
739-
solve_linear_diophantine_eq, IntoDimension,
773+
solve_linear_diophantine_eq, IntoDimension, squeeze,
740774
};
741775
use crate::error::{from_kind, ErrorKind};
742776
use crate::slice::Slice;
@@ -1055,4 +1089,34 @@ mod test {
10551089
s![.., 3..;6, NewAxis]
10561090
));
10571091
}
1092+
1093+
#[test]
1094+
#[cfg(feature = "std")]
1095+
fn test_squeeze() {
1096+
let dyndim = Dim::<&[usize]>;
1097+
1098+
let mut d = dyndim(&[1, 2, 1, 1, 3, 1]);
1099+
let mut s = dyndim(&[!0, !0, !0, 9, 10, !0]);
1100+
let dans = dyndim(&[2, 3]);
1101+
let sans = dyndim(&[!0, 10]);
1102+
squeeze(&mut d, &mut s);
1103+
assert_eq!(d, dans);
1104+
assert_eq!(s, sans);
1105+
1106+
let mut d = dyndim(&[1, 1]);
1107+
let mut s = dyndim(&[3, 4]);
1108+
let dans = dyndim(&[1]);
1109+
let sans = dyndim(&[1]);
1110+
squeeze(&mut d, &mut s);
1111+
assert_eq!(d, dans);
1112+
assert_eq!(s, sans);
1113+
1114+
let mut d = dyndim(&[0, 1, 3, 4]);
1115+
let mut s = dyndim(&[2, 3, 4, 5]);
1116+
let dans = dyndim(&[0, 3, 4]);
1117+
let sans = dyndim(&[2, 4, 5]);
1118+
squeeze(&mut d, &mut s);
1119+
assert_eq!(d, dans);
1120+
assert_eq!(s, sans);
1121+
}
10581122
}

0 commit comments

Comments
 (0)