From 362c48a7d060e880150b1a60adb69d398afae690 Mon Sep 17 00:00:00 2001
From: crusaderky <crusaderky@gmail.com>
Date: Fri, 18 Apr 2025 12:23:29 +0100
Subject: [PATCH 01/12] Type annotations, part 4

---
 array_api_compat/_internal.py           |   4 +-
 array_api_compat/common/_aliases.py     |  16 +-
 array_api_compat/common/_helpers.py     |  76 +++++-----
 array_api_compat/common/_linalg.py      |  13 +-
 array_api_compat/common/_typing.py      |  68 +--------
 array_api_compat/cupy/_aliases.py       |  28 ++--
 array_api_compat/cupy/fft.py            |   9 +-
 array_api_compat/cupy/linalg.py         |   2 +-
 array_api_compat/dask/array/__init__.py |   2 +-
 array_api_compat/dask/array/_aliases.py |   2 +-
 array_api_compat/dask/array/_info.py    |  80 ++--------
 array_api_compat/dask/array/fft.py      |   2 +-
 array_api_compat/dask/array/linalg.py   |  19 +--
 array_api_compat/numpy/__init__.py      |  11 +-
 array_api_compat/numpy/_aliases.py      |  35 ++---
 array_api_compat/numpy/_info.py         |  12 +-
 array_api_compat/numpy/_typing.py       |   1 -
 array_api_compat/numpy/linalg.py        |   4 +-
 array_api_compat/torch/_aliases.py      | 190 ++++++++++++------------
 array_api_compat/torch/fft.py           |  19 +--
 array_api_compat/torch/linalg.py        |  14 +-
 pyproject.toml                          |  58 +++++---
 22 files changed, 269 insertions(+), 396 deletions(-)

diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py
index cd8d939f..b1925492 100644
--- a/array_api_compat/_internal.py
+++ b/array_api_compat/_internal.py
@@ -46,8 +46,8 @@ def wrapped_f(*args: object, **kwargs: object) -> object:
 specification for more details.
 
 """
-        wrapped_f.__signature__ = new_sig  # pyright: ignore[reportAttributeAccessIssue]
-        return wrapped_f  # pyright: ignore[reportReturnType]
+        wrapped_f.__signature__ = new_sig  # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
+        return wrapped_f  # type: ignore[return-value] # pyright: ignore[reportReturnType]
 
     return inner
 
diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py
index 8ea9162a..f7bfc44d 100644
--- a/array_api_compat/common/_aliases.py
+++ b/array_api_compat/common/_aliases.py
@@ -5,11 +5,13 @@
 from __future__ import annotations
 
 import inspect
-from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast
+from collections.abc import Sequence
+from types import NoneType
+from typing import TYPE_CHECKING, Any, NamedTuple, cast
 
 from ._helpers import _check_device, array_namespace
 from ._helpers import device as _get_device
-from ._helpers import is_cupy_namespace as _is_cupy_namespace
+from ._helpers import is_cupy_namespace
 from ._typing import Array, Device, DType, Namespace
 
 if TYPE_CHECKING:
@@ -381,8 +383,8 @@ def clip(
     # TODO: np.clip has other ufunc kwargs
     out: Array | None = None,
 ) -> Array:
-    def _isscalar(a: object) -> TypeIs[int | float | None]:
-        return isinstance(a, (int, float, type(None)))
+    def _isscalar(a: object) -> TypeIs[float | None]:
+        return isinstance(a, int | float | NoneType)
 
     min_shape = () if _isscalar(min) else min.shape
     max_shape = () if _isscalar(max) else max.shape
@@ -450,7 +452,7 @@ def reshape(
     shape: tuple[int, ...],
     xp: Namespace,
     *,
-    copy: Optional[bool] = None,
+    copy: bool | None = None,
     **kwargs: object,
 ) -> Array:
     if copy is True:
@@ -657,7 +659,7 @@ def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
         out = xp.sign(x, **kwargs)
     # CuPy sign() does not propagate nans. See
     # https://github.com/data-apis/array-api-compat/issues/136
-    if _is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp):
+    if is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp):
         out[xp.isnan(x)] = xp.nan
     return out[()]
 
@@ -720,7 +722,7 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any:
     "finfo",
     "iinfo",
 ]
-_all_ignore = ["inspect", "array_namespace", "NamedTuple"]
+_all_ignore = ["is_cupy_namespace", "inspect", "array_namespace", "NamedTuple"]
 
 
 def __dir__() -> list[str]:
diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py
index db3e4cd7..c3b3a4f1 100644
--- a/array_api_compat/common/_helpers.py
+++ b/array_api_compat/common/_helpers.py
@@ -12,16 +12,14 @@
 import math
 import sys
 import warnings
-from collections.abc import Collection
+from types import NoneType
 from typing import (
     TYPE_CHECKING,
     Any,
     Final,
     Literal,
-    SupportsIndex,
     TypeAlias,
     TypeGuard,
-    TypeVar,
     cast,
     overload,
 )
@@ -29,39 +27,36 @@
 from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace
 
 if TYPE_CHECKING:
-
+    import cupy as cp
     import dask.array as da
     import jax
     import ndonnx as ndx
     import numpy as np
     import numpy.typing as npt
-    import sparse  # pyright: ignore[reportMissingTypeStubs]
+    import sparse
     import torch
 
     # TODO: import from typing (requires Python >=3.13)
-    from typing_extensions import TypeIs, TypeVar
-
-    _SizeT = TypeVar("_SizeT", bound = int | None)
+    from typing_extensions import TypeIs
 
     _ZeroGradientArray: TypeAlias = npt.NDArray[np.void]
-    _CupyArray: TypeAlias = Any  # cupy has no py.typed
 
     _ArrayApiObj: TypeAlias = (
         npt.NDArray[Any]
+        | cp.ndarray
         | da.Array
         | jax.Array
         | ndx.Array
         | sparse.SparseArray
         | torch.Tensor
-        | SupportsArrayNamespace[Any]
-        | _CupyArray
+        | SupportsArrayNamespace
     )
 
 _API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"})
 _API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"})
 
 
-def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
+def _is_jax_zero_gradient_array(x: object) -> TypeIs[_ZeroGradientArray]:
     """Return True if `x` is a zero-gradient array.
 
     These arrays are a design quirk of Jax that may one day be removed.
@@ -80,7 +75,7 @@ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
     )
 
 
-def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]:
+def is_numpy_array(x: object) -> TypeIs[npt.NDArray[Any]]:
     """
     Return True if `x` is a NumPy array.
 
@@ -137,7 +132,7 @@ def is_cupy_array(x: object) -> bool:
     if "cupy" not in sys.modules:
         return False
 
-    import cupy as cp  # pyright: ignore[reportMissingTypeStubs]
+    import cupy as cp
 
     # TODO: Should we reject ndarray subclasses?
     return isinstance(x, cp.ndarray)  # pyright: ignore[reportUnknownMemberType]
@@ -280,13 +275,13 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
     if "sparse" not in sys.modules:
         return False
 
-    import sparse  # pyright: ignore[reportMissingTypeStubs]
+    import sparse
 
     # TODO: Account for other backends.
     return isinstance(x, sparse.SparseArray)
 
 
-def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]:  # pyright: ignore[reportUnknownParameterType]
+def is_array_api_obj(x: object) -> TypeGuard[_ArrayApiObj]:
     """
     Return True if `x` is an array API compatible array object.
 
@@ -587,7 +582,7 @@ def your_function(x, y):
 
                 namespaces.add(cupy_namespace)
             else:
-                import cupy as cp  # pyright: ignore[reportMissingTypeStubs]
+                import cupy as cp
 
                 namespaces.add(cp)
         elif is_torch_array(x):
@@ -624,14 +619,14 @@ def your_function(x, y):
                 if hasattr(jax.numpy, "__array_api_version__"):
                     jnp = jax.numpy
                 else:
-                    import jax.experimental.array_api as jnp  # pyright: ignore[reportMissingImports]
+                    import jax.experimental.array_api as jnp  # type: ignore[no-redef]
             namespaces.add(jnp)
         elif is_pydata_sparse_array(x):
             if use_compat is True:
                 _check_api_version(api_version)
                 raise ValueError("`sparse` does not have an array-api-compat wrapper")
             else:
-                import sparse  # pyright: ignore[reportMissingTypeStubs]
+                import sparse
             # `sparse` is already an array namespace. We do not have a wrapper
             # submodule for it.
             namespaces.add(sparse)
@@ -640,9 +635,9 @@ def your_function(x, y):
                 raise ValueError(
                     "The given array does not have an array-api-compat wrapper"
                 )
-            x = cast("SupportsArrayNamespace[Any]", x)
+            x = cast(SupportsArrayNamespace, x)
             namespaces.add(x.__array_namespace__(api_version=api_version))
-        elif isinstance(x, (bool, int, float, complex, type(None))):
+        elif isinstance(x, int | float | complex | NoneType):
             continue
         else:
             # TODO: Support Python scalars?
@@ -738,7 +733,7 @@ def device(x: _ArrayApiObj, /) -> Device:
         return "cpu"
     elif is_dask_array(x):
         # Peek at the metadata of the Dask array to determine type
-        if is_numpy_array(x._meta):  # pyright: ignore
+        if is_numpy_array(x._meta):
             # Must be on CPU since backed by numpy
             return "cpu"
         return _DASK_DEVICE
@@ -767,7 +762,7 @@ def device(x: _ArrayApiObj, /) -> Device:
             return "cpu"
         # Return the device of the constituent array
         return device(inner)  # pyright: ignore
-    return x.device  # pyright: ignore
+    return x.device  # type: ignore  # pyright: ignore
 
 
 # Prevent shadowing, used below
@@ -776,12 +771,12 @@ def device(x: _ArrayApiObj, /) -> Device:
 
 # Based on cupy.array_api.Array.to_device
 def _cupy_to_device(
-    x: _CupyArray,
+    x: cp.ndarray,
     device: Device,
     /,
     stream: int | Any | None = None,
-) -> _CupyArray:
-    import cupy as cp  # pyright: ignore[reportMissingTypeStubs]
+) -> cp.ndarray:
+    import cupy as cp
     from cupy.cuda import Device as _Device  # pyright: ignore
     from cupy.cuda import stream as stream_module  # pyright: ignore
     from cupy_backends.cuda.api import runtime  # pyright: ignore
