diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index bbecf3fee01f3..ff21a68d31f92 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -158,6 +158,7 @@ Other enhancements - Added ``name`` parameter to :meth:`IntervalIndex.from_breaks`, :meth:`IntervalIndex.from_arrays` and :meth:`IntervalIndex.from_tuples` (:issue:`48911`) - Improve exception message when using :func:`assert_frame_equal` on a :class:`DataFrame` to include the column that is compared (:issue:`50323`) - Improved error message for :func:`merge_asof` when join-columns were duplicated (:issue:`50102`) +- Added support for extension array dtypes to :func:`get_dummies` (:func:`32430`) - Added :meth:`Index.infer_objects` analogous to :meth:`Series.infer_objects` (:issue:`50034`) - Added ``copy`` parameter to :meth:`Series.infer_objects` and :meth:`DataFrame.infer_objects`, passing ``False`` will avoid making copies for series or columns that are already non-object or where no better dtype can be inferred (:issue:`50096`) - :meth:`DataFrame.plot.hist` now recognizes ``xlabel`` and ``ylabel`` arguments (:issue:`49793`) diff --git a/pandas/core/reshape/encoding.py b/pandas/core/reshape/encoding.py index 7e45e587ca84a..2aa1a3001fb6b 100644 --- a/pandas/core/reshape/encoding.py +++ b/pandas/core/reshape/encoding.py @@ -16,6 +16,7 @@ is_integer_dtype, is_list_like, is_object_dtype, + pandas_dtype, ) from pandas.core.arrays import SparseArray @@ -240,9 +241,9 @@ def _get_dummies_1d( if dtype is None: dtype = np.dtype(bool) - dtype = np.dtype(dtype) + _dtype = pandas_dtype(dtype) - if is_object_dtype(dtype): + if is_object_dtype(_dtype): raise ValueError("dtype=object is not a valid dtype for get_dummies") def get_empty_frame(data) -> DataFrame: @@ -317,7 +318,12 @@ def get_empty_frame(data) -> DataFrame: else: # take on axis=1 + transpose to ensure ndarray layout is column-major - dummy_mat = np.eye(number_of_cols, dtype=dtype).take(codes, axis=1).T + eye_dtype: NpDtype + if isinstance(_dtype, np.dtype): + eye_dtype = _dtype + else: + eye_dtype = np.bool_ + dummy_mat = np.eye(number_of_cols, dtype=eye_dtype).take(codes, axis=1).T if not dummy_na: # reset NaN GH4446 @@ -327,7 +333,7 @@ def get_empty_frame(data) -> DataFrame: # remove first GH12042 dummy_mat = dummy_mat[:, 1:] dummy_cols = dummy_cols[1:] - return DataFrame(dummy_mat, index=index, columns=dummy_cols) + return DataFrame(dummy_mat, index=index, columns=dummy_cols, dtype=_dtype) def from_dummies( diff --git a/pandas/tests/reshape/test_get_dummies.py b/pandas/tests/reshape/test_get_dummies.py index 8a7985280eff4..ed4da9562aeee 100644 --- a/pandas/tests/reshape/test_get_dummies.py +++ b/pandas/tests/reshape/test_get_dummies.py @@ -657,3 +657,23 @@ def test_get_dummies_with_string_values(self, values): with pytest.raises(TypeError, match=msg): get_dummies(df, columns=values) + + def test_get_dummies_ea_dtype_series(self, any_numeric_ea_dtype): + # GH#32430 + ser = Series(list("abca")) + result = get_dummies(ser, dtype=any_numeric_ea_dtype) + expected = DataFrame( + {"a": [1, 0, 0, 1], "b": [0, 1, 0, 0], "c": [0, 0, 1, 0]}, + dtype=any_numeric_ea_dtype, + ) + tm.assert_frame_equal(result, expected) + + def test_get_dummies_ea_dtype_dataframe(self, any_numeric_ea_dtype): + # GH#32430 + df = DataFrame({"x": list("abca")}) + result = get_dummies(df, dtype=any_numeric_ea_dtype) + expected = DataFrame( + {"x_a": [1, 0, 0, 1], "x_b": [0, 1, 0, 0], "x_c": [0, 0, 1, 0]}, + dtype=any_numeric_ea_dtype, + ) + tm.assert_frame_equal(result, expected)