Skip to content

Commit 1c08549

Browse files
committed
Smoke searching functions
1 parent a7bdb9b commit 1c08549

File tree

3 files changed

+49
-8
lines changed

3 files changed

+49
-8
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from hypothesis import given
2+
from hypothesis import strategies as st
3+
4+
from . import _array_module as xp
5+
from . import hypothesis_helpers as hh
6+
from . import xps
7+
8+
9+
# TODO: generate kwargs
10+
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)))
11+
def test_argmin(x):
12+
xp.argmin(x)
13+
# TODO
14+
15+
16+
# TODO: generate kwargs
17+
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)))
18+
def test_argmax(x):
19+
xp.argmax(x)
20+
# TODO
21+
22+
23+
# TODO: generate kwargs, skip if opted out
24+
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)))
25+
def test_nonzero(x):
26+
xp.nonzero(x)
27+
# TODO
28+
29+
30+
# TODO: skip if opted out
31+
@given(
32+
shapes=hh.mutually_broadcastable_shapes(3),
33+
dtypes=hh.mutually_promotable_dtypes(),
34+
data=st.data(),
35+
)
36+
def test_where(shapes, dtypes, data):
37+
cond = data.draw(xps.arrays(dtype=xp.bool, shape=shapes[0]), label="condition")
38+
x1 = data.draw(xps.arrays(dtype=dtypes[0], shape=shapes[1]), label="x1")
39+
x2 = data.draw(xps.arrays(dtype=dtypes[1], shape=shapes[2]), label="x2")
40+
xp.where(cond, x1, x2)
41+
# TODO

array_api_tests/test_statistical_functions.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,49 +5,49 @@
55
from . import xps
66

77

8-
# TODO generate kwargs
8+
# TODO: generate kwargs
99
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)))
1010
def test_min(x):
1111
xp.min(x)
1212
# TODO
1313

1414

15-
# TODO generate kwargs
15+
# TODO: generate kwargs
1616
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)))
1717
def test_max(x):
1818
xp.max(x)
1919
# TODO
2020

2121

22-
# TODO generate kwargs
22+
# TODO: generate kwargs
2323
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)))
2424
def test_mean(x):
2525
xp.mean(x)
2626
# TODO
2727

2828

29-
# TODO generate kwargs
29+
# TODO: generate kwargs
3030
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)))
3131
def test_prod(x):
3232
xp.prod(x)
3333
# TODO
3434

3535

36-
# TODO generate kwargs
36+
# TODO: generate kwargs
3737
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)))
3838
def test_std(x):
3939
xp.std(x)
4040
# TODO
4141

4242

43-
# TODO generate kwargs
43+
# TODO: generate kwargs
4444
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)))
4545
def test_sum(x):
4646
xp.sum(x)
4747
# TODO
4848

4949

50-
# TODO generate kwargs
50+
# TODO: generate kwargs
5151
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)))
5252
def test_var(x):
5353
xp.var(x)

array_api_tests/test_type_promotion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020

2121
# TODO: move tests not covering elementwise funcs/ops into standalone tests
22-
# result_type, meshgrid, where, tensordor, vecdot
22+
# result_type, meshgrid, tensordor, vecdot
2323

2424

2525
@given(hh.mutually_promotable_dtypes(None))

0 commit comments

Comments
 (0)