Skip to content

Inplace on sit-sot / mit-sot when nsteps is symbolic #1283

Open
@ricardoV94

Description

@ricardoV94

Description

Follow up to #1281

In the numba backend we allow inplacing of the inner sit-sot / oldest mit-sot when we know the buffer is only large enough to store the most recent taps. However when n_steps to a Scan is symbolic PyTensor doesn't figure this out. Note how on the second graph, the inner scan Composite doesn't destroy the input *0

from pytensor import function, scan
import pytensor.tensor as pt

for constant_n_steps in (True, False):    
    print(f"{constant_n_steps=}")
    init_x = pt.vector("init_x", shape=(2,))
    n_steps = pt.iscalar("n_steps")

    def f_pow2(x_tm2, x_tm1):
        return 2 * x_tm1 + x_tm2
    
    trace, _ = scan(
        f_pow2,
        sequences=[],
        outputs_info=[{"initial": init_x, "taps": [-2, -1]}],
        non_sequences=[],
        n_steps=10 if constant_n_steps else n_steps,
    )
    fn = function([init_x, n_steps], trace[-1], on_unused_input="ignore", mode="NUMBA")
    fn.dprint(print_memory_map=True, print_shape=True)
    
# constant_n_steps=True
# Subtensor{i} [id A] shape=() v={0: [0]} 3
#  ├─ Scan{scan_fn, while_loop=False, inplace=all} [id B] shape=(?,) d={0: [1]} 2
#  │  ├─ 10 [id C] shape=()
#  │  └─ SetSubtensor{:stop} [id D] shape=(2,) d={0: [0]} 1
#  │     ├─ AllocEmpty{dtype='float64'} [id E] shape=(2,) 0
#  │     │  └─ 2 [id F] shape=()
#  │     ├─ init_x [id G] shape=(2,)
#  │     └─ 2 [id H] shape=()
#  └─ 1 [id I] shape=()

# Scan{scan_fn, while_loop=False, inplace=all} [id B] d={0: [1]}
#  ← Composite{((2.0 * i0) + i1)} [id J] shape=() d={0: [1]}
#     ├─ *1-<Scalar(float64, shape=())> [id K] shape=() -> [id D]
#     └─ *0-<Scalar(float64, shape=())> [id L] shape=() -> [id D]


# constant_n_steps=False
# Subtensor{i} [id A] shape=() v={0: [0]} 5
#  ├─ Scan{scan_fn, while_loop=False, inplace=all} [id B] shape=(?,) d={0: [1]} 4
#  │  ├─ Composite{...}.0 [id C] shape=() 0
#  │  │  └─ n_steps [id D] shape=()
#  │  └─ SetSubtensor{:stop} [id E] shape=(?,) d={0: [0]} 3
#  │     ├─ AllocEmpty{dtype='float64'} [id F] shape=(?,) 2
#  │     │  └─ Composite{...}.2 [id C] shape=() 0
#  │     │     └─ ···
#  │     ├─ init_x [id G] shape=(2,)
#  │     └─ 2 [id H] shape=()
#  └─ ScalarFromTensor [id I] shape=() 1
#     └─ Composite{...}.1 [id C] shape=() 0
#        └─ ···

# Inner graphs:
# Scan{scan_fn, while_loop=False, inplace=all} [id B] d={0: [1]}
#  ← Composite{((2.0 * i0) + i1)} [id J] shape=()
#     ├─ *1-<Scalar(float64, shape=())> [id K] shape=() -> [id E]
#     └─ *0-<Scalar(float64, shape=())> [id L] shape=() -> [id E]

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions