-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][SCF] Add scf.for
bufferization preprocessing pass
#87594
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][SCF] Add scf.for
bufferization preprocessing pass
#87594
Conversation
Add a bufferization preprocessing pass for `scf.for` loops to support loops where a yielded tensor value does not bufferize to the equivalent corresponding iter_arg buffer. This preprocessing works around a limitation of `scf.for` bufferization by inserting additional buffer copies for yielded tensors. This preprocessing can be used to support most cases where One-Shot Bufferize fails to bufferize the IR with the following error message: ``` error: Yield operand #0 is not equivalent to the corresponding iter bbArg ```
@llvm/pr-subscribers-mlir-scf @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesAdd a bufferization preprocessing pass for This preprocessing can be used to support most cases where One-Shot Bufferize fails to bufferize the IR with the following error message:
This commit fixes openxla/iree#16956. Full diff: https://github.com/llvm/llvm-project/pull/87594.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
index 90b315e83a8cfd..6107219ea94ae1 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
@@ -23,6 +23,9 @@ namespace mlir {
/// Creates a pass that bufferizes the SCF dialect.
std::unique_ptr<Pass> createSCFBufferizePass();
+/// Creates a pass that preprocesses SCF loop before One-Shot Bufferize.
+std::unique_ptr<Pass> createSCFLoopBufferizationPreprocessingPass();
+
/// Creates a pass that specializes for loop for unrolling and
/// vectorization.
std::unique_ptr<Pass> createForLoopSpecializationPass();
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 350611ad86873d..94d3e51a1c9044 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -18,6 +18,27 @@ def SCFBufferize : Pass<"scf-bufferize"> {
"memref::MemRefDialect"];
}
+def SCFLoopBufferizationPreprocessing
+ : Pass<"scf-loop-bufferization-preprocessing"> {
+ let summary = "Preprocess loops before One-Shot Bufferize";
+
+ let description = [{
+ Preprocess `scf.for` loops before running One-Shot Bufferize to support
+ loops where a yielded tensor is not equivalent to the respective iter_arg.
+ Such IR is currently not supported by One-Shot Bufferize.
+
+ This pass inserts a `bufferization.materialize_in_destination` op for every
+ yielded tensor, such that the yielded value is guaranteed to materialize in
+ the future buffer of the iter_arg; this is done by copying the tensor
+ contents into the iter_arg buffer. Such memcpys are a no-op in case the
+ tensor contents already materialize in the iter_arg buffer.
+ }];
+
+ let constructor = "mlir::createSCFLoopBufferizationPreprocessingPass()";
+ let dependentDialects = ["bufferization::BufferizationDialect",
+ "scf::SCFDialect"];
+}
+
// Note: Making these canonicalization patterns would require a dependency
// of the SCF dialect on the Affine/Tensor/MemRef dialects or vice versa.
def SCFForLoopCanonicalization
diff --git a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
index 21c618ab633f60..727c4fc7c6396e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
@@ -17,6 +17,7 @@
namespace mlir {
#define GEN_PASS_DEF_SCFBUFFERIZE
+#define GEN_PASS_DEF_SCFLOOPBUFFERIZATIONPREPROCESSING
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
} // namespace mlir
@@ -40,8 +41,40 @@ struct SCFBufferizePass : public impl::SCFBufferizeBase<SCFBufferizePass> {
return signalPassFailure();
};
};
+
+struct SCFLoopBufferizationPreprocessingPass
+ : public impl::SCFLoopBufferizationPreprocessingBase<
+ SCFLoopBufferizationPreprocessingPass> {
+ void runOnOperation() override {
+ OpBuilder builder(getOperation()->getContext());
+ getOperation()->walk([&](scf::YieldOp yieldOp) {
+ builder.setInsertionPoint(yieldOp);
+ // TODO: Support scf.while.
+ auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
+ if (!forOp)
+ return WalkResult::skip();
+ for (OpOperand &operand : yieldOp->getOpOperands()) {
+ auto tensorType = dyn_cast<TensorType>(operand.get().getType());
+ if (!tensorType)
+ continue;
+ auto bbArg = forOp.getRegionIterArgs()[operand.getOperandNumber()];
+ Value materialized =
+ builder
+ .create<bufferization::MaterializeInDestinationOp>(
+ yieldOp.getLoc(), tensorType, operand.get(), bbArg)
+ .getResult();
+ operand.set(materialized);
+ }
+ return WalkResult::advance();
+ });
+ }
+};
} // namespace
std::unique_ptr<Pass> mlir::createSCFBufferizePass() {
return std::make_unique<SCFBufferizePass>();
}
+
+std::unique_ptr<Pass> mlir::createSCFLoopBufferizationPreprocessingPass() {
+ return std::make_unique<SCFLoopBufferizationPreprocessingPass>();
+}
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-preprocessing.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-preprocessing.mlir
new file mode 100644
index 00000000000000..17661178245088
--- /dev/null
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize-preprocessing.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt %s -scf-loop-bufferization-preprocessing -one-shot-bufferize="bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map" -canonicalize | FileCheck %s
+
+// CHECK-LABEL: func @conflict_in_loop(
+// CHECK-SAME: %[[A:.*]]: memref<10xf32>
+func.func @conflict_in_loop(%A: tensor<10xf32>, %f: f32, %idx: index, %lb: index, %ub: index, %step: index) -> f32 {
+ // CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+ %r = scf.for %i = %lb to %ub step %step iter_args(%tA = %A) -> (tensor<10xf32>) {
+ // CHECK: %[[alloc:.*]] = memref.alloc()
+ // CHECK: memref.copy %[[A]], %[[alloc]]
+ // CHECK: memref.store %{{.*}}, %[[alloc]]
+ %0 = tensor.insert %f into %tA[%i] : tensor<10xf32>
+ // CHECK: %[[read:.*]] = memref.load %[[A]]
+ %read = tensor.extract %tA[%idx] : tensor<10xf32>
+ // CHECK: vector.print %[[read]]
+ vector.print %read : f32
+ // CHECK: memref.copy %[[alloc]], %[[A]]
+ scf.yield %0 : tensor<10xf32>
+ }
+
+ // CHECK: memref.load %[[A]]
+ %f0 = tensor.extract %r[%step] : tensor<10xf32>
+ return %f0 : f32
+}
|
Add a bufferization preprocessing pass for
scf.for
loops to support loops where a yielded tensor value does not bufferize to the equivalent corresponding iter_arg buffer. This preprocessing works around a limitation ofscf.for
bufferization by inserting additional buffer copies for yielded tensors.This preprocessing can be used to support most cases where One-Shot Bufferize fails to bufferize the IR with the following error message:
This commit fixes iree-org/iree#16956.