Skip to content

Commit 523bf4c

Browse files
committed
Merge remote-tracking branch 'upstream/master' into guard-tests-no-complex-dtypes
2 parents 462d0d3 + f1c3ed2 commit 523bf4c

File tree

7 files changed

+153
-48
lines changed

7 files changed

+153
-48
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def from_dtype(dtype, **kwargs) -> SearchStrategy[Scalar]:
6464

6565

6666
@wraps(xps.arrays)
67-
def arrays(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]:
67+
def arrays_no_scalars(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]:
6868
"""xps.arrays() without the crazy large numbers."""
6969
if isinstance(dtype, SearchStrategy):
7070
return dtype.flatmap(lambda d: arrays(d, *args, elements=elements, **kwargs))
@@ -77,6 +77,19 @@ def arrays(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]:
7777
return xps.arrays(dtype, *args, elements=elements, **kwargs)
7878

7979

80+
def _f(a, flag):
81+
return a[()] if a.ndim==0 and flag else a
82+
83+
84+
@wraps(xps.arrays)
85+
def arrays(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]:
86+
"""xps.arrays() without the crazy large numbers. Also draw 0D arrays or numpy scalars.
87+
88+
Is only relevant for numpy: on all other libraries, array[()] is no-op.
89+
"""
90+
return builds(_f, arrays_no_scalars(dtype, *args, elements=elements, **kwargs), booleans())
91+
92+
8093
_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.real_float_dtypes, dh.complex_dtypes]
8194
_sorted_dtypes = [d for category in _dtype_categories for d in category]
8295

@@ -232,6 +245,68 @@ def shapes(**kw):
232245
lambda shape: math.prod(i for i in shape if i) < MAX_ARRAY_SIZE
233246
)
234247

248+
def _factorize(n: int) -> List[int]:
249+
# Simple prime factorization. Only needs to handle n ~ MAX_ARRAY_SIZE
250+
factors = []
251+
while n % 2 == 0:
252+
factors.append(2)
253+
n //= 2
254+
255+
for i in range(3, int(math.sqrt(n)) + 1, 2):
256+
while n % i == 0:
257+
factors.append(i)
258+
n //= i
259+
260+
if n > 1: # n is a prime number greater than 2
261+
factors.append(n)
262+
263+
return factors
264+
265+
MAX_SIDE = MAX_ARRAY_SIZE // 64
266+
# NumPy only supports up to 32 dims. TODO: Get this from the new inspection APIs
267+
MAX_DIMS = min(MAX_ARRAY_SIZE // MAX_SIDE, 32)
268+
269+
270+
@composite
271+
def reshape_shapes(draw, arr_shape, ndims=integers(1, MAX_DIMS)):
272+
"""
273+
Generate shape tuples whose product equals the product of array_shape.
274+
"""
275+
shape = draw(arr_shape)
276+
277+
array_size = math.prod(shape)
278+
279+
n_dims = draw(ndims)
280+
281+
# Handle special cases
282+
if array_size == 0:
283+
# Generate a random tuple, and ensure at least one of the entries is 0
284+
result = list(draw(shapes(min_dims=n_dims, max_dims=n_dims)))
285+
pos = draw(integers(0, n_dims - 1))
286+
result[pos] = 0
287+
return tuple(result)
288+
289+
if array_size == 1:
290+
return tuple(1 for _ in range(n_dims))
291+
292+
# Get prime factorization
293+
factors = _factorize(array_size)
294+
295+
# Distribute prime factors randomly
296+
result = [1] * n_dims
297+
for factor in factors:
298+
pos = draw(integers(0, n_dims - 1))
299+
result[pos] *= factor
300+
301+
assert math.prod(result) == array_size
302+
303+
# An element of the reshape tuple can be -1, which means it is a stand-in
304+
# for the remaining factors.
305+
if draw(booleans()):
306+
pos = draw(integers(0, n_dims - 1))
307+
result[pos] = -1
308+
309+
return tuple(result)
235310

236311
one_d_shapes = xps.array_shapes(min_dims=1, max_dims=1, min_side=0, max_side=SQRT_MAX_ARRAY_SIZE)
237312

array_api_tests/test_creation_functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,8 @@ def scalar_eq(s1: Scalar, s2: Scalar) -> bool:
263263
data=st.data(),
264264
)
265265
def test_asarray_arrays(shape, dtypes, data):
266-
x = data.draw(hh.arrays(dtype=dtypes.input_dtype, shape=shape), label="x")
266+
# generate arrays only since we draw the copy= kwd below (and np.asarray(scalar, copy=False) error out)
267+
x = data.draw(hh.arrays_no_scalars(dtype=dtypes.input_dtype, shape=shape), label="x")
267268
dtypes_strat = st.just(dtypes.input_dtype)
268269
if dtypes.input_dtype == dtypes.result_dtype:
269270
dtypes_strat |= st.none()

