24
24
dtypes_to_scalars ,
25
25
func_in_categories ,
26
26
func_out_categories ,
27
- binary_operators ,
28
- unary_operators ,
29
- operators_to_functions ,
27
+ binary_op_to_symbol ,
28
+ unary_op_to_symbol ,
29
+ op_to_func ,
30
30
)
31
31
32
32
33
33
def generate_params (
34
+ func_family : Literal ['elementwise' , 'operator' ],
34
35
in_nargs : int ,
35
36
out_category : Literal ['bool' , 'promoted' ],
36
37
) -> Iterator :
37
- funcs = [
38
- f for f in elementwise_functions .__all__
39
- if nargs (f ) == in_nargs and func_out_categories [f ] == out_category
40
- ]
41
- if in_nargs == 1 :
42
- for func in funcs :
43
- in_category = func_in_categories [func ]
44
- for in_dtype in category_to_dtypes [in_category ]:
45
- yield pytest .param (func , in_dtype , id = f"{ func } ({ in_dtype } )" )
38
+ if func_family == 'elementwise' :
39
+ funcs = [
40
+ f for f in elementwise_functions .__all__
41
+ if nargs (f ) == in_nargs and func_out_categories [f ] == out_category
42
+ ]
43
+ if in_nargs == 1 :
44
+ for func in funcs :
45
+ in_category = func_in_categories [func ]
46
+ for in_dtype in category_to_dtypes [in_category ]:
47
+ yield pytest .param (func , in_dtype , id = f"{ func } ({ in_dtype } )" )
48
+ else :
49
+ for func , ((d1 , d2 ), d3 ) in product (funcs , promotion_table .items ()):
50
+ if all (d in category_to_dtypes [func_in_categories [func ]] for d in (d1 , d2 )):
51
+ if out_category == 'bool' :
52
+ yield pytest .param (func , (d1 , d2 ), id = f"{ func } ({ d1 } , { d2 } )" )
53
+ else :
54
+ yield pytest .param (func , ((d1 , d2 ), d3 ), id = f"{ func } ({ d1 } , { d2 } ) -> { d3 } " )
46
55
else :
47
- for func , ((d1 , d2 ), d3 ) in product (funcs , promotion_table .items ()):
48
- if all (d in category_to_dtypes [func_in_categories [func ]] for d in (d1 , d2 )):
49
- if out_category == 'bool' :
50
- yield pytest .param (func , (d1 , d2 ), id = f"{ func } ({ d1 } , { d2 } )" )
51
- else :
52
- yield pytest .param (func , ((d1 , d2 ), d3 ), id = f"{ func } ({ d1 } , { d2 } ) -> { d3 } " )
56
+ if in_nargs == 1 :
57
+ for op , symbol in unary_op_to_symbol .items ():
58
+ func = op_to_func [op ]
59
+ if func_out_categories [func ] == out_category :
60
+ in_category = func_in_categories [func ]
61
+ for in_dtype in category_to_dtypes [in_category ]:
62
+ yield pytest .param (op , symbol , in_dtype , id = f"{ op } ({ in_dtype } )" )
63
+ else :
64
+ for op , symbol in binary_op_to_symbol .items ():
65
+ if op == "__matmul__" :
66
+ continue
67
+ func = op_to_func [op ]
68
+ if func_out_categories [func ] == out_category :
69
+ in_category = func_in_categories [func ]
70
+ for ((d1 , d2 ), d3 ) in promotion_table .items ():
71
+ if all (d in category_to_dtypes [in_category ] for d in (d1 , d2 )):
72
+ if out_category == 'bool' :
73
+ yield pytest .param (op , symbol , (d1 , d2 ), id = f"{ op } ({ d1 } , { d2 } )" )
74
+ else :
75
+ if d1 == d3 :
76
+ yield pytest .param (op , symbol , ((d1 , d2 ), d3 ), id = f"{ op } ({ d1 } , { d2 } ) -> { d3 } " )
77
+
53
78
54
79
55
80
# TODO: These functions should still do type promotion internally, but we do
@@ -59,7 +84,7 @@ def generate_params(
59
84
# array(1.00000001, dtype=float64)) will be wrong if the float64 array is
60
85
# downcast to float32. See for instance
61
86
# https://github.com/numpy/numpy/issues/10322.
62
- @pytest .mark .parametrize ('func, dtypes' , generate_params (in_nargs = 2 , out_category = 'bool' ))
87
+ @pytest .mark .parametrize ('func, dtypes' , generate_params ('elementwise' , in_nargs = 2 , out_category = 'bool' ))
63
88
# The spec explicitly requires type promotion to work for shape 0
64
89
# Unfortunately, data(), isn't compatible with @example, so this is commented
65
90
# out for now.
@@ -91,7 +116,7 @@ def test_elementwise_two_args_bool_type_promotion(func, two_shapes, dtypes, data
91
116
92
117
# TODO: Extend this to all functions (not just elementwise), and handle
93
118
# functions that take more than 2 args
94
- @pytest .mark .parametrize ('func, dtypes' , generate_params (in_nargs = 2 , out_category = 'promoted' ))
119
+ @pytest .mark .parametrize ('func, dtypes' , generate_params ('elementwise' , in_nargs = 2 , out_category = 'promoted' ))
95
120
# The spec explicitly requires type promotion to work for shape 0
96
121
# Unfortunately, data(), isn't compatible with @example, so this is commented
97
122
# out for now.
@@ -124,7 +149,7 @@ def test_elementwise_two_args_promoted_type_promotion(func,
124
149
125
150
# TODO: Extend this to all functions (not just elementwise), and handle
126
151
# functions that take more than 2 args
127
- @pytest .mark .parametrize ('func, dtype' , generate_params (in_nargs = 1 , out_category = 'bool' ))
152
+ @pytest .mark .parametrize ('func, dtype' , generate_params ('elementwise' , in_nargs = 1 , out_category = 'bool' ))
128
153
# The spec explicitly requires type promotion to work for shape 0
129
154
# Unfortunately, data(), isn't compatible with @example, so this is commented
130
155
# out for now.
@@ -147,7 +172,7 @@ def test_elementwise_one_arg_bool(func, shape, dtype, data):
147
172
148
173
# TODO: Extend this to all functions (not just elementwise), and handle
149
174
# functions that take more than 2 args
150
- @pytest .mark .parametrize ('func,dtype' , generate_params (in_nargs = 1 , out_category = 'promoted' ))
175
+ @pytest .mark .parametrize ('func,dtype' , generate_params ('elementwise' , in_nargs = 1 , out_category = 'promoted' ))
151
176
# The spec explicitly requires type promotion to work for shape 0
152
177
# Unfortunately, data(), isn't compatible with @example, so this is commented
153
178
# out for now.
@@ -169,29 +194,28 @@ def test_elementwise_one_arg_type_promotion(func, shape,
169
194
170
195
assert res .dtype == dtype , f"{ func } ({ dtype } ) returned to { res .dtype } , should have promoted to { dtype } (shape={ shape } )"
171
196
172
- unary_operators_promoted = [unary_op_name for unary_op_name in sorted (unary_operators )
173
- if func_out_categories [operators_to_functions [unary_op_name ]] == 'promoted' ]
197
+ unary_operators_promoted = [unary_op_name for unary_op_name in sorted (unary_op_to_symbol )
198
+ if func_out_categories [op_to_func [unary_op_name ]] == 'promoted' ]
174
199
operator_one_arg_promoted_parametrize_inputs = [(unary_op_name , dtypes )
175
200
for unary_op_name in unary_operators_promoted
176
- for dtypes in category_to_dtypes [func_in_categories [operators_to_functions [unary_op_name ]]]
201
+ for dtypes in category_to_dtypes [func_in_categories [op_to_func [unary_op_name ]]]
177
202
]
178
203
operator_one_arg_promoted_parametrize_ids = [f"{ n } -{ d } " for n , d
179
204
in operator_one_arg_promoted_parametrize_inputs ]
180
205
181
206
182
207
# TODO: Extend this to all functions (not just elementwise), and handle
183
208
# functions that take more than 2 args
184
- @pytest .mark .parametrize ('unary_op_name,dtype' ,
185
- operator_one_arg_promoted_parametrize_inputs ,
186
- ids = operator_one_arg_promoted_parametrize_ids )
209
+ @pytest .mark .parametrize (
210
+ 'unary_op_name, unary_op, dtype' ,
211
+ generate_params ('operator' , in_nargs = 1 , out_category = 'promoted' ),
212
+ )
187
213
# The spec explicitly requires type promotion to work for shape 0
188
214
# Unfortunately, data(), isn't compatible with @example, so this is commented
189
215
# out for now.
190
216
# @example(shape=(0,))
191
217
@given (shape = shapes , data = data ())
192
- def test_operator_one_arg_type_promotion (unary_op_name , shape , dtype , data ):
193
- unary_op = unary_operators [unary_op_name ]
194
-
218
+ def test_operator_one_arg_type_promotion (unary_op_name , unary_op , shape , dtype , data ):
195
219
fillvalue = data .draw (scalars (just (dtype )))
196
220
197
221
if isinstance (dtype , _array_module ._UndefinedStub ):
@@ -211,24 +235,22 @@ def test_operator_one_arg_type_promotion(unary_op_name, shape, dtype, data):
211
235
assert res .dtype == dtype , f"{ unary_op } ({ dtype } ) returned to { res .dtype } , should have promoted to { dtype } (shape={ shape } )"
212
236
213
237
# Note: the boolean binary operators do not have reversed or in-place variants
214
- binary_operators_bool = [binary_op_name for binary_op_name in sorted (set (binary_operators ) - {'__matmul__' })
215
- if func_out_categories [operators_to_functions [binary_op_name ]] == 'bool' ]
238
+ binary_operators_bool = [binary_op_name for binary_op_name in sorted (set (binary_op_to_symbol ) - {'__matmul__' })
239
+ if func_out_categories [op_to_func [binary_op_name ]] == 'bool' ]
216
240
operator_two_args_bool_parametrize_inputs = [(binary_op_name , dtypes )
217
241
for binary_op_name in binary_operators_bool
218
242
for dtypes in promotion_table .keys ()
219
- if all (d in category_to_dtypes [func_in_categories [operators_to_functions [binary_op_name ]]] for d in dtypes )
243
+ if all (d in category_to_dtypes [func_in_categories [op_to_func [binary_op_name ]]] for d in dtypes )
220
244
]
221
245
operator_two_args_bool_parametrize_ids = [f"{ n } -{ d1 } -{ d2 } " for n , (d1 , d2 )
222
246
in operator_two_args_bool_parametrize_inputs ]
223
247
224
- @pytest .mark .parametrize ('binary_op_name,dtypes' ,
225
- operator_two_args_bool_parametrize_inputs ,
226
- ids = operator_two_args_bool_parametrize_ids )
248
+ @pytest .mark .parametrize (
249
+ 'binary_op_name, binary_op, dtypes' ,
250
+ generate_params ('operator' , in_nargs = 2 , out_category = 'bool' )
251
+ )
227
252
@given (two_shapes = two_mutually_broadcastable_shapes , data = data ())
228
- def test_operator_two_args_bool_promotion (binary_op_name , dtypes , two_shapes ,
229
- data ):
230
- binary_op = binary_operators [binary_op_name ]
231
-
253
+ def test_operator_two_args_bool_promotion (binary_op_name , binary_op , dtypes , two_shapes , data ):
232
254
dtype1 , dtype2 = dtypes
233
255
fillvalue1 = data .draw (scalars (just (dtype1 )))
234
256
fillvalue2 = data .draw (scalars (just (dtype2 )))
@@ -247,24 +269,19 @@ def test_operator_two_args_bool_promotion(binary_op_name, dtypes, two_shapes,
247
269
248
270
assert res .dtype == bool_dtype , f"{ dtype1 } { binary_op } { dtype2 } promoted to { res .dtype } , should have promoted to bool (shape={ shape1 , shape2 } )"
249
271
250
- binary_operators_promoted = [binary_op_name for binary_op_name in sorted (set (binary_operators ) - {'__matmul__' })
251
- if func_out_categories [operators_to_functions [binary_op_name ]] == 'promoted' ]
272
+ binary_operators_promoted = [binary_op_name for binary_op_name in sorted (set (binary_op_to_symbol ) - {'__matmul__' })
273
+ if func_out_categories [op_to_func [binary_op_name ]] == 'promoted' ]
252
274
operator_two_args_promoted_parametrize_inputs = [(binary_op_name , dtypes )
253
275
for binary_op_name in binary_operators_promoted
254
276
for dtypes in promotion_table .items ()
255
- if all (d in category_to_dtypes [func_in_categories [operators_to_functions [binary_op_name ]]] for d in dtypes [0 ])
277
+ if all (d in category_to_dtypes [func_in_categories [op_to_func [binary_op_name ]]] for d in dtypes [0 ])
256
278
]
257
279
operator_two_args_promoted_parametrize_ids = [f"{ n } -{ d1 } -{ d2 } " for n , ((d1 , d2 ), _ )
258
280
in operator_two_args_promoted_parametrize_inputs ]
259
281
260
- @pytest .mark .parametrize ('binary_op_name,dtypes' ,
261
- operator_two_args_promoted_parametrize_inputs ,
262
- ids = operator_two_args_promoted_parametrize_ids )
282
+ @pytest .mark .parametrize ('binary_op_name, binary_op, dtypes' , generate_params ('operator' , in_nargs = 2 , out_category = 'promoted' ))
263
283
@given (two_shapes = two_mutually_broadcastable_shapes , data = data ())
264
- def test_operator_two_args_promoted_promotion (binary_op_name , dtypes , two_shapes ,
265
- data ):
266
- binary_op = binary_operators [binary_op_name ]
267
-
284
+ def test_operator_two_args_promoted_promotion (binary_op_name , binary_op , dtypes , two_shapes , data ):
268
285
(dtype1 , dtype2 ), res_dtype = dtypes
269
286
fillvalue1 = data .draw (scalars (just (dtype1 )))
270
287
if binary_op_name in ['>>' , '<<' ]:
@@ -292,14 +309,10 @@ def test_operator_two_args_promoted_promotion(binary_op_name, dtypes, two_shapes
292
309
operator_inplace_two_args_promoted_parametrize_ids = ['-' .join ((n [:2 ] + 'i' + n [2 :], str (d1 ), str (d2 ))) for n , ((d1 , d2 ), _ )
293
310
in operator_inplace_two_args_promoted_parametrize_inputs ]
294
311
295
- @pytest .mark .parametrize ('binary_op_name,dtypes' ,
296
- operator_inplace_two_args_promoted_parametrize_inputs ,
297
- ids = operator_inplace_two_args_promoted_parametrize_ids )
312
+ @pytest .mark .parametrize ('binary_op_name, binary_op, dtypes' , generate_params ('operator' , in_nargs = 2 , out_category = 'promoted' ))
298
313
@given (two_shapes = two_broadcastable_shapes (), data = data ())
299
- def test_operator_inplace_two_args_promoted_promotion (binary_op_name , dtypes , two_shapes ,
314
+ def test_operator_inplace_two_args_promoted_promotion (binary_op_name , binary_op , dtypes , two_shapes ,
300
315
data ):
301
- binary_op = binary_operators [binary_op_name ]
302
-
303
316
(dtype1 , dtype2 ), res_dtype = dtypes
304
317
fillvalue1 = data .draw (scalars (just (dtype1 )))
305
318
if binary_op_name in ['>>' , '<<' ]:
@@ -326,8 +339,8 @@ def test_operator_inplace_two_args_promoted_promotion(binary_op_name, dtypes, tw
326
339
327
340
scalar_promotion_parametrize_inputs = [
328
341
pytest .param (binary_op_name , dtype , scalar_type , id = f"{ binary_op_name } -{ dtype } -{ scalar_type .__name__ } " )
329
- for binary_op_name in sorted (set (binary_operators ) - {'__matmul__' })
330
- for dtype in category_to_dtypes [func_in_categories [operators_to_functions [binary_op_name ]]]
342
+ for binary_op_name in sorted (set (binary_op_to_symbol ) - {'__matmul__' })
343
+ for dtype in category_to_dtypes [func_in_categories [op_to_func [binary_op_name ]]]
331
344
for scalar_type in dtypes_to_scalars [dtype ]
332
345
]
333
346
@@ -339,7 +352,7 @@ def test_operator_scalar_promotion(binary_op_name, dtype, scalar_type,
339
352
"""
340
353
See https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars
341
354
"""
342
- binary_op = binary_operators [binary_op_name ]
355
+ binary_op = binary_op_to_symbol [binary_op_name ]
343
356
if binary_op == '@' :
344
357
pytest .skip ("matmul (@) is not supported for scalars" )
345
358
0 commit comments