diff --git a/doc/source/whatsnew/v1.0.0.rst b/doc/source/whatsnew/v1.0.0.rst index 751db2b88069d..f2bb20746741d 100644 --- a/doc/source/whatsnew/v1.0.0.rst +++ b/doc/source/whatsnew/v1.0.0.rst @@ -223,7 +223,7 @@ Categorical - Bug where :func:`merge` was unable to join on categorical and extension dtype columns (:issue:`28668`) - :meth:`Categorical.searchsorted` and :meth:`CategoricalIndex.searchsorted` now work on unordered categoricals also (:issue:`21667`) - Added test to assert roundtripping to parquet with :func:`DataFrame.to_parquet` or :func:`read_parquet` will preserve Categorical dtypes for string types (:issue:`27955`) -- +- Changed the error message in :meth:`Categorical.remove_categories` to always show the invalid removals as a set (:issue:`28669`) Datetimelike diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index bab1127e6e539..a14b91d78212d 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -1120,7 +1120,7 @@ def remove_categories(self, removals, inplace=False): # GH 10156 if any(isna(removals)): - not_included = [x for x in not_included if notna(x)] + not_included = {x for x in not_included if notna(x)} new_categories = [x for x in new_categories if notna(x)] if len(not_included) != 0: diff --git a/pandas/tests/arrays/categorical/test_api.py b/pandas/tests/arrays/categorical/test_api.py index ab07b3c96a1db..42087b89a19b5 100644 --- a/pandas/tests/arrays/categorical/test_api.py +++ b/pandas/tests/arrays/categorical/test_api.py @@ -1,3 +1,5 @@ +import re + import numpy as np import pytest @@ -339,9 +341,13 @@ def test_remove_categories(self): tm.assert_categorical_equal(cat, new) assert res is None - # removal is not in categories - with pytest.raises(ValueError): - cat.remove_categories(["c"]) + @pytest.mark.parametrize("removals", [["c"], ["c", np.nan], "c", ["c", "c"]]) + def test_remove_categories_raises(self, removals): + cat = Categorical(["a", "b", "a"]) + message = re.escape("removals must all be in old categories: {'c'}") + + with pytest.raises(ValueError, match=message): + cat.remove_categories(removals) def test_remove_unused_categories(self): c = Categorical(["a", "b", "c", "d", "a"], categories=["a", "b", "c", "d", "e"])