Skip to content

Commit 216f8b7

Browse files
committed
Make hh.shapes a wrapper function, rudimentary test_concat
1 parent 730c1ce commit 216f8b7

File tree

8 files changed

+86
-60
lines changed

8 files changed

+86
-60
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,21 @@ def tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False)
113113

114114
# Use this to avoid memory errors with NumPy.
115115
# See https://github.com/numpy/numpy/issues/15753
116-
shapes = xps.array_shapes(min_dims=0, min_side=0).filter(
117-
lambda shape: prod(i for i in shape if i) < MAX_ARRAY_SIZE
118-
)
116+
def shapes(**kw):
117+
if 'min_dims' not in kw.keys():
118+
kw['min_dims'] = 0
119+
if 'min_side' not in kw.keys():
120+
kw['min_side'] = 0
121+
return xps.array_shapes(**kw).filter(
122+
lambda shape: prod(i for i in shape if i) < MAX_ARRAY_SIZE
123+
)
124+
119125

120126
one_d_shapes = xps.array_shapes(min_dims=1, max_dims=1, min_side=0, max_side=SQRT_MAX_ARRAY_SIZE)
121127

122128
# Matrix shapes assume stacks of matrices
123129
@composite
124-
def matrix_shapes(draw, stack_shapes=shapes):
130+
def matrix_shapes(draw, stack_shapes=shapes()):
125131
stack_shape = draw(stack_shapes)
126132
mat_shape = draw(xps.array_shapes(max_dims=2, min_dims=2))
127133
shape = stack_shape + mat_shape
@@ -159,13 +165,13 @@ def positive_definite_matrices(draw, dtypes=xps.floating_dtypes()):
159165
# using something like
160166
# https://github.com/scikit-learn/scikit-learn/blob/844b4be24/sklearn/datasets/_samples_generator.py#L1351.
161167
n = draw(integers(0))
162-
shape = draw(shapes) + (n, n)
168+
shape = draw(shapes()) + (n, n)
163169
assume(prod(i for i in shape if i) < MAX_ARRAY_SIZE)
164170
dtype = draw(dtypes)
165171
return broadcast_to(eye(n, dtype=dtype), shape)
166172

167173
@composite
168-
def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes):
174+
def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes()):
169175
# For now, just generate stacks of diagonal matrices.
170176
n = draw(integers(0, SQRT_MAX_ARRAY_SIZE),)
171177
stack_shape = draw(stack_shapes)

array_api_tests/meta_tests/test_hypothesis_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def valid_shape(shape) -> bool:
3232
)
3333

3434

35-
@given(hh.shapes)
35+
@given(hh.shapes())
3636
def test_shapes(shape):
3737
assert valid_shape(shape)
3838

array_api_tests/test_broadcasting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def test_broadcast_shapes_explicit_spec():
110110
@pytest.mark.parametrize('func_name', [i for i in
111111
elementwise_functions.__all__ if
112112
nargs(i) > 1])
113-
@given(shape1=shapes, shape2=shapes, data=data())
113+
@given(shape1=shapes(), shape2=shapes(), data=data())
114114
def test_broadcasting_hypothesis(func_name, shape1, shape2, data):
115115
# Internal consistency checks
116116
assert nargs(func_name) == 2

array_api_tests/test_creation_functions.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_arange(start, stop, step, dtype):
7070
or step < 0 and stop <= start)):
7171
assert a.size == ceil(asarray((stop-start)/step)), "arange() produced an array of the incorrect size"
7272

73-
@given(shapes, kwargs(dtype=none() | shared_dtypes))
73+
@given(shapes(), kwargs(dtype=none() | shared_dtypes))
7474
def test_empty(shape, kw):
7575
out = empty(shape, **kw)
7676
dtype = kw.get("dtype", None) or xp.float64
@@ -84,7 +84,7 @@ def test_empty(shape, kw):
8484

8585

8686
@given(
87-
x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shapes),
87+
x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shapes()),
8888
kw=kwargs(dtype=none() | xps.scalar_dtypes())
8989
)
9090
def test_empty_like(x, kw):
@@ -136,7 +136,7 @@ def full_fill_values(draw):
136136

137137

