Skip to content

Commit 7efcae0

Browse files
committed
Generate all valid dtypes in test_prod and test_sum
1 parent 816bf41 commit 7efcae0

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

array_api_tests/test_statistical_functions.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from . import hypothesis_helpers as hh
1212
from . import pytest_helpers as ph
1313
from . import xps
14-
from .typing import Scalar, ScalarType, Shape
14+
from .typing import DataType, Scalar, ScalarType, Shape
1515

1616

1717
def axes(ndim: int) -> st.SearchStrategy[Optional[Union[int, Shape]]]:
@@ -22,6 +22,11 @@ def axes(ndim: int) -> st.SearchStrategy[Optional[Union[int, Shape]]]:
2222
return st.one_of(axes_strats)
2323

2424

25+
def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]:
26+
dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype]
27+
return st.none() | st.sampled_from(dtypes)
28+
29+
2530
def normalise_axis(
2631
axis: Optional[Union[int, Tuple[int, ...]]], ndim: int
2732
) -> Tuple[int, ...]:
@@ -190,7 +195,7 @@ def test_prod(x, data):
190195
kw = data.draw(
191196
hh.kwargs(
192197
axis=axes(x.ndim),
193-
dtype=st.none() | st.just(x.dtype), # TODO: all valid dtypes
198+
dtype=kwarg_dtypes(x.dtype),
194199
keepdims=st.booleans(),
195200
),
196201
label="kw",
@@ -316,7 +321,7 @@ def test_sum(x, data):
316321
kw = data.draw(
317322
hh.kwargs(
318323
axis=axes(x.ndim),
319-
dtype=st.none() | st.just(x.dtype), # TODO: all valid dtypes
324+
dtype=kwarg_dtypes(x.dtype),
320325
keepdims=st.booleans(),
321326
),
322327
label="kw",

0 commit comments

Comments
 (0)