Skip to content

Commit d53aa9f

Browse files
committed
MAINT: Check essential data functions
1 parent 3c273cd commit d53aa9f

5 files changed

+54
-0
lines changed

array_api_tests/test_creation_functions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from itertools import count
44
from typing import Iterator, NamedTuple, Union
55

6+
import pytest
67
from hypothesis import assume, given, note
78
from hypothesis import strategies as st
89

@@ -76,6 +77,7 @@ def reals(min_value=None, max_value=None) -> st.SearchStrategy[Union[int, float]
7677
)
7778

7879

80+
@pytest.mark.has_setup_funcs
7981
@given(dtype=st.none() | hh.real_dtypes, data=st.data())
8082
def test_arange(dtype, data):
8183
if dtype is None or dh.is_float_dtype(dtype):
@@ -194,6 +196,7 @@ def test_arange(dtype, data):
194196
), f"out[0]={out[0]}, but should be {_start} {f_func}"
195197

196198

199+
@pytest.mark.has_setup_funcs
197200
@given(shape=hh.shapes(min_side=1), data=st.data())
198201
def test_asarray_scalars(shape, data):
199202
kw = data.draw(
@@ -257,6 +260,7 @@ def scalar_eq(s1: Scalar, s2: Scalar) -> bool:
257260
return s1 == s2
258261

259262

263+
@pytest.mark.has_setup_funcs
260264
@given(
261265
shape=hh.shapes(),
262266
dtypes=hh.oneway_promotable_dtypes(dh.all_dtypes),
@@ -424,6 +428,7 @@ def test_full(shape, fill_value, kw):
424428
ph.assert_fill("full", fill_value=fill_value, dtype=dtype, out=out, kw=dict(fill_value=fill_value))
425429

426430

431+
@pytest.mark.has_setup_funcs
427432
@given(kw=hh.kwargs(dtype=st.none() | hh.all_dtypes), data=st.data())
428433
def test_full_like(kw, data):
429434
dtype = kw.get("dtype", None) or data.draw(hh.all_dtypes, label="dtype")
@@ -442,6 +447,7 @@ def test_full_like(kw, data):
442447
finite_kw = {"allow_nan": False, "allow_infinity": False}
443448

444449

450+
@pytest.mark.has_setup_funcs
445451
@given(
446452
num=hh.sizes,
447453
dtype=st.none() | hh.real_floating_dtypes,
@@ -492,6 +498,7 @@ def test_linspace(num, dtype, endpoint, data):
492498
ph.assert_array_elements("linspace", out=out, expected=expected)
493499

494500

501+
@pytest.mark.has_setup_funcs
495502
@given(dtype=hh.numeric_dtypes, data=st.data())
496503
def test_meshgrid(dtype, data):
497504
# The number and size of generated arrays is arbitrarily limited to prevent

array_api_tests/test_manipulation_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ def test_repeat(x, kw, data):
347347

348348
reshape_shape = st.shared(hh.shapes(), key="reshape_shape")
349349

350+
@pytest.mark.has_setup_funcs
350351
@pytest.mark.unvectorized
351352
@given(
352353
x=hh.arrays(dtype=hh.all_dtypes, shape=reshape_shape),

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -781,6 +781,7 @@ def test_acosh(x):
781781
)
782782

783783

784+
@pytest.mark.has_setup_funcs
784785
@pytest.mark.parametrize("ctx,", make_binary_params("add", dh.numeric_dtypes))
785786
@given(data=st.data())
786787
def test_add(ctx, data):
@@ -854,6 +855,7 @@ def test_atanh(x):
854855
)
855856

856857

858+
@pytest.mark.has_setup_funcs
857859
@pytest.mark.parametrize(
858860
"ctx", make_binary_params("bitwise_and", dh.bool_and_all_int_dtypes)
859861
)
@@ -873,6 +875,7 @@ def test_bitwise_and(ctx, data):
873875
binary_param_assert_against_refimpl(ctx, left, right, res, "&", refimpl)
874876

875877

878+
@pytest.mark.has_setup_funcs
876879
@pytest.mark.parametrize(
877880
"ctx", make_binary_params("bitwise_left_shift", dh.all_int_dtypes)
878881
)
@@ -895,6 +898,7 @@ def test_bitwise_left_shift(ctx, data):
895898
)
896899

897900

901+
@pytest.mark.has_setup_funcs
898902
@pytest.mark.parametrize(
899903
"ctx", make_unary_params("bitwise_invert", dh.bool_and_all_int_dtypes)
900904
)
@@ -913,6 +917,7 @@ def test_bitwise_invert(ctx, data):
913917
unary_assert_against_refimpl(ctx.func_name, x, out, refimpl, expr_template="~{}={}")
914918

915919

920+
@pytest.mark.has_setup_funcs
916921
@pytest.mark.parametrize(
917922
"ctx", make_binary_params("bitwise_or", dh.bool_and_all_int_dtypes)
918923
)
@@ -932,6 +937,7 @@ def test_bitwise_or(ctx, data):
932937
binary_param_assert_against_refimpl(ctx, left, right, res, "|", refimpl)
933938

934939

940+
@pytest.mark.has_setup_funcs
935941
@pytest.mark.parametrize(
936942
"ctx", make_binary_params("bitwise_right_shift", dh.all_int_dtypes)
937943
)
@@ -953,6 +959,7 @@ def test_bitwise_right_shift(ctx, data):
953959
)
954960

