Skip to content

Commit cc8c499

Browse files
AdvH039ricardoV94
authored andcommitted
Stop using FunctionGraph and tag.test_value in linker tests
Co-authored-by: Adv <adhvaithhundi.221ds003@nitk.edu.in>
1 parent 51ea1a0 commit cc8c499

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1098
-1597
lines changed

tests/link/jax/test_basic.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
from pytensor.compile.builders import OpFromGraph
88
from pytensor.compile.function import function
99
from pytensor.compile.mode import JAX, Mode
10-
from pytensor.compile.sharedvalue import SharedVariable, shared
10+
from pytensor.compile.sharedvalue import shared
1111
from pytensor.configdefaults import config
1212
from pytensor.graph import RewriteDatabaseQuery
13-
from pytensor.graph.basic import Apply
13+
from pytensor.graph.basic import Apply, Variable
1414
from pytensor.graph.fg import FunctionGraph
15-
from pytensor.graph.op import Op, get_test_value
15+
from pytensor.graph.op import Op
1616
from pytensor.ifelse import ifelse
1717
from pytensor.link.jax import JAXLinker
1818
from pytensor.raise_op import assert_op
@@ -34,25 +34,28 @@ def set_pytensor_flags():
3434

3535

3636
def compare_jax_and_py(
37-
fgraph: FunctionGraph,
37+
graph_inputs: Iterable[Variable],
38+
graph_outputs: Variable | Iterable[Variable],
3839
test_inputs: Iterable,
40+
*,
3941
assert_fn: Callable | None = None,
4042
must_be_device_array: bool = True,
4143
jax_mode=jax_mode,
4244
py_mode=py_mode,
4345
):
44-
"""Function to compare python graph output and jax compiled output for testing equality
46+
"""Function to compare python function output and jax compiled output for testing equality
4547
46-
In the tests below computational graphs are defined in PyTensor. These graphs are then passed to
47-
this function which then compiles the graphs in both jax and python, runs the calculation
48-
in both and checks if the results are the same
48+
The inputs and outputs are then passed to this function which then compiles the given function in both
49+
jax and python, runs the calculation in both and checks if the results are the same
4950
5051
Parameters
5152
----------
52-
fgraph: FunctionGraph
53-
PyTensor function Graph object
53+
graph_inputs:
54+
Symbolic inputs to the graph
55+
outputs:
56+
Symbolic outputs of the graph
5457
test_inputs: iter
55-
Numerical inputs for testing the function graph
58+
Numerical inputs for testing the function.
5659
assert_fn: func, opt
5760
Assert function used to check for equality between python and jax. If not
5861
provided uses np.testing.assert_allclose
@@ -68,8 +71,10 @@ def compare_jax_and_py(
6871
if assert_fn is None:
6972
assert_fn = partial(np.testing.assert_allclose, rtol=1e-4)
7073

71-
fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)]
72-
pytensor_jax_fn = function(fn_inputs, fgraph.outputs, mode=jax_mode)
74+
if any(inp.owner is not None for inp in graph_inputs):
75+
raise ValueError("Inputs must be root variables")
76+
77+
pytensor_jax_fn = function(graph_inputs, graph_outputs, mode=jax_mode)
7378
jax_res = pytensor_jax_fn(*test_inputs)
7479

7580
if must_be_device_array:
@@ -78,10 +83,10 @@ def compare_jax_and_py(
7883
else:
7984
assert isinstance(jax_res, jax.Array)
8085

81-
pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode)
86+
pytensor_py_fn = function(graph_inputs, graph_outputs, mode=py_mode)
8287
py_res = pytensor_py_fn(*test_inputs)
8388

84-
if len(fgraph.outputs) > 1:
89+
if isinstance(graph_outputs, list | tuple):
8590
for j, p in zip(jax_res, py_res, strict=True):
8691
assert_fn(j, p)
8792
else:
@@ -187,16 +192,14 @@ def test_jax_ifelse():
187192
false_vals = np.r_[-1, -2, -3]
188193

189194
x = ifelse(np.array(True), true_vals, false_vals)
190-
x_fg = FunctionGraph([], [x])
191195

192-
compare_jax_and_py(x_fg, [])
196+
compare_jax_and_py([], [x], [])
193197

194198
a = dscalar("a")
195-
a.tag.test_value = np.array(0.2, dtype=config.floatX)
199+
a_test = np.array(0.2, dtype=config.floatX)
196200
x = ifelse(a < 0.5, true_vals, false_vals)
197-
x_fg = FunctionGraph([a], [x]) # I.e. False
198201

199-
compare_jax_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs])
202+
compare_jax_and_py([a], [x], [a_test])
200203

201204

