diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index c8e2cdff7..dc88eb28f 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -268,32 +268,61 @@ class DataFrame(NDFrame, OpsMixin): @overload def to_dict( self, - orient: Literal["dict", "list", "series", "split", "tight", "index"], + orient: Literal["records"], into: Mapping | type[Mapping], + index: Literal[True] = ..., + ) -> list[Mapping[Hashable, Any]]: ... + @overload + def to_dict( + self, + orient: Literal["records"], + into: None = ..., + index: Literal[True] = ..., + ) -> list[dict[Hashable, Any]]: ... + @overload + def to_dict( + self, + orient: Literal["dict", "list", "series", "index"], + into: Mapping | type[Mapping], + index: Literal[True] = ..., ) -> Mapping[Hashable, Any]: ... @overload def to_dict( self, - orient: Literal["dict", "list", "series", "split", "tight", "index"] = ..., - *, + orient: Literal["split", "tight"], into: Mapping | type[Mapping], + index: bool = ..., ) -> Mapping[Hashable, Any]: ... @overload def to_dict( self, - orient: Literal["dict", "list", "series", "split", "tight", "index"] = ..., - into: None = ..., - ) -> dict[Hashable, Any]: ... + orient: Literal["dict", "list", "series", "index"] = ..., + *, + into: Mapping | type[Mapping], + index: Literal[True] = ..., + ) -> Mapping[Hashable, Any]: ... @overload def to_dict( self, - orient: Literal["records"], + orient: Literal["split", "tight"] = ..., + *, into: Mapping | type[Mapping], - ) -> list[Mapping[Hashable, Any]]: ... + index: bool = ..., + ) -> Mapping[Hashable, Any]: ... @overload def to_dict( - self, orient: Literal["records"], into: None = ... - ) -> list[dict[Hashable, Any]]: ... + self, + orient: Literal["dict", "list", "series", "index"] = ..., + into: None = ..., + index: Literal[True] = ..., + ) -> dict[Hashable, Any]: ... + @overload + def to_dict( + self, + orient: Literal["split", "tight"] = ..., + into: None = ..., + index: bool = ..., + ) -> dict[Hashable, Any]: ... def to_gbq( self, destination_table: str, @@ -1400,8 +1429,8 @@ class DataFrame(NDFrame, OpsMixin): level: Level | None = ..., fill_value: float | None = ..., ) -> DataFrame: ... - def add_prefix(self, prefix: _str) -> DataFrame: ... - def add_suffix(self, suffix: _str) -> DataFrame: ... + def add_prefix(self, prefix: _str, axis: Axis | None = None) -> DataFrame: ... + def add_suffix(self, suffix: _str, axis: Axis | None = None) -> DataFrame: ... @overload def all( self, diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index c87bb2af1..227458764 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -1024,8 +1024,8 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]): def pop(self, item: Hashable) -> S1: ... def squeeze(self, axis: AxisIndex | None = ...) -> Scalar: ... def __abs__(self) -> Series[S1]: ... - def add_prefix(self, prefix: _str) -> Series[S1]: ... - def add_suffix(self, suffix: _str) -> Series[S1]: ... + def add_prefix(self, prefix: _str, axis: AxisIndex | None = ...) -> Series[S1]: ... + def add_suffix(self, suffix: _str, axis: AxisIndex | None = ...) -> Series[S1]: ... def reindex( self, index: Axes | None = ..., diff --git a/tests/test_frame.py b/tests/test_frame.py index 03d8299c4..893eb46f0 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -2526,3 +2526,47 @@ def test_loc_returns_series() -> None: df1 = pd.DataFrame({"x": [1, 2, 3, 4]}, index=[10, 20, 30, 40]) df2 = df1.loc[10, :] check(assert_type(df2, Union[pd.Series, pd.DataFrame]), pd.Series) + + +def test_to_dict_index() -> None: + df = pd.DataFrame({"a": [1, 2], "b": [9, 10]}) + check( + assert_type( + df.to_dict(orient="records", index=True), List[Dict[Hashable, Any]] + ), + list, + ) + check(assert_type(df.to_dict(orient="dict", index=True), Dict[Hashable, Any]), dict) + check( + assert_type(df.to_dict(orient="series", index=True), Dict[Hashable, Any]), dict + ) + check( + assert_type(df.to_dict(orient="index", index=True), Dict[Hashable, Any]), dict + ) + check( + assert_type(df.to_dict(orient="split", index=True), Dict[Hashable, Any]), dict + ) + check( + assert_type(df.to_dict(orient="tight", index=True), Dict[Hashable, Any]), dict + ) + check( + assert_type(df.to_dict(orient="tight", index=False), Dict[Hashable, Any]), dict + ) + check( + assert_type(df.to_dict(orient="split", index=False), Dict[Hashable, Any]), dict + ) + if TYPE_CHECKING_INVALID_USAGE: + check(assert_type(df.to_dict(orient="records", index=False), List[Dict[Hashable, Any]]), list) # type: ignore[assert-type, call-overload] # pyright: ignore[reportGeneralTypeIssues] + check(assert_type(df.to_dict(orient="dict", index=False), Dict[Hashable, Any]), dict) # type: ignore[assert-type, call-overload] # pyright: ignore[reportGeneralTypeIssues] + check(assert_type(df.to_dict(orient="series", index=False), Dict[Hashable, Any]), dict) # type: ignore[assert-type, call-overload] # pyright: ignore[reportGeneralTypeIssues] + check(assert_type(df.to_dict(orient="index", index=False), Dict[Hashable, Any]), dict) # type: ignore[assert-type, call-overload] # pyright: ignore[reportGeneralTypeIssues] + + +def test_suffix_prefix_index() -> None: + df = pd.DataFrame({"A": [1, 2, 3, 4], "B": [3, 4, 5, 6]}) + check(assert_type(df.add_suffix("_col", axis=1), pd.DataFrame), pd.DataFrame) + check(assert_type(df.add_suffix("_col", axis="index"), pd.DataFrame), pd.DataFrame) + check(assert_type(df.add_prefix("_col", axis="index"), pd.DataFrame), pd.DataFrame) + check( + assert_type(df.add_prefix("_col", axis="columns"), pd.DataFrame), pd.DataFrame + ) diff --git a/tests/test_series.py b/tests/test_series.py index 543aac5b1..8435c134b 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -1813,3 +1813,15 @@ def test_types_apply_set() -> None: {"list1": [1, 2, 3], "list2": ["a", "b", "c"], "list3": [True, False, True]} ) check(assert_type(series_of_lists.apply(lambda x: set(x)), pd.Series), pd.Series) + + +def test_prefix_summix_axis() -> None: + s = pd.Series([1, 2, 3, 4]) + check(assert_type(s.add_suffix("_item", axis=0), pd.Series), pd.Series) + check(assert_type(s.add_suffix("_item", axis="index"), pd.Series), pd.Series) + check(assert_type(s.add_prefix("_item", axis=0), pd.Series), pd.Series) + check(assert_type(s.add_prefix("_item", axis="index"), pd.Series), pd.Series) + + if TYPE_CHECKING_INVALID_USAGE: + check(assert_type(s.add_prefix("_item", axis=1), pd.Series), pd.Series) # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues] + check(assert_type(s.add_suffix("_item", axis="columns"), pd.Series), pd.Series) # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues]