Skip to content

Commit 65826e7

Browse files
ricardoV94lucianopaz
authored andcommitted
Handle invalid BroadcastTo shape in C backend
1 parent 24b67a8 commit 65826e7

File tree

2 files changed

+62
-20
lines changed

2 files changed

+62
-20
lines changed

pytensor/tensor/extra_ops.py

+28-6
Original file line numberDiff line numberDiff line change
@@ -1643,6 +1643,11 @@ def make_node(self, a, *shape):
16431643

16441644
shape, static_shape = at.infer_static_shape(shape)
16451645

1646+
if len(shape) < a.ndim:
1647+
raise ValueError(
1648+
f"Broadcast target shape has {len(shape)} dims, which is shorter than input with {a.ndim} dims"
1649+
)
1650+
16461651
out = TensorType(dtype=a.type.dtype, shape=static_shape)()
16471652

16481653
# Attempt to prevent in-place operations on this view-based output
@@ -1686,9 +1691,12 @@ def infer_shape(self, fgraph, node, ins_shapes):
16861691
return [node.inputs[1:]]
16871692

16881693
def c_code(self, node, name, inputs, outputs, sub):
1694+
inp_dims = node.inputs[0].ndim
1695+
out_dims = node.outputs[0].ndim
1696+
new_dims = out_dims - inp_dims
1697+
16891698
(x, *shape) = inputs
16901699
(out,) = outputs
1691-
ndims = len(shape)
16921700
fail = sub["fail"]
16931701

16941702
# TODO: Could just use `PyArray_Return`, no?
@@ -1701,20 +1709,34 @@ def c_code(self, node, name, inputs, outputs, sub):
17011709

