Skip to content

Fix bug in storage_input alignment of the JAX backend #587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions pytensor/link/jax/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
# Replace any shared RNG inputs so that their values can be updated in place
# without affecting the original RNG container. This is necessary because
# JAX does not accept RandomState/Generators as inputs, and they will have to
# be typyfied
# be tipyfied
if shared_rng_inputs:
warnings.warn(
f"The RandomType SharedVariables {shared_rng_inputs} will not be used "
Expand Down Expand Up @@ -52,9 +52,16 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
else: # no break
raise ValueError()
input_storage[input_storage_idx] = new_inp_storage
# We need to change the order of the inputs of the FunctionGraph
# so that the new input is in the same position as to old one,
# to align with the storage_map. We hope this is safe!
old_inp_fgrap_index = fgraph.inputs.index(old_inp)
fgraph.remove_input(
fgraph.inputs.index(old_inp), reason="JAXLinker.fgraph_convert"
old_inp_fgrap_index,
reason="JAXLinker.fgraph_convert",
)
fgraph.inputs.remove(new_inp)
fgraph.inputs.insert(old_inp_fgrap_index, new_inp)

return jax_funcify(
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
Expand Down
86 changes: 61 additions & 25 deletions tests/link/jax/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.random.basic import RandomVariable
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.random.utils import RandomStream
from tests.link.jax.test_basic import compare_jax_and_py, jax_mode, set_test_value

Expand All @@ -20,7 +21,7 @@
from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402


def random_function(*args, **kwargs):
def compile_random_function(*args, **kwargs):
with pytest.warns(
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
):
Expand All @@ -35,7 +36,7 @@ def test_random_RandomStream():
srng = RandomStream(seed=123)
out = srng.normal() - srng.normal()

fn = random_function([], out, mode=jax_mode)
fn = compile_random_function([], out, mode=jax_mode)
jax_res_1 = fn()
jax_res_2 = fn()

Expand All @@ -48,7 +49,7 @@ def test_random_updates(rng_ctor):
rng = shared(original_value, name="original_rng", borrow=False)
next_rng, x = pt.random.normal(name="x", rng=rng).owner.outputs

f = random_function([], [x], updates={rng: next_rng}, mode=jax_mode)
f = compile_random_function([], [x], updates={rng: next_rng}, mode=jax_mode)
assert f() != f()

# Check that original rng variable content was not overwritten when calling jax_typify
Expand All @@ -58,7 +59,42 @@ def test_random_updates(rng_ctor):
)


def test_random_updates_input_storage_order():
@pytest.mark.parametrize("noise_first", (False, True))
def test_replaced_shared_rng_storage_order(noise_first):
# Test that replacing the RNG variable in the linker does not cause
# a disalignment between the compiled graph and the storage_map.

mu = pytensor.shared(np.array(1.0), name="mu")
rng = pytensor.shared(np.random.default_rng(123))
next_rng, noise = pt.random.normal(rng=rng).owner.outputs

if noise_first:
out = noise * mu
else:
out = mu * noise

updates = {
mu: pt.grad(out, mu),
rng: next_rng,
}
f = compile_random_function([], [out], updates=updates, mode="JAX")

# The bug was found when noise used to be the first input of the fgraph
# If this changes, the test may need to be tweaked to keep the save coverage
assert isinstance(
f.input_storage[1 - noise_first].type, RandomType
), "Test may need to be tweaked"

# Confirm that input_storage type and fgraph input order are aligned
for storage, fgrapn_input in zip(f.input_storage, f.maker.fgraph.inputs):
assert storage.type == fgrapn_input.type

assert mu.get_value() == 1
f()
assert mu.get_value() != 1


def test_replaced_shared_rng_storage_ordering_equality():
"""Test case described in issue #314.

This happened when we tried to update the input storage after we clone the shared RNG.
Expand All @@ -79,7 +115,7 @@ def test_random_updates_input_storage_order():
# This function replaces inp by input_shared in the update expression
# This is what caused the RNG to appear later than inp_shared in the input_storage

fn = random_function(
fn = compile_random_function(
inputs=[],
outputs=[],
updates={inp_shared: inp_update},
Expand Down Expand Up @@ -453,7 +489,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
else:
rng = shared(np.random.RandomState(29402))
g = rv_op(*dist_params, size=(10_000,) + base_size, rng=rng)
g_fn = random_function(dist_params, g, mode=jax_mode)
g_fn = compile_random_function(dist_params, g, mode=jax_mode)
samples = g_fn(
*[
i.tag.test_value
Expand All @@ -477,7 +513,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
def test_random_bernoulli(size):
rng = shared(np.random.RandomState(123))
g = pt.random.bernoulli(0.5, size=(1000,) + size, rng=rng)
g_fn = random_function([], g, mode=jax_mode)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)

Expand All @@ -488,7 +524,7 @@ def test_random_mvnormal():
mu = np.ones(4)
cov = np.eye(4)
g = pt.random.multivariate_normal(mu, cov, size=(10000,), rng=rng)
g_fn = random_function([], g, mode=jax_mode)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1)

Expand All @@ -503,7 +539,7 @@ def test_random_mvnormal():
def test_random_dirichlet(parameter, size):
rng = shared(np.random.RandomState(123))
g = pt.random.dirichlet(parameter, size=(1000,) + size, rng=rng)
g_fn = random_function([], g, mode=jax_mode)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)

Expand All @@ -513,29 +549,29 @@ def test_random_choice():
num_samples = 10000
rng = shared(np.random.RandomState(123))
g = pt.random.choice(np.arange(4), size=num_samples, rng=rng)
g_fn = random_function([], g, mode=jax_mode)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(np.sum(samples == 3) / num_samples, 0.25, 2)

# `replace=False` produces unique results
rng = shared(np.random.RandomState(123))
g = pt.random.choice(np.arange(100), replace=False, size=99, rng=rng)
g_fn = random_function([], g, mode=jax_mode)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
assert len(np.unique(samples)) == 99

# We can pass an array with probabilities
rng = shared(np.random.RandomState(123))
g = pt.random.choice(np.arange(3), p=np.array([1.0, 0.0, 0.0]), size=10, rng=rng)
g_fn = random_function([], g, mode=jax_mode)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples, np.zeros(10))


def test_random_categorical():
rng = shared(np.random.RandomState(123))
g = pt.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng)
g_fn = random_function([], g, mode=jax_mode)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 6 / 4, 1)

Expand All @@ -544,7 +580,7 @@ def test_random_permutation():
array = np.arange(4)
rng = shared(np.random.RandomState(123))
g = pt.random.permutation(array, rng=rng)
g_fn = random_function([], g, mode=jax_mode)
g_fn = compile_random_function([], g, mode=jax_mode)
permuted = g_fn()
with pytest.raises(AssertionError):
np.testing.assert_allclose(array, permuted)
Expand All @@ -554,7 +590,7 @@ def test_random_geometric():
rng = shared(np.random.RandomState(123))
p = np.array([0.3, 0.7])
g = pt.random.geometric(p, size=(10_000, 2), rng=rng)
g_fn = random_function([], g, mode=jax_mode)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 1 / p, rtol=0.1)
np.testing.assert_allclose(samples.std(axis=0), np.sqrt((1 - p) / p**2), rtol=0.1)
Expand All @@ -565,7 +601,7 @@ def test_negative_binomial():
n = np.array([10, 40])
p = np.array([0.3, 0.7])
g = pt.random.negative_binomial(n, p, size=(10_000, 2), rng=rng)
g_fn = random_function([], g, mode=jax_mode)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n * (1 - p) / p, rtol=0.1)
np.testing.assert_allclose(
Expand All @@ -579,7 +615,7 @@ def test_binomial():
n = np.array([10, 40])
p = np.array([0.3, 0.7])
g = pt.random.binomial(n, p, size=(10_000, 2), rng=rng)
g_fn = random_function([], g, mode=jax_mode)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1)
np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.1)
Expand All @@ -594,7 +630,7 @@ def test_beta_binomial():
a = np.array([1.5, 13])
b = np.array([0.5, 9])
g = pt.random.betabinom(n, a, b, size=(10_000, 2), rng=rng)
g_fn = random_function([], g, mode=jax_mode)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n * a / (a + b), rtol=0.1)
np.testing.assert_allclose(
Expand All @@ -612,7 +648,7 @@ def test_multinomial():
n = np.array([10, 40])
p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]])
g = pt.random.multinomial(n, p, size=(10_000, 2), rng=rng)
g_fn = random_function([], g, mode=jax_mode)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1)
np.testing.assert_allclose(
Expand All @@ -628,7 +664,7 @@ def test_vonmises_mu_outside_circle():
mu = np.array([-30, 40])
kappa = np.array([100, 10])
g = pt.random.vonmises(mu, kappa, size=(10_000, 2), rng=rng)
g_fn = random_function([], g, mode=jax_mode)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(
samples.mean(axis=0), (mu + np.pi) % (2.0 * np.pi) - np.pi, rtol=0.1
Expand Down Expand Up @@ -728,15 +764,15 @@ def test_random_concrete_shape():
rng = shared(np.random.RandomState(123))
x_pt = pt.dmatrix()
out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng)
jax_fn = random_function([x_pt], out, mode=jax_mode)
jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
assert jax_fn(np.ones((2, 3))).shape == (2, 3)


def test_random_concrete_shape_from_param():
rng = shared(np.random.RandomState(123))
x_pt = pt.dmatrix()
out = pt.random.normal(x_pt, 1, rng=rng)
jax_fn = random_function([x_pt], out, mode=jax_mode)
jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
assert jax_fn(np.ones((2, 3))).shape == (2, 3)


Expand All @@ -755,7 +791,7 @@ def test_random_concrete_shape_subtensor():
rng = shared(np.random.RandomState(123))
x_pt = pt.dmatrix()
out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng)
jax_fn = random_function([x_pt], out, mode=jax_mode)
jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
assert jax_fn(np.ones((2, 3))).shape == (3,)


Expand All @@ -771,7 +807,7 @@ def test_random_concrete_shape_subtensor_tuple():
rng = shared(np.random.RandomState(123))
x_pt = pt.dmatrix()
out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng)
jax_fn = random_function([x_pt], out, mode=jax_mode)
jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
assert jax_fn(np.ones((2, 3))).shape == (2,)


Expand All @@ -782,5 +818,5 @@ def test_random_concrete_shape_graph_input():
rng = shared(np.random.RandomState(123))
size_pt = pt.scalar()
out = pt.random.normal(0, 1, size=size_pt, rng=rng)
jax_fn = random_function([size_pt], out, mode=jax_mode)
jax_fn = compile_random_function([size_pt], out, mode=jax_mode)
assert jax_fn(10).shape == (10,)