@@ -42,26 +42,26 @@ def generate_params(
42
42
yield pytest .param (func , ((d1 , d2 ), d3 ), id = f"{ func } ({ d1 } , { d2 } ) -> { d3 } " )
43
43
else :
44
44
if in_nargs == 1 :
45
- for op , symbol in dh .unary_op_to_symbol .items ():
46
- func = dh .op_to_func [op ]
45
+ for func , op in dh .unary_func_to_op .items ():
46
+ if func == "__matmul__" :
47
+ continue
47
48
if dh .func_out_categories [func ] == out_category :
48
49
in_category = dh .func_in_categories [func ]
49
50
for in_dtype in dh .category_to_dtypes [in_category ]:
50
- yield pytest .param (op , symbol , in_dtype , id = f"{ op } ({ in_dtype } )" )
51
+ yield pytest .param (func , op , in_dtype , id = f"{ func } ({ in_dtype } )" )
51
52
else :
52
- for op , symbol in dh .binary_op_to_symbol .items ():
53
- if op == "__matmul__" :
53
+ for func , op in dh .binary_func_to_op .items ():
54
+ if func == "__matmul__" :
54
55
continue
55
- func = dh .op_to_func [op ]
56
56
if dh .func_out_categories [func ] == out_category :
57
57
in_category = dh .func_in_categories [func ]
58
58
for ((d1 , d2 ), d3 ) in dh .promotion_table .items ():
59
59
if all (d in dh .category_to_dtypes [in_category ] for d in (d1 , d2 )):
60
60
if out_category == 'bool' :
61
- yield pytest .param (op , symbol , (d1 , d2 ), id = f"{ op } ({ d1 } , { d2 } )" )
61
+ yield pytest .param (func , op , (d1 , d2 ), id = f"{ func } ({ d1 } , { d2 } )" )
62
62
else :
63
63
if d1 == d3 :
64
- yield pytest .param (op , symbol , ((d1 , d2 ), d3 ), id = f"{ op } ({ d1 } , { d2 } ) -> { d3 } " )
64
+ yield pytest .param (func , op , ((d1 , d2 ), d3 ), id = f"{ func } ({ d1 } , { d2 } ) -> { d3 } " )
65
65
66
66
67
67
@@ -214,11 +214,11 @@ def test_operator_one_arg_return_promoted(unary_op_name, unary_op, shape, dtype,
214
214
assert res .dtype == dtype , f"{ unary_op } ({ dtype } ) returned to { res .dtype } , should have promoted to { dtype } (shape={ shape } )"
215
215
216
216
@pytest .mark .parametrize (
217
- 'binary_op_name, binary_op , dtypes' ,
217
+ 'func, op , dtypes' ,
218
218
generate_params ('operator' , in_nargs = 2 , out_category = 'bool' )
219
219
)
220
220
@given (two_shapes = hh .two_mutually_broadcastable_shapes , data = st .data ())
221
- def test_operator_two_args_return_bool (binary_op_name , binary_op , dtypes , two_shapes , data ):
221
+ def test_operator_two_args_return_bool (func , op , dtypes , two_shapes , data ):
222
222
dtype1 , dtype2 = dtypes
223
223
fillvalue1 = data .draw (hh .scalars (st .just (dtype1 )))
224
224
fillvalue2 = data .draw (hh .scalars (st .just (dtype2 )))
@@ -232,27 +232,17 @@ def test_operator_two_args_return_bool(binary_op_name, binary_op, dtypes, two_sh
232
232
a2 = ah .full (shape2 , fillvalue2 , dtype = dtype2 )
233
233
234
234
get_locals = lambda : dict (a1 = a1 , a2 = a2 )
235
- expression = f'a1 { binary_op } a2'
235
+ expression = f'a1 { op } a2'
236
236
res = eval (expression , get_locals ())
237
237
238
- assert res .dtype == xp .bool , f"{ dtype1 } { binary_op } { dtype2 } promoted to { res .dtype } , should have promoted to bool (shape={ shape1 , shape2 } )"
239
-
240
- binary_operators_promoted = [binary_op_name for binary_op_name in sorted (set (dh .binary_op_to_symbol ) - {'__matmul__' })
241
- if dh .func_out_categories [dh .op_to_func [binary_op_name ]] == 'promoted' ]
242
- operator_two_args_promoted_parametrize_inputs = [(binary_op_name , dtypes )
243
- for binary_op_name in binary_operators_promoted
244
- for dtypes in dh .promotion_table .items ()
245
- if all (d in dh .category_to_dtypes [dh .func_in_categories [dh .op_to_func [binary_op_name ]]] for d in dtypes [0 ])
246
- ]
247
- operator_two_args_promoted_parametrize_ids = [f"{ n } -{ d1 } -{ d2 } " for n , ((d1 , d2 ), _ )
248
- in operator_two_args_promoted_parametrize_inputs ]
238
+ assert res .dtype == xp .bool , f"{ dtype1 } { op } { dtype2 } promoted to { res .dtype } , should have promoted to bool (shape={ shape1 , shape2 } )"
249
239
250
- @pytest .mark .parametrize ('binary_op_name, binary_op , dtypes' , generate_params ('operator' , in_nargs = 2 , out_category = 'promoted' ))
240
+ @pytest .mark .parametrize ('func, op , dtypes' , generate_params ('operator' , in_nargs = 2 , out_category = 'promoted' ))
251
241
@given (two_shapes = hh .two_mutually_broadcastable_shapes , data = st .data ())
252
- def test_operator_two_args_return_promoted (binary_op_name , binary_op , dtypes , two_shapes , data ):
242
+ def test_operator_two_args_return_promoted (func , op , dtypes , two_shapes , data ):
253
243
(dtype1 , dtype2 ), res_dtype = dtypes
254
244
fillvalue1 = data .draw (hh .scalars (st .just (dtype1 )))
255
- if binary_op_name in ['>>' , '<<' ]:
245
+ if op in ['>>' , '<<' ]:
256
246
fillvalue2 = data .draw (hh .scalars (st .just (dtype2 )).filter (lambda x : x > 0 ))
257
247
else :
258
248
fillvalue2 = data .draw (hh .scalars (st .just (dtype2 )))
@@ -267,23 +257,18 @@ def test_operator_two_args_return_promoted(binary_op_name, binary_op, dtypes, tw
267
257
a2 = ah .full (shape2 , fillvalue2 , dtype = dtype2 )
268
258
269
259
get_locals = lambda : dict (a1 = a1 , a2 = a2 )
270
- expression = f'a1 { binary_op } a2'
260
+ expression = f'a1 { op } a2'
271
261
res = eval (expression , get_locals ())
272
262
273
- assert res .dtype == res_dtype , f"{ dtype1 } { binary_op } { dtype2 } promoted to { res .dtype } , should have promoted to { res_dtype } (shape={ shape1 , shape2 } )"
274
-
275
- operator_inplace_two_args_promoted_parametrize_inputs = [(binary_op , dtypes ) for binary_op , dtypes in operator_two_args_promoted_parametrize_inputs
276
- if dtypes [0 ][0 ] == dtypes [1 ]]
277
- operator_inplace_two_args_promoted_parametrize_ids = ['-' .join ((n [:2 ] + 'i' + n [2 :], str (d1 ), str (d2 ))) for n , ((d1 , d2 ), _ )
278
- in operator_inplace_two_args_promoted_parametrize_inputs ]
263
+ assert res .dtype == res_dtype , f"{ dtype1 } { op } { dtype2 } promoted to { res .dtype } , should have promoted to { res_dtype } (shape={ shape1 , shape2 } )"
279
264
280
- @pytest .mark .parametrize ('binary_op_name, binary_op , dtypes' , generate_params ('operator' , in_nargs = 2 , out_category = 'promoted' ))
265
+ @pytest .mark .parametrize ('func, op , dtypes' , generate_params ('operator' , in_nargs = 2 , out_category = 'promoted' ))
281
266
@given (two_shapes = hh .two_broadcastable_shapes (), data = st .data ())
282
- def test_operator_inplace_two_args_return_promoted (binary_op_name , binary_op , dtypes , two_shapes ,
267
+ def test_operator_inplace_two_args_return_promoted (func , op , dtypes , two_shapes ,
283
268
data ):
284
269
(dtype1 , dtype2 ), res_dtype = dtypes
285
270
fillvalue1 = data .draw (hh .scalars (st .just (dtype1 )))
286
- if binary_op_name in ['>>' , '<<' ]:
271
+ if func in ['>>' , '<<' ]:
287
272
fillvalue2 = data .draw (hh .scalars (st .just (dtype2 )).filter (lambda x : x > 0 ))
288
273
else :
289
274
fillvalue2 = data .draw (hh .scalars (st .just (dtype2 )))
@@ -299,29 +284,29 @@ def test_operator_inplace_two_args_return_promoted(binary_op_name, binary_op, dt
299
284
get_locals = lambda : dict (a1 = a1 , a2 = a2 )
300
285
301
286
res_locals = get_locals ()
302
- expression = f'a1 { binary_op } = a2'
287
+ expression = f'a1 { op } = a2'
303
288
exec (expression , res_locals )
304
289
res = res_locals ['a1' ]
305
290
306
- assert res .dtype == res_dtype , f"{ dtype1 } { binary_op } = { dtype2 } promoted to { res .dtype } , should have promoted to { res_dtype } (shape={ shape1 , shape2 } )"
291
+ assert res .dtype == res_dtype , f"{ dtype1 } { op } = { dtype2 } promoted to { res .dtype } , should have promoted to { res_dtype } (shape={ shape1 , shape2 } )"
307
292
308
293
scalar_promotion_parametrize_inputs = [
309
- pytest .param (binary_op_name , dtype , scalar_type , id = f"{ binary_op_name } -{ dtype } -{ scalar_type .__name__ } " )
310
- for binary_op_name in sorted (set (dh .binary_op_to_symbol ) - {'__matmul__' })
311
- for dtype in dh .category_to_dtypes [dh .func_in_categories [dh . op_to_func [ binary_op_name ] ]]
294
+ pytest .param (func , dtype , scalar_type , id = f"{ func } -{ dtype } -{ scalar_type .__name__ } " )
295
+ for func in sorted (set (dh .binary_func_to_op ) - {'__matmul__' })
296
+ for dtype in dh .category_to_dtypes [dh .func_in_categories [func ]]
312
297
for scalar_type in dh .dtypes_to_scalars [dtype ]
313
298
]
314
299
315
- @pytest .mark .parametrize ('binary_op_name ,dtype,scalar_type' ,
300
+ @pytest .mark .parametrize ('func ,dtype,scalar_type' ,
316
301
scalar_promotion_parametrize_inputs )
317
302
@given (shape = hh .shapes , python_scalars = st .data (), data = st .data ())
318
- def test_operator_scalar_arg_return_promoted (binary_op_name , dtype , scalar_type ,
303
+ def test_operator_scalar_arg_return_promoted (func , dtype , scalar_type ,
319
304
shape , python_scalars , data ):
320
305
"""
321
306
See https://st.data-apis.github.io/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-hh.scalars
322
307
"""
323
- binary_op = dh .binary_op_to_symbol [ binary_op_name ]
324
- if binary_op == '@' :
308
+ op = dh .binary_func_to_op [ func ]
309
+ if op == '@' :
325
310
pytest .skip ("matmul (@) is not supported for hh.scalars" )
326
311
327
312
if dtype in dh .category_to_dtypes ['integer' ]:
@@ -344,23 +329,23 @@ def test_operator_scalar_arg_return_promoted(binary_op_name, dtype, scalar_type,
344
329
# 2. Execute the operation for `array <op> 0-D array` (or `0-D array <op>
345
330
# array` if `scalar` was the left-hand argument).
346
331
347
- array_scalar = f'a { binary_op } s'
348
- array_scalar_expected = f'a { binary_op } scalar_as_array'
332
+ array_scalar = f'a { op } s'
333
+ array_scalar_expected = f'a { op } scalar_as_array'
349
334
res = eval (array_scalar , get_locals ())
350
335
expected = eval (array_scalar_expected , get_locals ())
351
336
ah .assert_exactly_equal (res , expected )
352
337
353
- scalar_array = f's { binary_op } a'
354
- scalar_array_expected = f'scalar_as_array { binary_op } a'
338
+ scalar_array = f's { op } a'
339
+ scalar_array_expected = f'scalar_as_array { op } a'
355
340
res = eval (scalar_array , get_locals ())
356
341
expected = eval (scalar_array_expected , get_locals ())
357
342
ah .assert_exactly_equal (res , expected )
358
343
359
344
# Test in-place operators
360
- if binary_op in ['==' , '!=' , '<' , '>' , '<=' , '>=' ]:
345
+ if op in ['==' , '!=' , '<' , '>' , '<=' , '>=' ]:
361
346
return
362
- array_scalar = f'a { binary_op } = s'
363
- array_scalar_expected = f'a { binary_op } = scalar_as_array'
347
+ array_scalar = f'a { op } = s'
348
+ array_scalar_expected = f'a { op } = scalar_as_array'
364
349
a = ah .full (shape , fillvalue , dtype = dtype )
365
350
res_locals = get_locals ()
366
351
exec (array_scalar , get_locals ())
0 commit comments