Open
Description
This RFC seeks to include a new API in the array API specification for the purpose of computing the log of summed exponentials.
Overview
The Array API specification currently includes logaddexp
which performs an element-wise operation on two input arrays, but does not include the reduction logsumexp
. This API is commonly implemented in accelerator libraries for better numerical stability in deep learning applications.
- logaddexp:
This can be implemented using log(sum(exp))
; however, such an implementation is not likely to be numerically stable.
Prior art
- NumPy: (not currently implemented)
- NumPy does, however, implement
logaddexp.reduce
.
- NumPy does, however, implement
- Dask: (not currently implemented)
- SciPy: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html
- CuPy: https://docs.cupy.dev/en/v13.0.0b1/reference/generated/cupyx.scipy.special.logsumexp.html
- In
scipy.special
namespace.
- In
- PyTorch: https://pytorch.org/docs/stable/generated/torch.logsumexp.html (also an alias in
torch.special
: https://pytorch.org/docs/stable/special.html#torch.special.logsumexp) - TensorFlow: https://www.tensorflow.org/api_docs/python/tf/math/reduce_logsumexp
- JAX: jax.nn.logsumexp and jax.scipy.special.logsumexp (same function, exposed in two places)
Proposal:
def logsumexp(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype: Optional[dtype] = None, keepdims: bool = False) -> array
dtype
kwarg is for consistency withsum
et al
Related
cc @kgryte
Metadata
Metadata
Assignees
Labels
Type
Projects
Status
Stage 1