diff --git a/asv_bench/benchmarks/strings.py b/asv_bench/benchmarks/strings.py index 79ea2a4fba284..0f68d1043b49d 100644 --- a/asv_bench/benchmarks/strings.py +++ b/asv_bench/benchmarks/strings.py @@ -249,10 +249,18 @@ def time_rsplit(self, dtype, expand): class Dummies: - def setup(self): - self.s = Series(tm.makeStringIndex(10 ** 5)).str.join("|") + params = ["str", "string", "arrow_string"] + param_names = ["dtype"] + + def setup(self, dtype): + from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401 + + try: + self.s = Series(tm.makeStringIndex(10 ** 5), dtype=dtype).str.join("|") + except ImportError: + raise NotImplementedError - def time_get_dummies(self): + def time_get_dummies(self, dtype): self.s.str.get_dummies("|") diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index 2646ddfa45b58..3b4549b55d1aa 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -24,6 +24,7 @@ is_categorical_dtype, is_integer, is_list_like, + is_object_dtype, is_re, ) from pandas.core.dtypes.generic import ( @@ -265,7 +266,11 @@ def _wrap_result( # infer from ndim if expand is not specified expand = result.ndim != 1 - elif expand is True and not isinstance(self._orig, ABCIndex): + elif ( + expand is True + and is_object_dtype(result) + and not isinstance(self._orig, ABCIndex) + ): # required when expand=True is explicitly specified # not needed when inferred diff --git a/pandas/tests/strings/test_strings.py b/pandas/tests/strings/test_strings.py index 5d8a63fe481f8..86c90398d0259 100644 --- a/pandas/tests/strings/test_strings.py +++ b/pandas/tests/strings/test_strings.py @@ -301,17 +301,19 @@ def test_isnumeric(any_string_dtype): tm.assert_series_equal(s.str.isdecimal(), Series(decimal_e, dtype=dtype)) -def test_get_dummies(): - s = Series(["a|b", "a|c", np.nan]) +def test_get_dummies(any_string_dtype): + s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype) result = s.str.get_dummies("|") expected = DataFrame([[1, 1, 0], [1, 0, 1], [0, 0, 0]], columns=list("abc")) tm.assert_frame_equal(result, expected) - s = Series(["a;b", "a", 7]) + s = Series(["a;b", "a", 7], dtype=any_string_dtype) result = s.str.get_dummies(";") expected = DataFrame([[0, 1, 1], [0, 1, 0], [1, 0, 0]], columns=list("7ab")) tm.assert_frame_equal(result, expected) + +def test_get_dummies_index(): # GH9980, GH8028 idx = Index(["a|b", "a|c", "b|c"]) result = idx.str.get_dummies("|") @@ -322,14 +324,18 @@ def test_get_dummies(): tm.assert_index_equal(result, expected) -def test_get_dummies_with_name_dummy(): +def test_get_dummies_with_name_dummy(any_string_dtype): # GH 12180 # Dummies named 'name' should work as expected - s = Series(["a", "b,name", "b"]) + s = Series(["a", "b,name", "b"], dtype=any_string_dtype) result = s.str.get_dummies(",") expected = DataFrame([[1, 0, 0], [0, 1, 1], [0, 1, 0]], columns=["a", "b", "name"]) tm.assert_frame_equal(result, expected) + +def test_get_dummies_with_name_dummy_index(): + # GH 12180 + # Dummies named 'name' should work as expected idx = Index(["a|b", "name|c", "b|name"]) result = idx.str.get_dummies("|")