From 845d8b552e368d64a5b2a12896f711bd7b412e01 Mon Sep 17 00:00:00 2001
From: crusaderky <crusaderky@gmail.com>
Date: Thu, 3 Apr 2025 11:30:31 +0100
Subject: [PATCH 1/3] ENH: torch dtype promotions

---
 array_api_compat/torch/_aliases.py | 118 +++++++++++++++--------------
 1 file changed, 61 insertions(+), 57 deletions(-)

diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py
index a2ed1449..5a69d27e 100644
--- a/array_api_compat/torch/_aliases.py
+++ b/array_api_compat/torch/_aliases.py
@@ -22,9 +22,9 @@
 try:
     # torch >=2.3
     _int_dtypes |= {torch.uint16, torch.uint32, torch.uint64}
+    _HAS_LARGE_UINT = True
 except AttributeError:
-    pass
-
+    _HAS_LARGE_UINT = False
 
 _array_api_dtypes = {
     torch.bool,
@@ -35,47 +35,23 @@
     torch.complex128,
 }
 
-_promotion_table  = {
-    # bool
-    (torch.bool, torch.bool): torch.bool,
+_promotion_table = {
     # ints
-    (torch.int8, torch.int8): torch.int8,
     (torch.int8, torch.int16): torch.int16,
     (torch.int8, torch.int32): torch.int32,
     (torch.int8, torch.int64): torch.int64,
-    (torch.int16, torch.int8): torch.int16,
-    (torch.int16, torch.int16): torch.int16,
     (torch.int16, torch.int32): torch.int32,
     (torch.int16, torch.int64): torch.int64,
-    (torch.int32, torch.int8): torch.int32,
-    (torch.int32, torch.int16): torch.int32,
-    (torch.int32, torch.int32): torch.int32,
     (torch.int32, torch.int64): torch.int64,
-    (torch.int64, torch.int8): torch.int64,
-    (torch.int64, torch.int16): torch.int64,
-    (torch.int64, torch.int32): torch.int64,
-    (torch.int64, torch.int64): torch.int64,
-    # uints
-    (torch.uint8, torch.uint8): torch.uint8,
     # ints and uints (mixed sign)
-    (torch.int8, torch.uint8): torch.int16,
-    (torch.int16, torch.uint8): torch.int16,
-    (torch.int32, torch.uint8): torch.int32,
-    (torch.int64, torch.uint8): torch.int64,
     (torch.uint8, torch.int8): torch.int16,
     (torch.uint8, torch.int16): torch.int16,
     (torch.uint8, torch.int32): torch.int32,
     (torch.uint8, torch.int64): torch.int64,
     # floats
-    (torch.float32, torch.float32): torch.float32,
     (torch.float32, torch.float64): torch.float64,
-    (torch.float64, torch.float32): torch.float64,
-    (torch.float64, torch.float64): torch.float64,
     # complexes
-    (torch.complex64, torch.complex64): torch.complex64,
     (torch.complex64, torch.complex128): torch.complex128,
-    (torch.complex128, torch.complex64): torch.complex128,
-    (torch.complex128, torch.complex128): torch.complex128,
     # Mixed float and complex
     (torch.float32, torch.complex64): torch.complex64,
     (torch.float32, torch.complex128): torch.complex128,
@@ -83,6 +59,31 @@
     (torch.float64, torch.complex128): torch.complex128,
 }
 
+if _HAS_LARGE_UINT:  # torch >=2.3
+    _promotion_table.update(
+        {
+            # uints
+            (torch.uint8, torch.uint16): torch.uint16,
+            (torch.uint8, torch.uint32): torch.uint32,
+            (torch.uint8, torch.uint64): torch.uint64,
+            (torch.uint16, torch.uint32): torch.uint32,
+            (torch.uint16, torch.uint64): torch.uint64,
+            (torch.uint32, torch.uint64): torch.uint64,
+            # ints and uints (mixed sign)
+            (torch.uint16, torch.int8): torch.int32,
+            (torch.uint16, torch.int16): torch.int32,
+            (torch.uint16, torch.int32): torch.int32,
+            (torch.uint16, torch.int64): torch.int64,
+            (torch.uint32, torch.int8): torch.int64,
+            (torch.uint32, torch.int16): torch.int64,
+            (torch.uint32, torch.int32): torch.int64,
+            (torch.uint32, torch.int64): torch.int64,
+        }
+    )
+
+_promotion_table.update({(b, a): c for (a, b), c in _promotion_table.items()})
+_promotion_table.update({(a, a): a for a in _array_api_dtypes})
+
 
 def _two_arg(f):
     @_wraps(f)
@@ -301,6 +302,31 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
             out = torch.unsqueeze(out, a)
     return out
 
+
+def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array:
+    """
+    Implements `sum(..., axis=())` and `prod(..., axis=())`.
+    
+    Works around https://github.com/pytorch/pytorch/issues/29137
+    """
+    if dtype is not None:
+        return x.clone() if dtype == x.dtype else x.to(dtype)
+
+    if x.dtype in (torch.int8, torch.int16, torch.int32):
+        return x.to(torch.int64)
+
+    if _HAS_LARGE_UINT and x.dtype in (torch.uint8, torch.uint16, torch.uint32):
+        return x.to(torch.uint64)
+
+    if x.dtype == torch.uint8:
+        # We can't upcast uint8 according to the spec because there is no
+        # torch.uint64, so at least upcast to int64 which is what prod does
+        # when axis=None.
+        return x.to(torch.int64)
+
+    return x.clone()
+
+
 def prod(x: Array,
          /,
          *,
@@ -308,20 +334,9 @@ def prod(x: Array,
          dtype: Optional[DType] = None,
          keepdims: bool = False,
          **kwargs) -> Array:
-    ndim = x.ndim
 
-    # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
-    # below because it still needs to upcast.
     if axis == ():
-        if dtype is None:
-            # We can't upcast uint8 according to the spec because there is no
-            # torch.uint64, so at least upcast to int64 which is what sum does
-            # when axis=None.
-            if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]:
-                return x.to(torch.int64)
-            return x.clone()
-        return x.to(dtype)
-
+        return _sum_prod_no_axis(x, dtype)
     # torch.prod doesn't support multiple axes
     # (https://github.com/pytorch/pytorch/issues/56586).
     if isinstance(axis, tuple):
@@ -330,7 +345,7 @@ def prod(x: Array,
         # torch doesn't support keepdims with axis=None
         # (https://github.com/pytorch/pytorch/issues/71209)
         res = torch.prod(x, dtype=dtype, **kwargs)
-        res = _axis_none_keepdims(res, ndim, keepdims)
+        res = _axis_none_keepdims(res, x.ndim, keepdims)
         return res
 
     return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs)
@@ -343,25 +358,14 @@ def sum(x: Array,
          dtype: Optional[DType] = None,
          keepdims: bool = False,
          **kwargs) -> Array:
-    ndim = x.ndim
 
-    # https://github.com/pytorch/pytorch/issues/29137.
-    # Make sure it upcasts.
     if axis == ():
-        if dtype is None:
-            # We can't upcast uint8 according to the spec because there is no
-            # torch.uint64, so at least upcast to int64 which is what sum does
-            # when axis=None.
-            if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]:
-                return x.to(torch.int64)
-            return x.clone()
-        return x.to(dtype)
-
+        return _sum_prod_no_axis(x, dtype)
     if axis is None:
         # torch doesn't support keepdims with axis=None
         # (https://github.com/pytorch/pytorch/issues/71209)
         res = torch.sum(x, dtype=dtype, **kwargs)
