diff --git a/pytensor/xtensor/__init__.py b/pytensor/xtensor/__init__.py new file mode 100644 index 0000000000..7b243a68ac --- /dev/null +++ b/pytensor/xtensor/__init__.py @@ -0,0 +1,14 @@ +import warnings + +import pytensor.xtensor.rewriting +from pytensor.xtensor.spaces import BaseSpace, Dim, DimLike, OrderedSpace +from pytensor.xtensor.type import XTensorType +from pytensor.xtensor.variable import ( + XTensorConstant, + XTensorVariable, + as_xtensor, + as_xtensor_variable, +) + + +warnings.warn("xtensor module is experimental and full of bugs") diff --git a/pytensor/xtensor/basic.py b/pytensor/xtensor/basic.py new file mode 100644 index 0000000000..4c83d2a9af --- /dev/null +++ b/pytensor/xtensor/basic.py @@ -0,0 +1,89 @@ +from itertools import chain +from typing import Any, Iterable, Sequence + +import pytensor.scalar as ps +import pytensor.xtensor as px +from pytensor.graph import Apply, Op +from pytensor.graph.basic import Variable +from pytensor.graph.op import OutputStorageType, ParamsInputType +from pytensor.tensor import TensorType +from pytensor.xtensor.spaces import DimLike, Space + + +class TensorFromXTensor(Op): + def make_node(self, *inputs: Variable) -> Apply: + [x] = inputs + if not isinstance(x.type, px.XTensorType): + raise TypeError(f"x must be have an XTensorType, got {type(x.type)}") + output = TensorType(x.type.dtype, shape=x.type.shape)() + return Apply(self, inputs, [output]) + + def perform( + self, + node: Apply, + inputs: Sequence[Any], + output_storage: OutputStorageType, + params: ParamsInputType = None, + ) -> None: + [x] = inputs + output_storage[0][0] = x.copy() + + +tensor_from_xtensor = TensorFromXTensor() + + +class XTensorFromTensor(Op): + __props__ = ("dims",) + + def __init__(self, dims: Iterable[DimLike]): + super().__init__() + self.dims = Space(dims) + + def make_node(self, *inputs: Variable) -> Apply: + [x] = inputs + if not isinstance(x.type, TensorType): + raise TypeError(f"x must be an TensorType type, got {type(x.type)}") + output = px.XTensorType(x.type.dtype, dims=self.dims, shape=x.type.shape)() + return Apply(self, inputs, [output]) + + def perform( + self, + node: Apply, + inputs: Sequence[Any], + output_storage: OutputStorageType, + params: ParamsInputType = None, + ) -> None: + [x] = inputs + output_storage[0][0] = x.copy() + + +def xtensor_from_tensor(x, dims: Iterable[DimLike]): + return XTensorFromTensor(dims=dims)(x) + + +class XElemwise(Op): + __props__ = ("scalar_op",) + + def __init__(self, scalar_op): + super().__init__() + self.scalar_op = scalar_op + + def make_node(self, *inputs): + # TODO: Check dim lengths match + inputs = [px.as_xtensor_variable(inp) for inp in inputs] + # NOTE: The output will have unordered dims + output_dims = set(chain.from_iterable(inp.type.dims for inp in inputs)) + # TODO: Fix dtype + output_type = px.XTensorType( + "float64", dims=output_dims, shape=(None,) * len(output_dims) + ) + outputs = [output_type() for _ in range(self.scalar_op.nout)] + return Apply(self, inputs, outputs) + + def perform(self, *args, **kwargs) -> None: + raise NotImplementedError( + "xtensor operations must be rewritten as tensor operations" + ) + + +add = XElemwise(ps.add) diff --git a/pytensor/xtensor/rewriting/__init__.py b/pytensor/xtensor/rewriting/__init__.py new file mode 100644 index 0000000000..6ff8b80822 --- /dev/null +++ b/pytensor/xtensor/rewriting/__init__.py @@ -0,0 +1 @@ +import pytensor.xtensor.rewriting.basic diff --git a/pytensor/xtensor/rewriting/basic.py b/pytensor/xtensor/rewriting/basic.py new file mode 100644 index 0000000000..03750fa53f --- /dev/null +++ b/pytensor/xtensor/rewriting/basic.py @@ -0,0 +1,30 @@ +from pytensor.graph import node_rewriter +from pytensor.tensor import expand_dims +from pytensor.tensor.elemwise import Elemwise +from pytensor.xtensor.basic import XElemwise, tensor_from_xtensor, xtensor_from_tensor +from pytensor.xtensor.rewriting.utils import register_xcanonicalize + + +@register_xcanonicalize +@node_rewriter(tracks=[XElemwise]) +def xelemwise_to_elemwise(fgraph, node): + # Convert inputs to TensorVariables and add broadcastable dims + output_dims = node.outputs[0].type.dims + + tensor_inputs = [] + for inp in node.inputs: + inp_dims = inp.type.dims + axis = [i for i, dim in enumerate(output_dims) if dim not in inp_dims] + tensor_inp = tensor_from_xtensor(inp) + tensor_inp = expand_dims(tensor_inp, axis) + tensor_inputs.append(tensor_inp) + + tensor_outs = Elemwise(scalar_op=node.op.scalar_op)( + *tensor_inputs, return_list=True + ) + + # TODO: copy_stack_trace + new_outs = [ + xtensor_from_tensor(tensor_out, dims=output_dims) for tensor_out in tensor_outs + ] + return new_outs diff --git a/pytensor/xtensor/rewriting/utils.py b/pytensor/xtensor/rewriting/utils.py new file mode 100644 index 0000000000..bce5f19932 --- /dev/null +++ b/pytensor/xtensor/rewriting/utils.py @@ -0,0 +1,39 @@ +from typing import Union + +from pytensor.compile import optdb +from pytensor.graph.rewriting.basic import NodeRewriter +from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase + + +optdb.register( + "xcanonicalize", + EquilibriumDB(ignore_newtrees=False), + "fast_run", + "fast_compile", + "xtensor", + position=0, +) + + +def register_xcanonicalize( + node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs +): + if isinstance(node_rewriter, str): + + def register(inner_rewriter: Union[RewriteDatabase, NodeRewriter]): + return register_xcanonicalize( + # FIXME: Signature violation below; 2nd argument isn't a str + inner_rewriter, + node_rewriter, # type: ignore + *tags, + **kwargs, + ) + + return register + + else: + name = kwargs.pop("name", None) or type(node_rewriter).__name__ + optdb["xtensor"].register( + name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs + ) + return node_rewriter diff --git a/pytensor/xtensor/spaces.py b/pytensor/xtensor/spaces.py new file mode 100644 index 0000000000..b1c8e5313b --- /dev/null +++ b/pytensor/xtensor/spaces.py @@ -0,0 +1,177 @@ +import sys +from collections.abc import Iterable, Iterator, Sequence +from typing import ( + FrozenSet, + Protocol, + SupportsIndex, + Union, + cast, + overload, + runtime_checkable, +) + + +@runtime_checkable +class DimLike(Protocol): + """Most basic signature of a dimension.""" + + def __str__(self) -> str: + ... + + def __hash__(self) -> int: + ... + + +class Dim(DimLike): + """The most common type of dimension.""" + + _name: str + + def __init__(self, name: str) -> None: + self._name = name + super().__init__() + + def __str__(self) -> str: + return self._name + + def __repr__(self) -> str: + return f"Dim('{self._name}')" + + def __eq__(self, __value: object) -> bool: + return self._name == str(__value) + + def __hash__(self) -> int: + return self._name.__hash__() + + +class BaseSpace(FrozenSet[DimLike]): + """The most generic type of space is an unordered frozen set of dimensions. + + It implements broadcasting rules for the following tensor operators: + * Addition → Unordered union + * Subtraction → Unordered union + * Multiplication → Unordered union + * Power → Identity + + The logic operators (AND &, OR |, XOR ^) do space math with the frozenset. + """ + + def __add__(self, other: Iterable[DimLike]) -> "BaseSpace": + try: + other = Space(other) + return BaseSpace({*self, *other}) + except Exception as ex: + raise TypeError(f"Can't {other} to space.") from ex + + def __sub__(self, other: Iterable[DimLike]) -> "BaseSpace": + try: + other = Space(other) + return BaseSpace({*self, *other}) + except Exception as ex: + raise TypeError(f"Can't subtract {other} from space.") from ex + + def __mul__(self, other) -> "BaseSpace": + try: + other = Space(other) + return BaseSpace({*self, *other}) + except Exception as ex: + raise TypeError(f"Can't multiply space by {other}.") from ex + + def __truediv__(self, other) -> "BaseSpace": + try: + other = Space(other) + return BaseSpace({*self, *other}) + except Exception as ex: + raise TypeError(f"Can't divide space by {other}.") from ex + + def __pow__(self, other: Iterable[DimLike]) -> "BaseSpace": + return self + + def __and__(self, other: Iterable[DimLike]) -> "BaseSpace": + try: + other = Space(other) + return BaseSpace(set(self) & other) + except Exception as ex: + raise TypeError(f"Can't AND space with {other}.") from ex + + def __or__(self, other: Iterable[DimLike]) -> "BaseSpace": + try: + other = Space(other) + return BaseSpace(set(self) | other) + except Exception as ex: + raise TypeError(f"Can't OR space with {other}.") from ex + + def __xor__(self, other: Iterable[DimLike]) -> "BaseSpace": + try: + other = Space(other) + return BaseSpace(set(self) ^ other) + except Exception as ex: + raise TypeError(f"Can't XOR space with {other}.") from ex + + def __repr__(self) -> str: + return "Space{" + ", ".join(f"'{d}'" for d in self) + "}" + + +class OrderedSpace(BaseSpace, Sequence[DimLike]): + """A very tidied-up space, with a known order of dimensions.""" + + def __init__(self, dims: Sequence[DimLike]) -> None: + self._order = tuple(dims) + super().__init__() + + def __eq__(self, __value) -> bool: + if not isinstance(__value, Sequence) or isinstance(__value, str): + return False + return self._order == tuple(__value) + + def __ne__(self, __value) -> bool: + if not isinstance(__value, Sequence) or isinstance(__value, str): + return True + return self._order != tuple(__value) + + def __iter__(self) -> Iterator[DimLike]: + yield from self._order + + def __reversed__(self) -> Iterator[DimLike]: + yield from reversed(self._order) + + def index( + self, + __value: DimLike, + __start: SupportsIndex = 0, + __stop: SupportsIndex = sys.maxsize, + ) -> int: + return self._order.index(__value, __start, __stop) + + @overload + def __getitem__(self, __key: SupportsIndex) -> DimLike: + """Indexing gives a dim""" + + @overload + def __getitem__(self, __key: slice) -> "OrderedSpace": + """Slicing preserves order""" + + def __getitem__(self, __key) -> Union[DimLike, "OrderedSpace"]: + if isinstance(__key, slice): + return OrderedSpace(self._order[__key]) + return cast(DimLike, self._order[__key]) + + def __repr__(self) -> str: + return "OrderedSpace(" + ", ".join(f"'{d}'" for d in self) + ")" + + +@overload +def Space(dims: Sequence[DimLike]) -> OrderedSpace: + """Sequences of dims give an ordered space.""" + ... + + +@overload +def Space(dims: Iterable[DimLike]) -> BaseSpace: + ... + + +def Space(dims: Iterable[DimLike]) -> Union[OrderedSpace, BaseSpace]: + if isinstance(dims, Sequence): + return OrderedSpace(dims) + return BaseSpace(dims) diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py new file mode 100644 index 0000000000..f94ea4aa1d --- /dev/null +++ b/pytensor/xtensor/type.py @@ -0,0 +1,142 @@ +from typing import Iterable, Optional, Union + +import numpy as np + +import pytensor +from pytensor import scalar as aes +from pytensor.graph.basic import Variable +from pytensor.graph.type import HasDataType +from pytensor.tensor.type import TensorType +from pytensor.xtensor.spaces import DimLike, Space + + +class XTensorType(TensorType, HasDataType): + """A `Type` for sparse tensors. + + Notes + ----- + Currently, sparse tensors can only be matrices (i.e. have two dimensions). + + """ + + __props__ = ("dtype", "shape", "dims") + + def __init__( + self, + dtype: Union[str, np.dtype], + *, + dims: Iterable[DimLike], + shape: Optional[Iterable[Optional[Union[bool, int]]]] = None, + name: Optional[str] = None, + ): + super().__init__(dtype, shape=shape, name=name) + self.dims = Space(dims) + + def clone( + self, + dtype=None, + dims=None, + shape=None, + **kwargs, + ): + if dtype is None: + dtype = self.dtype + if dims is None: + dims = self.dims + if shape is None: + shape = self.shape + return type(self)(format, dtype, shape=shape, dims=dims, **kwargs) + + def filter(self, value, strict=False, allow_downcast=None): + # TODO: Implement this + return value + + if isinstance(value, Variable): + raise TypeError( + "Expected an array-like object, but found a Variable: " + "maybe you are trying to call a function on a (possibly " + "shared) variable instead of a numeric array?" + ) + + if ( + isinstance(value, self.format_cls[self.format]) + and value.dtype == self.dtype + ): + return value + + if strict: + raise TypeError( + f"{value} is not sparse, or not the right dtype (is {value.dtype}, " + f"expected {self.dtype})" + ) + + # The input format could be converted here + if allow_downcast: + sp = self.format_cls[self.format](value, dtype=self.dtype) + else: + data = self.format_cls[self.format](value) + up_dtype = aes.upcast(self.dtype, data.dtype) + if up_dtype != self.dtype: + raise TypeError(f"Expected {self.dtype} dtype but got {data.dtype}") + sp = data.astype(up_dtype) + + assert sp.format == self.format + + return sp + + def convert_variable(self, var): + # TODO: Implement this + return var + res = super().convert_variable(var) + + if res is None: + return res + + if not isinstance(res.type, type(self)): + return None + + if res.dims != self.dims: + # TODO: Does this make sense? + return None + + return res + + def __hash__(self): + return super().__hash__() ^ hash(self.dims) + + def __repr__(self): + # TODO: Add `?` for unknown shapes like `TensorType` does + return f"XTensorType({self.dtype}, {self.dims}, {self.shape})" + + def __eq__(self, other): + res = super().__eq__(other) + + if isinstance(res, bool): + return res and other.dims == self.dims + + return res + + def is_super(self, otype): + # TODO: Implement this + return True + + if not super().is_super(otype): + return False + + if self.dims == otype.dims: + return True + + return False + + +# TODO: Implement creater helper xtensor + +pytensor.compile.register_view_op_c_code( + XTensorType, + """ + Py_XDECREF(%(oname)s); + %(oname)s = %(iname)s; + Py_XINCREF(%(oname)s); + """, + 1, +) diff --git a/pytensor/xtensor/variable.py b/pytensor/xtensor/variable.py new file mode 100644 index 0000000000..a5fb9537d7 --- /dev/null +++ b/pytensor/xtensor/variable.py @@ -0,0 +1,366 @@ +import numpy as np +import xarray as xr + +import pytensor.xtensor.basic as xbasic +from pytensor import Variable, _as_symbolic +from pytensor.graph import Apply, Constant +from pytensor.tensor import TensorVariable +from pytensor.tensor.utils import hash_from_ndarray +from pytensor.tensor.variable import TensorConstant, _tensor_py_operators +from pytensor.utils import hash_from_code +from pytensor.xtensor.spaces import OrderedSpace +from pytensor.xtensor.type import XTensorType + + +@_as_symbolic.register(xr.DataArray) +def as_symbolic_sparse(x, **kwargs): + return as_xtensor_variable(x, **kwargs) + + +def as_xtensor_variable(x, name=None, ndim=None, **kwargs): + """ + Wrapper around SparseVariable constructor to construct + a Variable with a sparse matrix with the same dtype and + format. + + Parameters + ---------- + x + A sparse matrix. + + Returns + ------- + object + SparseVariable version of `x`. + + """ + + # TODO + # Verify that sp is sufficiently sparse, and raise a + # warning if it is not + + if isinstance(x, Apply): + if len(x.outputs) != 1: + raise ValueError( + "It is ambiguous which output of a " + "multi-output Op has to be fetched.", + x, + ) + else: + x = x.outputs[0] + if isinstance(x, Variable): + if not isinstance(x.type, XTensorType): + raise TypeError(f"Variable type field must be a XTensorType, got {x.type}") + return x + try: + return constant(x, name=name) + except TypeError as err: + raise TypeError(f"Cannot convert {x} to XTensorType {type(x)}") from err + + +as_xtensor = as_xtensor_variable + + +def constant(x, name=None): + if not isinstance(x, xr.DataArray): + raise TypeError("xtensor.constant must be called on a Xarray DataArray") + try: + return XTensorConstant( + XTensorType(dtype=x.dtype, dims=x.dims, shape=x.shape), + x.values.copy(), + name=name, + ) + except TypeError: + raise TypeError(f"Could not convert {x} to XTensorType") + + +# def sp_ones_like(x): +# """ +# Construct a sparse matrix of ones with the same sparsity pattern. +# +# Parameters +# ---------- +# x +# Sparse matrix to take the sparsity pattern. +# +# Returns +# ------- +# A sparse matrix +# The same as `x` with data changed for ones. +# +# """ +# # TODO: don't restrict to CSM formats +# data, indices, indptr, _shape = csm_properties(x) +# return CSM(format=x.format)(at.ones_like(data), indices, indptr, _shape) +# +# +# def sp_zeros_like(x): +# """ +# Construct a sparse matrix of zeros. +# +# Parameters +# ---------- +# x +# Sparse matrix to take the shape. +# +# Returns +# ------- +# A sparse matrix +# The same as `x` with zero entries for all element. +# +# """ +# +# # TODO: don't restrict to CSM formats +# _, _, indptr, _shape = csm_properties(x) +# return CSM(format=x.format)( +# data=np.array([], dtype=x.type.dtype), +# indices=np.array([], dtype="int32"), +# indptr=at.zeros_like(indptr), +# shape=_shape, +# ) +# +# +# def override_dense(*methods): +# def decorate(cls): +# def native(method): +# original = getattr(cls.__base__, method) +# +# def to_dense(self, *args, **kwargs): +# self = self.toarray() +# new_args = [ +# arg.toarray() +# if hasattr(arg, "type") and isinstance(arg.type, SparseTensorType) +# else arg +# for arg in args +# ] +# warn( +# f"Method {method} is not implemented for sparse variables. The variable will be converted to dense." +# ) +# return original(self, *new_args, **kwargs) +# +# return to_dense +# +# for method in methods: +# setattr(cls, method, native(method)) +# return cls +# +# return decorate + + +# @override_dense( +# "__abs__", +# "__ceil__", +# "__floor__", +# "__trunc__", +# "transpose", +# "any", +# "all", +# "flatten", +# "ravel", +# "arccos", +# "arcsin", +# "arctan", +# "arccosh", +# "arcsinh", +# "arctanh", +# "ceil", +# "cos", +# "cosh", +# "deg2rad", +# "exp", +# "exp2", +# "expm1", +# "floor", +# "log", +# "log10", +# "log1p", +# "log2", +# "rad2deg", +# "sin", +# "sinh", +# "sqrt", +# "tan", +# "tanh", +# "copy", +# "prod", +# "mean", +# "var", +# "std", +# "min", +# "max", +# "argmin", +# "argmax", +# "round", +# "trace", +# "cumsum", +# "cumprod", +# "ptp", +# "squeeze", +# "diagonal", +# "__and__", +# "__or__", +# "__xor__", +# "__pow__", +# "__mod__", +# "__divmod__", +# "__truediv__", +# "__floordiv__", +# "reshape", +# "dimshuffle", +# ) +class _xtensor_py_operators(_tensor_py_operators): + T = property( + lambda self: transpose(self), doc="Return aliased transpose of self (read-only)" + ) + + def astype(self, dtype): + return cast(self, dtype) + + def __neg__(self): + return neg(self) + + def __add__(left, right): + return xbasic.add(left, right) + + def __radd__(right, left): + return add(left, right) + + def __sub__(left, right): + return sub(left, right) + + def __rsub__(right, left): + return sub(left, right) + + def __mul__(left, right): + return mul(left, right) + + def __rmul__(left, right): + return mul(left, right) + + # comparison operators + + def __lt__(self, other): + return lt(self, other) + + def __le__(self, other): + return le(self, other) + + def __gt__(self, other): + return gt(self, other) + + def __ge__(self, other): + return ge(self, other) + + def __dot__(left, right): + return structured_dot(left, right) + + def __rdot__(right, left): + return structured_dot(left, right) + + def sum(self, axis=None, sparse_grad=False): + return sp_sum(self, axis=axis, sparse_grad=sparse_grad) + + dot = __dot__ + + def toarray(self): + return dense_from_sparse(self) + + @property + def shape(self): + # TODO: The plan is that the ShapeFeature in at.opt will do shape + # propagation and remove the dense_from_sparse from the graph. This + # will *NOT* actually expand your sparse matrix just to get the shape. + return shape(dense_from_sparse(self)) + + @property + def ndim(self) -> int: + return self.type.ndim + + @property + def dtype(self): + return self.type.dtype + + # Note that the `size` attribute of sparse matrices behaves differently + # from dense matrices: it is the number of elements stored in the matrix + # rather than the total number of elements that may be stored. Note also + # that stored zeros *do* count in the size. + size = property(lambda self: csm_data(self).size) + + def zeros_like(model): + return sp_zeros_like(model) + + def __getitem__(self, args): + if not isinstance(args, tuple): + args = (args,) + + if len(args) == 2: + scalar_arg_1 = ( + np.isscalar(args[0]) or getattr(args[0], "type", None) == iscalar + ) + scalar_arg_2 = ( + np.isscalar(args[1]) or getattr(args[1], "type", None) == iscalar + ) + if scalar_arg_1 and scalar_arg_2: + ret = get_item_scalar(self, args) + elif isinstance(args[0], list): + ret = get_item_2lists(self, args[0], args[1]) + else: + ret = get_item_2d(self, args) + elif isinstance(args[0], list): + ret = get_item_list(self, args[0]) + else: + ret = get_item_2d(self, args) + return ret + + def conj(self): + return conjugate(self) + + +class XTensorVariable(_xtensor_py_operators, TensorVariable): + pass + + # def __str__(self): + # return f"{self.__class__.__name__}{{{self.format},{self.dtype}}}" + + # def __repr__(self): + # return str(self) + + +class XTensorConstantSignature(tuple): + def __eq__(self, other): + if type(self) != type(other): + return False + + (t0, d0), (t1, d1) = self, other + if t0 != t1 or d0.shape != d1.shape: + return False + + return True + + def __ne__(self, other): + return not self == other + + def __hash__(self): + (a, b) = self + return hash(type(self)) ^ hash(a) ^ hash(type(b)) + + def pytensor_hash(self): + t, d = self + return "".join([hash_from_ndarray(d)] + [hash_from_code(dim) for dim in t.dims]) + + +class XTensorConstant(TensorConstant, _xtensor_py_operators): + def __init__(self, type: XTensorType, data, name=None): + # TODO: Add checks that type and data are compatible + # Check that the type carries ordered dims + if not isinstance(type.dims, OrderedSpace): + raise ValueError(f"XTensor constants require ordered dims, got {type.dims}") + Constant.__init__(self, type, data, name) + + def signature(self): + assert self.data is not None + return XTensorConstantSignature((self.type, self.data)) + + +XTensorType.variable_type = XTensorVariable +XTensorType.constant_type = XTensorConstant diff --git a/scripts/mypy-failing.txt b/scripts/mypy-failing.txt index 1cae4d9152..19dd3c7c73 100644 --- a/scripts/mypy-failing.txt +++ b/scripts/mypy-failing.txt @@ -32,4 +32,5 @@ pytensor/tensor/slinalg.py pytensor/tensor/subtensor.py pytensor/tensor/type.py pytensor/tensor/type_other.py -pytensor/tensor/variable.py \ No newline at end of file +pytensor/tensor/variable.py +pytensor/xtensor/variable.py \ No newline at end of file diff --git a/tests/xtensor/test_spaces.py b/tests/xtensor/test_spaces.py new file mode 100644 index 0000000000..a3ba59be80 --- /dev/null +++ b/tests/xtensor/test_spaces.py @@ -0,0 +1,120 @@ +from typing import Sequence + +import pytest + +import pytensor.xtensor.spaces as xsp + + +class TestDims: + def test_str_is_dimlike(self): + assert isinstance("d", xsp.DimLike) + + def test_dim(self): + d0 = xsp.Dim("d0") + assert isinstance(d0, xsp.DimLike) + assert str(d0) == "d0" + assert "d0" in d0.__repr__() + # Dims can compare with strings + assert d0 == "d0" + # They must be hashable to be used as keys + assert isinstance(hash(d0), int) + + +class TestBaseSpace: + def test_type(self): + assert issubclass(xsp.BaseSpace, frozenset) + s1 = xsp.BaseSpace({"d0", "d1"}) + assert "Space" in s1.__repr__() + assert "d0" in s1.__repr__() + assert "d1" in s1.__repr__() + # Spaces are frozensets which makes them convenient to use + assert isinstance(s1, frozenset) + # But they can't be sets, because .add(newdim) would mess up things + assert not isinstance(s1, set) + assert "d0" in s1 + assert "d1" in s1 + assert len(s1) == 2 + assert s1 == {"d1", "d0"} + # Can't index an unordered space + assert not hasattr(s1, "index") + with pytest.raises(TypeError, match="not subscriptable"): + s1[1] + + def test_spacemath(self): + assert xsp.BaseSpace("ab") == {"a", "b"} + # Set logic operations result in spaces + union = xsp.BaseSpace("ab") | {"b", "c"} + assert isinstance(union, xsp.BaseSpace) + assert union == {"a", "b", "c"} + + intersection = xsp.BaseSpace("ab") & {"b", "c"} + assert isinstance(intersection, xsp.BaseSpace) + assert intersection == {"b"} + + xor = xsp.BaseSpace("ab") ^ {"b", "c"} + assert isinstance(xor, xsp.BaseSpace) + assert xor == {"a", "c"} + + def test_tensormath(self): + # Tensors and spaces follow the same basic math rules + addition = xsp.BaseSpace("ab") + {"c"} + assert isinstance(addition, xsp.BaseSpace) + assert addition == {"a", "b", "c"} + + subtraction = xsp.BaseSpace("ab") - {"b", "c"} + assert isinstance(subtraction, xsp.BaseSpace) + assert subtraction == {"a", "b", "c"} + + multiplication = xsp.BaseSpace("ab") * {"c"} + assert isinstance(multiplication, xsp.BaseSpace) + assert multiplication == {"a", "b", "c"} + + division = xsp.BaseSpace("ab") / {"b", "c"} + assert isinstance(division, xsp.BaseSpace) + assert division == {"a", "b", "c"} + + power = xsp.BaseSpace("ba") ** 3 + assert isinstance(power, xsp.BaseSpace) + assert power == {"a", "b"} + + +class TestOrderedSpace: + def test_type(self): + o1 = xsp.OrderedSpace(["b", "a"]) + assert o1.__repr__() == "OrderedSpace('b', 'a')" + assert isinstance(o1, Sequence) + assert len(o1) == 2 + # Addition/multiplication is different compare to tuples + assert not isinstance(o1, tuple) + # And lists would be mutable, but ordered spaces are not + assert not isinstance(o1, list) + + def test_comparison(self): + # Ordered spaces can only be equal to other ordered things + assert xsp.OrderedSpace("a") != {"a"} + assert xsp.OrderedSpace("a") == ("a",) + assert xsp.OrderedSpace("a") == ["a"] + assert xsp.OrderedSpace("a") == xsp.OrderedSpace("a") + # Except for strings, because they could be a dim + assert not xsp.OrderedSpace("a") == "a" + assert xsp.OrderedSpace("a") != "a" + + def test_indexing(self): + b = xsp.Dim("b") + o1 = xsp.OrderedSpace([b, "a"]) + # Ordered spaces can be indexed + assert o1.index("b") == 0 + assert o1[0] is b + sliced = o1[::-1] + assert isinstance(sliced, xsp.OrderedSpace) + assert sliced == ("a", "b") + + +def test_space_function(): + usp = xsp.Space({"a", "b"}) + assert isinstance(usp, xsp.BaseSpace) + assert not isinstance(usp, xsp.OrderedSpace) + + assert isinstance(xsp.Space(["a", "b"]), xsp.OrderedSpace) + assert isinstance(xsp.Space(("a", "b")), xsp.OrderedSpace) + assert isinstance(xsp.Space("ab"), xsp.OrderedSpace)