Skip to content

Commit 84bd3ef

Browse files
committed
Move oneway strategies to hypothesis_helpers.py
1 parent 78c57d0 commit 84bd3ef

File tree

6 files changed

+52
-62
lines changed

6 files changed

+52
-62
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from operator import mul
55
from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union
66

7-
from hypothesis import assume
7+
from hypothesis import assume, reject
88
from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
99
integers, just, lists, none, one_of,
1010
sampled_from, shared)
@@ -99,6 +99,46 @@ def mutually_promotable_dtypes(
9999
return one_of(strats).map(tuple)
100100

101101

102+
class OnewayPromotableDtypes(NamedTuple):
103+
input_dtype: DataType
104+
result_dtype: DataType
105+
106+
107+
@composite
108+
def oneway_promotable_dtypes(
109+
draw, dtypes: Sequence[DataType]
110+
) -> SearchStrategy[OnewayPromotableDtypes]:
111+
"""Return a strategy for input dtypes that promote to result dtypes."""
112+
d1, d2 = draw(mutually_promotable_dtypes(dtypes=dtypes))
113+
result_dtype = dh.result_type(d1, d2)
114+
if d1 == result_dtype:
115+
return OnewayPromotableDtypes(d2, d1)
116+
elif d2 == result_dtype:
117+
return OnewayPromotableDtypes(d1, d2)
118+
else:
119+
reject()
120+
121+
122+
class OnewayBroadcastableShapes(NamedTuple):
123+
input_shape: Shape
124+
result_shape: Shape
125+
126+
127+
@composite
128+
def oneway_broadcastable_shapes(draw) -> SearchStrategy[OnewayBroadcastableShapes]:
129+
"""Return a strategy for input shapes that broadcast to result shapes."""
130+
result_shape = draw(shapes(min_side=1))
131+
input_shape = draw(
132+
xps.broadcastable_shapes(
133+
result_shape,
134+
# Override defaults so bad shapes are less likely to be generated.
135+
max_side=None if result_shape == () else max(result_shape),
136+
max_dims=len(result_shape),
137+
).filter(lambda s: sh.broadcast_shapes(result_shape, s) == result_shape)
138+
)
139+
return OnewayBroadcastableShapes(input_shape, result_shape)
140+
141+
102142
# shared() allows us to draw either the function or the function name and they
103143
# will both correspond to the same function.
104144

array_api_tests/meta/test_utils.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,12 @@
44

55
from .. import _array_module as xp
66
from .. import dtype_helpers as dh
7+
from .. import hypothesis_helpers as hh
78
from .. import shape_helpers as sh
89
from .. import xps
910
from ..test_creation_functions import frange
1011
from ..test_manipulation_functions import roll_ndindex
11-
from ..test_operators_and_elementwise_functions import (
12-
mock_int_dtype,
13-
oneway_broadcastable_shapes,
14-
oneway_promotable_dtypes,
15-
)
12+
from ..test_operators_and_elementwise_functions import mock_int_dtype
1613

1714

1815
@pytest.mark.parametrize(
@@ -115,11 +112,11 @@ def test_int_to_dtype(x, dtype):
115112
assert mock_int_dtype(x, dtype) == d
116113

117114

118-
@given(oneway_promotable_dtypes(dh.all_dtypes))
115+
@given(hh.oneway_promotable_dtypes(dh.all_dtypes))
119116
def test_oneway_promotable_dtypes(D):
120117
assert D.result_dtype == dh.result_type(*D)
121118

122119

123-
@given(oneway_broadcastable_shapes())
120+
@given(hh.oneway_broadcastable_shapes())
124121
def test_oneway_broadcastable_shapes(S):
125122
assert S.result_shape == sh.broadcast_shapes(*S)

array_api_tests/test_array_object.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from . import pytest_helpers as ph
1313
from . import shape_helpers as sh
1414
from . import xps
15-
from .test_operators_and_elementwise_functions import oneway_promotable_dtypes
1615
from .typing import DataType, Index, Param, Scalar, ScalarType, Shape
1716

1817
pytestmark = pytest.mark.ci
@@ -108,7 +107,7 @@ def test_getitem(shape, dtype, data):
108107

109108
@given(
110109
shape=hh.shapes(),
111-
dtypes=oneway_promotable_dtypes(dh.all_dtypes),
110+
dtypes=hh.oneway_promotable_dtypes(dh.all_dtypes),
112111
data=st.data(),
113112
)
114113
def test_setitem(shape, dtypes, data):

array_api_tests/test_creation_functions.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from . import pytest_helpers as ph
1313
from . import shape_helpers as sh
1414
from . import xps
15-
from .test_operators_and_elementwise_functions import oneway_promotable_dtypes
1615
from .typing import DataType, Scalar
1716

1817
pytestmark = pytest.mark.ci
@@ -256,7 +255,7 @@ def scalar_eq(s1: Scalar, s2: Scalar) -> bool:
256255

257256
@given(
258257
shape=hh.shapes(),
259-
dtypes=oneway_promotable_dtypes(dh.all_dtypes),
258+
dtypes=hh.oneway_promotable_dtypes(dh.all_dtypes),
260259
data=st.data(),
261260
)
262261
def test_asarray_arrays(shape, dtypes, data):

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 3 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@
88
from typing import Callable, List, NamedTuple, Optional, Sequence, TypeVar, Union
99

1010
import pytest
11-
from hypothesis import assume, given
11+
from hypothesis import assume, given, reject
1212
from hypothesis import strategies as st
13-
from hypothesis.control import reject
1413

1514
from . import _array_module as xp, api_version
1615
from . import array_helpers as ah
@@ -41,46 +40,6 @@ def all_floating_dtypes() -> st.SearchStrategy[DataType]:
4140
return strat
4241

4342

44-
class OnewayPromotableDtypes(NamedTuple):
45-
input_dtype: DataType
46-
result_dtype: DataType
47-
48-
49-
@st.composite
50-
def oneway_promotable_dtypes(
51-
draw, dtypes: Sequence[DataType]
52-
) -> st.SearchStrategy[OnewayPromotableDtypes]:
53-
"""Return a strategy for input dtypes that promote to result dtypes."""
54-
d1, d2 = draw(hh.mutually_promotable_dtypes(dtypes=dtypes))
55-
result_dtype = dh.result_type(d1, d2)
56-
if d1 == result_dtype:
57-
return OnewayPromotableDtypes(d2, d1)
58-
elif d2 == result_dtype:
59-
return OnewayPromotableDtypes(d1, d2)
60-
else:
61-
reject()
62-
63-
64-
class OnewayBroadcastableShapes(NamedTuple):
65-
input_shape: Shape
66-
result_shape: Shape
67-
68-
69-
@st.composite
70-
def oneway_broadcastable_shapes(draw) -> st.SearchStrategy[OnewayBroadcastableShapes]:
71-
"""Return a strategy for input shapes that broadcast to result shapes."""
72-
result_shape = draw(hh.shapes(min_side=1))
73-
input_shape = draw(
74-
xps.broadcastable_shapes(
75-
result_shape,
76-
# Override defaults so bad shapes are less likely to be generated.
77-
max_side=None if result_shape == () else max(result_shape),
78-
max_dims=len(result_shape),
79-
).filter(lambda s: sh.broadcast_shapes(result_shape, s) == result_shape)
80-
)
81-
return OnewayBroadcastableShapes(input_shape, result_shape)
82-
83-
8443
def mock_int_dtype(n: int, dtype: DataType) -> int:
8544
"""Returns equivalent of `n` that mocks `dtype` behaviour."""
8645
nbits = dh.dtype_nbits[dtype]
@@ -557,7 +516,7 @@ def make_binary_params(
557516
) -> List[Param[BinaryParamContext]]:
558517
if hh.FILTER_UNDEFINED_DTYPES:
559518
dtypes = [d for d in dtypes if not isinstance(d, xp._UndefinedStub)]
560-
shared_oneway_dtypes = st.shared(oneway_promotable_dtypes(dtypes))
519+
shared_oneway_dtypes = st.shared(hh.oneway_promotable_dtypes(dtypes))
561520
left_dtypes = shared_oneway_dtypes.map(lambda D: D.result_dtype)
562521
right_dtypes = shared_oneway_dtypes.map(lambda D: D.input_dtype)
563522

@@ -576,7 +535,7 @@ def make_param(
576535
right_strat = right_dtypes.flatmap(lambda d: xps.from_dtype(d, **finite_kw))
577536
else:
578537
if func_type is FuncType.IOP:
579-
shared_oneway_shapes = st.shared(oneway_broadcastable_shapes())
538+
shared_oneway_shapes = st.shared(hh.oneway_broadcastable_shapes())
580539
left_strat = xps.arrays(
581540
dtype=left_dtypes,
582541
shape=shared_oneway_shapes.map(lambda S: S.result_shape),

array_api_tests/test_special_cases.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,6 @@
3535
from . import xps
3636
from ._array_module import mod as xp
3737
from .stubs import category_to_funcs
38-
from .test_operators_and_elementwise_functions import (
39-
oneway_broadcastable_shapes,
40-
oneway_promotable_dtypes,
41-
)
4238

4339
pytestmark = pytest.mark.ci
4440

@@ -1281,8 +1277,8 @@ def test_binary(func_name, func, case, x1, x2, data):
12811277

12821278
@pytest.mark.parametrize("iop_name, iop, case", iop_params)
12831279
@given(
1284-
oneway_dtypes=oneway_promotable_dtypes(dh.float_dtypes),
1285-
oneway_shapes=oneway_broadcastable_shapes(),
1280+
oneway_dtypes=hh.oneway_promotable_dtypes(dh.float_dtypes),
1281+
oneway_shapes=hh.oneway_broadcastable_shapes(),
12861282
data=st.data(),
12871283
)
12881284
def test_iop(iop_name, iop, case, oneway_dtypes, oneway_shapes, data):

0 commit comments

Comments
 (0)