Skip to content

Named tensors with typed spaces #477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions pytensor/xtensor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import warnings

import pytensor.xtensor.rewriting
from pytensor.xtensor.spaces import BaseSpace, Dim, DimLike, OrderedSpace
from pytensor.xtensor.type import XTensorType
from pytensor.xtensor.variable import (
XTensorConstant,
XTensorVariable,
as_xtensor,
as_xtensor_variable,
)


warnings.warn("xtensor module is experimental and full of bugs")
89 changes: 89 additions & 0 deletions pytensor/xtensor/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from itertools import chain
from typing import Any, Iterable, Sequence

import pytensor.scalar as ps
import pytensor.xtensor as px
from pytensor.graph import Apply, Op
from pytensor.graph.basic import Variable
from pytensor.graph.op import OutputStorageType, ParamsInputType
from pytensor.tensor import TensorType
from pytensor.xtensor.spaces import DimLike, Space


class TensorFromXTensor(Op):
def make_node(self, *inputs: Variable) -> Apply:
[x] = inputs
if not isinstance(x.type, px.XTensorType):
raise TypeError(f"x must be have an XTensorType, got {type(x.type)}")
output = TensorType(x.type.dtype, shape=x.type.shape)()
return Apply(self, inputs, [output])

def perform(
self,
node: Apply,
inputs: Sequence[Any],
output_storage: OutputStorageType,
params: ParamsInputType = None,
) -> None:
[x] = inputs
output_storage[0][0] = x.copy()


tensor_from_xtensor = TensorFromXTensor()


class XTensorFromTensor(Op):
__props__ = ("dims",)