138138
@given(
139-
shape=shapes,
139+
shape=shapes(),
140140
fill_value=full_fill_values(),
141141
kw=shared(kwargs(dtype=none() | xps.scalar_dtypes()), key="full_kw"),
142142
)
@@ -174,7 +174,7 @@ def full_like_fill_values(draw):
174174

175175

176176
@given(
177-
x=xps.arrays(dtype=shared_dtypes, shape=shapes),
177+
x=xps.arrays(dtype=shared_dtypes, shape=shapes()),
178178
fill_value=full_like_fill_values(),
179179
kw=shared(kwargs(dtype=none() | xps.scalar_dtypes()), key="full_like_kw"),
180180
)
@@ -245,7 +245,7 @@ def make_one(dtype):
245245
return True
246246

247247

248-
@given(shapes, kwargs(dtype=none() | xps.scalar_dtypes()))
248+
@given(shapes(), kwargs(dtype=none() | xps.scalar_dtypes()))
249249
def test_ones(shape, kw):
250250
out = ones(shape, **kw)
251251
dtype = kw.get("dtype", None) or xp.float64
@@ -258,7 +258,7 @@ def test_ones(shape, kw):
258258

259259

260260
@given(
261-
x=xps.arrays(dtype=dtypes, shape=shapes),
261+
x=xps.arrays(dtype=dtypes, shape=shapes()),
262262
kw=kwargs(dtype=none() | xps.scalar_dtypes()),
263263
)
264264
def test_ones_like(x, kw):
@@ -281,7 +281,7 @@ def make_zero(dtype):
281281
return False
282282

283283

284-
@given(shapes, kwargs(dtype=none() | xps.scalar_dtypes()))
284+
@given(shapes(), kwargs(dtype=none() | xps.scalar_dtypes()))
285285
def test_zeros(shape, kw):
286286
out = zeros(shape, **kw)
287287
dtype = kw.get("dtype", None) or xp.float64
@@ -294,7 +294,7 @@ def test_zeros(shape, kw):
294294

295295

296296
@given(
297-
x=xps.arrays(dtype=dtypes, shape=shapes),
297+
x=xps.arrays(dtype=dtypes, shape=shapes()),
298298
kw=kwargs(dtype=none() | xps.scalar_dtypes()),
299299
)
300300
def test_zeros_like(x, kw):

array_api_tests/test_elementwise_functions.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def sanity_check(x1, x2):
4949
except ValueError:
5050
raise RuntimeError("Error in test generation (probably a bug in the test suite")
5151

52-
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes))
52+
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
5353
def test_abs(x):
5454
if dh.is_int_dtype(x.dtype):
5555
minval = dh.dtype_ranges[x.dtype][0]
@@ -66,7 +66,7 @@ def test_abs(x):
6666
# abs(x) = x for x >= 0
6767
ah.assert_exactly_equal(a[ah.logical_not(less_zero)], x[ah.logical_not(less_zero)])
6868

69-
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes))
69+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
7070
def test_acos(x):
7171
a = xp.acos(x)
7272
ONE = ah.one(x.shape, x.dtype)
@@ -80,7 +80,7 @@ def test_acos(x):
8080
# nan, which is already tested in the special cases.
8181
ah.assert_exactly_equal(domain, codomain)
8282

83-
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes))
83+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
8484
def test_acosh(x):
8585
a = xp.acosh(x)
8686
ONE = ah.one(x.shape, x.dtype)
@@ -102,7 +102,7 @@ def test_add(x1, x2):
102102
ah.assert_exactly_equal(a, b)
103103
# TODO: Test that add is actually addition
104104

105-
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes))
105+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
106106
def test_asin(x):
107107
a = xp.asin(x)
108108
ONE = ah.one(x.shape, x.dtype)
@@ -113,7 +113,7 @@ def test_asin(x):
113113
# mapped to nan, which is already tested in the special cases.
114114
ah.assert_exactly_equal(domain, codomain)
115115

116-
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes))
116+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
117117
def test_asinh(x):
118118
a = xp.asinh(x)
119119
INFINITY = ah.infinity(x.shape, x.dtype)
@@ -123,7 +123,7 @@ def test_asinh(x):
123123
# mapped to nan, which is already tested in the special cases.
124124
ah.assert_exactly_equal(domain, codomain)
125125

