@@ -59,12 +59,14 @@ def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType):
59
59
60
60
61
61
def assert_shape (
62
- func_name : str , out_shape : Shape , expected : Union [int , Shape ], / , ** kw
62
+ func_name : str , out_shape : Union [ int , Shape ] , expected : Union [int , Shape ], / , ** kw
63
63
):
64
- f_kw = ", " . join ( f" { k } = { v } " for k , v in kw . items ())
65
- msg = f"out.shape= { out_shape } , but should be { expected } [ { func_name } ( { f_kw } )]"
64
+ if isinstance ( out_shape , int ):
65
+ out_shape = ( out_shape ,)
66
66
if isinstance (expected , int ):
67
67
expected = (expected ,)
68
+ f_kw = ", " .join (f"{ k } ={ v } " for k , v in kw .items ())
69
+ msg = f"out.shape={ out_shape } , but should be { expected } [{ func_name } ({ f_kw } )]"
68
70
assert out_shape == expected , msg
69
71
70
72
@@ -183,7 +185,7 @@ def test_arange(dtype, data):
183
185
else :
184
186
_dtype = dtype
185
187
186
- # sanity check
188
+ # sanity checks
187
189
if dh .is_int_dtype (_dtype ):
188
190
m , M = dh .dtype_ranges [_dtype ]
189
191
assert m <= _start <= M
@@ -213,9 +215,10 @@ def test_arange(dtype, data):
213
215
assert_default_float ("arange" , out .dtype )
214
216
else :
215
217
assert out .dtype == dtype
216
- assert out .ndim == 1 , f"{ out .ndim = } , should be 1 [linspace()]"
218
+ assert out .ndim == 1 , f"{ out .ndim = } , but should be 1 [linspace()]"
219
+ f_func = f"[linspace({ start = } , { stop = } , { step = } )]"
217
220
if dh .is_int_dtype (_dtype ):
218
- assert out .size == size
221
+ assert out .size == size , f" { out . size = } , but should be { size } { f_func } "
219
222
else :
220
223
# We check size is roughly as expected to avoid edge cases e.g.
221
224
#
@@ -224,7 +227,11 @@ def test_arange(dtype, data):
224
227
# >>> xp.arange(2, step=0.3333333333333333)
225
228
# [0.0, 0.33, 0.66, 1.0, 1.33, 1.66]
226
229
#
227
- assert math .floor (math .sqrt (size )) <= out .size <= math .ceil (size ** 2 )
230
+ min_size = math .floor (size * 0.9 )
231
+ max_size = math .ceil (size * 1.1 )
232
+ assert (
233
+ min_size <= out .size <= max_size
234
+ ), f"{ out .size = } , but should be roughly { size } { f_func } "
228
235
assume (out .size == size )
229
236
if dh .is_int_dtype (_dtype ):
230
237
ah .assert_exactly_equal (out , ah .asarray (list (r ), dtype = _dtype ))
@@ -407,24 +414,22 @@ def test_linspace(num, dtype, endpoint, data):
407
414
out = xp .linspace (start , stop , num , ** kw )
408
415
409
416
assert_shape ("linspace" , out .shape , num , start = stop , stop = stop , num = num )
417
+ f_func = f"[linspace({ start = } , { stop = } , { num = } )]"
410
418
if num > 0 :
411
419
assert ah .equal (
412
420
out [0 ], ah .asarray (start , dtype = out .dtype )
413
- ), f"out[0]={ out [0 ]} , but should be { start = } [linspace( { stop = } , { num = } )] "
421
+ ), f"out[0]={ out [0 ]} , but should be { start } { f_func } "
414
422
if endpoint :
415
423
if num > 1 :
416
424
assert ah .equal (
417
425
out [- 1 ], ah .asarray (stop , dtype = out .dtype )
418
- ), f"out[-1]={ out [- 1 ]} , but should be { stop = } [linspace( { start = } , { num = } )] "
426
+ ), f"out[-1]={ out [- 1 ]} , but should be { stop } { f_func } "
419
427
else :
420
428
# linspace(..., num, endpoint=True) should return an array equivalent to
421
429
# the first num elements when endpoint=False
422
430
expected = xp .linspace (start , stop , num + 1 , dtype = dtype , endpoint = True )
423
431
expected = expected [:- 1 ]
424
432
ah .assert_exactly_equal (out , expected )
425
- assert (
426
- out .size == num
427
- ), f"{ out .size = } , but should be { num = } [linspace({ start = } , { stop = } )]"
428
433
429
434
430
435
def make_one (dtype : DataType ) -> Scalar :
0 commit comments