955961

962+
@pytest.mark.has_setup_funcs
956963
@pytest.mark.parametrize(
957964
"ctx", make_binary_params("bitwise_xor", dh.bool_and_all_int_dtypes)
958965
)
@@ -981,6 +988,7 @@ def test_ceil(x):
981988

982989

983990
@pytest.mark.min_version("2023.12")
991+
@pytest.mark.has_setup_funcs
984992
@given(x=hh.arrays(dtype=hh.real_dtypes, shape=hh.shapes()), data=st.data())
985993
def test_clip(x, data):
986994
# Ensure that if both min and max are arrays that all three of x, min, max
@@ -1145,6 +1153,7 @@ def test_cosh(x):
11451153
unary_assert_against_refimpl("cosh", x, out, refimpl)
11461154

11471155

1156+
@pytest.mark.has_setup_funcs
11481157
@pytest.mark.parametrize("ctx", make_binary_params("divide", dh.all_float_dtypes))
11491158
@given(data=st.data())
11501159
def test_divide(ctx, data):
@@ -1168,6 +1177,7 @@ def test_divide(ctx, data):
11681177
)
11691178

11701179

1180+
@pytest.mark.has_setup_funcs
11711181
@pytest.mark.parametrize("ctx", make_binary_params("equal", dh.all_dtypes))
11721182
@given(data=st.data())
11731183
def test_equal(ctx, data):
@@ -1242,6 +1252,7 @@ def refimpl(z):
12421252
unary_assert_against_refimpl("floor", x, out, refimpl, strict_check=True)
12431253

12441254

1255+
@pytest.mark.has_setup_funcs
12451256
@pytest.mark.parametrize("ctx", make_binary_params("floor_divide", dh.real_dtypes))
12461257
@given(data=st.data())
12471258
def test_floor_divide(ctx, data):
@@ -1261,6 +1272,7 @@ def test_floor_divide(ctx, data):
12611272
binary_param_assert_against_refimpl(ctx, left, right, res, "//", operator.floordiv)
12621273

12631274

1275+
@pytest.mark.has_setup_funcs
12641276
@pytest.mark.parametrize("ctx", make_binary_params("greater", dh.real_dtypes))
12651277
@given(data=st.data())
12661278
def test_greater(ctx, data):
@@ -1281,6 +1293,7 @@ def test_greater(ctx, data):
12811293
)
12821294

12831295

1296+
@pytest.mark.has_setup_funcs
12841297
@pytest.mark.parametrize("ctx", make_binary_params("greater_equal", dh.real_dtypes))
12851298
@given(data=st.data())
12861299
def test_greater_equal(ctx, data):
@@ -1352,6 +1365,7 @@ def test_isnan(x):
13521365
unary_assert_against_refimpl("isnan", x, out, refimpl, res_stype=bool)
13531366

13541367

1368+
@pytest.mark.has_setup_funcs
13551369
@pytest.mark.parametrize("ctx", make_binary_params("less", dh.real_dtypes))
13561370
@given(data=st.data())
13571371
def test_less(ctx, data):
@@ -1372,6 +1386,7 @@ def test_less(ctx, data):
13721386
)
13731387

13741388

1389+
@pytest.mark.has_setup_funcs
13751390
@pytest.mark.parametrize("ctx", make_binary_params("less_equal", dh.real_dtypes))
13761391
@given(data=st.data())
13771392
def test_less_equal(ctx, data):
@@ -1463,6 +1478,7 @@ def logaddexp_refimpl(l: float, r: float) -> float:
14631478

14641479

14651480
@pytest.mark.min_version("2023.12")
1481+
@pytest.mark.has_setup_funcs
14661482
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
14671483
def test_logaddexp(x1, x2):
14681484
out = xp.logaddexp(x1, x2)
@@ -1476,6 +1492,7 @@ def test_logaddexp(x1, x2):
14761492
)
14771493

14781494

1495+
@pytest.mark.has_setup_funcs
14791496
@given(hh.arrays(dtype=xp.bool, shape=hh.shapes()))
14801497
def test_logical_not(x):
14811498
out = xp.logical_not(x)
@@ -1486,6 +1503,7 @@ def test_logical_not(x):
14861503
)
14871504

14881505

1506+
@pytest.mark.has_setup_funcs
14891507
@given(*hh.two_mutual_arrays([xp.bool]))
14901508
def test_logical_and(x1, x2):
14911509
out = xp.logical_and(x1, x2)
@@ -1500,6 +1518,7 @@ def test_logical_and(x1, x2):
15001518
)
15011519

15021520

1521+
@pytest.mark.has_setup_funcs
15031522
@given(*hh.two_mutual_arrays([xp.bool]))
15041523
def test_logical_or(x1, x2):
15051524
out = xp.logical_or(x1, x2)
@@ -1514,6 +1533,7 @@ def test_logical_or(x1, x2):
15141533
)
15151534

15161535

1536+
@pytest.mark.has_setup_funcs
15171537
@given(*hh.two_mutual_arrays([xp.bool]))
15181538
def test_logical_xor(x1, x2):
15191539
out = xp.logical_xor(x1, x2)
@@ -1546,6 +1566,7 @@ def test_minimum(x1, x2):
15461566
)
15471567

15481568

1569+
@pytest.mark.has_setup_funcs
15491570
@pytest.mark.parametrize("ctx", make_binary_params("multiply", dh.numeric_dtypes))
15501571
@given(data=st.data())
15511572
def test_multiply(ctx, data):
@@ -1577,6 +1598,7 @@ def test_negative(ctx, data):
15771598
)
15781599

15791600

1601+
@pytest.mark.has_setup_funcs
15801602
@pytest.mark.parametrize("ctx", make_binary_params("not_equal", dh.all_dtypes))
15811603
@given(data=st.data())
15821604
def test_not_equal(ctx, data):
@@ -1598,6 +1620,7 @@ def test_not_equal(ctx, data):
15981620

15991621

16001622
@pytest.mark.min_version("2024.12")
1623+
@pytest.mark.has_setup_funcs
16011624
@given(
16021625
shapes=hh.two_mutually_broadcastable_shapes,
16031626
dtype=hh.real_floating_dtypes,
@@ -1617,6 +1640,8 @@ def test_nextafter(shapes, dtype, data):
16171640
out=out
16181641
)
16191642

1643+
1644+
@pytest.mark.has_setup_funcs
16201645
@pytest.mark.parametrize("ctx", make_unary_params("positive", dh.numeric_dtypes))
16211646
@given(data=st.data())
16221647
def test_positive(ctx, data):
@@ -1629,6 +1654,7 @@ def test_positive(ctx, data):
16291654
ph.assert_array_elements(ctx.func_name, out=out, expected=x)
16301655

16311656

1657+
@pytest.mark.has_setup_funcs
16321658
@pytest.mark.parametrize("ctx", make_binary_params("pow", dh.numeric_dtypes))
16331659
@given(data=st.data())
16341660
def test_pow(ctx, data):
@@ -1676,6 +1702,7 @@ def test_reciprocal(x):
16761702

16771703

16781704
@pytest.mark.skip(reason="flaky")
1705+
@pytest.mark.has_setup_funcs
16791706
@pytest.mark.parametrize("ctx", make_binary_params("remainder", dh.real_dtypes))
16801707
@given(data=st.data())
16811708
def test_remainder(ctx, data):
@@ -1770,6 +1797,7 @@ def test_sqrt(x):
17701797
)
17711798

17721799

1800+
@pytest.mark.has_setup_funcs
17731801
@pytest.mark.parametrize("ctx", make_binary_params("subtract", dh.numeric_dtypes))
17741802
@given(data=st.data())
17751803
def test_subtract(ctx, data):
@@ -1923,6 +1951,7 @@ def test_binary_with_scalars_bitwise_shifts(func_data, x1x2):
19231951

19241952

19251953
@pytest.mark.min_version("2024.12")
1954+
@pytest.mark.has_setup_funcs
19261955
@pytest.mark.unvectorized
19271956
@given(
19281957
x1x2=hh.array_and_py_scalar([xp.int32]),

array_api_tests/test_utility_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def test_any(x, data):
6767

6868
@pytest.mark.unvectorized
6969
@pytest.mark.min_version("2024.12")
70+
@pytest.mark.has_setup_funcs
7071
@given(
7172
x=hh.arrays(hh.numeric_dtypes, hh.shapes(min_dims=1, min_side=1)),
7273
data=st.data(),

conftest.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ def pytest_configure(config):
9898
"markers",
9999
"unvectorized: asserts against values via element-wise iteration (not performative!)",
100100
)
101+
config.addinivalue_line(
102+
"markers",
103+
"has_setup_funcs: run when essential draw data setup functions used "
104+
"by Hypothesis are available in the namespace",
105+
)
101106
# Hypothesis
102107
deadline = None if config.getoption("--hypothesis-disable-deadline") else 800
103108
settings.register_profile(
@@ -202,6 +207,9 @@ def pytest_collection_modifyitems(config, items):
202207
# ------------------------------------------------------
203208

204209
xfail_mark = get_xfail_mark()
210+
211+
essential_funcs = ["asarray", "isnan", "reshape", "zeros"]
212+
HAS_ESSENTIAL_FUNCS = all(hasattr(xp, func_name) for func_name in essential_funcs)
205213

206214
for item in items:
207215
markers = list(item.iter_markers())
@@ -245,6 +253,14 @@ def pytest_collection_modifyitems(config, items):
245253
reason=f"requires ARRAY_API_TESTS_VERSION >= {min_version}"
246254
)
247255
)
256+
# skip if namespace doesn't support essential draw data setup functions
257+
if any(m.name == "has_setup_funcs" for m in markers) and not HAS_ESSENTIAL_FUNCS:
258+
item.add_marker(
259+
mark.skip(reason="At least one of the essential data setup "
260+
"functions is not present in the namespace: "
261+
f"{essential_funcs}")
262+
)
263+
248264
# reduce max generated Hypothesis example for unvectorized tests
249265
if any(m.name == "unvectorized" for m in markers):
250266
# TODO: limit generated examples when settings already applied

0 commit comments

Comments
 (0)