Skip to content

Commit 19cfff7

Browse files
committed
Rudimentary operator parameters refactor
1 parent c07ba0e commit 19cfff7

File tree

3 files changed

+82
-69
lines changed

3 files changed

+82
-69
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
'dtype_signed',
1010
'func_in_categories',
1111
'func_out_categories',
12-
'binary_operators',
13-
'unary_operators',
14-
'operators_to_functions',
12+
'binary_op_to_symbol',
13+
'unary_op_to_symbol',
14+
'op_to_func',
1515
]
1616

1717

@@ -223,7 +223,7 @@
223223
}
224224

225225

226-
binary_operators = {
226+
binary_op_to_symbol = {
227227
'__add__': '+',
228228
'__and__': '&',
229229
'__eq__': '==',
@@ -246,15 +246,15 @@
246246
}
247247

248248

249-
unary_operators = {
249+
unary_op_to_symbol = {
250250
'__abs__': 'abs()',
251251
'__invert__': '~',
252252
'__neg__': '-',
253253
'__pos__': '+',
254254
}
255255

256256

257-
operators_to_functions = {
257+
op_to_func = {
258258
'__abs__': 'abs',
259259
'__add__': 'add',
260260
'__and__': 'bitwise_and',

array_api_tests/test_signatures.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from ._array_module import mod, mod_name, ones, eye, float64, bool, int64
66
from .pytest_helpers import raises, doesnt_raise
7-
from .dtype_helpers import func_in_categories, operators_to_functions
7+
from .dtype_helpers import func_in_categories, op_to_func
88

99
from . import function_stubs
1010

@@ -160,9 +160,9 @@ def test_function_positional_args(name):
160160
dtype = None
161161
if (name.startswith('__i') and name not in ['__int__', '__invert__', '__index__']
162162
or name.startswith('__r') and name != '__rshift__'):
163-
n = operators_to_functions[name[:2] + name[3:]]
163+
n = op_to_func[name[:2] + name[3:]]
164164
else:
165-
n = operators_to_functions.get(name, name)
165+
n = op_to_func.get(name, name)
166166
if 'boolean' in func_in_categories.get(n, 'floating'):
167167
dtype = bool
168168
elif 'integer' in func_in_categories.get(n, 'floating'):

array_api_tests/test_type_promotion.py

Lines changed: 73 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -24,32 +24,57 @@
2424
dtypes_to_scalars,
2525
func_in_categories,
2626
func_out_categories,
27-
binary_operators,
28-
unary_operators,
29-
operators_to_functions,
27+
binary_op_to_symbol,
28+
unary_op_to_symbol,
29+
op_to_func,
3030
)
3131

3232

3333
def generate_params(
34+
func_family: Literal['elementwise', 'operator'],
3435
in_nargs: int,
3536
out_category: Literal['bool', 'promoted'],
3637
) -> Iterator:
37-
funcs = [
38-
f for f in elementwise_functions.__all__
39-
if nargs(f) == in_nargs and func_out_categories[f] == out_category
40-
]
41-
if in_nargs == 1:
42-
for func in funcs:
43-
in_category = func_in_categories[func]
44-
for in_dtype in category_to_dtypes[in_category]:
45-
yield pytest.param(func, in_dtype, id=f"{func}({in_dtype})")
38+
if func_family == 'elementwise':
39+
funcs = [
40+
f for f in elementwise_functions.__all__
41+
if nargs(f) == in_nargs and func_out_categories[f] == out_category
42+
]
43+
if in_nargs == 1:
44+
for func in funcs:
45+
in_category = func_in_categories[func]
46+
for in_dtype in category_to_dtypes[in_category]:
47+
yield pytest.param(func, in_dtype, id=f"{func}({in_dtype})")
48+
else:
49+
for func, ((d1, d2), d3) in product(funcs, promotion_table.items()):
50+
if all(d in category_to_dtypes[func_in_categories[func]] for d in (d1, d2)):
51+
if out_category == 'bool':
52+
yield pytest.param(func, (d1, d2), id=f"{func}({d1}, {d2})")
53+
else:
54+
yield pytest.param(func, ((d1, d2), d3), id=f"{func}({d1}, {d2}) -> {d3}")
4655
else:
47-
for func, ((d1, d2), d3) in product(funcs, promotion_table.items()):
48-
if all(d in category_to_dtypes[func_in_categories[func]] for d in (d1, d2)):
49-
if out_category == 'bool':
50-
yield pytest.param(func, (d1, d2), id=f"{func}({d1}, {d2})")
51-
else:
52-
yield pytest.param(func, ((d1, d2), d3), id=f"{func}({d1}, {d2}) -> {d3}")
56+
if in_nargs == 1:
57+
for op, symbol in unary_op_to_symbol.items():
58+
func = op_to_func[op]
59+
if func_out_categories[func] == out_category:
60+
in_category = func_in_categories[func]
61+
for in_dtype in category_to_dtypes[in_category]:
62+
yield pytest.param(op, symbol, in_dtype, id=f"{op}({in_dtype})")
63+
else:
64+
for op, symbol in binary_op_to_symbol.items():
65+
if op == "__matmul__":
66+
continue
67+
func = op_to_func[op]
68+
if func_out_categories[func] == out_category:
69+
in_category = func_in_categories[func]
70+
for ((d1, d2), d3) in promotion_table.items():
71+
if all(d in category_to_dtypes[in_category] for d in (d1, d2)):
72+
if out_category == 'bool':
73+
yield pytest.param(op, symbol, (d1, d2), id=f"{op}({d1}, {d2})")
74+
else:
75+
if d1 == d3:
76+
yield pytest.param(op, symbol, ((d1, d2), d3), id=f"{op}({d1}, {d2}) -> {d3}")
77+
5378

5479

5580
# TODO: These functions should still do type promotion internally, but we do
@@ -59,7 +84,7 @@ def generate_params(
5984
# array(1.00000001, dtype=float64)) will be wrong if the float64 array is
6085
# downcast to float32. See for instance
6186
# https://github.com/numpy/numpy/issues/10322.
62-
@pytest.mark.parametrize('func, dtypes', generate_params(in_nargs=2, out_category='bool'))
87+
@pytest.mark.parametrize('func, dtypes', generate_params('elementwise', in_nargs=2, out_category='bool'))
6388
# The spec explicitly requires type promotion to work for shape 0
6489
# Unfortunately, data(), isn't compatible with @example, so this is commented
6590
# out for now.
@@ -91,7 +116,7 @@ def test_elementwise_two_args_bool_type_promotion(func, two_shapes, dtypes, data
91116

92117
# TODO: Extend this to all functions (not just elementwise), and handle
93118
# functions that take more than 2 args
94-
@pytest.mark.parametrize('func, dtypes', generate_params(in_nargs=2, out_category='promoted'))
119+
@pytest.mark.parametrize('func, dtypes', generate_params('elementwise', in_nargs=2, out_category='promoted'))
95120
# The spec explicitly requires type promotion to work for shape 0
96121
# Unfortunately, data(), isn't compatible with @example, so this is commented
97122
# out for now.
@@ -124,7 +149,7 @@ def test_elementwise_two_args_promoted_type_promotion(func,
124149

125150
# TODO: Extend this to all functions (not just elementwise), and handle
126151
# functions that take more than 2 args
127-
@pytest.mark.parametrize('func, dtype', generate_params(in_nargs=1, out_category='bool'))
152+
@pytest.mark.parametrize('func, dtype', generate_params('elementwise', in_nargs=1, out_category='bool'))
128153
# The spec explicitly requires type promotion to work for shape 0
129154
# Unfortunately, data(), isn't compatible with @example, so this is commented
130155
# out for now.
@@ -147,7 +172,7 @@ def test_elementwise_one_arg_bool(func, shape, dtype, data):
147172

148173
# TODO: Extend this to all functions (not just elementwise), and handle
149174
# functions that take more than 2 args
150-
@pytest.mark.parametrize('func,dtype', generate_params(in_nargs=1, out_category='promoted'))
175+
@pytest.mark.parametrize('func,dtype', generate_params('elementwise', in_nargs=1, out_category='promoted'))
151176
# The spec explicitly requires type promotion to work for shape 0
152177
# Unfortunately, data(), isn't compatible with @example, so this is commented
153178
# out for now.
@@ -169,29 +194,28 @@ def test_elementwise_one_arg_type_promotion(func, shape,
169194

170195
assert res.dtype == dtype, f"{func}({dtype}) returned to {res.dtype}, should have promoted to {dtype} (shape={shape})"
171196

172-
unary_operators_promoted = [unary_op_name for unary_op_name in sorted(unary_operators)
173-
if func_out_categories[operators_to_functions[unary_op_name]] == 'promoted']
197+
unary_operators_promoted = [unary_op_name for unary_op_name in sorted(unary_op_to_symbol)
198+
if func_out_categories[op_to_func[unary_op_name]] == 'promoted']
174199
operator_one_arg_promoted_parametrize_inputs = [(unary_op_name, dtypes)
175200
for unary_op_name in unary_operators_promoted
176-
for dtypes in category_to_dtypes[func_in_categories[operators_to_functions[unary_op_name]]]
201+
for dtypes in category_to_dtypes[func_in_categories[op_to_func[unary_op_name]]]
177202
]
178203
operator_one_arg_promoted_parametrize_ids = [f"{n}-{d}" for n, d
179204
in operator_one_arg_promoted_parametrize_inputs]
180205

181206

182207
# TODO: Extend this to all functions (not just elementwise), and handle
183208
# functions that take more than 2 args
184-
@pytest.mark.parametrize('unary_op_name,dtype',
185-
operator_one_arg_promoted_parametrize_inputs,
186-
ids=operator_one_arg_promoted_parametrize_ids)
209+
@pytest.mark.parametrize(
210+
'unary_op_name, unary_op, dtype',
211+
generate_params('operator', in_nargs=1, out_category='promoted'),
212+
)
187213
# The spec explicitly requires type promotion to work for shape 0
188214
# Unfortunately, data(), isn't compatible with @example, so this is commented
189215
# out for now.
190216
# @example(shape=(0,))
191217
@given(shape=shapes, data=data())
192-
def test_operator_one_arg_type_promotion(unary_op_name, shape, dtype, data):
193-
unary_op = unary_operators[unary_op_name]
194-
218+
def test_operator_one_arg_type_promotion(unary_op_name, unary_op, shape, dtype, data):
195219
fillvalue = data.draw(scalars(just(dtype)))
196220

197221
if isinstance(dtype, _array_module._UndefinedStub):
@@ -211,24 +235,22 @@ def test_operator_one_arg_type_promotion(unary_op_name, shape, dtype, data):
211235
assert res.dtype == dtype, f"{unary_op}({dtype}) returned to {res.dtype}, should have promoted to {dtype} (shape={shape})"
212236

213237
# Note: the boolean binary operators do not have reversed or in-place variants
214-
binary_operators_bool = [binary_op_name for binary_op_name in sorted(set(binary_operators) - {'__matmul__'})
215-
if func_out_categories[operators_to_functions[binary_op_name]] == 'bool']
238+
binary_operators_bool = [binary_op_name for binary_op_name in sorted(set(binary_op_to_symbol) - {'__matmul__'})
239+
if func_out_categories[op_to_func[binary_op_name]] == 'bool']
216240
operator_two_args_bool_parametrize_inputs = [(binary_op_name, dtypes)
217241
for binary_op_name in binary_operators_bool
218242
for dtypes in promotion_table.keys()
219-
if all(d in category_to_dtypes[func_in_categories[operators_to_functions[binary_op_name]]] for d in dtypes)
243+
if all(d in category_to_dtypes[func_in_categories[op_to_func[binary_op_name]]] for d in dtypes)
220244
]
221245
operator_two_args_bool_parametrize_ids = [f"{n}-{d1}-{d2}" for n, (d1, d2)
222246
in operator_two_args_bool_parametrize_inputs]
223247

224-
@pytest.mark.parametrize('binary_op_name,dtypes',
225-
operator_two_args_bool_parametrize_inputs,
226-
ids=operator_two_args_bool_parametrize_ids)
248+
@pytest.mark.parametrize(
249+
'binary_op_name, binary_op, dtypes',
250+
generate_params('operator', in_nargs=2, out_category='bool')
251+
)
227252
@given(two_shapes=two_mutually_broadcastable_shapes, data=data())
228-
def test_operator_two_args_bool_promotion(binary_op_name, dtypes, two_shapes,
229-
data):
230-
binary_op = binary_operators[binary_op_name]
231-
253+
def test_operator_two_args_bool_promotion(binary_op_name, binary_op, dtypes, two_shapes, data):
232254
dtype1, dtype2 = dtypes
233255
fillvalue1 = data.draw(scalars(just(dtype1)))
234256
fillvalue2 = data.draw(scalars(just(dtype2)))
@@ -247,24 +269,19 @@ def test_operator_two_args_bool_promotion(binary_op_name, dtypes, two_shapes,
247269

248270
assert res.dtype == bool_dtype, f"{dtype1} {binary_op} {dtype2} promoted to {res.dtype}, should have promoted to bool (shape={shape1, shape2})"
249271

250-
binary_operators_promoted = [binary_op_name for binary_op_name in sorted(set(binary_operators) - {'__matmul__'})
251-
if func_out_categories[operators_to_functions[binary_op_name]] == 'promoted']
272+
binary_operators_promoted = [binary_op_name for binary_op_name in sorted(set(binary_op_to_symbol) - {'__matmul__'})
273+
if func_out_categories[op_to_func[binary_op_name]] == 'promoted']
252274
operator_two_args_promoted_parametrize_inputs = [(binary_op_name, dtypes)
253275
for binary_op_name in binary_operators_promoted
254276
for dtypes in promotion_table.items()
255-
if all(d in category_to_dtypes[func_in_categories[operators_to_functions[binary_op_name]]] for d in dtypes[0])
277+
if all(d in category_to_dtypes[func_in_categories[op_to_func[binary_op_name]]] for d in dtypes[0])
256278
]
257279
operator_two_args_promoted_parametrize_ids = [f"{n}-{d1}-{d2}" for n, ((d1, d2), _)
258280
in operator_two_args_promoted_parametrize_inputs]
259281

260-
@pytest.mark.parametrize('binary_op_name,dtypes',
261-
operator_two_args_promoted_parametrize_inputs,
262-
ids=operator_two_args_promoted_parametrize_ids)
282+
@pytest.mark.parametrize('binary_op_name, binary_op, dtypes', generate_params('operator', in_nargs=2, out_category='promoted'))
263283
@given(two_shapes=two_mutually_broadcastable_shapes, data=data())
264-
def test_operator_two_args_promoted_promotion(binary_op_name, dtypes, two_shapes,
265-
data):
266-
binary_op = binary_operators[binary_op_name]
267-
284+
def test_operator_two_args_promoted_promotion(binary_op_name, binary_op, dtypes, two_shapes, data):
268285
(dtype1, dtype2), res_dtype = dtypes
269286
fillvalue1 = data.draw(scalars(just(dtype1)))
270287
if binary_op_name in ['>>', '<<']:
@@ -292,14 +309,10 @@ def test_operator_two_args_promoted_promotion(binary_op_name, dtypes, two_shapes
292309
operator_inplace_two_args_promoted_parametrize_ids = ['-'.join((n[:2] + 'i' + n[2:], str(d1), str(d2))) for n, ((d1, d2), _)
293310
in operator_inplace_two_args_promoted_parametrize_inputs]
294311

295-
@pytest.mark.parametrize('binary_op_name,dtypes',
296-
operator_inplace_two_args_promoted_parametrize_inputs,
297-
ids=operator_inplace_two_args_promoted_parametrize_ids)
312+
@pytest.mark.parametrize('binary_op_name, binary_op, dtypes', generate_params('operator', in_nargs=2, out_category='promoted'))
298313
@given(two_shapes=two_broadcastable_shapes(), data=data())
299-
def test_operator_inplace_two_args_promoted_promotion(binary_op_name, dtypes, two_shapes,
314+
def test_operator_inplace_two_args_promoted_promotion(binary_op_name, binary_op, dtypes, two_shapes,
300315
data):
301-
binary_op = binary_operators[binary_op_name]
302-
303316
(dtype1, dtype2), res_dtype = dtypes
304317
fillvalue1 = data.draw(scalars(just(dtype1)))
305318
if binary_op_name in ['>>', '<<']:
@@ -326,8 +339,8 @@ def test_operator_inplace_two_args_promoted_promotion(binary_op_name, dtypes, tw
326339

327340
scalar_promotion_parametrize_inputs = [
328341
pytest.param(binary_op_name, dtype, scalar_type, id=f"{binary_op_name}-{dtype}-{scalar_type.__name__}")
329-
for binary_op_name in sorted(set(binary_operators) - {'__matmul__'})
330-
for dtype in category_to_dtypes[func_in_categories[operators_to_functions[binary_op_name]]]
342+
for binary_op_name in sorted(set(binary_op_to_symbol) - {'__matmul__'})
343+
for dtype in category_to_dtypes[func_in_categories[op_to_func[binary_op_name]]]
331344
for scalar_type in dtypes_to_scalars[dtype]
332345
]
333346

@@ -339,7 +352,7 @@ def test_operator_scalar_promotion(binary_op_name, dtype, scalar_type,
339352
"""
340353
See https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars
341354
"""
342-
binary_op = binary_operators[binary_op_name]
355+
binary_op = binary_op_to_symbol[binary_op_name]
343356
if binary_op == '@':
344357
pytest.skip("matmul (@) is not supported for scalars")
345358

0 commit comments

Comments
 (0)