126-
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes))
126+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
127127
def test_atan(x):
128128
a = xp.atan(x)
129129
INFINITY = ah.infinity(x.shape, x.dtype)
@@ -170,7 +170,7 @@ def test_atan2(x1, x2):
170170
ah.assert_exactly_equal(ah.logical_or(ah.logical_and(negx1, posx2),
171171
ah.logical_and(negx1, negx2)), nega)
172172

173-
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes))
173+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
174174
def test_atanh(x):
175175
a = xp.atanh(x)
176176
ONE = ah.one(x.shape, x.dtype)
@@ -230,7 +230,7 @@ def test_bitwise_left_shift(x1, x2):
230230
vals_shift = ah.int_to_dtype(vals_shift, dh.dtype_nbits[out.dtype], dh.dtype_signed[out.dtype])
231231
assert vals_shift == res
232232

233-
@given(xps.arrays(dtype=hh.integer_or_boolean_dtypes, shape=hh.shapes))
233+
@given(xps.arrays(dtype=hh.integer_or_boolean_dtypes, shape=hh.shapes()))
234234
def test_bitwise_invert(x):
235235
out = xp.bitwise_invert(x)
236236
# Compare against the Python ~ operator.
@@ -322,7 +322,7 @@ def test_bitwise_xor(x1, x2):
322322
vals_xor = ah.int_to_dtype(vals_xor, dh.dtype_nbits[out.dtype], dh.dtype_signed[out.dtype])
323323
assert vals_xor == res
324324

325-
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes))
325+
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
326326
def test_ceil(x):
327327
# This test is almost identical to test_floor()
328328
a = xp.ceil(x)
@@ -333,7 +333,7 @@ def test_ceil(x):
333333
integers = ah.isintegral(x)
334334
ah.assert_exactly_equal(a[integers], x[integers])
335335

336-
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes))
336+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
337337
def test_cos(x):
338338
a = xp.cos(x)
339339
ONE = ah.one(x.shape, x.dtype)
@@ -344,7 +344,7 @@ def test_cos(x):
344344
# to nan, which is already tested in the special cases.
345345
ah.assert_exactly_equal(domain, codomain)
346346

347-
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes))
347+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
348348
def test_cosh(x):
349349
a = xp.cosh(x)
350350
INFINITY = ah.infinity(x.shape, x.dtype)
@@ -379,7 +379,7 @@ def test_equal(x1, x2):
379379

380380
# First we broadcast the arrays so that they can be indexed uniformly.
381381
# TODO: it should be possible to skip this step if we instead generate
382-
# indices to x1 and x2 that correspond to the broadcasted hh.shapes. This
382+
# indices to x1 and x2 that correspond to the broadcasted shapes. This
383383
# would avoid the dependence in this test on broadcast_to().
384384
shape = broadcast_shapes(x1.shape, x2.shape)
385385
_x1 = xp.broadcast_to(x1, shape)
@@ -414,7 +414,7 @@ def test_equal(x1, x2):
414414
assert aidx.shape == x1idx.shape == x2idx.shape
415415
assert bool(aidx) == (scalar_func(x1idx) == scalar_func(x2idx))
416416

417-
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes))
417+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
418418
def test_exp(x):
419419
a = xp.exp(x)
420420
INFINITY = ah.infinity(x.shape, x.dtype)
@@ -425,7 +425,7 @@ def test_exp(x):
425425
# mapped to nan, which is already tested in the special cases.
426426
ah.assert_exactly_equal(domain, codomain)
427427

428-
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes))
428+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
429429
def test_expm1(x):
430430
a = xp.expm1(x)
431431
INFINITY = ah.infinity(x.shape, x.dtype)
@@ -436,7 +436,7 @@ def test_expm1(x):
436436
# mapped to nan, which is already tested in the special cases.
437437
ah.assert_exactly_equal(domain, codomain)
438438

439-
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes))
439+
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
440440
def test_floor(x):
441441
# This test is almost identical to test_ceil
442442
a = xp.floor(x)
@@ -529,7 +529,7 @@ def test_greater_equal(x1, x2):
529529
assert aidx.shape == x1idx.shape == x2idx.shape
530530
assert bool(aidx) == (scalar_func(x1idx) >= scalar_func(x2idx))
531531

532-
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes))
532+
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
533533
def test_isfinite(x):
534534
a = ah.isfinite(x)
535535
TRUE = ah.true(x.shape)
@@ -545,7 +545,7 @@ def test_isfinite(x):
545545
s = float(x[idx])
546546
assert bool(a[idx]) == math.isfinite(s)
547547

