Skip to content

Commit a5f294e

Browse files
committed
Specify hh.specified_kwargs() arguments as named tuple, test it
1 parent 460034a commit a5f294e

File tree

3 files changed

+74
-37
lines changed

3 files changed

+74
-37
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,24 @@
1+
import itertools
12
from functools import reduce
2-
from operator import mul
33
from math import sqrt
4-
import itertools
5-
from typing import Tuple, Optional, List
4+
from operator import mul
5+
from typing import Any, List, NamedTuple, Optional, Tuple
66

77
from hypothesis import assume
8-
from hypothesis.strategies import (lists, integers, sampled_from,
9-
shared, floats, just, composite, one_of,
10-
none, booleans, SearchStrategy)
8+
from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
9+
integers, just, lists, none, one_of,
10+
sampled_from, shared)
1111

12-
from .pytest_helpers import nargs
13-
from .array_helpers import ndindex
14-
from .typing import DataType, Shape
15-
from . import dtype_helpers as dh
16-
from ._array_module import (full, float32, float64, bool as bool_dtype,
17-
_UndefinedStub, eye, broadcast_to)
1812
from . import _array_module as xp
13+
from . import dtype_helpers as dh
1914
from . import xps
20-
15+
from ._array_module import _UndefinedStub
16+
from ._array_module import bool as bool_dtype
17+
from ._array_module import broadcast_to, eye, float32, float64, full
18+
from .array_helpers import ndindex
2119
from .function_stubs import elementwise_functions
22-
20+
from .pytest_helpers import nargs
21+
from .typing import DataType, Shape
2322

2423
# Set this to True to not fail tests just because a dtype isn't implemented.
2524
# If no compatible dtype is implemented for a given test, the test will fail
@@ -382,3 +381,24 @@ def test_f(x, kw):
382381
if draw(booleans()):
383382
result[k] = draw(strat)
384383
return result
384+
385+
386+
class KVD(NamedTuple):
387+
keyword: str
388+
value: Any
389+
default: Any
390+
391+
392+
@composite
393+
def specified_kwargs(draw, *keys_values_defaults: KVD):
394+
"""Generates valid kwargs given expected defaults.
395+
396+
When we can't realistically use hh.kwargs() and thus test whether xp infact
397+
defaults correctly, this strategy lets us remove generated arguments if they
398+
are of the default value anyway.
399+
"""
400+
kw = {}
401+
for keyword, value, default in keys_values_defaults:
402+
if value is not default or draw(booleans()):
403+
kw[keyword] = value
404+
return kw

array_api_tests/meta/test_hypothesis_helpers.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from hypothesis import given, strategies as st, settings
55

66
from .. import _array_module as xp
7+
from .. import xps
78
from .._array_module import _UndefinedStub
89
from .. import array_helpers as ah
910
from .. import dtype_helpers as dh
@@ -76,6 +77,37 @@ def run(kw):
7677
assert len(c_results) > 0
7778
assert all(isinstance(kw["c"], str) for kw in c_results)
7879

80+
81+
def test_specified_kwargs():
82+
results = []
83+
84+
@given(n=st.integers(0, 10), d=st.none() | xps.scalar_dtypes(), data=st.data())
85+
@settings(max_examples=100)
86+
def run(n, d, data):
87+
kw = data.draw(
88+
hh.specified_kwargs(
89+
hh.KVD("n", n, 0),
90+
hh.KVD("d", d, None),
91+
),
92+
label="kw",
93+
)
94+
results.append(kw)
95+
run()
96+
97+
assert all(isinstance(kw, dict) for kw in results)
98+
99+
assert any(len(kw) == 0 for kw in results)
100+
101+
assert any("n" not in kw.keys() for kw in results)
102+
assert any("n" in kw.keys() and kw["n"] == 0 for kw in results)
103+
assert any("n" in kw.keys() and kw["n"] != 0 for kw in results)
104+
105+
assert any("d" not in kw.keys() for kw in results)
106+
assert any("d" in kw.keys() and kw["d"] is None for kw in results)
107+
assert any("d" in kw.keys() and kw["d"] is xp.float64 for kw in results)
108+
109+
110+
79111
@given(m=hh.symmetric_matrices(hh.shared_floating_dtypes,
80112
finite=st.shared(st.booleans(), key='finite')),
81113
dtype=hh.shared_floating_dtypes,

array_api_tests/test_creation_functions.py

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
from itertools import count
3-
from typing import Any, Iterator, NamedTuple, Tuple, Union
3+
from typing import Iterator, NamedTuple, Union
44

55
from hypothesis import assume, given
66
from hypothesis import strategies as st
@@ -14,21 +14,6 @@
1414
from .typing import DataType, Scalar
1515

1616

17-
@st.composite
18-
def specified_kwargs(draw, *keys_values_defaults: Tuple[str, Any, Any]):
19-
"""Generates valid kwargs given expected defaults.
20-
21-
When we can't realistically use hh.kwargs() and thus test whether xp infact
22-
defaults correctly, this strategy lets us remove generated arguments if they
23-
are of the default value anyway.
24-
"""
25-
kw = {}
26-
for key, value, default in keys_values_defaults:
27-
if value is not default or draw(st.booleans()):
28-
kw[key] = value
29-
return kw
30-
31-
3217
class frange(NamedTuple):
3318
start: float
3419
stop: float
@@ -147,10 +132,10 @@ def test_arange(dtype, data):
147132
), f"{size=} should be no more than {hh.MAX_ARRAY_SIZE}" # sanity check
148133

149134
kw = data.draw(
150-
specified_kwargs(
151-
("stop", stop, None),
152-
("step", step, None),
153-
("dtype", dtype, None),
135+
hh.specified_kwargs(
136+
hh.KVD("stop", stop, None),
137+
hh.KVD("step", step, None),
138+
hh.KVD("dtype", dtype, None),
154139
),
155140
label="kw",
156141
)
@@ -360,9 +345,9 @@ def test_linspace(num, dtype, endpoint, data):
360345
stop = data.draw(int_stops(start, num, _dtype, endpoint), label="stop")
361346

362347
kw = data.draw(
363-
specified_kwargs(
364-
("dtype", dtype, None),
365-
("endpoint", endpoint, True),
348+
hh.specified_kwargs(
349+
hh.KVD("dtype", dtype, None),
350+
hh.KVD("endpoint", endpoint, True),
366351
),
367352
label="kw",
368353
)

0 commit comments

Comments
 (0)