From 7c5c9a3d2f3aeabc83e6256fa9c0d538ead66b52 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 2 Aug 2023 14:59:22 +0200 Subject: [PATCH 1/5] . POC named tensors (cherry picked from commit 5b0c4726a0f7940fe7f68d4d07cb08c3656579dd) --- pytensor/xtensor/__init__.py | 8 + pytensor/xtensor/basic.py | 70 +++++ pytensor/xtensor/rewriting/__init__.py | 1 + pytensor/xtensor/rewriting/basic.py | 26 ++ pytensor/xtensor/rewriting/utils.py | 33 +++ pytensor/xtensor/type.py | 147 ++++++++++ pytensor/xtensor/variable.py | 356 +++++++++++++++++++++++++ 7 files changed, 641 insertions(+) create mode 100644 pytensor/xtensor/__init__.py create mode 100644 pytensor/xtensor/basic.py create mode 100644 pytensor/xtensor/rewriting/__init__.py create mode 100644 pytensor/xtensor/rewriting/basic.py create mode 100644 pytensor/xtensor/rewriting/utils.py create mode 100644 pytensor/xtensor/type.py create mode 100644 pytensor/xtensor/variable.py diff --git a/pytensor/xtensor/__init__.py b/pytensor/xtensor/__init__.py new file mode 100644 index 0000000000..c207db15c6 --- /dev/null +++ b/pytensor/xtensor/__init__.py @@ -0,0 +1,8 @@ +import warnings +import pytensor.xtensor.rewriting + +from pytensor.xtensor.variable import XTensorVariable, XTensorConstant, as_xtensor, as_xtensor_variable +from pytensor.xtensor.type import XTensorType + + +warnings.warn("xtensor module is experimental and full of bugs") \ No newline at end of file diff --git a/pytensor/xtensor/basic.py b/pytensor/xtensor/basic.py new file mode 100644 index 0000000000..ab1f74172e --- /dev/null +++ b/pytensor/xtensor/basic.py @@ -0,0 +1,70 @@ +from itertools import chain + +import pytensor.scalar as ps +from pytensor.graph import Apply, Op +import pytensor.xtensor as px +from pytensor.tensor import TensorType + + +class TensorFromXTensor(Op): + + def make_node(self, x) -> Apply: + if not isinstance(x.type, px.XTensorType): + raise TypeError(f"x must be have an XTensorType, got {type(x.type)}") + output = TensorType(x.type.dtype, shape=x.type.shape)() + return Apply(self, [x], [output]) + + def perform(self, node, inputs, output_storage) -> None: + [x] = inputs + output_storage[0][0] = x.copy() + + +tensor_from_xtensor = TensorFromXTensor() + + +class XTensorFromTensor(Op): + + __props__ = ("dims",) + + def __init__(self, dims): + super().__init__() + self.dims = dims + + def make_node(self, x) -> Apply: + if not isinstance(x.type, TensorType): + raise TypeError(f"x must be an TensorType type, got {type(x.type)}") + output = px.XTensorType(x.type.dtype, dims=self.dims, shape=x.type.shape)() + return Apply(self, [x], [output]) + + def perform(self, node, inputs, output_storage) -> None: + [x] = inputs + output_storage[0][0] = x.copy() + + +def xtensor_from_tensor(x, dims): + return XTensorFromTensor(dims=dims)(x) + + +class XElemwise(Op): + + __props__ = ("scalar_op",) + + def __init__(self, scalar_op): + super().__init__() + self.scalar_op = scalar_op + + def make_node(self, *inputs): + # TODO: Check dim lengths match + inputs = [px.as_xtensor_variable(inp) for inp in inputs] + # TODO: This ordering is different than what xarray does + unique_dims = sorted(set(chain.from_iterable(inp.type.dims for inp in inputs))) + # TODO: Fix dtype + output_type = px.XTensorType("float64", dims=unique_dims, shape=(None,) * len(unique_dims)) + outputs = [output_type() for _ in range(self.scalar_op.nout)] + return Apply(self, inputs, outputs) + + def perform(self, *args, **kwargs) -> None: + raise NotImplementedError("xtensor operations must be rewritten as tensor operations") + + +add = XElemwise(ps.add) diff --git a/pytensor/xtensor/rewriting/__init__.py b/pytensor/xtensor/rewriting/__init__.py new file mode 100644 index 0000000000..53a493f878 --- /dev/null +++ b/pytensor/xtensor/rewriting/__init__.py @@ -0,0 +1 @@ +import pytensor.xtensor.rewriting.basic \ No newline at end of file diff --git a/pytensor/xtensor/rewriting/basic.py b/pytensor/xtensor/rewriting/basic.py new file mode 100644 index 0000000000..d2c13cd379 --- /dev/null +++ b/pytensor/xtensor/rewriting/basic.py @@ -0,0 +1,26 @@ +from pytensor.graph import node_rewriter +from pytensor.tensor import expand_dims +from pytensor.tensor.elemwise import Elemwise +from pytensor.xtensor.basic import tensor_from_xtensor, XElemwise, xtensor_from_tensor +from pytensor.xtensor.rewriting.utils import register_xcanonicalize + + +@register_xcanonicalize +@node_rewriter(tracks=[XElemwise]) +def xelemwise_to_elemwise(fgraph, node): + # Convert inputs to TensorVariables and add broadcastable dims + output_dims = node.outputs[0].type.dims + + tensor_inputs = [] + for inp in node.inputs: + inp_dims = inp.type.dims + axis = [i for i, dim in enumerate(output_dims) if dim not in inp_dims] + tensor_inp = tensor_from_xtensor(inp) + tensor_inp = expand_dims(tensor_inp, axis) + tensor_inputs.append(tensor_inp) + + tensor_outs = Elemwise(scalar_op=node.op.scalar_op)(*tensor_inputs, return_list=True) + + # TODO: copy_stack_trace + new_outs = [xtensor_from_tensor(tensor_out, dims=output_dims) for tensor_out in tensor_outs] + return new_outs diff --git a/pytensor/xtensor/rewriting/utils.py b/pytensor/xtensor/rewriting/utils.py new file mode 100644 index 0000000000..f059b656cc --- /dev/null +++ b/pytensor/xtensor/rewriting/utils.py @@ -0,0 +1,33 @@ +from typing import Union + +from pytensor.compile import optdb +from pytensor.graph.rewriting.basic import NodeRewriter +from pytensor.graph.rewriting.db import RewriteDatabase, EquilibriumDB + + +optdb.register( + "xcanonicalize", + EquilibriumDB(ignore_newtrees=False), + "fast_run", + "fast_compile", + "xtensor", + position=0, +) + + +def register_xcanonicalize( + node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs +): + if isinstance(node_rewriter, str): + + def register(inner_rewriter: Union[RewriteDatabase, NodeRewriter]): + return register_xcanonicalize(inner_rewriter, node_rewriter, *tags, **kwargs) + + return register + + else: + name = kwargs.pop("name", None) or node_rewriter.__name__ + optdb["xtensor"].register( + name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs + ) + return node_rewriter diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py new file mode 100644 index 0000000000..0696233a8c --- /dev/null +++ b/pytensor/xtensor/type.py @@ -0,0 +1,147 @@ +from typing import Iterable, Optional, Union, Sequence, TypeVar + +import numpy as np + +import pytensor +from pytensor import scalar as aes +from pytensor.graph.basic import Variable +from pytensor.graph.type import HasDataType +from pytensor.tensor.type import TensorType + + +_XTensorTypeType = TypeVar("_XTensorTypeType", bound=TensorType) + + +class XTensorType(TensorType, HasDataType): + """A `Type` for sparse tensors. + + Notes + ----- + Currently, sparse tensors can only be matrices (i.e. have two dimensions). + + """ + + __props__ = ("dtype", "shape", "dims") + + def __init__( + self, + dtype: Union[str, np.dtype], + *, + dims: Sequence[str], + shape: Optional[Iterable[Optional[Union[bool, int]]]] = None, + name: Optional[str] = None, + ): + super().__init__(dtype, shape=shape, name=name) + if not isinstance(dims, (list, tuple)): + raise TypeError("dims must be a list or tuple") + dims = tuple(dims) + self.dims = dims + + def clone( + self, + dtype=None, + dims=None, + shape=None, + **kwargs, + ): + if dtype is None: + dtype = self.dtype + if dims is None: + dims = self.dims + if shape is None: + shape = self.shape + return type(self)(format, dtype, shape=shape, dims=dims, **kwargs) + + def filter(self, value, strict=False, allow_downcast=None): + # TODO: Implement this + return value + + if isinstance(value, Variable): + raise TypeError( + "Expected an array-like object, but found a Variable: " + "maybe you are trying to call a function on a (possibly " + "shared) variable instead of a numeric array?" + ) + + if ( + isinstance(value, self.format_cls[self.format]) + and value.dtype == self.dtype + ): + return value + + if strict: + raise TypeError( + f"{value} is not sparse, or not the right dtype (is {value.dtype}, " + f"expected {self.dtype})" + ) + + # The input format could be converted here + if allow_downcast: + sp = self.format_cls[self.format](value, dtype=self.dtype) + else: + data = self.format_cls[self.format](value) + up_dtype = aes.upcast(self.dtype, data.dtype) + if up_dtype != self.dtype: + raise TypeError(f"Expected {self.dtype} dtype but got {data.dtype}") + sp = data.astype(up_dtype) + + assert sp.format == self.format + + return sp + + def convert_variable(self, var): + # TODO: Implement this + return var + res = super().convert_variable(var) + + if res is None: + return res + + if not isinstance(res.type, type(self)): + return None + + if res.dims != self.dims: + # TODO: Does this make sense? + return None + + return res + + def __hash__(self): + return super().__hash__() ^ hash(self.dims) + + def __repr__(self): + # TODO: Add `?` for unknown shapes like `TensorType` does + return f"XTensorType({self.dtype}, {self.dims}, {self.shape})" + + def __eq__(self, other): + res = super().__eq__(other) + + if isinstance(res, bool): + return res and other.dims == self.dims + + return res + + def is_super(self, otype): + # TODO: Implement this + return True + + if not super().is_super(otype): + return False + + if self.dims == otype.dims: + return True + + return False + + +# TODO: Implement creater helper xtensor + +pytensor.compile.register_view_op_c_code( + XTensorType, + """ + Py_XDECREF(%(oname)s); + %(oname)s = %(iname)s; + Py_XINCREF(%(oname)s); + """, + 1, +) diff --git a/pytensor/xtensor/variable.py b/pytensor/xtensor/variable.py new file mode 100644 index 0000000000..bb2a18ac4b --- /dev/null +++ b/pytensor/xtensor/variable.py @@ -0,0 +1,356 @@ +import xarray as xr +import pytensor.xtensor.basic as xbasic + +from pytensor import _as_symbolic, Variable +from pytensor.graph import Apply, Constant +from pytensor.tensor import TensorVariable +from pytensor.tensor.utils import hash_from_ndarray +from pytensor.tensor.var import _tensor_py_operators, TensorConstant +from pytensor.utils import hash_from_code +from pytensor.xtensor.type import XTensorType, _XTensorTypeType + + +@_as_symbolic.register(xr.DataArray) +def as_symbolic_sparse(x, **kwargs): + return as_xtensor_variable(x, **kwargs) + + +def as_xtensor_variable(x, name=None, ndim=None, **kwargs): + """ + Wrapper around SparseVariable constructor to construct + a Variable with a sparse matrix with the same dtype and + format. + + Parameters + ---------- + x + A sparse matrix. + + Returns + ------- + object + SparseVariable version of `x`. + + """ + + # TODO + # Verify that sp is sufficiently sparse, and raise a + # warning if it is not + + if isinstance(x, Apply): + if len(x.outputs) != 1: + raise ValueError( + "It is ambiguous which output of a " + "multi-output Op has to be fetched.", + x, + ) + else: + x = x.outputs[0] + if isinstance(x, Variable): + if not isinstance(x.type, XTensorType): + raise TypeError(f"Variable type field must be a XTensorType, got {x.type}") + return x + try: + return constant(x, name=name) + except TypeError as err: + raise TypeError(f"Cannot convert {x} to XTensorType {type(x)}") from err + + +as_xtensor = as_xtensor_variable + + +def constant(x, name=None): + if not isinstance(x, xr.DataArray): + raise TypeError("xtensor.constant must be called on a Xarray DataArray") + try: + return XTensorConstant( + XTensorType(dtype=x.dtype, dims=x.dims, shape=x.shape), x.values.copy(), name=name + ) + except TypeError: + raise TypeError(f"Could not convert {x} to XTensorType") + + +# def sp_ones_like(x): +# """ +# Construct a sparse matrix of ones with the same sparsity pattern. +# +# Parameters +# ---------- +# x +# Sparse matrix to take the sparsity pattern. +# +# Returns +# ------- +# A sparse matrix +# The same as `x` with data changed for ones. +# +# """ +# # TODO: don't restrict to CSM formats +# data, indices, indptr, _shape = csm_properties(x) +# return CSM(format=x.format)(at.ones_like(data), indices, indptr, _shape) +# +# +# def sp_zeros_like(x): +# """ +# Construct a sparse matrix of zeros. +# +# Parameters +# ---------- +# x +# Sparse matrix to take the shape. +# +# Returns +# ------- +# A sparse matrix +# The same as `x` with zero entries for all element. +# +# """ +# +# # TODO: don't restrict to CSM formats +# _, _, indptr, _shape = csm_properties(x) +# return CSM(format=x.format)( +# data=np.array([], dtype=x.type.dtype), +# indices=np.array([], dtype="int32"), +# indptr=at.zeros_like(indptr), +# shape=_shape, +# ) +# +# +# def override_dense(*methods): +# def decorate(cls): +# def native(method): +# original = getattr(cls.__base__, method) +# +# def to_dense(self, *args, **kwargs): +# self = self.toarray() +# new_args = [ +# arg.toarray() +# if hasattr(arg, "type") and isinstance(arg.type, SparseTensorType) +# else arg +# for arg in args +# ] +# warn( +# f"Method {method} is not implemented for sparse variables. The variable will be converted to dense." +# ) +# return original(self, *new_args, **kwargs) +# +# return to_dense +# +# for method in methods: +# setattr(cls, method, native(method)) +# return cls +# +# return decorate + + +# @override_dense( +# "__abs__", +# "__ceil__", +# "__floor__", +# "__trunc__", +# "transpose", +# "any", +# "all", +# "flatten", +# "ravel", +# "arccos", +# "arcsin", +# "arctan", +# "arccosh", +# "arcsinh", +# "arctanh", +# "ceil", +# "cos", +# "cosh", +# "deg2rad", +# "exp", +# "exp2", +# "expm1", +# "floor", +# "log", +# "log10", +# "log1p", +# "log2", +# "rad2deg", +# "sin", +# "sinh", +# "sqrt", +# "tan", +# "tanh", +# "copy", +# "prod", +# "mean", +# "var", +# "std", +# "min", +# "max", +# "argmin", +# "argmax", +# "round", +# "trace", +# "cumsum", +# "cumprod", +# "ptp", +# "squeeze", +# "diagonal", +# "__and__", +# "__or__", +# "__xor__", +# "__pow__", +# "__mod__", +# "__divmod__", +# "__truediv__", +# "__floordiv__", +# "reshape", +# "dimshuffle", +# ) +class _xtensor_py_operators(_tensor_py_operators): + T = property( + lambda self: transpose(self), doc="Return aliased transpose of self (read-only)" + ) + + def astype(self, dtype): + return cast(self, dtype) + + def __neg__(self): + return neg(self) + + def __add__(left, right): + return xbasic.add(left, right) + + def __radd__(right, left): + return add(left, right) + + def __sub__(left, right): + return sub(left, right) + + def __rsub__(right, left): + return sub(left, right) + + def __mul__(left, right): + return mul(left, right) + + def __rmul__(left, right): + return mul(left, right) + + # comparison operators + + def __lt__(self, other): + return lt(self, other) + + def __le__(self, other): + return le(self, other) + + def __gt__(self, other): + return gt(self, other) + + def __ge__(self, other): + return ge(self, other) + + def __dot__(left, right): + return structured_dot(left, right) + + def __rdot__(right, left): + return structured_dot(left, right) + + def sum(self, axis=None, sparse_grad=False): + return sp_sum(self, axis=axis, sparse_grad=sparse_grad) + + dot = __dot__ + + def toarray(self): + return dense_from_sparse(self) + + @property + def shape(self): + # TODO: The plan is that the ShapeFeature in at.opt will do shape + # propagation and remove the dense_from_sparse from the graph. This + # will *NOT* actually expand your sparse matrix just to get the shape. + return shape(dense_from_sparse(self)) + + ndim = property(lambda self: self.type.ndim) + dtype = property(lambda self: self.type.dtype) + + # Note that the `size` attribute of sparse matrices behaves differently + # from dense matrices: it is the number of elements stored in the matrix + # rather than the total number of elements that may be stored. Note also + # that stored zeros *do* count in the size. + size = property(lambda self: csm_data(self).size) + + def zeros_like(model): + return sp_zeros_like(model) + + def __getitem__(self, args): + if not isinstance(args, tuple): + args = (args,) + + if len(args) == 2: + scalar_arg_1 = ( + np.isscalar(args[0]) or getattr(args[0], "type", None) == iscalar + ) + scalar_arg_2 = ( + np.isscalar(args[1]) or getattr(args[1], "type", None) == iscalar + ) + if scalar_arg_1 and scalar_arg_2: + ret = get_item_scalar(self, args) + elif isinstance(args[0], list): + ret = get_item_2lists(self, args[0], args[1]) + else: + ret = get_item_2d(self, args) + elif isinstance(args[0], list): + ret = get_item_list(self, args[0]) + else: + ret = get_item_2d(self, args) + return ret + + def conj(self): + return conjugate(self) + + + +class XTensorVariable(_xtensor_py_operators, TensorVariable): + pass + + # def __str__(self): + # return f"{self.__class__.__name__}{{{self.format},{self.dtype}}}" + + # def __repr__(self): + # return str(self) + + +class XTensorConstantSignature(tuple): + def __eq__(self, other): + if type(self) != type(other): + return False + + (t0, d0), (t1, d1) = self, other + if t0 != t1 or d0.shape != d1.shape: + return False + + return True + + def __ne__(self, other): + return not self == other + + def __hash__(self): + (a, b) = self + return hash(type(self)) ^ hash(a) ^ hash(type(b)) + + def pytensor_hash(self): + t, d = self + return "".join([hash_from_ndarray(d)] + [hash_from_code(dim) for dim in t.dims]) + + +class XTensorConstant(TensorConstant, _xtensor_py_operators): + + def __init__(self, type: _XTensorTypeType, data, name=None): + # TODO: Add checks that type and data are compatible + Constant.__init__(self, type, data, name) + + def signature(self): + assert self.data is not None + return XTensorConstantSignature((self.type, self.data)) + + +XTensorType.variable_type = XTensorVariable +XTensorType.constant_type = XTensorConstant \ No newline at end of file From 05520fd792dfbe29089b1e00a94602547b099f0e Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Tue, 10 Oct 2023 00:59:58 +0200 Subject: [PATCH 2/5] WIP: Type dims and spaces --- pytensor/xtensor/spaces.py | 147 +++++++++++++++++++++++++++++++++++++ 1 file changed, 147 insertions(+) create mode 100644 pytensor/xtensor/spaces.py diff --git a/pytensor/xtensor/spaces.py b/pytensor/xtensor/spaces.py new file mode 100644 index 0000000000..a6e6b29473 --- /dev/null +++ b/pytensor/xtensor/spaces.py @@ -0,0 +1,147 @@ +from abc import ABC, abstractmethod, abstractproperty +from collections.abc import Iterator, Sequence, Collection, Sized, Iterable, Container, Set, Reversible +import sys +from typing import FrozenSet, Protocol, Tuple, Union, Iterator, overload, SupportsIndex + + +class DimLike(Protocol): + """Most basic signature of a dimension.""" + + def __str__(self) -> str: + ... + + def __hash__(self) -> int: + ... + + +class Dim(DimLike): + """The most common type of dimension.""" + + _name: str + + def __init__(self, name: str) -> None: + self._name = name + super().__init__() + + def __str__(self) -> str: + return self._name + + def __repr__(self) -> str: + return f"Dim('{self._name}')" + + def __eq__(self, __value: object) -> bool: + return self._name == str(__value) + + def __hash__(self) -> int: + return self._name.__hash__() + + +class BaseSpace(FrozenSet[DimLike]): + """The most generic type of space is an unordered frozen set of dimensions. + + It implements the following calculation operators: + * Addition → Unordered union + * Subtraction → Unordered union + * Multiplication → Unordered union + * Power → Identity + + The logic operators (AND &, OR |, XOR ^) do space math with the frozenset. + """ + + def __add__(self, other: Iterable[DimLike]) -> "BaseSpace": + try: + other = Space(other) + return BaseSpace({*self, *other}) + except Exception as ex: + raise TypeError(f"Can't {other} to space.") from ex + + def __sub__(self, other: Iterable[DimLike]) -> "BaseSpace": + try: + other = Space(other) + return BaseSpace({*self, *other}) + except Exception as ex: + raise TypeError(f"Can't subtract {other} from space.") from ex + + def __mul__(self, other) -> "BaseSpace": + try: + other = Space(other) + return BaseSpace({*self, *other}) + except Exception as ex: + raise TypeError(f"Can't multiply space by {other}.") from ex + + def __pow__(self, other: Iterable[DimLike]) -> "BaseSpace": + return self + + def __and__(self, other: Iterable[DimLike]) -> "BaseSpace": + try: + other = Space(other) + return BaseSpace(set(self) & other) + except Exception as ex: + raise TypeError(f"Can't AND space with {other}.") from ex + + def __or__(self, other: Iterable[DimLike]) -> "BaseSpace": + try: + other = Space(other) + return BaseSpace(set(self) | other) + except Exception as ex: + raise TypeError(f"Can't OR space with {other}.") from ex + + def __xor__(self, other: Iterable[DimLike]) -> "BaseSpace": + try: + other = Space(other) + return BaseSpace(set(self) ^ other) + except Exception as ex: + raise TypeError(f"Can't XOR space with {other}.") from ex + + def __repr__(self) -> str: + return "Space{" + ", ".join(f"'{d}'" for d in self) + "}" + + +class OrderedSpace(BaseSpace, Reversible[DimLike]): + """A very tidied-up space, with a known order of dimensions.""" + + def __init__(self, dims: Sequence[DimLike]) -> None: + self._order = tuple(dims) + super().__init__() + + def __iter__(self) -> Iterator[DimLike]: + for d in self._order: + yield d + + def __reversed__(self) -> Iterator[DimLike]: + for d in reversed(self._order): + yield d + + def index(self, __value: DimLike, __start: SupportsIndex = 0, __stop: SupportsIndex = sys.maxsize) -> int: + return self._order.index(__value, __start, __stop) + + @overload + def __getitem__(self, __key: slice) -> "OrderedSpace": + """Slicing an ordered space results in an ordered space.""" + return OrderedSpace(self._order[__key]) + + @overload + def __getitem__(self, __key: SupportsIndex) -> DimLike: + return self._order[__key] + + def __getitem__(self, __key) -> DimLike: + return self._order[__key] + + + def __repr__(self) -> str: + return "OrderedSpace(" + ", ".join(f"'{d}'" for d in self) + ")" + + +@overload +def Space(dims: Sequence[DimLike]) -> OrderedSpace: + """Sequences of dims give an ordered space.""" + ... + +@overload +def Space(dims: Set[DimLike]) -> BaseSpace: + ... + +def Space(dims: Iterable[DimLike]) -> Union[OrderedSpace, BaseSpace]: + if isinstance(dims, Sequence): + return OrderedSpace(dims) + return BaseSpace(dims) From c7bafc6d15a804c7ae8dff996369e573a12ee7d3 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sun, 15 Oct 2023 13:30:33 +0200 Subject: [PATCH 3/5] Add spaces tests and fix typing --- pytensor/xtensor/spaces.py | 81 +++++++++++++++-------- tests/xtensor/test_spaces.py | 120 +++++++++++++++++++++++++++++++++++ 2 files changed, 176 insertions(+), 25 deletions(-) create mode 100644 tests/xtensor/test_spaces.py diff --git a/pytensor/xtensor/spaces.py b/pytensor/xtensor/spaces.py index a6e6b29473..0b7d353c9a 100644 --- a/pytensor/xtensor/spaces.py +++ b/pytensor/xtensor/spaces.py @@ -1,9 +1,18 @@ -from abc import ABC, abstractmethod, abstractproperty -from collections.abc import Iterator, Sequence, Collection, Sized, Iterable, Container, Set, Reversible import sys -from typing import FrozenSet, Protocol, Tuple, Union, Iterator, overload, SupportsIndex - - +from collections.abc import Iterable, Iterator, Sequence +from typing import ( + FrozenSet, + Iterator, + Protocol, + SupportsIndex, + Union, + cast, + overload, + runtime_checkable, +) + + +@runtime_checkable class DimLike(Protocol): """Most basic signature of a dimension.""" @@ -25,21 +34,21 @@ def __init__(self, name: str) -> None: def __str__(self) -> str: return self._name - + def __repr__(self) -> str: return f"Dim('{self._name}')" def __eq__(self, __value: object) -> bool: return self._name == str(__value) - + def __hash__(self) -> int: return self._name.__hash__() class BaseSpace(FrozenSet[DimLike]): """The most generic type of space is an unordered frozen set of dimensions. - - It implements the following calculation operators: + + It implements broadcasting rules for the following tensor operators: * Addition → Unordered union * Subtraction → Unordered union * Multiplication → Unordered union @@ -69,6 +78,13 @@ def __mul__(self, other) -> "BaseSpace": except Exception as ex: raise TypeError(f"Can't multiply space by {other}.") from ex + def __truediv__(self, other) -> "BaseSpace": + try: + other = Space(other) + return BaseSpace({*self, *other}) + except Exception as ex: + raise TypeError(f"Can't divide space by {other}.") from ex + def __pow__(self, other: Iterable[DimLike]) -> "BaseSpace": return self @@ -97,36 +113,49 @@ def __repr__(self) -> str: return "Space{" + ", ".join(f"'{d}'" for d in self) + "}" -class OrderedSpace(BaseSpace, Reversible[DimLike]): +class OrderedSpace(BaseSpace, Sequence[DimLike]): """A very tidied-up space, with a known order of dimensions.""" def __init__(self, dims: Sequence[DimLike]) -> None: self._order = tuple(dims) super().__init__() + def __eq__(self, __value) -> bool: + if not isinstance(__value, Sequence) or isinstance(__value, str): + return False + return self._order == tuple(__value) + + def __ne__(self, __value) -> bool: + if not isinstance(__value, Sequence) or isinstance(__value, str): + return True + return self._order != tuple(__value) + def __iter__(self) -> Iterator[DimLike]: - for d in self._order: - yield d + yield from self._order def __reversed__(self) -> Iterator[DimLike]: - for d in reversed(self._order): - yield d - - def index(self, __value: DimLike, __start: SupportsIndex = 0, __stop: SupportsIndex = sys.maxsize) -> int: + yield from reversed(self._order) + + def index( + self, + __value: DimLike, + __start: SupportsIndex = 0, + __stop: SupportsIndex = sys.maxsize, + ) -> int: return self._order.index(__value, __start, __stop) - - @overload - def __getitem__(self, __key: slice) -> "OrderedSpace": - """Slicing an ordered space results in an ordered space.""" - return OrderedSpace(self._order[__key]) @overload def __getitem__(self, __key: SupportsIndex) -> DimLike: - return self._order[__key] + """Indexing gives a dim""" - def __getitem__(self, __key) -> DimLike: - return self._order[__key] + @overload + def __getitem__(self, __key: slice) -> "OrderedSpace": + """Slicing preserves order""" + def __getitem__(self, __key) -> Union[DimLike, "OrderedSpace"]: + if isinstance(__key, slice): + return OrderedSpace(self._order[__key]) + return cast(DimLike, self._order[__key]) def __repr__(self) -> str: return "OrderedSpace(" + ", ".join(f"'{d}'" for d in self) + ")" @@ -137,10 +166,12 @@ def Space(dims: Sequence[DimLike]) -> OrderedSpace: """Sequences of dims give an ordered space.""" ... + @overload -def Space(dims: Set[DimLike]) -> BaseSpace: +def Space(dims: Iterable[DimLike]) -> BaseSpace: ... + def Space(dims: Iterable[DimLike]) -> Union[OrderedSpace, BaseSpace]: if isinstance(dims, Sequence): return OrderedSpace(dims) diff --git a/tests/xtensor/test_spaces.py b/tests/xtensor/test_spaces.py new file mode 100644 index 0000000000..a3ba59be80 --- /dev/null +++ b/tests/xtensor/test_spaces.py @@ -0,0 +1,120 @@ +from typing import Sequence + +import pytest + +import pytensor.xtensor.spaces as xsp + + +class TestDims: + def test_str_is_dimlike(self): + assert isinstance("d", xsp.DimLike) + + def test_dim(self): + d0 = xsp.Dim("d0") + assert isinstance(d0, xsp.DimLike) + assert str(d0) == "d0" + assert "d0" in d0.__repr__() + # Dims can compare with strings + assert d0 == "d0" + # They must be hashable to be used as keys + assert isinstance(hash(d0), int) + + +class TestBaseSpace: + def test_type(self): + assert issubclass(xsp.BaseSpace, frozenset) + s1 = xsp.BaseSpace({"d0", "d1"}) + assert "Space" in s1.__repr__() + assert "d0" in s1.__repr__() + assert "d1" in s1.__repr__() + # Spaces are frozensets which makes them convenient to use + assert isinstance(s1, frozenset) + # But they can't be sets, because .add(newdim) would mess up things + assert not isinstance(s1, set) + assert "d0" in s1 + assert "d1" in s1 + assert len(s1) == 2 + assert s1 == {"d1", "d0"} + # Can't index an unordered space + assert not hasattr(s1, "index") + with pytest.raises(TypeError, match="not subscriptable"): + s1[1] + + def test_spacemath(self): + assert xsp.BaseSpace("ab") == {"a", "b"} + # Set logic operations result in spaces + union = xsp.BaseSpace("ab") | {"b", "c"} + assert isinstance(union, xsp.BaseSpace) + assert union == {"a", "b", "c"} + + intersection = xsp.BaseSpace("ab") & {"b", "c"} + assert isinstance(intersection, xsp.BaseSpace) + assert intersection == {"b"} + + xor = xsp.BaseSpace("ab") ^ {"b", "c"} + assert isinstance(xor, xsp.BaseSpace) + assert xor == {"a", "c"} + + def test_tensormath(self): + # Tensors and spaces follow the same basic math rules + addition = xsp.BaseSpace("ab") + {"c"} + assert isinstance(addition, xsp.BaseSpace) + assert addition == {"a", "b", "c"} + + subtraction = xsp.BaseSpace("ab") - {"b", "c"} + assert isinstance(subtraction, xsp.BaseSpace) + assert subtraction == {"a", "b", "c"} + + multiplication = xsp.BaseSpace("ab") * {"c"} + assert isinstance(multiplication, xsp.BaseSpace) + assert multiplication == {"a", "b", "c"} + + division = xsp.BaseSpace("ab") / {"b", "c"} + assert isinstance(division, xsp.BaseSpace) + assert division == {"a", "b", "c"} + + power = xsp.BaseSpace("ba") ** 3 + assert isinstance(power, xsp.BaseSpace) + assert power == {"a", "b"} + + +class TestOrderedSpace: + def test_type(self): + o1 = xsp.OrderedSpace(["b", "a"]) + assert o1.__repr__() == "OrderedSpace('b', 'a')" + assert isinstance(o1, Sequence) + assert len(o1) == 2 + # Addition/multiplication is different compare to tuples + assert not isinstance(o1, tuple) + # And lists would be mutable, but ordered spaces are not + assert not isinstance(o1, list) + + def test_comparison(self): + # Ordered spaces can only be equal to other ordered things + assert xsp.OrderedSpace("a") != {"a"} + assert xsp.OrderedSpace("a") == ("a",) + assert xsp.OrderedSpace("a") == ["a"] + assert xsp.OrderedSpace("a") == xsp.OrderedSpace("a") + # Except for strings, because they could be a dim + assert not xsp.OrderedSpace("a") == "a" + assert xsp.OrderedSpace("a") != "a" + + def test_indexing(self): + b = xsp.Dim("b") + o1 = xsp.OrderedSpace([b, "a"]) + # Ordered spaces can be indexed + assert o1.index("b") == 0 + assert o1[0] is b + sliced = o1[::-1] + assert isinstance(sliced, xsp.OrderedSpace) + assert sliced == ("a", "b") + + +def test_space_function(): + usp = xsp.Space({"a", "b"}) + assert isinstance(usp, xsp.BaseSpace) + assert not isinstance(usp, xsp.OrderedSpace) + + assert isinstance(xsp.Space(["a", "b"]), xsp.OrderedSpace) + assert isinstance(xsp.Space(("a", "b")), xsp.OrderedSpace) + assert isinstance(xsp.Space("ab"), xsp.OrderedSpace) From bb3e38a8e7deab79aa8c8ef7ef95e9aef59ed368 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sun, 15 Oct 2023 15:01:00 +0200 Subject: [PATCH 4/5] Fix various typing issues in xtensor module --- pytensor/xtensor/__init__.py | 12 +++++-- pytensor/xtensor/basic.py | 46 ++++++++++++++++++-------- pytensor/xtensor/rewriting/__init__.py | 2 +- pytensor/xtensor/rewriting/basic.py | 10 ++++-- pytensor/xtensor/rewriting/utils.py | 12 +++++-- pytensor/xtensor/spaces.py | 1 - pytensor/xtensor/type.py | 2 +- pytensor/xtensor/variable.py | 24 +++++++++----- scripts/mypy-failing.txt | 3 +- 9 files changed, 76 insertions(+), 36 deletions(-) diff --git a/pytensor/xtensor/__init__.py b/pytensor/xtensor/__init__.py index c207db15c6..7b243a68ac 100644 --- a/pytensor/xtensor/__init__.py +++ b/pytensor/xtensor/__init__.py @@ -1,8 +1,14 @@ import warnings -import pytensor.xtensor.rewriting -from pytensor.xtensor.variable import XTensorVariable, XTensorConstant, as_xtensor, as_xtensor_variable +import pytensor.xtensor.rewriting +from pytensor.xtensor.spaces import BaseSpace, Dim, DimLike, OrderedSpace from pytensor.xtensor.type import XTensorType +from pytensor.xtensor.variable import ( + XTensorConstant, + XTensorVariable, + as_xtensor, + as_xtensor_variable, +) -warnings.warn("xtensor module is experimental and full of bugs") \ No newline at end of file +warnings.warn("xtensor module is experimental and full of bugs") diff --git a/pytensor/xtensor/basic.py b/pytensor/xtensor/basic.py index ab1f74172e..91fc8d70a9 100644 --- a/pytensor/xtensor/basic.py +++ b/pytensor/xtensor/basic.py @@ -1,20 +1,29 @@ from itertools import chain +from typing import Any, Sequence import pytensor.scalar as ps -from pytensor.graph import Apply, Op import pytensor.xtensor as px +from pytensor.graph import Apply, Op +from pytensor.graph.basic import Variable +from pytensor.graph.op import OutputStorageType, ParamsInputType from pytensor.tensor import TensorType class TensorFromXTensor(Op): - - def make_node(self, x) -> Apply: + def make_node(self, *inputs: Variable) -> Apply: + [x] = inputs if not isinstance(x.type, px.XTensorType): raise TypeError(f"x must be have an XTensorType, got {type(x.type)}") output = TensorType(x.type.dtype, shape=x.type.shape)() - return Apply(self, [x], [output]) - - def perform(self, node, inputs, output_storage) -> None: + return Apply(self, inputs, [output]) + + def perform( + self, + node: Apply, + inputs: Sequence[Any], + output_storage: OutputStorageType, + params: ParamsInputType = None, + ) -> None: [x] = inputs output_storage[0][0] = x.copy() @@ -23,20 +32,26 @@ def perform(self, node, inputs, output_storage) -> None: class XTensorFromTensor(Op): - __props__ = ("dims",) def __init__(self, dims): super().__init__() self.dims = dims - def make_node(self, x) -> Apply: + def make_node(self, *inputs: Variable) -> Apply: + [x] = inputs if not isinstance(x.type, TensorType): raise TypeError(f"x must be an TensorType type, got {type(x.type)}") output = px.XTensorType(x.type.dtype, dims=self.dims, shape=x.type.shape)() - return Apply(self, [x], [output]) - - def perform(self, node, inputs, output_storage) -> None: + return Apply(self, inputs, [output]) + + def perform( + self, + node: Apply, + inputs: Sequence[Any], + output_storage: OutputStorageType, + params: ParamsInputType = None, + ) -> None: [x] = inputs output_storage[0][0] = x.copy() @@ -46,7 +61,6 @@ def xtensor_from_tensor(x, dims): class XElemwise(Op): - __props__ = ("scalar_op",) def __init__(self, scalar_op): @@ -59,12 +73,16 @@ def make_node(self, *inputs): # TODO: This ordering is different than what xarray does unique_dims = sorted(set(chain.from_iterable(inp.type.dims for inp in inputs))) # TODO: Fix dtype - output_type = px.XTensorType("float64", dims=unique_dims, shape=(None,) * len(unique_dims)) + output_type = px.XTensorType( + "float64", dims=unique_dims, shape=(None,) * len(unique_dims) + ) outputs = [output_type() for _ in range(self.scalar_op.nout)] return Apply(self, inputs, outputs) def perform(self, *args, **kwargs) -> None: - raise NotImplementedError("xtensor operations must be rewritten as tensor operations") + raise NotImplementedError( + "xtensor operations must be rewritten as tensor operations" + ) add = XElemwise(ps.add) diff --git a/pytensor/xtensor/rewriting/__init__.py b/pytensor/xtensor/rewriting/__init__.py index 53a493f878..6ff8b80822 100644 --- a/pytensor/xtensor/rewriting/__init__.py +++ b/pytensor/xtensor/rewriting/__init__.py @@ -1 +1 @@ -import pytensor.xtensor.rewriting.basic \ No newline at end of file +import pytensor.xtensor.rewriting.basic diff --git a/pytensor/xtensor/rewriting/basic.py b/pytensor/xtensor/rewriting/basic.py index d2c13cd379..03750fa53f 100644 --- a/pytensor/xtensor/rewriting/basic.py +++ b/pytensor/xtensor/rewriting/basic.py @@ -1,7 +1,7 @@ from pytensor.graph import node_rewriter from pytensor.tensor import expand_dims from pytensor.tensor.elemwise import Elemwise -from pytensor.xtensor.basic import tensor_from_xtensor, XElemwise, xtensor_from_tensor +from pytensor.xtensor.basic import XElemwise, tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.rewriting.utils import register_xcanonicalize @@ -19,8 +19,12 @@ def xelemwise_to_elemwise(fgraph, node): tensor_inp = expand_dims(tensor_inp, axis) tensor_inputs.append(tensor_inp) - tensor_outs = Elemwise(scalar_op=node.op.scalar_op)(*tensor_inputs, return_list=True) + tensor_outs = Elemwise(scalar_op=node.op.scalar_op)( + *tensor_inputs, return_list=True + ) # TODO: copy_stack_trace - new_outs = [xtensor_from_tensor(tensor_out, dims=output_dims) for tensor_out in tensor_outs] + new_outs = [ + xtensor_from_tensor(tensor_out, dims=output_dims) for tensor_out in tensor_outs + ] return new_outs diff --git a/pytensor/xtensor/rewriting/utils.py b/pytensor/xtensor/rewriting/utils.py index f059b656cc..bce5f19932 100644 --- a/pytensor/xtensor/rewriting/utils.py +++ b/pytensor/xtensor/rewriting/utils.py @@ -2,7 +2,7 @@ from pytensor.compile import optdb from pytensor.graph.rewriting.basic import NodeRewriter -from pytensor.graph.rewriting.db import RewriteDatabase, EquilibriumDB +from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase optdb.register( @@ -21,12 +21,18 @@ def register_xcanonicalize( if isinstance(node_rewriter, str): def register(inner_rewriter: Union[RewriteDatabase, NodeRewriter]): - return register_xcanonicalize(inner_rewriter, node_rewriter, *tags, **kwargs) + return register_xcanonicalize( + # FIXME: Signature violation below; 2nd argument isn't a str + inner_rewriter, + node_rewriter, # type: ignore + *tags, + **kwargs, + ) return register else: - name = kwargs.pop("name", None) or node_rewriter.__name__ + name = kwargs.pop("name", None) or type(node_rewriter).__name__ optdb["xtensor"].register( name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs ) diff --git a/pytensor/xtensor/spaces.py b/pytensor/xtensor/spaces.py index 0b7d353c9a..b1c8e5313b 100644 --- a/pytensor/xtensor/spaces.py +++ b/pytensor/xtensor/spaces.py @@ -2,7 +2,6 @@ from collections.abc import Iterable, Iterator, Sequence from typing import ( FrozenSet, - Iterator, Protocol, SupportsIndex, Union, diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 0696233a8c..bf402bf310 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -1,4 +1,4 @@ -from typing import Iterable, Optional, Union, Sequence, TypeVar +from typing import Iterable, Optional, Sequence, TypeVar, Union import numpy as np diff --git a/pytensor/xtensor/variable.py b/pytensor/xtensor/variable.py index bb2a18ac4b..71eeb707bc 100644 --- a/pytensor/xtensor/variable.py +++ b/pytensor/xtensor/variable.py @@ -1,11 +1,12 @@ +import numpy as np import xarray as xr -import pytensor.xtensor.basic as xbasic -from pytensor import _as_symbolic, Variable +import pytensor.xtensor.basic as xbasic +from pytensor import Variable, _as_symbolic from pytensor.graph import Apply, Constant from pytensor.tensor import TensorVariable from pytensor.tensor.utils import hash_from_ndarray -from pytensor.tensor.var import _tensor_py_operators, TensorConstant +from pytensor.tensor.variable import TensorConstant, _tensor_py_operators from pytensor.utils import hash_from_code from pytensor.xtensor.type import XTensorType, _XTensorTypeType @@ -64,7 +65,9 @@ def constant(x, name=None): raise TypeError("xtensor.constant must be called on a Xarray DataArray") try: return XTensorConstant( - XTensorType(dtype=x.dtype, dims=x.dims, shape=x.shape), x.values.copy(), name=name + XTensorType(dtype=x.dtype, dims=x.dims, shape=x.shape), + x.values.copy(), + name=name, ) except TypeError: raise TypeError(f"Could not convert {x} to XTensorType") @@ -268,8 +271,13 @@ def shape(self): # will *NOT* actually expand your sparse matrix just to get the shape. return shape(dense_from_sparse(self)) - ndim = property(lambda self: self.type.ndim) - dtype = property(lambda self: self.type.dtype) + @property + def ndim(self) -> int: + return self.type.ndim + + @property + def dtype(self): + return self.type.dtype # Note that the `size` attribute of sparse matrices behaves differently # from dense matrices: it is the number of elements stored in the matrix @@ -307,7 +315,6 @@ def conj(self): return conjugate(self) - class XTensorVariable(_xtensor_py_operators, TensorVariable): pass @@ -342,7 +349,6 @@ def pytensor_hash(self): class XTensorConstant(TensorConstant, _xtensor_py_operators): - def __init__(self, type: _XTensorTypeType, data, name=None): # TODO: Add checks that type and data are compatible Constant.__init__(self, type, data, name) @@ -353,4 +359,4 @@ def signature(self): XTensorType.variable_type = XTensorVariable -XTensorType.constant_type = XTensorConstant \ No newline at end of file +XTensorType.constant_type = XTensorConstant diff --git a/scripts/mypy-failing.txt b/scripts/mypy-failing.txt index 1cae4d9152..19dd3c7c73 100644 --- a/scripts/mypy-failing.txt +++ b/scripts/mypy-failing.txt @@ -32,4 +32,5 @@ pytensor/tensor/slinalg.py pytensor/tensor/subtensor.py pytensor/tensor/type.py pytensor/tensor/type_other.py -pytensor/tensor/variable.py \ No newline at end of file +pytensor/tensor/variable.py +pytensor/xtensor/variable.py \ No newline at end of file From 165a09086aa0e99b7db420d119568f9903c895c8 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sun, 15 Oct 2023 15:31:51 +0200 Subject: [PATCH 5/5] Make XTensor use spaces --- pytensor/xtensor/basic.py | 15 ++++++++------- pytensor/xtensor/type.py | 13 ++++--------- pytensor/xtensor/variable.py | 8 ++++++-- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/pytensor/xtensor/basic.py b/pytensor/xtensor/basic.py index 91fc8d70a9..4c83d2a9af 100644 --- a/pytensor/xtensor/basic.py +++ b/pytensor/xtensor/basic.py @@ -1,5 +1,5 @@ from itertools import chain -from typing import Any, Sequence +from typing import Any, Iterable, Sequence import pytensor.scalar as ps import pytensor.xtensor as px @@ -7,6 +7,7 @@ from pytensor.graph.basic import Variable from pytensor.graph.op import OutputStorageType, ParamsInputType from pytensor.tensor import TensorType +from pytensor.xtensor.spaces import DimLike, Space class TensorFromXTensor(Op): @@ -34,9 +35,9 @@ def perform( class XTensorFromTensor(Op): __props__ = ("dims",) - def __init__(self, dims): + def __init__(self, dims: Iterable[DimLike]): super().__init__() - self.dims = dims + self.dims = Space(dims) def make_node(self, *inputs: Variable) -> Apply: [x] = inputs @@ -56,7 +57,7 @@ def perform( output_storage[0][0] = x.copy() -def xtensor_from_tensor(x, dims): +def xtensor_from_tensor(x, dims: Iterable[DimLike]): return XTensorFromTensor(dims=dims)(x) @@ -70,11 +71,11 @@ def __init__(self, scalar_op): def make_node(self, *inputs): # TODO: Check dim lengths match inputs = [px.as_xtensor_variable(inp) for inp in inputs] - # TODO: This ordering is different than what xarray does - unique_dims = sorted(set(chain.from_iterable(inp.type.dims for inp in inputs))) + # NOTE: The output will have unordered dims + output_dims = set(chain.from_iterable(inp.type.dims for inp in inputs)) # TODO: Fix dtype output_type = px.XTensorType( - "float64", dims=unique_dims, shape=(None,) * len(unique_dims) + "float64", dims=output_dims, shape=(None,) * len(output_dims) ) outputs = [output_type() for _ in range(self.scalar_op.nout)] return Apply(self, inputs, outputs) diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index bf402bf310..f94ea4aa1d 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -1,4 +1,4 @@ -from typing import Iterable, Optional, Sequence, TypeVar, Union +from typing import Iterable, Optional, Union import numpy as np @@ -7,9 +7,7 @@ from pytensor.graph.basic import Variable from pytensor.graph.type import HasDataType from pytensor.tensor.type import TensorType - - -_XTensorTypeType = TypeVar("_XTensorTypeType", bound=TensorType) +from pytensor.xtensor.spaces import DimLike, Space class XTensorType(TensorType, HasDataType): @@ -27,15 +25,12 @@ def __init__( self, dtype: Union[str, np.dtype], *, - dims: Sequence[str], + dims: Iterable[DimLike], shape: Optional[Iterable[Optional[Union[bool, int]]]] = None, name: Optional[str] = None, ): super().__init__(dtype, shape=shape, name=name) - if not isinstance(dims, (list, tuple)): - raise TypeError("dims must be a list or tuple") - dims = tuple(dims) - self.dims = dims + self.dims = Space(dims) def clone( self, diff --git a/pytensor/xtensor/variable.py b/pytensor/xtensor/variable.py index 71eeb707bc..a5fb9537d7 100644 --- a/pytensor/xtensor/variable.py +++ b/pytensor/xtensor/variable.py @@ -8,7 +8,8 @@ from pytensor.tensor.utils import hash_from_ndarray from pytensor.tensor.variable import TensorConstant, _tensor_py_operators from pytensor.utils import hash_from_code -from pytensor.xtensor.type import XTensorType, _XTensorTypeType +from pytensor.xtensor.spaces import OrderedSpace +from pytensor.xtensor.type import XTensorType @_as_symbolic.register(xr.DataArray) @@ -349,8 +350,11 @@ def pytensor_hash(self): class XTensorConstant(TensorConstant, _xtensor_py_operators): - def __init__(self, type: _XTensorTypeType, data, name=None): + def __init__(self, type: XTensorType, data, name=None): # TODO: Add checks that type and data are compatible + # Check that the type carries ordered dims + if not isinstance(type.dims, OrderedSpace): + raise ValueError(f"XTensor constants require ordered dims, got {type.dims}") Constant.__init__(self, type, data, name) def signature(self):