-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[mlir][spirv] Allow yielding values from loop regions #135344
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
IgWod-IMG
wants to merge
1
commit into
llvm:main
Choose a base branch
from
imaginationtech:img_loop-yield
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+174
−35
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This change extends `spirv.mlir.loop` so it can yield values, the same as `spirv.mlir.selection`.
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Igor Wodiany (IgWod-IMG) ChangesThis change extends Full diff: https://github.com/llvm/llvm-project/pull/135344.diff 7 Files Affected:
diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md
index ae9afbd9fdfe5..1e8c1c7be9f6a 100644
--- a/mlir/docs/Dialects/SPIR-V.md
+++ b/mlir/docs/Dialects/SPIR-V.md
@@ -734,6 +734,18 @@ func.func @loop(%count : i32) -> () {
}
```
+Similarly to selection, loops can also yield values using `spirv.mlir.merge`. This
+mechanism allows values defined within the loop region to be used outside of it.
+
+For example
+
+```mlir
+%yielded = spirv.mlir.loop -> i32 {
+ // ...
+ spirv.mlir.merge %to_yield : i32
+}
+```
+
### Block argument for Phi
There are no direct Phi operations in the SPIR-V dialect; SPIR-V `OpPhi`
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
index 039af03871411..ef6682ab3630c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
@@ -311,17 +311,27 @@ def SPIRV_LoopOp : SPIRV_Op<"mlir.loop", [InFunctionScope]> {
The continue block should be the second to last block and it should have a
branch to the loop header block. The loop continue block should be the only
block, except the entry block, branching to the header block.
+
+ Values defined inside the loop regions cannot be directly used
+ outside of them; however, the loop region can yield values. These values are
+ yielded using a `spirv.mlir.merge` op and returned as a result of the loop op.
}];
let arguments = (ins
SPIRV_LoopControlAttr:$loop_control
);
- let results = (outs);
+ let results = (outs Variadic<AnyType>:$results);
let regions = (region AnyRegion:$body);
- let builders = [OpBuilder<(ins)>];
+ let builders = [
+ OpBuilder<(ins)>,
+ OpBuilder<(ins "spirv::LoopControl":$loopControl),
+ [{
+ build($_builder, $_state, TypeRange(), loopControl);
+ }]>
+ ];
let extraClassDeclaration = [{
// Returns the entry block.
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index ed9a30086deca..cf983af6a07ac 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -230,6 +230,11 @@ ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
result))
return failure();
+
+ if (succeeded(parser.parseOptionalArrow()))
+ if (parser.parseTypeList(result.types))
+ return failure();
+
return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
}
@@ -237,6 +242,10 @@ void LoopOp::print(OpAsmPrinter &printer) {
auto control = getLoopControl();
if (control != spirv::LoopControl::None)
printer << " control(" << spirv::stringifyLoopControl(control) << ")";
+ if (getNumResults() > 0) {
+ printer << " -> ";
+ printer << getResultTypes();
+ }
printer << ' ';
printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 25749ec598f00..2c7a93949307c 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -2003,7 +2003,8 @@ LogicalResult ControlFlowStructurizer::structurize() {
// block inside the selection (`body.back()`). Values produced by block
// arguments will be yielded by the selection region. We do not update uses or
// erase original block arguments yet. It will be done later in the code.
- if (!isLoop) {
+ // We do not currently support block arguments in loop merge blocks.
+ if (!isLoop)
for (BlockArgument blockArg : mergeBlock->getArguments()) {
// Create new block arguments in the last block ("merge block") of the
// selection region. We create one argument for each argument in
@@ -2013,7 +2014,6 @@ LogicalResult ControlFlowStructurizer::structurize() {
valuesToYield.push_back(body.back().getArguments().back());
outsideUses.push_back(blockArg);
}
- }
// All the blocks cloned into the SelectionOp/LoopOp's region can now be
// cleaned up.
@@ -2025,32 +2025,30 @@ LogicalResult ControlFlowStructurizer::structurize() {
// All internal uses should be removed from original blocks by now, so
// whatever is left is an outside use and will need to be yielded from
- // the newly created selection region.
- if (!isLoop) {
- for (Block *block : constructBlocks) {
- for (Operation &op : *block) {
- if (!op.use_empty())
- for (Value result : op.getResults()) {
- valuesToYield.push_back(mapper.lookupOrNull(result));
- outsideUses.push_back(result);
- }
- }
- for (BlockArgument &arg : block->getArguments()) {
- if (!arg.use_empty()) {
- valuesToYield.push_back(mapper.lookupOrNull(arg));
- outsideUses.push_back(arg);
+ // the newly created selection / loop region.
+ for (Block *block : constructBlocks) {
+ for (Operation &op : *block) {
+ if (!op.use_empty())
+ for (Value result : op.getResults()) {
+ valuesToYield.push_back(mapper.lookupOrNull(result));
+ outsideUses.push_back(result);
}
+ }
+ for (BlockArgument &arg : block->getArguments()) {
+ if (!arg.use_empty()) {
+ valuesToYield.push_back(mapper.lookupOrNull(arg));
+ outsideUses.push_back(arg);
}
}
}
assert(valuesToYield.size() == outsideUses.size());
- // If we need to yield any values from the selection region we will take
- // care of it here.
- if (!isLoop && !valuesToYield.empty()) {
+ // If we need to yield any values from the selection / loop region we will
+ // take care of it here.
+ if (!valuesToYield.empty()) {
LLVM_DEBUG(logger.startLine()
- << "[cf] yielding values from the selection region\n");
+ << "[cf] yielding values from the selection / loop region\n");
// Update `mlir.merge` with values to be yield.
auto mergeOps = body.back().getOps<spirv::MergeOp>();
@@ -2059,25 +2057,40 @@ LogicalResult ControlFlowStructurizer::structurize() {
merge->setOperands(valuesToYield);
// MLIR does not allow changing the number of results of an operation, so
- // we create a new SelectionOp with required list of results and move
- // the region from the initial SelectionOp. The initial operation is then
- // removed. Since we move the region to the new op all links between blocks
- // and remapping we have previously done should be preserved.
+ // we create a new SelectionOp / LoopOp with required list of results and
+ // move the region from the initial SelectionOp / LoopOp. The initial
+ // operation is then removed. Since we move the region to the new op all
+ // links between blocks and remapping we have previously done should be
+ // preserved.
builder.setInsertionPoint(&mergeBlock->front());
- auto selectionOp = builder.create<spirv::SelectionOp>(
- location, TypeRange(ValueRange(outsideUses)),
- static_cast<spirv::SelectionControl>(control));
- selectionOp->getRegion(0).takeBody(body);
+
+ Operation *newOp = nullptr;
+
+ if (isLoop)
+ newOp = builder.create<spirv::LoopOp>(
+ location, TypeRange(ValueRange(outsideUses)),
+ static_cast<spirv::LoopControl>(control));
+ else
+ newOp = builder.create<spirv::SelectionOp>(
+ location, TypeRange(ValueRange(outsideUses)),
+ static_cast<spirv::SelectionControl>(control));
+
+ newOp->getRegion(0).takeBody(body);
// Remove initial op and swap the pointer to the newly created one.
op->erase();
- op = selectionOp;
+ op = newOp;
- // Update all outside uses to use results of the SelectionOp and remove
- // block arguments from the original merge block.
+ // Update all outside uses to use results of the SelectionOp / LoopOp and
+ // remove block arguments from the original merge block.
for (unsigned i = 0, e = outsideUses.size(); i != e; ++i)
- outsideUses[i].replaceAllUsesWith(selectionOp.getResult(i));
- mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
+ outsideUses[i].replaceAllUsesWith(op->getResult(i));
+
+ // We do not support block arguments in loop merge block. Also running this
+ // function with loop would break some of the loop specific code above
+ // dealing with block arguments.
+ if (!isLoop)
+ mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
}
// Check that whether some op in the to-be-erased blocks still has uses. Those
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 5ed59a4134d37..ff3cc92ee8078 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -520,6 +520,13 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
auto mergeID = getBlockID(mergeBlock);
auto loc = loopOp.getLoc();
+ // Before we do anything replace results of the selection operation with
+ // values yielded (with `mlir.merge`) from inside the region.
+ auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
+ assert(loopOp.getNumResults() == mergeOp.getNumOperands());
+ for (unsigned i = 0, e = loopOp.getNumResults(); i != e; ++i)
+ loopOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
+
// This LoopOp is in some MLIR block with preceding and following ops. In the
// binary format, it should reside in separate SPIR-V blocks from its
// preceding and following ops. So we need to emit unconditional branches to
diff --git a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
index 107c8a3207b02..8ec0bf5bbaacf 100644
--- a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
@@ -426,6 +426,47 @@ func.func @only_entry_and_continue_branch_to_header() -> () {
// -----
+func.func @loop_yield(%count : i32) -> () {
+ %zero = spirv.Constant 0: i32
+ %one = spirv.Constant 1: i32
+ %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+ // CHECK: {{%.*}} = spirv.mlir.loop -> i32 {
+ %final_i = spirv.mlir.loop -> i32 {
+ // CHECK-NEXT: spirv.Branch ^bb1({{%.*}}: i32)
+ spirv.Branch ^header(%zero: i32)
+
+ // CHECK-NEXT: ^bb1({{%.*}}: i32):
+ ^header(%i : i32):
+ %cmp = spirv.SLessThan %i, %count : i32
+ // CHECK: spirv.BranchConditional %{{.*}}, ^bb2, ^bb4
+ spirv.BranchConditional %cmp, ^body, ^merge
+
+ // CHECK-NEXT: ^bb2:
+ ^body:
+ // CHECK-NEXT: spirv.Branch ^bb3
+ spirv.Branch ^continue
+
+ // CHECK-NEXT: ^bb3:
+ ^continue:
+ %new_i = spirv.IAdd %i, %one : i32
+ // CHECK: spirv.Branch ^bb1({{%.*}}: i32)
+ spirv.Branch ^header(%new_i: i32)
+
+ // CHECK-NEXT: ^bb4:
+ ^merge:
+ // CHECK-NEXT: spirv.mlir.merge {{%.*}} : i32
+ spirv.mlir.merge %i : i32
+ }
+
+ // CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32
+ spirv.Store "Function" %var, %final_i : i32
+
+ return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.mlir.merge
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/loop.mlir b/mlir/test/Target/SPIRV/loop.mlir
index d89600558f56d..95b87b319ac2d 100644
--- a/mlir/test/Target/SPIRV/loop.mlir
+++ b/mlir/test/Target/SPIRV/loop.mlir
@@ -288,3 +288,50 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Linkage, Addre
spirv.Return
}
}
+
+// -----
+
+// Loop yielding values
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
+ spirv.func @loop_yield(%count : i32) -> () "None" {
+ %zero = spirv.Constant 0: i32
+ %one = spirv.Constant 1: i32
+ %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+// CHECK: {{%.*}} = spirv.mlir.loop -> i32 {
+ %final_i = spirv.mlir.loop -> i32 {
+// CHECK-NEXT: spirv.Branch ^[[HEADER:.+]]({{%.*}}: i32)
+ spirv.Branch ^header(%zero: i32)
+
+// CHECK-NEXT: ^[[HEADER:.+]]({{%.*}}: i32):
+ ^header(%i : i32):
+ %cmp = spirv.SLessThan %i, %count : i32
+// CHECK: spirv.BranchConditional %{{.*}}, ^[[BODY:.+]], ^[[MERGE:.+]]
+ spirv.BranchConditional %cmp, ^body, ^merge
+
+// CHECK-NEXT: ^[[BODY:.+]]:
+ ^body:
+// CHECK-NEXT: spirv.Branch ^[[CONTINUE:.+]]
+ spirv.Branch ^continue
+
+// CHECK-NEXT: ^[[CONTINUE:.+]]:
+ ^continue:
+ %new_i = spirv.IAdd %i, %one : i32
+// CHECK: spirv.Branch ^[[HEADER:.+]]({{%.*}}: i32)
+ spirv.Branch ^header(%new_i: i32)
+
+// CHECK-NEXT: ^[[MERGE:.+]]:
+ ^merge:
+// CHECK-NEXT: spirv.mlir.merge {{%.*}} : i32
+ spirv.mlir.merge %i : i32
+// CHECK-NEXT: }
+ }
+
+// CHECK-NEXT: spirv.Store "Function" {{%.*}}, {{%.*}} : i32
+ spirv.Store "Function" %var, %final_i : i32
+
+// CHECK-NEXT: spirv.Return
+ spirv.Return
+ }
+}
|
kuhar
approved these changes
Apr 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This change extends
spirv.mlir.loop
so it can yield values, the same asspirv.mlir.selection
.