@@ -835,7 +835,7 @@ getUntiledProducerFromSliceSource(OpOperand *source,
835
835
// / Implementation of fusing producer of a single slice by computing the
836
836
// / slice of the producer in-place.
837
837
std::optional<scf::SCFFuseProducerOfSliceResult>
838
- mlir::scf::tileAndFuseProducerOfSlice (
838
+ mlir::scf::tileAndFuseProducerOfSliceImpl (
839
839
RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
840
840
MutableArrayRef<LoopLikeOpInterface> loops) {
841
841
// 1. Get the producer of the source (potentially walking through
@@ -949,6 +949,135 @@ mlir::scf::tileAndFuseProducerOfSlice(
949
949
tileAndFuseResult->tiledOps };
950
950
}
951
951
952
+ // / Get the Root source of target ExtractSliceOp
953
+ // / %0 =
954
+ // / %1 = scf.for(%arg1 = %0)
955
+ // / %2 = extract %arg1
956
+ // / %3 = scf.for(%arg2 = %2)
957
+ // / %4 = extract %args2
958
+ // / ...
959
+ // / @param targetSliceOp: %4 = extract %args2
960
+ // / @param extractSliceOpChain: chain of all related extract sliceOp
961
+ // / @return Value of Root Source : %0
962
+ static FailureOr<Value> getRootSourceOfExtractSliceOp (
963
+ Operation *targetSliceOp,
964
+ SmallVectorImpl<tensor::ExtractSliceOp> &extractSliceOpChain,
965
+ int curDepth = 0 , int maxDepth = 5 ) {
966
+ assert (isa<tensor::ExtractSliceOp>(targetSliceOp));
967
+ // control recursive time in avoid of stack overflow
968
+ if (curDepth > maxDepth)
969
+ return failure ();
970
+
971
+ auto extractOp = cast<tensor::ExtractSliceOp>(targetSliceOp);
972
+ extractSliceOpChain.push_back (extractOp);
973
+ Value rootSource = extractOp.getSourceMutable ().get ();
974
+
975
+ while (true ) {
976
+ if (auto iterArg = dyn_cast<BlockArgument>(rootSource)) {
977
+ if (auto outerLoop = dyn_cast<LoopLikeOpInterface>(
978
+ iterArg.getOwner ()->getParentOp ())) {
979
+ rootSource = outerLoop.getTiedLoopInit (iterArg)->get ();
980
+ continue ;
981
+ }
982
+ return failure ();
983
+ } else if (auto sliceOp =
984
+ rootSource.getDefiningOp <tensor::ExtractSliceOp>()) {
985
+ // walk up loop to find larger candidate extractSliceOp
986
+ return getRootSourceOfExtractSliceOp (sliceOp, extractSliceOpChain,
987
+ curDepth + 1 );
988
+ }
989
+ break ;
990
+ }
991
+ return rootSource;
992
+ }
993
+
994
+ // / Recursively find the outer nest loops of given loop(included) while the
995
+ // / predict function succeed, sorted from outer to inner.
996
+ // /
997
+ // / @param loop: target loop, note that this loop will be also included. I.e.
998
+ // / if no other nest loops were found, just return itself.
999
+ // / @param pred: predict function, the termination condition of recursive
1000
+ // / process.
1001
+ // / @return Outer Nest Loops: nest loops outside given target loop(included).
1002
+ // /
1003
+ // / E.g.
1004
+ // /
1005
+ // / ```
1006
+ // / %0 = scf.for()
1007
+ // / %1 = scf.for()
1008
+ // / %2 = scf.for()
1009
+ // / ```
1010
+ // /
1011
+ // / If `%2 = scf.for` is given without specific prediction function, this
1012
+ // / function will return three nest loops: %0 + %1 + %2.
1013
+ static SmallVector<LoopLikeOpInterface>
1014
+ getOuterNestLoopsWhile (LoopLikeOpInterface loop,
1015
+ std::function<LogicalResult(LoopLikeOpInterface)> pred) {
1016
+ SmallVector<LoopLikeOpInterface> nestLoops = {loop};
1017
+ auto outerLoop = dyn_cast<LoopLikeOpInterface>(loop->getParentOp ());
1018
+ while (outerLoop && succeeded (pred (outerLoop))) {
1019
+ nestLoops.push_back (outerLoop);
1020
+ outerLoop = dyn_cast<LoopLikeOpInterface>(outerLoop->getParentOp ());
1021
+ }
1022
+ // sorted from outer to inner
1023
+ return {nestLoops.rbegin (), nestLoops.rend ()};
1024
+ }
1025
+
1026
+ // / Enhanced version of `tileAndFuseProducerOfSliceImpl`, which can deal with
1027
+ // / multi-level `extractSliceOp`. E.g.
1028
+ // /
1029
+ // / ```
1030
+ // / %0 = untiled_producer
1031
+ // / %1 = scf.for(%arg1 = %0)
1032
+ // / %2 = extract %arg1
1033
+ // / %3 = scf.for(%arg2 = %2)
1034
+ // / %4 = extract %args2
1035
+ // / %5 = tiled_consumer ins(%4)
1036
+ // / ```
1037
+ std::optional<scf::SCFFuseProducerOfSliceResult>
1038
+ mlir::scf::tileAndFuseProducerOfSlice (RewriterBase &rewriter,
1039
+ Operation *candidateSliceOp) {
1040
+ SmallVector<tensor::ExtractSliceOp> sliceOpChain;
1041
+ if (failed (getRootSourceOfExtractSliceOp (candidateSliceOp, sliceOpChain))) {
1042
+ return std::nullopt;
1043
+ }
1044
+
1045
+ std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResult;
1046
+ // reverse from outer to inner
1047
+ std::reverse (sliceOpChain.begin (), sliceOpChain.end ());
1048
+ // multiple application of `tileAndFuseProducerOfSliceImpl`
1049
+ for (auto &&[index , sliceOp] : llvm::enumerate (sliceOpChain)) {
1050
+ // get nest loops between next candidate sliceOp and tiled producer.
1051
+ auto whileProducerOutOfBlock =
1052
+ [&fuseProducerResult](LoopLikeOpInterface loop) -> LogicalResult {
1053
+ if (fuseProducerResult) {
1054
+ Block &body = loop->getRegion (0 ).front ();
1055
+ if (fuseProducerResult->tiledAndFusedProducer .getDefiningOp ()
1056
+ ->getBlock () == &body)
1057
+ return failure ();
1058
+ }
1059
+ return success ();
1060
+ };
1061
+ SmallVector<LoopLikeOpInterface> outerLoops =
1062
+ getOuterNestLoopsWhile (sliceOp->getParentOfType <LoopLikeOpInterface>(),
1063
+ whileProducerOutOfBlock);
1064
+ fuseProducerResult =
1065
+ tileAndFuseProducerOfSliceImpl (rewriter, sliceOp, outerLoops);
1066
+ if (!fuseProducerResult) {
1067
+ return std::nullopt;
1068
+ }
1069
+ }
1070
+ return fuseProducerResult;
1071
+ }
1072
+
1073
+ // / To be compatible with previous behavior
1074
+ std::optional<scf::SCFFuseProducerOfSliceResult>
1075
+ mlir::scf::tileAndFuseProducerOfSlice (
1076
+ RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1077
+ MutableArrayRef<LoopLikeOpInterface> loops) {
1078
+ return tileAndFuseProducerOfSliceImpl (rewriter, candidateSliceOp, loops);
1079
+ }
1080
+
952
1081
// / Reconstruct the fused producer from within the tiled-and-fused code.
953
1082
LogicalResult mlir::scf::yieldReplacementForFusedProducer (
954
1083
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
0 commit comments