def __init__(self, dims: Iterable[DimLike]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dims should be Sequence since Iterable can exhaust...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by "exhaust" here?

super().__init__()
self.dims = Space(dims)

def make_node(self, *inputs: Variable) -> Apply:
[x] = inputs
if not isinstance(x.type, TensorType):
raise TypeError(f"x must be an TensorType type, got {type(x.type)}")
output = px.XTensorType(x.type.dtype, dims=self.dims, shape=x.type.shape)()
return Apply(self, inputs, [output])

def perform(
self,
node: Apply,
inputs: Sequence[Any],
output_storage: OutputStorageType,
params: ParamsInputType = None,
) -> None:
[x] = inputs
output_storage[0][0] = x.copy()


def xtensor_from_tensor(x, dims: Iterable[DimLike]):
return XTensorFromTensor(dims=dims)(x)


class XElemwise(Op):
__props__ = ("scalar_op",)

def __init__(self, scalar_op):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing type hints

super().__init__()
self.scalar_op = scalar_op

def make_node(self, *inputs):
# TODO: Check dim lengths match
inputs = [px.as_xtensor_variable(inp) for inp in inputs]
# NOTE: The output will have unordered dims
output_dims = set(chain.from_iterable(inp.type.dims for inp in inputs))
# TODO: Fix dtype
output_type = px.XTensorType(
"float64", dims=output_dims, shape=(None,) * len(output_dims)
)
outputs = [output_type() for _ in range(self.scalar_op.nout)]
return Apply(self, inputs, outputs)

def perform(self, *args, **kwargs) -> None:
raise NotImplementedError(
"xtensor operations must be rewritten as tensor operations"
)


add = XElemwise(ps.add)
1 change: 1 addition & 0 deletions pytensor/xtensor/rewriting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import pytensor.xtensor.rewriting.basic
30 changes: 30 additions & 0 deletions pytensor/xtensor/rewriting/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from pytensor.graph import node_rewriter
from pytensor.tensor import expand_dims
from pytensor.tensor.elemwise import Elemwise
from pytensor.xtensor.basic import XElemwise, tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.utils import register_xcanonicalize


@register_xcanonicalize
@node_rewriter(tracks=[XElemwise])
def xelemwise_to_elemwise(fgraph, node):
# Convert inputs to TensorVariables and add broadcastable dims
output_dims = node.outputs[0].type.dims

tensor_inputs = []
for inp in node.inputs:
inp_dims = inp.type.dims
axis = [i for i, dim in enumerate(output_dims) if dim not in inp_dims]
tensor_inp = tensor_from_xtensor(inp)
tensor_inp = expand_dims(tensor_inp, axis)
tensor_inputs.append(tensor_inp)

tensor_outs = Elemwise(scalar_op=node.op.scalar_op)(
*tensor_inputs, return_list=True
)

# TODO: copy_stack_trace
new_outs = [
xtensor_from_tensor(tensor_out, dims=output_dims) for tensor_out in tensor_outs
]
return new_outs
39 changes: 39 additions & 0 deletions pytensor/xtensor/rewriting/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Union

from pytensor.compile import optdb
from pytensor.graph.rewriting.basic import NodeRewriter
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase


optdb.register(
"xcanonicalize",
EquilibriumDB(ignore_newtrees=False),
"fast_run",
"fast_compile",
"xtensor",
position=0,
)


def register_xcanonicalize(
node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs
):
if isinstance(node_rewriter, str):

def register(inner_rewriter: Union[RewriteDatabase, NodeRewriter]):
return register_xcanonicalize(
# FIXME: Signature violation below; 2nd argument isn't a str
inner_rewriter,
node_rewriter, # type: ignore
*tags,
**kwargs,
)

return register

else:
name = kwargs.pop("name", None) or type(node_rewriter).__name__
optdb["xtensor"].register(
name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs
)
return node_rewriter
177 changes: 177 additions & 0 deletions pytensor/xtensor/spaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import sys
from collections.abc import Iterable, Iterator, Sequence
from typing import (
FrozenSet,
Protocol,
SupportsIndex,
Union,
cast,
overload,
runtime_checkable,
)


@runtime_checkable
class DimLike(Protocol):
"""Most basic signature of a dimension."""

def __str__(self) -> str:
...

def __hash__(self) -> int:
...


class Dim(DimLike):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I remember, it is a bad practice to inherit a base class from the protocol

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dim can be just a dataclass instead

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the explanation in the PEP I would disagree: https://peps.python.org/pep-0544/#explicitly-declaring-implementation

By inheriting the protocol, we enable type checkers to warn about incomplete/incorrect implementations.

"""The most common type of dimension."""

_name: str

def __init__(self, name: str) -> None:
self._name = name
super().__init__()

def __str__(self) -> str:
return self._name

def __repr__(self) -> str:
return f"Dim('{self._name}')"

def __eq__(self, __value: object) -> bool:
return self._name == str(__value)

def __hash__(self) -> int:
return self._name.__hash__()


class BaseSpace(FrozenSet[DimLike]):
"""The most generic type of space is an unordered frozen set of dimensions.

It implements broadcasting rules for the following tensor operators:
* Addition → Unordered union
* Subtraction → Unordered union
* Multiplication → Unordered union
* Power → Identity

The logic operators (AND &, OR |, XOR ^) do space math with the frozenset.
"""

def __add__(self, other: Iterable[DimLike]) -> "BaseSpace":
try:
other = Space(other)
return BaseSpace({*self, *other})
except Exception as ex:
raise TypeError(f"Can't {other} to space.") from ex

def __sub__(self, other: Iterable[DimLike]) -> "BaseSpace":
try:
other = Space(other)
return BaseSpace({*self, *other})
except Exception as ex:
raise TypeError(f"Can't subtract {other} from space.") from ex

def __mul__(self, other) -> "BaseSpace":
try:
other = Space(other)
return BaseSpace({*self, *other})
except Exception as ex:
raise TypeError(f"Can't multiply space by {other}.") from ex

def __truediv__(self, other) -> "BaseSpace":
try:
other = Space(other)
return BaseSpace({*self, *other})
except Exception as ex:
raise TypeError(f"Can't divide space by {other}.") from ex

def __pow__(self, other: Iterable[DimLike]) -> "BaseSpace":
return self

def __and__(self, other: Iterable[DimLike]) -> "BaseSpace":
try:
other = Space(other)
return BaseSpace(set(self) & other)
except Exception as ex:
raise TypeError(f"Can't AND space with {other}.") from ex

def __or__(self, other: Iterable[DimLike]) -> "BaseSpace":
try:
other = Space(other)
return BaseSpace(set(self) | other)
except Exception as ex:
raise TypeError(f"Can't OR space with {other}.") from ex

def __xor__(self, other: Iterable[DimLike]) -> "BaseSpace":
try:
other = Space(other)
return BaseSpace(set(self) ^ other)
except Exception as ex:
raise TypeError(f"Can't XOR space with {other}.") from ex

def __repr__(self) -> str:
return "Space{" + ", ".join(f"'{d}'" for d in self) + "}"


class OrderedSpace(BaseSpace, Sequence[DimLike]):
"""A very tidied-up space, with a known order of dimensions."""

def __init__(self, dims: Sequence[DimLike]) -> None:
self._order = tuple(dims)
super().__init__()

def __eq__(self, __value) -> bool:
if not isinstance(__value, Sequence) or isinstance(__value, str):
return False
return self._order == tuple(__value)

def __ne__(self, __value) -> bool:
if not isinstance(__value, Sequence) or isinstance(__value, str):
return True
return self._order != tuple(__value)

def __iter__(self) -> Iterator[DimLike]:
yield from self._order

def __reversed__(self) -> Iterator[DimLike]:
yield from reversed(self._order)

def index(
self,
__value: DimLike,
__start: SupportsIndex = 0,
__stop: SupportsIndex = sys.maxsize,
) -> int:
return self._order.index(__value, __start, __stop)

@overload
def __getitem__(self, __key: SupportsIndex) -> DimLike:
"""Indexing gives a dim"""

@overload
def __getitem__(self, __key: slice) -> "OrderedSpace":
"""Slicing preserves order"""

def __getitem__(self, __key) -> Union[DimLike, "OrderedSpace"]:
if isinstance(__key, slice):
return OrderedSpace(self._order[__key])
return cast(DimLike, self._order[__key])

def __repr__(self) -> str:
return "OrderedSpace(" + ", ".join(f"'{d}'" for d in self) + ")"


@overload
def Space(dims: Sequence[DimLike]) -> OrderedSpace:
"""Sequences of dims give an ordered space."""
...


@overload
def Space(dims: Iterable[DimLike]) -> BaseSpace:
...


def Space(dims: Iterable[DimLike]) -> Union[OrderedSpace, BaseSpace]:
if isinstance(dims, Sequence):
return OrderedSpace(dims)
return BaseSpace(dims)
Loading