Skip to content

Support ExtensionArray types in where #24077

Closed
@TomAugspurger

Description

@TomAugspurger

This is blocking DatetimeArray. It's also a slight regression from 0.24, since things like .where on a DataFrame with period objects would work (via object dtype).

I think the easiest place for this is by defining ExtensionBlock.where, and restricting it to cases where the dtype of self and other match (so that the result dtype is the same).

We can do this pretty easily for our EAs by performing the .where on _ndarray_values. But _ndarray_values isn't part of the EA interface yet. I'm not sure if we'll have time to properly design and implement a generic .where for any ExtensionArray since there are a couple subtlies.

Here's a start

diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py
index 1b67c2053..ce5c01359 100644
--- a/pandas/core/internals/blocks.py
+++ b/pandas/core/internals/blocks.py
@@ -1955,6 +1955,37 @@ class ExtensionBlock(NonConsolidatableMixIn, Block):
                                            placement=self.mgr_locs,
                                            ndim=self.ndim)]
 
+    def where(self, other, cond, align=True, errors='raise',
+              try_cast=False, axis=0, transpose=False):
+        import pandas.core.computation.expressions as expressions
+
+        values = self.values._ndarray_values
+
+        if cond.ndim == 2:
+            assert cond.shape[-1] == 1
+            cond = cond._data.blocks[0].values.ravel()
+
+        if hasattr(other, 'ndim') and other.ndim == 2:
+            # TODO: this hasn't been normalized
+            assert other.shape[-1] == 1
+            other = other._data.blocks[0].values
+
+        elif (lib.is_scalar(other) and isna(other)) or other is None:
+            # TODO: we need the storage NA value (e.g. iNaT)
+            other = self.values.dtype.na_value
+            # other = tslibs.iNaT
+
+        # TODO: cond.ravel().all() short-circut
+
+        if cond.ndim > 1:
+            cond = cond.ravel()
+
+        result = expressions.where(cond, values, other)
+        if not isinstance(result, self._holder):
+            # Need a kind of _from_ndarray_values()
+            # this is different from _from_sequence
+            result = self._holder.(result, dtype=self.dtype)
+        return self.make_block_same_class(result)
+
     @property
     def _ftype(self):
         return getattr(self.values, '_pandas_ftype', Block._ftype)

There are a couple TODOs there, plus tests, and I'm sure plenty of edge cases.

In [7]: df = pd.DataFrame({"A": pd.period_range("2000", periods=12)})

In [8]: df.where(df.A.dt.day == 2)
Out[8]:
             A
0          NaT
1   2000-01-02
2          NaT
3          NaT
4          NaT
5          NaT
6          NaT
7          NaT
8          NaT
9          NaT
10         NaT
11         NaT

Metadata

Metadata

Assignees

No one assigned

    Labels

    ExtensionArrayExtending pandas with custom dtypes or arrays.

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions