Skip to content

[mlir][scf] Add a scope op to the scf dialect #89274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

aniragil
Copy link
Contributor

Add to the scf dialect an operation modeling an isolated-from-above
single basic block that is executed once. It provides a localized,
hierarchical alternative to outlining code into a function/call pair.

Add to the scf dialect an operation modeling an isolated-from-above
single basic block that is executed once. It provides a localized,
hierarchical alternative to outlining code into a function/call pair.
@llvmbot
Copy link
Member

llvmbot commented Apr 18, 2024

@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir

Author: Gil Rapaport (aniragil)

Changes

Add to the scf dialect an operation modeling an isolated-from-above
single basic block that is executed once. It provides a localized,
hierarchical alternative to outlining code into a function/call pair.


Full diff: https://github.com/llvm/llvm-project/pull/89274.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+36-1)
  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+19)
  • (modified) mlir/test/Dialect/SCF/invalid.mlir (+26-1)
  • (modified) mlir/test/Dialect/SCF/ops.mlir (+25)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index b3d085bfff1af9..6c04fe4cdd651d 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -931,6 +931,41 @@ def ReduceReturnOp :
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// ScopeOp
+//===----------------------------------------------------------------------===//
+
+def ScopeOp : SCF_Op<"scope",
+    [AutomaticAllocationScope,
+     RecursiveMemoryEffects,
+     IsolatedFromAbove,
+     SingleBlockImplicitTerminator<"scf::YieldOp">]> {
+  let summary = "isolated code scope";
+  let description = [{
+    The 'scope' op encapsulates computations by providing an isolated-from-above,
+    executed-once single basic block. The op takes any number of operands, and
+    its return values are defined by its terminating `scf.yield`. For example:
+
+    ```mlir
+    %p:2 = scf.scope %arg0, %c77, %arg1, %arg2 : (i32, i32, f32, f32) -> (i32, f32) {
+    ^bb0(%a : i32, %b : i32, %c : f32, %d : f32):
+      %add = arith.addi %a, %b : i32
+      %mul = arith.mulf %c, %d : f32
+      scf.yield %add, %mul : i32, f32
+    }
+    ```
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$operands);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region AnyRegion:$body);
+
+  let assemblyFormat = [{
+    $operands attr-dict `:` functional-type($operands, $results) $body
+  }];
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // WhileOp
 //===----------------------------------------------------------------------===//
@@ -1155,7 +1190,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
 
 def YieldOp : SCF_Op<"yield", [Pure, ReturnLike, Terminator,
     ParentOneOf<["ExecuteRegionOp", "ForOp", "IfOp", "IndexSwitchOp",
-                 "WhileOp"]>]> {
+                 "ScopeOp", "WhileOp"]>]> {
   let summary = "loop yield and termination operation";
   let description = [{
     "scf.yield" yields an SSA value from the SCF dialect op region and
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5bca8e85f889d9..6a8d28afdbc251 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3079,6 +3079,25 @@ LogicalResult ReduceReturnOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ScopeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult scf::ScopeOp::verify() {
+  Region &body = getBody();
+  Block &block = body.front();
+  Operation *terminator = block.getTerminator();
+  if (terminator->getOperands().getTypes() != getResults().getTypes()) {
+    InFlightDiagnostic diag = emitOpError()
+                              << "expects terminator operands to have the "
+                                 "same type as results of the operation";
+    diag.attachNote(terminator->getLoc()) << "terminator";
+    return diag;
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // WhileOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 337eb9eeb8fa57..6871881a49458c 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -476,7 +476,7 @@ func.func @parallel_invalid_yield(
 
 func.func @yield_invalid_parent_op() {
   "my.op"() ({
-   // expected-error@+1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.index_switch, scf.while'}}
+   // expected-error@+1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.index_switch, scf.scope, scf.while'}}
    scf.yield
   }) : () -> ()
   return
@@ -747,3 +747,28 @@ func.func @parallel_missing_terminator(%0 : index) {
   return
 }
 
+// -----
+
+func.func @scope_not_isolated_from_above(%arg0 : i32, %arg1 : i32) -> (i32) {
+  // expected-note @below {{required by region isolation constraints}}
+  %p = scf.scope : () -> (i32) {
+  ^bb0():
+    // expected-error @below {{'arith.addi' op using value defined outside the region}}
+    %add = arith.addi %arg0, %arg1 : i32
+    scf.yield %add : i32
+  }
+  return %p : i32
+}
+
+// -----
+
+func.func @scope_yield_results_mismatch(%arg0 : i32, %arg1 : i32) -> (i32) {
+  // expected-error @below {{'scf.scope' op expects terminator operands to have the same type as results of the operation}}
+  %p = scf.scope %arg0, %arg1 : (i32, i32) -> (i32) {
+  ^bb0(%k : i32, %t : i32):
+    %add = arith.addi %k, %t : i32
+    // expected-note @below {{terminator}}
+    scf.yield %add, %add : i32, i32
+  }
+  return %p : i32
+}
diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir
index 7f457ef3b6ba0c..70cb72a9ec0bf6 100644
--- a/mlir/test/Dialect/SCF/ops.mlir
+++ b/mlir/test/Dialect/SCF/ops.mlir
@@ -441,3 +441,28 @@ func.func @switch(%arg0: index) -> i32 {
 
   return %0 : i32
 }
+
+// CHECK-LABEL: @scope
+// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32
+func.func @scope(%arg0 : i32, %arg1 : f32, %arg2 : f32) -> (f32) {
+  // CHECK: %[[VAL_3:.*]] = arith.constant 77 : i32
+  %c77 = arith.constant 77 : i32
+
+  // CHECK: %[[VAL_4:.*]]:2 = scf.scope %[[VAL_0]], %[[VAL_3]], %[[VAL_1]], %[[VAL_2]] : (i32, i32, f32, f32) -> (i32, f32)
+  %p:2 = scf.scope %arg0, %c77, %arg1, %arg2 : (i32, i32, f32, f32) -> (i32, f32) {
+  // CHECK: ^bb0(%[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32, %[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32):
+  ^bb0(%a : i32, %b : i32, %c : f32, %d : f32):
+    // CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_5]], %[[VAL_6]] : i32
+    %add = arith.addi %a, %b : i32
+    // CHECK: %[[VAL_10:.*]] = arith.mulf %[[VAL_7]], %[[VAL_8]] : f32
+    %mul = arith.mulf %c, %d : f32
+    // CHECK: scf.yield %[[VAL_9]], %[[VAL_10]] : i32, f32
+    scf.yield %add, %mul : i32, f32
+  }
+
+  // CHECK: %[[VAL_11:.*]] = arith.sitofp %[[VAL_4:.*]]#0 : i32 to f32
+  %m = arith.sitofp %p#0 : i32 to f32
+  // CHECK: %[[VAL_13:.*]] = arith.subf %[[VAL_11]], %[[VAL_4]]#1 : f32
+  %r = arith.subf %m, %p#1 : f32
+  return %r : f32
+}

@aniragil aniragil marked this pull request as draft April 18, 2024 17:38
@joker-eph
Copy link
Collaborator

Wasn't there another PR already where we talked about this?

@ftynse
Copy link
Member

ftynse commented Apr 19, 2024

Wasn't there another PR already where we talked about this?

There was a draft PR in which you had explicitly asked for this op to be split out: #87352 (comment).

@joker-eph
Copy link
Collaborator

Wasn't there another PR already where we talked about this?

There was a draft PR in which you had explicitly asked for this op to be split out: #87352 (comment).

Thanks!
I wrote "I don't understand the description, and this the motivation for adding this operation right now, can you elaborate?", and I think this is still relevant right now. The PR description does not seem to provide enough context to me.
Without understanding why this op is needed, I'm not sure why we're adding it right now. Can we get the end-to-end example? When and where will it be used?

@aniragil
Copy link
Contributor Author

aniragil commented May 9, 2024

Without understanding why this op is needed, I'm not sure why we're adding it right now. Can we get the end-to-end example? When and where will it be used?

[Sorry for the late response] Right, I uploaded this draft following Renato's EuroLLVM'24 round table, hopefully making the discussion on whether such an op is required a bit more concrete.
As @ftynse mentioned, one use case (and the original motivation) for this op is a proposed transform op that uses it in order to isolate transformations. The round table raised a second possible use case as a means to encapsulate graphs of ops (e.g. linalg.generic) in order to retain their idiom (in response to an issue raised by @mgehre).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants