From 0e91b5efab03d5e54e44e97de0c57b4245e44b1e Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 16 Dec 2024 22:33:43 +0200 Subject: [PATCH] MAINT: simplify API version guards --- ...est_operators_and_elementwise_functions.py | 44 +++++++++---------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 2158c163..2298eab5 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -1061,14 +1061,13 @@ def refimpl(_x, _min, _max): ) -if api_version >= "2022.12": - - @given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes())) - def test_conj(x): - out = xp.conj(x) - ph.assert_dtype("conj", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("conj", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate")) +@pytest.mark.min_version("2022.12") +@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes())) +def test_conj(x): + out = xp.conj(x) + ph.assert_dtype("conj", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("conj", out_shape=out.shape, expected=x.shape) + unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate")) @pytest.mark.min_version("2023.12") @@ -1263,14 +1262,14 @@ def test_hypot(x1, x2): binary_assert_against_refimpl("hypot", x1, x2, out, math.hypot) -if api_version >= "2022.12": - @given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes())) - def test_imag(x): - out = xp.imag(x) - ph.assert_dtype("imag", in_dtype=x.dtype, out_dtype=out.dtype, expected=dh.dtype_components[x.dtype]) - ph.assert_shape("imag", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("imag", x, out, operator.attrgetter("imag")) +@pytest.mark.min_version("2022.12") +@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes())) +def test_imag(x): + out = xp.imag(x) + ph.assert_dtype("imag", in_dtype=x.dtype, out_dtype=out.dtype, expected=dh.dtype_components[x.dtype]) + ph.assert_shape("imag", out_shape=out.shape, expected=x.shape) + unary_assert_against_refimpl("imag", x, out, operator.attrgetter("imag")) @given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes())) @@ -1559,14 +1558,13 @@ def test_pow(ctx, data): # Values testing pow is too finicky -if api_version >= "2022.12": - - @given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes())) - def test_real(x): - out = xp.real(x) - ph.assert_dtype("real", in_dtype=x.dtype, out_dtype=out.dtype, expected=dh.dtype_components[x.dtype]) - ph.assert_shape("real", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("real", x, out, operator.attrgetter("real")) +@pytest.mark.min_version("2022.12") +@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes())) +def test_real(x): + out = xp.real(x) + ph.assert_dtype("real", in_dtype=x.dtype, out_dtype=out.dtype, expected=dh.dtype_components[x.dtype]) + ph.assert_shape("real", out_shape=out.shape, expected=x.shape) + unary_assert_against_refimpl("real", x, out, operator.attrgetter("real")) @pytest.mark.skip(reason="flaky")