Skip to content

Commit ca2ef81

Browse files
committed
Only generate 1D arrays in test_meshgrid, prevent memory errors
1 parent aaf0a7d commit ca2ef81

File tree

1 file changed

+84
-75
lines changed

1 file changed

+84
-75
lines changed
Lines changed: 84 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html
33
"""
4+
import math
45
from collections import defaultdict
56
from typing import Tuple, Union, List
67

@@ -24,22 +25,30 @@
2425
@given(hh.mutually_promotable_dtypes(None))
2526
def test_result_type(dtypes):
2627
out = xp.result_type(*dtypes)
27-
ph.assert_dtype('result_type', dtypes, out, out_name='out')
28+
ph.assert_dtype("result_type", dtypes, out, out_name="out")
2829

2930

31+
# The number and size of generated arrays is arbitrarily limited to prevent
32+
# meshgrid() running out of memory.
3033
@given(
31-
dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes),
34+
dtypes=hh.mutually_promotable_dtypes(5, dtypes=dh.numeric_dtypes),
3235
data=st.data(),
3336
)
3437
def test_meshgrid(dtypes, data):
3538
arrays = []
36-
shapes = data.draw(hh.mutually_broadcastable_shapes(len(dtypes)), label='shapes')
39+
shapes = data.draw(
40+
hh.mutually_broadcastable_shapes(
41+
len(dtypes), min_dims=1, max_dims=1, max_side=5
42+
),
43+
label="shapes",
44+
)
3745
for i, (dtype, shape) in enumerate(zip(dtypes, shapes), 1):
38-
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f'x{i}')
46+
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}")
3947
arrays.append(x)
48+
assert math.prod(x.size for x in arrays) <= hh.MAX_ARRAY_SIZE # sanity check
4049
out = xp.meshgrid(*arrays)
4150
for i, x in enumerate(out):
42-
ph.assert_dtype('meshgrid', dtypes, x.dtype, out_name=f'out[{i}].dtype')
51+
ph.assert_dtype("meshgrid", dtypes, x.dtype, out_name=f"out[{i}].dtype")
4352

4453

4554
@given(
@@ -50,10 +59,10 @@ def test_meshgrid(dtypes, data):
5059
def test_concat(shape, dtypes, data):
5160
arrays = []
5261
for i, dtype in enumerate(dtypes, 1):
53-
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f'x{i}')
62+
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}")
5463
arrays.append(x)
5564
out = xp.concat(arrays)
56-
ph.assert_dtype('concat', dtypes, out.dtype)
65+
ph.assert_dtype("concat", dtypes, out.dtype)
5766

5867

5968
@given(
@@ -64,26 +73,26 @@ def test_concat(shape, dtypes, data):
6473
def test_stack(shape, dtypes, data):
6574
arrays = []
6675
for i, dtype in enumerate(dtypes, 1):
67-
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f'x{i}')
76+
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}")
6877
arrays.append(x)
6978
out = xp.stack(arrays)
70-
ph.assert_dtype('stack', dtypes, out.dtype)
79+
ph.assert_dtype("stack", dtypes, out.dtype)
7180

7281

7382
bitwise_shift_funcs = [
74-
'bitwise_left_shift',
75-
'bitwise_right_shift',
76-
'__lshift__',
77-
'__rshift__',
78-
'__ilshift__',
79-
'__irshift__',
83+
"bitwise_left_shift",
84+
"bitwise_right_shift",
85+
"__lshift__",
86+
"__rshift__",
87+
"__ilshift__",
88+
"__irshift__",
8089
]
8190

8291

8392
# We pass kwargs to the elements strategy used by xps.arrays() so that we don't
8493
# generate array elements that are erroneous or undefined for a function.
8594
func_elements = defaultdict(
86-
lambda: None, {func: {'min_value': 1} for func in bitwise_shift_funcs}
95+
lambda: None, {func: {"min_value": 1} for func in bitwise_shift_funcs}
8796
)
8897

8998

@@ -94,7 +103,7 @@ def make_id(
94103
) -> str:
95104
f_args = dh.fmt_types(in_dtypes)
96105
f_out_dtype = dh.dtype_to_name[out_dtype]
97-
return f'{func_name}({f_args}) -> {f_out_dtype}'
106+
return f"{func_name}({f_args}) -> {f_out_dtype}"
98107

99108

100109
func_params: List[Param[str, Tuple[DataType, ...], DataType]] = []
@@ -128,25 +137,25 @@ def make_id(
128137
raise NotImplementedError()
129138

130139

131-
@pytest.mark.parametrize('func_name, in_dtypes, out_dtype', func_params)
140+
@pytest.mark.parametrize("func_name, in_dtypes, out_dtype", func_params)
132141
@given(data=st.data())
133142
def test_func_promotion(func_name, in_dtypes, out_dtype, data):
134143
func = getattr(xp, func_name)
135144
elements = func_elements[func_name]
136145
if len(in_dtypes) == 1:
137146
x = data.draw(
138147
xps.arrays(dtype=in_dtypes[0], shape=hh.shapes(), elements=elements),
139-
label='x',
148+
label="x",
140149
)
141150
out = func(x)
142151
else:
143152
arrays = []
144153
shapes = data.draw(
145-
hh.mutually_broadcastable_shapes(len(in_dtypes)), label='shapes'
154+
hh.mutually_broadcastable_shapes(len(in_dtypes)), label="shapes"
146155
)
147156
for i, (dtype, shape) in enumerate(zip(in_dtypes, shapes), 1):
148157
x = data.draw(
149-
xps.arrays(dtype=dtype, shape=shape, elements=elements), label=f'x{i}'
158+
xps.arrays(dtype=dtype, shape=shape, elements=elements), label=f"x{i}"
150159
)
151160
arrays.append(x)
152161
try:
@@ -161,46 +170,46 @@ def test_func_promotion(func_name, in_dtypes, out_dtype, data):
161170
p = pytest.param(
162171
(dtype1, dtype2),
163172
promoted_dtype,
164-
id=make_id('', (dtype1, dtype2), promoted_dtype),
173+
id=make_id("", (dtype1, dtype2), promoted_dtype),
165174
)
166175
promotion_params.append(p)
167176

168177

169-
@pytest.mark.parametrize('in_dtypes, out_dtype', promotion_params)
178+
@pytest.mark.parametrize("in_dtypes, out_dtype", promotion_params)
170179
@given(shapes=hh.mutually_broadcastable_shapes(3), data=st.data())
171180
def test_where(in_dtypes, out_dtype, shapes, data):
172-
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label='x1')
173-
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2')
174-
cond = data.draw(xps.arrays(dtype=xp.bool, shape=shapes[2]), label='condition')
181+
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label="x1")
182+
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label="x2")
183+
cond = data.draw(xps.arrays(dtype=xp.bool, shape=shapes[2]), label="condition")
175184
out = xp.where(cond, x1, x2)
176-
ph.assert_dtype('where', in_dtypes, out.dtype, out_dtype)
185+
ph.assert_dtype("where", in_dtypes, out.dtype, out_dtype)
177186

178187

179188
numeric_promotion_params = promotion_params[1:]
180189

181190

182-
@pytest.mark.parametrize('in_dtypes, out_dtype', numeric_promotion_params)
191+
@pytest.mark.parametrize("in_dtypes, out_dtype", numeric_promotion_params)
183192
@given(shapes=hh.mutually_broadcastable_shapes(2, min_dims=2), data=st.data())
184193
def test_tensordot(in_dtypes, out_dtype, shapes, data):
185-
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label='x1')
186-
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2')
194+
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label="x1")
195+
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label="x2")
187196
out = xp.tensordot(x1, x2)
188-
ph.assert_dtype('tensordot', in_dtypes, out.dtype, out_dtype)
197+
ph.assert_dtype("tensordot", in_dtypes, out.dtype, out_dtype)
189198

190199

191-
@pytest.mark.parametrize('in_dtypes, out_dtype', numeric_promotion_params)
200+
@pytest.mark.parametrize("in_dtypes, out_dtype", numeric_promotion_params)
192201
@given(shapes=hh.mutually_broadcastable_shapes(2, min_dims=1), data=st.data())
193202
def test_vecdot(in_dtypes, out_dtype, shapes, data):
194-
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label='x1')
195-
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2')
203+
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label="x1")
204+
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label="x2")
196205
out = xp.vecdot(x1, x2)
197-
ph.assert_dtype('vecdot', in_dtypes, out.dtype, out_dtype)
206+
ph.assert_dtype("vecdot", in_dtypes, out.dtype, out_dtype)
198207

199208

200209
op_params: List[Param[str, str, Tuple[DataType, ...], DataType]] = []
201210
op_to_symbol = {**dh.unary_op_to_symbol, **dh.binary_op_to_symbol}
202211
for op, symbol in op_to_symbol.items():
203-
if op == '__matmul__':
212+
if op == "__matmul__":
204213
continue
205214
valid_in_dtypes = dh.func_in_dtypes[op]
206215
ndtypes = ph.nargs(op)
@@ -209,7 +218,7 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
209218
out_dtype = xp.bool if dh.func_returns_bool[op] else in_dtype
210219
p = pytest.param(
211220
op,
212-
f'{symbol}x',
221+
f"{symbol}x",
213222
(in_dtype,),
214223
out_dtype,
215224
id=make_id(op, (in_dtype,), out_dtype),
@@ -221,42 +230,42 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
221230
out_dtype = xp.bool if dh.func_returns_bool[op] else promoted_dtype
222231
p = pytest.param(
223232
op,
224-
f'x1 {symbol} x2',
233+
f"x1 {symbol} x2",
225234
(in_dtype1, in_dtype2),
226235
out_dtype,
227236
id=make_id(op, (in_dtype1, in_dtype2), out_dtype),
228237
)
229238
op_params.append(p)
230239
# We generate params for abs seperately as it does not have an associated symbol
231-
for in_dtype in dh.func_in_dtypes['__abs__']:
240+
for in_dtype in dh.func_in_dtypes["__abs__"]:
232241
p = pytest.param(
233-
'__abs__',
234-
'abs(x)',
242+
"__abs__",
243+
"abs(x)",
235244
(in_dtype,),
236245
in_dtype,
237-
id=make_id('__abs__', (in_dtype,), in_dtype),
246+
id=make_id("__abs__", (in_dtype,), in_dtype),
238247
)
239248
op_params.append(p)
240249

241250

242-
@pytest.mark.parametrize('op, expr, in_dtypes, out_dtype', op_params)
251+
@pytest.mark.parametrize("op, expr, in_dtypes, out_dtype", op_params)
243252
@given(data=st.data())
244253
def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
245254
elements = func_elements[func_name]
246255
if len(in_dtypes) == 1:
247256
x = data.draw(
248257
xps.arrays(dtype=in_dtypes[0], shape=hh.shapes(), elements=elements),
249-
label='x',
258+
label="x",
250259
)
251-
out = eval(expr, {'x': x})
260+
out = eval(expr, {"x": x})
252261
else:
253262
locals_ = {}
254263
shapes = data.draw(
255-
hh.mutually_broadcastable_shapes(len(in_dtypes)), label='shapes'
264+
hh.mutually_broadcastable_shapes(len(in_dtypes)), label="shapes"
256265
)
257266
for i, (dtype, shape) in enumerate(zip(in_dtypes, shapes), 1):
258-
locals_[f'x{i}'] = data.draw(
259-
xps.arrays(dtype=dtype, shape=shape, elements=elements), label=f'x{i}'
267+
locals_[f"x{i}"] = data.draw(
268+
xps.arrays(dtype=dtype, shape=shape, elements=elements), label=f"x{i}"
260269
)
261270
try:
262271
out = eval(expr, locals_)
@@ -267,7 +276,7 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
267276

268277
inplace_params: List[Param[str, str, Tuple[DataType, ...], DataType]] = []
269278
for op, symbol in dh.inplace_op_to_symbol.items():
270-
if op == '__imatmul__':
279+
if op == "__imatmul__":
271280
continue
272281
valid_in_dtypes = dh.func_in_dtypes[op]
273282
for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items():
@@ -278,44 +287,44 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
278287
):
279288
p = pytest.param(
280289
op,
281-
f'x1 {symbol} x2',
290+
f"x1 {symbol} x2",
282291
(in_dtype1, in_dtype2),
283292
promoted_dtype,
284293
id=make_id(op, (in_dtype1, in_dtype2), promoted_dtype),
285294
)
286295
inplace_params.append(p)
287296

288297

289-
@pytest.mark.parametrize('op, expr, in_dtypes, out_dtype', inplace_params)
298+
@pytest.mark.parametrize("op, expr, in_dtypes, out_dtype", inplace_params)
290299
@given(shapes=hh.mutually_broadcastable_shapes(2), data=st.data())
291300
def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data):
292301
assume(len(shapes[0]) >= len(shapes[1]))
293302
elements = func_elements[func_name]
294303
x1 = data.draw(
295-
xps.arrays(dtype=in_dtypes[0], shape=shapes[0], elements=elements), label='x1'
304+
xps.arrays(dtype=in_dtypes[0], shape=shapes[0], elements=elements), label="x1"
296305
)
297306
x2 = data.draw(
298-
xps.arrays(dtype=in_dtypes[1], shape=shapes[1], elements=elements), label='x2'
307+
xps.arrays(dtype=in_dtypes[1], shape=shapes[1], elements=elements), label="x2"
299308
)
300-
locals_ = {'x1': x1, 'x2': x2}
309+
locals_ = {"x1": x1, "x2": x2}
301310
try:
302311
exec(expr, locals_)
303312
except OverflowError:
304313
reject()
305-
x1 = locals_['x1']
306-
ph.assert_dtype(op, in_dtypes, x1.dtype, out_dtype, out_name='x1.dtype')
314+
x1 = locals_["x1"]
315+
ph.assert_dtype(op, in_dtypes, x1.dtype, out_dtype, out_name="x1.dtype")
307316

308317

309318
op_scalar_params: List[Param[str, str, DataType, ScalarType, DataType]] = []
310319
for op, symbol in dh.binary_op_to_symbol.items():
311-
if op == '__matmul__':
320+
if op == "__matmul__":
312321
continue
313322
for in_dtype in dh.func_in_dtypes[op]:
314323
out_dtype = xp.bool if dh.func_returns_bool[op] else in_dtype
315324
for in_stype in dh.dtype_to_scalars[in_dtype]:
316325
p = pytest.param(
317326
op,
318-
f'x {symbol} s',
327+
f"x {symbol} s",
319328
in_dtype,
320329
in_stype,
321330
out_dtype,
@@ -324,57 +333,57 @@ def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data):
324333
op_scalar_params.append(p)
325334

326335

327-
@pytest.mark.parametrize('op, expr, in_dtype, in_stype, out_dtype', op_scalar_params)
336+
@pytest.mark.parametrize("op, expr, in_dtype, in_stype, out_dtype", op_scalar_params)
328337
@given(data=st.data())
329338
def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data):
330339
elements = func_elements[func_name]
331-
kw = {k: in_stype is float for k in ('allow_nan', 'allow_infinity')}
332-
s = data.draw(xps.from_dtype(in_dtype, **kw).map(in_stype), label='scalar')
340+
kw = {k: in_stype is float for k in ("allow_nan", "allow_infinity")}
341+
s = data.draw(xps.from_dtype(in_dtype, **kw).map(in_stype), label="scalar")
333342
x = data.draw(
334-
xps.arrays(dtype=in_dtype, shape=hh.shapes(), elements=elements), label='x'
343+
xps.arrays(dtype=in_dtype, shape=hh.shapes(), elements=elements), label="x"
335344
)
336345
try:
337-
out = eval(expr, {'x': x, 's': s})
346+
out = eval(expr, {"x": x, "s": s})
338347
except OverflowError:
339348
reject()
340349
ph.assert_dtype(op, (in_dtype, in_stype), out.dtype, out_dtype)
341350

342351

343352
inplace_scalar_params: List[Param[str, str, DataType, ScalarType]] = []
344353
for op, symbol in dh.inplace_op_to_symbol.items():
345-
if op == '__imatmul__':
354+
if op == "__imatmul__":
346355
continue
347356
for dtype in dh.func_in_dtypes[op]:
348357
for in_stype in dh.dtype_to_scalars[dtype]:
349358
p = pytest.param(
350359
op,
351-
f'x {symbol} s',
360+
f"x {symbol} s",
352361
dtype,
353362
in_stype,
354363
id=make_id(op, (dtype, in_stype), dtype),
355364
)
356365
inplace_scalar_params.append(p)
357366

358367

359-
@pytest.mark.parametrize('op, expr, dtype, in_stype', inplace_scalar_params)
368+
@pytest.mark.parametrize("op, expr, dtype, in_stype", inplace_scalar_params)
360369
@given(data=st.data())
361370
def test_inplace_op_scalar_promotion(op, expr, dtype, in_stype, data):
362371
elements = func_elements[func_name]
363-
kw = {k: in_stype is float for k in ('allow_nan', 'allow_infinity')}
364-
s = data.draw(xps.from_dtype(dtype, **kw).map(in_stype), label='scalar')
372+
kw = {k: in_stype is float for k in ("allow_nan", "allow_infinity")}
373+
s = data.draw(xps.from_dtype(dtype, **kw).map(in_stype), label="scalar")
365374
x = data.draw(
366-
xps.arrays(dtype=dtype, shape=hh.shapes(), elements=elements), label='x'
375+
xps.arrays(dtype=dtype, shape=hh.shapes(), elements=elements), label="x"
367376
)
368-
locals_ = {'x': x, 's': s}
377+
locals_ = {"x": x, "s": s}
369378
try:
370379
exec(expr, locals_)
371380
except OverflowError:
372381
reject()
373-
x = locals_['x']
374-
assert x.dtype == dtype, f'{x.dtype=!s}, but should be {dtype}'
375-
ph.assert_dtype(op, (dtype, in_stype), x.dtype, dtype, out_name='x.dtype')
382+
x = locals_["x"]
383+
assert x.dtype == dtype, f"{x.dtype=!s}, but should be {dtype}"
384+
ph.assert_dtype(op, (dtype, in_stype), x.dtype, dtype, out_name="x.dtype")
376385

377386

378-
if __name__ == '__main__':
387+
if __name__ == "__main__":
379388
for (i, j), p in dh.promotion_table.items():
380-
print(f'({i}, {j}) -> {p}')
389+
print(f"({i}, {j}) -> {p}")

0 commit comments

Comments
 (0)