Closed
Description
torch.result_type
results are order-dependent when default dtype is float64
.
from array_api_compat import torch as xp
xp.set_default_dtype(xp.float64)
dtypes = ['float32', 'float64', 'int32', 'int64', 'complex64', 'complex128']
for dtype_a in dtypes:
for dtype_b in dtypes:
a = xp.asarray([2, 1], dtype=getattr(xp, dtype_a))
b = xp.asarray([1, -1], dtype=getattr(xp, dtype_b))
dtype_1 = xp.result_type(a, b, 1.0)
dtype_2 = xp.result_type(b, a, 1.0)
if dtype_1 != dtype_2:
print('dtype_1 != dtype_2:', dtype_1, dtype_2)
# dtype_1 != dtype_2: torch.float64 torch.float32
# dtype_1 != dtype_2: torch.float64 torch.float32
# dtype_1 != dtype_2: torch.float32 torch.float64
# dtype_1 != dtype_2: torch.complex64 torch.complex128
# dtype_1 != dtype_2: torch.float32 torch.float64
# dtype_1 != dtype_2: torch.complex64 torch.complex128
# dtype_1 != dtype_2: torch.complex128 torch.complex64
# dtype_1 != dtype_2: torch.complex128 torch.complex64
Thanks @ev-br for figuring out that it was dependent on the default dtype!
Metadata
Metadata
Assignees
Labels
No labels