From 7e4d303f776dbe910952a5a10e3da9ba57a9fb0f Mon Sep 17 00:00:00 2001
From: crusaderky <crusaderky@gmail.com>
Date: Thu, 3 Apr 2025 11:49:22 +0100
Subject: [PATCH 1/3] MAINT: clarify `default_device` output

---
 array_api_compat/common/_aliases.py  |  2 +-
 array_api_compat/cupy/_info.py       | 14 ++++++++++++--
 array_api_compat/dask/array/_info.py |  4 ++--
 array_api_compat/numpy/_info.py      |  4 ++--
 array_api_compat/torch/_info.py      | 25 ++++++++++++++++++-------
 5 files changed, 35 insertions(+), 14 deletions(-)

diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py
index 46cbb359..351b5bd6 100644
--- a/array_api_compat/common/_aliases.py
+++ b/array_api_compat/common/_aliases.py
@@ -18,7 +18,7 @@
 
 # These functions are modified from the NumPy versions.
 
-# Creation functions add the device keyword (which does nothing for NumPy)
+# Creation functions add the device keyword (which does nothing for NumPy and Dask)
 
 def arange(
     start: Union[int, float],
diff --git a/array_api_compat/cupy/_info.py b/array_api_compat/cupy/_info.py
index 790621e4..66d3c4ae 100644
--- a/array_api_compat/cupy/_info.py
+++ b/array_api_compat/cupy/_info.py
@@ -26,6 +26,7 @@
     complex128,
 )
 
+
 class __array_namespace_info__:
     """
     Get the array API inspection namespace for CuPy.
@@ -117,7 +118,7 @@ def default_device(self):
 
         Returns
         -------
-        device : str
+        device : Device
             The default device used for new CuPy arrays.
 
         Examples
@@ -126,6 +127,15 @@ def default_device(self):
         >>> info.default_device()
         Device(0)
 
+        Notes
+        -----
+        This method returns the static default device when CuPy is initialized.
+        However, the *current* device used by creation functions (``empty`` etc.)
+        can be changed globally or with a context manager.
+
+        See Also
+        --------
+        https://github.com/data-apis/array-api/issues/835
         """
         return cuda.Device(0)
 
@@ -312,7 +322,7 @@ def devices(self):
 
         Returns
         -------
-        devices : list of str
+        devices : list[Device]
             The devices supported by CuPy.
 
         See Also
diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py
index fc70b5a2..97ceec67 100644
--- a/array_api_compat/dask/array/_info.py
+++ b/array_api_compat/dask/array/_info.py
@@ -130,7 +130,7 @@ def default_device(self):
 
         Returns
         -------
-        device : str
+        device : Device
             The default device used for new Dask arrays.
 
         Examples
@@ -335,7 +335,7 @@ def devices(self):
 
         Returns
         -------
-        devices : list of str
+        devices : list[Device]
             The devices supported by Dask.
 
         See Also
diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py
index e706d118..a30ee352 100644
--- a/array_api_compat/numpy/_info.py
+++ b/array_api_compat/numpy/_info.py
@@ -119,7 +119,7 @@ def default_device(self):
 
         Returns
         -------
-        device : str
+        device : Device
             The default device used for new NumPy arrays.
 
         Examples
@@ -326,7 +326,7 @@ def devices(self):
 
         Returns
         -------
-        devices : list of str
+        devices : list[Device]
             The devices supported by NumPy.
 
         See Also
diff --git a/array_api_compat/torch/_info.py b/array_api_compat/torch/_info.py
index 34fbcb21..b0486a58 100644
--- a/array_api_compat/torch/_info.py
+++ b/array_api_compat/torch/_info.py
@@ -102,15 +102,24 @@ def default_device(self):
 
         Returns
         -------
-        device : str
+        device : Device
             The default device used for new PyTorch arrays.
 
         Examples
         --------
         >>> info = np.__array_namespace_info__()
         >>> info.default_device()
-        'cpu'
+        device(type='cpu')
 
