Skip to content

Commit f047068

Browse files
authored
BUG: fix torch.result_type cross-kind promotion (#55)
1 parent 546fa3d commit f047068

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
133133
# This doesn't result_type(dtype, dtype) for non-array API dtypes
134134
# because torch.result_type only accepts tensors. This does however, allow
135135
# cross-kind promotion.
136+
x = torch.tensor([], dtype=x) if isinstance(x, torch.dtype) else x
137+
y = torch.tensor([], dtype=y) if isinstance(y, torch.dtype) else y
136138
return torch.result_type(x, y)
137139

138140
def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:

0 commit comments

Comments
 (0)