diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 65d19aaa..66d53b90 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1,4 +1,3 @@ -import contextlib import math import warnings from types import ModuleType @@ -24,7 +23,7 @@ from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal from array_api_extra._lib._utils._compat import device as get_device from array_api_extra._lib._utils._helpers import eager_shape, ndindex -from array_api_extra._lib._utils._typing import Array, Device +from array_api_extra._lib._utils._typing import Device from array_api_extra.testing import lazy_xp_function # some xp backends are untyped @@ -291,22 +290,12 @@ def test_xp(self, xp: ModuleType): class TestExpandDims: @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims") - @pytest.mark.xfail_xp_backend(Backend.DASK, reason="tuple index out of range") - @pytest.mark.xfail_xp_backend(Backend.TORCH, reason="tuple index out of range") - def test_functionality(self, xp: ModuleType): - def _squeeze_all(b: Array) -> Array: - """Mimics `np.squeeze(b)`. `xpx.squeeze`?""" - for axis in range(b.ndim): - with contextlib.suppress(ValueError): - b = xp.squeeze(b, axis=axis) - return b - - s = (2, 3, 4, 5) - a = xp.empty(s) + def test_single_axis(self, xp: ModuleType): + """Trivial case where xpx.expand_dims doesn't add anything to xp.expand_dims""" + a = xp.empty((2, 3, 4, 5)) for axis in range(-5, 4): b = expand_dims(a, axis=axis) - assert b.shape[axis] == 1 - assert _squeeze_all(b).shape == s + xp_assert_equal(b, xp.expand_dims(a, axis=axis)) @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims") def test_axis_tuple(self, xp: ModuleType): @@ -317,8 +306,7 @@ def test_axis_tuple(self, xp: ModuleType): assert expand_dims(a, axis=(0, -3, -5)).shape == (1, 1, 3, 1, 3, 3) def test_axis_out_of_range(self, xp: ModuleType): - s = (2, 3, 4, 5) - a = xp.empty(s) + a = xp.empty((2, 3, 4, 5)) with pytest.raises(IndexError, match="out of bounds"): _ = expand_dims(a, axis=-6) with pytest.raises(IndexError, match="out of bounds"):