Skip to content

[mlir][python] Add normalforms to capture preconditions of transforms #79449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

martin-luecke
Copy link
Contributor

This adds the concept of a Normalform to the transform dialect python extras.

Normalforms are defined by a sequence of transforms required to ensure this normalform. The term and design is inspired by a similar concept in the term rewriting community.

Each transform handle adheres to a normalform. Initially simply Normalform, the weakest normalform, that all other normalforms inherit from.
A normalform might imply other normalforms which is currently modeled using inheritance. i.e. stronger normalforms inherit from weaker normalforms.

Note that the normalform of a handle is a static property currently without precise info about the actual payload.

Normalizing a handle to a normalform might trigger propagation of this form to parent and child handles, depending on the specific normalform.
An example for a (conceived) normalform is the PerfectForNestForm that aims to achieve perfect loop nests in the IR, if possible, by using canonicalization and loop invariant code motion transforms.
e.g.

class PerfectForNestForm(Normalform):
  propagate_up = False
  propagate_down = True

  def _impl(cls, handle: OpHandle) -> None:
    with handle.apply_patterns(): 
      structured.ApplyTilingCanonicalizationPatternsOp()
      loop.ApplyForLoopCanonicalizationPatternsOp() 
      transform.ApplyCanonicalizationPatternsOp()
    handle.apply_licm()
    handle.apply_cse()

This normalform is propagated only to child handles, as all handles to operations that are nested at a deeper level will have also been impacted by these transforms and consist of perfectly nested loops, if possible. This normalform is not propagated to parent handles as these are not impacted by this specific normalization.

With the current design Normalforms are never instantiated. Only ever a type of normalform is used, e.g.

class Handle(ir.Value):
  @property
  def normalform(self) -> Type["Normalform"]:

This could possibly also be modeled as an Enum, but that makes modeling a hierarchy of normalforms more complicated.

@llvmbot llvmbot added mlir:python MLIR Python bindings mlir labels Jan 25, 2024
@llvmbot
Copy link
Member

llvmbot commented Jan 25, 2024

@llvm/pr-subscribers-mlir

Author: None (martin-luecke)

Changes

This adds the concept of a Normalform to the transform dialect python extras.

Normalforms are defined by a sequence of transforms required to ensure this normalform. The term and design is inspired by a similar concept in the term rewriting community.

Each transform handle adheres to a normalform. Initially simply Normalform, the weakest normalform, that all other normalforms inherit from.
A normalform might imply other normalforms which is currently modeled using inheritance. i.e. stronger normalforms inherit from weaker normalforms.

Note that the normalform of a handle is a static property currently without precise info about the actual payload.

Normalizing a handle to a normalform might trigger propagation of this form to parent and child handles, depending on the specific normalform.
An example for a (conceived) normalform is the PerfectForNestForm that aims to achieve perfect loop nests in the IR, if possible, by using canonicalization and loop invariant code motion transforms.
e.g.

class PerfectForNestForm(Normalform):
  propagate_up = False
  propagate_down = True

  def _impl(cls, handle: OpHandle) -> None:
    with handle.apply_patterns(): 
      structured.ApplyTilingCanonicalizationPatternsOp()
      loop.ApplyForLoopCanonicalizationPatternsOp() 
      transform.ApplyCanonicalizationPatternsOp()
    handle.apply_licm()
    handle.apply_cse()

This normalform is propagated only to child handles, as all handles to operations that are nested at a deeper level will have also been impacted by these transforms and consist of perfectly nested loops, if possible. This normalform is not propagated to parent handles as these are not impacted by this specific normalization.

With the current design Normalforms are never instantiated. Only ever a type of normalform is used, e.g.

class Handle(ir.Value):
  @<!-- -->property
  def normalform(self) -&gt; Type["Normalform"]:

This could possibly also be modeled as an Enum, but that makes modeling a hierarchy of normalforms more complicated.


Full diff: https://github.com/llvm/llvm-project/pull/79449.diff

