11
11
from . import hypothesis_helpers as hh
12
12
from . import pytest_helpers as ph
13
13
from . import xps
14
- from .typing import Scalar , ScalarType , Shape
14
+ from .typing import DataType , Scalar , ScalarType , Shape
15
15
16
16
17
17
def axes (ndim : int ) -> st .SearchStrategy [Optional [Union [int , Shape ]]]:
@@ -22,6 +22,11 @@ def axes(ndim: int) -> st.SearchStrategy[Optional[Union[int, Shape]]]:
22
22
return st .one_of (axes_strats )
23
23
24
24
25
+ def kwarg_dtypes (dtype : DataType ) -> st .SearchStrategy [Optional [DataType ]]:
26
+ dtypes = [d2 for d1 , d2 in dh .promotion_table if d1 == dtype ]
27
+ return st .none () | st .sampled_from (dtypes )
28
+
29
+
25
30
def normalise_axis (
26
31
axis : Optional [Union [int , Tuple [int , ...]]], ndim : int
27
32
) -> Tuple [int , ...]:
@@ -190,7 +195,7 @@ def test_prod(x, data):
190
195
kw = data .draw (
191
196
hh .kwargs (
192
197
axis = axes (x .ndim ),
193
- dtype = st . none () | st . just ( x .dtype ), # TODO: all valid dtypes
198
+ dtype = kwarg_dtypes ( x .dtype ),
194
199
keepdims = st .booleans (),
195
200
),
196
201
label = "kw" ,
@@ -316,7 +321,7 @@ def test_sum(x, data):
316
321
kw = data .draw (
317
322
hh .kwargs (
318
323
axis = axes (x .ndim ),
319
- dtype = st . none () | st . just ( x .dtype ), # TODO: all valid dtypes
324
+ dtype = kwarg_dtypes ( x .dtype ),
320
325
keepdims = st .booleans (),
321
326
),
322
327
label = "kw" ,
0 commit comments