-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][python] Add normalforms to capture preconditions of transforms #79449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir Author: None (martin-luecke) ChangesThis adds the concept of a Normalform to the transform dialect python extras. Normalforms are defined by a sequence of transforms required to ensure this normalform. The term and design is inspired by a similar concept in the term rewriting community. Each transform handle adheres to a normalform. Initially simply Note that the normalform of a handle is a static property currently without precise info about the actual payload. Normalizing a handle to a normalform might trigger propagation of this form to parent and child handles, depending on the specific normalform. class PerfectForNestForm(Normalform):
propagate_up = False
propagate_down = True
def _impl(cls, handle: OpHandle) -> None:
with handle.apply_patterns():
structured.ApplyTilingCanonicalizationPatternsOp()
loop.ApplyForLoopCanonicalizationPatternsOp()
transform.ApplyCanonicalizationPatternsOp()
handle.apply_licm()
handle.apply_cse() This normalform is propagated only to child handles, as all handles to operations that are nested at a deeper level will have also been impacted by these transforms and consist of perfectly nested loops, if possible. This normalform is not propagated to parent handles as these are not impacted by this specific normalization. With the current design Normalforms are never instantiated. Only ever a type of normalform is used, e.g. class Handle(ir.Value):
@<!-- -->property
def normalform(self) -> Type["Normalform"]: This could possibly also be modeled as an Full diff: https://github.com/llvm/llvm-project/pull/79449.diff 2 Files Affected:
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 8d045cad7a4a36f..ccfe4a6babb6fde 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -2,7 +2,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from typing import Callable, Optional, Sequence, Union
+from typing import Callable, Optional, Sequence, Type, TypeVar, Union
from ....extras.meta import region_op
from .... import ir
@@ -20,6 +20,8 @@
)
from .. import structured
+HandleT = TypeVar("HandleT", bound="Handle")
+
class Handle(ir.Value):
"""
@@ -42,6 +44,46 @@ def __init__(
super().__init__(v)
self.parent = parent
self.children = children if children is not None else []
+ self._normalform = Normalform
+
+ @property
+ def normalform(self) -> Type["Normalform"]:
+ """
+ The normalform of this handle. This is a static property of the handle
+ and indicates a group of previously applied transforms. This can be used
+ by subsequent transforms to statically reason about the structure of the
+ payload operations and whether other enabling transforms could possibly
+ be skipped.
+ Setting this property triggers propagation of the normalform to parent
+ and child handles depending on the specific normalform.
+ """
+ return self._normalform
+
+ @normalform.setter
+ def normalform(self, normalform: Type["Normalform"]):
+ self._normalform = normalform
+ if self._normalform.propagate_up:
+ self.propagate_up_normalform(normalform)
+ if self._normalform.propagate_down:
+ self.propagate_down_normalform(normalform)
+
+ def propagate_up_normalform(self, normalform: Type["Normalform"]):
+ if self.parent:
+ # We set the parent normalform directly to avoid infinite recursion
+ # in case this normalform needs to be propagated up and down.
+ self.parent._normalform = normalform
+ self.parent.propagate_up_normalform(normalform)
+
+ def propagate_down_normalform(self, normalform: Type["Normalform"]):
+ for child in self.children:
+ # We set the child normalform directly to avoid infinite recursion
+ # in case this normalform needs to be propagated up and down.
+ child._normalform = normalform
+ child.propagate_down_normalform(normalform)
+
+ def normalize(self: "HandleT", normalform: Type["Normalform"]) -> "HandleT":
+ return normalform.apply(self)
+
@ir.register_value_caster(AnyOpType.get_static_typeid())
@ir.register_value_caster(OperationType.get_static_typeid())
@@ -192,6 +234,45 @@ def constant_param(value: Union[ir.Attribute, int]) -> ParamHandle:
return op.param
+class Normalform:
+ """
+ Represents the weakest normalform and is the base class for all normalforms.
+ A normalform is defined as a sequence of transforms to be applied to a
+ handle to reach this normalform.
+
+ `propagate_up`: Propagate this normalform up to parent handles.
+ `propagate_down`: Propagate this normalform down to all child handles
+ """
+
+ propagate_up: bool = True
+ propagate_down: bool = True
+
+ def __init__(self):
+ raise TypeError(
+ "Normalform cannot be instantiated directly. Use Type[Normalform]"
+ "instead."
+ )
+
+ @classmethod
+ def _impl(cls, handle: "HandleT") -> "HandleT":
+ """
+ Defines the transforms required to reach this normalform.
+ A normalform may apply arbitrary transforms and thus possibly
+ invalidate `handle`.
+ """
+ return handle
+
+ @classmethod
+ def apply(cls, handle: "HandleT") -> "HandleT":
+ """Apply transforms to a handle to bring it into this normalform."""
+ new_handle = cls._impl(handle)
+ new_handle.children.extend(handle.children)
+ new_handle.parent = handle.parent
+ # Setting this property propagates the normalform accordingly
+ new_handle.normalform = cls
+ return new_handle
+
+
def insert_transform_script(
block_or_insertion_point: Union[ir.Block, ir.InsertionPoint],
script: Callable[[OpHandle], None],
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
index ea47f170cb63212..2ca7cfc5ead04f3 100644
--- a/mlir/test/python/dialects/transform_extras.py
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -19,6 +19,7 @@
insert_transform_script,
sequence,
apply_patterns,
+ Normalform,
)
from mlir.extras import types as T
@@ -55,6 +56,12 @@ def build_transform_script_at_insertion_point(script: Callable[[OpHandle], None]
module.operation.verify()
+def run(f: Callable[[], None]):
+ print("\nTEST:", f.__name__)
+ with ir.Context(), ir.Location.unknown():
+ f()
+
+
# CHECK-LABEL: TEST: test_build_script_at_insertion_point
@build_transform_script_at_insertion_point
def test_build_script_at_insertion_point(op: OpHandle):
@@ -175,6 +182,88 @@ def test_match_ops_mixed(op: OpHandle):
# CHECK-SAME: -> !transform.any_op
+# CHECK-LABEL: TEST: test_normalform_base
+@build_transform_script
+def test_normalform_base(op: OpHandle):
+ # Normalform is the weakest normalform so op should already be in that form.
+ # Normalization to Normalform should be a no-op.
+ assert op._normalform is Normalform
+ op.normalize(Normalform)
+ assert op._normalform is Normalform
+ # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+ # CHECK-NEXT: transform.yield
+
+
+class DummyNormalform(Normalform):
+ propagate_up: bool = True
+ propagate_down: bool = True
+
+ @classmethod
+ def _impl(cls, handle: OpHandle) -> OpHandle:
+ return handle.print("dummy normalization")
+
+
+# CHECK-LABEL: test_normalform_no_instantiation
+@run
+def test_normalform_no_instantiation():
+ try:
+ DummyNormalform()
+ except TypeError as e:
+ print(e)
+ else:
+ print("Exception not produced")
+
+ # CHECK: Normalform cannot be instantiated directly
+
+
+# CHECK-LABEL: TEST: test_normalform_dummyform
+@build_transform_script
+def test_normalform_dummyform(op: OpHandle):
+ op.normalize(DummyNormalform)
+ assert op._normalform is DummyNormalform
+ # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+ # CHECK-NEXT: transform.print %[[VAL_0]] {name = "dummy normalization"}
+
+
+# CHECK-LABEL: TEST: test_normalform_propagate_up
+@build_transform_script
+def test_normalform_propagate_up(op: OpHandle):
+ nested_handle = op.match_ops("dummy.op")
+ nested_handle.normalize(DummyNormalform)
+ assert nested_handle._normalform is DummyNormalform
+ assert op._normalform is DummyNormalform
+ # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["dummy.op"]}
+ # CHECK-NEXT: transform.print %[[VAL_1]] {name = "dummy normalization"}
+
+
+# CHECK-LABEL: TEST: test_normalform_propagate_down
+@build_transform_script
+def test_normalform_propagate_down(op: OpHandle):
+ nested_handle = op.match_ops("dummy.op")
+ op.normalize(DummyNormalform)
+ assert nested_handle._normalform is DummyNormalform
+ assert op._normalform is DummyNormalform
+ # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["dummy.op"]}
+ # CHECK-NEXT: transform.print %[[VAL_0]] {name = "dummy normalization"}
+
+
+# CHECK-LABEL: TEST: test_normalform_propagate_up_and_down
+@build_transform_script
+def test_normalform_propagate_up_and_down(op: OpHandle):
+ nested_handle = op.match_ops("dummy.op1")
+ nested_nested_handle = nested_handle.match_ops("dummy.op2")
+ nested_handle.normalize(DummyNormalform)
+ assert nested_handle._normalform is DummyNormalform
+ assert op._normalform is DummyNormalform
+ assert nested_nested_handle._normalform is DummyNormalform
+ # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["dummy.op1"]}
+ # CHECK-NEXT: %[[VAL_2:.*]] = transform.structured.match ops{["dummy.op2"]}
+ # CHECK-NEXT: transform.print %[[VAL_1]] {name = "dummy normalization"}
+
+
# CHECK-LABEL: TEST: test_print_message
@build_transform_script
def test_print_message(op: OpHandle):
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM modulo nits
I'm confused, why is this a python concept? |
4d9bb51
to
2d1cbe4
Compare
This does not have to be a Python concept. In fact, we thought about possibly tracking the normalform of a handle through the MLIR type of that handle. For now, we opted to have this in a shallow frontend abstraction rather than in the dialect itself. Additionally, this design might change in the future as we learn more about it. For instance, I am not sure if modeling a hierarchy of normalforms using inheritance is strong enough to express the possible implications between a variety of normalforms in the same and across dialects. Note that all of this is only concerned with the generation of a transform script. It does not do any analysis / transformations on existing transform IR. In summary, we want to get some more experience with transform pre/post-conditions and introduce modeling of this in the dialect itself when we know better what aspects are key. |
I'm just not sure this belongs to MLIR right now. I believe the python bindings are just that... bindings. That would be a significant departure to start adding "features" that don't exist in MLIR. |
The question is how much lifting the bindings are allowed to do to make IR generation easier. The goal is for instance to be able to write with Context(), InsertionPoint(body), Autonormalize():
matmul = module.match_op(linalg.matmulOp)
matmul.tile(tile_sizes=[4,8]) instead of: with Context(), InsertionPoint(body):
matmul = module.match_op(linalg.matmulOp)
module.apply_tiling_canonicalization()
module.apply_licm()
module.canonicalize()
matmul.tile(tile_sizes=[4,8]) to generate similar IR. I think the Python bindings have slowly transitioned into more than bindings. Rather than simply exposing the C++ API to Python, they provide sensible pythonic frontend abstractions. The feedback about this at the LLVM devmtg was encouraging and to me, they did not feel out of place in this environment of "easing IR generation" so I started upstreaming. Would you be happier with this feature if we had it in C++ as well or even found a way to encode this in the IR? Of course, if there are objections to such extras I will take this down. |
@joker-eph ping |
The "concept of normalform" is just a convenience? Ultimately to me the "bindings" are supposed to be just that: binding to the MLIR concepts. Of course there is some wrapping that can be done to make them more pythonic, safer, etc. But what you're doing here does not seem to fit this category of "let's make the MLIR API more pythonic" but really adding a new concept here. |
@ftynse offered to setup meeting to chat about this if you'd like. |
This adds the concept of a Normalform to the transform dialect python extras.
Normalforms are defined by a sequence of transforms required to ensure this normalform. The term and design is inspired by a similar concept in the term rewriting community.
Each transform handle adheres to a normalform. Initially simply
Normalform
, the weakest normalform, that all other normalforms inherit from.A normalform might imply other normalforms which is currently modeled using inheritance. i.e. stronger normalforms inherit from weaker normalforms.
Note that the normalform of a handle is a static property currently without precise info about the actual payload.
Normalizing a handle to a normalform might trigger propagation of this form to parent and child handles, depending on the specific normalform.
An example for a (conceived) normalform is the
PerfectForNestForm
that aims to achieve perfect loop nests in the IR, if possible, by using canonicalization and loop invariant code motion transforms.e.g.
This normalform is propagated only to child handles, as all handles to operations that are nested at a deeper level will have also been impacted by these transforms and consist of perfectly nested loops, if possible. This normalform is not propagated to parent handles as these are not impacted by this specific normalization.
With the current design Normalforms are never instantiated. Only ever a type of normalform is used, e.g.
This could possibly also be modeled as an
Enum
, but that makes modeling a hierarchy of normalforms more complicated.