Skip to content

Commit 2ff39c8

Browse files
committed
WIP for test_elementwise_function_one_arg_type_promotion()
1 parent 9b4c57b commit 2ff39c8

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

array_api_tests/test_type_promotion.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,40 @@ def test_elementwise_function_two_arg_type_promotion(func_name, shape, dtypes):
352352

353353
assert res.dtype == res_dtype, f"{func_name}({dtype1}, {dtype2}) promoted to {res.dtype}, should have promoted to {res_dtype} (shape={shape})"
354354

355+
356+
elementwise_function_one_arg_func_names = [func_name for func_name in
357+
elementwise_functions.__all__ if
358+
nargs(func_name) == 1]
359+
for func_name in elementwise_function_one_arg_func_names:
360+
assert elementwise_function_output_types[func_name] in ['promoted', 'same'], func_name
361+
elementwise_function_one_arg_parametrize_inputs = [(func_name, dtypes)
362+
for func_name in elementwise_function_one_arg_func_names
363+
for dtypes in input_types[elementwise_function_input_types[func_name]]]
364+
elementwise_function_one_arg_parametrize_ids = ['-'.join((n, d)) for n, d
365+
in elementwise_function_two_arg_parametrize_inputs]
366+
367+
# TODO: Extend this to all functions (not just elementwise), and handle
368+
# functions that take more than 2 args
369+
@pytest.mark.parametrize('func_name,dtype_name',
370+
elementwise_function_one_arg_parametrize_inputs, ids=elementwise_function_one_arg_parametrize_ids)
371+
# The spec explicitly requires type promotion to work for shape 0
372+
@example(shape=(0,))
373+
@given(shape=shapes)
374+
def test_elementwise_function_one_arg_type_promotion(func_name, shape, dtype_name):
375+
assert nargs(func_name) == 2
376+
func = getattr(_array_module, func_name)
377+
378+
dtype = dtype_mapping[dtype_name]
379+
380+
for i in [func, dtype]:
381+
if isinstance(i, _array_module._UndefinedStub):
382+
func._raise()
383+
384+
x = ones(shape, dtype=dtype)
385+
res = func(x)
386+
387+
assert res.dtype == dtype, f"{func_name}({dtype}) returned to {res.dtype}, should have promoted to {dtype} (shape={shape})"
388+
355389
@pytest.mark.parametrize('binary_op', sorted(set(binary_operators.values()) - {'@'}))
356390
@pytest.mark.parametrize('scalar_type,dtype', [(s, d) for s in scalar_to_dtype
357391
for d in scalar_to_dtype[s]])

0 commit comments

Comments
 (0)