diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 39905456..d82edc62 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -1,6 +1,7 @@ """ Test element-wise functions/operators against reference implementations. """ +import cmath import math import operator from copy import copy @@ -48,7 +49,7 @@ def mock_int_dtype(n: int, dtype: DataType) -> int: def isclose( a: float, b: float, - M: float, + maximum: float, *, rel_tol: float = 0.25, abs_tol: float = 1, @@ -61,12 +62,30 @@ def isclose( if math.isnan(a) or math.isnan(b): raise ValueError(f"{a=} and {b=}, but input must be non-NaN") if math.isinf(a): - return math.isinf(b) or abs(b) > math.log(M) + return math.isinf(b) or abs(b) > math.log(maximum) elif math.isinf(b): - return math.isinf(a) or abs(a) > math.log(M) + return math.isinf(a) or abs(a) > math.log(maximum) return math.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol) +def isclose_complex( + a: complex, + b: complex, + maximum: float, + *, + rel_tol: float = 0.25, + abs_tol: float = 1, +) -> bool: + """Like isclose() but specifically for complex values.""" + if cmath.isnan(a) or cmath.isnan(b): + raise ValueError(f"{a=} and {b=}, but input must be non-NaN") + if cmath.isinf(a): + return cmath.isinf(b) or abs(b) > cmath.log(maximum) + elif cmath.isinf(b): + return cmath.isinf(a) or abs(a) > cmath.log(maximum) + return cmath.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol) + + def default_filter(s: Scalar) -> bool: """Returns False when s is a non-finite or a signed zero. @@ -254,8 +273,7 @@ def unary_assert_against_refimpl( f"{f_i}={scalar_i}" ) if res.dtype in dh.complex_dtypes: - assert isclose(scalar_o.real, expected.real, M), msg - assert isclose(scalar_o.imag, expected.imag, M), msg + assert isclose_complex(scalar_o, expected, M), msg else: assert isclose(scalar_o, expected, M), msg else: @@ -330,8 +348,7 @@ def binary_assert_against_refimpl( f"{f_l}={scalar_l}, {f_r}={scalar_r}" ) if res.dtype in dh.complex_dtypes: - assert isclose(scalar_o.real, expected.real, M), msg - assert isclose(scalar_o.imag, expected.imag, M), msg + assert isclose_complex(scalar_o, expected, M), msg else: assert isclose(scalar_o, expected, M), msg else: @@ -403,8 +420,7 @@ def right_scalar_assert_against_refimpl( f"{f_l}={scalar_l}" ) if res.dtype in dh.complex_dtypes: - assert isclose(scalar_o.real, expected.real, M), msg - assert isclose(scalar_o.imag, expected.imag, M), msg + assert isclose_complex(scalar_o, expected, M), msg else: assert isclose(scalar_o, expected, M), msg else: @@ -1394,7 +1410,7 @@ def test_square(x): ph.assert_dtype("square", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("square", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl( - "square", x, out, lambda s: s**2, expr_template="{}²={}" + "square", x, out, lambda s: s*s, expr_template="{}²={}" )