Skip to content

Commit 16801a1

Browse files
authored
ENH: Add arrow tests for get_dummies (#50951)
* ENH: Add arrow tests for get_dummies * Fix arrow dep * Fix mypy
1 parent 21399f5 commit 16801a1

File tree

3 files changed

+53
-6
lines changed

3 files changed

+53
-6
lines changed

pandas/_testing/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
import pandas as pd
4444
from pandas import (
45+
ArrowDtype,
4546
Categorical,
4647
CategoricalIndex,
4748
DataFrame,
@@ -198,10 +199,16 @@
198199
UNSIGNED_INT_PYARROW_DTYPES = [pa.uint8(), pa.uint16(), pa.uint32(), pa.uint64()]
199200
SIGNED_INT_PYARROW_DTYPES = [pa.int8(), pa.int16(), pa.int32(), pa.int64()]
200201
ALL_INT_PYARROW_DTYPES = UNSIGNED_INT_PYARROW_DTYPES + SIGNED_INT_PYARROW_DTYPES
202+
ALL_INT_PYARROW_DTYPES_STR_REPR = [
203+
str(ArrowDtype(typ)) for typ in ALL_INT_PYARROW_DTYPES
204+
]
201205

202206
# pa.float16 doesn't seem supported
203207
# https://github.com/apache/arrow/blob/master/python/pyarrow/src/arrow/python/helpers.cc#L86
204208
FLOAT_PYARROW_DTYPES = [pa.float32(), pa.float64()]
209+
FLOAT_PYARROW_DTYPES_STR_REPR = [
210+
str(ArrowDtype(typ)) for typ in FLOAT_PYARROW_DTYPES
211+
]
205212
STRING_PYARROW_DTYPES = [pa.string()]
206213
BINARY_PYARROW_DTYPES = [pa.binary()]
207214

@@ -234,6 +241,9 @@
234241
+ TIMEDELTA_PYARROW_DTYPES
235242
+ BOOL_PYARROW_DTYPES
236243
)
244+
else:
245+
FLOAT_PYARROW_DTYPES_STR_REPR = []
246+
ALL_INT_PYARROW_DTYPES_STR_REPR = []
237247

238248

239249
EMPTY_STRING_PATTERN = re.compile("^$")

pandas/conftest.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1527,6 +1527,43 @@ def any_numeric_ea_dtype(request):
15271527
return request.param
15281528

15291529

1530+
# Unsupported operand types for + ("List[Union[str, ExtensionDtype, dtype[Any],
1531+
# Type[object]]]" and "List[str]")
1532+
@pytest.fixture(
1533+
params=tm.ALL_INT_EA_DTYPES
1534+
+ tm.FLOAT_EA_DTYPES
1535+
+ tm.ALL_INT_PYARROW_DTYPES_STR_REPR
1536+
+ tm.FLOAT_PYARROW_DTYPES_STR_REPR # type: ignore[operator]
1537+
)
1538+
def any_numeric_ea_and_arrow_dtype(request):
1539+
"""
1540+
Parameterized fixture for any nullable integer dtype and
1541+
any float ea dtypes.
1542+
1543+
* 'UInt8'
1544+
* 'Int8'
1545+
* 'UInt16'
1546+
* 'Int16'
1547+
* 'UInt32'
1548+
* 'Int32'
1549+
* 'UInt64'
1550+
* 'Int64'
1551+
* 'Float32'
1552+
* 'Float64'
1553+
* 'uint8[pyarrow]'
1554+
* 'int8[pyarrow]'
1555+
* 'uint16[pyarrow]'
1556+
* 'int16[pyarrow]'
1557+
* 'uint32[pyarrow]'
1558+
* 'int32[pyarrow]'
1559+
* 'uint64[pyarrow]'
1560+
* 'int64[pyarrow]'
1561+
* 'float32[pyarrow]'
1562+
* 'float64[pyarrow]'
1563+
"""
1564+
return request.param
1565+
1566+
15301567
@pytest.fixture(params=tm.SIGNED_INT_EA_DTYPES)
15311568
def any_signed_int_ea_dtype(request):
15321569
"""

pandas/tests/reshape/test_get_dummies.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -658,22 +658,22 @@ def test_get_dummies_with_string_values(self, values):
658658
with pytest.raises(TypeError, match=msg):
659659
get_dummies(df, columns=values)
660660

661-
def test_get_dummies_ea_dtype_series(self, any_numeric_ea_dtype):
661+
def test_get_dummies_ea_dtype_series(self, any_numeric_ea_and_arrow_dtype):
662662
# GH#32430
663663
ser = Series(list("abca"))
664-
result = get_dummies(ser, dtype=any_numeric_ea_dtype)
664+
result = get_dummies(ser, dtype=any_numeric_ea_and_arrow_dtype)
665665
expected = DataFrame(
666666
{"a": [1, 0, 0, 1], "b": [0, 1, 0, 0], "c": [0, 0, 1, 0]},
667-
dtype=any_numeric_ea_dtype,
667+
dtype=any_numeric_ea_and_arrow_dtype,
668668
)
669669
tm.assert_frame_equal(result, expected)
670670

671-
def test_get_dummies_ea_dtype_dataframe(self, any_numeric_ea_dtype):
671+
def test_get_dummies_ea_dtype_dataframe(self, any_numeric_ea_and_arrow_dtype):
672672
# GH#32430
673673
df = DataFrame({"x": list("abca")})
674-
result = get_dummies(df, dtype=any_numeric_ea_dtype)
674+
result = get_dummies(df, dtype=any_numeric_ea_and_arrow_dtype)
675675
expected = DataFrame(
676676
{"x_a": [1, 0, 0, 1], "x_b": [0, 1, 0, 0], "x_c": [0, 0, 1, 0]},
677-
dtype=any_numeric_ea_dtype,
677+
dtype=any_numeric_ea_and_arrow_dtype,
678678
)
679679
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)