1
+ """
2
+ We're not interested in being 100% strict - instead we focus on areas which
3
+ could affect interop, e.g. with
4
+
5
+ def add(x1, x2, /):
6
+ ...
7
+
8
+ x1 and x2 don't need to be pos-only for the purposes of interoperability, but with
9
+
10
+ def squeeze(x, /, axis):
11
+ ...
12
+
13
+ axis has to be pos-or-keyword to support both styles
14
+
15
+ >>> squeeze(x, 0)
16
+ ...
17
+ >>> squeeze(x, axis=0)
18
+ ...
19
+
20
+ """
21
+ from collections import defaultdict
1
22
from inspect import Parameter , Signature , signature
23
+ from itertools import chain
2
24
from types import FunctionType
3
- from typing import Callable , Dict
25
+ from typing import Callable , DefaultDict , Dict , List
4
26
5
27
import pytest
6
28
from hypothesis import given
29
+ from hypothesis import strategies as st
7
30
31
+ from . import dtype_helpers as dh
8
32
from . import hypothesis_helpers as hh
9
33
from . import xps
34
+ from ._array_module import _UndefinedStub
10
35
from ._array_module import mod as xp
11
36
from .stubs import array_methods , category_to_funcs , extension_to_funcs
37
+ from .typing import DataType , Shape
12
38
13
39
pytestmark = pytest .mark .ci
14
40
41
+
15
42
kind_to_str : Dict [Parameter , str ] = {
16
43
Parameter .POSITIONAL_OR_KEYWORD : "normal argument" ,
17
44
Parameter .POSITIONAL_ONLY : "pos-only argument" ,
20
47
Parameter .VAR_KEYWORD : "star-kwargs (i.e. **kwargs) argument" ,
21
48
}
22
49
50
+ VAR_KINDS = (Parameter .VAR_POSITIONAL , Parameter .VAR_KEYWORD )
23
51
24
- def _test_signature (
25
- func : Callable , stub : FunctionType , ignore_first_stub_param : bool = False
26
- ):
27
- """
28
- Signature of function is correct enough to not affect interoperability
29
-
30
- We're not interested in being 100% strict - instead we focus on areas which
31
- could affect interop, e.g. with
32
-
33
- def add(x1, x2, /):
34
- ...
35
52
36
- x1 and x2 don't need to be pos-only for the purposes of interoperability, but with
37
-
38
- def squeeze(x, /, axis):
39
- ...
40
-
41
- axis has to be pos-or-keyword to support both styles
42
-
43
- >>> squeeze(x, 0)
44
- ...
45
- >>> squeeze(x, axis=0)
46
- ...
47
-
48
- """
49
- try :
50
- sig = signature (func )
51
- except ValueError :
52
- pytest .skip (
53
- msg = f"type({ stub .__name__ } )={ type (func )} not supported by inspect.signature()"
54
- )
53
+ def _test_inspectable_func (sig : Signature , stub_sig : Signature ):
55
54
params = list (sig .parameters .values ())
56
-
57
- stub_sig = signature (stub )
58
55
stub_params = list (stub_sig .parameters .values ())
59
- if ignore_first_stub_param :
60
- stub_params = stub_params [1 :]
61
- stub = Signature (
62
- parameters = stub_params , return_annotation = stub_sig .return_annotation
63
- )
64
-
65
56
# We're not interested if the array module has additional arguments, so we
66
57
# only iterate through the arguments listed in the spec.
67
58
for i , stub_param in enumerate (stub_params ):
68
- assert (
69
- len (params ) >= i + 1
70
- ), f"Argument '{ stub_param .name } ' missing from signature"
71
- param = params [i ]
59
+ if sig is not None :
60
+ assert (
61
+ len (params ) >= i + 1
62
+ ), f"Argument '{ stub_param .name } ' missing from signature"
63
+ param = params [i ]
72
64
73
65
# We're not interested in the name if it isn't actually used
74
- if stub_param .kind not in [
66
+ if sig is not None and stub_param .kind not in [
75
67
Parameter .POSITIONAL_ONLY ,
76
- Parameter .VAR_POSITIONAL ,
77
- Parameter .VAR_KEYWORD ,
68
+ * VAR_KINDS ,
78
69
]:
79
70
assert (
80
71
param .name == stub_param .name
81
72
), f"Expected argument '{ param .name } ' to be named '{ stub_param .name } '"
82
73
83
- if (
84
- stub_param .name in ["x" , "x1" , "x2" ]
85
- and stub_param .kind != Parameter .POSITIONAL_ONLY
86
- ):
87
- pytest .skip (
88
- f"faulty spec - argument { stub_param .name } should be a "
89
- f"{ kind_to_str [Parameter .POSITIONAL_ONLY ]} "
90
- )
91
- f_kind = kind_to_str [param .kind ]
92
74
f_stub_kind = kind_to_str [stub_param .kind ]
93
- if stub_param .kind in [
94
- Parameter . POSITIONAL_OR_KEYWORD ,
95
- Parameter . VAR_POSITIONAL ,
96
- Parameter . VAR_KEYWORD ,
97
- ]:
98
- assert (
99
- param . kind == stub_param . kind
100
- ), f" { param . name } is a { f_kind } , but should be a { f_stub_kind } "
75
+ if stub_param .kind in [Parameter . POSITIONAL_OR_KEYWORD , * VAR_KINDS ]:
76
+ if sig is not None :
77
+ assert param . kind == stub_param . kind , (
78
+ f" { param . name } is a { kind_to_str [ param . kind ] } , "
79
+ f"but should be a { f_stub_kind } "
80
+ )
81
+ else :
82
+ pass
101
83
else :
102
84
# TODO: allow for kw-only args to be out-of-order
103
- assert param .kind in [stub_param .kind , Parameter .POSITIONAL_OR_KEYWORD ], (
104
- f"{ param .name } is a { f_kind } , "
105
- f"but should be a { f_stub_kind } "
106
- f"(or at least a { kind_to_str [Parameter .POSITIONAL_OR_KEYWORD ]} )"
85
+ if sig is not None :
86
+ assert param .kind in [
87
+ stub_param .kind ,
88
+ Parameter .POSITIONAL_OR_KEYWORD ,
89
+ ], (
90
+ f"{ param .name } is a { kind_to_str [param .kind ]} , "
91
+ f"but should be a { f_stub_kind } "
92
+ f"(or at least a { kind_to_str [Parameter .POSITIONAL_OR_KEYWORD ]} )"
93
+ )
94
+ else :
95
+ pass
96
+
97
+ def shapes (** kw ) -> st .SearchStrategy [Shape ]:
98
+ if "min_side" not in kw .keys ():
99
+ kw ["min_side" ] = 1
100
+ return hh .shapes (** kw )
101
+
102
+
103
+ matrixy_funcs : List [str ] = [
104
+ f .__name__
105
+ for f in chain (category_to_funcs ["linear_algebra" ], extension_to_funcs ["linalg" ])
106
+ ]
107
+ matrixy_funcs += ["__matmul__" , "triu" , "tril" ]
108
+ func_to_shapes : DefaultDict [str , st .SearchStrategy [Shape ]] = defaultdict (
109
+ shapes ,
110
+ {
111
+ ** {k : st .just (()) for k in ["__bool__" , "__int__" , "__index__" , "__float__" ]},
112
+ "sort" : shapes (min_dims = 1 ), # for axis=-1,
113
+ ** {k : shapes (min_dims = 2 ) for k in matrixy_funcs },
114
+ # Override for some matrixy functions
115
+ "cross" : shapes (min_side = 3 , max_side = 3 , min_dims = 3 , max_dims = 3 ),
116
+ "outer" : shapes (min_dims = 1 , max_dims = 1 ),
117
+ },
118
+ )
119
+
120
+
121
+ def get_dtypes_strategy (func_name : str ) -> st .SearchStrategy [DataType ]:
122
+ if func_name in dh .func_in_dtypes .keys ():
123
+ dtypes = dh .func_in_dtypes [func_name ]
124
+ if hh .FILTER_UNDEFINED_DTYPES :
125
+ dtypes = [d for d in dtypes if not isinstance (d , _UndefinedStub )]
126
+ return st .sampled_from (dtypes )
127
+ else :
128
+ return xps .scalar_dtypes ()
129
+
130
+
131
+ @given (data = st .data ())
132
+ def _test_uninspectable_func (func_name : str , func : Callable , stub_sig : Signature , data ):
133
+ if func_name in ["cholesky" , "inv" ]:
134
+ func (xp .asarray ([[1.0 , 0.0 ], [0.0 , 1.0 ]]))
135
+ return
136
+ elif func_name == "solve" :
137
+ func (xp .asarray ([[1.0 , 2.0 ], [3.0 , 5.0 ]]), xp .asarray ([1.0 , 2.0 ]))
138
+ return
139
+
140
+ pos_argname_to_example_value = {}
141
+ normal_argname_to_example_value = {}
142
+ kw_argname_to_example_value = {}
143
+ for stub_param in stub_sig .parameters .values ():
144
+ if stub_param .name in ["x" , "x1" ]:
145
+ dtypes = get_dtypes_strategy (func_name )
146
+ shapes = func_to_shapes [func_name ]
147
+ example_value = data .draw (
148
+ xps .arrays (dtype = dtypes , shape = shapes ), label = stub_param .name
107
149
)
150
+ elif stub_param .name == "x2" :
151
+ assert "x1" in pos_argname_to_example_value .keys () # sanity check
152
+ x1 = pos_argname_to_example_value ["x1" ]
153
+ example_value = data .draw (
154
+ xps .arrays (dtype = x1 .dtype , shape = x1 .shape ), label = "x2"
155
+ )
156
+ else :
157
+ if stub_param .default != Parameter .empty :
158
+ example_value = stub_param .default
159
+ else :
160
+ pytest .skip (f"No example value for argument '{ stub_param .name } '" )
161
+
162
+ if stub_param .kind == Parameter .POSITIONAL_ONLY :
163
+ pos_argname_to_example_value [stub_param .name ] = example_value
164
+ elif stub_param .kind == Parameter .POSITIONAL_OR_KEYWORD :
165
+ normal_argname_to_example_value [stub_param .name ] = example_value
166
+ elif stub_param .kind == Parameter .KEYWORD_ONLY :
167
+ kw_argname_to_example_value [stub_param .name ] = example_value
168
+ else :
169
+ pytest .skip ()
170
+
171
+ if len (normal_argname_to_example_value ) == 0 :
172
+ func (* pos_argname_to_example_value .values (), ** kw_argname_to_example_value )
173
+ else :
174
+ pass # TODO
175
+
176
+
177
+ def _test_func_signature (
178
+ func : Callable , stub : FunctionType , ignore_first_stub_param : bool = False
179
+ ):
180
+ stub_sig = signature (stub )
181
+ if ignore_first_stub_param :
182
+ stub_params = list (stub_sig .parameters .values ())
183
+ del stub_params [0 ]
184
+ stub_sig = Signature (
185
+ parameters = stub_params , return_annotation = stub_sig .return_annotation
186
+ )
187
+
188
+ try :
189
+ sig = signature (func )
190
+ _test_inspectable_func (sig , stub_sig )
191
+ except ValueError :
192
+ _test_uninspectable_func (stub .__name__ , func , stub_sig )
108
193
109
194
110
195
@pytest .mark .parametrize (
@@ -115,7 +200,7 @@ def squeeze(x, /, axis):
115
200
def test_func_signature (stub : FunctionType ):
116
201
assert hasattr (xp , stub .__name__ ), f"{ stub .__name__ } not found in array module"
117
202
func = getattr (xp , stub .__name__ )
118
- _test_signature (func , stub )
203
+ _test_func_signature (func , stub )
119
204
120
205
121
206
extension_and_stub_params = []
@@ -134,13 +219,16 @@ def test_extension_func_signature(extension: str, stub: FunctionType):
134
219
mod , stub .__name__
135
220
), f"{ stub .__name__ } not found in { extension } extension"
136
221
func = getattr (mod , stub .__name__ )
137
- _test_signature (func , stub )
222
+ _test_func_signature (func , stub )
138
223
139
224
140
225
@pytest .mark .parametrize ("stub" , array_methods , ids = lambda f : f .__name__ )
141
- @given (x = xps .arrays (dtype = xps .scalar_dtypes (), shape = hh .shapes ()))
142
- def test_array_method_signature (stub : FunctionType , x ):
226
+ @given (data = st .data ())
227
+ def test_array_method_signature (stub : FunctionType , data ):
228
+ dtypes = get_dtypes_strategy (stub .__name__ )
229
+ shapes = func_to_shapes [stub .__name__ ]
230
+ x = data .draw (xps .arrays (dtype = dtypes , shape = shapes ), label = "x" )
143
231
assert hasattr (x , stub .__name__ ), f"{ stub .__name__ } not found in array object { x !r} "
144
232
method = getattr (x , stub .__name__ )
145
233
# Ignore 'self' arg in stub, which won't be present in instantiated objects.
146
- _test_signature (method , stub , ignore_first_stub_param = True )
234
+ _test_func_signature (method , stub , ignore_first_stub_param = True )
0 commit comments