Skip to content

Commit df568d9

Browse files
committed
Handle invalid BroadcastTo shape in C backend
1 parent db18c97 commit df568d9

File tree

2 files changed

+39
-10
lines changed

2 files changed

+39
-10
lines changed

pytensor/tensor/extra_ops.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -1643,6 +1643,9 @@ def make_node(self, a, *shape):
16431643

16441644
shape, static_shape = at.infer_static_shape(shape)
16451645

1646+
if len(shape) < a.ndim:
1647+
raise ValueError("Broadcast shape cannot be shorter than input shape")
1648+
16461649
out = TensorType(dtype=a.type.dtype, shape=static_shape)()
16471650

16481651
# Attempt to prevent in-place operations on this view-based output
@@ -1715,6 +1718,23 @@ def c_code(self, node, name, inputs, outputs, sub):
17151718
1, ops, flags, NPY_CORDER, NPY_NO_CASTING, op_flags, op_dtypes, oa_ndim, op_axes, itershape, buffersize
17161719
);
17171720
1721+
1722+
int offset = %(ndims)s - PyArray_NDIM(%(x)s);
1723+
1724+
for(int i = 0; i < PyArray_NDIM(%(x)s); i++)
1725+
{
1726+
if ((PyArray_DIMS(%(x)s)[i] != 1) && (PyArray_DIMS(%(x)s)[i] != itershape[i + offset]))
1727+
{
1728+
PyErr_Format(PyExc_ValueError,
1729+
"Shape mismatch in broadcast_to: target shape[%%i] = %%lld is incompatible with input shape = %%lld.",
1730+
i,
1731+
(long long int) itershape[i + offset],
1732+
(long long int) PyArray_DIMS(%(x)s)[i]
1733+
);
1734+
%(fail)s
1735+
}
1736+
}
1737+
17181738
%(out)s = NpyIter_GetIterView(iter, 0);
17191739
17201740
if(%(out)s == NULL){
@@ -1733,7 +1753,7 @@ def c_code(self, node, name, inputs, outputs, sub):
17331753
return src
17341754

17351755
def c_code_cache_version(self):
1736-
return (1,)
1756+
return (2,)
17371757

17381758

17391759
broadcast_to_ = BroadcastTo()

tests/tensor/test_extra_ops.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -1253,30 +1253,39 @@ def test_avoid_useless_subtensors(self):
12531253
@pytest.mark.parametrize("linker", ["cvm", "py"])
12541254
def test_perform(self, linker):
12551255

1256-
a = pytensor.shared(5)
1256+
a = pytensor.shared(np.full((3, 1, 1), 5))
1257+
s_0 = iscalar("s_0")
12571258
s_1 = iscalar("s_1")
1258-
shape = (s_1, 1)
1259+
shape = (s_0, s_1, 1)
12591260

12601261
bcast_res = broadcast_to(a, shape)
1261-
assert bcast_res.broadcastable == (False, True)
1262+
assert bcast_res.broadcastable == (False, False, True)
12621263

12631264
bcast_fn = pytensor.function(
1264-
[s_1], bcast_res, mode=Mode(optimizer=None, linker=linker)
1265+
[s_0, s_1], bcast_res, mode=Mode(optimizer=None, linker=linker)
12651266
)
12661267
bcast_fn.vm.allow_gc = False
12671268

1268-
bcast_at = bcast_fn(4)
1269-
bcast_np = np.broadcast_to(5, (4, 1))
1269+
bcast_at = bcast_fn(3, 4)
1270+
bcast_np = np.broadcast_to(5, (3, 4, 1))
12701271

12711272
assert np.array_equal(bcast_at, bcast_np)
12721273

1273-
bcast_var = bcast_fn.maker.fgraph.outputs[0].owner.inputs[0]
1274-
bcast_in = bcast_fn.vm.storage_map[a]
1275-
bcast_out = bcast_fn.vm.storage_map[bcast_var]
1274+
with pytest.raises(ValueError):
1275+
bcast_fn(5, 4)
12761276

12771277
if linker != "py":
1278+
bcast_var = bcast_fn.maker.fgraph.outputs[0].owner.inputs[0]
1279+
bcast_in = bcast_fn.vm.storage_map[a]
1280+
bcast_out = bcast_fn.vm.storage_map[bcast_var]
12781281
assert np.shares_memory(bcast_out[0], bcast_in[0])
12791282

1283+
def test_make_node_error_handling(self):
1284+
with pytest.raises(
1285+
ValueError, match="Broadcast shape cannot be shorter than input shape"
1286+
):
1287+
broadcast_to(at.zeros((3, 4)), (5,))
1288+
12801289
@pytest.mark.skipif(
12811290
not config.cxx, reason="G++ not available, so we need to skip this test."
12821291
)

0 commit comments

Comments
 (0)