Skip to content

Unify constraint handling for AutoContinuous, AutoDelta, AutoNormal. #2015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

tillahoffmann
Copy link
Collaborator

This PR unifies constraint handling for several auto guides. In short, distributions for all variables in the guide are constructed in unconstrained space. If the variable has non-real support, the distribution is transformed.

The original motivation was to ensure that the support of random variables in the guide and model match. AutoDelta and AutoContinuous did not meet that requirement because they use Delta distributions in constrained space.

@tillahoffmann tillahoffmann requested a review from Copilot April 2, 2025 02:54
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR unifies constraint handling across auto guides by constructing distributions in the unconstrained space and then transforming them to the appropriate constrained support. Key changes include updating tests to verify matching supports between guide and model, revising docstrings to emphasize unconstrained values, and refactoring AutoGuide implementations to use a common helper for constraint handling.

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.

File Description
test/infer/test_autoguide.py Added tests to verify that guide and model distributions have matching supports.
numpyro/infer/util.py Updated docstrings to clarify that parameters are now in unconstrained space.
numpyro/infer/autoguide.py Refactored constraint handling in auto guides including updated event dimension logic.
Comments suppressed due to low confidence (1)

test/infer/test_autoguide.py:487

  • Consider adding AutoContinuous to the parameterized tests to ensure its constraint handling behavior is also validated.
@pytest.mark.parametrize("auto_class", [AutoNormal, AutoDelta])

@tillahoffmann
Copy link
Collaborator Author

Interestingly, the neutra.py example fails with a SIGSEGV signal. I think this is not related to the changes here, as it also crops up in #3979. I don't get the error locally. Any idea what might be the cause?

@fehiepsi
Copy link
Member

fehiepsi commented Apr 2, 2025

I think it is memory issue. Could you try to set a smaller number for this flag: https://github.com/pyro-ppl/numpyro/blob/master/examples/neutra.py#L205 in the test?

@tillahoffmann
Copy link
Collaborator Author

The example didn't pass with a smaller number of hidden features. I had a look at the last passing and first failing runs. The dependencies are almost the same and don't look like they would cause this issue. Maybe something changed about the runner?

On a related note, should we consider locking the requirements using a version-locked requirements.txt or uv.lock? There were a few instances of hard-to-debug CI failures because of versions changes if I remember correctly.

20c20
< coverage==7.7.1
---
> coverage==7.8.0
32c32
< flax==0.10.4
---
> flax==0.10.5
35c35
< fsspec==2025.3.0
---
> fsspec==2025.3.2
98c98
< -e git+https://github.com/pyro-ppl/numpyro@19fbd57d96973d7d33ec594ad99110ff544e9ea7#egg=numpyro
---
> -e git+https://github.com/pyro-ppl/numpyro@4027928e31e859cfc7eacc41d8b53baba36137de#egg=numpyro
133c133
< rich==13.9.4
---
> rich==14.0.0
170c170
< xarray==2025.3.0
---
> xarray==2025.3.1

@fehiepsi
Copy link
Member

Could you mark this test as xfail in CI instead? like in https://github.com/pyro-ppl/numpyro/blob/master/test/test_examples.py#L52-L59

Re version lock: how does it work? what if users want to just install numpyro without upgrading/degrading other dependencies?

)

site_fn = dist.Delta(site_loc).to_event(event_dim)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you are finding the MAP point in unconstrained space. This class gets MAP point in constrained space.

transform = biject_to(site["fn"].support)
value = transform(unconstrained_value)
event_ndim = site["fn"].event_dim
if numpyro.get_mask() is False:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need this logic to save computation for prediction.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants