Skip to content

[mlir][spirv] Allow yielding values from loop regions #135344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

IgWod-IMG
Copy link
Contributor

This change extends spirv.mlir.loop so it can yield values, the same as spirv.mlir.selection.

This change extends `spirv.mlir.loop` so it can yield values, the
same as `spirv.mlir.selection`.
@llvmbot
Copy link
Member

llvmbot commented Apr 11, 2025

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Igor Wodiany (IgWod-IMG)

Changes

This change extends spirv.mlir.loop so it can yield values, the same as spirv.mlir.selection.


Full diff: https://github.com/llvm/llvm-project/pull/135344.diff

7 Files Affected:

  • (modified) mlir/docs/Dialects/SPIR-V.md (+12)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td (+12-2)
  • (modified) mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp (+9)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+46-33)
  • (modified) mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp (+7)
  • (modified) mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir (+41)
  • (modified) mlir/test/Target/SPIRV/loop.mlir (+47)
diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md
index ae9afbd9fdfe5..1e8c1c7be9f6a 100644
--- a/mlir/docs/Dialects/SPIR-V.md
+++ b/mlir/docs/Dialects/SPIR-V.md
@@ -734,6 +734,18 @@ func.func @loop(%count : i32) -> () {
 }
 ```
 
+Similarly to selection, loops can also yield values using `spirv.mlir.merge`. This
+mechanism allows values defined within the loop region to be used outside of it.
+
+For example
+
+```mlir
+%yielded = spirv.mlir.loop -> i32 {
+  // ...
+  spirv.mlir.merge %to_yield : i32
+}
+```
+
 ### Block argument for Phi
 
 There are no direct Phi operations in the SPIR-V dialect; SPIR-V `OpPhi`
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
index 039af03871411..ef6682ab3630c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
@@ -311,17 +311,27 @@ def SPIRV_LoopOp : SPIRV_Op<"mlir.loop", [InFunctionScope]> {
     The continue block should be the second to last block and it should have a
     branch to the loop header block. The loop continue block should be the only
     block, except the entry block, branching to the header block.
+
+    Values defined inside the loop regions cannot be directly used
+    outside of them; however, the loop region can yield values. These values are
+    yielded using a `spirv.mlir.merge` op and returned as a result of the loop op.
   }];
 
   let arguments = (ins
     SPIRV_LoopControlAttr:$loop_control
   );
 
-  let results = (outs);
+  let results = (outs Variadic<AnyType>:$results);
 
   let regions = (region AnyRegion:$body);
 
-  let builders = [OpBuilder<(ins)>];
+  let builders = [
+    OpBuilder<(ins)>,
+    OpBuilder<(ins "spirv::LoopControl":$loopControl),
+    [{
+      build($_builder, $_state, TypeRange(), loopControl);
+    }]>
+  ];
 
   let extraClassDeclaration = [{
     // Returns the entry block.
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index ed9a30086deca..cf983af6a07ac 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -230,6 +230,11 @@ ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
   if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
                                                                         result))
     return failure();
+
+  if (succeeded(parser.parseOptionalArrow()))
+    if (parser.parseTypeList(result.types))
+      return failure();
+
   return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
 }
 
@@ -237,6 +242,10 @@ void LoopOp::print(OpAsmPrinter &printer) {
   auto control = getLoopControl();
   if (control != spirv::LoopControl::None)
     printer << " control(" << spirv::stringifyLoopControl(control) << ")";
+  if (getNumResults() > 0) {
+    printer << " -> ";
+    printer << getResultTypes();
+  }
   printer << ' ';
   printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
                       /*printBlockTerminators=*/true);
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 25749ec598f00..2c7a93949307c 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -2003,7 +2003,8 @@ LogicalResult ControlFlowStructurizer::structurize() {
   // block inside the selection (`body.back()`). Values produced by block
   // arguments will be yielded by the selection region. We do not update uses or
   // erase original block arguments yet. It will be done later in the code.
-  if (!isLoop) {
+  // We do not currently support block arguments in loop merge blocks.
+  if (!isLoop)
     for (BlockArgument blockArg : mergeBlock->getArguments()) {
       // Create new block arguments in the last block ("merge block") of the
       // selection region. We create one argument for each argument in
@@ -2013,7 +2014,6 @@ LogicalResult ControlFlowStructurizer::structurize() {
       valuesToYield.push_back(body.back().getArguments().back());
       outsideUses.push_back(blockArg);
     }
-  }
 
   // All the blocks cloned into the SelectionOp/LoopOp's region can now be
   // cleaned up.
@@ -2025,32 +2025,30 @@ LogicalResult ControlFlowStructurizer::structurize() {
 
   // All internal uses should be removed from original blocks by now, so
   // whatever is left is an outside use and will need to be yielded from
-  // the newly created selection region.
-  if (!isLoop) {
-    for (Block *block : constructBlocks) {
-      for (Operation &op : *block) {
-        if (!op.use_empty())
-          for (Value result : op.getResults()) {
-            valuesToYield.push_back(mapper.lookupOrNull(result));
-            outsideUses.push_back(result);
-          }
-      }
-      for (BlockArgument &arg : block->getArguments()) {
-        if (!arg.use_empty()) {
-          valuesToYield.push_back(mapper.lookupOrNull(arg));
-          outsideUses.push_back(arg);
+  // the newly created selection / loop region.
+  for (Block *block : constructBlocks) {
+    for (Operation &op : *block) {
+      if (!op.use_empty())
+        for (Value result : op.getResults()) {
+          valuesToYield.push_back(mapper.lookupOrNull(result));
+          outsideUses.push_back(result);
         }
+    }
+    for (BlockArgument &arg : block->getArguments()) {
+      if (!arg.use_empty()) {
+        valuesToYield.push_back(mapper.lookupOrNull(arg));
+        outsideUses.push_back(arg);
       }
     }
   }
 
   assert(valuesToYield.size() == outsideUses.size());
 
-  // If we need to yield any values from the selection region we will take
-  // care of it here.
-  if (!isLoop && !valuesToYield.empty()) {
+  // If we need to yield any values from the selection / loop region we will
+  // take care of it here.
+  if (!valuesToYield.empty()) {
     LLVM_DEBUG(logger.startLine()
-               << "[cf] yielding values from the selection region\n");
+               << "[cf] yielding values from the selection / loop region\n");
 
     // Update `mlir.merge` with values to be yield.
     auto mergeOps = body.back().getOps<spirv::MergeOp>();
@@ -2059,25 +2057,40 @@ LogicalResult ControlFlowStructurizer::structurize() {
     merge->setOperands(valuesToYield);
 
     // MLIR does not allow changing the number of results of an operation, so
-    // we create a new SelectionOp with required list of results and move
-    // the region from the initial SelectionOp. The initial operation is then
-    // removed. Since we move the region to the new op all links between blocks
-    // and remapping we have previously done should be preserved.
+    // we create a new SelectionOp / LoopOp with required list of results and
+    // move the region from the initial SelectionOp / LoopOp. The initial
+    // operation is then removed. Since we move the region to the new op all
+    // links between blocks and remapping we have previously done should be
+    // preserved.
     builder.setInsertionPoint(&mergeBlock->front());
-    auto selectionOp = builder.create<spirv::SelectionOp>(
-        location, TypeRange(ValueRange(outsideUses)),
-        static_cast<spirv::SelectionControl>(control));
-    selectionOp->getRegion(0).takeBody(body);
+
+    Operation *newOp = nullptr;
+
+    if (isLoop)
+      newOp = builder.create<spirv::LoopOp>(
+          location, TypeRange(ValueRange(outsideUses)),
+          static_cast<spirv::LoopControl>(control));
+    else
+      newOp = builder.create<spirv::SelectionOp>(
+          location, TypeRange(ValueRange(outsideUses)),
+          static_cast<spirv::SelectionControl>(control));
+
+    newOp->getRegion(0).takeBody(body);
 
     // Remove initial op and swap the pointer to the newly created one.
     op->erase();
-    op = selectionOp;
+    op = newOp;
 
-    // Update all outside uses to use results of the SelectionOp and remove
-    // block arguments from the original merge block.
+    // Update all outside uses to use results of the SelectionOp / LoopOp and
+    // remove block arguments from the original merge block.
     for (unsigned i = 0, e = outsideUses.size(); i != e; ++i)
-      outsideUses[i].replaceAllUsesWith(selectionOp.getResult(i));
-    mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
+      outsideUses[i].replaceAllUsesWith(op->getResult(i));
+
+    // We do not support block arguments in loop merge block. Also running this
+    // function with loop would break some of the loop specific code above
+    // dealing with block arguments.
+    if (!isLoop)
+      mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
   }
 
   // Check that whether some op in the to-be-erased blocks still has uses. Those
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 5ed59a4134d37..ff3cc92ee8078 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -520,6 +520,13 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
   auto mergeID = getBlockID(mergeBlock);
   auto loc = loopOp.getLoc();
 
+  // Before we do anything replace results of the selection operation with
+  // values yielded (with `mlir.merge`) from inside the region.
+  auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
+  assert(loopOp.getNumResults() == mergeOp.getNumOperands());
+  for (unsigned i = 0, e = loopOp.getNumResults(); i != e; ++i)
+    loopOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
+
   // This LoopOp is in some MLIR block with preceding and following ops. In the
   // binary format, it should reside in separate SPIR-V blocks from its
   // preceding and following ops. So we need to emit unconditional branches to
diff --git a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
index 107c8a3207b02..8ec0bf5bbaacf 100644
--- a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
@@ -426,6 +426,47 @@ func.func @only_entry_and_continue_branch_to_header() -> () {
 
 // -----
 
+func.func @loop_yield(%count : i32) -> () {
+  %zero = spirv.Constant 0: i32
+  %one = spirv.Constant 1: i32
+  %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+  // CHECK: {{%.*}} = spirv.mlir.loop -> i32 {
+  %final_i = spirv.mlir.loop -> i32 {
+    // CHECK-NEXT: spirv.Branch ^bb1({{%.*}}: i32)
+    spirv.Branch ^header(%zero: i32)
+
+  // CHECK-NEXT: ^bb1({{%.*}}: i32):
+  ^header(%i : i32):
+    %cmp = spirv.SLessThan %i, %count : i32
+    // CHECK: spirv.BranchConditional %{{.*}}, ^bb2, ^bb4
+    spirv.BranchConditional %cmp, ^body, ^merge
+
+  // CHECK-NEXT: ^bb2:
+  ^body:
+    // CHECK-NEXT: spirv.Branch ^bb3
+    spirv.Branch ^continue
+
+  // CHECK-NEXT: ^bb3:
+  ^continue:
+    %new_i = spirv.IAdd %i, %one : i32
+    // CHECK: spirv.Branch ^bb1({{%.*}}: i32)
+    spirv.Branch ^header(%new_i: i32)
+
+  // CHECK-NEXT: ^bb4:
+  ^merge:
+    // CHECK-NEXT: spirv.mlir.merge {{%.*}} : i32
+    spirv.mlir.merge %i : i32
+  }
+
+  // CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32
+  spirv.Store "Function" %var, %final_i : i32
+
+  return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.mlir.merge
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/loop.mlir b/mlir/test/Target/SPIRV/loop.mlir
index d89600558f56d..95b87b319ac2d 100644
--- a/mlir/test/Target/SPIRV/loop.mlir
+++ b/mlir/test/Target/SPIRV/loop.mlir
@@ -288,3 +288,50 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Linkage, Addre
     spirv.Return
   }
 }
+
+// -----
+
+// Loop yielding values
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
+  spirv.func @loop_yield(%count : i32) -> () "None" {
+    %zero = spirv.Constant 0: i32
+    %one = spirv.Constant 1: i32
+    %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+// CHECK: {{%.*}} = spirv.mlir.loop -> i32 {
+    %final_i = spirv.mlir.loop -> i32 {
+// CHECK-NEXT: spirv.Branch ^[[HEADER:.+]]({{%.*}}: i32)
+      spirv.Branch ^header(%zero: i32)
+
+// CHECK-NEXT: ^[[HEADER:.+]]({{%.*}}: i32):
+    ^header(%i : i32):
+      %cmp = spirv.SLessThan %i, %count : i32
+// CHECK: spirv.BranchConditional %{{.*}}, ^[[BODY:.+]], ^[[MERGE:.+]]
+      spirv.BranchConditional %cmp, ^body, ^merge
+
+// CHECK-NEXT: ^[[BODY:.+]]:
+    ^body:
+// CHECK-NEXT: spirv.Branch ^[[CONTINUE:.+]]
+      spirv.Branch ^continue
+
+// CHECK-NEXT: ^[[CONTINUE:.+]]:
+    ^continue:
+      %new_i = spirv.IAdd %i, %one : i32
+// CHECK: spirv.Branch ^[[HEADER:.+]]({{%.*}}: i32)
+      spirv.Branch ^header(%new_i: i32)
+
+// CHECK-NEXT: ^[[MERGE:.+]]:
+    ^merge:
+// CHECK-NEXT: spirv.mlir.merge {{%.*}} : i32
+      spirv.mlir.merge %i : i32
+// CHECK-NEXT: }
+    }
+
+// CHECK-NEXT: spirv.Store "Function" {{%.*}}, {{%.*}} : i32
+    spirv.Store "Function" %var, %final_i : i32
+
+// CHECK-NEXT: spirv.Return
+    spirv.Return
+  }
+}

@kuhar kuhar requested a review from andfau-amd April 11, 2025 14:11
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants