@@ -225,12 +225,14 @@ def test_arange(dtype, data):
225
225
# [0.0, 0.33, 0.66, 1.0, 1.33, 1.66]
226
226
#
227
227
assert math .floor (math .sqrt (size )) <= out .size <= math .ceil (size ** 2 )
228
-
229
228
assume (out .size == size )
230
229
if dh .is_int_dtype (_dtype ):
231
230
ah .assert_exactly_equal (out , ah .asarray (list (r ), dtype = _dtype ))
232
231
else :
233
- pass # TODO: either emulate array module behaviour or assert a rough equals
232
+ if out .size > 0 :
233
+ assert ah .equal (
234
+ out [0 ], ah .asarray (_start , dtype = out .dtype )
235
+ ), f"out[0]={ out [0 ]} , but should be { _start } [linspace({ start = } , { stop = } )]"
234
236
235
237
236
238
@given (hh .shapes (), hh .kwargs (dtype = st .none () | hh .shared_dtypes ))
@@ -357,15 +359,21 @@ def test_full_like(x, fill_value, kw):
357
359
finite_kw = {"allow_nan" : False , "allow_infinity" : False }
358
360
359
361
360
- @st .composite
361
- def int_stops (draw , start : int , min_gap : int , m : int , M : int ):
362
- sign = draw (st .booleans ().map (int ))
363
- max_gap = abs (M - m )
364
- max_int = math .floor (math .sqrt (max_gap ))
365
- gap = draw (st .just (0 ) | st .integers (1 , max_int ).map (lambda n : min_gap ** n ))
366
- stop = start + sign * gap
367
- assume (m <= stop <= M )
368
- return stop
362
+ def int_stops (
363
+ start : int , num , dtype : DataType , endpoint : bool
364
+ ) -> st .SearchStrategy [int ]:
365
+ min_gap = num
366
+ if endpoint :
367
+ min_gap += 1
368
+ m , M = dh .dtype_ranges [dtype ]
369
+ max_pos_gap = M - start
370
+ max_neg_gap = start - m
371
+ max_pos_mul = max_pos_gap // min_gap
372
+ max_neg_mul = max_neg_gap // min_gap
373
+ return st .one_of (
374
+ st .integers (0 , max_pos_mul ).map (lambda n : start + min_gap * n ),
375
+ st .integers (0 , max_neg_mul ).map (lambda n : start - min_gap * n ),
376
+ )
369
377
370
378
371
379
@given (
@@ -381,17 +389,13 @@ def test_linspace(num, dtype, endpoint, data):
381
389
if dh .is_float_dtype (_dtype ):
382
390
stop = data .draw (xps .from_dtype (_dtype , ** finite_kw ), label = "stop" )
383
391
# avoid overflow errors
384
- delta = ah .asarray (stop - start , dtype = _dtype )
385
- assume (not ah .isnan (delta ))
392
+ assume ( not ah .isnan ( ah . asarray (stop - start , dtype = _dtype )) )
393
+ assume (not ah .isnan (ah . asarray ( start - stop , dtype = _dtype ) ))
386
394
else :
387
395
if num == 0 :
388
396
stop = start
389
397
else :
390
- min_gap = num
391
- if endpoint :
392
- min_gap += 1
393
- m , M = dh .dtype_ranges [_dtype ]
394
- stop = data .draw (int_stops (start , min_gap , m , M ), label = "stop" )
398
+ stop = data .draw (int_stops (start , num , _dtype , endpoint ), label = "stop" )
395
399
396
400
kw = data .draw (
397
401
specified_kwargs (
@@ -403,7 +407,10 @@ def test_linspace(num, dtype, endpoint, data):
403
407
out = xp .linspace (start , stop , num , ** kw )
404
408
405
409
assert_shape ("linspace" , out .shape , num , start = stop , stop = stop , num = num )
406
-
410
+ if num > 0 :
411
+ assert ah .equal (
412
+ out [0 ], ah .asarray (start , dtype = out .dtype )
413
+ ), f"out[0]={ out [0 ]} , but should be { start = } [linspace({ stop = } , { num = } )]"
407
414
if endpoint :
408
415
if num > 1 :
409
416
assert ah .equal (
@@ -415,13 +422,9 @@ def test_linspace(num, dtype, endpoint, data):
415
422
expected = xp .linspace (start , stop , num + 1 , dtype = dtype , endpoint = True )
416
423
expected = expected [:- 1 ]
417
424
ah .assert_exactly_equal (out , expected )
418
-
419
- if num > 0 :
420
- assert ah .equal (
421
- out [0 ], ah .asarray (start , dtype = out .dtype )
422
- ), f"out[0]={ out [0 ]} , but should be { start = } [linspace({ stop = } , { num = } )]"
423
-
424
- # TODO: array assertions ala test_arange
425
+ assert (
426
+ out .size == num
427
+ ), f"{ out .size = } , but should be { num = } [linspace({ start = } , { stop = } )]"
425
428
426
429
427
430
def make_one (dtype : DataType ) -> Scalar :
0 commit comments