From f2c3713f87cbe612bcd36a24f0f37dc333acc293 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Tue, 22 Oct 2019 10:01:59 +0200 Subject: [PATCH 1/5] BUG/TST: ensure groupby.agg preserves extension dtype --- pandas/core/groupby/ops.py | 4 +-- .../tests/extension/decimal/test_decimal.py | 34 +++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index fcc646dec89d9..bed7f38c61ca5 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -655,7 +655,7 @@ def agg_series(self, obj, func): return self._aggregate_series_fast(obj, func) except AssertionError: raise - except ValueError as err: + except (ValueError, AttributeError) as err: if "No result." in str(err): # raised in libreduction pass @@ -663,7 +663,7 @@ def agg_series(self, obj, func): # raised in libreduction pass else: - raise + pass return self._aggregate_series_pure_python(obj, func) def _aggregate_series_fast(self, obj, func): diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index 3ac9d37ccf4f3..3da136899d83c 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -426,3 +426,37 @@ def test_array_ufunc_series_defer(): tm.assert_series_equal(r1, expected) tm.assert_series_equal(r2, expected) + + +def test_groupby_agg(): + # Ensure that the result of agg is inferred to be decimal dtype + # https://github.com/pandas-dev/pandas/issues/29141 + + data = make_data()[:5] + df = pd.DataFrame({"id": [0, 0, 0, 1, 1], "decimals": DecimalArray(data)}) + expected = pd.Series(to_decimal([data[0], data[3]])) + + result = df.groupby("id")["decimals"].agg(lambda x: x.iloc[0]) + tm.assert_series_equal(result, expected, check_names=False) + result = df["decimals"].groupby(df["id"]).agg(lambda x: x.iloc[0]) + tm.assert_series_equal(result, expected, check_names=False) + + +def test_groupby_agg_ea_method(monkeypatch): + # Ensure that the result of agg is inferred to be decimal dtype + # https://github.com/pandas-dev/pandas/issues/29141 + + def DecimalArray__my_sum(self): + return np.sum(np.array(self)) + + monkeypatch.setattr(DecimalArray, "my_sum", DecimalArray__my_sum, raising=False) + + data = make_data()[:5] + df = pd.DataFrame({"id": [0, 0, 0, 1, 1], "decimals": DecimalArray(data)}) + expected = pd.Series(to_decimal([data[0] + data[1] + data[2], data[3] + data[4]])) + + result = df.groupby("id")["decimals"].agg(lambda x: x.values.my_sum()) + tm.assert_series_equal(result, expected, check_names=False) + s = pd.Series(DecimalArray(data)) + result = s.groupby(np.array([0, 0, 0, 1, 1])).agg(lambda x: x.values.my_sum()) + tm.assert_series_equal(result, expected, check_names=False) From 617af953c18b59544a1fe2f4f6c050ba8887cc67 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Tue, 22 Oct 2019 15:03:03 +0200 Subject: [PATCH 2/5] also catch TypeError --- pandas/core/groupby/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index bed7f38c61ca5..5f96cdac107fa 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -655,7 +655,7 @@ def agg_series(self, obj, func): return self._aggregate_series_fast(obj, func) except AssertionError: raise - except (ValueError, AttributeError) as err: + except (ValueError, AttributeError, TypeError) as err: if "No result." in str(err): # raised in libreduction pass From 7255da1797ee2e750887d64d28b639674b2ff683 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Tue, 22 Oct 2019 17:23:35 +0200 Subject: [PATCH 3/5] only catch typeerror --- pandas/core/groupby/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 5f96cdac107fa..4a3d279dbb147 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -655,7 +655,7 @@ def agg_series(self, obj, func): return self._aggregate_series_fast(obj, func) except AssertionError: raise - except (ValueError, AttributeError, TypeError) as err: + except (ValueError, TypeError) as err: if "No result." in str(err): # raised in libreduction pass From e4210a2ce026066301483ef3f1f2b4943d4f389b Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 23 Oct 2019 07:28:36 +0200 Subject: [PATCH 4/5] check for specific error message --- pandas/core/groupby/ops.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 4a3d279dbb147..7c8c7e4f5bcd4 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -655,7 +655,7 @@ def agg_series(self, obj, func): return self._aggregate_series_fast(obj, func) except AssertionError: raise - except (ValueError, TypeError) as err: + except ValueError as err: if "No result." in str(err): # raised in libreduction pass @@ -663,8 +663,14 @@ def agg_series(self, obj, func): # raised in libreduction pass else: + raise + except TypeError as err: + if "ndarray" in str(err): + # raised in libreduction if obj's values is no ndarray pass - return self._aggregate_series_pure_python(obj, func) + else: + raise + return self._aggregate_series_pure_python(obj, func) def _aggregate_series_fast(self, obj, func): func = self._is_builtin_func(func) From 83535c744532d1fd857f94cd143a2137aa4aebb0 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 23 Oct 2019 16:54:00 +0200 Subject: [PATCH 5/5] additional tests with multiple keys/columns --- .../tests/extension/decimal/test_decimal.py | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index 3da136899d83c..86724d4d09819 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -433,14 +433,32 @@ def test_groupby_agg(): # https://github.com/pandas-dev/pandas/issues/29141 data = make_data()[:5] - df = pd.DataFrame({"id": [0, 0, 0, 1, 1], "decimals": DecimalArray(data)}) + df = pd.DataFrame( + {"id1": [0, 0, 0, 1, 1], "id2": [0, 1, 0, 1, 1], "decimals": DecimalArray(data)} + ) + + # single key, selected column expected = pd.Series(to_decimal([data[0], data[3]])) + result = df.groupby("id1")["decimals"].agg(lambda x: x.iloc[0]) + tm.assert_series_equal(result, expected, check_names=False) + result = df["decimals"].groupby(df["id1"]).agg(lambda x: x.iloc[0]) + tm.assert_series_equal(result, expected, check_names=False) - result = df.groupby("id")["decimals"].agg(lambda x: x.iloc[0]) + # multiple keys, selected column + expected = pd.Series( + to_decimal([data[0], data[1], data[3]]), + index=pd.MultiIndex.from_tuples([(0, 0), (0, 1), (1, 1)]), + ) + result = df.groupby(["id1", "id2"])["decimals"].agg(lambda x: x.iloc[0]) tm.assert_series_equal(result, expected, check_names=False) - result = df["decimals"].groupby(df["id"]).agg(lambda x: x.iloc[0]) + result = df["decimals"].groupby([df["id1"], df["id2"]]).agg(lambda x: x.iloc[0]) tm.assert_series_equal(result, expected, check_names=False) + # multiple columns + expected = pd.DataFrame({"id2": [0, 1], "decimals": to_decimal([data[0], data[3]])}) + result = df.groupby("id1").agg(lambda x: x.iloc[0]) + tm.assert_frame_equal(result, expected, check_names=False) + def test_groupby_agg_ea_method(monkeypatch): # Ensure that the result of agg is inferred to be decimal dtype