-
Notifications
You must be signed in to change notification settings - Fork 129
Named tensors with typed spaces #477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
7c5c9a3
05520fd
c7bafc6
bb3e38a
165a090
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import warnings | ||
|
||
import pytensor.xtensor.rewriting | ||
from pytensor.xtensor.spaces import BaseSpace, Dim, DimLike, OrderedSpace | ||
from pytensor.xtensor.type import XTensorType | ||
from pytensor.xtensor.variable import ( | ||
XTensorConstant, | ||
XTensorVariable, | ||
as_xtensor, | ||
as_xtensor_variable, | ||
) | ||
|
||
|
||
warnings.warn("xtensor module is experimental and full of bugs") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
from itertools import chain | ||
from typing import Any, Iterable, Sequence | ||
|
||
import pytensor.scalar as ps | ||
import pytensor.xtensor as px | ||
from pytensor.graph import Apply, Op | ||
from pytensor.graph.basic import Variable | ||
from pytensor.graph.op import OutputStorageType, ParamsInputType | ||
from pytensor.tensor import TensorType | ||
from pytensor.xtensor.spaces import DimLike, Space | ||
|
||
|
||
class TensorFromXTensor(Op): | ||
def make_node(self, *inputs: Variable) -> Apply: | ||
[x] = inputs | ||
if not isinstance(x.type, px.XTensorType): | ||
raise TypeError(f"x must be have an XTensorType, got {type(x.type)}") | ||
output = TensorType(x.type.dtype, shape=x.type.shape)() | ||
return Apply(self, inputs, [output]) | ||
|
||
def perform( | ||
self, | ||
node: Apply, | ||
inputs: Sequence[Any], | ||
output_storage: OutputStorageType, | ||
params: ParamsInputType = None, | ||
) -> None: | ||
[x] = inputs | ||
output_storage[0][0] = x.copy() | ||
|
||
|
||
tensor_from_xtensor = TensorFromXTensor() | ||
|
||
|
||
class XTensorFromTensor(Op): | ||
__props__ = ("dims",) | ||
|
||
def __init__(self, dims: Iterable[DimLike]): | ||
super().__init__() | ||
self.dims = Space(dims) | ||
|
||
def make_node(self, *inputs: Variable) -> Apply: | ||
[x] = inputs | ||
if not isinstance(x.type, TensorType): | ||
raise TypeError(f"x must be an TensorType type, got {type(x.type)}") | ||
output = px.XTensorType(x.type.dtype, dims=self.dims, shape=x.type.shape)() | ||
return Apply(self, inputs, [output]) | ||
|
||
def perform( | ||
self, | ||
node: Apply, | ||
inputs: Sequence[Any], | ||
output_storage: OutputStorageType, | ||
params: ParamsInputType = None, | ||
) -> None: | ||
[x] = inputs | ||
output_storage[0][0] = x.copy() | ||
|
||
|
||
def xtensor_from_tensor(x, dims: Iterable[DimLike]): | ||
return XTensorFromTensor(dims=dims)(x) | ||
|
||
|
||
class XElemwise(Op): | ||
__props__ = ("scalar_op",) | ||
|
||
def __init__(self, scalar_op): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing type hints |
||
super().__init__() | ||
self.scalar_op = scalar_op | ||
|
||
def make_node(self, *inputs): | ||
# TODO: Check dim lengths match | ||
inputs = [px.as_xtensor_variable(inp) for inp in inputs] | ||
# NOTE: The output will have unordered dims | ||
output_dims = set(chain.from_iterable(inp.type.dims for inp in inputs)) | ||
# TODO: Fix dtype | ||
output_type = px.XTensorType( | ||
"float64", dims=output_dims, shape=(None,) * len(output_dims) | ||
) | ||
outputs = [output_type() for _ in range(self.scalar_op.nout)] | ||
return Apply(self, inputs, outputs) | ||
|
||
def perform(self, *args, **kwargs) -> None: | ||
raise NotImplementedError( | ||
"xtensor operations must be rewritten as tensor operations" | ||
) | ||
|
||
|
||
add = XElemwise(ps.add) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
import pytensor.xtensor.rewriting.basic |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from pytensor.graph import node_rewriter | ||
from pytensor.tensor import expand_dims | ||
from pytensor.tensor.elemwise import Elemwise | ||
from pytensor.xtensor.basic import XElemwise, tensor_from_xtensor, xtensor_from_tensor | ||
from pytensor.xtensor.rewriting.utils import register_xcanonicalize | ||
|
||
|
||
@register_xcanonicalize | ||
@node_rewriter(tracks=[XElemwise]) | ||
def xelemwise_to_elemwise(fgraph, node): | ||
# Convert inputs to TensorVariables and add broadcastable dims | ||
output_dims = node.outputs[0].type.dims | ||
|
||
tensor_inputs = [] | ||
for inp in node.inputs: | ||
inp_dims = inp.type.dims | ||
axis = [i for i, dim in enumerate(output_dims) if dim not in inp_dims] | ||
tensor_inp = tensor_from_xtensor(inp) | ||
tensor_inp = expand_dims(tensor_inp, axis) | ||
tensor_inputs.append(tensor_inp) | ||
|
||
tensor_outs = Elemwise(scalar_op=node.op.scalar_op)( | ||
*tensor_inputs, return_list=True | ||
) | ||
|
||
# TODO: copy_stack_trace | ||
new_outs = [ | ||
xtensor_from_tensor(tensor_out, dims=output_dims) for tensor_out in tensor_outs | ||
] | ||
return new_outs |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from typing import Union | ||
|
||
from pytensor.compile import optdb | ||
from pytensor.graph.rewriting.basic import NodeRewriter | ||
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase | ||
|
||
|
||
optdb.register( | ||
"xcanonicalize", | ||
EquilibriumDB(ignore_newtrees=False), | ||
"fast_run", | ||
"fast_compile", | ||
"xtensor", | ||
position=0, | ||
) | ||
|
||
|
||
def register_xcanonicalize( | ||
node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs | ||
): | ||
if isinstance(node_rewriter, str): | ||
|
||
def register(inner_rewriter: Union[RewriteDatabase, NodeRewriter]): | ||
return register_xcanonicalize( | ||
# FIXME: Signature violation below; 2nd argument isn't a str | ||
inner_rewriter, | ||
node_rewriter, # type: ignore | ||
*tags, | ||
**kwargs, | ||
) | ||
|
||
return register | ||
|
||
else: | ||
name = kwargs.pop("name", None) or type(node_rewriter).__name__ | ||
optdb["xtensor"].register( | ||
name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs | ||
) | ||
return node_rewriter |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
import sys | ||
from collections.abc import Iterable, Iterator, Sequence | ||
from typing import ( | ||
FrozenSet, | ||
Protocol, | ||
SupportsIndex, | ||
Union, | ||
cast, | ||
overload, | ||
runtime_checkable, | ||
) | ||
|
||
|
||
@runtime_checkable | ||
class DimLike(Protocol): | ||
"""Most basic signature of a dimension.""" | ||
|
||
def __str__(self) -> str: | ||
... | ||
|
||
def __hash__(self) -> int: | ||
... | ||
|
||
|
||
class Dim(DimLike): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As far as I remember, it is a bad practice to inherit a base class from the protocol There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dim can be just a dataclass instead There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to the explanation in the PEP I would disagree: https://peps.python.org/pep-0544/#explicitly-declaring-implementation By inheriting the protocol, we enable type checkers to warn about incomplete/incorrect implementations. |
||
"""The most common type of dimension.""" | ||
|
||
_name: str | ||
|
||
def __init__(self, name: str) -> None: | ||
self._name = name | ||
super().__init__() | ||
|
||
def __str__(self) -> str: | ||
return self._name | ||
|
||
def __repr__(self) -> str: | ||
return f"Dim('{self._name}')" | ||
|
||
def __eq__(self, __value: object) -> bool: | ||
return self._name == str(__value) | ||
|
||
def __hash__(self) -> int: | ||
return self._name.__hash__() | ||
|
||
|
||
class BaseSpace(FrozenSet[DimLike]): | ||
"""The most generic type of space is an unordered frozen set of dimensions. | ||
|
||
It implements broadcasting rules for the following tensor operators: | ||
* Addition → Unordered union | ||
* Subtraction → Unordered union | ||
* Multiplication → Unordered union | ||
* Power → Identity | ||
|
||
The logic operators (AND &, OR |, XOR ^) do space math with the frozenset. | ||
""" | ||
|
||
def __add__(self, other: Iterable[DimLike]) -> "BaseSpace": | ||
try: | ||
other = Space(other) | ||
return BaseSpace({*self, *other}) | ||
except Exception as ex: | ||
raise TypeError(f"Can't {other} to space.") from ex | ||
|
||
def __sub__(self, other: Iterable[DimLike]) -> "BaseSpace": | ||
try: | ||
other = Space(other) | ||
return BaseSpace({*self, *other}) | ||
except Exception as ex: | ||
raise TypeError(f"Can't subtract {other} from space.") from ex | ||
|
||
def __mul__(self, other) -> "BaseSpace": | ||
try: | ||
other = Space(other) | ||
return BaseSpace({*self, *other}) | ||
except Exception as ex: | ||
raise TypeError(f"Can't multiply space by {other}.") from ex | ||
|
||
def __truediv__(self, other) -> "BaseSpace": | ||
try: | ||
other = Space(other) | ||
return BaseSpace({*self, *other}) | ||
except Exception as ex: | ||
raise TypeError(f"Can't divide space by {other}.") from ex | ||
|
||
def __pow__(self, other: Iterable[DimLike]) -> "BaseSpace": | ||
return self | ||
|
||
def __and__(self, other: Iterable[DimLike]) -> "BaseSpace": | ||
try: | ||
other = Space(other) | ||
return BaseSpace(set(self) & other) | ||
except Exception as ex: | ||
raise TypeError(f"Can't AND space with {other}.") from ex | ||
|
||
def __or__(self, other: Iterable[DimLike]) -> "BaseSpace": | ||
try: | ||
other = Space(other) | ||
return BaseSpace(set(self) | other) | ||
except Exception as ex: | ||
raise TypeError(f"Can't OR space with {other}.") from ex | ||
|
||
def __xor__(self, other: Iterable[DimLike]) -> "BaseSpace": | ||
try: | ||
other = Space(other) | ||
return BaseSpace(set(self) ^ other) | ||
except Exception as ex: | ||
raise TypeError(f"Can't XOR space with {other}.") from ex | ||
|
||
def __repr__(self) -> str: | ||
return "Space{" + ", ".join(f"'{d}'" for d in self) + "}" | ||
|
||
|
||
class OrderedSpace(BaseSpace, Sequence[DimLike]): | ||
"""A very tidied-up space, with a known order of dimensions.""" | ||
|
||
def __init__(self, dims: Sequence[DimLike]) -> None: | ||
self._order = tuple(dims) | ||
super().__init__() | ||
|
||
def __eq__(self, __value) -> bool: | ||
if not isinstance(__value, Sequence) or isinstance(__value, str): | ||
return False | ||
return self._order == tuple(__value) | ||
|
||
def __ne__(self, __value) -> bool: | ||
if not isinstance(__value, Sequence) or isinstance(__value, str): | ||
return True | ||
return self._order != tuple(__value) | ||
|
||
def __iter__(self) -> Iterator[DimLike]: | ||
yield from self._order | ||
|
||
def __reversed__(self) -> Iterator[DimLike]: | ||
yield from reversed(self._order) | ||
|
||
def index( | ||
self, | ||
__value: DimLike, | ||
__start: SupportsIndex = 0, | ||
__stop: SupportsIndex = sys.maxsize, | ||
) -> int: | ||
return self._order.index(__value, __start, __stop) | ||
|
||
@overload | ||
def __getitem__(self, __key: SupportsIndex) -> DimLike: | ||
"""Indexing gives a dim""" | ||
|
||
@overload | ||
def __getitem__(self, __key: slice) -> "OrderedSpace": | ||
"""Slicing preserves order""" | ||
|
||
def __getitem__(self, __key) -> Union[DimLike, "OrderedSpace"]: | ||
if isinstance(__key, slice): | ||
return OrderedSpace(self._order[__key]) | ||
return cast(DimLike, self._order[__key]) | ||
|
||
def __repr__(self) -> str: | ||
return "OrderedSpace(" + ", ".join(f"'{d}'" for d in self) + ")" | ||
|
||
|
||
@overload | ||
def Space(dims: Sequence[DimLike]) -> OrderedSpace: | ||
"""Sequences of dims give an ordered space.""" | ||
... | ||
|
||
|
||
@overload | ||
def Space(dims: Iterable[DimLike]) -> BaseSpace: | ||
... | ||
|
||
|
||
def Space(dims: Iterable[DimLike]) -> Union[OrderedSpace, BaseSpace]: | ||
if isinstance(dims, Sequence): | ||
return OrderedSpace(dims) | ||
return BaseSpace(dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dims should be Sequence since Iterable can exhaust...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean by "exhaust" here?