diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index 5c0fa6c1..76a6a072 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -175,16 +175,17 @@ def test_arange(dtype, data): # min_size = math.floor(size * 0.9) max_size = max(math.ceil(size * 1.1), 1) + out_size = math.prod(out.shape) assert ( - min_size <= out.size <= max_size - ), f"{out.size=}, but should be roughly {size} {f_func}" + min_size <= out_size <= max_size + ), f"prod(out.shape)={out_size}, but should be roughly {size} {f_func}" if dh.is_int_dtype(_dtype): elements = list(r) - assume(out.size == len(elements)) + assume(out_size == len(elements)) ph.assert_array_elements("arange", out, xp.asarray(elements, dtype=_dtype)) else: - assume(out.size == size) - if out.size > 0: + assume(out_size == size) + if out_size > 0: assert xp.equal( out[0], xp.asarray(_start, dtype=out.dtype) ), f"out[0]={out[0]}, but should be {_start} {f_func}" @@ -497,7 +498,8 @@ def test_meshgrid(dtype, data): for i, shape in enumerate(shapes, 1): x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") arrays.append(x) - assert math.prod(x.size for x in arrays) <= hh.MAX_ARRAY_SIZE # sanity check + # sanity check + assert math.prod(math.prod(x.shape) for x in arrays) <= hh.MAX_ARRAY_SIZE out = xp.meshgrid(*arrays) for i, x in enumerate(out): ph.assert_dtype("meshgrid", dtype, x.dtype, repr_name=f"out[{i}].dtype") diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index b9d9e03d..a30a0030 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -91,7 +91,7 @@ def test_concat(dtypes, base_shape, data): ph.assert_result_shape("concat", shapes, out.shape, shape, **kw) if _axis is None: - out_indices = (i for i in range(out.size)) + out_indices = (i for i in range(math.prod(out.shape))) for x_num, x in enumerate(arrays, 1): for x_idx in sh.ndindex(x.shape): out_i = next(out_indices) diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index 6b134bb0..41f5a77c 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -1,3 +1,5 @@ +import math + import pytest from hypothesis import given from hypothesis import strategies as st @@ -90,12 +92,14 @@ def test_nonzero(x): assert len(out) == 1, f"{len(out)=}, but should be 1 for 0-dimensional arrays" else: assert len(out) == x.ndim, f"{len(out)=}, but should be {x.ndim=}" - size = out[0].size + out_size = math.prod(out[0].shape) for i in range(len(out)): assert out[i].ndim == 1, f"out[{i}].ndim={x.ndim}, but should be 1" - assert ( - out[i].size == size - ), f"out[{i}].size={x.size}, but should be out[0].size={size}" + size_at = math.prod(out[i].shape) + assert size_at == out_size, ( + f"prod(out[{i}].shape)={size_at}, " + f"but should be prod(out[0].shape)={out_size}" + ) ph.assert_default_index("nonzero", out[i].dtype, repr_name=f"out[{i}].dtype") indices = [] if x.dtype == xp.bool: @@ -107,11 +111,11 @@ def test_nonzero(x): if x[idx] != 0: indices.append(idx) if x.ndim == 0: - assert out[0].size == len( + assert out_size == len( indices - ), f"{out[0].size=}, but should be {len(indices)}" + ), f"prod(out[0].shape)={out_size}, but should be {len(indices)}" else: - for i in range(size): + for i in range(out_size): idx = tuple(int(x[i]) for x in out) f_idx = f"Extrapolated index (x[{i}] for x in out)={idx}" f_element = f"x[{idx}]={x[idx]}" diff --git a/array_api_tests/test_set_functions.py b/array_api_tests/test_set_functions.py index 5bae6147..5e415858 100644 --- a/array_api_tests/test_set_functions.py +++ b/array_api_tests/test_set_functions.py @@ -110,7 +110,7 @@ def test_unique_all(x): vals_idx[val] = idx if dh.is_float_dtype(out.values.dtype): - assume(x.size <= 128) # may not be representable + assume(math.prod(x.shape) <= 128) # may not be representable expected = sum(v for k, v in counts.items() if math.isnan(k)) assert nans == expected, f"{nans} NaNs in out, but should be {expected}" @@ -157,7 +157,7 @@ def test_unique_counts(x): ), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" vals_idx[val] = idx if dh.is_float_dtype(out.values.dtype): - assume(x.size <= 128) # may not be representable + assume(math.prod(x.shape) <= 128) # may not be representable expected = sum(v for k, v in counts.items() if math.isnan(k)) assert nans == expected, f"{nans} NaNs in out, but should be {expected}" @@ -210,7 +210,7 @@ def test_unique_inverse(x): else: assert val == expected, msg if dh.is_float_dtype(out.values.dtype): - assume(x.size <= 128) # may not be representable + assume(math.prod(x.shape) <= 128) # may not be representable expected = xp.sum(xp.astype(xp.isnan(x), xp.uint8)) assert nans == expected, f"{nans} NaNs in out.values, but should be {expected}" @@ -234,6 +234,6 @@ def test_unique_values(x): ), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" vals_idx[val] = idx if dh.is_float_dtype(out.dtype): - assume(x.size <= 128) # may not be representable + assume(math.prod(x.shape) <= 128) # may not be representable expected = xp.sum(xp.astype(xp.isnan(x), xp.uint8)) assert nans == expected, f"{nans} NaNs in out, but should be {expected}" diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index c88d0a53..dbc08651 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -175,7 +175,7 @@ def test_prod(x, data): dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1), elements={"allow_nan": False}, - ).filter(lambda x: x.size >= 2), + ).filter(lambda x: math.prod(x.shape) >= 2), data=st.data(), ) def test_std(x, data): @@ -273,7 +273,7 @@ def test_sum(x, data): dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1), elements={"allow_nan": False}, - ).filter(lambda x: x.size >= 2), + ).filter(lambda x: math.prod(x.shape) >= 2), data=st.data(), ) def test_var(x, data):