From 5730e3bba21331c0f516500ee116fc91289f4848 Mon Sep 17 00:00:00 2001 From: Dries Schaumont Date: Wed, 14 Apr 2021 16:02:06 +0200 Subject: [PATCH 1/4] Add test and attempted fix. --- pandas/core/generic.py | 2 +- pandas/tests/frame/indexing/test_mask.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index bb8de35d22462..ae7570754f084 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -8983,7 +8983,7 @@ def _where( join="left", axis=axis, level=level, - fill_value=np.nan, + fill_value=None, copy=False, ) diff --git a/pandas/tests/frame/indexing/test_mask.py b/pandas/tests/frame/indexing/test_mask.py index afa8c757c23e4..a7d2236ac9262 100644 --- a/pandas/tests/frame/indexing/test_mask.py +++ b/pandas/tests/frame/indexing/test_mask.py @@ -5,7 +5,10 @@ import numpy as np from pandas import ( + NA, DataFrame, + Series, + StringDtype, isna, ) import pandas._testing as tm @@ -99,3 +102,16 @@ def test_mask_try_cast_deprecated(frame_or_series): with tm.assert_produces_warning(FutureWarning): # try_cast keyword deprecated obj.mask(mask, -1, try_cast=True) + + +def test_mask_stringdtype(): + df = DataFrame( + {"A": ["foo", "bar", "baz", NA]}, + index=["id1", "id2", "id3", "id4"], + dtype=StringDtype(), + ) + filtered_df = DataFrame( + {"A": ["this", "that"]}, index=["id2", "id3"], dtype=StringDtype() + ) + filter = Series([False, True, True, False]) + df.mask(filter, filtered_df) From e7e0b6510ef1e9825dbf824bf3336ab14f94ad07 Mon Sep 17 00:00:00 2001 From: Dries Schaumont Date: Thu, 15 Apr 2021 11:13:22 +0200 Subject: [PATCH 2/4] Add whatsnew and issue number. --- doc/source/whatsnew/v1.3.0.rst | 2 +- pandas/tests/frame/indexing/test_mask.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index 1c7942dfedafa..0ef4cb00d5ae5 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -827,7 +827,7 @@ ExtensionArray - Bug in :meth:`DataFrame.where` when ``other`` is a :class:`Series` with :class:`ExtensionArray` dtype (:issue:`38729`) - Fixed bug where :meth:`Series.idxmax`, :meth:`Series.idxmin` and ``argmax/min`` fail when the underlying data is :class:`ExtensionArray` (:issue:`32749`, :issue:`33719`, :issue:`36566`) - Fixed a bug where some properties of subclasses of :class:`PandasExtensionDtype` where improperly cached (:issue:`40329`) -- +- Bug in :meth:`DataFrame.mask` where masking a :class:`Dataframe` with an :class:`ExtensionArray` dtype raises ``ValueError`` (:issue:`40941`) Styler ^^^^^^ diff --git a/pandas/tests/frame/indexing/test_mask.py b/pandas/tests/frame/indexing/test_mask.py index a7d2236ac9262..af8e21f1fc360 100644 --- a/pandas/tests/frame/indexing/test_mask.py +++ b/pandas/tests/frame/indexing/test_mask.py @@ -105,6 +105,7 @@ def test_mask_try_cast_deprecated(frame_or_series): def test_mask_stringdtype(): + # GH 40824 df = DataFrame( {"A": ["foo", "bar", "baz", NA]}, index=["id1", "id2", "id3", "id4"], @@ -113,5 +114,12 @@ def test_mask_stringdtype(): filtered_df = DataFrame( {"A": ["this", "that"]}, index=["id2", "id3"], dtype=StringDtype() ) - filter = Series([False, True, True, False]) - df.mask(filter, filtered_df) + filter_ser = Series([False, True, True, False]) + result = df.mask(filter_ser, filtered_df) + + expected = DataFrame( + {"A": [NA, "this", "that", NA]}, + index=["id1", "id2", "id3", "id4"], + dtype=StringDtype(), + ) + tm.assert_equal(result, expected) From 9601f25623b97af74bb2449fd5cd9eb60a9bddd0 Mon Sep 17 00:00:00 2001 From: Dries Schaumont Date: Thu, 15 Apr 2021 22:14:02 +0200 Subject: [PATCH 3/4] Add where test. --- pandas/tests/frame/indexing/test_where.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/pandas/tests/frame/indexing/test_where.py b/pandas/tests/frame/indexing/test_where.py index 574fa46d10f67..7ffe2fb9ab1ff 100644 --- a/pandas/tests/frame/indexing/test_where.py +++ b/pandas/tests/frame/indexing/test_where.py @@ -10,6 +10,7 @@ DataFrame, DatetimeIndex, Series, + StringDtype, Timestamp, date_range, isna, @@ -709,3 +710,22 @@ def test_where_copies_with_noop(frame_or_series): where_res *= 2 tm.assert_equal(result, expected) + + +def test_where_string_dtype(frame_or_series): + # GH40824 + obj = frame_or_series( + ["a", "b", "c", "d"], index=["id1", "id2", "id3", "id4"], dtype=StringDtype() + ) + filtered_obj = frame_or_series( + ["b", "c"], index=["id2", "id3"], dtype=StringDtype() + ) + filter_ser = Series([False, True, True, False]) + + result = obj.where(filter_ser, filtered_obj) + expected = frame_or_series( + [pd.NA, "b", "c", pd.NA], + index=["id1", "id2", "id3", "id4"], + dtype=StringDtype(), + ) + tm.assert_equal(result, expected) From bd7282c8f235a9ddf461814d565d5b4841915102 Mon Sep 17 00:00:00 2001 From: Dries Schaumont Date: Fri, 16 Apr 2021 06:05:36 +0200 Subject: [PATCH 4/4] Add series test, change assert statement. --- pandas/tests/frame/indexing/test_mask.py | 2 +- pandas/tests/series/indexing/test_mask.py | 25 ++++++++++++++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/pandas/tests/frame/indexing/test_mask.py b/pandas/tests/frame/indexing/test_mask.py index af8e21f1fc360..364475428e529 100644 --- a/pandas/tests/frame/indexing/test_mask.py +++ b/pandas/tests/frame/indexing/test_mask.py @@ -122,4 +122,4 @@ def test_mask_stringdtype(): index=["id1", "id2", "id3", "id4"], dtype=StringDtype(), ) - tm.assert_equal(result, expected) + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/series/indexing/test_mask.py b/pandas/tests/series/indexing/test_mask.py index dc4fb530dbb52..a4dda3a5c0c5b 100644 --- a/pandas/tests/series/indexing/test_mask.py +++ b/pandas/tests/series/indexing/test_mask.py @@ -1,7 +1,11 @@ import numpy as np import pytest -from pandas import Series +from pandas import ( + NA, + Series, + StringDtype, +) import pandas._testing as tm @@ -63,3 +67,22 @@ def test_mask_inplace(): rs = s.copy() rs.mask(cond, -s, inplace=True) tm.assert_series_equal(rs, s.mask(cond, -s)) + + +def test_mask_stringdtype(): + # GH 40824 + ser = Series( + ["foo", "bar", "baz", NA], + index=["id1", "id2", "id3", "id4"], + dtype=StringDtype(), + ) + filtered_ser = Series(["this", "that"], index=["id2", "id3"], dtype=StringDtype()) + filter_ser = Series([False, True, True, False]) + result = ser.mask(filter_ser, filtered_ser) + + expected = Series( + [NA, "this", "that", NA], + index=["id1", "id2", "id3", "id4"], + dtype=StringDtype(), + ) + tm.assert_series_equal(result, expected)