Skip to content

Commit 27e18c8

Browse files
Change default dtype for get_dummies to bool (#13174)
This PR changes the default dtype for get_dummies to bool from uint8 to match pandas-2.0: pandas-dev/pandas#48022
1 parent b772017 commit 27e18c8

File tree

3 files changed

+40
-49
lines changed

3 files changed

+40
-49
lines changed

python/cudf/cudf/_lib/transform.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def one_hot_encode(Column input_column, Column categories):
163163
move(c_result.second),
164164
owner=owner,
165165
column_names=[
166-
x if x is not None else 'null' for x in pylist_categories
166+
x if x is not None else '<NA>' for x in pylist_categories
167167
]
168168
)
169169
return encodings

python/cudf/cudf/core/reshape.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ def get_dummies(
609609
cats=None,
610610
sparse=False,
611611
drop_first=False,
612-
dtype="uint8",
612+
dtype="bool",
613613
):
614614
"""Returns a dataframe whose columns are the one hot encodings of all
615615
columns in `df`
@@ -640,23 +640,23 @@ def get_dummies(
640640
columns. Note this is different from pandas default behavior, which
641641
encodes all columns with dtype object or categorical
642642
dtype : str, optional
643-
Output dtype, default 'uint8'
643+
Output dtype, default 'bool'
644644
645645
Examples
646646
--------
647647
>>> import cudf
648648
>>> df = cudf.DataFrame({"a": ["value1", "value2", None], "b": [0, 0, 0]})
649649
>>> cudf.get_dummies(df)
650650
b a_value1 a_value2
651-
0 0 1 0
652-
1 0 0 1
653-
2 0 0 0
651+
0 0 True False
652+
1 0 False True
653+
2 0 False False
654654
655655
>>> cudf.get_dummies(df, dummy_na=True)
656-
b a_None a_value1 a_value2
657-
0 0 0 1 0
658-
1 0 0 0 1
659-
2 0 1 0 0
656+
b a_<NA> a_value1 a_value2
657+
0 0 False True False
658+
1 0 False False True
659+
2 0 True False False
660660
661661
>>> import numpy as np
662662
>>> df = cudf.DataFrame({"a":cudf.Series([1, 2, np.nan, None],
@@ -669,11 +669,11 @@ def get_dummies(
669669
3 <NA>
670670
671671
>>> cudf.get_dummies(df, dummy_na=True, columns=["a"])
672-
a_1.0 a_2.0 a_nan a_null
673-
0 1 0 0 0
674-
1 0 1 0 0
675-
2 0 0 1 0
676-
3 0 0 0 1
672+
a_<NA> a_1.0 a_2.0 a_nan
673+
0 False True False False
674+
1 False False True False
675+
2 False False False True
676+
3 True False False False
677677
678678
>>> series = cudf.Series([1, 2, None, 2, 4])
679679
>>> series
@@ -684,12 +684,12 @@ def get_dummies(
684684
4 4
685685
dtype: int64
686686
>>> cudf.get_dummies(series, dummy_na=True)
687-
null 1 2 4
688-
0 0 1 0 0
689-
1 0 0 1 0
690-
2 1 0 0 0
691-
3 0 0 1 0
692-
4 0 0 0 1
687+
<NA> 1 2 4
688+
0 False True False False
689+
1 False False True False
690+
2 True False False False
691+
3 False False True False
692+
4 False False False True
693693
"""
694694

695695
if cats is None:

python/cudf/cudf/tests/test_onehot.py

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2018-2022, NVIDIA CORPORATION.
1+
# Copyright (c) 2018-2023, NVIDIA CORPORATION.
22

33
from string import ascii_lowercase
44

@@ -23,19 +23,13 @@
2323
(range(10), [1, 2, 3, 4, 5] * 2),
2424
],
2525
)
26-
def test_get_dummies(data, index):
26+
@pytest.mark.parametrize("dtype", ["bool", "uint8"])
27+
def test_get_dummies(data, index, dtype):
2728
gdf = DataFrame({"x": data}, index=index)
2829
pdf = pd.DataFrame({"x": data}, index=index)
2930

