diff --git a/pandas/_testing/__init__.py b/pandas/_testing/__init__.py index eb25566e7983e..045254d2041fc 100644 --- a/pandas/_testing/__init__.py +++ b/pandas/_testing/__init__.py @@ -42,6 +42,7 @@ import pandas as pd from pandas import ( + ArrowDtype, Categorical, CategoricalIndex, DataFrame, @@ -198,10 +199,16 @@ UNSIGNED_INT_PYARROW_DTYPES = [pa.uint8(), pa.uint16(), pa.uint32(), pa.uint64()] SIGNED_INT_PYARROW_DTYPES = [pa.int8(), pa.int16(), pa.int32(), pa.int64()] ALL_INT_PYARROW_DTYPES = UNSIGNED_INT_PYARROW_DTYPES + SIGNED_INT_PYARROW_DTYPES + ALL_INT_PYARROW_DTYPES_STR_REPR = [ + str(ArrowDtype(typ)) for typ in ALL_INT_PYARROW_DTYPES + ] # pa.float16 doesn't seem supported # https://github.com/apache/arrow/blob/master/python/pyarrow/src/arrow/python/helpers.cc#L86 FLOAT_PYARROW_DTYPES = [pa.float32(), pa.float64()] + FLOAT_PYARROW_DTYPES_STR_REPR = [ + str(ArrowDtype(typ)) for typ in FLOAT_PYARROW_DTYPES + ] STRING_PYARROW_DTYPES = [pa.string()] BINARY_PYARROW_DTYPES = [pa.binary()] @@ -234,6 +241,9 @@ + TIMEDELTA_PYARROW_DTYPES + BOOL_PYARROW_DTYPES ) +else: + FLOAT_PYARROW_DTYPES_STR_REPR = [] + ALL_INT_PYARROW_DTYPES_STR_REPR = [] EMPTY_STRING_PATTERN = re.compile("^$") diff --git a/pandas/conftest.py b/pandas/conftest.py index 2e9638036eec5..b49dfeb92e2af 100644 --- a/pandas/conftest.py +++ b/pandas/conftest.py @@ -1527,6 +1527,43 @@ def any_numeric_ea_dtype(request): return request.param +# Unsupported operand types for + ("List[Union[str, ExtensionDtype, dtype[Any], +# Type[object]]]" and "List[str]") +@pytest.fixture( + params=tm.ALL_INT_EA_DTYPES + + tm.FLOAT_EA_DTYPES + + tm.ALL_INT_PYARROW_DTYPES_STR_REPR + + tm.FLOAT_PYARROW_DTYPES_STR_REPR # type: ignore[operator] +) +def any_numeric_ea_and_arrow_dtype(request): + """ + Parameterized fixture for any nullable integer dtype and + any float ea dtypes. + + * 'UInt8' + * 'Int8' + * 'UInt16' + * 'Int16' + * 'UInt32' + * 'Int32' + * 'UInt64' + * 'Int64' + * 'Float32' + * 'Float64' + * 'uint8[pyarrow]' + * 'int8[pyarrow]' + * 'uint16[pyarrow]' + * 'int16[pyarrow]' + * 'uint32[pyarrow]' + * 'int32[pyarrow]' + * 'uint64[pyarrow]' + * 'int64[pyarrow]' + * 'float32[pyarrow]' + * 'float64[pyarrow]' + """ + return request.param + + @pytest.fixture(params=tm.SIGNED_INT_EA_DTYPES) def any_signed_int_ea_dtype(request): """ diff --git a/pandas/tests/reshape/test_get_dummies.py b/pandas/tests/reshape/test_get_dummies.py index ed4da9562aeee..daac5a0c9dac2 100644 --- a/pandas/tests/reshape/test_get_dummies.py +++ b/pandas/tests/reshape/test_get_dummies.py @@ -658,22 +658,22 @@ def test_get_dummies_with_string_values(self, values): with pytest.raises(TypeError, match=msg): get_dummies(df, columns=values) - def test_get_dummies_ea_dtype_series(self, any_numeric_ea_dtype): + def test_get_dummies_ea_dtype_series(self, any_numeric_ea_and_arrow_dtype): # GH#32430 ser = Series(list("abca")) - result = get_dummies(ser, dtype=any_numeric_ea_dtype) + result = get_dummies(ser, dtype=any_numeric_ea_and_arrow_dtype) expected = DataFrame( {"a": [1, 0, 0, 1], "b": [0, 1, 0, 0], "c": [0, 0, 1, 0]}, - dtype=any_numeric_ea_dtype, + dtype=any_numeric_ea_and_arrow_dtype, ) tm.assert_frame_equal(result, expected) - def test_get_dummies_ea_dtype_dataframe(self, any_numeric_ea_dtype): + def test_get_dummies_ea_dtype_dataframe(self, any_numeric_ea_and_arrow_dtype): # GH#32430 df = DataFrame({"x": list("abca")}) - result = get_dummies(df, dtype=any_numeric_ea_dtype) + result = get_dummies(df, dtype=any_numeric_ea_and_arrow_dtype) expected = DataFrame( {"x_a": [1, 0, 0, 1], "x_b": [0, 1, 0, 0], "x_c": [0, 0, 1, 0]}, - dtype=any_numeric_ea_dtype, + dtype=any_numeric_ea_and_arrow_dtype, ) tm.assert_frame_equal(result, expected)