diff --git a/doc/source/whatsnew/v2.1.0.rst b/doc/source/whatsnew/v2.1.0.rst index 9c64105fb0464..a5520d9e5ca2f 100644 --- a/doc/source/whatsnew/v2.1.0.rst +++ b/doc/source/whatsnew/v2.1.0.rst @@ -636,6 +636,7 @@ Reshaping ^^^^^^^^^ - Bug in :func:`concat` coercing to ``object`` dtype when one column has ``pa.null()`` dtype (:issue:`53702`) - Bug in :func:`crosstab` when ``dropna=False`` would not keep ``np.nan`` in the result (:issue:`10772`) +- Bug in :func:`melt` where the ``variable`` column would lose extension dtypes (:issue:`54297`) - Bug in :func:`merge_asof` raising ``KeyError`` for extension dtypes (:issue:`52904`) - Bug in :func:`merge_asof` raising ``ValueError`` for data backed by read-only ndarrays (:issue:`53513`) - Bug in :func:`merge_asof` with ``left_index=True`` or ``right_index=True`` with mismatched index dtypes giving incorrect results in some cases instead of raising ``MergeError`` (:issue:`53870`) diff --git a/pandas/core/reshape/melt.py b/pandas/core/reshape/melt.py index cd0dc313a7fb3..74e6a6a28ccb0 100644 --- a/pandas/core/reshape/melt.py +++ b/pandas/core/reshape/melt.py @@ -141,8 +141,7 @@ def melt( else: mdata[value_name] = frame._values.ravel("F") for i, col in enumerate(var_name): - # asanyarray will keep the columns as an Index - mdata[col] = np.asanyarray(frame.columns._get_level_values(i)).repeat(N) + mdata[col] = frame.columns._get_level_values(i).repeat(N) result = frame._constructor(mdata, columns=mcolumns) diff --git a/pandas/tests/reshape/test_melt.py b/pandas/tests/reshape/test_melt.py index ea9da4e87240b..bf03aa56a1da2 100644 --- a/pandas/tests/reshape/test_melt.py +++ b/pandas/tests/reshape/test_melt.py @@ -437,6 +437,26 @@ def test_melt_ea_dtype(self, dtype): ) tm.assert_frame_equal(result, expected) + def test_melt_ea_columns(self): + # GH 54297 + df = DataFrame( + { + "A": {0: "a", 1: "b", 2: "c"}, + "B": {0: 1, 1: 3, 2: 5}, + "C": {0: 2, 1: 4, 2: 6}, + } + ) + df.columns = df.columns.astype("string[python]") + result = df.melt(id_vars=["A"], value_vars=["B"]) + expected = DataFrame( + { + "A": list("abc"), + "variable": pd.Series(["B"] * 3, dtype="string[python]"), + "value": [1, 3, 5], + } + ) + tm.assert_frame_equal(result, expected) + class TestLreshape: def test_pairs(self):