Skip to content

Commit 2876e61

Browse files
committed
Accept single dtype in dh.resut_type() and thus ph.assert_dtype()
1 parent 1fde197 commit 2876e61

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,11 @@ class MinMax(NamedTuple):
151151
}
152152

153153

154-
def result_type(*dtypes):
155-
if len(dtypes) < 2:
154+
def result_type(*dtypes: DataType):
155+
if len(dtypes) == 0:
156156
raise ValueError()
157+
elif len(dtypes) == 1:
158+
return dtypes[0]
157159
result = promotion_table[dtypes[0], dtypes[1]]
158160
for i in range(2, len(dtypes)):
159161
result = promotion_table[result, dtypes[i]]
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from pytest import raises
2+
3+
from .. import pytest_helpers as ph
4+
from .. import _array_module as xp
5+
6+
7+
def test_assert_dtype():
8+
ph.assert_dtype("promoted_func", (xp.uint8, xp.int8), xp.int16)
9+
with raises(AssertionError):
10+
ph.assert_dtype("bad_func", (xp.uint8, xp.int8), xp.float32)
11+
ph.assert_dtype("bool_func", (xp.uint8, xp.int8), xp.bool, xp.bool)
12+
ph.assert_dtype("single_promoted_func", (xp.uint8,), xp.uint8)
13+
ph.assert_dtype("single_bool_func", (xp.uint8,), xp.bool, xp.bool)

0 commit comments

Comments
 (0)