Skip to content

Commit 3f7358d

Browse files
committed
BUG: setdiff1d vs. non-1d arrays
1 parent 2bce614 commit 3f7358d

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

src/array_api_extra/_lib/_funcs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,7 @@ def setdiff1d(
585585

586586
if assume_unique:
587587
x1 = xp.reshape(x1, (-1,))
588+
x2 = xp.reshape(x2, (-1,))
588589
else:
589590
x1 = xp.unique_values(x1)
590591
x2 = xp.unique_values(x2)

tests/test_funcs.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,21 @@ def test_assume_unique(self, xp: ModuleType):
586586
actual = setdiff1d(x1, x2, assume_unique=True)
587587
xp_assert_equal(actual, expected)
588588

589+
@pytest.mark.parametrize("assume_unique", [True, False])
590+
@pytest.mark.parametrize("shape1", [(), (1,), (1, 1)])
591+
@pytest.mark.parametrize("shape2", [(), (1,), (1, 1)])
592+
def test_shapes(
593+
self,
594+
assume_unique: bool,
595+
shape1: tuple[int, ...],
596+
shape2: tuple[int, ...],
597+
xp: ModuleType,
598+
):
599+
x1 = xp.zeros(shape1)
600+
x2 = xp.zeros(shape2)
601+
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
602+
xp_assert_equal(actual, xp.empty((0,)))
603+
589604
def test_device(self, xp: ModuleType, device: Device):
590605
x1 = xp.asarray([3, 8, 20], device=device)
591606
x2 = xp.asarray([2, 3, 4], device=device)

0 commit comments

Comments
 (0)