array_api_tests/test_data_type_functions.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Union
33

44
import pytest
5-
from hypothesis import given
5+
from hypothesis import given, assume
66
from hypothesis import strategies as st
77

88
from . import _array_module as xp
@@ -23,26 +23,43 @@ def float32(n: Union[int, float]) -> float:
2323
return struct.unpack("!f", struct.pack("!f", float(n)))[0]
2424

2525

26+
def _float_match_complex(complex_dtype):
27+
return xp.float32 if complex_dtype == xp.complex64 else xp.float64
28+
29+
2630
@given(
27-
x_dtype=non_complex_dtypes(),
28-
dtype=non_complex_dtypes(),
31+
x_dtype=hh.all_dtypes,
32+
dtype=hh.all_dtypes,
2933
kw=hh.kwargs(copy=st.booleans()),
3034
data=st.data(),
3135
)
3236
def test_astype(x_dtype, dtype, kw, data):
37+
_complex_dtypes = (xp.complex64, xp.complex128)
38+
3339
if xp.bool in (x_dtype, dtype):
3440
elements_strat = hh.from_dtype(x_dtype)
3541
else:
36-
m1, M1 = dh.dtype_ranges[x_dtype]
37-
m2, M2 = dh.dtype_ranges[dtype]
42+
3843
if dh.is_int_dtype(x_dtype):
3944
cast = int
40-
elif x_dtype == xp.float32:
45+
elif x_dtype in (xp.float32, xp.complex64):
4146
cast = float32
4247
else:
4348
cast = float
49+
50+
real_dtype = x_dtype
51+
if x_dtype in _complex_dtypes:
52+
real_dtype = _float_match_complex(x_dtype)
53+
m1, M1 = dh.dtype_ranges[real_dtype]
54+
55+
real_dtype = dtype
56+
if dtype in _complex_dtypes:
57+
real_dtype = _float_match_complex(x_dtype)
58+
m2, M2 = dh.dtype_ranges[real_dtype]
59+
4460
min_value = cast(max(m1, m2))
4561
max_value = cast(min(M1, M2))
62+
4663
elements_strat = hh.from_dtype(
4764
x_dtype,
4865
min_value=min_value,
@@ -54,6 +71,11 @@ def test_astype(x_dtype, dtype, kw, data):
5471
hh.arrays(dtype=x_dtype, shape=hh.shapes(), elements=elements_strat), label="x"
5572
)
5673

74+
# according to the spec, "Casting a complex floating-point array to a real-valued
75+
# data type should not be permitted."
76+
# https://data-apis.org/array-api/latest/API_specification/generated/array_api.astype.html#astype
77+
assume(not ((x_dtype in _complex_dtypes) and (dtype not in _complex_dtypes)))
78+
5779
out = xp.astype(x, dtype, **kw)
5880

5981
ph.assert_kw_dtype("astype", kw_dtype=dtype, out_dtype=out.dtype)

array_api_tests/test_fft.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def test_fft(x, data):
120120
assert_n_axis_shape("fft", x=x, n=n, axis=axis, out=out)
121121

122122

123+
if hh.complex_dtypes:
123124
@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
124125
def test_ifft(x, data):
125126
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)
@@ -130,6 +131,7 @@ def test_ifft(x, data):
130131
assert_n_axis_shape("ifft", x=x, n=n, axis=axis, out=out)
131132

132133

134+
if hh.complex_dtypes:
133135
@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
134136
def test_fftn(x, data):
135137
s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)
@@ -140,6 +142,7 @@ def test_fftn(x, data):
140142
assert_s_axes_shape("fftn", x=x, s=s, axes=axes, out=out)
141143

142144

145+
if hh.complex_dtypes:
143146
@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
144147
def test_ifftn(x, data):
145148
s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)
@@ -230,21 +233,20 @@ def test_irfftn(x, data):
230233
expected=dh.dtype_components[x.dtype],
231234
)
232235

