Skip to content

Replace uses of x.size with math.prod(x.shape) #162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions array_api_tests/test_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,16 +175,17 @@ def test_arange(dtype, data):
#
min_size = math.floor(size * 0.9)
max_size = max(math.ceil(size * 1.1), 1)
out_size = math.prod(out.shape)
assert (
min_size <= out.size <= max_size
), f"{out.size=}, but should be roughly {size} {f_func}"
min_size <= out_size <= max_size
), f"prod(out.shape)={out_size}, but should be roughly {size} {f_func}"
if dh.is_int_dtype(_dtype):
elements = list(r)
assume(out.size == len(elements))
assume(out_size == len(elements))
ph.assert_array_elements("arange", out, xp.asarray(elements, dtype=_dtype))
else:
assume(out.size == size)
if out.size > 0:
assume(out_size == size)
if out_size > 0:
assert xp.equal(
out[0], xp.asarray(_start, dtype=out.dtype)
), f"out[0]={out[0]}, but should be {_start} {f_func}"
Expand Down Expand Up @@ -497,7 +498,8 @@ def test_meshgrid(dtype, data):
for i, shape in enumerate(shapes, 1):
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}")
arrays.append(x)
assert math.prod(x.size for x in arrays) <= hh.MAX_ARRAY_SIZE # sanity check
# sanity check
assert math.prod(math.prod(x.shape) for x in arrays) <= hh.MAX_ARRAY_SIZE
out = xp.meshgrid(*arrays)
for i, x in enumerate(out):
ph.assert_dtype("meshgrid", dtype, x.dtype, repr_name=f"out[{i}].dtype")
Expand Down
2 changes: 1 addition & 1 deletion array_api_tests/test_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_concat(dtypes, base_shape, data):
ph.assert_result_shape("concat", shapes, out.shape, shape, **kw)

if _axis is None:
out_indices = (i for i in range(out.size))
out_indices = (i for i in range(math.prod(out.shape)))
for x_num, x in enumerate(arrays, 1):
for x_idx in sh.ndindex(x.shape):
out_i = next(out_indices)
Expand Down
18 changes: 11 additions & 7 deletions array_api_tests/test_searching_functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

import pytest
from hypothesis import given
from hypothesis import strategies as st
Expand Down Expand Up @@ -90,12 +92,14 @@ def test_nonzero(x):
assert len(out) == 1, f"{len(out)=}, but should be 1 for 0-dimensional arrays"
else:
assert len(out) == x.ndim, f"{len(out)=}, but should be {x.ndim=}"
size = out[0].size
out_size = math.prod(out[0].shape)
for i in range(len(out)):
assert out[i].ndim == 1, f"out[{i}].ndim={x.ndim}, but should be 1"
assert (
out[i].size == size
), f"out[{i}].size={x.size}, but should be out[0].size={size}"
size_at = math.prod(out[i].shape)
assert size_at == out_size, (
f"prod(out[{i}].shape)={size_at}, "
f"but should be prod(out[0].shape)={out_size}"
)
ph.assert_default_index("nonzero", out[i].dtype, repr_name=f"out[{i}].dtype")
indices = []
if x.dtype == xp.bool:
Expand All @@ -107,11 +111,11 @@ def test_nonzero(x):
if x[idx] != 0:
indices.append(idx)
if x.ndim == 0:
assert out[0].size == len(
assert out_size == len(
indices
), f"{out[0].size=}, but should be {len(indices)}"
), f"prod(out[0].shape)={out_size}, but should be {len(indices)}"
else:
for i in range(size):
for i in range(out_size):
idx = tuple(int(x[i]) for x in out)
f_idx = f"Extrapolated index (x[{i}] for x in out)={idx}"
f_element = f"x[{idx}]={x[idx]}"
Expand Down
8 changes: 4 additions & 4 deletions array_api_tests/test_set_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_unique_all(x):
vals_idx[val] = idx

if dh.is_float_dtype(out.values.dtype):
assume(x.size <= 128) # may not be representable
assume(math.prod(x.shape) <= 128) # may not be representable
expected = sum(v for k, v in counts.items() if math.isnan(k))
assert nans == expected, f"{nans} NaNs in out, but should be {expected}"

Expand Down Expand Up @@ -157,7 +157,7 @@ def test_unique_counts(x):
), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]"
vals_idx[val] = idx
if dh.is_float_dtype(out.values.dtype):
assume(x.size <= 128) # may not be representable
assume(math.prod(x.shape) <= 128) # may not be representable
expected = sum(v for k, v in counts.items() if math.isnan(k))
assert nans == expected, f"{nans} NaNs in out, but should be {expected}"

Expand Down Expand Up @@ -210,7 +210,7 @@ def test_unique_inverse(x):
else:
assert val == expected, msg
if dh.is_float_dtype(out.values.dtype):
assume(x.size <= 128) # may not be representable
assume(math.prod(x.shape) <= 128) # may not be representable
expected = xp.sum(xp.astype(xp.isnan(x), xp.uint8))
assert nans == expected, f"{nans} NaNs in out.values, but should be {expected}"

Expand All @@ -234,6 +234,6 @@ def test_unique_values(x):
), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]"
vals_idx[val] = idx
if dh.is_float_dtype(out.dtype):
assume(x.size <= 128) # may not be representable
assume(math.prod(x.shape) <= 128) # may not be representable
expected = xp.sum(xp.astype(xp.isnan(x), xp.uint8))
assert nans == expected, f"{nans} NaNs in out, but should be {expected}"
4 changes: 2 additions & 2 deletions array_api_tests/test_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_prod(x, data):
dtype=xps.floating_dtypes(),
shape=hh.shapes(min_side=1),
elements={"allow_nan": False},
).filter(lambda x: x.size >= 2),
).filter(lambda x: math.prod(x.shape) >= 2),
data=st.data(),
)
def test_std(x, data):
Expand Down Expand Up @@ -273,7 +273,7 @@ def test_sum(x, data):
dtype=xps.floating_dtypes(),
shape=hh.shapes(min_side=1),
elements={"allow_nan": False},
).filter(lambda x: x.size >= 2),
).filter(lambda x: math.prod(x.shape) >= 2),
data=st.data(),
)
def test_var(x, data):
Expand Down