diff --git a/pytensor/link/basic.py b/pytensor/link/basic.py index 849b74ed73..c3be28e117 100644 --- a/pytensor/link/basic.py +++ b/pytensor/link/basic.py @@ -609,6 +609,9 @@ def create_thunk_inputs(self, storage_map: Dict[Variable, List[Any]]) -> List[An def jit_compile(self, fn: Callable) -> Callable: """JIT compile a converted ``FunctionGraph``.""" + def typify(self, var: Variable): + return var + def output_filter(self, var: Variable, out: Any) -> Any: """Apply a filter to the data output by a JITed function call.""" return out @@ -735,7 +738,7 @@ def make_all(self, input_storage=None, output_storage=None, storage_map=None): return ( fn, [ - Container(input, storage) + Container(self.typify(input), storage) for input, storage in zip(fgraph.inputs, input_storage) ], [ diff --git a/pytensor/link/jax/dispatch/basic.py b/pytensor/link/jax/dispatch/basic.py index b35759f837..8c90886172 100644 --- a/pytensor/link/jax/dispatch/basic.py +++ b/pytensor/link/jax/dispatch/basic.py @@ -87,11 +87,11 @@ def assert_fn(x, *inputs): def jnp_safe_copy(x): try: res = jnp.copy(x) - except NotImplementedError: - warnings.warn( - "`jnp.copy` is not implemented yet. Using the object's `copy` method." - ) + except (NotImplementedError, TypeError): if hasattr(x, "copy"): + warnings.warn( + "`jnp.copy` is not implemented yet. Using the object's `copy` method." + ) res = jnp.array(x.copy()) else: warnings.warn(f"Object has no `copy` method: {x}") diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index 0981234db0..85c2d3d66a 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -10,6 +10,7 @@ import pytensor.tensor.random.basic as aer from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify from pytensor.link.jax.dispatch.shape import JAXShapeTuple +from pytensor.tensor.random.type import RandomType from pytensor.tensor.shape import Shape, Shape_i @@ -57,8 +58,7 @@ def jax_typify_RandomState(state, **kwargs): state = state.get_state(legacy=False) state["bit_generator"] = numpy_bit_gens[state["bit_generator"]] # XXX: Is this a reasonable approach? - state["jax_state"] = state["state"]["key"][0:2] - return state + return state["state"]["key"][0:2] @jax_typify.register(Generator) @@ -83,7 +83,36 @@ def jax_typify_Generator(rng, **kwargs): state_32 = _coerce_to_uint32_array(state["state"]["state"]) state["state"]["inc"] = inc_32[0] << 32 | inc_32[1] state["state"]["state"] = state_32[0] << 32 | state_32[1] - return state + return state["jax_state"] + + +class RandomPRNGKeyType(RandomType[jax.random.PRNGKey]): + """JAX-compatible PRNGKey type. + + This type is not exposed to users directly. + + It is introduced by the JIT linker in place of any RandomType input + variables used in the original function. Nodes in the function graph will + still show the original types as inputs and outputs. + """ + + def filter(self, data, strict: bool = False, allow_downcast=None): + # PRNGs are just JAX Arrays, we assume this is a valid one! + if isinstance(data, jax.Array): + return data + + if strict: + raise TypeError() + + return jax_typify(data) + + +random_prng_key_type = RandomPRNGKeyType() + + +@jax_typify.register(RandomType) +def jax_typify_RandomType(type): + return random_prng_key_type() @jax_funcify.register(aer.RandomVariable) @@ -130,12 +159,10 @@ def jax_sample_fn_generic(op): name = op.name jax_op = getattr(jax.random, name) - def sample_fn(rng, size, dtype, *parameters): - rng_key = rng["jax_state"] + def sample_fn(rng_key, size, dtype, *parameters): rng_key, sampling_key = jax.random.split(rng_key, 2) sample = jax_op(sampling_key, *parameters, shape=size, dtype=dtype) - rng["jax_state"] = rng_key - return (rng, sample) + return (rng_key, sample) return sample_fn @@ -157,13 +184,11 @@ def jax_sample_fn_loc_scale(op): name = op.name jax_op = getattr(jax.random, name) - def sample_fn(rng, size, dtype, *parameters): - rng_key = rng["jax_state"] + def sample_fn(rng_key, size, dtype, *parameters): rng_key, sampling_key = jax.random.split(rng_key, 2) loc, scale = parameters sample = loc + jax_op(sampling_key, size, dtype) * scale - rng["jax_state"] = rng_key - return (rng, sample) + return (rng_key, sample) return sample_fn @@ -175,12 +200,10 @@ def jax_sample_fn_no_dtype(op): name = op.name jax_op = getattr(jax.random, name) - def sample_fn(rng, size, dtype, *parameters): - rng_key = rng["jax_state"] + def sample_fn(rng_key, size, dtype, *parameters): rng_key, sampling_key = jax.random.split(rng_key, 2) sample = jax_op(sampling_key, *parameters, shape=size) - rng["jax_state"] = rng_key - return (rng, sample) + return (rng_key, sample) return sample_fn @@ -201,15 +224,13 @@ def jax_sample_fn_uniform(op): name = "randint" jax_op = getattr(jax.random, name) - def sample_fn(rng, size, dtype, *parameters): - rng_key = rng["jax_state"] + def sample_fn(rng_key, size, dtype, *parameters): rng_key, sampling_key = jax.random.split(rng_key, 2) minval, maxval = parameters sample = jax_op( sampling_key, shape=size, dtype=dtype, minval=minval, maxval=maxval ) - rng["jax_state"] = rng_key - return (rng, sample) + return (rng_key, sample) return sample_fn @@ -226,13 +247,11 @@ def jax_sample_fn_shape_rate(op): name = op.name jax_op = getattr(jax.random, name) - def sample_fn(rng, size, dtype, *parameters): - rng_key = rng["jax_state"] + def sample_fn(rng_key, size, dtype, *parameters): rng_key, sampling_key = jax.random.split(rng_key, 2) (shape, rate) = parameters sample = jax_op(sampling_key, shape, size, dtype) / rate - rng["jax_state"] = rng_key - return (rng, sample) + return (rng_key, sample) return sample_fn @@ -241,13 +260,11 @@ def sample_fn(rng, size, dtype, *parameters): def jax_sample_fn_exponential(op): """JAX implementation of `ExponentialRV`.""" - def sample_fn(rng, size, dtype, *parameters): - rng_key = rng["jax_state"] + def sample_fn(rng_key, size, dtype, *parameters): rng_key, sampling_key = jax.random.split(rng_key, 2) (scale,) = parameters sample = jax.random.exponential(sampling_key, size, dtype) * scale - rng["jax_state"] = rng_key - return (rng, sample) + return (rng_key, sample) return sample_fn @@ -256,8 +273,7 @@ def sample_fn(rng, size, dtype, *parameters): def jax_sample_fn_t(op): """JAX implementation of `StudentTRV`.""" - def sample_fn(rng, size, dtype, *parameters): - rng_key = rng["jax_state"] + def sample_fn(rng_key, size, dtype, *parameters): rng_key, sampling_key = jax.random.split(rng_key, 2) ( df, @@ -265,8 +281,7 @@ def sample_fn(rng, size, dtype, *parameters): scale, ) = parameters sample = loc + jax.random.t(sampling_key, df, size, dtype) * scale - rng["jax_state"] = rng_key - return (rng, sample) + return (rng_key, sample) return sample_fn @@ -275,13 +290,11 @@ def sample_fn(rng, size, dtype, *parameters): def jax_funcify_choice(op): """JAX implementation of `ChoiceRV`.""" - def sample_fn(rng, size, dtype, *parameters): - rng_key = rng["jax_state"] + def sample_fn(rng_key, size, dtype, *parameters): rng_key, sampling_key = jax.random.split(rng_key, 2) (a, p, replace) = parameters smpl_value = jax.random.choice(sampling_key, a, size, replace, p) - rng["jax_state"] = rng_key - return (rng, smpl_value) + return (rng_key, smpl_value) return sample_fn @@ -290,13 +303,11 @@ def sample_fn(rng, size, dtype, *parameters): def jax_sample_fn_permutation(op): """JAX implementation of `PermutationRV`.""" - def sample_fn(rng, size, dtype, *parameters): - rng_key = rng["jax_state"] + def sample_fn(rng_key, size, dtype, *parameters): rng_key, sampling_key = jax.random.split(rng_key, 2) (x,) = parameters sample = jax.random.permutation(sampling_key, x) - rng["jax_state"] = rng_key - return (rng, sample) + return (rng_key, sample) return sample_fn @@ -311,15 +322,12 @@ def jax_sample_fn_binomial(op): from numpyro.distributions.util import binomial - def sample_fn(rng, size, dtype, n, p): - rng_key = rng["jax_state"] + def sample_fn(rng_key, size, dtype, n, p): rng_key, sampling_key = jax.random.split(rng_key, 2) sample = binomial(key=sampling_key, n=n, p=p, shape=size) - rng["jax_state"] = rng_key - - return (rng, sample) + return (rng_key, sample) return sample_fn @@ -334,15 +342,12 @@ def jax_sample_fn_multinomial(op): from numpyro.distributions.util import multinomial - def sample_fn(rng, size, dtype, n, p): - rng_key = rng["jax_state"] + def sample_fn(rng_key, size, dtype, n, p): rng_key, sampling_key = jax.random.split(rng_key, 2) sample = multinomial(key=sampling_key, n=n, p=p, shape=size) - rng["jax_state"] = rng_key - - return (rng, sample) + return (rng_key, sample) return sample_fn @@ -357,8 +362,7 @@ def jax_sample_fn_vonmises(op): from numpyro.distributions.util import von_mises_centered - def sample_fn(rng, size, dtype, mu, kappa): - rng_key = rng["jax_state"] + def sample_fn(rng_key, size, dtype, mu, kappa): rng_key, sampling_key = jax.random.split(rng_key, 2) sample = von_mises_centered( @@ -366,8 +370,6 @@ def sample_fn(rng, size, dtype, mu, kappa): ) sample = (sample + mu + np.pi) % (2.0 * np.pi) - np.pi - rng["jax_state"] = rng_key - - return (rng, sample) + return (rng_key, sample) return sample_fn diff --git a/pytensor/link/jax/dispatch/sparse.py b/pytensor/link/jax/dispatch/sparse.py index 4e58199ae0..81b38413ea 100644 --- a/pytensor/link/jax/dispatch/sparse.py +++ b/pytensor/link/jax/dispatch/sparse.py @@ -1,38 +1,66 @@ import jax.experimental.sparse as jsp from scipy.sparse import spmatrix -from pytensor.graph.basic import Constant +from pytensor.graph.type import HasDataType from pytensor.link.jax.dispatch import jax_funcify, jax_typify -from pytensor.sparse.basic import Dot, StructuredDot +from pytensor.sparse.basic import Dot, StructuredDot, Transpose from pytensor.sparse.type import SparseTensorType +from pytensor.tensor import TensorType @jax_typify.register(spmatrix) def jax_typify_spmatrix(matrix, dtype=None, **kwargs): - # Note: This changes the type of the constants from CSR/CSC to BCOO - # We could add BCOO as a PyTensor type but this would only be useful for JAX graphs - # and it would break the premise of one graph -> multiple backends. - # The same situation happens with RandomGenerators... return jsp.BCOO.from_scipy_sparse(matrix) +class BCOOType(TensorType, HasDataType): + """JAX-compatible BCOO type. + + This type is not exposed to users directly. + + It is introduced by the JIT linker in place of any SparseTensorType input + variables used in the original function. Nodes in the function graph will + still show the original types as inputs and outputs. + """ + + def filter(self, data, strict: bool = False, allow_downcast=None): + if isinstance(data, jsp.BCOO): + return data + + if strict: + raise TypeError() + + return jax_typify(data) + + +@jax_typify.register(SparseTensorType) +def jax_typify_SparseTensorType(type): + return BCOOType( + dtype=type.dtype, + shape=type.shape, + name=type.name, + broadcastable=type.broadcastable, + ) + + @jax_funcify.register(Dot) @jax_funcify.register(StructuredDot) def jax_funcify_sparse_dot(op, node, **kwargs): - for input in node.inputs: - if isinstance(input.type, SparseTensorType) and not isinstance(input, Constant): - raise NotImplementedError( - "JAX sparse dot only implemented for constant sparse inputs" - ) - - if isinstance(node.outputs[0].type, SparseTensorType): - raise NotImplementedError("JAX sparse dot only implemented for dense outputs") - @jsp.sparsify def sparse_dot(x, y): out = x @ y - if isinstance(out, jsp.BCOO): + if isinstance(out, jsp.BCOO) and not isinstance( + node.outputs[0].type, SparseTensorType + ): out = out.todense() return out return sparse_dot + + +@jax_funcify.register(Transpose) +def jax_funcify_sparse_transpose(op, **kwargs): + def sparse_transpose(x): + return x.T + + return sparse_transpose diff --git a/pytensor/link/jax/linker.py b/pytensor/link/jax/linker.py index 2d75e76d5c..cd8931c481 100644 --- a/pytensor/link/jax/linker.py +++ b/pytensor/link/jax/linker.py @@ -3,7 +3,7 @@ from numpy.random import Generator, RandomState from pytensor.compile.sharedvalue import SharedVariable, shared -from pytensor.graph.basic import Constant +from pytensor.graph.basic import Constant, Variable from pytensor.link.basic import JITLinker @@ -12,8 +12,28 @@ class JAXLinker(JITLinker): def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): from pytensor.link.jax.dispatch import jax_funcify + from pytensor.sparse.type import SparseTensorType from pytensor.tensor.random.type import RandomType + if any( + isinstance(inp.type, RandomType) and not isinstance(inp, SharedVariable) + for inp in fgraph.inputs + ): + warnings.warn( + "RandomTypes are implicitly converted to random PRNGKey arrays in JAX. " + "Input values should be provided in this format to avoid a conversion overhead." + ) + + if any( + isinstance(inp.type, SparseTensorType) + and not isinstance(inp, SharedVariable) + for inp in fgraph.inputs + ): + warnings.warn( + "SparseTypes are implicitly converted to sparse BCOO arrays in JAX. " + "Input values should be provided in this format to to avoid a conversion overhead." + ) + shared_rng_inputs = [ inp for inp in fgraph.inputs @@ -70,6 +90,11 @@ def jit_compile(self, fn): ] return jax.jit(fn, static_argnums=static_argnums) + def typify(self, var: Variable): + from pytensor.link.jax.dispatch import jax_typify + + return jax_typify(var.type) + def create_thunk_inputs(self, storage_map): from pytensor.link.jax.dispatch import jax_typify diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 54e4e09307..64d6dc7a1f 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -5,11 +5,11 @@ import pytensor import pytensor.tensor as at import pytensor.tensor.random as aer -from pytensor.compile.function import function from pytensor.compile.sharedvalue import SharedVariable, shared from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph from pytensor.tensor.random.basic import RandomVariable +from pytensor.tensor.random.type import random_generator_type, random_state_type from pytensor.tensor.random.utils import RandomStream from tests.link.jax.test_basic import compare_jax_and_py, jax_mode, set_test_value @@ -24,7 +24,40 @@ def random_function(*args, **kwargs): with pytest.warns( UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used" ): - return function(*args, **kwargs) + return pytensor.function(*args, **kwargs) + + +@pytest.mark.parametrize("random_type", ("generator", "state")) +def test_rng_io(random_type): + """Test explicit (non-shared) input and output RNG types in JAX.""" + if random_type == "generator": + rng = random_generator_type("rng") + np_rng = np.random.default_rng(0) + else: + rng = random_state_type("rng") + np_rng = np.random.RandomState(0) + jx_rng = jax.random.PRNGKey(0) + + next_rng, x = aer.normal(rng=rng).owner.outputs + + with pytest.warns( + UserWarning, + match="RandomTypes are implicitly converted to random PRNGKey arrays", + ): + fn = pytensor.function([rng], [next_rng, x], mode="JAX") + + # Inputs - RNG outputs + assert isinstance(fn(np_rng)[0], jax.Array) + assert isinstance(fn(jx_rng)[0], jax.Array) + + # Inputs - Value outputs + assert fn(np_rng)[1] == fn(np_rng)[1] + assert fn(jx_rng)[1] == fn(jx_rng)[1] + assert fn(np_rng)[1] != fn(jx_rng)[1] + + # Chained Inputs - RNG / Value outputs + assert fn(fn(np_rng)[0])[1] != fn(np_rng)[1] + assert fn(fn(jx_rng)[0])[1] != fn(jx_rng)[1] def test_random_RandomStream(): diff --git a/tests/link/jax/test_sparse.py b/tests/link/jax/test_sparse.py index 0c377bdcd8..06ef2c17b6 100644 --- a/tests/link/jax/test_sparse.py +++ b/tests/link/jax/test_sparse.py @@ -2,13 +2,51 @@ import pytest import scipy.sparse + +jax = pytest.importorskip("jax") +from jax.experimental.sparse import BCOO + import pytensor.sparse as ps import pytensor.tensor as pt from pytensor import function -from pytensor.graph import FunctionGraph +from pytensor.graph import Constant, FunctionGraph +from pytensor.tensor.type import DenseTensorType from tests.link.jax.test_basic import compare_jax_and_py +def assert_bcoo_arrays_allclose(a1, a2): + assert isinstance(a1, BCOO) + assert isinstance(a1, BCOO) + np.testing.assert_allclose(a1.todense(), a2.todense()) + + +@pytest.mark.parametrize("sparse_type", ("csc", "csr")) +def test_sparse_io(sparse_type): + """Test explicit (non-shared) input and output sparse types in JAX.""" + sparse_mat = ps.matrix(format=sparse_type, name="csc", dtype="float32") + sparse_mat_out = sparse_mat.T + + with pytest.warns( + UserWarning, + match="SparseTypes are implicitly converted to sparse BCOO arrays", + ): + fn = function([sparse_mat], sparse_mat_out, mode="JAX") + + sp_sparse_mat = scipy.sparse.random( + 5, 40, density=0.25, format=sparse_type, dtype="float32" + ) + jx_sparse_mat = BCOO.from_scipy_sparse(sp_sparse_mat) + + sp_res = fn(sp_sparse_mat) + jx_res = fn(jx_sparse_mat) + assert_bcoo_arrays_allclose(sp_res, jx_sparse_mat.T) + assert_bcoo_arrays_allclose(jx_res, jx_sparse_mat.T) + + # Chained applications + assert_bcoo_arrays_allclose(fn(fn(sp_sparse_mat)), jx_sparse_mat) + assert_bcoo_arrays_allclose(fn(fn(jx_sparse_mat)), jx_sparse_mat) + + @pytest.mark.parametrize( "op, x_type, y_type", [ @@ -19,57 +57,62 @@ # structured_dot only allows matrix @ matrix (ps.structured_dot, pt.matrix, ps.matrix), (ps.structured_dot, ps.matrix, pt.matrix), + (ps.structured_dot, ps.matrix, ps.matrix), ], ) -def test_sparse_dot_constant_sparse(x_type, y_type, op): +@pytest.mark.parametrize("x_constant", (False, True)) +@pytest.mark.parametrize("y_constant", (False, True)) +def test_sparse_dot(x_type, y_type, op, x_constant, y_constant): inputs = [] test_values = [] if x_type is ps.matrix: - x_sp = scipy.sparse.random(5, 40, density=0.25, format="csr", dtype="float32") - x_pt = ps.as_sparse_variable(x_sp, name="x") + x_test = scipy.sparse.random(5, 40, density=0.25, format="csr", dtype="float32") + x_pt = ps.as_sparse_variable(x_test, name="x") else: - x_pt = x_type("x", dtype="float32") - if x_pt.ndim == 1: + if x_type is pt.vector: x_test = np.arange(40, dtype="float32") else: x_test = np.arange(5 * 40, dtype="float32").reshape(5, 40) + x_pt = pt.as_tensor_variable(x_test, name="x") + assert isinstance(x_pt, Constant) + + if not x_constant: + x_pt = x_pt.type(name="x") inputs.append(x_pt) test_values.append(x_test) if y_type is ps.matrix: - y_sp = scipy.sparse.random(40, 3, density=0.25, format="csc", dtype="float32") - y_pt = ps.as_sparse_variable(y_sp, name="y") + y_test = scipy.sparse.random(40, 3, density=0.25, format="csc", dtype="float32") + y_pt = ps.as_sparse_variable(y_test, name="y") else: - y_pt = y_type("y", dtype="float32") - if y_pt.ndim == 1: + if y_type is pt.vector: y_test = np.arange(40, dtype="float32") else: y_test = np.arange(40 * 3, dtype="float32").reshape(40, 3) + y_pt = pt.as_tensor_variable(y_test, name="y") + assert isinstance(y_pt, Constant) + + if not y_constant: + y_pt = y_pt.type(name="y") inputs.append(y_pt) test_values.append(y_test) dot_pt = op(x_pt, y_pt) fgraph = FunctionGraph(inputs, [dot_pt]) - compare_jax_and_py(fgraph, test_values) - - -def test_sparse_dot_non_const_raises(): - x_pt = pt.vector("x") - - y_sp = scipy.sparse.random(40, 3, density=0.25, format="csc", dtype="float32") - y_pt = ps.as_sparse_variable(y_sp, name="y").type() - - out = ps.dot(x_pt, y_pt) - - msg = "JAX sparse dot only implemented for constant sparse inputs" - - with pytest.raises(NotImplementedError, match=msg): - function([x_pt, y_pt], out, mode="JAX") - - y_pt_shared = ps.shared(y_sp, name="y") - out = ps.dot(x_pt, y_pt_shared) + def assert_fn(x, y): + [x] = x + [y] = y + if hasattr(x, "todense"): + x = x.todense() + if hasattr(y, "todense"): + y = y.todense() + np.testing.assert_allclose(x, y) - with pytest.raises(NotImplementedError, match=msg): - function([x_pt], out, mode="JAX") + compare_jax_and_py( + fgraph, + test_values, + must_be_device_array=isinstance(dot_pt.type, DenseTensorType), + assert_fn=assert_fn, + )