Skip to content

Commit 12da6f7

Browse files
committed
add tests for adapter jacobians
1 parent 00ed2e5 commit 12da6f7

File tree

2 files changed

+74
-1
lines changed

2 files changed

+74
-1
lines changed

tests/test_adapters/conftest.py

+40
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,52 @@ def random_data():
4949
"z1": np.random.standard_normal(size=(32, 2)),
5050
"p1": np.random.lognormal(size=(32, 2)),
5151
"p2": np.random.lognormal(size=(32, 2)),
52+
"p3": np.random.lognormal(size=(32, 2)),
53+
"n1": 1 - np.random.lognormal(size=(32, 2)),
5254
"s1": np.random.standard_normal(size=(32, 3, 2)),
5355
"s2": np.random.standard_normal(size=(32, 3, 2)),
5456
"t1": np.zeros((3, 2)),
5557
"t2": np.ones((32, 3, 2)),
5658
"d1": np.random.standard_normal(size=(32, 2)),
5759
"d2": np.random.standard_normal(size=(32, 2)),
5860
"o1": np.random.randint(0, 9, size=(32, 2)),
61+
"u1": np.random.uniform(low=-1, high=2, size=(32, 1)),
5962
"key_to_split": np.random.standard_normal(size=(32, 10)),
6063
}
64+
65+
66+
@pytest.fixture()
67+
def adapter_jacobian():
68+
from bayesflow.adapters import Adapter
69+
70+
adapter = (
71+
Adapter()
72+
.scale("x1", by=2)
73+
.log("p1", p1=True)
74+
.sqrt("p2")
75+
.constrain("p3", lower=0)
76+
.constrain("n1", upper=1)
77+
.constrain("u1", lower=-1, upper=2)
78+
.concatenate(["p1", "p2", "p3"], into="p")
79+
.rename("u1", "u")
80+
)
81+
82+
return adapter
83+
84+
85+
@pytest.fixture()
86+
def adapter_jacobian_inverse():
87+
from bayesflow.adapters import Adapter
88+
89+
adapter = (
90+
Adapter()
91+
.standardize("x1", mean=1, std=2)
92+
.log("p1")
93+
.sqrt("p2")
94+
.constrain("p3", lower=0, method="log")
95+
.constrain("n1", upper=1, method="log")
96+
.constrain("u1", lower=-1, upper=2)
97+
.scale(["p1", "p2", "p3"], by=3.5)
98+
)
99+
100+
return adapter

tests/test_adapters/test_adapters.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def test_cycle_consistency(adapter, random_data):
1313
deprocessed = adapter(processed, inverse=True)
1414

1515
for key, value in random_data.items():
16-
if key in ["d1", "d2"]:
16+
if key in ["d1", "d2", "p3", "n1", "u1"]:
1717
# dropped
1818
continue
1919
assert key in deprocessed
@@ -230,3 +230,36 @@ def test_to_dict_transform():
230230

231231
# category should have 5 one-hot categories, even though it was only passed 4
232232
assert processed["category"].shape[-1] == 5
233+
234+
235+
def test_jacobian(adapter_jacobian, random_data):
236+
d, jacobian = adapter_jacobian(random_data, jacobian=True)
237+
238+
assert np.allclose(jacobian["x1"], np.log(2))
239+
240+
p1 = -np.log1p(random_data["p1"])
241+
p2 = -0.5 * np.log(random_data["p2"]) + 0.5
242+
p3 = random_data["p3"] - np.log(np.exp(random_data["p3"]) - 1)
243+
p = np.sum(p1, axis=-1) + np.sum(p2, axis=-1) + np.sum(p3, axis=-1)
244+
245+
assert np.allclose(jacobian["p"], p)
246+
247+
n1 = -(random_data["n1"] - 1)
248+
n1 = n1 - np.log(np.exp(n1) - 1)
249+
n1 = np.sum(n1, axis=-1)
250+
251+
assert np.allclose(jacobian["n1"], n1)
252+
253+
u1 = random_data["u1"]
254+
u1 = (u1 + 1) / 3
255+
u1 = -np.log(u1) - np.log1p(-u1) - np.log(3)
256+
257+
assert np.allclose(jacobian["u"], u1[:, 0])
258+
259+
260+
def test_jacobian_inverse(adapter_jacobian_inverse, random_data):
261+
d, forward_jacobian = adapter_jacobian_inverse(random_data, jacobian=True)
262+
d, inverse_jacobian = adapter_jacobian_inverse(d, inverse=True, jacobian=True)
263+
264+
for key in forward_jacobian.keys():
265+
assert np.allclose(forward_jacobian[key], -inverse_jacobian[key])

0 commit comments

Comments
 (0)