diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 657e4f819..2b5768e99 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -85,7 +85,6 @@ jobs: env: JAX_CHECK_TRACER_LEAKS: 1 run: | - pytest -vs test/contrib/test_infer_discrete.py::test_scan_hmm_smoke pytest -vs test/infer/test_mcmc.py::test_chain_inside_jit pytest -vs test/infer/test_mcmc.py::test_chain_jit_args_smoke pytest -vs test/infer/test_mcmc.py::test_reuse_mcmc_run diff --git a/numpyro/contrib/funsor/discrete.py b/numpyro/contrib/funsor/discrete.py index db074b8ef..86bffc791 100644 --- a/numpyro/contrib/funsor/discrete.py +++ b/numpyro/contrib/funsor/discrete.py @@ -56,14 +56,14 @@ def _sample_posterior( model_trace = trace(seed(model, rng_key)).get_trace(*args, **kwargs) first_available_dim = -_guess_max_plate_nesting(model_trace) - 1 - with funsor.adjoint.AdjointTape() as tape: + with funsor.interpretations.lazy: with block(), enum(first_available_dim=first_available_dim): log_prob, model_tr, log_measures = _enum_log_density( - model, args, kwargs, {}, sum_op, prod_op + model, args, kwargs, {}, sum_op, prod_op, apply_optimizer=False ) with approx: - approx_factors = tape.adjoint(sum_op, prod_op, log_prob) + approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob) # construct a result trace to replay against the model sample_tr = model_tr.copy() diff --git a/numpyro/contrib/funsor/infer_util.py b/numpyro/contrib/funsor/infer_util.py index 5e97f82d2..956abe353 100644 --- a/numpyro/contrib/funsor/infer_util.py +++ b/numpyro/contrib/funsor/infer_util.py @@ -194,7 +194,9 @@ def compute_markov_factors( return markov_factors -def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op): +def _enum_log_density( + model, model_args, model_kwargs, params, sum_op, prod_op, apply_optimizer=True +): """Helper function to compute elbo and extract its components from execution traces.""" model = substitute(model, data=params) with plate_to_enum_plate(): @@ -286,6 +288,8 @@ def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op): eliminate=sum_vars | prod_vars, plates=prod_vars, ) + if not apply_optimizer: + return lazy_result, model_trace, log_measures result = funsor.optimizer.apply_optimizer(lazy_result) if len(result.inputs) > 0: raise ValueError( diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py index b3a9330de..2c3656fef 100644 --- a/test/contrib/test_enum_elbo.py +++ b/test/contrib/test_enum_elbo.py @@ -1386,10 +1386,10 @@ def iplate_plate_loss_fn(params): params ) - assert_equal(iplate_iplate_loss, plate_iplate_loss, prec=1e-5) - assert_equal(iplate_iplate_grad, plate_iplate_grad, prec=1e-5) - assert_equal(iplate_iplate_loss, iplate_plate_loss, prec=1e-5) - assert_equal(iplate_iplate_grad, iplate_plate_grad, prec=1e-5) + assert_equal(iplate_iplate_loss, plate_iplate_loss, prec=2e-5) + assert_equal(iplate_iplate_grad, plate_iplate_grad, prec=2e-5) + assert_equal(iplate_iplate_loss, iplate_plate_loss, prec=2e-5) + assert_equal(iplate_iplate_grad, iplate_plate_grad, prec=2e-5) # But promoting both to plates should result in an error. with pytest.raises(ValueError, match="intractable!"): diff --git a/test/contrib/test_infer_discrete.py b/test/contrib/test_infer_discrete.py index 3e67b42e9..84a60294c 100644 --- a/test/contrib/test_infer_discrete.py +++ b/test/contrib/test_infer_discrete.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import logging -import os import numpy as np from numpy.testing import assert_allclose @@ -49,7 +48,7 @@ def log_prob_sum(trace): return log_joint -@pytest.mark.parametrize("length", [1, 2, 10]) +@pytest.mark.parametrize("length", [1, 2, 8]) @pytest.mark.parametrize("temperature", [0, 1]) def test_hmm_smoke(length, temperature): # This should match the example in the infer_discrete docstring. @@ -96,10 +95,6 @@ def hmm(data, hidden_dim=10): ], ) @pytest.mark.parametrize("temperature", [0, 1]) -@pytest.mark.xfail( - os.getenv("JAX_CHECK_TRACER_LEAKS") == "1", - reason="Expected tracer leak: https://github.com/pyro-ppl/numpyro/issues/1998", -) def test_scan_hmm_smoke(length, temperature): # This should match the example in the infer_discrete docstring. def hmm(data, hidden_dim=10):