Skip to content

Commit 116832d

Browse files
authored
TYP: Persist typing information for pipe args and kwargs (#56760)
* Type generic pipe with function params * Type common pipe with function params * Type resample pipe with function params * Type groupby pipe with function params * Type style pipe function params and tuple func
1 parent 82449b9 commit 116832d

File tree

6 files changed

+152
-17
lines changed

6 files changed

+152
-17
lines changed

pandas/_typing.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,18 +90,29 @@
9090
from typing import SupportsIndex
9191

9292
if sys.version_info >= (3, 10):
93+
from typing import Concatenate # pyright: ignore[reportUnusedImport]
94+
from typing import ParamSpec
9395
from typing import TypeGuard # pyright: ignore[reportUnusedImport]
9496
else:
95-
from typing_extensions import TypeGuard # pyright: ignore[reportUnusedImport]
97+
from typing_extensions import ( # pyright: ignore[reportUnusedImport]
98+
Concatenate,
99+
ParamSpec,
100+
TypeGuard,
101+
)
102+
103+
P = ParamSpec("P")
96104

97105
if sys.version_info >= (3, 11):
98106
from typing import Self # pyright: ignore[reportUnusedImport]
99107
else:
100108
from typing_extensions import Self # pyright: ignore[reportUnusedImport]
109+
101110
else:
102111
npt: Any = None
112+
ParamSpec: Any = None
103113
Self: Any = None
104114
TypeGuard: Any = None
115+
Concatenate: Any = None
105116

106117
HashableT = TypeVar("HashableT", bound=Hashable)
107118
MutableMappingT = TypeVar("MutableMappingT", bound=MutableMapping)

pandas/core/common.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
TYPE_CHECKING,
2525
Any,
2626
Callable,
27+
TypeVar,
2728
cast,
2829
overload,
2930
)
@@ -51,7 +52,9 @@
5152
from pandas._typing import (
5253
AnyArrayLike,
5354
ArrayLike,
55+
Concatenate,
5456
NpDtype,
57+
P,
5558
RandomState,
5659
T,
5760
)
@@ -463,8 +466,34 @@ def random_state(state: RandomState | None = None):
463466
)
464467

465468

469+
_T = TypeVar("_T") # Secondary TypeVar for use in pipe's type hints
470+
471+
472+
@overload
473+
def pipe(
474+
obj: _T,
475+
func: Callable[Concatenate[_T, P], T],
476+
*args: P.args,
477+
**kwargs: P.kwargs,
478+
) -> T:
479+
...
480+
481+
482+
@overload
483+
def pipe(
484+
obj: Any,
485+
func: tuple[Callable[..., T], str],
486+
*args: Any,
487+
**kwargs: Any,
488+
) -> T:
489+
...
490+
491+
466492
def pipe(
467-
obj, func: Callable[..., T] | tuple[Callable[..., T], str], *args, **kwargs
493+
obj: _T,
494+
func: Callable[Concatenate[_T, P], T] | tuple[Callable[..., T], str],
495+
*args: Any,
496+
**kwargs: Any,
468497
) -> T:
469498
"""
470499
Apply a function ``func`` to object ``obj`` either by passing obj as the
@@ -490,12 +519,13 @@ def pipe(
490519
object : the return type of ``func``.
491520
"""
492521
if isinstance(func, tuple):
493-
func, target = func
522+
# Assigning to func_ so pyright understands that it's a callable
523+
func_, target = func
494524
if target in kwargs:
495525
msg = f"{target} is both the pipe target and a keyword argument"
496526
raise ValueError(msg)
497527
kwargs[target] = obj
498-
return func(*args, **kwargs)
528+
return func_(*args, **kwargs)
499529
else:
500530
return func(obj, *args, **kwargs)
501531

pandas/core/generic.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
Axis,
5151
AxisInt,
5252
CompressionOptions,
53+
Concatenate,
5354
DtypeArg,
5455
DtypeBackend,
5556
DtypeObj,
@@ -213,6 +214,7 @@
213214
)
214215

215216
from pandas._libs.tslibs import BaseOffset
217+
from pandas._typing import P
216218

