Skip to content

Commit 3170c7d

Browse files
committed
Rename helper function
1 parent c5b96d9 commit 3170c7d

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

tests/link/jax/test_random.py

+24-24
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402
2121

2222

23-
def random_function(*args, **kwargs):
23+
def compile_random_function(*args, **kwargs):
2424
with pytest.warns(
2525
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
2626
):
@@ -35,7 +35,7 @@ def test_random_RandomStream():
3535
srng = RandomStream(seed=123)
3636
out = srng.normal() - srng.normal()
3737

38-
fn = random_function([], out, mode=jax_mode)
38+
fn = compile_random_function([], out, mode=jax_mode)
3939
jax_res_1 = fn()
4040
jax_res_2 = fn()
4141

@@ -48,7 +48,7 @@ def test_random_updates(rng_ctor):
4848
rng = shared(original_value, name="original_rng", borrow=False)
4949
next_rng, x = pt.random.normal(name="x", rng=rng).owner.outputs
5050

51-
f = random_function([], [x], updates={rng: next_rng}, mode=jax_mode)
51+
f = compile_random_function([], [x], updates={rng: next_rng}, mode=jax_mode)
5252
assert f() != f()
5353

5454
# Check that original rng variable content was not overwritten when calling jax_typify
@@ -79,7 +79,7 @@ def test_random_updates_input_storage_order():
7979
# This function replaces inp by input_shared in the update expression
8080
# This is what caused the RNG to appear later than inp_shared in the input_storage
8181

82-
fn = random_function(
82+
fn = compile_random_function(
8383
inputs=[],
8484
outputs=[],
8585
updates={inp_shared: inp_update},
@@ -453,7 +453,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
453453
else:
454454
rng = shared(np.random.RandomState(29402))
455455
g = rv_op(*dist_params, size=(10_000,) + base_size, rng=rng)
456-
g_fn = random_function(dist_params, g, mode=jax_mode)
456+
g_fn = compile_random_function(dist_params, g, mode=jax_mode)
457457
samples = g_fn(
458458
*[
459459
i.tag.test_value
@@ -477,7 +477,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
477477
def test_random_bernoulli(size):
478478
rng = shared(np.random.RandomState(123))
479479
g = pt.random.bernoulli(0.5, size=(1000,) + size, rng=rng)
480-
g_fn = random_function([], g, mode=jax_mode)
480+
g_fn = compile_random_function([], g, mode=jax_mode)
481481
samples = g_fn()
482482
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)
483483

@@ -488,7 +488,7 @@ def test_random_mvnormal():
488488
mu = np.ones(4)
489489
cov = np.eye(4)
490490
g = pt.random.multivariate_normal(mu, cov, size=(10000,), rng=rng)
491-
g_fn = random_function([], g, mode=jax_mode)
491+
g_fn = compile_random_function([], g, mode=jax_mode)
492492
samples = g_fn()
493493
np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1)
494494

@@ -503,7 +503,7 @@ def test_random_mvnormal():
503503
def test_random_dirichlet(parameter, size):
504504
rng = shared(np.random.RandomState(123))
505505
g = pt.random.dirichlet(parameter, size=(1000,) + size, rng=rng)
506-
g_fn = random_function([], g, mode=jax_mode)
506+
g_fn = compile_random_function([], g, mode=jax_mode)
507507
samples = g_fn()
508508
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)
509509

@@ -513,29 +513,29 @@ def test_random_choice():
513513
num_samples = 10000
514514
rng = shared(np.random.RandomState(123))
515515
g = pt.random.choice(np.arange(4), size=num_samples, rng=rng)
516-
g_fn = random_function([], g, mode=jax_mode)
516+
g_fn = compile_random_function([], g, mode=jax_mode)
517517
samples = g_fn()
518518
np.testing.assert_allclose(np.sum(samples == 3) / num_samples, 0.25, 2)
519519

520520
# `replace=False` produces unique results
521521
rng = shared(np.random.RandomState(123))
522522
g = pt.random.choice(np.arange(100), replace=False, size=99, rng=rng)
523-
g_fn = random_function([], g, mode=jax_mode)
523+
g_fn = compile_random_function([], g, mode=jax_mode)
524524
samples = g_fn()
525525
assert len(np.unique(samples)) == 99
526526

527527
# We can pass an array with probabilities
528528
rng = shared(np.random.RandomState(123))
529529
g = pt.random.choice(np.arange(3), p=np.array([1.0, 0.0, 0.0]), size=10, rng=rng)
530-
g_fn = random_function([], g, mode=jax_mode)
530+
g_fn = compile_random_function([], g, mode=jax_mode)
531531
samples = g_fn()
532532
np.testing.assert_allclose(samples, np.zeros(10))
533533

534534

535535
def test_random_categorical():
536536
rng = shared(np.random.RandomState(123))
537537
g = pt.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng)
538-
g_fn = random_function([], g, mode=jax_mode)
538+
g_fn = compile_random_function([], g, mode=jax_mode)
539539
samples = g_fn()
540540
np.testing.assert_allclose(samples.mean(axis=0), 6 / 4, 1)
541541

