Skip to content

Subset arrays #411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,6 @@ docs/

# MacOS
.DS_Store

# Rproj
.Rproj.user
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am unfamiliar with R. What is this directory used for, and should all other users have it ignored too? Otherwise, please put this in your local .git/info/exclude instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

according to @stefanradev93, this should be .Rproj

83 changes: 82 additions & 1 deletion bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
Standardize,
ToArray,
Transform,
RandomSubsample,
Take
)
from .transforms.filter_transform import Predicate

Expand Down Expand Up @@ -541,6 +543,47 @@ def one_hot(self, keys: str | Sequence[str], num_classes: int):
transform = MapTransform({key: OneHot(num_classes=num_classes) for key in keys})
self.transforms.append(transform)
return self

def random_subsample(self,
key: str | Sequence[str],
*,
sample_size: int | float,
axis: int=-1,
**kwargs,
):
"""
Append a :py:class:`~transforms.SubsampleArray` transform to the adapter.

Parameters
----------
predicate : Predicate, optional
Function that indicates which variables should be transformed.
include : str or Sequence of str, optional
Names of variables to include in the transform.
exclude : str or Sequence of str, optional
Names of variables to exclude from the transform.
**kwargs : dict
Additional keyword arguments passed to the transform.

"""


if isinstance(key, Sequence[str]) and len(keys) >1:
TypeError("`key` should be either a string or a list of length one. Only one dataset may be modified at a time.")

if isinstance(key, str):
keys = [key]

transform = MapTransform(
transform_map={
key:RandomSubsample(sample_size=sample_size, axis=axis)
for key in keys
}

)

self.transforms.append(transform)
return self

def rename(self, from_key: str, to_key: str):
"""Append a :py:class:`~transforms.Rename` transform to the adapter.
Expand Down Expand Up @@ -601,6 +644,38 @@ def standardize(
self.transforms.append(transform)
return self

def take(self,
indices,
axis,
*,
predicate: Predicate = None,
include: str | Sequence[str] = None,
exclude: str | Sequence[str] = None,
**kwargs,):
"""
Append a :py:class:`~transforms.Take` transform to the adapter.

Parameters
----------
predicate : Predicate, optional
Function that indicates which variables should be transformed.
include : str or Sequence of str, optional
Names of variables to include in the transform.
exclude : str or Sequence of str, optional
Names of variables to exclude from the transform.
**kwargs : dict
Additional keyword arguments passed to the transform. """
transform = FilterTransform(
transform_constructor=Take(indices=indices, axis=axis),
predicate=predicate,
include=include,
exclude=exclude,
**kwargs,
)
self.transforms.append(transform)
return self


def to_array(
self,
*,
Expand All @@ -618,7 +693,7 @@ def to_array(
include : str or Sequence of str, optional
Names of variables to include in the transform.
exclude : str or Sequence of str, optional
Names of variables to exclude from the transform.
Names of variabxles to exclude from the transform.
**kwargs : dict
Additional keyword arguments passed to the transform.
"""
Expand All @@ -631,3 +706,9 @@ def to_array(
)
self.transforms.append(transform)
return self






2 changes: 2 additions & 0 deletions bayesflow/adapters/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from .standardize import Standardize
from .to_array import ToArray
from .transform import Transform
from .random_subsample import RandomSubsample
from .take import Take

from ...utils._docs import _add_imports_to_all

Expand Down
47 changes: 47 additions & 0 deletions bayesflow/adapters/transforms/random_subsample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np
from keras.saving import register_keras_serializable as serializable

from .elementwise_transform import ElementwiseTransform


@serializable(package="bayesflow.adapters")
class RandomSubsample(ElementwiseTransform):
"""
A transform that takes a random subsample of the data within an axis.

Example: adapter.random_subsample("x", sample_size = 3, axis = -1)

"""

def __init__(
self,
sample_size: int | float,
axis: int = -1,
):
super().__init__()
if isinstance(sample_size, float):
if sample_size <= 0 or sample_size >= 1:
ValueError("Sample size as a percentage must be a float between 0 and 1 exclustive. ")
self.sample_size = sample_size
self.axis = axis


