diff --git a/pandas/tests/groupby/transform/test_transform.py b/pandas/tests/groupby/transform/test_transform.py index 528f417ea1039..ab4f2bd5e529b 100644 --- a/pandas/tests/groupby/transform/test_transform.py +++ b/pandas/tests/groupby/transform/test_transform.py @@ -1046,6 +1046,26 @@ def test_groupby_transform_with_datetimes(func, values): tm.assert_series_equal(result, expected) +def test_groupby_transform_dtype(): + # GH 22243 + df = DataFrame({"a": [1], "val": [1.35]}) + + result = df["val"].transform(lambda x: x.map(lambda y: f"+{y}")) + expected1 = Series(["+1.35"], name="val", dtype="object") + tm.assert_series_equal(result, expected1) + + result = df.groupby("a")["val"].transform(lambda x: x.map(lambda y: f"+{y}")) + tm.assert_series_equal(result, expected1) + + result = df.groupby("a")["val"].transform(lambda x: x.map(lambda y: f"+({y})")) + expected2 = Series(["+(1.35)"], name="val", dtype="object") + tm.assert_series_equal(result, expected2) + + df["val"] = df["val"].astype(object) + result = df.groupby("a")["val"].transform(lambda x: x.map(lambda y: f"+{y}")) + tm.assert_series_equal(result, expected1) + + @pytest.mark.parametrize("func", ["cumsum", "cumprod", "cummin", "cummax"]) def test_transform_absent_categories(func): # GH 16771