-
Notifications
You must be signed in to change notification settings - Fork 63
Subset arrays #411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
eodole
wants to merge
15
commits into
bayesflow-org:dev
Choose a base branch
from
eodole:subset_arrays
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Subset arrays #411
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
69e236d
made initial backend functions for adapter subsetting, need to still …
eodole 9c0da4c
added subsample functionality, to do would be adding them to testing …
eodole d57aee4
made the take function and ran the linter
eodole 8d834da
changed name of subsampling function
eodole 6c1d503
changed documentation, to be consistent with external notation, rathe…
eodole 2e83846
small formation change to documentation
eodole dee4534
changed subsample to have sample size and axis in the constructor
eodole 71dc35a
moved transforms in the adapter.py so they're in alphabetical order l…
eodole 6c34a5d
changed random_subsample to maptransform rather than filter transform
eodole c3640cb
updated documentation with new naming convention
eodole f17322f
added arguments of take to the constructor
eodole 5312c5f
added feature to specify a percentage of the data to subsample rather…
eodole 5361c04
changed subsample in adapter.py to allow float as an input for the sa…
eodole 504344b
renamed subsample_array and associated classes/functions to RandomSub…
eodole 4218b70
included TypeError to force users to only subsample one dataset at a …
eodole File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,3 +39,6 @@ docs/ | |
|
||
# MacOS | ||
.DS_Store | ||
|
||
# Rproj | ||
.Rproj.user | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import numpy as np | ||
from keras.saving import register_keras_serializable as serializable | ||
|
||
from .elementwise_transform import ElementwiseTransform | ||
|
||
|
||
@serializable(package="bayesflow.adapters") | ||
class RandomSubsample(ElementwiseTransform): | ||
""" | ||
A transform that takes a random subsample of the data within an axis. | ||
|
||
Example: adapter.random_subsample("x", sample_size = 3, axis = -1) | ||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
sample_size: int | float, | ||
axis: int = -1, | ||
): | ||
super().__init__() | ||
if isinstance(sample_size, float): | ||
if sample_size <= 0 or sample_size >= 1: | ||
ValueError("Sample size as a percentage must be a float between 0 and 1 exclustive. ") | ||
self.sample_size = sample_size | ||
self.axis = axis | ||
|
||
|
||
def forward(self, data: np.ndarray): | ||
|
||
axis = self.axis | ||
max_sample_size = data.shape[axis] | ||
|
||
if isinstance(self.sample_size, int): | ||
sample_size = self.sample_size | ||
else: | ||
sample_size = np.round(self.sample_size * max_sample_size) | ||
|
||
sample_indices = np.random.permutation(max_sample_size)[ | ||
0 : sample_size - 1 | ||
] # random sample without replacement | ||
|
||
return np.take(data, sample_indices, axis) | ||
|
||
def inverse(self, data, **kwargs): | ||
# non invertible transform | ||
return data |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import numpy as np | ||
from keras.saving import register_keras_serializable as serializable | ||
|
||
from .elementwise_transform import ElementwiseTransform | ||
|
||
|
||
@serializable(package="bayesflow.adapters") | ||
class Take(ElementwiseTransform): | ||
""" | ||
A transform to reduce the dimensionality of arrays output by the summary network | ||
Axis is a mandatory argument and will default to the last axis. | ||
Example: adapter.take("x", np.arange(0,3), axis=-1) | ||
|
||
""" | ||
|
||
def __init__(self,indices, axis=-1): | ||
super().__init__() | ||
self.indices = indices | ||
self.axis = axis | ||
|
||
|
||
def forward(self, data): | ||
return np.take(data, self.indices, self.axis) | ||
|
||
def inverse(self, data): | ||
# not a true invertible function | ||
return data |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,9 +30,9 @@ def check_ordering(output, axis): | |
assert np.all(np.diff(output, axis=axis) > 0), f"is not ordered along specified axis: {axis}." | ||
for i in range(output.ndim): | ||
if i != axis % output.ndim: | ||
assert not np.all(np.diff(output, axis=i) > 0), ( | ||
f"is ordered along axis which is not meant to be ordered: {i}." | ||
) | ||
assert not np.all( | ||
np.diff(output, axis=i) > 0 | ||
), f"is ordered along axis which is not meant to be ordered: {i}." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure why this is being reordered now, are you running ruff version There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. apparently I was running ruff 0.8.1 but I will update it |
||
|
||
|
||
@pytest.mark.parametrize("axis", [0, 1, 2]) | ||
|
@@ -69,6 +69,6 @@ def test_positive_semi_definite(random_matrix_batch): | |
output = keras.ops.convert_to_numpy(output) | ||
eigenvalues = np.linalg.eig(output).eigenvalues | ||
|
||
assert np.all(eigenvalues.real > 0) and np.all(np.isclose(eigenvalues.imag, 0)), ( | ||
f"output is not positive semi-definite: real={eigenvalues.real}, imag={eigenvalues.imag}" | ||
) | ||
assert np.all(eigenvalues.real > 0) and np.all( | ||
np.isclose(eigenvalues.imag, 0) | ||
), f"output is not positive semi-definite: real={eigenvalues.real}, imag={eigenvalues.imag}" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am unfamiliar with R. What is this directory used for, and should all other users have it ignored too? Otherwise, please put this in your local
.git/info/exclude
instead.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
according to @stefanradev93, this should be
.Rproj