Skip to content

Commit 380be98

Browse files
committed
FEAT: Add dimension::squeeze to remove dimensions with len == 1
1 parent 572dea0 commit 380be98

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed

src/dimension/mod.rs

+67
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,41 @@ where D: Dimension
785785
}
786786
}
787787

788+
/// Remove axes with length one, except never removing the last axis.
789+
pub(crate) fn squeeze<D>(dim: &mut D, strides: &mut D)
790+
where D: Dimension
791+
{
792+
if let Some(_) = D::NDIM {
793+
return;
794+
}
795+
debug_assert_eq!(dim.ndim(), strides.ndim());
796+
797+
// Count axes with dim == 1; we keep axes with d == 0 or d > 1
798+
let mut ndim_new = 0;
799+
for &d in dim.slice() {
800+
if d != 1 {
801+
ndim_new += 1;
802+
}
803+
}
804+
ndim_new = Ord::max(1, ndim_new);
805+
let mut new_dim = D::zeros(ndim_new);
806+
let mut new_strides = D::zeros(ndim_new);
807+
let mut i = 0;
808+
for (&d, &s) in izip!(dim.slice(), strides.slice()) {
809+
if d != 1 {
810+
new_dim[i] = d;
811+
new_strides[i] = s;
812+
i += 1;
813+
}
814+
}
815+
if i == 0 {
816+
new_dim[i] = 1;
817+
new_strides[i] = 1;
818+
}
819+
*dim = new_dim;
820+
*strides = new_strides;
821+
}
822+
788823
#[cfg(test)]
789824
mod test
790825
{
@@ -797,6 +832,7 @@ mod test
797832
slice_min_max,
798833
slices_intersect,
799834
solve_linear_diophantine_eq,
835+
squeeze,
800836
IntoDimension,
801837
};
802838
use crate::error::{from_kind, ErrorKind};
@@ -1146,4 +1182,35 @@ mod test
11461182
s![.., 3..;6, NewAxis]
11471183
));
11481184
}
1185+
1186+
#[test]
1187+
#[cfg(feature = "std")]
1188+
fn test_squeeze()
1189+
{
1190+
let dyndim = Dim::<&[usize]>;
1191+
1192+
let mut d = dyndim(&[1, 2, 1, 1, 3, 1]);
1193+
let mut s = dyndim(&[!0, !0, !0, 9, 10, !0]);
1194+
let dans = dyndim(&[2, 3]);
1195+
let sans = dyndim(&[!0, 10]);
1196+
squeeze(&mut d, &mut s);
1197+
assert_eq!(d, dans);
1198+
assert_eq!(s, sans);
1199+
1200+
let mut d = dyndim(&[1, 1]);
1201+
let mut s = dyndim(&[3, 4]);
1202+
let dans = dyndim(&[1]);
1203+
let sans = dyndim(&[1]);
1204+
squeeze(&mut d, &mut s);
1205+
assert_eq!(d, dans);
1206+
assert_eq!(s, sans);
1207+
1208+
let mut d = dyndim(&[0, 1, 3, 4]);
1209+
let mut s = dyndim(&[2, 3, 4, 5]);
1210+
let dans = dyndim(&[0, 3, 4]);
1211+
let sans = dyndim(&[2, 4, 5]);
1212+
squeeze(&mut d, &mut s);
1213+
assert_eq!(d, dans);
1214+
assert_eq!(s, sans);
1215+
}
11491216
}

0 commit comments

Comments
 (0)