-
Notifications
You must be signed in to change notification settings - Fork 63
Adapter keeps track of the transform jacobians #419
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
Thanks for the PR. Could you add tests for this behavior @Kucharssim? See e.g., |
Yes, on it! |
This was accidentally closed. We will investigate how to restore the branch and reopen PRs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for implementing this! I noted some minor concerns, mostly regarding clarity.
|
||
def log_det_jac(self, data, inverse: bool = False, **kwargs) -> np.ndarray: | ||
if self.std is None: | ||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this not raise an error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, but this also revealed a deeper issue with my implementation, thanks! The problem is that the log_det_jac
here was called before the forward
pass, and since standardize
is a stateful layer, the result of log_det_jac
would be based on values that are out of date - here None
for the first time the standardize
is called. I flipped the order so that the forward is always called before log_det_jac
, but this imo also shows that it would be better to just compute the transform and the log_det_jac
at the same time as you pointed out earlier... 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the changes! I think we are mostly ready to merge this as-is. Can you address the open conversation above and the comments from Copilot? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR enhances the adapter functionality by adding support for tracking the log determinant of the Jacobians across transforms. Key changes include:
- Introducing log_det_jac methods for several transforms (e.g. standardize, sqrt, scale, constrain, etc.).
- Modifying the adapter’s forward and inverse methods to optionally return the log determinant of the Jacobian.
- Adding comprehensive tests to verify the correct computation and exception handling for the Jacobian tracking.
Reviewed Changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated 2 comments.
Show a summary per file
File | Description |
---|---|
tests/test_adapters/test_adapters.py | Updated tests to cover additional keys and jacobian tracking tests. |
tests/test_adapters/conftest.py | Introduced new fixtures that test the log_det_jac functionality. |
bayesflow/approximators/continuous_approximator.py | Modified forward method to incorporate and apply the change‐of‐variables formula. |
bayesflow/adapters/transforms/* | Added or updated log_det_jac methods across multiple transform files. |
bayesflow/adapters/adapter.py | Updated forward and inverse methods to support the log_det_jac flag. |
if ldj is None: | ||
continue | ||
elif key in log_det_jac: | ||
log_det_jac[key] += ldj | ||
else: | ||
log_det_jac[key] = ldj | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable 'ldj' is used outside the 'if key in data:' block, which could lead to a NameError when strict mode is False and a key is missing. Consider nesting the 'if ldj is None:' check inside the block where 'ldj' is defined.
if ldj is None: | |
continue | |
elif key in log_det_jac: | |
log_det_jac[key] += ldj | |
else: | |
log_det_jac[key] = ldj | |
if ldj is None: | |
continue | |
elif key in log_det_jac: | |
log_det_jac[key] += ldj | |
else: | |
log_det_jac[key] = ldj |
Copilot uses AI. Check for mistakes.
for transform in self.transforms: | ||
data = transform(data, stage=stage, **kwargs) | ||
transformed_data = transform(data, stage=stage, **kwargs) | ||
log_det_jac = transform.log_det_jac(data, log_det_jac, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The log determinant is computed using the original 'data' rather than the transformed output, which may lead to inconsistencies in the contribution of each transform’s Jacobian. Consider calling log_det_jac on the output of the transform to ensure that the logged determinant matches the actual transformation.
log_det_jac = transform.log_det_jac(data, log_det_jac, **kwargs) | |
log_det_jac = transform.log_det_jac(transformed_data, log_det_jac, **kwargs) |
Copilot uses AI. Check for mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that this is correct. The J is a function of the original data.
fixes #245
I have not implemented it for
NumpyTransform
because I was not sure about implementation w.r.t. serialization. I tried to make sure that if we cannot keep track of the jacobians, we raise an error rather than returning an incorrect output.