Skip to content

Commit 4ac0610

Browse files
committed
Test unary/binary/inplace ops alongside scalars
1 parent b3b2519 commit 4ac0610

File tree

3 files changed

+1040
-503
lines changed

3 files changed

+1040
-503
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
'binary_op_to_symbol',
3131
'unary_op_to_symbol',
3232
'inplace_op_to_symbol',
33+
'op_to_func',
3334
'fmt_types',
3435
]
3536

@@ -332,7 +333,7 @@ def result_type(*dtypes: DataType):
332333
}
333334

334335

335-
_op_to_func = {
336+
op_to_func = {
336337
'__abs__': 'abs',
337338
'__add__': 'add',
338339
'__and__': 'bitwise_and',
@@ -359,7 +360,7 @@ def result_type(*dtypes: DataType):
359360
}
360361

361362

362-
for op, elwise_func in _op_to_func.items():
363+
for op, elwise_func in op_to_func.items():
363364
func_in_dtypes[op] = func_in_dtypes[elwise_func]
364365
func_returns_bool[op] = func_returns_bool[elwise_func]
365366

array_api_tests/pytest_helpers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,19 @@ def assert_default_int(func_name: str, dtype: DataType):
114114

115115

116116
def assert_shape(
117-
func_name: str, out_shape: Union[int, Shape], expected: Union[int, Shape], /, **kw
117+
func_name: str,
118+
out_shape: Union[int, Shape],
119+
expected: Union[int, Shape],
120+
/,
121+
out_name="out.shape",
122+
**kw,
118123
):
119124
if isinstance(out_shape, int):
120125
out_shape = (out_shape,)
121126
if isinstance(expected, int):
122127
expected = (expected,)
123128
msg = (
124-
f"out.shape={out_shape}, but should be {expected} [{func_name}({fmt_kw(kw)})]"
129+
f"{out_name}={out_shape}, but should be {expected} [{func_name}({fmt_kw(kw)})]"
125130
)
126131
assert out_shape == expected, msg
127132

0 commit comments

Comments
 (0)