Skip to content

Commit 22f9815

Browse files
committed
Move all_floating_dtypes() into hypothesis_helpers.py
And use it in `test_fft.py`
1 parent 7c31597 commit 22f9815

File tree

3 files changed

+36
-33
lines changed

3 files changed

+36
-33
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
integers, just, lists, none, one_of,
1212
sampled_from, shared)
1313

14-
from . import _array_module as xp
14+
from . import _array_module as xp, api_version
1515
from . import dtype_helpers as dh
1616
from . import shape_helpers as sh
1717
from . import xps
@@ -141,6 +141,13 @@ def oneway_broadcastable_shapes(draw) -> SearchStrategy[OnewayBroadcastableShape
141141
return OnewayBroadcastableShapes(input_shape, result_shape)
142142

143143

144+
def all_floating_dtypes() -> SearchStrategy[DataType]:
145+
strat = xps.floating_dtypes()
146+
if api_version >= "2022.12":
147+
strat |= xps.complex_dtypes()
148+
return strat
149+
150+
144151
# shared() allows us to draw either the function or the function name and they
145152
# will both correspond to the same function.
146153

array_api_tests/test_fft.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from array_api_tests.typing import DataType
77

88
from . import _array_module as xp
9+
from . import dtype_helpers as dh
910
from . import hypothesis_helpers as hh
1011
from . import pytest_helpers as ph
1112
from . import xps
@@ -23,36 +24,38 @@
2324
def assert_fft_dtype(func_name: str, *, in_dtype: DataType, out_dtype: DataType):
2425
if in_dtype == xp.float32:
2526
expected = xp.complex64
26-
else:
27-
assert in_dtype == xp.float64 # sanity check
27+
elif in_dtype == xp.float64:
2828
expected = xp.complex128
29+
else:
30+
assert dh.is_float_dtype(in_dtype) # sanity check
31+
expected = in_dtype
2932
ph.assert_dtype(
3033
func_name, in_dtype=in_dtype, out_dtype=out_dtype, expected=expected
3134
)
3235

3336

34-
@given(x=xps.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat))
37+
@given(x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat))
3538
def test_fft(x):
3639
out = xp.fft.fft(x)
3740
assert_fft_dtype("fft", in_dtype=x.dtype, out_dtype=out.dtype)
3841
ph.assert_shape("fft", out_shape=out.shape, expected=x.shape)
3942

4043

41-
@given(x=xps.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat))
44+
@given(x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat))
4245
def test_ifft(x):
4346
out = xp.fft.ifft(x)
4447
assert_fft_dtype("ifft", in_dtype=x.dtype, out_dtype=out.dtype)
4548
ph.assert_shape("ifft", out_shape=out.shape, expected=x.shape)
4649

4750

48-
@given(x=xps.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat))
51+
@given(x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat))
4952
def test_fftn(x):
5053
out = xp.fft.fftn(x)
5154
assert_fft_dtype("fftn", in_dtype=x.dtype, out_dtype=out.dtype)
5255
ph.assert_shape("fftn", out_shape=out.shape, expected=x.shape)
5356

5457

55-
@given(x=xps.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat))
58+
@given(x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat))
5659
def test_ifftn(x):
5760
out = xp.fft.ifftn(x)
5861
assert_fft_dtype("ifftn", in_dtype=x.dtype, out_dtype=out.dtype)

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,6 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]:
3333
return xps.boolean_dtypes() | all_integer_dtypes()
3434

3535

36-
def all_floating_dtypes() -> st.SearchStrategy[DataType]:
37-
strat = xps.floating_dtypes()
38-
if api_version >= "2022.12":
39-
strat |= xps.complex_dtypes()
40-
return strat
41-
42-
4336
def mock_int_dtype(n: int, dtype: DataType) -> int:
4437
"""Returns equivalent of `n` that mocks `dtype` behaviour."""
4538
nbits = dh.dtype_nbits[dtype]
@@ -714,7 +707,7 @@ def test_abs(ctx, data):
714707
)
715708

716709

717-
@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
710+
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
718711
def test_acos(x):
719712
out = xp.acos(x)
720713
ph.assert_dtype("acos", in_dtype=x.dtype, out_dtype=out.dtype)
@@ -724,7 +717,7 @@ def test_acos(x):
724717
)
725718

726719

727-
@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
720+
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
728721
def test_acosh(x):
729722
out = xp.acosh(x)
730723
ph.assert_dtype("acosh", in_dtype=x.dtype, out_dtype=out.dtype)
@@ -748,7 +741,7 @@ def test_add(ctx, data):
748741
binary_param_assert_against_refimpl(ctx, left, right, res, "+", operator.add)
749742

750743

751-
@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
744+
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
752745
def test_asin(x):
753746
out = xp.asin(x)
754747
ph.assert_dtype("asin", in_dtype=x.dtype, out_dtype=out.dtype)
@@ -758,15 +751,15 @@ def test_asin(x):
758751
)
759752

760753

761-
@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
754+
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
762755
def test_asinh(x):
763756
out = xp.asinh(x)
764757
ph.assert_dtype("asinh", in_dtype=x.dtype, out_dtype=out.dtype)
765758
ph.assert_shape("asinh", out_shape=out.shape, expected=x.shape)
766759
unary_assert_against_refimpl("asinh", x, out, math.asinh)
767760