217219
from pandas import (
218220
DataFrame,
@@ -6118,13 +6120,31 @@ def sample(
61186120

61196121
return result
61206122

6123+
@overload
6124+
def pipe(
6125+
self,
6126+
func: Callable[Concatenate[Self, P], T],
6127+
*args: P.args,
6128+
**kwargs: P.kwargs,
6129+
) -> T:
6130+
...
6131+
6132+
@overload
6133+
def pipe(
6134+
self,
6135+
func: tuple[Callable[..., T], str],
6136+
*args: Any,
6137+
**kwargs: Any,
6138+
) -> T:
6139+
...
6140+
61216141
@final
61226142
@doc(klass=_shared_doc_kwargs["klass"])
61236143
def pipe(
61246144
self,
6125-
func: Callable[..., T] | tuple[Callable[..., T], str],
6126-
*args,
6127-
**kwargs,
6145+
func: Callable[Concatenate[Self, P], T] | tuple[Callable[..., T], str],
6146+
*args: Any,
6147+
**kwargs: Any,
61286148
) -> T:
61296149
r"""
61306150
Apply chainable functions that expect Series or DataFrames.

pandas/core/groupby/groupby.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class providing the base-class of operations.
2929
Union,
3030
cast,
3131
final,
32+
overload,
3233
)
3334
import warnings
3435

@@ -55,7 +56,6 @@ class providing the base-class of operations.
5556
PositionalIndexer,
5657
RandomState,
5758
Scalar,
58-
T,
5959
npt,
6060
)
6161
from pandas.compat.numpy import function as nv
@@ -147,7 +147,13 @@ class providing the base-class of operations.
147147
)
148148

149149
if TYPE_CHECKING:
150-
from typing import Any
150+
from pandas._typing import (
151+
Any,
152+
Concatenate,
153+
P,
154+
Self,
155+
T,
156+
)
151157

152158
from pandas.core.resample import Resampler
153159
from pandas.core.window import (
@@ -989,6 +995,24 @@ def _selected_obj(self):
989995
def _dir_additions(self) -> set[str]:
990996
return self.obj._dir_additions()
991997

998+
@overload
999+
def pipe(
1000+
self,
1001+
func: Callable[Concatenate[Self, P], T],
1002+
*args: P.args,
1003+
**kwargs: P.kwargs,
1004+
) -> T:
1005+
...
1006+
1007+
@overload
1008+
def pipe(
1009+
self,
1010+
func: tuple[Callable[..., T], str],
1011+
*args: Any,
1012+
**kwargs: Any,
1013+
) -> T:
1014+
...
1015+
9921016
@Substitution(
9931017
klass="GroupBy",
9941018
examples=dedent(
@@ -1014,9 +1038,9 @@ def _dir_additions(self) -> set[str]:
10141038
@Appender(_pipe_template)
10151039
def pipe(
10161040
self,
1017-
func: Callable[..., T] | tuple[Callable[..., T], str],
1018-
*args,
1019-
**kwargs,
1041+
func: Callable[Concatenate[Self, P], T] | tuple[Callable[..., T], str],
1042+
*args: Any,
1043+
**kwargs: Any,
10201044
) -> T:
10211045
return com.pipe(self, func, *args, **kwargs)
10221046

pandas/core/resample.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
cast,
1010
final,
1111
no_type_check,
12+
overload,
1213
)
1314
import warnings
1415

@@ -97,12 +98,16 @@
9798
from collections.abc import Hashable
9899

99100
from pandas._typing import (
101+
Any,
100102
AnyArrayLike,
101103
Axis,
102104
AxisInt,
105+
Concatenate,
103106
Frequency,
104107
IndexLabel,
105108
InterpolateOptions,
109+
P,
110+
Self,
106111
T,
107112
TimedeltaConvertibleTypes,
108113
TimeGrouperOrigin,
@@ -254,6 +259,24 @@ def _get_binner(self):
254259
bin_grouper = BinGrouper(bins, binlabels, indexer=self._indexer)
255260
return binner, bin_grouper
256261

262+
@overload
263+
def pipe(
264+
self,
265+
func: Callable[Concatenate[Self, P], T],
266+
*args: P.args,
267+
**kwargs: P.kwargs,
268+
) -> T:
269+
...
270+
271+
@overload
272+
def pipe(
273+
self,
274+
func: tuple[Callable[..., T], str],
275+
*args: Any,
276+
**kwargs: Any,
277+
) -> T:
278+
...
279+
257280
@final
258281
@Substitution(
259282
klass="Resampler",
@@ -278,9 +301,9 @@ def _get_binner(self):
278301
@Appender(_pipe_template)
279302
def pipe(
280303
self,
281-
func: Callable[..., T] | tuple[Callable[..., T], str],
282-
*args,
283-
**kwargs,
304+
func: Callable[Concatenate[Self, P], T] | tuple[Callable[..., T], str],
305+
*args: Any,
306+
**kwargs: Any,
284307
) -> T:
285308
return super().pipe(func, *args, **kwargs)
286309

pandas/io/formats/style.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import operator
1010
from typing import (
1111
TYPE_CHECKING,
12-
Any,
1312
Callable,
1413
overload,
1514
)
@@ -66,15 +65,20 @@
6665
from matplotlib.colors import Colormap
6766

6867
from pandas._typing import (
68+
Any,
6969
Axis,
7070
AxisInt,
71+
Concatenate,
7172
FilePath,
7273
IndexLabel,
7374
IntervalClosedType,
7475
Level,
76+
P,
7577
QuantileInterpolation,
7678
Scalar,
79+
Self,
7780
StorageOptions,
81+
T,
7882
WriteBuffer,
7983
WriteExcelBuffer,
8084
)
@@ -3614,7 +3618,30 @@ class MyStyler(cls): # type: ignore[valid-type,misc]
36143618

36153619
return MyStyler
36163620

3617-
def pipe(self, func: Callable, *args, **kwargs):
3621+
@overload
3622+
def pipe(
3623+
self,
3624+
func: Callable[Concatenate[Self, P], T],
3625+
*args: P.args,
3626+
**kwargs: P.kwargs,
3627+
) -> T:
3628+
...
3629+
3630+
@overload
3631+
def pipe(
3632+
self,
3633+
func: tuple[Callable[..., T], str],
3634+
*args: Any,
3635+
**kwargs: Any,
3636+
) -> T:
3637+
...
3638+
3639+
def pipe(
3640+
self,
3641+
func: Callable[Concatenate[Self, P], T] | tuple[Callable[..., T], str],
3642+
*args: Any,
3643+
**kwargs: Any,
3644+
) -> T:
36183645
"""
36193646
Apply ``func(self, *args, **kwargs)``, and return the result.
36203647

0 commit comments

Comments
 (0)