Skip to content

Commit 43744bf

Browse files
committed
Avoid PyTensor function overhead in OpFromGraph
Also provide pure C-implementation when all Ops allow it. LazyLinker does not complain about thunks that return outputs, since itself can be a thunk. Adding a Python wrapper that hides the outputs incurs considerable overhead, and modifying the LazyLinker to optionally not return outputs seems unnecessarily complex.
1 parent 676296c commit 43744bf

File tree

5 files changed

+174
-30
lines changed

5 files changed

+174
-30
lines changed

pytensor/compile/builders.py

+109-6
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
from functools import partial
77
from typing import Union, cast
88

9-
from pytensor.compile.function import function
10-
from pytensor.compile.function.pfunc import rebuild_collect_shared
9+
from pytensor.compile import get_default_mode, insert_deepcopy
10+
from pytensor.compile.function.pfunc import pfunc, rebuild_collect_shared
11+
from pytensor.compile.function.types import add_supervisor_to_fgraph
12+
from pytensor.compile.io import In, Out
13+
from pytensor.compile.mode import Mode
1114
from pytensor.compile.sharedvalue import SharedVariable
1215
from pytensor.configdefaults import config
1316
from pytensor.gradient import DisconnectedType, Rop, grad
@@ -21,7 +24,7 @@
2124
)
2225
from pytensor.graph.fg import FunctionGraph
2326
from pytensor.graph.null_type import NullType
24-
from pytensor.graph.op import HasInnerGraph, Op
27+
from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType
2528
from pytensor.graph.replace import clone_replace
2629
from pytensor.graph.utils import MissingInputError
2730

