From 3fe901d394c95b4d968fdd668e0d205b72937081 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 13 Jan 2023 11:51:53 +0100 Subject: [PATCH 1/6] Implement new Loop and Scan Operators Co-authored-by: Adrian Seyboldt --- pytensor/loop/__init__.py | 0 pytensor/loop/op.py | 448 ++++++++++++++++++++++++++++++++++++++ tests/loop/__init__.py | 0 tests/loop/test_op.py | 163 ++++++++++++++ 4 files changed, 611 insertions(+) create mode 100644 pytensor/loop/__init__.py create mode 100644 pytensor/loop/op.py create mode 100644 tests/loop/__init__.py create mode 100644 tests/loop/test_op.py diff --git a/pytensor/loop/__init__.py b/pytensor/loop/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pytensor/loop/op.py b/pytensor/loop/op.py new file mode 100644 index 0000000000..ba14e383ed --- /dev/null +++ b/pytensor/loop/op.py @@ -0,0 +1,448 @@ +from typing import Optional + +import numpy as np + +from pytensor import In, Out +from pytensor.compile import optdb, pfunc +from pytensor.graph import Apply, FunctionGraph, Op, Type, node_rewriter +from pytensor.graph.rewriting.basic import in2out +from pytensor.scalar import constant +from pytensor.tensor import ( + NoneConst, + add, + and_, + empty, + get_scalar_constant_value, + set_subtensor, +) +from pytensor.tensor.exceptions import NotScalarConstantError +from pytensor.tensor.shape import Shape_i +from pytensor.tensor.type import DenseTensorType, TensorType +from pytensor.tensor.type_other import NoneTypeT + + +def validate_loop_update_types(update): + assert update.outputs[0].type.dtype == "bool" + for i, (input_state, output_state) in enumerate( + zip(update.inputs, update.outputs[1:]) + ): + if input_state.type != output_state.type: + raise TypeError( + f"The {i}-th input and output states of the inner loop function have different types: " + f"{input_state.type} vs {output_state.type}." + ) + + +class Loop(Op): + """Represent a do-while loop. + + We represent the loop body as an inner FunctionGraph, which + computes the next state and whether the loop should continue. + + Roughly equivalent to + ``` + def loop(fn, initial_state, constants): + state = initial_state + while True: + resume, state = fn(i, state, *constants) + if not resume: + break + return state + ``` + Multiple initial states and constants can be provided + """ + + def __init__( + self, + update_fg: FunctionGraph, # (*state, *consts) -> (bool, *state) + reverse_fg: Optional[FunctionGraph] = None, + ): + validate_loop_update_types(update_fg) + self.state_types = [out.type for out in update_fg.outputs[1:]] + self.const_types = [ + inp.type for inp in update_fg.inputs[len(self.state_types) :] + ] + self.update_fg = update_fg + self.reverse_fg = reverse_fg + self._fn = None + + @property + def fn(self): + """Lazily compile the inner update function graph.""" + if self._fn is not None: + return self._fn + + fgraph = self.update_fg + wrapped_inputs = [In(x, borrow=True) for x in fgraph.inputs] + wrapped_outputs = [Out(x, borrow=False) for x in fgraph.outputs] + + self._fn = pfunc( + wrapped_inputs, + wrapped_outputs, + mode="FAST_RUN", # TODO: Figure this out + accept_inplace=False, + on_unused_input="ignore", + fgraph=fgraph, + ) + return self._fn + + def make_node(self, *inputs): + assert len(inputs) == len(self.state_types) + len(self.const_types) + + states = inputs[: len(self.state_types)] + states = [ + inp_type.filter_variable(inp) + for inp_type, inp in zip(self.state_types, states) + ] + + consts = inputs[len(self.state_types) :] + consts = [ + inp_type.filter_variable(inp) + for inp_type, inp in zip(self.const_types, consts) + ] + + return Apply( + self, + [*states, *consts], + [state_type() for state_type in self.state_types], + ) + + def infer_shape(self, fgraph, node, input_shapes): + return input_shapes[: len(self.state_types)] + + def perform(self, node, inputs, output_storage): + update_fn = self.fn + + states = inputs[: len(self.state_types)] + consts = inputs[len(self.state_types) :] + resume = True + while resume: + resume, *states = update_fn(*states, *consts) + + for i, state in enumerate(states): + output_storage[i][0] = state + + def L_Op(self, *args): + if not self.reverse_fg: + raise NotImplementedError() + # Use L_Op of self.reverse_fg + ... + + def R_Op(self, *args): + # Use R_op of self.update_fg + ... + + +class Scan(Op): + """Represent a scan. + + This Op can be thought of as a loop that collects intermediate steps + + Roughly equivalent to + ``` + def scan(fn, initial_states, constants, max_iters): + traces = [[]*len(initial_states)] + states = initial_states + for i in range(max_iters): + resume, states = fn(*states, *constants) + for trace, state in zip(traces, states): + trace.append(state) + if not resume: + break + return states, traces + ``` + Not all types of states can be collected, for instance RandomGenerator. For these + `None` is returned in place of the respective traces + + This Op must always be converted to a Loop during compilation. + """ + + def __init__( + self, + update_fg: FunctionGraph, # (*state, *consts) -> (bool, *state) + reverse_fg: Optional[FunctionGraph] = None, + ): + validate_loop_update_types(update_fg) + + self.state_types = [out.type for out in update_fg.outputs[1:]] + self.n_states = len(self.state_types) + self.trace_types: list[Type] = [] + for state_type in self.state_types: + # TODO: Accommodate SparseTensors and Scalars + if isinstance(state_type, DenseTensorType): + self.trace_types.append( + DenseTensorType( + shape=(None, *state_type.shape), dtype=state_type.dtype + ) + ) + else: + # We can't concatenate all types of states, such as RandomTypes + self.trace_types.append(NoneConst.type) + + self.constant_types = [inp.type for inp in update_fg.inputs[self.n_states :]] + self.n_constants = len(self.constant_types) + + self.update_fg = update_fg.clone(check_integrity=False) + self.reverse_fg = ( + reverse_fg.clone(check_integrity=False) if reverse_fg is not None else None + ) + + # It's more conservative to assume the Op has a while condition + self.has_while_condition = True + try: + self.has_while_condition = not get_scalar_constant_value( + update_fg.outputs[0] + ) + except NotScalarConstantError: + pass + + def make_node(self, max_iters, *inputs): + assert len(inputs) == self.n_states + self.n_constants + + max_iters = TensorType(dtype="int64", shape=()).filter_variable(max_iters) + + states = inputs[: self.n_states] + states = [ + inp_type.filter_variable(inp) + for inp_type, inp in zip(self.state_types, states) + ] + + constants = inputs[self.n_states :] + constants = [ + inp_type.filter_variable(inp) + for inp_type, inp in zip(self.constant_types, constants) + ] + + # If there is no while condition, `max_iters` exclusively defines the number of iterations + # If this value is constant, we can get static type shapes for the leading dimensions of traces + trace_types = self.trace_types + if not self.has_while_condition: + try: + n_iters = int(get_scalar_constant_value(max_iters)) + except NotScalarConstantError: + pass + else: + trace_types = [] + for trace_type in self.trace_types: + if isinstance(trace_type, DenseTensorType): + trace_types.append( + DenseTensorType( + dtype=trace_type.dtype, + shape=(n_iters, *trace_type.shape[1:]), + ) + ) + else: + trace_types.append(trace_type) + + return Apply( + self, + [max_iters, *states, *constants], + [output_type() for output_type in self.state_types + trace_types], + ) + + def infer_shape(self, fgraph, node, input_shapes): + # If there is a while condition, `max_iters` provides only the upper bound for the number of iterations + if self.has_while_condition: + # find the first non-None trace + trace_out = next( + trace + for trace in node.outputs[self.n_states :] + if not isinstance(trace.type, NoneTypeT) + ) + n_iters = Shape_i(0)(trace_out) + else: + n_iters = node.inputs[0] # max_iters + + state_shapes = input_shapes[1 : self.n_states + 1] + trace_shapes = [ + (n_iters, *state_shape) if state_shape is not None else None + for state_shape in state_shapes + ] + return state_shapes + trace_shapes + + def do_constant_folding(self, fgraph, node): + return False + + def perform(self, node, inputs, output_storage): + raise RuntimeError("Scan Op should not be present in compiled graph") + + def L_op(self, *args): + # Use trace outputs + ... + + def R_op(self, *args): + # Use R_op of self.update + ... + + +@node_rewriter([Scan]) +def scan_to_loop(fgraph, node): + """Rewrite a Scan Op into a Loop Op + + It roughly creates the following computational graph + ``` + + def scan(fn, idx, initial_states, constants, max_iters): + idx = 0 + states = initial_states + traces = [empty(max_iters, *initial_state.shape) for initial_state in initial_states] + while True: + resume, states, fn(*states, *traces, *constants) + for trace, state in zip(traces, states): + trace[idx] = state + idx += 1 + if not resume or idx >= max_iters: + break + traces = [trace[: idx] for trace in traces] + return states, traces + ``` + + Traces that are not used anywhere in the graph are omitted from the final Loop + + """ + op: Scan = node.op # type: ignore + + old_states = node.outputs[: op.n_states] + old_traces = node.outputs[op.n_states :] + + # Only include the intermediate states that are used elsewhere + used_traces_idxs = [ + i + for i, trace in enumerate(node.outputs[op.n_states :]) + if fgraph.clients[trace] + ] + + # Check that outputs that cannot be converted into sequences (such as RandomTypes) are not being referenced + for trace_idx in used_traces_idxs: + assert not isinstance(old_states[trace_idx].type, NoneTypeT) + + # Inputs to the new Loop + max_iters = node.inputs[0] + init_states = node.inputs[1 : 1 + op.n_states] + init_traces = [ + empty( + (max_iters, *tuple(init_states[trace_idx].shape)), + dtype=init_states[trace_idx].dtype, + ) + for trace_idx in used_traces_idxs + ] + constants = node.inputs[1 + op.n_states :] + + update_fg = op.update_fg.clone(check_integrity=False) + + # Check if inner_fg computes an index already, otherwise create a new one + has_idx = False + if len(node.inputs) > 1: + try: + outer_inp = node.inputs[1] + outer_is_zero = get_scalar_constant_value(outer_inp) == 0 + except NotScalarConstantError: + pass + else: + if ( + outer_is_zero + and len(update_fg.inputs) > 0 + and len(update_fg.outputs) > 1 + ): + inner_out = update_fg.outputs[1] + if ( + inner_out.owner is not None + and inner_out.owner.op == add + and len(inner_out.owner.inputs) == 2 + ): + left, right = inner_out.owner.inputs + if left is update_fg.inputs[0]: + try: + has_idx = ( + get_scalar_constant_value( + right, only_process_constants=True + ) + == 1 + ) + except NotScalarConstantError: + pass + + if has_idx: + init_idx = outer_inp + inner_idx = inner_out.owner.inputs[0] + inner_next_idx = inner_out + if not has_idx: + init_idx = constant(np.array(0, dtype="int64"), name="idx") + inner_idx = init_idx.type() + inner_idx.name = "idx" + inner_next_idx = inner_idx + 1 + inner_next_idx.name = "next_idx" + + # Inner traces + inner_states = update_fg.inputs[: op.n_states] + inner_traces = [init_trace.type() for init_trace in init_traces] + for s, t in zip(inner_states, inner_traces): + t.name = "trace" + if s.name: + t.name = "_".join((t.name, s.name)) + + inner_constants = update_fg.inputs[op.n_states :] + + # Inner while condition + inner_while_cond, *inner_next_states = update_fg.outputs + inner_next_traces = [ + set_subtensor(prev_trace[inner_idx], inner_next_states[trace_idx]) + for trace_idx, prev_trace in zip(used_traces_idxs, inner_traces) + ] + for t in inner_next_traces: + t.name = "next_trace" + inner_max_iters = max_iters.type() + inner_while_cond = and_(inner_while_cond, inner_next_idx < inner_max_iters) + inner_while_cond.name = "while(?)" + + if not has_idx: + init_states = [init_idx] + init_states + inner_states = [inner_idx] + inner_states + inner_next_states = [inner_next_idx] + inner_next_states + + new_update_fg = FunctionGraph( + inputs=[ + *inner_states, + *inner_traces, + *inner_constants, + inner_max_iters, + ], + outputs=[ + inner_while_cond, + *inner_next_states, + *inner_next_traces, + ], + ) + + # TODO: Implement Reverse? + loop_op = Loop(update_fg=new_update_fg) + + new_outs = loop_op(*init_states, *init_traces, *constants, max_iters) + if has_idx: + # idx was part of the original scan, and therefore has a corresponding trace + final_idx = new_outs[0] + else: + final_idx, *new_outs = new_outs + new_states = new_outs[: op.n_states] + new_traces = new_outs[op.n_states :] + + replacements = dict(zip(old_states, new_states)) + for trace_idx, new_trace in zip(used_traces_idxs, new_traces): + # If there is no while condition, the whole trace will be used + if op.has_while_condition: + new_trace = new_trace[:final_idx] + replacements[old_traces[trace_idx]] = new_trace + return replacements + + +# TODO: Create new Loop dataset +# Needs to be executed after `local_shape_to_shape_i`, otherwise shape graphs +# cannot be properly replaced +optdb.register( + "scan_to_loop", + in2out(scan_to_loop), + "fast_compile", + "fast_run", + "not_jax", + position=1.0, +) diff --git a/tests/loop/__init__.py b/tests/loop/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/loop/test_op.py b/tests/loop/test_op.py new file mode 100644 index 0000000000..c3de8f5c36 --- /dev/null +++ b/tests/loop/test_op.py @@ -0,0 +1,163 @@ +import numpy as np + +import pytensor +from pytensor import config, function, shared +from pytensor.compile import DeepCopyOp +from pytensor.graph import FunctionGraph +from pytensor.loop.op import Loop, Scan +from pytensor.tensor import constant, empty, lscalar, scalar, vector +from pytensor.tensor.random import normal +from pytensor.tensor.subtensor import Subtensor +from pytensor.tensor.type_other import NoneTypeT + + +def test_loop_basic(): + i = lscalar("i") + x = scalar("x") + update_fg = FunctionGraph([i, x], [(i + 1) < 10, i + 1, x + 2]) + + loop_op = Loop(update_fg=update_fg) + assert len(loop_op.state_types) == 2 + assert len(loop_op.const_types) == 0 + _, y = loop_op(np.array(0, dtype="int64"), x) + assert y.eval({x: 0}) == 20 + + +def test_loop_with_constant(): + i = lscalar("i") + x = scalar("x") + const = scalar("const") + update_fg = FunctionGraph([i, x, const], [(i + 1) < 10, i + 1, x + const]) + + loop_op = Loop(update_fg=update_fg) + assert len(loop_op.state_types) == 2 + assert len(loop_op.const_types) == 1 + _, y = loop_op(np.array(0, dtype="int64"), x, const) + assert y.eval({x: 0, const: 2}) == 20 + + +def test_fori_scan(): + x = scalar("x") + update_fg = FunctionGraph([x], [constant(np.array(True)), x + 2]) + + n_iters = 10 + y, ys = Scan(update_fg=update_fg)(n_iters, x) + + fn = function([x], [y, ys]) + + subtensor_nodes = tuple( + node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Subtensor) + ) + assert len(subtensor_nodes) == 0 + loop_nodes = tuple( + node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Loop) + ) + assert len(loop_nodes) == 1 + (loop_node,) = loop_nodes + assert len(loop_node.outputs) == 3 + assert loop_node.outputs[0].type.shape == () + assert loop_node.outputs[1].type.shape == () + assert loop_node.outputs[2].type.shape == (10,) + + y_eval, ys_eval = fn(0) + np.testing.assert_array_equal(ys_eval, np.arange(2, 22, 2)) + np.testing.assert_array_equal(ys_eval[-1], y_eval) + + +def test_fori_scan_shape(): + x = scalar("x") + update_fg = FunctionGraph([x], [constant(np.array(True)), x + 2]) + + n_iters = 10 + _, ys = Scan(update_fg=update_fg)(n_iters, x) + + fn = function([x], ys.shape, on_unused_input="ignore") + nodes = tuple(fn.maker.fgraph.apply_nodes) + assert len(nodes) == 1 + assert isinstance(nodes[0].op, DeepCopyOp) + assert fn(0) == 10 + + +def test_while_scan(): + i = lscalar("i") + x = scalar("x") + update_fg = FunctionGraph([i, x], [(i + 1) < 10, i + 1, x + 2]) + + max_iters = 1000 + _, y, _, ys = Scan(update_fg=update_fg)(max_iters, np.array(0, dtype="int64"), x) + + fn = function([x], [y, ys]) + + subtensor_nodes = tuple( + node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Subtensor) + ) + assert len(subtensor_nodes) == 1 + loop_nodes = tuple( + node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Loop) + ) + assert len(loop_nodes) == 1 + (loop_node,) = loop_nodes + assert len(loop_node.outputs) == 3 + assert loop_node.outputs[0].type.shape == () + assert loop_node.outputs[1].type.shape == () + assert loop_node.outputs[2].type.shape == (1000,) + + y_eval, ys_eval = fn(0) + np.testing.assert_array_equal(ys_eval, np.arange(2, 22, 2)) + np.testing.assert_array_equal(ys_eval[-1], y_eval) + + +def test_while_scan_shape(): + i = lscalar("i") + x = scalar("x") + update_fg = FunctionGraph([i, x], [(i + 1) < 10, i + 1, x + 2]) + + max_iters = 1000 + _, _, _, ys = Scan(update_fg=update_fg)(max_iters, np.array(0, dtype="int64"), x) + + fn = function([x], ys.shape) + loop_nodes = tuple( + node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Loop) + ) + assert len(loop_nodes) == 1 + assert fn(0) == 10 + + +def test_foreach_scan(): + idx = scalar("idx", dtype="int64") + dummy_x0 = empty(()) + xs = vector("xs") + const = scalar("const") + update_fg = FunctionGraph( + [idx, dummy_x0, xs, const], [constant(np.array(True)), idx + 1, xs[idx] * const] + ) + + n_steps = xs.shape[0] + _, _, _, ys = Scan(update_fg=update_fg)(n_steps, 0, dummy_x0, xs, const) + + fn = pytensor.function([xs, const], ys) + + np.testing.assert_almost_equal( + fn(np.arange(10, dtype=config.floatX), 100), np.arange(10) * 100 + ) + + +def test_fori_random_scan(): + rng_test = np.random.default_rng(123) + rng_shared = shared(np.random.default_rng(123)) + n_iters = 5 + + dummy_init = empty(()) + rng = rng_shared.type() + update_fg = FunctionGraph( + [dummy_init, rng], + [constant(np.array(True)), *normal(rng=rng).owner.outputs[::-1]], + ) + + _, new_rng, ys, rngs = Scan(update_fg=update_fg)(n_iters, dummy_init, rng_shared) + assert isinstance(rngs.type, NoneTypeT) + + fn = function([], ys, updates={rng_shared: new_rng}) + + np.testing.assert_array_equal(fn(), rng_test.normal(size=5)) + np.testing.assert_array_equal(fn(), rng_test.normal(size=5)) From 78bb829072db6e8746e921f1a81e3babe1dd9e9b Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 13 Jan 2023 11:53:09 +0100 Subject: [PATCH 2/6] Implement new scan constructor user facing functions Co-authored-by: Adrian Seyboldt --- pytensor/loop/basic.py | 200 +++++++++++++++++++++++++++++++++++++++ tests/loop/test_basic.py | 125 ++++++++++++++++++++++++ 2 files changed, 325 insertions(+) create mode 100644 pytensor/loop/basic.py create mode 100644 tests/loop/test_basic.py diff --git a/pytensor/loop/basic.py b/pytensor/loop/basic.py new file mode 100644 index 0000000000..56ba41e8ad --- /dev/null +++ b/pytensor/loop/basic.py @@ -0,0 +1,200 @@ +import functools +from typing import List, Tuple + +import numpy as np + +from pytensor import Variable, as_symbolic, clone_replace +from pytensor.graph import FunctionGraph +from pytensor.graph.basic import Constant, truncated_graph_inputs +from pytensor.loop.op import Scan +from pytensor.scan.utils import until +from pytensor.tensor import as_tensor, constant, empty_like, minimum + + +def scan( + fn, + init_states=None, + sequences=None, + non_sequences=None, + n_steps=None, + go_backwards=False, +) -> Tuple[List[Variable], List[Variable]]: + if sequences is None and n_steps is None: + raise ValueError("Must provide n_steps when scanning without sequences") + + if init_states is None: + init_states = [] + else: + if not isinstance(init_states, (tuple, list)): + init_states = [init_states] + init_states = [as_symbolic(i) if i is not None else None for i in init_states] + + if sequences is None: + sequences = [] + else: + if not isinstance(sequences, (tuple, list)): + sequences = [sequences] + sequences = [as_tensor(s) for s in sequences] + + if sequences: + leading_dims = [seq.shape[0] for seq in sequences] + shortest_dim = functools.reduce(minimum, leading_dims) + if n_steps is None: + n_steps = shortest_dim + else: + n_steps = minimum(n_steps, shortest_dim) + + if non_sequences is None: + non_sequences = [] + else: + if not isinstance(non_sequences, (tuple, list)): + non_sequences = [non_sequences] + non_sequences = [as_symbolic(n) for n in non_sequences] + + # Create dummy inputs for the init state. The user function should not + # draw any relationship with the outer initial states, since these are only + # valid in the first iteration + inner_states = [i.type() if i is not None else None for i in init_states] + + # Create subsequence inputs for the inner function + idx = constant(0, dtype="int64", name="idx") + symbolic_idx = idx.type(name="idx") + subsequences = [s[symbolic_idx] for s in sequences] + + # Call user function to retrieve inner outputs. We use the same order as the old Scan, + # although inner_states + subsequences + non_sequences seems more intuitive, + # since subsequences are just a fancy non_sequence + # We don't pass the non-carried outputs [init is None] to the inner function + fn_inputs = ( + subsequences + [i for i in inner_states if i is not None] + non_sequences + ) + fn_outputs = fn(*fn_inputs) + if not isinstance(fn_outputs, (tuple, list)): + fn_outputs = [fn_outputs] + next_states = [out for out in fn_outputs if not isinstance(out, until)] + + if len(next_states) > len(init_states): + if not init_states: + init_states = [None] * len(next_states) + inner_states = init_states + else: + raise ValueError( + "Please provide None as `init` for any output that is not carried over (i.e. it behaves like a map) " + ) + + # Replace None init by dummy empty tensors + prev_states = [] + prev_inner_states = [] + for i, (init_state, inner_state, next_state) in enumerate( + zip(init_states, inner_states, next_states) + ): + if init_state is None: + # next_state may reference idx. We replace that by the initial value, + # so that the shape of the dummy init state does not depend on it. + [next_state] = clone_replace( + output=[next_state], replace={symbolic_idx: idx} + ) + init_state = empty_like(next_state) + init_state.name = "empty_init_state" + inner_state = init_state.type(name="dummy_state") + prev_states.append(init_state) + prev_inner_states.append(inner_state) + + # Flip until to while condition + while_condition = [~out.condition for out in fn_outputs if isinstance(out, until)] + if not while_condition: + while_condition = [as_tensor(np.array(True))] + if len(while_condition) > 1: + raise ValueError("Only one until condition can be returned") + + fgraph_inputs = [symbolic_idx] + prev_inner_states + sequences + non_sequences + fgraph_outputs = while_condition + [symbolic_idx + 1] + next_states + + all_fgraph_inputs = truncated_graph_inputs( + fgraph_outputs, ancestors_to_include=fgraph_inputs + ) + extra_fgraph_inputs = [ + inp + for inp in all_fgraph_inputs + if (not isinstance(inp, Constant) and inp not in fgraph_inputs) + ] + fgraph_inputs = fgraph_inputs + extra_fgraph_inputs + update_fg = FunctionGraph(inputs=fgraph_inputs, outputs=fgraph_outputs) + + scan_op = Scan(update_fg=update_fg) + scan_outs = scan_op( + n_steps, idx, *prev_states, *sequences, *non_sequences, *extra_fgraph_inputs + ) + assert isinstance(scan_outs, list) + last_states = scan_outs[: scan_op.n_states] + traces = scan_outs[scan_op.n_states :] + # Don't return the inner index state + return last_states[1:], traces[1:] + + +def map( + fn, + sequences, + non_sequences=None, + go_backwards=False, +): + _, traces = scan( + fn=fn, + sequences=sequences, + non_sequences=non_sequences, + go_backwards=go_backwards, + ) + if len(traces) == 1: + return traces[0] + return traces + + +def reduce( + fn, + init_states, + sequences, + non_sequences=None, + go_backwards=False, +): + final_states, _ = scan( + fn=fn, + init_states=init_states, + sequences=sequences, + non_sequences=non_sequences, + go_backwards=go_backwards, + ) + if len(final_states) == 1: + return final_states[0] + return final_states + + +def filter( + fn, + sequences, + non_sequences=None, + go_backwards=False, +): + if not isinstance(sequences, (tuple, list)): + sequences = [sequences] + + _, masks = scan( + fn=fn, + sequences=sequences, + non_sequences=non_sequences, + go_backwards=go_backwards, + ) + + if not all(mask.dtype == "bool" for mask in masks): + raise TypeError("The output of filter fn should be a boolean variable") + if len(masks) == 1: + masks = [masks[0]] * len(sequences) + elif len(masks) != len(sequences): + raise ValueError( + "filter fn must return one variable or len(sequences), but it returned {len(masks)}" + ) + + filtered_sequences = [seq[mask] for seq, mask in zip(sequences, masks)] + + if len(filtered_sequences) == 1: + return filtered_sequences[0] + return filtered_sequences diff --git a/tests/loop/test_basic.py b/tests/loop/test_basic.py new file mode 100644 index 0000000000..32852f4719 --- /dev/null +++ b/tests/loop/test_basic.py @@ -0,0 +1,125 @@ +import numpy as np + +import pytensor +from pytensor import config, function, grad +from pytensor.loop.basic import filter, map, reduce, scan +from pytensor.scan import until +from pytensor.tensor import arange, eq, scalar, vector, zeros + + +def test_scan_with_sequences(): + xs = vector("xs") + ys = vector("ys") + _, [zs] = scan( + fn=lambda x, y: x * y, + sequences=[xs, ys], + ) + pytensor.dprint(ys, print_type=True) + np.testing.assert_almost_equal( + zs.eval( + { + xs: np.arange(10, dtype=config.floatX), + ys: np.arange(10, dtype=config.floatX), + } + ), + np.arange(10) ** 2, + ) + + +def test_scan_with_carried_and_non_carried_states(): + x = scalar("x") + _, [ys1, ys2] = scan( + fn=lambda xtm1: (xtm1 + 1, (xtm1 + 1) * 2), + init_states=[x, None], + n_steps=10, + ) + fn = function([x], [ys1, ys2]) + res = fn(-1) + np.testing.assert_almost_equal(res[0], np.arange(10)) + np.testing.assert_almost_equal(res[1], np.arange(10) * 2) + + +def test_scan_with_sequence_and_carried_state(): + xs = vector("xs") + _, [ys] = scan( + fn=lambda x, ytm1: (ytm1 + 1) * x, + init_states=[zeros(())], + sequences=[xs], + ) + fn = function([xs], ys) + np.testing.assert_almost_equal(fn([1, 2, 3]), [1, 4, 15]) + + +def test_scan_taking_grads_wrt_non_sequence(): + # Tests sequence + non-carried state + xs = vector("xs") + ys = xs**2 + + _, [J] = scan( + lambda i, ys, xs: grad(ys[i], wrt=xs), + sequences=arange(ys.shape[0]), + non_sequences=[ys, xs], + ) + + f = pytensor.function([xs], J) + np.testing.assert_array_equal(f([4, 4]), np.c_[[8, 0], [0, 8]]) + + +def test_scan_taking_grads_wrt_sequence(): + # This is not possible with the old Scan + xs = vector("xs") + ys = xs**2 + + _, [J] = scan( + lambda y, xs: grad(y, wrt=xs), + sequences=[ys], + non_sequences=[xs], + ) + + f = pytensor.function([xs], J) + np.testing.assert_array_equal(f([4, 4]), np.c_[[8, 0], [0, 8]]) + + +def test_while_scan(): + _, [xs] = scan( + fn=lambda x: (x + 1, until((x + 1) >= 9)), + init_states=[-1], + n_steps=20, + ) + + f = pytensor.function([], xs) + np.testing.assert_array_equal(f(), np.arange(10)) + + +def test_map(): + xs = vector("xs") + ys = map( + fn=lambda x: x * 100, + sequences=xs, + ) + np.testing.assert_almost_equal( + ys.eval({xs: np.arange(10, dtype=config.floatX)}), np.arange(10) * 100 + ) + + +def test_reduce(): + xs = vector("xs") + y = reduce( + fn=lambda x, acc: acc + x, + init_states=zeros(()), + sequences=xs, + ) + np.testing.assert_almost_equal( + y.eval({xs: np.arange(10, dtype=config.floatX)}), np.arange(10).cumsum()[-1] + ) + + +def test_filter(): + xs = vector("xs") + ys = filter( + fn=lambda x: eq(x % 2, 0), + sequences=xs, + ) + np.testing.assert_array_equal( + ys.eval({xs: np.arange(0, 20, dtype=config.floatX)}), np.arange(0, 20, 2) + ) From 4da2f6e686b0f94c7475b9e284527ad22abf5b3c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 13 Jan 2023 17:02:16 +0100 Subject: [PATCH 3/6] Add JAX rewrite for new Scan Op --- pytensor/compile/mode.py | 4 +- pytensor/link/jax/dispatch/__init__.py | 1 + pytensor/link/jax/dispatch/loop.py | 52 +++++++++++++ pytensor/link/utils.py | 1 + tests/link/jax/test_basic.py | 10 +-- tests/link/jax/test_loop.py | 104 +++++++++++++++++++++++++ 6 files changed, 165 insertions(+), 7 deletions(-) create mode 100644 pytensor/link/jax/dispatch/loop.py create mode 100644 tests/link/jax/test_loop.py diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 8aecf1a902..dc661fe0de 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -449,7 +449,9 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): JAX = Mode( JAXLinker(), - RewriteDatabaseQuery(include=["fast_run", "jax"], exclude=["cxx_only", "BlasOpt"]), + RewriteDatabaseQuery( + include=["fast_run", "jax"], exclude=["cxx_only", "BlasOpt", "not_jax"] + ), ) NUMBA = Mode( NumbaLinker(), diff --git a/pytensor/link/jax/dispatch/__init__.py b/pytensor/link/jax/dispatch/__init__.py index efcac788cf..2f08bbbb94 100644 --- a/pytensor/link/jax/dispatch/__init__.py +++ b/pytensor/link/jax/dispatch/__init__.py @@ -12,5 +12,6 @@ import pytensor.link.jax.dispatch.random import pytensor.link.jax.dispatch.elemwise import pytensor.link.jax.dispatch.scan +import pytensor.link.jax.dispatch.loop # isort: on diff --git a/pytensor/link/jax/dispatch/loop.py b/pytensor/link/jax/dispatch/loop.py new file mode 100644 index 0000000000..2c5b9609d9 --- /dev/null +++ b/pytensor/link/jax/dispatch/loop.py @@ -0,0 +1,52 @@ +import jax + +from pytensor.compile.mode import get_mode +from pytensor.link.jax.dispatch.basic import jax_funcify +from pytensor.loop.op import Scan + + +@jax_funcify.register(Scan) +def jax_funcify_Scan(op, node, global_fgraph, **kwargs): + # TODO: Rewrite as a while loop if only last states are used + if op.has_while_condition: + raise NotImplementedError( + "Scan ops with while condition cannot be transpiled JAX" + ) + + # Apply inner rewrites + # TODO: Not sure this is the right place to do this, should we have a rewrite that + # explicitly triggers the optimization of the inner graphs of Scan? + update_fg = op.update_fg.clone() + rewriter = get_mode("JAX").optimizer + rewriter(update_fg) + + jaxified_scan_inner_fn = jax_funcify(update_fg, **kwargs) + + # Only include the intermediate states that are used elsewhere + used_traces_idxs = [ + i + for i, trace in enumerate(node.outputs[op.n_states :]) + if global_fgraph.clients[trace] + ] + + def scan(max_iters, *outer_inputs): + states = outer_inputs[: op.n_states] + constants = outer_inputs[op.n_states :] + + def scan_fn(carry, _): + resume, *carry = jaxified_scan_inner_fn(*carry, *constants) + assert resume + carry = list(carry) + # Return states as both carry and output to be appended + return carry, [c for i, c in enumerate(carry) if i in used_traces_idxs] + + states, traces = jax.lax.scan( + scan_fn, init=list(states), xs=None, length=max_iters + ) + for i in range(len(states)): + if i not in used_traces_idxs: + traces.insert(i, None) + + return *states, *traces + + return scan diff --git a/pytensor/link/utils.py b/pytensor/link/utils.py index fd76e1278e..09d70756f7 100644 --- a/pytensor/link/utils.py +++ b/pytensor/link/utils.py @@ -736,6 +736,7 @@ def fgraph_to_python( global_env = {} body_assigns = [] + kwargs.setdefault("global_fgraph", fgraph) for node in order: compiled_func = op_conversion_fn( node.op, node=node, storage_map=storage_map, **kwargs diff --git a/tests/link/jax/test_basic.py b/tests/link/jax/test_basic.py index e20bd255fb..a86c2636c9 100644 --- a/tests/link/jax/test_basic.py +++ b/tests/link/jax/test_basic.py @@ -5,15 +5,13 @@ import pytest from pytensor.compile.function import function -from pytensor.compile.mode import Mode +from pytensor.compile.mode import get_mode from pytensor.compile.sharedvalue import SharedVariable, shared from pytensor.configdefaults import config from pytensor.graph.basic import Apply from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op, get_test_value -from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.ifelse import ifelse -from pytensor.link.jax import JAXLinker from pytensor.raise_op import assert_op from pytensor.tensor.type import dscalar, scalar, vector @@ -27,9 +25,9 @@ def set_pytensor_flags(): jax = pytest.importorskip("jax") -opts = RewriteDatabaseQuery(include=["jax"], exclude=["cxx_only", "BlasOpt"]) -jax_mode = Mode(JAXLinker(), opts) -py_mode = Mode("py", opts) +# We assume that the JAX mode includes all the rewrites needed to transpile JAX graphs +jax_mode = get_mode("JAX") +py_mode = get_mode("FAST_COMPILE") def compare_jax_and_py( diff --git a/tests/link/jax/test_loop.py b/tests/link/jax/test_loop.py new file mode 100644 index 0000000000..64a28444e5 --- /dev/null +++ b/tests/link/jax/test_loop.py @@ -0,0 +1,104 @@ +import numpy as np +import pytest + +from pytensor import config, function, shared +from pytensor.graph import FunctionGraph +from pytensor.loop.basic import scan +from pytensor.scan import until +from pytensor.tensor import scalar, vector, zeros +from pytensor.tensor.random import normal +from tests.link.jax.test_basic import compare_jax_and_py + + +def test_scan_with_single_sequence(): + xs = vector("xs") + _, [ys] = scan(lambda x: x * 100, sequences=[xs]) + + out_fg = FunctionGraph([xs], [ys]) + compare_jax_and_py(out_fg, [np.arange(10, dtype=config.floatX)]) + + +def test_scan_with_single_sequence_shortened_by_nsteps(): + xs = vector("xs", shape=(10,)) # JAX needs the length to be constant + _, [ys] = scan( + lambda x: x * 100, + sequences=[xs], + n_steps=9, + ) + + out_fg = FunctionGraph([xs], [ys]) + compare_jax_and_py(out_fg, [np.arange(10, dtype=config.floatX)]) + + +def test_scan_with_multiple_sequences(): + # JAX can only handle constant n_steps + xs = vector("xs", shape=(10,)) + ys = vector("ys", shape=(10,)) + _, [zs] = scan( + fn=lambda x, y: x * y, + sequences=[xs, ys], + ) + + out_fg = FunctionGraph([xs, ys], [zs]) + compare_jax_and_py( + out_fg, [np.arange(10, dtype=xs.dtype), np.arange(10, dtype=ys.dtype)] + ) + + +def test_scan_with_carried_and_non_carried_states(): + x = scalar("x") + _, [ys1, ys2] = scan( + fn=lambda xtm1: (xtm1 + 1, (xtm1 + 1) * 2), + init_states=[x, None], + n_steps=10, + ) + out_fg = FunctionGraph([x], [ys1, ys2]) + compare_jax_and_py(out_fg, [-1]) + + +def test_scan_with_sequence_and_carried_state(): + xs = vector("xs") + _, [ys] = scan( + fn=lambda x, ytm1: (ytm1 + 1) * x, + init_states=[zeros(())], + sequences=[xs], + ) + out_fg = FunctionGraph([xs], [ys]) + compare_jax_and_py(out_fg, [[1, 2, 3]]) + + +def test_scan_with_rvs(): + rng = shared(np.random.default_rng(123)) + + [next_rng, _], [_, xs] = scan( + fn=lambda prev_rng: normal(rng=prev_rng).owner.outputs, + init_states=[rng, None], + n_steps=10, + ) + + # First without updates + fn = function([], xs, mode="JAX", updates=None) + res1 = fn() + res2 = fn() + assert not set(tuple(np.array(res1))) ^ set(tuple(np.array(res2))) + + # Now with updates + fn = function([], xs, mode="JAX", updates={rng: next_rng}) + res1 = fn() + res2 = fn() + assert not set(tuple(np.array(res1))) & set(tuple(np.array(res2))) + + +def test_while_scan_fails(): + _, [xs] = scan( + fn=lambda x: (x + 1, until((x + 1) >= 9)), + init_states=[-1], + n_steps=20, + ) + + out_fg = FunctionGraph([], [xs]) + with pytest.raises( + NotImplementedError, + match="Scan ops with while condition cannot be transpiled JAX", + ): + compare_jax_and_py(out_fg, []) From e2fdf28bfce36cd1c4bf636975eec08a5173b1fc Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 16 Jan 2023 11:20:29 +0100 Subject: [PATCH 4/6] Override __bool__ of TypedListType --- pytensor/typed_list/basic.py | 20 ++++++++++++++++++++ tests/typed_list/test_type.py | 4 ++++ 2 files changed, 24 insertions(+) diff --git a/pytensor/typed_list/basic.py b/pytensor/typed_list/basic.py index 1b836c0bfa..f458779064 100644 --- a/pytensor/typed_list/basic.py +++ b/pytensor/typed_list/basic.py @@ -19,6 +19,11 @@ def __getitem__(self, index): def __len__(self): return length(self) + def __bool__(self): + # Truthiness of typedLists cannot depend on length, + # just like truthiness of TensorVariables does not depend on size or contents + return True + def append(self, toAppend): return append(self, toAppend) @@ -677,3 +682,18 @@ def perform(self, node, inputs, outputs): All PyTensor variables must have the same type. """ + + +class MakeEmptyList(Op): + __props__ = () + + def make_node(self, ttype): + tl = TypedListType(ttype)() + return Apply(self, [], [tl]) + + def perform(self, node, inputs, outputs): + (out,) = outputs + out[0] = [] + + +make_empty_list = MakeEmptyList() diff --git a/tests/typed_list/test_type.py b/tests/typed_list/test_type.py index 41c1a8326b..a5281e21d1 100644 --- a/tests/typed_list/test_type.py +++ b/tests/typed_list/test_type.py @@ -150,3 +150,7 @@ def test_variable_is_Typed_List_variable(self): )() assert isinstance(mySymbolicVariable, TypedListVariable) + + def test_any(self): + tlist = TypedListType(TensorType(dtype="int64", shape=(None,)))() + assert any([tlist]) From db7068e6ce0755454ecc792132cbf73af7c1fcae Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 16 Jan 2023 11:23:22 +0100 Subject: [PATCH 5/6] Allow non-TensorVariable types to be traced in new Scan Op --- pytensor/link/jax/dispatch/loop.py | 19 +++++++--- pytensor/loop/op.py | 60 ++++++++++++++++++++++-------- tests/link/jax/test_loop.py | 13 ++++++- tests/loop/test_op.py | 37 ++++++++++++++---- 4 files changed, 100 insertions(+), 29 deletions(-) diff --git a/pytensor/link/jax/dispatch/loop.py b/pytensor/link/jax/dispatch/loop.py index 2c5b9609d9..d17d740d15 100644 --- a/pytensor/link/jax/dispatch/loop.py +++ b/pytensor/link/jax/dispatch/loop.py @@ -1,8 +1,10 @@ import jax +from jax.tree_util import tree_flatten, tree_unflatten from pytensor.compile.mode import get_mode from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.loop.op import Scan +from pytensor.typed_list import TypedListType @jax_funcify.register(Scan) @@ -43,10 +45,17 @@ def scan_fn(carry, _): states, traces = jax.lax.scan( scan_fn, init=list(states), xs=None, length=max_iters ) - for i in range(len(states)): - if i not in used_traces_idxs: - traces.insert(i, None) - - return *states, *traces + final_traces = [None] * len(states) + for idx, trace in zip(used_traces_idxs, traces): + if isinstance(op.trace_types[idx], TypedListType): + flattened_trace, treedef = tree_flatten(trace) + transposed_trace = [ + tree_unflatten(treedef, l) for l in zip(*flattened_trace) + ] + final_traces[idx] = transposed_trace + else: + final_traces[idx] = trace + + return *states, *final_traces return scan diff --git a/pytensor/loop/op.py b/pytensor/loop/op.py index ba14e383ed..565d2a3b7c 100644 --- a/pytensor/loop/op.py +++ b/pytensor/loop/op.py @@ -7,18 +7,13 @@ from pytensor.graph import Apply, FunctionGraph, Op, Type, node_rewriter from pytensor.graph.rewriting.basic import in2out from pytensor.scalar import constant -from pytensor.tensor import ( - NoneConst, - add, - and_, - empty, - get_scalar_constant_value, - set_subtensor, -) +from pytensor.tensor import add, and_, empty, get_scalar_constant_value, set_subtensor from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.shape import Shape_i +from pytensor.tensor.subtensor import Subtensor, get_idx_list from pytensor.tensor.type import DenseTensorType, TensorType from pytensor.tensor.type_other import NoneTypeT +from pytensor.typed_list import GetItem, TypedListType, append, make_empty_list def validate_loop_update_types(update): @@ -176,8 +171,7 @@ def __init__( ) ) else: - # We can't concatenate all types of states, such as RandomTypes - self.trace_types.append(NoneConst.type) + self.trace_types.append(TypedListType(state_type)) self.constant_types = [inp.type for inp in update_fg.inputs[self.n_states :]] self.n_constants = len(self.constant_types) @@ -312,10 +306,6 @@ def scan(fn, idx, initial_states, constants, max_iters): if fgraph.clients[trace] ] - # Check that outputs that cannot be converted into sequences (such as RandomTypes) are not being referenced - for trace_idx in used_traces_idxs: - assert not isinstance(old_states[trace_idx].type, NoneTypeT) - # Inputs to the new Loop max_iters = node.inputs[0] init_states = node.inputs[1 : 1 + op.n_states] @@ -324,6 +314,8 @@ def scan(fn, idx, initial_states, constants, max_iters): (max_iters, *tuple(init_states[trace_idx].shape)), dtype=init_states[trace_idx].dtype, ) + if isinstance(init_states[trace_idx].type, DenseTensorType) + else make_empty_list(init_states[trace_idx].type) for trace_idx in used_traces_idxs ] constants = node.inputs[1 + op.n_states :] @@ -387,6 +379,8 @@ def scan(fn, idx, initial_states, constants, max_iters): inner_while_cond, *inner_next_states = update_fg.outputs inner_next_traces = [ set_subtensor(prev_trace[inner_idx], inner_next_states[trace_idx]) + if isinstance(prev_trace.type, DenseTensorType) + else append(prev_trace, inner_next_states[trace_idx]) for trace_idx, prev_trace in zip(used_traces_idxs, inner_traces) ] for t in inner_next_traces: @@ -429,7 +423,7 @@ def scan(fn, idx, initial_states, constants, max_iters): replacements = dict(zip(old_states, new_states)) for trace_idx, new_trace in zip(used_traces_idxs, new_traces): # If there is no while condition, the whole trace will be used - if op.has_while_condition: + if op.has_while_condition and isinstance(new_trace.type, DenseTensorType): new_trace = new_trace[:final_idx] replacements[old_traces[trace_idx]] = new_trace return replacements @@ -446,3 +440,39 @@ def scan(fn, idx, initial_states, constants, max_iters): "not_jax", position=1.0, ) + + +@node_rewriter([Scan]) +def scan_view_last_state(fgraph, node): + """Replace trace[-1] by the last state output of a Scan node""" + replacements = {} + for final_state, trace in zip( + node.outputs[: node.op.n_states], node.outputs[node.op.n_states :] + ): + clients = fgraph.clients[trace] + for client, _ in clients: + if client == "output": + continue + if isinstance(client.op, (Subtensor, GetItem)): + if isinstance(client.op, Subtensor): + idxs = get_idx_list(client.inputs, client.op.idx_list) + if len(idxs) == 1: + idx = idxs[0] + else: + idx = client.inputs[1] + try: + last_index = get_scalar_constant_value(idx) == -1 + except NotScalarConstantError: + continue + if last_index: + replacements[client.default_output()] = final_state + return replacements + + +optdb.register( + "scan_view_last_state", + in2out(scan_view_last_state), + "fast_compile", + "fast_run", + position=0.999, +) diff --git a/tests/link/jax/test_loop.py b/tests/link/jax/test_loop.py index 64a28444e5..78c3f07964 100644 --- a/tests/link/jax/test_loop.py +++ b/tests/link/jax/test_loop.py @@ -1,5 +1,6 @@ import numpy as np import pytest +from jax.tree_util import tree_leaves from pytensor import config, function, shared from pytensor.graph import FunctionGraph @@ -70,7 +71,7 @@ def test_scan_with_sequence_and_carried_state(): def test_scan_with_rvs(): rng = shared(np.random.default_rng(123)) - [next_rng, _], [_, xs] = scan( + [final_rng, _], [rngs, xs] = scan( fn=lambda prev_rng: normal(rng=prev_rng).owner.outputs, init_states=[rng, None], n_steps=10, @@ -83,11 +84,19 @@ def test_scan_with_rvs(): assert not set(tuple(np.array(res1))) ^ set(tuple(np.array(res2))) # Now with updates - fn = function([], xs, mode="JAX", updates={rng: next_rng}) + fn = function([], xs, mode="JAX", updates={rng: final_rng}) res1 = fn() res2 = fn() assert not set(tuple(np.array(res1))) & set(tuple(np.array(res2))) + # Test traced rngs + fn = function([], [rngs, final_rng], mode="JAX") + rngs_res, final_rng_res = fn() + assert isinstance(rngs_res, list) and len(rngs_res) == 10 + assert [np.array(v).tolist() for v in tree_leaves(rngs_res[-1])] == [ + np.array(v).tolist() for v in tree_leaves(final_rng_res) + ] + def test_while_scan_fails(): _, [xs] = scan( diff --git a/tests/loop/test_op.py b/tests/loop/test_op.py index c3de8f5c36..c0e46aa7b8 100644 --- a/tests/loop/test_op.py +++ b/tests/loop/test_op.py @@ -4,11 +4,13 @@ from pytensor import config, function, shared from pytensor.compile import DeepCopyOp from pytensor.graph import FunctionGraph -from pytensor.loop.op import Loop, Scan +from pytensor.graph.rewriting.basic import in2out +from pytensor.loop.op import Loop, Scan, scan_view_last_state from pytensor.tensor import constant, empty, lscalar, scalar, vector from pytensor.tensor.random import normal +from pytensor.tensor.random.type import RandomGeneratorType from pytensor.tensor.subtensor import Subtensor -from pytensor.tensor.type_other import NoneTypeT +from pytensor.typed_list import TypedListType def test_loop_basic(): @@ -154,10 +156,31 @@ def test_fori_random_scan(): [constant(np.array(True)), *normal(rng=rng).owner.outputs[::-1]], ) - _, new_rng, ys, rngs = Scan(update_fg=update_fg)(n_iters, dummy_init, rng_shared) - assert isinstance(rngs.type, NoneTypeT) + last_y, last_rng, ys, rngs = Scan(update_fg=update_fg)( + n_iters, dummy_init, rng_shared + ) + assert isinstance(last_rng.type, RandomGeneratorType) + assert isinstance(rngs.type, TypedListType) + assert isinstance(rngs.type.ttype, RandomGeneratorType) + + fn = function([], [ys, rngs], updates={rng_shared: last_rng}) + for i in range(2): + ys_res, rngs_res = fn() + for y_res, rng_res in zip(ys_res, rngs_res): + np.testing.assert_almost_equal(y_res, rng_test.normal()) + assert rng_res.__getstate__() == rng_test.__getstate__() - fn = function([], ys, updates={rng_shared: new_rng}) - np.testing.assert_array_equal(fn(), rng_test.normal(size=5)) - np.testing.assert_array_equal(fn(), rng_test.normal(size=5)) +def test_scan_view_last_state(): + x = scalar("x") + update_fg = FunctionGraph([x], [x > 5, x + 2]) + + n_iters = 10 + y1, ys = Scan(update_fg=update_fg)(n_iters, x) + + y2 = ys[-1] + fgraph = FunctionGraph(outputs=[y2, ys], clone=False) + assert fgraph.outputs[0] is not y1 + in2out(scan_view_last_state).apply(fgraph) + assert fgraph.outputs[0] is y1 + assert fgraph.outputs[1] is ys From 5bc7070a82d3937d44a41be319f930bd49dc1e98 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 16 Jan 2023 14:27:27 +0100 Subject: [PATCH 6/6] Make scan helper return sequences to match old API This was not possible prior to use of TypedListType for non TensorVariable sequences, as it would otherwise not be possible to represent indexing of last sequence state, which is needed e.g., for shared random generator updates. --- pytensor/loop/basic.py | 35 +++++++++++++++++------------------ tests/link/jax/test_loop.py | 15 ++++++++------- tests/loop/test_basic.py | 35 ++++++++++++++++++++++++++++------- 3 files changed, 53 insertions(+), 32 deletions(-) diff --git a/pytensor/loop/basic.py b/pytensor/loop/basic.py index 56ba41e8ad..cbe3efa52b 100644 --- a/pytensor/loop/basic.py +++ b/pytensor/loop/basic.py @@ -1,5 +1,5 @@ import functools -from typing import List, Tuple +from typing import List, Union import numpy as np @@ -18,7 +18,7 @@ def scan( non_sequences=None, n_steps=None, go_backwards=False, -) -> Tuple[List[Variable], List[Variable]]: +) -> Union[Variable, List[Variable]]: if sequences is None and n_steps is None: raise ValueError("Must provide n_steps when scanning without sequences") @@ -126,10 +126,11 @@ def scan( n_steps, idx, *prev_states, *sequences, *non_sequences, *extra_fgraph_inputs ) assert isinstance(scan_outs, list) - last_states = scan_outs[: scan_op.n_states] - traces = scan_outs[scan_op.n_states :] - # Don't return the inner index state - return last_states[1:], traces[1:] + # Don't return the last states or the trace for the inner index + traces = scan_outs[scan_op.n_states + 1 :] + if len(traces) == 1: + return traces[0] + return traces def map( @@ -138,14 +139,12 @@ def map( non_sequences=None, go_backwards=False, ): - _, traces = scan( + traces = scan( fn=fn, sequences=sequences, non_sequences=non_sequences, go_backwards=go_backwards, ) - if len(traces) == 1: - return traces[0] return traces @@ -156,16 +155,16 @@ def reduce( non_sequences=None, go_backwards=False, ): - final_states, _ = scan( + traces = scan( fn=fn, init_states=init_states, sequences=sequences, non_sequences=non_sequences, go_backwards=go_backwards, ) - if len(final_states) == 1: - return final_states[0] - return final_states + if not isinstance(traces, list): + return traces[-1] + return [trace[-1] for trace in traces] def filter( @@ -177,21 +176,21 @@ def filter( if not isinstance(sequences, (tuple, list)): sequences = [sequences] - _, masks = scan( + masks = scan( fn=fn, sequences=sequences, non_sequences=non_sequences, go_backwards=go_backwards, ) - if not all(mask.dtype == "bool" for mask in masks): - raise TypeError("The output of filter fn should be a boolean variable") - if len(masks) == 1: - masks = [masks[0]] * len(sequences) + if not isinstance(masks, list): + masks = [masks] * len(sequences) elif len(masks) != len(sequences): raise ValueError( "filter fn must return one variable or len(sequences), but it returned {len(masks)}" ) + if not all(mask.dtype == "bool" for mask in masks): + raise TypeError("The output of filter fn should be a boolean variable") filtered_sequences = [seq[mask] for seq, mask in zip(sequences, masks)] diff --git a/tests/link/jax/test_loop.py b/tests/link/jax/test_loop.py index 78c3f07964..f7efc527fc 100644 --- a/tests/link/jax/test_loop.py +++ b/tests/link/jax/test_loop.py @@ -13,7 +13,7 @@ def test_scan_with_single_sequence(): xs = vector("xs") - _, [ys] = scan(lambda x: x * 100, sequences=[xs]) + ys = scan(lambda x: x * 100, sequences=[xs]) out_fg = FunctionGraph([xs], [ys]) compare_jax_and_py(out_fg, [np.arange(10, dtype=config.floatX)]) @@ -21,7 +21,7 @@ def test_scan_with_single_sequence(): def test_scan_with_single_sequence_shortened_by_nsteps(): xs = vector("xs", shape=(10,)) # JAX needs the length to be constant - _, [ys] = scan( + ys = scan( lambda x: x * 100, sequences=[xs], n_steps=9, @@ -35,7 +35,7 @@ def test_scan_with_multiple_sequences(): # JAX can only handle constant n_steps xs = vector("xs", shape=(10,)) ys = vector("ys", shape=(10,)) - _, [zs] = scan( + zs = scan( fn=lambda x, y: x * y, sequences=[xs, ys], ) @@ -48,7 +48,7 @@ def test_scan_with_multiple_sequences(): def test_scan_with_carried_and_non_carried_states(): x = scalar("x") - _, [ys1, ys2] = scan( + [ys1, ys2] = scan( fn=lambda xtm1: (xtm1 + 1, (xtm1 + 1) * 2), init_states=[x, None], n_steps=10, @@ -59,7 +59,7 @@ def test_scan_with_carried_and_non_carried_states(): def test_scan_with_sequence_and_carried_state(): xs = vector("xs") - _, [ys] = scan( + ys = scan( fn=lambda x, ytm1: (ytm1 + 1) * x, init_states=[zeros(())], sequences=[xs], @@ -71,11 +71,12 @@ def test_scan_with_sequence_and_carried_state(): def test_scan_with_rvs(): rng = shared(np.random.default_rng(123)) - [final_rng, _], [rngs, xs] = scan( + [rngs, xs] = scan( fn=lambda prev_rng: normal(rng=prev_rng).owner.outputs, init_states=[rng, None], n_steps=10, ) + final_rng = rngs[-1] # First without updates fn = function([], xs, mode="JAX", updates=None) @@ -99,7 +100,7 @@ def test_scan_with_rvs(): def test_while_scan_fails(): - _, [xs] = scan( + xs = scan( fn=lambda x: (x + 1, until((x + 1) >= 9)), init_states=[-1], n_steps=20, diff --git a/tests/loop/test_basic.py b/tests/loop/test_basic.py index 32852f4719..e66ee31793 100644 --- a/tests/loop/test_basic.py +++ b/tests/loop/test_basic.py @@ -1,16 +1,17 @@ import numpy as np import pytensor -from pytensor import config, function, grad +from pytensor import config, function, grad, shared from pytensor.loop.basic import filter, map, reduce, scan from pytensor.scan import until from pytensor.tensor import arange, eq, scalar, vector, zeros +from pytensor.tensor.random import normal def test_scan_with_sequences(): xs = vector("xs") ys = vector("ys") - _, [zs] = scan( + zs = scan( fn=lambda x, y: x * y, sequences=[xs, ys], ) @@ -28,7 +29,7 @@ def test_scan_with_sequences(): def test_scan_with_carried_and_non_carried_states(): x = scalar("x") - _, [ys1, ys2] = scan( + [ys1, ys2] = scan( fn=lambda xtm1: (xtm1 + 1, (xtm1 + 1) * 2), init_states=[x, None], n_steps=10, @@ -41,7 +42,7 @@ def test_scan_with_carried_and_non_carried_states(): def test_scan_with_sequence_and_carried_state(): xs = vector("xs") - _, [ys] = scan( + ys = scan( fn=lambda x, ytm1: (ytm1 + 1) * x, init_states=[zeros(())], sequences=[xs], @@ -55,7 +56,7 @@ def test_scan_taking_grads_wrt_non_sequence(): xs = vector("xs") ys = xs**2 - _, [J] = scan( + J = scan( lambda i, ys, xs: grad(ys[i], wrt=xs), sequences=arange(ys.shape[0]), non_sequences=[ys, xs], @@ -70,7 +71,7 @@ def test_scan_taking_grads_wrt_sequence(): xs = vector("xs") ys = xs**2 - _, [J] = scan( + J = scan( lambda y, xs: grad(y, wrt=xs), sequences=[ys], non_sequences=[xs], @@ -81,7 +82,7 @@ def test_scan_taking_grads_wrt_sequence(): def test_while_scan(): - _, [xs] = scan( + xs = scan( fn=lambda x: (x + 1, until((x + 1) >= 9)), init_states=[-1], n_steps=20, @@ -91,6 +92,26 @@ def test_while_scan(): np.testing.assert_array_equal(f(), np.arange(10)) +def test_scan_rvs(): + rng = shared(np.random.default_rng(123)) + test_rng = np.random.default_rng(123) + + def normal_fn(prev_rng): + next_rng, x = normal(rng=prev_rng).owner.outputs + return next_rng, x + + [rngs, xs] = scan( + fn=normal_fn, + init_states=[rng, None], + n_steps=5, + ) + fn = function([], xs, updates={rng: rngs[-1]}) + + for i in range(3): + res = fn() + np.testing.assert_almost_equal(res, test_rng.normal(size=5)) + + def test_map(): xs = vector("xs") ys = map(