Skip to content

Commit 8454c3b

Browse files
committed
Fix bug in ScanSaveMem with broadcasted initial value
1 parent 249dfae commit 8454c3b

File tree

2 files changed

+75
-70
lines changed

2 files changed

+75
-70
lines changed

pytensor/scan/rewriting.py

+59-63
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,11 @@
5858
from pytensor.tensor.elemwise import DimShuffle, Elemwise
5959
from pytensor.tensor.exceptions import NotScalarConstantError
6060
from pytensor.tensor.math import Dot, dot, maximum, minimum
61-
from pytensor.tensor.rewriting.basic import constant_folding, local_useless_switch
61+
from pytensor.tensor.rewriting.basic import (
62+
broadcasted_by,
63+
constant_folding,
64+
local_useless_switch,
65+
)
6266
from pytensor.tensor.rewriting.elemwise import local_upcast_elemwise_constant_inputs
6367
from pytensor.tensor.rewriting.math import local_abs_merge, local_mul_switch_sink
6468
from pytensor.tensor.shape import shape
@@ -1183,6 +1187,34 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
11831187
return subtensor_merge_replacements
11841188

11851189

1190+
def _is_default_scan_buffer(x: TensorVariable) -> bool:
1191+
node = x.owner
1192+
1193+
if node is None:
1194+
return False
1195+
1196+
op = node.op
1197+
if not (
1198+
isinstance(op, IncSubtensor)
1199+
and op.set_instead_of_inc
1200+
and op.idx_list == [slice(None, ps.int64)]
1201+
):
1202+
return False
1203+
1204+
x, y, *_ = node.inputs
1205+
if not (x.owner is not None and isinstance(x.owner.op, AllocEmpty)):
1206+
return None
1207+
1208+
# The value may have been broadcast to fill in the initial taps.
1209+
# This check is easier than trying to recreate the subtensor that is being set
1210+
# and checking if that may broadcast the update value, which is what we care about.
1211+
# TODO: This check is poorly thought
1212+
if broadcasted_by(y, x):
1213+
return False
1214+
1215+
return True
1216+
1217+
11861218
def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: bool):
11871219
r"""Graph optimizer that reduces scan memory consumption.
11881220
@@ -1523,51 +1555,30 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
15231555

15241556
# 3.2 check orphane outputs to see if we can eliminate any
15251557
required, not_required = scan_can_remove_outs(node.op, orphane_outs)
1526-
# 3.3. compose replace pairs for those nodes that need not
1527-
# to store everything in memory ( or ar orphane and required
1528-
# by the inner function .. )
1558+
1559+
# 3.3. compose replace pairs for those nodes that need not store everything in memory
1560+
# (or ar orphan but required by the inner function)
15291561
replaced_outs = []
15301562
offset = 1 + op_info.n_seqs + op_info.n_mit_mot
1531-
for idx, _val in enumerate(store_steps[op_info.n_mit_mot :]):
1563+
for idx, val in enumerate(store_steps[op_info.n_mit_mot :]):
15321564
i = idx + op_info.n_mit_mot
1533-
if not (isinstance(_val, int) and _val <= 0 and i not in required):
1534-
if idx + op_info.n_mit_mot in required:
1535-
val = 1
1536-
else:
1537-
val = _val
1565+
if not (isinstance(val, int) and val <= 0 and i not in required):
1566+
required_orphan = idx + op_info.n_mit_mot in required
15381567
# If the memory for this output has been pre-allocated
15391568
# before going into the scan op (by an alloc node)
15401569
if idx < op_info.n_mit_sot + op_info.n_sit_sot:
1541-
# In case the input is still an alloc node, we
1542-
# actually have two options:
1543-
# a) the input is a set_subtensor, in that case we
1544-
# can replace the initial tensor by a slice,
1545-
# b) it is not, and we simply take a slice of it.
1546-
# TODO: commit change below with Razvan
1547-
if (
1548-
nw_inputs[offset + idx].owner
1549-
and isinstance(nw_inputs[offset + idx].owner.op, IncSubtensor)
1550-
and nw_inputs[offset + idx].owner.op.set_instead_of_inc
1551-
and isinstance(
1552-
nw_inputs[offset + idx].owner.op.idx_list[0], slice
1553-
)
1554-
# Don't try to create a smart Alloc, if set_subtensor is broadcasting the fill value
1555-
# As it happens in set_subtensor(empty(2)[:], 0)
1556-
and not (
1557-
nw_inputs[offset + idx].ndim
1558-
> nw_inputs[offset + idx].owner.inputs[1].ndim
1559-
)
1560-
):
1561-
_nw_input = nw_inputs[offset + idx].owner.inputs[1]
1562-
cval = pt.as_tensor_variable(val)
1563-
initl = pt.as_tensor_variable(init_l[i])
1564-
tmp_idx = pt.switch(cval < initl, cval + initl, cval - initl)
1565-
nw_input = expand_empty(_nw_input, tmp_idx)
1570+
nw_input = nw_inputs[offset + idx]
1571+
1572+
# Check if the input looks like a default pre-allocated Scan buffer
1573+
# created via `expand_empty`, which looks like empty(...)[:init.shape[0]].set(init)
1574+
# If so, we can just recreate the pre-allocated buffer with a smaller size
1575+
if _is_default_scan_buffer(nw_input):
1576+
extra_size = 1 if required_orphan else val - init_l[i]
1577+
nw_input = expand_empty(nw_input.owner.inputs[1], extra_size)
1578+
# Otherwise, just trim the buffer with a slice
15661579
else:
1567-
tmp = pt.as_tensor_variable(val)
1568-
initl = pt.as_tensor_variable(init_l[i])
1569-
tmp = maximum(tmp, initl)
1570-
nw_input = nw_inputs[offset + idx][:tmp]
1580+
stop = init_l[i] if required_orphan else val
1581+
nw_input = nw_input[:stop]
15711582

15721583
nw_inputs[offset + idx] = nw_input
15731584
replaced_outs.append(op_info.n_mit_mot + idx)
@@ -1591,7 +1602,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
15911602
+ op_info.n_shared_outs
15921603
)
15931604
if nw_inputs[pos] == node.inputs[0]:
1594-
nw_inputs[pos] = val
1605+
nw_inputs[pos] = 1 if required_orphan else val
15951606
odx = op_info.n_mit_mot + idx
15961607
replaced_outs.append(odx)
15971608
old_outputs += [
@@ -1603,37 +1614,22 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
16031614
],
16041615
)
16051616
]
1606-
# 3.4. Recompute inputs for everything else based on the new
1607-
# number of steps
1617+
# 3.4. Recompute inputs for everything else based on the new number of steps
16081618
if global_nsteps is not None:
16091619
for idx, val in enumerate(store_steps[op_info.n_mit_mot :]):
16101620
if val == 0:
16111621
# val == 0 means that we want to keep all intermediate
16121622
# results for that state, including the initial values.
16131623
if idx < op_info.n_mit_sot + op_info.n_sit_sot:
16141624
in_idx = offset + idx
1615-
# Number of steps in the initial state
1616-
initl = init_l[op_info.n_mit_mot + idx]
1617-
1618-
# If the initial buffer has the form
1619-
# inc_subtensor(zeros(...)[...], _nw_input)
1620-
# we want to make the zeros tensor as small as
1621-
# possible (nw_steps + initl), and call
1622-
# inc_subtensor on that instead.
1623-
# Otherwise, simply take 0:(nw_steps+initl).
1624-
if (
1625-
nw_inputs[in_idx].owner
1626-
and isinstance(nw_inputs[in_idx].owner.op, IncSubtensor)
1627-
and isinstance(
1628-
nw_inputs[in_idx].owner.op.idx_list[0], slice
1629-
)
1630-
):
1631-
_nw_input = nw_inputs[in_idx].owner.inputs[1]
1632-
nw_input = expand_empty(_nw_input, nw_steps)
1633-
nw_inputs[in_idx] = nw_input
1625+
nw_input = nw_inputs[in_idx]
1626+
if _is_default_scan_buffer(nw_input):
1627+
nw_input = expand_empty(nw_input.owner.inputs[1], nw_steps)
16341628
else:
1635-
# FIXME: This is never used
1636-
nw_input = nw_inputs[in_idx][: (initl + nw_steps)]
1629+
# Number of steps in the initial state
1630+
init_l_pt = pt.as_tensor(init_l[op_info.n_mit_mot + idx])
1631+
nw_input = nw_input[: (init_l_pt + nw_steps)]
1632+
nw_inputs[in_idx] = nw_input
16371633

16381634
elif (
16391635
idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot

tests/scan/test_rewriting.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -1634,21 +1634,30 @@ def test_while_scan_taps_and_map(self):
16341634
assert stored_ys_steps == 2
16351635
assert stored_zs_steps == 1
16361636

1637-
def test_vector_zeros_init(self):
1637+
@pytest.mark.parametrize("val_ndim", (0, 1))
1638+
@pytest.mark.parametrize("keep_beginning", (False, True))
1639+
def test_broadcasted_init(self, keep_beginning, val_ndim):
1640+
# Regression test when the original value is a broadcasted alloc
1641+
# The scan save mem rewrite used to wrongly slice on the unbroadcasted value
1642+
val_shape = (1,) * val_ndim
1643+
val = pt.tensor("val", shape=val_shape)
1644+
1645+
init = pt.full((2,), val)
16381646
ys, _ = pytensor.scan(
1639-
fn=lambda ytm2, ytm1: ytm1 + ytm2,
1640-
outputs_info=[{"initial": pt.zeros(2), "taps": range(-2, 0)}],
1647+
fn=lambda *args: pt.add(*args),
1648+
outputs_info=[{"initial": init, "taps": (-2, -1)}],
16411649
n_steps=100,
16421650
)
16431651

1644-
fn = pytensor.function([], ys[-50:], mode=self.mode)
1645-
assert tuple(fn().shape) == (50,)
1652+
out = ys[:-50] if keep_beginning else ys[-50:]
1653+
fn = pytensor.function([val], out, mode=self.mode)
1654+
assert fn(np.zeros(val_shape)).shape == (50,)
16461655

16471656
# Check that rewrite worked
16481657
[scan_node] = (n for n in fn.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
16491658
_, ys_trace = scan_node.inputs
1650-
debug_fn = pytensor.function([], ys_trace.shape[0], accept_inplace=True)
1651-
assert debug_fn() == 50
1659+
debug_fn = pytensor.function([val], ys_trace.shape[0], accept_inplace=True)
1660+
assert debug_fn(np.zeros(val_shape)) == 52 if keep_beginning else 50
16521661

16531662

16541663
def test_inner_replace_dot():

0 commit comments

Comments
 (0)