@@ -544,7 +544,7 @@ def test_random_permutation():
544544
array = np.arange(4)
545545
rng = shared(np.random.RandomState(123))
546546
g = pt.random.permutation(array, rng=rng)
547-
g_fn = random_function([], g, mode=jax_mode)
547+
g_fn = compile_random_function([], g, mode=jax_mode)
548548
permuted = g_fn()
549549
with pytest.raises(AssertionError):
550550
np.testing.assert_allclose(array, permuted)
@@ -554,7 +554,7 @@ def test_random_geometric():
554554
rng = shared(np.random.RandomState(123))
555555
p = np.array([0.3, 0.7])
556556
g = pt.random.geometric(p, size=(10_000, 2), rng=rng)
557-
g_fn = random_function([], g, mode=jax_mode)
557+
g_fn = compile_random_function([], g, mode=jax_mode)
558558
samples = g_fn()
559559
np.testing.assert_allclose(samples.mean(axis=0), 1 / p, rtol=0.1)
560560
np.testing.assert_allclose(samples.std(axis=0), np.sqrt((1 - p) / p**2), rtol=0.1)
@@ -565,7 +565,7 @@ def test_negative_binomial():
565565
n = np.array([10, 40])
566566
p = np.array([0.3, 0.7])
567567
g = pt.random.negative_binomial(n, p, size=(10_000, 2), rng=rng)
568-
g_fn = random_function([], g, mode=jax_mode)
568+
g_fn = compile_random_function([], g, mode=jax_mode)
569569
samples = g_fn()
570570
np.testing.assert_allclose(samples.mean(axis=0), n * (1 - p) / p, rtol=0.1)
571571
np.testing.assert_allclose(
@@ -579,7 +579,7 @@ def test_binomial():
579579
n = np.array([10, 40])
580580
p = np.array([0.3, 0.7])
581581
g = pt.random.binomial(n, p, size=(10_000, 2), rng=rng)
582-
g_fn = random_function([], g, mode=jax_mode)
582+
g_fn = compile_random_function([], g, mode=jax_mode)
583583
samples = g_fn()
584584
np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1)
585585
np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.1)
@@ -594,7 +594,7 @@ def test_beta_binomial():
594594
a = np.array([1.5, 13])
595595
b = np.array([0.5, 9])
596596
g = pt.random.betabinom(n, a, b, size=(10_000, 2), rng=rng)
597-
g_fn = random_function([], g, mode=jax_mode)
597+
g_fn = compile_random_function([], g, mode=jax_mode)
598598
samples = g_fn()
599599
np.testing.assert_allclose(samples.mean(axis=0), n * a / (a + b), rtol=0.1)
600600
np.testing.assert_allclose(
@@ -612,7 +612,7 @@ def test_multinomial():
612612
n = np.array([10, 40])
613613
p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]])
614614
g = pt.random.multinomial(n, p, size=(10_000, 2), rng=rng)
615-
g_fn = random_function([], g, mode=jax_mode)
615+
g_fn = compile_random_function([], g, mode=jax_mode)
616616
samples = g_fn()
617617
np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1)
618618
np.testing.assert_allclose(
@@ -628,7 +628,7 @@ def test_vonmises_mu_outside_circle():
628628
mu = np.array([-30, 40])
629629
kappa = np.array([100, 10])
630630
g = pt.random.vonmises(mu, kappa, size=(10_000, 2), rng=rng)
631-
g_fn = random_function([], g, mode=jax_mode)
631+
g_fn = compile_random_function([], g, mode=jax_mode)
632632
samples = g_fn()
633633
np.testing.assert_allclose(
634634
samples.mean(axis=0), (mu + np.pi) % (2.0 * np.pi) - np.pi, rtol=0.1
@@ -728,15 +728,15 @@ def test_random_concrete_shape():
728728
rng = shared(np.random.RandomState(123))
729729
x_pt = pt.dmatrix()
730730
out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng)
731-
jax_fn = random_function([x_pt], out, mode=jax_mode)
731+
jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
732732
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
733733

734734

735735
def test_random_concrete_shape_from_param():
736736
rng = shared(np.random.RandomState(123))
737737
x_pt = pt.dmatrix()
738738
out = pt.random.normal(x_pt, 1, rng=rng)
739-
jax_fn = random_function([x_pt], out, mode=jax_mode)
739+
jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
740740
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
741741

742742

@@ -755,7 +755,7 @@ def test_random_concrete_shape_subtensor():
755755
rng = shared(np.random.RandomState(123))
756756
x_pt = pt.dmatrix()
757757
out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng)
758-
jax_fn = random_function([x_pt], out, mode=jax_mode)
758+
jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
759759
assert jax_fn(np.ones((2, 3))).shape == (3,)
760760

761761

@@ -771,7 +771,7 @@ def test_random_concrete_shape_subtensor_tuple():
771771
rng = shared(np.random.RandomState(123))
772772
x_pt = pt.dmatrix()
773773
out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng)
774-
jax_fn = random_function([x_pt], out, mode=jax_mode)
774+
jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
775775
assert jax_fn(np.ones((2, 3))).shape == (2,)
776776

777777

@@ -782,5 +782,5 @@ def test_random_concrete_shape_graph_input():
782782
rng = shared(np.random.RandomState(123))
783783
size_pt = pt.scalar()
784784
out = pt.random.normal(0, 1, size=size_pt, rng=rng)
785-
jax_fn = random_function([size_pt], out, mode=jax_mode)
785+
jax_fn = compile_random_function([size_pt], out, mode=jax_mode)
786786
assert jax_fn(10).shape == (10,)

0 commit comments

Comments
 (0)