-
Notifications
You must be signed in to change notification settings - Fork 13.4k
convert scfforall to scf for with shared outputs #133032
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
convert scfforall to scf for with shared outputs #133032
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-scf Author: Saiem Irfan (CursedKeyboard) Changes@matthias-springer unable to ping the actual guy so I'll ping the name I see the most in mlir. Full diff: https://github.com/llvm/llvm-project/pull/133032.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 1cfb866db0b51..e41be8cbc1aa1 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -19,12 +19,14 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <optional>
using namespace mlir;
using namespace mlir::scf;
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
index a2f03f1e1056e..c8960039a6ce1 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
@@ -14,7 +14,12 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "llvm/ADT/STLExtras.h"
namespace mlir {
#define GEN_PASS_DEF_SCFFORALLTOFORLOOP
@@ -35,16 +40,108 @@ mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp,
SmallVector<Value> lbs = forallOp.getLowerBound(rewriter);
SmallVector<Value> ubs = forallOp.getUpperBound(rewriter);
SmallVector<Value> steps = forallOp.getStep(rewriter);
- LoopNest loopNest = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps);
+ SmallVector<Value> iterArgs;
+ for (auto result : forallOp->getResults()) {
+ iterArgs.push_back(forallOp.getTiedOpOperand(result)->get());
+ }
+
+ InParallelOp threadReduction =
+ cast<InParallelOp>(forallOp.getBody()->getTerminator());
+ SmallVector<tensor::ParallelInsertSliceOp> regionArgToSlice;
+ for (auto &op : threadReduction.getBody()->getOperations()) {
+ auto parallelInsert = dyn_cast<tensor::ParallelInsertSliceOp>(op);
+ if (!parallelInsert) {
+ return op.emitOpError() << "expected parallel insert slice op";
+ }
+ regionArgToSlice.push_back(parallelInsert);
+ }
+
+ function_ref<ValueVector(OpBuilder &, Location, ValueRange, ValueRange)>
+ build = [&](OpBuilder &rewriter, Location loc, ValueRange ivs,
+ ValueRange regionArgs) -> ValueVector {
+ SmallVector<Value> res;
+ for (auto [i, val] : llvm::enumerate(regionArgs)) {
+ tensor::ParallelInsertSliceOp sliceOp = regionArgToSlice[i];
+
+ // Map new induction variables where applicable.
+
+ SmallVector<OpFoldResult> sliceOpOffsets = sliceOp.getMixedOffsets();
+ for (OpFoldResult offset : sliceOpOffsets) {
+ if (offset.is<Value>()) {
+ Value dynamicOffset = offset.get<Value>();
+ SmallVector<Value> originalInductionVars =
+ forallOp.getInductionVars();
+ auto *it = llvm::find(originalInductionVars, dynamicOffset);
+ if (it != originalInductionVars.end()) {
+ size_t index = std::distance(originalInductionVars.begin(), it);
+ offset = ivs[index];
+ }
+ }
+ }
+
+ SmallVector<OpFoldResult> sliceOpSizes = sliceOp.getMixedSizes();
+ for (OpFoldResult size : sliceOpSizes) {
+ if (size.is<Value>()) {
+ Value dynamicSize = size.get<Value>();
+ SmallVector<Value> originalInductionVars =
+ forallOp.getInductionVars();
+ auto *it = llvm::find(originalInductionVars, dynamicSize);
+ if (it != originalInductionVars.end()) {
+ size_t index = std::distance(originalInductionVars.begin(), it);
+ size = ivs[index];
+ }
+ }
+ }
+
+ SmallVector<OpFoldResult> sliceOpStrides = sliceOp.getMixedStrides();
+ for (OpFoldResult stride : sliceOpStrides) {
+ if (stride.is<Value>()) {
+ Value dynamicStride = stride.get<Value>();
+ SmallVector<Value> originalInductionVars =
+ forallOp.getInductionVars();
+ auto *it = llvm::find(originalInductionVars, dynamicStride);
+ if (it != originalInductionVars.end()) {
+ size_t index = std::distance(originalInductionVars.begin(), it);
+ stride = ivs[index];
+ }
+ }
+ }
+
+ res.push_back(rewriter.create<tensor::InsertSliceOp>(
+ sliceOp->getLoc(), sliceOp.getSource(), val, sliceOpOffsets,
+ sliceOpSizes, sliceOpStrides));
+ }
+ return res;
+ };
+ // Now we want to create our new loops with the innermost getting the tensor
+ // insert slices appropriately.
+ LoopNest loopNest =
+ scf::buildLoopNest(rewriter, loc, lbs, ubs, steps, iterArgs, build);
SmallVector<Value> ivs = llvm::map_to_vector(
loopNest.loops, [](scf::ForOp loop) { return loop.getInductionVar(); });
+ rewriter.replaceAllOpUsesWith(forallOp,
+ {loopNest.loops.front()->getResults()});
+ // Erase the parallel inserts and associated shared outputs.
+ for (tensor::ParallelInsertSliceOp insertSlice :
+ llvm::make_early_inc_range(regionArgToSlice)) {
+ auto loopBlockArg = dyn_cast<BlockArgument>(insertSlice.getDest());
+ if (!loopBlockArg || loopBlockArg.getOwner()->getParentOp() != forallOp) {
+ insertSlice->emitOpError()
+ << "expected destination to be block argument in loop";
+ }
+ rewriter.eraseOp(insertSlice);
+ rewriter.modifyOpInPlace(forallOp, [&]() {
+ forallOp.getBody()->eraseArgument(loopBlockArg.getArgNumber());
+ });
+ }
+ rewriter.eraseOp(forallOp.getTerminator());
+
Block *innermostBlock = loopNest.loops.back().getBody();
- rewriter.eraseOp(forallOp.getBody()->getTerminator());
+
rewriter.inlineBlockBefore(forallOp.getBody(), innermostBlock,
- innermostBlock->getTerminator()->getIterator(),
- ivs);
+ innermostBlock->front().getIterator(), ivs);
rewriter.eraseOp(forallOp);
if (results) {
diff --git a/mlir/test/Dialect/SCF/forall-to-for.mlir b/mlir/test/Dialect/SCF/forall-to-for.mlir
index e7d183fb9d2b5..4d8390f0b62c4 100644
--- a/mlir/test/Dialect/SCF/forall-to-for.mlir
+++ b/mlir/test/Dialect/SCF/forall-to-for.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-for))' -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-for,canonicalize))' -split-input-file | FileCheck %s
func.func private @callee(%i: index, %j: index)
@@ -55,3 +55,40 @@ func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) {
}
return
}
+
+// -----
+
+func.func @nested_with_result() -> tensor<4x2xf32> {
+ %c2 = arith.constant 2 : index
+ %c4 = arith.constant 4 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<4x2xf32>
+ %res = scf.forall (%arg0, %arg1) in (%c4, %c2) shared_outs(%o = %0) -> (tensor<4x2xf32>) {
+ %1 = tensor.empty() : tensor<1x1xf32>
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x1xf32>) -> tensor<1x1xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %2 into %o[%arg0, %arg1] [1, 1] [1, 1] :
+ tensor<1x1xf32> into tensor<4x2xf32>
+ }
+ }
+ return %res: tensor<4x2xf32>
+}
+
+// CHECK-LABEL: func.func @nested_with_result() -> tensor<4x2xf32> {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[FILL:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[REDUCED_RES:.*]] = tensor.empty() : tensor<4x2xf32>
+// CHECK: %[[OUTER:.*]] = scf.for %[[IV_OUTER:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[OUTER_RES:.*]] = %[[REDUCED_RES]]) -> (tensor<4x2xf32>) {
+// CHECK: %[[INNER:.*]] = scf.for %[[IV_INNER:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[INNER_RES:.*]] = %[[OUTER_RES]]) -> (tensor<4x2xf32>) {
+// CHECK: %[[ITERATION_TENS:.*]] = tensor.empty() : tensor<1x1xf32>
+// CHECK: %[[ITERATION_RES:.*]] = linalg.fill ins(%[[FILL]] : f32) outs(%[[ITERATION_TENS]] : tensor<1x1xf32>) -> tensor<1x1xf32>
+// CHECK: %[[UPDATED_RES:.*]] = tensor.insert_slice %[[ITERATION_RES]] into %[[INNER_RES]]{{\[}}%[[IV_OUTER]], %[[IV_INNER]]] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<4x2xf32>
+// CHECK: scf.yield %[[UPDATED_RES]] : tensor<4x2xf32>
+// CHECK: }
+// CHECK: scf.yield %[[INNER]] : tensor<4x2xf32>
+// CHECK: }
+// CHECK: return %[[OUTER]] : tensor<4x2xf32>
+// CHECK: }
\ No newline at end of file
|
@matthias-springer unable to ping the actual guy so I'll ping the name I see the most in mlir.