|
1 | 1 | import math
|
2 | 2 | from collections import deque
|
3 |
| -from typing import Iterable, Union |
| 3 | +from itertools import product |
| 4 | +from typing import Iterable, Iterator, Tuple, Union |
4 | 5 |
|
5 | 6 | from hypothesis import assume, given
|
6 | 7 | from hypothesis import strategies as st
|
@@ -43,6 +44,28 @@ def assert_array_ndindex(
|
43 | 44 | assert out[out_idx] == x[x_idx], msg
|
44 | 45 |
|
45 | 46 |
|
| 47 | +def axis_ndindex( |
| 48 | + shape: Shape, axis: int |
| 49 | +) -> Iterator[Tuple[Tuple[Union[int, slice], ...], ...]]: |
| 50 | + iterables = [range(side) for side in shape[:axis]] |
| 51 | + for _ in range(len(shape[axis:])): |
| 52 | + iterables.append([slice(None, None)]) |
| 53 | + yield from product(*iterables) |
| 54 | + |
| 55 | + |
| 56 | +def assert_equals( |
| 57 | + func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw |
| 58 | +): |
| 59 | + msg = ( |
| 60 | + f"{out_repr}={out_val}, should be {x_repr}={x_val} " |
| 61 | + f"[{func_name}({ph.fmt_kw(kw)})]" |
| 62 | + ) |
| 63 | + if dh.is_float_dtype(out_val.dtype) and xp.isnan(out_val): |
| 64 | + assert xp.isnan(x_val), msg |
| 65 | + else: |
| 66 | + assert out_val == out_val, msg |
| 67 | + |
| 68 | + |
46 | 69 | @st.composite
|
47 | 70 | def concat_shapes(draw, shape, axis):
|
48 | 71 | shape = list(shape)
|
@@ -85,21 +108,35 @@ def test_concat(dtypes, kw, data):
|
85 | 108 | shape = tuple(shape)
|
86 | 109 | ph.assert_result_shape("concat", shapes, out.shape, shape, **kw)
|
87 | 110 |
|
88 |
| - # TODO: adjust indices with nonzero axis |
89 |
| - if axis is None or axis == 0: |
90 |
| - out_indices = ah.ndindex(out.shape) |
91 |
| - for i, x in enumerate(arrays, 1): |
92 |
| - msg_suffix = f" [concat({ph.fmt_kw(kw)})]\nx{i}={x!r}\n{out=}" |
| 111 | + if axis is None: |
| 112 | + out_indices = (i for i in range(out.size)) |
| 113 | + for x_num, x in enumerate(arrays, 1): |
93 | 114 | for x_idx in ah.ndindex(x.shape):
|
94 |
| - out_idx = next(out_indices) |
95 |
| - msg = ( |
96 |
| - f"out[{out_idx}]={out[out_idx]}, should be x{i}[{x_idx}]={x[x_idx]}" |
| 115 | + out_i = next(out_indices) |
| 116 | + assert_equals( |
| 117 | + "concat", |
| 118 | + f"x{x_num}[{x_idx}]", |
| 119 | + x[x_idx], |
| 120 | + f"out[{out_i}]", |
| 121 | + out[out_i], |
| 122 | + **kw, |
97 | 123 | )
|
98 |
| - msg += msg_suffix |
99 |
| - if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): |
100 |
| - assert xp.isnan(out[out_idx]), msg |
101 |
| - else: |
102 |
| - assert out[out_idx] == x[x_idx], msg |
| 124 | + else: |
| 125 | + out_indices = ah.ndindex(out.shape) |
| 126 | + for idx in axis_ndindex(shape, axis): |
| 127 | + f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) |
| 128 | + for x_num, x in enumerate(arrays, 1): |
| 129 | + indexed_x = x[idx] |
| 130 | + for x_idx in ah.ndindex(indexed_x.shape): |
| 131 | + out_idx = next(out_indices) |
| 132 | + assert_equals( |
| 133 | + "concat", |
| 134 | + f"x{x_num}[{f_idx}][{x_idx}]", |
| 135 | + indexed_x[x_idx], |
| 136 | + f"out[{out_idx}]", |
| 137 | + out[out_idx], |
| 138 | + **kw, |
| 139 | + ) |
103 | 140 |
|
104 | 141 |
|
105 | 142 | @given(
|
|
0 commit comments