Skip to content

[mlir][SCF] Add scf.for bufferization preprocessing pass #87594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

matthias-springer
Copy link
Member

Add a bufferization preprocessing pass for scf.for loops to support loops where a yielded tensor value does not bufferize to the equivalent corresponding iter_arg buffer. This preprocessing works around a limitation of scf.for bufferization by inserting additional buffer copies for yielded tensors.

This preprocessing can be used to support most cases where One-Shot Bufferize fails to bufferize the IR with the following error message:

error: Yield operand #0 is not equivalent to the corresponding iter bbArg

This commit fixes iree-org/iree#16956.

Add a bufferization preprocessing pass for `scf.for` loops to support loops where a yielded tensor value does not bufferize to the equivalent corresponding iter_arg buffer. This preprocessing works around a limitation of `scf.for` bufferization by inserting additional buffer copies for yielded tensors.

This preprocessing can be used to support most cases where One-Shot Bufferize fails to bufferize the IR with the following error message:
```
error: Yield operand #0 is not equivalent to the corresponding iter bbArg
```
@llvmbot
Copy link
Member

llvmbot commented Apr 4, 2024

@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Add a bufferization preprocessing pass for scf.for loops to support loops where a yielded tensor value does not bufferize to the equivalent corresponding iter_arg buffer. This preprocessing works around a limitation of scf.for bufferization by inserting additional buffer copies for yielded tensors.

This preprocessing can be used to support most cases where One-Shot Bufferize fails to bufferize the IR with the following error message:

error: Yield operand #<!-- -->0 is not equivalent to the corresponding iter bbArg

This commit fixes openxla/iree#16956.


Full diff: https://github.com/llvm/llvm-project/pull/87594.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/Passes.h (+3)
  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/Passes.td (+21)
  • (modified) mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp (+33)
  • (added) mlir/test/Dialect/SCF/one-shot-bufferize-preprocessing.mlir (+23)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
index 90b315e83a8cfd..6107219ea94ae1 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
@@ -23,6 +23,9 @@ namespace mlir {
 /// Creates a pass that bufferizes the SCF dialect.
 std::unique_ptr<Pass> createSCFBufferizePass();
 
+/// Creates a pass that preprocesses SCF loop before One-Shot Bufferize.
+std::unique_ptr<Pass> createSCFLoopBufferizationPreprocessingPass();
+
 /// Creates a pass that specializes for loop for unrolling and
 /// vectorization.
 std::unique_ptr<Pass> createForLoopSpecializationPass();
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 350611ad86873d..94d3e51a1c9044 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -18,6 +18,27 @@ def SCFBufferize : Pass<"scf-bufferize"> {
                            "memref::MemRefDialect"];
 }
 
+def SCFLoopBufferizationPreprocessing
+    : Pass<"scf-loop-bufferization-preprocessing"> {
+  let summary = "Preprocess loops before One-Shot Bufferize";
+
+  let description = [{
+    Preprocess `scf.for` loops before running One-Shot Bufferize to support
+    loops where a yielded tensor is not equivalent to the respective iter_arg.
+    Such IR is currently not supported by One-Shot Bufferize.
+
+    This pass inserts a `bufferization.materialize_in_destination` op for every
+    yielded tensor, such that the yielded value is guaranteed to materialize in
+    the future buffer of the iter_arg; this is done by copying the tensor
+    contents into the iter_arg buffer. Such memcpys are a no-op in case the
+    tensor contents already materialize in the iter_arg buffer.
+  }];
+
+  let constructor = "mlir::createSCFLoopBufferizationPreprocessingPass()";
+  let dependentDialects = ["bufferization::BufferizationDialect",
+                           "scf::SCFDialect"];
+}
+
 // Note: Making these canonicalization patterns would require a dependency
 // of the SCF dialect on the Affine/Tensor/MemRef dialects or vice versa.
 def SCFForLoopCanonicalization
diff --git a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
index 21c618ab633f60..727c4fc7c6396e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
@@ -17,6 +17,7 @@
 
 namespace mlir {
 #define GEN_PASS_DEF_SCFBUFFERIZE
+#define GEN_PASS_DEF_SCFLOOPBUFFERIZATIONPREPROCESSING
 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
 } // namespace mlir
 
@@ -40,8 +41,40 @@ struct SCFBufferizePass : public impl::SCFBufferizeBase<SCFBufferizePass> {
       return signalPassFailure();
   };
 };
+
+struct SCFLoopBufferizationPreprocessingPass
+    : public impl::SCFLoopBufferizationPreprocessingBase<
+          SCFLoopBufferizationPreprocessingPass> {
+  void runOnOperation() override {
+    OpBuilder builder(getOperation()->getContext());
+    getOperation()->walk([&](scf::YieldOp yieldOp) {
+      builder.setInsertionPoint(yieldOp);
+      // TODO: Support scf.while.
+      auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
+      if (!forOp)
+        return WalkResult::skip();
+      for (OpOperand &operand : yieldOp->getOpOperands()) {
+        auto tensorType = dyn_cast<TensorType>(operand.get().getType());
+        if (!tensorType)
+          continue;
+        auto bbArg = forOp.getRegionIterArgs()[operand.getOperandNumber()];
+        Value materialized =
+            builder
+                .create<bufferization::MaterializeInDestinationOp>(
+                    yieldOp.getLoc(), tensorType, operand.get(), bbArg)
+                .getResult();
+        operand.set(materialized);
+      }
+      return WalkResult::advance();
+    });
+  }
+};
 } // namespace
 
 std::unique_ptr<Pass> mlir::createSCFBufferizePass() {
   return std::make_unique<SCFBufferizePass>();
 }
+
+std::unique_ptr<Pass> mlir::createSCFLoopBufferizationPreprocessingPass() {
+  return std::make_unique<SCFLoopBufferizationPreprocessingPass>();
+}
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-preprocessing.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-preprocessing.mlir
new file mode 100644
index 00000000000000..17661178245088
--- /dev/null
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize-preprocessing.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt %s -scf-loop-bufferization-preprocessing -one-shot-bufferize="bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map" -canonicalize | FileCheck %s
+
+// CHECK-LABEL: func @conflict_in_loop(
+//  CHECK-SAME:     %[[A:.*]]: memref<10xf32>
+func.func @conflict_in_loop(%A: tensor<10xf32>, %f: f32, %idx: index, %lb: index, %ub: index, %step: index) -> f32 {
+  // CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+  %r = scf.for %i = %lb to %ub step %step iter_args(%tA = %A) -> (tensor<10xf32>) {
+    // CHECK: %[[alloc:.*]] = memref.alloc()
+    // CHECK: memref.copy %[[A]], %[[alloc]]
+    // CHECK: memref.store %{{.*}}, %[[alloc]]
+    %0 = tensor.insert %f into %tA[%i] : tensor<10xf32>
+    // CHECK: %[[read:.*]] = memref.load %[[A]]
+    %read = tensor.extract %tA[%idx] : tensor<10xf32>
+    // CHECK: vector.print %[[read]]
+    vector.print %read : f32
+    // CHECK: memref.copy %[[alloc]], %[[A]]
+    scf.yield %0 : tensor<10xf32>
+  }
+
+  // CHECK: memref.load %[[A]]
+  %f0 = tensor.extract %r[%step] : tensor<10xf32>
+  return %f0 : f32
+}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Missing support in scf.for bufferization
2 participants