diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index b05e0fd..cfad896 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -1,15 +1,16 @@ +from collections.abc import Callable +from functools import wraps +from types import NoneType + import numpy as np from ._array_object import Array from ._creation_functions import asarray from ._data_type_functions import broadcast_to, iinfo from ._dtypes import ( - _boolean_dtypes, _complex_floating_dtypes, _dtype_categories, - _floating_dtypes, _integer_dtypes, - _integer_or_boolean_dtypes, _numeric_dtypes, _real_floating_dtypes, _real_numeric_dtypes, @@ -35,7 +36,7 @@ def _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func): return Array._new(np_func(x1._array, x2._array), device=x1.device) -_binary_docstring_template = """ +_docstring_template = """ Array API compatible wrapper for :py:func:`np.%s `. See its docstring for more information. @@ -117,7 +118,7 @@ def inner(x1, x2, /) -> Array: func = _create_binary_func(func_name, dtype_category, np_func) func.__name__ = func_name - func.__doc__ = _binary_docstring_template % (numpy_name, numpy_name) + func.__doc__ = _docstring_template % (numpy_name, numpy_name) func.__annotations__['x1'] = _annotations[dtype_category] func.__annotations__['x2'] = _annotations[dtype_category] @@ -153,115 +154,86 @@ def bitwise_right_shift(x1: int | Array, x2: int | Array, /) -> Array: del func, _create_binary_func -def abs(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.abs `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in abs") - return Array._new(np.abs(x._array), device=x.device) - - -# Note: the function name is different here -def acos(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.arccos `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in acos") - return Array._new(np.arccos(x._array), device=x.device) - - -# Note: the function name is different here -def acosh(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.arccosh `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in acosh") - return Array._new(np.arccosh(x._array), device=x.device) - - -# Note: the function name is different here -def asin(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.arcsin `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in asin") - return Array._new(np.arcsin(x._array), device=x.device) - - -# Note: the function name is different here -def asinh(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.arcsinh `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in asinh") - return Array._new(np.arcsinh(x._array), device=x.device) - - -# Note: the function name is different here -def atan(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.arctan `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in atan") - return Array._new(np.arctan(x._array), device=x.device) - - -# Note: the function name is different here -def atanh(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.arctanh `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in atanh") - return Array._new(np.arctanh(x._array), device=x.device) - - -# Note: the function name is different here -def bitwise_invert(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.invert `. - - See its docstring for more information. - """ - if x.dtype not in _integer_or_boolean_dtypes: - raise TypeError("Only integer or boolean dtypes are allowed in bitwise_invert") - return Array._new(np.invert(x._array), device=x.device) - +def _create_unary_func( + func_name: str, + dtype_category: str, + np_func_name: str | None = None, +) -> Callable[[Array], Array]: + allowed_dtypes = _dtype_categories[dtype_category] + np_func_name = np_func_name or func_name + np_func = getattr(np, np_func_name) + + def func(x: Array, /) -> Array: + if not isinstance(x, Array): + raise TypeError(f"Only Array objects are allowed; got {type(x)}") + if x.dtype not in allowed_dtypes: + raise TypeError( + f"Only {dtype_category} dtypes are allowed in {func_name}; " + f"got {x.dtype}." + ) + return Array._new(np_func(x._array), device=x.device) -def ceil(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.ceil `. - - See its docstring for more information. - """ - if x.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in ceil") - if x.dtype in _integer_dtypes: - # Note: The return dtype of ceil is the same as the input - return x - return Array._new(np.ceil(x._array), device=x.device) + func.__name__ = func_name + func.__doc__ = _docstring_template % (np_func_name, np_func_name) + return func + + +def _identity_if_integer(func: Callable[[Array], Array]) -> Callable[[Array], Array]: + """Hack around NumPy 1.x behaviour for ceil, floor, and trunc + vs. integer inputs + """ + + @wraps(func) + def wrapper(x: Array, /) -> Array: + if isinstance(x, Array) and x.dtype in _integer_dtypes: + return x + return func(x) + + return wrapper + + +abs = _create_unary_func("abs", "numeric") +acos = _create_unary_func("acos", "floating-point", "arccos") +acosh = _create_unary_func("acosh", "floating-point", "arccosh") +asin = _create_unary_func("asin", "floating-point", "arcsin") +asinh = _create_unary_func("asinh", "floating-point", "arcsinh") +atan = _create_unary_func("atan", "floating-point", "arctan") +atanh = _create_unary_func("atanh", "floating-point", "arctanh") +bitwise_invert = _create_unary_func("bitwise_invert", "integer or boolean", "invert") +ceil = _identity_if_integer(_create_unary_func("ceil", "real numeric")) +conj = _create_unary_func("conj", "numeric") +cos = _create_unary_func("cos", "floating-point", "cos") +cosh = _create_unary_func("cosh", "floating-point", "cosh") +exp = _create_unary_func("exp", "floating-point") +expm1 = _create_unary_func("expm1", "floating-point") +floor = _identity_if_integer(_create_unary_func("floor", "real numeric")) +imag = _create_unary_func("imag", "complex floating-point") +isfinite = _create_unary_func("isfinite", "numeric") +isinf = _create_unary_func("isinf", "numeric") +isnan = _create_unary_func("isnan", "numeric") +log = _create_unary_func("log", "floating-point") +log10 = _create_unary_func("log10", "floating-point") +log1p = _create_unary_func("log1p", "floating-point") +log2 = _create_unary_func("log2", "floating-point") +logical_not = _create_unary_func("logical_not", "boolean") +negative = _create_unary_func("negative", "numeric") +positive = _create_unary_func("positive", "numeric") +real = _create_unary_func("real", "numeric") +reciprocal = requires_api_version("2024.12")( + _create_unary_func("reciprocal", "floating-point") +) +round = _create_unary_func("round", "numeric") +signbit = requires_api_version("2023.12")( + _create_unary_func("signbit", "real floating-point") +) +sin = _create_unary_func("sin", "floating-point") +sinh = _create_unary_func("sinh", "floating-point") +sqrt = _create_unary_func("sqrt", "floating-point") +square = _create_unary_func("square", "numeric") +tan = _create_unary_func("tan", "floating-point") +tanh = _create_unary_func("tanh", "floating-point") +trunc = _identity_if_integer(_create_unary_func("trunc", "real numeric")) -# WARNING: This function is not yet tested by the array-api-tests test suite. # Note: min and max argument names are different and not optional in numpy. @requires_api_version('2023.12') @@ -276,42 +248,40 @@ def clip( See its docstring for more information. """ - if isinstance(min, Array) and x.device != min.device: - raise ValueError(f"Arrays from two different devices ({x.device} and {min.device}) can not be combined.") - if isinstance(max, Array) and x.device != max.device: - raise ValueError(f"Arrays from two different devices ({x.device} and {max.device}) can not be combined.") + if not isinstance(x, Array): + raise TypeError(f"Only Array objects are allowed; got {type(x)}") if (x.dtype not in _real_numeric_dtypes or isinstance(min, Array) and min.dtype not in _real_numeric_dtypes or isinstance(max, Array) and max.dtype not in _real_numeric_dtypes): raise TypeError("Only real numeric dtypes are allowed in clip") - if not isinstance(min, (int, float, Array, type(None))): - raise TypeError("min must be an None, int, float, or an array") - if not isinstance(max, (int, float, Array, type(None))): - raise TypeError("max must be an None, int, float, or an array") - - # Mixed dtype kinds is implementation defined - if (x.dtype in _integer_dtypes - and (isinstance(min, float) or - isinstance(min, Array) and min.dtype in _real_floating_dtypes)): - raise TypeError("min must be integral when x is integral") - if (x.dtype in _integer_dtypes - and (isinstance(max, float) or - isinstance(max, Array) and max.dtype in _real_floating_dtypes)): - raise TypeError("max must be integral when x is integral") - if (x.dtype in _real_floating_dtypes - and (isinstance(min, int) or - isinstance(min, Array) and min.dtype in _integer_dtypes)): - raise TypeError("min must be floating-point when x is floating-point") - if (x.dtype in _real_floating_dtypes - and (isinstance(max, int) or - isinstance(max, Array) and max.dtype in _integer_dtypes)): - raise TypeError("max must be floating-point when x is floating-point") if min is max is None: - # Note: NumPy disallows min = max = None return x + for argname, arg in ("min", min), ("max", max): + if isinstance(arg, Array): + if x.device != arg.device: + raise ValueError( + f"Arrays from two different devices ({x.device} and {arg.device}) " + "can not be combined." + ) + # Disallow subclasses of Python scalars, e.g. np.float64 + elif type(arg) not in (int, float, NoneType): + raise TypeError( + f"{argname} must be None, int, float, or Array; got {type(arg)}" + ) + + # Mixed dtype kinds is implementation defined + if (x.dtype in _integer_dtypes + and (isinstance(arg, float) or + isinstance(arg, Array) and arg.dtype in _real_floating_dtypes)): + raise TypeError(f"{argname} must be integral when x is integral") + if (x.dtype in _real_floating_dtypes + and (isinstance(arg, int) or + isinstance(arg, Array) and arg.dtype in _integer_dtypes)): + raise TypeError(f"{arg} must be floating-point when x is floating-point") + # Normalize to make the below logic simpler if min is not None: min = asarray(min)._array @@ -370,330 +340,17 @@ def _isscalar(a): return Array._new(out, device=device) -def conj(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.conj `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in conj") - return Array._new(np.conj(x._array), device=x.device) - - -def cos(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.cos `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in cos") - return Array._new(np.cos(x._array), device=x.device) - - -def cosh(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.cosh `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in cosh") - return Array._new(np.cosh(x._array), device=x.device) - - -def exp(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.exp `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in exp") - return Array._new(np.exp(x._array), device=x.device) - - -def expm1(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.expm1 `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in expm1") - return Array._new(np.expm1(x._array), device=x.device) - - -def floor(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.floor `. - - See its docstring for more information. - """ - if x.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in floor") - if x.dtype in _integer_dtypes: - # Note: The return dtype of floor is the same as the input - return x - return Array._new(np.floor(x._array), device=x.device) - - -def imag(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.imag `. - - See its docstring for more information. - """ - if x.dtype not in _complex_floating_dtypes: - raise TypeError("Only complex floating-point dtypes are allowed in imag") - return Array._new(np.imag(x._array), device=x.device) - - -def isfinite(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.isfinite `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in isfinite") - return Array._new(np.isfinite(x._array), device=x.device) - - -def isinf(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.isinf `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in isinf") - return Array._new(np.isinf(x._array), device=x.device) - - -def isnan(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.isnan `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in isnan") - return Array._new(np.isnan(x._array), device=x.device) - - -def log(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.log `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in log") - return Array._new(np.log(x._array), device=x.device) - - -def log1p(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.log1p `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in log1p") - return Array._new(np.log1p(x._array), device=x.device) - - -def log2(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.log2 `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in log2") - return Array._new(np.log2(x._array), device=x.device) - - -def log10(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.log10 `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in log10") - return Array._new(np.log10(x._array), device=x.device) - - -def logical_not(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.logical_not `. - - See its docstring for more information. - """ - if x.dtype not in _boolean_dtypes: - raise TypeError("Only boolean dtypes are allowed in logical_not") - return Array._new(np.logical_not(x._array), device=x.device) - - -def negative(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.negative `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in negative") - return Array._new(np.negative(x._array), device=x.device) - - -def positive(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.positive `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in positive") - return Array._new(np.positive(x._array), device=x.device) - - -def real(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.real `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in real") - return Array._new(np.real(x._array), device=x.device) - - -@requires_api_version('2024.12') -def reciprocal(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.reciprocal `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in reciprocal") - return Array._new(np.reciprocal(x._array), device=x.device) - - -def round(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.round `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in round") - return Array._new(np.round(x._array), device=x.device) - - def sign(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.sign `. See its docstring for more information. """ + if not isinstance(x, Array): + raise TypeError(f"Only Array objects are allowed; got {type(x)}") if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in sign") + # Special treatment to work around non-compliant NumPy 1.x behaviour if x.dtype in _complex_floating_dtypes: return x/abs(x) return Array._new(np.sign(x._array), device=x.device) - - -@requires_api_version('2023.12') -def signbit(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.signbit `. - - See its docstring for more information. - """ - if x.dtype not in _real_floating_dtypes: - raise TypeError("Only real floating-point dtypes are allowed in signbit") - return Array._new(np.signbit(x._array), device=x.device) - - -def sin(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.sin `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in sin") - return Array._new(np.sin(x._array), device=x.device) - - -def sinh(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.sinh `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in sinh") - return Array._new(np.sinh(x._array), device=x.device) - - -def square(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.square `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in square") - return Array._new(np.square(x._array), device=x.device) - - -def sqrt(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.sqrt `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in sqrt") - return Array._new(np.sqrt(x._array), device=x.device) - - -def tan(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.tan `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in tan") - return Array._new(np.tan(x._array), device=x.device) - - -def tanh(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.tanh `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in tanh") - return Array._new(np.tanh(x._array), device=x.device) - - -def trunc(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.trunc `. - - See its docstring for more information. - """ - if x.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in trunc") - if x.dtype in _integer_dtypes: - # Note: The return dtype of trunc is the same as the input - return x - return Array._new(np.trunc(x._array), device=x.device) diff --git a/array_api_strict/_helpers.py b/array_api_strict/_helpers.py index e8c6767..5e4d652 100644 --- a/array_api_strict/_helpers.py +++ b/array_api_strict/_helpers.py @@ -4,7 +4,7 @@ from ._dtypes import _dtype_categories from ._flags import get_array_api_strict_flags -_py_scalars = (bool, int, float, complex) +_PY_SCALARS = (bool, int, float, complex) def _maybe_normalize_py_scalars( @@ -20,20 +20,26 @@ def _maybe_normalize_py_scalars( _allowed_dtypes = _dtype_categories[dtype_category] - if isinstance(x1, _py_scalars): - if isinstance(x2, _py_scalars): + # Disallow subclasses, e.g. np.float64 and np.complex128 + if type(x1) in _PY_SCALARS: + if type(x2) in _PY_SCALARS: raise TypeError(f"Two scalars not allowed, got {type(x1) = } and {type(x2) =}") - # x2 must be an array + if not isinstance(x2, Array): + raise TypeError(f"Argument is neither an Array nor a Python scalar: {type(x2)=} ") if x2.dtype not in _allowed_dtypes: raise TypeError(f"Only {dtype_category} dtypes are allowed {func_name}. Got {x2.dtype}.") x1 = x2._promote_scalar(x1) - elif isinstance(x2, _py_scalars): - # x1 must be an array + elif type(x2) in _PY_SCALARS: + if not isinstance(x1, Array): + raise TypeError(f"Argument is neither an Array nor a Python scalar: {type(x2)=} ") if x1.dtype not in _allowed_dtypes: raise TypeError(f"Only {dtype_category} dtypes are allowed {func_name}. Got {x1.dtype}.") x2 = x1._promote_scalar(x2) else: + if not isinstance(x1, Array) or not isinstance(x2, Array): + raise TypeError(f"Argument(s) are neither Array nor Python scalars: {type(x1)=} and {type(x2)=}") + if x1.dtype not in _allowed_dtypes or x2.dtype not in _allowed_dtypes: raise TypeError(f"Only {dtype_category} dtypes are allowed in {func_name}(...). " f"Got {x1.dtype} and {x2.dtype}.") diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index c42ccc7..48135e9 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -103,6 +103,9 @@ def where( See its docstring for more information. """ + if not isinstance(condition, Array): + raise TypeError(f"`condition` must be an Array; got {type(condition)}") + x1, x2 = _maybe_normalize_py_scalars(x1, x2, "all", "where") # Call result type here just to raise on disallowed type combinations diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index e950be5..15d88a9 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -412,10 +412,12 @@ def _matmul_array_vals(): @pytest.mark.parametrize("op,dtypes", binary_op_dtypes.items()) -def test_binary_operators_vs_numpy_generics(op, dtypes): - """Test that np.bool_, np.int64, np.float32, np.float64, np.complex64, np.complex128 - are disallowed in binary operators. - np.float64 and np.complex128 are subclasses of float and complex, so they need +def test_binary_operators_numpy_scalars(op, dtypes): + """ + Test that NumPy scalars (np.generic) are explicitly disallowed. + + This must notably include np.float64 and np.complex128, which are + subclasses of float and complex respectively, so they need special treatment in order to be rejected. """ match = "Expected Array or Python scalar" diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 99596b4..0f740d3 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -1,24 +1,26 @@ from inspect import signature, getmodule -from pytest import raises as assert_raises +import numpy as np +import pytest from numpy.testing import suppress_warnings from .. import asarray, _elementwise_functions +from .._array_object import ALL_DEVICES, CPU_DEVICE, Device from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift from .._dtypes import ( _dtype_categories, _boolean_dtypes, _floating_dtypes, _integer_dtypes, + bool as xp_bool, + float64, int8, int16, int32, int64, uint64, ) -from .._flags import set_array_api_strict_flags - from .test_array_object import _check_op_array_scalar, BIG_INT import array_api_strict @@ -104,6 +106,13 @@ def nargs(func): } +elementwise_binary_function_names = [ + func_name + for func_name in elementwise_function_input_types + if nargs(getattr(_elementwise_functions, func_name)) == 2 +] + + def test_nargs(): # Explicitly check number of arguments for a few functions assert nargs(array_api_strict.logaddexp) == 2 @@ -126,33 +135,81 @@ def test_missing_functions(): assert set(mod_funcs) == set(elementwise_function_input_types) -def test_function_device_persists(): - # Test that the device of the input and output array are the same +@pytest.mark.parametrize("device", ALL_DEVICES) +@pytest.mark.parametrize("func_name,types", elementwise_function_input_types.items()) +def test_elementwise_function_device_persists(func_name, types, device): + """Test that the device of the input and output array are the same""" def _array_vals(dtypes): - for d in dtypes: - yield asarray(1., dtype=d) - - # Use the latest version of the standard so all functions are included - set_array_api_strict_flags(api_version="2024.12") - - for func_name, types in elementwise_function_input_types.items(): - dtypes = _dtype_categories[types] - func = getattr(_elementwise_functions, func_name) + for dtype in dtypes: + yield asarray(1., dtype=dtype, device=device) + + dtypes = _dtype_categories[types] + func = getattr(_elementwise_functions, func_name) + + for x in _array_vals(dtypes): + if nargs(func) == 2: + # This way we don't have to deal with incompatible + # types of the two arguments. + r = func(x, x) + assert r.device == x.device + + else: + # `atanh` needs a slightly different input value from + # everyone else + if func_name == "atanh": + x -= 0.1 + r = func(x) + assert r.device == x.device + + +@pytest.mark.parametrize("func_name", elementwise_binary_function_names) +def test_elementwise_function_device_mismatch(func_name): + func = getattr(_elementwise_functions, func_name) + dtypes = elementwise_function_input_types[func_name] + if dtypes in ("floating-point", "real floating-point"): + dtype = float64 + elif dtypes == "boolean": + dtype = xp_bool + else: + dtype = int64 + + a = asarray(1, dtype=dtype, device=CPU_DEVICE) + b = asarray(1, dtype=dtype, device=Device("device1")) + _ = func(a, a) + with pytest.raises(ValueError, match="different devices"): + func(a, b) + + +@pytest.mark.parametrize("func_name", elementwise_function_input_types) +def test_elementwise_function_numpy_scalars(func_name): + """ + Test that NumPy scalars (np.generic) are explicitly disallowed. + + This must notably include np.float64 and np.complex128, which are + subclasses of float and complex respectively, so they need + special treatment in order to be rejected. + """ + func = getattr(_elementwise_functions, func_name) + dtypes = elementwise_function_input_types[func_name] + xp_dtypes = _dtype_categories[dtypes] + np_dtypes = [dtype._np_dtype for dtype in xp_dtypes] + + value = 0.5 if func_name == "atanh" else 1 + for xp_dtype in xp_dtypes: + for np_dtype in np_dtypes: + a = asarray(value, dtype=xp_dtype, device=CPU_DEVICE) + b = np.asarray(value, dtype=np_dtype)[()] - for x in _array_vals(dtypes): if nargs(func) == 2: - # This way we don't have to deal with incompatible - # types of the two arguments. - r = func(x, x) - assert r.device == x.device - + _ = func(a, a) + with pytest.raises(TypeError, match="neither Array nor Python scalars"): + func(a, b) + with pytest.raises(TypeError, match="neither Array nor Python scalars"): + func(b, a) else: - # `atanh` needs a slightly different input value from - # everyone else - if func_name == "atanh": - x -= 0.1 - r = func(x) - assert r.device == x.device + _ = func(a) + with pytest.raises(TypeError, match="allowed"): + func(b) def test_function_types(): @@ -168,9 +225,6 @@ def _array_vals(): for d in _floating_dtypes: yield asarray(1.0, dtype=d) - # Use the latest version of the standard so all functions are included - set_array_api_strict_flags(api_version="2024.12") - for x in _array_vals(): for func_name, types in elementwise_function_input_types.items(): dtypes = _dtype_categories[types] @@ -187,23 +241,23 @@ def _array_vals(): or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes ): - assert_raises(TypeError, func, x, y) + with pytest.raises(TypeError): + func(x, y) if x.dtype not in dtypes or y.dtype not in dtypes: - assert_raises(TypeError, func, x, y) + with pytest.raises(TypeError): + func(x, y) else: if x.dtype not in dtypes: - assert_raises(TypeError, func, x) + with pytest.raises(TypeError): + func(x) def test_bitwise_shift_error(): # bitwise shift functions should raise when the second argument is negative - assert_raises( - ValueError, lambda: bitwise_left_shift(asarray([1, 1]), asarray([1, -1])) - ) - assert_raises( - ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1])) - ) - + with pytest.raises(ValueError): + bitwise_left_shift(asarray([1, 1]), asarray([1, -1])) + with pytest.raises(ValueError): + bitwise_right_shift(asarray([1, 1]), asarray([1, -1])) def test_scalars(): @@ -212,9 +266,6 @@ def test_scalars(): # Also check that binary functions accept (array, scalar) and (scalar, array) # arguments, and reject (scalar, scalar) arguments. - # Use the latest version of the standard so that scalars are actually allowed - set_array_api_strict_flags(api_version="2024.12") - def _array_vals(): for d in _integer_dtypes: yield asarray(1, dtype=d) @@ -256,7 +307,5 @@ def _array_vals(): assert func(s, a) == func(conv_scalar, a) assert func(a, s) == func(a, conv_scalar) - with assert_raises(TypeError): + with pytest.raises(TypeError): func(s, s) - - diff --git a/array_api_strict/tests/test_searching_functions.py b/array_api_strict/tests/test_searching_functions.py index 2a3a79e..abe1949 100644 --- a/array_api_strict/tests/test_searching_functions.py +++ b/array_api_strict/tests/test_searching_functions.py @@ -3,6 +3,8 @@ import array_api_strict as xp from array_api_strict import ArrayAPIStrictFlags +from .._array_object import ALL_DEVICES, CPU_DEVICE, Device +from .._dtypes import _all_dtypes def test_where_with_scalars(): @@ -23,6 +25,10 @@ def test_where_with_scalars(): with pytest.raises(TypeError, match="Two scalars"): xp.where(x == 1, 42, 44) + # The spec does not allow for condition to be scalar + with pytest.raises(TypeError, match="Array"): + xp.where(True, x, x) + def test_where_mixed_dtypes(): # https://github.com/data-apis/array-api-strict/issues/131 @@ -42,3 +48,55 @@ def test_where_f32(): res = xp.where(xp.asarray([True, False]), 1., xp.asarray([2, 2], dtype=xp.float32)) assert res.dtype == xp.float32 + +@pytest.mark.parametrize("device", ALL_DEVICES) +def test_where_device_persists(device): + """Test that the device of the input and output array are the same""" + + cond = xp.asarray([True, False], device=device) + x1 = xp.asarray([1, 2], device=device) + x2 = xp.asarray([3, 4], device=device) + res = xp.where(cond, x1, x2) + assert res.device == device + res = xp.where(cond, 1, x2) + assert res.device == device + res = xp.where(cond, x1, 2) + assert res.device == device + + +@pytest.mark.parametrize( + "cond_device,x1_device,x2_device", + [ + (CPU_DEVICE, CPU_DEVICE, Device("device1")), + (CPU_DEVICE, Device("device1"), CPU_DEVICE), + (Device("device1"), CPU_DEVICE, CPU_DEVICE), + ] +) +def test_where_device_mismatch(cond_device, x1_device, x2_device): + cond = xp.asarray([True, False], device=cond_device) + x1 = xp.asarray([1, 2], device=x1_device) + x2 = xp.asarray([3, 4], device=x2_device) + with pytest.raises(ValueError, match="device"): + xp.where(cond, x1, x2) + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_where_numpy_scalars(dtype): + """ + Test that NumPy scalars (np.generic) are explicitly disallowed. + + This must notably include np.float64 and np.complex128, which are + subclasses of float and complex respectively, so they need + special treatment in order to be rejected. + """ + cond = xp.asarray(True) + x1 = xp.asarray(1, dtype=dtype) + x2 = xp.asarray(1, dtype=dtype) + _ = xp.where(cond, x1, x2) + + with pytest.raises(TypeError, match="neither Array nor Python scalars"): + xp.where(cond, x1, x2._array[()]) + with pytest.raises(TypeError, match="neither Array nor Python scalars"): + xp.where(cond, x1._array[()], x2) + with pytest.raises(TypeError, match="must be an Array"): + xp.where(cond._array[()], x1, x2)