Skip to content

Commit e446a18

Browse files
committed
Remove old custom strategies for test_elementwise.py
1 parent a3e5a01 commit e446a18

File tree

1 file changed

+19
-38
lines changed

1 file changed

+19
-38
lines changed

array_api_tests/test_elementwise_functions.py

Lines changed: 19 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,35 +12,16 @@
1212
import math
1313

1414
from hypothesis import assume, given
15-
from hypothesis import strategies as st
1615

1716
from . import _array_module as xp
1817
from . import array_helpers as ah
19-
from . import hypothesis_helpers as hh
2018
from . import dtype_helpers as dh
19+
from . import hypothesis_helpers as hh
2120
from . import xps
2221
# We might as well use this implementation rather than requiring
2322
# mod.broadcast_shapes(). See test_equal() and others.
2423
from .test_broadcasting import broadcast_shapes
2524

26-
# integer_scalars = hh.array_scalars(integer_dtypes)
27-
floating_scalars = hh.array_scalars(hh.floating_dtypes)
28-
numeric_scalars = hh.array_scalars(hh.numeric_dtypes)
29-
integer_or_boolean_scalars = hh.array_scalars(hh.integer_or_boolean_dtypes)
30-
boolean_scalars = hh.array_scalars(hh.boolean_dtypes)
31-
32-
two_integer_dtypes = hh.mutually_promotable_dtypes(dtypes=dh.all_int_dtypes)
33-
two_floating_dtypes = hh.mutually_promotable_dtypes(dtypes=dh.float_dtypes)
34-
two_numeric_dtypes = hh.mutually_promotable_dtypes(dtypes=dh.numeric_dtypes)
35-
two_integer_or_boolean_dtypes = hh.mutually_promotable_dtypes(dtypes=dh.bool_and_all_int_dtypes)
36-
two_boolean_dtypes = hh.mutually_promotable_dtypes(dtypes=(xp.bool,))
37-
two_any_dtypes = hh.mutually_promotable_dtypes()
38-
39-
@st.composite
40-
def two_array_scalars(draw, dtype1, dtype2):
41-
# two_dtypes should be a strategy that returns two dtypes (like
42-
# hh.mutually_promotable_dtypes())
43-
return draw(hh.array_scalars(st.just(dtype1))), draw(hh.array_scalars(st.just(dtype2)))
4425

4526
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
4627
def test_abs(x):
@@ -811,44 +792,44 @@ def test_round(x):
811792
ah.assert_exactly_equal(a[round_down], floor[round_down])
812793
ah.assert_exactly_equal(a[round_up], ceil[round_up])
813794

814-
@given(numeric_scalars)
795+
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
815796
def test_sign(x):
816-
# a = xp.sign(x)
797+
# out = xp.sign(x)
817798
pass
818799

819-
@given(floating_scalars)
800+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
820801
def test_sin(x):
821-
# a = xp.sin(x)
802+
# out = xp.sin(x)
822803
pass
823804

824-
@given(floating_scalars)
805+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
825806
def test_sinh(x):
826-
# a = xp.sinh(x)
807+
# out = xp.sinh(x)
827808
pass
828809

829-
@given(numeric_scalars)
810+
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
830811
def test_square(x):
831-
# a = xp.square(x)
812+
# out = xp.square(x)
832813
pass
833814

834-
@given(floating_scalars)
815+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
835816
def test_sqrt(x):
836-
# a = xp.sqrt(x)
817+
# out = xp.sqrt(x)
837818
pass
838819

839-
@given(two_numeric_dtypes.flatmap(lambda i: two_array_scalars(*i)))
840-
def test_subtract(args):
841-
x1, x2 = args
842-
# a = xp.subtract(x1, x2)
820+
@given(*hh.two_mutual_arrays(dh.numeric_dtypes))
821+
def test_subtract(x1, x2):
822+
# out = xp.subtract(x1, x2)
823+
pass
843824

844-
@given(floating_scalars)
825+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
845826
def test_tan(x):
846-
# a = xp.tan(x)
827+
# out = xp.tan(x)
847828
pass
848829

849-
@given(floating_scalars)
830+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
850831
def test_tanh(x):
851-
# a = xp.tanh(x)
832+
# out = xp.tanh(x)
852833
pass
853834

854835
@given(xps.arrays(dtype=hh.numeric_dtypes, shape=xps.array_shapes()))

0 commit comments

Comments
 (0)