1
1
from hypothesis import given
2
2
from hypothesis import strategies as st
3
3
4
+ from array_api_tests .algos import broadcast_shapes
5
+ from array_api_tests .test_manipulation_functions import assert_equals as assert_equals_
4
6
from array_api_tests .test_statistical_functions import (
5
7
assert_equals ,
6
8
assert_keepdimable_shape ,
13
15
from . import array_helpers as ah
14
16
from . import dtype_helpers as dh
15
17
from . import hypothesis_helpers as hh
18
+ from . import pytest_helpers as ph
16
19
from . import xps
17
20
18
21
@@ -95,6 +98,7 @@ def test_argmin(x, data):
95
98
assert_equals ("argmin" , int , out_idx , min_i , expected )
96
99
97
100
101
+ # TODO: skip if opted out
98
102
@given (xps .arrays (dtype = xps .scalar_dtypes (), shape = hh .shapes (min_side = 1 )))
99
103
def test_nonzero (x ):
100
104
out = xp .nonzero (x )
@@ -133,7 +137,6 @@ def test_nonzero(x):
133
137
), f"{ f_idx } is in the wrong position, should be { indices .index (idx )} "
134
138
135
139
136
- # TODO: skip if opted out
137
140
@given (
138
141
shapes = hh .mutually_broadcastable_shapes (3 ),
139
142
dtypes = hh .mutually_promotable_dtypes (),
@@ -143,5 +146,17 @@ def test_where(shapes, dtypes, data):
143
146
cond = data .draw (xps .arrays (dtype = xp .bool , shape = shapes [0 ]), label = "condition" )
144
147
x1 = data .draw (xps .arrays (dtype = dtypes [0 ], shape = shapes [1 ]), label = "x1" )
145
148
x2 = data .draw (xps .arrays (dtype = dtypes [1 ], shape = shapes [2 ]), label = "x2" )
146
- xp .where (cond , x1 , x2 )
147
- # TODO
149
+
150
+ out = xp .where (cond , x1 , x2 )
151
+
152
+ shape = broadcast_shapes (* shapes )
153
+ ph .assert_shape ("where" , out .shape , shape )
154
+ # TODO: generate indices without broadcasting arrays
155
+ _cond = xp .broadcast_to (cond , shape )
156
+ _x1 = xp .broadcast_to (x1 , shape )
157
+ _x2 = xp .broadcast_to (x2 , shape )
158
+ for idx in ah .ndindex (shape ):
159
+ if _cond [idx ]:
160
+ assert_equals_ ("where" , f"_x1[{ idx } ]" , _x1 [idx ], f"out[{ idx } ]" , out [idx ])
161
+ else :
162
+ assert_equals_ ("where" , f"_x2[{ idx } ]" , _x2 [idx ], f"out[{ idx } ]" , out [idx ])
0 commit comments