-
Notifications
You must be signed in to change notification settings - Fork 256
Unify constraint handling for AutoContinuous
, AutoDelta
, AutoNormal
.
#2015
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR unifies constraint handling across auto guides by constructing distributions in the unconstrained space and then transforming them to the appropriate constrained support. Key changes include updating tests to verify matching supports between guide and model, revising docstrings to emphasize unconstrained values, and refactoring AutoGuide implementations to use a common helper for constraint handling.
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.
File | Description |
---|---|
test/infer/test_autoguide.py | Added tests to verify that guide and model distributions have matching supports. |
numpyro/infer/util.py | Updated docstrings to clarify that parameters are now in unconstrained space. |
numpyro/infer/autoguide.py | Refactored constraint handling in auto guides including updated event dimension logic. |
Comments suppressed due to low confidence (1)
test/infer/test_autoguide.py:487
- Consider adding AutoContinuous to the parameterized tests to ensure its constraint handling behavior is also validated.
@pytest.mark.parametrize("auto_class", [AutoNormal, AutoDelta])
c1a47d1
to
65f1b29
Compare
65f1b29
to
d861328
Compare
Interestingly, the |
I think it is memory issue. Could you try to set a smaller number for this flag: https://github.com/pyro-ppl/numpyro/blob/master/examples/neutra.py#L205 in the test? |
The example didn't pass with a smaller number of hidden features. I had a look at the last passing and first failing runs. The dependencies are almost the same and don't look like they would cause this issue. Maybe something changed about the runner? On a related note, should we consider locking the requirements using a version-locked 20c20
< coverage==7.7.1
---
> coverage==7.8.0
32c32
< flax==0.10.4
---
> flax==0.10.5
35c35
< fsspec==2025.3.0
---
> fsspec==2025.3.2
98c98
< -e git+https://github.com/pyro-ppl/numpyro@19fbd57d96973d7d33ec594ad99110ff544e9ea7#egg=numpyro
---
> -e git+https://github.com/pyro-ppl/numpyro@4027928e31e859cfc7eacc41d8b53baba36137de#egg=numpyro
133c133
< rich==13.9.4
---
> rich==14.0.0
170c170
< xarray==2025.3.0
---
> xarray==2025.3.1 |
Could you mark this test as xfail in CI instead? like in https://github.com/pyro-ppl/numpyro/blob/master/test/test_examples.py#L52-L59 Re version lock: how does it work? what if users want to just install numpyro without upgrading/degrading other dependencies? |
) | ||
|
||
site_fn = dist.Delta(site_loc).to_event(event_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here you are finding the MAP point in unconstrained space. This class gets MAP point in constrained space.
transform = biject_to(site["fn"].support) | ||
value = transform(unconstrained_value) | ||
event_ndim = site["fn"].event_dim | ||
if numpyro.get_mask() is False: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need this logic to save computation for prediction.
This PR unifies constraint handling for several auto guides. In short, distributions for all variables in the guide are constructed in unconstrained space. If the variable has non-real support, the distribution is transformed.
The original motivation was to ensure that the support of random variables in the guide and model match.
AutoDelta
andAutoContinuous
did not meet that requirement because they useDelta
distributions in constrained space.