diff --git a/doc/source/whatsnew/v1.1.0.rst b/doc/source/whatsnew/v1.1.0.rst index a797090a83444..1ccd82a894234 100644 --- a/doc/source/whatsnew/v1.1.0.rst +++ b/doc/source/whatsnew/v1.1.0.rst @@ -82,6 +82,7 @@ Other enhancements - :class:`Styler` may now render CSS more efficiently where multiple cells have the same styling (:issue:`30876`) - :meth:`Styler.highlight_null` now accepts ``subset`` argument (:issue:`31345`) - When writing directly to a sqlite connection :func:`to_sql` now supports the ``multi`` method (:issue:`29921`) +- We have added a :meth:`pandas.from_dummies`, which is an inverse transformation of :meth:`pandas.get_dummies` (:issue:`8745`) - `OptionError` is now exposed in `pandas.errors` (:issue:`27553`) - :func:`timedelta_range` will now infer a frequency when passed ``start``, ``stop``, and ``periods`` (:issue:`32377`) - Positional slicing on a :class:`IntervalIndex` now supports slices with ``step > 1`` (:issue:`31658`) diff --git a/pandas/__init__.py b/pandas/__init__.py index 2b9a461e0e95d..85d9a452c4cc9 100644 --- a/pandas/__init__.py +++ b/pandas/__init__.py @@ -136,6 +136,7 @@ get_dummies, cut, qcut, + from_dummies, ) import pandas.api diff --git a/pandas/core/reshape/api.py b/pandas/core/reshape/api.py index 3c76eef809c7a..7054926b9c0c4 100644 --- a/pandas/core/reshape/api.py +++ b/pandas/core/reshape/api.py @@ -4,5 +4,5 @@ from pandas.core.reshape.melt import lreshape, melt, wide_to_long from pandas.core.reshape.merge import merge, merge_asof, merge_ordered from pandas.core.reshape.pivot import crosstab, pivot, pivot_table -from pandas.core.reshape.reshape import get_dummies +from pandas.core.reshape.reshape import from_dummies, get_dummies from pandas.core.reshape.tile import cut, qcut diff --git a/pandas/core/reshape/reshape.py b/pandas/core/reshape/reshape.py index 882e3e0a649cc..23f14cc84ad62 100644 --- a/pandas/core/reshape/reshape.py +++ b/pandas/core/reshape/reshape.py @@ -1,11 +1,12 @@ import itertools -from typing import List, Optional, Union +from typing import List, Optional, Set, Union import numpy as np import pandas._libs.algos as libalgos import pandas._libs.reshape as libreshape from pandas._libs.sparse import IntIndex +from pandas._typing import Dtype from pandas.util._decorators import cache_readonly from pandas.core.dtypes.cast import maybe_promote @@ -727,6 +728,149 @@ def _convert_level_number(level_num, columns): return result +def from_dummies( + data: "DataFrame", + prefix: Optional[Union[str, List[str]]] = None, + prefix_sep: str = "_", + dtype: Dtype = "category", +) -> "DataFrame": + """ + The inverse transformation of ``pandas.get_dummies``. + + .. versionadded:: 1.1.0 + + Parameters + ---------- + data : DataFrame + Data which contains dummy indicators. + prefix : list-like, default None + How to name the decoded groups of columns. If there are columns + containing `prefix_sep`, then the part of their name preceding + `prefix_sep` will be used (see examples below). + prefix_sep : str, default '_' + Separator between original column name and dummy variable. + dtype : dtype, default 'category' + Data dtype for new columns - only a single data type is allowed. + + Returns + ------- + DataFrame + Decoded data. + + See Also + -------- + get_dummies : The inverse operation. + + Examples + -------- + Say we have a dataframe where some variables have been dummified: + + >>> df = pd.DataFrame( + ... { + ... "baboon": [0, 0, 1], + ... "lemur": [0, 1, 0], + ... "zebra": [1, 0, 0], + ... } + ... ) + >>> df + baboon lemur zebra + 0 0 0 1 + 1 0 1 0 + 2 1 0 0 + + We can recover the original dataframe using `from_dummies`: + + >>> pd.from_dummies(df, prefix='animal') + animal + 0 zebra + 1 lemur + 2 baboon + + If our dataframe already has columns with `prefix_sep` in them, + we don't need to pass in the `prefix` argument: + + >>> df = pd.DataFrame( + ... { + ... "animal_baboon": [0, 0, 1], + ... "animal_lemur": [0, 1, 0], + ... "animal_zebra": [1, 0, 0], + ... "other": ['a', 'b', 'c'], + ... } + ... ) + >>> df + animal_baboon animal_lemur animal_zebra other + 0 0 0 1 a + 1 0 1 0 b + 2 1 0 0 c + + >>> pd.from_dummies(df) + other animal + 0 a zebra + 1 b lemur + 2 c baboon + """ + if dtype is None: + dtype = "category" + + columns_to_decode = [i for i in data.columns if prefix_sep in i] + if not columns_to_decode: + if prefix is None: + raise ValueError( + "If no columns contain `prefix_sep`, you must " + "pass a value to `prefix` with which to name " + "the decoded columns." + ) + # If no column contains `prefix_sep`, we prepend `prefix` and + # `prefix_sep` to each column. + out = data.rename(columns=lambda x: f"{prefix}{prefix_sep}{x}").copy() + columns_to_decode = out.columns + else: + out = data.copy() + + data_to_decode = out[columns_to_decode] + + if prefix is None: + # If no prefix has been passed, extract it from columns containing + # `prefix_sep` + seen: Set[str] = set() + prefix = [] + for i in columns_to_decode: + i = i.split(prefix_sep)[0] + if i in seen: + continue + seen.add(i) + prefix.append(i) + elif isinstance(prefix, str): + prefix = [prefix] + + # Check each row sums to 1 or 0 + def _validate_values(data): + if (data.sum(axis=1) != 1).any(): + raise ValueError( + "Data cannot be decoded! Each row must contain only 0s and " + "1s, and each row may have at most one 1." + ) + + for prefix_ in prefix: + cols, labels = ( + [ + i.replace(x, "") + for i in data_to_decode.columns + if prefix_ + prefix_sep in i + ] + for x in ["", prefix_ + prefix_sep] + ) + if not cols: + continue + _validate_values(data_to_decode[cols]) + out = out.drop(cols, axis=1) + out[prefix_] = Series( + np.array(labels)[np.argmax(data_to_decode[cols].to_numpy(), axis=1)], + dtype=dtype, + ) + return out + + def get_dummies( data, prefix=None, @@ -777,6 +921,7 @@ def get_dummies( See Also -------- Series.str.get_dummies : Convert Series to dummy codes. + from_dummies : The inverse operation. Examples -------- diff --git a/pandas/tests/api/test_api.py b/pandas/tests/api/test_api.py index 5aab5b814bae7..bc23911c1eee1 100644 --- a/pandas/tests/api/test_api.py +++ b/pandas/tests/api/test_api.py @@ -120,6 +120,7 @@ class TestPDApi(Base): "eval", "factorize", "get_dummies", + "from_dummies", "infer_freq", "isna", "isnull", diff --git a/pandas/tests/reshape/test_from_dummies.py b/pandas/tests/reshape/test_from_dummies.py new file mode 100644 index 0000000000000..89334e2348e0b --- /dev/null +++ b/pandas/tests/reshape/test_from_dummies.py @@ -0,0 +1,60 @@ +import pytest + +import pandas as pd +import pandas._testing as tm + + +@pytest.mark.parametrize( + "dtype, expected_dict", + [ + ("str", {"col1": ["a", "a", "b"]}), + (str, {"col1": ["a", "a", "b"]},), + ("category", {"col1": ["a", "a", "b"]}), + ], +) +def test_dtype(dtype, expected_dict): + df = pd.DataFrame({"col1_a": [1, 1, 0], "col1_b": [0, 0, 1]}) + result = pd.from_dummies(df, dtype=dtype) + expected = pd.DataFrame(expected_dict, dtype=dtype) + tm.assert_frame_equal(result, expected) + + +def test_malformed(): + df = pd.DataFrame({"col1_a": [1, 1, 0], "col1_b": [1, 0, 1]}) + msg = ( + "Data cannot be decoded! Each row must contain only 0s and 1s" + ", and each row may have at most one 1" + ) + with pytest.raises(ValueError, match=msg): + pd.from_dummies(df) + + +@pytest.mark.parametrize( + "prefix_sep, input_dict", + [ + ("_", {"col1_a": [1, 1, 0], "col1_b": [0, 0, 1]}), + ("*", {"col1*a": [1, 1, 0], "col1*b": [0, 0, 1]}), + (".", {"col1.a": [1, 1, 0], "col1.b": [0, 0, 1]}), + ], +) +def test_prefix_sep(prefix_sep, input_dict): + df = pd.DataFrame(input_dict) + result = pd.from_dummies(df, prefix_sep=prefix_sep) + expected = pd.DataFrame({"col1": ["a", "a", "b"]}, dtype="category") + tm.assert_frame_equal(result, expected) + + +def test_no_prefix(): + df = pd.DataFrame({"a": [1, 1, 0], "b": [0, 0, 1]}) + result = pd.from_dummies(df, prefix="letter") + expected = pd.DataFrame({"letter": ["a", "a", "b"]}, dtype="category") + tm.assert_frame_equal(result, expected) + + +def test_multiple_columns(): + df = pd.DataFrame( + {"col1_a": [1, 0], "col1_b": [0, 1], "col2_a": [0, 0], "col2_c": [1, 1]} + ) + result = pd.from_dummies(df) + expected = pd.DataFrame({"col1": ["a", "b"], "col2": ["c", "c"]}, dtype="category") + tm.assert_frame_equal(result, expected)