Skip to content

Commit 08c3f71

Browse files
committed
ENH: add dtype argument to get_dummies
1 parent 34a03f3 commit 08c3f71

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

pandas/core/reshape/reshape.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,7 @@ def _convert_level_number(level_num, columns):
697697

698698

699699
def get_dummies(data, prefix=None, prefix_sep='_', dummy_na=False,
700-
columns=None, sparse=False, drop_first=False):
700+
columns=None, sparse=False, drop_first=False, dtype=None):
701701
"""
702702
Convert categorical variable into dummy/indicator variables
703703
@@ -725,6 +725,8 @@ def get_dummies(data, prefix=None, prefix_sep='_', dummy_na=False,
725725
drop_first : bool, default False
726726
Whether to get k-1 dummies out of k categorical levels by removing the
727727
first level.
728+
dtype : dtype, default np.uint8
729+
Data type to force on a new columns. Only a single dtype is allowed.
728730
729731
.. versionadded:: 0.18.0
730732
@@ -783,13 +785,22 @@ def get_dummies(data, prefix=None, prefix_sep='_', dummy_na=False,
783785
3 0 0
784786
4 0 0
785787
788+
>>> pd.get_dummies(pd.Series(list('abc')), dtype=float)
789+
a b c
790+
0 1.0 0.0 0.0
791+
1 0.0 1.0 0.0
792+
2 0.0 0.0 1.0
793+
786794
See Also
787795
--------
788796
Series.str.get_dummies
789797
"""
790798
from pandas.core.reshape.concat import concat
791799
from itertools import cycle
792800

801+
if dtype is None:
802+
dtype = np.uint8
803+
793804
if isinstance(data, DataFrame):
794805
# determine columns being encoded
795806

@@ -835,17 +846,17 @@ def check_len(item, name):
835846

836847
dummy = _get_dummies_1d(data[col], prefix=pre, prefix_sep=sep,
837848
dummy_na=dummy_na, sparse=sparse,
838-
drop_first=drop_first)
849+
drop_first=drop_first, dtype=dtype)
839850
with_dummies.append(dummy)
840851
result = concat(with_dummies, axis=1)
841852
else:
842853
result = _get_dummies_1d(data, prefix, prefix_sep, dummy_na,
843-
sparse=sparse, drop_first=drop_first)
854+
sparse=sparse, drop_first=drop_first, dtype=dtype)
844855
return result
845856

846857

847858
def _get_dummies_1d(data, prefix, prefix_sep='_', dummy_na=False,
848-
sparse=False, drop_first=False):
859+
sparse=False, drop_first=False, dtype=np.uint8):
849860
# Series avoids inconsistent NaN handling
850861
codes, levels = _factorize_from_iterable(Series(data))
851862

@@ -903,18 +914,18 @@ def get_empty_Frame(data, sparse):
903914
sp_indices = sp_indices[1:]
904915
dummy_cols = dummy_cols[1:]
905916
for col, ixs in zip(dummy_cols, sp_indices):
906-
sarr = SparseArray(np.ones(len(ixs), dtype=np.uint8),
917+
sarr = SparseArray(np.ones(len(ixs), dtype=dtype),
907918
sparse_index=IntIndex(N, ixs), fill_value=0,
908-
dtype=np.uint8)
919+
dtype=dtype)
909920
sparse_series[col] = SparseSeries(data=sarr, index=index)
910921

911922
out = SparseDataFrame(sparse_series, index=index, columns=dummy_cols,
912923
default_fill_value=0,
913-
dtype=np.uint8)
924+
dtype=dtype)
914925
return out
915926

916927
else:
917-
dummy_mat = np.eye(number_of_cols, dtype=np.uint8).take(codes, axis=0)
928+
dummy_mat = np.eye(number_of_cols, dtype=dtype).take(codes, axis=0)
918929

919930
if not dummy_na:
920931
# reset NaN GH4446

0 commit comments

Comments
 (0)