Skip to content

[mlir][python] Add normalforms to capture preconditions of transforms #79449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 82 additions & 1 deletion mlir/python/mlir/dialects/transform/extras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Callable, Optional, Sequence, Union
from typing import Callable, Optional, Sequence, Type, TypeVar, Union

from ....extras.meta import region_op
from .... import ir
Expand All @@ -20,6 +20,8 @@
)
from .. import structured

HandleT = TypeVar("HandleT", bound="Handle")


class Handle(ir.Value):
"""
Expand All @@ -42,6 +44,46 @@ def __init__(
super().__init__(v)
self.parent = parent
self.children = children if children is not None else []
self._normalForm = NormalForm

@property
def normalForm(self) -> Type["NormalForm"]:
"""
The normalform of this handle. This is a static property of the handle
and indicates a group of previously applied transforms. This can be used
by subsequent transforms to statically reason about the structure of the
payload operations and whether other enabling transforms could possibly
be skipped.
Setting this property triggers propagation of the normalform to parent
and child handles depending on the specific normalform.
"""
return self._normalForm

@normalForm.setter
def normalForm(self, normalForm: Type["NormalForm"]):
self._normalForm = normalForm
if self._normalForm.propagate_up:
self.propagate_up_normalform(normalForm)
if self._normalForm.propagate_down:
self.propagate_down_normalform(normalForm)

def propagate_up_normalform(self, normalForm: Type["NormalForm"]):
if self.parent:
# We set the parent normalform directly to avoid infinite recursion
# in case this normalform needs to be propagated up and down.
self.parent._normalForm = normalForm
self.parent.propagate_up_normalform(normalForm)

def propagate_down_normalform(self, normalForm: Type["NormalForm"]):
for child in self.children:
# We set the child normalform directly to avoid infinite recursion
# in case this normalform needs to be propagated up and down.
child._normalForm = normalForm
child.propagate_down_normalform(normalForm)

def normalize(self: HandleT, normalForm: Type["NormalForm"]) -> HandleT:
return normalForm.apply(self)


@ir.register_value_caster(AnyOpType.get_static_typeid())
@ir.register_value_caster(OperationType.get_static_typeid())
Expand Down Expand Up @@ -192,6 +234,45 @@ def constant_param(value: Union[ir.Attribute, int]) -> ParamHandle:
return op.param


class NormalForm:
"""
Represents the weakest normalform and is the base class for all normalforms.
A normalform is defined as a sequence of transforms to be applied to a
handle to reach this normalform.

`propagate_up`: Propagate this normalform up to parent handles.
`propagate_down`: Propagate this normalform down to all child handles
"""

propagate_up: bool = True
propagate_down: bool = True

def __init__(self):
raise TypeError(
"NormalForm cannot be instantiated directly. Use Type[NormalForm]"
"instead."
)

@classmethod
def _impl(cls, handle: HandleT) -> HandleT:
"""
Defines the transforms required to reach this normalform.
A normalform may apply arbitrary transforms and thus possibly
invalidate `handle`.
"""
return handle

@classmethod
def apply(cls, handle: HandleT) -> HandleT:
"""Apply transforms to a handle to bring it into this normalform."""
new_handle = cls._impl(handle)
new_handle.children.extend(handle.children)
new_handle.parent = handle.parent
# Setting this property propagates the normalform accordingly
new_handle.normalForm = cls
return new_handle


def insert_transform_script(
block_or_insertion_point: Union[ir.Block, ir.InsertionPoint],
script: Callable[[OpHandle], None],
Expand Down
89 changes: 89 additions & 0 deletions mlir/test/python/dialects/transform_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
insert_transform_script,
sequence,
apply_patterns,
NormalForm,
)
from mlir.extras import types as T

Expand Down Expand Up @@ -55,6 +56,12 @@ def build_transform_script_at_insertion_point(script: Callable[[OpHandle], None]
module.operation.verify()


def run(f: Callable[[], None]):
print("\nTEST:", f.__name__)
with ir.Context(), ir.Location.unknown():
f()


# CHECK-LABEL: TEST: test_build_script_at_insertion_point
@build_transform_script_at_insertion_point
def test_build_script_at_insertion_point(op: OpHandle):
Expand Down Expand Up @@ -175,6 +182,88 @@ def test_match_ops_mixed(op: OpHandle):
# CHECK-SAME: -> !transform.any_op


# CHECK-LABEL: TEST: test_normalform_base
@build_transform_script
def test_normalform_base(op: OpHandle):
# Normalform is the weakest normalform so op should already be in that form.
# Normalization to Normalform should be a no-op.
assert op._normalForm is NormalForm
op.normalize(NormalForm)
assert op._normalForm is NormalForm
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
# CHECK-NEXT: transform.yield


class DummyNormalform(NormalForm):
propagate_up: bool = True
propagate_down: bool = True

@classmethod
def _impl(cls, handle: OpHandle) -> OpHandle:
return handle.print("dummy normalization")


# CHECK-LABEL: test_normalform_no_instantiation
@run
def test_normalform_no_instantiation():
try:
DummyNormalform()
except TypeError as e:
print(e)
else:
print("Exception not produced")

# CHECK: NormalForm cannot be instantiated directly


# CHECK-LABEL: TEST: test_normalform_dummyform
@build_transform_script
def test_normalform_dummyform(op: OpHandle):
op.normalize(DummyNormalform)
assert op._normalForm is DummyNormalform
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
# CHECK-NEXT: transform.print %[[VAL_0]] {name = "dummy normalization"}


# CHECK-LABEL: TEST: test_normalform_propagate_up
@build_transform_script
def test_normalform_propagate_up(op: OpHandle):
nested_handle = op.match_ops("dummy.op")
nested_handle.normalize(DummyNormalform)
assert nested_handle._normalForm is DummyNormalform
assert op._normalForm is DummyNormalform
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
# CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["dummy.op"]}
# CHECK-NEXT: transform.print %[[VAL_1]] {name = "dummy normalization"}


# CHECK-LABEL: TEST: test_normalform_propagate_down
@build_transform_script
def test_normalform_propagate_down(op: OpHandle):
nested_handle = op.match_ops("dummy.op")
op.normalize(DummyNormalform)
assert nested_handle._normalForm is DummyNormalform
assert op._normalForm is DummyNormalform
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
# CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["dummy.op"]}
# CHECK-NEXT: transform.print %[[VAL_0]] {name = "dummy normalization"}


# CHECK-LABEL: TEST: test_normalform_propagate_up_and_down
@build_transform_script
def test_normalform_propagate_up_and_down(op: OpHandle):
nested_handle = op.match_ops("dummy.op1")
nested_nested_handle = nested_handle.match_ops("dummy.op2")
nested_handle.normalize(DummyNormalform)
assert nested_handle._normalForm is DummyNormalform
assert op._normalForm is DummyNormalform
assert nested_nested_handle._normalForm is DummyNormalform
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
# CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["dummy.op1"]}
# CHECK-NEXT: %[[VAL_2:.*]] = transform.structured.match ops{["dummy.op2"]}
# CHECK-NEXT: transform.print %[[VAL_1]] {name = "dummy normalization"}


# CHECK-LABEL: TEST: test_print_message
@build_transform_script
def test_print_message(op: OpHandle):
Expand Down