Skip to content

Commit 0666fd5

Browse files
committed
Fix bug in storage_input alignment of the JAX backend
When replacing the Shared RNG variables, the input order of the FunctionGraph was not explicitly aligned with the input storage of the function being compiled.
1 parent 3170c7d commit 0666fd5

File tree

2 files changed

+46
-3
lines changed

2 files changed

+46
-3
lines changed

pytensor/link/jax/linker.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
2323
# Replace any shared RNG inputs so that their values can be updated in place
2424
# without affecting the original RNG container. This is necessary because
2525
# JAX does not accept RandomState/Generators as inputs, and they will have to
26-
# be typyfied
26+
# be tipyfied
2727
if shared_rng_inputs:
2828
warnings.warn(
2929
f"The RandomType SharedVariables {shared_rng_inputs} will not be used "
@@ -52,9 +52,16 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
5252
else: # no break
5353
raise ValueError()
5454
input_storage[input_storage_idx] = new_inp_storage
55+
# We need to change the order of the inputs of the FunctionGraph
56+
# so that the new input is in the same position as to old one,
57+
# to align with the storage_map. We hope this is safe!
58+
old_inp_fgrap_index = fgraph.inputs.index(old_inp)
5559
fgraph.remove_input(
56-
fgraph.inputs.index(old_inp), reason="JAXLinker.fgraph_convert"
60+
old_inp_fgrap_index,
61+
reason="JAXLinker.fgraph_convert",
5762
)
63+
fgraph.inputs.remove(new_inp)
64+
fgraph.inputs.insert(old_inp_fgrap_index, new_inp)
5865

5966
return jax_funcify(
6067
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs

tests/link/jax/test_random.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pytensor.graph.basic import Constant
1111
from pytensor.graph.fg import FunctionGraph
1212
from pytensor.tensor.random.basic import RandomVariable
13+
from pytensor.tensor.random.type import RandomType
1314
from pytensor.tensor.random.utils import RandomStream
1415
from tests.link.jax.test_basic import compare_jax_and_py, jax_mode, set_test_value
1516

@@ -58,7 +59,42 @@ def test_random_updates(rng_ctor):
5859
)
5960

6061

61-
def test_random_updates_input_storage_order():
62+
@pytest.mark.parametrize("noise_first", (False, True))
63+
def test_replaced_shared_rng_storage_order(noise_first):
64+
# Test that replacing the RNG variable in the linker does not cause
65+
# a disalignment between the compiled graph and the storage_map.
66+
67+
mu = pytensor.shared(np.array(1.0), name="mu")
68+
rng = pytensor.shared(np.random.default_rng(123))
69+
next_rng, noise = pt.random.normal(rng=rng).owner.outputs
70+
71+
if noise_first:
72+
out = noise * mu
73+
else:
74+
out = mu * noise
75+
76+
updates = {
77+
mu: pt.grad(out, mu),
78+
rng: next_rng,
79+
}
80+
f = compile_random_function([], [out], updates=updates, mode="JAX")
81+
82+
# The bug was found when noise used to be the first input of the fgraph
83+
# If this changes, the test may need to be tweaked to keep the save coverage
84+
assert isinstance(
85+
f.input_storage[1 - noise_first].type, RandomType
86+
), "Test may need to be tweaked"
87+
88+
# Confirm that input_storage type and fgraph input order are aligned
89+
for storage, fgrapn_input in zip(f.input_storage, f.maker.fgraph.inputs):
90+
assert storage.type == fgrapn_input.type
91+
92+
assert mu.get_value() == 1
93+
f()
94+
assert mu.get_value() != 1
95+
96+
97+
def test_replaced_shared_rng_storage_ordering_equality():
6298
"""Test case described in issue #314.
6399
64100
This happened when we tried to update the input storage after we clone the shared RNG.

0 commit comments

Comments
 (0)