|
21 | 21 | ScalarType = Union[Type[bool], Type[int], Type[float]]
|
22 | 22 |
|
23 | 23 |
|
24 |
| -multi_promotable_dtypes: st.SearchStrategy[Tuple[DT, ...]] = st.one_of( |
25 |
| - st.lists(st.just(xp.bool), min_size=2), |
26 |
| - st.lists(st.sampled_from(dh.all_int_dtypes), min_size=2).filter( |
27 |
| - lambda l: not (xp.uint64 in l and any(d in dh.int_dtypes for d in l)) |
28 |
| - ), |
29 |
| - st.lists(st.sampled_from(dh.float_dtypes), min_size=2), |
30 |
| -).map(tuple) |
| 24 | +def multi_promotable_dtypes( |
| 25 | + allow_bool: bool = True, |
| 26 | +) -> st.SearchStrategy[Tuple[DT, ...]]: |
| 27 | + strats = [ |
| 28 | + st.lists(st.sampled_from(dh.all_int_dtypes), min_size=2).filter( |
| 29 | + lambda l: not (xp.uint64 in l and any(d in dh.int_dtypes for d in l)) |
| 30 | + ), |
| 31 | + st.lists(st.sampled_from(dh.float_dtypes), min_size=2), |
| 32 | + ] |
| 33 | + if allow_bool: |
| 34 | + strats.append(st.lists(st.just(xp.bool), min_size=2)) |
| 35 | + return st.one_of(strats).map(tuple) |
| 36 | + |
| 37 | + |
| 38 | +@given(multi_promotable_dtypes()) |
| 39 | +def test_result_type(dtypes): |
| 40 | + out = xp.result_type(*dtypes) |
| 41 | + expected = dh.result_type(*dtypes) |
| 42 | + assert out == expected, f'{out=!s}, but should be {expected}' |
31 | 43 |
|
32 | 44 |
|
33 |
| -@given(multi_promotable_dtypes) |
34 |
| -def test_result_type(dtypes): |
35 |
| - assert xp.result_type(*dtypes) == dh.result_type(*dtypes) |
| 45 | +@given( |
| 46 | + dtypes=multi_promotable_dtypes(allow_bool=False), |
| 47 | + kw=hh.kwargs(indexing=st.sampled_from(['xy', 'ij'])), |
| 48 | + data=st.data(), |
| 49 | +) |
| 50 | +def test_meshgrid(dtypes, kw, data): |
| 51 | + arrays = [] |
| 52 | + shapes = data.draw(hh.mutually_broadcastable_shapes(len(dtypes)), label='shapes') |
| 53 | + for i, (dtype, shape) in enumerate(zip(dtypes, shapes), 1): |
| 54 | + x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f'x{i}') |
| 55 | + arrays.append(x) |
| 56 | + out = xp.meshgrid(*arrays, **kw) |
| 57 | + expected = dh.result_type(*dtypes) |
| 58 | + for i in range(len(out)): |
| 59 | + assert ( |
| 60 | + out[i].dtype == expected |
| 61 | + ), f'out[{i}]={out[i].dtype}, but should be {expected}' |
36 | 62 |
|
37 | 63 |
|
38 | 64 | bitwise_shift_funcs = [
|
|
0 commit comments