Open
Description
Description
import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_default_mode
n = pt.iscalar("n")
x0 = pt.vector("x0")
xs, _ = pytensor.scan(lambda xtm1: xtm1 + 1, outputs_info=[x0], n_steps=n)
out = xs[-1] # Invalid when nsteps=0
fn = pytensor.function([n, x0], out)
print(fn(n=0, x0=[0, 1])) # [1. 2.]
fn = pytensor.function([n, x0], out, mode=get_default_mode().excluding("shape_unsafe"))
print(fn(n=0, x0=[0, 1])) # [1. 2.]
fn = pytensor.function([n, x0], out, mode=get_default_mode().excluding("scan_save_mem"))
print(fn(n=0, x0=[0, 1])) # IndexError: index out of bounds
I suspect from this hack:
pytensor/pytensor/scan/rewriting.py
Lines 1438 to 1445 in 8454c3b
But removing this hack leads to some tests failing, so other stuff may be doing wrong assumptions downstream of it (or perhaps inside the rewrite)