From e61c31fa613ff71025ec9f988c5d0c686978cdbd Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 29 Jan 2025 15:25:55 +0000 Subject: [PATCH] BUG: `setdiff1d` vs. non-1d arrays --- src/array_api_extra/_lib/_funcs.py | 1 + tests/test_funcs.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 0af265e2..9986569c 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -580,6 +580,7 @@ def setdiff1d( if assume_unique: x1 = xp.reshape(x1, (-1,)) + x2 = xp.reshape(x2, (-1,)) else: x1 = xp.unique_values(x1) x2 = xp.unique_values(x2) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index ef1a1fc2..b31cdfb8 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -579,6 +579,21 @@ def test_assume_unique(self, xp: ModuleType): actual = setdiff1d(x1, x2, assume_unique=True) xp_assert_equal(actual, expected) + @pytest.mark.parametrize("assume_unique", [True, False]) + @pytest.mark.parametrize("shape1", [(), (1,), (1, 1)]) + @pytest.mark.parametrize("shape2", [(), (1,), (1, 1)]) + def test_shapes( + self, + assume_unique: bool, + shape1: tuple[int, ...], + shape2: tuple[int, ...], + xp: ModuleType, + ): + x1 = xp.zeros(shape1) + x2 = xp.zeros(shape2) + actual = setdiff1d(x1, x2, assume_unique=assume_unique) + xp_assert_equal(actual, xp.empty((0,))) + def test_device(self, xp: ModuleType, device: Device): x1 = xp.asarray([3, 8, 20], device=device) x2 = xp.asarray([2, 3, 4], device=device)