From 6c566d50238eb5f20b30012f7e3e99838a6a1d76 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 23 Apr 2025 12:47:30 +0100 Subject: [PATCH 1/7] ENH: unary functions overhaul --- array_api_strict/_elementwise_functions.py | 571 ++++-------------- array_api_strict/_helpers.py | 18 +- array_api_strict/_searching_functions.py | 3 + .../tests/test_elementwise_functions.py | 139 +++-- .../tests/test_searching_functions.py | 57 ++ 5 files changed, 282 insertions(+), 506 deletions(-) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index b05e0fd..f18c9f5 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -1,3 +1,6 @@ +from functools import wraps +from types import NoneType + import numpy as np from ._array_object import Array @@ -35,7 +38,7 @@ def _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func): return Array._new(np_func(x1._array, x2._array), device=x1.device) -_binary_docstring_template = """ +_docstring_template = """ Array API compatible wrapper for :py:func:`np.%s `. See its docstring for more information. @@ -117,7 +120,7 @@ def inner(x1, x2, /) -> Array: func = _create_binary_func(func_name, dtype_category, np_func) func.__name__ = func_name - func.__doc__ = _binary_docstring_template % (numpy_name, numpy_name) + func.__doc__ = _docstring_template % (numpy_name, numpy_name) func.__annotations__['x1'] = _annotations[dtype_category] func.__annotations__['x2'] = _annotations[dtype_category] @@ -153,115 +156,98 @@ def bitwise_right_shift(x1: int | Array, x2: int | Array, /) -> Array: del func, _create_binary_func -def abs(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.abs `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in abs") - return Array._new(np.abs(x._array), device=x.device) - - -# Note: the function name is different here -def acos(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.arccos `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in acos") - return Array._new(np.arccos(x._array), device=x.device) - - -# Note: the function name is different here -def acosh(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.arccosh `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in acosh") - return Array._new(np.arccosh(x._array), device=x.device) - - -# Note: the function name is different here -def asin(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.arcsin `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in asin") - return Array._new(np.arcsin(x._array), device=x.device) - - -# Note: the function name is different here -def asinh(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.arcsinh `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in asinh") - return Array._new(np.arcsinh(x._array), device=x.device) - +def _create_unary_func(func_name, dtype_category, np_func_name=None): + allowed_dtypes = _dtype_categories[dtype_category] + np_func_name = np_func_name or func_name + np_func = getattr(np, np_func_name) -# Note: the function name is different here -def atan(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.arctan `. + def func(x: Array, /) -> Array: + if not isinstance(x, Array): + raise TypeError(f"Only Array objects are allowed; got {type(x)}") + if x.dtype not in allowed_dtypes: + raise TypeError( + f"Only {dtype_category} dtypes are allowed in {func_name}; " + f"got {x.dtype}." + ) + return Array._new(np_func(x._array), device=x.device) - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in atan") - return Array._new(np.arctan(x._array), device=x.device) - - -# Note: the function name is different here -def atanh(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.arctanh `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in atanh") - return Array._new(np.arctanh(x._array), device=x.device) - - -# Note: the function name is different here -def bitwise_invert(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.invert `. - - See its docstring for more information. - """ - if x.dtype not in _integer_or_boolean_dtypes: - raise TypeError("Only integer or boolean dtypes are allowed in bitwise_invert") - return Array._new(np.invert(x._array), device=x.device) + func.__name__ = func_name + func.__doc__ = _docstring_template % (np_func_name, np_func_name) + return func + + +def _identity_if_integer(func): + """Hack around NumPy 1.x behaviour for ceil, floor, and trunc + vs. integer inputs + """ + + @wraps(func) + def wrapper(x: Array, /) -> Array: + if isinstance(x, Array) and x.dtype in _integer_dtypes: + return x + return func(x) + + return wrapper + + +abs = _create_unary_func("abs", "numeric") +acos = _create_unary_func("acos", "floating-point", "arccos") +acosh = _create_unary_func("acosh", "floating-point", "arccosh") +asin = _create_unary_func("asin", "floating-point", "arcsin") +asinh = _create_unary_func("asinh", "floating-point", "arcsinh") +atan = _create_unary_func("atan", "floating-point", "arctan") +atanh = _create_unary_func("atanh", "floating-point", "arctanh") +bitwise_invert = _create_unary_func("bitwise_invert", "integer or boolean", "invert") +ceil = _identity_if_integer(_create_unary_func("ceil", "real numeric")) +conj = _create_unary_func("conj", "numeric") +cos = _create_unary_func("cos", "floating-point", "arccos") +cosh = _create_unary_func("cosh", "floating-point", "arccosh") +exp = _create_unary_func("exp", "floating-point") +expm1 = _create_unary_func("expm1", "floating-point") +floor = _identity_if_integer(_create_unary_func("floor", "real numeric")) +imag = _create_unary_func("imag", "complex floating-point") +isfinite = _create_unary_func("isfinite", "numeric") +isinf = _create_unary_func("isinf", "numeric") +isnan = _create_unary_func("isnan", "numeric") +log = _create_unary_func("log", "floating-point") +log1p = _create_unary_func("log1p", "floating-point") +log2 = _create_unary_func("log2", "floating-point") +log10 = _create_unary_func("log10", "floating-point") +logical_not = _create_unary_func("logical_not", "boolean") +negative = _create_unary_func("negative", "numeric") +positive = _create_unary_func("positive", "numeric") +reciprocal = requires_api_version("2024.12")( + _create_unary_func("reciprocal", "floating-point") +) +real = _create_unary_func("real", "numeric") +round = _create_unary_func("round", "numeric") +signbit = requires_api_version("2023.12")( + _create_unary_func("signbit", "real floating-point") +) +sin = _create_unary_func("sin", "floating-point") +sinh = _create_unary_func("sinh", "floating-point") +square = _create_unary_func("square", "numeric") +sqrt = _create_unary_func("sqrt", "floating-point") +tan = _create_unary_func("tan", "floating-point") +tanh = _create_unary_func("tanh", "floating-point") +trunc = _identity_if_integer(_create_unary_func("trunc", "real numeric")) -def ceil(x: Array, /) -> Array: +def sign(x: Array, /) -> Array: """ - Array API compatible wrapper for :py:func:`np.ceil `. + Array API compatible wrapper for :py:func:`np.sign `. See its docstring for more information. """ - if x.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in ceil") - if x.dtype in _integer_dtypes: - # Note: The return dtype of ceil is the same as the input - return x - return Array._new(np.ceil(x._array), device=x.device) + if not isinstance(x, Array): + raise TypeError(f"Only Array objects are allowed; got {type(x)}") + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in sign") + # Special treatment to work around non-compliant NumPy 1.x behaviour + if x.dtype in _complex_floating_dtypes: + return x/abs(x) + return Array._new(np.sign(x._array), device=x.device) -# WARNING: This function is not yet tested by the array-api-tests test suite. # Note: min and max argument names are different and not optional in numpy. @requires_api_version('2023.12') @@ -276,42 +262,40 @@ def clip( See its docstring for more information. """ - if isinstance(min, Array) and x.device != min.device: - raise ValueError(f"Arrays from two different devices ({x.device} and {min.device}) can not be combined.") - if isinstance(max, Array) and x.device != max.device: - raise ValueError(f"Arrays from two different devices ({x.device} and {max.device}) can not be combined.") + if not isinstance(x, Array): + raise TypeError(f"Only Array objects are allowed; got {type(x)}") if (x.dtype not in _real_numeric_dtypes or isinstance(min, Array) and min.dtype not in _real_numeric_dtypes or isinstance(max, Array) and max.dtype not in _real_numeric_dtypes): raise TypeError("Only real numeric dtypes are allowed in clip") - if not isinstance(min, (int, float, Array, type(None))): - raise TypeError("min must be an None, int, float, or an array") - if not isinstance(max, (int, float, Array, type(None))): - raise TypeError("max must be an None, int, float, or an array") - - # Mixed dtype kinds is implementation defined - if (x.dtype in _integer_dtypes - and (isinstance(min, float) or - isinstance(min, Array) and min.dtype in _real_floating_dtypes)): - raise TypeError("min must be integral when x is integral") - if (x.dtype in _integer_dtypes - and (isinstance(max, float) or - isinstance(max, Array) and max.dtype in _real_floating_dtypes)): - raise TypeError("max must be integral when x is integral") - if (x.dtype in _real_floating_dtypes - and (isinstance(min, int) or - isinstance(min, Array) and min.dtype in _integer_dtypes)): - raise TypeError("min must be floating-point when x is floating-point") - if (x.dtype in _real_floating_dtypes - and (isinstance(max, int) or - isinstance(max, Array) and max.dtype in _integer_dtypes)): - raise TypeError("max must be floating-point when x is floating-point") if min is max is None: - # Note: NumPy disallows min = max = None return x + for argname, arg in ("min", min), ("max", max): + if isinstance(arg, Array): + if x.device != arg.device: + raise ValueError( + f"Arrays from two different devices ({x.device} and {arg.device}) " + "can not be combined." + ) + # Disallow subclasses of Python scalars, e.g. np.float64 + elif type(arg) not in (int, float, NoneType): + raise TypeError( + f"{argname} must be None, int, float, or Array; got {type(arg)}" + ) + + # Mixed dtype kinds is implementation defined + if (x.dtype in _integer_dtypes + and (isinstance(arg, float) or + isinstance(arg, Array) and arg.dtype in _real_floating_dtypes)): + raise TypeError(f"{argname} must be integral when x is integral") + if (x.dtype in _real_floating_dtypes + and (isinstance(arg, int) or + isinstance(arg, Array) and arg.dtype in _integer_dtypes)): + raise TypeError(f"{arg} must be floating-point when x is floating-point") + # Normalize to make the below logic simpler if min is not None: min = asarray(min)._array @@ -368,332 +352,3 @@ def _isscalar(a): ib = (out > b) | np.isnan(b) out[ib] = b[ib] return Array._new(out, device=device) - - -def conj(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.conj `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in conj") - return Array._new(np.conj(x._array), device=x.device) - - -def cos(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.cos `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in cos") - return Array._new(np.cos(x._array), device=x.device) - - -def cosh(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.cosh `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in cosh") - return Array._new(np.cosh(x._array), device=x.device) - - -def exp(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.exp `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in exp") - return Array._new(np.exp(x._array), device=x.device) - - -def expm1(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.expm1 `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in expm1") - return Array._new(np.expm1(x._array), device=x.device) - - -def floor(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.floor `. - - See its docstring for more information. - """ - if x.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in floor") - if x.dtype in _integer_dtypes: - # Note: The return dtype of floor is the same as the input - return x - return Array._new(np.floor(x._array), device=x.device) - - -def imag(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.imag `. - - See its docstring for more information. - """ - if x.dtype not in _complex_floating_dtypes: - raise TypeError("Only complex floating-point dtypes are allowed in imag") - return Array._new(np.imag(x._array), device=x.device) - - -def isfinite(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.isfinite `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in isfinite") - return Array._new(np.isfinite(x._array), device=x.device) - - -def isinf(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.isinf `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in isinf") - return Array._new(np.isinf(x._array), device=x.device) - - -def isnan(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.isnan `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in isnan") - return Array._new(np.isnan(x._array), device=x.device) - - -def log(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.log `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in log") - return Array._new(np.log(x._array), device=x.device) - - -def log1p(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.log1p `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in log1p") - return Array._new(np.log1p(x._array), device=x.device) - - -def log2(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.log2 `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in log2") - return Array._new(np.log2(x._array), device=x.device) - - -def log10(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.log10 `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in log10") - return Array._new(np.log10(x._array), device=x.device) - - -def logical_not(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.logical_not `. - - See its docstring for more information. - """ - if x.dtype not in _boolean_dtypes: - raise TypeError("Only boolean dtypes are allowed in logical_not") - return Array._new(np.logical_not(x._array), device=x.device) - - -def negative(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.negative `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in negative") - return Array._new(np.negative(x._array), device=x.device) - - -def positive(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.positive `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in positive") - return Array._new(np.positive(x._array), device=x.device) - - -def real(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.real `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in real") - return Array._new(np.real(x._array), device=x.device) - - -@requires_api_version('2024.12') -def reciprocal(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.reciprocal `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in reciprocal") - return Array._new(np.reciprocal(x._array), device=x.device) - - -def round(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.round `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in round") - return Array._new(np.round(x._array), device=x.device) - - -def sign(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.sign `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in sign") - if x.dtype in _complex_floating_dtypes: - return x/abs(x) - return Array._new(np.sign(x._array), device=x.device) - - -@requires_api_version('2023.12') -def signbit(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.signbit `. - - See its docstring for more information. - """ - if x.dtype not in _real_floating_dtypes: - raise TypeError("Only real floating-point dtypes are allowed in signbit") - return Array._new(np.signbit(x._array), device=x.device) - - -def sin(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.sin `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in sin") - return Array._new(np.sin(x._array), device=x.device) - - -def sinh(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.sinh `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in sinh") - return Array._new(np.sinh(x._array), device=x.device) - - -def square(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.square `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in square") - return Array._new(np.square(x._array), device=x.device) - - -def sqrt(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.sqrt `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in sqrt") - return Array._new(np.sqrt(x._array), device=x.device) - - -def tan(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.tan `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in tan") - return Array._new(np.tan(x._array), device=x.device) - - -def tanh(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.tanh `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in tanh") - return Array._new(np.tanh(x._array), device=x.device) - - -def trunc(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.trunc `. - - See its docstring for more information. - """ - if x.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in trunc") - if x.dtype in _integer_dtypes: - # Note: The return dtype of trunc is the same as the input - return x - return Array._new(np.trunc(x._array), device=x.device) diff --git a/array_api_strict/_helpers.py b/array_api_strict/_helpers.py index e8c6767..5e4d652 100644 --- a/array_api_strict/_helpers.py +++ b/array_api_strict/_helpers.py @@ -4,7 +4,7 @@ from ._dtypes import _dtype_categories from ._flags import get_array_api_strict_flags -_py_scalars = (bool, int, float, complex) +_PY_SCALARS = (bool, int, float, complex) def _maybe_normalize_py_scalars( @@ -20,20 +20,26 @@ def _maybe_normalize_py_scalars( _allowed_dtypes = _dtype_categories[dtype_category] - if isinstance(x1, _py_scalars): - if isinstance(x2, _py_scalars): + # Disallow subclasses, e.g. np.float64 and np.complex128 + if type(x1) in _PY_SCALARS: + if type(x2) in _PY_SCALARS: raise TypeError(f"Two scalars not allowed, got {type(x1) = } and {type(x2) =}") - # x2 must be an array + if not isinstance(x2, Array): + raise TypeError(f"Argument is neither an Array nor a Python scalar: {type(x2)=} ") if x2.dtype not in _allowed_dtypes: raise TypeError(f"Only {dtype_category} dtypes are allowed {func_name}. Got {x2.dtype}.") x1 = x2._promote_scalar(x1) - elif isinstance(x2, _py_scalars): - # x1 must be an array + elif type(x2) in _PY_SCALARS: + if not isinstance(x1, Array): + raise TypeError(f"Argument is neither an Array nor a Python scalar: {type(x2)=} ") if x1.dtype not in _allowed_dtypes: raise TypeError(f"Only {dtype_category} dtypes are allowed {func_name}. Got {x1.dtype}.") x2 = x1._promote_scalar(x2) else: + if not isinstance(x1, Array) or not isinstance(x2, Array): + raise TypeError(f"Argument(s) are neither Array nor Python scalars: {type(x1)=} and {type(x2)=}") + if x1.dtype not in _allowed_dtypes or x2.dtype not in _allowed_dtypes: raise TypeError(f"Only {dtype_category} dtypes are allowed in {func_name}(...). " f"Got {x1.dtype} and {x2.dtype}.") diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index c42ccc7..48135e9 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -103,6 +103,9 @@ def where( See its docstring for more information. """ + if not isinstance(condition, Array): + raise TypeError(f"`condition` must be an Array; got {type(condition)}") + x1, x2 = _maybe_normalize_py_scalars(x1, x2, "all", "where") # Call result type here just to raise on disallowed type combinations diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 99596b4..c6f615d 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -1,16 +1,22 @@ from inspect import signature, getmodule -from pytest import raises as assert_raises +import numpy as np +import pytest from numpy.testing import suppress_warnings from .. import asarray, _elementwise_functions +from .._array_object import ALL_DEVICES, CPU_DEVICE, Device +from .._data_type_functions import isdtype from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift from .._dtypes import ( _dtype_categories, _boolean_dtypes, _floating_dtypes, _integer_dtypes, + bool as xp_bool, + complex128, + float64, int8, int16, int32, @@ -104,6 +110,13 @@ def nargs(func): } +elementwise_binary_function_names = [ + func_name + for func_name in elementwise_function_input_types + if nargs(getattr(_elementwise_functions, func_name)) == 2 +] + + def test_nargs(): # Explicitly check number of arguments for a few functions assert nargs(array_api_strict.logaddexp) == 2 @@ -126,33 +139,83 @@ def test_missing_functions(): assert set(mod_funcs) == set(elementwise_function_input_types) -def test_function_device_persists(): - # Test that the device of the input and output array are the same +@pytest.mark.parametrize("device", ALL_DEVICES) +@pytest.mark.parametrize("func_name,types", elementwise_function_input_types.items()) +def test_elementwise_function_device_persists(func_name, types, device): + """Test that the device of the input and output array are the same""" def _array_vals(dtypes): - for d in dtypes: - yield asarray(1., dtype=d) + for dtype in dtypes: + yield asarray(1., dtype=dtype, device=device) + + dtypes = _dtype_categories[types] + func = getattr(_elementwise_functions, func_name) + + for x in _array_vals(dtypes): + if nargs(func) == 2: + # This way we don't have to deal with incompatible + # types of the two arguments. + r = func(x, x) + assert r.device == x.device + + else: + # `atanh` needs a slightly different input value from + # everyone else + if func_name == "atanh": + x -= 0.1 + r = func(x) + assert r.device == x.device + + +@pytest.mark.parametrize("func_name", elementwise_binary_function_names) +def test_elementwise_function_device_mismatch(func_name): + func = getattr(_elementwise_functions, func_name) + dtypes = elementwise_function_input_types[func_name] + if dtypes in ("floating-point", "real floating-point"): + dtype = float64 + elif dtypes == "boolean": + dtype = xp_bool + else: + dtype = int64 + + a = asarray(1, dtype=dtype, device=CPU_DEVICE) + b = asarray(1, dtype=dtype, device=Device("device1")) + _ = func(a, a) + with pytest.raises(ValueError, match="different devices"): + func(a, b) + + +@pytest.mark.parametrize("func_name", elementwise_function_input_types) +def test_elementwise_function_vs_numpy_generics(func_name): + """ + Test that NumPy generics are explicitly disallowed. + + This must notably includes np.float64 and np.complex128, which are + subclasses of float and complex respectively. + """ + func = getattr(_elementwise_functions, func_name) + dtypes = elementwise_function_input_types[func_name] + xp_dtypes = _dtype_categories[dtypes] + np_dtypes = [dtype._np_dtype for dtype in xp_dtypes] + + match = ( + "You are comparing a array_api_strict dtype against a NumPy " + "native dtype object" + ) - # Use the latest version of the standard so all functions are included - set_array_api_strict_flags(api_version="2024.12") + value = 0.5 if func_name == "atanh" else 1 + for xp_dtype in xp_dtypes: + for np_dtype in np_dtypes: + a = asarray(value, dtype=xp_dtype, device=CPU_DEVICE) + b = np.asarray(value, dtype=np_dtype)[()] - for func_name, types in elementwise_function_input_types.items(): - dtypes = _dtype_categories[types] - func = getattr(_elementwise_functions, func_name) - - for x in _array_vals(dtypes): if nargs(func) == 2: - # This way we don't have to deal with incompatible - # types of the two arguments. - r = func(x, x) - assert r.device == x.device - + _ = func(a, a) + with pytest.raises(TypeError, match="neither Array nor Python scalars"): + func(a, b) else: - # `atanh` needs a slightly different input value from - # everyone else - if func_name == "atanh": - x -= 0.1 - r = func(x) - assert r.device == x.device + _ = func(a) + with pytest.raises(TypeError, match="allowed"): + func(b) def test_function_types(): @@ -168,9 +231,6 @@ def _array_vals(): for d in _floating_dtypes: yield asarray(1.0, dtype=d) - # Use the latest version of the standard so all functions are included - set_array_api_strict_flags(api_version="2024.12") - for x in _array_vals(): for func_name, types in elementwise_function_input_types.items(): dtypes = _dtype_categories[types] @@ -187,23 +247,23 @@ def _array_vals(): or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes ): - assert_raises(TypeError, func, x, y) + with pytest.raises(TypeError): + func(x, y) if x.dtype not in dtypes or y.dtype not in dtypes: - assert_raises(TypeError, func, x, y) + with pytest.raises(TypeError): + func(x, y) else: if x.dtype not in dtypes: - assert_raises(TypeError, func, x) + with pytest.raises(TypeError): + func(x) def test_bitwise_shift_error(): # bitwise shift functions should raise when the second argument is negative - assert_raises( - ValueError, lambda: bitwise_left_shift(asarray([1, 1]), asarray([1, -1])) - ) - assert_raises( - ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1])) - ) - + with pytest.raises(ValueError): + bitwise_left_shift(asarray([1, 1]), asarray([1, -1])) + with pytest.raises(ValueError): + bitwise_right_shift(asarray([1, 1]), asarray([1, -1])) def test_scalars(): @@ -212,9 +272,6 @@ def test_scalars(): # Also check that binary functions accept (array, scalar) and (scalar, array) # arguments, and reject (scalar, scalar) arguments. - # Use the latest version of the standard so that scalars are actually allowed - set_array_api_strict_flags(api_version="2024.12") - def _array_vals(): for d in _integer_dtypes: yield asarray(1, dtype=d) @@ -256,7 +313,5 @@ def _array_vals(): assert func(s, a) == func(conv_scalar, a) assert func(a, s) == func(a, conv_scalar) - with assert_raises(TypeError): + with pytest.raises(TypeError): func(s, s) - - diff --git a/array_api_strict/tests/test_searching_functions.py b/array_api_strict/tests/test_searching_functions.py index 2a3a79e..63072dc 100644 --- a/array_api_strict/tests/test_searching_functions.py +++ b/array_api_strict/tests/test_searching_functions.py @@ -3,6 +3,8 @@ import array_api_strict as xp from array_api_strict import ArrayAPIStrictFlags +from .._array_object import ALL_DEVICES, CPU_DEVICE, Device +from .._dtypes import _all_dtypes def test_where_with_scalars(): @@ -23,6 +25,10 @@ def test_where_with_scalars(): with pytest.raises(TypeError, match="Two scalars"): xp.where(x == 1, 42, 44) + # The spec does not allow for condition to be scalar + with pytest.raises(TypeError, match="Array"): + xp.where(True, x, x) + def test_where_mixed_dtypes(): # https://github.com/data-apis/array-api-strict/issues/131 @@ -42,3 +48,54 @@ def test_where_f32(): res = xp.where(xp.asarray([True, False]), 1., xp.asarray([2, 2], dtype=xp.float32)) assert res.dtype == xp.float32 + +@pytest.mark.parametrize("device", ALL_DEVICES) +def test_where_device_persists(device): + """Test that the device of the input and output array are the same""" + + cond = xp.asarray([True, False], device=device) + x1 = xp.asarray([1, 2], device=device) + x2 = xp.asarray([3, 4], device=device) + res = xp.where(cond, x1, x2) + assert res.device == device + res = xp.where(cond, 1, x2) + assert res.device == device + res = xp.where(cond, x1, 2) + assert res.device == device + + +@pytest.mark.parametrize( + "cond_device,x1_device,x2_device", + [ + (CPU_DEVICE, CPU_DEVICE, Device("device1")), + (CPU_DEVICE, Device("device1"), CPU_DEVICE), + (Device("device1"), CPU_DEVICE, CPU_DEVICE), + ] +) +def test_where_device_mismatch(cond_device, x1_device, x2_device): + cond = xp.asarray([True, False], device=cond_device) + x1 = xp.asarray([1, 2], device=x1_device) + x2 = xp.asarray([3, 4], device=x2_device) + with pytest.raises(ValueError, match="device"): + xp.where(cond, x1, x2) + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_where_numpy_generics(dtype): + """ + Test that NumPy generics are explicitly disallowed. + + This must notably includes np.float64 and np.complex128, which are + subclasses of float and complex respectively. + """ + cond = xp.asarray(True) + x1 = xp.asarray(1, dtype=dtype) + x2 = xp.asarray(1, dtype=dtype) + _ = xp.where(cond, x1, x2) + + with pytest.raises(TypeError, match="neither Array nor Python scalars"): + xp.where(cond, x1, x2._array[()]) + with pytest.raises(TypeError, match="neither Array nor Python scalars"): + xp.where(cond, x1._array[()], x2) + with pytest.raises(TypeError, match="must be an Array"): + xp.where(cond._array[()], x1, x2) From 8f9da46df25349e9fbe8b132b70ed28aa5ef6f79 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 23 Apr 2025 13:52:25 +0100 Subject: [PATCH 2/7] trivial fix --- array_api_strict/_elementwise_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index f18c9f5..91549ac 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -200,8 +200,8 @@ def wrapper(x: Array, /) -> Array: bitwise_invert = _create_unary_func("bitwise_invert", "integer or boolean", "invert") ceil = _identity_if_integer(_create_unary_func("ceil", "real numeric")) conj = _create_unary_func("conj", "numeric") -cos = _create_unary_func("cos", "floating-point", "arccos") -cosh = _create_unary_func("cosh", "floating-point", "arccosh") +cos = _create_unary_func("cos", "floating-point", "cos") +cosh = _create_unary_func("cosh", "floating-point", "cosh") exp = _create_unary_func("exp", "floating-point") expm1 = _create_unary_func("expm1", "floating-point") floor = _identity_if_integer(_create_unary_func("floor", "real numeric")) From 1d88ff91b573a1910a7d50db4bd0bbb7d7882b17 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 23 Apr 2025 13:56:24 +0100 Subject: [PATCH 3/7] lint --- array_api_strict/_elementwise_functions.py | 3 --- array_api_strict/tests/test_elementwise_functions.py | 9 --------- 2 files changed, 12 deletions(-) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 91549ac..1a3067a 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -7,12 +7,9 @@ from ._creation_functions import asarray from ._data_type_functions import broadcast_to, iinfo from ._dtypes import ( - _boolean_dtypes, _complex_floating_dtypes, _dtype_categories, - _floating_dtypes, _integer_dtypes, - _integer_or_boolean_dtypes, _numeric_dtypes, _real_floating_dtypes, _real_numeric_dtypes, diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index c6f615d..5e28cb9 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -7,7 +7,6 @@ from .. import asarray, _elementwise_functions from .._array_object import ALL_DEVICES, CPU_DEVICE, Device -from .._data_type_functions import isdtype from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift from .._dtypes import ( _dtype_categories, @@ -15,7 +14,6 @@ _floating_dtypes, _integer_dtypes, bool as xp_bool, - complex128, float64, int8, int16, @@ -23,8 +21,6 @@ int64, uint64, ) -from .._flags import set_array_api_strict_flags - from .test_array_object import _check_op_array_scalar, BIG_INT import array_api_strict @@ -197,11 +193,6 @@ def test_elementwise_function_vs_numpy_generics(func_name): xp_dtypes = _dtype_categories[dtypes] np_dtypes = [dtype._np_dtype for dtype in xp_dtypes] - match = ( - "You are comparing a array_api_strict dtype against a NumPy " - "native dtype object" - ) - value = 0.5 if func_name == "atanh" else 1 for xp_dtype in xp_dtypes: for np_dtype in np_dtypes: From 052d387be882005f330d5a01d91fd209fc7db00f Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 23 Apr 2025 13:59:23 +0100 Subject: [PATCH 4/7] reduce diff size --- array_api_strict/_elementwise_functions.py | 32 +++++++++++----------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 1a3067a..1ebbfcb 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -230,22 +230,6 @@ def wrapper(x: Array, /) -> Array: trunc = _identity_if_integer(_create_unary_func("trunc", "real numeric")) -def sign(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.sign `. - - See its docstring for more information. - """ - if not isinstance(x, Array): - raise TypeError(f"Only Array objects are allowed; got {type(x)}") - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in sign") - # Special treatment to work around non-compliant NumPy 1.x behaviour - if x.dtype in _complex_floating_dtypes: - return x/abs(x) - return Array._new(np.sign(x._array), device=x.device) - - # Note: min and max argument names are different and not optional in numpy. @requires_api_version('2023.12') def clip( @@ -349,3 +333,19 @@ def _isscalar(a): ib = (out > b) | np.isnan(b) out[ib] = b[ib] return Array._new(out, device=device) + + +def sign(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.sign `. + + See its docstring for more information. + """ + if not isinstance(x, Array): + raise TypeError(f"Only Array objects are allowed; got {type(x)}") + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in sign") + # Special treatment to work around non-compliant NumPy 1.x behaviour + if x.dtype in _complex_floating_dtypes: + return x/abs(x) + return Array._new(np.sign(x._array), device=x.device) From 8e4e8bbbaa3a1a50121c107b055c3d005315765b Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 23 Apr 2025 14:04:13 +0100 Subject: [PATCH 5/7] alphabetical order --- array_api_strict/_elementwise_functions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 1ebbfcb..7cc6cf6 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -207,24 +207,24 @@ def wrapper(x: Array, /) -> Array: isinf = _create_unary_func("isinf", "numeric") isnan = _create_unary_func("isnan", "numeric") log = _create_unary_func("log", "floating-point") +log10 = _create_unary_func("log10", "floating-point") log1p = _create_unary_func("log1p", "floating-point") log2 = _create_unary_func("log2", "floating-point") -log10 = _create_unary_func("log10", "floating-point") logical_not = _create_unary_func("logical_not", "boolean") negative = _create_unary_func("negative", "numeric") positive = _create_unary_func("positive", "numeric") +real = _create_unary_func("real", "numeric") reciprocal = requires_api_version("2024.12")( _create_unary_func("reciprocal", "floating-point") ) -real = _create_unary_func("real", "numeric") round = _create_unary_func("round", "numeric") signbit = requires_api_version("2023.12")( _create_unary_func("signbit", "real floating-point") ) sin = _create_unary_func("sin", "floating-point") sinh = _create_unary_func("sinh", "floating-point") -square = _create_unary_func("square", "numeric") sqrt = _create_unary_func("sqrt", "floating-point") +square = _create_unary_func("square", "numeric") tan = _create_unary_func("tan", "floating-point") tanh = _create_unary_func("tanh", "floating-point") trunc = _identity_if_integer(_create_unary_func("trunc", "real numeric")) From 771d85ad3f0263e74a3566358d3c0212a216134f Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 23 Apr 2025 14:07:07 +0100 Subject: [PATCH 6/7] type annotations --- array_api_strict/_elementwise_functions.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 7cc6cf6..cfad896 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from functools import wraps from types import NoneType @@ -153,7 +154,11 @@ def bitwise_right_shift(x1: int | Array, x2: int | Array, /) -> Array: del func, _create_binary_func -def _create_unary_func(func_name, dtype_category, np_func_name=None): +def _create_unary_func( + func_name: str, + dtype_category: str, + np_func_name: str | None = None, +) -> Callable[[Array], Array]: allowed_dtypes = _dtype_categories[dtype_category] np_func_name = np_func_name or func_name np_func = getattr(np, np_func_name) @@ -173,7 +178,7 @@ def func(x: Array, /) -> Array: return func -def _identity_if_integer(func): +def _identity_if_integer(func: Callable[[Array], Array]) -> Callable[[Array], Array]: """Hack around NumPy 1.x behaviour for ceil, floor, and trunc vs. integer inputs """ From ff8ea025c487c79a0753c7229cbf79488eb14e5e Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 24 Apr 2025 09:50:07 +0100 Subject: [PATCH 7/7] Code review --- array_api_strict/tests/test_array_object.py | 10 ++++++---- array_api_strict/tests/test_elementwise_functions.py | 11 +++++++---- array_api_strict/tests/test_searching_functions.py | 9 +++++---- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index e950be5..15d88a9 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -412,10 +412,12 @@ def _matmul_array_vals(): @pytest.mark.parametrize("op,dtypes", binary_op_dtypes.items()) -def test_binary_operators_vs_numpy_generics(op, dtypes): - """Test that np.bool_, np.int64, np.float32, np.float64, np.complex64, np.complex128 - are disallowed in binary operators. - np.float64 and np.complex128 are subclasses of float and complex, so they need +def test_binary_operators_numpy_scalars(op, dtypes): + """ + Test that NumPy scalars (np.generic) are explicitly disallowed. + + This must notably include np.float64 and np.complex128, which are + subclasses of float and complex respectively, so they need special treatment in order to be rejected. """ match = "Expected Array or Python scalar" diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 5e28cb9..0f740d3 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -181,12 +181,13 @@ def test_elementwise_function_device_mismatch(func_name): @pytest.mark.parametrize("func_name", elementwise_function_input_types) -def test_elementwise_function_vs_numpy_generics(func_name): +def test_elementwise_function_numpy_scalars(func_name): """ - Test that NumPy generics are explicitly disallowed. + Test that NumPy scalars (np.generic) are explicitly disallowed. - This must notably includes np.float64 and np.complex128, which are - subclasses of float and complex respectively. + This must notably include np.float64 and np.complex128, which are + subclasses of float and complex respectively, so they need + special treatment in order to be rejected. """ func = getattr(_elementwise_functions, func_name) dtypes = elementwise_function_input_types[func_name] @@ -203,6 +204,8 @@ def test_elementwise_function_vs_numpy_generics(func_name): _ = func(a, a) with pytest.raises(TypeError, match="neither Array nor Python scalars"): func(a, b) + with pytest.raises(TypeError, match="neither Array nor Python scalars"): + func(b, a) else: _ = func(a) with pytest.raises(TypeError, match="allowed"): diff --git a/array_api_strict/tests/test_searching_functions.py b/array_api_strict/tests/test_searching_functions.py index 63072dc..abe1949 100644 --- a/array_api_strict/tests/test_searching_functions.py +++ b/array_api_strict/tests/test_searching_functions.py @@ -81,12 +81,13 @@ def test_where_device_mismatch(cond_device, x1_device, x2_device): @pytest.mark.parametrize("dtype", _all_dtypes) -def test_where_numpy_generics(dtype): +def test_where_numpy_scalars(dtype): """ - Test that NumPy generics are explicitly disallowed. + Test that NumPy scalars (np.generic) are explicitly disallowed. - This must notably includes np.float64 and np.complex128, which are - subclasses of float and complex respectively. + This must notably include np.float64 and np.complex128, which are + subclasses of float and complex respectively, so they need + special treatment in order to be rejected. """ cond = xp.asarray(True) x1 = xp.asarray(1, dtype=dtype)