diff --git a/pandas/tests/groupby/test_transform.py b/pandas/tests/groupby/test_transform.py index 740103eec185a..73a99f2dd4402 100644 --- a/pandas/tests/groupby/test_transform.py +++ b/pandas/tests/groupby/test_transform.py @@ -1,4 +1,5 @@ """ test with the .transform """ +from datetime import timedelta from io import StringIO import numpy as np @@ -23,6 +24,17 @@ from pandas.core.groupby.groupby import DataError +@pytest.fixture +def df_for_transformation_func(): + return DataFrame( + { + "A": [121, 121, 121, 121, 231, 231, 676], + "B": [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0], + "C": pd.date_range("2013-11-03", periods=7, freq="D"), + } + ) + + def assert_fp_equal(a, b): assert (np.abs(a - b) < 1e-12).all() @@ -318,7 +330,7 @@ def test_dispatch_transform(tsframe): tm.assert_frame_equal(filled, expected) -def test_transform_transformation_func(transformation_func): +def test_transform_transformation_func(transformation_func, df_for_transformation_func): # GH 30918 df = DataFrame( { @@ -346,6 +358,47 @@ def test_transform_transformation_func(transformation_func): tm.assert_frame_equal(result, expected) +def test_groupby_transform_corrwith(df_for_transformation_func): + + # GH 27905 + df = df_for_transformation_func + g = df.groupby("A") + + result = g.corrwith(df) + expected = pd.DataFrame(dict(B=[1, np.nan, np.nan], A=[np.nan] * 3)) + expected.index = pd.Index([121, 231, 676], name="A") + tm.assert_frame_equal(result, expected) + + msg = "'Series' object has no attribute 'corrwith'" + + with pytest.raises(AttributeError, match=msg): + g.transform("corrwith", df) + + +def test_groupby_transform_tshift(df_for_transformation_func): + + # GH 27905 + df = df_for_transformation_func + g = df.set_index("C").groupby("A") + result = g.tshift(2, "D") + df["C"] = df["C"] + timedelta(days=2) + expected = df + tm.assert_frame_equal( + result.reset_index().reindex(columns=["A", "B", "C"]), expected + ) + + op1 = g.transform(lambda x: x.tshift(2, "D")) + op2 = g.transform("tshift", *[2, "D"]) + + for result in [op1, op2]: + pytest.xfail( + "The output of groupby.transform with tshift is wrong, see GH 32344" + ) + tm.assert_frame_equal( + result.reset_index().reindex(columns=["A", "B", "C"]), expected + ) + + def test_transform_select_columns(df): f = lambda x: x.mean() result = df.groupby("A")[["C", "D"]].transform(f)