Skip to content

ENH: Add arrow tests for get_dummies #50951

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions pandas/_testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

import pandas as pd
from pandas import (
ArrowDtype,
Categorical,
CategoricalIndex,
DataFrame,
Expand Down Expand Up @@ -198,10 +199,16 @@
UNSIGNED_INT_PYARROW_DTYPES = [pa.uint8(), pa.uint16(), pa.uint32(), pa.uint64()]
SIGNED_INT_PYARROW_DTYPES = [pa.int8(), pa.int16(), pa.int32(), pa.int64()]
ALL_INT_PYARROW_DTYPES = UNSIGNED_INT_PYARROW_DTYPES + SIGNED_INT_PYARROW_DTYPES
ALL_INT_PYARROW_DTYPES_STR_REPR = [
str(ArrowDtype(typ)) for typ in ALL_INT_PYARROW_DTYPES
]

# pa.float16 doesn't seem supported
# https://github.com/apache/arrow/blob/master/python/pyarrow/src/arrow/python/helpers.cc#L86
FLOAT_PYARROW_DTYPES = [pa.float32(), pa.float64()]
FLOAT_PYARROW_DTYPES_STR_REPR = [
str(ArrowDtype(typ)) for typ in FLOAT_PYARROW_DTYPES
]
STRING_PYARROW_DTYPES = [pa.string()]
BINARY_PYARROW_DTYPES = [pa.binary()]

Expand Down Expand Up @@ -234,6 +241,9 @@
+ TIMEDELTA_PYARROW_DTYPES
+ BOOL_PYARROW_DTYPES
)
else:
FLOAT_PYARROW_DTYPES_STR_REPR = []
ALL_INT_PYARROW_DTYPES_STR_REPR = []


EMPTY_STRING_PATTERN = re.compile("^$")
Expand Down
37 changes: 37 additions & 0 deletions pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,6 +1527,43 @@ def any_numeric_ea_dtype(request):
return request.param


# Unsupported operand types for + ("List[Union[str, ExtensionDtype, dtype[Any],
# Type[object]]]" and "List[str]")
@pytest.fixture(
params=tm.ALL_INT_EA_DTYPES
+ tm.FLOAT_EA_DTYPES
+ tm.ALL_INT_PYARROW_DTYPES_STR_REPR
+ tm.FLOAT_PYARROW_DTYPES_STR_REPR # type: ignore[operator]
)
def any_numeric_ea_and_arrow_dtype(request):
"""
Parameterized fixture for any nullable integer dtype and
any float ea dtypes.

* 'UInt8'
* 'Int8'
* 'UInt16'
* 'Int16'
* 'UInt32'
* 'Int32'
* 'UInt64'
* 'Int64'
* 'Float32'
* 'Float64'
* 'uint8[pyarrow]'
* 'int8[pyarrow]'
* 'uint16[pyarrow]'
* 'int16[pyarrow]'
* 'uint32[pyarrow]'
* 'int32[pyarrow]'
* 'uint64[pyarrow]'
* 'int64[pyarrow]'
* 'float32[pyarrow]'
* 'float64[pyarrow]'
"""
return request.param


@pytest.fixture(params=tm.SIGNED_INT_EA_DTYPES)
def any_signed_int_ea_dtype(request):
"""
Expand Down
12 changes: 6 additions & 6 deletions pandas/tests/reshape/test_get_dummies.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,22 +658,22 @@ def test_get_dummies_with_string_values(self, values):
with pytest.raises(TypeError, match=msg):
get_dummies(df, columns=values)

def test_get_dummies_ea_dtype_series(self, any_numeric_ea_dtype):
def test_get_dummies_ea_dtype_series(self, any_numeric_ea_and_arrow_dtype):
# GH#32430
ser = Series(list("abca"))
result = get_dummies(ser, dtype=any_numeric_ea_dtype)
result = get_dummies(ser, dtype=any_numeric_ea_and_arrow_dtype)
expected = DataFrame(
{"a": [1, 0, 0, 1], "b": [0, 1, 0, 0], "c": [0, 0, 1, 0]},
dtype=any_numeric_ea_dtype,
dtype=any_numeric_ea_and_arrow_dtype,
)
tm.assert_frame_equal(result, expected)

def test_get_dummies_ea_dtype_dataframe(self, any_numeric_ea_dtype):
def test_get_dummies_ea_dtype_dataframe(self, any_numeric_ea_and_arrow_dtype):
# GH#32430
df = DataFrame({"x": list("abca")})
result = get_dummies(df, dtype=any_numeric_ea_dtype)
result = get_dummies(df, dtype=any_numeric_ea_and_arrow_dtype)
expected = DataFrame(
{"x_a": [1, 0, 0, 1], "x_b": [0, 1, 0, 0], "x_c": [0, 0, 1, 0]},
dtype=any_numeric_ea_dtype,
dtype=any_numeric_ea_and_arrow_dtype,
)
tm.assert_frame_equal(result, expected)