Skip to content

Commit f2ad711

Browse files
committed
Implement unconditional constant_folding rewrite
1 parent a570dbf commit f2ad711

File tree

2 files changed

+105
-50
lines changed

2 files changed

+105
-50
lines changed

pytensor/tensor/rewriting/basic.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from pytensor.graph import FunctionGraph
3333
from pytensor.graph.basic import Constant, Variable
3434
from pytensor.graph.rewriting.basic import (
35+
NodeProcessingGraphRewriter,
3536
NodeRewriter,
3637
RemovalNodeRewriter,
3738
Rewriter,
@@ -1101,10 +1102,7 @@ def local_useless_split(fgraph, node):
11011102

11021103

11031104
@node_rewriter(None)
1104-
def constant_folding(fgraph, node):
1105-
if not node.op.do_constant_folding(fgraph, node):
1106-
return False
1107-
1105+
def unconditional_constant_folding(fgraph, node):
11081106
if not all(isinstance(inp, Constant) for inp in node.inputs):
11091107
return False
11101108

@@ -1151,6 +1149,23 @@ def constant_folding(fgraph, node):
11511149
return rval
11521150

11531151

1152+
topo_unconditional_constant_folding = in2out(
1153+
unconditional_constant_folding,
1154+
ignore_newtrees=True,
1155+
name="topo_unconditional_constant_folding",
1156+
# Not all Ops have a perform method, so we ignore failures to constant_fold
1157+
failure_callback=NodeProcessingGraphRewriter.warn_ignore,
1158+
)
1159+
1160+
1161+
@node_rewriter(None)
1162+
def constant_folding(fgraph, node):
1163+
if not node.op.do_constant_folding(fgraph, node):
1164+
return False
1165+
1166+
return unconditional_constant_folding.transform(fgraph, node)
1167+
1168+
11541169
topo_constant_folding = in2out(
11551170
constant_folding, ignore_newtrees=True, name="topo_constant_folding"
11561171
)

tests/tensor/rewriting/test_basic.py

+86-46
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from pytensor.compile.mode import get_default_mode, get_mode
1313
from pytensor.compile.ops import DeepCopyOp, deep_copy_op
1414
from pytensor.configdefaults import config
15-
from pytensor.graph.basic import equal_computations
15+
from pytensor.graph import Op
16+
from pytensor.graph.basic import Constant, equal_computations
1617
from pytensor.graph.fg import FunctionGraph
1718
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
1819
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
@@ -29,6 +30,7 @@
2930
TensorFromScalar,
3031
as_tensor,
3132
cast,
33+
constant,
3234
join,
3335
tile,
3436
)
@@ -65,6 +67,8 @@
6567
local_merge_alloc,
6668
local_useless_alloc,
6769
local_useless_elemwise,
70+
topo_constant_folding,
71+
topo_unconditional_constant_folding,
6872
topological_fill_sink,
6973
)
7074
from pytensor.tensor.rewriting.math import local_lift_transpose_through_dot
@@ -742,56 +746,92 @@ def test_upcast(self):
742746
) or (len(topo) > 1)
743747

744748

745-
def test_constant_folding():
746-
# Test that constant folding get registered at fast_compile
747-
# An error removed that registration during the registration.
748-
x = dvector()
749-
mode = get_mode("FAST_COMPILE").excluding("fusion")
750-
f = function([x], [x * 2, x + x], mode=mode)
751-
topo = f.maker.fgraph.toposort()
752-
assert len(topo) == 2
753-
754-
# Test that we do not crash when constant folding elemwise scalar
755-
# as they should not generate c code.
749+
class TestConstantFolding:
750+
def test_constant_folding(self):
751+
# Test that constant folding get registered at fast_compile
752+
# An error removed that registration during the registration.
753+
x = dvector()
754+
mode = get_mode("FAST_COMPILE").excluding("fusion")
755+
f = function([x], [x * 2, x + x], mode=mode)
756+
topo = f.maker.fgraph.toposort()
757+
assert len(topo) == 2
756758

757-
x = pt.constant(3)
758-
assert x.ndim == 0
759-
mode = get_mode("FAST_COMPILE").excluding("fusion")
760-
f = function([], [x * 2, x + x], mode=mode)
761-
topo = f.maker.fgraph.toposort()
762-
assert len(topo) == 2
763-
assert all(isinstance(n.op, DeepCopyOp) for n in topo)
759+
# Test that we do not crash when constant folding elemwise scalar
760+
# as they should not generate c code.
764761

