From e9173856f0fb42a76cd7cea95be763c1ec700a51 Mon Sep 17 00:00:00 2001 From: aram-cinnamon Date: Mon, 18 Dec 2023 06:39:21 +0100 Subject: [PATCH 1/3] add test for groupby on complex numbers --- pandas/tests/groupby/test_groupby.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/pandas/tests/groupby/test_groupby.py b/pandas/tests/groupby/test_groupby.py index fce7caa90cce4..1926796b06802 100644 --- a/pandas/tests/groupby/test_groupby.py +++ b/pandas/tests/groupby/test_groupby.py @@ -136,6 +136,19 @@ def test_basic_aggregations(dtype): grouped.aggregate(lambda x: x * 2) +def test_groupby_complex_numbers(): + df = DataFrame([ + {"a": 2, "b": 1 + 2j}, + {"a": 1, "b": 1 + 1j}, + {"a": 1, "b": 1 + 2j}, + ]) + assert df.groupby("b").groups == {(1 + 1j): [1], (1 + 2j): [0, 2]} + tm.assert_frame_equal( + df.groupby("b").mean(), + DataFrame({"a": {(1 + 1j): 1.0, (1 + 2j): 1.5}}), + ) + + def test_groupby_nonobject_dtype(multiindex_dataframe_random_data): key = multiindex_dataframe_random_data.index.codes[0] grouped = multiindex_dataframe_random_data.groupby(key) From 808bff452f90aa5d778f86b7253b6aced3a42471 Mon Sep 17 00:00:00 2001 From: aram-cinnamon Date: Tue, 19 Dec 2023 21:23:40 +0100 Subject: [PATCH 2/3] move and change name --- pandas/tests/groupby/test_groupby.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/pandas/tests/groupby/test_groupby.py b/pandas/tests/groupby/test_groupby.py index 1926796b06802..3e0d36646c1ef 100644 --- a/pandas/tests/groupby/test_groupby.py +++ b/pandas/tests/groupby/test_groupby.py @@ -136,19 +136,6 @@ def test_basic_aggregations(dtype): grouped.aggregate(lambda x: x * 2) -def test_groupby_complex_numbers(): - df = DataFrame([ - {"a": 2, "b": 1 + 2j}, - {"a": 1, "b": 1 + 1j}, - {"a": 1, "b": 1 + 2j}, - ]) - assert df.groupby("b").groups == {(1 + 1j): [1], (1 + 2j): [0, 2]} - tm.assert_frame_equal( - df.groupby("b").mean(), - DataFrame({"a": {(1 + 1j): 1.0, (1 + 2j): 1.5}}), - ) - - def test_groupby_nonobject_dtype(multiindex_dataframe_random_data): key = multiindex_dataframe_random_data.index.codes[0] grouped = multiindex_dataframe_random_data.groupby(key) @@ -1209,6 +1196,21 @@ def test_groupby_complex(): tm.assert_series_equal(result, expected) +def test_groupby_complex_2(): + df = DataFrame( + [ + {"a": 2, "b": 1 + 2j}, + {"a": 1, "b": 1 + 1j}, + {"a": 1, "b": 1 + 2j}, + ] + ) + assert df.groupby("b").groups == {(1 + 1j): [1], (1 + 2j): [0, 2]} + tm.assert_frame_equal( + df.groupby("b").mean(), + DataFrame({"a": {(1 + 1j): 1.0, (1 + 2j): 1.5}}), + ) + + def test_groupby_complex_numbers(using_infer_string): # GH 17927 df = DataFrame( From 8bad6978b5036476a842359cd696bd2ad5c61ae4 Mon Sep 17 00:00:00 2001 From: aram-cinnamon Date: Thu, 21 Dec 2023 04:04:31 +0100 Subject: [PATCH 3/3] fix per comments --- pandas/tests/groupby/test_groupby.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/pandas/tests/groupby/test_groupby.py b/pandas/tests/groupby/test_groupby.py index 3e0d36646c1ef..4c903e691add1 100644 --- a/pandas/tests/groupby/test_groupby.py +++ b/pandas/tests/groupby/test_groupby.py @@ -1196,7 +1196,8 @@ def test_groupby_complex(): tm.assert_series_equal(result, expected) -def test_groupby_complex_2(): +def test_groupby_complex_mean(): + # GH 26475 df = DataFrame( [ {"a": 2, "b": 1 + 2j}, @@ -1204,11 +1205,13 @@ def test_groupby_complex_2(): {"a": 1, "b": 1 + 2j}, ] ) - assert df.groupby("b").groups == {(1 + 1j): [1], (1 + 2j): [0, 2]} - tm.assert_frame_equal( - df.groupby("b").mean(), - DataFrame({"a": {(1 + 1j): 1.0, (1 + 2j): 1.5}}), + result = df.groupby("b").mean() + expected = DataFrame( + [[1.0], [1.5]], + index=Index([(1 + 1j), (1 + 2j)], name="b"), + columns=Index(["a"]), ) + tm.assert_frame_equal(result, expected) def test_groupby_complex_numbers(using_infer_string):