diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 8aecf1a902..dc661fe0de 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -449,7 +449,9 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): JAX = Mode( JAXLinker(), - RewriteDatabaseQuery(include=["fast_run", "jax"], exclude=["cxx_only", "BlasOpt"]), + RewriteDatabaseQuery( + include=["fast_run", "jax"], exclude=["cxx_only", "BlasOpt", "not_jax"] + ), ) NUMBA = Mode( NumbaLinker(), diff --git a/pytensor/link/jax/dispatch/__init__.py b/pytensor/link/jax/dispatch/__init__.py index efcac788cf..2f08bbbb94 100644 --- a/pytensor/link/jax/dispatch/__init__.py +++ b/pytensor/link/jax/dispatch/__init__.py @@ -12,5 +12,6 @@ import pytensor.link.jax.dispatch.random import pytensor.link.jax.dispatch.elemwise import pytensor.link.jax.dispatch.scan +import pytensor.link.jax.dispatch.loop # isort: on diff --git a/pytensor/link/jax/dispatch/loop.py b/pytensor/link/jax/dispatch/loop.py new file mode 100644 index 0000000000..d17d740d15 --- /dev/null +++ b/pytensor/link/jax/dispatch/loop.py @@ -0,0 +1,61 @@ +import jax +from jax.tree_util import tree_flatten, tree_unflatten + +from pytensor.compile.mode import get_mode +from pytensor.link.jax.dispatch.basic import jax_funcify +from pytensor.loop.op import Scan +from pytensor.typed_list import TypedListType + + +@jax_funcify.register(Scan) +def jax_funcify_Scan(op, node, global_fgraph, **kwargs): + # TODO: Rewrite as a while loop if only last states are used + if op.has_while_condition: + raise NotImplementedError( + "Scan ops with while condition cannot be transpiled JAX" + ) + + # Apply inner rewrites + # TODO: Not sure this is the right place to do this, should we have a rewrite that + # explicitly triggers the optimization of the inner graphs of Scan? + update_fg = op.update_fg.clone() + rewriter = get_mode("JAX").optimizer + rewriter(update_fg) + + jaxified_scan_inner_fn = jax_funcify(update_fg, **kwargs) + + # Only include the intermediate states that are used elsewhere + used_traces_idxs = [ + i + for i, trace in enumerate(node.outputs[op.n_states :]) + if global_fgraph.clients[trace] + ] + + def scan(max_iters, *outer_inputs): + states = outer_inputs[: op.n_states] + constants = outer_inputs[op.n_states :] + + def scan_fn(carry, _): + resume, *carry = jaxified_scan_inner_fn(*carry, *constants) + assert resume + carry = list(carry) + # Return states as both carry and output to be appended + return carry, [c for i, c in enumerate(carry) if i in used_traces_idxs] + + states, traces = jax.lax.scan( + scan_fn, init=list(states), xs=None, length=max_iters + ) + final_traces = [None] * len(states) + for idx, trace in zip(used_traces_idxs, traces): + if isinstance(op.trace_types[idx], TypedListType): + flattened_trace, treedef = tree_flatten(trace) + transposed_trace = [ + tree_unflatten(treedef, l) for l in zip(*flattened_trace) + ] + final_traces[idx] = transposed_trace + else: + final_traces[idx] = trace + + return *states, *final_traces + + return scan diff --git a/pytensor/link/utils.py b/pytensor/link/utils.py index fd76e1278e..09d70756f7 100644 --- a/pytensor/link/utils.py +++ b/pytensor/link/utils.py @@ -736,6 +736,7 @@ def fgraph_to_python( global_env = {} body_assigns = [] + kwargs.setdefault("global_fgraph", fgraph) for node in order: compiled_func = op_conversion_fn( node.op, node=node, storage_map=storage_map, **kwargs diff --git a/pytensor/loop/__init__.py b/pytensor/loop/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pytensor/loop/basic.py b/pytensor/loop/basic.py new file mode 100644 index 0000000000..cbe3efa52b --- /dev/null +++ b/pytensor/loop/basic.py @@ -0,0 +1,199 @@ +import functools +from typing import List, Union + +import numpy as np + +from pytensor import Variable, as_symbolic, clone_replace +from pytensor.graph import FunctionGraph +from pytensor.graph.basic import Constant, truncated_graph_inputs +from pytensor.loop.op import Scan +from pytensor.scan.utils import until +from pytensor.tensor import as_tensor, constant, empty_like, minimum + + +def scan( + fn, + init_states=None, + sequences=None, + non_sequences=None, + n_steps=None, + go_backwards=False, +) -> Union[Variable, List[Variable]]: + if sequences is None and n_steps is None: + raise ValueError("Must provide n_steps when scanning without sequences") + + if init_states is None: + init_states = [] + else: + if not isinstance(init_states, (tuple, list)): + init_states = [init_states] + init_states = [as_symbolic(i) if i is not None else None for i in init_states] + + if sequences is None: + sequences = [] + else: + if not isinstance(sequences, (tuple, list)): + sequences = [sequences] + sequences = [as_tensor(s) for s in sequences] + + if sequences: + leading_dims = [seq.shape[0] for seq in sequences] + shortest_dim = functools.reduce(minimum, leading_dims) + if n_steps is None: + n_steps = shortest_dim + else: + n_steps = minimum(n_steps, shortest_dim) + + if non_sequences is None: + non_sequences = [] + else: + if not isinstance(non_sequences, (tuple, list)): + non_sequences = [non_sequences] + non_sequences = [as_symbolic(n) for n in non_sequences] + + # Create dummy inputs for the init state. The user function should not + # draw any relationship with the outer initial states, since these are only + # valid in the first iteration + inner_states = [i.type() if i is not None else None for i in init_states] + + # Create subsequence inputs for the inner function + idx = constant(0, dtype="int64", name="idx") + symbolic_idx = idx.type(name="idx") + subsequences = [s[symbolic_idx] for s in sequences] + + # Call user function to retrieve inner outputs. We use the same order as the old Scan, + # although inner_states + subsequences + non_sequences seems more intuitive, + # since subsequences are just a fancy non_sequence + # We don't pass the non-carried outputs [init is None] to the inner function + fn_inputs = ( + subsequences + [i for i in inner_states if i is not None] + non_sequences + ) + fn_outputs = fn(*fn_inputs) + if not isinstance(fn_outputs, (tuple, list)): + fn_outputs = [fn_outputs] + next_states = [out for out in fn_outputs if not isinstance(out, until)] + + if len(next_states) > len(init_states): + if not init_states: + init_states = [None] * len(next_states) + inner_states = init_states + else: + raise ValueError( + "Please provide None as `init` for any output that is not carried over (i.e. it behaves like a map) " + ) + + # Replace None init by dummy empty tensors + prev_states = [] + prev_inner_states = [] + for i, (init_state, inner_state, next_state) in enumerate( + zip(init_states, inner_states, next_states) + ): + if init_state is None: + # next_state may reference idx. We replace that by the initial value, + # so that the shape of the dummy init state does not depend on it. + [next_state] = clone_replace( + output=[next_state], replace={symbolic_idx: idx} + ) + init_state = empty_like(next_state) + init_state.name = "empty_init_state" + inner_state = init_state.type(name="dummy_state") + prev_states.append(init_state) + prev_inner_states.append(inner_state) + + # Flip until to while condition + while_condition = [~out.condition for out in fn_outputs if isinstance(out, until)] + if not while_condition: + while_condition = [as_tensor(np.array(True))] + if len(while_condition) > 1: + raise ValueError("Only one until condition can be returned") + + fgraph_inputs = [symbolic_idx] + prev_inner_states + sequences + non_sequences + fgraph_outputs = while_condition + [symbolic_idx + 1] + next_states + + all_fgraph_inputs = truncated_graph_inputs( + fgraph_outputs, ancestors_to_include=fgraph_inputs + ) + extra_fgraph_inputs = [ + inp + for inp in all_fgraph_inputs + if (not isinstance(inp, Constant) and inp not in fgraph_inputs) + ] + fgraph_inputs = fgraph_inputs + extra_fgraph_inputs + update_fg = FunctionGraph(inputs=fgraph_inputs, outputs=fgraph_outputs) + + scan_op = Scan(update_fg=update_fg) + scan_outs = scan_op( + n_steps, idx, *prev_states, *sequences, *non_sequences, *extra_fgraph_inputs + ) + assert isinstance(scan_outs, list) + # Don't return the last states or the trace for the inner index + traces = scan_outs[scan_op.n_states + 1 :] + if len(traces) == 1: + return traces[0] + return traces + + +def map( + fn, + sequences, + non_sequences=None, + go_backwards=False, +): + traces = scan( + fn=fn, + sequences=sequences, + non_sequences=non_sequences, + go_backwards=go_backwards, + ) + return traces + + +def reduce( + fn, + init_states, + sequences, + non_sequences=None, + go_backwards=False, +): + traces = scan( + fn=fn, + init_states=init_states, + sequences=sequences, + non_sequences=non_sequences, + go_backwards=go_backwards, + ) + if not isinstance(traces, list): + return traces[-1] + return [trace[-1] for trace in traces] + + +def filter( + fn, + sequences, + non_sequences=None, + go_backwards=False, +): + if not isinstance(sequences, (tuple, list)): + sequences = [sequences] + + masks = scan( + fn=fn, + sequences=sequences, + non_sequences=non_sequences, + go_backwards=go_backwards, + ) + + if not isinstance(masks, list): + masks = [masks] * len(sequences) + elif len(masks) != len(sequences): + raise ValueError( + "filter fn must return one variable or len(sequences), but it returned {len(masks)}" + ) + if not all(mask.dtype == "bool" for mask in masks): + raise TypeError("The output of filter fn should be a boolean variable") + + filtered_sequences = [seq[mask] for seq, mask in zip(sequences, masks)] + + if len(filtered_sequences) == 1: + return filtered_sequences[0] + return filtered_sequences diff --git a/pytensor/loop/op.py b/pytensor/loop/op.py new file mode 100644 index 0000000000..565d2a3b7c --- /dev/null +++ b/pytensor/loop/op.py @@ -0,0 +1,478 @@ +from typing import Optional + +import numpy as np + +from pytensor import In, Out +from pytensor.compile import optdb, pfunc +from pytensor.graph import Apply, FunctionGraph, Op, Type, node_rewriter +from pytensor.graph.rewriting.basic import in2out +from pytensor.scalar import constant +from pytensor.tensor import add, and_, empty, get_scalar_constant_value, set_subtensor +from pytensor.tensor.exceptions import NotScalarConstantError +from pytensor.tensor.shape import Shape_i +from pytensor.tensor.subtensor import Subtensor, get_idx_list +from pytensor.tensor.type import DenseTensorType, TensorType +from pytensor.tensor.type_other import NoneTypeT +from pytensor.typed_list import GetItem, TypedListType, append, make_empty_list + + +def validate_loop_update_types(update): + assert update.outputs[0].type.dtype == "bool" + for i, (input_state, output_state) in enumerate( + zip(update.inputs, update.outputs[1:]) + ): + if input_state.type != output_state.type: + raise TypeError( + f"The {i}-th input and output states of the inner loop function have different types: " + f"{input_state.type} vs {output_state.type}." + ) + + +class Loop(Op): + """Represent a do-while loop. + + We represent the loop body as an inner FunctionGraph, which + computes the next state and whether the loop should continue. + + Roughly equivalent to + ``` + def loop(fn, initial_state, constants): + state = initial_state + while True: + resume, state = fn(i, state, *constants) + if not resume: + break + return state + ``` + Multiple initial states and constants can be provided + """ + + def __init__( + self, + update_fg: FunctionGraph, # (*state, *consts) -> (bool, *state) + reverse_fg: Optional[FunctionGraph] = None, + ): + validate_loop_update_types(update_fg) + self.state_types = [out.type for out in update_fg.outputs[1:]] + self.const_types = [ + inp.type for inp in update_fg.inputs[len(self.state_types) :] + ] + self.update_fg = update_fg + self.reverse_fg = reverse_fg + self._fn = None + + @property + def fn(self): + """Lazily compile the inner update function graph.""" + if self._fn is not None: + return self._fn + + fgraph = self.update_fg + wrapped_inputs = [In(x, borrow=True) for x in fgraph.inputs] + wrapped_outputs = [Out(x, borrow=False) for x in fgraph.outputs] + + self._fn = pfunc( + wrapped_inputs, + wrapped_outputs, + mode="FAST_RUN", # TODO: Figure this out + accept_inplace=False, + on_unused_input="ignore", + fgraph=fgraph, + ) + return self._fn + + def make_node(self, *inputs): + assert len(inputs) == len(self.state_types) + len(self.const_types) + + states = inputs[: len(self.state_types)] + states = [ + inp_type.filter_variable(inp) + for inp_type, inp in zip(self.state_types, states) + ] + + consts = inputs[len(self.state_types) :] + consts = [ + inp_type.filter_variable(inp) + for inp_type, inp in zip(self.const_types, consts) + ] + + return Apply( + self, + [*states, *consts], + [state_type() for state_type in self.state_types], + ) + + def infer_shape(self, fgraph, node, input_shapes): + return input_shapes[: len(self.state_types)] + + def perform(self, node, inputs, output_storage): + update_fn = self.fn + + states = inputs[: len(self.state_types)] + consts = inputs[len(self.state_types) :] + resume = True + while resume: + resume, *states = update_fn(*states, *consts) + + for i, state in enumerate(states): + output_storage[i][0] = state + + def L_Op(self, *args): + if not self.reverse_fg: + raise NotImplementedError() + # Use L_Op of self.reverse_fg + ... + + def R_Op(self, *args): + # Use R_op of self.update_fg + ... + + +class Scan(Op): + """Represent a scan. + + This Op can be thought of as a loop that collects intermediate steps + + Roughly equivalent to + ``` + def scan(fn, initial_states, constants, max_iters): + traces = [[]*len(initial_states)] + states = initial_states + for i in range(max_iters): + resume, states = fn(*states, *constants) + for trace, state in zip(traces, states): + trace.append(state) + if not resume: + break + return states, traces + ``` + Not all types of states can be collected, for instance RandomGenerator. For these + `None` is returned in place of the respective traces + + This Op must always be converted to a Loop during compilation. + """ + + def __init__( + self, + update_fg: FunctionGraph, # (*state, *consts) -> (bool, *state) + reverse_fg: Optional[FunctionGraph] = None, + ): + validate_loop_update_types(update_fg) + + self.state_types = [out.type for out in update_fg.outputs[1:]] + self.n_states = len(self.state_types) + self.trace_types: list[Type] = [] + for state_type in self.state_types: + # TODO: Accommodate SparseTensors and Scalars + if isinstance(state_type, DenseTensorType): + self.trace_types.append( + DenseTensorType( + shape=(None, *state_type.shape), dtype=state_type.dtype + ) + ) + else: + self.trace_types.append(TypedListType(state_type)) + + self.constant_types = [inp.type for inp in update_fg.inputs[self.n_states :]] + self.n_constants = len(self.constant_types) + + self.update_fg = update_fg.clone(check_integrity=False) + self.reverse_fg = ( + reverse_fg.clone(check_integrity=False) if reverse_fg is not None else None + ) + + # It's more conservative to assume the Op has a while condition + self.has_while_condition = True + try: + self.has_while_condition = not get_scalar_constant_value( + update_fg.outputs[0] + ) + except NotScalarConstantError: + pass + + def make_node(self, max_iters, *inputs): + assert len(inputs) == self.n_states + self.n_constants + + max_iters = TensorType(dtype="int64", shape=()).filter_variable(max_iters) + + states = inputs[: self.n_states] + states = [ + inp_type.filter_variable(inp) + for inp_type, inp in zip(self.state_types, states) + ] + + constants = inputs[self.n_states :] + constants = [ + inp_type.filter_variable(inp) + for inp_type, inp in zip(self.constant_types, constants) + ] + + # If there is no while condition, `max_iters` exclusively defines the number of iterations + # If this value is constant, we can get static type shapes for the leading dimensions of traces + trace_types = self.trace_types + if not self.has_while_condition: + try: + n_iters = int(get_scalar_constant_value(max_iters)) + except NotScalarConstantError: + pass + else: + trace_types = [] + for trace_type in self.trace_types: + if isinstance(trace_type, DenseTensorType): + trace_types.append( + DenseTensorType( + dtype=trace_type.dtype, + shape=(n_iters, *trace_type.shape[1:]), + ) + ) + else: + trace_types.append(trace_type) + + return Apply( + self, + [max_iters, *states, *constants], + [output_type() for output_type in self.state_types + trace_types], + ) + + def infer_shape(self, fgraph, node, input_shapes): + # If there is a while condition, `max_iters` provides only the upper bound for the number of iterations + if self.has_while_condition: + # find the first non-None trace + trace_out = next( + trace + for trace in node.outputs[self.n_states :] + if not isinstance(trace.type, NoneTypeT) + ) + n_iters = Shape_i(0)(trace_out) + else: + n_iters = node.inputs[0] # max_iters + + state_shapes = input_shapes[1 : self.n_states + 1] + trace_shapes = [ + (n_iters, *state_shape) if state_shape is not None else None + for state_shape in state_shapes + ] + return state_shapes + trace_shapes + + def do_constant_folding(self, fgraph, node): + return False + + def perform(self, node, inputs, output_storage): + raise RuntimeError("Scan Op should not be present in compiled graph") + + def L_op(self, *args): + # Use trace outputs + ... + + def R_op(self, *args): + # Use R_op of self.update + ... + + +@node_rewriter([Scan]) +def scan_to_loop(fgraph, node): + """Rewrite a Scan Op into a Loop Op + + It roughly creates the following computational graph + ``` + + def scan(fn, idx, initial_states, constants, max_iters): + idx = 0 + states = initial_states + traces = [empty(max_iters, *initial_state.shape) for initial_state in initial_states] + while True: + resume, states, fn(*states, *traces, *constants) + for trace, state in zip(traces, states): + trace[idx] = state + idx += 1 + if not resume or idx >= max_iters: + break + traces = [trace[: idx] for trace in traces] + return states, traces + ``` + + Traces that are not used anywhere in the graph are omitted from the final Loop + + """ + op: Scan = node.op # type: ignore + + old_states = node.outputs[: op.n_states] + old_traces = node.outputs[op.n_states :] + + # Only include the intermediate states that are used elsewhere + used_traces_idxs = [ + i + for i, trace in enumerate(node.outputs[op.n_states :]) + if fgraph.clients[trace] + ] + + # Inputs to the new Loop + max_iters = node.inputs[0] + init_states = node.inputs[1 : 1 + op.n_states] + init_traces = [ + empty( + (max_iters, *tuple(init_states[trace_idx].shape)), + dtype=init_states[trace_idx].dtype, + ) + if isinstance(init_states[trace_idx].type, DenseTensorType) + else make_empty_list(init_states[trace_idx].type) + for trace_idx in used_traces_idxs + ] + constants = node.inputs[1 + op.n_states :] + + update_fg = op.update_fg.clone(check_integrity=False) + + # Check if inner_fg computes an index already, otherwise create a new one + has_idx = False + if len(node.inputs) > 1: + try: + outer_inp = node.inputs[1] + outer_is_zero = get_scalar_constant_value(outer_inp) == 0 + except NotScalarConstantError: + pass + else: + if ( + outer_is_zero + and len(update_fg.inputs) > 0 + and len(update_fg.outputs) > 1 + ): + inner_out = update_fg.outputs[1] + if ( + inner_out.owner is not None + and inner_out.owner.op == add + and len(inner_out.owner.inputs) == 2 + ): + left, right = inner_out.owner.inputs + if left is update_fg.inputs[0]: + try: + has_idx = ( + get_scalar_constant_value( + right, only_process_constants=True + ) + == 1 + ) + except NotScalarConstantError: + pass + + if has_idx: + init_idx = outer_inp + inner_idx = inner_out.owner.inputs[0] + inner_next_idx = inner_out + if not has_idx: + init_idx = constant(np.array(0, dtype="int64"), name="idx") + inner_idx = init_idx.type() + inner_idx.name = "idx" + inner_next_idx = inner_idx + 1 + inner_next_idx.name = "next_idx" + + # Inner traces + inner_states = update_fg.inputs[: op.n_states] + inner_traces = [init_trace.type() for init_trace in init_traces] + for s, t in zip(inner_states, inner_traces): + t.name = "trace" + if s.name: + t.name = "_".join((t.name, s.name)) + + inner_constants = update_fg.inputs[op.n_states :] + + # Inner while condition + inner_while_cond, *inner_next_states = update_fg.outputs + inner_next_traces = [ + set_subtensor(prev_trace[inner_idx], inner_next_states[trace_idx]) + if isinstance(prev_trace.type, DenseTensorType) + else append(prev_trace, inner_next_states[trace_idx]) + for trace_idx, prev_trace in zip(used_traces_idxs, inner_traces) + ] + for t in inner_next_traces: + t.name = "next_trace" + inner_max_iters = max_iters.type() + inner_while_cond = and_(inner_while_cond, inner_next_idx < inner_max_iters) + inner_while_cond.name = "while(?)" + + if not has_idx: + init_states = [init_idx] + init_states + inner_states = [inner_idx] + inner_states + inner_next_states = [inner_next_idx] + inner_next_states + + new_update_fg = FunctionGraph( + inputs=[ + *inner_states, + *inner_traces, + *inner_constants, + inner_max_iters, + ], + outputs=[ + inner_while_cond, + *inner_next_states, + *inner_next_traces, + ], + ) + + # TODO: Implement Reverse? + loop_op = Loop(update_fg=new_update_fg) + + new_outs = loop_op(*init_states, *init_traces, *constants, max_iters) + if has_idx: + # idx was part of the original scan, and therefore has a corresponding trace + final_idx = new_outs[0] + else: + final_idx, *new_outs = new_outs + new_states = new_outs[: op.n_states] + new_traces = new_outs[op.n_states :] + + replacements = dict(zip(old_states, new_states)) + for trace_idx, new_trace in zip(used_traces_idxs, new_traces): + # If there is no while condition, the whole trace will be used + if op.has_while_condition and isinstance(new_trace.type, DenseTensorType): + new_trace = new_trace[:final_idx] + replacements[old_traces[trace_idx]] = new_trace + return replacements + + +# TODO: Create new Loop dataset +# Needs to be executed after `local_shape_to_shape_i`, otherwise shape graphs +# cannot be properly replaced +optdb.register( + "scan_to_loop", + in2out(scan_to_loop), + "fast_compile", + "fast_run", + "not_jax", + position=1.0, +) + + +@node_rewriter([Scan]) +def scan_view_last_state(fgraph, node): + """Replace trace[-1] by the last state output of a Scan node""" + replacements = {} + for final_state, trace in zip( + node.outputs[: node.op.n_states], node.outputs[node.op.n_states :] + ): + clients = fgraph.clients[trace] + for client, _ in clients: + if client == "output": + continue + if isinstance(client.op, (Subtensor, GetItem)): + if isinstance(client.op, Subtensor): + idxs = get_idx_list(client.inputs, client.op.idx_list) + if len(idxs) == 1: + idx = idxs[0] + else: + idx = client.inputs[1] + try: + last_index = get_scalar_constant_value(idx) == -1 + except NotScalarConstantError: + continue + if last_index: + replacements[client.default_output()] = final_state + return replacements + + +optdb.register( + "scan_view_last_state", + in2out(scan_view_last_state), + "fast_compile", + "fast_run", + position=0.999, +) diff --git a/pytensor/typed_list/basic.py b/pytensor/typed_list/basic.py index 1b836c0bfa..f458779064 100644 --- a/pytensor/typed_list/basic.py +++ b/pytensor/typed_list/basic.py @@ -19,6 +19,11 @@ def __getitem__(self, index): def __len__(self): return length(self) + def __bool__(self): + # Truthiness of typedLists cannot depend on length, + # just like truthiness of TensorVariables does not depend on size or contents + return True + def append(self, toAppend): return append(self, toAppend) @@ -677,3 +682,18 @@ def perform(self, node, inputs, outputs): All PyTensor variables must have the same type. """ + + +class MakeEmptyList(Op): + __props__ = () + + def make_node(self, ttype): + tl = TypedListType(ttype)() + return Apply(self, [], [tl]) + + def perform(self, node, inputs, outputs): + (out,) = outputs + out[0] = [] + + +make_empty_list = MakeEmptyList() diff --git a/tests/link/jax/test_basic.py b/tests/link/jax/test_basic.py index e20bd255fb..a86c2636c9 100644 --- a/tests/link/jax/test_basic.py +++ b/tests/link/jax/test_basic.py @@ -5,15 +5,13 @@ import pytest from pytensor.compile.function import function -from pytensor.compile.mode import Mode +from pytensor.compile.mode import get_mode from pytensor.compile.sharedvalue import SharedVariable, shared from pytensor.configdefaults import config from pytensor.graph.basic import Apply from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op, get_test_value -from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.ifelse import ifelse -from pytensor.link.jax import JAXLinker from pytensor.raise_op import assert_op from pytensor.tensor.type import dscalar, scalar, vector @@ -27,9 +25,9 @@ def set_pytensor_flags(): jax = pytest.importorskip("jax") -opts = RewriteDatabaseQuery(include=["jax"], exclude=["cxx_only", "BlasOpt"]) -jax_mode = Mode(JAXLinker(), opts) -py_mode = Mode("py", opts) +# We assume that the JAX mode includes all the rewrites needed to transpile JAX graphs +jax_mode = get_mode("JAX") +py_mode = get_mode("FAST_COMPILE") def compare_jax_and_py( diff --git a/tests/link/jax/test_loop.py b/tests/link/jax/test_loop.py new file mode 100644 index 0000000000..f7efc527fc --- /dev/null +++ b/tests/link/jax/test_loop.py @@ -0,0 +1,114 @@ +import numpy as np +import pytest +from jax.tree_util import tree_leaves + +from pytensor import config, function, shared +from pytensor.graph import FunctionGraph +from pytensor.loop.basic import scan +from pytensor.scan import until +from pytensor.tensor import scalar, vector, zeros +from pytensor.tensor.random import normal +from tests.link.jax.test_basic import compare_jax_and_py + + +def test_scan_with_single_sequence(): + xs = vector("xs") + ys = scan(lambda x: x * 100, sequences=[xs]) + + out_fg = FunctionGraph([xs], [ys]) + compare_jax_and_py(out_fg, [np.arange(10, dtype=config.floatX)]) + + +def test_scan_with_single_sequence_shortened_by_nsteps(): + xs = vector("xs", shape=(10,)) # JAX needs the length to be constant + ys = scan( + lambda x: x * 100, + sequences=[xs], + n_steps=9, + ) + + out_fg = FunctionGraph([xs], [ys]) + compare_jax_and_py(out_fg, [np.arange(10, dtype=config.floatX)]) + + +def test_scan_with_multiple_sequences(): + # JAX can only handle constant n_steps + xs = vector("xs", shape=(10,)) + ys = vector("ys", shape=(10,)) + zs = scan( + fn=lambda x, y: x * y, + sequences=[xs, ys], + ) + + out_fg = FunctionGraph([xs, ys], [zs]) + compare_jax_and_py( + out_fg, [np.arange(10, dtype=xs.dtype), np.arange(10, dtype=ys.dtype)] + ) + + +def test_scan_with_carried_and_non_carried_states(): + x = scalar("x") + [ys1, ys2] = scan( + fn=lambda xtm1: (xtm1 + 1, (xtm1 + 1) * 2), + init_states=[x, None], + n_steps=10, + ) + out_fg = FunctionGraph([x], [ys1, ys2]) + compare_jax_and_py(out_fg, [-1]) + + +def test_scan_with_sequence_and_carried_state(): + xs = vector("xs") + ys = scan( + fn=lambda x, ytm1: (ytm1 + 1) * x, + init_states=[zeros(())], + sequences=[xs], + ) + out_fg = FunctionGraph([xs], [ys]) + compare_jax_and_py(out_fg, [[1, 2, 3]]) + + +def test_scan_with_rvs(): + rng = shared(np.random.default_rng(123)) + + [rngs, xs] = scan( + fn=lambda prev_rng: normal(rng=prev_rng).owner.outputs, + init_states=[rng, None], + n_steps=10, + ) + final_rng = rngs[-1] + + # First without updates + fn = function([], xs, mode="JAX", updates=None) + res1 = fn() + res2 = fn() + assert not set(tuple(np.array(res1))) ^ set(tuple(np.array(res2))) + + # Now with updates + fn = function([], xs, mode="JAX", updates={rng: final_rng}) + res1 = fn() + res2 = fn() + assert not set(tuple(np.array(res1))) & set(tuple(np.array(res2))) + + # Test traced rngs + fn = function([], [rngs, final_rng], mode="JAX") + rngs_res, final_rng_res = fn() + assert isinstance(rngs_res, list) and len(rngs_res) == 10 + assert [np.array(v).tolist() for v in tree_leaves(rngs_res[-1])] == [ + np.array(v).tolist() for v in tree_leaves(final_rng_res) + ] + + +def test_while_scan_fails(): + xs = scan( + fn=lambda x: (x + 1, until((x + 1) >= 9)), + init_states=[-1], + n_steps=20, + ) + + out_fg = FunctionGraph([], [xs]) + with pytest.raises( + NotImplementedError, + match="Scan ops with while condition cannot be transpiled JAX", + ): + compare_jax_and_py(out_fg, []) diff --git a/tests/loop/__init__.py b/tests/loop/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/loop/test_basic.py b/tests/loop/test_basic.py new file mode 100644 index 0000000000..e66ee31793 --- /dev/null +++ b/tests/loop/test_basic.py @@ -0,0 +1,146 @@ +import numpy as np + +import pytensor +from pytensor import config, function, grad, shared +from pytensor.loop.basic import filter, map, reduce, scan +from pytensor.scan import until +from pytensor.tensor import arange, eq, scalar, vector, zeros +from pytensor.tensor.random import normal + + +def test_scan_with_sequences(): + xs = vector("xs") + ys = vector("ys") + zs = scan( + fn=lambda x, y: x * y, + sequences=[xs, ys], + ) + pytensor.dprint(ys, print_type=True) + np.testing.assert_almost_equal( + zs.eval( + { + xs: np.arange(10, dtype=config.floatX), + ys: np.arange(10, dtype=config.floatX), + } + ), + np.arange(10) ** 2, + ) + + +def test_scan_with_carried_and_non_carried_states(): + x = scalar("x") + [ys1, ys2] = scan( + fn=lambda xtm1: (xtm1 + 1, (xtm1 + 1) * 2), + init_states=[x, None], + n_steps=10, + ) + fn = function([x], [ys1, ys2]) + res = fn(-1) + np.testing.assert_almost_equal(res[0], np.arange(10)) + np.testing.assert_almost_equal(res[1], np.arange(10) * 2) + + +def test_scan_with_sequence_and_carried_state(): + xs = vector("xs") + ys = scan( + fn=lambda x, ytm1: (ytm1 + 1) * x, + init_states=[zeros(())], + sequences=[xs], + ) + fn = function([xs], ys) + np.testing.assert_almost_equal(fn([1, 2, 3]), [1, 4, 15]) + + +def test_scan_taking_grads_wrt_non_sequence(): + # Tests sequence + non-carried state + xs = vector("xs") + ys = xs**2 + + J = scan( + lambda i, ys, xs: grad(ys[i], wrt=xs), + sequences=arange(ys.shape[0]), + non_sequences=[ys, xs], + ) + + f = pytensor.function([xs], J) + np.testing.assert_array_equal(f([4, 4]), np.c_[[8, 0], [0, 8]]) + + +def test_scan_taking_grads_wrt_sequence(): + # This is not possible with the old Scan + xs = vector("xs") + ys = xs**2 + + J = scan( + lambda y, xs: grad(y, wrt=xs), + sequences=[ys], + non_sequences=[xs], + ) + + f = pytensor.function([xs], J) + np.testing.assert_array_equal(f([4, 4]), np.c_[[8, 0], [0, 8]]) + + +def test_while_scan(): + xs = scan( + fn=lambda x: (x + 1, until((x + 1) >= 9)), + init_states=[-1], + n_steps=20, + ) + + f = pytensor.function([], xs) + np.testing.assert_array_equal(f(), np.arange(10)) + + +def test_scan_rvs(): + rng = shared(np.random.default_rng(123)) + test_rng = np.random.default_rng(123) + + def normal_fn(prev_rng): + next_rng, x = normal(rng=prev_rng).owner.outputs + return next_rng, x + + [rngs, xs] = scan( + fn=normal_fn, + init_states=[rng, None], + n_steps=5, + ) + fn = function([], xs, updates={rng: rngs[-1]}) + + for i in range(3): + res = fn() + np.testing.assert_almost_equal(res, test_rng.normal(size=5)) + + +def test_map(): + xs = vector("xs") + ys = map( + fn=lambda x: x * 100, + sequences=xs, + ) + np.testing.assert_almost_equal( + ys.eval({xs: np.arange(10, dtype=config.floatX)}), np.arange(10) * 100 + ) + + +def test_reduce(): + xs = vector("xs") + y = reduce( + fn=lambda x, acc: acc + x, + init_states=zeros(()), + sequences=xs, + ) + np.testing.assert_almost_equal( + y.eval({xs: np.arange(10, dtype=config.floatX)}), np.arange(10).cumsum()[-1] + ) + + +def test_filter(): + xs = vector("xs") + ys = filter( + fn=lambda x: eq(x % 2, 0), + sequences=xs, + ) + np.testing.assert_array_equal( + ys.eval({xs: np.arange(0, 20, dtype=config.floatX)}), np.arange(0, 20, 2) + ) diff --git a/tests/loop/test_op.py b/tests/loop/test_op.py new file mode 100644 index 0000000000..c0e46aa7b8 --- /dev/null +++ b/tests/loop/test_op.py @@ -0,0 +1,186 @@ +import numpy as np + +import pytensor +from pytensor import config, function, shared +from pytensor.compile import DeepCopyOp +from pytensor.graph import FunctionGraph +from pytensor.graph.rewriting.basic import in2out +from pytensor.loop.op import Loop, Scan, scan_view_last_state +from pytensor.tensor import constant, empty, lscalar, scalar, vector +from pytensor.tensor.random import normal +from pytensor.tensor.random.type import RandomGeneratorType +from pytensor.tensor.subtensor import Subtensor +from pytensor.typed_list import TypedListType + + +def test_loop_basic(): + i = lscalar("i") + x = scalar("x") + update_fg = FunctionGraph([i, x], [(i + 1) < 10, i + 1, x + 2]) + + loop_op = Loop(update_fg=update_fg) + assert len(loop_op.state_types) == 2 + assert len(loop_op.const_types) == 0 + _, y = loop_op(np.array(0, dtype="int64"), x) + assert y.eval({x: 0}) == 20 + + +def test_loop_with_constant(): + i = lscalar("i") + x = scalar("x") + const = scalar("const") + update_fg = FunctionGraph([i, x, const], [(i + 1) < 10, i + 1, x + const]) + + loop_op = Loop(update_fg=update_fg) + assert len(loop_op.state_types) == 2 + assert len(loop_op.const_types) == 1 + _, y = loop_op(np.array(0, dtype="int64"), x, const) + assert y.eval({x: 0, const: 2}) == 20 + + +def test_fori_scan(): + x = scalar("x") + update_fg = FunctionGraph([x], [constant(np.array(True)), x + 2]) + + n_iters = 10 + y, ys = Scan(update_fg=update_fg)(n_iters, x) + + fn = function([x], [y, ys]) + + subtensor_nodes = tuple( + node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Subtensor) + ) + assert len(subtensor_nodes) == 0 + loop_nodes = tuple( + node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Loop) + ) + assert len(loop_nodes) == 1 + (loop_node,) = loop_nodes + assert len(loop_node.outputs) == 3 + assert loop_node.outputs[0].type.shape == () + assert loop_node.outputs[1].type.shape == () + assert loop_node.outputs[2].type.shape == (10,) + + y_eval, ys_eval = fn(0) + np.testing.assert_array_equal(ys_eval, np.arange(2, 22, 2)) + np.testing.assert_array_equal(ys_eval[-1], y_eval) + + +def test_fori_scan_shape(): + x = scalar("x") + update_fg = FunctionGraph([x], [constant(np.array(True)), x + 2]) + + n_iters = 10 + _, ys = Scan(update_fg=update_fg)(n_iters, x) + + fn = function([x], ys.shape, on_unused_input="ignore") + nodes = tuple(fn.maker.fgraph.apply_nodes) + assert len(nodes) == 1 + assert isinstance(nodes[0].op, DeepCopyOp) + assert fn(0) == 10 + + +def test_while_scan(): + i = lscalar("i") + x = scalar("x") + update_fg = FunctionGraph([i, x], [(i + 1) < 10, i + 1, x + 2]) + + max_iters = 1000 + _, y, _, ys = Scan(update_fg=update_fg)(max_iters, np.array(0, dtype="int64"), x) + + fn = function([x], [y, ys]) + + subtensor_nodes = tuple( + node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Subtensor) + ) + assert len(subtensor_nodes) == 1 + loop_nodes = tuple( + node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Loop) + ) + assert len(loop_nodes) == 1 + (loop_node,) = loop_nodes + assert len(loop_node.outputs) == 3 + assert loop_node.outputs[0].type.shape == () + assert loop_node.outputs[1].type.shape == () + assert loop_node.outputs[2].type.shape == (1000,) + + y_eval, ys_eval = fn(0) + np.testing.assert_array_equal(ys_eval, np.arange(2, 22, 2)) + np.testing.assert_array_equal(ys_eval[-1], y_eval) + + +def test_while_scan_shape(): + i = lscalar("i") + x = scalar("x") + update_fg = FunctionGraph([i, x], [(i + 1) < 10, i + 1, x + 2]) + + max_iters = 1000 + _, _, _, ys = Scan(update_fg=update_fg)(max_iters, np.array(0, dtype="int64"), x) + + fn = function([x], ys.shape) + loop_nodes = tuple( + node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Loop) + ) + assert len(loop_nodes) == 1 + assert fn(0) == 10 + + +def test_foreach_scan(): + idx = scalar("idx", dtype="int64") + dummy_x0 = empty(()) + xs = vector("xs") + const = scalar("const") + update_fg = FunctionGraph( + [idx, dummy_x0, xs, const], [constant(np.array(True)), idx + 1, xs[idx] * const] + ) + + n_steps = xs.shape[0] + _, _, _, ys = Scan(update_fg=update_fg)(n_steps, 0, dummy_x0, xs, const) + + fn = pytensor.function([xs, const], ys) + + np.testing.assert_almost_equal( + fn(np.arange(10, dtype=config.floatX), 100), np.arange(10) * 100 + ) + + +def test_fori_random_scan(): + rng_test = np.random.default_rng(123) + rng_shared = shared(np.random.default_rng(123)) + n_iters = 5 + + dummy_init = empty(()) + rng = rng_shared.type() + update_fg = FunctionGraph( + [dummy_init, rng], + [constant(np.array(True)), *normal(rng=rng).owner.outputs[::-1]], + ) + + last_y, last_rng, ys, rngs = Scan(update_fg=update_fg)( + n_iters, dummy_init, rng_shared + ) + assert isinstance(last_rng.type, RandomGeneratorType) + assert isinstance(rngs.type, TypedListType) + assert isinstance(rngs.type.ttype, RandomGeneratorType) + + fn = function([], [ys, rngs], updates={rng_shared: last_rng}) + for i in range(2): + ys_res, rngs_res = fn() + for y_res, rng_res in zip(ys_res, rngs_res): + np.testing.assert_almost_equal(y_res, rng_test.normal()) + assert rng_res.__getstate__() == rng_test.__getstate__() + + +def test_scan_view_last_state(): + x = scalar("x") + update_fg = FunctionGraph([x], [x > 5, x + 2]) + + n_iters = 10 + y1, ys = Scan(update_fg=update_fg)(n_iters, x) + + y2 = ys[-1] + fgraph = FunctionGraph(outputs=[y2, ys], clone=False) + assert fgraph.outputs[0] is not y1 + in2out(scan_view_last_state).apply(fgraph) + assert fgraph.outputs[0] is y1 + assert fgraph.outputs[1] is ys diff --git a/tests/typed_list/test_type.py b/tests/typed_list/test_type.py index 41c1a8326b..a5281e21d1 100644 --- a/tests/typed_list/test_type.py +++ b/tests/typed_list/test_type.py @@ -150,3 +150,7 @@ def test_variable_is_Typed_List_variable(self): )() assert isinstance(mySymbolicVariable, TypedListVariable) + + def test_any(self): + tlist = TypedListType(TensorType(dtype="int64", shape=(None,)))() + assert any([tlist])