202205
def test_jax_checkandraise():
@@ -209,22 +212,16 @@ def test_jax_checkandraise():
209212
function((p,), res, mode=jax_mode)
210213

211214

212-
def set_test_value(x, v):
213-
x.tag.test_value = v
214-
return x
215-
216-
217215
def test_OpFromGraph():
218216
x, y, z = matrices("xyz")
219217
ofg_1 = OpFromGraph([x, y], [x + y], inline=False)
220218
ofg_2 = OpFromGraph([x, y], [x * y, x - y], inline=False)
221219

222220
o1, o2 = ofg_2(y, z)
223221
out = ofg_1(x, o1) + o2
224-
out_fg = FunctionGraph([x, y, z], [out])
225222

226223
xv = np.ones((2, 2), dtype=config.floatX)
227224
yv = np.ones((2, 2), dtype=config.floatX) * 3
228225
zv = np.ones((2, 2), dtype=config.floatX) * 5
229226

230-
compare_jax_and_py(out_fg, [xv, yv, zv])
227+
compare_jax_and_py([x, y, z], [out], [xv, yv, zv])

tests/link/jax/test_blas.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
from pytensor.compile.function import function
55
from pytensor.compile.mode import Mode
66
from pytensor.configdefaults import config
7-
from pytensor.graph.fg import FunctionGraph
8-
from pytensor.graph.op import get_test_value
97
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
108
from pytensor.link.jax import JAXLinker
119
from pytensor.tensor import blas as pt_blas
@@ -16,21 +14,20 @@
1614
def test_jax_BatchedDot():
1715
# tensor3 . tensor3
1816
a = tensor3("a")
19-
a.tag.test_value = (
17+
a_test_value = (
2018
np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3))
2119
)
2220
b = tensor3("b")
23-
b.tag.test_value = (
21+
b_test_value = (
2422
np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2))
2523
)
2624
out = pt_blas.BatchedDot()(a, b)
27-
fgraph = FunctionGraph([a, b], [out])
28-
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
25+
compare_jax_and_py([a, b], [out], [a_test_value, b_test_value])
2926

3027
# A dimension mismatch should raise a TypeError for compatibility
31-
inputs = [get_test_value(a)[:-1], get_test_value(b)]
28+
inputs = [a_test_value[:-1], b_test_value]
3229
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
3330
jax_mode = Mode(JAXLinker(), opts)
34-
pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
31+
pytensor_jax_fn = function([a, b], [out], mode=jax_mode)
3532
with pytest.raises(TypeError):
3633
pytensor_jax_fn(*inputs)

tests/link/jax/test_blockwise.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import pytest
33

44
from pytensor import config
5-
from pytensor.graph import FunctionGraph
65
from pytensor.tensor import tensor
76
from pytensor.tensor.blockwise import Blockwise
87
from pytensor.tensor.math import Dot, matmul
@@ -32,8 +31,7 @@ def test_matmul(matmul_op):
3231

3332
out = matmul_op(a, b)
3433
assert isinstance(out.owner.op, Blockwise)
35-
fg = FunctionGraph([a, b], [out])
36-
fn, _ = compare_jax_and_py(fg, test_values)
34+
fn, _ = compare_jax_and_py([a, b], [out], test_values)
3735

3836
# Check we are not adding any unnecessary stuff
3937
jaxpr = str(jax.make_jaxpr(fn.vm.jit_fn)(*test_values))

tests/link/jax/test_einsum.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import pytest
33

44
import pytensor.tensor as pt
5-
from pytensor.graph import FunctionGraph
65
from tests.link.jax.test_basic import compare_jax_and_py
76

87

@@ -22,8 +21,7 @@ def test_jax_einsum():
2221
}
2322
x_pt, y_pt, z_pt = (pt.tensor(name, shape=shape) for name, shape in shapes.items())
2423
out = pt.einsum(subscripts, x_pt, y_pt, z_pt)
25-
fg = FunctionGraph([x_pt, y_pt, z_pt], [out])
26-
compare_jax_and_py(fg, [x, y, z])
24+
compare_jax_and_py([x_pt, y_pt, z_pt], [out], [x, y, z])
2725

2826

2927
def test_ellipsis_einsum():
@@ -34,5 +32,4 @@ def test_ellipsis_einsum():
3432
x_pt = pt.tensor("x", shape=x.shape)
3533
y_pt = pt.tensor("y", shape=y.shape)
3634
out = pt.einsum(subscripts, x_pt, y_pt)
37-
fg = FunctionGraph([x_pt, y_pt], [out])
38-
compare_jax_and_py(fg, [x, y])
35+
compare_jax_and_py([x_pt, y_pt], [out], [x, y])

