diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 4c8333c9..785d3665 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -244,7 +244,7 @@ def unary_assert_against_refimpl( continue try: expected = refimpl(scalar_i) - except Exception: + except OverflowError: continue if res.dtype != xp.bool: if res.dtype in dh.complex_dtypes: @@ -319,7 +319,7 @@ def binary_assert_against_refimpl( continue try: expected = refimpl(scalar_l, scalar_r) - except Exception: + except OverflowError: continue if res.dtype != xp.bool: if res.dtype in dh.complex_dtypes: @@ -394,7 +394,7 @@ def right_scalar_assert_against_refimpl( continue try: expected = refimpl(scalar_l, right) - except Exception: + except OverflowError: continue if left.dtype != xp.bool: if res.dtype in dh.complex_dtypes: @@ -712,9 +712,9 @@ def test_abs(ctx, data): abs, # type: ignore res_stype=float if x.dtype in dh.complex_dtypes else None, expr_template="abs({})={}", - filter_=lambda s: ( - s == float("infinity") or (math.isfinite(s) and not ph.is_neg_zero(s)) - ), + # filter_=lambda s: ( + # s == float("infinity") or (cmath.isfinite(s) and not ph.is_neg_zero(s)) + # ), ) @@ -723,8 +723,10 @@ def test_acos(x): out = xp.acos(x) ph.assert_dtype("acos", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("acos", out_shape=out.shape, expected=x.shape) + refimpl = cmath.acos if x.dtype in dh.complex_dtypes else math.acos + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 <= s <= 1 unary_assert_against_refimpl( - "acos", x, out, math.acos, filter_=lambda s: default_filter(s) and -1 <= s <= 1 + "acos", x, out, refimpl, filter_=filter_ ) @@ -733,8 +735,10 @@ def test_acosh(x): out = xp.acosh(x) ph.assert_dtype("acosh", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("acosh", out_shape=out.shape, expected=x.shape) + refimpl = cmath.acosh if x.dtype in dh.complex_dtypes else math.acosh + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s >= 1 unary_assert_against_refimpl( - "acosh", x, out, math.acosh, filter_=lambda s: default_filter(s) and s >= 1 + "acosh", x, out, refimpl, filter_=filter_ ) @@ -757,8 +761,10 @@ def test_asin(x): out = xp.asin(x) ph.assert_dtype("asin", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("asin", out_shape=out.shape, expected=x.shape) + refimpl = cmath.asin if x.dtype in dh.complex_dtypes else math.asin + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 <= s <= 1 unary_assert_against_refimpl( - "asin", x, out, math.asin, filter_=lambda s: default_filter(s) and -1 <= s <= 1 + "asin", x, out, refimpl, filter_=filter_ ) @@ -767,7 +773,8 @@ def test_asinh(x): out = xp.asinh(x) ph.assert_dtype("asinh", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("asinh", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("asinh", x, out, math.asinh) + refimpl = cmath.asinh if x.dtype in dh.complex_dtypes else math.asinh + unary_assert_against_refimpl("asinh", x, out, refimpl) @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) @@ -775,7 +782,8 @@ def test_atan(x): out = xp.atan(x) ph.assert_dtype("atan", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("atan", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("atan", x, out, math.atan) + refimpl = cmath.atan if x.dtype in dh.complex_dtypes else math.atan + unary_assert_against_refimpl("atan", x, out, refimpl) @given(*hh.two_mutual_arrays(dh.real_float_dtypes)) @@ -783,7 +791,8 @@ def test_atan2(x1, x2): out = xp.atan2(x1, x2) ph.assert_dtype("atan2", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) ph.assert_result_shape("atan2", in_shapes=[x1.shape, x2.shape], out_shape=out.shape) - binary_assert_against_refimpl("atan2", x1, x2, out, math.atan2) + refimpl = cmath.atan2 if x1.dtype in dh.complex_dtypes else math.atan2 + binary_assert_against_refimpl("atan2", x1, x2, out, refimpl) @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) @@ -791,12 +800,14 @@ def test_atanh(x): out = xp.atanh(x) ph.assert_dtype("atanh", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("atanh", out_shape=out.shape, expected=x.shape) + refimpl = cmath.atanh if x.dtype in dh.complex_dtypes else math.atanh + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 < s < 1 unary_assert_against_refimpl( "atanh", x, out, - math.atanh, - filter_=lambda s: default_filter(s) and -1 <= s <= 1, + refimpl, + filter_=filter_, ) @@ -835,7 +846,7 @@ def test_bitwise_left_shift(ctx, data): binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) - nbits = res.dtype + nbits = dh.dtype_nbits[res.dtype] binary_param_assert_against_refimpl( ctx, left, right, res, "<<", lambda l, r: l << r if r < nbits else 0 ) @@ -1074,7 +1085,8 @@ def test_cos(x): out = xp.cos(x) ph.assert_dtype("cos", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("cos", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("cos", x, out, math.cos) + refimpl = cmath.cos if x.dtype in dh.complex_dtypes else math.cos + unary_assert_against_refimpl("cos", x, out, refimpl) @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) @@ -1082,7 +1094,8 @@ def test_cosh(x): out = xp.cosh(x) ph.assert_dtype("cosh", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("cosh", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("cosh", x, out, math.cosh) + refimpl = cmath.cosh if x.dtype in dh.complex_dtypes else math.cosh + unary_assert_against_refimpl("cosh", x, out, refimpl) @pytest.mark.parametrize("ctx", make_binary_params("divide", dh.all_float_dtypes)) @@ -1106,7 +1119,7 @@ def test_divide(ctx, data): res, "/", operator.truediv, - filter_=lambda s: math.isfinite(s) and s != 0, + filter_=lambda s: cmath.isfinite(s) and s != 0, ) @@ -1143,7 +1156,8 @@ def test_exp(x): out = xp.exp(x) ph.assert_dtype("exp", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("exp", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("exp", x, out, math.exp) + refimpl = cmath.exp if x.dtype in dh.complex_dtypes else math.exp + unary_assert_against_refimpl("exp", x, out, refimpl) @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) @@ -1151,7 +1165,23 @@ def test_expm1(x): out = xp.expm1(x) ph.assert_dtype("expm1", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("expm1", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("expm1", x, out, math.expm1) + if x.dtype in dh.complex_dtypes: + def refimpl(z): + # There's no cmath.expm1. Use + # + # exp(x+yi) - 1 + # = exp(x)exp(yi) - 1 + # = exp(x)(cos(y) + sin(y)i) - 1 + # = (exp(x) - 1)cos(y) + (cos(y) - 1) + exp(x)sin(y)i + # = expm1(x)cos(y) - 2sin(y/2)^2 + exp(x)sin(y)i + # + # where 1 - cos(y) = 2sin(y/2)^2 is used to avoid loss of + # significance near y = 0. + re, im = z.real, z.imag + return math.expm1(re)*math.cos(im) - 2*math.sin(im/2)**2 + 1j*math.exp(re)*math.sin(im) + else: + refimpl = math.expm1 + unary_assert_against_refimpl("expm1", x, out, refimpl) @given(hh.arrays(dtype=hh.real_dtypes, shape=hh.shapes())) @@ -1159,7 +1189,12 @@ def test_floor(x): out = xp.floor(x) ph.assert_dtype("floor", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("floor", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("floor", x, out, math.floor, strict_check=True) + if x.dtype in dh.complex_dtypes: + def refimpl(z): + return complex(math.floor(z.real), math.floor(z.imag)) + else: + refimpl = math.floor + unary_assert_against_refimpl("floor", x, out, refimpl, strict_check=True) @pytest.mark.parametrize("ctx", make_binary_params("floor_divide", dh.real_dtypes)) @@ -1245,7 +1280,8 @@ def test_isfinite(x): out = xp.isfinite(x) ph.assert_dtype("isfinite", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) ph.assert_shape("isfinite", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("isfinite", x, out, math.isfinite, res_stype=bool) + refimpl = cmath.isfinite if x.dtype in dh.complex_dtypes else math.isfinite + unary_assert_against_refimpl("isfinite", x, out, refimpl, res_stype=bool) @given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes())) @@ -1253,7 +1289,8 @@ def test_isinf(x): out = xp.isinf(x) ph.assert_dtype("isfinite", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) ph.assert_shape("isinf", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("isinf", x, out, math.isinf, res_stype=bool) + refimpl = cmath.isinf if x.dtype in dh.complex_dtypes else math.isinf + unary_assert_against_refimpl("isinf", x, out, refimpl, res_stype=bool) @given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes())) @@ -1261,7 +1298,8 @@ def test_isnan(x): out = xp.isnan(x) ph.assert_dtype("isnan", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) ph.assert_shape("isnan", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("isnan", x, out, math.isnan, res_stype=bool) + refimpl = cmath.isnan if x.dtype in dh.complex_dtypes else math.isnan + unary_assert_against_refimpl("isnan", x, out, refimpl, res_stype=bool) @pytest.mark.parametrize("ctx", make_binary_params("less", dh.real_dtypes)) @@ -1309,8 +1347,10 @@ def test_log(x): out = xp.log(x) ph.assert_dtype("log", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("log", out_shape=out.shape, expected=x.shape) + refimpl = cmath.log if x.dtype in dh.complex_dtypes else math.log + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0 unary_assert_against_refimpl( - "log", x, out, math.log, filter_=lambda s: default_filter(s) and s >= 1 + "log", x, out, refimpl, filter_=filter_ ) @@ -1319,8 +1359,19 @@ def test_log1p(x): out = xp.log1p(x) ph.assert_dtype("log1p", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("log1p", out_shape=out.shape, expected=x.shape) + # There isn't a cmath.log1p, and implementing one isn't straightforward + # (see + # https://stackoverflow.com/questions/78318212/unexpected-behaviour-of-log1p-numpy). + # For now, just use log(1+p) for complex inputs, which should hopefully be + # fine given the very loose numerical tolerances we use. If it isn't, we + # can try using something like a series expansion for small p. + if x.dtype in dh.complex_dtypes: + refimpl = lambda z: cmath.log(1+z) + else: + refimpl = math.log1p + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > -1 unary_assert_against_refimpl( - "log1p", x, out, math.log1p, filter_=lambda s: default_filter(s) and s >= 1 + "log1p", x, out, refimpl, filter_=filter_ ) @@ -1329,8 +1380,13 @@ def test_log2(x): out = xp.log2(x) ph.assert_dtype("log2", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("log2", out_shape=out.shape, expected=x.shape) + if x.dtype in dh.complex_dtypes: + refimpl = lambda z: cmath.log(z)/math.log(2) + else: + refimpl = math.log2 + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0 unary_assert_against_refimpl( - "log2", x, out, math.log2, filter_=lambda s: default_filter(s) and s > 1 + "log2", x, out, refimpl, filter_=filter_ ) @@ -1339,13 +1395,21 @@ def test_log10(x): out = xp.log10(x) ph.assert_dtype("log10", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("log10", out_shape=out.shape, expected=x.shape) + if x.dtype in dh.complex_dtypes: + refimpl = lambda z: cmath.log(z)/math.log(10) + else: + refimpl = math.log10 + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0 unary_assert_against_refimpl( - "log10", x, out, math.log10, filter_=lambda s: default_filter(s) and s > 0 + "log10", x, out, refimpl, filter_=filter_ ) -def logaddexp(l: float, r: float) -> float: - return math.log(math.exp(l) + math.exp(r)) +def logaddexp_refimpl(l: float, r: float) -> float: + try: + return math.log(math.exp(l) + math.exp(r)) + except ValueError: # raised for log(0.) + raise OverflowError @given(*hh.two_mutual_arrays(dh.real_float_dtypes)) @@ -1353,7 +1417,7 @@ def test_logaddexp(x1, x2): out = xp.logaddexp(x1, x2) ph.assert_dtype("logaddexp", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) ph.assert_result_shape("logaddexp", in_shapes=[x1.shape, x2.shape], out_shape=out.shape) - binary_assert_against_refimpl("logaddexp", x1, x2, out, logaddexp) + binary_assert_against_refimpl("logaddexp", x1, x2, out, logaddexp_refimpl) @given(*hh.two_mutual_arrays([xp.bool])) @@ -1530,7 +1594,11 @@ def test_round(x): out = xp.round(x) ph.assert_dtype("round", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("round", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("round", x, out, round, strict_check=True) + if x.dtype in dh.complex_dtypes: + refimpl = lambda z: complex(round(z.real), round(z.imag)) + else: + refimpl = round + unary_assert_against_refimpl("round", x, out, refimpl, strict_check=True) @pytest.mark.min_version("2023.12") @@ -1548,13 +1616,12 @@ def test_sign(x): out = xp.sign(x) ph.assert_dtype("sign", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("sign", out_shape=out.shape, expected=x.shape) - refimpl = lambda x: x / math.abs(x) if x != 0 else 0 + refimpl = lambda x: x / abs(x) if x != 0 else 0 unary_assert_against_refimpl( "sign", x, out, refimpl, - filter_=lambda s: s != 0, strict_check=True, ) @@ -1564,7 +1631,8 @@ def test_sin(x): out = xp.sin(x) ph.assert_dtype("sin", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("sin", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("sin", x, out, math.sin) + refimpl = cmath.sin if x.dtype in dh.complex_dtypes else math.sin + unary_assert_against_refimpl("sin", x, out, refimpl) @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) @@ -1572,7 +1640,8 @@ def test_sinh(x): out = xp.sinh(x) ph.assert_dtype("sinh", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("sinh", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("sinh", x, out, math.sinh) + refimpl = cmath.sinh if x.dtype in dh.complex_dtypes else math.sinh + unary_assert_against_refimpl("sinh", x, out, refimpl) @given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes())) @@ -1590,8 +1659,10 @@ def test_sqrt(x): out = xp.sqrt(x) ph.assert_dtype("sqrt", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("sqrt", out_shape=out.shape, expected=x.shape) + refimpl = cmath.sqrt if x.dtype in dh.complex_dtypes else math.sqrt + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s >= 0 unary_assert_against_refimpl( - "sqrt", x, out, math.sqrt, filter_=lambda s: default_filter(s) and s >= 0 + "sqrt", x, out, refimpl, filter_=filter_ ) @@ -1614,7 +1685,8 @@ def test_tan(x): out = xp.tan(x) ph.assert_dtype("tan", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("tan", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("tan", x, out, math.tan) + refimpl = cmath.tan if x.dtype in dh.complex_dtypes else math.tan + unary_assert_against_refimpl("tan", x, out, refimpl) @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) @@ -1622,7 +1694,8 @@ def test_tanh(x): out = xp.tanh(x) ph.assert_dtype("tanh", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("tanh", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("tanh", x, out, math.tanh) + refimpl = cmath.tanh if x.dtype in dh.complex_dtypes else math.tanh + unary_assert_against_refimpl("tanh", x, out, refimpl) @given(hh.arrays(dtype=hh.real_dtypes, shape=xps.array_shapes()))