From c19962bf9e22742674b8ae5742a4f41f53eb3d18 Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Tue, 1 Apr 2025 21:25:20 -0400 Subject: [PATCH 1/5] Note that `initialize_model` returns unconstrained parameters. --- numpyro/infer/util.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index fe39a87bf..61f756d3e 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -668,12 +668,12 @@ def initialize_model( Defaults to True. :return: a namedtupe `ModelInfo` which contains the fields (`param_info`, `potential_fn`, `postprocess_fn`, `model_trace`), where - `param_info` is a namedtuple `ParamInfo` containing values from the prior - used to initiate MCMC, their corresponding potential energy, and their gradients; - `postprocess_fn` is a callable that uses inverse transforms - to convert unconstrained HMC samples to constrained values that - lie within the site's support, in addition to returning values - at `deterministic` sites in the model. + `param_info` is a namedtuple `ParamInfo` containing *unconstrained* values from + the prior used to initiate MCMC, their corresponding potential energy, and their + gradients; `postprocess_fn` is a callable that uses inverse transforms to + convert unconstrained HMC samples to constrained values that lie within the + site's support, in addition to returning values at `deterministic` sites in the + model. """ model_kwargs = {} if model_kwargs is None else model_kwargs substituted_model = substitute( From 48dd429147d2bfb62ccef6c00119c96e06760afb Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Tue, 1 Apr 2025 22:01:25 -0400 Subject: [PATCH 2/5] Unify constraint handling for `AutoContinuous`, `AutoDelta`, `AutoNormal`. --- numpyro/infer/autoguide.py | 88 +++++++++++++++++------------------- test/infer/test_autoguide.py | 8 ++++ 2 files changed, 49 insertions(+), 47 deletions(-) diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index 85351a0a5..53b91d8eb 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -339,6 +339,28 @@ def quantiles(self, params, quantiles): return result +def _maybe_constrain_dist_for_site( + site: dict, base_distribution: dist.Distribution +) -> dist.Distribution: + support = site["fn"].support + + # Short-circuit if the support is real and return the base distribution with the + # correct number of event dimensions. + base_support = support + while isinstance(base_support, constraints.independent): + base_support = base_support.base_constraint + if base_support is constraints.real: + if support.event_dim: + return base_distribution.to_event(support.event_dim) + else: + return base_distribution + + # Transform the distribution to the support of the site. + with helpful_support_errors(site): + transform = biject_to(support) + return dist.TransformedDistribution(base_distribution, transform) + + class AutoNormal(AutoGuide): """ This implementation of :class:`AutoGuide` uses Normal distributions @@ -431,18 +453,11 @@ def __call__(self, *args, **kwargs): constraint=self.scale_constraint, event_dim=event_dim, ) - - site_fn = dist.Normal(site_loc, site_scale).to_event(event_dim) - if site["fn"].support is constraints.real or ( - isinstance(site["fn"].support, constraints.independent) - and site["fn"].support.base_constraint is constraints.real - ): - result[name] = numpyro.sample(name, site_fn) - else: - with helpful_support_errors(site): - transform = biject_to(site["fn"].support) - guide_dist = dist.TransformedDistribution(site_fn, transform) - result[name] = numpyro.sample(name, guide_dist) + unconstrained_dist = dist.Normal(site_loc, site_scale) + constrained_dist = _maybe_constrain_dist_for_site( + site, unconstrained_dist + ) + result[name] = numpyro.sample(name, constrained_dist) return result @@ -528,12 +543,6 @@ def __init__( def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) - with numpyro.handlers.block(): - self._init_locs = { - k: v - for k, v in self._postprocess_fn(self._init_locs).items() - if k in self._init_locs - } for name, site in self.prototype_trace.items(): if site["type"] != "sample" or site["is_observed"]: continue @@ -561,26 +570,22 @@ def __call__(self, *args, **kwargs): if site["type"] != "sample" or site["is_observed"]: continue - event_dim = self._event_dims[name] init_loc = self._init_locs[name] with ExitStack() as stack: for frame in site["cond_indep_stack"]: stack.enter_context(plates[frame.name]) - site_loc = numpyro.param( - "{}_{}_loc".format(name, self.prefix), - init_loc, - constraint=site["fn"].support, - event_dim=event_dim, + site_loc = numpyro.param(f"{name}_{self.prefix}_loc", init_loc) + unconstrained_dist = dist.Delta(site_loc) + constrained_dist = _maybe_constrain_dist_for_site( + site, unconstrained_dist ) - - site_fn = dist.Delta(site_loc).to_event(event_dim) - result[name] = numpyro.sample(name, site_fn) + result[name] = numpyro.sample(name, constrained_dist) return result def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs): - locs = {k: params["{}_{}_loc".format(k, self.prefix)] for k in self._init_locs} + locs = self.median(params) latent_samples = { k: jnp.broadcast_to(v, sample_shape + jnp.shape(v)) for k, v in locs.items() } @@ -600,7 +605,11 @@ def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs): return {**latent_samples, **deterministic_samples} def median(self, params): - locs = {k: params["{}_{}_loc".format(k, self.prefix)] for k in self._init_locs} + locs = {} + for name in self._init_locs: + unconstrained = params[f"{name}_{self.prefix}_loc"] + transform = biject_to(self.prototype_trace[name]["fn"].support) + locs[name] = transform(unconstrained) return locs @@ -708,26 +717,11 @@ def __call__(self, *args, **kwargs): # unpack continuous latent samples result = {} - for name, unconstrained_value in self._unpack_latent(latent).items(): site = self.prototype_trace[name] - with helpful_support_errors(site): - transform = biject_to(site["fn"].support) - value = transform(unconstrained_value) - event_ndim = site["fn"].event_dim - if numpyro.get_mask() is False: - log_density = 0.0 - else: - log_density = -transform.log_abs_det_jacobian( - unconstrained_value, value - ) - log_density = sum_rightmost( - log_density, jnp.ndim(log_density) - jnp.ndim(value) + event_ndim - ) - delta_dist = dist.Delta( - value, log_density=log_density, event_dim=event_ndim - ) - result[name] = numpyro.sample(name, delta_dist) + unconstrained_dist = dist.Delta(unconstrained_value) + constrained_dist = _maybe_constrain_dist_for_site(site, unconstrained_dist) + result[name] = numpyro.sample(name, constrained_dist) return result diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index ecf2b3c9f..195bc97fe 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -122,6 +122,14 @@ def body_fn(i, val): predictive_samples = predictive(random.PRNGKey(1), None) assert predictive_samples["obs"].shape == (1000, N, 2) + # Check support of guide and model match. + model_trace = handlers.trace(handlers.seed(model, 0)).get_trace(data) + guide_trace = handlers.trace(handlers.seed(guide, 0)).get_trace(data) + for name, guide_site in guide_trace.items(): + if guide_site["type"] == "sample" and name in model_trace: + model_site = model_trace[name] + assert guide_site["fn"].support == model_site["fn"].support + @pytest.mark.parametrize( "auto_class", From bb58d773a034d37ab1e54b841ef8d4bed68d71fd Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Tue, 1 Apr 2025 22:46:34 -0400 Subject: [PATCH 3/5] Fix `event_dim` for `AutoDelta` for guide subsampling and add test. --- numpyro/infer/autoguide.py | 11 +++++++++-- test/infer/test_autoguide.py | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index 53b91d8eb..ba2c6d808 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -547,7 +547,11 @@ def _setup_prototype(self, *args, **kwargs): if site["type"] != "sample" or site["is_observed"]: continue - event_dim = site["fn"].event_dim + event_dim = ( + site["fn"].event_dim + + jnp.ndim(self._init_locs[name]) + - jnp.ndim(site["value"]) + ) self._event_dims[name] = event_dim # If subsampling, repeat init_value to full size. @@ -571,11 +575,14 @@ def __call__(self, *args, **kwargs): continue init_loc = self._init_locs[name] + event_dim = self._event_dims[name] with ExitStack() as stack: for frame in site["cond_indep_stack"]: stack.enter_context(plates[frame.name]) - site_loc = numpyro.param(f"{name}_{self.prefix}_loc", init_loc) + site_loc = numpyro.param( + f"{name}_{self.prefix}_loc", init_loc, event_dim=event_dim + ) unconstrained_dist = dist.Delta(site_loc) constrained_dist = _maybe_constrain_dist_for_site( site, unconstrained_dist diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index 195bc97fe..7eacf3d19 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -484,7 +484,7 @@ def model(x, y): lax.scan(lambda state, i: svi.update(state), svi_state, jnp.zeros(1000)) -@pytest.mark.parametrize("auto_class", [AutoNormal]) +@pytest.mark.parametrize("auto_class", [AutoNormal, AutoDelta]) def test_subsample_guide(auto_class): # The model adapted from tutorial/source/easyguide.ipynb def model(batch, subsample, full_size): From d861328417496cb1e8762d42fa61c5fe07eb3f66 Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Wed, 2 Apr 2025 10:33:36 -0400 Subject: [PATCH 4/5] Log `stdout` and `stderr` for failed examples. --- test/test_examples.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/test/test_examples.py b/test/test_examples.py index ebb034af9..f92f0da51 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -1,8 +1,9 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import logging import os -from subprocess import check_call +from subprocess import PIPE, CalledProcessError, check_output import sys import pytest @@ -80,9 +81,15 @@ @pytest.mark.parametrize("example", EXAMPLES) @pytest.mark.filterwarnings("ignore:There are not enough devices:UserWarning") @pytest.mark.filterwarnings("ignore:Higgs is a 2.6 GB dataset:UserWarning") -def test_cpu(example): +def test_cpu(request: pytest.FixtureRequest, example): print("Running:\npython examples/{}".format(example)) example = example.split() filename, args = example[0], example[1:] filename = os.path.join(EXAMPLES_DIR, filename) - check_call([sys.executable, filename] + args) + try: + check_output([sys.executable, filename] + args, text=True, stderr=PIPE) + except CalledProcessError as ex: + logger = logging.getLogger(request.node.name) + logger.error("stdout:\n%s", ex.stdout) + logger.error("stderr:\n%s", ex.stderr) + raise From 8be07a14e1f9d4680b06598dca3f092dc8095d72 Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Wed, 2 Apr 2025 14:36:34 -0400 Subject: [PATCH 5/5] Halve default number of hidden factors in `neutra` example. --- examples/neutra.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/neutra.py b/examples/neutra.py index c69d2a722..2a9ee4942 100644 --- a/examples/neutra.py +++ b/examples/neutra.py @@ -202,7 +202,7 @@ def main(args): parser.add_argument("-n", "--num-samples", nargs="?", default=4000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) parser.add_argument("--num-chains", nargs="?", default=1, type=int) - parser.add_argument("--hidden-factor", nargs="?", default=8, type=int) + parser.add_argument("--hidden-factor", nargs="?", default=4, type=int) parser.add_argument("--num-iters", nargs="?", default=10000, type=int) parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".') args = parser.parse_args()