Skip to content

Commit d42ee96

Browse files
committed
FEAT: Add dimension merge function to merge contiguous axes
1 parent f13c63e commit d42ee96

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

src/dimension/mod.rs

+74
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,32 @@ where D: Dimension
759759
}
760760
}
761761

762+
/// Attempt to merge axes if possible, starting from the back
763+
///
764+
/// Given axes [Axis(0), Axis(1), Axis(2), Axis(3)] this attempts
765+
/// to merge all axes one by one into Axis(3); when/if this fails,
766+
/// it attempts to merge the rest of the axes together into the next
767+
/// axis in line, for example a result could be:
768+
///
769+
/// [1, Axis(0) + Axis(1), 1, Axis(2) + Axis(3)] where `+` would
770+
/// mean axes were merged.
771+
pub(crate) fn merge_axes_from_the_back<D>(dim: &mut D, strides: &mut D)
772+
where D: Dimension
773+
{
774+
debug_assert_eq!(dim.ndim(), strides.ndim());
775+
match dim.ndim() {
776+
0 | 1 => {}
777+
n => {
778+
let mut last = n - 1;
779+
for i in (0..last).rev() {
780+
if !merge_axes(dim, strides, Axis(i), Axis(last)) {
781+
last = i;
782+
}
783+
}
784+
}
785+
}
786+
}
787+
762788
/// Move the axis which has the smallest absolute stride and a length
763789
/// greater than one to be the last axis.
764790
pub fn move_min_stride_axis_to_last<D>(dim: &mut D, strides: &mut D)
@@ -820,6 +846,30 @@ where D: Dimension
820846
*strides = new_strides;
821847
}
822848

849+
/// Sort axes to standard/row major order, i.e Axis(0) has biggest stride and Axis(n - 1) least
850+
/// stride
851+
///
852+
/// The axes are sorted according to the .abs() of their stride.
853+
pub(crate) fn sort_axes_to_standard<D>(dim: &mut D, strides: &mut D)
854+
where D: Dimension
855+
{
856+
debug_assert!(dim.ndim() > 1);
857+
debug_assert_eq!(dim.ndim(), strides.ndim());
858+
// bubble sort axes
859+
let mut changed = true;
860+
while changed {
861+
changed = false;
862+
for i in 0..dim.ndim() - 1 {
863+
// make sure higher stride axes sort before.
864+
if strides.get_stride(Axis(i)).abs() < strides.get_stride(Axis(i + 1)).abs() {
865+
changed = true;
866+
dim.slice_mut().swap(i, i + 1);
867+
strides.slice_mut().swap(i, i + 1);
868+
}
869+
}
870+
}
871+
}
872+
823873
#[cfg(test)]
824874
mod test
825875
{
@@ -829,6 +879,7 @@ mod test
829879
can_index_slice_not_custom,
830880
extended_gcd,
831881
max_abs_offset_check_overflow,
882+
merge_axes_from_the_back,
832883
slice_min_max,
833884
slices_intersect,
834885
solve_linear_diophantine_eq,
@@ -1213,4 +1264,27 @@ mod test
12131264
assert_eq!(d, dans);
12141265
assert_eq!(s, sans);
12151266
}
1267+
1268+
#[test]
1269+
fn test_merge_axes_from_the_back()
1270+
{
1271+
let dyndim = Dim::<&[usize]>;
1272+
1273+
let mut d = Dim([3, 4, 5]);
1274+
let mut s = Dim([20, 5, 1]);
1275+
merge_axes_from_the_back(&mut d, &mut s);
1276+
assert_eq!(d, Dim([1, 1, 60]));
1277+
assert_eq!(s, Dim([20, 5, 1]));
1278+
1279+
let mut d = Dim([3, 4, 5, 2]);
1280+
let mut s = Dim([80, 20, 2, 1]);
1281+
merge_axes_from_the_back(&mut d, &mut s);
1282+
assert_eq!(d, Dim([1, 12, 1, 10]));
1283+
assert_eq!(s, Dim([80, 20, 2, 1]));
1284+
let mut d = d.into_dyn();
1285+
let mut s = s.into_dyn();
1286+
squeeze(&mut d, &mut s);
1287+
assert_eq!(d, dyndim(&[12, 10]));
1288+
assert_eq!(s, dyndim(&[20, 1]));
1289+
}
12161290
}

0 commit comments

Comments
 (0)