@@ -352,6 +352,40 @@ def test_elementwise_function_two_arg_type_promotion(func_name, shape, dtypes):
352
352
353
353
assert res .dtype == res_dtype , f"{ func_name } ({ dtype1 } , { dtype2 } ) promoted to { res .dtype } , should have promoted to { res_dtype } (shape={ shape } )"
354
354
355
+
356
+ elementwise_function_one_arg_func_names = [func_name for func_name in
357
+ elementwise_functions .__all__ if
358
+ nargs (func_name ) == 1 ]
359
+ for func_name in elementwise_function_one_arg_func_names :
360
+ assert elementwise_function_output_types [func_name ] in ['promoted' , 'same' ], func_name
361
+ elementwise_function_one_arg_parametrize_inputs = [(func_name , dtypes )
362
+ for func_name in elementwise_function_one_arg_func_names
363
+ for dtypes in input_types [elementwise_function_input_types [func_name ]]]
364
+ elementwise_function_one_arg_parametrize_ids = ['-' .join ((n , d )) for n , d
365
+ in elementwise_function_two_arg_parametrize_inputs ]
366
+
367
+ # TODO: Extend this to all functions (not just elementwise), and handle
368
+ # functions that take more than 2 args
369
+ @pytest .mark .parametrize ('func_name,dtype_name' ,
370
+ elementwise_function_one_arg_parametrize_inputs , ids = elementwise_function_one_arg_parametrize_ids )
371
+ # The spec explicitly requires type promotion to work for shape 0
372
+ @example (shape = (0 ,))
373
+ @given (shape = shapes )
374
+ def test_elementwise_function_one_arg_type_promotion (func_name , shape , dtype_name ):
375
+ assert nargs (func_name ) == 2
376
+ func = getattr (_array_module , func_name )
377
+
378
+ dtype = dtype_mapping [dtype_name ]
379
+
380
+ for i in [func , dtype ]:
381
+ if isinstance (i , _array_module ._UndefinedStub ):
382
+ func ._raise ()
383
+
384
+ x = ones (shape , dtype = dtype )
385
+ res = func (x )
386
+
387
+ assert res .dtype == dtype , f"{ func_name } ({ dtype } ) returned to { res .dtype } , should have promoted to { dtype } (shape={ shape } )"
388
+
355
389
@pytest .mark .parametrize ('binary_op' , sorted (set (binary_operators .values ()) - {'@' }))
356
390
@pytest .mark .parametrize ('scalar_type,dtype' , [(s , d ) for s in scalar_to_dtype
357
391
for d in scalar_to_dtype [s ]])
0 commit comments