From 191008bdb7952e028b806f98d4db28db80399bcb Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Mon, 3 Feb 2025 22:12:01 +0530 Subject: [PATCH] [mlir] Fix consumer fusion for producer with multiple results In the case of consumer fusion where the producer is producing multiple results all used by a single consumer for e.g., %results:3 = scf.forall ... -> (tensor<...>, tensor<...>, tensor<...>) { // Produces 3 results scf.yield %a, %b, %c : tensor<...>, tensor<...>, tensor<...>} // Consumer uses all 3 results %final = consumer %results#0, %results#1, %results#2 all other operands of the tiled consumer needs to updated. --- .../SCF/Transforms/TileUsingInterface.cpp | 125 +++++++++++++++-- .../tile-and-fuse-consumer.mlir | 132 +++++++++++++++++- 2 files changed, 238 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index b548f8ce8b560..3c2324d620211 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1949,6 +1949,60 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) { } } +// If the producer of the operand is a loopLikeOp, then finds the last +// insertSlice/parallelInsertSlice in the producer op that uses the block +// argument corresponding to the operand. +static FailureOr +getSliceOpFromConsumerOperand(OpOperand &operand) { + + OpResult producerResult = dyn_cast(operand.get()); + if (!producerResult) + return failure(); + + LoopLikeOpInterface loopLikeOp = + dyn_cast(producerResult.getOwner()); + if (!loopLikeOp) + return failure(); + + // Obtain the BlockArgument correponding to the result. + BlockArgument bbArg = + loopLikeOp.getRegionIterArgs()[producerResult.getResultNumber()]; + + // Finally return the operation corresponding to the yielded value. + // Also check whether it's an InsertSliceOp. + if (dyn_cast(producerResult.getOwner())) { + OpOperand *yieldVal = loopLikeOp.getTiedLoopYieldedValue(bbArg); + Operation *lastOp = dyn_cast(yieldVal->get()).getOwner(); + auto isInsertSliceOp = isa(lastOp); + if (!isInsertSliceOp) { + return failure(); + } + return lastOp; + } + + auto forallOp = dyn_cast(producerResult.getOwner()); + if (!forallOp) + return failure(); + + // Iterate over the terminator operation of the forallOp to find the last + // parallelInsertSliceOp that uses the blockArgument. + Operation *lastOp = nullptr; + forallOp.getTerminator()->walk([&](tensor::ParallelInsertSliceOp op) { + for (mlir::Value operand : op->getOperands()) { + if (auto maybeBlockArg = dyn_cast(operand)) { + if (maybeBlockArg == bbArg) { + lastOp = op; + } + } + } + }); + + if (!lastOp) + return failure(); + + return lastOp; +} + /// Implementation of fusing consumer of a single slice by computing the /// slice of the consumer in-place for scf loop. FailureOr @@ -1979,6 +2033,26 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, consumerOp, "consumer op's operand doesn't seem to be an OpResult"); } + SmallVector potentialOperands = {*maybeConsumerOpOperand}; + SmallVector potentialOperandResultNos = { + consumerOpOperand->getOperandNumber()}; + SmallVector potentialSliceOps = {candidateSliceOp}; + + // 1b. Get all the other operands of the consumer op and their corresponding + // slice ops. In the case of the consumer using multiple results + // from the producer, we need to update every operand. + for (OpOperand &otherOperand : consumerOp->getOpOperands()) { + if (&otherOperand == *maybeConsumerOpOperand) + continue; + auto maybePotentialSlice = getSliceOpFromConsumerOperand(otherOperand); + if (failed(maybePotentialSlice)) { + continue; + } + potentialSliceOps.push_back(*maybePotentialSlice); + potentialOperands.push_back(&otherOperand); + potentialOperandResultNos.push_back(otherOperand.getOperandNumber()); + } + // There are two possible cases regarding `oldLoopOp` here: // 1. single `scf.forall` or `scf.for`. // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the @@ -2037,18 +2111,29 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, // tensor.insert_slice. In the scf.for case this is a clone of the // candidateSliceOp whereas in the scf.forall case this is created from the // operands of tensor.parallel_insert_slice. - tensor::InsertSliceOp clonedInsertSliceOp; + + SmallVector allClonedInsertSliceOps; + + scf::ForallOp newForallOp; if (auto sliceOp = dyn_cast(candidateSliceOp)) { auto newForallOp = cast(innerMostLoop.getOperation()); rewriter.setInsertionPoint(newForallOp.getTerminator()); - clonedInsertSliceOp = rewriter.create( - loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(), - sliceOp.getMixedSizes(), sliceOp.getMixedStrides()); } else { - rewriter.setInsertionPoint(candidateSliceOp); - clonedInsertSliceOp = - cast(rewriter.clone(*candidateSliceOp)); + rewriter.setInsertionPoint(potentialSliceOps.back()); + } + + for (auto *candidateSliceOp : potentialSliceOps) { + if (auto sliceOp = + dyn_cast(candidateSliceOp)) { + allClonedInsertSliceOps.push_back(rewriter.create( + loc, sliceOp.getSource(), sliceOp.getDest(), + sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), + sliceOp.getMixedStrides())); + } else { + allClonedInsertSliceOps.push_back( + cast(rewriter.clone(*candidateSliceOp))); + } } // 5.a. Clone consumer op. @@ -2056,24 +2141,34 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, // 5.b. Replace all uses of the loop result with the result of the cloned // tensor.insert_slice. - OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber); - rewriter.modifyOpInPlace(clonedConsumerOp, [&]() { - operandToReplace.set(clonedInsertSliceOp.getResult()); - }); + for (const auto &it : llvm::enumerate(allClonedInsertSliceOps)) { + OpOperand &operandToReplace = + clonedConsumerOp->getOpOperand(potentialOperandResultNos[it.index()]); + rewriter.modifyOpInPlace(clonedConsumerOp, [&]() { + operandToReplace.set(it.value().getResult()); + }); + } // 6. Perform tiling of the cloned consumer and replace the operand at // `operandNumber` with the source of the cloned tensor.insert_slice op. - auto ossSliceOp = - cast(clonedInsertSliceOp.getOperation()); + auto ossSliceOp = cast( + allClonedInsertSliceOps.front().getOperation()); FailureOr tileAndFuseResult = tensor::replaceInsertSliceWithTiledConsumer( rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber)); + if (failed(tileAndFuseResult)) { return failure(); } + auto tiledConsumerOp = cast(tileAndFuseResult->tiledOps[0]); - rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber), - clonedInsertSliceOp.getSource()); + + // 6b. Update the tiled consumer op with the new operands. + for (const auto &it : llvm::enumerate(allClonedInsertSliceOps)) { + rewriter.replaceAllUsesWith( + tiledConsumerOp->getOperand(potentialOperandResultNos[it.index()]), + it.value().getSource()); + } // 7. Reconstruct [nested] loop with new inits. YieldTiledValuesFn newYieldValuesFn = diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir index a2871b30698c5..14b9ec504c158 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir @@ -282,7 +282,7 @@ module { return %unpack : tensor<2048xf32> } } - + module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 @@ -343,7 +343,7 @@ module { return %unpack : tensor<2047xf32> } } - + module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 @@ -404,7 +404,7 @@ module { return %pack : tensor<4x32x16xf32> } } - + module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 @@ -610,7 +610,7 @@ module attributes {transform.with_named_sequence} { // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32> // CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32> // CHECK: %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]] -// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]]) +// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]]) // CHECK-SAME: { // CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] // CHECK: %[[ADD_INS0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 256] [1, 1] @@ -676,3 +676,127 @@ module attributes {transform.with_named_sequence} { // CHECK: } // CHECK: %[[RES_SLICE:.+]] = tensor.insert_slice // CHECK: return %[[LOOP_RESULT]]#1, %[[RES_SLICE]] + +// ----- + +module { + func.func @forall_producer_multiple_result_single_consumer(%arg2: tensor<64x64xf32>) -> tensor<64x64xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) { + %outs = tensor.empty() : tensor<32x32xf32> + %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32> + %3 = linalg.matmul ins(%extracted_slice, %extracted_slice : tensor<32x32xf32>, tensor<32x32xf32>) outs(%outs : tensor<32x32xf32>) -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32> + tensor.parallel_insert_slice %extracted_slice into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32> + } + } + %final_out = tensor.empty() : tensor<64x64xf32> + %2 = linalg.elemwise_binary {fun = #linalg.binary_fn} ins(%1#0, %1#1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%final_out : tensor<64x64xf32>) -> tensor<64x64xf32> + return %2 : tensor<64x64xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %consumer, %fused_consumer = transform.test.fuse_consumer %1#0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// CHECK-LABEL: func.func @forall_producer_multiple_result_single_consumer( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<64x64xf32> + +// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<64x64xf32> +// CHECK: %[[LOOP_RESULT:.+]]:3 = scf.forall (%[[I:.+]], %[[J:.+]]) in (2, 2) shared_outs(%[[SHARED0:.+]] = %[[ARG0]], %[[SHARED1:.+]] = %[[ARG0]], %[[SHARED2:.+]] = %[[INIT]]) + +// CHECK: %[[TILE_INIT:.+]] = tensor.empty() : tensor<32x32xf32> +// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[SHARED0]][%[[I]], %[[J]]] [32, 32] [1, 1] +// CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE]] : tensor<32x32xf32>, tensor<32x32xf32>) outs(%[[TILE_INIT]] : tensor<32x32xf32>) +// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[MATMUL]] into %[[SHARED1]][%[[I]], %[[J]]] [32, 32] [1, 1] +// CHECK: %[[INSERTED_SLICE0:.+]] = tensor.insert_slice %[[EXTRACTED_SLICE]] into %[[SHARED0]][%[[I]], %[[J]]] [32, 32] [1, 1] +// CHECK: %[[EXTRACTED_SLICE1:.+]] = tensor.extract_slice %[[SHARED2]][%[[I]], %[[J]]] [32, 32] [1, 1] +// CHECK: %[[ADD:.+]] = linalg.elemwise_binary {fun = #linalg.binary_fn} ins(%[[EXTRACTED_SLICE]], %[[MATMUL]] : tensor<32x32xf32>, tensor<32x32xf32>) outs(%[[EXTRACTED_SLICE1]] : tensor<32x32xf32>) + +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[MATMUL]] into %[[SHARED1]][%[[I]], %[[J]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[EXTRACTED_SLICE]] into %[[SHARED0]][%[[I]], %[[J]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[ADD]] into %[[SHARED2]][%[[I]], %[[J]]] [32, 32] [1, 1] +// CHECK: } + +// CHECK: return %[[LOOP_RESULT]]#2 : tensor<64x64xf32> + + +// ----- + +#map = affine_map<(d0) -> (d0)> +module { + func.func @for_producer_producing_multiple_result_single_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) { + %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32xf32> + %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32> + %5 = tensor.insert_slice %3 into %arg5[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32> + scf.yield %5, %4 : tensor<64xf32>, tensor<64xf32> + } + %out_operand = tensor.empty() : tensor<64xf32> + %2 = linalg.elemwise_binary {fun = #linalg.binary_fn} ins(%1#1, %1#0 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand : tensor<64xf32>) -> tensor<64xf32> + return %2 : tensor<64xf32> + } + } + + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %consumer, %fused_consumer = transform.test.fuse_consumer %1#0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } + } + +// CHECK-LABEL: func.func @for_producer_producing_multiple_result_single_consumer( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<32xf32>, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<32xf32>, +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<64xf32> + +// CHECK: %[[C4:.+]] = arith.constant 4 : index +// CHECK: %[[C64:.+]] = arith.constant 64 : index +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<64xf32> + +// CHECK: %[[LOOP_RESULT:.+]]:3 = scf.for %[[IV:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C64]] step %[[C4]] +// CHECK-SAME: iter_args(%[[ITER0:.+]] = %[[ARG2]], %[[ITER1:.+]] = %[[ARG2]], %[[ITER2:.+]] = %[[INIT]]) +// CHECK-SAME: -> (tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) + +// CHECK: %[[EXTRACT_SLICE:.+]] = tensor.extract_slice %[[ITER0]][%[[IV]]] [32] [1] +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<32xf32>, tensor<32xf32>) +// CHECK-SAME: outs(%[[EXTRACT_SLICE]] : tensor<32xf32>) +// CHECK: ^{{.*}}(%[[IN0:.+]]: f32, %[[IN1:.+]]: f32, %[[OUT:.+]]: f32): +// CHECK: %[[MUL:.+]] = arith.mulf %[[IN0]], %[[IN1]] : f32 +// CHECK: %[[ADD:.+]] = arith.addf %[[OUT]], %[[MUL]] : f32 +// CHECK: linalg.yield %[[ADD]] : f32 + +// CHECK: %[[INSERT_SLICE0:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER0]][%[[IV]]] [32] [1] +// CHECK: %[[INSERT_SLICE1:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER1]][%[[IV]]] [32] [1] +// CHECK: %[[EXTRACT_SLICE2:.+]] = tensor.extract_slice %[[ITER2]][%[[IV]]] [32] [1] +// CHECK: %[[BINARY:.+]] = linalg.elemwise_binary {fun = #linalg.binary_fn} +// CHECK-SAME: ins(%[[GENERIC]], %[[GENERIC]] : tensor<32xf32>, tensor<32xf32>) +// CHECK-SAME: outs(%[[EXTRACT_SLICE2]] : tensor<32xf32>) +// CHECK: %[[INSERT_SLICE2:.+]] = tensor.insert_slice %[[BINARY]] into %[[ITER2]][%[[IV]]] [32] [1] + +// CHECK: scf.yield %[[INSERT_SLICE1]], %[[INSERT_SLICE0]], %[[INSERT_SLICE2]] + +// CHECK: return %[[LOOP_RESULT]]#2 : tensor<64xf32>