762+
x = pt.constant(3)
763+
assert x.ndim == 0
764+
mode = get_mode("FAST_COMPILE").excluding("fusion")
765+
f = function([], [x * 2, x + x], mode=mode)
766+
topo = f.maker.fgraph.toposort()
767+
assert len(topo) == 2
768+
assert all(isinstance(n.op, DeepCopyOp) for n in topo)
765769

766-
@pytest.mark.xfail(
767-
reason="PyTensor rewrites constants before stabilization. "
768-
"This breaks stabilization rewrites in some cases. See #504.",
769-
raises=AssertionError,
770-
)
771-
def test_constant_get_stabilized():
772-
# Currently PyTensor enables the `constant_folding` rewrite before stabilization rewrites.
773-
# This caused some stabilization rewrites to not be activated and that
774-
# caused inf values to appear when they should not.
770+
@pytest.mark.xfail(
771+
reason="PyTensor rewrites constants before stabilization. "
772+
"This breaks stabilization rewrites in some cases. See #504.",
773+
raises=AssertionError,
774+
)
775+
def test_constant_get_stabilized(self):
776+
# Currently PyTensor enables the `constant_folding` rewrite before stabilization rewrites.
777+
# This caused some stabilization rewrites to not be activated and that
778+
# caused inf values to appear when they should not.
775779

776-
# We can't simply move the `constant_folding` rewrite to
777-
# specialize since this will break other rewrites. We will need to
778-
# partially duplicate some canonicalize rewrites to fix this issue.
780+
# We can't simply move the `constant_folding` rewrite to
781+
# specialize since this will break other rewrites. We will need to
782+
# partially duplicate some canonicalize rewrites to fix this issue.
779783

780-
x2 = scalar()
781-
y2 = log(1 + exp(x2))
782-
mode = get_default_mode()
783-
mode.check_isfinite = False
784-
f2 = function([x2], y2, mode=mode)
785-
786-
assert len(f2.maker.fgraph.toposort()) == 1
787-
assert f2.maker.fgraph.toposort()[0].op == softplus
788-
assert f2(800) == 800
789-
790-
x = pt.as_tensor_variable(800)
791-
y = log(1 + exp(x))
792-
f = function([], y, mode=mode)
793-
# When this error is fixed, the following line should be ok.
794-
assert f() == 800, f()
784+
x2 = scalar()
785+
y2 = log(1 + exp(x2))
786+
mode = get_default_mode()
787+
mode.check_isfinite = False
788+
f2 = function([x2], y2, mode=mode)
789+
790+
assert len(f2.maker.fgraph.toposort()) == 1
791+
assert f2.maker.fgraph.toposort()[0].op == softplus
792+
assert f2(800) == 800
793+
794+
x = pt.as_tensor_variable(800)
795+
y = log(1 + exp(x))
796+
f = function([], y, mode=mode)
797+
# When this error is fixed, the following line should be ok.
798+
assert f() == 800, f()
799+
800+
def test_unconditional(self):
801+
x = pt.alloc(np.e, *(3, 5))
802+
fg = FunctionGraph(outputs=[x], clone=False)
803+
804+
# Default constant folding doesn't apply to Alloc used as outputs
805+
topo_constant_folding.apply(fg)
806+
assert not isinstance(fg.outputs[0], Constant)
807+
808+
# Unconditional constant folding does apply
809+
topo_unconditional_constant_folding.apply(fg)
810+
assert isinstance(fg.outputs[0], Constant)
811+
np.testing.assert_allclose(fg.outputs[0].data, np.full((3, 5), np.e))
812+
813+
def test_unconditional_no_perform_method(self):
814+
"""Test that errors are caught when the Op does not have a perform method."""
815+
816+
class OpNoPerform(Op):
817+
itypes = [scalar(dtype="float64").type]
818+
otypes = [scalar(dtype="float64").type]
819+
820+
def perform(self, *args, **kwargs):
821+
raise NotImplementedError("This Op cannot be evaluated")
822+
823+
x = constant(np.array(5.0))
824+
out = OpNoPerform()(x)
825+
826+
fg = FunctionGraph(outputs=[out], clone=False)
827+
# Default constant_folding will raise
828+
with pytest.raises(NotImplementedError):
829+
topo_constant_folding.apply(fg)
830+
831+
# Unconditional constant folding will be silent
832+
topo_unconditional_constant_folding.apply(fg)
833+
assert not isinstance(fg.outputs[0], Constant)
834+
assert isinstance(fg.outputs[0].owner.op, OpNoPerform)
795835

796836

797837
class TestLocalSwitchSink:

0 commit comments

Comments
 (0)