1
1
import math
2
- from typing import Union , Any , Tuple
3
- from itertools import takewhile , count
2
+ from typing import Union , Any , Tuple , NamedTuple , Iterator
3
+ from itertools import count
4
4
5
5
from hypothesis import assume , given , strategies as st
6
6
@@ -79,6 +79,31 @@ def assert_fill(
79
79
assert ah .all (ah .equal (out , ah .asarray (fill_value , dtype = dtype ))), msg
80
80
81
81
82
+ class frange (NamedTuple ):
83
+ start : float
84
+ stop : float
85
+ step : float
86
+
87
+ def __iter__ (self ) -> Iterator [float ]:
88
+ pos_range = self .stop > self .start
89
+ pos_step = self .step > 0
90
+ if pos_step != pos_range :
91
+ return
92
+ if pos_range :
93
+ for n in count (self .start , self .step ):
94
+ if n >= self .stop :
95
+ break
96
+ yield n
97
+ else :
98
+ for n in count (self .start , self .step ):
99
+ if n <= self .stop :
100
+ break
101
+ yield n
102
+
103
+ def __len__ (self ) -> int :
104
+ return max (math .ceil ((self .stop - self .start ) / self .step ), 0 )
105
+
106
+
82
107
# Testing xp.arange() requires bounding the start/stop/step arguments to only
83
108
# test argument combinations compliant with the Array API, as well as to not
84
109
# produce arrays with sizes not supproted by an array module.
@@ -165,20 +190,8 @@ def test_arange(dtype, data):
165
190
assert m <= _stop <= M
166
191
assert m <= step <= M
167
192
168
- pos_range = _stop > _start
169
- pos_step = step > 0
170
- if _start != _stop and pos_range == pos_step :
171
- if pos_step :
172
- condition = lambda x : x < _stop
173
- else :
174
- condition = lambda x : x > _stop
175
- scalar_type = int if dh .is_int_dtype (_dtype ) else float
176
- elements = list (
177
- scalar_type (n ) for n in takewhile (condition , count (_start , step ))
178
- )
179
- else :
180
- elements = []
181
- size = len (elements )
193
+ r = frange (_start , _stop , step )
194
+ size = len (r )
182
195
assert (
183
196
size <= hh .MAX_ARRAY_SIZE
184
197
), f"{ size = } should be no more than { hh .MAX_ARRAY_SIZE } " # sanity check
@@ -200,8 +213,8 @@ def test_arange(dtype, data):
200
213
assert_default_float ("arange" , out .dtype )
201
214
else :
202
215
assert out .dtype == dtype
203
- assert out .ndim == 1
204
- if dh .is_int_dtype (step ):
216
+ assert out .ndim == 1 , f" { out . ndim = } , should be 1 [linspace()]"
217
+ if dh .is_int_dtype (_dtype ):
205
218
assert out .size == size
206
219
else :
207
220
# We check size is roughly as expected to avoid edge cases e.g.
@@ -212,9 +225,10 @@ def test_arange(dtype, data):
212
225
# [0.0, 0.33, 0.66, 1.0, 1.33, 1.66]
213
226
#
214
227
assert math .floor (math .sqrt (size )) <= out .size <= math .ceil (size ** 2 )
228
+
215
229
assume (out .size == size )
216
230
if dh .is_int_dtype (_dtype ):
217
- ah .assert_exactly_equal (out , ah .asarray (elements , dtype = _dtype ))
231
+ ah .assert_exactly_equal (out , ah .asarray (list ( r ) , dtype = _dtype ))
218
232
else :
219
233
pass # TODO: either emulate array module behaviour or assert a rough equals
220
234
0 commit comments