Skip to content

TST: fix failures in expand_dims test #162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 18, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 6 additions & 18 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import contextlib
import math
import warnings
from types import ModuleType
Expand All @@ -24,7 +23,7 @@
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
from array_api_extra._lib._utils._compat import device as get_device
from array_api_extra._lib._utils._helpers import eager_shape, ndindex
from array_api_extra._lib._utils._typing import Array, Device
from array_api_extra._lib._utils._typing import Device
from array_api_extra.testing import lazy_xp_function

# some xp backends are untyped
Expand Down Expand Up @@ -291,22 +290,12 @@ def test_xp(self, xp: ModuleType):

class TestExpandDims:
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")
@pytest.mark.xfail_xp_backend(Backend.DASK, reason="tuple index out of range")
@pytest.mark.xfail_xp_backend(Backend.TORCH, reason="tuple index out of range")
def test_functionality(self, xp: ModuleType):
def _squeeze_all(b: Array) -> Array:
"""Mimics `np.squeeze(b)`. `xpx.squeeze`?"""
for axis in range(b.ndim):
with contextlib.suppress(ValueError):
b = xp.squeeze(b, axis=axis)
return b

s = (2, 3, 4, 5)
a = xp.empty(s)
def test_single_axis(self, xp: ModuleType):
"""Trivial case where xpx.expand_dims doesn't add anything to xp.expand_dims"""
a = xp.empty((2, 3, 4, 5))
for axis in range(-5, 4):
b = expand_dims(a, axis=axis)
assert b.shape[axis] == 1
assert _squeeze_all(b).shape == s
xp_assert_equal(b, xp.expand_dims(a, axis=axis))

@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")
def test_axis_tuple(self, xp: ModuleType):
Expand All @@ -317,8 +306,7 @@ def test_axis_tuple(self, xp: ModuleType):
assert expand_dims(a, axis=(0, -3, -5)).shape == (1, 1, 3, 1, 3, 3)

def test_axis_out_of_range(self, xp: ModuleType):
s = (2, 3, 4, 5)
a = xp.empty(s)
a = xp.empty((2, 3, 4, 5))
with pytest.raises(IndexError, match="out of bounds"):
_ = expand_dims(a, axis=-6)
with pytest.raises(IndexError, match="out of bounds"):
Expand Down