@@ -697,7 +697,7 @@ def _convert_level_number(level_num, columns):
697
697
698
698
699
699
def get_dummies (data , prefix = None , prefix_sep = '_' , dummy_na = False ,
700
- columns = None , sparse = False , drop_first = False ):
700
+ columns = None , sparse = False , drop_first = False , dtype = None ):
701
701
"""
702
702
Convert categorical variable into dummy/indicator variables
703
703
@@ -725,6 +725,8 @@ def get_dummies(data, prefix=None, prefix_sep='_', dummy_na=False,
725
725
drop_first : bool, default False
726
726
Whether to get k-1 dummies out of k categorical levels by removing the
727
727
first level.
728
+ dtype : dtype, default np.uint8
729
+ Data type to force on a new columns. Only a single dtype is allowed.
728
730
729
731
.. versionadded:: 0.18.0
730
732
@@ -783,13 +785,22 @@ def get_dummies(data, prefix=None, prefix_sep='_', dummy_na=False,
783
785
3 0 0
784
786
4 0 0
785
787
788
+ >>> pd.get_dummies(pd.Series(list('abc')), dtype=float)
789
+ a b c
790
+ 0 1.0 0.0 0.0
791
+ 1 0.0 1.0 0.0
792
+ 2 0.0 0.0 1.0
793
+
786
794
See Also
787
795
--------
788
796
Series.str.get_dummies
789
797
"""
790
798
from pandas .core .reshape .concat import concat
791
799
from itertools import cycle
792
800
801
+ if dtype is None :
802
+ dtype = np .uint8
803
+
793
804
if isinstance (data , DataFrame ):
794
805
# determine columns being encoded
795
806
@@ -835,17 +846,17 @@ def check_len(item, name):
835
846
836
847
dummy = _get_dummies_1d (data [col ], prefix = pre , prefix_sep = sep ,
837
848
dummy_na = dummy_na , sparse = sparse ,
838
- drop_first = drop_first )
849
+ drop_first = drop_first , dtype = dtype )
839
850
with_dummies .append (dummy )
840
851
result = concat (with_dummies , axis = 1 )
841
852
else :
842
853
result = _get_dummies_1d (data , prefix , prefix_sep , dummy_na ,
843
- sparse = sparse , drop_first = drop_first )
854
+ sparse = sparse , drop_first = drop_first , dtype = dtype )
844
855
return result
845
856
846
857
847
858
def _get_dummies_1d (data , prefix , prefix_sep = '_' , dummy_na = False ,
848
- sparse = False , drop_first = False ):
859
+ sparse = False , drop_first = False , dtype = np . uint8 ):
849
860
# Series avoids inconsistent NaN handling
850
861
codes , levels = _factorize_from_iterable (Series (data ))
851
862
@@ -903,18 +914,18 @@ def get_empty_Frame(data, sparse):
903
914
sp_indices = sp_indices [1 :]
904
915
dummy_cols = dummy_cols [1 :]
905
916
for col , ixs in zip (dummy_cols , sp_indices ):
906
- sarr = SparseArray (np .ones (len (ixs ), dtype = np . uint8 ),
917
+ sarr = SparseArray (np .ones (len (ixs ), dtype = dtype ),
907
918
sparse_index = IntIndex (N , ixs ), fill_value = 0 ,
908
- dtype = np . uint8 )
919
+ dtype = dtype )
909
920
sparse_series [col ] = SparseSeries (data = sarr , index = index )
910
921
911
922
out = SparseDataFrame (sparse_series , index = index , columns = dummy_cols ,
912
923
default_fill_value = 0 ,
913
- dtype = np . uint8 )
924
+ dtype = dtype )
914
925
return out
915
926
916
927
else :
917
- dummy_mat = np .eye (number_of_cols , dtype = np . uint8 ).take (codes , axis = 0 )
928
+ dummy_mat = np .eye (number_of_cols , dtype = dtype ).take (codes , axis = 0 )
918
929
919
930
if not dummy_na :
920
931
# reset NaN GH4446
0 commit comments