Skip to content

Commit 6addd76

Browse files
committed
Cover everything for test_where
1 parent 5716667 commit 6addd76

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

array_api_tests/test_searching_functions.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from hypothesis import given
22
from hypothesis import strategies as st
33

4+
from array_api_tests.algos import broadcast_shapes
5+
from array_api_tests.test_manipulation_functions import assert_equals as assert_equals_
46
from array_api_tests.test_statistical_functions import (
57
assert_equals,
68
assert_keepdimable_shape,
@@ -13,6 +15,7 @@
1315
from . import array_helpers as ah
1416
from . import dtype_helpers as dh
1517
from . import hypothesis_helpers as hh
18+
from . import pytest_helpers as ph
1619
from . import xps
1720

1821

@@ -95,6 +98,7 @@ def test_argmin(x, data):
9598
assert_equals("argmin", int, out_idx, min_i, expected)
9699

97100

101+
# TODO: skip if opted out
98102
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)))
99103
def test_nonzero(x):
100104
out = xp.nonzero(x)
@@ -133,7 +137,6 @@ def test_nonzero(x):
133137
), f"{f_idx} is in the wrong position, should be {indices.index(idx)}"
134138

135139

136-
# TODO: skip if opted out
137140
@given(
138141
shapes=hh.mutually_broadcastable_shapes(3),
139142
dtypes=hh.mutually_promotable_dtypes(),
@@ -143,5 +146,17 @@ def test_where(shapes, dtypes, data):
143146
cond = data.draw(xps.arrays(dtype=xp.bool, shape=shapes[0]), label="condition")
144147
x1 = data.draw(xps.arrays(dtype=dtypes[0], shape=shapes[1]), label="x1")
145148
x2 = data.draw(xps.arrays(dtype=dtypes[1], shape=shapes[2]), label="x2")
146-
xp.where(cond, x1, x2)
147-
# TODO
149+
150+
out = xp.where(cond, x1, x2)
151+
152+
shape = broadcast_shapes(*shapes)
153+
ph.assert_shape("where", out.shape, shape)
154+
# TODO: generate indices without broadcasting arrays
155+
_cond = xp.broadcast_to(cond, shape)
156+
_x1 = xp.broadcast_to(x1, shape)
157+
_x2 = xp.broadcast_to(x2, shape)
158+
for idx in ah.ndindex(shape):
159+
if _cond[idx]:
160+
assert_equals_("where", f"_x1[{idx}]", _x1[idx], f"out[{idx}]", out[idx])
161+
else:
162+
assert_equals_("where", f"_x2[{idx}]", _x2[idx], f"out[{idx}]", out[idx])

0 commit comments

Comments
 (0)