Skip to content

Commit 67453a5

Browse files
committed
Add get method to DataFrame
1 parent f67e4fa commit 67453a5

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ from typing import (
1313
Any,
1414
ClassVar,
1515
Literal,
16+
TypeVar,
1617
overload,
1718
)
1819

@@ -127,6 +128,8 @@ from pandas.plotting import PlotAccessor
127128
_str = str
128129
_bool = bool
129130

131+
_T = TypeVar("_T")
132+
130133
class _iLocIndexerFrame(_iLocIndexer):
131134
@overload
132135
def __getitem__(self, idx: tuple[int, int]) -> Scalar: ...
@@ -1683,7 +1686,14 @@ class DataFrame(NDFrame, OpsMixin):
16831686
# def from_dict
16841687
# def from_records
16851688
def ge(self, other, axis: Axis = ..., level: Level | None = ...) -> DataFrame: ...
1686-
# def get
1689+
@overload
1690+
def get(self, key: Hashable, default: None = ...) -> Series | None: ...
1691+
@overload
1692+
def get(self, key: Hashable, default: _T = ...) -> Series | _T: ...
1693+
@overload
1694+
def get(self, key: list[Hashable], default: None = ...) -> DataFrame | None: ...
1695+
@overload
1696+
def get(self, key: list[Hashable], default: _T = ...) -> DataFrame | _T: ...
16871697
def gt(self, other, axis: Axis = ..., level: Level | None = ...) -> DataFrame: ...
16881698
def head(self, n: int = ...) -> DataFrame: ...
16891699
def infer_objects(self) -> DataFrame: ...

tests/test_frame.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3107,3 +3107,34 @@ def test_itertuples() -> None:
31073107
for item in df.itertuples():
31083108
check(assert_type(item, _PandasNamedTuple), tuple)
31093109
assert_type(item.a, Scalar)
3110+
3111+
3112+
def test_get() -> None:
3113+
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]})
3114+
3115+
# Get single column
3116+
check(assert_type(df.get("a"), Union[pd.Series, None]), pd.Series, np.int64)
3117+
check(assert_type(df.get("z"), Union[pd.Series, None]), type(None))
3118+
check(
3119+
assert_type(df.get("a", default=None), Union[pd.Series, None]),
3120+
pd.Series,
3121+
np.int64,
3122+
)
3123+
check(
3124+
assert_type(df.get("a", default=1), Union[pd.Series, int]), pd.Series, np.int64
3125+
)
3126+
check(assert_type(df.get("z", default=1), Union[pd.Series, int]), int)
3127+
3128+
# Get multiple columns
3129+
check(assert_type(df.get(["a"]), Union[pd.DataFrame, None]), pd.DataFrame)
3130+
check(assert_type(df.get(["a", "b"]), Union[pd.DataFrame, None]), pd.DataFrame)
3131+
check(assert_type(df.get(["z"]), Union[pd.DataFrame, None]), type(None))
3132+
check(
3133+
assert_type(df.get(["a", "b"], default=None), Union[pd.DataFrame, None]),
3134+
pd.DataFrame,
3135+
)
3136+
check(
3137+
assert_type(df.get(["a", "b"], default=1), Union[pd.DataFrame, int]),
3138+
pd.DataFrame,
3139+
)
3140+
check(assert_type(df.get(["z"], default=1), Union[pd.DataFrame, int]), int)

0 commit comments

Comments
 (0)