58
58
from pytensor .tensor .elemwise import DimShuffle , Elemwise
59
59
from pytensor .tensor .exceptions import NotScalarConstantError
60
60
from pytensor .tensor .math import Dot , dot , maximum , minimum
61
- from pytensor .tensor .rewriting .basic import constant_folding , local_useless_switch
61
+ from pytensor .tensor .rewriting .basic import (
62
+ broadcasted_by ,
63
+ constant_folding ,
64
+ local_useless_switch ,
65
+ )
62
66
from pytensor .tensor .rewriting .elemwise import local_upcast_elemwise_constant_inputs
63
67
from pytensor .tensor .rewriting .math import local_abs_merge , local_mul_switch_sink
64
68
from pytensor .tensor .shape import shape
@@ -1183,6 +1187,34 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
1183
1187
return subtensor_merge_replacements
1184
1188
1185
1189
1190
+ def _is_default_scan_buffer (x : TensorVariable ) -> bool :
1191
+ node = x .owner
1192
+
1193
+ if node is None :
1194
+ return False
1195
+
1196
+ op = node .op
1197
+ if not (
1198
+ isinstance (op , IncSubtensor )
1199
+ and op .set_instead_of_inc
1200
+ and op .idx_list == [slice (None , ps .int64 )]
1201
+ ):
1202
+ return False
1203
+
1204
+ x , y , * _ = node .inputs
1205
+ if not (x .owner is not None and isinstance (x .owner .op , AllocEmpty )):
1206
+ return None
1207
+
1208
+ # The value may have been broadcast to fill in the initial taps.
1209
+ # This check is easier than trying to recreate the subtensor that is being set
1210
+ # and checking if that may broadcast the update value, which is what we care about.
1211
+ # TODO: This check is poorly thought
1212
+ if broadcasted_by (y , x ):
1213
+ return False
1214
+
1215
+ return True
1216
+
1217
+
1186
1218
def scan_save_mem_rewrite (fgraph , node , backend_supports_output_pre_allocation : bool ):
1187
1219
r"""Graph optimizer that reduces scan memory consumption.
1188
1220
@@ -1523,51 +1555,30 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
1523
1555
1524
1556
# 3.2 check orphane outputs to see if we can eliminate any
1525
1557
required , not_required = scan_can_remove_outs (node .op , orphane_outs )
1526
- # 3.3. compose replace pairs for those nodes that need not
1527
- # to store everything in memory ( or ar orphane and required
1528
- # by the inner function .. )
1558
+
1559
+ # 3.3. compose replace pairs for those nodes that need not store everything in memory
1560
+ # (or ar orphan but required by the inner function)
1529
1561
replaced_outs = []
1530
1562
offset = 1 + op_info .n_seqs + op_info .n_mit_mot
1531
- for idx , _val in enumerate (store_steps [op_info .n_mit_mot :]):
1563
+ for idx , val in enumerate (store_steps [op_info .n_mit_mot :]):
1532
1564
i = idx + op_info .n_mit_mot
1533
- if not (isinstance (_val , int ) and _val <= 0 and i not in required ):
1534
- if idx + op_info .n_mit_mot in required :
1535
- val = 1
1536
- else :
1537
- val = _val
1565
+ if not (isinstance (val , int ) and val <= 0 and i not in required ):
1566
+ required_orphan = idx + op_info .n_mit_mot in required
1538
1567
# If the memory for this output has been pre-allocated
1539
1568
# before going into the scan op (by an alloc node)
1540
1569
if idx < op_info .n_mit_sot + op_info .n_sit_sot :
1541
- # In case the input is still an alloc node, we
1542
- # actually have two options:
1543
- # a) the input is a set_subtensor, in that case we
1544
- # can replace the initial tensor by a slice,
1545
- # b) it is not, and we simply take a slice of it.
1546
- # TODO: commit change below with Razvan
1547
- if (
1548
- nw_inputs [offset + idx ].owner
1549
- and isinstance (nw_inputs [offset + idx ].owner .op , IncSubtensor )
1550
- and nw_inputs [offset + idx ].owner .op .set_instead_of_inc
1551
- and isinstance (
1552
- nw_inputs [offset + idx ].owner .op .idx_list [0 ], slice
1553
- )
1554
- # Don't try to create a smart Alloc, if set_subtensor is broadcasting the fill value
1555
- # As it happens in set_subtensor(empty(2)[:], 0)
1556
- and not (
1557
- nw_inputs [offset + idx ].ndim
1558
- > nw_inputs [offset + idx ].owner .inputs [1 ].ndim
1559
- )
1560
- ):
1561
- _nw_input = nw_inputs [offset + idx ].owner .inputs [1 ]
1562
- cval = pt .as_tensor_variable (val )
1563
- initl = pt .as_tensor_variable (init_l [i ])
1564
- tmp_idx = pt .switch (cval < initl , cval + initl , cval - initl )
1565
- nw_input = expand_empty (_nw_input , tmp_idx )
1570
+ nw_input = nw_inputs [offset + idx ]
1571
+
1572
+ # Check if the input looks like a default pre-allocated Scan buffer
1573
+ # created via `expand_empty`, which looks like empty(...)[:init.shape[0]].set(init)
1574
+ # If so, we can just recreate the pre-allocated buffer with a smaller size
1575
+ if _is_default_scan_buffer (nw_input ):
1576
+ extra_size = 1 if required_orphan else val - init_l [i ]
1577
+ nw_input = expand_empty (nw_input .owner .inputs [1 ], extra_size )
1578
+ # Otherwise, just trim the buffer with a slice
1566
1579
else :
1567
- tmp = pt .as_tensor_variable (val )
1568
- initl = pt .as_tensor_variable (init_l [i ])
1569
- tmp = maximum (tmp , initl )
1570
- nw_input = nw_inputs [offset + idx ][:tmp ]
1580
+ stop = init_l [i ] if required_orphan else val
1581
+ nw_input = nw_input [:stop ]
1571
1582
1572
1583
nw_inputs [offset + idx ] = nw_input
1573
1584
replaced_outs .append (op_info .n_mit_mot + idx )
@@ -1591,7 +1602,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
1591
1602
+ op_info .n_shared_outs
1592
1603
)
1593
1604
if nw_inputs [pos ] == node .inputs [0 ]:
1594
- nw_inputs [pos ] = val
1605
+ nw_inputs [pos ] = 1 if required_orphan else val
1595
1606
odx = op_info .n_mit_mot + idx
1596
1607
replaced_outs .append (odx )
1597
1608
old_outputs += [
@@ -1603,37 +1614,22 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
1603
1614
],
1604
1615
)
1605
1616
]
1606
- # 3.4. Recompute inputs for everything else based on the new
1607
- # number of steps
1617
+ # 3.4. Recompute inputs for everything else based on the new number of steps
1608
1618
if global_nsteps is not None :
1609
1619
for idx , val in enumerate (store_steps [op_info .n_mit_mot :]):
1610
1620
if val == 0 :
1611
1621
# val == 0 means that we want to keep all intermediate
1612
1622
# results for that state, including the initial values.
1613
1623
if idx < op_info .n_mit_sot + op_info .n_sit_sot :
1614
1624
in_idx = offset + idx
1615
- # Number of steps in the initial state
1616
- initl = init_l [op_info .n_mit_mot + idx ]
1617
-
1618
- # If the initial buffer has the form
1619
- # inc_subtensor(zeros(...)[...], _nw_input)
1620
- # we want to make the zeros tensor as small as
1621
- # possible (nw_steps + initl), and call
1622
- # inc_subtensor on that instead.
1623
- # Otherwise, simply take 0:(nw_steps+initl).
1624
- if (
1625
- nw_inputs [in_idx ].owner
1626
- and isinstance (nw_inputs [in_idx ].owner .op , IncSubtensor )
1627
- and isinstance (
1628
- nw_inputs [in_idx ].owner .op .idx_list [0 ], slice
1629
- )
1630
- ):
1631
- _nw_input = nw_inputs [in_idx ].owner .inputs [1 ]
1632
- nw_input = expand_empty (_nw_input , nw_steps )
1633
- nw_inputs [in_idx ] = nw_input
1625
+ nw_input = nw_inputs [in_idx ]
1626
+ if _is_default_scan_buffer (nw_input ):
1627
+ nw_input = expand_empty (nw_input .owner .inputs [1 ], nw_steps )
1634
1628
else :
1635
- # FIXME: This is never used
1636
- nw_input = nw_inputs [in_idx ][: (initl + nw_steps )]
1629
+ # Number of steps in the initial state
1630
+ init_l_pt = pt .as_tensor (init_l [op_info .n_mit_mot + idx ])
1631
+ nw_input = nw_input [: (init_l_pt + nw_steps )]
1632
+ nw_inputs [in_idx ] = nw_input
1637
1633
1638
1634
elif (
1639
1635
idx < op_info .n_mit_sot + op_info .n_sit_sot + op_info .n_nit_sot
0 commit comments