From a848787ddb014bb645ded59cc72fbd4a62196b06 Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Tue, 16 Aug 2022 18:24:08 -0400 Subject: [PATCH 1/2] fix Series.split with expand=True --- pandas-stubs/core/indexes/base.pyi | 3 ++- pandas-stubs/core/series.pyi | 2 +- pandas-stubs/core/strings.pyi | 22 ++++++++++++++++++++-- tests/test_indexes.py | 7 +++++++ tests/test_series.py | 2 ++ 5 files changed, 32 insertions(+), 4 deletions(-) diff --git a/pandas-stubs/core/indexes/base.pyi b/pandas-stubs/core/indexes/base.pyi index 6f9b6deda..5db495ea6 100644 --- a/pandas-stubs/core/indexes/base.pyi +++ b/pandas-stubs/core/indexes/base.pyi @@ -12,6 +12,7 @@ from typing import ( import numpy as np from pandas import ( DataFrame, + MultiIndex, Series, ) from pandas.core.arrays import ExtensionArray @@ -58,7 +59,7 @@ class Index(IndexOpsMixin, PandasObject): tupleize_cols: bool = ..., ): ... @property - def str(self) -> StringMethods[Index]: ... + def str(self) -> StringMethods[Index, MultiIndex]: ... @property def asi8(self) -> np_ndarray_int64: ... def is_(self, other) -> bool: ... diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index e15cffc7c..fc49615ba 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -854,7 +854,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]): ) -> Series[S1]: ... def to_period(self, freq: _str | None = ..., copy: _bool = ...) -> DataFrame: ... @property - def str(self) -> StringMethods[Series]: ... + def str(self) -> StringMethods[Series, DataFrame]: ... @property def dt(self) -> CombinedDatetimelikeProperties: ... @property diff --git a/pandas-stubs/core/strings.pyi b/pandas-stubs/core/strings.pyi index a853b8538..b194726a4 100644 --- a/pandas-stubs/core/strings.pyi +++ b/pandas-stubs/core/strings.pyi @@ -5,17 +5,25 @@ from typing import ( Generic, Literal, Sequence, + TypeVar, overload, ) import numpy as np import pandas as pd -from pandas import Series +from pandas import ( + DataFrame, + MultiIndex, + Series, +) from pandas.core.base import NoNewAttributesMixin from pandas._typing import T -class StringMethods(NoNewAttributesMixin, Generic[T]): +# The _TS type is what is used for the result of str.split with expand=True +_TS = TypeVar("_TS", DataFrame, MultiIndex) + +class StringMethods(NoNewAttributesMixin, Generic[T, _TS]): def __init__(self, data: T) -> None: ... def __getitem__(self, key: slice | int) -> T: ... def __iter__(self) -> T: ... @@ -44,9 +52,19 @@ class StringMethods(NoNewAttributesMixin, Generic[T]): na_rep: str | None = ..., join: Literal["left", "right", "outer", "inner"] = ..., ) -> T: ... + @overload + def split( + self, pat: str = ..., n: int = ..., *, expand: Literal[True], regex: bool = ... + ) -> _TS: ... + @overload def split( self, pat: str = ..., n: int = ..., expand: bool = ..., *, regex: bool = ... ) -> T: ... + @overload + def rsplit( + self, pat: str = ..., n: int = ..., *, expand: Literal[True], regex: bool = ... + ) -> T: ... + @overload def rsplit( self, pat: str = ..., n: int = ..., expand: bool = ..., *, regex: bool = ... ) -> T: ... diff --git a/tests/test_indexes.py b/tests/test_indexes.py index e0262d911..eb3797171 100644 --- a/tests/test_indexes.py +++ b/tests/test_indexes.py @@ -68,3 +68,10 @@ def test_difference_none() -> None: # https://github.com/pandas-dev/pandas-stubs/issues/17 ind = pd.Index([1, 2, 3]) check(assert_type(ind.difference([1, None]), "pd.Index"), pd.Index, int) + + +def test_str_split() -> None: + # GH 194 + ind = pd.Index(["a-b", "c-d"]) + check(assert_type(ind.str.split("-"), pd.Index), pd.Index) + check(assert_type(ind.str.split("-", expand=True), pd.MultiIndex), pd.MultiIndex) diff --git a/tests/test_series.py b/tests/test_series.py index 9ad01b00e..79e8bd192 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -977,6 +977,8 @@ def test_string_accessors(): check(assert_type(s.str.slice(0, 4, 2), pd.Series), pd.Series) check(assert_type(s.str.slice_replace(0, 2, "XX"), pd.Series), pd.Series) check(assert_type(s.str.split("a"), pd.Series), pd.Series) + # GH 194 + check(assert_type(s.str.split("a", expand=True), pd.DataFrame), pd.DataFrame) check(assert_type(s.str.startswith("a"), "pd.Series[bool]"), pd.Series, bool) check(assert_type(s.str.strip(), pd.Series), pd.Series) check(assert_type(s.str.swapcase(), pd.Series), pd.Series) From ffbd5afb778c7abce38e84dd8e92c163849d4ba9 Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Tue, 16 Aug 2022 20:13:22 -0400 Subject: [PATCH 2/2] align asterisk in split params --- pandas-stubs/core/strings.pyi | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas-stubs/core/strings.pyi b/pandas-stubs/core/strings.pyi index b194726a4..e8ead0ec4 100644 --- a/pandas-stubs/core/strings.pyi +++ b/pandas-stubs/core/strings.pyi @@ -58,7 +58,7 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS]): ) -> _TS: ... @overload def split( - self, pat: str = ..., n: int = ..., expand: bool = ..., *, regex: bool = ... + self, pat: str = ..., n: int = ..., *, expand: bool = ..., regex: bool = ... ) -> T: ... @overload def rsplit( @@ -66,7 +66,7 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS]): ) -> T: ... @overload def rsplit( - self, pat: str = ..., n: int = ..., expand: bool = ..., *, regex: bool = ... + self, pat: str = ..., n: int = ..., *, expand: bool = ..., regex: bool = ... ) -> T: ... @overload def partition(self, sep: str = ...) -> pd.DataFrame: ...