233-
# TODO: assert shape correctly
234-
# _axes = sh.normalize_axis(axes, x.ndim)
235-
# _s = x.shape if s is None else s
236-
# expected = []
237-
# for i in range(x.ndim):
238-
# if i in _axes:
239-
# side = _s[_axes.index(i)]
240-
# else:
241-
# side = x.shape[i]
242-
# expected.append(side)
243-
# last_axis = max(_axes)
244-
# expected[last_axis] = _s[_axes.index(last_axis)] // 2 + 1
245-
# ph.assert_shape("irfftn", out_shape=out.shape, expected=tuple(expected))
236+
_axes = sh.normalize_axis(axes, x.ndim)
237+
_s = x.shape if s is None else s
238+
expected = []
239+
for i in range(x.ndim):
240+
if i in _axes:
241+
side = _s[_axes.index(i)]
242+
else:
243+
side = x.shape[i]
244+
expected.append(side)
245+
expected[_axes[-1]] = 2*(_s[-1] - 1) if s is None else _s[-1]
246+
ph.assert_shape("irfftn", out_shape=out.shape, expected=tuple(expected))
246247

247248

249+
if hh.complex_dtypes:
248250
@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
249251
def test_hfft(x, data):
250252
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data, size_gt_1=True)

array_api_tests/test_linalg.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from . import _array_module as xp
4646
from ._array_module import linalg
4747

48+
4849
def assert_equal(x, y, msg_extra=None):
4950
extra = '' if not msg_extra else f' ({msg_extra})'
5051
if x.dtype in dh.all_float_dtypes:
@@ -60,6 +61,7 @@ def assert_equal(x, y, msg_extra=None):
6061
else:
6162
assert_exactly_equal(x, y, msg_extra=msg_extra)
6263

