Skip to content

Commit fa50249

Browse files
committed
frange class to mimic range behaviour for floats
1 parent ebb6c36 commit fa50249

File tree

2 files changed

+52
-21
lines changed

2 files changed

+52
-21
lines changed

array_api_tests/meta/test_utils.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,26 @@
1+
import pytest
2+
13
from ..test_signatures import extension_module
4+
from ..test_creation_functions import frange
25

36

47
def test_extension_module_is_extension():
5-
assert extension_module('linalg')
8+
assert extension_module("linalg")
69

710

811
def test_extension_func_is_not_extension():
9-
assert not extension_module('linalg.cross')
12+
assert not extension_module("linalg.cross")
13+
14+
15+
@pytest.mark.parametrize(
16+
"r, size, elements",
17+
[
18+
(frange(0, 1, 1), 1, [0]),
19+
(frange(1, 0, -1), 1, [1]),
20+
(frange(0, 1, -1), 0, []),
21+
(frange(0, 1, 2), 1, [0]),
22+
],
23+
)
24+
def test_frange(r, size, elements):
25+
assert len(r) == size
26+
assert list(r) == elements

array_api_tests/test_creation_functions.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
2-
from typing import Union, Any, Tuple
3-
from itertools import takewhile, count
2+
from typing import Union, Any, Tuple, NamedTuple, Iterator
3+
from itertools import count
44

55
from hypothesis import assume, given, strategies as st
66

@@ -79,6 +79,31 @@ def assert_fill(
7979
assert ah.all(ah.equal(out, ah.asarray(fill_value, dtype=dtype))), msg
8080

8181

82+
class frange(NamedTuple):
83+
start: float
84+
stop: float
85+
step: float
86+
87+
def __iter__(self) -> Iterator[float]:
88+
pos_range = self.stop > self.start
89+
pos_step = self.step > 0
90+
if pos_step != pos_range:
91+
return
92+
if pos_range:
93+
for n in count(self.start, self.step):
94+
if n >= self.stop:
95+
break
96+
yield n
97+
else:
98+
for n in count(self.start, self.step):
99+
if n <= self.stop:
100+
break
101+
yield n
102+
103+
def __len__(self) -> int:
104+
return max(math.ceil((self.stop - self.start) / self.step), 0)
105+
106+
82107
# Testing xp.arange() requires bounding the start/stop/step arguments to only
83108
# test argument combinations compliant with the Array API, as well as to not
84109
# produce arrays with sizes not supproted by an array module.
@@ -165,20 +190,8 @@ def test_arange(dtype, data):
165190
assert m <= _stop <= M
166191
assert m <= step <= M
167192

168-
pos_range = _stop > _start
169-
pos_step = step > 0
170-
if _start != _stop and pos_range == pos_step:
171-
if pos_step:
172-
condition = lambda x: x < _stop
173-
else:
174-
condition = lambda x: x > _stop
175-
scalar_type = int if dh.is_int_dtype(_dtype) else float
176-
elements = list(
177-
scalar_type(n) for n in takewhile(condition, count(_start, step))
178-
)
179-
else:
180-
elements = []
181-
size = len(elements)
193+
r = frange(_start, _stop, step)
194+
size = len(r)
182195
assert (
183196
size <= hh.MAX_ARRAY_SIZE
184197
), f"{size=} should be no more than {hh.MAX_ARRAY_SIZE}" # sanity check
@@ -200,8 +213,8 @@ def test_arange(dtype, data):
200213
assert_default_float("arange", out.dtype)
201214
else:
202215
assert out.dtype == dtype
203-
assert out.ndim == 1
204-
if dh.is_int_dtype(step):
216+
assert out.ndim == 1, f"{out.ndim=}, should be 1 [linspace()]"
217+
if dh.is_int_dtype(_dtype):
205218
assert out.size == size
206219
else:
207220
# We check size is roughly as expected to avoid edge cases e.g.
@@ -212,9 +225,10 @@ def test_arange(dtype, data):
212225
# [0.0, 0.33, 0.66, 1.0, 1.33, 1.66]
213226
#
214227
assert math.floor(math.sqrt(size)) <= out.size <= math.ceil(size ** 2)
228+
215229
assume(out.size == size)
216230
if dh.is_int_dtype(_dtype):
217-
ah.assert_exactly_equal(out, ah.asarray(elements, dtype=_dtype))
231+
ah.assert_exactly_equal(out, ah.asarray(list(r), dtype=_dtype))
218232
else:
219233
pass # TODO: either emulate array module behaviour or assert a rough equals
220234

0 commit comments

Comments
 (0)