Skip to content

torch.result_type results are order-dependent when default dtype is float64 #273

Closed
@mdhaber

Description

@mdhaber

torch.result_type results are order-dependent when default dtype is float64.

from array_api_compat import torch as xp
xp.set_default_dtype(xp.float64)

dtypes = ['float32', 'float64', 'int32', 'int64', 'complex64', 'complex128']
for dtype_a in dtypes:
    for dtype_b in dtypes:
        a = xp.asarray([2, 1], dtype=getattr(xp, dtype_a))
        b = xp.asarray([1, -1], dtype=getattr(xp, dtype_b))
        dtype_1 = xp.result_type(a, b, 1.0)
        dtype_2 = xp.result_type(b, a, 1.0)
        if dtype_1 != dtype_2:
            print('dtype_1 != dtype_2:', dtype_1, dtype_2)

# dtype_1 != dtype_2: torch.float64 torch.float32
# dtype_1 != dtype_2: torch.float64 torch.float32
# dtype_1 != dtype_2: torch.float32 torch.float64
# dtype_1 != dtype_2: torch.complex64 torch.complex128
# dtype_1 != dtype_2: torch.float32 torch.float64
# dtype_1 != dtype_2: torch.complex64 torch.complex128
# dtype_1 != dtype_2: torch.complex128 torch.complex64
# dtype_1 != dtype_2: torch.complex128 torch.complex64

Thanks @ev-br for figuring out that it was dependent on the default dtype!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions