diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index 3accb2c6..b72c8030 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -1,7 +1,7 @@ import math import pytest -from hypothesis import given, note +from hypothesis import given, note, assume from hypothesis import strategies as st from . import _array_module as xp @@ -106,13 +106,15 @@ def test_argmin(x, data): def test_count_nonzero(x, data): kw = data.draw( hh.kwargs( - axis=st.none() | st.integers(-x.ndim, max(x.ndim - 1, 0)), + axis=hh.axes(x.ndim), keepdims=st.booleans(), ), label="kw", ) keepdims = kw.get("keepdims", False) + assume(kw.get("axis", None) != ()) # TODO clarify in the spec + out = xp.count_nonzero(x, **kw) ph.assert_default_index("count_nonzero", out.dtype)