Skip to content

Commit c0c6ba9

Browse files
committed
Limit the repititions by the total array size in test_repeat
1 parent d218c36 commit c0c6ba9

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

array_api_tests/test_manipulation_functions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def test_repeat(x, kw, data):
301301
shape = x.shape
302302
axis = kw.get("axis", None)
303303
size = math.prod(shape) if axis is None else shape[axis]
304-
repeat_strat = st.integers(1, 4)
304+
repeat_strat = st.integers(1, 10)
305305
repeats = data.draw(repeat_strat
306306
| hh.arrays(dtype=hh.int_dtypes, elements=repeat_strat,
307307
shape=st.sampled_from([(1,), (size,)])),
@@ -314,6 +314,8 @@ def test_repeat(x, kw, data):
314314
else:
315315
n_repititions = int(xp.sum(repeats))
316316

317+
assume(n_repititions <= hh.SQRT_MAX_ARRAY_SIZE)
318+
317319
out = xp.repeat(x, repeats, **kw)
318320
ph.assert_dtype("repeat", in_dtype=x.dtype, out_dtype=out.dtype)
319321
if axis is None:

0 commit comments

Comments
 (0)