diff --git a/pytensor/link/jax/linker.py b/pytensor/link/jax/linker.py index 2d75e76d5c..7b76265197 100644 --- a/pytensor/link/jax/linker.py +++ b/pytensor/link/jax/linker.py @@ -23,7 +23,7 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): # Replace any shared RNG inputs so that their values can be updated in place # without affecting the original RNG container. This is necessary because # JAX does not accept RandomState/Generators as inputs, and they will have to - # be typyfied + # be tipyfied if shared_rng_inputs: warnings.warn( f"The RandomType SharedVariables {shared_rng_inputs} will not be used " @@ -52,9 +52,16 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): else: # no break raise ValueError() input_storage[input_storage_idx] = new_inp_storage + # We need to change the order of the inputs of the FunctionGraph + # so that the new input is in the same position as to old one, + # to align with the storage_map. We hope this is safe! + old_inp_fgrap_index = fgraph.inputs.index(old_inp) fgraph.remove_input( - fgraph.inputs.index(old_inp), reason="JAXLinker.fgraph_convert" + old_inp_fgrap_index, + reason="JAXLinker.fgraph_convert", ) + fgraph.inputs.remove(new_inp) + fgraph.inputs.insert(old_inp_fgrap_index, new_inp) return jax_funcify( fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 811e8122de..4e813fd128 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -10,6 +10,7 @@ from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph from pytensor.tensor.random.basic import RandomVariable +from pytensor.tensor.random.type import RandomType from pytensor.tensor.random.utils import RandomStream from tests.link.jax.test_basic import compare_jax_and_py, jax_mode, set_test_value @@ -20,7 +21,7 @@ from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402 -def random_function(*args, **kwargs): +def compile_random_function(*args, **kwargs): with pytest.warns( UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used" ): @@ -35,7 +36,7 @@ def test_random_RandomStream(): srng = RandomStream(seed=123) out = srng.normal() - srng.normal() - fn = random_function([], out, mode=jax_mode) + fn = compile_random_function([], out, mode=jax_mode) jax_res_1 = fn() jax_res_2 = fn() @@ -48,7 +49,7 @@ def test_random_updates(rng_ctor): rng = shared(original_value, name="original_rng", borrow=False) next_rng, x = pt.random.normal(name="x", rng=rng).owner.outputs - f = random_function([], [x], updates={rng: next_rng}, mode=jax_mode) + f = compile_random_function([], [x], updates={rng: next_rng}, mode=jax_mode) assert f() != f() # Check that original rng variable content was not overwritten when calling jax_typify @@ -58,7 +59,42 @@ def test_random_updates(rng_ctor): ) -def test_random_updates_input_storage_order(): +@pytest.mark.parametrize("noise_first", (False, True)) +def test_replaced_shared_rng_storage_order(noise_first): + # Test that replacing the RNG variable in the linker does not cause + # a disalignment between the compiled graph and the storage_map. + + mu = pytensor.shared(np.array(1.0), name="mu") + rng = pytensor.shared(np.random.default_rng(123)) + next_rng, noise = pt.random.normal(rng=rng).owner.outputs + + if noise_first: + out = noise * mu + else: + out = mu * noise + + updates = { + mu: pt.grad(out, mu), + rng: next_rng, + } + f = compile_random_function([], [out], updates=updates, mode="JAX") + + # The bug was found when noise used to be the first input of the fgraph + # If this changes, the test may need to be tweaked to keep the save coverage + assert isinstance( + f.input_storage[1 - noise_first].type, RandomType + ), "Test may need to be tweaked" + + # Confirm that input_storage type and fgraph input order are aligned + for storage, fgrapn_input in zip(f.input_storage, f.maker.fgraph.inputs): + assert storage.type == fgrapn_input.type + + assert mu.get_value() == 1 + f() + assert mu.get_value() != 1 + + +def test_replaced_shared_rng_storage_ordering_equality(): """Test case described in issue #314. This happened when we tried to update the input storage after we clone the shared RNG. @@ -79,7 +115,7 @@ def test_random_updates_input_storage_order(): # This function replaces inp by input_shared in the update expression # This is what caused the RNG to appear later than inp_shared in the input_storage - fn = random_function( + fn = compile_random_function( inputs=[], outputs=[], updates={inp_shared: inp_update}, @@ -453,7 +489,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c else: rng = shared(np.random.RandomState(29402)) g = rv_op(*dist_params, size=(10_000,) + base_size, rng=rng) - g_fn = random_function(dist_params, g, mode=jax_mode) + g_fn = compile_random_function(dist_params, g, mode=jax_mode) samples = g_fn( *[ i.tag.test_value @@ -477,7 +513,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c def test_random_bernoulli(size): rng = shared(np.random.RandomState(123)) g = pt.random.bernoulli(0.5, size=(1000,) + size, rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1) @@ -488,7 +524,7 @@ def test_random_mvnormal(): mu = np.ones(4) cov = np.eye(4) g = pt.random.multivariate_normal(mu, cov, size=(10000,), rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1) @@ -503,7 +539,7 @@ def test_random_mvnormal(): def test_random_dirichlet(parameter, size): rng = shared(np.random.RandomState(123)) g = pt.random.dirichlet(parameter, size=(1000,) + size, rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1) @@ -513,21 +549,21 @@ def test_random_choice(): num_samples = 10000 rng = shared(np.random.RandomState(123)) g = pt.random.choice(np.arange(4), size=num_samples, rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(np.sum(samples == 3) / num_samples, 0.25, 2) # `replace=False` produces unique results rng = shared(np.random.RandomState(123)) g = pt.random.choice(np.arange(100), replace=False, size=99, rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() assert len(np.unique(samples)) == 99 # We can pass an array with probabilities rng = shared(np.random.RandomState(123)) g = pt.random.choice(np.arange(3), p=np.array([1.0, 0.0, 0.0]), size=10, rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples, np.zeros(10)) @@ -535,7 +571,7 @@ def test_random_choice(): def test_random_categorical(): rng = shared(np.random.RandomState(123)) g = pt.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), 6 / 4, 1) @@ -544,7 +580,7 @@ def test_random_permutation(): array = np.arange(4) rng = shared(np.random.RandomState(123)) g = pt.random.permutation(array, rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) permuted = g_fn() with pytest.raises(AssertionError): np.testing.assert_allclose(array, permuted) @@ -554,7 +590,7 @@ def test_random_geometric(): rng = shared(np.random.RandomState(123)) p = np.array([0.3, 0.7]) g = pt.random.geometric(p, size=(10_000, 2), rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), 1 / p, rtol=0.1) np.testing.assert_allclose(samples.std(axis=0), np.sqrt((1 - p) / p**2), rtol=0.1) @@ -565,7 +601,7 @@ def test_negative_binomial(): n = np.array([10, 40]) p = np.array([0.3, 0.7]) g = pt.random.negative_binomial(n, p, size=(10_000, 2), rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), n * (1 - p) / p, rtol=0.1) np.testing.assert_allclose( @@ -579,7 +615,7 @@ def test_binomial(): n = np.array([10, 40]) p = np.array([0.3, 0.7]) g = pt.random.binomial(n, p, size=(10_000, 2), rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1) np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.1) @@ -594,7 +630,7 @@ def test_beta_binomial(): a = np.array([1.5, 13]) b = np.array([0.5, 9]) g = pt.random.betabinom(n, a, b, size=(10_000, 2), rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), n * a / (a + b), rtol=0.1) np.testing.assert_allclose( @@ -612,7 +648,7 @@ def test_multinomial(): n = np.array([10, 40]) p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]]) g = pt.random.multinomial(n, p, size=(10_000, 2), rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1) np.testing.assert_allclose( @@ -628,7 +664,7 @@ def test_vonmises_mu_outside_circle(): mu = np.array([-30, 40]) kappa = np.array([100, 10]) g = pt.random.vonmises(mu, kappa, size=(10_000, 2), rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose( samples.mean(axis=0), (mu + np.pi) % (2.0 * np.pi) - np.pi, rtol=0.1 @@ -728,7 +764,7 @@ def test_random_concrete_shape(): rng = shared(np.random.RandomState(123)) x_pt = pt.dmatrix() out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng) - jax_fn = random_function([x_pt], out, mode=jax_mode) + jax_fn = compile_random_function([x_pt], out, mode=jax_mode) assert jax_fn(np.ones((2, 3))).shape == (2, 3) @@ -736,7 +772,7 @@ def test_random_concrete_shape_from_param(): rng = shared(np.random.RandomState(123)) x_pt = pt.dmatrix() out = pt.random.normal(x_pt, 1, rng=rng) - jax_fn = random_function([x_pt], out, mode=jax_mode) + jax_fn = compile_random_function([x_pt], out, mode=jax_mode) assert jax_fn(np.ones((2, 3))).shape == (2, 3) @@ -755,7 +791,7 @@ def test_random_concrete_shape_subtensor(): rng = shared(np.random.RandomState(123)) x_pt = pt.dmatrix() out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng) - jax_fn = random_function([x_pt], out, mode=jax_mode) + jax_fn = compile_random_function([x_pt], out, mode=jax_mode) assert jax_fn(np.ones((2, 3))).shape == (3,) @@ -771,7 +807,7 @@ def test_random_concrete_shape_subtensor_tuple(): rng = shared(np.random.RandomState(123)) x_pt = pt.dmatrix() out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng) - jax_fn = random_function([x_pt], out, mode=jax_mode) + jax_fn = compile_random_function([x_pt], out, mode=jax_mode) assert jax_fn(np.ones((2, 3))).shape == (2,) @@ -782,5 +818,5 @@ def test_random_concrete_shape_graph_input(): rng = shared(np.random.RandomState(123)) size_pt = pt.scalar() out = pt.random.normal(0, 1, size=size_pt, rng=rng) - jax_fn = random_function([size_pt], out, mode=jax_mode) + jax_fn = compile_random_function([size_pt], out, mode=jax_mode) assert jax_fn(10).shape == (10,)