1
+ import functools
1
2
from typing import List , Tuple
2
3
3
4
import numpy as np
4
5
5
- from pytensor import Variable , as_symbolic
6
+ from pytensor import Variable , as_symbolic , clone_replace
6
7
from pytensor .graph import FunctionGraph
8
+ from pytensor .graph .basic import Constant , truncated_graph_inputs
7
9
from pytensor .loop .op import Scan
8
10
from pytensor .scan .utils import until
9
- from pytensor .tensor import as_tensor , empty_like
11
+ from pytensor .tensor import as_tensor , constant , empty_like , minimum
10
12
11
13
12
14
def scan (
@@ -20,6 +22,8 @@ def scan(
20
22
if sequences is None and n_steps is None :
21
23
raise ValueError ("Must provide n_steps when scanning without sequences" )
22
24
25
+ # TODO: init_states should be made opaque to the inner function,
26
+ # since any relationship to the outer graph no longer holds
23
27
if init_states is None :
24
28
init_states = []
25
29
else :
@@ -34,20 +38,31 @@ def scan(
34
38
sequences = [sequences ]
35
39
sequences = [as_tensor (s ) for s in sequences ]
36
40
41
+ if sequences :
42
+ leading_dims = [seq .shape [0 ] for seq in sequences ]
43
+ shortest_dim = functools .reduce (minimum , leading_dims )
44
+ if n_steps is None :
45
+ n_steps = shortest_dim
46
+ else :
47
+ n_steps = minimum (n_steps , shortest_dim )
48
+
37
49
if non_sequences is None :
38
50
non_sequences = []
39
51
else :
40
52
if not isinstance (non_sequences , (tuple , list )):
41
53
non_sequences = [non_sequences ]
42
54
non_sequences = [as_symbolic (n ) for n in non_sequences ]
43
55
56
+ # Create subsequence inputs for the inner function
57
+ idx = constant (0 , dtype = "int64" , name = "idx" )
58
+ symbolic_idx = idx .type (name = "idx" )
59
+ subsequences = [s [symbolic_idx ] for s in sequences ]
44
60
# Note: Old scan order is sequences + init + non_sequences
45
- inner_sequences = [s [0 ] for s in sequences ]
46
- inner_inputs = [i .type () for i in init_states + inner_sequences + non_sequences ]
47
- inner_outputs = fn (* inner_inputs )
48
- if not isinstance (inner_outputs , (tuple , list )):
49
- inner_outputs = [inner_outputs ]
50
- next_states = [out for out in inner_outputs if not isinstance (out , until )]
61
+ fn_inputs = init_states + subsequences + non_sequences
62
+ fn_outputs = fn (* fn_inputs )
63
+ if not isinstance (fn_outputs , (tuple , list )):
64
+ fn_outputs = [fn_outputs ]
65
+ next_states = [out for out in fn_outputs if not isinstance (out , until )]
51
66
52
67
if len (next_states ) > len (init_states ):
53
68
if not init_states :
@@ -61,27 +76,45 @@ def scan(
61
76
prev_states = []
62
77
for i , (init_state , next_state ) in enumerate (zip (init_states , next_states )):
63
78
if init_state is None :
79
+ # next_state may reference idx, let's replace that by the initial value
80
+ [next_state ] = clone_replace (
81
+ output = [next_state ], replace = {symbolic_idx : idx }
82
+ )
64
83
init_state = empty_like (next_state )
65
- init_state .name = "empty_init_state"
66
- inner_inputs .insert (i , init_state .type ())
84
+ init_state .name = (
85
+ "empty_init_state" # add 1 offset, since idx is the first state
86
+ )
67
87
prev_states .append (init_state )
68
88
69
- until_condition = [out .condition for out in inner_outputs if isinstance (out , until )]
89
+ until_condition = [out .condition for out in fn_outputs if isinstance (out , until )]
70
90
if not until_condition :
71
91
until_condition = [as_tensor (np .array (True ))]
72
92
if len (until_condition ) > 1 :
73
93
raise ValueError ("Only one until condition can be returned" )
74
94
75
- update_fg = FunctionGraph (
76
- inputs = inner_inputs , outputs = until_condition + next_states
95
+ fgraph_inputs = [symbolic_idx ] + prev_states + sequences + non_sequences
96
+ fgraph_outputs = until_condition + [symbolic_idx + 1 ] + next_states
97
+
98
+ all_fgraph_inputs = truncated_graph_inputs (
99
+ fgraph_outputs , ancestors_to_include = fgraph_inputs
100
+ )
101
+ extra_fgraph_inputs = [
102
+ inp
103
+ for inp in all_fgraph_inputs
104
+ if (not isinstance (inp , Constant ) and inp not in fgraph_inputs )
105
+ ]
106
+ fgraph_inputs = fgraph_inputs + extra_fgraph_inputs
107
+ update_fg = FunctionGraph (inputs = fgraph_inputs , outputs = fgraph_outputs )
108
+
109
+ scan_op = Scan (update_fg = update_fg )
110
+ scan_outs = scan_op (
111
+ n_steps , idx , * prev_states , * sequences , * non_sequences , * extra_fgraph_inputs
77
112
)
78
- scan_op = Scan (update_fg = update_fg , n_sequences = len (sequences ))
79
- scan_outs = scan_op (n_steps , * prev_states , * sequences , * non_sequences )
80
113
assert isinstance (scan_outs , list )
81
114
last_states = scan_outs [: scan_op .n_states ]
82
115
traces = scan_outs [scan_op .n_states :]
83
-
84
- return last_states , traces
116
+ # Don't return the inner index state
117
+ return last_states [ 1 :] , traces [ 1 :]
85
118
86
119
87
120
def map (
0 commit comments