Skip to content

Commit 0189c79

Browse files
committed
Fix test_concat axes iteration
1 parent 0c13438 commit 0189c79

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

array_api_tests/test_manipulation_functions.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
22
from collections import deque
33
from itertools import product
4-
from typing import Iterable, Iterator, Tuple, Union
4+
from typing import Iterable, Union
55

66
from hypothesis import assume, given
77
from hypothesis import strategies as st
@@ -45,15 +45,6 @@ def assert_array_ndindex(
4545
assert out[out_idx] == x[x_idx], msg
4646

4747

48-
def axis_ndindex(
49-
shape: Shape, axis: int
50-
) -> Iterator[Tuple[Tuple[Union[int, slice], ...], ...]]:
51-
iterables = [range(side) for side in shape[:axis]]
52-
for _ in range(len(shape[axis:])):
53-
iterables.append([slice(None, None)])
54-
yield from product(*iterables)
55-
56-
5748
def assert_equals(
5849
func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw
5950
):
@@ -124,7 +115,10 @@ def test_concat(dtypes, kw, data):
124115
)
125116
else:
126117
out_indices = ah.ndindex(out.shape)
127-
for idx in axis_ndindex(shape, axis):
118+
axis_indices = [range(side) for side in shapes[0][:_axis]]
119+
for _ in range(_axis, len(shape)):
120+
axis_indices.append([slice(None, None)])
121+
for idx in product(*axis_indices):
128122
f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx)
129123
for x_num, x in enumerate(arrays, 1):
130124
indexed_x = x[idx]

0 commit comments

Comments
 (0)