@@ -797,10 +792,10 @@ def _cupy_to_device(
         raise ValueError(f"Unsupported device {device!r}")
     else:
         # see cupy/cupy#5985 for the reason how we handle device/stream here
-        prev_device: Any = runtime.getDevice()  # pyright: ignore[reportUnknownMemberType]
+        prev_device: Device = runtime.getDevice()  # pyright: ignore[reportUnknownMemberType]
         prev_stream = None
         if stream is not None:
-            prev_stream: Any = stream_module.get_current_stream()  # pyright: ignore
+            prev_stream = stream_module.get_current_stream()  # pyright: ignore
             # stream can be an int as specified in __dlpack__, or a CuPy stream
             if isinstance(stream, int):
                 stream = cp.cuda.ExternalStream(stream)  # pyright: ignore
@@ -814,7 +809,7 @@ def _cupy_to_device(
             arr = x.copy()
         finally:
             runtime.setDevice(prev_device)  # pyright: ignore[reportUnknownMemberType]
-            if stream is not None:
+            if prev_stream is not None:
                 prev_stream.use()
         return arr
 
@@ -823,7 +818,7 @@ def _torch_to_device(
     x: torch.Tensor,
     device: torch.device | str | int,
     /,
-    stream: None = None,
+    stream: int | Any | None = None,
 ) -> torch.Tensor:
     if stream is not None:
         raise NotImplementedError
@@ -889,7 +884,7 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -
         # cupy does not yet have to_device
         return _cupy_to_device(x, device, stream=stream)
     elif is_torch_array(x):
-        return _torch_to_device(x, device, stream=stream)  # pyright: ignore[reportArgumentType]
+        return _torch_to_device(x, device, stream=stream)
     elif is_dask_array(x):
         if stream is not None:
             raise ValueError("The stream argument to to_device() is not supported")
@@ -914,12 +909,12 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -
 
 
 @overload
-def size(x: HasShape[Collection[SupportsIndex]]) -> int: ...
+def size(x: HasShape[int]) -> int: ...
 @overload
-def size(x: HasShape[Collection[None]]) -> None: ...
+def size(x: HasShape[int | None]) -> int | None: ...
 @overload
-def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ...
-def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
+def size(x: HasShape[float]) -> int | None: ...  # Dask special case
+def size(x: HasShape[float | None]) -> int | None:
     """
     Return the total number of elements of x.
 
@@ -934,12 +929,12 @@ def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
     # Lazy API compliant arrays, such as ndonnx, can contain None in their shape
     if None in x.shape:
         return None
-    out = math.prod(cast("Collection[SupportsIndex]", x.shape))
+    out = math.prod(cast(tuple[float, ...], x.shape))
     # dask.array.Array.shape can contain NaN
-    return None if math.isnan(out) else out
+    return None if math.isnan(out) else cast(int, out)
 
 
-def is_writeable_array(x: object) -> bool:
+def is_writeable_array(x: object) -> TypeGuard[_ArrayApiObj]:
     """
     Return False if ``x.__setitem__`` is expected to raise; True otherwise.
     Return False if `x` is not an array API compatible object.
@@ -956,7 +951,7 @@ def is_writeable_array(x: object) -> bool:
     return is_array_api_obj(x)
 
 
-def is_lazy_array(x: object) -> bool:
+def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]:
     """Return True if x is potentially a future or it may be otherwise impossible or
     expensive to eagerly read its contents, regardless of their size, e.g. by
     calling ``bool(x)`` or ``float(x)``.
@@ -997,7 +992,7 @@ def is_lazy_array(x: object) -> bool:
     # on __bool__ (dask is one such example, which however is special-cased above).
 
     # Select a single point of the array
-    s = size(cast("HasShape[Collection[SupportsIndex | None]]", x))
+    s = size(cast(HasShape, x))
     if s is None:
         return True
     xp = array_namespace(x)
@@ -1044,5 +1039,6 @@ def is_lazy_array(x: object) -> bool:
 
 _all_ignore = ["sys", "math", "inspect", "warnings"]
 
+
 def __dir__() -> list[str]:
     return __all__
diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py
index 7e002aed..cf7cf90b 100644
--- a/array_api_compat/common/_linalg.py
+++ b/array_api_compat/common/_linalg.py
@@ -8,7 +8,7 @@
 if np.__version__[0] == "2":
     from numpy.lib.array_utils import normalize_axis_tuple
 else:
-    from numpy.core.numeric import normalize_axis_tuple
+    from numpy.core.numeric import normalize_axis_tuple  # type: ignore[no-redef]
 
 from .._internal import get_xp
 from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot
@@ -164,7 +164,7 @@ def vector_norm(
     if axis is None:
         # Note: xp.linalg.norm() doesn't handle 0-D arrays
         _x = x.ravel()
-        _axis = 0
+        axis = 0
     elif isinstance(axis, tuple):
         # Note: The axis argument supports any number of axes, whereas
         # xp.linalg.norm() only supports a single axis for vector norm.
@@ -176,25 +176,24 @@ def vector_norm(
         newshape = axis + rest
         _x = xp.transpose(x, newshape).reshape(
             (math.prod([x.shape[i] for i in axis]), *[x.shape[i] for i in rest]))
-        _axis = 0
+        axis = 0
     else:
         _x = x
-        _axis = axis
 
-    res = xp.linalg.norm(_x, axis=_axis, ord=ord)
+    res = xp.linalg.norm(_x, axis=axis, ord=ord)
 
     if keepdims:
         # We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks
         # above to avoid matrix norm logic.
         shape = list(x.shape)
-        _axis = cast(
+        axis = cast(
             "tuple[int, ...]",
             normalize_axis_tuple(  # pyright: ignore[reportCallIssue]
                 range(x.ndim) if axis is None else axis,
                 x.ndim,
             ),
         )
-        for i in _axis:
+        for i in axis:
             shape[i] = 1
         res = xp.reshape(res, tuple(shape))
 
diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py
index d7deade1..c94f73fc 100644
--- a/array_api_compat/common/_typing.py
+++ b/array_api_compat/common/_typing.py
@@ -1,6 +1,5 @@
 from __future__ import annotations
 
-from collections.abc import Mapping
 from types import ModuleType as Namespace
 from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, TypeVar
 
@@ -26,13 +25,13 @@ def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
     def __len__(self, /) -> int: ...
 
 
-class SupportsArrayNamespace(Protocol[_T_co]):
-    def __array_namespace__(self, /, *, api_version: str | None) -> _T_co: ...
+class SupportsArrayNamespace(Protocol):
+    def __array_namespace__(self, /, *, api_version: str | None) -> Namespace: ...
 
 
 class HasShape(Protocol[_T_co]):
     @property
-    def shape(self, /) -> _T_co: ...
+    def shape(self, /) -> tuple[_T_co, ...]: ...
 
 
 # Return type of `__array_namespace_info__.default_dtypes`
@@ -70,72 +69,11 @@ def shape(self, /) -> _T_co: ...
 DTypeKind: TypeAlias = _DTypeKind | tuple[_DTypeKind, ...]
 
 
-# `__array_namespace_info__.dtypes(kind="bool")`
-class DTypesBool(TypedDict):
-    bool: DType
-
-
-# `__array_namespace_info__.dtypes(kind="signed integer")`
-class DTypesSigned(TypedDict):
-    int8: DType
-    int16: DType
-    int32: DType
-    int64: DType
-
-
-# `__array_namespace_info__.dtypes(kind="unsigned integer")`
-class DTypesUnsigned(TypedDict):
-    uint8: DType
-    uint16: DType
-    uint32: DType
-    uint64: DType
-
-
-# `__array_namespace_info__.dtypes(kind="integral")`
-class DTypesIntegral(DTypesSigned, DTypesUnsigned):
-    pass
-
-
-# `__array_namespace_info__.dtypes(kind="real floating")`
-class DTypesReal(TypedDict):
-    float32: DType
-    float64: DType
-
-
-# `__array_namespace_info__.dtypes(kind="complex floating")`
-class DTypesComplex(TypedDict):
-    complex64: DType
-    complex128: DType
-
-
-# `__array_namespace_info__.dtypes(kind="numeric")`
-class DTypesNumeric(DTypesIntegral, DTypesReal, DTypesComplex):
-    pass
-
-
-# `__array_namespace_info__.dtypes(kind=None)` (default)
-class DTypesAll(DTypesBool, DTypesNumeric):
-    pass
-
-
-# `__array_namespace_info__.dtypes(kind=?)` (fallback)
-DTypesAny: TypeAlias = Mapping[str, DType]
-
-
 __all__ = [
     "Array",
     "Capabilities",
     "DType",
     "DTypeKind",
-    "DTypesAny",
-    "DTypesAll",
-    "DTypesBool",
-    "DTypesNumeric",
-    "DTypesIntegral",
-    "DTypesSigned",
-    "DTypesUnsigned",
-    "DTypesReal",
-    "DTypesComplex",
     "DefaultDTypes",
     "Device",
     "HasShape",
diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py
index fd1460ae..da4be14b 100644
--- a/array_api_compat/cupy/_aliases.py
+++ b/array_api_compat/cupy/_aliases.py
@@ -1,9 +1,8 @@
 from __future__ import annotations
 
-from typing import Optional
+from builtins import bool as py_bool
 
 import cupy as cp
-
 from ..common import _aliases, _helpers
 from ..common._typing import NestedSequence, SupportsBufferProtocol
 from .._internal import get_xp
@@ -69,18 +68,13 @@
 
 # asarray also adds the copy keyword, which is not present in numpy 1.0.
 def asarray(
-    obj: (
-        Array 
-        | bool | int | float | complex 
-        | NestedSequence[bool | int | float | complex] 
-        | SupportsBufferProtocol
-    ),
+    obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
     /,
     *,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
-    copy: Optional[bool] = _copy_default,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    copy: py_bool | None = _copy_default,
+    **kwargs: object,
 ) -> Array:
     """
     Array API compatibility wrapper for asarray().
@@ -115,8 +109,8 @@ def astype(
     dtype: DType,
     /,
     *,
-    copy: bool = True,
-    device: Optional[Device] = None,
+    copy: py_bool = True,
+    device: Device | None = None,
 ) -> Array:
     if device is None:
         return x.astype(dtype=dtype, copy=copy)
@@ -127,8 +121,8 @@ def astype(
 # cupy.count_nonzero does not have keepdims
 def count_nonzero(
     x: Array,
-    axis=None,
-    keepdims=False
+    axis: int | tuple[int, ...] | None = None,
+    keepdims: py_bool = False,
 ) -> Array:
    result = cp.count_nonzero(x, axis)
    if keepdims:
@@ -161,4 +155,4 @@ def count_nonzero(
                               'bitwise_invert', 'bitwise_right_shift',
                               'bool', 'concat', 'count_nonzero', 'pow', 'sign']
 
-_all_ignore = ['cp', 'get_xp']
+_all_ignore = ['cp', 'get_xp', 'py_bool']
diff --git a/array_api_compat/cupy/fft.py b/array_api_compat/cupy/fft.py
index 307e0f72..2bd11940 100644
--- a/array_api_compat/cupy/fft.py
+++ b/array_api_compat/cupy/fft.py
@@ -1,10 +1,11 @@
-from cupy.fft import * # noqa: F403
+from cupy.fft import *  # noqa: F403
+
 # cupy.fft doesn't have __all__. If it is added, replace this with
 #
 # from cupy.fft import __all__ as linalg_all
-_n = {}
-exec('from cupy.fft import *', _n)
-del _n['__builtins__']
+_n: dict[str, object] = {}
+exec("from cupy.fft import *", _n)
+del _n["__builtins__"]
 fft_all = list(_n)
 del _n
 
diff --git a/array_api_compat/cupy/linalg.py b/array_api_compat/cupy/linalg.py
index 7fcdd498..7bc3536e 100644
--- a/array_api_compat/cupy/linalg.py
+++ b/array_api_compat/cupy/linalg.py
@@ -2,7 +2,7 @@
 # cupy.linalg doesn't have __all__. If it is added, replace this with
 #
 # from cupy.linalg import __all__ as linalg_all
-_n = {}
+_n: dict[str, object] = {}
 exec('from cupy.linalg import *', _n)
 del _n['__builtins__']
 linalg_all = list(_n)
diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py
index 1e47b960..6d2ea7cd 100644
--- a/array_api_compat/dask/array/__init__.py
+++ b/array_api_compat/dask/array/__init__.py
@@ -3,7 +3,7 @@
 from dask.array import *  # noqa: F403
 
 # These imports may overwrite names from the import * above.
-from ._aliases import *  # noqa: F403
+from ._aliases import *  # type: ignore[assignment] # noqa: F403
 
 __array_api_version__: Final = "2024.12"
 
diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py
index 9687a9cd..86870e9b 100644
--- a/array_api_compat/dask/array/_aliases.py
+++ b/array_api_compat/dask/array/_aliases.py
@@ -146,7 +146,7 @@ def arange(
 
 # asarray also adds the copy keyword, which is not present in numpy 1.0.
 def asarray(
-    obj: complex | NestedSequence[complex] | Array | SupportsBufferProtocol,
+    obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
     /,
     *,
     dtype: DType | None = None,
diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py
index 9e4d736f..5e3e9018 100644
--- a/array_api_compat/dask/array/_info.py
+++ b/array_api_compat/dask/array/_info.py
@@ -12,9 +12,9 @@
 
 from __future__ import annotations
 
-from typing import Literal as L
-from typing import TypeAlias, overload
+from typing import Literal, TypeAlias
 
+import dask.array as da
 from numpy import bool_ as bool
 from numpy import (
     complex64,
@@ -33,24 +33,10 @@
     uint64,
 )
 
-from ...common._helpers import _DASK_DEVICE, _dask_device
-from ...common._typing import (
-    Capabilities,
-    DefaultDTypes,
-    DType,
-    DTypeKind,
-    DTypesAll,
-    DTypesAny,
-    DTypesBool,
-    DTypesComplex,
-    DTypesIntegral,
-    DTypesNumeric,
-    DTypesReal,
-    DTypesSigned,
-    DTypesUnsigned,
-)
+from ...common._helpers import _DASK_DEVICE, _check_device, _dask_device
+from ...common._typing import Capabilities, DefaultDTypes, DType, DTypeKind
 
-_Device: TypeAlias = L["cpu"] | _dask_device
+Device: TypeAlias = Literal["cpu"] | _dask_device
 
 
 class __array_namespace_info__:
@@ -142,7 +128,7 @@ def capabilities(self) -> Capabilities:
             "max dimensions": 64,
         }
 
-    def default_device(self) -> L["cpu"]:
+    def default_device(self) -> Device:
         """
         The default device used for new Dask arrays.
 
@@ -169,7 +155,7 @@ def default_device(self) -> L["cpu"]:
         """
         return "cpu"
 
-    def default_dtypes(self, /, *, device: _Device | None = None) -> DefaultDTypes:
+    def default_dtypes(self, /, *, device: Device | None = None) -> DefaultDTypes:
         """
         The default data types used for new Dask arrays.
 
@@ -208,11 +194,7 @@ def default_dtypes(self, /, *, device: _Device | None = None) -> DefaultDTypes:
          'indexing': dask.int64}
 
         """
-        if device not in ["cpu", _DASK_DEVICE, None]:
-            raise ValueError(
-                f'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, '
-                f"but received: {device!r}"
-            )
+        _check_device(da, device)
         return {
             "real floating": dtype(float64),
             "complex floating": dtype(complex128),
@@ -220,41 +202,9 @@ def default_dtypes(self, /, *, device: _Device | None = None) -> DefaultDTypes:
             "indexing": dtype(intp),
         }
 
-    @overload
-    def dtypes(
-        self, /, *, device: _Device | None = None, kind: None = None
-    ) -> DTypesAll: ...
-    @overload
-    def dtypes(
-        self, /, *, device: _Device | None = None, kind: L["bool"]
-    ) -> DTypesBool: ...
-    @overload
-    def dtypes(
-        self, /, *, device: _Device | None = None, kind: L["signed integer"]
-    ) -> DTypesSigned: ...
-    @overload
-    def dtypes(
-        self, /, *, device: _Device | None = None, kind: L["unsigned integer"]
-    ) -> DTypesUnsigned: ...
-    @overload
-    def dtypes(
-        self, /, *, device: _Device | None = None, kind: L["integral"]
-    ) -> DTypesIntegral: ...
-    @overload
-    def dtypes(
-        self, /, *, device: _Device | None = None, kind: L["real floating"]
-    ) -> DTypesReal: ...
-    @overload
-    def dtypes(
-        self, /, *, device: _Device | None = None, kind: L["complex floating"]
-    ) -> DTypesComplex: ...
-    @overload
-    def dtypes(
-        self, /, *, device: _Device | None = None, kind: L["numeric"]
-    ) -> DTypesNumeric: ...
     def dtypes(
-        self, /, *, device: _Device | None = None, kind: DTypeKind | None = None
-    ) -> DTypesAny:
+        self, /, *, device: Device | None = None, kind: DTypeKind | None = None
+    ) -> dict[str, DType]:
         """
         The array API data types supported by Dask.
 
@@ -308,11 +258,7 @@ def dtypes(
          'int64': dask.int64}
 
         """
-        if device not in ["cpu", _DASK_DEVICE, None]:
-            raise ValueError(
-                'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:'
-                f" {device}"
-            )
+        _check_device(da, device)
         if kind is None:
             return {
                 "bool": dtype(bool),
@@ -381,14 +327,14 @@ def dtypes(
                 "complex64": dtype(complex64),
                 "complex128": dtype(complex128),
             }
-        if isinstance(kind, tuple):  # type: ignore[reportUnnecessaryIsinstanceCall]
+        if isinstance(kind, tuple):
             res: dict[str, DType] = {}
             for k in kind:
                 res.update(self.dtypes(kind=k))
             return res
         raise ValueError(f"unsupported kind: {kind!r}")
 
-    def devices(self) -> list[_Device]:
+    def devices(self) -> list[Device]:
         """
         The devices supported by Dask.
 
diff --git a/array_api_compat/dask/array/fft.py b/array_api_compat/dask/array/fft.py
index 3f40dffe..68c4280e 100644
--- a/array_api_compat/dask/array/fft.py
+++ b/array_api_compat/dask/array/fft.py
@@ -2,7 +2,7 @@
 # dask.array.fft doesn't have __all__. If it is added, replace this with
 #
 # from dask.array.fft import __all__ as linalg_all
-_n = {}
+_n: dict[str, object] = {}
 exec('from dask.array.fft import *', _n)
 for k in ("__builtins__", "Sequence", "annotations", "warnings"):
     _n.pop(k, None)
diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py
index 0825386e..06f596bc 100644
--- a/array_api_compat/dask/array/linalg.py
+++ b/array_api_compat/dask/array/linalg.py
@@ -4,21 +4,22 @@
 
 import dask.array as da
 
-# The `matmul` and `tensordot` functions are in both the main and linalg namespaces
-from dask.array import matmul, outer, tensordot
-
 # Exports
 from dask.array.linalg import *  # noqa: F403
+from dask.array import outer
+# The `matmul` and `tensordot` functions are in both the main and linalg namespaces
+from dask.array import matmul, tensordot
+
 
 from ..._internal import get_xp
 from ...common import _linalg
-from ...common._typing import Array as _Array
+from ...common._typing import Array
 from ._aliases import matrix_transpose, vecdot
 
 # dask.array.linalg doesn't have __all__. If it is added, replace this with
 #
 # from dask.array.linalg import __all__ as linalg_all
-_n = {}
+_n: dict[str, object] = {}
 exec('from dask.array.linalg import *', _n)
 for k in ('__builtins__', 'annotations', 'operator', 'warnings', 'Array'):
     _n.pop(k, None)
@@ -33,8 +34,8 @@
 # supports the mode keyword on QR
 # https://github.com/dask/dask/issues/10388
 #qr = get_xp(da)(_linalg.qr)
-def qr(
-    x: _Array,
+def qr(  # type: ignore[no-redef]
+    x: Array,
     mode: Literal["reduced", "complete"] = "reduced",
     **kwargs: object,
 ) -> QRResult:
@@ -50,12 +51,12 @@ def qr(
 # Wrap the svd functions to not pass full_matrices to dask
 # when full_matrices=False (as that is the default behavior for dask),
 # and dask doesn't have the full_matrices keyword
-def svd(x: _Array, full_matrices: bool = True, **kwargs) -> SVDResult:
+def svd(x: Array, full_matrices: bool = True, **kwargs: object) -> SVDResult:  # type: ignore[no-redef]
     if full_matrices:
         raise ValueError("full_matrics=True is not supported by dask.")
     return da.linalg.svd(x, coerce_signs=False, **kwargs)
 
-def svdvals(x: _Array) -> _Array:
+def svdvals(x: Array) -> Array:
     # TODO: can't avoid computing U or V for dask
     _, s, _ =  svd(x)
     return s
diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py
index f7b558ba..ae9406a6 100644
--- a/array_api_compat/numpy/__init__.py
+++ b/array_api_compat/numpy/__init__.py
@@ -10,7 +10,7 @@
 from numpy import round as round
 
 # These imports may overwrite names from the import * above.
-from ._aliases import *  # noqa: F403
+from ._aliases import *  # type: ignore[assignment,no-redef] # noqa: F403
 
 # Don't know why, but we have to do an absolute import to import linalg. If we
 # instead do
@@ -23,13 +23,6 @@
 
 __import__(__package__ + ".fft")
 
-from ..common._helpers import *  # noqa: F403
-from .linalg import matrix_transpose, vecdot  # noqa: F401
-
-try:
-    # Used in asarray(). Not present in older versions.
-    from numpy import _CopyMode  # noqa: F401
-except ImportError:
-    pass
+from .linalg import matrix_transpose, vecdot  # type: ignore[no-redef] # noqa: F401
 
 __array_api_version__: Final = "2024.12"
diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py
index d8792611..918f501f 100644
--- a/array_api_compat/numpy/_aliases.py
+++ b/array_api_compat/numpy/_aliases.py
@@ -2,7 +2,7 @@
 from __future__ import annotations
 
 from builtins import bool as py_bool
-from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast
+from typing import Any, cast
 
 import numpy as np
 
@@ -12,13 +12,6 @@
 from ._info import __array_namespace_info__
 from ._typing import Array, Device, DType
 
-if TYPE_CHECKING:
-    from typing_extensions import Buffer, TypeIs
-
-# The values of the `_CopyMode` enum can be either `False`, `True`, or `2`:
-# https://github.com/numpy/numpy/blob/5a8a6a79d9c2fff8f07dcab5d41e14f8508d673f/numpy/_globals.pyi#L7-L10
-_Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode
-
 bool = np.bool_
 
 # Basic renames
@@ -74,14 +67,6 @@
 iinfo = get_xp(np)(_aliases.iinfo)
 
 
-def _supports_buffer_protocol(obj: object) -> TypeIs[Buffer]:  # pyright: ignore[reportUnusedFunction]
-    try:
-        memoryview(obj)  # pyright: ignore[reportArgumentType]
-    except TypeError:
-        return False
-    return True
-
-
 # asarray also adds the copy keyword, which is not present in numpy 1.0.
 # asarray() is different enough between numpy, cupy, and dask, the logic
 # complicated enough that it's easier to define it separately for each module
@@ -92,7 +77,7 @@ def asarray(
     *,
     dtype: DType | None = None,
     device: Device | None = None,
-    copy: _Copy | None = None,
+    copy: py_bool | None = None,
     **kwargs: Any,
 ) -> Array:
     """
@@ -104,13 +89,13 @@ def asarray(
     _helpers._check_device(np, device)
 
     if copy is None:
-        copy = np._CopyMode.IF_NEEDED
-    elif copy is False:
-        copy = np._CopyMode.NEVER
-    elif copy is True:
-        copy = np._CopyMode.ALWAYS
+        np1_copy = np._CopyMode.IF_NEEDED  # type: ignore[attr-defined]
+    elif copy:
+        np1_copy = np._CopyMode.ALWAYS  # type: ignore[attr-defined]
+    else:
+        np1_copy = np._CopyMode.NEVER  # type: ignore[attr-defined]
 
-    return np.array(obj, copy=copy, dtype=dtype, **kwargs)  # pyright: ignore
+    return np.array(obj, copy=np1_copy, dtype=dtype, **kwargs)
 
 
 def astype(
@@ -134,7 +119,7 @@ def count_nonzero(
 ) -> Array:
     # NOTE: this is currently incorrectly typed in numpy, but will be fixed in
     # numpy 2.2.5 and 2.3.0: https://github.com/numpy/numpy/pull/28750
-    result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims))  # pyright: ignore[reportArgumentType, reportCallIssue]
+    result = cast(Any, np.count_nonzero(x, axis=axis, keepdims=keepdims))  # type: ignore[arg-type]  # pyright: ignore[reportArgumentType, reportCallIssue]
     if axis is None and not keepdims:
         return np.asarray(result)
     return result
@@ -145,7 +130,7 @@ def count_nonzero(
 if hasattr(np, "vecdot"):
     vecdot = np.vecdot
 else:
-    vecdot = get_xp(np)(_aliases.vecdot)
+    vecdot = get_xp(np)(_aliases.vecdot)  # type: ignore[assignment]
 
 if hasattr(np, "isdtype"):
     isdtype = np.isdtype
diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py
index f307f62c..11126e5d 100644
--- a/array_api_compat/numpy/_info.py
+++ b/array_api_compat/numpy/_info.py
@@ -7,6 +7,7 @@
 more details.
 
 """
+
 from __future__ import annotations
 
 from numpy import bool_ as bool
@@ -27,6 +28,7 @@
     uint64,
 )
 
+from ..common._typing import DefaultDTypes
 from ._typing import Device, DType
 
 
@@ -62,7 +64,7 @@ class __array_namespace_info__:
 
     """
 
-    __module__ = 'numpy'
+    __module__ = "numpy"
 
     def capabilities(self):
         """
@@ -139,7 +141,7 @@ def default_dtypes(
         self,
         *,
         device: Device | None = None,
-    ) -> dict[str, dtype[intp | float64 | complex128]]:
+    ) -> DefaultDTypes:
         """
         The default data types used for new NumPy arrays.
 
@@ -181,8 +183,7 @@ def default_dtypes(
         """
         if device not in ["cpu", None]:
             raise ValueError(
-                'Device not understood. Only "cpu" is allowed, but received:'
-                f' {device}'
+                f'Device not understood. Only "cpu" is allowed, but received: {device}'
             )
         return {
             "real floating": dtype(float64),
@@ -253,8 +254,7 @@ def dtypes(
         """
         if device not in ["cpu", None]:
             raise ValueError(
-                'Device not understood. Only "cpu" is allowed, but received:'
-                f' {device}'
+                f'Device not understood. Only "cpu" is allowed, but received: {device}'
             )
         if kind is None:
             return {
diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py
index e771c788..617cfb71 100644
--- a/array_api_compat/numpy/_typing.py
+++ b/array_api_compat/numpy/_typing.py
@@ -7,7 +7,6 @@
 Device: TypeAlias = Literal["cpu"]
 
 if TYPE_CHECKING:
-
     # NumPy 1.x on Python 3.10 fails to parse np.dtype[]
     DType: TypeAlias = np.dtype[
         np.bool_
diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py
index 2d3e731d..9a618be9 100644
--- a/array_api_compat/numpy/linalg.py
+++ b/array_api_compat/numpy/linalg.py
@@ -65,7 +65,7 @@
 # https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43).
 def solve(x1: Array, x2: Array, /) -> Array:
     try:
-        from numpy.linalg._linalg import (
+        from numpy.linalg._linalg import (  # type: ignore[attr-defined]
             _assert_stacked_2d,
             _assert_stacked_square,
             _commonType,
@@ -74,7 +74,7 @@ def solve(x1: Array, x2: Array, /) -> Array:
             isComplexType,
         )
     except ImportError:
-        from numpy.linalg.linalg import (
+        from numpy.linalg.linalg import (  # type: ignore[attr-defined]
             _assert_stacked_2d,
             _assert_stacked_square,
             _commonType,
diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py
index 027a0261..5a7d1870 100644
--- a/array_api_compat/torch/_aliases.py
+++ b/array_api_compat/torch/_aliases.py
@@ -1,8 +1,9 @@
 from __future__ import annotations
 
+from collections.abc import Sequence
 from functools import reduce as _reduce, wraps as _wraps
 from builtins import all as _builtin_all, any as _builtin_any
-from typing import Any, List, Optional, Sequence, Tuple, Union
+from typing import Any
 
 import torch
 
@@ -96,9 +97,7 @@ def _fix_promotion(x1, x2, only_scalar=True):
 _py_scalars = (bool, int, float, complex)
 
 
-def result_type(
-    *arrays_and_dtypes: Array | DType | bool | int | float | complex
-) -> DType:
+def result_type(*arrays_and_dtypes: Array | DType | complex) -> DType:
     num = len(arrays_and_dtypes)
 
     if num == 0:
@@ -129,10 +128,7 @@ def result_type(
         return _reduce(_result_type, others + scalars)
 
 
-def _result_type(
-    x: Array | DType | bool | int | float | complex,
-    y: Array | DType | bool | int | float | complex,
-) -> DType:
+def _result_type(x: Array | DType | complex, y: Array | DType | complex) -> DType:
     if not (isinstance(x, _py_scalars) or isinstance(y, _py_scalars)):
         xdt = x if isinstance(x, torch.dtype) else x.dtype
         ydt = y if isinstance(y, torch.dtype) else y.dtype
@@ -150,7 +146,7 @@ def _result_type(
     return torch.result_type(x, y)
 
 
-def can_cast(from_: Union[DType, Array], to: DType, /) -> bool:
+def can_cast(from_: DType | Array, to: DType, /) -> bool:
     if not isinstance(from_, torch.dtype):
         from_ = from_.dtype
     return torch.can_cast(from_, to)
@@ -194,12 +190,7 @@ def can_cast(from_: Union[DType, Array], to: DType, /) -> bool:
 
 
 def asarray(
-    obj: (
-    Array 
-        | bool | int | float | complex 
-        | NestedSequence[bool | int | float | complex] 
-        | SupportsBufferProtocol
-    ),
+    obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
     /,
     *,
     dtype: DType | None = None,
@@ -218,13 +209,13 @@ def asarray(
 # of 'axis'.
 
 # torch.min and torch.max return a tuple and don't support multiple axes https://github.com/pytorch/pytorch/issues/58745
-def max(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array:
+def max(x: Array, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array:
     # https://github.com/pytorch/pytorch/issues/29137
     if axis == ():
         return torch.clone(x)
     return torch.amax(x, axis, keepdims=keepdims)
 
-def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array:
+def min(x: Array, /, *, axis: int | tuple[int, ...] |None = None, keepdims: bool = False) -> Array:
     # https://github.com/pytorch/pytorch/issues/29137
     if axis == ():
         return torch.clone(x)
@@ -240,7 +231,15 @@ def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep
 
 # torch.sort also returns a tuple
 # https://github.com/pytorch/pytorch/issues/70921
-def sort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> Array:
+def sort(
+    x: Array,
+    /,
+    *,
+    axis: int = -1,
+    descending: bool = False,
+    stable: bool = True,
+    **kwargs: object,
+) -> Array:
     return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values
 
 def _normalize_axes(axis, ndim):
@@ -307,10 +306,10 @@ def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array:
 def prod(x: Array,
          /,
          *,
-         axis: Optional[Union[int, Tuple[int, ...]]] = None,
-         dtype: Optional[DType] = None,
+         axis: int | tuple[int, ...] | None = None,
+         dtype: DType | None = None,
          keepdims: bool = False,
-         **kwargs) -> Array:
+         **kwargs: object) -> Array:
 
     if axis == ():
         return _sum_prod_no_axis(x, dtype)
@@ -331,10 +330,10 @@ def prod(x: Array,
 def sum(x: Array,
          /,
          *,
-         axis: Optional[Union[int, Tuple[int, ...]]] = None,
-         dtype: Optional[DType] = None,
+         axis: int | tuple[int, ...] | None = None,
+         dtype: DType | None = None,
          keepdims: bool = False,
-         **kwargs) -> Array:
+         **kwargs: object) -> Array:
 
     if axis == ():
         return _sum_prod_no_axis(x, dtype)
@@ -350,9 +349,9 @@ def sum(x: Array,
 def any(x: Array,
         /,
         *,
-        axis: Optional[Union[int, Tuple[int, ...]]] = None,
+        axis: int | tuple[int, ...] | None = None,
         keepdims: bool = False,
-        **kwargs) -> Array:
+        **kwargs: object) -> Array:
 
     if axis == ():
         return x.to(torch.bool)
@@ -374,9 +373,9 @@ def any(x: Array,
 def all(x: Array,
         /,
         *,
-        axis: Optional[Union[int, Tuple[int, ...]]] = None,
+        axis: int | tuple[int, ...] | None = None,
         keepdims: bool = False,
-        **kwargs) -> Array:
+        **kwargs: object) -> Array:
 
     if axis == ():
         return x.to(torch.bool)
@@ -398,9 +397,9 @@ def all(x: Array,
 def mean(x: Array,
          /,
          *,
-         axis: Optional[Union[int, Tuple[int, ...]]] = None,
+         axis: int | tuple[int, ...] | None = None,
          keepdims: bool = False,
-         **kwargs) -> Array:
+         **kwargs: object) -> Array:
     # https://github.com/pytorch/pytorch/issues/29137
     if axis == ():
         return torch.clone(x)
@@ -415,10 +414,10 @@ def mean(x: Array,
 def std(x: Array,
         /,
         *,
-        axis: Optional[Union[int, Tuple[int, ...]]] = None,
-        correction: Union[int, float] = 0.0,
+        axis: int | tuple[int, ...] | None = None,
+        correction: float = 0.0,
         keepdims: bool = False,
-        **kwargs) -> Array:
+        **kwargs: object) -> Array:
     # Note, float correction is not supported
     # https://github.com/pytorch/pytorch/issues/61492. We don't try to
     # implement it here for now.
@@ -446,10 +445,10 @@ def std(x: Array,
 def var(x: Array,
         /,
         *,
-        axis: Optional[Union[int, Tuple[int, ...]]] = None,
-        correction: Union[int, float] = 0.0,
+        axis: int | tuple[int, ...] | None = None,
+        correction: float = 0.0,
         keepdims: bool = False,
-        **kwargs) -> Array:
+        **kwargs: object) -> Array:
     # Note, float correction is not supported
     # https://github.com/pytorch/pytorch/issues/61492. We don't try to
     # implement it here for now.
@@ -472,11 +471,11 @@ def var(x: Array,
 
 # torch.concat doesn't support dim=None
 # https://github.com/pytorch/pytorch/issues/70925
-def concat(arrays: Union[Tuple[Array, ...], List[Array]],
+def concat(arrays: tuple[Array, ...] | list[Array],
            /,
            *,
-           axis: Optional[int] = 0,
-           **kwargs) -> Array:
+           axis: int | None = 0,
+           **kwargs: object) -> Array:
     if axis is None:
         arrays = tuple(ar.flatten() for ar in arrays)
         axis = 0
@@ -485,7 +484,7 @@ def concat(arrays: Union[Tuple[Array, ...], List[Array]],
 # torch.squeeze only accepts int dim and doesn't require it
 # https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was
 # added at https://github.com/pytorch/pytorch/pull/89017.
-def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array:
+def squeeze(x: Array, /, axis: int | tuple[int, ...]) -> Array:
     if isinstance(axis, int):
         axis = (axis,)
     for a in axis:
@@ -499,27 +498,27 @@ def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array:
     return x
 
 # torch.broadcast_to uses size instead of shape
-def broadcast_to(x: Array, /, shape: Tuple[int, ...], **kwargs) -> Array:
+def broadcast_to(x: Array, /, shape: tuple[int, ...], **kwargs: object) -> Array:
     return torch.broadcast_to(x, shape, **kwargs)
 
 # torch.permute uses dims instead of axes
-def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
+def permute_dims(x: Array, /, axes: tuple[int, ...]) -> Array:
     return torch.permute(x, axes)
 
 # The axis parameter doesn't work for flip() and roll()
 # https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't
 # accept axis=None
-def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> Array:
+def flip(x: Array, /, *, axis: int | tuple[int, ...] | None = None, **kwargs: object) -> Array:
     if axis is None:
         axis = tuple(range(x.ndim))
     # torch.flip doesn't accept dim as an int but the method does
     # https://github.com/pytorch/pytorch/issues/18095
     return x.flip(axis, **kwargs)
 
-def roll(x: Array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> Array:
+def roll(x: Array, /, shift: int | tuple[int, ...], *, axis: int | tuple[int, ...] | None = None, **kwargs: object) -> Array:
     return torch.roll(x, shift, axis, **kwargs)
 
-def nonzero(x: Array, /, **kwargs) -> Tuple[Array, ...]:
+def nonzero(x: Array, /, **kwargs: object) -> tuple[Array, ...]:
     if x.ndim == 0:
         raise ValueError("nonzero() does not support zero-dimensional arrays")
     return torch.nonzero(x, as_tuple=True, **kwargs)
@@ -532,8 +531,8 @@ def diff(
     *,
     axis: int = -1,
     n: int = 1,
-    prepend: Optional[Array] = None,
-    append: Optional[Array] = None,
+    prepend: Array | None = None,
+    append: Array | None = None,
 ) -> Array:
     return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append)
 
@@ -543,7 +542,7 @@ def count_nonzero(
     x: Array,
     /,
     *,
-    axis: Optional[Union[int, Tuple[int, ...]]] = None,
+    axis: int | tuple[int, ...] | None = None,
     keepdims: bool = False,
 ) -> Array:
     result = torch.count_nonzero(x, dim=axis)
@@ -560,12 +559,7 @@ def repeat(x: Array, repeats: int | Array, /, *, axis: int | None = None) -> Arr
     return torch.repeat_interleave(x, repeats, axis)
 
 
-def where(
-    condition: Array, 
-    x1: Array | bool | int | float | complex, 
-    x2: Array | bool | int | float | complex,
-    /,
-) -> Array:
+def where(condition: Array, x1: Array | complex, x2: Array | complex, /) -> Array:
     x1, x2 = _fix_promotion(x1, x2)
     return torch.where(condition, x1, x2)
 
@@ -573,10 +567,10 @@ def where(
 # torch.reshape doesn't have the copy keyword
 def reshape(x: Array,
             /,
-            shape: Tuple[int, ...],
+            shape: tuple[int, ...],
             *,
-            copy: Optional[bool] = None,
-            **kwargs) -> Array:
+            copy: bool | None = None,
+            **kwargs: object) -> Array:
     if copy is not None:
         raise NotImplementedError("torch.reshape doesn't yet support the copy keyword")
     return torch.reshape(x, shape, **kwargs)
@@ -585,14 +579,14 @@ def reshape(x: Array,
 # (https://github.com/pytorch/pytorch/issues/70915), and doesn't support some
 # keyword argument combinations
 # (https://github.com/pytorch/pytorch/issues/70914)
-def arange(start: Union[int, float],
+def arange(start: float,
            /,
-           stop: Optional[Union[int, float]] = None,
-           step: Union[int, float] = 1,
+           stop: float | None = None,
+           step: float = 1,
            *,
-           dtype: Optional[DType] = None,
-           device: Optional[Device] = None,
-           **kwargs) -> Array:
+           dtype: DType | None = None,
+           device: Device | None = None,
+           **kwargs: object) -> Array:
     if stop is None:
         start, stop = 0, start
     if step > 0 and stop <= start or step < 0 and stop >= start:
@@ -607,13 +601,13 @@ def arange(start: Union[int, float],
 # torch.eye does not accept None as a default for the second argument and
 # doesn't support off-diagonals (https://github.com/pytorch/pytorch/issues/70910)
 def eye(n_rows: int,
-        n_cols: Optional[int] = None,
+        n_cols: int | None = None,
         /,
         *,
         k: int = 0,
-        dtype: Optional[DType] = None,
-        device: Optional[Device] = None,
-        **kwargs) -> Array:
+        dtype: DType | None = None,
+        device: Device | None = None,
+        **kwargs: object) -> Array:
     if n_cols is None:
         n_cols = n_rows
     z = torch.zeros(n_rows, n_cols, dtype=dtype, device=device, **kwargs)
@@ -622,52 +616,52 @@ def eye(n_rows: int,
     return z
 
 # torch.linspace doesn't have the endpoint parameter
-def linspace(start: Union[int, float],
-             stop: Union[int, float],
+def linspace(start: float,
+             stop: float,
              /,
              num: int,
              *,
-             dtype: Optional[DType] = None,
-             device: Optional[Device] = None,
+             dtype: DType | None = None,
+             device: Device | None = None,
              endpoint: bool = True,
-             **kwargs) -> Array:
+             **kwargs: object) -> Array:
     if not endpoint:
         return torch.linspace(start, stop, num+1, dtype=dtype, device=device, **kwargs)[:-1]
     return torch.linspace(start, stop, num, dtype=dtype, device=device, **kwargs)
 
 # torch.full does not accept an int size
 # https://github.com/pytorch/pytorch/issues/70906
-def full(shape: Union[int, Tuple[int, ...]],
-         fill_value: bool | int | float | complex,
+def full(shape: int | tuple[int, ...],
+         fill_value: complex,
          *,
-         dtype: Optional[DType] = None,
-         device: Optional[Device] = None,
-         **kwargs) -> Array:
+         dtype: DType | None = None,
+         device: Device | None = None,
+         **kwargs: object) -> Array:
     if isinstance(shape, int):
         shape = (shape,)
 
     return torch.full(shape, fill_value, dtype=dtype, device=device, **kwargs)
 
 # ones, zeros, and empty do not accept shape as a keyword argument
-def ones(shape: Union[int, Tuple[int, ...]],
+def ones(shape: int | tuple[int, ...],
          *,
-         dtype: Optional[DType] = None,
-         device: Optional[Device] = None,
-         **kwargs) -> Array:
+         dtype: DType | None = None,
+         device: Device | None = None,
+         **kwargs: object) -> Array:
     return torch.ones(shape, dtype=dtype, device=device, **kwargs)
 
-def zeros(shape: Union[int, Tuple[int, ...]],
+def zeros(shape: int | tuple[int, ...],
          *,
-         dtype: Optional[DType] = None,
-         device: Optional[Device] = None,
-         **kwargs) -> Array:
+         dtype: DType | None = None,
+         device: Device | None = None,
+         **kwargs: object) -> Array:
     return torch.zeros(shape, dtype=dtype, device=device, **kwargs)
 
-def empty(shape: Union[int, Tuple[int, ...]],
+def empty(shape: int | tuple[int, ...],
          *,
-         dtype: Optional[DType] = None,
-         device: Optional[Device] = None,
-         **kwargs) -> Array:
+         dtype: DType | None = None,
+         device: Device | None = None,
+         **kwargs: object) -> Array:
     return torch.empty(shape, dtype=dtype, device=device, **kwargs)
 
 # tril and triu do not call the keyword argument k
@@ -689,14 +683,14 @@ def astype(
     /,
     *,
     copy: bool = True,
-    device: Optional[Device] = None,
+    device: Device | None = None,
 ) -> Array:
     if device is not None:
         return x.to(device, dtype=dtype, copy=copy)
     return x.to(dtype=dtype, copy=copy)
 
 
-def broadcast_arrays(*arrays: Array) -> List[Array]:
+def broadcast_arrays(*arrays: Array) -> list[Array]:
     shape = torch.broadcast_shapes(*[a.shape for a in arrays])
     return [torch.broadcast_to(a, shape) for a in arrays]
 
@@ -734,7 +728,7 @@ def unique_inverse(x: Array) -> UniqueInverseResult:
 def unique_values(x: Array) -> Array:
     return torch.unique(x)
 
-def matmul(x1: Array, x2: Array, /, **kwargs) -> Array:
+def matmul(x1: Array, x2: Array, /, **kwargs: object) -> Array:
     # torch.matmul doesn't type promote (but differently from _fix_promotion)
     x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
     return torch.matmul(x1, x2, **kwargs)
@@ -752,8 +746,8 @@ def tensordot(
     x2: Array,
     /,
     *, 
-    axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, 
-    **kwargs,
+    axes: int | tuple[Sequence[int], Sequence[int]] = 2, 
+    **kwargs: object,
 ) -> Array:
     # Note: torch.tensordot fails with integer dtypes when there is only 1
     # element in the axis (https://github.com/pytorch/pytorch/issues/84530).
@@ -762,8 +756,10 @@ def tensordot(
 
 
 def isdtype(
-    dtype: DType, kind: Union[DType, str, Tuple[Union[DType, str], ...]],
-    *, _tuple=True, # Disallow nested tuples
+    dtype: DType, 
+    kind: DType | str | tuple[DType | str, ...],
+    *,
+    _tuple: bool = True, # Disallow nested tuples
 ) -> bool:
     """
     Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
@@ -797,7 +793,7 @@ def isdtype(
     else:
         return dtype == kind
 
-def take(x: Array, indices: Array, /, *, axis: Optional[int] = None, **kwargs) -> Array:
+def take(x: Array, indices: Array, /, *, axis: int | None = None, **kwargs: object) -> Array:
     if axis is None:
         if x.ndim != 1:
             raise ValueError("axis must be specified when ndim > 1")
diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py
index 50e6a0d0..ddf87c65 100644
--- a/array_api_compat/torch/fft.py
+++ b/array_api_compat/torch/fft.py
@@ -1,6 +1,7 @@
 from __future__ import annotations
 
-from typing import Union, Sequence, Literal
+from collections.abc import Sequence
+from typing import Literal
 
 import torch
 import torch.fft
@@ -17,7 +18,7 @@ def fftn(
     s: Sequence[int] = None,
     axes: Sequence[int] = None,
     norm: Literal["backward", "ortho", "forward"] = "backward",
-    **kwargs,
+    **kwargs: object,
 ) -> Array:
     return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs)
 
@@ -28,7 +29,7 @@ def ifftn(
     s: Sequence[int] = None,
     axes: Sequence[int] = None,
     norm: Literal["backward", "ortho", "forward"] = "backward",
-    **kwargs,
+    **kwargs: object,
 ) -> Array:
     return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs)
 
@@ -39,7 +40,7 @@ def rfftn(
     s: Sequence[int] = None,
     axes: Sequence[int] = None,
     norm: Literal["backward", "ortho", "forward"] = "backward",
-    **kwargs,
+    **kwargs: object,
 ) -> Array:
     return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs)
 
@@ -50,7 +51,7 @@ def irfftn(
     s: Sequence[int] = None,
     axes: Sequence[int] = None,
     norm: Literal["backward", "ortho", "forward"] = "backward",
-    **kwargs,
+    **kwargs: object,
 ) -> Array:
     return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs)
 
@@ -58,8 +59,8 @@ def fftshift(
     x: Array,
     /,
     *,
-    axes: Union[int, Sequence[int]] = None,
-    **kwargs,
+    axes: int | Sequence[int] = None,
+    **kwargs: object,
 ) -> Array:
     return torch.fft.fftshift(x, dim=axes, **kwargs)
 
@@ -67,8 +68,8 @@ def ifftshift(
     x: Array,
     /,
     *,
-    axes: Union[int, Sequence[int]] = None,
-    **kwargs,
+    axes: int | Sequence[int] = None,
+    **kwargs: object,
 ) -> Array:
     return torch.fft.ifftshift(x, dim=axes, **kwargs)
 
diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py
index 1ff7319d..490b7bd1 100644
--- a/array_api_compat/torch/linalg.py
+++ b/array_api_compat/torch/linalg.py
@@ -1,8 +1,6 @@
 from __future__ import annotations
 
 import torch
-from typing import Optional, Union, Tuple
-
 from torch.linalg import * # noqa: F403
 
 # torch.linalg doesn't define __all__
@@ -31,7 +29,7 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
     x1, x2 = torch.broadcast_tensors(x1, x2)
     return torch_linalg.cross(x1, x2, dim=axis)
 
-def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs) -> Array:
+def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs: object) -> Array:
     from ._aliases import isdtype
 
     x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
@@ -53,7 +51,7 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs) -> Array:
         return res[..., 0, 0]
     return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs)
 
-def solve(x1: Array, x2: Array, /, **kwargs) -> Array:
+def solve(x1: Array, x2: Array, /, **kwargs: object) -> Array:
     x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
     # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve
     # whenever
@@ -74,7 +72,7 @@ def solve(x1: Array, x2: Array, /, **kwargs) -> Array:
     return torch.linalg.solve(x1, x2, **kwargs)
 
 # torch.trace doesn't support the offset argument and doesn't support stacking
-def trace(x: Array, /, *, offset: int = 0, dtype: Optional[DType] = None) -> Array:
+def trace(x: Array, /, *, offset: int = 0, dtype: DType | None = None) -> Array:
     # Use our wrapped sum to make sure it does upcasting correctly
     return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
 
@@ -82,11 +80,11 @@ def vector_norm(
     x: Array,
     /,
     *,
-    axis: Optional[Union[int, Tuple[int, ...]]] = None,
+    axis: int | tuple[int, ...] | None = None,
     keepdims: bool = False,
     # float stands for inf | -inf, which are not valid for Literal
-    ord: Union[int, float] = 2,
-    **kwargs,
+    ord: float = 2,
+    **kwargs: object,
 ) -> Array:
     # torch.vector_norm incorrectly treats axis=() the same as axis=None
     if axis == ():
diff --git a/pyproject.toml b/pyproject.toml
index aacebd11..86310358 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -43,11 +43,11 @@ dev = [
     "array-api-strict",
     "dask[array]>=2024.9.0",
     "jax[cpu]",
+    "ndonnx",
     "numpy>=1.22",
     "pytest",
     "torch",
     "sparse>=0.15.1",
-    "ndonnx"
 ]
 
 [project.urls]
@@ -61,7 +61,7 @@ version = {attr = "array_api_compat.__version__"}
 include = ["array_api_compat*"]
 namespaces = false
 
-[toolint]
+[tool.ruff.lint]
 preview = true
 select = [
 # Defaults
@@ -79,20 +79,44 @@ ignore = [
   "E722"
 ]
 
-[tool.ruff.lint]
-preview = true
-select = [
-# Defaults
-"E4", "E7", "E9", "F",
-# Undefined export
-"F822",
-# Useless import alias
-"PLC0414"
-]
 
-ignore = [
-  # Module import not at top of file
-  "E402",
-  # Do not use bare `except`
-  "E722"
+[tool.mypy]
+files = ["array_api_compat"]
+python_version = "3.10"
+disallow_incomplete_defs = true
+disallow_untyped_decorators = true
+disallow_untyped_defs = false  # TODO
+ignore_missing_imports = true
+no_implicit_optional = true
+show_error_codes = true
+warn_redundant_casts = true
+warn_unused_ignores = true
+warn_unreachable = true
+
+[[tool.mypy.overrides]]
+module = ["cupy.*", "dask.*", "jax.*", "ndonnx.*", "sparse.*", "torch.*"]
+ignore_missing_imports = true
+
+
+[tool.pyright]
+include = ["src", "tests"]
+pythonVersion = "3.10"
+pythonPlatform = "All"
+
+reportAny = false
+reportExplicitAny = false
+# missing type stubs
+reportAttributeAccessIssue = false
+reportUnknownMemberType = false
+reportUnknownVariableType = false
+# Redundant with mypy checks
+reportMissingImports = false
+reportMissingTypeStubs = false
+# false positives for input validation
+reportUnreachable = false
+# ruff handles this
+reportUnusedParameter = false
+
+executionEnvironments = [
+  { root = "array_api_compat" },
 ]

From ad375dc4a63c2a2914484ccc75ab235a9f274d3c Mon Sep 17 00:00:00 2001
From: crusaderky <crusaderky@gmail.com>
Date: Fri, 18 Apr 2025 15:07:44 +0100
Subject: [PATCH 02/12] Fix CopyMode

---
 array_api_compat/numpy/_aliases.py | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py
index 918f501f..0c75d47d 100644
--- a/array_api_compat/numpy/_aliases.py
+++ b/array_api_compat/numpy/_aliases.py
@@ -88,14 +88,14 @@ def asarray(
     """
     _helpers._check_device(np, device)
 
+    # None is unsupported in NumPy 1.0, but we can use an internal enum
+    # False in NumPy 1.0 means None in NumPy 2.0 and in the Array API
     if copy is None:
-        np1_copy = np._CopyMode.IF_NEEDED  # type: ignore[attr-defined]
-    elif copy:
-        np1_copy = np._CopyMode.ALWAYS  # type: ignore[attr-defined]
-    else:
-        np1_copy = np._CopyMode.NEVER  # type: ignore[attr-defined]
+        copy = np._CopyMode.IF_NEEDED  # type: ignore[assignment,attr-defined]
+    elif copy is False:
+        copy = np._CopyMode.NEVER  # type: ignore[assignment,attr-defined]
 
-    return np.array(obj, copy=np1_copy, dtype=dtype, **kwargs)
+    return np.array(obj, copy=copy, dtype=dtype, **kwargs)
 
 
 def astype(

From 49f9ba72709a4b8ab6886ecd047f8ef1545d7871 Mon Sep 17 00:00:00 2001
From: crusaderky <crusaderky@gmail.com>
Date: Fri, 18 Apr 2025 16:07:39 +0100
Subject: [PATCH 03/12] revert

---
 array_api_compat/common/_linalg.py | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py
index cf7cf90b..f483af41 100644
--- a/array_api_compat/common/_linalg.py
+++ b/array_api_compat/common/_linalg.py
@@ -164,7 +164,7 @@ def vector_norm(
     if axis is None:
         # Note: xp.linalg.norm() doesn't handle 0-D arrays
         _x = x.ravel()
-        axis = 0
+        _axis = 0
     elif isinstance(axis, tuple):
         # Note: The axis argument supports any number of axes, whereas
         # xp.linalg.norm() only supports a single axis for vector norm.
@@ -176,24 +176,25 @@ def vector_norm(
         newshape = axis + rest
         _x = xp.transpose(x, newshape).reshape(
             (math.prod([x.shape[i] for i in axis]), *[x.shape[i] for i in rest]))
-        axis = 0
+        _axis = 0
     else:
         _x = x
+        _axis = axis
 
-    res = xp.linalg.norm(_x, axis=axis, ord=ord)
+    res = xp.linalg.norm(_x, axis=_axis, ord=ord)
 
     if keepdims:
         # We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks
         # above to avoid matrix norm logic.
         shape = list(x.shape)
-        axis = cast(
+        _axis = cast(
             "tuple[int, ...]",
             normalize_axis_tuple(  # pyright: ignore[reportCallIssue]
                 range(x.ndim) if axis is None else axis,
                 x.ndim,
             ),
         )
-        for i in axis:
+        for i in _axis:
             shape[i] = 1
         res = xp.reshape(res, tuple(shape))
 

From c724a5220d8d8dfaec9fe93b8defef095bfaa1ed Mon Sep 17 00:00:00 2001
From: crusaderky <crusaderky@gmail.com>
Date: Mon, 21 Apr 2025 10:06:14 +0100
Subject: [PATCH 04/12] Revert `_all_ignore`

---
 array_api_compat/common/_aliases.py | 2 +-
 array_api_compat/common/_helpers.py | 1 -
 array_api_compat/cupy/_aliases.py   | 2 +-
 3 files changed, 2 insertions(+), 3 deletions(-)

diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py
index f7bfc44d..3b2f74ee 100644
--- a/array_api_compat/common/_aliases.py
+++ b/array_api_compat/common/_aliases.py
@@ -722,7 +722,7 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any:
     "finfo",
     "iinfo",
 ]
-_all_ignore = ["is_cupy_namespace", "inspect", "array_namespace", "NamedTuple"]
+_all_ignore = ["inspect", "array_namespace", "NamedTuple"]
 
 
 def __dir__() -> list[str]:
diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py
index c3b3a4f1..c85eebc8 100644
--- a/array_api_compat/common/_helpers.py
+++ b/array_api_compat/common/_helpers.py
@@ -1039,6 +1039,5 @@ def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]:
 
 _all_ignore = ["sys", "math", "inspect", "warnings"]
 
-
 def __dir__() -> list[str]:
     return __all__
diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py
index da4be14b..f74827ff 100644
--- a/array_api_compat/cupy/_aliases.py
+++ b/array_api_compat/cupy/_aliases.py
@@ -155,4 +155,4 @@ def count_nonzero(
                               'bitwise_invert', 'bitwise_right_shift',
                               'bool', 'concat', 'count_nonzero', 'pow', 'sign']
 
-_all_ignore = ['cp', 'get_xp', 'py_bool']
+_all_ignore = ['cp', 'get_xp']

From 14f70afcb3bfde9fe50520c3ab92ca8b22a7397b Mon Sep 17 00:00:00 2001
From: crusaderky <crusaderky@gmail.com>
Date: Mon, 21 Apr 2025 11:46:46 +0100
Subject: [PATCH 05/12] code review

---
 array_api_compat/common/_helpers.py | 2 +-
 pyproject.toml                      | 1 -
 2 files changed, 1 insertion(+), 2 deletions(-)

diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py
index c85eebc8..cec985f6 100644
--- a/array_api_compat/common/_helpers.py
+++ b/array_api_compat/common/_helpers.py
@@ -56,7 +56,7 @@
 _API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"})
 
 
-def _is_jax_zero_gradient_array(x: object) -> TypeIs[_ZeroGradientArray]:
+def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
     """Return True if `x` is a zero-gradient array.
 
     These arrays are a design quirk of Jax that may one day be removed.
diff --git a/pyproject.toml b/pyproject.toml
index 86310358..e5a02c00 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -100,7 +100,6 @@ ignore_missing_imports = true
 
 [tool.pyright]
 include = ["src", "tests"]
-pythonVersion = "3.10"
 pythonPlatform = "All"
 
 reportAny = false

From 0a571bc1088ac9b6bdf6e7118bae56c27a7af410 Mon Sep 17 00:00:00 2001
From: crusaderky <crusaderky@gmail.com>
Date: Mon, 21 Apr 2025 11:53:00 +0100
Subject: [PATCH 06/12] code review

---
 pyproject.toml | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/pyproject.toml b/pyproject.toml
index e5a02c00..ec054417 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -82,11 +82,10 @@ ignore = [
 
 [tool.mypy]
 files = ["array_api_compat"]
-python_version = "3.10"
 disallow_incomplete_defs = true
 disallow_untyped_decorators = true
 disallow_untyped_defs = false  # TODO
-ignore_missing_imports = true
+ignore_missing_imports = false
 no_implicit_optional = true
 show_error_codes = true
 warn_redundant_casts = true
@@ -94,7 +93,7 @@ warn_unused_ignores = true
 warn_unreachable = true
 
 [[tool.mypy.overrides]]
-module = ["cupy.*", "dask.*", "jax.*", "ndonnx.*", "sparse.*", "torch.*"]
+module = ["cupy.*", "cupy_backends.*", "dask.*", "jax.*", "ndonnx.*", "sparse.*", "torch.*"]
 ignore_missing_imports = true
 
 

From 017230030cfa43d46e8b0df4eaf72037cf1cd20a Mon Sep 17 00:00:00 2001
From: crusaderky <crusaderky@gmail.com>
Date: Mon, 21 Apr 2025 11:53:10 +0100
Subject: [PATCH 07/12] JustInt mypy ignores

---
 array_api_compat/common/_typing.py | 15 ++++++---------
 1 file changed, 6 insertions(+), 9 deletions(-)

diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py
index cd8050b7..6cfab829 100644
--- a/array_api_compat/common/_typing.py
+++ b/array_api_compat/common/_typing.py
@@ -33,32 +33,29 @@
 # - docs: https://github.com/jorenham/optype/blob/master/README.md#just
 # - code: https://github.com/jorenham/optype/blob/master/optype/_core/_just.py
 @final
-class JustInt(Protocol):
-    @property
+class JustInt(Protocol):  # type: ignore[misc]
+    @property  # type: ignore[override]
     def __class__(self, /) -> type[int]: ...
     @__class__.setter
     def __class__(self, value: type[int], /) -> None: ...  # pyright: ignore[reportIncompatibleMethodOverride]
 
 
 @final
-class JustFloat(Protocol):
-    @property
+class JustFloat(Protocol):  # type: ignore[misc]
+    @property  # type: ignore[override]
     def __class__(self, /) -> type[float]: ...
     @__class__.setter
     def __class__(self, value: type[float], /) -> None: ...  # pyright: ignore[reportIncompatibleMethodOverride]
 
 
 @final
-class JustComplex(Protocol):
-    @property
+class JustComplex(Protocol):  # type: ignore[misc]
+    @property  # type: ignore[override]
     def __class__(self, /) -> type[complex]: ...
     @__class__.setter
     def __class__(self, value: type[complex], /) -> None: ...  # pyright: ignore[reportIncompatibleMethodOverride]
 
 
-#
-
-
 class NestedSequence(Protocol[_T_co]):
     def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
     def __len__(self, /) -> int: ...

From 014e20fc595d2c96450aba8c4cc7644127124567 Mon Sep 17 00:00:00 2001
From: crusaderky <crusaderky@gmail.com>
Date: Wed, 23 Apr 2025 15:29:46 +0100
Subject: [PATCH 08/12] lint

---
 array_api_compat/common/_helpers.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py
index d384279a..5b133626 100644
--- a/array_api_compat/common/_helpers.py
+++ b/array_api_compat/common/_helpers.py
@@ -12,7 +12,7 @@
 import math
 import sys
 import warnings
-from collections.abc import Collection, Hashable
+from collections.abc import Hashable
 from functools import lru_cache
 from types import NoneType
 from typing import (

From 924fc3dadd2eebd24b0e8c14d1e08e1a80ae7143 Mon Sep 17 00:00:00 2001
From: crusaderky <crusaderky@gmail.com>
Date: Thu, 24 Apr 2025 13:35:59 +0100
Subject: [PATCH 09/12] fix merge

---
 array_api_compat/cupy/_aliases.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py
index ca4a14d9..9ce18d43 100644
--- a/array_api_compat/cupy/_aliases.py
+++ b/array_api_compat/cupy/_aliases.py
@@ -71,7 +71,7 @@ def asarray(
     *,
     dtype: DType | None = None,
     device: Device | None = None,
-    copy: py_bool | None = _copy_default,
+    copy: py_bool | None = None,
     **kwargs: object,
 ) -> Array:
     """

From 8eb647f36e8d568e3637f4d065258c86b1086b52 Mon Sep 17 00:00:00 2001
From: crusaderky <crusaderky@gmail.com>
Date: Thu, 15 May 2025 09:12:24 +0100
Subject: [PATCH 10/12] lint

---
 array_api_compat/torch/_aliases.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py
index b3bb991d..7a449001 100644
--- a/array_api_compat/torch/_aliases.py
+++ b/array_api_compat/torch/_aliases.py
@@ -3,7 +3,7 @@
 from collections.abc import Sequence
 from functools import reduce as _reduce, wraps as _wraps
 from builtins import all as _builtin_all, any as _builtin_any
-from typing import Any, List, Optional, Sequence, Tuple, Union, Literal
+from typing import Any, Literal
 
 import torch
 
@@ -824,7 +824,7 @@ def sign(x: Array, /) -> Array:
         return out
 
 
-def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> List[Array]:
+def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> list[Array]:
     # enforce the default of 'xy'
     # TODO: is the return type a list or a tuple
     return list(torch.meshgrid(*arrays, indexing='xy'))

From 247ee6d28f542ab119ed8dbf4bc278b5e0ad9c89 Mon Sep 17 00:00:00 2001
From: crusaderky <crusaderky@gmail.com>
Date: Mon, 19 May 2025 12:49:28 +0100
Subject: [PATCH 11/12] Reverts and tweaks

---
 array_api_compat/common/_aliases.py |  3 +--
 array_api_compat/common/_helpers.py | 24 +++++++++++-------------
 array_api_compat/common/_linalg.py  |  4 ++--
 array_api_compat/common/_typing.py  |  6 +++---
 array_api_compat/cupy/_aliases.py   |  3 ++-
 array_api_compat/numpy/_aliases.py  |  4 ++--
 array_api_compat/numpy/_info.py     |  9 +++++----
 array_api_compat/numpy/_typing.py   |  1 +
 8 files changed, 27 insertions(+), 27 deletions(-)

diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py
index 3b2f74ee..51732b91 100644
--- a/array_api_compat/common/_aliases.py
+++ b/array_api_compat/common/_aliases.py
@@ -6,7 +6,6 @@
 
 import inspect
 from collections.abc import Sequence
-from types import NoneType
 from typing import TYPE_CHECKING, Any, NamedTuple, cast
 
 from ._helpers import _check_device, array_namespace
@@ -384,7 +383,7 @@ def clip(
     out: Array | None = None,
 ) -> Array:
     def _isscalar(a: object) -> TypeIs[float | None]:
-        return isinstance(a, int | float | NoneType)
+        return isinstance(a, int | float) or a is None
 
     min_shape = () if _isscalar(min) else min.shape
     max_shape = () if _isscalar(max) else max.shape
diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py
index 1a8071f1..26e24f15 100644
--- a/array_api_compat/common/_helpers.py
+++ b/array_api_compat/common/_helpers.py
@@ -12,14 +12,14 @@
 import math
 import sys
 import warnings
-from collections.abc import Hashable
+from collections.abc import Collection, Hashable
 from functools import lru_cache
-from types import NoneType
 from typing import (
     TYPE_CHECKING,
     Any,
     Final,
     Literal,
+    SupportsIndex,
     TypeAlias,
     TypeGuard,
     cast,
@@ -51,7 +51,7 @@
         | ndx.Array
         | sparse.SparseArray
         | torch.Tensor
-        | SupportsArrayNamespace
+        | SupportsArrayNamespace[Any]
     )
 
 _API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"})
@@ -630,9 +630,9 @@ def your_function(x, y):
                 raise ValueError(
                     "The given array does not have an array-api-compat wrapper"
                 )
-            x = cast(SupportsArrayNamespace, x)
+            x = cast("SupportsArrayNamespace[Any]", x)
             namespaces.add(x.__array_namespace__(api_version=api_version))
-        elif isinstance(x, int | float | complex | NoneType):
+        elif isinstance(x, int | float | complex) or x is None:
             continue
         else:
             # TODO: Support Python scalars?
@@ -890,12 +890,10 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -
 
 
 @overload
-def size(x: HasShape[int]) -> int: ...
+def size(x: HasShape[Collection[SupportsIndex]]) -> int: ...
 @overload
-def size(x: HasShape[int | None]) -> int | None: ...
-@overload
-def size(x: HasShape[float]) -> int | None: ...  # Dask special case
-def size(x: HasShape[float | None]) -> int | None:
+def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ...
+def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
     """
     Return the total number of elements of x.
 
@@ -910,9 +908,9 @@ def size(x: HasShape[float | None]) -> int | None:
     # Lazy API compliant arrays, such as ndonnx, can contain None in their shape
     if None in x.shape:
         return None
-    out = math.prod(cast(tuple[float, ...], x.shape))
+    out = math.prod(cast("Collection[SupportsIndex]", x.shape))
     # dask.array.Array.shape can contain NaN
-    return None if math.isnan(out) else cast(int, out)
+    return None if math.isnan(out) else out
 
 
 @lru_cache(100)
@@ -1003,7 +1001,7 @@ def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]:
     # on __bool__ (dask is one such example, which however is special-cased above).
 
     # Select a single point of the array
-    s = size(cast(HasShape, x))
+    s = size(cast("HasShape[Collection[SupportsIndex | None]]", x))
     if s is None:
         return True
     xp = array_namespace(x)
diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py
index 87c1029d..3fd9d860 100644
--- a/array_api_compat/common/_linalg.py
+++ b/array_api_compat/common/_linalg.py
@@ -187,14 +187,14 @@ def vector_norm(
         # We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks
         # above to avoid matrix norm logic.
         shape = list(x.shape)
-        _axis = cast(
+        axes = cast(
             "tuple[int, ...]",
             normalize_axis_tuple(  # pyright: ignore[reportCallIssue]
                 range(x.ndim) if axis is None else axis,
                 x.ndim,
             ),
         )
-        for i in _axis:
+        for i in axes:
             shape[i] = 1
         res = xp.reshape(res, tuple(shape))
 
diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py
index 6cfab829..59a79751 100644
--- a/array_api_compat/common/_typing.py
+++ b/array_api_compat/common/_typing.py
@@ -61,13 +61,13 @@ def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
     def __len__(self, /) -> int: ...
 
 
-class SupportsArrayNamespace(Protocol):
-    def __array_namespace__(self, /, *, api_version: str | None) -> Namespace: ...
+class SupportsArrayNamespace(Protocol[_T_co]):
+    def __array_namespace__(self, /, *, api_version: str | None) -> _T_co: ...
 
 
 class HasShape(Protocol[_T_co]):
     @property
-    def shape(self, /) -> tuple[_T_co, ...]: ...
+    def shape(self, /) -> _T_co: ...
 
 
 # Return type of `__array_namespace_info__.default_dtypes`
diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py
index 80627738..c89b8775 100644
--- a/array_api_compat/cupy/_aliases.py
+++ b/array_api_compat/cupy/_aliases.py
@@ -3,6 +3,7 @@
 from builtins import bool as py_bool
 
 import cupy as cp
+
 from ..common import _aliases, _helpers
 from ..common._typing import NestedSequence, SupportsBufferProtocol
 from .._internal import get_xp
@@ -119,7 +120,7 @@ def count_nonzero(
 
 
 # take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg
-def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
+def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
     return cp.take_along_axis(x, indices, axis=axis)
 
 
diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py
index f5d7e030..5a05a820 100644
--- a/array_api_compat/numpy/_aliases.py
+++ b/array_api_compat/numpy/_aliases.py
@@ -119,14 +119,14 @@ def count_nonzero(
 ) -> Array:
     # NOTE: this is currently incorrectly typed in numpy, but will be fixed in
     # numpy 2.2.5 and 2.3.0: https://github.com/numpy/numpy/pull/28750
-    result = cast(Any, np.count_nonzero(x, axis=axis, keepdims=keepdims))  # type: ignore[arg-type]  # pyright: ignore[reportArgumentType, reportCallIssue]
+    result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims))  # pyright: ignore[reportArgumentType, reportCallIssue]
     if axis is None and not keepdims:
         return np.asarray(result)
     return result
 
 
 # take_along_axis: axis defaults to -1 but in numpy axis is a required arg
-def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
+def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
     return np.take_along_axis(x, indices, axis=axis)
 
 
diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py
index 11126e5d..c625c13e 100644
--- a/array_api_compat/numpy/_info.py
+++ b/array_api_compat/numpy/_info.py
@@ -7,7 +7,6 @@
 more details.
 
 """
-
 from __future__ import annotations
 
 from numpy import bool_ as bool
@@ -64,7 +63,7 @@ class __array_namespace_info__:
 
     """
 
-    __module__ = "numpy"
+    __module__ = 'numpy'
 
     def capabilities(self):
         """
@@ -183,7 +182,8 @@ def default_dtypes(
         """
         if device not in ["cpu", None]:
             raise ValueError(
-                f'Device not understood. Only "cpu" is allowed, but received: {device}'
+                'Device not understood. Only "cpu" is allowed, but received:'
+                f' {device}'
             )
         return {
             "real floating": dtype(float64),
@@ -254,7 +254,8 @@ def dtypes(
         """
         if device not in ["cpu", None]:
             raise ValueError(
-                f'Device not understood. Only "cpu" is allowed, but received: {device}'
+                'Device not understood. Only "cpu" is allowed, but received:'
+                f' {device}'
             )
         if kind is None:
             return {
diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py
index 617cfb71..e771c788 100644
--- a/array_api_compat/numpy/_typing.py
+++ b/array_api_compat/numpy/_typing.py
@@ -7,6 +7,7 @@
 Device: TypeAlias = Literal["cpu"]
 
 if TYPE_CHECKING:
+
     # NumPy 1.x on Python 3.10 fails to parse np.dtype[]
     DType: TypeAlias = np.dtype[
         np.bool_

From 85fce08cf71a269f8ff690ea88b17fa587384066 Mon Sep 17 00:00:00 2001
From: crusaderky <crusaderky@gmail.com>
Date: Mon, 19 May 2025 12:53:36 +0100
Subject: [PATCH 12/12] Fix test_all

---
 array_api_compat/cupy/_aliases.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py
index c89b8775..c0473ca4 100644
--- a/array_api_compat/cupy/_aliases.py
+++ b/array_api_compat/cupy/_aliases.py
@@ -148,4 +148,6 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
                               'bool', 'concat', 'count_nonzero', 'pow', 'sign',
                               'take_along_axis']
 
-_all_ignore = ['cp', 'get_xp']
+
+def __dir__() -> list[str]:
+    return __all__