Closed
Description
torch.result_type
does not accept combinations of scalars and dtypes. (These are permitted by the standard; only one of the arguments has to be an array or dtype.)
from array_api_compat import torch as xp
types = ['scalar', 'array', 'dtype']
dtypes = [float, int, complex]
for type_a in types:
for type_b in types:
for dtype_a in dtypes:
for dtype_b in dtypes:
scalar_a = dtype_a(1)
scalar_b = dtype_b(1)
array_a = xp.asarray(scalar_a)
array_b = xp.asarray(scalar_b)
in1 = dict(scalar=scalar_a, array=array_a, dtype=array_a.dtype)[type_a]
in2 = dict(scalar=scalar_b, array=array_b, dtype=array_b.dtype)[type_b]
try:
xp.result_type(in1, in2)
except:
print(in1, in2)
Result:
1.0 torch.float32
1.0 torch.int64
1.0 torch.complex64
1 torch.float32
1 torch.int64
1 torch.complex64
(1+0j) torch.float32
(1+0j) torch.int64
(1+0j) torch.complex64
torch.float32 1.0
torch.float32 1
torch.float32 (1+0j)
torch.int64 1.0
torch.int64 1
torch.int64 (1+0j)
torch.complex64 1.0
torch.complex64 1
torch.complex64 (1+0j)
This seems to include any combination of Python scalar and torch dtype and only these combinations. (E.g., array/scalar and scalar/array are fine).
Metadata
Metadata
Assignees
Labels
No labels