Skip to content

Commit d2e3f74

Browse files
authored
gh:524 - xs method added to series and frame.pyi (#567)
* xs method added to series and frame.pyi * Update test_frame.py * Update test_frame.py * Update test_frame.py
1 parent 927a2d0 commit d2e3f74

File tree

5 files changed

+37
-11
lines changed

5 files changed

+37
-11
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2106,3 +2106,10 @@ class DataFrame(NDFrame, OpsMixin):
21062106
) -> DataFrame: ...
21072107
# Move from generic because Series is Generic and it returns Series[bool] there
21082108
def __invert__(self) -> DataFrame: ...
2109+
def xs(
2110+
self,
2111+
key: Hashable,
2112+
axis: Axis = ...,
2113+
level: Level | None = ...,
2114+
drop_level: _bool = ...,
2115+
) -> DataFrame | Series: ...

pandas-stubs/core/generic.pyi

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,7 @@ from typing import (
1515
)
1616

1717
import numpy as np
18-
from pandas import (
19-
DataFrame,
20-
Index,
21-
)
18+
from pandas import Index
2219
from pandas.core.base import PandasObject
2320
import pandas.core.indexing as indexing
2421
import sqlalchemy.engine
@@ -289,13 +286,6 @@ class NDFrame(PandasObject, indexing.IndexingMixin):
289286
def take(
290287
self, indices, axis=..., is_copy: _bool | None = ..., **kwargs
291288
) -> NDFrame: ...
292-
def xs(
293-
self,
294-
key: Hashable,
295-
axis: AxisIndex = ...,
296-
level: Level | None = ...,
297-
drop_level: _bool = ...,
298-
) -> DataFrame | Series: ...
299289
def __delitem__(self, idx: Hashable): ...
300290
def get(self, key: object, default: Dtype | None = ...) -> Dtype: ...
301291
def reindex_like(

pandas-stubs/core/series.pyi

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1950,6 +1950,13 @@ class TimedeltaSeries(Series[Timedelta]):
19501950
numeric_only: _bool = ...,
19511951
**kwargs,
19521952
) -> Timedelta: ...
1953+
def xs(
1954+
self,
1955+
key: Hashable,
1956+
axis: AxisIndex = ...,
1957+
level: Level | None = ...,
1958+
drop_level: _bool = ...,
1959+
) -> Series: ...
19531960

19541961
class PeriodSeries(Series[Period]):
19551962
# ignore needed because of mypy

tests/test_frame.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2467,3 +2467,19 @@ def test_astype() -> None:
24672467

24682468
states = pd.DataFrame({"population": population, "area": area})
24692469
check(assert_type(states.astype(object), pd.DataFrame), pd.DataFrame, object)
2470+
2471+
2472+
def test_xs_frame_new() -> None:
2473+
d = {
2474+
"num_legs": [4, 4, 2, 2],
2475+
"num_wings": [0, 0, 2, 2],
2476+
"class": ["mammal", "mammal", "mammal", "bird"],
2477+
"animal": ["cat", "dog", "bat", "penguin"],
2478+
"locomotion": ["walks", "walks", "flies", "walks"],
2479+
}
2480+
df = pd.DataFrame(data=d)
2481+
df = df.set_index(["class", "animal", "locomotion"])
2482+
s1 = df.xs("mammal", axis=0)
2483+
s2 = df.xs("num_wings", axis=1)
2484+
check(assert_type(s1, Union[pd.Series, pd.DataFrame]), pd.DataFrame)
2485+
check(assert_type(s2, Union[pd.Series, pd.DataFrame]), pd.Series)

tests/test_series.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1748,3 +1748,9 @@ def test_updated_astype() -> None:
17481748
pd.Series,
17491749
np.integer,
17501750
)
1751+
1752+
1753+
def test_check_xs() -> None:
1754+
s4 = pd.Series([1, 4])
1755+
s4.xs(0, axis=0)
1756+
check(assert_type(s4, pd.Series), pd.Series)

0 commit comments

Comments
 (0)