Skip to content

Scan save mem rewrite masks issues with steps=0 #1288

Open
@ricardoV94

Description

@ricardoV94

Description

import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_default_mode

n = pt.iscalar("n")
x0 = pt.vector("x0")
xs, _ = pytensor.scan(lambda xtm1: xtm1 + 1, outputs_info=[x0], n_steps=n)

out = xs[-1]  # Invalid when nsteps=0

fn = pytensor.function([n, x0], out)
print(fn(n=0, x0=[0, 1]))  # [1. 2.]

fn = pytensor.function([n, x0], out, mode=get_default_mode().excluding("shape_unsafe"))
print(fn(n=0, x0=[0, 1]))  # [1. 2.]

fn = pytensor.function([n, x0], out, mode=get_default_mode().excluding("scan_save_mem"))
print(fn(n=0, x0=[0, 1]))  # IndexError: index out of bounds

I suspect from this hack:

# FIXME: This is not correct. Scan with 0 steps seems to be supported
# Make sure the ScanSaveMem optimization never makes the new
# number of steps to be 0 (this could happen, for instance, if
# the optimization detects that the outputs of the Scan go through
# subtensor nodes that end up taking no elements) because Scan with
# 0 iterations are not supported. Make sure the new number of steps
# is at least 1.
nw_steps = select_max(nw_steps, 1)

But removing this hack leads to some tests failing, so other stuff may be doing wrong assumptions downstream of it (or perhaps inside the rewrite)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions