Skip to content

Commit 4d09e46

Browse files
committedFeb 11, 2022
tolerance
1 parent bca6004 commit 4d09e46

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed
 

‎tests/test_nnj.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from copy import deepcopy
22
from typing import Callable
33

4-
import numpy
54
import pytest
65
import torch
76

87
from stochman import nnj
98

9+
_ = torch.manual_seed(42)
10+
1011
_batch_size = 2
1112
_features = 5
1213
_dims = 6
@@ -144,7 +145,7 @@ def test_jacobians(self, model, input_shape, device, dtype):
144145
input = torch.randn(*input_shape, device=device, dtype=dtype)
145146
_, jac = model(input, jacobian=True)
146147
jacnum = _compare_jacobian(model, input).to(device)
147-
assert torch.isclose(jac, jacnum, atol=1e-5).all(), "jacobians did not match"
148+
assert torch.isclose(jac, jacnum, atol=1e-4).all(), "jacobians did not match"
148149

149150
@pytest.mark.parametrize("return_jac", [True, False])
150151
def test_jac_return(self, model, input_shape, device, return_jac):

0 commit comments

Comments
 (0)