Skip to content

Commit 1b30535

Browse files
committed
specified_kwargs() strategy to test default kwargs
1 parent 7fe9c96 commit 1b30535

File tree

1 file changed

+37
-6
lines changed

1 file changed

+37
-6
lines changed

array_api_tests/test_creation_functions.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import Union
2+
from typing import Union, Any, Tuple
33
from itertools import takewhile, count
44

55
from hypothesis import assume, given, strategies as st
@@ -13,6 +13,21 @@
1313
from .typing import Shape, DataType, Array, Scalar
1414

1515

16+
@st.composite
17+
def specified_kwargs(draw, *keys_values_defaults: Tuple[str, Any, Any]):
18+
"""Generates valid kwargs given expected defaults.
19+
20+
When we can't realistically use hh.kwargs() and thus test whether xp infact
21+
defaults correctly, this strategy lets us remove generated arguments if they
22+
are of the default value anyway.
23+
"""
24+
kw = {}
25+
for key, value, default in keys_values_defaults:
26+
if value is not default or draw(st.booleans()):
27+
kw[key] = value
28+
return kw
29+
30+
1631
def assert_default_float(func_name: str, dtype: DataType):
1732
f_dtype = dh.dtype_to_name[dtype]
1833
f_default = dh.dtype_to_name[dh.default_float]
@@ -168,7 +183,15 @@ def test_arange(dtype, data):
168183
size <= hh.MAX_ARRAY_SIZE
169184
), f"{size=} should be no more than {hh.MAX_ARRAY_SIZE}" # sanity check
170185

171-
out = xp.arange(start, stop=stop, step=step, dtype=dtype)
186+
kw = data.draw(
187+
specified_kwargs(
188+
("stop", stop, None),
189+
("step", step, None),
190+
("dtype", dtype, None),
191+
),
192+
label="kw",
193+
)
194+
out = xp.arange(start, **kw)
172195

173196
if dtype is None:
174197
if all_int:
@@ -356,15 +379,22 @@ def test_linspace(num, dtype, endpoint, data):
356379
m, M = dh.dtype_ranges[_dtype]
357380
stop = data.draw(int_stops(start, min_gap, m, M), label="stop")
358381

359-
out = xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint)
382+
kw = data.draw(
383+
specified_kwargs(
384+
("dtype", dtype, None),
385+
("endpoint", endpoint, True),
386+
),
387+
label="kw",
388+
)
389+
out = xp.linspace(start, stop, num, **kw)
360390

361391
assert_shape("linspace", out.shape, num, start=stop, stop=stop, num=num)
362392

363393
if endpoint:
364394
if num > 1:
365395
assert ah.equal(
366396
out[-1], ah.asarray(stop, dtype=out.dtype)
367-
), f"out[-1]={out[-1]}, but should be {stop=} [linspace()]"
397+
), f"out[-1]={out[-1]}, but should be {stop=} [linspace({start=}, {num=})]"
368398
else:
369399
# linspace(..., num, endpoint=True) should return an array equivalent to
370400
# the first num elements when endpoint=False
@@ -375,8 +405,9 @@ def test_linspace(num, dtype, endpoint, data):
375405
if num > 0:
376406
assert ah.equal(
377407
out[0], ah.asarray(start, dtype=out.dtype)
378-
), f"out[0]={out[0]}, but should be {start=} [linspace()]"
379-
# TODO: array assertions ala test_arange
408+
), f"out[0]={out[0]}, but should be {start=} [linspace({stop=}, {num=})]"
409+
410+
# TODO: array assertions ala test_arange
380411

381412

382413
def make_one(dtype: DataType) -> Scalar:

0 commit comments

Comments
 (0)