Skip to content

[mlir] Fix consumer fusion for producer with multiple results #125915

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 110 additions & 15 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1949,6 +1949,60 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
}
}

// If the producer of the operand is a loopLikeOp, then finds the last
// insertSlice/parallelInsertSlice in the producer op that uses the block
// argument corresponding to the operand.
static FailureOr<Operation *>
getSliceOpFromConsumerOperand(OpOperand &operand) {

OpResult producerResult = dyn_cast<OpResult>(operand.get());
if (!producerResult)
return failure();

LoopLikeOpInterface loopLikeOp =
dyn_cast<LoopLikeOpInterface>(producerResult.getOwner());
if (!loopLikeOp)
return failure();

// Obtain the BlockArgument correponding to the result.
BlockArgument bbArg =
loopLikeOp.getRegionIterArgs()[producerResult.getResultNumber()];

// Finally return the operation corresponding to the yielded value.
// Also check whether it's an InsertSliceOp.
if (dyn_cast<scf::ForOp>(producerResult.getOwner())) {
OpOperand *yieldVal = loopLikeOp.getTiedLoopYieldedValue(bbArg);
Operation *lastOp = dyn_cast<OpResult>(yieldVal->get()).getOwner();
auto isInsertSliceOp = isa<tensor::InsertSliceOp>(lastOp);
if (!isInsertSliceOp) {
return failure();
}
return lastOp;
}

auto forallOp = dyn_cast<scf::ForallOp>(producerResult.getOwner());
if (!forallOp)
return failure();

// Iterate over the terminator operation of the forallOp to find the last
// parallelInsertSliceOp that uses the blockArgument.
Operation *lastOp = nullptr;
forallOp.getTerminator()->walk([&](tensor::ParallelInsertSliceOp op) {
for (mlir::Value operand : op->getOperands()) {
if (auto maybeBlockArg = dyn_cast<BlockArgument>(operand)) {
if (maybeBlockArg == bbArg) {
lastOp = op;
}
}
}
});

if (!lastOp)
return failure();

return lastOp;
}

/// Implementation of fusing consumer of a single slice by computing the
/// slice of the consumer in-place for scf loop.
FailureOr<scf::SCFFuseConsumerOfSliceResult>
Expand Down Expand Up @@ -1979,6 +2033,26 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
consumerOp, "consumer op's operand doesn't seem to be an OpResult");
}

SmallVector<OpOperand *> potentialOperands = {*maybeConsumerOpOperand};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please leave some comments as to what this is for.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment below actually is fine. Move that up above these.
Also instead of a SmallVector<OpOperand *> and a SmallVector<Operation *> for the slices, just make a SmallVector<std::tuple<OpOperand *, Operation *>>

SmallVector<unsigned> potentialOperandResultNos = {
consumerOpOperand->getOperandNumber()};
SmallVector<Operation *> potentialSliceOps = {candidateSliceOp};

// 1b. Get all the other operands of the consumer op and their corresponding
// slice ops. In the case of the consumer using multiple results
// from the producer, we need to update every operand.
for (OpOperand &otherOperand : consumerOp->getOpOperands()) {
if (&otherOperand == *maybeConsumerOpOperand)
continue;
auto maybePotentialSlice = getSliceOpFromConsumerOperand(otherOperand);
if (failed(maybePotentialSlice)) {
continue;
}
potentialSliceOps.push_back(*maybePotentialSlice);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to check right here that the producer of otherOperand is the same as *maybeConsumerOpOperand?

potentialOperands.push_back(&otherOperand);
potentialOperandResultNos.push_back(otherOperand.getOperandNumber());
}

// There are two possible cases regarding `oldLoopOp` here:
// 1. single `scf.forall` or `scf.for`.
// 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
Expand Down Expand Up @@ -2037,43 +2111,64 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
// tensor.insert_slice. In the scf.for case this is a clone of the
// candidateSliceOp whereas in the scf.forall case this is created from the
// operands of tensor.parallel_insert_slice.
tensor::InsertSliceOp clonedInsertSliceOp;

SmallVector<tensor::InsertSliceOp> allClonedInsertSliceOps;

scf::ForallOp newForallOp;
if (auto sliceOp =
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
rewriter.setInsertionPoint(newForallOp.getTerminator());
clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
} else {
rewriter.setInsertionPoint(candidateSliceOp);
clonedInsertSliceOp =
cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
rewriter.setInsertionPoint(potentialSliceOps.back());
}

