From d07c07fdded9c26d942746b603d8481451312903 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Wed, 30 Aug 2023 16:05:58 +0100 Subject: [PATCH] BUG: fix `torch.result_type` cross-kind promotion --- array_api_compat/torch/_aliases.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index af5510d7..63d99e45 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -133,6 +133,8 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype: # This doesn't result_type(dtype, dtype) for non-array API dtypes # because torch.result_type only accepts tensors. This does however, allow # cross-kind promotion. + x = torch.tensor([], dtype=x) if isinstance(x, torch.dtype) else x + y = torch.tensor([], dtype=y) if isinstance(y, torch.dtype) else y return torch.result_type(x, y) def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: