Skip to content

Commit ce6a7e8

Browse files
committed
Clarify operator-related variable names
1 parent 85212da commit ce6a7e8

File tree

2 files changed

+47
-58
lines changed

2 files changed

+47
-58
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@
99
'dtype_signed',
1010
'func_in_categories',
1111
'func_out_categories',
12-
'binary_op_to_symbol',
13-
'unary_op_to_symbol',
14-
'op_to_func',
12+
'binary_func_to_op',
13+
'unary_func_to_op',
1514
]
1615

1716

@@ -223,7 +222,7 @@
223222
}
224223

225224

226-
binary_op_to_symbol = {
225+
binary_func_to_op = {
227226
'__add__': '+',
228227
'__and__': '&',
229228
'__eq__': '==',
@@ -246,15 +245,15 @@
246245
}
247246

248247

249-
unary_op_to_symbol = {
248+
unary_func_to_op = {
250249
'__abs__': 'abs()',
251250
'__invert__': '~',
252251
'__neg__': '-',
253252
'__pos__': '+',
254253
}
255254

256255

257-
op_to_func = {
256+
_operator_to_elementwise = {
258257
'__abs__': 'abs',
259258
'__add__': 'add',
260259
'__and__': 'bitwise_and',
@@ -265,7 +264,7 @@
265264
'__le__': 'less_equal',
266265
'__lshift__': 'bitwise_left_shift',
267266
'__lt__': 'less',
268-
'__matmul__': 'matmul',
267+
# '__matmul__': 'matmul', # TODO: support matmul
269268
'__mod__': 'remainder',
270269
'__mul__': 'multiply',
271270
'__ne__': 'not_equal',
@@ -279,3 +278,8 @@
279278
'__neg__': 'negative',
280279
'__pos__': 'positive',
281280
}
281+
282+
283+
for op_func, elwise_func in _operator_to_elementwise.items():
284+
func_in_categories[op_func] = func_in_categories[elwise_func]
285+
func_out_categories[op_func] = func_out_categories[elwise_func]

array_api_tests/test_type_promotion.py

Lines changed: 36 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -42,26 +42,26 @@ def generate_params(
4242
yield pytest.param(func, ((d1, d2), d3), id=f"{func}({d1}, {d2}) -> {d3}")
4343
else:
4444
if in_nargs == 1:
45-
for op, symbol in dh.unary_op_to_symbol.items():
46-
func = dh.op_to_func[op]
45+
for func, op in dh.unary_func_to_op.items():
46+
if func == "__matmul__":
47+
continue
4748
if dh.func_out_categories[func] == out_category:
4849
in_category = dh.func_in_categories[func]
4950
for in_dtype in dh.category_to_dtypes[in_category]:
50-
yield pytest.param(op, symbol, in_dtype, id=f"{op}({in_dtype})")
51+
yield pytest.param(func, op, in_dtype, id=f"{func}({in_dtype})")
5152
else:
52-
for op, symbol in dh.binary_op_to_symbol.items():
53-
if op == "__matmul__":
53+
for func, op in dh.binary_func_to_op.items():
54+
if func == "__matmul__":
5455
continue
55-
func = dh.op_to_func[op]
5656
if dh.func_out_categories[func] == out_category:
5757
in_category = dh.func_in_categories[func]
5858
for ((d1, d2), d3) in dh.promotion_table.items():
5959
if all(d in dh.category_to_dtypes[in_category] for d in (d1, d2)):
6060
if out_category == 'bool':
61-
yield pytest.param(op, symbol, (d1, d2), id=f"{op}({d1}, {d2})")
61+
yield pytest.param(func, op, (d1, d2), id=f"{func}({d1}, {d2})")
6262
else:
6363
if d1 == d3:
64-
yield pytest.param(op, symbol, ((d1, d2), d3), id=f"{op}({d1}, {d2}) -> {d3}")
64+
yield pytest.param(func, op, ((d1, d2), d3), id=f"{func}({d1}, {d2}) -> {d3}")
6565

6666

6767

@@ -214,11 +214,11 @@ def test_operator_one_arg_return_promoted(unary_op_name, unary_op, shape, dtype,
214214
assert res.dtype == dtype, f"{unary_op}({dtype}) returned to {res.dtype}, should have promoted to {dtype} (shape={shape})"
215215

216216
@pytest.mark.parametrize(
217-
'binary_op_name, binary_op, dtypes',
217+
'func, op, dtypes',
218218
generate_params('operator', in_nargs=2, out_category='bool')
219219
)
220220
@given(two_shapes=hh.two_mutually_broadcastable_shapes, data=st.data())
221-
def test_operator_two_args_return_bool(binary_op_name, binary_op, dtypes, two_shapes, data):
221+
def test_operator_two_args_return_bool(func, op, dtypes, two_shapes, data):
222222
dtype1, dtype2 = dtypes
223223
fillvalue1 = data.draw(hh.scalars(st.just(dtype1)))
224224
fillvalue2 = data.draw(hh.scalars(st.just(dtype2)))
@@ -232,27 +232,17 @@ def test_operator_two_args_return_bool(binary_op_name, binary_op, dtypes, two_sh
232232
a2 = ah.full(shape2, fillvalue2, dtype=dtype2)
233233

234234
get_locals = lambda: dict(a1=a1, a2=a2)
235-
expression = f'a1 {binary_op} a2'
235+
expression = f'a1 {op} a2'
236236
res = eval(expression, get_locals())
237237

238-
assert res.dtype == xp.bool, f"{dtype1} {binary_op} {dtype2} promoted to {res.dtype}, should have promoted to bool (shape={shape1, shape2})"
239-
240-
binary_operators_promoted = [binary_op_name for binary_op_name in sorted(set(dh.binary_op_to_symbol) - {'__matmul__'})
241-
if dh.func_out_categories[dh.op_to_func[binary_op_name]] == 'promoted']
242-
operator_two_args_promoted_parametrize_inputs = [(binary_op_name, dtypes)
243-
for binary_op_name in binary_operators_promoted
244-
for dtypes in dh.promotion_table.items()
245-
if all(d in dh.category_to_dtypes[dh.func_in_categories[dh.op_to_func[binary_op_name]]] for d in dtypes[0])
246-
]
247-
operator_two_args_promoted_parametrize_ids = [f"{n}-{d1}-{d2}" for n, ((d1, d2), _)
248-
in operator_two_args_promoted_parametrize_inputs]
238+
assert res.dtype == xp.bool, f"{dtype1} {op} {dtype2} promoted to {res.dtype}, should have promoted to bool (shape={shape1, shape2})"
249239

250-
@pytest.mark.parametrize('binary_op_name, binary_op, dtypes', generate_params('operator', in_nargs=2, out_category='promoted'))
240+
@pytest.mark.parametrize('func, op, dtypes', generate_params('operator', in_nargs=2, out_category='promoted'))
251241
@given(two_shapes=hh.two_mutually_broadcastable_shapes, data=st.data())
252-
def test_operator_two_args_return_promoted(binary_op_name, binary_op, dtypes, two_shapes, data):
242+
def test_operator_two_args_return_promoted(func, op, dtypes, two_shapes, data):
253243
(dtype1, dtype2), res_dtype = dtypes
254244
fillvalue1 = data.draw(hh.scalars(st.just(dtype1)))
255-
if binary_op_name in ['>>', '<<']:
245+
if op in ['>>', '<<']:
256246
fillvalue2 = data.draw(hh.scalars(st.just(dtype2)).filter(lambda x: x > 0))
257247
else:
258248
fillvalue2 = data.draw(hh.scalars(st.just(dtype2)))
@@ -267,23 +257,18 @@ def test_operator_two_args_return_promoted(binary_op_name, binary_op, dtypes, tw
267257
a2 = ah.full(shape2, fillvalue2, dtype=dtype2)
268258

269259
get_locals = lambda: dict(a1=a1, a2=a2)
270-
expression = f'a1 {binary_op} a2'
260+
expression = f'a1 {op} a2'
271261
res = eval(expression, get_locals())
272262

273-
assert res.dtype == res_dtype, f"{dtype1} {binary_op} {dtype2} promoted to {res.dtype}, should have promoted to {res_dtype} (shape={shape1, shape2})"
274-
275-
operator_inplace_two_args_promoted_parametrize_inputs = [(binary_op, dtypes) for binary_op, dtypes in operator_two_args_promoted_parametrize_inputs
276-
if dtypes[0][0] == dtypes[1]]
277-
operator_inplace_two_args_promoted_parametrize_ids = ['-'.join((n[:2] + 'i' + n[2:], str(d1), str(d2))) for n, ((d1, d2), _)
278-
in operator_inplace_two_args_promoted_parametrize_inputs]
263+
assert res.dtype == res_dtype, f"{dtype1} {op} {dtype2} promoted to {res.dtype}, should have promoted to {res_dtype} (shape={shape1, shape2})"
279264

280-
@pytest.mark.parametrize('binary_op_name, binary_op, dtypes', generate_params('operator', in_nargs=2, out_category='promoted'))
265+
@pytest.mark.parametrize('func, op, dtypes', generate_params('operator', in_nargs=2, out_category='promoted'))
281266
@given(two_shapes=hh.two_broadcastable_shapes(), data=st.data())
282-
def test_operator_inplace_two_args_return_promoted(binary_op_name, binary_op, dtypes, two_shapes,
267+
def test_operator_inplace_two_args_return_promoted(func, op, dtypes, two_shapes,
283268
data):
284269
(dtype1, dtype2), res_dtype = dtypes
285270
fillvalue1 = data.draw(hh.scalars(st.just(dtype1)))
286-
if binary_op_name in ['>>', '<<']:
271+
if func in ['>>', '<<']:
287272
fillvalue2 = data.draw(hh.scalars(st.just(dtype2)).filter(lambda x: x > 0))
288273
else:
289274
fillvalue2 = data.draw(hh.scalars(st.just(dtype2)))
@@ -299,29 +284,29 @@ def test_operator_inplace_two_args_return_promoted(binary_op_name, binary_op, dt
299284
get_locals = lambda: dict(a1=a1, a2=a2)
300285

301286
res_locals = get_locals()
302-
expression = f'a1 {binary_op}= a2'
287+
expression = f'a1 {op}= a2'
303288
exec(expression, res_locals)
304289
res = res_locals['a1']
305290

306-
assert res.dtype == res_dtype, f"{dtype1} {binary_op}= {dtype2} promoted to {res.dtype}, should have promoted to {res_dtype} (shape={shape1, shape2})"
291+
assert res.dtype == res_dtype, f"{dtype1} {op}= {dtype2} promoted to {res.dtype}, should have promoted to {res_dtype} (shape={shape1, shape2})"
307292

308293
scalar_promotion_parametrize_inputs = [
309-
pytest.param(binary_op_name, dtype, scalar_type, id=f"{binary_op_name}-{dtype}-{scalar_type.__name__}")
310-
for binary_op_name in sorted(set(dh.binary_op_to_symbol) - {'__matmul__'})
311-
for dtype in dh.category_to_dtypes[dh.func_in_categories[dh.op_to_func[binary_op_name]]]
294+
pytest.param(func, dtype, scalar_type, id=f"{func}-{dtype}-{scalar_type.__name__}")
295+
for func in sorted(set(dh.binary_func_to_op) - {'__matmul__'})
296+
for dtype in dh.category_to_dtypes[dh.func_in_categories[func]]
312297
for scalar_type in dh.dtypes_to_scalars[dtype]
313298
]
314299

315-
@pytest.mark.parametrize('binary_op_name,dtype,scalar_type',
300+
@pytest.mark.parametrize('func,dtype,scalar_type',
316301
scalar_promotion_parametrize_inputs)
317302
@given(shape=hh.shapes, python_scalars=st.data(), data=st.data())
318-
def test_operator_scalar_arg_return_promoted(binary_op_name, dtype, scalar_type,
303+
def test_operator_scalar_arg_return_promoted(func, dtype, scalar_type,
319304
shape, python_scalars, data):
320305
"""
321306
See https://st.data-apis.github.io/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-hh.scalars
322307
"""
323-
binary_op = dh.binary_op_to_symbol[binary_op_name]
324-
if binary_op == '@':
308+
op = dh.binary_func_to_op[func]
309+
if op == '@':
325310
pytest.skip("matmul (@) is not supported for hh.scalars")
326311

327312
if dtype in dh.category_to_dtypes['integer']:
@@ -344,23 +329,23 @@ def test_operator_scalar_arg_return_promoted(binary_op_name, dtype, scalar_type,
344329
# 2. Execute the operation for `array <op> 0-D array` (or `0-D array <op>
345330
# array` if `scalar` was the left-hand argument).
346331

347-
array_scalar = f'a {binary_op} s'
348-
array_scalar_expected = f'a {binary_op} scalar_as_array'
332+
array_scalar = f'a {op} s'
333+
array_scalar_expected = f'a {op} scalar_as_array'
349334
res = eval(array_scalar, get_locals())
350335
expected = eval(array_scalar_expected, get_locals())
351336
ah.assert_exactly_equal(res, expected)
352337

353-
scalar_array = f's {binary_op} a'
354-
scalar_array_expected = f'scalar_as_array {binary_op} a'
338+
scalar_array = f's {op} a'
339+
scalar_array_expected = f'scalar_as_array {op} a'
355340
res = eval(scalar_array, get_locals())
356341
expected = eval(scalar_array_expected, get_locals())
357342
ah.assert_exactly_equal(res, expected)
358343

359344
# Test in-place operators
360-
if binary_op in ['==', '!=', '<', '>', '<=', '>=']:
345+
if op in ['==', '!=', '<', '>', '<=', '>=']:
361346
return
362-
array_scalar = f'a {binary_op}= s'
363-
array_scalar_expected = f'a {binary_op}= scalar_as_array'
347+
array_scalar = f'a {op}= s'
348+
array_scalar_expected = f'a {op}= scalar_as_array'
364349
a = ah.full(shape, fillvalue, dtype=dtype)
365350
res_locals = get_locals()
366351
exec(array_scalar, get_locals())

0 commit comments

Comments
 (0)