@@ -867,15 +867,55 @@ def test_random_concrete_shape_subtensor_tuple(self):
867
867
jax_fn = compile_random_function ([x_pt ], out )
868
868
assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
869
869
870
+ def test_random_scalar_shape_input (self ):
871
+ dim0 = pt .scalar ("dim0" , dtype = int )
872
+ dim1 = pt .scalar ("dim1" , dtype = int )
873
+
874
+ out = pt .random .normal (0 , 1 , size = dim0 )
875
+ jax_fn = compile_random_function ([dim0 ], out )
876
+ assert jax_fn (np .array (2 )).shape == (2 ,)
877
+ assert jax_fn (np .array (3 )).shape == (3 ,)
878
+
879
+ out = pt .random .normal (0 , 1 , size = [dim0 , dim1 ])
880
+ jax_fn = compile_random_function ([dim0 , dim1 ], out )
881
+ assert jax_fn (np .array (2 ), np .array (3 )).shape == (2 , 3 )
882
+ assert jax_fn (np .array (4 ), np .array (5 )).shape == (4 , 5 )
883
+
870
884
@pytest .mark .xfail (
871
- reason = "`size_pt` should be specified as a static argument" , strict = True
885
+ raises = TypeError , reason = "Cannot convert scalar input to integer"
872
886
)
873
- def test_random_concrete_shape_graph_input (self ):
874
- rng = shared (np .random .default_rng (123 ))
875
- size_pt = pt .scalar ()
876
- out = pt .random .normal (0 , 1 , size = size_pt , rng = rng )
877
- jax_fn = compile_random_function ([size_pt ], out )
878
- assert jax_fn (10 ).shape == (10 ,)
887
+ def test_random_scalar_shape_input_not_supported (self ):
888
+ dim = pt .scalar ("dim" , dtype = int )
889
+ out1 = pt .random .normal (0 , 1 , size = dim )
890
+ # An operation that wouldn't work if we replaced 0d array by integer
891
+ out2 = dim [...].set (1 )
892
+ jax_fn = compile_random_function ([dim ], [out1 , out2 ])
893
+
894
+ res1 , res2 = jax_fn (np .array (2 ))
895
+ assert res1 .shape == (2 ,)
896
+ assert res2 == 1
897
+
898
+ @pytest .mark .xfail (
899
+ raises = TypeError , reason = "Cannot convert scalar input to integer"
900
+ )
901
+ def test_random_scalar_shape_input_not_supported2 (self ):
902
+ dim = pt .scalar ("dim" , dtype = int )
903
+ # This could theoretically be supported
904
+ # but would require knowing that * 2 is a safe operation for a python integer
905
+ out = pt .random .normal (0 , 1 , size = dim * 2 )
906
+ jax_fn = compile_random_function ([dim ], out )
907
+ assert jax_fn (np .array (2 )).shape == (4 ,)
908
+
909
+ @pytest .mark .xfail (
910
+ raises = TypeError , reason = "Cannot convert tensor input to shape tuple"
911
+ )
912
+ def test_random_vector_shape_graph_input (self ):
913
+ shape = pt .vector ("shape" , shape = (2 ,), dtype = int )
914
+ out = pt .random .normal (0 , 1 , size = shape )
915
+
916
+ jax_fn = compile_random_function ([shape ], out )
917
+ assert jax_fn (np .array ([2 , 3 ])).shape == (2 , 3 )
918
+ assert jax_fn (np .array ([4 , 5 ])).shape == (4 , 5 )
879
919
880
920
def test_constant_shape_after_graph_rewriting (self ):
881
921
size = pt .vector ("size" , shape = (2 ,), dtype = int )
0 commit comments