Skip to content

Commit f6ea2c1

Browse files
committed
Improve int_stops() scope and performance
1 parent fa50249 commit f6ea2c1

File tree

1 file changed

+29
-26
lines changed

1 file changed

+29
-26
lines changed

array_api_tests/test_creation_functions.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -225,12 +225,14 @@ def test_arange(dtype, data):
225225
# [0.0, 0.33, 0.66, 1.0, 1.33, 1.66]
226226
#
227227
assert math.floor(math.sqrt(size)) <= out.size <= math.ceil(size ** 2)
228-
229228
assume(out.size == size)
230229
if dh.is_int_dtype(_dtype):
231230
ah.assert_exactly_equal(out, ah.asarray(list(r), dtype=_dtype))
232231
else:
233-
pass # TODO: either emulate array module behaviour or assert a rough equals
232+
if out.size > 0:
233+
assert ah.equal(
234+
out[0], ah.asarray(_start, dtype=out.dtype)
235+
), f"out[0]={out[0]}, but should be {_start} [linspace({start=}, {stop=})]"
234236

235237

236238
@given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.shared_dtypes))
@@ -357,15 +359,21 @@ def test_full_like(x, fill_value, kw):
357359
finite_kw = {"allow_nan": False, "allow_infinity": False}
358360

359361

360-
@st.composite
361-
def int_stops(draw, start: int, min_gap: int, m: int, M: int):
362-
sign = draw(st.booleans().map(int))
363-
max_gap = abs(M - m)
364-
max_int = math.floor(math.sqrt(max_gap))
365-
gap = draw(st.just(0) | st.integers(1, max_int).map(lambda n: min_gap ** n))
366-
stop = start + sign * gap
367-
assume(m <= stop <= M)
368-
return stop
362+
def int_stops(
363+
start: int, num, dtype: DataType, endpoint: bool
364+
) -> st.SearchStrategy[int]:
365+
min_gap = num
366+
if endpoint:
367+
min_gap += 1
368+
m, M = dh.dtype_ranges[dtype]
369+
max_pos_gap = M - start
370+
max_neg_gap = start - m
371+
max_pos_mul = max_pos_gap // min_gap
372+
max_neg_mul = max_neg_gap // min_gap
373+
return st.one_of(
374+
st.integers(0, max_pos_mul).map(lambda n: start + min_gap * n),
375+
st.integers(0, max_neg_mul).map(lambda n: start - min_gap * n),
376+
)
369377

370378

371379
@given(
@@ -381,17 +389,13 @@ def test_linspace(num, dtype, endpoint, data):
381389
if dh.is_float_dtype(_dtype):
382390
stop = data.draw(xps.from_dtype(_dtype, **finite_kw), label="stop")
383391
# avoid overflow errors
384-
delta = ah.asarray(stop - start, dtype=_dtype)
385-
assume(not ah.isnan(delta))
392+
assume(not ah.isnan(ah.asarray(stop - start, dtype=_dtype)))
393+
assume(not ah.isnan(ah.asarray(start - stop, dtype=_dtype)))
386394
else:
387395
if num == 0:
388396
stop = start
389397
else:
390-
min_gap = num
391-
if endpoint:
392-
min_gap += 1
393-
m, M = dh.dtype_ranges[_dtype]
394-
stop = data.draw(int_stops(start, min_gap, m, M), label="stop")
398+
stop = data.draw(int_stops(start, num, _dtype, endpoint), label="stop")
395399

396400
kw = data.draw(
397401
specified_kwargs(
@@ -403,7 +407,10 @@ def test_linspace(num, dtype, endpoint, data):
403407
out = xp.linspace(start, stop, num, **kw)
404408

405409
assert_shape("linspace", out.shape, num, start=stop, stop=stop, num=num)
406-
410+
if num > 0:
411+
assert ah.equal(
412+
out[0], ah.asarray(start, dtype=out.dtype)
413+
), f"out[0]={out[0]}, but should be {start=} [linspace({stop=}, {num=})]"
407414
if endpoint:
408415
if num > 1:
409416
assert ah.equal(
@@ -415,13 +422,9 @@ def test_linspace(num, dtype, endpoint, data):
415422
expected = xp.linspace(start, stop, num + 1, dtype=dtype, endpoint=True)
416423
expected = expected[:-1]
417424
ah.assert_exactly_equal(out, expected)
418-
419-
if num > 0:
420-
assert ah.equal(
421-
out[0], ah.asarray(start, dtype=out.dtype)
422-
), f"out[0]={out[0]}, but should be {start=} [linspace({stop=}, {num=})]"
423-
424-
# TODO: array assertions ala test_arange
425+
assert (
426+
out.size == num
427+
), f"{out.size=}, but should be {num=} [linspace({start=}, {stop=})]"
425428

426429

427430
def make_one(dtype: DataType) -> Scalar:

0 commit comments

Comments
 (0)