for (auto *candidateSliceOp : potentialSliceOps) {
if (auto sliceOp =
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
allClonedInsertSliceOps.push_back(rewriter.create<tensor::InsertSliceOp>(
loc, sliceOp.getSource(), sliceOp.getDest(),
sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
sliceOp.getMixedStrides()));
} else {
allClonedInsertSliceOps.push_back(
cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp)));
}
}

// 5.a. Clone consumer op.
auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));

// 5.b. Replace all uses of the loop result with the result of the cloned
// tensor.insert_slice.
OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
operandToReplace.set(clonedInsertSliceOp.getResult());
});
for (const auto &it : llvm::enumerate(allClonedInsertSliceOps)) {
OpOperand &operandToReplace =
clonedConsumerOp->getOpOperand(potentialOperandResultNos[it.index()]);
rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
operandToReplace.set(it.value().getResult());
});
}

// 6. Perform tiling of the cloned consumer and replace the operand at
// `operandNumber` with the source of the cloned tensor.insert_slice op.
auto ossSliceOp =
cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
auto ossSliceOp = cast<OffsetSizeAndStrideOpInterface>(
allClonedInsertSliceOps.front().getOperation());
FailureOr<TilingResult> tileAndFuseResult =
tensor::replaceInsertSliceWithTiledConsumer(
rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));

if (failed(tileAndFuseResult)) {
return failure();
}

auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
clonedInsertSliceOp.getSource());

// 6b. Update the tiled consumer op with the new operands.
for (const auto &it : llvm::enumerate(allClonedInsertSliceOps)) {
rewriter.replaceAllUsesWith(
tiledConsumerOp->getOperand(potentialOperandResultNos[it.index()]),
it.value().getSource());
}

// 7. Reconstruct [nested] loop with new inits.
YieldTiledValuesFn newYieldValuesFn =
Expand Down
132 changes: 128 additions & 4 deletions mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ module {
return %unpack : tensor<2048xf32>
}
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
Expand Down Expand Up @@ -343,7 +343,7 @@ module {
return %unpack : tensor<2047xf32>
}
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
Expand Down Expand Up @@ -404,7 +404,7 @@ module {
return %pack : tensor<4x32x16xf32>
}
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
Expand Down Expand Up @@ -610,7 +610,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
// CHECK: %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]]
// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]])
// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]])
// CHECK-SAME: {
// CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
// CHECK: %[[ADD_INS0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 256] [1, 1]
Expand Down Expand Up @@ -676,3 +676,127 @@ module attributes {transform.with_named_sequence} {
// CHECK: }
// CHECK: %[[RES_SLICE:.+]] = tensor.insert_slice
// CHECK: return %[[LOOP_RESULT]]#1, %[[RES_SLICE]]

// -----

module {
func.func @forall_producer_multiple_result_single_consumer(%arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
%c4 = arith.constant 4 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) {
%outs = tensor.empty() : tensor<32x32xf32>
%extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
%3 = linalg.matmul ins(%extracted_slice, %extracted_slice : tensor<32x32xf32>, tensor<32x32xf32>) outs(%outs : tensor<32x32xf32>) -> tensor<32x32xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
tensor.parallel_insert_slice %extracted_slice into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
}
}
%final_out = tensor.empty() : tensor<64x64xf32>
%2 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%1#0, %1#1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%final_out : tensor<64x64xf32>) -> tensor<64x64xf32>
return %2 : tensor<64x64xf32>
}
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%consumer, %fused_consumer = transform.test.fuse_consumer %1#0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}

// CHECK-LABEL: func.func @forall_producer_multiple_result_single_consumer(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<64x64xf32>

// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<64x64xf32>
// CHECK: %[[LOOP_RESULT:.+]]:3 = scf.forall (%[[I:.+]], %[[J:.+]]) in (2, 2) shared_outs(%[[SHARED0:.+]] = %[[ARG0]], %[[SHARED1:.+]] = %[[ARG0]], %[[SHARED2:.+]] = %[[INIT]])