768761

769-
@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
762+
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
770763
def test_atan(x):
771764
out = xp.atan(x)
772765
ph.assert_dtype("atan", in_dtype=x.dtype, out_dtype=out.dtype)
@@ -782,7 +775,7 @@ def test_atan2(x1, x2):
782775
binary_assert_against_refimpl("atan2", x1, x2, out, math.atan2)
783776

784777

785-
@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
778+
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
786779
def test_atanh(x):
787780
out = xp.atanh(x)
788781
ph.assert_dtype("atanh", in_dtype=x.dtype, out_dtype=out.dtype)
@@ -932,15 +925,15 @@ def test_conj(x):
932925
unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate"))
933926

934927

935-
@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
928+
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
936929
def test_cos(x):
937930
out = xp.cos(x)
938931
ph.assert_dtype("cos", in_dtype=x.dtype, out_dtype=out.dtype)
939932
ph.assert_shape("cos", out_shape=out.shape, expected=x.shape)
940933
unary_assert_against_refimpl("cos", x, out, math.cos)
941934

942935

943-
@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
936+
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
944937
def test_cosh(x):
945938
out = xp.cosh(x)
946939
ph.assert_dtype("cosh", in_dtype=x.dtype, out_dtype=out.dtype)
@@ -1001,15 +994,15 @@ def test_equal(ctx, data):
1001994
)
1002995

1003996

1004-
@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
997+
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
1005998
def test_exp(x):
1006999
out = xp.exp(x)
10071000
ph.assert_dtype("exp", in_dtype=x.dtype, out_dtype=out.dtype)
10081001
ph.assert_shape("exp", out_shape=out.shape, expected=x.shape)
10091002
unary_assert_against_refimpl("exp", x, out, math.exp)
10101003

10111004

1012-
@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
1005+
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
10131006
def test_expm1(x):
10141007
out = xp.expm1(x)
10151008
ph.assert_dtype("expm1", in_dtype=x.dtype, out_dtype=out.dtype)
@@ -1158,7 +1151,7 @@ def test_less_equal(ctx, data):
11581151
)
11591152

11601153

1161-
@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
1154+
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
11621155
def test_log(x):
11631156
out = xp.log(x)
11641157
ph.assert_dtype("log", in_dtype=x.dtype, out_dtype=out.dtype)
@@ -1168,7 +1161,7 @@ def test_log(x):
11681161
)
11691162

11701163

1171-
@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
1164+
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
11721165
def test_log1p(x):
11731166
out = xp.log1p(x)
11741167
ph.assert_dtype("log1p", in_dtype=x.dtype, out_dtype=out.dtype)
@@ -1178,7 +1171,7 @@ def test_log1p(x):
11781171
)
11791172

11801173

1181-
@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
1174+
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
11821175
def test_log2(x):
11831176
out = xp.log2(x)
11841177
ph.assert_dtype("log2", in_dtype=x.dtype, out_dtype=out.dtype)
@@ -1188,7 +1181,7 @@ def test_log2(x):
11881181
)
11891182

11901183

1191-
@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
1184+
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
11921185
def test_log10(x):
11931186
out = xp.log10(x)
11941187
ph.assert_dtype("log10", in_dtype=x.dtype, out_dtype=out.dtype)
@@ -1379,15 +1372,15 @@ def test_sign(x):
13791372
)
13801373

13811374

1382-
@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
1375+
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
13831376
def test_sin(x):
13841377
out = xp.sin(x)
13851378
ph.assert_dtype("sin", in_dtype=x.dtype, out_dtype=out.dtype)
13861379
ph.assert_shape("sin", out_shape=out.shape, expected=x.shape)
13871380
unary_assert_against_refimpl("sin", x, out, math.sin)
13881381

13891382

1390-
@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
1383+
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
13911384
def test_sinh(x):
13921385
out = xp.sinh(x)
13931386
ph.assert_dtype("sinh", in_dtype=x.dtype, out_dtype=out.dtype)
@@ -1405,7 +1398,7 @@ def test_square(x):
14051398
)
14061399

14071400

1408-
@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
1401+
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
14091402
def test_sqrt(x):
14101403
out = xp.sqrt(x)
14111404
ph.assert_dtype("sqrt", in_dtype=x.dtype, out_dtype=out.dtype)
@@ -1429,15 +1422,15 @@ def test_subtract(ctx, data):
14291422
binary_param_assert_against_refimpl(ctx, left, right, res, "-", operator.sub)
14301423

14311424

1432-
@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
1425+
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
14331426
def test_tan(x):
14341427
out = xp.tan(x)
14351428
ph.assert_dtype("tan", in_dtype=x.dtype, out_dtype=out.dtype)
14361429
ph.assert_shape("tan", out_shape=out.shape, expected=x.shape)
14371430
unary_assert_against_refimpl("tan", x, out, math.tan)
14381431

14391432

1440-
@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
1433+
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
14411434
def test_tanh(x):
14421435
out = xp.tanh(x)
14431436
ph.assert_dtype("tanh", in_dtype=x.dtype, out_dtype=out.dtype)

0 commit comments

Comments
 (0)