-        res = _axis_none_keepdims(res, ndim, keepdims)
+        res = _axis_none_keepdims(res, x.ndim, keepdims)
         return res
 
     return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs)
@@ -372,7 +376,7 @@ def any(x: Array,
         axis: Optional[Union[int, Tuple[int, ...]]] = None,
         keepdims: bool = False,
         **kwargs) -> Array:
-    ndim = x.ndim
+
     if axis == ():
         return x.to(torch.bool)
     # torch.any doesn't support multiple axes
@@ -384,7 +388,7 @@ def any(x: Array,
         # torch doesn't support keepdims with axis=None
         # (https://github.com/pytorch/pytorch/issues/71209)
         res = torch.any(x, **kwargs)
-        res = _axis_none_keepdims(res, ndim, keepdims)
+        res = _axis_none_keepdims(res, x.ndim, keepdims)
         return res.to(torch.bool)
 
     # torch.any doesn't return bool for uint8
@@ -396,7 +400,7 @@ def all(x: Array,
         axis: Optional[Union[int, Tuple[int, ...]]] = None,
         keepdims: bool = False,
         **kwargs) -> Array:
-    ndim = x.ndim
+
     if axis == ():
         return x.to(torch.bool)
     # torch.all doesn't support multiple axes
@@ -408,7 +412,7 @@ def all(x: Array,
         # torch doesn't support keepdims with axis=None
         # (https://github.com/pytorch/pytorch/issues/71209)
         res = torch.all(x, **kwargs)
-        res = _axis_none_keepdims(res, ndim, keepdims)
+        res = _axis_none_keepdims(res, x.ndim, keepdims)
         return res.to(torch.bool)
 
     # torch.all doesn't return bool for uint8

From 39d285c1e778c547a6dcf04fd9a4625cd1022cbf Mon Sep 17 00:00:00 2001
From: crusaderky <crusaderky@gmail.com>
Date: Thu, 10 Apr 2025 10:58:56 +0100
Subject: [PATCH 2/3] tweak _result_type

---
 array_api_compat/torch/_aliases.py | 13 +++++++++----
 1 file changed, 9 insertions(+), 4 deletions(-)

diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py
index 5a69d27e..826632ad 100644
--- a/array_api_compat/torch/_aliases.py
+++ b/array_api_compat/torch/_aliases.py
@@ -151,13 +151,18 @@ def result_type(
         return _reduce(_result_type, others + scalars)
 
 
-def _result_type(x, y):
+def _result_type(
+    x: Array | DType | bool | int | float | complex,
+    y: Array | DType | bool | int | float | complex,
+) -> DType:
     if not (isinstance(x, _py_scalars) or isinstance(y, _py_scalars)):
-        xdt = x.dtype if not isinstance(x, torch.dtype) else x
-        ydt = y.dtype if not isinstance(y, torch.dtype) else y
+        xdt = x if isinstance(x, torch.dtype) else x.dtype
+        ydt = y if isinstance(y, torch.dtype) else y.dtype
 
-        if (xdt, ydt) in _promotion_table:
+        try:
             return _promotion_table[xdt, ydt]
+        except KeyError:
+            pass
 
     # This doesn't result_type(dtype, dtype) for non-array API dtypes
     # because torch.result_type only accepts tensors. This does however, allow

From 1ae5c1b217e54e6b954ca9ac17ee06ed984fb849 Mon Sep 17 00:00:00 2001
From: crusaderky <crusaderky@gmail.com>
Date: Thu, 10 Apr 2025 10:59:05 +0100
Subject: [PATCH 3/3] Revert uint promotions

---
 array_api_compat/torch/_aliases.py | 40 +++++-------------------------
 1 file changed, 6 insertions(+), 34 deletions(-)

diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py
index 826632ad..5370803f 100644
--- a/array_api_compat/torch/_aliases.py
+++ b/array_api_compat/torch/_aliases.py
@@ -22,9 +22,9 @@
 try:
     # torch >=2.3
     _int_dtypes |= {torch.uint16, torch.uint32, torch.uint64}
-    _HAS_LARGE_UINT = True
 except AttributeError:
-    _HAS_LARGE_UINT = False
+    pass
+
 
 _array_api_dtypes = {
     torch.bool,
@@ -59,28 +59,6 @@
     (torch.float64, torch.complex128): torch.complex128,
 }
 
-if _HAS_LARGE_UINT:  # torch >=2.3
-    _promotion_table.update(
-        {
-            # uints
-            (torch.uint8, torch.uint16): torch.uint16,
-            (torch.uint8, torch.uint32): torch.uint32,
-            (torch.uint8, torch.uint64): torch.uint64,
-            (torch.uint16, torch.uint32): torch.uint32,
-            (torch.uint16, torch.uint64): torch.uint64,
-            (torch.uint32, torch.uint64): torch.uint64,
-            # ints and uints (mixed sign)
-            (torch.uint16, torch.int8): torch.int32,
-            (torch.uint16, torch.int16): torch.int32,
-            (torch.uint16, torch.int32): torch.int32,
-            (torch.uint16, torch.int64): torch.int64,
-            (torch.uint32, torch.int8): torch.int64,
-            (torch.uint32, torch.int16): torch.int64,
-            (torch.uint32, torch.int32): torch.int64,
-            (torch.uint32, torch.int64): torch.int64,
-        }
-    )
-
 _promotion_table.update({(b, a): c for (a, b), c in _promotion_table.items()})
 _promotion_table.update({(a, a): a for a in _array_api_dtypes})
 
@@ -317,16 +295,10 @@ def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array:
     if dtype is not None:
         return x.clone() if dtype == x.dtype else x.to(dtype)
 
-    if x.dtype in (torch.int8, torch.int16, torch.int32):
-        return x.to(torch.int64)
-
-    if _HAS_LARGE_UINT and x.dtype in (torch.uint8, torch.uint16, torch.uint32):
-        return x.to(torch.uint64)
-
-    if x.dtype == torch.uint8:
-        # We can't upcast uint8 according to the spec because there is no
-        # torch.uint64, so at least upcast to int64 which is what prod does
-        # when axis=None.
+    # We can't upcast uint8 according to the spec because there is no
+    # torch.uint64, so at least upcast to int64 which is what prod does
+    # when axis=None.
+    if x.dtype in (torch.uint8, torch.int8, torch.int16, torch.int32):
         return x.to(torch.int64)
 
     return x.clone()