@@ -18,15 +18,13 @@ def squeeze(x, /, axis):
18
18
...
19
19
20
20
"""
21
- from collections import defaultdict
22
21
from copy import copy
23
22
from inspect import Parameter , Signature , signature
24
23
from itertools import chain
25
24
from types import FunctionType
26
- from typing import Any , Callable , DefaultDict , Dict , List , Literal , Sequence , get_args
25
+ from typing import Any , Callable , Dict , List , Literal , Sequence , get_args
27
26
28
27
import pytest
29
- from hypothesis import given , note
30
28
from hypothesis import strategies as st
31
29
32
30
from . import dtype_helpers as dh
@@ -35,7 +33,7 @@ def squeeze(x, /, axis):
35
33
from ._array_module import _UndefinedStub
36
34
from ._array_module import mod as xp
37
35
from .stubs import array_methods , category_to_funcs , extension_to_funcs
38
- from .typing import DataType , Shape
36
+ from .typing import DataType
39
37
40
38
pytestmark = pytest .mark .ci
41
39
@@ -53,7 +51,7 @@ def squeeze(x, /, axis):
53
51
Parameter .POSITIONAL_ONLY : "pos-only argument" ,
54
52
Parameter .KEYWORD_ONLY : "keyword-only argument" ,
55
53
Parameter .VAR_POSITIONAL : "star-args (i.e. *args) argument" ,
56
- Parameter .VAR_KEYWORD : "star-kwargs (i.e. **kwargs ) argument" ,
54
+ Parameter .VAR_KEYWORD : "star-kwonly (i.e. **kwonly ) argument" ,
57
55
}
58
56
59
57
@@ -63,14 +61,13 @@ def _test_inspectable_func(sig: Signature, stub_sig: Signature):
63
61
# We're not interested if the array module has additional arguments, so we
64
62
# only iterate through the arguments listed in the spec.
65
63
for i , stub_param in enumerate (stub_params ):
66
- if sig is not None :
67
- assert (
68
- len (params ) >= i + 1
69
- ), f"Argument '{ stub_param .name } ' missing from signature"
70
- param = params [i ]
64
+ assert (
65
+ len (params ) >= i + 1
66
+ ), f"Argument '{ stub_param .name } ' missing from signature"
67
+ param = params [i ]
71
68
72
69
# We're not interested in the name if it isn't actually used
73
- if sig is not None and stub_param .kind not in [
70
+ if stub_param .kind not in [
74
71
Parameter .POSITIONAL_ONLY ,
75
72
* VAR_KINDS ,
76
73
]:
@@ -80,50 +77,17 @@ def _test_inspectable_func(sig: Signature, stub_sig: Signature):
80
77
81
78
f_stub_kind = kind_to_str [stub_param .kind ]
82
79
if stub_param .kind in [Parameter .POSITIONAL_OR_KEYWORD , * VAR_KINDS ]:
83
- if sig is not None :
84
- assert param .kind == stub_param .kind , (
85
- f"{ param .name } is a { kind_to_str [param .kind ]} , "
86
- f"but should be a { f_stub_kind } "
87
- )
88
- else :
89
- pass
80
+ assert param .kind == stub_param .kind , (
81
+ f"{ param .name } is a { kind_to_str [param .kind ]} , "
82
+ f"but should be a { f_stub_kind } "
83
+ )
90
84
else :
91
85
# TODO: allow for kw-only args to be out-of-order
92
- if sig is not None :
93
- assert param .kind in [
94
- stub_param .kind ,
95
- Parameter .POSITIONAL_OR_KEYWORD ,
96
- ], (
97
- f"{ param .name } is a { kind_to_str [param .kind ]} , "
98
- f"but should be a { f_stub_kind } "
99
- f"(or at least a { kind_to_str [ParameterKind .POSITIONAL_OR_KEYWORD ]} )"
100
- )
101
- else :
102
- pass
103
-
104
-
105
- def shapes (** kw ) -> st .SearchStrategy [Shape ]:
106
- if "min_side" not in kw .keys ():
107
- kw ["min_side" ] = 1
108
- return hh .shapes (** kw )
109
-
110
-
111
- matrixy_funcs : List [str ] = [
112
- f .__name__
113
- for f in chain (category_to_funcs ["linear_algebra" ], extension_to_funcs ["linalg" ])
114
- ]
115
- matrixy_funcs += ["__matmul__" , "triu" , "tril" ]
116
- func_to_shapes : DefaultDict [str , st .SearchStrategy [Shape ]] = defaultdict (
117
- shapes ,
118
- {
119
- ** {k : st .just (()) for k in ["__bool__" , "__int__" , "__index__" , "__float__" ]},
120
- "sort" : shapes (min_dims = 1 ), # for axis=-1,
121
- ** {k : shapes (min_dims = 2 ) for k in matrixy_funcs },
122
- # Overwrite min_dims=2 shapes for some matrixy functions
123
- "cross" : shapes (min_side = 3 , max_side = 3 , min_dims = 3 , max_dims = 3 ),
124
- "outer" : shapes (min_dims = 1 , max_dims = 1 ),
125
- },
126
- )
86
+ assert param .kind in [stub_param .kind , Parameter .POSITIONAL_OR_KEYWORD ,], (
87
+ f"{ param .name } is a { kind_to_str [param .kind ]} , "
88
+ f"but should be a { f_stub_kind } "
89
+ f"(or at least a { kind_to_str [ParameterKind .POSITIONAL_OR_KEYWORD ]} )"
90
+ )
127
91
128
92
129
93
def get_dtypes_strategy (func_name : str ) -> st .SearchStrategy [DataType ]:
@@ -136,97 +100,93 @@ def get_dtypes_strategy(func_name: str) -> st.SearchStrategy[DataType]:
136
100
return xps .scalar_dtypes ()
137
101
138
102
139
- func_to_example_values : Dict [str , Dict [ParameterKind , Dict [str , Any ]]] = {
140
- "broadcast_to" : {
141
- Parameter .POSITIONAL_ONLY : {"x" : xp .asarray ([0 , 1 ])},
142
- Parameter .POSITIONAL_OR_KEYWORD : {"shape" : (1 , 2 )},
143
- },
144
- "cholesky" : {
145
- Parameter .POSITIONAL_ONLY : {"x" : xp .asarray ([[1.0 , 0.0 ], [0.0 , 1.0 ]])}
146
- },
147
- "inv" : {Parameter .POSITIONAL_ONLY : {"x" : xp .asarray ([[1.0 , 0.0 ], [0.0 , 1.0 ]])}},
148
- }
149
-
150
-
151
- def make_pretty_func (func_name : str , args : Sequence [Any ], kwargs : Dict [str , Any ]):
103
+ def make_pretty_func (func_name : str , args : Sequence [Any ], kwonly : Dict [str , Any ]):
152
104
f_sig = f"{ func_name } ("
153
105
f_sig += ", " .join (str (a ) for a in args )
154
- if len (kwargs ) != 0 :
106
+ if len (kwonly ) != 0 :
155
107
if len (args ) != 0 :
156
108
f_sig += ", "
157
- f_sig += ", " .join (f"{ k } ={ v } " for k , v in kwargs .items ())
109
+ f_sig += ", " .join (f"{ k } ={ v } " for k , v in kwonly .items ())
158
110
f_sig += ")"
159
111
return f_sig
160
112
161
113
162
- @given (data = st .data ())
163
- def _test_uninspectable_func (func_name : str , func : Callable , stub_sig : Signature , data ):
164
- example_values : Dict [ParameterKind , Dict [str , Any ]] = func_to_example_values .get (
165
- func_name , {}
166
- )
167
- for kind in ALL_KINDS :
168
- example_values .setdefault (kind , {})
114
+ matrixy_funcs : List [str ] = [
115
+ f .__name__
116
+ for f in chain (category_to_funcs ["linear_algebra" ], extension_to_funcs ["linalg" ])
117
+ ]
118
+ matrixy_funcs += ["__matmul__" , "triu" , "tril" ]
169
119
170
- for param in stub_sig .parameters .values ():
171
- for name_to_value in example_values .values ():
172
- if param .name in name_to_value .keys ():
173
- continue
174
120
175
- if param .default != Parameter .empty :
176
- example_value = param .default
121
+ def _test_uninspectable_func (func_name : str , func : Callable , stub_sig : Signature ):
122
+ skip_msg = (
123
+ f"Signature for { func_name } () is not inspectable "
124
+ "and is too troublesome to test for otherwise"
125
+ )
126
+ if func_name in [
127
+ "__bool__" ,
128
+ "__int__" ,
129
+ "__index__" ,
130
+ "__float__" ,
131
+ "pow" ,
132
+ "bitwise_left_shift" ,
133
+ "bitwise_right_shift" ,
134
+ "broadcast_to" ,
135
+ "permute_dims" ,
136
+ "sort" ,
137
+ * matrixy_funcs ,
138
+ ]:
139
+ pytest .skip (skip_msg )
140
+
141
+ param_to_value : Dict [Parameter , Any ] = {}
142
+ for param in stub_sig .parameters .values ():
143
+ if param .kind in VAR_KINDS :
144
+ pytest .skip (skip_msg )
145
+ elif param .default != Parameter .empty :
146
+ value = param .default
177
147
elif param .name in ["x" , "x1" ]:
178
148
dtypes = get_dtypes_strategy (func_name )
179
- shapes = func_to_shapes [func_name ]
180
- example_value = data .draw (
181
- xps .arrays (dtype = dtypes , shape = shapes ), label = param .name
182
- )
149
+ value = xps .arrays (dtype = dtypes , shape = hh .shapes (min_side = 1 )).example ()
183
150
elif param .name == "x2" :
184
151
# sanity check
185
- assert "x1" in example_values [Parameter .POSITIONAL_ONLY ].keys ()
186
- x1 = example_values [Parameter .POSITIONAL_ONLY ]["x1" ]
187
- example_value = data .draw (
188
- xps .arrays (dtype = x1 .dtype , shape = x1 .shape ), label = "x2"
189
- )
190
- elif param .name == "axes" :
191
- example_value = ()
192
- elif param .name == "shape" :
193
- example_value = ()
152
+ assert "x1" in [p .name for p in param_to_value .keys ()]
153
+ x1 = next (v for p , v in param_to_value .items () if p .name == "x1" )
154
+ value = xps .arrays (dtype = x1 .dtype , shape = x1 .shape ).example ()
194
155
else :
195
- pytest .skip (f"No example value for argument '{ param .name } '" )
196
-
197
- if param .kind in VAR_KINDS :
198
- pytest .skip ("TODO" )
199
- example_values [param .kind ][param .name ] = example_value
200
-
201
- if len (example_values [Parameter .POSITIONAL_OR_KEYWORD ]) == 0 :
202
- f_func = make_pretty_func (
203
- func_name ,
204
- example_values [Parameter .POSITIONAL_ONLY ].values (),
205
- example_values [Parameter .KEYWORD_ONLY ],
206
- )
207
- note (f"trying { f_func } " )
208
- func (
209
- * example_values [Parameter .POSITIONAL_ONLY ].values (),
210
- ** example_values [Parameter .KEYWORD_ONLY ],
211
- )
156
+ pytest .skip (skip_msg )
157
+ param_to_value [param ] = value
158
+
159
+ posonly : List [Any ] = [
160
+ v for p , v in param_to_value .items () if p .kind == Parameter .POSITIONAL_ONLY
161
+ ]
162
+ kwonly : Dict [str , Any ] = {
163
+ p .name : v for p , v in param_to_value .items () if p .kind == Parameter .KEYWORD_ONLY
164
+ }
165
+ if (
166
+ sum (p .kind == Parameter .POSITIONAL_OR_KEYWORD for p in param_to_value .keys ())
167
+ == 0
168
+ ):
169
+ f_func = make_pretty_func (func_name , posonly , kwonly )
170
+ print (f"trying { f_func } " )
171
+ func (* posonly , ** kwonly )
212
172
else :
213
173
either_argname_value_pairs = list (
214
- example_values [Parameter .POSITIONAL_OR_KEYWORD ].items ()
174
+ (p .name , v )
175
+ for p , v in param_to_value .items ()
176
+ if p .kind == Parameter .POSITIONAL_OR_KEYWORD
215
177
)
216
178
n_either_args = len (either_argname_value_pairs )
217
179
for n_extra_args in reversed (range (n_either_args + 1 )):
218
- extra_args = [v for _ , v in either_argname_value_pairs [:n_extra_args ]]
180
+ extra_posargs = [v for _ , v in either_argname_value_pairs [:n_extra_args ]]
219
181
if n_extra_args < n_either_args :
220
182
extra_kwargs = dict (either_argname_value_pairs [n_extra_args :])
221
183
else :
222
184
extra_kwargs = {}
223
- args = list (example_values [Parameter .POSITIONAL_ONLY ].values ())
224
- args += extra_args
225
- kwargs = copy (example_values [Parameter .KEYWORD_ONLY ])
226
- if len (extra_kwargs ) != 0 :
227
- kwargs .update (extra_kwargs )
185
+ args = copy (posonly )
186
+ args += extra_posargs
187
+ kwargs = {** kwonly , ** extra_kwargs }
228
188
f_func = make_pretty_func (func_name , args , kwargs )
229
- note (f"trying { f_func } " )
189
+ print (f"trying { f_func } " )
230
190
func (* args , ** kwargs )
231
191
232
192
@@ -279,11 +239,9 @@ def test_extension_func_signature(extension: str, stub: FunctionType):
279
239
280
240
281
241
@pytest .mark .parametrize ("stub" , array_methods , ids = lambda f : f .__name__ )
282
- @given (data = st .data ())
283
- def test_array_method_signature (stub : FunctionType , data ):
242
+ def test_array_method_signature (stub : FunctionType ):
284
243
dtypes = get_dtypes_strategy (stub .__name__ )
285
- shapes = func_to_shapes [stub .__name__ ]
286
- x = data .draw (xps .arrays (dtype = dtypes , shape = shapes ), label = "x" )
244
+ x = xps .arrays (dtype = dtypes , shape = hh .shapes (min_side = 1 )).example ()
287
245
assert hasattr (x , stub .__name__ ), f"{ stub .__name__ } not found in array object { x !r} "
288
246
method = getattr (x , stub .__name__ )
289
247
# Ignore 'self' arg in stub, which won't be present in instantiated objects.
0 commit comments