30-
encoded_expected = pd.get_dummies(pdf, prefix="test")
31-
encoded_actual = cudf.get_dummies(gdf, prefix="test")
32-
33-
utils.assert_eq(
34-
encoded_expected,
35-
encoded_actual,
36-
check_dtype=len(data) != 0,
37-
)
38-
encoded_actual = cudf.get_dummies(gdf, prefix="test", dtype=np.uint8)
31+
encoded_expected = pd.get_dummies(pdf, prefix="test", dtype=dtype)
32+
encoded_actual = cudf.get_dummies(gdf, prefix="test", dtype=dtype)
3933

4034
utils.assert_eq(
4135
encoded_expected,
@@ -63,16 +57,13 @@ def test_onehot_get_dummies_multicol(n_cols):
6357
@pytest.mark.parametrize("nan_as_null", [True, False])
6458
@pytest.mark.parametrize("dummy_na", [True, False])
6559
def test_onehost_get_dummies_dummy_na(nan_as_null, dummy_na):
66-
pdf = pd.DataFrame({"a": [0, 1, np.nan]})
67-
df = DataFrame.from_pandas(pdf, nan_as_null=nan_as_null)
60+
df = cudf.DataFrame({"a": [0, 1, np.nan]}, nan_as_null=nan_as_null)
61+
pdf = df.to_pandas(nullable=nan_as_null)
6862

6963
expected = pd.get_dummies(pdf, dummy_na=dummy_na, columns=["a"])
7064
got = cudf.get_dummies(df, dummy_na=dummy_na, columns=["a"])
7165

72-
if dummy_na and nan_as_null:
73-
got = got.rename(columns={"a_null": "a_nan"})[expected.columns]
74-
75-
utils.assert_eq(expected, got)
66+
utils.assert_eq(expected, got, check_like=True)
7667

7768

7869
@pytest.mark.parametrize(
@@ -120,12 +111,12 @@ def test_get_dummies_with_nan():
120111
)
121112
expected = cudf.DataFrame(
122113
{
123-
"a_null": [0, 0, 0, 1],
124-
"a_1.0": [1, 0, 0, 0],
125-
"a_2.0": [0, 1, 0, 0],
126-
"a_nan": [0, 0, 1, 0],
114+
"a_<NA>": [False, False, False, True],
115+
"a_1.0": [True, False, False, False],
116+
"a_2.0": [False, True, False, False],
117+
"a_nan": [False, False, True, False],
127118
},
128-
dtype="uint8",
119+
dtype="bool",
129120
)
130121
actual = cudf.get_dummies(df, dummy_na=True, columns=["a"])
131122

@@ -163,13 +154,13 @@ def test_get_dummies_array_like_with_nan():
163154
ser = cudf.Series([0.1, 2, 3, None, np.nan], nan_as_null=False)
164155
expected = cudf.DataFrame(
165156
{
166-
"a_null": [0, 0, 0, 1, 0],
167-
"a_0.1": [1, 0, 0, 0, 0],
168-
"a_2.0": [0, 1, 0, 0, 0],
169-
"a_3.0": [0, 0, 1, 0, 0],
170-
"a_nan": [0, 0, 0, 0, 1],
157+
"a_<NA>": [False, False, False, True, False],
158+
"a_0.1": [True, False, False, False, False],
159+
"a_2.0": [False, True, False, False, False],
160+
"a_3.0": [False, False, True, False, False],
161+
"a_nan": [False, False, False, False, True],
171162
},
172-
dtype="uint8",
163+
dtype="bool",
173164
)
174165
actual = cudf.get_dummies(ser, dummy_na=True, prefix="a", prefix_sep="_")
175166

0 commit comments

Comments
 (0)