@@ -19,13 +19,14 @@ def squeeze(x, /, axis):
19
19
20
20
"""
21
21
from collections import defaultdict
22
+ from copy import copy
22
23
from inspect import Parameter , Signature , signature
23
24
from itertools import chain
24
25
from types import FunctionType
25
- from typing import Callable , DefaultDict , Dict , List
26
+ from typing import Any , Callable , DefaultDict , Dict , List , Literal , Sequence , get_args
26
27
27
28
import pytest
28
- from hypothesis import given
29
+ from hypothesis import given , note
29
30
from hypothesis import strategies as st
30
31
31
32
from . import dtype_helpers as dh
@@ -38,17 +39,23 @@ def squeeze(x, /, axis):
38
39
39
40
pytestmark = pytest .mark .ci
40
41
41
-
42
- kind_to_str : Dict [Parameter , str ] = {
42
+ ParameterKind = Literal [
43
+ Parameter .POSITIONAL_ONLY ,
44
+ Parameter .VAR_POSITIONAL ,
45
+ Parameter .POSITIONAL_OR_KEYWORD ,
46
+ Parameter .KEYWORD_ONLY ,
47
+ Parameter .VAR_KEYWORD ,
48
+ ]
49
+ ALL_KINDS = get_args (ParameterKind )
50
+ VAR_KINDS = (Parameter .VAR_POSITIONAL , Parameter .VAR_KEYWORD )
51
+ kind_to_str : Dict [ParameterKind , str ] = {
43
52
Parameter .POSITIONAL_OR_KEYWORD : "normal argument" ,
44
53
Parameter .POSITIONAL_ONLY : "pos-only argument" ,
45
54
Parameter .KEYWORD_ONLY : "keyword-only argument" ,
46
55
Parameter .VAR_POSITIONAL : "star-args (i.e. *args) argument" ,
47
56
Parameter .VAR_KEYWORD : "star-kwargs (i.e. **kwargs) argument" ,
48
57
}
49
58
50
- VAR_KINDS = (Parameter .VAR_POSITIONAL , Parameter .VAR_KEYWORD )
51
-
52
59
53
60
def _test_inspectable_func (sig : Signature , stub_sig : Signature ):
54
61
params = list (sig .parameters .values ())
@@ -89,11 +96,12 @@ def _test_inspectable_func(sig: Signature, stub_sig: Signature):
89
96
], (
90
97
f"{ param .name } is a { kind_to_str [param .kind ]} , "
91
98
f"but should be a { f_stub_kind } "
92
- f"(or at least a { kind_to_str [Parameter .POSITIONAL_OR_KEYWORD ]} )"
99
+ f"(or at least a { kind_to_str [ParameterKind .POSITIONAL_OR_KEYWORD ]} )"
93
100
)
94
101
else :
95
102
pass
96
103
104
+
97
105
def shapes (** kw ) -> st .SearchStrategy [Shape ]:
98
106
if "min_side" not in kw .keys ():
99
107
kw ["min_side" ] = 1
@@ -111,7 +119,7 @@ def shapes(**kw) -> st.SearchStrategy[Shape]:
111
119
** {k : st .just (()) for k in ["__bool__" , "__int__" , "__index__" , "__float__" ]},
112
120
"sort" : shapes (min_dims = 1 ), # for axis=-1,
113
121
** {k : shapes (min_dims = 2 ) for k in matrixy_funcs },
114
- # Override for some matrixy functions
122
+ # Overwrite min_dims=2 shapes for some matrixy functions
115
123
"cross" : shapes (min_side = 3 , max_side = 3 , min_dims = 3 , max_dims = 3 ),
116
124
"outer" : shapes (min_dims = 1 , max_dims = 1 ),
117
125
},
@@ -128,50 +136,98 @@ def get_dtypes_strategy(func_name: str) -> st.SearchStrategy[DataType]:
128
136
return xps .scalar_dtypes ()
129
137
130
138
139
+ func_to_example_values : Dict [str , Dict [ParameterKind , Dict [str , Any ]]] = {
140
+ "broadcast_to" : {
141
+ Parameter .POSITIONAL_ONLY : {"x" : xp .asarray ([0 , 1 ])},
142
+ Parameter .POSITIONAL_OR_KEYWORD : {"shape" : (1 , 2 )},
143
+ },
144
+ "cholesky" : {
145
+ Parameter .POSITIONAL_ONLY : {"x" : xp .asarray ([[1.0 , 0.0 ], [0.0 , 1.0 ]])}
146
+ },
147
+ "inv" : {Parameter .POSITIONAL_ONLY : {"x" : xp .asarray ([[1.0 , 0.0 ], [0.0 , 1.0 ]])}},
148
+ }
149
+
150
+
151
+ def make_pretty_func (func_name : str , args : Sequence [Any ], kwargs : Dict [str , Any ]):
152
+ f_sig = f"{ func_name } ("
153
+ f_sig += ", " .join (str (a ) for a in args )
154
+ if len (kwargs ) != 0 :
155
+ if len (args ) != 0 :
156
+ f_sig += ", "
157
+ f_sig += ", " .join (f"{ k } ={ v } " for k , v in kwargs .items ())
158
+ f_sig += ")"
159
+ return f_sig
160
+
161
+
131
162
@given (data = st .data ())
132
163
def _test_uninspectable_func (func_name : str , func : Callable , stub_sig : Signature , data ):
133
- if func_name in ["cholesky" , "inv" ]:
134
- func (xp .asarray ([[1.0 , 0.0 ], [0.0 , 1.0 ]]))
135
- return
136
- elif func_name == "solve" :
137
- func (xp .asarray ([[1.0 , 2.0 ], [3.0 , 5.0 ]]), xp .asarray ([1.0 , 2.0 ]))
138
- return
139
-
140
- pos_argname_to_example_value = {}
141
- normal_argname_to_example_value = {}
142
- kw_argname_to_example_value = {}
143
- for stub_param in stub_sig .parameters .values ():
144
- if stub_param .name in ["x" , "x1" ]:
164
+ example_values : Dict [ParameterKind , Dict [str , Any ]] = func_to_example_values .get (
165
+ func_name , {}
166
+ )
167
+ for kind in ALL_KINDS :
168
+ example_values .setdefault (kind , {})
169
+
170
+ for param in stub_sig .parameters .values ():
171
+ for name_to_value in example_values .values ():
172
+ if param .name in name_to_value .keys ():
173
+ continue
174
+
175
+ if param .default != Parameter .empty :
176
+ example_value = param .default
177
+ elif param .name in ["x" , "x1" ]:
145
178
dtypes = get_dtypes_strategy (func_name )
146
179
shapes = func_to_shapes [func_name ]
147
180
example_value = data .draw (
148
- xps .arrays (dtype = dtypes , shape = shapes ), label = stub_param .name
181
+ xps .arrays (dtype = dtypes , shape = shapes ), label = param .name
149
182
)
150
- elif stub_param .name == "x2" :
151
- assert "x1" in pos_argname_to_example_value .keys () # sanity check
152
- x1 = pos_argname_to_example_value ["x1" ]
183
+ elif param .name == "x2" :
184
+ # sanity check
185
+ assert "x1" in example_values [Parameter .POSITIONAL_ONLY ].keys ()
186
+ x1 = example_values [Parameter .POSITIONAL_ONLY ]["x1" ]
153
187
example_value = data .draw (
154
188
xps .arrays (dtype = x1 .dtype , shape = x1 .shape ), label = "x2"
155
189
)
190
+ elif param .name == "axes" :
191
+ example_value = ()
192
+ elif param .name == "shape" :
193
+ example_value = ()
156
194
else :
157
- if stub_param .default != Parameter .empty :
158
- example_value = stub_param .default
159
- else :
160
- pytest .skip (f"No example value for argument '{ stub_param .name } '" )
161
-
162
- if stub_param .kind == Parameter .POSITIONAL_ONLY :
163
- pos_argname_to_example_value [stub_param .name ] = example_value
164
- elif stub_param .kind == Parameter .POSITIONAL_OR_KEYWORD :
165
- normal_argname_to_example_value [stub_param .name ] = example_value
166
- elif stub_param .kind == Parameter .KEYWORD_ONLY :
167
- kw_argname_to_example_value [stub_param .name ] = example_value
168
- else :
169
- pytest .skip ()
195
+ pytest .skip (f"No example value for argument '{ param .name } '" )
170
196
171
- if len (normal_argname_to_example_value ) == 0 :
172
- func (* pos_argname_to_example_value .values (), ** kw_argname_to_example_value )
197
+ if param .kind in VAR_KINDS :
198
+ pytest .skip ("TODO" )
199
+ example_values [param .kind ][param .name ] = example_value
200
+
201
+ if len (example_values [Parameter .POSITIONAL_OR_KEYWORD ]) == 0 :
202
+ f_func = make_pretty_func (
203
+ func_name ,
204
+ example_values [Parameter .POSITIONAL_ONLY ].values (),
205
+ example_values [Parameter .KEYWORD_ONLY ],
206
+ )
207
+ note (f"trying { f_func } " )
208
+ func (
209
+ * example_values [Parameter .POSITIONAL_ONLY ].values (),
210
+ ** example_values [Parameter .KEYWORD_ONLY ],
211
+ )
173
212
else :
174
- pass # TODO
213
+ either_argname_value_pairs = list (
214
+ example_values [Parameter .POSITIONAL_OR_KEYWORD ].items ()
215
+ )
216
+ n_either_args = len (either_argname_value_pairs )
217
+ for n_extra_args in reversed (range (n_either_args + 1 )):
218
+ extra_args = [v for _ , v in either_argname_value_pairs [:n_extra_args ]]
219
+ if n_extra_args < n_either_args :
220
+ extra_kwargs = dict (either_argname_value_pairs [n_extra_args :])
221
+ else :
222
+ extra_kwargs = {}
223
+ args = list (example_values [Parameter .POSITIONAL_ONLY ].values ())
224
+ args += extra_args
225
+ kwargs = copy (example_values [Parameter .KEYWORD_ONLY ])
226
+ if len (extra_kwargs ) != 0 :
227
+ kwargs .update (extra_kwargs )
228
+ f_func = make_pretty_func (func_name , args , kwargs )
229
+ note (f"trying { f_func } " )
230
+ func (* args , ** kwargs )
175
231
176
232
177
233
def _test_func_signature (
0 commit comments