+        Notes
+        -----
+        This method returns the static default device when PyTorch is initialized.
+        However, the *current* device used by creation functions (``empty`` etc.)
+        can be changed at runtime.
+
+        See Also
+        --------
+        https://github.com/data-apis/array-api/issues/835
         """
         return torch.device("cpu")
 
@@ -120,9 +129,9 @@ def default_dtypes(self, *, device=None):
 
         Parameters
         ----------
-        device : str, optional
-            The device to get the default data types for. For PyTorch, only
-            ``'cpu'`` is allowed.
+        device : Device, optional
+            The device to get the default data types for.
+            Unused for PyTorch, as all devices use the same default dtypes.
 
         Returns
         -------
@@ -250,8 +259,9 @@ def dtypes(self, *, device=None, kind=None):
 
         Parameters
         ----------
-        device : str, optional
+        device : Device, optional
             The device to get the data types for.
+            Unused for PyTorch, as all devices use the same dtypes.
         kind : str or tuple of str, optional
             The kind of data types to return. If ``None``, all data types are
             returned. If a string, only data types of that kind are returned.
@@ -310,7 +320,7 @@ def devices(self):
 
         Returns
         -------
-        devices : list of str
+        devices : list[Device]
             The devices supported by PyTorch.
 
         See Also
@@ -333,6 +343,7 @@ def devices(self):
         # device:
         try:
             torch.device('notadevice')
+            raise AssertionError("unreachable")  # pragma: nocover
         except RuntimeError as e:
             # The error message is something like:
             # "Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: notadevice"

From 04e226802752a00b9a8e8f2b5e1a5d9a77cf07c6 Mon Sep 17 00:00:00 2001
From: crusaderky <crusaderky@gmail.com>
Date: Fri, 11 Apr 2025 17:36:01 +0100
Subject: [PATCH 2/3] Remove outdated comment on 'max rank'

---
 array_api_compat/cupy/_info.py  | 1 -
 array_api_compat/numpy/_info.py | 1 -
 array_api_compat/torch/_info.py | 1 -
 3 files changed, 3 deletions(-)

diff --git a/array_api_compat/cupy/_info.py b/array_api_compat/cupy/_info.py
index 66d3c4ae..f8342c6b 100644
--- a/array_api_compat/cupy/_info.py
+++ b/array_api_compat/cupy/_info.py
@@ -101,7 +101,6 @@ def capabilities(self):
         return {
             "boolean indexing": True,
             "data-dependent shapes": True,
-            # 'max rank' will be part of the 2024.12 standard
             "max dimensions": 64,
         }
 
diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py
index a30ee352..9a3b11a9 100644
--- a/array_api_compat/numpy/_info.py
+++ b/array_api_compat/numpy/_info.py
@@ -100,7 +100,6 @@ def capabilities(self):
         return {
             "boolean indexing": True,
             "data-dependent shapes": True,
-            # 'max rank' will be part of the 2024.12 standard
             "max dimensions": 64,
         }
 
diff --git a/array_api_compat/torch/_info.py b/array_api_compat/torch/_info.py
index b0486a58..b1952c6e 100644
--- a/array_api_compat/torch/_info.py
+++ b/array_api_compat/torch/_info.py
@@ -85,7 +85,6 @@ def capabilities(self):
         return {
             "boolean indexing": True,
             "data-dependent shapes": True,
-            # 'max rank' will be part of the 2024.12 standard
             "max dimensions": 64,
         }
 

From 4172d855d49d7eec412e2f60bb34cb4cdb348f51 Mon Sep 17 00:00:00 2001
From: crusaderky <crusaderky@gmail.com>
Date: Fri, 11 Apr 2025 17:39:44 +0100
Subject: [PATCH 3/3] Update docs

---
 array_api_compat/cupy/_info.py       |  5 +++--
 array_api_compat/dask/array/_info.py | 15 ++++++++-------
 array_api_compat/numpy/_info.py      |  3 ++-
 array_api_compat/torch/_info.py      | 15 ++++++++-------
 4 files changed, 21 insertions(+), 17 deletions(-)

diff --git a/array_api_compat/cupy/_info.py b/array_api_compat/cupy/_info.py
index f8342c6b..78e48a33 100644
--- a/array_api_compat/cupy/_info.py
+++ b/array_api_compat/cupy/_info.py
@@ -50,7 +50,7 @@ class __array_namespace_info__:
 
     Examples
     --------
-    >>> info = np.__array_namespace_info__()
+    >>> info = xp.__array_namespace_info__()
     >>> info.default_dtypes()
     {'real floating': cupy.float64,
      'complex floating': cupy.complex128,
@@ -95,7 +95,8 @@ def capabilities(self):
         >>> info = xp.__array_namespace_info__()
         >>> info.capabilities()
         {'boolean indexing': True,
-         'data-dependent shapes': True}
+         'data-dependent shapes': True,
+         'max dimensions': 64}
 
         """
         return {
diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py
index 97ceec67..614f43d9 100644
--- a/array_api_compat/dask/array/_info.py
+++ b/array_api_compat/dask/array/_info.py
@@ -50,7 +50,7 @@ class __array_namespace_info__:
 
     Examples
     --------
-    >>> info = np.__array_namespace_info__()
+    >>> info = xp.__array_namespace_info__()
     >>> info.default_dtypes()
     {'real floating': dask.float64,
      'complex floating': dask.complex128,
@@ -103,10 +103,11 @@ def capabilities(self):
 
         Examples
         --------
-        >>> info = np.__array_namespace_info__()
+        >>> info = xp.__array_namespace_info__()
         >>> info.capabilities()
         {'boolean indexing': True,
-         'data-dependent shapes': True}
+         'data-dependent shapes': True,
+         'max dimensions': 64}
 
         """
         return {
@@ -135,7 +136,7 @@ def default_device(self):
 
         Examples
         --------
-        >>> info = np.__array_namespace_info__()
+        >>> info = xp.__array_namespace_info__()
         >>> info.default_device()
         'cpu'
 
@@ -173,7 +174,7 @@ def default_dtypes(self, *, device=None):
 
         Examples
         --------
-        >>> info = np.__array_namespace_info__()
+        >>> info = xp.__array_namespace_info__()
         >>> info.default_dtypes()
         {'real floating': dask.float64,
          'complex floating': dask.complex128,
@@ -239,7 +240,7 @@ def dtypes(self, *, device=None, kind=None):
 
         Examples
         --------
-        >>> info = np.__array_namespace_info__()
+        >>> info = xp.__array_namespace_info__()
         >>> info.dtypes(kind='signed integer')
         {'int8': dask.int8,
          'int16': dask.int16,
@@ -347,7 +348,7 @@ def devices(self):
 
         Examples
         --------
-        >>> info = np.__array_namespace_info__()
+        >>> info = xp.__array_namespace_info__()
         >>> info.devices()
         ['cpu', DASK_DEVICE]
 
diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py
index 9a3b11a9..365855b8 100644
--- a/array_api_compat/numpy/_info.py
+++ b/array_api_compat/numpy/_info.py
@@ -94,7 +94,8 @@ def capabilities(self):
         >>> info = np.__array_namespace_info__()
         >>> info.capabilities()
         {'boolean indexing': True,
-         'data-dependent shapes': True}
+         'data-dependent shapes': True,
+         'max dimensions': 64}
 
         """
         return {
diff --git a/array_api_compat/torch/_info.py b/array_api_compat/torch/_info.py
index b1952c6e..818e5d37 100644
--- a/array_api_compat/torch/_info.py
+++ b/array_api_compat/torch/_info.py
@@ -34,7 +34,7 @@ class __array_namespace_info__:
 
     Examples
     --------
-    >>> info = np.__array_namespace_info__()
+    >>> info = xp.__array_namespace_info__()
     >>> info.default_dtypes()
     {'real floating': numpy.float64,
      'complex floating': numpy.complex128,
@@ -76,10 +76,11 @@ def capabilities(self):
 
         Examples
         --------
-        >>> info = np.__array_namespace_info__()
+        >>> info = xp.__array_namespace_info__()
         >>> info.capabilities()
         {'boolean indexing': True,
-         'data-dependent shapes': True}
+         'data-dependent shapes': True,
+         'max dimensions': 64}
 
         """
         return {
@@ -106,7 +107,7 @@ def default_device(self):
 
         Examples
         --------
-        >>> info = np.__array_namespace_info__()
+        >>> info = xp.__array_namespace_info__()
         >>> info.default_device()
         device(type='cpu')
 
@@ -147,7 +148,7 @@ def default_dtypes(self, *, device=None):
 
         Examples
         --------
-        >>> info = np.__array_namespace_info__()
+        >>> info = xp.__array_namespace_info__()
         >>> info.default_dtypes()
         {'real floating': torch.float32,
          'complex floating': torch.complex64,
@@ -296,7 +297,7 @@ def dtypes(self, *, device=None, kind=None):
 
         Examples
         --------
-        >>> info = np.__array_namespace_info__()
+        >>> info = xp.__array_namespace_info__()
         >>> info.dtypes(kind='signed integer')
         {'int8': numpy.int8,
          'int16': numpy.int16,
@@ -331,7 +332,7 @@ def devices(self):
 
         Examples
         --------
-        >>> info = np.__array_namespace_info__()
+        >>> info = xp.__array_namespace_info__()
         >>> info.devices()
         [device(type='cpu'), device(type='mps', index=0), device(type='meta')]