548-
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes))
548+
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
549549
def test_isinf(x):
550550
a = xp.isinf(x)
551551
FALSE = ah.false(x.shape)
@@ -560,7 +560,7 @@ def test_isinf(x):
560560
s = float(x[idx])
561561
assert bool(a[idx]) == math.isinf(s)
562562

563-
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes))
563+
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
564564
def test_isnan(x):
565565
a = ah.isnan(x)
566566
FALSE = ah.false(x.shape)
@@ -633,7 +633,7 @@ def test_less_equal(x1, x2):
633633
assert aidx.shape == x1idx.shape == x2idx.shape
634634
assert bool(aidx) == (scalar_func(x1idx) <= scalar_func(x2idx))
635635

636-
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes))
636+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
637637
def test_log(x):
638638
a = xp.log(x)
639639
INFINITY = ah.infinity(x.shape, x.dtype)
@@ -644,7 +644,7 @@ def test_log(x):
644644
# mapped to nan, which is already tested in the special cases.
645645
ah.assert_exactly_equal(domain, codomain)
646646

647-
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes))
647+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
648648
def test_log1p(x):
649649
a = xp.log1p(x)
650650
INFINITY = ah.infinity(x.shape, x.dtype)
@@ -655,7 +655,7 @@ def test_log1p(x):
655655
# mapped to nan, which is already tested in the special cases.
656656
ah.assert_exactly_equal(domain, codomain)
657657

658-
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes))
658+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
659659
def test_log2(x):
660660
a = xp.log2(x)
661661
INFINITY = ah.infinity(x.shape, x.dtype)
@@ -666,7 +666,7 @@ def test_log2(x):
666666
# mapped to nan, which is already tested in the special cases.
667667
ah.assert_exactly_equal(domain, codomain)
668668

669-
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes))
669+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
670670
def test_log10(x):
671671
a = xp.log10(x)
672672
INFINITY = ah.infinity(x.shape, x.dtype)
@@ -698,7 +698,7 @@ def test_logical_and(x1, x2):
698698
for idx in ah.ndindex(shape):
699699
assert a[idx] == (bool(_x1[idx]) and bool(_x2[idx]))
700700

701-
@given(xps.arrays(dtype=xp.bool, shape=hh.shapes))
701+
@given(xps.arrays(dtype=xp.bool, shape=hh.shapes()))
702702
def test_logical_not(x):
703703
a = ah.logical_not(x)
704704

@@ -740,7 +740,7 @@ def test_multiply(x1, x2):
740740
# multiply is commutative
741741
ah.assert_exactly_equal(a, b)
742742

743-
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes))
743+
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
744744
def test_negative(x):
745745
out = ah.negative(x)
746746

@@ -790,7 +790,7 @@ def test_not_equal(x1, x2):
790790
assert bool(aidx) == (scalar_func(x1idx) != scalar_func(x2idx))
791791

792792

793-
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes))
793+
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
794794
def test_positive(x):
795795
out = xp.positive(x)
796796
# Positive does nothing
@@ -817,7 +817,7 @@ def test_remainder(x1, x2):
817817
not_nan = ah.logical_not(ah.logical_or(ah.isnan(out), ah.isnan(x2)))
818818
ah.assert_same_sign(out[not_nan], x2[not_nan])
819819

820-
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes))
820+
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
821821
def test_round(x):
822822
a = xp.round(x)
823823

array_api_tests/test_indexing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def test_slicing(size, s):
5858
for i in range(len(sliced_list)):
5959
assert sliced_array[i] == sliced_list[i], "Slice index did not give the same elements as slicing an equivalent Python list"
6060

61-
@given(shared(shapes, key='array_shapes'),
62-
multiaxis_indices(shapes=shared(shapes, key='array_shapes')))
61+
@given(shared(shapes(), key='array_shapes'),
62+
multiaxis_indices(shapes=shared(shapes(), key='array_shapes')))
6363
def test_multiaxis_indexing(shape, idx):
6464
# NOTE: Out of bounds indices (both integer and slices) are out of scope
6565
# for the spec. If you get a (valid) out of bounds error, it indicates a

0 commit comments

Comments
 (0)