tests/link/jax/test_elemwise.py

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
import pytensor.tensor as pt
77
from pytensor.compile import get_mode
88
from pytensor.configdefaults import config
9-
from pytensor.graph.fg import FunctionGraph
10-
from pytensor.graph.op import get_test_value
119
from pytensor.tensor import elemwise as pt_elemwise
1210
from pytensor.tensor.math import all as pt_all
1311
from pytensor.tensor.math import prod
@@ -26,87 +24,81 @@ def test_jax_Dimshuffle():
2624
a_pt = matrix("a")
2725

2826
x = a_pt.T
29-
x_fg = FunctionGraph([a_pt], [x])
30-
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)])
27+
compare_jax_and_py(
28+
[a_pt], [x], [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]
29+
)
3130

3231
x = a_pt.dimshuffle([0, 1, "x"])
33-
x_fg = FunctionGraph([a_pt], [x])
34-
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)])
32+
compare_jax_and_py(
33+
[a_pt], [x], [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]
34+
)
3535

3636
a_pt = tensor(dtype=config.floatX, shape=(None, 1))
3737
x = a_pt.dimshuffle((0,))
38-
x_fg = FunctionGraph([a_pt], [x])
39-
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
38+
compare_jax_and_py([a_pt], [x], [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
4039

4140
a_pt = tensor(dtype=config.floatX, shape=(None, 1))
4241
x = pt_elemwise.DimShuffle(input_ndim=2, new_order=(0,))(a_pt)
43-
x_fg = FunctionGraph([a_pt], [x])
44-
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
42+
compare_jax_and_py([a_pt], [x], [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
4543

4644

4745
def test_jax_CAReduce():
4846
a_pt = vector("a")
4947
a_pt.tag.test_value = np.r_[1, 2, 3].astype(config.floatX)
5048

5149
x = pt_sum(a_pt, axis=None)
52-
x_fg = FunctionGraph([a_pt], [x])
5350

54-
compare_jax_and_py(x_fg, [np.r_[1, 2, 3].astype(config.floatX)])
51+
compare_jax_and_py([a_pt], [x], [np.r_[1, 2, 3].astype(config.floatX)])
5552

5653
a_pt = matrix("a")
5754
a_pt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)
5855

5956
x = pt_sum(a_pt, axis=0)
60-
x_fg = FunctionGraph([a_pt], [x])
6157

62-
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
58+
compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
6359

6460
x = pt_sum(a_pt, axis=1)
65-
x_fg = FunctionGraph([a_pt], [x])
6661

67-
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
62+
compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
6863

6964
a_pt = matrix("a")
7065
a_pt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)
7166

7267
x = prod(a_pt, axis=0)
73-
x_fg = FunctionGraph([a_pt], [x])
7468

75-
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
69+
compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
7670

7771
x = pt_all(a_pt)
78-
x_fg = FunctionGraph([a_pt], [x])
7972

80-
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
73+
compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
8174

8275

8376
@pytest.mark.parametrize("axis", [None, 0, 1])
8477
def test_softmax(axis):
8578
x = matrix("x")
86-
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
79+
x_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
8780
out = softmax(x, axis=axis)
88-
fgraph = FunctionGraph([x], [out])
89-
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
81+
compare_jax_and_py([x], [out], [x_test_value])
9082

9183

9284
@pytest.mark.parametrize("axis", [None, 0, 1])
9385
def test_logsoftmax(axis):
9486
x = matrix("x")
95-
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
87+
x_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
9688
out = log_softmax(x, axis=axis)
97-
fgraph = FunctionGraph([x], [out])
98-
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
89+
90+
compare_jax_and_py([x], [out], [x_test_value])
9991

10092

10193
@pytest.mark.parametrize("axis", [None, 0, 1])
10294
def test_softmax_grad(axis):
10395
dy = matrix("dy")
104-
dy.tag.test_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
96+
dy_test_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
10597
sm = matrix("sm")
106-
sm.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
98+
sm_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
10799
out = SoftmaxGrad(axis=axis)(dy, sm)
108-
fgraph = FunctionGraph([dy, sm], [out])
109-
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
100+
101+
compare_jax_and_py([dy, sm], [out], [dy_test_value, sm_test_value])
110102

111103

112104
@pytest.mark.parametrize("size", [(10, 10), (1000, 1000)])
@@ -134,6 +126,4 @@ def test_logsumexp_benchmark(size, axis, benchmark):
134126
def test_multiple_input_multiply():
135127
x, y, z = vectors("xyz")
136128
out = pt.mul(x, y, z)
137-
138-
fg = FunctionGraph(outputs=[out], clone=False)
139-
compare_jax_and_py(fg, [[1.5], [2.5], [3.5]])
129+
compare_jax_and_py([x, y, z], [out], test_inputs=[[1.5], [2.5], [3.5]])

0 commit comments

Comments
 (0)