|
10 | 10 | from pytensor.graph.basic import Constant
|
11 | 11 | from pytensor.graph.fg import FunctionGraph
|
12 | 12 | from pytensor.tensor.random.basic import RandomVariable
|
| 13 | +from pytensor.tensor.random.type import RandomType |
13 | 14 | from pytensor.tensor.random.utils import RandomStream
|
14 | 15 | from tests.link.jax.test_basic import compare_jax_and_py, jax_mode, set_test_value
|
15 | 16 |
|
@@ -58,7 +59,42 @@ def test_random_updates(rng_ctor):
|
58 | 59 | )
|
59 | 60 |
|
60 | 61 |
|
61 |
| -def test_random_updates_input_storage_order(): |
| 62 | +@pytest.mark.parametrize("noise_first", (False, True)) |
| 63 | +def test_replaced_shared_rng_storage_order(noise_first): |
| 64 | + # Test that replacing the RNG variable in the linker does not cause |
| 65 | + # a disalignment between the compiled graph and the storage_map. |
| 66 | + |
| 67 | + mu = pytensor.shared(np.array(1.0), name="mu") |
| 68 | + rng = pytensor.shared(np.random.default_rng(123)) |
| 69 | + next_rng, noise = pt.random.normal(rng=rng).owner.outputs |
| 70 | + |
| 71 | + if noise_first: |
| 72 | + out = noise * mu |
| 73 | + else: |
| 74 | + out = mu * noise |
| 75 | + |
| 76 | + updates = { |
| 77 | + mu: pt.grad(out, mu), |
| 78 | + rng: next_rng, |
| 79 | + } |
| 80 | + f = compile_random_function([], [out], updates=updates, mode="JAX") |
| 81 | + |
| 82 | + # The bug was found when noise used to be the first input of the fgraph |
| 83 | + # If this changes, the test may need to be tweaked to keep the save coverage |
| 84 | + assert isinstance( |
| 85 | + f.input_storage[1 - noise_first].type, RandomType |
| 86 | + ), "Test may need to be tweaked" |
| 87 | + |
| 88 | + # Confirm that input_storage type and fgraph input order are aligned |
| 89 | + for storage, fgrapn_input in zip(f.input_storage, f.maker.fgraph.inputs): |
| 90 | + assert storage.type == fgrapn_input.type |
| 91 | + |
| 92 | + assert mu.get_value() == 1 |
| 93 | + f() |
| 94 | + assert mu.get_value() != 1 |
| 95 | + |
| 96 | + |
| 97 | +def test_replaced_shared_rng_storage_ordering_equality(): |
62 | 98 | """Test case described in issue #314.
|
63 | 99 |
|
64 | 100 | This happened when we tried to update the input storage after we clone the shared RNG.
|
|
0 commit comments