2
2
# See https://llvm.org/LICENSE.txt for license information.
3
3
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
4
5
- from typing import Callable , Optional , Sequence , Union
5
+ from typing import Callable , Optional , Sequence , Type , TypeVar , Union
6
6
7
7
from ....extras .meta import region_op
8
8
from .... import ir
20
20
)
21
21
from .. import structured
22
22
23
+ HandleT = TypeVar ("HandleT" , bound = "Handle" )
24
+
23
25
24
26
class Handle (ir .Value ):
25
27
"""
@@ -42,6 +44,46 @@ def __init__(
42
44
super ().__init__ (v )
43
45
self .parent = parent
44
46
self .children = children if children is not None else []
47
+ self ._normalform = Normalform
48
+
49
+ @property
50
+ def normalform (self ) -> Type ["Normalform" ]:
51
+ """
52
+ The normalform of this handle. This is a static property of the handle
53
+ and indicates a group of previously applied transforms. This can be used
54
+ by subsequent transforms to statically reason about the structure of the
55
+ payload operations and whether other enabling transforms could possibly
56
+ be skipped.
57
+ Setting this property triggers propagation of the normalform to parent
58
+ and child handles depending on the specific normalform.
59
+ """
60
+ return self ._normalform
61
+
62
+ @normalform .setter
63
+ def normalform (self , normalform : Type ["Normalform" ]):
64
+ self ._normalform = normalform
65
+ if self ._normalform .propagate_up :
66
+ self .propagate_up_normalform (normalform )
67
+ if self ._normalform .propagate_down :
68
+ self .propagate_down_normalform (normalform )
69
+
70
+ def propagate_up_normalform (self , normalform : Type ["Normalform" ]):
71
+ if self .parent :
72
+ # We set the parent normalform directly to avoid infinite recursion
73
+ # in case this normalform needs to be propagated up and down.
74
+ self .parent ._normalform = normalform
75
+ self .parent .propagate_up_normalform (normalform )
76
+
77
+ def propagate_down_normalform (self , normalform : Type ["Normalform" ]):
78
+ for child in self .children :
79
+ # We set the child normalform directly to avoid infinite recursion
80
+ # in case this normalform needs to be propagated up and down.
81
+ child ._normalform = normalform
82
+ child .propagate_down_normalform (normalform )
83
+
84
+ def normalize (self : "HandleT" , normalform : Type ["Normalform" ]) -> "HandleT" :
85
+ return normalform .apply (self )
86
+
45
87
46
88
@ir .register_value_caster (AnyOpType .get_static_typeid ())
47
89
@ir .register_value_caster (OperationType .get_static_typeid ())
@@ -192,6 +234,45 @@ def constant_param(value: Union[ir.Attribute, int]) -> ParamHandle:
192
234
return op .param
193
235
194
236
237
+ class Normalform :
238
+ """
239
+ Represents the weakest normalform and is the base class for all normalforms.
240
+ A normalform is defined as a sequence of transforms to be applied to a
241
+ handle to reach this normalform.
242
+
243
+ `propagate_up`: Propagate this normalform up to parent handles.
244
+ `propagate_down`: Propagate this normalform down to all child handles
245
+ """
246
+
247
+ propagate_up : bool = True
248
+ propagate_down : bool = True
249
+
250
+ def __init__ (self ):
251
+ raise TypeError (
252
+ "Normalform cannot be instantiated directly. Use Type[Normalform]"
253
+ "instead."
254
+ )
255
+
256
+ @classmethod
257
+ def _impl (cls , handle : "HandleT" ) -> "HandleT" :
258
+ """
259
+ Defines the transforms required to reach this normalform.
260
+ A normalform may apply arbitrary transforms and thus possibly
261
+ invalidate `handle`.
262
+ """
263
+ return handle
264
+
265
+ @classmethod
266
+ def apply (cls , handle : "HandleT" ) -> "HandleT" :
267
+ """Apply transforms to a handle to bring it into this normalform."""
268
+ new_handle = cls ._impl (handle )
269
+ new_handle .children .extend (handle .children )
270
+ new_handle .parent = handle .parent
271
+ # Setting this property propagates the normalform accordingly
272
+ new_handle .normalform = cls
273
+ return new_handle
274
+
275
+
195
276
def insert_transform_script (
196
277
block_or_insertion_point : Union [ir .Block , ir .InsertionPoint ],
197
278
script : Callable [[OpHandle ], None ],
0 commit comments