Skip to content

Commit c5b96d9

Browse files
lucianopazricardoV94
authored andcommitted
Don't store fortran objects in ScipyGer tag
1 parent 0ebc83b commit c5b96d9

File tree

4 files changed

+24
-16
lines changed

4 files changed

+24
-16
lines changed

pytensor/tensor/blas.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -144,20 +144,20 @@
144144
# If check_init_y() == True we need to initialize y when beta == 0.
145145
def check_init_y():
146146
if check_init_y._result is None:
147-
if not have_fblas:
147+
if not have_fblas: # pragma: no cover
148148
check_init_y._result = False
149-
150-
y = float("NaN") * np.ones((2,))
151-
x = np.ones((2,))
152-
A = np.ones((2, 2))
153-
gemv = _blas_gemv_fns[y.dtype]
154-
gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True)
155-
check_init_y._result = np.isnan(y).any()
149+
else:
150+
y = float("NaN") * np.ones((2,))
151+
x = np.ones((2,))
152+
A = np.ones((2, 2))
153+
gemv = _blas_gemv_fns[y.dtype]
154+
gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True)
155+
check_init_y._result = np.isnan(y).any()
156156

157157
return check_init_y._result
158158

159159

160-
check_init_y._result = None
160+
check_init_y._result = None # type: ignore
161161

162162

163163
class Gemv(Op):

pytensor/tensor/blas_scipy.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,13 @@
1919

2020

2121
class ScipyGer(Ger):
22-
def prepare_node(self, node, storage_map, compute_map, impl):
23-
if impl == "py":
24-
node.tag.local_ger = _blas_ger_fns[np.dtype(node.inputs[0].type.dtype)]
25-
2622
def perform(self, node, inputs, output_storage):
2723
cA, calpha, cx, cy = inputs
2824
(cZ,) = output_storage
2925
# N.B. some versions of scipy (e.g. mine) don't actually work
3026
# in-place on a, even when I tell it to.
3127
A = cA
32-
local_ger = node.tag.local_ger
28+
local_ger = _blas_ger_fns[cA.dtype]
3329
if A.size == 0:
3430
# We don't have to compute anything, A is empty.
3531
# We need this special case because Numpy considers it

scripts/mypy-failing.txt

+1-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ pytensor/scalar/basic.py
1717
pytensor/sparse/basic.py
1818
pytensor/sparse/type.py
1919
pytensor/tensor/basic.py
20-
pytensor/tensor/blas.py
2120
pytensor/tensor/blas_c.py
2221
pytensor/tensor/blas_headers.py
2322
pytensor/tensor/elemwise.py
@@ -31,4 +30,4 @@ pytensor/tensor/slinalg.py
3130
pytensor/tensor/subtensor.py
3231
pytensor/tensor/type.py
3332
pytensor/tensor/type_other.py
34-
pytensor/tensor/variable.py
33+
pytensor/tensor/variable.py

tests/tensor/test_blas_scipy.py

+13
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pickle
2+
13
import numpy as np
24
import pytest
35

@@ -58,6 +60,17 @@ def test_scaled_A_plus_scaled_outer(self):
5860
self.assertFunctionContains(f, gemm_no_inplace)
5961
self.run_f(f) # DebugMode tests correctness
6062

63+
def test_pickle(self):
64+
out = ScipyGer(destructive=False)(self.A, self.a, self.x, self.y)
65+
f = pytensor.function([self.A, self.a, self.x, self.y], out)
66+
new_f = pickle.loads(pickle.dumps(f))
67+
68+
assert isinstance(new_f.maker.fgraph.toposort()[-1].op, ScipyGer)
69+
assert np.allclose(
70+
f(self.Aval, 1.0, self.xval, self.yval),
71+
new_f(self.Aval, 1.0, self.xval, self.yval),
72+
)
73+
6174

6275
class TestBlasStridesScipy(TestBlasStrides):
6376
mode = pytensor.compile.get_default_mode()

0 commit comments

Comments
 (0)