Open
Description
This is a tracker of the current state of support for more than one device at once in the Array API, its helper libraries, and the libraries that implement it.
Supporting multiple devices at the same time is typically substantially more fragile than pinning one of the available devices at interpreter level and then using that one exclusively, which typically works as intended.
Array API
- Dictates that the device of the output arrays must always follow from that of the input(s), unless explicitly overridden by a
device=
kwarg, where allowed. - Is frequently misinterpreted when it comes to priority of input arrays vs. global/context device: docs: clarify input arrays in device placement #919
- There is controversy on what
__array_namespace_info__().default_device()
should return: Clarify definitions of "default device" and "current device" #835
array-api-strict
- Supports three hardcoded devices, "cpu", "device1", "device2". This is fit for purpose for testing downstream bugs re. device propagation.
array-api-tests
- Devices are untested: Test device support array-api-tests#302
array-api-compat
- Adds device param to numpy 1, cupy, torch, and dask (read below).
- Implements helper functions
device()
andto_device()
to work around non-compliance of wrapped libraries
array-api-extra
- Full support and testing for non-default devices, using array-api-strict only. Actual support from real backends entirely depends on the below.
NumPy
- It supports a single dummy device, "cpu".
array-api-compat
backports it to NumPy 1.x.
CuPy
- Non-compliant support for multiple devices.
- array-api-compat adds a dummy
device=
parameter to functions. - A compatibility layer is being added at the moment of writing by [DNM] ENH: CuPy multi-device support array-api-compat#293. [EDIT] it can't work, as array-api-compat can't patch methods.
- As it doesn't have a "cpu" device, it's impossible to test multi-device ops without access to a dual-GPU host.
PyTorch
- Fully supported (with array-api-compat shims)
- However there's a bug that hampers testing on GPU CI:
asarray
: device does not propagate from input to output afterset_default_device
pytorch/pytorch#150199
JAX
- Bugs in
__array_namespace_info__
: Array APIdefault_device()
anddevices()
are incorrect jax-ml/jax#27606 - Inside
jax.jit
, input-to-output device propagation works, but it's impossible to call creation functions (empty
,zeros
,full
, etc.) on a non-default device: Missing.device
attribute inside@jax.jit
jax-ml/jax#26000
Dask
- Dask doesn't have a concept of device
- array-api-compat adds stub support, that returns "cpu" when wrapping around numpy and a dummy
DASK_DEVICE
otherwise. Notably, this is stored nowhere and does not survive a round-trip (device(to_device(x, d) == d
can fail). - This is a non-issue when wrapping around numpy, or when wrapping around cupy with both client and workers mounting a single GPU.
- Multi-GPU Dask+CuPy support could be achieved by starting separate worker processes on the same host and pinning the GPU at interpreter level. This is extremely inefficient as it incurs in IPC and possibly memory duplication. If a user does so, the client and array-api-compat will never know.
dask-cuda
may improve the situation (did not investigate).
SciPy
- Largely untested. Initial attempt to test: BUG/TST:
special.logsumexp
on non-default device scipy/scipy#22756
Metadata
Metadata
Assignees
Labels
No labels