-
Notifications
You must be signed in to change notification settings - Fork 34
TYP: Type annotations, part 4 #313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
362c48a
ad375dc
49f9ba7
4371506
c724a52
14f70af
0a571bc
0172300
8711041
014e20f
7c5408c
924fc3d
5d98aa8
8eb647f
a06d51f
247ee6d
85fce08
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -12,56 +12,51 @@ | |||||
import math | ||||||
import sys | ||||||
import warnings | ||||||
from collections.abc import Collection | ||||||
from types import NoneType | ||||||
from typing import ( | ||||||
TYPE_CHECKING, | ||||||
Any, | ||||||
Final, | ||||||
Literal, | ||||||
SupportsIndex, | ||||||
TypeAlias, | ||||||
TypeGuard, | ||||||
TypeVar, | ||||||
cast, | ||||||
overload, | ||||||
) | ||||||
|
||||||
from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace | ||||||
|
||||||
if TYPE_CHECKING: | ||||||
|
||||||
import cupy as cp | ||||||
import dask.array as da | ||||||
import jax | ||||||
import ndonnx as ndx | ||||||
import numpy as np | ||||||
import numpy.typing as npt | ||||||
import sparse # pyright: ignore[reportMissingTypeStubs] | ||||||
import sparse | ||||||
import torch | ||||||
|
||||||
# TODO: import from typing (requires Python >=3.13) | ||||||
from typing_extensions import TypeIs, TypeVar | ||||||
|
||||||
_SizeT = TypeVar("_SizeT", bound = int | None) | ||||||
from typing_extensions import TypeIs | ||||||
|
||||||
_ZeroGradientArray: TypeAlias = npt.NDArray[np.void] | ||||||
_CupyArray: TypeAlias = Any # cupy has no py.typed | ||||||
|
||||||
_ArrayApiObj: TypeAlias = ( | ||||||
npt.NDArray[Any] | ||||||
| cp.ndarray | ||||||
| da.Array | ||||||
| jax.Array | ||||||
| ndx.Array | ||||||
| sparse.SparseArray | ||||||
| torch.Tensor | ||||||
| SupportsArrayNamespace[Any] | ||||||
| _CupyArray | ||||||
| SupportsArrayNamespace | ||||||
) | ||||||
|
||||||
_API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"}) | ||||||
_API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"}) | ||||||
|
||||||
|
||||||
def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: | ||||||
def _is_jax_zero_gradient_array(x: object) -> TypeIs[_ZeroGradientArray]: | ||||||
crusaderky marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
"""Return True if `x` is a zero-gradient array. | ||||||
|
||||||
These arrays are a design quirk of Jax that may one day be removed. | ||||||
|
@@ -80,7 +75,7 @@ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: | |||||
) | ||||||
|
||||||
|
||||||
def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]: | ||||||
def is_numpy_array(x: object) -> TypeIs[npt.NDArray[Any]]: | ||||||
""" | ||||||
Return True if `x` is a NumPy array. | ||||||
|
||||||
|
@@ -137,7 +132,7 @@ def is_cupy_array(x: object) -> bool: | |||||
if "cupy" not in sys.modules: | ||||||
return False | ||||||
|
||||||
import cupy as cp # pyright: ignore[reportMissingTypeStubs] | ||||||
import cupy as cp | ||||||
|
||||||
# TODO: Should we reject ndarray subclasses? | ||||||
return isinstance(x, cp.ndarray) # pyright: ignore[reportUnknownMemberType] | ||||||
|
@@ -280,13 +275,13 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: | |||||
if "sparse" not in sys.modules: | ||||||
return False | ||||||
|
||||||
import sparse # pyright: ignore[reportMissingTypeStubs] | ||||||
import sparse | ||||||
|
||||||
# TODO: Account for other backends. | ||||||
return isinstance(x, sparse.SparseArray) | ||||||
|
||||||
|
||||||
def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType] | ||||||
def is_array_api_obj(x: object) -> TypeGuard[_ArrayApiObj]: | ||||||
jorenham marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
""" | ||||||
Return True if `x` is an array API compatible array object. | ||||||
|
||||||
|
@@ -587,7 +582,7 @@ def your_function(x, y): | |||||
|
||||||
namespaces.add(cupy_namespace) | ||||||
else: | ||||||
import cupy as cp # pyright: ignore[reportMissingTypeStubs] | ||||||
import cupy as cp | ||||||
|
||||||
namespaces.add(cp) | ||||||
elif is_torch_array(x): | ||||||
|
@@ -624,14 +619,14 @@ def your_function(x, y): | |||||
if hasattr(jax.numpy, "__array_api_version__"): | ||||||
jnp = jax.numpy | ||||||
else: | ||||||
import jax.experimental.array_api as jnp # pyright: ignore[reportMissingImports] | ||||||
import jax.experimental.array_api as jnp # type: ignore[no-redef] | ||||||
namespaces.add(jnp) | ||||||
elif is_pydata_sparse_array(x): | ||||||
if use_compat is True: | ||||||
_check_api_version(api_version) | ||||||
raise ValueError("`sparse` does not have an array-api-compat wrapper") | ||||||
else: | ||||||
import sparse # pyright: ignore[reportMissingTypeStubs] | ||||||
import sparse | ||||||
jorenham marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
# `sparse` is already an array namespace. We do not have a wrapper | ||||||
# submodule for it. | ||||||
namespaces.add(sparse) | ||||||
|
@@ -640,9 +635,9 @@ def your_function(x, y): | |||||
raise ValueError( | ||||||
"The given array does not have an array-api-compat wrapper" | ||||||
) | ||||||
x = cast("SupportsArrayNamespace[Any]", x) | ||||||
x = cast(SupportsArrayNamespace, x) | ||||||
jorenham marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
namespaces.add(x.__array_namespace__(api_version=api_version)) | ||||||
elif isinstance(x, (bool, int, float, complex, type(None))): | ||||||
elif isinstance(x, int | float | complex | NoneType): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
(I'll spare you the pseudo-philosophical rant this time) |
||||||
continue | ||||||
else: | ||||||
# TODO: Support Python scalars? | ||||||
|
@@ -738,7 +733,7 @@ def device(x: _ArrayApiObj, /) -> Device: | |||||
return "cpu" | ||||||
elif is_dask_array(x): | ||||||
# Peek at the metadata of the Dask array to determine type | ||||||
if is_numpy_array(x._meta): # pyright: ignore | ||||||
if is_numpy_array(x._meta): | ||||||
# Must be on CPU since backed by numpy | ||||||
return "cpu" | ||||||
return _DASK_DEVICE | ||||||
|
@@ -767,7 +762,7 @@ def device(x: _ArrayApiObj, /) -> Device: | |||||
return "cpu" | ||||||
# Return the device of the constituent array | ||||||
return device(inner) # pyright: ignore | ||||||
return x.device # pyright: ignore | ||||||
return x.device # type: ignore # pyright: ignore | ||||||
|
||||||
|
||||||
# Prevent shadowing, used below | ||||||
|
@@ -776,12 +771,12 @@ def device(x: _ArrayApiObj, /) -> Device: | |||||
|
||||||
# Based on cupy.array_api.Array.to_device | ||||||
def _cupy_to_device( | ||||||
x: _CupyArray, | ||||||
x: cp.ndarray, | ||||||
device: Device, | ||||||
/, | ||||||
stream: int | Any | None = None, | ||||||
) -> _CupyArray: | ||||||
import cupy as cp # pyright: ignore[reportMissingTypeStubs] | ||||||
) -> cp.ndarray: | ||||||
jorenham marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
import cupy as cp | ||||||
from cupy.cuda import Device as _Device # pyright: ignore | ||||||
from cupy.cuda import stream as stream_module # pyright: ignore | ||||||
from cupy_backends.cuda.api import runtime # pyright: ignore | ||||||
|
@@ -797,10 +792,10 @@ def _cupy_to_device( | |||||
raise ValueError(f"Unsupported device {device!r}") | ||||||
else: | ||||||
# see cupy/cupy#5985 for the reason how we handle device/stream here | ||||||
prev_device: Any = runtime.getDevice() # pyright: ignore[reportUnknownMemberType] | ||||||
prev_device: Device = runtime.getDevice() # pyright: ignore[reportUnknownMemberType] | ||||||
prev_stream = None | ||||||
if stream is not None: | ||||||
prev_stream: Any = stream_module.get_current_stream() # pyright: ignore | ||||||
prev_stream = stream_module.get_current_stream() # pyright: ignore | ||||||
# stream can be an int as specified in __dlpack__, or a CuPy stream | ||||||
if isinstance(stream, int): | ||||||
stream = cp.cuda.ExternalStream(stream) # pyright: ignore | ||||||
|
@@ -814,7 +809,7 @@ def _cupy_to_device( | |||||
arr = x.copy() | ||||||
finally: | ||||||
runtime.setDevice(prev_device) # pyright: ignore[reportUnknownMemberType] | ||||||
if stream is not None: | ||||||
if prev_stream is not None: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. casual bugfix 🤔 ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's actually logically identical to before. But it was convoluted and rightfully the type checker was complaining. |
||||||
prev_stream.use() | ||||||
return arr | ||||||
|
||||||
|
@@ -823,7 +818,7 @@ def _torch_to_device( | |||||
x: torch.Tensor, | ||||||
device: torch.device | str | int, | ||||||
/, | ||||||
stream: None = None, | ||||||
stream: int | Any | None = None, | ||||||
jorenham marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
) -> torch.Tensor: | ||||||
if stream is not None: | ||||||
raise NotImplementedError | ||||||
|
@@ -889,7 +884,7 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) - | |||||
# cupy does not yet have to_device | ||||||
return _cupy_to_device(x, device, stream=stream) | ||||||
elif is_torch_array(x): | ||||||
return _torch_to_device(x, device, stream=stream) # pyright: ignore[reportArgumentType] | ||||||
return _torch_to_device(x, device, stream=stream) | ||||||
elif is_dask_array(x): | ||||||
if stream is not None: | ||||||
raise ValueError("The stream argument to to_device() is not supported") | ||||||
|
@@ -914,12 +909,12 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) - | |||||
|
||||||
|
||||||
@overload | ||||||
def size(x: HasShape[Collection[SupportsIndex]]) -> int: ... | ||||||
def size(x: HasShape[int]) -> int: ... | ||||||
@overload | ||||||
def size(x: HasShape[Collection[None]]) -> None: ... | ||||||
def size(x: HasShape[int | None]) -> int | None: ... | ||||||
@overload | ||||||
def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ... | ||||||
def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: | ||||||
def size(x: HasShape[float]) -> int | None: ... # Dask special case | ||||||
def size(x: HasShape[float | None]) -> int | None: | ||||||
jorenham marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
""" | ||||||
Return the total number of elements of x. | ||||||
|
||||||
|
@@ -934,12 +929,12 @@ def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: | |||||
# Lazy API compliant arrays, such as ndonnx, can contain None in their shape | ||||||
if None in x.shape: | ||||||
return None | ||||||
out = math.prod(cast("Collection[SupportsIndex]", x.shape)) | ||||||
out = math.prod(cast(tuple[float, ...], x.shape)) | ||||||
# dask.array.Array.shape can contain NaN | ||||||
return None if math.isnan(out) else out | ||||||
return None if math.isnan(out) else cast(int, out) | ||||||
|
||||||
|
||||||
def is_writeable_array(x: object) -> bool: | ||||||
def is_writeable_array(x: object) -> TypeGuard[_ArrayApiObj]: | ||||||
""" | ||||||
Return False if ``x.__setitem__`` is expected to raise; True otherwise. | ||||||
Return False if `x` is not an array API compatible object. | ||||||
|
@@ -956,7 +951,7 @@ def is_writeable_array(x: object) -> bool: | |||||
return is_array_api_obj(x) | ||||||
|
||||||
|
||||||
def is_lazy_array(x: object) -> bool: | ||||||
def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]: | ||||||
jorenham marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
"""Return True if x is potentially a future or it may be otherwise impossible or | ||||||
expensive to eagerly read its contents, regardless of their size, e.g. by | ||||||
calling ``bool(x)`` or ``float(x)``. | ||||||
|
@@ -997,7 +992,7 @@ def is_lazy_array(x: object) -> bool: | |||||
# on __bool__ (dask is one such example, which however is special-cased above). | ||||||
|
||||||
# Select a single point of the array | ||||||
s = size(cast("HasShape[Collection[SupportsIndex | None]]", x)) | ||||||
s = size(cast(HasShape, x)) | ||||||
if s is None: | ||||||
return True | ||||||
xp = array_namespace(x) | ||||||
|
@@ -1044,5 +1039,6 @@ def is_lazy_array(x: object) -> bool: | |||||
|
||||||
_all_ignore = ["sys", "math", "inspect", "warnings"] | ||||||
|
||||||
|
||||||
def __dir__() -> list[str]: | ||||||
return __all__ |
Uh oh!
There was an error while loading. Please reload this page.