Skip to content

Commit 4d9bb51

Browse files
committed
[mlir][python] Add normalforms to capture preconditions for transforms
1 parent 03cf0e9 commit 4d9bb51

File tree

2 files changed

+171
-1
lines changed

2 files changed

+171
-1
lines changed

mlir/python/mlir/dialects/transform/extras/__init__.py

+82-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5-
from typing import Callable, Optional, Sequence, Union
5+
from typing import Callable, Optional, Sequence, Type, TypeVar, Union
66

77
from ....extras.meta import region_op
88
from .... import ir
@@ -20,6 +20,8 @@
2020
)
2121
from .. import structured
2222

23+
HandleT = TypeVar("HandleT", bound="Handle")
24+
2325

2426
class Handle(ir.Value):
2527
"""
@@ -42,6 +44,46 @@ def __init__(
4244
super().__init__(v)
4345
self.parent = parent
4446
self.children = children if children is not None else []
47+
self._normalform = Normalform
48+
49+
@property
50+
def normalform(self) -> Type["Normalform"]:
51+
"""
52+
The normalform of this handle. This is a static property of the handle
53+
and indicates a group of previously applied transforms. This can be used
54+
by subsequent transforms to statically reason about the structure of the
55+
payload operations and whether other enabling transforms could possibly
56+
be skipped.
57+
Setting this property triggers propagation of the normalform to parent
58+
and child handles depending on the specific normalform.
59+
"""
60+
return self._normalform
61+
62+
@normalform.setter
63+
def normalform(self, normalform: Type["Normalform"]):
64+
self._normalform = normalform
65+
if self._normalform.propagate_up:
66+
self.propagate_up_normalform(normalform)
67+
if self._normalform.propagate_down:
68+
self.propagate_down_normalform(normalform)
69+
70+
def propagate_up_normalform(self, normalform: Type["Normalform"]):
71+
if self.parent:
72+
# We set the parent normalform directly to avoid infinite recursion
73+
# in case this normalform needs to be propagated up and down.
74+
self.parent._normalform = normalform
75+
self.parent.propagate_up_normalform(normalform)
76+
77+
def propagate_down_normalform(self, normalform: Type["Normalform"]):
78+
for child in self.children:
79+
# We set the child normalform directly to avoid infinite recursion
80+
# in case this normalform needs to be propagated up and down.
81+
child._normalform = normalform
82+
child.propagate_down_normalform(normalform)
83+
84+
def normalize(self: "HandleT", normalform: Type["Normalform"]) -> "HandleT":
85+
return normalform.apply(self)
86+
4587

4688
@ir.register_value_caster(AnyOpType.get_static_typeid())
4789
@ir.register_value_caster(OperationType.get_static_typeid())
@@ -192,6 +234,45 @@ def constant_param(value: Union[ir.Attribute, int]) -> ParamHandle:
192234
return op.param
193235

194236

237+
class Normalform:
238+
"""
239+
Represents the weakest normalform and is the base class for all normalforms.
240+
A normalform is defined as a sequence of transforms to be applied to a
241+
handle to reach this normalform.
242+
243+
`propagate_up`: Propagate this normalform up to parent handles.
244+
`propagate_down`: Propagate this normalform down to all child handles
245+
"""
246+
247+
propagate_up: bool = True
248+
propagate_down: bool = True
249+
250+
def __init__(self):
251+
raise TypeError(
252+
"Normalform cannot be instantiated directly. Use Type[Normalform]"
253+
"instead."
254+
)
255+
256+
@classmethod
257+
def _impl(cls, handle: "HandleT") -> "HandleT":
258+
"""
259+
Defines the transforms required to reach this normalform.
260+
A normalform may apply arbitrary transforms and thus possibly
261+
invalidate `handle`.
262+
"""
263+
return handle
264+
265+
@classmethod
266+
def apply(cls, handle: "HandleT") -> "HandleT":
267+
"""Apply transforms to a handle to bring it into this normalform."""
268+
new_handle = cls._impl(handle)
269+
new_handle.children.extend(handle.children)
270+
new_handle.parent = handle.parent
271+
# Setting this property propagates the normalform accordingly
272+
new_handle.normalform = cls
273+
return new_handle
274+
275+
195276
def insert_transform_script(
196277
block_or_insertion_point: Union[ir.Block, ir.InsertionPoint],
197278
script: Callable[[OpHandle], None],

mlir/test/python/dialects/transform_extras.py

+89
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
insert_transform_script,
2020
sequence,
2121
apply_patterns,
22+
Normalform,
2223
)
2324
from mlir.extras import types as T
2425

@@ -55,6 +56,12 @@ def build_transform_script_at_insertion_point(script: Callable[[OpHandle], None]
5556
module.operation.verify()
5657

5758

59+
def run(f: Callable[[], None]):
60+
print("\nTEST:", f.__name__)
61+
with ir.Context(), ir.Location.unknown():
62+
f()
63+
64+
5865
# CHECK-LABEL: TEST: test_build_script_at_insertion_point
5966
@build_transform_script_at_insertion_point
6067
def test_build_script_at_insertion_point(op: OpHandle):
@@ -175,6 +182,88 @@ def test_match_ops_mixed(op: OpHandle):
175182
# CHECK-SAME: -> !transform.any_op
176183

177184

185+
# CHECK-LABEL: TEST: test_normalform_base
186+
@build_transform_script
187+
def test_normalform_base(op: OpHandle):
188+
# Normalform is the weakest normalform so op should already be in that form.
189+
# Normalization to Normalform should be a no-op.
190+
assert op._normalform is Normalform
191+
op.normalize(Normalform)
192+
assert op._normalform is Normalform
193+
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
194+
# CHECK-NEXT: transform.yield
195+
196+
197+
class DummyNormalform(Normalform):
198+
propagate_up: bool = True
199+
propagate_down: bool = True
200+
201+
@classmethod
202+
def _impl(cls, handle: OpHandle) -> OpHandle:
203+
return handle.print("dummy normalization")
204+
205+
206+
# CHECK-LABEL: test_normalform_no_instantiation
207+
@run
208+
def test_normalform_no_instantiation():
209+
try:
210+
DummyNormalform()
211+
except TypeError as e:
212+
print(e)
213+
else:
214+
print("Exception not produced")
215+
216+
# CHECK: Normalform cannot be instantiated directly
217+
218+
219+
# CHECK-LABEL: TEST: test_normalform_dummyform
220+
@build_transform_script
221+
def test_normalform_dummyform(op: OpHandle):
222+
op.normalize(DummyNormalform)
223+
assert op._normalform is DummyNormalform
224+
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
225+
# CHECK-NEXT: transform.print %[[VAL_0]] {name = "dummy normalization"}
226+
227+
228+
# CHECK-LABEL: TEST: test_normalform_propagate_up
229+
@build_transform_script
230+
def test_normalform_propagate_up(op: OpHandle):
231+
nested_handle = op.match_ops("dummy.op")
232+
nested_handle.normalize(DummyNormalform)
233+
assert nested_handle._normalform is DummyNormalform
234+
assert op._normalform is DummyNormalform
235+
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
236+
# CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["dummy.op"]}
237+
# CHECK-NEXT: transform.print %[[VAL_1]] {name = "dummy normalization"}
238+
239+
240+
# CHECK-LABEL: TEST: test_normalform_propagate_down
241+
@build_transform_script
242+
def test_normalform_propagate_down(op: OpHandle):
243+
nested_handle = op.match_ops("dummy.op")
244+
op.normalize(DummyNormalform)
245+
assert nested_handle._normalform is DummyNormalform
246+
assert op._normalform is DummyNormalform
247+
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
248+
# CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["dummy.op"]}
249+
# CHECK-NEXT: transform.print %[[VAL_0]] {name = "dummy normalization"}
250+
251+
252+
# CHECK-LABEL: TEST: test_normalform_propagate_up_and_down
253+
@build_transform_script
254+
def test_normalform_propagate_up_and_down(op: OpHandle):
255+
nested_handle = op.match_ops("dummy.op1")
256+
nested_nested_handle = nested_handle.match_ops("dummy.op2")
257+
nested_handle.normalize(DummyNormalform)
258+
assert nested_handle._normalform is DummyNormalform
259+
assert op._normalform is DummyNormalform
260+
assert nested_nested_handle._normalform is DummyNormalform
261+
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
262+
# CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["dummy.op1"]}
263+
# CHECK-NEXT: %[[VAL_2:.*]] = transform.structured.match ops{["dummy.op2"]}
264+
# CHECK-NEXT: transform.print %[[VAL_1]] {name = "dummy normalization"}
265+
266+
178267
# CHECK-LABEL: TEST: test_print_message
179268
@build_transform_script
180269
def test_print_message(op: OpHandle):

0 commit comments

Comments
 (0)