From 041717ed6df4ca94dcb96c72e9c5056537cade84 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 3 Mar 2025 17:14:54 +0100 Subject: [PATCH 1/2] add a test for count_nonzero Parrot the test from test_{argmin, argmax} --- array_api_tests/test_searching_functions.py | 38 +++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index 579cbbc2..2fea0670 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -88,6 +88,44 @@ def test_argmin(x, data): ph.assert_scalar_equals("argmin", type_=int, idx=out_idx, out=min_i, expected=expected) +@pytest.mark.min_version("2024.12") +@given( + x=hh.arrays( + dtype=hh.real_dtypes, + shape=hh.shapes(min_dims=1, min_side=1), + elements={"allow_nan": False}, + ), + data=st.data(), +) +def test_count_nonzero(x, data): + kw = data.draw( + hh.kwargs( + axis=st.none() | st.integers(-x.ndim, max(x.ndim - 1, 0)), + keepdims=st.booleans(), + ), + label="kw", + ) + keepdims = kw.get("keepdims", False) + + out = xp.count_nonzero(x, **kw) + + ph.assert_default_index("count_nonzero", out.dtype) + axes = sh.normalize_axis(kw.get("axis", None), x.ndim) + ph.assert_keepdimable_shape( + "count_nonzero", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw + ) + scalar_type = dh.get_scalar_type(x.dtype) + + for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)): + count = int(out[out_idx]) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = sum(el != 0 for el in elements) + ph.assert_scalar_equals("count_nonzero", type_=int, idx=out_idx, out=count, expected=expected) + + @given(hh.arrays(dtype=hh.all_dtypes, shape=())) def test_nonzero_zerodim_error(x): with pytest.raises(Exception): From f5a38826b14572413f2cb3650d46bf74de37cfd7 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 3 Mar 2025 17:05:36 +0000 Subject: [PATCH 2/2] BUG: work around/comment on the strategy for x in count_nonzero On torch, work around count_nonzero not implemented for uints On jax, there are problems with integers > iinfo(jnp.int32) --- array_api_tests/test_searching_functions.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index 2fea0670..412085c5 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -88,10 +88,18 @@ def test_argmin(x, data): ph.assert_scalar_equals("argmin", type_=int, idx=out_idx, out=min_i, expected=expected) +# XXX: dtype= stanza below is to work around unsigned int dtypes in torch +# (count_nonzero_cpu not implemented for uint32 etc) +# XXX: the strategy for x is problematic on JAX unless JAX_ENABLE_X64 is on +# the problem is tha for ints >iinfo(int32) it runs into essentially this: +# >>> jnp.asarray[2147483648], dtype=jnp.int64) +# .... https://github.com/jax-ml/jax/pull/6047 ... +# Explicitly limiting the range in elements(...) runs into problems with +# hypothesis where floating-point numbers are not exactly representable. @pytest.mark.min_version("2024.12") @given( x=hh.arrays( - dtype=hh.real_dtypes, + dtype=st.sampled_from(dh.int_dtypes + dh.real_float_dtypes + dh.complex_dtypes + (xp.bool,)), shape=hh.shapes(min_dims=1, min_side=1), elements={"allow_nan": False}, ),