@@ -433,6 +436,9 @@ def __init__(
433436
assert isinstance(name, str), "name must be None or string object"
434437
self.name = name
435438
self.destroy_map = destroy_map if destroy_map is not None else {}
439+
self._rewritten_fgraph = {}
440+
self._wrapped_inputs = {}
441+
self._wrapped_outputs = {}
436442

437443
def __eq__(self, other):
438444
# TODO: recognize a copy
@@ -847,14 +853,58 @@ def infer_shape(self, fgraph, node, shapes):
847853

848854
return ret
849855

856+
def _rewrite_fgraph(self, impl):
857+
if self._rewritten_fgraph.get(impl, None) is None:
858+
mode = get_default_mode()
859+
if impl == "py":
860+
mode = mode.excluding("cxx")
861+
rewriter = mode.optimizer
862+
863+
# We are cloning fgraph too many times, but one of the existing tests checks for this
864+
# TestOpFromGraph.test_outputs_consistency
865+
fgraph = self.fgraph.clone()
866+
self._wrapped_inputs[impl] = temp_wrapped_inputs = [
867+
In(inp, borrow=False, mutable=False) for inp in fgraph.inputs
868+
]
869+
# These are just temporary because the graph rewirite may change them
870+
temp_wrapped_outputs = [
871+
Out(out, borrow=True) for out in self.fgraph.outputs
872+
]
873+
add_supervisor_to_fgraph(
874+
fgraph,
875+
temp_wrapped_inputs,
876+
accept_inplace=False,
877+
)
878+
with config.change_flags(compute_test_value="off"):
879+
rewriter(fgraph)
880+
insert_deepcopy(fgraph, temp_wrapped_inputs, temp_wrapped_outputs)
881+
self._wrapped_outputs[impl] = [
882+
Out(out, borrow=True) for out in fgraph.outputs
883+
]
884+
self._rewritten_fgraph[impl] = fgraph
885+
886+
return (
887+
self._rewritten_fgraph[impl],
888+
self._wrapped_inputs[impl],
889+
self._wrapped_outputs[impl],
890+
)
891+
850892
@property
851893
def fn(self):
852-
"""Lazily compile the inner function graph."""
853894
if getattr(self, "_fn", None) is not None:
854895
return self._fn
855896

856-
self._fn = function(self.inner_inputs, self.inner_outputs, **self.kwargs)
857-
self._fn.trust_input = True
897+
fgraph, wrapped_inputs, wrapped_outputs = self._rewrite_fgraph(impl=None)
898+
899+
self._fn = pfunc(
900+
wrapped_inputs,
901+
wrapped_outputs,
902+
mode=Mode(linker=get_default_mode().linker, optimizer=None),
903+
accept_inplace=True,
904+
on_unused_input="ignore",
905+
fgraph=fgraph,
906+
trust_input=True,
907+
)
858908

859909
return self._fn
860910

@@ -871,6 +921,59 @@ def clone(self):
871921
res.fgraph = res.fgraph.clone()
872922
return res
873923

924+
def prepare_node(
925+
self,
926+
node: Apply,
927+
storage_map: StorageMapType | None,
928+
compute_map: ComputeMapType | None,
929+
impl: str | None,
930+
) -> None:
931+
self._rewrite_fgraph(impl)
932+
self.fn
933+
934+
def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
935+
from pytensor.link.c.basic import CLinker
936+
from pytensor.link.vm import VMLinker
937+
938+
self.prepare_node(node, storage_map, compute_map, impl)
939+
fg, _, _ = self._rewrite_fgraph(impl)
940+
fg_no_recycling = [
941+
new_o
942+
for (new_o, old_o) in zip(fg.outputs, node.outputs, strict=True)
943+
if old_o in no_recycling
944+
]
945+
946+
node_input_storage = [storage_map[r] for r in node.inputs]
947+
node_output_storage = [storage_map[r] for r in node.outputs]
948+
node_compute_map = [compute_map[r] for r in node.outputs]
949+
950+
def create_thunk(linker):
951+
linker.accept(fg, no_recycling=fg_no_recycling)
952+
thunk, _, _ = linker.make_thunk(
953+
input_storage=node_input_storage,
954+
output_storage=node_output_storage,
955+
)
956+
return thunk
957+
958+
def thunk_wrapper(thunk=thunk, node_compute_map=node_compute_map):
959+
thunk()
960+
for cm in node_compute_map:
961+
cm[0] = True
962+
963+
return thunk_wrapper
964+
965+
if impl != "py":
966+
try:
967+
# We default to CLinker because it generates code for the whole graph that the compiler can reason about.
968+
# Whereas the VMLinker will compile each node separately and call them in a pre-defined VM.
969+
# It also has less overhead
970+
return create_thunk(linker=CLinker())
971+
except NotImplementedError:
972+
# Some Op doesn't have a C implementation, VM it is
973+
return create_thunk(VMLinker(use_cloop=True, c_thunks=True))
974+
else:
975+
return create_thunk(VMLinker(use_cloop=False, c_thunks=False))
976+
874977
def perform(self, node, inputs, outputs):
875978
variables = self.fn(*inputs)
876979
assert len(variables) == len(outputs)

pytensor/link/c/c_code/lazylinker_c.c

+2-15
Original file line numberDiff line numberDiff line change
@@ -676,20 +676,7 @@ static int lazy_rec_eval(CLazyLinker *self, Py_ssize_t var_idx, PyObject *one,
676676
// rval is new ref
677677
if (rval) // pycall returned normally (no exception)
678678
{
679-
if (rval == Py_None) {
680-
Py_DECREF(rval); // ignore a return of None
681-
} else if (PyList_Check(rval)) {
682-
PyErr_SetString(PyExc_TypeError,
683-
"non-lazy thunk should return None, not list");
684-
err = 1;
685-
goto pyfail;
686-
} else // don't know what it returned, but it wasn't right.
687-
{
688-
PyErr_SetObject(PyExc_TypeError, rval);
689-
err = 1;
690-
// We don't release rval since we put it in the error above
691-
goto fail;
692-
}
679+
Py_DECREF(rval); // ignore whatever was returned
693680
} else // pycall returned NULL (internal error)
694681
{
695682
err = 1;
@@ -981,7 +968,7 @@ static PyTypeObject lazylinker_ext_CLazyLinkerType = {
981968
};
982969

983970
static PyObject *get_version(PyObject *dummy, PyObject *args) {
984-
PyObject *result = PyFloat_FromDouble(0.3);
971+
PyObject *result = PyFloat_FromDouble(0.4);
985972
return result;
986973
}
987974

pytensor/link/c/lazylinker_c.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
_logger = logging.getLogger(__file__)
1515

1616
force_compile = False
17-
version = 0.3 # must match constant returned in function get_version()
17+
version = 0.4 # must match constant returned in function get_version()
1818
lazylinker_ext: ModuleType | None = None
1919

2020

pytensor/tensor/rewriting/basic.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -1120,15 +1120,11 @@ def unconditional_constant_folding(fgraph, node):
11201120
compute_map[o] = [False]
11211121

11221122
thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling=[])
1123-
required = thunk()
1124-
1125-
# A node whose inputs are all provided should always return successfully
1126-
assert not required
1123+
thunk()
11271124

11281125
rval = []
11291126
for output in node.outputs:
11301127
data = storage_map[output][0]
1131-
assert compute_map[output][0], (output, data)
11321128

11331129
# TODO: `Type` itself should provide an interface for constructing
11341130
# instances appropriate for a given constant.

tests/compile/test_builders.py

+61-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55

66
import pytensor.tensor as pt
7+
from pytensor import scan
78
from pytensor.compile import shared
89
from pytensor.compile.builders import OpFromGraph
910
from pytensor.compile.function import function
@@ -15,9 +16,10 @@
1516
grad,
1617
verify_grad,
1718
)
18-
from pytensor.graph.basic import equal_computations
19+
from pytensor.graph.basic import Apply, equal_computations
1920
from pytensor.graph.fg import FunctionGraph
2021
from pytensor.graph.null_type import NullType, null_type
22+
from pytensor.graph.op import Op
2123
from pytensor.graph.rewriting.utils import rewrite_graph
2224
from pytensor.graph.utils import MissingInputError
2325
from pytensor.printing import debugprint
@@ -622,14 +624,15 @@ def test_outputs_consistency(self):
622624
"""Make sure that `OpFromGraph.fn` doesn't change the value of `OpFromGraph.inner_outputs`."""
623625

624626
x = scalar("x")
625-
op = OpFromGraph([x], [x**2 / x], mode="FAST_RUN")
627+
op = OpFromGraph([x], [x**2 / x])
626628

627629
# Confirm that the inner-graph is as expected
628630
assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x])
629631

630632
# These outputs of the compiled `op.fgraph` should differ from the
631633
# original, uncompiled `op.fgraph` outputs
632-
fn = op.fn
634+
with config.change_flags(mode="FAST_RUN"):
635+
fn = op.fn
633636
new_inputs = fn.maker.fgraph.inputs
634637
new_outputs = fn.maker.fgraph.outputs
635638
assert not equal_computations(new_outputs, [x**2 / x], new_inputs, [x])
@@ -740,3 +743,58 @@ def test_debugprint():
740743

741744
for truth, out in zip(exp_res.split("\n"), lines, strict=True):
742745
assert truth.strip() == out.strip()
746+
747+
748+
@pytest.mark.parametrize("kind", ("ofg", "inlined", "scan"))
749+
@pytest.mark.parametrize("c_op", (True, False), ids=lambda x: f"c_op={x}")
750+
def test_benchmark(c_op, kind, benchmark):
751+
class ExpWithoutC(Op):
752+
def make_node(self, x):
753+
return Apply(self, [x], [x.type()])
754+
755+
def perform(self, node, inputs, output_storage):
756+
output_storage[0][0] = np.exp(inputs[0])
757+
758+
exp_without_c = ExpWithoutC()
759+
760+
n = 25
761+
762+
def _f(x):
763+
if isinstance(x, np.ndarray):
764+
y = np.exp(x)
765+
else:
766+
if c_op:
767+
y = pt.exp(x)
768+
else:
769+
y = exp_without_c(x)
770+
y /= y.sum()
771+
return y
772+
773+
x = pt.vector("x")
774+
775+
if kind == "ofg":
776+
f = OpFromGraph([x], [_f(x)])
777+
else:
778+
f = _f
779+
780+
if kind == "scan":
781+
# Scan is included for a reference of how bad the overhead can be
782+
outs, _ = scan(fn=f, outputs_info=[x], n_steps=n)
783+
out = outs[-1]
784+
else:
785+
out = x
786+
for i in range(n):
787+
out = f(out)
788+
789+
compiled_fn = function([x], out, trust_input=True, mode="FAST_RUN")
790+
compiled_fn.vm.allow_gc = False
791+
792+
rng = np.random.default_rng(1)
793+
x_test = rng.normal(size=(10,))
794+
795+
res = benchmark(compiled_fn, x_test)
796+
797+
expected_res = x_test
798+
for i in range(n):
799+
expected_res = _f(expected_res)
800+
np.testing.assert_allclose(res, expected_res)

0 commit comments

Comments
 (0)