From 43d88eb73d7555c352152f36a53fc49dca5bc8e4 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler Date: Thu, 19 Jan 2023 00:02:00 +0100 Subject: [PATCH 1/2] ENH: Add ea support to get_dummies --- doc/source/whatsnew/v2.0.0.rst | 1 + pandas/core/reshape/encoding.py | 12 +++++++++--- pandas/tests/reshape/test_get_dummies.py | 20 ++++++++++++++++++++ 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index 7054d93457264..5e81ae62bc586 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -158,6 +158,7 @@ Other enhancements - Added ``name`` parameter to :meth:`IntervalIndex.from_breaks`, :meth:`IntervalIndex.from_arrays` and :meth:`IntervalIndex.from_tuples` (:issue:`48911`) - Improve exception message when using :func:`assert_frame_equal` on a :class:`DataFrame` to include the column that is compared (:issue:`50323`) - Improved error message for :func:`merge_asof` when join-columns were duplicated (:issue:`50102`) +- Added support for extension array dtypes to :func:`get_dummies` (:func:`32430`) - Added :meth:`Index.infer_objects` analogous to :meth:`Series.infer_objects` (:issue:`50034`) - Added ``copy`` parameter to :meth:`Series.infer_objects` and :meth:`DataFrame.infer_objects`, passing ``False`` will avoid making copies for series or columns that are already non-object or where no better dtype can be inferred (:issue:`50096`) - :meth:`DataFrame.plot.hist` now recognizes ``xlabel`` and ``ylabel`` arguments (:issue:`49793`) diff --git a/pandas/core/reshape/encoding.py b/pandas/core/reshape/encoding.py index 7e45e587ca84a..b6bfcf75695a3 100644 --- a/pandas/core/reshape/encoding.py +++ b/pandas/core/reshape/encoding.py @@ -12,10 +12,12 @@ from pandas._libs.sparse import IntIndex from pandas._typing import NpDtype +from pandas.core.dtypes.base import ExtensionDtype from pandas.core.dtypes.common import ( is_integer_dtype, is_list_like, is_object_dtype, + pandas_dtype, ) from pandas.core.arrays import SparseArray @@ -240,7 +242,7 @@ def _get_dummies_1d( if dtype is None: dtype = np.dtype(bool) - dtype = np.dtype(dtype) + dtype = pandas_dtype(dtype) if is_object_dtype(dtype): raise ValueError("dtype=object is not a valid dtype for get_dummies") @@ -317,7 +319,11 @@ def get_empty_frame(data) -> DataFrame: else: # take on axis=1 + transpose to ensure ndarray layout is column-major - dummy_mat = np.eye(number_of_cols, dtype=dtype).take(codes, axis=1).T + if isinstance(dtype, ExtensionDtype): + eye_dtype = np.bool_ + else: + eye_dtype = dtype + dummy_mat = np.eye(number_of_cols, dtype=eye_dtype).take(codes, axis=1).T if not dummy_na: # reset NaN GH4446 @@ -327,7 +333,7 @@ def get_empty_frame(data) -> DataFrame: # remove first GH12042 dummy_mat = dummy_mat[:, 1:] dummy_cols = dummy_cols[1:] - return DataFrame(dummy_mat, index=index, columns=dummy_cols) + return DataFrame(dummy_mat, index=index, columns=dummy_cols, dtype=dtype) def from_dummies( diff --git a/pandas/tests/reshape/test_get_dummies.py b/pandas/tests/reshape/test_get_dummies.py index 8a7985280eff4..ed4da9562aeee 100644 --- a/pandas/tests/reshape/test_get_dummies.py +++ b/pandas/tests/reshape/test_get_dummies.py @@ -657,3 +657,23 @@ def test_get_dummies_with_string_values(self, values): with pytest.raises(TypeError, match=msg): get_dummies(df, columns=values) + + def test_get_dummies_ea_dtype_series(self, any_numeric_ea_dtype): + # GH#32430 + ser = Series(list("abca")) + result = get_dummies(ser, dtype=any_numeric_ea_dtype) + expected = DataFrame( + {"a": [1, 0, 0, 1], "b": [0, 1, 0, 0], "c": [0, 0, 1, 0]}, + dtype=any_numeric_ea_dtype, + ) + tm.assert_frame_equal(result, expected) + + def test_get_dummies_ea_dtype_dataframe(self, any_numeric_ea_dtype): + # GH#32430 + df = DataFrame({"x": list("abca")}) + result = get_dummies(df, dtype=any_numeric_ea_dtype) + expected = DataFrame( + {"x_a": [1, 0, 0, 1], "x_b": [0, 1, 0, 0], "x_c": [0, 0, 1, 0]}, + dtype=any_numeric_ea_dtype, + ) + tm.assert_frame_equal(result, expected) From 6c5ce7e2868f4ac2e88cebff417c848545cdb87d Mon Sep 17 00:00:00 2001 From: Patrick Hoefler Date: Thu, 19 Jan 2023 21:47:07 +0100 Subject: [PATCH 2/2] Fix mypy --- pandas/core/reshape/encoding.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pandas/core/reshape/encoding.py b/pandas/core/reshape/encoding.py index b6bfcf75695a3..2aa1a3001fb6b 100644 --- a/pandas/core/reshape/encoding.py +++ b/pandas/core/reshape/encoding.py @@ -12,7 +12,6 @@ from pandas._libs.sparse import IntIndex from pandas._typing import NpDtype -from pandas.core.dtypes.base import ExtensionDtype from pandas.core.dtypes.common import ( is_integer_dtype, is_list_like, @@ -242,9 +241,9 @@ def _get_dummies_1d( if dtype is None: dtype = np.dtype(bool) - dtype = pandas_dtype(dtype) + _dtype = pandas_dtype(dtype) - if is_object_dtype(dtype): + if is_object_dtype(_dtype): raise ValueError("dtype=object is not a valid dtype for get_dummies") def get_empty_frame(data) -> DataFrame: @@ -319,10 +318,11 @@ def get_empty_frame(data) -> DataFrame: else: # take on axis=1 + transpose to ensure ndarray layout is column-major - if isinstance(dtype, ExtensionDtype): - eye_dtype = np.bool_ + eye_dtype: NpDtype + if isinstance(_dtype, np.dtype): + eye_dtype = _dtype else: - eye_dtype = dtype + eye_dtype = np.bool_ dummy_mat = np.eye(number_of_cols, dtype=eye_dtype).take(codes, axis=1).T if not dummy_na: @@ -333,7 +333,7 @@ def get_empty_frame(data) -> DataFrame: # remove first GH12042 dummy_mat = dummy_mat[:, 1:] dummy_cols = dummy_cols[1:] - return DataFrame(dummy_mat, index=index, columns=dummy_cols, dtype=dtype) + return DataFrame(dummy_mat, index=index, columns=dummy_cols, dtype=_dtype) def from_dummies(