Skip to content

Unify constraint handling for AutoContinuous, AutoDelta, AutoNormal. #2015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/neutra.py
Original file line number Diff line number Diff line change
@@ -202,7 +202,7 @@ def main(args):
parser.add_argument("-n", "--num-samples", nargs="?", default=4000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
parser.add_argument("--num-chains", nargs="?", default=1, type=int)
parser.add_argument("--hidden-factor", nargs="?", default=8, type=int)
parser.add_argument("--hidden-factor", nargs="?", default=4, type=int)
parser.add_argument("--num-iters", nargs="?", default=10000, type=int)
parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
args = parser.parse_args()
95 changes: 48 additions & 47 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
@@ -339,6 +339,28 @@ def quantiles(self, params, quantiles):
return result


def _maybe_constrain_dist_for_site(
site: dict, base_distribution: dist.Distribution
) -> dist.Distribution:
support = site["fn"].support

# Short-circuit if the support is real and return the base distribution with the
# correct number of event dimensions.
base_support = support
while isinstance(base_support, constraints.independent):
base_support = base_support.base_constraint
if base_support is constraints.real:
if support.event_dim:
return base_distribution.to_event(support.event_dim)
else:
return base_distribution

# Transform the distribution to the support of the site.
with helpful_support_errors(site):
transform = biject_to(support)
return dist.TransformedDistribution(base_distribution, transform)


class AutoNormal(AutoGuide):
"""
This implementation of :class:`AutoGuide` uses Normal distributions
@@ -431,18 +453,11 @@ def __call__(self, *args, **kwargs):
constraint=self.scale_constraint,
event_dim=event_dim,
)

site_fn = dist.Normal(site_loc, site_scale).to_event(event_dim)
if site["fn"].support is constraints.real or (
isinstance(site["fn"].support, constraints.independent)
and site["fn"].support.base_constraint is constraints.real
):
result[name] = numpyro.sample(name, site_fn)
else:
with helpful_support_errors(site):
transform = biject_to(site["fn"].support)
guide_dist = dist.TransformedDistribution(site_fn, transform)
result[name] = numpyro.sample(name, guide_dist)
unconstrained_dist = dist.Normal(site_loc, site_scale)
constrained_dist = _maybe_constrain_dist_for_site(
site, unconstrained_dist
)
result[name] = numpyro.sample(name, constrained_dist)

return result

@@ -528,17 +543,15 @@ def __init__(

def _setup_prototype(self, *args, **kwargs):
super()._setup_prototype(*args, **kwargs)
with numpyro.handlers.block():
self._init_locs = {
k: v
for k, v in self._postprocess_fn(self._init_locs).items()
if k in self._init_locs
}
for name, site in self.prototype_trace.items():
if site["type"] != "sample" or site["is_observed"]:
continue

event_dim = site["fn"].event_dim
event_dim = (
site["fn"].event_dim
+ jnp.ndim(self._init_locs[name])
- jnp.ndim(site["value"])
)
self._event_dims[name] = event_dim

# If subsampling, repeat init_value to full size.
@@ -561,26 +574,25 @@ def __call__(self, *args, **kwargs):
if site["type"] != "sample" or site["is_observed"]:
continue

event_dim = self._event_dims[name]
init_loc = self._init_locs[name]
event_dim = self._event_dims[name]
with ExitStack() as stack:
for frame in site["cond_indep_stack"]:
stack.enter_context(plates[frame.name])

site_loc = numpyro.param(
"{}_{}_loc".format(name, self.prefix),
init_loc,
constraint=site["fn"].support,
event_dim=event_dim,
f"{name}_{self.prefix}_loc", init_loc, event_dim=event_dim
)

site_fn = dist.Delta(site_loc).to_event(event_dim)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you are finding the MAP point in unconstrained space. This class gets MAP point in constrained space.

result[name] = numpyro.sample(name, site_fn)
unconstrained_dist = dist.Delta(site_loc)
constrained_dist = _maybe_constrain_dist_for_site(
site, unconstrained_dist
)
result[name] = numpyro.sample(name, constrained_dist)

return result

def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs):
locs = {k: params["{}_{}_loc".format(k, self.prefix)] for k in self._init_locs}
locs = self.median(params)
latent_samples = {
k: jnp.broadcast_to(v, sample_shape + jnp.shape(v)) for k, v in locs.items()
}
@@ -600,7 +612,11 @@ def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs):
return {**latent_samples, **deterministic_samples}

def median(self, params):
locs = {k: params["{}_{}_loc".format(k, self.prefix)] for k in self._init_locs}
locs = {}
for name in self._init_locs:
unconstrained = params[f"{name}_{self.prefix}_loc"]
transform = biject_to(self.prototype_trace[name]["fn"].support)
locs[name] = transform(unconstrained)
return locs


@@ -708,26 +724,11 @@ def __call__(self, *args, **kwargs):

# unpack continuous latent samples
result = {}

for name, unconstrained_value in self._unpack_latent(latent).items():
site = self.prototype_trace[name]
with helpful_support_errors(site):
transform = biject_to(site["fn"].support)
value = transform(unconstrained_value)
event_ndim = site["fn"].event_dim
if numpyro.get_mask() is False:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need this logic to save computation for prediction.

log_density = 0.0
else:
log_density = -transform.log_abs_det_jacobian(
unconstrained_value, value
)
log_density = sum_rightmost(
log_density, jnp.ndim(log_density) - jnp.ndim(value) + event_ndim
)
delta_dist = dist.Delta(
value, log_density=log_density, event_dim=event_ndim
)
result[name] = numpyro.sample(name, delta_dist)
unconstrained_dist = dist.Delta(unconstrained_value)
constrained_dist = _maybe_constrain_dist_for_site(site, unconstrained_dist)
result[name] = numpyro.sample(name, constrained_dist)

return result

12 changes: 6 additions & 6 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
@@ -668,12 +668,12 @@ def initialize_model(
Defaults to True.
:return: a namedtupe `ModelInfo` which contains the fields
(`param_info`, `potential_fn`, `postprocess_fn`, `model_trace`), where
`param_info` is a namedtuple `ParamInfo` containing values from the prior
used to initiate MCMC, their corresponding potential energy, and their gradients;
`postprocess_fn` is a callable that uses inverse transforms
to convert unconstrained HMC samples to constrained values that
lie within the site's support, in addition to returning values
at `deterministic` sites in the model.
`param_info` is a namedtuple `ParamInfo` containing *unconstrained* values from
the prior used to initiate MCMC, their corresponding potential energy, and their
gradients; `postprocess_fn` is a callable that uses inverse transforms to
convert unconstrained HMC samples to constrained values that lie within the
site's support, in addition to returning values at `deterministic` sites in the
model.
"""
model_kwargs = {} if model_kwargs is None else model_kwargs
substituted_model = substitute(
10 changes: 9 additions & 1 deletion test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
@@ -122,6 +122,14 @@ def body_fn(i, val):
predictive_samples = predictive(random.PRNGKey(1), None)
assert predictive_samples["obs"].shape == (1000, N, 2)

# Check support of guide and model match.
model_trace = handlers.trace(handlers.seed(model, 0)).get_trace(data)
guide_trace = handlers.trace(handlers.seed(guide, 0)).get_trace(data)
for name, guide_site in guide_trace.items():
if guide_site["type"] == "sample" and name in model_trace:
model_site = model_trace[name]
assert guide_site["fn"].support == model_site["fn"].support


@pytest.mark.parametrize(
"auto_class",
@@ -476,7 +484,7 @@ def model(x, y):
lax.scan(lambda state, i: svi.update(state), svi_state, jnp.zeros(1000))


@pytest.mark.parametrize("auto_class", [AutoNormal])
@pytest.mark.parametrize("auto_class", [AutoNormal, AutoDelta])
def test_subsample_guide(auto_class):
# The model adapted from tutorial/source/easyguide.ipynb
def model(batch, subsample, full_size):
13 changes: 10 additions & 3 deletions test/test_examples.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import logging
import os
from subprocess import check_call
from subprocess import PIPE, CalledProcessError, check_output
import sys

import pytest
@@ -80,9 +81,15 @@
@pytest.mark.parametrize("example", EXAMPLES)
@pytest.mark.filterwarnings("ignore:There are not enough devices:UserWarning")
@pytest.mark.filterwarnings("ignore:Higgs is a 2.6 GB dataset:UserWarning")
def test_cpu(example):
def test_cpu(request: pytest.FixtureRequest, example):
print("Running:\npython examples/{}".format(example))
example = example.split()
filename, args = example[0], example[1:]
filename = os.path.join(EXAMPLES_DIR, filename)
check_call([sys.executable, filename] + args)
try:
check_output([sys.executable, filename] + args, text=True, stderr=PIPE)
except CalledProcessError as ex:
logger = logging.getLogger(request.node.name)
logger.error("stdout:\n%s", ex.stdout)
logger.error("stderr:\n%s", ex.stderr)
raise
Loading