Skip to content

Commit 13210b7

Browse files
committed
Merge pull request #4781 from jreback/where
ENH: Add axis and level keywords to where, so that the other argument can now be an alignable pandas object.
2 parents 859aada + 558a594 commit 13210b7

File tree

7 files changed

+144
-27
lines changed

7 files changed

+144
-27
lines changed

doc/source/indexing.rst

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,18 @@ This can be done intuitively like so:
625625
df2[df2 < 0] = 0
626626
df2
627627
628+
By default, ``where`` returns a modified copy of the data. There is an
629+
optional parameter ``inplace`` so that the original data can be modified
630+
without creating a copy:
631+
632+
.. ipython:: python
633+
634+
df_orig = df.copy()
635+
df_orig.where(df > 0, -df, inplace=True);
636+
df_orig
637+
638+
**alignment**
639+
628640
Furthermore, ``where`` aligns the input boolean condition (ndarray or DataFrame),
629641
such that partial selection with setting is possible. This is analagous to
630642
partial setting via ``.ix`` (but on the contents rather than the axis labels)
@@ -635,24 +647,30 @@ partial setting via ``.ix`` (but on the contents rather than the axis labels)
635647
df2[ df2[1:4] > 0 ] = 3
636648
df2
637649
638-
By default, ``where`` returns a modified copy of the data. There is an
639-
optional parameter ``inplace`` so that the original data can be modified
640-
without creating a copy:
650+
.. versionadded:: 0.13
651+
652+
Where can also accept ``axis`` and ``level`` parameters to align the input when
653+
performing the ``where``.
641654

642655
.. ipython:: python
643656
644-
df_orig = df.copy()
657+
df2 = df.copy()
658+
df2.where(df2>0,df2['A'],axis='index')
645659
646-
df_orig.where(df > 0, -df, inplace=True);
660+
This is equivalent (but faster than) the following.
647661

648-
df_orig
662+
.. ipython:: python
663+
664+
df2 = df.copy()
665+
df.apply(lambda x, y: x.where(x>0,y), y=df['A'])
666+
667+
**mask**
649668

650669
``mask`` is the inverse boolean operation of ``where``.
651670

652671
.. ipython:: python
653672
654673
s.mask(s >= 0)
655-
656674
df.mask(df >= 0)
657675
658676
Take Methods

doc/source/missing_data.rst

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,33 @@ To remind you, these are the available filling methods:
205205
With time series data, using pad/ffill is extremely common so that the "last
206206
known value" is available at every time point.
207207

208+
Filling with a PandasObject
209+
~~~~~~~~~~~~~~~~~~~~~~~~~~~
210+
211+
.. versionadded:: 0.12
212+
213+
You can also fill using a direct assignment with an alignable object. The
214+
use case of this is to fill a DataFrame with the mean of that column.
215+
216+
.. ipython:: python
217+
218+
df = DataFrame(np.random.randn(10,3))
219+
df.iloc[3:5,0] = np.nan
220+
df.iloc[4:6,1] = np.nan
221+
df.iloc[5:8,2] = np.nan
222+
df
223+
224+
df.fillna(df.mean())
225+
226+
.. versionadded:: 0.13
227+
228+
Same result as above, but is aligning the 'fill' value which is
229+
a Series in this case.
230+
231+
.. ipython:: python
232+
233+
df.where(pd.notnull(df),df.mean(),axis='columns')
234+
208235
.. _missing_data.dropna:
209236

210237
Dropping axis labels with missing data: dropna

doc/source/release.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ Improvements to existing features
103103
tests/test_frame, tests/test_multilevel (:issue:`4732`).
104104
- Performance improvement of timesesies plotting with PeriodIndex and added
105105
test to vbench (:issue:`4705` and :issue:`4722`)
106+
- Add ``axis`` and ``level`` keywords to ``where``, so that the ``other`` argument
107+
can now be an alignable pandas object.
106108

107109
API Changes
108110
~~~~~~~~~~~

pandas/core/generic.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2173,6 +2173,8 @@ def align(self, other, join='outer', axis=None, level=None, copy=True,
21732173
from pandas import DataFrame, Series
21742174
method = com._clean_fill_method(method)
21752175

2176+
if axis is not None:
2177+
axis = self._get_axis_number(axis)
21762178
if isinstance(other, DataFrame):
21772179
return self._align_frame(other, join=join, axis=axis, level=level,
21782180
copy=copy, fill_value=fill_value,
@@ -2262,7 +2264,8 @@ def _align_series(self, other, join='outer', axis=None, level=None,
22622264
else:
22632265
return left_result, right_result
22642266

2265-
def where(self, cond, other=np.nan, inplace=False, try_cast=False, raise_on_error=True):
2267+
def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
2268+
try_cast=False, raise_on_error=True):
22662269
"""
22672270
Return an object of same shape as self and whose corresponding
22682271
entries are from self where cond is True and otherwise are from other.
@@ -2273,6 +2276,8 @@ def where(self, cond, other=np.nan, inplace=False, try_cast=False, raise_on_erro
22732276
other : scalar or DataFrame
22742277
inplace : boolean, default False
22752278
Whether to perform the operation in place on the data
2279+
axis : alignment axis if needed, default None
2280+
level : alignment level if needed, default None
22762281
try_cast : boolean, default False
22772282
try to cast the result back to the input type (if possible),
22782283
raise_on_error : boolean, default True
@@ -2306,15 +2311,17 @@ def where(self, cond, other=np.nan, inplace=False, try_cast=False, raise_on_erro
23062311
# align with me
23072312
if other.ndim <= self.ndim:
23082313

2309-
_, other = self.align(other, join='left', fill_value=np.nan)
2314+
_, other = self.align(other, join='left',
2315+
axis=axis, level=level,
2316+
fill_value=np.nan)
23102317

23112318
# if we are NOT aligned, raise as we cannot where index
2312-
if not all([ other._get_axis(i).equals(ax) for i, ax in enumerate(self.axes) ]):
2319+
if axis is None and not all([ other._get_axis(i).equals(ax) for i, ax in enumerate(self.axes) ]):
23132320
raise InvalidIndexError
23142321

23152322
# slice me out of the other
23162323
else:
2317-
raise NotImplemented
2324+
raise NotImplemented("cannot align with a bigger dimensional PandasObject")
23182325

23192326
elif is_list_like(other):
23202327

@@ -2386,11 +2393,11 @@ def where(self, cond, other=np.nan, inplace=False, try_cast=False, raise_on_erro
23862393
if inplace:
23872394
# we may have different type blocks come out of putmask, so
23882395
# reconstruct the block manager
2389-
self._data = self._data.putmask(cond, other, inplace=True)
2396+
self._data = self._data.putmask(cond, other, align=axis is None, inplace=True)
23902397

23912398
else:
23922399
new_data = self._data.where(
2393-
other, cond, raise_on_error=raise_on_error, try_cast=try_cast)
2400+
other, cond, align=axis is None, raise_on_error=raise_on_error, try_cast=try_cast)
23942401

23952402
return self._constructor(new_data)
23962403

pandas/core/internals.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -593,29 +593,52 @@ def setitem(self, indexer, value):
593593

594594
return [ self ]
595595

596-
def putmask(self, mask, new, inplace=False):
596+
def putmask(self, mask, new, align=True, inplace=False):
597597
""" putmask the data to the block; it is possible that we may create a new dtype of block
598-
return the resulting block(s) """
598+
return the resulting block(s)
599+
600+
Parameters
601+
----------
602+
mask : the condition to respect
603+
new : a ndarray/object
604+
align : boolean, perform alignment on other/cond, default is True
605+
inplace : perform inplace modification, default is False
606+
607+
Returns
608+
-------
609+
a new block(s), the result of the putmask
610+
"""
599611

600612
new_values = self.values if inplace else self.values.copy()
601613

602614
# may need to align the new
603615
if hasattr(new, 'reindex_axis'):
604-
axis = getattr(new, '_info_axis_number', 0)
605-
new = new.reindex_axis(self.items, axis=axis, copy=False).values.T
616+
if align:
617+
axis = getattr(new, '_info_axis_number', 0)
618+
new = new.reindex_axis(self.items, axis=axis, copy=False).values.T
619+
else:
620+
new = new.values.T
606621

607622
# may need to align the mask
608623
if hasattr(mask, 'reindex_axis'):
609-
axis = getattr(mask, '_info_axis_number', 0)
610-
mask = mask.reindex_axis(
611-
self.items, axis=axis, copy=False).values.T
624+
if align:
625+
axis = getattr(mask, '_info_axis_number', 0)
626+
mask = mask.reindex_axis(
627+
self.items, axis=axis, copy=False).values.T
628+
else:
629+
mask = mask.values.T
612630

613631
# if we are passed a scalar None, convert it here
614632
if not is_list_like(new) and isnull(new):
615633
new = np.nan
616634

617635
if self._can_hold_element(new):
618636
new = self._try_cast(new)
637+
638+
# pseudo-broadcast
639+
if isinstance(new,np.ndarray) and new.ndim == self.ndim-1:
640+
new = np.repeat(new,self.shape[-1]).reshape(self.shape)
641+
619642
np.putmask(new_values, mask, new)
620643

621644
# maybe upcast me
@@ -842,14 +865,15 @@ def handle_error():
842865

843866
return [make_block(result, self.items, self.ref_items, ndim=self.ndim, fastpath=True)]
844867

845-
def where(self, other, cond, raise_on_error=True, try_cast=False):
868+
def where(self, other, cond, align=True, raise_on_error=True, try_cast=False):
846869
"""
847870
evaluate the block; return result block(s) from the result
848871
849872
Parameters
850873
----------
851874
other : a ndarray/object
852875
cond : the condition to respect
876+
align : boolean, perform alignment on other/cond
853877
raise_on_error : if True, raise when I can't perform the function, False by default (and just return
854878
the data that we had coming in)
855879
@@ -862,21 +886,30 @@ def where(self, other, cond, raise_on_error=True, try_cast=False):
862886

863887
# see if we can align other
864888
if hasattr(other, 'reindex_axis'):
865-
axis = getattr(other, '_info_axis_number', 0)
866-
other = other.reindex_axis(self.items, axis=axis, copy=True).values
889+
if align:
890+
axis = getattr(other, '_info_axis_number', 0)
891+
other = other.reindex_axis(self.items, axis=axis, copy=True).values
892+
else:
893+
other = other.values
867894

868895
# make sure that we can broadcast
869896
is_transposed = False
870897
if hasattr(other, 'ndim') and hasattr(values, 'ndim'):
871898
if values.ndim != other.ndim or values.shape == other.shape[::-1]:
872-
values = values.T
873-
is_transposed = True
899+
900+
# pseodo broadcast (its a 2d vs 1d say and where needs it in a specific direction)
901+
if other.ndim >= 1 and values.ndim-1 == other.ndim and values.shape[0] != other.shape[0]:
902+
other = _block_shape(other).T
903+
else:
904+
values = values.T
905+
is_transposed = True
874906

875907
# see if we can align cond
876908
if not hasattr(cond, 'shape'):
877909
raise ValueError(
878910
"where must have a condition that is ndarray like")
879-
if hasattr(cond, 'reindex_axis'):
911+
912+
if align and hasattr(cond, 'reindex_axis'):
880913
axis = getattr(cond, '_info_axis_number', 0)
881914
cond = cond.reindex_axis(self.items, axis=axis, copy=True).values
882915
else:

pandas/core/series.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2725,7 +2725,7 @@ def apply(self, func, convert_dtype=True, args=(), **kwds):
27252725
else:
27262726
return self._constructor(mapped, index=self.index, name=self.name)
27272727

2728-
def align(self, other, join='outer', level=None, copy=True,
2728+
def align(self, other, join='outer', axis=None, level=None, copy=True,
27292729
fill_value=None, method=None, limit=None):
27302730
"""
27312731
Align two Series object with the specified join method
@@ -2734,6 +2734,7 @@ def align(self, other, join='outer', level=None, copy=True,
27342734
----------
27352735
other : Series
27362736
join : {'outer', 'inner', 'left', 'right'}, default 'outer'
2737+
axis : None, alignment axis (is 0 for Series)
27372738
level : int or name
27382739
Broadcast across a level, matching Index values on the
27392740
passed MultiIndex level

pandas/tests/test_frame.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7931,6 +7931,35 @@ def test_where_none(self):
79317931
expected = DataFrame({'series': Series([0,1,2,3,4,5,6,7,np.nan,np.nan]) })
79327932
assert_frame_equal(df, expected)
79337933

7934+
def test_where_align(self):
7935+
7936+
def create():
7937+
df = DataFrame(np.random.randn(10,3))
7938+
df.iloc[3:5,0] = np.nan
7939+
df.iloc[4:6,1] = np.nan
7940+
df.iloc[5:8,2] = np.nan
7941+
return df
7942+
7943+
# series
7944+
df = create()
7945+
expected = df.fillna(df.mean())
7946+
result = df.where(pd.notnull(df),df.mean(),axis='columns')
7947+
assert_frame_equal(result, expected)
7948+
7949+
df.where(pd.notnull(df),df.mean(),inplace=True,axis='columns')
7950+
assert_frame_equal(df, expected)
7951+
7952+
df = create().fillna(0)
7953+
expected = df.apply(lambda x, y: x.where(x>0,y), y=df[0])
7954+
result = df.where(df>0,df[0],axis='index')
7955+
assert_frame_equal(result, expected)
7956+
7957+
# frame
7958+
df = create()
7959+
expected = df.fillna(1)
7960+
result = df.where(pd.notnull(df),DataFrame(1,index=df.index,columns=df.columns))
7961+
assert_frame_equal(result, expected)
7962+
79347963
def test_mask(self):
79357964
df = DataFrame(np.random.randn(5, 3))
79367965
cond = df > 0

0 commit comments

Comments
 (0)