@@ -759,6 +759,32 @@ where D: Dimension
759
759
}
760
760
}
761
761
762
+ /// Attempt to merge axes if possible, starting from the back
763
+ ///
764
+ /// Given axes [Axis(0), Axis(1), Axis(2), Axis(3)] this attempts
765
+ /// to merge all axes one by one into Axis(3); when/if this fails,
766
+ /// it attempts to merge the rest of the axes together into the next
767
+ /// axis in line, for example a result could be:
768
+ ///
769
+ /// [1, Axis(0) + Axis(1), 1, Axis(2) + Axis(3)] where `+` would
770
+ /// mean axes were merged.
771
+ pub ( crate ) fn merge_axes_from_the_back < D > ( dim : & mut D , strides : & mut D )
772
+ where D : Dimension
773
+ {
774
+ debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
775
+ match dim. ndim ( ) {
776
+ 0 | 1 => { }
777
+ n => {
778
+ let mut last = n - 1 ;
779
+ for i in ( 0 ..last) . rev ( ) {
780
+ if !merge_axes ( dim, strides, Axis ( i) , Axis ( last) ) {
781
+ last = i;
782
+ }
783
+ }
784
+ }
785
+ }
786
+ }
787
+
762
788
/// Move the axis which has the smallest absolute stride and a length
763
789
/// greater than one to be the last axis.
764
790
pub fn move_min_stride_axis_to_last < D > ( dim : & mut D , strides : & mut D )
@@ -820,6 +846,30 @@ where D: Dimension
820
846
* strides = new_strides;
821
847
}
822
848
849
+ /// Sort axes to standard/row major order, i.e Axis(0) has biggest stride and Axis(n - 1) least
850
+ /// stride
851
+ ///
852
+ /// The axes are sorted according to the .abs() of their stride.
853
+ pub ( crate ) fn sort_axes_to_standard < D > ( dim : & mut D , strides : & mut D )
854
+ where D : Dimension
855
+ {
856
+ debug_assert ! ( dim. ndim( ) > 1 ) ;
857
+ debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
858
+ // bubble sort axes
859
+ let mut changed = true ;
860
+ while changed {
861
+ changed = false ;
862
+ for i in 0 ..dim. ndim ( ) - 1 {
863
+ // make sure higher stride axes sort before.
864
+ if strides. get_stride ( Axis ( i) ) . abs ( ) < strides. get_stride ( Axis ( i + 1 ) ) . abs ( ) {
865
+ changed = true ;
866
+ dim. slice_mut ( ) . swap ( i, i + 1 ) ;
867
+ strides. slice_mut ( ) . swap ( i, i + 1 ) ;
868
+ }
869
+ }
870
+ }
871
+ }
872
+
823
873
#[ cfg( test) ]
824
874
mod test
825
875
{
@@ -829,6 +879,7 @@ mod test
829
879
can_index_slice_not_custom,
830
880
extended_gcd,
831
881
max_abs_offset_check_overflow,
882
+ merge_axes_from_the_back,
832
883
slice_min_max,
833
884
slices_intersect,
834
885
solve_linear_diophantine_eq,
@@ -1213,4 +1264,27 @@ mod test
1213
1264
assert_eq ! ( d, dans) ;
1214
1265
assert_eq ! ( s, sans) ;
1215
1266
}
1267
+
1268
+ #[ test]
1269
+ fn test_merge_axes_from_the_back ( )
1270
+ {
1271
+ let dyndim = Dim :: < & [ usize ] > ;
1272
+
1273
+ let mut d = Dim ( [ 3 , 4 , 5 ] ) ;
1274
+ let mut s = Dim ( [ 20 , 5 , 1 ] ) ;
1275
+ merge_axes_from_the_back ( & mut d, & mut s) ;
1276
+ assert_eq ! ( d, Dim ( [ 1 , 1 , 60 ] ) ) ;
1277
+ assert_eq ! ( s, Dim ( [ 20 , 5 , 1 ] ) ) ;
1278
+
1279
+ let mut d = Dim ( [ 3 , 4 , 5 , 2 ] ) ;
1280
+ let mut s = Dim ( [ 80 , 20 , 2 , 1 ] ) ;
1281
+ merge_axes_from_the_back ( & mut d, & mut s) ;
1282
+ assert_eq ! ( d, Dim ( [ 1 , 12 , 1 , 10 ] ) ) ;
1283
+ assert_eq ! ( s, Dim ( [ 80 , 20 , 2 , 1 ] ) ) ;
1284
+ let mut d = d. into_dyn ( ) ;
1285
+ let mut s = s. into_dyn ( ) ;
1286
+ squeeze ( & mut d, & mut s) ;
1287
+ assert_eq ! ( d, dyndim( & [ 12 , 10 ] ) ) ;
1288
+ assert_eq ! ( s, dyndim( & [ 20 , 1 ] ) ) ;
1289
+ }
1216
1290
}
0 commit comments