Skip to content

Commit 1e5c487

Browse files
committed
Allow running JAX functions with scalar inputs for RV shapes
1 parent 194b871 commit 1e5c487

File tree

2 files changed

+83
-14
lines changed

2 files changed

+83
-14
lines changed

pytensor/link/jax/linker.py

+36-7
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,19 @@
33
from numpy.random import Generator, RandomState
44

55
from pytensor.compile.sharedvalue import SharedVariable, shared
6-
from pytensor.graph.basic import Constant
76
from pytensor.link.basic import JITLinker
87

98

109
class JAXLinker(JITLinker):
1110
"""A `Linker` that JIT-compiles NumPy-based operations using JAX."""
1211

12+
def __init__(self, *args, **kwargs):
13+
self.scalar_shape_inputs: tuple[int] = () # type: ignore[annotation-unchecked]
14+
super().__init__(*args, **kwargs)
15+
1316
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
1417
from pytensor.link.jax.dispatch import jax_funcify
18+
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
1519
from pytensor.tensor.random.type import RandomType
1620

1721
shared_rng_inputs = [
@@ -65,19 +69,44 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
6569
fgraph.inputs.remove(new_inp)
6670
fgraph.inputs.insert(old_inp_fgrap_index, new_inp)
6771

72+
fgraph_inputs = fgraph.inputs
73+
clients = fgraph.clients
74+
# Detect scalar shape inputs that are used only in JAXShapeTuple nodes
75+
scalar_shape_inputs = [
76+
inp
77+
for node in fgraph.apply_nodes
78+
if isinstance(node.op, JAXShapeTuple)
79+
for inp in node.inputs
80+
if inp in fgraph_inputs
81+
and all(isinstance(node.op, JAXShapeTuple) for node, _ in clients[inp])
82+
]
83+
self.scalar_shape_inputs = tuple(
84+
fgraph_inputs.index(inp) for inp in scalar_shape_inputs
85+
)
86+
6887
return jax_funcify(
6988
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
7089
)
7190

7291
def jit_compile(self, fn):
7392
import jax
7493

75-
# I suppose we can consider `Constant`s to be "static" according to
76-
# JAX.
77-
static_argnums = [
78-
n for n, i in enumerate(self.fgraph.inputs) if isinstance(i, Constant)
79-
]
80-
return jax.jit(fn, static_argnums=static_argnums)
94+
jit_fn = jax.jit(fn, static_argnums=self.scalar_shape_inputs)
95+
96+
if not self.scalar_shape_inputs:
97+
return jit_fn
98+
99+
def convert_scalar_shape_inputs(
100+
*args, scalar_shape_inputs=self.scalar_shape_inputs
101+
):
102+
return jit_fn(
103+
*(
104+
int(arg) if i in scalar_shape_inputs else arg
105+
for i, arg in enumerate(args)
106+
)
107+
)
108+
109+
return convert_scalar_shape_inputs
81110

82111
def create_thunk_inputs(self, storage_map):
83112
from pytensor.link.jax.dispatch import jax_typify

tests/link/jax/test_random.py

+47-7
Original file line numberDiff line numberDiff line change
@@ -867,15 +867,55 @@ def test_random_concrete_shape_subtensor_tuple(self):
867867
jax_fn = compile_random_function([x_pt], out)
868868
assert jax_fn(np.ones((2, 3))).shape == (2,)
869869

870+
def test_random_scalar_shape_input(self):
871+
dim0 = pt.scalar("dim0", dtype=int)
872+
dim1 = pt.scalar("dim1", dtype=int)
873+
874+
out = pt.random.normal(0, 1, size=dim0)
875+
jax_fn = compile_random_function([dim0], out)
876+
assert jax_fn(np.array(2)).shape == (2,)
877+
assert jax_fn(np.array(3)).shape == (3,)
878+
879+
out = pt.random.normal(0, 1, size=[dim0, dim1])
880+
jax_fn = compile_random_function([dim0, dim1], out)
881+
assert jax_fn(np.array(2), np.array(3)).shape == (2, 3)
882+
assert jax_fn(np.array(4), np.array(5)).shape == (4, 5)
883+
870884
@pytest.mark.xfail(
871-
reason="`size_pt` should be specified as a static argument", strict=True
885+
raises=TypeError, reason="Cannot convert scalar input to integer"
872886
)
873-
def test_random_concrete_shape_graph_input(self):
874-
rng = shared(np.random.default_rng(123))
875-
size_pt = pt.scalar()
876-
out = pt.random.normal(0, 1, size=size_pt, rng=rng)
877-
jax_fn = compile_random_function([size_pt], out)
878-
assert jax_fn(10).shape == (10,)
887+
def test_random_scalar_shape_input_not_supported(self):
888+
dim = pt.scalar("dim", dtype=int)
889+
out1 = pt.random.normal(0, 1, size=dim)
890+
# An operation that wouldn't work if we replaced 0d array by integer
891+
out2 = dim[...].set(1)
892+
jax_fn = compile_random_function([dim], [out1, out2])
893+
894+
res1, res2 = jax_fn(np.array(2))
895+
assert res1.shape == (2,)
896+
assert res2 == 1
897+
898+
@pytest.mark.xfail(
899+
raises=TypeError, reason="Cannot convert scalar input to integer"
900+
)
901+
def test_random_scalar_shape_input_not_supported2(self):
902+
dim = pt.scalar("dim", dtype=int)
903+
# This could theoretically be supported
904+
# but would require knowing that * 2 is a safe operation for a python integer
905+
out = pt.random.normal(0, 1, size=dim * 2)
906+
jax_fn = compile_random_function([dim], out)
907+
assert jax_fn(np.array(2)).shape == (4,)
908+
909+
@pytest.mark.xfail(
910+
raises=TypeError, reason="Cannot convert tensor input to shape tuple"
911+
)
912+
def test_random_vector_shape_graph_input(self):
913+
shape = pt.vector("shape", shape=(2,), dtype=int)
914+
out = pt.random.normal(0, 1, size=shape)
915+
916+
jax_fn = compile_random_function([shape], out)
917+
assert jax_fn(np.array([2, 3])).shape == (2, 3)
918+
assert jax_fn(np.array([4, 5])).shape == (4, 5)
879919

880920
def test_constant_shape_after_graph_rewriting(self):
881921
size = pt.vector("size", shape=(2,), dtype=int)

0 commit comments

Comments
 (0)