// CHECK: %[[TILE_INIT:.+]] = tensor.empty() : tensor<32x32xf32>
// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[SHARED0]][%[[I]], %[[J]]] [32, 32] [1, 1]
// CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE]] : tensor<32x32xf32>, tensor<32x32xf32>) outs(%[[TILE_INIT]] : tensor<32x32xf32>)
// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[MATMUL]] into %[[SHARED1]][%[[I]], %[[J]]] [32, 32] [1, 1]
// CHECK: %[[INSERTED_SLICE0:.+]] = tensor.insert_slice %[[EXTRACTED_SLICE]] into %[[SHARED0]][%[[I]], %[[J]]] [32, 32] [1, 1]
// CHECK: %[[EXTRACTED_SLICE1:.+]] = tensor.extract_slice %[[SHARED2]][%[[I]], %[[J]]] [32, 32] [1, 1]
// CHECK: %[[ADD:.+]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%[[EXTRACTED_SLICE]], %[[MATMUL]] : tensor<32x32xf32>, tensor<32x32xf32>) outs(%[[EXTRACTED_SLICE1]] : tensor<32x32xf32>)

// CHECK: scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %[[MATMUL]] into %[[SHARED1]][%[[I]], %[[J]]] [32, 32] [1, 1]
// CHECK: tensor.parallel_insert_slice %[[EXTRACTED_SLICE]] into %[[SHARED0]][%[[I]], %[[J]]] [32, 32] [1, 1]
// CHECK: tensor.parallel_insert_slice %[[ADD]] into %[[SHARED2]][%[[I]], %[[J]]] [32, 32] [1, 1]
// CHECK: }

// CHECK: return %[[LOOP_RESULT]]#2 : tensor<64x64xf32>


// -----

#map = affine_map<(d0) -> (d0)>
module {
func.func @for_producer_producing_multiple_result_single_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
%c4 = arith.constant 4 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) {
%extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32>
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) {
^bb0(%in: f32, %in_16: f32, %out: f32):
%13 = arith.mulf %in, %in_16 : f32
%14 = arith.addf %out, %13 : f32
linalg.yield %14 : f32
} -> tensor<32xf32>
%4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
%5 = tensor.insert_slice %3 into %arg5[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
scf.yield %5, %4 : tensor<64xf32>, tensor<64xf32>
}
%out_operand = tensor.empty() : tensor<64xf32>
%2 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%1#1, %1#0 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand : tensor<64xf32>) -> tensor<64xf32>
return %2 : tensor<64xf32>
}
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%consumer, %fused_consumer = transform.test.fuse_consumer %1#0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}

// CHECK-LABEL: func.func @for_producer_producing_multiple_result_single_consumer(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<32xf32>,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<32xf32>,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<64xf32>

// CHECK: %[[C4:.+]] = arith.constant 4 : index
// CHECK: %[[C64:.+]] = arith.constant 64 : index
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<64xf32>

// CHECK: %[[LOOP_RESULT:.+]]:3 = scf.for %[[IV:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C64]] step %[[C4]]
// CHECK-SAME: iter_args(%[[ITER0:.+]] = %[[ARG2]], %[[ITER1:.+]] = %[[ARG2]], %[[ITER2:.+]] = %[[INIT]])
// CHECK-SAME: -> (tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)

// CHECK: %[[EXTRACT_SLICE:.+]] = tensor.extract_slice %[[ITER0]][%[[IV]]] [32] [1]
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<32xf32>, tensor<32xf32>)
// CHECK-SAME: outs(%[[EXTRACT_SLICE]] : tensor<32xf32>)
// CHECK: ^{{.*}}(%[[IN0:.+]]: f32, %[[IN1:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: %[[MUL:.+]] = arith.mulf %[[IN0]], %[[IN1]] : f32
// CHECK: %[[ADD:.+]] = arith.addf %[[OUT]], %[[MUL]] : f32
// CHECK: linalg.yield %[[ADD]] : f32

// CHECK: %[[INSERT_SLICE0:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER0]][%[[IV]]] [32] [1]
// CHECK: %[[INSERT_SLICE1:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER1]][%[[IV]]] [32] [1]
// CHECK: %[[EXTRACT_SLICE2:.+]] = tensor.extract_slice %[[ITER2]][%[[IV]]] [32] [1]
// CHECK: %[[BINARY:.+]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
// CHECK-SAME: ins(%[[GENERIC]], %[[GENERIC]] : tensor<32xf32>, tensor<32xf32>)
// CHECK-SAME: outs(%[[EXTRACT_SLICE2]] : tensor<32xf32>)
// CHECK: %[[INSERT_SLICE2:.+]] = tensor.insert_slice %[[BINARY]] into %[[ITER2]][%[[IV]]] [32] [1]

// CHECK: scf.yield %[[INSERT_SLICE1]], %[[INSERT_SLICE0]], %[[INSERT_SLICE2]]

// CHECK: return %[[LOOP_RESULT]]#2 : tensor<64xf32>