Skip to content

Commit d12808b

Browse files
committed
Meta test for promote_dtype
1 parent d4f8dea commit d12808b

File tree

1 file changed

+32
-14
lines changed

1 file changed

+32
-14
lines changed
Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,51 @@
1-
from ..array_helpers import exactly_equal, notequal, int_to_dtype
2-
from ..hypothesis_helpers import integer_dtypes
3-
from ..dtype_helpers import dtype_nbits, dtype_signed
4-
from .._array_module import asarray, nan, equal, all
5-
1+
import pytest
62
from hypothesis import given, assume
73
from hypothesis.strategies import integers
84

5+
from ..array_helpers import exactly_equal, notequal, int_to_dtype, promote_dtypes
6+
from ..hypothesis_helpers import integer_dtypes
7+
from ..dtype_helpers import dtype_nbits, dtype_signed
8+
from .. import _array_module as xp
9+
910
# TODO: These meta-tests currently only work with NumPy
1011

1112
def test_exactly_equal():
12-
a = asarray([0, 0., -0., -0., nan, nan, 1, 1])
13-
b = asarray([0, -1, -0., 0., nan, 1, 1, 2])
13+
a = xp.asarray([0, 0., -0., -0., xp.nan, xp.nan, 1, 1])
14+
b = xp.asarray([0, -1, -0., 0., xp.nan, 1, 1, 2])
1415

15-
res = asarray([True, False, True, False, True, False, True, False])
16-
assert all(equal(exactly_equal(a, b), res))
16+
res = xp.asarray([True, False, True, False, True, False, True, False])
17+
assert xp.all(xp.equal(exactly_equal(a, b), res))
1718

1819
def test_notequal():
19-
a = asarray([0, 0., -0., -0., nan, nan, 1, 1])
20-
b = asarray([0, -1, -0., 0., nan, 1, 1, 2])
20+
a = xp.asarray([0, 0., -0., -0., xp.nan, xp.nan, 1, 1])
21+
b = xp.asarray([0, -1, -0., 0., xp.nan, 1, 1, 2])
2122

22-
res = asarray([False, True, False, False, False, True, False, True])
23-
assert all(equal(notequal(a, b), res))
23+
res = xp.asarray([False, True, False, False, False, True, False, True])
24+
assert xp.all(xp.equal(notequal(a, b), res))
2425

2526
@given(integers(), integer_dtypes)
2627
def test_int_to_dtype(x, dtype):
2728
n = dtype_nbits(dtype)
2829
signed = dtype_signed(dtype)
2930
try:
30-
d = asarray(x, dtype=dtype)
31+
d = xp.asarray(x, dtype=dtype)
3132
except OverflowError:
3233
assume(False)
3334
assert int_to_dtype(x, n, signed) == d
35+
36+
@pytest.mark.parametrize(
37+
"dtype1, dtype2, result",
38+
[
39+
(xp.uint8, xp.uint8, xp.uint8),
40+
(xp.uint8, xp.int8, xp.int16),
41+
(xp.int8, xp.int8, xp.int8),
42+
]
43+
)
44+
def test_promote_dtypes(dtype1, dtype2, result):
45+
assert promote_dtypes(dtype1, dtype2) == result
46+
47+
48+
@pytest.mark.parametrize("dtype1, dtype2", [(xp.uint8, xp.float32)])
49+
def test_promote_dtypes_incompatible_dtypes_fail(dtype1, dtype2):
50+
with pytest.raises(ValueError):
51+
promote_dtypes(dtype1, dtype2)

0 commit comments

Comments
 (0)