From ee4ccea5669ec40d95e95f48308a93c9ab5a77af Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 12 Jan 2024 13:41:33 +0100 Subject: [PATCH 1/2] Rename helper function --- tests/link/jax/test_random.py | 48 +++++++++++++++++------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 811e8122de..10b10ff543 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -20,7 +20,7 @@ from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402 -def random_function(*args, **kwargs): +def compile_random_function(*args, **kwargs): with pytest.warns( UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used" ): @@ -35,7 +35,7 @@ def test_random_RandomStream(): srng = RandomStream(seed=123) out = srng.normal() - srng.normal() - fn = random_function([], out, mode=jax_mode) + fn = compile_random_function([], out, mode=jax_mode) jax_res_1 = fn() jax_res_2 = fn() @@ -48,7 +48,7 @@ def test_random_updates(rng_ctor): rng = shared(original_value, name="original_rng", borrow=False) next_rng, x = pt.random.normal(name="x", rng=rng).owner.outputs - f = random_function([], [x], updates={rng: next_rng}, mode=jax_mode) + f = compile_random_function([], [x], updates={rng: next_rng}, mode=jax_mode) assert f() != f() # Check that original rng variable content was not overwritten when calling jax_typify @@ -79,7 +79,7 @@ def test_random_updates_input_storage_order(): # This function replaces inp by input_shared in the update expression # This is what caused the RNG to appear later than inp_shared in the input_storage - fn = random_function( + fn = compile_random_function( inputs=[], outputs=[], updates={inp_shared: inp_update}, @@ -453,7 +453,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c else: rng = shared(np.random.RandomState(29402)) g = rv_op(*dist_params, size=(10_000,) + base_size, rng=rng) - g_fn = random_function(dist_params, g, mode=jax_mode) + g_fn = compile_random_function(dist_params, g, mode=jax_mode) samples = g_fn( *[ i.tag.test_value @@ -477,7 +477,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c def test_random_bernoulli(size): rng = shared(np.random.RandomState(123)) g = pt.random.bernoulli(0.5, size=(1000,) + size, rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1) @@ -488,7 +488,7 @@ def test_random_mvnormal(): mu = np.ones(4) cov = np.eye(4) g = pt.random.multivariate_normal(mu, cov, size=(10000,), rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1) @@ -503,7 +503,7 @@ def test_random_mvnormal(): def test_random_dirichlet(parameter, size): rng = shared(np.random.RandomState(123)) g = pt.random.dirichlet(parameter, size=(1000,) + size, rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1) @@ -513,21 +513,21 @@ def test_random_choice(): num_samples = 10000 rng = shared(np.random.RandomState(123)) g = pt.random.choice(np.arange(4), size=num_samples, rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(np.sum(samples == 3) / num_samples, 0.25, 2) # `replace=False` produces unique results rng = shared(np.random.RandomState(123)) g = pt.random.choice(np.arange(100), replace=False, size=99, rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() assert len(np.unique(samples)) == 99 # We can pass an array with probabilities rng = shared(np.random.RandomState(123)) g = pt.random.choice(np.arange(3), p=np.array([1.0, 0.0, 0.0]), size=10, rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples, np.zeros(10)) @@ -535,7 +535,7 @@ def test_random_choice(): def test_random_categorical(): rng = shared(np.random.RandomState(123)) g = pt.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), 6 / 4, 1) @@ -544,7 +544,7 @@ def test_random_permutation(): array = np.arange(4) rng = shared(np.random.RandomState(123)) g = pt.random.permutation(array, rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) permuted = g_fn() with pytest.raises(AssertionError): np.testing.assert_allclose(array, permuted) @@ -554,7 +554,7 @@ def test_random_geometric(): rng = shared(np.random.RandomState(123)) p = np.array([0.3, 0.7]) g = pt.random.geometric(p, size=(10_000, 2), rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), 1 / p, rtol=0.1) np.testing.assert_allclose(samples.std(axis=0), np.sqrt((1 - p) / p**2), rtol=0.1) @@ -565,7 +565,7 @@ def test_negative_binomial(): n = np.array([10, 40]) p = np.array([0.3, 0.7]) g = pt.random.negative_binomial(n, p, size=(10_000, 2), rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), n * (1 - p) / p, rtol=0.1) np.testing.assert_allclose( @@ -579,7 +579,7 @@ def test_binomial(): n = np.array([10, 40]) p = np.array([0.3, 0.7]) g = pt.random.binomial(n, p, size=(10_000, 2), rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1) np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.1) @@ -594,7 +594,7 @@ def test_beta_binomial(): a = np.array([1.5, 13]) b = np.array([0.5, 9]) g = pt.random.betabinom(n, a, b, size=(10_000, 2), rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), n * a / (a + b), rtol=0.1) np.testing.assert_allclose( @@ -612,7 +612,7 @@ def test_multinomial(): n = np.array([10, 40]) p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]]) g = pt.random.multinomial(n, p, size=(10_000, 2), rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1) np.testing.assert_allclose( @@ -628,7 +628,7 @@ def test_vonmises_mu_outside_circle(): mu = np.array([-30, 40]) kappa = np.array([100, 10]) g = pt.random.vonmises(mu, kappa, size=(10_000, 2), rng=rng) - g_fn = random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose( samples.mean(axis=0), (mu + np.pi) % (2.0 * np.pi) - np.pi, rtol=0.1 @@ -728,7 +728,7 @@ def test_random_concrete_shape(): rng = shared(np.random.RandomState(123)) x_pt = pt.dmatrix() out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng) - jax_fn = random_function([x_pt], out, mode=jax_mode) + jax_fn = compile_random_function([x_pt], out, mode=jax_mode) assert jax_fn(np.ones((2, 3))).shape == (2, 3) @@ -736,7 +736,7 @@ def test_random_concrete_shape_from_param(): rng = shared(np.random.RandomState(123)) x_pt = pt.dmatrix() out = pt.random.normal(x_pt, 1, rng=rng) - jax_fn = random_function([x_pt], out, mode=jax_mode) + jax_fn = compile_random_function([x_pt], out, mode=jax_mode) assert jax_fn(np.ones((2, 3))).shape == (2, 3) @@ -755,7 +755,7 @@ def test_random_concrete_shape_subtensor(): rng = shared(np.random.RandomState(123)) x_pt = pt.dmatrix() out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng) - jax_fn = random_function([x_pt], out, mode=jax_mode) + jax_fn = compile_random_function([x_pt], out, mode=jax_mode) assert jax_fn(np.ones((2, 3))).shape == (3,) @@ -771,7 +771,7 @@ def test_random_concrete_shape_subtensor_tuple(): rng = shared(np.random.RandomState(123)) x_pt = pt.dmatrix() out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng) - jax_fn = random_function([x_pt], out, mode=jax_mode) + jax_fn = compile_random_function([x_pt], out, mode=jax_mode) assert jax_fn(np.ones((2, 3))).shape == (2,) @@ -782,5 +782,5 @@ def test_random_concrete_shape_graph_input(): rng = shared(np.random.RandomState(123)) size_pt = pt.scalar() out = pt.random.normal(0, 1, size=size_pt, rng=rng) - jax_fn = random_function([size_pt], out, mode=jax_mode) + jax_fn = compile_random_function([size_pt], out, mode=jax_mode) assert jax_fn(10).shape == (10,) From 4f79fd470ff0c4abe3ca3584552008c7a04e5340 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 12 Jan 2024 13:40:00 +0100 Subject: [PATCH 2/2] Fix bug in storage_input alignment of the JAX backend When replacing the Shared RNG variables, the input order of the FunctionGraph was not explicitly aligned with the input storage of the function being compiled. --- pytensor/link/jax/linker.py | 11 ++++++++-- tests/link/jax/test_random.py | 38 ++++++++++++++++++++++++++++++++++- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/pytensor/link/jax/linker.py b/pytensor/link/jax/linker.py index 2d75e76d5c..7b76265197 100644 --- a/pytensor/link/jax/linker.py +++ b/pytensor/link/jax/linker.py @@ -23,7 +23,7 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): # Replace any shared RNG inputs so that their values can be updated in place # without affecting the original RNG container. This is necessary because # JAX does not accept RandomState/Generators as inputs, and they will have to - # be typyfied + # be tipyfied if shared_rng_inputs: warnings.warn( f"The RandomType SharedVariables {shared_rng_inputs} will not be used " @@ -52,9 +52,16 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): else: # no break raise ValueError() input_storage[input_storage_idx] = new_inp_storage + # We need to change the order of the inputs of the FunctionGraph + # so that the new input is in the same position as to old one, + # to align with the storage_map. We hope this is safe! + old_inp_fgrap_index = fgraph.inputs.index(old_inp) fgraph.remove_input( - fgraph.inputs.index(old_inp), reason="JAXLinker.fgraph_convert" + old_inp_fgrap_index, + reason="JAXLinker.fgraph_convert", ) + fgraph.inputs.remove(new_inp) + fgraph.inputs.insert(old_inp_fgrap_index, new_inp) return jax_funcify( fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 10b10ff543..4e813fd128 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -10,6 +10,7 @@ from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph from pytensor.tensor.random.basic import RandomVariable +from pytensor.tensor.random.type import RandomType from pytensor.tensor.random.utils import RandomStream from tests.link.jax.test_basic import compare_jax_and_py, jax_mode, set_test_value @@ -58,7 +59,42 @@ def test_random_updates(rng_ctor): ) -def test_random_updates_input_storage_order(): +@pytest.mark.parametrize("noise_first", (False, True)) +def test_replaced_shared_rng_storage_order(noise_first): + # Test that replacing the RNG variable in the linker does not cause + # a disalignment between the compiled graph and the storage_map. + + mu = pytensor.shared(np.array(1.0), name="mu") + rng = pytensor.shared(np.random.default_rng(123)) + next_rng, noise = pt.random.normal(rng=rng).owner.outputs + + if noise_first: + out = noise * mu + else: + out = mu * noise + + updates = { + mu: pt.grad(out, mu), + rng: next_rng, + } + f = compile_random_function([], [out], updates=updates, mode="JAX") + + # The bug was found when noise used to be the first input of the fgraph + # If this changes, the test may need to be tweaked to keep the save coverage + assert isinstance( + f.input_storage[1 - noise_first].type, RandomType + ), "Test may need to be tweaked" + + # Confirm that input_storage type and fgraph input order are aligned + for storage, fgrapn_input in zip(f.input_storage, f.maker.fgraph.inputs): + assert storage.type == fgrapn_input.type + + assert mu.get_value() == 1 + f() + assert mu.get_value() != 1 + + +def test_replaced_shared_rng_storage_ordering_equality(): """Test case described in issue #314. This happened when we tried to update the input storage after we clone the shared RNG.