Skip to content

TYP: EA.isin #56423

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,7 @@ def fillna(

return super().fillna(value=value, method=method, limit=limit, copy=copy)

def isin(self, values) -> npt.NDArray[np.bool_]:
def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]:
# short-circuit to return all False array.
if not len(values):
return np.zeros(len(self), dtype=bool)
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,15 +1355,15 @@ def equals(self, other: object) -> bool:
equal_na = self.isna() & other.isna() # type: ignore[operator]
return bool((equal_values | equal_na).all())

def isin(self, values) -> npt.NDArray[np.bool_]:
def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]:
"""
Pointwise comparison for set containment in the given values.

Roughly equivalent to `np.array([x in values for x in self])`

Parameters
----------
values : Sequence
values : np.ndarray or ExtensionArray

Returns
-------
Expand Down
11 changes: 2 additions & 9 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2570,7 +2570,7 @@ def describe(self) -> DataFrame:

return result

def isin(self, values) -> npt.NDArray[np.bool_]:
def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]:
"""
Check whether `values` are contained in Categorical.

Expand All @@ -2580,7 +2580,7 @@ def isin(self, values) -> npt.NDArray[np.bool_]:

Parameters
----------
values : set or list-like
values : np.ndarray or ExtensionArray
The sequence of values to test. Passing in a single string will
raise a ``TypeError``. Instead, turn a single string into a
list of one element.
Expand Down Expand Up @@ -2611,13 +2611,6 @@ def isin(self, values) -> npt.NDArray[np.bool_]:
>>> s.isin(['lama'])
array([ True, False, True, False, True, False])
"""
if not is_list_like(values):
values_type = type(values).__name__
raise TypeError(
"only list-like objects are allowed to be passed "
f"to isin(), you passed a `{values_type}`"
)
values = sanitize_array(values, None, None)
null_mask = np.asarray(isna(values))
code_values = self.categories.get_indexer_for(values)
code_values = code_values[null_mask | (code_values >= 0)]
Expand Down
20 changes: 12 additions & 8 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,22 +734,19 @@ def map(self, mapper, na_action=None):
else:
return result.array

def isin(self, values) -> npt.NDArray[np.bool_]:
def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]:
"""
Compute boolean array of whether each value is found in the
passed set of values.

Parameters
----------
values : set or sequence of values
values : np.ndarray or ExtensionArray

Returns
-------
ndarray[bool]
"""
if not hasattr(values, "dtype"):
values = np.asarray(values)

if values.dtype.kind in "fiuc":
# TODO: de-duplicate with equals, validate_comparison_value
return np.zeros(self.shape, dtype=bool)
Expand Down Expand Up @@ -781,15 +778,22 @@ def isin(self, values) -> npt.NDArray[np.bool_]:

if self.dtype.kind in "mM":
self = cast("DatetimeArray | TimedeltaArray", self)
values = values.as_unit(self.unit)
# error: Item "ExtensionArray" of "ExtensionArray | ndarray[Any, Any]"
# has no attribute "as_unit"
values = values.as_unit(self.unit) # type: ignore[union-attr]

try:
self._check_compatible_with(values)
# error: Argument 1 to "_check_compatible_with" of "DatetimeLikeArrayMixin"
# has incompatible type "ExtensionArray | ndarray[Any, Any]"; expected
# "Period | Timestamp | Timedelta | NaTType"
self._check_compatible_with(values) # type: ignore[arg-type]
except (TypeError, ValueError):
# Includes tzawareness mismatch and IncompatibleFrequencyError
return np.zeros(self.shape, dtype=bool)

return isin(self.asi8, values.asi8)
# error: Item "ExtensionArray" of "ExtensionArray | ndarray[Any, Any]"
# has no attribute "asi8"
return isin(self.asi8, values.asi8) # type: ignore[union-attr]

# ------------------------------------------------------------------
# Null Handling
Expand Down
8 changes: 2 additions & 6 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1789,12 +1789,8 @@ def contains(self, other):
other < self._right if self.open_right else other <= self._right
)

def isin(self, values) -> npt.NDArray[np.bool_]:
if not hasattr(values, "dtype"):
values = np.array(values)
values = extract_array(values, extract_numpy=True)

if isinstance(values.dtype, IntervalDtype):
def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]:
if isinstance(values, IntervalArray):
if self.closed != values.closed:
# not comparable -> no overlap
return np.zeros(self.shape, dtype=bool)
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ def take(

# error: Return type "BooleanArray" of "isin" incompatible with return type
# "ndarray" in supertype "ExtensionArray"
def isin(self, values) -> BooleanArray: # type: ignore[override]
def isin(self, values: ArrayLike) -> BooleanArray: # type: ignore[override]
from pandas.core.arrays import BooleanArray

# algorithms.isin will eventually convert values to an ndarray, so no extra
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from collections.abc import Sequence

from pandas._typing import (
ArrayLike,
AxisInt,
Dtype,
Scalar,
Expand Down Expand Up @@ -212,7 +213,7 @@ def _maybe_convert_setitem_value(self, value):
raise TypeError("Scalar must be NA or str")
return super()._maybe_convert_setitem_value(value)

def isin(self, values) -> npt.NDArray[np.bool_]:
def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]:
value_set = [
pa_scalar.as_py()
for pa_scalar in [pa.scalar(value, from_pandas=True) for value in values]
Expand Down