Skip to content

Commit 72b2a6d

Browse files
committed
Make sequences just another constant input
Sequences are now demoted to being just another constant in the Scan Op. The user facing function creates the right indexing graph for iterating over sequences automatically. Some extra logic is added in the `scan_to_loop` rewrite to avoid creating duplicated indexes, while being on guard for Scans created elsewhere. Additionally, the outer graph is kept visible to the scan user function
1 parent 51a973d commit 72b2a6d

File tree

4 files changed

+200
-148
lines changed

4 files changed

+200
-148
lines changed

pytensor/loop/basic.py

+50-17
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1+
import functools
12
from typing import List, Tuple
23

34
import numpy as np
45

5-
from pytensor import Variable, as_symbolic
6+
from pytensor import Variable, as_symbolic, clone_replace
67
from pytensor.graph import FunctionGraph
8+
from pytensor.graph.basic import Constant, truncated_graph_inputs
79
from pytensor.loop.op import Scan
810
from pytensor.scan.utils import until
9-
from pytensor.tensor import as_tensor, empty_like
11+
from pytensor.tensor import as_tensor, constant, empty_like, minimum
1012

1113

1214
def scan(
@@ -20,6 +22,8 @@ def scan(
2022
if sequences is None and n_steps is None:
2123
raise ValueError("Must provide n_steps when scanning without sequences")
2224

25+
# TODO: init_states should be made opaque to the inner function,
26+
# since any relationship to the outer graph no longer holds
2327
if init_states is None:
2428
init_states = []
2529
else:
@@ -34,20 +38,31 @@ def scan(
3438
sequences = [sequences]
3539
sequences = [as_tensor(s) for s in sequences]
3640

41+
if sequences:
42+
leading_dims = [seq.shape[0] for seq in sequences]
43+
shortest_dim = functools.reduce(minimum, leading_dims)
44+
if n_steps is None:
45+
n_steps = shortest_dim
46+
else:
47+
n_steps = minimum(n_steps, shortest_dim)
48+
3749
if non_sequences is None:
3850
non_sequences = []
3951
else:
4052
if not isinstance(non_sequences, (tuple, list)):
4153
non_sequences = [non_sequences]
4254
non_sequences = [as_symbolic(n) for n in non_sequences]
4355

56+
# Create subsequence inputs for the inner function
57+
idx = constant(0, dtype="int64", name="idx")
58+
symbolic_idx = idx.type(name="idx")
59+
subsequences = [s[symbolic_idx] for s in sequences]
4460
# Note: Old scan order is sequences + init + non_sequences
45-
inner_sequences = [s[0] for s in sequences]
46-
inner_inputs = [i.type() for i in init_states + inner_sequences + non_sequences]
47-
inner_outputs = fn(*inner_inputs)
48-
if not isinstance(inner_outputs, (tuple, list)):
49-
inner_outputs = [inner_outputs]
50-
next_states = [out for out in inner_outputs if not isinstance(out, until)]
61+
fn_inputs = init_states + subsequences + non_sequences
62+
fn_outputs = fn(*fn_inputs)
63+
if not isinstance(fn_outputs, (tuple, list)):
64+
fn_outputs = [fn_outputs]
65+
next_states = [out for out in fn_outputs if not isinstance(out, until)]
5166

5267
if len(next_states) > len(init_states):
5368
if not init_states:
@@ -61,27 +76,45 @@ def scan(
6176
prev_states = []
6277
for i, (init_state, next_state) in enumerate(zip(init_states, next_states)):
6378
if init_state is None:
79+
# next_state may reference idx, let's replace that by the initial value
80+
[next_state] = clone_replace(
81+
output=[next_state], replace={symbolic_idx: idx}
82+
)
6483
init_state = empty_like(next_state)
65-
init_state.name = "empty_init_state"
66-
inner_inputs.insert(i, init_state.type())
84+
init_state.name = (
85+
"empty_init_state" # add 1 offset, since idx is the first state
86+
)
6787
prev_states.append(init_state)
6888

69-
until_condition = [out.condition for out in inner_outputs if isinstance(out, until)]
89+
until_condition = [out.condition for out in fn_outputs if isinstance(out, until)]
7090
if not until_condition:
7191
until_condition = [as_tensor(np.array(True))]
7292
if len(until_condition) > 1:
7393
raise ValueError("Only one until condition can be returned")
7494

75-
update_fg = FunctionGraph(
76-
inputs=inner_inputs, outputs=until_condition + next_states
95+
fgraph_inputs = [symbolic_idx] + prev_states + sequences + non_sequences
96+
fgraph_outputs = until_condition + [symbolic_idx + 1] + next_states
97+
98+
all_fgraph_inputs = truncated_graph_inputs(
99+
fgraph_outputs, ancestors_to_include=fgraph_inputs
100+
)
101+
extra_fgraph_inputs = [
102+
inp
103+
for inp in all_fgraph_inputs
104+
if (not isinstance(inp, Constant) and inp not in fgraph_inputs)
105+
]
106+
fgraph_inputs = fgraph_inputs + extra_fgraph_inputs
107+
update_fg = FunctionGraph(inputs=fgraph_inputs, outputs=fgraph_outputs)
108+
109+
scan_op = Scan(update_fg=update_fg)
110+
scan_outs = scan_op(
111+
n_steps, idx, *prev_states, *sequences, *non_sequences, *extra_fgraph_inputs
77112
)
78-
scan_op = Scan(update_fg=update_fg, n_sequences=len(sequences))
79-
scan_outs = scan_op(n_steps, *prev_states, *sequences, *non_sequences)
80113
assert isinstance(scan_outs, list)
81114
last_states = scan_outs[: scan_op.n_states]
82115
traces = scan_outs[scan_op.n_states :]
83-
84-
return last_states, traces
116+
# Don't return the inner index state
117+
return last_states[1:], traces[1:]
85118

86119

87120
def map(

0 commit comments

Comments
 (0)