1
1
import math
2
- from typing import Union
2
+ from typing import Union , Any , Tuple
3
3
from itertools import takewhile , count
4
4
5
5
from hypothesis import assume , given , strategies as st
13
13
from .typing import Shape , DataType , Array , Scalar
14
14
15
15
16
+ @st .composite
17
+ def specified_kwargs (draw , * keys_values_defaults : Tuple [str , Any , Any ]):
18
+ """Generates valid kwargs given expected defaults.
19
+
20
+ When we can't realistically use hh.kwargs() and thus test whether xp infact
21
+ defaults correctly, this strategy lets us remove generated arguments if they
22
+ are of the default value anyway.
23
+ """
24
+ kw = {}
25
+ for key , value , default in keys_values_defaults :
26
+ if value is not default or draw (st .booleans ()):
27
+ kw [key ] = value
28
+ return kw
29
+
30
+
16
31
def assert_default_float (func_name : str , dtype : DataType ):
17
32
f_dtype = dh .dtype_to_name [dtype ]
18
33
f_default = dh .dtype_to_name [dh .default_float ]
@@ -168,7 +183,15 @@ def test_arange(dtype, data):
168
183
size <= hh .MAX_ARRAY_SIZE
169
184
), f"{ size = } should be no more than { hh .MAX_ARRAY_SIZE } " # sanity check
170
185
171
- out = xp .arange (start , stop = stop , step = step , dtype = dtype )
186
+ kw = data .draw (
187
+ specified_kwargs (
188
+ ("stop" , stop , None ),
189
+ ("step" , step , None ),
190
+ ("dtype" , dtype , None ),
191
+ ),
192
+ label = "kw" ,
193
+ )
194
+ out = xp .arange (start , ** kw )
172
195
173
196
if dtype is None :
174
197
if all_int :
@@ -356,15 +379,22 @@ def test_linspace(num, dtype, endpoint, data):
356
379
m , M = dh .dtype_ranges [_dtype ]
357
380
stop = data .draw (int_stops (start , min_gap , m , M ), label = "stop" )
358
381
359
- out = xp .linspace (start , stop , num , dtype = dtype , endpoint = endpoint )
382
+ kw = data .draw (
383
+ specified_kwargs (
384
+ ("dtype" , dtype , None ),
385
+ ("endpoint" , endpoint , True ),
386
+ ),
387
+ label = "kw" ,
388
+ )
389
+ out = xp .linspace (start , stop , num , ** kw )
360
390
361
391
assert_shape ("linspace" , out .shape , num , start = stop , stop = stop , num = num )
362
392
363
393
if endpoint :
364
394
if num > 1 :
365
395
assert ah .equal (
366
396
out [- 1 ], ah .asarray (stop , dtype = out .dtype )
367
- ), f"out[-1]={ out [- 1 ]} , but should be { stop = } [linspace()]"
397
+ ), f"out[-1]={ out [- 1 ]} , but should be { stop = } [linspace({ start = } , { num = } )]"
368
398
else :
369
399
# linspace(..., num, endpoint=True) should return an array equivalent to
370
400
# the first num elements when endpoint=False
@@ -375,8 +405,9 @@ def test_linspace(num, dtype, endpoint, data):
375
405
if num > 0 :
376
406
assert ah .equal (
377
407
out [0 ], ah .asarray (start , dtype = out .dtype )
378
- ), f"out[0]={ out [0 ]} , but should be { start = } [linspace()]"
379
- # TODO: array assertions ala test_arange
408
+ ), f"out[0]={ out [0 ]} , but should be { start = } [linspace({ stop = } , { num = } )]"
409
+
410
+ # TODO: array assertions ala test_arange
380
411
381
412
382
413
def make_one (dtype : DataType ) -> Scalar :
0 commit comments