Skip to content

Commit ecf7817

Browse files
committed
Filter undefined dtypes, skip test cases where appropiate
1 parent 65182a5 commit ecf7817

File tree

4 files changed

+39
-29
lines changed

4 files changed

+39
-29
lines changed

xptests/hypothesis_helpers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,13 @@ def _dtypes_sorter(dtype_pair: Tuple[DataType, DataType]):
6767
key += 1
6868
return key
6969

70-
promotable_dtypes: List[Tuple[DataType, DataType]] = sorted(dh.promotion_table.keys(), key=_dtypes_sorter)
70+
_promotable_dtypes = list(dh.promotion_table.keys())
71+
if FILTER_UNDEFINED_DTYPES:
72+
_promotable_dtypes = [
73+
(d1, d2) for d1, d2 in _promotable_dtypes
74+
if not isinstance(d1, _UndefinedStub) or not isinstance(d2, _UndefinedStub)
75+
]
76+
promotable_dtypes: List[Tuple[DataType, DataType]] = sorted(_promotable_dtypes, key=_dtypes_sorter)
7177

7278
def mutually_promotable_dtypes(
7379
max_size: Optional[int] = 2,

xptests/test_array2scalar.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,16 @@
1717

1818
def make_param(method_name: str, dtype: DataType) -> Param:
1919
stype = method_stype[method_name]
20+
if isinstance(dtype, xp._UndefinedStub):
21+
marks = pytest.mark.skip(reason=f"xp.{dtype.name} not defined")
22+
else:
23+
marks = ()
2024
return pytest.param(
21-
method_name, dtype, stype, id=f"{method_name}({dh.dtype_to_name[dtype]})"
25+
method_name,
26+
dtype,
27+
stype,
28+
id=f"{method_name}({dh.dtype_to_name[dtype]})",
29+
marks=marks,
2230
)
2331

2432

xptests/test_elementwise_functions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
def make_unary_params(
5454
elwise_func_name: str, dtypes: Sequence[DataType]
5555
) -> List[UnaryParam]:
56+
if hh.FILTER_UNDEFINED_DTYPES:
57+
dtypes = [d for d in dtypes if not isinstance(d, xp._UndefinedStub)]
5658
strat = xps.arrays(dtype=st.sampled_from(dtypes), shape=hh.shapes())
5759
func = getattr(xp, elwise_func_name)
5860
op_name = func_to_op[elwise_func_name]
@@ -93,6 +95,8 @@ class FuncType(Enum):
9395
def make_binary_params(
9496
elwise_func_name: str, dtypes: Sequence[DataType]
9597
) -> List[BinaryParam]:
98+
if hh.FILTER_UNDEFINED_DTYPES:
99+
dtypes = [d for d in dtypes if not isinstance(d, xp._UndefinedStub)]
96100
dtypes_strat = st.sampled_from(dtypes)
97101

98102
def make_param(

xptests/test_type_promotion.py

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
"""
22
https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html
33
"""
4-
import math
54
from collections import defaultdict
6-
from typing import Tuple, Union, List
5+
from typing import List, Tuple, Union
76

87
import pytest
98
from hypothesis import assume, given, reject
@@ -14,9 +13,8 @@
1413
from . import hypothesis_helpers as hh
1514
from . import pytest_helpers as ph
1615
from . import xps
17-
from .typing import DataType, ScalarType, Param
1816
from .function_stubs import elementwise_functions
19-
17+
from .typing import DataType, Param, ScalarType
2018

2119
# TODO: move tests not covering elementwise funcs/ops into standalone tests
2220
# result_type, meshgrid, tensordor, vecdot
@@ -28,29 +26,6 @@ def test_result_type(dtypes):
2826
ph.assert_dtype("result_type", dtypes, out, repr_name="out")
2927

3028

31-
# The number and size of generated arrays is arbitrarily limited to prevent
32-
# meshgrid() running out of memory.
33-
@given(
34-
dtypes=hh.mutually_promotable_dtypes(5, dtypes=dh.numeric_dtypes),
35-
data=st.data(),
36-
)
37-
def test_meshgrid(dtypes, data):
38-
arrays = []
39-
shapes = data.draw(
40-
hh.mutually_broadcastable_shapes(
41-
len(dtypes), min_dims=1, max_dims=1, max_side=5
42-
),
43-
label="shapes",
44-
)
45-
for i, (dtype, shape) in enumerate(zip(dtypes, shapes), 1):
46-
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}")
47-
arrays.append(x)
48-
assert math.prod(x.size for x in arrays) <= hh.MAX_ARRAY_SIZE # sanity check
49-
out = xp.meshgrid(*arrays)
50-
for i, x in enumerate(out):
51-
ph.assert_dtype("meshgrid", dtypes, x.dtype, repr_name=f"out[{i}].dtype")
52-
53-
5429
bitwise_shift_funcs = [
5530
"bitwise_left_shift",
5631
"bitwise_right_shift",
@@ -78,6 +53,14 @@ def make_id(
7853
return f"{func_name}({f_args}) -> {f_out_dtype}"
7954

8055

56+
def mark_stubbed_dtypes(*dtypes):
57+
for dtype in dtypes:
58+
if isinstance(dtype, xp._UndefinedStub):
59+
return pytest.mark.skip(reason=f"xp.{dtype.name} not defined")
60+
else:
61+
return ()
62+
63+
8164
func_params: List[Param[str, Tuple[DataType, ...], DataType]] = []
8265
for func_name in elementwise_functions.__all__:
8366
valid_in_dtypes = dh.func_in_dtypes[func_name]
@@ -90,6 +73,7 @@ def make_id(
9073
(in_dtype,),
9174
out_dtype,
9275
id=make_id(func_name, (in_dtype,), out_dtype),
76+
marks=mark_stubbed_dtypes(in_dtype, out_dtype),
9377
)
9478
func_params.append(p)
9579
elif ndtypes == 2:
@@ -103,6 +87,7 @@ def make_id(
10387
(in_dtype1, in_dtype2),
10488
out_dtype,
10589
id=make_id(func_name, (in_dtype1, in_dtype2), out_dtype),
90+
marks=mark_stubbed_dtypes(in_dtype1, in_dtype2, out_dtype),
10691
)
10792
func_params.append(p)
10893
else:
@@ -143,6 +128,7 @@ def test_func_promotion(func_name, in_dtypes, out_dtype, data):
143128
(dtype1, dtype2),
144129
promoted_dtype,
145130
id=make_id("", (dtype1, dtype2), promoted_dtype),
131+
marks=mark_stubbed_dtypes(dtype1, dtype2, promoted_dtype),
146132
)
147133
promotion_params.append(p)
148134

@@ -194,6 +180,7 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
194180
(in_dtype,),
195181
out_dtype,
196182
id=make_id(op, (in_dtype,), out_dtype),
183+
marks=mark_stubbed_dtypes(in_dtype, out_dtype),
197184
)
198185
op_params.append(p)
199186
else:
@@ -206,6 +193,7 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
206193
(in_dtype1, in_dtype2),
207194
out_dtype,
208195
id=make_id(op, (in_dtype1, in_dtype2), out_dtype),
196+
marks=mark_stubbed_dtypes(in_dtype1, in_dtype2, out_dtype),
209197
)
210198
op_params.append(p)
211199
# We generate params for abs seperately as it does not have an associated symbol
@@ -216,6 +204,7 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
216204
(in_dtype,),
217205
in_dtype,
218206
id=make_id("__abs__", (in_dtype,), in_dtype),
207+
marks=mark_stubbed_dtypes(in_dtype),
219208
)
220209
op_params.append(p)
221210

@@ -263,6 +252,7 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
263252
(in_dtype1, in_dtype2),
264253
promoted_dtype,
265254
id=make_id(op, (in_dtype1, in_dtype2), promoted_dtype),
255+
marks=mark_stubbed_dtypes(in_dtype1, in_dtype2, promoted_dtype),
266256
)
267257
inplace_params.append(p)
268258

@@ -301,6 +291,7 @@ def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data):
301291
in_stype,
302292
out_dtype,
303293
id=make_id(op, (in_dtype, in_stype), out_dtype),
294+
marks=mark_stubbed_dtypes(in_dtype, out_dtype),
304295
)
305296
op_scalar_params.append(p)
306297

@@ -333,6 +324,7 @@ def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data):
333324
dtype,
334325
in_stype,
335326
id=make_id(op, (dtype, in_stype), dtype),
327+
marks=mark_stubbed_dtypes(dtype),
336328
)
337329
inplace_scalar_params.append(p)
338330

0 commit comments

Comments
 (0)