Open
Description
Description
Follow up to #1281
In the numba backend we allow inplacing of the inner sit-sot / oldest mit-sot when we know the buffer is only large enough to store the most recent taps. However when n_steps
to a Scan is symbolic PyTensor doesn't figure this out. Note how on the second graph, the inner scan Composite doesn't destroy the input *0
from pytensor import function, scan
import pytensor.tensor as pt
for constant_n_steps in (True, False):
print(f"{constant_n_steps=}")
init_x = pt.vector("init_x", shape=(2,))
n_steps = pt.iscalar("n_steps")
def f_pow2(x_tm2, x_tm1):
return 2 * x_tm1 + x_tm2
trace, _ = scan(
f_pow2,
sequences=[],
outputs_info=[{"initial": init_x, "taps": [-2, -1]}],
non_sequences=[],
n_steps=10 if constant_n_steps else n_steps,
)
fn = function([init_x, n_steps], trace[-1], on_unused_input="ignore", mode="NUMBA")
fn.dprint(print_memory_map=True, print_shape=True)
# constant_n_steps=True
# Subtensor{i} [id A] shape=() v={0: [0]} 3
# ├─ Scan{scan_fn, while_loop=False, inplace=all} [id B] shape=(?,) d={0: [1]} 2
# │ ├─ 10 [id C] shape=()
# │ └─ SetSubtensor{:stop} [id D] shape=(2,) d={0: [0]} 1
# │ ├─ AllocEmpty{dtype='float64'} [id E] shape=(2,) 0
# │ │ └─ 2 [id F] shape=()
# │ ├─ init_x [id G] shape=(2,)
# │ └─ 2 [id H] shape=()
# └─ 1 [id I] shape=()
# Scan{scan_fn, while_loop=False, inplace=all} [id B] d={0: [1]}
# ← Composite{((2.0 * i0) + i1)} [id J] shape=() d={0: [1]}
# ├─ *1-<Scalar(float64, shape=())> [id K] shape=() -> [id D]
# └─ *0-<Scalar(float64, shape=())> [id L] shape=() -> [id D]
# constant_n_steps=False
# Subtensor{i} [id A] shape=() v={0: [0]} 5
# ├─ Scan{scan_fn, while_loop=False, inplace=all} [id B] shape=(?,) d={0: [1]} 4
# │ ├─ Composite{...}.0 [id C] shape=() 0
# │ │ └─ n_steps [id D] shape=()
# │ └─ SetSubtensor{:stop} [id E] shape=(?,) d={0: [0]} 3
# │ ├─ AllocEmpty{dtype='float64'} [id F] shape=(?,) 2
# │ │ └─ Composite{...}.2 [id C] shape=() 0
# │ │ └─ ···
# │ ├─ init_x [id G] shape=(2,)
# │ └─ 2 [id H] shape=()
# └─ ScalarFromTensor [id I] shape=() 1
# └─ Composite{...}.1 [id C] shape=() 0
# └─ ···
# Inner graphs:
# Scan{scan_fn, while_loop=False, inplace=all} [id B] d={0: [1]}
# ← Composite{((2.0 * i0) + i1)} [id J] shape=()
# ├─ *1-<Scalar(float64, shape=())> [id K] shape=() -> [id E]
# └─ *0-<Scalar(float64, shape=())> [id L] shape=() -> [id E]