|
1 | 1 | import math
|
2 | 2 | from collections import deque
|
3 | 3 | from itertools import product
|
4 |
| -from typing import Iterable, Iterator, Tuple, Union |
| 4 | +from typing import Iterable, Union |
5 | 5 |
|
6 | 6 | from hypothesis import assume, given
|
7 | 7 | from hypothesis import strategies as st
|
@@ -45,15 +45,6 @@ def assert_array_ndindex(
|
45 | 45 | assert out[out_idx] == x[x_idx], msg
|
46 | 46 |
|
47 | 47 |
|
48 |
| -def axis_ndindex( |
49 |
| - shape: Shape, axis: int |
50 |
| -) -> Iterator[Tuple[Tuple[Union[int, slice], ...], ...]]: |
51 |
| - iterables = [range(side) for side in shape[:axis]] |
52 |
| - for _ in range(len(shape[axis:])): |
53 |
| - iterables.append([slice(None, None)]) |
54 |
| - yield from product(*iterables) |
55 |
| - |
56 |
| - |
57 | 48 | def assert_equals(
|
58 | 49 | func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw
|
59 | 50 | ):
|
@@ -124,7 +115,10 @@ def test_concat(dtypes, kw, data):
|
124 | 115 | )
|
125 | 116 | else:
|
126 | 117 | out_indices = ah.ndindex(out.shape)
|
127 |
| - for idx in axis_ndindex(shape, axis): |
| 118 | + axis_indices = [range(side) for side in shapes[0][:_axis]] |
| 119 | + for _ in range(_axis, len(shape)): |
| 120 | + axis_indices.append([slice(None, None)]) |
| 121 | + for idx in product(*axis_indices): |
128 | 122 | f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx)
|
129 | 123 | for x_num, x in enumerate(arrays, 1):
|
130 | 124 | indexed_x = x[idx]
|
|
0 commit comments