Skip to content

Commit 730c1ce

Browse files
committed
Rudimentary test_meshgrid
1 parent bcdb6aa commit 730c1ce

File tree

1 file changed

+36
-10
lines changed

1 file changed

+36
-10
lines changed

array_api_tests/test_type_promotion.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,44 @@
2121
ScalarType = Union[Type[bool], Type[int], Type[float]]
2222

2323

24-
multi_promotable_dtypes: st.SearchStrategy[Tuple[DT, ...]] = st.one_of(
25-
st.lists(st.just(xp.bool), min_size=2),
26-
st.lists(st.sampled_from(dh.all_int_dtypes), min_size=2).filter(
27-
lambda l: not (xp.uint64 in l and any(d in dh.int_dtypes for d in l))
28-
),
29-
st.lists(st.sampled_from(dh.float_dtypes), min_size=2),
30-
).map(tuple)
24+
def multi_promotable_dtypes(
25+
allow_bool: bool = True,
26+
) -> st.SearchStrategy[Tuple[DT, ...]]:
27+
strats = [
28+
st.lists(st.sampled_from(dh.all_int_dtypes), min_size=2).filter(
29+
lambda l: not (xp.uint64 in l and any(d in dh.int_dtypes for d in l))
30+
),
31+
st.lists(st.sampled_from(dh.float_dtypes), min_size=2),
32+
]
33+
if allow_bool:
34+
strats.append(st.lists(st.just(xp.bool), min_size=2))
35+
return st.one_of(strats).map(tuple)
36+
37+
38+
@given(multi_promotable_dtypes())
39+
def test_result_type(dtypes):
40+
out = xp.result_type(*dtypes)
41+
expected = dh.result_type(*dtypes)
42+
assert out == expected, f'{out=!s}, but should be {expected}'
3143

3244

33-
@given(multi_promotable_dtypes)
34-
def test_result_type(dtypes):
35-
assert xp.result_type(*dtypes) == dh.result_type(*dtypes)
45+
@given(
46+
dtypes=multi_promotable_dtypes(allow_bool=False),
47+
kw=hh.kwargs(indexing=st.sampled_from(['xy', 'ij'])),
48+
data=st.data(),
49+
)
50+
def test_meshgrid(dtypes, kw, data):
51+
arrays = []
52+
shapes = data.draw(hh.mutually_broadcastable_shapes(len(dtypes)), label='shapes')
53+
for i, (dtype, shape) in enumerate(zip(dtypes, shapes), 1):
54+
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f'x{i}')
55+
arrays.append(x)
56+
out = xp.meshgrid(*arrays, **kw)
57+
expected = dh.result_type(*dtypes)
58+
for i in range(len(out)):
59+
assert (
60+
out[i].dtype == expected
61+
), f'out[{i}]={out[i].dtype}, but should be {expected}'
3662

3763

3864
bitwise_shift_funcs = [

0 commit comments

Comments
 (0)