@@ -832,12 +832,12 @@ getUntiledProducerFromSliceSource(OpOperand *source,
832
832
return {dyn_cast<OpResult>(source->get ()), destinationIterArg};
833
833
}
834
834
835
- // / Implementation of fusing producer of a single slice by computing the
835
+ // / Basic implementation of fusing producer of a single slice by computing the
836
836
// / slice of the producer in-place.
837
- std::optional<scf::SCFFuseProducerOfSliceResult>
838
- mlir::scf::tileAndFuseProducerOfSlice (
839
- RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
840
- MutableArrayRef<LoopLikeOpInterface> loops) {
837
+ static std::optional<scf::SCFFuseProducerOfSliceResult>
838
+ tileAndFuseProducerOfSliceImpl (RewriterBase &rewriter,
839
+ tensor::ExtractSliceOp candidateSliceOp,
840
+ MutableArrayRef<LoopLikeOpInterface> loops) {
841
841
// 1. Get the producer of the source (potentially walking through
842
842
// `iter_args` of nested `scf.for`)
843
843
auto [fusableProducer, destinationInitArg] =
@@ -949,6 +949,145 @@ mlir::scf::tileAndFuseProducerOfSlice(
949
949
tileAndFuseResult->tiledOps };
950
950
}
951
951
952
+ // / Get the real producer from candidate ExtractSliceOp
953
+ // /
954
+ // / ```
955
+ // / %0 = producer
956
+ // / %1 = scf.for(%arg1 = %0)
957
+ // / %2 = extract %arg1
958
+ // / %3 = scf.for(%arg2 = %2)
959
+ // / %4 = extract %args2
960
+ // / ...
961
+ // / ```
962
+ // /
963
+ // / @param candidateSliceOp: %4 = extract %args2
964
+ // / @param backwardSlice: in-out parameter populated by backward extractSliceOps
965
+ // / @return OpResult Producer : %0 = producer
966
+ static FailureOr<OpResult> getRealProducerFromExtractSliceOp (
967
+ Operation *candidateSliceOp,
968
+ SmallVector<tensor::ExtractSliceOp> &backwardSlice, int curDepth = 0 ,
969
+ int maxDepth = 5 ) {
970
+ if (!isa<tensor::ExtractSliceOp>(candidateSliceOp))
971
+ return failure ();
972
+ // control recursive time in avoid of stack overflow
973
+ if (curDepth > maxDepth)
974
+ return failure ();
975
+
976
+ auto extractOp = cast<tensor::ExtractSliceOp>(candidateSliceOp);
977
+ backwardSlice.push_back (extractOp);
978
+ Value rootSource = extractOp.getSourceMutable ().get ();
979
+
980
+ while (true ) {
981
+ if (auto iterArg = dyn_cast<BlockArgument>(rootSource)) {
982
+ if (auto outerLoop = dyn_cast<LoopLikeOpInterface>(
983
+ iterArg.getOwner ()->getParentOp ())) {
984
+ rootSource = outerLoop.getTiedLoopInit (iterArg)->get ();
985
+ continue ;
986
+ }
987
+ return failure ();
988
+ } else if (auto sliceOp =
989
+ rootSource.getDefiningOp <tensor::ExtractSliceOp>()) {
990
+ // walk up loop to find larger candidate extractSliceOp
991
+ return getRealProducerFromExtractSliceOp (sliceOp, backwardSlice,
992
+ curDepth + 1 );
993
+ }
994
+ break ;
995
+ }
996
+ return dyn_cast<OpResult>(rootSource);
997
+ }
998
+
999
+ // / Recursively find the outer nest loops of given loop(included) while the
1000
+ // / predict function succeed, sorted from outer to inner.
1001
+ // /
1002
+ // / @param loop: target loop, note that this loop will be also included. I.e.
1003
+ // / if no other nest loops were found, just return itself.
1004
+ // / @param pred: predict function, the termination condition of recursive
1005
+ // / process.
1006
+ // / @return Outer Nest Loops: nest loops outside given target loop(included).
1007
+ // /
1008
+ // / E.g.
1009
+ // /
1010
+ // / ```
1011
+ // / %0 = scf.for()
1012
+ // / %1 = scf.for()
1013
+ // / %2 = scf.for()
1014
+ // / ```
1015
+ // /
1016
+ // / If `%2 = scf.for` is given without specific prediction function, this
1017
+ // / function will return three nest loops: %0 + %1 + %2.
1018
+ static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile (
1019
+ LoopLikeOpInterface loop,
1020
+ const std::function<LogicalResult(LoopLikeOpInterface)> &pred) {
1021
+ SmallVector<LoopLikeOpInterface> nestLoops = {loop};
1022
+ auto outerLoop = dyn_cast<LoopLikeOpInterface>(loop->getParentOp ());
1023
+ while (outerLoop && succeeded (pred (outerLoop))) {
1024
+ nestLoops.push_back (outerLoop);
1025
+ outerLoop = dyn_cast<LoopLikeOpInterface>(outerLoop->getParentOp ());
1026
+ }
1027
+ // sorted from outer to inner
1028
+ return {nestLoops.rbegin (), nestLoops.rend ()};
1029
+ }
1030
+
1031
+ // / Enhanced version for basic implementation of fusing producer, which can deal
1032
+ // / with multi-level candidates. E.g.
1033
+ // /
1034
+ // / ```
1035
+ // / %0 = untiled_producer
1036
+ // / %1 = scf.for(%arg1 = %0)
1037
+ // / %2 = tensor.extract_slice %arg1
1038
+ // / %3 = scf.for(%arg2 = %2)
1039
+ // / %4 = tensor.extract_slice %args2
1040
+ // / %5 = tiled_consumer ins(%4)
1041
+ // / ```
1042
+ // /
1043
+ // / This utility can fuse untiled producer at `%4 = tensor.extract_slice` within
1044
+ // / inner loop `%3 = scf.for`.
1045
+ std::optional<scf::SCFFuseProducerOfSliceResult>
1046
+ mlir::scf::tileAndFuseProducerOfSlice (RewriterBase &rewriter,
1047
+ Operation *candidateSliceOp) {
1048
+ SmallVector<tensor::ExtractSliceOp> backwardSlice;
1049
+ if (failed (
1050
+ getRealProducerFromExtractSliceOp (candidateSliceOp, backwardSlice))) {
1051
+ return std::nullopt;
1052
+ }
1053
+
1054
+ std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResult;
1055
+ // reverse from outer to inner
1056
+ std::reverse (backwardSlice.begin (), backwardSlice.end ());
1057
+ // multiple application of `tileAndFuseProducerOfSliceImpl`
1058
+ for (auto &&[index , sliceOp] : llvm::enumerate (backwardSlice)) {
1059
+ // get nest loops between next candidate sliceOp and tiled producer.
1060
+ auto whileProducerOutOfLoopBlock =
1061
+ [&fuseProducerResult](LoopLikeOpInterface loop) -> LogicalResult {
1062
+ if (fuseProducerResult) {
1063
+ Block &body = loop->getRegion (0 ).front ();
1064
+ if (fuseProducerResult->tiledAndFusedProducer .getDefiningOp ()
1065
+ ->getBlock () == &body)
1066
+ return failure ();
1067
+ }
1068
+ return success ();
1069
+ };
1070
+ SmallVector<LoopLikeOpInterface> outerLoops =
1071
+ getOuterNestLoopsWhile (sliceOp->getParentOfType <LoopLikeOpInterface>(),
1072
+ whileProducerOutOfLoopBlock);
1073
+ fuseProducerResult =
1074
+ tileAndFuseProducerOfSliceImpl (rewriter, sliceOp, outerLoops);
1075
+ if (!fuseProducerResult) {
1076
+ return std::nullopt;
1077
+ }
1078
+ }
1079
+ return fuseProducerResult;
1080
+ }
1081
+
1082
+ // / Implementation of fusing producer of a single slice by computing the
1083
+ // / slice of the producer in-place.
1084
+ std::optional<scf::SCFFuseProducerOfSliceResult>
1085
+ mlir::scf::tileAndFuseProducerOfSlice (
1086
+ RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1087
+ MutableArrayRef<LoopLikeOpInterface> loops) {
1088
+ return tileAndFuseProducerOfSliceImpl (rewriter, candidateSliceOp, loops);
1089
+ }
1090
+
952
1091
// / Reconstruct the fused producer from within the tiled-and-fused code.
953
1092
LogicalResult mlir::scf::yieldReplacementForFusedProducer (
954
1093
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
0 commit comments