Skip to content

Support RandomVariable graphs with scalar shape parameters in JAX backend #1029

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Oct 11, 2024

This should make it possible to do forward sampling in more PyMC models that use dims to define variables shapes

def test_random_scalar_shape_input():
    dim0 = pt.scalar("dim0", dtype=int)
    dim1 = pt.scalar("dim1", dtype=int)

    out = pt.random.normal(0, 1, size=dim0)
    jax_fn = compile_random_function([dim0], out)
    assert jax_fn(np.array(2)).shape == (2,)
    assert jax_fn(np.array(3)).shape == (3,)

    out = pt.random.normal(0, 1, size=[dim0, dim1])
    jax_fn = compile_random_function([dim0, dim1], out)
    assert jax_fn(np.array(2), np.array(3)).shape == (2, 3)
    assert jax_fn(np.array(4), np.array(5)).shape == (4, 5)

These was already a special rewrite to replace make_vector, expand_dims in the shape of RVs, but without handling these inputs from the outside it wouldn't achieve much for PyTensor users:

@node_rewriter([RandomVariable])
def size_parameter_as_tuple(fgraph, node):
"""Replace `MakeVector` and `DimShuffle` (when used to transform a scalar
into a 1d vector) when they are found as the input of a `size` or `shape`
parameter by `JAXShapeTuple` during transpilation.
The JAX implementations of `MakeVector` and `DimShuffle` always return JAX
`TracedArrays`, but JAX only accepts concrete values as inputs for the `size`
or `shape` parameter. When these `Op`s are used to convert scalar or tuple
inputs, however, we can avoid tracing by making them return a tuple of their
inputs instead.
Note that JAX does not accept scalar inputs for the `size` or `shape`
parameters, and this rewrite also ensures that scalar inputs are turned into
tuples during transpilation.
"""
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
size_arg = node.inputs[1]
size_node = size_arg.owner
if size_node is None:
return
if isinstance(size_node.op, JAXShapeTuple):
return
if isinstance(size_node.op, MakeVector) or (
isinstance(size_node.op, DimShuffle)
and size_node.op.input_ndim == 0
and size_node.op.new_order == ("x",)
):
# Here PyTensor converted a tuple or list to a tensor
new_size_args = JAXShapeTuple()(*size_node.inputs)
new_inputs = list(node.inputs)
new_inputs[1] = new_size_args
new_node = node.clone_with_new_inputs(new_inputs)
return new_node.outputs


📚 Documentation preview 📚: https://pytensor--1029.org.readthedocs.build/en/1029/

@ricardoV94 ricardoV94 added enhancement New feature or request jax labels Oct 11, 2024
@ricardoV94 ricardoV94 changed the title Allow running RandomVariable graphs with scalar shape parameters in JAX backend Support RandomVariable graphs with scalar shape parameters in JAX backend Oct 11, 2024
@ricardoV94 ricardoV94 force-pushed the jax_scalar_rv_shapes branch 2 times, most recently from ff899b1 to 9b9bcba Compare October 11, 2024 09:44
Copy link

codecov bot commented Oct 11, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 82.11%. Comparing base (0824dba) to head (1e5c487).

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1029      +/-   ##
==========================================
- Coverage   82.12%   82.11%   -0.02%     
==========================================
  Files         183      183              
  Lines       48111    48122      +11     
  Branches     8667     8668       +1     
==========================================
+ Hits        39510    39513       +3     
- Misses       6435     6439       +4     
- Partials     2166     2170       +4     
Files with missing lines Coverage Δ
pytensor/link/jax/linker.py 96.22% <100.00%> (+0.98%) ⬆️

... and 1 file with indirect coverage changes

@ricardoV94
Copy link
Member Author

ricardoV94 commented Oct 11, 2024

This solves pymc-devs/pymc#7348

Copy link
Member

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. There are just 3 minor changes that I requested before merging.

if isinstance(node.op, JAXShapeTuple)
for inp in node.inputs
if inp in fgraph_inputs
and all(isinstance(node.op, JAXShapeTuple) for node, _ in clients[inp])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
and all(isinstance(node.op, JAXShapeTuple) for node, _ in clients[inp])
and all(isinstance(client.op, JAXShapeTuple) for client, _ in clients[inp])

I understand that the node name will be overriden in the last level of nesting, and that it won't affect the outermost node variable, but I think it's dangerous to override names in list comprehensions.


# Rebuild with strict=False so output type is not updated
# This reflects cases where size is constant folded during rewrites but the RV node is not recreated
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True)
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False)


# Rebuild with strict=True, so output type is updated
# This uses a different path in the dispatch implementation
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False)
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True)

The comments were inverted with respect to the code

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request jax
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants