Skip to content

Commit 3ccc917

Browse files
committed
copy codes from pandas-dev#38422
1 parent 1449d3c commit 3ccc917

File tree

4 files changed

+31
-11
lines changed

4 files changed

+31
-11
lines changed

doc/source/reference/extensions.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ objects.
4848
api.extensions.ExtensionArray.equals
4949
api.extensions.ExtensionArray.factorize
5050
api.extensions.ExtensionArray.fillna
51+
api.extensions.ExtensionArray.isin
5152
api.extensions.ExtensionArray.isna
5253
api.extensions.ExtensionArray.ravel
5354
api.extensions.ExtensionArray.repeat

pandas/core/algorithms.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -449,10 +449,8 @@ def isin(comps: AnyArrayLike, values: AnyArrayLike) -> np.ndarray:
449449

450450
comps = _ensure_arraylike(comps)
451451
comps = extract_array(comps, extract_numpy=True)
452-
if is_categorical_dtype(comps.dtype):
453-
# TODO(extension)
454-
# handle categoricals
455-
return cast("Categorical", comps).isin(values)
452+
if is_extension_array_dtype(comps.dtype):
453+
return comps.isin(values)
456454

457455
if needs_i8_conversion(comps.dtype):
458456
# Dispatch to DatetimeLikeArrayMixin.isin
@@ -464,11 +462,7 @@ def isin(comps: AnyArrayLike, values: AnyArrayLike) -> np.ndarray:
464462
elif needs_i8_conversion(values.dtype):
465463
return isin(comps, values.astype(object))
466464

467-
elif is_extension_array_dtype(comps.dtype) or is_extension_array_dtype(
468-
values.dtype
469-
):
470-
if type(comps).__name__ == "IntegerArray":
471-
comps = comps._data # type: ignore[attr-defined, assignment]
465+
elif is_extension_array_dtype(values.dtype):
472466
return isin(np.asarray(comps), np.asarray(values))
473467

474468
# GH16012

pandas/core/arrays/base.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from pandas.core.dtypes.missing import isna
4646

4747
from pandas.core import ops
48-
from pandas.core.algorithms import factorize_array, unique
48+
from pandas.core.algorithms import factorize_array, isin, unique
4949
from pandas.core.missing import get_fill_func
5050
from pandas.core.sorting import nargminmax, nargsort
5151

@@ -78,6 +78,7 @@ class ExtensionArray:
7878
factorize
7979
fillna
8080
equals
81+
isin
8182
isna
8283
ravel
8384
repeat
@@ -833,6 +834,22 @@ def equals(self, other: object) -> bool:
833834
equal_na = self.isna() & other.isna()
834835
return bool((equal_values | equal_na).all())
835836

837+
def isin(self, values) -> np.ndarray:
838+
"""
839+
Pointwise comparison for set containment in the given values.
840+
841+
Roughly equivalent to `np.array([x in values for x in self])`
842+
843+
Parameters
844+
----------
845+
values : Sequence
846+
847+
Returns
848+
-------
849+
np.ndarray[bool]
850+
"""
851+
return isin(np.asarray(self), values)
852+
836853
def _values_for_factorize(self) -> Tuple[np.ndarray, Any]:
837854
"""
838855
Return an array and missing value suitable for factorization.

pandas/core/arrays/masked.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@
1919
from pandas.core.dtypes.missing import isna, notna
2020

2121
from pandas.core import nanops
22-
from pandas.core.algorithms import factorize_array, take
22+
from pandas.core.algorithms import factorize_array, isin, take
2323
from pandas.core.array_algos import masked_reductions
2424
from pandas.core.arraylike import OpsMixin
2525
from pandas.core.arrays import ExtensionArray
2626
from pandas.core.indexers import check_array_indexer
2727

2828
if TYPE_CHECKING:
2929
from pandas import Series
30+
from pandas.core.arrays import BooleanArray
3031

3132

3233
BaseMaskedArrayT = TypeVar("BaseMaskedArrayT", bound="BaseMaskedArray")
@@ -299,6 +300,13 @@ def take(
299300

300301
return type(self)(result, mask, copy=False)
301302

303+
def isin(self, values) -> "BooleanArray":
304+
305+
from pandas.core.arrays import BooleanArray
306+
307+
result = isin(self._data, values)
308+
return BooleanArray(result, self._mask.copy(), copy=False)
309+
302310
def copy(self: BaseMaskedArrayT) -> BaseMaskedArrayT:
303311
data, mask = self._data, self._mask
304312
data = data.copy()

0 commit comments

Comments
 (0)