17021710
src = (
17031711
"""
1704-
npy_intp itershape[%(ndims)s] = {%(dims_array)s};
1712+
npy_intp itershape[%(out_dims)s] = {%(dims_array)s};
17051713
1714+
NpyIter *iter;
17061715
PyArrayObject *ops[1] = {%(x)s};
17071716
npy_uint32 flags = NPY_ITER_MULTI_INDEX | NPY_ITER_REFS_OK | NPY_ITER_ZEROSIZE_OK;
17081717
npy_uint32 op_flags[1] = {NPY_ITER_READONLY};
17091718
PyArray_Descr *op_dtypes[1] = {NULL};
1710-
int oa_ndim = %(ndims)s;
1719+
int oa_ndim = %(out_dims)s;
17111720
int* op_axes[1] = {NULL};
17121721
npy_intp buffersize = 0;
17131722
1714-
NpyIter *iter = NpyIter_AdvancedNew(
1723+
for(int i = 0; i < %(inp_dims)s; i++)
1724+
{
1725+
if ((PyArray_DIMS(%(x)s)[i] != 1) && (PyArray_DIMS(%(x)s)[i] != itershape[i + %(new_dims)s]))
1726+
{
1727+
PyErr_Format(PyExc_ValueError,
1728+
"Shape mismatch in broadcast_to: target shape[%%i] = %%lld is incompatible with input shape = %%lld.",
1729+
i,
1730+
(long long int) itershape[i + %(new_dims)s],
1731+
(long long int) PyArray_DIMS(%(x)s)[i]
1732+
);
1733+
%(fail)s
1734+
}
1735+
}
1736+
1737+
iter = NpyIter_AdvancedNew(
17151738
1, ops, flags, NPY_CORDER, NPY_NO_CASTING, op_flags, op_dtypes, oa_ndim, op_axes, itershape, buffersize
17161739
);
1717-
17181740
%(out)s = NpyIter_GetIterView(iter, 0);
17191741
17201742
if(%(out)s == NULL){
@@ -1733,7 +1755,7 @@ def c_code(self, node, name, inputs, outputs, sub):
17331755
return src
17341756

17351757
def c_code_cache_version(self):
1736-
return (1,)
1758+
return (2,)
17371759

17381760

17391761
broadcast_to_ = BroadcastTo()

tests/tensor/test_extra_ops.py

+34-14
Original file line numberDiff line numberDiff line change
@@ -1253,41 +1253,52 @@ def test_avoid_useless_subtensors(self):
12531253
@pytest.mark.parametrize("linker", ["cvm", "py"])
12541254
def test_perform(self, linker):
12551255

1256-
a = pytensor.shared(5)
1256+
a = pytensor.shared(np.full((3, 1, 1), 5))
1257+
s_0 = iscalar("s_0")
12571258
s_1 = iscalar("s_1")
1258-
shape = (s_1, 1)
1259+
shape = (s_0, s_1, 1)
12591260

12601261
bcast_res = broadcast_to(a, shape)
1261-
assert bcast_res.broadcastable == (False, True)
1262+
assert bcast_res.broadcastable == (False, False, True)
12621263

12631264
bcast_fn = pytensor.function(
1264-
[s_1], bcast_res, mode=Mode(optimizer=None, linker=linker)
1265+
[s_0, s_1], bcast_res, mode=Mode(optimizer=None, linker=linker)
12651266
)
12661267
bcast_fn.vm.allow_gc = False
12671268

1268-
bcast_at = bcast_fn(4)
1269-
bcast_np = np.broadcast_to(5, (4, 1))
1269+
bcast_at = bcast_fn(3, 4)
1270+
bcast_np = np.broadcast_to(5, (3, 4, 1))
12701271

12711272
assert np.array_equal(bcast_at, bcast_np)
12721273

1273-
bcast_var = bcast_fn.maker.fgraph.outputs[0].owner.inputs[0]
1274-
bcast_in = bcast_fn.vm.storage_map[a]
1275-
bcast_out = bcast_fn.vm.storage_map[bcast_var]
1274+
with pytest.raises(ValueError):
1275+
bcast_fn(5, 4)
12761276

12771277
if linker != "py":
1278+
bcast_var = bcast_fn.maker.fgraph.outputs[0].owner.inputs[0]
1279+
bcast_in = bcast_fn.vm.storage_map[a]
1280+
bcast_out = bcast_fn.vm.storage_map[bcast_var]
12781281
assert np.shares_memory(bcast_out[0], bcast_in[0])
12791282

1283+
def test_make_node_error_handling(self):
1284+
with pytest.raises(
1285+
ValueError,
1286+
match="Broadcast target shape has 1 dims, which is shorter than input with 2 dims",
1287+
):
1288+
broadcast_to(at.zeros((3, 4)), (5,))
1289+
12801290
@pytest.mark.skipif(
12811291
not config.cxx, reason="G++ not available, so we need to skip this test."
12821292
)
1283-
def test_memory_leak(self):
1293+
@pytest.mark.parametrize("valid", (True, False))
1294+
def test_memory_leak(self, valid):
12841295
import gc
12851296
import tracemalloc
12861297

12871298
from pytensor.link.c.cvm import CVM
12881299

12891300
n = 100_000
1290-
x = pytensor.shared(np.ones(n, dtype=np.float64))
1301+
x = pytensor.shared(np.ones((1, n), dtype=np.float64))
12911302
y = broadcast_to(x, (5, n))
12921303

12931304
f = pytensor.function([], y, mode=Mode(optimizer=None, linker="cvm"))
@@ -1303,8 +1314,17 @@ def test_memory_leak(self):
13031314
blocks_last = None
13041315
block_diffs = []
13051316
for i in range(1, 50):
1306-
x.set_value(np.ones(n))
1307-
_ = f()
1317+
if valid:
1318+
x.set_value(np.ones((1, n)))
1319+
_ = f()
1320+
else:
1321+
x.set_value(np.ones((2, n)))
1322+
try:
1323+
_ = f()
1324+
except ValueError:
1325+
pass
1326+
else:
1327+
raise RuntimeError("Should have failed")
13081328
_ = gc.collect()
13091329
blocks_i, _ = tracemalloc.get_traced_memory()
13101330
if blocks_last is not None:
@@ -1313,7 +1333,7 @@ def test_memory_leak(self):
13131333
blocks_last = blocks_i
13141334

13151335
tracemalloc.stop()
1316-
assert np.allclose(np.mean(block_diffs), 0)
1336+
assert np.all(np.array(block_diffs) <= (0 + 1e-8))
13171337

13181338
@pytest.mark.parametrize(
13191339
"fn,input_dims",

0 commit comments

Comments
 (0)