20
20
from pytensor .link .jax .dispatch .random import numpyro_available # noqa: E402
21
21
22
22
23
- def random_function (* args , ** kwargs ):
23
+ def compile_random_function (* args , ** kwargs ):
24
24
with pytest .warns (
25
25
UserWarning , match = r"The RandomType SharedVariables \[.+\] will not be used"
26
26
):
@@ -35,7 +35,7 @@ def test_random_RandomStream():
35
35
srng = RandomStream (seed = 123 )
36
36
out = srng .normal () - srng .normal ()
37
37
38
- fn = random_function ([], out , mode = jax_mode )
38
+ fn = compile_random_function ([], out , mode = jax_mode )
39
39
jax_res_1 = fn ()
40
40
jax_res_2 = fn ()
41
41
@@ -48,7 +48,7 @@ def test_random_updates(rng_ctor):
48
48
rng = shared (original_value , name = "original_rng" , borrow = False )
49
49
next_rng , x = pt .random .normal (name = "x" , rng = rng ).owner .outputs
50
50
51
- f = random_function ([], [x ], updates = {rng : next_rng }, mode = jax_mode )
51
+ f = compile_random_function ([], [x ], updates = {rng : next_rng }, mode = jax_mode )
52
52
assert f () != f ()
53
53
54
54
# Check that original rng variable content was not overwritten when calling jax_typify
@@ -79,7 +79,7 @@ def test_random_updates_input_storage_order():
79
79
# This function replaces inp by input_shared in the update expression
80
80
# This is what caused the RNG to appear later than inp_shared in the input_storage
81
81
82
- fn = random_function (
82
+ fn = compile_random_function (
83
83
inputs = [],
84
84
outputs = [],
85
85
updates = {inp_shared : inp_update },
@@ -453,7 +453,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
453
453
else :
454
454
rng = shared (np .random .RandomState (29402 ))
455
455
g = rv_op (* dist_params , size = (10_000 ,) + base_size , rng = rng )
456
- g_fn = random_function (dist_params , g , mode = jax_mode )
456
+ g_fn = compile_random_function (dist_params , g , mode = jax_mode )
457
457
samples = g_fn (
458
458
* [
459
459
i .tag .test_value
@@ -477,7 +477,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
477
477
def test_random_bernoulli (size ):
478
478
rng = shared (np .random .RandomState (123 ))
479
479
g = pt .random .bernoulli (0.5 , size = (1000 ,) + size , rng = rng )
480
- g_fn = random_function ([], g , mode = jax_mode )
480
+ g_fn = compile_random_function ([], g , mode = jax_mode )
481
481
samples = g_fn ()
482
482
np .testing .assert_allclose (samples .mean (axis = 0 ), 0.5 , 1 )
483
483
@@ -488,7 +488,7 @@ def test_random_mvnormal():
488
488
mu = np .ones (4 )
489
489
cov = np .eye (4 )
490
490
g = pt .random .multivariate_normal (mu , cov , size = (10000 ,), rng = rng )
491
- g_fn = random_function ([], g , mode = jax_mode )
491
+ g_fn = compile_random_function ([], g , mode = jax_mode )
492
492
samples = g_fn ()
493
493
np .testing .assert_allclose (samples .mean (axis = 0 ), mu , atol = 0.1 )
494
494
@@ -503,7 +503,7 @@ def test_random_mvnormal():
503
503
def test_random_dirichlet (parameter , size ):
504
504
rng = shared (np .random .RandomState (123 ))
505
505
g = pt .random .dirichlet (parameter , size = (1000 ,) + size , rng = rng )
506
- g_fn = random_function ([], g , mode = jax_mode )
506
+ g_fn = compile_random_function ([], g , mode = jax_mode )
507
507
samples = g_fn ()
508
508
np .testing .assert_allclose (samples .mean (axis = 0 ), 0.5 , 1 )
509
509
@@ -513,29 +513,29 @@ def test_random_choice():
513
513
num_samples = 10000
514
514
rng = shared (np .random .RandomState (123 ))
515
515
g = pt .random .choice (np .arange (4 ), size = num_samples , rng = rng )
516
- g_fn = random_function ([], g , mode = jax_mode )
516
+ g_fn = compile_random_function ([], g , mode = jax_mode )
517
517
samples = g_fn ()
518
518
np .testing .assert_allclose (np .sum (samples == 3 ) / num_samples , 0.25 , 2 )
519
519
520
520
# `replace=False` produces unique results
521
521
rng = shared (np .random .RandomState (123 ))
522
522
g = pt .random .choice (np .arange (100 ), replace = False , size = 99 , rng = rng )
523
- g_fn = random_function ([], g , mode = jax_mode )
523
+ g_fn = compile_random_function ([], g , mode = jax_mode )
524
524
samples = g_fn ()
525
525
assert len (np .unique (samples )) == 99
526
526
527
527
# We can pass an array with probabilities
528
528
rng = shared (np .random .RandomState (123 ))
529
529
g = pt .random .choice (np .arange (3 ), p = np .array ([1.0 , 0.0 , 0.0 ]), size = 10 , rng = rng )
530
- g_fn = random_function ([], g , mode = jax_mode )
530
+ g_fn = compile_random_function ([], g , mode = jax_mode )
531
531
samples = g_fn ()
532
532
np .testing .assert_allclose (samples , np .zeros (10 ))
533
533
534
534
535
535
def test_random_categorical ():
536
536
rng = shared (np .random .RandomState (123 ))
537
537
g = pt .random .categorical (0.25 * np .ones (4 ), size = (10000 , 4 ), rng = rng )
538
- g_fn = random_function ([], g , mode = jax_mode )
538
+ g_fn = compile_random_function ([], g , mode = jax_mode )
539
539
samples = g_fn ()
540
540
np .testing .assert_allclose (samples .mean (axis = 0 ), 6 / 4 , 1 )
541
541
@@ -544,7 +544,7 @@ def test_random_permutation():
544
544
array = np .arange (4 )
545
545
rng = shared (np .random .RandomState (123 ))
546
546
g = pt .random .permutation (array , rng = rng )
547
- g_fn = random_function ([], g , mode = jax_mode )
547
+ g_fn = compile_random_function ([], g , mode = jax_mode )
548
548
permuted = g_fn ()
549
549
with pytest .raises (AssertionError ):
550
550
np .testing .assert_allclose (array , permuted )
@@ -554,7 +554,7 @@ def test_random_geometric():
554
554
rng = shared (np .random .RandomState (123 ))
555
555
p = np .array ([0.3 , 0.7 ])
556
556
g = pt .random .geometric (p , size = (10_000 , 2 ), rng = rng )
557
- g_fn = random_function ([], g , mode = jax_mode )
557
+ g_fn = compile_random_function ([], g , mode = jax_mode )
558
558
samples = g_fn ()
559
559
np .testing .assert_allclose (samples .mean (axis = 0 ), 1 / p , rtol = 0.1 )
560
560
np .testing .assert_allclose (samples .std (axis = 0 ), np .sqrt ((1 - p ) / p ** 2 ), rtol = 0.1 )
@@ -565,7 +565,7 @@ def test_negative_binomial():
565
565
n = np .array ([10 , 40 ])
566
566
p = np .array ([0.3 , 0.7 ])
567
567
g = pt .random .negative_binomial (n , p , size = (10_000 , 2 ), rng = rng )
568
- g_fn = random_function ([], g , mode = jax_mode )
568
+ g_fn = compile_random_function ([], g , mode = jax_mode )
569
569
samples = g_fn ()
570
570
np .testing .assert_allclose (samples .mean (axis = 0 ), n * (1 - p ) / p , rtol = 0.1 )
571
571
np .testing .assert_allclose (
@@ -579,7 +579,7 @@ def test_binomial():
579
579
n = np .array ([10 , 40 ])
580
580
p = np .array ([0.3 , 0.7 ])
581
581
g = pt .random .binomial (n , p , size = (10_000 , 2 ), rng = rng )
582
- g_fn = random_function ([], g , mode = jax_mode )
582
+ g_fn = compile_random_function ([], g , mode = jax_mode )
583
583
samples = g_fn ()
584
584
np .testing .assert_allclose (samples .mean (axis = 0 ), n * p , rtol = 0.1 )
585
585
np .testing .assert_allclose (samples .std (axis = 0 ), np .sqrt (n * p * (1 - p )), rtol = 0.1 )
@@ -594,7 +594,7 @@ def test_beta_binomial():
594
594
a = np .array ([1.5 , 13 ])
595
595
b = np .array ([0.5 , 9 ])
596
596
g = pt .random .betabinom (n , a , b , size = (10_000 , 2 ), rng = rng )
597
- g_fn = random_function ([], g , mode = jax_mode )
597
+ g_fn = compile_random_function ([], g , mode = jax_mode )
598
598
samples = g_fn ()
599
599
np .testing .assert_allclose (samples .mean (axis = 0 ), n * a / (a + b ), rtol = 0.1 )
600
600
np .testing .assert_allclose (
@@ -612,7 +612,7 @@ def test_multinomial():
612
612
n = np .array ([10 , 40 ])
613
613
p = np .array ([[0.3 , 0.7 , 0.0 ], [0.1 , 0.4 , 0.5 ]])
614
614
g = pt .random .multinomial (n , p , size = (10_000 , 2 ), rng = rng )
615
- g_fn = random_function ([], g , mode = jax_mode )
615
+ g_fn = compile_random_function ([], g , mode = jax_mode )
616
616
samples = g_fn ()
617
617
np .testing .assert_allclose (samples .mean (axis = 0 ), n [..., None ] * p , rtol = 0.1 )
618
618
np .testing .assert_allclose (
@@ -628,7 +628,7 @@ def test_vonmises_mu_outside_circle():
628
628
mu = np .array ([- 30 , 40 ])
629
629
kappa = np .array ([100 , 10 ])
630
630
g = pt .random .vonmises (mu , kappa , size = (10_000 , 2 ), rng = rng )
631
- g_fn = random_function ([], g , mode = jax_mode )
631
+ g_fn = compile_random_function ([], g , mode = jax_mode )
632
632
samples = g_fn ()
633
633
np .testing .assert_allclose (
634
634
samples .mean (axis = 0 ), (mu + np .pi ) % (2.0 * np .pi ) - np .pi , rtol = 0.1
@@ -728,15 +728,15 @@ def test_random_concrete_shape():
728
728
rng = shared (np .random .RandomState (123 ))
729
729
x_pt = pt .dmatrix ()
730
730
out = pt .random .normal (0 , 1 , size = x_pt .shape , rng = rng )
731
- jax_fn = random_function ([x_pt ], out , mode = jax_mode )
731
+ jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
732
732
assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
733
733
734
734
735
735
def test_random_concrete_shape_from_param ():
736
736
rng = shared (np .random .RandomState (123 ))
737
737
x_pt = pt .dmatrix ()
738
738
out = pt .random .normal (x_pt , 1 , rng = rng )
739
- jax_fn = random_function ([x_pt ], out , mode = jax_mode )
739
+ jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
740
740
assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
741
741
742
742
@@ -755,7 +755,7 @@ def test_random_concrete_shape_subtensor():
755
755
rng = shared (np .random .RandomState (123 ))
756
756
x_pt = pt .dmatrix ()
757
757
out = pt .random .normal (0 , 1 , size = x_pt .shape [1 ], rng = rng )
758
- jax_fn = random_function ([x_pt ], out , mode = jax_mode )
758
+ jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
759
759
assert jax_fn (np .ones ((2 , 3 ))).shape == (3 ,)
760
760
761
761
@@ -771,7 +771,7 @@ def test_random_concrete_shape_subtensor_tuple():
771
771
rng = shared (np .random .RandomState (123 ))
772
772
x_pt = pt .dmatrix ()
773
773
out = pt .random .normal (0 , 1 , size = (x_pt .shape [0 ],), rng = rng )
774
- jax_fn = random_function ([x_pt ], out , mode = jax_mode )
774
+ jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
775
775
assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
776
776
777
777
@@ -782,5 +782,5 @@ def test_random_concrete_shape_graph_input():
782
782
rng = shared (np .random .RandomState (123 ))
783
783
size_pt = pt .scalar ()
784
784
out = pt .random .normal (0 , 1 , size = size_pt , rng = rng )
785
- jax_fn = random_function ([size_pt ], out , mode = jax_mode )
785
+ jax_fn = compile_random_function ([size_pt ], out , mode = jax_mode )
786
786
assert jax_fn (10 ).shape == (10 ,)
0 commit comments