From f0bb9a8bc217fc8e7cc8e9acd57b812b30662b36 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Fri, 7 Mar 2025 12:37:26 -0500 Subject: [PATCH 1/3] use functional interface funsor.adjoint.adjoint --- .github/workflows/ci.yml | 1 - numpyro/contrib/funsor/discrete.py | 4 ++-- test/contrib/test_infer_discrete.py | 5 ----- 3 files changed, 2 insertions(+), 8 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a89ddff42..0cf421d34 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -84,7 +84,6 @@ jobs: JAX_CHECK_TRACER_LEAKS: 1 run: | pytest -vs test/contrib/einstein/test_steinvi.py::test_run_smoke -k ASVGD - pytest -vs test/contrib/test_infer_discrete.py::test_scan_hmm_smoke pytest -vs test/infer/test_mcmc.py::test_chain_inside_jit pytest -vs test/infer/test_mcmc.py::test_chain_jit_args_smoke pytest -vs test/infer/test_mcmc.py::test_reuse_mcmc_run diff --git a/numpyro/contrib/funsor/discrete.py b/numpyro/contrib/funsor/discrete.py index db074b8ef..0baaef3e1 100644 --- a/numpyro/contrib/funsor/discrete.py +++ b/numpyro/contrib/funsor/discrete.py @@ -56,14 +56,14 @@ def _sample_posterior( model_trace = trace(seed(model, rng_key)).get_trace(*args, **kwargs) first_available_dim = -_guess_max_plate_nesting(model_trace) - 1 - with funsor.adjoint.AdjointTape() as tape: + with funsor.interpretations.lazy: with block(), enum(first_available_dim=first_available_dim): log_prob, model_tr, log_measures = _enum_log_density( model, args, kwargs, {}, sum_op, prod_op ) with approx: - approx_factors = tape.adjoint(sum_op, prod_op, log_prob) + approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob) # construct a result trace to replay against the model sample_tr = model_tr.copy() diff --git a/test/contrib/test_infer_discrete.py b/test/contrib/test_infer_discrete.py index 3e67b42e9..c37f4efae 100644 --- a/test/contrib/test_infer_discrete.py +++ b/test/contrib/test_infer_discrete.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import logging -import os import numpy as np from numpy.testing import assert_allclose @@ -96,10 +95,6 @@ def hmm(data, hidden_dim=10): ], ) @pytest.mark.parametrize("temperature", [0, 1]) -@pytest.mark.xfail( - os.getenv("JAX_CHECK_TRACER_LEAKS") == "1", - reason="Expected tracer leak: https://github.com/pyro-ppl/numpyro/issues/1998", -) def test_scan_hmm_smoke(length, temperature): # This should match the example in the infer_discrete docstring. def hmm(data, hidden_dim=10): From fe7af199b6185b4ec0d939d605f38e0296b1a564 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sun, 9 Mar 2025 12:32:21 -0400 Subject: [PATCH 2/3] do not apply optimizer before adjoint --- numpyro/contrib/funsor/discrete.py | 2 +- numpyro/contrib/funsor/infer_util.py | 6 +++++- test/contrib/test_infer_discrete.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/numpyro/contrib/funsor/discrete.py b/numpyro/contrib/funsor/discrete.py index 0baaef3e1..86bffc791 100644 --- a/numpyro/contrib/funsor/discrete.py +++ b/numpyro/contrib/funsor/discrete.py @@ -59,7 +59,7 @@ def _sample_posterior( with funsor.interpretations.lazy: with block(), enum(first_available_dim=first_available_dim): log_prob, model_tr, log_measures = _enum_log_density( - model, args, kwargs, {}, sum_op, prod_op + model, args, kwargs, {}, sum_op, prod_op, apply_optimizer=False ) with approx: diff --git a/numpyro/contrib/funsor/infer_util.py b/numpyro/contrib/funsor/infer_util.py index 5e97f82d2..956abe353 100644 --- a/numpyro/contrib/funsor/infer_util.py +++ b/numpyro/contrib/funsor/infer_util.py @@ -194,7 +194,9 @@ def compute_markov_factors( return markov_factors -def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op): +def _enum_log_density( + model, model_args, model_kwargs, params, sum_op, prod_op, apply_optimizer=True +): """Helper function to compute elbo and extract its components from execution traces.""" model = substitute(model, data=params) with plate_to_enum_plate(): @@ -286,6 +288,8 @@ def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op): eliminate=sum_vars | prod_vars, plates=prod_vars, ) + if not apply_optimizer: + return lazy_result, model_trace, log_measures result = funsor.optimizer.apply_optimizer(lazy_result) if len(result.inputs) > 0: raise ValueError( diff --git a/test/contrib/test_infer_discrete.py b/test/contrib/test_infer_discrete.py index c37f4efae..84a60294c 100644 --- a/test/contrib/test_infer_discrete.py +++ b/test/contrib/test_infer_discrete.py @@ -48,7 +48,7 @@ def log_prob_sum(trace): return log_joint -@pytest.mark.parametrize("length", [1, 2, 10]) +@pytest.mark.parametrize("length", [1, 2, 8]) @pytest.mark.parametrize("temperature", [0, 1]) def test_hmm_smoke(length, temperature): # This should match the example in the infer_discrete docstring. From f0fe41ab50358a0648a22dd90a8ea7897ee0f61e Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sun, 9 Mar 2025 14:41:10 -0400 Subject: [PATCH 3/3] adjust precision of enum plates_6 --- test/contrib/test_enum_elbo.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py index b3a9330de..2c3656fef 100644 --- a/test/contrib/test_enum_elbo.py +++ b/test/contrib/test_enum_elbo.py @@ -1386,10 +1386,10 @@ def iplate_plate_loss_fn(params): params ) - assert_equal(iplate_iplate_loss, plate_iplate_loss, prec=1e-5) - assert_equal(iplate_iplate_grad, plate_iplate_grad, prec=1e-5) - assert_equal(iplate_iplate_loss, iplate_plate_loss, prec=1e-5) - assert_equal(iplate_iplate_grad, iplate_plate_grad, prec=1e-5) + assert_equal(iplate_iplate_loss, plate_iplate_loss, prec=2e-5) + assert_equal(iplate_iplate_grad, plate_iplate_grad, prec=2e-5) + assert_equal(iplate_iplate_loss, iplate_plate_loss, prec=2e-5) + assert_equal(iplate_iplate_grad, iplate_plate_grad, prec=2e-5) # But promoting both to plates should result in an error. with pytest.raises(ValueError, match="intractable!"):