2 Files Affected:

  • (modified) mlir/python/mlir/dialects/transform/extras/init.py (+82-1)
  • (modified) mlir/test/python/dialects/transform_extras.py (+89)
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 8d045cad7a4a36f..ccfe4a6babb6fde 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -2,7 +2,7 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from typing import Callable, Optional, Sequence, Union
+from typing import Callable, Optional, Sequence, Type, TypeVar, Union
 
 from ....extras.meta import region_op
 from .... import ir
@@ -20,6 +20,8 @@
 )
 from .. import structured
 
+HandleT = TypeVar("HandleT", bound="Handle")
+
 
 class Handle(ir.Value):
     """
@@ -42,6 +44,46 @@ def __init__(
         super().__init__(v)
         self.parent = parent
         self.children = children if children is not None else []
+        self._normalform = Normalform
+
+    @property
+    def normalform(self) -> Type["Normalform"]:
+        """
+        The normalform of this handle. This is a static property of the handle
+        and indicates a group of previously applied transforms. This can be used
+        by subsequent transforms to statically reason about the structure of the
+        payload operations and whether other enabling transforms could possibly
+        be skipped.
+        Setting this property triggers propagation of the normalform to parent
+        and child handles depending on the specific normalform.
+        """
+        return self._normalform
+
+    @normalform.setter
+    def normalform(self, normalform: Type["Normalform"]):
+        self._normalform = normalform
+        if self._normalform.propagate_up:
+            self.propagate_up_normalform(normalform)
+        if self._normalform.propagate_down:
+            self.propagate_down_normalform(normalform)
+
+    def propagate_up_normalform(self, normalform: Type["Normalform"]):
+        if self.parent:
+            # We set the parent normalform directly to avoid infinite recursion
+            # in case this normalform needs to be propagated up and down.
+            self.parent._normalform = normalform
+            self.parent.propagate_up_normalform(normalform)
+
+    def propagate_down_normalform(self, normalform: Type["Normalform"]):
+        for child in self.children:
+            # We set the child normalform directly to avoid infinite recursion
+            # in case this normalform needs to be propagated up and down.
+            child._normalform = normalform
+            child.propagate_down_normalform(normalform)
+
+    def normalize(self: "HandleT", normalform: Type["Normalform"]) -> "HandleT":
+        return normalform.apply(self)
+
 
 @ir.register_value_caster(AnyOpType.get_static_typeid())
 @ir.register_value_caster(OperationType.get_static_typeid())
@@ -192,6 +234,45 @@ def constant_param(value: Union[ir.Attribute, int]) -> ParamHandle:
     return op.param
 
 
+class Normalform:
+    """
+    Represents the weakest normalform and is the base class for all normalforms.
+    A normalform is defined as a sequence of transforms to be applied to a
+    handle to reach this normalform.
+
+    `propagate_up`: Propagate this normalform up to parent handles.
+    `propagate_down`: Propagate this normalform down to all child handles
+    """
+
+    propagate_up: bool = True
+    propagate_down: bool = True
+
+    def __init__(self):
+        raise TypeError(
+            "Normalform cannot be instantiated directly. Use Type[Normalform]"
+            "instead."
+        )
+
+    @classmethod
+    def _impl(cls, handle: "HandleT") -> "HandleT":
+        """
+        Defines the transforms required to reach this normalform.
+        A normalform may apply arbitrary transforms and thus possibly
+        invalidate `handle`.
+        """
+        return handle
+
+    @classmethod
+    def apply(cls, handle: "HandleT") -> "HandleT":
+        """Apply transforms to a handle to bring it into this normalform."""
+        new_handle = cls._impl(handle)
+        new_handle.children.extend(handle.children)
+        new_handle.parent = handle.parent
+        # Setting this property propagates the normalform accordingly
+        new_handle.normalform = cls
+        return new_handle
+
+
 def insert_transform_script(
     block_or_insertion_point: Union[ir.Block, ir.InsertionPoint],
     script: Callable[[OpHandle], None],
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
index ea47f170cb63212..2ca7cfc5ead04f3 100644
--- a/mlir/test/python/dialects/transform_extras.py
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -19,6 +19,7 @@
     insert_transform_script,
     sequence,
     apply_patterns,
+    Normalform,
 )
 from mlir.extras import types as T
 
@@ -55,6 +56,12 @@ def build_transform_script_at_insertion_point(script: Callable[[OpHandle], None]
         module.operation.verify()
 
 
+def run(f: Callable[[], None]):
+    print("\nTEST:", f.__name__)
+    with ir.Context(), ir.Location.unknown():
+        f()
+
+
 # CHECK-LABEL: TEST: test_build_script_at_insertion_point
 @build_transform_script_at_insertion_point
 def test_build_script_at_insertion_point(op: OpHandle):
@@ -175,6 +182,88 @@ def test_match_ops_mixed(op: OpHandle):
     # CHECK-SAME:     -> !transform.any_op
 
 
+# CHECK-LABEL: TEST: test_normalform_base
+@build_transform_script
+def test_normalform_base(op: OpHandle):
+    # Normalform is the weakest normalform so op should already be in that form.
+    # Normalization to Normalform should be a no-op.
+    assert op._normalform is Normalform
+    op.normalize(Normalform)
+    assert op._normalform is Normalform
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: transform.yield
+
+
+class DummyNormalform(Normalform):
+    propagate_up: bool = True
+    propagate_down: bool = True
+
+    @classmethod
+    def _impl(cls, handle: OpHandle) -> OpHandle:
+        return handle.print("dummy normalization")
+
+
+# CHECK-LABEL: test_normalform_no_instantiation
+@run
+def test_normalform_no_instantiation():
+    try:
+        DummyNormalform()
+    except TypeError as e:
+        print(e)
+    else:
+        print("Exception not produced")
+
+    # CHECK: Normalform cannot be instantiated directly
+
+
+# CHECK-LABEL: TEST: test_normalform_dummyform
+@build_transform_script
+def test_normalform_dummyform(op: OpHandle):
+    op.normalize(DummyNormalform)
+    assert op._normalform is DummyNormalform
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: transform.print %[[VAL_0]] {name = "dummy normalization"}
+
+
+# CHECK-LABEL: TEST: test_normalform_propagate_up
+@build_transform_script
+def test_normalform_propagate_up(op: OpHandle):
+    nested_handle = op.match_ops("dummy.op")
+    nested_handle.normalize(DummyNormalform)
+    assert nested_handle._normalform is DummyNormalform
+    assert op._normalform is DummyNormalform
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["dummy.op"]}
+    # CHECK-NEXT: transform.print %[[VAL_1]] {name = "dummy normalization"}
+
+
+# CHECK-LABEL: TEST: test_normalform_propagate_down
+@build_transform_script
+def test_normalform_propagate_down(op: OpHandle):
+    nested_handle = op.match_ops("dummy.op")
+    op.normalize(DummyNormalform)
+    assert nested_handle._normalform is DummyNormalform
+    assert op._normalform is DummyNormalform
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["dummy.op"]}
+    # CHECK-NEXT: transform.print %[[VAL_0]] {name = "dummy normalization"}
+
+
+# CHECK-LABEL: TEST: test_normalform_propagate_up_and_down
+@build_transform_script
+def test_normalform_propagate_up_and_down(op: OpHandle):
+    nested_handle = op.match_ops("dummy.op1")
+    nested_nested_handle = nested_handle.match_ops("dummy.op2")
+    nested_handle.normalize(DummyNormalform)
+    assert nested_handle._normalform is DummyNormalform
+    assert op._normalform is DummyNormalform
+    assert nested_nested_handle._normalform is DummyNormalform
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["dummy.op1"]}
+    # CHECK-NEXT: %[[VAL_2:.*]] = transform.structured.match ops{["dummy.op2"]}
+    # CHECK-NEXT: transform.print %[[VAL_1]] {name = "dummy normalization"}
+
+
 # CHECK-LABEL: TEST: test_print_message
 @build_transform_script
 def test_print_message(op: OpHandle):

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo nits

@joker-eph
Copy link
Collaborator

I'm confused, why is this a python concept?

@martin-luecke
Copy link
Contributor Author

martin-luecke commented Jan 26, 2024

@joker-eph

I'm confused, why is this a python concept?

This does not have to be a Python concept. In fact, we thought about possibly tracking the normalform of a handle through the MLIR type of that handle.

For now, we opted to have this in a shallow frontend abstraction rather than in the dialect itself.
We want to get some mileage with this concept and see how it eases the generation of transform scripts.
That includes additional extensions to this that are harder to realize in transform core (e.g. requiring new passes or extensive changes to the transfrom core infrastructure).
For instance, I am working on a follow-up PR that introduces automatic normalization if a handle is not statically known to be in a specific normalform. The goal is to remove some of the guesswork of which canonicalizations / clean-up passes need to be applied before a specific transform.

Additionally, this design might change in the future as we learn more about it. For instance, I am not sure if modeling a hierarchy of normalforms using inheritance is strong enough to express the possible implications between a variety of normalforms in the same and across dialects.

Note that all of this is only concerned with the generation of a transform script. It does not do any analysis / transformations on existing transform IR.

In summary, we want to get some more experience with transform pre/post-conditions and introduce modeling of this in the dialect itself when we know better what aspects are key.

@joker-eph
Copy link
Collaborator

In summary, we want to get some more experience with transform pre/post-conditions and introduce modeling of this in the dialect itself when we know better what aspects are key.

I'm just not sure this belongs to MLIR right now. I believe the python bindings are just that... bindings. That would be a significant departure to start adding "features" that don't exist in MLIR.

@martin-luecke
Copy link
Contributor Author

The question is how much lifting the bindings are allowed to do to make IR generation easier.
In my opinion, these are convenience extras to help generate transform IR more easily.
No analysis or transformation of existing IR is happening here.

The goal is for instance to be able to write

with Context(), InsertionPoint(body), Autonormalize(): 
  matmul = module.match_op(linalg.matmulOp)
  matmul.tile(tile_sizes=[4,8])

instead of:

with Context(), InsertionPoint(body): 
  matmul = module.match_op(linalg.matmulOp)
  module.apply_tiling_canonicalization()
  module.apply_licm()
  module.canonicalize()
  matmul.tile(tile_sizes=[4,8])

to generate similar IR.

I think the Python bindings have slowly transitioned into more than bindings. Rather than simply exposing the C++ API to Python, they provide sensible pythonic frontend abstractions.
For instance, we can annotate Python functions to generate func.func operations with an automatically inserted func.return for the SSA values returned from the Python function, rather than calling into C++ APIs explicitly for this.
AFAIK the C++ Dialect APIs are not concerned with easing IR emittance beyond OpBuilders. Abstractions that automatically inject ops are in Python only and improve usability there a lot.
Similar to other Python abstractions, I don't think this is designed as a feature that makes Python a non-optional component for transform dialect.

The feedback about this at the LLVM devmtg was encouraging and to me, they did not feel out of place in this environment of "easing IR generation" so I started upstreaming.

Would you be happier with this feature if we had it in C++ as well or even found a way to encode this in the IR?

Of course, if there are objections to such extras I will take this down.

@martin-luecke
Copy link
Contributor Author

@joker-eph ping

@joker-eph
Copy link
Collaborator

The question is how much lifting the bindings are allowed to do to make IR generation easier.
In my opinion, these are convenience extras to help generate transform IR more easily.

The "concept of normalform" is just a convenience?
From reading this PR description this seems more than that to me, and it seems like a concept that could be purely native.

Ultimately to me the "bindings" are supposed to be just that: binding to the MLIR concepts. Of course there is some wrapping that can be done to make them more pythonic, safer, etc. But what you're doing here does not seem to fit this category of "let's make the MLIR API more pythonic" but really adding a new concept here.

@joker-eph
Copy link
Collaborator

@ftynse offered to setup meeting to chat about this if you'd like.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants