Skip to content

Commit 2ccf1cb

Browse files
committed
ENH: NDFrame.mask supports same kwds as where
1 parent 0222024 commit 2ccf1cb

File tree

4 files changed

+101
-25
lines changed

4 files changed

+101
-25
lines changed

doc/source/whatsnew/v0.16.1.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Enhancements
1919

2020
- Added ``StringMethods.capitalize()`` and ``swapcase`` which behave as the same as standard ``str`` (:issue:`9766`)
2121

22-
22+
- ``DataFrame.mask()`` and ``Series.mask()`` now support same keywords as ``where`` (:issue:`8801`)
2323

2424

2525

pandas/core/generic.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3250,16 +3250,14 @@ def _align_series(self, other, join='outer', axis=None, level=None,
32503250
return (left_result.__finalize__(self),
32513251
right_result.__finalize__(other))
32523252

3253-
def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
3254-
try_cast=False, raise_on_error=True):
3255-
"""
3253+
_shared_docs['where'] = ("""
32563254
Return an object of same shape as self and whose corresponding
3257-
entries are from self where cond is True and otherwise are from other.
3255+
entries are from self where cond is %(cond)s and otherwise are from other.
32583256
32593257
Parameters
32603258
----------
3261-
cond : boolean NDFrame or array
3262-
other : scalar or NDFrame
3259+
cond : boolean %(klass)s or array
3260+
other : scalar or %(klass)s
32633261
inplace : boolean, default False
32643262
Whether to perform the operation in place on the data
32653263
axis : alignment axis if needed, default None
@@ -3273,7 +3271,11 @@ def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
32733271
Returns
32743272
-------
32753273
wh : same type as caller
3276-
"""
3274+
""")
3275+
@Appender(_shared_docs['where'] % dict(_shared_doc_kwargs, cond="True"))
3276+
def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
3277+
try_cast=False, raise_on_error=True):
3278+
32773279
if isinstance(cond, NDFrame):
32783280
cond = cond.reindex(**self._construct_axes_dict())
32793281
else:
@@ -3400,20 +3402,11 @@ def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
34003402

34013403
return self._constructor(new_data).__finalize__(self)
34023404

3403-
def mask(self, cond):
3404-
"""
3405-
Returns copy whose values are replaced with nan if the
3406-
inverted condition is True
3407-
3408-
Parameters
3409-
----------
3410-
cond : boolean NDFrame or array
3411-
3412-
Returns
3413-
-------
3414-
wh: same as input
3415-
"""
3416-
return self.where(~cond, np.nan)
3405+
@Appender(_shared_docs['where'] % dict(_shared_doc_kwargs, cond="False"))
3406+
def mask(self, cond, other=np.nan, inplace=False, axis=None, level=None,
3407+
try_cast=False, raise_on_error=True):
3408+
return self.where(~cond, other=other, inplace=inplace, axis=axis,
3409+
level=level, try_cast=try_cast, raise_on_error=raise_on_error)
34173410

34183411
def shift(self, periods=1, freq=None, axis=0, **kwargs):
34193412
"""

pandas/tests/test_frame.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9775,6 +9775,27 @@ def test_mask(self):
97759775
assert_frame_equal(rs, df.mask(df <= 0))
97769776
assert_frame_equal(rs, df.mask(~cond))
97779777

9778+
other = DataFrame(np.random.randn(5, 3))
9779+
rs = df.where(cond, other)
9780+
assert_frame_equal(rs, df.mask(df <= 0, other))
9781+
assert_frame_equal(rs, df.mask(~cond, other))
9782+
9783+
def test_mask_inplace(self):
9784+
# GH8801
9785+
df = DataFrame(np.random.randn(5, 3))
9786+
cond = df > 0
9787+
9788+
rdf = df.copy()
9789+
9790+
rdf.where(cond, inplace=True)
9791+
assert_frame_equal(rdf, df.where(cond))
9792+
assert_frame_equal(rdf, df.mask(~cond))
9793+
9794+
rdf = df.copy()
9795+
rdf.where(cond, -df, inplace=True)
9796+
assert_frame_equal(rdf, df.where(cond, -df))
9797+
assert_frame_equal(rdf, df.mask(~cond, -df))
9798+
97789799
def test_mask_edge_case_1xN_frame(self):
97799800
# GH4071
97809801
df = DataFrame([[1, 2]])

pandas/tests/test_series.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1821,6 +1821,10 @@ def test_where_broadcast(self):
18211821
for i, use_item in enumerate(selection)])
18221822
assert_series_equal(s, expected)
18231823

1824+
s = Series(data)
1825+
result = s.where(~selection, arr)
1826+
assert_series_equal(result, expected)
1827+
18241828
def test_where_inplace(self):
18251829
s = Series(np.random.randn(5))
18261830
cond = s > 0
@@ -1856,11 +1860,69 @@ def test_where_dups(self):
18561860
assert_series_equal(comb, expected)
18571861

18581862
def test_mask(self):
1863+
# compare with tested results in test_where
1864+
s = Series(np.random.randn(5))
1865+
cond = s > 0
1866+
1867+
rs = s.where(~cond, np.nan)
1868+
assert_series_equal(rs, s.mask(cond))
1869+
1870+
rs = s.where(~cond)
1871+
rs2 = s.mask(cond)
1872+
assert_series_equal(rs, rs2)
1873+
1874+
rs = s.where(~cond, -s)
1875+
rs2 = s.mask(cond, -s)
1876+
assert_series_equal(rs, rs2)
1877+
1878+
cond = Series([True, False, False, True, False], index=s.index)
1879+
s2 = -(s.abs())
1880+
rs = s2.where(~cond[:3])
1881+
rs2 = s2.mask(cond[:3])
1882+
assert_series_equal(rs, rs2)
1883+
1884+
rs = s2.where(~cond[:3], -s2)
1885+
rs2 = s2.mask(cond[:3], -s2)
1886+
assert_series_equal(rs, rs2)
1887+
1888+
self.assertRaises(ValueError, s.mask, 1)
1889+
self.assertRaises(ValueError, s.mask, cond[:3].values, -s)
1890+
1891+
# dtype changes
1892+
s = Series([1,2,3,4])
1893+
result = s.mask(s>2, np.nan)
1894+
expected = Series([1, 2, np.nan, np.nan])
1895+
assert_series_equal(result, expected)
1896+
1897+
def test_mask_broadcast(self):
1898+
# GH 8801
1899+
# copied from test_where_broadcast
1900+
for size in range(2, 6):
1901+
for selection in [np.resize([True, False, False, False, False], size), # First element should be set
1902+
# Set alternating elements]
1903+
np.resize([True, False], size),
1904+
np.resize([False], size)]: # No element should be set
1905+
for item in [2.0, np.nan, np.finfo(np.float).max, np.finfo(np.float).min]:
1906+
for arr in [np.array([item]), [item], (item,)]:
1907+
data = np.arange(size, dtype=float)
1908+
s = Series(data)
1909+
result = s.mask(selection, arr)
1910+
expected = Series([item if use_item else data[i]
1911+
for i, use_item in enumerate(selection)])
1912+
assert_series_equal(result, expected)
1913+
1914+
def test_mask_inplace(self):
18591915
s = Series(np.random.randn(5))
18601916
cond = s > 0
18611917

1862-
rs = s.where(cond, np.nan)
1863-
assert_series_equal(rs, s.mask(~cond))
1918+
rs = s.copy()
1919+
rs.mask(cond, inplace=True)
1920+
assert_series_equal(rs.dropna(), s[~cond])
1921+
assert_series_equal(rs, s.mask(cond))
1922+
1923+
rs = s.copy()
1924+
rs.mask(cond, -s, inplace=True)
1925+
assert_series_equal(rs, s.mask(cond, -s))
18641926

18651927
def test_drop(self):
18661928

@@ -6845,7 +6907,7 @@ def test_repeat(self):
68456907
def test_unique_data_ownership(self):
68466908
# it works! #1807
68476909
Series(Series(["a", "c", "b"]).unique()).sort()
6848-
6910+
68496911
def test_datetime_timedelta_quantiles(self):
68506912
# covers #9694
68516913
self.assertTrue(pd.isnull(Series([],dtype='M8[ns]').quantile(.5)))

0 commit comments

Comments
 (0)