Skip to content

Commit 0ebc83b

Browse files
committed
Fix vectorize_graph bug when replacements were provided only some outputs of a node
The provided output could be silently ignored and replaced by the new output of the vectorized node. The changes also avoid vectorizing multiple-output nodes when none of the unreplaced outputs are needed.
1 parent c4ae6e3 commit 0ebc83b

File tree

4 files changed

+116
-8
lines changed

4 files changed

+116
-8
lines changed

pytensor/graph/basic.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1439,15 +1439,16 @@ def io_toposort(
14391439
order = []
14401440
while todo:
14411441
cur = todo.pop()
1442-
# We suppose that all outputs are always computed
1443-
if cur.outputs[0] in computed:
1442+
if all(out in computed for out in cur.outputs):
14441443
continue
14451444
if all(i in computed or i.owner is None for i in cur.inputs):
14461445
computed.update(cur.outputs)
14471446
order.append(cur)
14481447
else:
14491448
todo.append(cur)
1450-
todo.extend(i.owner for i in cur.inputs if i.owner)
1449+
todo.extend(
1450+
i.owner for i in cur.inputs if (i.owner and i not in computed)
1451+
)
14511452
return order
14521453

14531454
compute_deps = None

pytensor/graph/replace.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,11 @@ def vectorize_graph(
306306
vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs]
307307
vect_node = vectorize_node(node, *vect_inputs)
308308
for output, vect_output in zip(node.outputs, vect_node.outputs):
309+
if output in vect_vars:
310+
# This can happen when some outputs of a multi-output node are given a replacement,
311+
# while some of the remaining outputs are still needed in the graph.
312+
# We make sure we don't overwrite the provided replacement with the newly vectorized output
313+
continue
309314
vect_vars[output] = vect_output
310315

311316
seq_vect_outputs = [vect_vars[out] for out in seq_outputs]

tests/graph/test_basic.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from pytensor.tensor.type import TensorType, iscalars, matrix, scalars, vector
3535
from pytensor.tensor.type_other import NoneConst
3636
from pytensor.tensor.variable import TensorVariable
37-
from tests.graph.utils import MyInnerGraphOp
37+
from tests.graph.utils import MyInnerGraphOp, op_multiple_outputs
3838

3939

4040
class MyType(Type):
@@ -287,6 +287,45 @@ def test_outputs_clients(self):
287287
all = io_toposort([], o0.outputs)
288288
assert all == [o0]
289289

290+
def test_multi_output_nodes(self):
291+
l0, r0 = op_multiple_outputs(shared(0.0))
292+
l1, r1 = op_multiple_outputs(shared(0.0))
293+
294+
v0 = r0 + 1
295+
v1 = pt.exp(v0)
296+
out = r1 * v1
297+
298+
# When either r0 or r1 is provided as an input, the respective node shouldn't be part of the toposort
299+
assert set(io_toposort([], [out])) == {
300+
r0.owner,
301+
r1.owner,
302+
v0.owner,
303+
v1.owner,
304+
out.owner,
305+
}
306+
assert set(io_toposort([r0], [out])) == {
307+
r1.owner,
308+
v0.owner,
309+
v1.owner,
310+
out.owner,
311+
}
312+
assert set(io_toposort([r1], [out])) == {
313+
r0.owner,
314+
v0.owner,
315+
v1.owner,
316+
out.owner,
317+
}
318+
assert set(io_toposort([r0, r1], [out])) == {v0.owner, v1.owner, out.owner}
319+
320+
# When l0 and/or l1 are provided, we still need to compute the respective nodes
321+
assert set(io_toposort([l0, l1], [out])) == {
322+
r0.owner,
323+
r1.owner,
324+
v0.owner,
325+
v1.owner,
326+
out.owner,
327+
}
328+
290329

291330
class TestEval:
292331
def setup_method(self):

tests/graph/test_replace.py

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,15 @@
55
import pytensor.tensor as pt
66
from pytensor import config, function, shared
77
from pytensor.graph.basic import equal_computations, graph_inputs
8-
from pytensor.graph.replace import clone_replace, graph_replace, vectorize_graph
8+
from pytensor.graph.replace import (
9+
clone_replace,
10+
graph_replace,
11+
vectorize_graph,
12+
vectorize_node,
13+
)
914
from pytensor.tensor import dvector, fvector, vector
1015
from tests import unittest_tools as utt
11-
from tests.graph.utils import MyOp, MyVariable
16+
from tests.graph.utils import MyOp, MyVariable, op_multiple_outputs
1217

1318

1419
class TestCloneReplace:
@@ -227,8 +232,6 @@ def test_graph_replace_disconnected(self):
227232

228233

229234
class TestVectorizeGraph:
230-
# TODO: Add tests with multiple outputs, constants, and other singleton types
231-
232235
def test_basic(self):
233236
x = pt.vector("x")
234237
y = pt.exp(x) / pt.sum(pt.exp(x))
@@ -260,3 +263,63 @@ def test_multiple_outputs(self):
260263
new_y1_res, new_y2_res = fn(new_x_test)
261264
np.testing.assert_allclose(new_y1_res, [0, 3, 6])
262265
np.testing.assert_allclose(new_y2_res, [2, 5, 8])
266+
267+
def test_multi_output_node(self):
268+
x = pt.scalar("x")
269+
node = op_multiple_outputs.make_node(x)
270+
y1, y2 = node.outputs
271+
out = pt.add(y1, y2)
272+
273+
new_x = pt.vector("new_x")
274+
new_y1 = pt.vector("new_y1")
275+
new_y2 = pt.vector("new_y2")
276+
277+
# Cases where either x or both of y1 and y2 are given replacements
278+
new_out = vectorize_graph(out, {x: new_x})
279+
expected_new_out = pt.add(*vectorize_node(node, new_x).outputs)
280+
assert equal_computations([new_out], [expected_new_out])
281+
282+
new_out = vectorize_graph(out, {y1: new_y1, y2: new_y2})
283+
expected_new_out = pt.add(new_y1, new_y2)
284+
assert equal_computations([new_out], [expected_new_out])
285+
286+
new_out = vectorize_graph(out, {x: new_x, y1: new_y1, y2: new_y2})
287+
expected_new_out = pt.add(new_y1, new_y2)
288+
assert equal_computations([new_out], [expected_new_out])
289+
290+
# Special case where x is given a replacement as well as only one of y1 and y2
291+
# The graph combines the replaced variable with the other vectorized output
292+
new_out = vectorize_graph(out, {x: new_x, y1: new_y1})
293+
expected_new_out = pt.add(new_y1, vectorize_node(node, new_x).outputs[1])
294+
assert equal_computations([new_out], [expected_new_out])
295+
296+
def test_multi_output_node_random_variable(self):
297+
"""This is a regression test for #569.
298+
299+
Functionally, it covers the same case as `test_multiple_output_node`
300+
"""
301+
302+
# RandomVariables have two outputs, a hidden RNG and the visible draws
303+
beta0 = pt.random.normal(name="beta0")
304+
beta1 = pt.random.normal(name="beta1")
305+
306+
out1 = beta0 + 1
307+
out2 = beta1 * pt.exp(out1)
308+
309+
# We replace the second output of each RandomVariable
310+
new_beta0 = pt.tensor("new_beta0", shape=(3,))
311+
new_beta1 = pt.tensor("new_beta1", shape=(3,))
312+
313+
new_outs = vectorize_graph(
314+
[out1, out2],
315+
replace={
316+
beta0: new_beta0,
317+
beta1: new_beta1,
318+
},
319+
)
320+
321+
expected_new_outs = [
322+
new_beta0 + 1,
323+
new_beta1 * pt.exp(new_beta0 + 1),
324+
]
325+
assert equal_computations(new_outs, expected_new_outs)

0 commit comments

Comments
 (0)