Skip to content

Adapter keeps track of the transform jacobians #419

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: dev
Choose a base branch
from

Conversation

Kucharssim
Copy link
Member

fixes #245

I have not implemented it for NumpyTransform because I was not sure about implementation w.r.t. serialization. I tried to make sure that if we cannot keep track of the jacobians, we raise an error rather than returning an incorrect output.

@Kucharssim Kucharssim requested a review from LarsKue April 17, 2025 10:11
Copy link

codecov bot commented Apr 17, 2025

Codecov Report

Attention: Patch coverage is 88.80597% with 15 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
bayesflow/approximators/continuous_approximator.py 0.00% 4 Missing ⚠️
bayesflow/adapters/transforms/filter_transform.py 85.00% 3 Missing ⚠️
bayesflow/adapters/transforms/map_transform.py 86.95% 3 Missing ⚠️
bayesflow/adapters/transforms/drop.py 50.00% 1 Missing ⚠️
...sflow/adapters/transforms/elementwise_transform.py 50.00% 1 Missing ⚠️
bayesflow/adapters/transforms/keep.py 50.00% 1 Missing ⚠️
bayesflow/adapters/transforms/numpy_transform.py 50.00% 1 Missing ⚠️
bayesflow/adapters/transforms/transform.py 50.00% 1 Missing ⚠️
Files with missing lines Coverage Δ
bayesflow/adapters/adapter.py 82.77% <100.00%> (+1.14%) ⬆️
bayesflow/adapters/transforms/concatenate.py 87.69% <100.00%> (+3.37%) ⬆️
bayesflow/adapters/transforms/constrain.py 83.49% <100.00%> (+16.82%) ⬆️
bayesflow/adapters/transforms/log.py 100.00% <100.00%> (ø)
bayesflow/adapters/transforms/rename.py 93.33% <100.00%> (+0.47%) ⬆️
bayesflow/adapters/transforms/scale.py 100.00% <100.00%> (ø)
bayesflow/adapters/transforms/sqrt.py 93.75% <100.00%> (+2.84%) ⬆️
bayesflow/adapters/transforms/standardize.py 97.67% <100.00%> (+0.37%) ⬆️
bayesflow/adapters/transforms/drop.py 88.23% <50.00%> (-5.10%) ⬇️
...sflow/adapters/transforms/elementwise_transform.py 78.94% <50.00%> (-3.41%) ⬇️
... and 6 more

... and 7 files with indirect coverage changes

@LarsKue
Copy link
Contributor

LarsKue commented Apr 17, 2025

Thanks for the PR. Could you add tests for this behavior @Kucharssim? See e.g., tests/test_networks/test_inference_networks.py::test_density_numerically for a related test.

@Kucharssim
Copy link
Member Author

Yes, on it!

@stefanradev93 stefanradev93 deleted the branch bayesflow-org:dev April 22, 2025 14:37
@LarsKue
Copy link
Contributor

LarsKue commented Apr 22, 2025

This was accidentally closed. We will investigate how to restore the branch and reopen PRs.

@LarsKue LarsKue reopened this Apr 22, 2025
Copy link
Contributor

@LarsKue LarsKue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for implementing this! I noted some minor concerns, mostly regarding clarity.


def log_det_jac(self, data, inverse: bool = False, **kwargs) -> np.ndarray:
if self.std is None:
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this not raise an error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, but this also revealed a deeper issue with my implementation, thanks! The problem is that the log_det_jac here was called before the forward pass, and since standardize is a stateful layer, the result of log_det_jac would be based on values that are out of date - here None for the first time the standardize is called. I flipped the order so that the forward is always called before log_det_jac, but this imo also shows that it would be better to just compute the transform and the log_det_jac at the same time as you pointed out earlier... 😅

@Kucharssim Kucharssim requested a review from LarsKue April 24, 2025 17:28
Copy link
Contributor

@LarsKue LarsKue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the changes! I think we are mostly ready to merge this as-is. Can you address the open conversation above and the comments from Copilot? Thanks!

@LarsKue LarsKue requested a review from Copilot April 25, 2025 17:48
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR enhances the adapter functionality by adding support for tracking the log determinant of the Jacobians across transforms. Key changes include:

  • Introducing log_det_jac methods for several transforms (e.g. standardize, sqrt, scale, constrain, etc.).
  • Modifying the adapter’s forward and inverse methods to optionally return the log determinant of the Jacobian.
  • Adding comprehensive tests to verify the correct computation and exception handling for the Jacobian tracking.

Reviewed Changes

Copilot reviewed 18 out of 18 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
tests/test_adapters/test_adapters.py Updated tests to cover additional keys and jacobian tracking tests.
tests/test_adapters/conftest.py Introduced new fixtures that test the log_det_jac functionality.
bayesflow/approximators/continuous_approximator.py Modified forward method to incorporate and apply the change‐of‐variables formula.
bayesflow/adapters/transforms/* Added or updated log_det_jac methods across multiple transform files.
bayesflow/adapters/adapter.py Updated forward and inverse methods to support the log_det_jac flag.

Comment on lines +76 to +82
if ldj is None:
continue
elif key in log_det_jac:
log_det_jac[key] += ldj
else:
log_det_jac[key] = ldj

Copy link
Preview

Copilot AI Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable 'ldj' is used outside the 'if key in data:' block, which could lead to a NameError when strict mode is False and a key is missing. Consider nesting the 'if ldj is None:' check inside the block where 'ldj' is defined.

Suggested change
if ldj is None:
continue
elif key in log_det_jac:
log_det_jac[key] += ldj
else:
log_det_jac[key] = ldj
if ldj is None:
continue
elif key in log_det_jac:
log_det_jac[key] += ldj
else:
log_det_jac[key] = ldj

Copilot uses AI. Check for mistakes.

for transform in self.transforms:
data = transform(data, stage=stage, **kwargs)
transformed_data = transform(data, stage=stage, **kwargs)
log_det_jac = transform.log_det_jac(data, log_det_jac, **kwargs)
Copy link
Preview

Copilot AI Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The log determinant is computed using the original 'data' rather than the transformed output, which may lead to inconsistencies in the contribution of each transform’s Jacobian. Consider calling log_det_jac on the output of the transform to ensure that the logged determinant matches the actual transformation.

Suggested change
log_det_jac = transform.log_det_jac(data, log_det_jac, **kwargs)
log_det_jac = transform.log_det_jac(transformed_data, log_det_jac, **kwargs)

Copilot uses AI. Check for mistakes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that this is correct. The J is a function of the original data.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants