-
Notifications
You must be signed in to change notification settings - Fork 129
Support RandomVariable graphs with scalar shape parameters in JAX backend #1029
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
ff899b1
to
9b9bcba
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1029 +/- ##
==========================================
- Coverage 82.12% 82.11% -0.02%
==========================================
Files 183 183
Lines 48111 48122 +11
Branches 8667 8668 +1
==========================================
+ Hits 39510 39513 +3
- Misses 6435 6439 +4
- Partials 2166 2170 +4
|
This solves pymc-devs/pymc#7348 |
9b9bcba
to
1e5c487
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. There are just 3 minor changes that I requested before merging.
if isinstance(node.op, JAXShapeTuple) | ||
for inp in node.inputs | ||
if inp in fgraph_inputs | ||
and all(isinstance(node.op, JAXShapeTuple) for node, _ in clients[inp]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and all(isinstance(node.op, JAXShapeTuple) for node, _ in clients[inp]) | |
and all(isinstance(client.op, JAXShapeTuple) for client, _ in clients[inp]) |
I understand that the node
name will be overriden in the last level of nesting, and that it won't affect the outermost node
variable, but I think it's dangerous to override names in list comprehensions.
|
||
# Rebuild with strict=False so output type is not updated | ||
# This reflects cases where size is constant folded during rewrites but the RV node is not recreated | ||
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True) | |
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False) |
|
||
# Rebuild with strict=True, so output type is updated | ||
# This uses a different path in the dispatch implementation | ||
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False) | |
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True) |
The comments were inverted with respect to the code
This should make it possible to do forward sampling in more PyMC models that use dims to define variables shapes
These was already a special rewrite to replace make_vector, expand_dims in the shape of RVs, but without handling these inputs from the outside it wouldn't achieve much for PyTensor users:
pytensor/pytensor/tensor/random/rewriting/jax.py
Lines 38 to 77 in e88117e
📚 Documentation preview 📚: https://pytensor--1029.org.readthedocs.build/en/1029/