From f7dd9d8ff798cc083bc7a5b41a65499e730c814b Mon Sep 17 00:00:00 2001 From: "Song, Yunfei" Date: Fri, 5 Jul 2024 02:12:59 -0700 Subject: [PATCH 1/2] extend fuse producer to multi-level extractSliceOp --- .../SCF/Transforms/TileUsingInterface.h | 4 + .../SCF/Transforms/TileUsingInterface.cpp | 149 +++++++++++++++++- .../tile-and-fuse-producer.mlir | 86 ++++++++++ .../TestTilingInterfaceTransformOps.cpp | 50 ++++++ .../TestTilingInterfaceTransformOps.td | 19 +++ 5 files changed, 303 insertions(+), 5 deletions(-) create mode 100644 mlir/test/Interfaces/TilingInterface/tile-and-fuse-producer.mlir diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index 1f21af6d6a29a..76fdda3645a01 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -157,11 +157,15 @@ struct SCFFuseProducerOfSliceResult { Value tiledAndFusedProducer; // Tile and fused producer value. SmallVector tiledOps; }; + std::optional tileAndFuseProducerOfSlice(RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef loops); +std::optional +tileAndFuseProducerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp); + /// Reconstruct the fused producer from within the tiled-and-fused code. Based /// on the slice of the producer computed in place it is possible that within /// the loop nest same slice of the producer is computed multiple times. It is diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index e404c01010a32..ef4235c6015ad 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1068,12 +1068,12 @@ getUntiledProducerFromSliceSource(OpOperand *source, return {dyn_cast(source->get()), destinationIterArg}; } -/// Implementation of fusing producer of a single slice by computing the +/// Basic implementation of fusing producer of a single slice by computing the /// slice of the producer in-place. -std::optional -mlir::scf::tileAndFuseProducerOfSlice( - RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, - MutableArrayRef loops) { +static std::optional +tileAndFuseProducerOfSliceImpl(RewriterBase &rewriter, + tensor::ExtractSliceOp candidateSliceOp, + MutableArrayRef loops) { // 1. Get the producer of the source (potentially walking through // `iter_args` of nested `scf.for`) auto [fusableProducer, destinationInitArg] = @@ -1185,6 +1185,145 @@ mlir::scf::tileAndFuseProducerOfSlice( tileAndFuseResult->tiledOps}; } +/// Get the real producer from candidate ExtractSliceOp +/// +/// ``` +/// %0 = producer +/// %1 = scf.for(%arg1 = %0) +/// %2 = extract %arg1 +/// %3 = scf.for(%arg2 = %2) +/// %4 = extract %args2 +/// ... +/// ``` +/// +/// @param candidateSliceOp: %4 = extract %args2 +/// @param backwardSlice: in-out parameter populated by backward extractSliceOps +/// @return OpResult Producer : %0 = producer +static FailureOr getRealProducerFromExtractSliceOp( + Operation *candidateSliceOp, + SmallVector &backwardSlice, int curDepth = 0, + int maxDepth = 5) { + if (!isa(candidateSliceOp)) + return failure(); + // control recursive time in avoid of stack overflow + if (curDepth > maxDepth) + return failure(); + + auto extractOp = cast(candidateSliceOp); + backwardSlice.push_back(extractOp); + Value rootSource = extractOp.getSourceMutable().get(); + + while (true) { + if (auto iterArg = dyn_cast(rootSource)) { + if (auto outerLoop = dyn_cast( + iterArg.getOwner()->getParentOp())) { + rootSource = outerLoop.getTiedLoopInit(iterArg)->get(); + continue; + } + return failure(); + } else if (auto sliceOp = + rootSource.getDefiningOp()) { + // walk up loop to find larger candidate extractSliceOp + return getRealProducerFromExtractSliceOp(sliceOp, backwardSlice, + curDepth + 1); + } + break; + } + return dyn_cast(rootSource); +} + +/// Recursively find the outer nest loops of given loop(included) while the +/// predict function succeed, sorted from outer to inner. +/// +/// @param loop: target loop, note that this loop will be also included. I.e. +/// if no other nest loops were found, just return itself. +/// @param pred: predict function, the termination condition of recursive +/// process. +/// @return Outer Nest Loops: nest loops outside given target loop(included). +/// +/// E.g. +/// +/// ``` +/// %0 = scf.for() +/// %1 = scf.for() +/// %2 = scf.for() +/// ``` +/// +/// If `%2 = scf.for` is given without specific prediction function, this +/// function will return three nest loops: %0 + %1 + %2. +static SmallVector getOuterNestLoopsWhile( + LoopLikeOpInterface loop, + const std::function &pred) { + SmallVector nestLoops = {loop}; + auto outerLoop = dyn_cast(loop->getParentOp()); + while (outerLoop && succeeded(pred(outerLoop))) { + nestLoops.push_back(outerLoop); + outerLoop = dyn_cast(outerLoop->getParentOp()); + } + // sorted from outer to inner + return {nestLoops.rbegin(), nestLoops.rend()}; +} + +/// Enhanced version for basic implementation of fusing producer, which can deal +/// with multi-level candidates. E.g. +/// +/// ``` +/// %0 = untiled_producer +/// %1 = scf.for(%arg1 = %0) +/// %2 = tensor.extract_slice %arg1 +/// %3 = scf.for(%arg2 = %2) +/// %4 = tensor.extract_slice %args2 +/// %5 = tiled_consumer ins(%4) +/// ``` +/// +/// This utility can fuse untiled producer at `%4 = tensor.extract_slice` within +/// inner loop `%3 = scf.for`. +std::optional +mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter, + Operation *candidateSliceOp) { + SmallVector backwardSlice; + if (failed( + getRealProducerFromExtractSliceOp(candidateSliceOp, backwardSlice))) { + return std::nullopt; + } + + std::optional fuseProducerResult; + // reverse from outer to inner + std::reverse(backwardSlice.begin(), backwardSlice.end()); + // multiple application of `tileAndFuseProducerOfSliceImpl` + for (auto &&[index, sliceOp] : llvm::enumerate(backwardSlice)) { + // get nest loops between next candidate sliceOp and tiled producer. + auto whileProducerOutOfLoopBlock = + [&fuseProducerResult](LoopLikeOpInterface loop) -> LogicalResult { + if (fuseProducerResult) { + Block &body = loop->getRegion(0).front(); + if (fuseProducerResult->tiledAndFusedProducer.getDefiningOp() + ->getBlock() == &body) + return failure(); + } + return success(); + }; + SmallVector outerLoops = + getOuterNestLoopsWhile(sliceOp->getParentOfType(), + whileProducerOutOfLoopBlock); + fuseProducerResult = + tileAndFuseProducerOfSliceImpl(rewriter, sliceOp, outerLoops); + if (!fuseProducerResult) { + return std::nullopt; + } + } + return fuseProducerResult; +} + +/// Implementation of fusing producer of a single slice by computing the +/// slice of the producer in-place. +std::optional +mlir::scf::tileAndFuseProducerOfSlice( + RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, + MutableArrayRef loops) { + return tileAndFuseProducerOfSliceImpl(rewriter, candidateSliceOp, loops); +} + /// Reconstruct the fused producer from within the tiled-and-fused code. LogicalResult mlir::scf::yieldReplacementForFusedProducer( RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-producer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-producer.mlir new file mode 100644 index 0000000000000..ef1c6952a55e1 --- /dev/null +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-producer.mlir @@ -0,0 +1,86 @@ +// RUN: mlir-opt --transform-interpreter --cse --split-input-file %s | FileCheck %s + +#map = affine_map<(d0) -> (d0 * 128)> +module { + func.func @gemm_fill_fusion_multi_level_extract_slice(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %cst = arith.constant 0.000000e+00 : f32 + %dest0 = tensor.empty() : tensor<256x256xf32> + %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> + %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> { + %iv0 = affine.apply #map(%arg3) + %iv1 = affine.apply #map(%arg4) + %extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32> + %extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32> + %extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32> + %2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<128x128xf32>) { + %3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args(%arg9 = %arg7) -> (tensor<128x128xf32>) { + %extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32> + %extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32> + %extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32> + %4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32> + %insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32> + scf.yield %insert_slice : tensor<128x128xf32> + } + scf.yield %3 : tensor<128x128xf32> + } + scf.forall.in_parallel { + tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32> + } + } + return %1 : tensor<256x256xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield = transform.get_producer_of_operand %matmul[2] + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_producer %yield + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// CHECK: #[[MAP0:.*]] = affine_map<(d0) -> (d0 * 128)> +// CHECK: func.func @gemm_fill_fusion_multi_level_extract_slice( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32> +// CHECK: %[[FORALL_RESULT:.*]] = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2) +// CHECK-SAME: shared_outs(%[[INIT_ARG0:.*]] = %[[dest0]]) +// CHECK-SAME: { +// CHECK: %[[AFFINE_IV1:.*]] = affine.apply #[[MAP0]](%[[IV1]]) +// CHECK: %[[AFFINE_IV2:.*]] = affine.apply #[[MAP0]](%[[IV2]]) +// CHECK: %[[FILL_OUT_SLICE0:.*]] = tensor.extract_slice %[[INIT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1] +// CHECK: %[[INPUT_SLICE0:.*]] = tensor.extract_slice %[[ARG0]][%[[AFFINE_IV1]], 0] [128, 512] [1, 1] +// CHECK: %[[WEIGHT_SLICE0:.*]] = tensor.extract_slice %[[ARG1]][0, %[[AFFINE_IV2]]] [512, 128] [1, 1] +// CHECK: %[[LOOP_RESULT1:.*]] = scf.for %[[IV3:.*]] = %[[C0]] +// CHECK-SAME: iter_args(%[[INIT_ARG1:.*]] = %[[FILL_OUT_SLICE0]]) +// CHECK-SAME: { +// CHECK: %[[LOOP_RESULT2:.*]] = scf.for %[[IV4:.*]] = %[[C0]] +// CHECK-SAME: iter_args(%[[INIT_ARG2:.*]] = %[[INIT_ARG1]]) +// CHECK-SAME: { +// CHECK: %[[FILL_OUT_SLICE1:.*]] = tensor.extract_slice %[[INIT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1] +// CHECK: %[[TILED_FILL_OUT:.*]] = linalg.fill +// CHECK-SAME: outs(%[[FILL_OUT_SLICE1]] : +// CHECK: %[[INPUT_SLICE1:.*]] = tensor.extract_slice %[[INPUT_SLICE0]][%[[IV3]], 0] [64, 512] [1, 1] +// CHECK: %[[WEIGHT_SLICE1:.*]] = tensor.extract_slice %[[WEIGHT_SLICE0]][0, %[[IV4]]] [512, 64] [1, 1] +// CHECK: %[[TILED_MAT_OUT:.*]] = linalg.matmul +// CHECK-SAME: outs(%[[TILED_FILL_OUT]] : +// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[INIT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1] +// CHECK: scf.yield %[[INSERT_MAT]] : +// CHECK: } +// CHECK: scf.yield %[[LOOP_RESULT2]] : +// CHECK: } +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[LOOP_RESULT1]] into %[[INIT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1] +// CHECK: } +// CHECK: } +// CHECK: return %[[FORALL_RESULT]] : \ No newline at end of file diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp index 7aa7b58433f36..b4dad98e2399c 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp @@ -160,6 +160,56 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter, : DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// TestFuseProducerOp +//===----------------------------------------------------------------------===// + +/// Apply fusing of producer transformation to all payload ops and store both +/// the original producer operation as well as the fused producer operation. +template +static LogicalResult +applyFuseProducer(RewriterBase &rewriter, Operation *transformOp, + Range &&payloadOps, TransformResults &transformResults) { + SmallVector originalProducerOps; + SmallVector fusedProducerOps; + + for (Operation *target : payloadOps) { + rewriter.setInsertionPoint(target); + + std::optional fuseProducerResults = + scf::tileAndFuseProducerOfSlice(rewriter, target); + + if (!fuseProducerResults) + return failure(); + + // Report back the relevant handles to the transform op. + originalProducerOps.push_back(fuseProducerResults->origProducer.getOwner()); + fusedProducerOps.push_back(fuseProducerResults->tiledOps[0]); + } + + transformResults.set(transformOp->getOpResult(0), originalProducerOps); + transformResults.set(transformOp->getOpResult(1), fusedProducerOps); + return success(); +} + +DiagnosedSilenceableFailure +transform::TestFuseProducerOp::apply(TransformRewriter &rewriter, + TransformResults &transformResults, + TransformState &state) { + LogicalResult result = + applyFuseProducer(rewriter, getOperation(), + state.getPayloadOps(getTarget()), transformResults); + return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() + : DiagnosedSilenceableFailure::success(); +} + +void transform::TestFuseProducerOp::getEffects( + SmallVectorImpl &effects) { + consumesHandle(getTargetMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // TestFuseConsumerOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td index d55d746bd6aa9..6e73478c35c4a 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td @@ -49,6 +49,25 @@ def TestFuseAndYieldOp : Op, + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Fuses the producer of the operation pointed to by the target handle + using the options provided as attributes. + }]; + + let arguments = + (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$producer, + TransformHandleTypeInterface:$fused_producer); + + let assemblyFormat = [{ + $target attr-dict `:` functional-type(operands, results) + }]; +} + def TestFuseConsumerOp : Op, DeclareOpInterfaceMethods, From 23796bfe3d1483950cf9175ff1c87978b6eb0bf1 Mon Sep 17 00:00:00 2001 From: "Song, Yunfei" Date: Tue, 6 Aug 2024 21:25:05 -0700 Subject: [PATCH 2/2] add `isForOpYieldResultOfInnerLoop` check --- .../SCF/Transforms/TileUsingInterface.cpp | 47 +++++++++++++------ 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index ef4235c6015ad..1b8b5a3e7f3db 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1251,9 +1251,9 @@ static FailureOr getRealProducerFromExtractSliceOp( /// /// If `%2 = scf.for` is given without specific prediction function, this /// function will return three nest loops: %0 + %1 + %2. -static SmallVector getOuterNestLoopsWhile( - LoopLikeOpInterface loop, - const std::function &pred) { +static SmallVector +getOuterNestLoopsWhile(LoopLikeOpInterface loop, + function_ref pred) { SmallVector nestLoops = {loop}; auto outerLoop = dyn_cast(loop->getParentOp()); while (outerLoop && succeeded(pred(outerLoop))) { @@ -1264,6 +1264,21 @@ static SmallVector getOuterNestLoopsWhile( return {nestLoops.rbegin(), nestLoops.rend()}; } +/// Check if it is the ForOp that yield the result of inner loop +static LogicalResult isForOpYieldResultOfInnerLoop(LoopLikeOpInterface loop) { + if (auto forOp = dyn_cast(loop.getOperation())) { + Block::OpListType &opsInLoopBody = forOp.getBody()->getOperations(); + for (auto &&[index, op] : llvm::enumerate(opsInLoopBody)) { + // If the orderIndex of inner loop is the last second one before the + // yieldOp of ForOp, the given loop must yield the result of inner loop. + if (isa(op)) { + return success((index + 2) == opsInLoopBody.size()); + } + } + } + return failure(); +} + /// Enhanced version for basic implementation of fusing producer, which can deal /// with multi-level candidates. E.g. /// @@ -1282,10 +1297,10 @@ std::optional mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp) { SmallVector backwardSlice; - if (failed( - getRealProducerFromExtractSliceOp(candidateSliceOp, backwardSlice))) { + FailureOr realProducer = + getRealProducerFromExtractSliceOp(candidateSliceOp, backwardSlice); + if (failed(realProducer)) return std::nullopt; - } std::optional fuseProducerResult; // reverse from outer to inner @@ -1294,14 +1309,18 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter, for (auto &&[index, sliceOp] : llvm::enumerate(backwardSlice)) { // get nest loops between next candidate sliceOp and tiled producer. auto whileProducerOutOfLoopBlock = - [&fuseProducerResult](LoopLikeOpInterface loop) -> LogicalResult { - if (fuseProducerResult) { - Block &body = loop->getRegion(0).front(); - if (fuseProducerResult->tiledAndFusedProducer.getDefiningOp() - ->getBlock() == &body) - return failure(); - } - return success(); + [&fuseProducerResult, + &realProducer](LoopLikeOpInterface loop) -> LogicalResult { + // ensure that all surrounding outer loops are just yielding the result of + // the inner loops. + if (failed(isForOpYieldResultOfInnerLoop(loop))) + return failure(); + Operation *originalOp = + fuseProducerResult + ? fuseProducerResult->tiledAndFusedProducer.getDefiningOp() + : realProducer->getDefiningOp(); + Block &body = loop->getRegion(0).front(); + return success(originalOp->getBlock() != &body); }; SmallVector outerLoops = getOuterNestLoopsWhile(sliceOp->getParentOfType(),