-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
BUG/Perf: Support ExtensionArrays in where #24114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c4604df
56470c3
6f79282
a69dbb3
911a2da
badb5be
edff47e
4715ef6
d90f384
5e14414
e9665b8
033ac9c
1271d3d
9e0d87d
e05a597
796332c
cad0c4c
6edd286
30775f0
4de8bb5
ce04a75
f98a82c
bcfb8f8
8d9b20b
c0351fd
539d3cb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -501,10 +501,13 @@ def _can_reindex(self, indexer): | |
|
||
@Appender(_index_shared_docs['where']) | ||
def where(self, cond, other=None): | ||
# TODO: Investigate an alternative implementation with | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
# 1. copy the underyling Categorical | ||
# 2. setitem with `cond` and `other` | ||
# 3. Rebuild CategoricalIndex. | ||
if other is None: | ||
other = self._na_value | ||
values = np.where(cond, self.values, other) | ||
|
||
cat = Categorical(values, dtype=self.dtype) | ||
return self._shallow_copy(cat, **self._get_attributes_dict()) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,8 @@ | |
from pandas.core.dtypes.dtypes import ( | ||
CategoricalDtype, DatetimeTZDtype, ExtensionDtype, PandasExtensionDtype) | ||
from pandas.core.dtypes.generic import ( | ||
ABCDatetimeIndex, ABCExtensionArray, ABCIndexClass, ABCSeries) | ||
ABCDataFrame, ABCDatetimeIndex, ABCExtensionArray, ABCIndexClass, | ||
ABCSeries) | ||
from pandas.core.dtypes.missing import ( | ||
_isna_compat, array_equivalent, is_null_datelike_scalar, isna, notna) | ||
|
||
|
@@ -1886,7 +1887,6 @@ def take_nd(self, indexer, axis=0, new_mgr_locs=None, fill_tuple=None): | |
new_values = self.values.take(indexer, fill_value=fill_value, | ||
allow_fill=True) | ||
|
||
# if we are a 1-dim object, then always place at 0 | ||
if self.ndim == 1 and new_mgr_locs is None: | ||
new_mgr_locs = [0] | ||
else: | ||
|
@@ -1967,6 +1967,57 @@ def shift(self, periods, axis=0): | |
placement=self.mgr_locs, | ||
ndim=self.ndim)] | ||
|
||
def where(self, other, cond, align=True, errors='raise', | ||
try_cast=False, axis=0, transpose=False): | ||
# Extract the underlying arrays. | ||
if isinstance(other, (ABCIndexClass, ABCSeries)): | ||
other = other.array | ||
|
||
elif isinstance(other, ABCDataFrame): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add some comments here |
||
# ExtensionArrays are 1-D, so if we get here then | ||
# `other` should be a DataFrame with a single column. | ||
assert other.shape[1] == 1 | ||
other = other.iloc[:, 0].array | ||
|
||
if isinstance(cond, ABCDataFrame): | ||
assert cond.shape[1] == 1 | ||
cond = cond.iloc[:, 0].array | ||
|
||
elif isinstance(cond, (ABCIndexClass, ABCSeries)): | ||
cond = cond.array | ||
|
||
if lib.is_scalar(other) and isna(other): | ||
# The default `other` for Series / Frame is np.nan | ||
# we want to replace that with the correct NA value | ||
# for the type | ||
other = self.dtype.na_value | ||
|
||
if is_sparse(self.values): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without this, we fail in the result = self._holder._from_sequence(
np.where(cond, self.values, other),
dtype=dtype, since the Implementing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be an overriding method in Sparse then, not here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't have a SparseBlock anymore. I can add one back if you want, but I figured it'd be easier not to since implementing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is pretty hacky. This was why we had originally a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, we don't need this. I think we can just re-infer the dtype from the output of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so is this changing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changing from master? Yes, in the sense that it'll return a SparseArray. But it still densifies when If you mean "is this changing in the future", yes it'll be removed when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh ok, can you add a TODO comment |
||
# TODO(SparseArray.__setitem__): remove this if condition | ||
# We need to re-infer the type of the data after doing the | ||
# where, for cases where the subtypes don't match | ||
dtype = None | ||
else: | ||
dtype = self.dtype | ||
|
||
try: | ||
result = self.values.copy() | ||
icond = ~cond | ||
if lib.is_scalar(other): | ||
result[icond] = other | ||
else: | ||
result[icond] = other[icond] | ||
except (NotImplementedError, TypeError): | ||
# NotImplementedError for class not implementing `__setitem__` | ||
# TypeError for SparseArray, which implements just to raise | ||
# a TypeError | ||
result = self._holder._from_sequence( | ||
np.where(cond, self.values, other), | ||
dtype=dtype, | ||
) | ||
|
||
return self.make_block_same_class(result, placement=self.mgr_locs) | ||
|
||
@property | ||
def _ftype(self): | ||
return getattr(self.values, '_pandas_ftype', Block._ftype) | ||
|
@@ -2658,6 +2709,33 @@ def concat_same_type(self, to_concat, placement=None): | |
values, placement=placement or slice(0, len(values), 1), | ||
ndim=self.ndim) | ||
|
||
def where(self, other, cond, align=True, errors='raise', | ||
try_cast=False, axis=0, transpose=False): | ||
# TODO(CategoricalBlock.where): | ||
# This can all be deleted in favor of ExtensionBlock.where once | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add TODO(EA) or someting here so we know to remove this |
||
# we enforce the deprecation. | ||
object_msg = ( | ||
"Implicitly converting categorical to object-dtype ndarray. " | ||
"One or more of the values in 'other' are not present in this " | ||
"categorical's categories. A future version of pandas will raise " | ||
"a ValueError when 'other' contains different categories.\n\n" | ||
"To preserve the current behavior, add the new categories to " | ||
"the categorical before calling 'where', or convert the " | ||
"categorical to a different dtype." | ||
) | ||
try: | ||
# Attempt to do preserve categorical dtype. | ||
result = super(CategoricalBlock, self).where( | ||
other, cond, align, errors, try_cast, axis, transpose | ||
) | ||
except (TypeError, ValueError): | ||
warnings.warn(object_msg, FutureWarning, stacklevel=6) | ||
result = self.astype(object).where(other, cond, align=align, | ||
errors=errors, | ||
try_cast=try_cast, | ||
axis=axis, transpose=transpose) | ||
return result | ||
|
||
|
||
class DatetimeBlock(DatetimeLikeBlockMixin, Block): | ||
__slots__ = () | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
import numpy as np | ||
import pytest | ||
|
||
import pandas as pd | ||
from pandas import Categorical, CategoricalIndex, Index, PeriodIndex, Series | ||
import pandas.core.common as com | ||
from pandas.tests.arrays.categorical.common import TestCategorical | ||
|
@@ -43,6 +44,45 @@ def test_setitem(self): | |
|
||
tm.assert_categorical_equal(c, expected) | ||
|
||
@pytest.mark.parametrize('other', [ | ||
pd.Categorical(['b', 'a']), | ||
pd.Categorical(['b', 'a'], categories=['b', 'a']), | ||
]) | ||
def test_setitem_same_but_unordered(self, other): | ||
# GH-24142 | ||
target = pd.Categorical(['a', 'b'], categories=['a', 'b']) | ||
mask = np.array([True, False]) | ||
target[mask] = other[mask] | ||
expected = pd.Categorical(['b', 'b'], categories=['a', 'b']) | ||
tm.assert_categorical_equal(target, expected) | ||
|
||
@pytest.mark.parametrize('other', [ | ||
pd.Categorical(['b', 'a'], categories=['b', 'a', 'c']), | ||
pd.Categorical(['b', 'a'], categories=['a', 'b', 'c']), | ||
pd.Categorical(['a', 'a'], categories=['a']), | ||
pd.Categorical(['b', 'b'], categories=['b']), | ||
]) | ||
def test_setitem_different_unordered_raises(self, other): | ||
# GH-24142 | ||
target = pd.Categorical(['a', 'b'], categories=['a', 'b']) | ||
mask = np.array([True, False]) | ||
with pytest.raises(ValueError): | ||
target[mask] = other[mask] | ||
|
||
@pytest.mark.parametrize('other', [ | ||
pd.Categorical(['b', 'a']), | ||
pd.Categorical(['b', 'a'], categories=['b', 'a'], ordered=True), | ||
pd.Categorical(['b', 'a'], categories=['a', 'b', 'c'], ordered=True), | ||
]) | ||
def test_setitem_same_ordered_rasies(self, other): | ||
# Gh-24142 | ||
target = pd.Categorical(['a', 'b'], categories=['a', 'b'], | ||
ordered=True) | ||
mask = np.array([True, False]) | ||
|
||
with pytest.raises(ValueError): | ||
target[mask] = other[mask] | ||
|
||
|
||
class TestCategoricalIndexing(object): | ||
|
||
|
@@ -122,6 +162,60 @@ def test_get_indexer_non_unique(self, idx_values, key_values, key_class): | |
tm.assert_numpy_array_equal(expected, result) | ||
tm.assert_numpy_array_equal(exp_miss, res_miss) | ||
|
||
def test_where_unobserved_nan(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is where all of the where tests are? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There weren't any previously (we used to fall back to object). |
||
ser = pd.Series(pd.Categorical(['a', 'b'])) | ||
result = ser.where([True, False]) | ||
expected = pd.Series(pd.Categorical(['a', None], | ||
categories=['a', 'b'])) | ||
tm.assert_series_equal(result, expected) | ||
|
||
# all NA | ||
ser = pd.Series(pd.Categorical(['a', 'b'])) | ||
result = ser.where([False, False]) | ||
expected = pd.Series(pd.Categorical([None, None], | ||
categories=['a', 'b'])) | ||
tm.assert_series_equal(result, expected) | ||
|
||
def test_where_unobserved_categories(self): | ||
ser = pd.Series( | ||
Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a']) | ||
) | ||
result = ser.where([True, True, False], other='b') | ||
expected = pd.Series( | ||
Categorical(['a', 'b', 'b'], categories=ser.cat.categories) | ||
) | ||
tm.assert_series_equal(result, expected) | ||
|
||
def test_where_other_categorical(self): | ||
ser = pd.Series( | ||
Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a']) | ||
) | ||
other = Categorical(['b', 'c', 'a'], categories=['a', 'c', 'b', 'd']) | ||
result = ser.where([True, False, True], other) | ||
expected = pd.Series(Categorical(['a', 'c', 'c'], dtype=ser.dtype)) | ||
tm.assert_series_equal(result, expected) | ||
|
||
def test_where_warns(self): | ||
ser = pd.Series(Categorical(['a', 'b', 'c'])) | ||
with tm.assert_produces_warning(FutureWarning): | ||
result = ser.where([True, False, True], 'd') | ||
|
||
expected = pd.Series(np.array(['a', 'd', 'c'], dtype='object')) | ||
tm.assert_series_equal(result, expected) | ||
|
||
def test_where_ordered_differs_rasies(self): | ||
ser = pd.Series( | ||
Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a'], | ||
ordered=True) | ||
) | ||
other = Categorical(['b', 'c', 'a'], categories=['a', 'c', 'b', 'd'], | ||
ordered=True) | ||
with tm.assert_produces_warning(FutureWarning): | ||
result = ser.where([True, False, True], other) | ||
|
||
expected = pd.Series(np.array(['a', 'c', 'c'], dtype=object)) | ||
tm.assert_series_equal(result, expected) | ||
|
||
|
||
@pytest.mark.parametrize("index", [True, False]) | ||
def test_mask_with_boolean(index): | ||
|
Uh oh!
There was an error while loading. Please reload this page.