def forward(self, data: np.ndarray):

axis = self.axis
max_sample_size = data.shape[axis]

if isinstance(self.sample_size, int):
sample_size = self.sample_size
else:
sample_size = np.round(self.sample_size * max_sample_size)

sample_indices = np.random.permutation(max_sample_size)[
0 : sample_size - 1
] # random sample without replacement

return np.take(data, sample_indices, axis)

def inverse(self, data, **kwargs):
# non invertible transform
return data
27 changes: 27 additions & 0 deletions bayesflow/adapters/transforms/take.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import numpy as np
from keras.saving import register_keras_serializable as serializable

from .elementwise_transform import ElementwiseTransform


@serializable(package="bayesflow.adapters")
class Take(ElementwiseTransform):
"""
A transform to reduce the dimensionality of arrays output by the summary network
Axis is a mandatory argument and will default to the last axis.
Example: adapter.take("x", np.arange(0,3), axis=-1)

"""

def __init__(self,indices, axis=-1):
super().__init__()
self.indices = indices
self.axis = axis


def forward(self, data):
return np.take(data, self.indices, self.axis)

def inverse(self, data):
# not a true invertible function
return data
3 changes: 3 additions & 0 deletions tests/test_adapters/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def adapter():
.one_hot("o1", 10)
.keep(["x", "y", "z1", "p1", "p2", "s1", "s2", "t1", "t2", "o1"])
.rename("o1", "o2")
.random_subsample("s3", sample_size = 33, axis = 0)
.take("s3", indices = np.arange(0,32), axis = 0)
)

return d
Expand All @@ -47,4 +49,5 @@ def random_data():
"d1": np.random.standard_normal(size=(32, 2)),
"d2": np.random.standard_normal(size=(32, 2)),
"o1": np.random.randint(0, 9, size=(32, 2)),
"s3": np.random.standard_normal(size=(35,2))
}
12 changes: 6 additions & 6 deletions tests/test_links/test_links.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def check_ordering(output, axis):
assert np.all(np.diff(output, axis=axis) > 0), f"is not ordered along specified axis: {axis}."
for i in range(output.ndim):
if i != axis % output.ndim:
assert not np.all(np.diff(output, axis=i) > 0), (
f"is ordered along axis which is not meant to be ordered: {i}."
)
assert not np.all(
np.diff(output, axis=i) > 0
), f"is ordered along axis which is not meant to be ordered: {i}."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why this is being reordered now, are you running ruff version 0.11.2?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apparently I was running ruff 0.8.1 but I will update it



@pytest.mark.parametrize("axis", [0, 1, 2])
Expand Down Expand Up @@ -69,6 +69,6 @@ def test_positive_semi_definite(random_matrix_batch):
output = keras.ops.convert_to_numpy(output)
eigenvalues = np.linalg.eig(output).eigenvalues

assert np.all(eigenvalues.real > 0) and np.all(np.isclose(eigenvalues.imag, 0)), (
f"output is not positive semi-definite: real={eigenvalues.real}, imag={eigenvalues.imag}"
)
assert np.all(eigenvalues.real > 0) and np.all(
np.isclose(eigenvalues.imag, 0)
), f"output is not positive semi-definite: real={eigenvalues.real}, imag={eigenvalues.imag}"
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def test_save_and_load(tmp_path, point_inference_network, random_samples, random

for key_outer in out1.keys():
for key_inner in out1[key_outer].keys():
assert keras.ops.all(keras.ops.isclose(out1[key_outer][key_inner], out2[key_outer][key_inner])), (
"Output of original and loaded model differs significantly."
)
assert keras.ops.all(
keras.ops.isclose(out1[key_outer][key_inner], out2[key_outer][key_inner])
), "Output of original and loaded model differs significantly."


def test_copy_unequal(point_inference_network, random_samples, random_conditions):
Expand Down