64+
6365
def _test_stacks(f, *args, res=None, dims=2, true_val=None,
6466
matrix_axes=(-2, -1),
6567
res_axes=None,
@@ -106,6 +108,7 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None,
106108
if true_val:
107109
assert_equal(decomp_res_stack, true_val(*x_stacks, **kw), msg_extra)
108110

111+
109112
def _test_namedtuple(res, fields, func_name):
110113
"""
111114
Test that res is a namedtuple with the correct fields.
@@ -121,6 +124,7 @@ def _test_namedtuple(res, fields, func_name):
121124
assert hasattr(res, field), f"{func_name}() result namedtuple doesn't have the '{field}' field"
122125
assert res[i] is getattr(res, field), f"{func_name}() result namedtuple '{field}' field is not in position {i}"
123126

127+
124128
@pytest.mark.unvectorized
125129
@pytest.mark.xp_extension('linalg')
126130
@given(
@@ -901,6 +905,15 @@ def true_trace(x_stack, offset=0):
901905

902906
_test_stacks(linalg.trace, x, **kw, res=res, dims=0, true_val=true_trace)
903907

908+
909+
def _conj(x):
910+
# XXX: replace with xp.dtype when all array libraries implement it
911+
if x.dtype in (xp.complex64, xp.complex128):
912+
return xp.conj(x)
913+
else:
914+
return x
915+
916+
904917
def _test_vecdot(namespace, x1, x2, data):
905918
vecdot = namespace.vecdot
906919
broadcasted_shape = sh.broadcast_shapes(x1.shape, x2.shape)
@@ -925,11 +938,8 @@ def _test_vecdot(namespace, x1, x2, data):
925938
ph.assert_result_shape("vecdot", in_shapes=[x1.shape, x2.shape],
926939
out_shape=res.shape, expected=expected_shape)
927940

928-
if x1.dtype in dh.int_dtypes:
929-
def true_val(x, y, axis=-1):
930-
return xp.sum(xp.multiply(x, y), dtype=res.dtype)
931-
else:
932-
true_val = None
941+
def true_val(x, y, axis=-1):
942+
return xp.sum(xp.multiply(_conj(x), y), dtype=res.dtype)
933943

934944
_test_stacks(vecdot, x1, x2, res=res, dims=0,
935945
matrix_axes=(axis,), true_val=true_val)
@@ -944,6 +954,7 @@ def true_val(x, y, axis=-1):
944954
def test_linalg_vecdot(x1, x2, data):
945955
_test_vecdot(linalg, x1, x2, data)
946956

957+
947958
@pytest.mark.unvectorized
948959
@given(
949960
*two_mutual_arrays(dh.numeric_dtypes, mutually_broadcastable_shapes(2, min_dims=1)),
@@ -952,10 +963,12 @@ def test_linalg_vecdot(x1, x2, data):
952963
def test_vecdot(x1, x2, data):
953964
_test_vecdot(_array_module, x1, x2, data)
954965

966+
955967
# Insanely large orders might not work. There isn't a limit specified in the
956968
# spec, so we just limit to reasonable values here.
957969
max_ord = 100
958970

971+
959972
@pytest.mark.unvectorized
960973
@pytest.mark.xp_extension('linalg')
961974
@given(

array_api_tests/test_manipulation_functions.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414
from . import xps
1515
from .typing import Array, Shape
1616

17-
MAX_SIDE = hh.MAX_ARRAY_SIZE // 64
18-
MAX_DIMS = min(hh.MAX_ARRAY_SIZE // MAX_SIDE, 32) # NumPy only supports up to 32 dims
19-
2017

2118
def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]:
2219
key = "shape"
@@ -66,7 +63,7 @@ def test_concat(dtypes, base_shape, data):
6663
shape_strat = hh.shapes()
6764
else:
6865
_axis = axis if axis >= 0 else len(base_shape) + axis
69-
shape_strat = st.integers(0, MAX_SIDE).map(
66+
shape_strat = st.integers(0, hh.MAX_SIDE).map(
7067
lambda i: base_shape[:_axis] + (i,) + base_shape[_axis + 1 :]
7168
)
7269
arrays = []
@@ -348,26 +345,14 @@ def test_repeat(x, kw, data):
348345
kw=kw)
349346
start = end
350347

351-
@st.composite
352-
def reshape_shapes(draw, shape):
353-
size = 1 if len(shape) == 0 else math.prod(shape)
354-
rshape = draw(st.lists(st.integers(0)).filter(lambda s: math.prod(s) == size))
355-
assume(all(side <= MAX_SIDE for side in rshape))
356-
if len(rshape) != 0 and size > 0 and draw(st.booleans()):
357-
index = draw(st.integers(0, len(rshape) - 1))
358-
rshape[index] = -1
359-
return tuple(rshape)
360-
348+
reshape_shape = st.shared(hh.shapes(), key="reshape_shape")
361349

362350
@pytest.mark.unvectorized
363-
@pytest.mark.skip("flaky") # TODO: fix!
364351
@given(
365-
x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(max_side=MAX_SIDE)),
366-
data=st.data(),
352+
x=hh.arrays(dtype=hh.all_dtypes, shape=reshape_shape),
353+
shape=hh.reshape_shapes(reshape_shape),
367354
)
368-
def test_reshape(x, data):
369-
shape = data.draw(reshape_shapes(x.shape))
370-
355+
def test_reshape(x, shape):
371356
out = xp.reshape(x, shape)
372357

373358
ph.assert_dtype("reshape", in_dtype=x.dtype, out_dtype=out.dtype)

array_api_tests/test_special_cases.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@
2020
from decimal import ROUND_HALF_EVEN, Decimal
2121
from enum import Enum, auto
2222
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Literal
23-
from warnings import warn
23+
from warnings import warn, filterwarnings, catch_warnings
2424

2525
import pytest
2626
from hypothesis import given, note, settings, assume
2727
from hypothesis import strategies as st
28+
from hypothesis.errors import NonInteractiveExampleWarning
2829

2930
from array_api_tests.typing import Array, DataType
3031

@@ -1250,7 +1251,13 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]
12501251

12511252
@pytest.mark.parametrize("func_name, func, case", unary_params)
12521253
def test_unary(func_name, func, case):
1253-
in_value = case.cond_from_dtype(xp.float64).example()
1254+
with catch_warnings():
1255+
# XXX: We are using example here to generate one example draw, but
1256+
# hypothesis issues a warning from this. We should consider either
1257+
# drawing multiple examples like a normal test, or just hard-coding a
1258+
# single example test case without using hypothesis.
1259+
filterwarnings('ignore', category=NonInteractiveExampleWarning)
1260+
in_value = case.cond_from_dtype(xp.float64).example()
12541261
x = xp.asarray(in_value, dtype=xp.float64)
12551262
out = func(x)
12561263
out_value = float(out)

0 commit comments

Comments
 (0)