Skip to content

Commit 903d6d7

Browse files
authored
make Series.apply() return Series or DataFrame based on callable (#343)
* make Series.apply() return Series or DataFrame based on callable * allow Hashable as result of Callable to return Series
1 parent f7b2207 commit 903d6d7

File tree

2 files changed

+41
-6
lines changed

2 files changed

+41
-6
lines changed

pandas-stubs/core/series.pyi

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -663,9 +663,22 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
663663
*args,
664664
**kwargs,
665665
) -> DataFrame: ...
666+
@overload
666667
def apply(
667-
self, func: Callable, convertDType: _bool = ..., args: tuple = ..., **kwds
668-
) -> Series | DataFrame: ...
668+
self,
669+
func: Callable[..., Hashable],
670+
convertDType: _bool = ...,
671+
args: tuple = ...,
672+
**kwds,
673+
) -> Series: ...
674+
@overload
675+
def apply(
676+
self,
677+
func: Callable[..., Series],
678+
convertDType: _bool = ...,
679+
args: tuple = ...,
680+
**kwds,
681+
) -> DataFrame: ...
669682
def align(
670683
self,
671684
other: DataFrame | Series,

tests/test_series.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -374,10 +374,32 @@ def test_types_unique() -> None:
374374

375375

376376
def test_types_apply() -> None:
377-
s = pd.Series([-10, 2, 2, 3, 10, 10])
378-
s.apply(lambda x: x**2)
379-
s.apply(np.exp)
380-
s.apply(str)
377+
s = pd.Series([-10, 2, 2, 3.4, 10, 10])
378+
379+
def square(x: float) -> float:
380+
return x**2
381+
382+
check(assert_type(s.apply(square), pd.Series), pd.Series, float)
383+
check(assert_type(s.apply(np.exp), pd.Series), pd.Series, float)
384+
check(assert_type(s.apply(str), pd.Series), pd.Series, str)
385+
386+
def makeseries(x: float) -> pd.Series:
387+
return pd.Series([x, 2 * x])
388+
389+
check(assert_type(s.apply(makeseries), pd.DataFrame), pd.DataFrame)
390+
391+
# GH 293
392+
393+
def retseries(x: float) -> float:
394+
return x
395+
396+
check(assert_type(s.apply(retseries).tolist(), list), list)
397+
398+
def get_depth(url: str) -> int:
399+
return len(url)
400+
401+
ss = s.astype(str)
402+
check(assert_type(ss.apply(get_depth), pd.Series), pd.Series, int)
381403

382404

383405
def test_types_element_wise_arithmetic() -> None:

0 commit comments

Comments
 (0)