diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py index 8d045cad7a4a3..8f7d04ddb1f9e 100644 --- a/mlir/python/mlir/dialects/transform/extras/__init__.py +++ b/mlir/python/mlir/dialects/transform/extras/__init__.py @@ -2,7 +2,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Callable, Optional, Sequence, Union +from typing import Callable, Optional, Sequence, Type, TypeVar, Union from ....extras.meta import region_op from .... import ir @@ -20,6 +20,8 @@ ) from .. import structured +HandleT = TypeVar("HandleT", bound="Handle") + class Handle(ir.Value): """ @@ -42,6 +44,46 @@ def __init__( super().__init__(v) self.parent = parent self.children = children if children is not None else [] + self._normalForm = NormalForm + + @property + def normalForm(self) -> Type["NormalForm"]: + """ + The normalform of this handle. This is a static property of the handle + and indicates a group of previously applied transforms. This can be used + by subsequent transforms to statically reason about the structure of the + payload operations and whether other enabling transforms could possibly + be skipped. + Setting this property triggers propagation of the normalform to parent + and child handles depending on the specific normalform. + """ + return self._normalForm + + @normalForm.setter + def normalForm(self, normalForm: Type["NormalForm"]): + self._normalForm = normalForm + if self._normalForm.propagate_up: + self.propagate_up_normalform(normalForm) + if self._normalForm.propagate_down: + self.propagate_down_normalform(normalForm) + + def propagate_up_normalform(self, normalForm: Type["NormalForm"]): + if self.parent: + # We set the parent normalform directly to avoid infinite recursion + # in case this normalform needs to be propagated up and down. + self.parent._normalForm = normalForm + self.parent.propagate_up_normalform(normalForm) + + def propagate_down_normalform(self, normalForm: Type["NormalForm"]): + for child in self.children: + # We set the child normalform directly to avoid infinite recursion + # in case this normalform needs to be propagated up and down. + child._normalForm = normalForm + child.propagate_down_normalform(normalForm) + + def normalize(self: HandleT, normalForm: Type["NormalForm"]) -> HandleT: + return normalForm.apply(self) + @ir.register_value_caster(AnyOpType.get_static_typeid()) @ir.register_value_caster(OperationType.get_static_typeid()) @@ -192,6 +234,45 @@ def constant_param(value: Union[ir.Attribute, int]) -> ParamHandle: return op.param +class NormalForm: + """ + Represents the weakest normalform and is the base class for all normalforms. + A normalform is defined as a sequence of transforms to be applied to a + handle to reach this normalform. + + `propagate_up`: Propagate this normalform up to parent handles. + `propagate_down`: Propagate this normalform down to all child handles + """ + + propagate_up: bool = True + propagate_down: bool = True + + def __init__(self): + raise TypeError( + "NormalForm cannot be instantiated directly. Use Type[NormalForm]" + "instead." + ) + + @classmethod + def _impl(cls, handle: HandleT) -> HandleT: + """ + Defines the transforms required to reach this normalform. + A normalform may apply arbitrary transforms and thus possibly + invalidate `handle`. + """ + return handle + + @classmethod + def apply(cls, handle: HandleT) -> HandleT: + """Apply transforms to a handle to bring it into this normalform.""" + new_handle = cls._impl(handle) + new_handle.children.extend(handle.children) + new_handle.parent = handle.parent + # Setting this property propagates the normalform accordingly + new_handle.normalForm = cls + return new_handle + + def insert_transform_script( block_or_insertion_point: Union[ir.Block, ir.InsertionPoint], script: Callable[[OpHandle], None], diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py index ea47f170cb632..b329609f795b5 100644 --- a/mlir/test/python/dialects/transform_extras.py +++ b/mlir/test/python/dialects/transform_extras.py @@ -19,6 +19,7 @@ insert_transform_script, sequence, apply_patterns, + NormalForm, ) from mlir.extras import types as T @@ -55,6 +56,12 @@ def build_transform_script_at_insertion_point(script: Callable[[OpHandle], None] module.operation.verify() +def run(f: Callable[[], None]): + print("\nTEST:", f.__name__) + with ir.Context(), ir.Location.unknown(): + f() + + # CHECK-LABEL: TEST: test_build_script_at_insertion_point @build_transform_script_at_insertion_point def test_build_script_at_insertion_point(op: OpHandle): @@ -175,6 +182,88 @@ def test_match_ops_mixed(op: OpHandle): # CHECK-SAME: -> !transform.any_op +# CHECK-LABEL: TEST: test_normalform_base +@build_transform_script +def test_normalform_base(op: OpHandle): + # Normalform is the weakest normalform so op should already be in that form. + # Normalization to Normalform should be a no-op. + assert op._normalForm is NormalForm + op.normalize(NormalForm) + assert op._normalForm is NormalForm + # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { + # CHECK-NEXT: transform.yield + + +class DummyNormalform(NormalForm): + propagate_up: bool = True + propagate_down: bool = True + + @classmethod + def _impl(cls, handle: OpHandle) -> OpHandle: + return handle.print("dummy normalization") + + +# CHECK-LABEL: test_normalform_no_instantiation +@run +def test_normalform_no_instantiation(): + try: + DummyNormalform() + except TypeError as e: + print(e) + else: + print("Exception not produced") + + # CHECK: NormalForm cannot be instantiated directly + + +# CHECK-LABEL: TEST: test_normalform_dummyform +@build_transform_script +def test_normalform_dummyform(op: OpHandle): + op.normalize(DummyNormalform) + assert op._normalForm is DummyNormalform + # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { + # CHECK-NEXT: transform.print %[[VAL_0]] {name = "dummy normalization"} + + +# CHECK-LABEL: TEST: test_normalform_propagate_up +@build_transform_script +def test_normalform_propagate_up(op: OpHandle): + nested_handle = op.match_ops("dummy.op") + nested_handle.normalize(DummyNormalform) + assert nested_handle._normalForm is DummyNormalform + assert op._normalForm is DummyNormalform + # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { + # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["dummy.op"]} + # CHECK-NEXT: transform.print %[[VAL_1]] {name = "dummy normalization"} + + +# CHECK-LABEL: TEST: test_normalform_propagate_down +@build_transform_script +def test_normalform_propagate_down(op: OpHandle): + nested_handle = op.match_ops("dummy.op") + op.normalize(DummyNormalform) + assert nested_handle._normalForm is DummyNormalform + assert op._normalForm is DummyNormalform + # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { + # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["dummy.op"]} + # CHECK-NEXT: transform.print %[[VAL_0]] {name = "dummy normalization"} + + +# CHECK-LABEL: TEST: test_normalform_propagate_up_and_down +@build_transform_script +def test_normalform_propagate_up_and_down(op: OpHandle): + nested_handle = op.match_ops("dummy.op1") + nested_nested_handle = nested_handle.match_ops("dummy.op2") + nested_handle.normalize(DummyNormalform) + assert nested_handle._normalForm is DummyNormalform + assert op._normalForm is DummyNormalform + assert nested_nested_handle._normalForm is DummyNormalform + # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { + # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["dummy.op1"]} + # CHECK-NEXT: %[[VAL_2:.*]] = transform.structured.match ops{["dummy.op2"]} + # CHECK-NEXT: transform.print %[[VAL_1]] {name = "dummy normalization"} + + # CHECK-LABEL: TEST: test_print_message @build_transform_script def test_print_message(op: OpHandle):