diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 5f7354578..89d839f9e 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -2106,3 +2106,10 @@ class DataFrame(NDFrame, OpsMixin): ) -> DataFrame: ... # Move from generic because Series is Generic and it returns Series[bool] there def __invert__(self) -> DataFrame: ... + def xs( + self, + key: Hashable, + axis: Axis = ..., + level: Level | None = ..., + drop_level: _bool = ..., + ) -> DataFrame | Series: ... diff --git a/pandas-stubs/core/generic.pyi b/pandas-stubs/core/generic.pyi index 0cbeb7f54..984a65384 100644 --- a/pandas-stubs/core/generic.pyi +++ b/pandas-stubs/core/generic.pyi @@ -15,10 +15,7 @@ from typing import ( ) import numpy as np -from pandas import ( - DataFrame, - Index, -) +from pandas import Index from pandas.core.base import PandasObject import pandas.core.indexing as indexing import sqlalchemy.engine @@ -289,13 +286,6 @@ class NDFrame(PandasObject, indexing.IndexingMixin): def take( self, indices, axis=..., is_copy: _bool | None = ..., **kwargs ) -> NDFrame: ... - def xs( - self, - key: Hashable, - axis: AxisIndex = ..., - level: Level | None = ..., - drop_level: _bool = ..., - ) -> DataFrame | Series: ... def __delitem__(self, idx: Hashable): ... def get(self, key: object, default: Dtype | None = ...) -> Dtype: ... def reindex_like( diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index a514a7195..d95807817 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -1950,6 +1950,13 @@ class TimedeltaSeries(Series[Timedelta]): numeric_only: _bool = ..., **kwargs, ) -> Timedelta: ... + def xs( + self, + key: Hashable, + axis: AxisIndex = ..., + level: Level | None = ..., + drop_level: _bool = ..., + ) -> Series: ... class PeriodSeries(Series[Period]): # ignore needed because of mypy diff --git a/tests/test_frame.py b/tests/test_frame.py index d2f7dd030..c0dc5c7f4 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -2467,3 +2467,19 @@ def test_astype() -> None: states = pd.DataFrame({"population": population, "area": area}) check(assert_type(states.astype(object), pd.DataFrame), pd.DataFrame, object) + + +def test_xs_frame_new() -> None: + d = { + "num_legs": [4, 4, 2, 2], + "num_wings": [0, 0, 2, 2], + "class": ["mammal", "mammal", "mammal", "bird"], + "animal": ["cat", "dog", "bat", "penguin"], + "locomotion": ["walks", "walks", "flies", "walks"], + } + df = pd.DataFrame(data=d) + df = df.set_index(["class", "animal", "locomotion"]) + s1 = df.xs("mammal", axis=0) + s2 = df.xs("num_wings", axis=1) + check(assert_type(s1, Union[pd.Series, pd.DataFrame]), pd.DataFrame) + check(assert_type(s2, Union[pd.Series, pd.DataFrame]), pd.Series) diff --git a/tests/test_series.py b/tests/test_series.py index 373431c7f..172e389d4 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -1748,3 +1748,9 @@ def test_updated_astype() -> None: pd.Series, np.integer, ) + + +def test_check_xs() -> None: + s4 = pd.Series([1, 4]) + s4.xs(0, axis=0) + check(assert_type(s4, pd.Series), pd.Series)