Skip to content

Commit 0b0fa9b

Browse files
authored
TYP: resample (#51178)
1 parent bae0bf0 commit 0b0fa9b

File tree

2 files changed

+58
-34
lines changed

2 files changed

+58
-34
lines changed

pandas/core/generic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8783,7 +8783,7 @@ def resample(
87838783

87848784
axis = self._get_axis_number(axis)
87858785
return get_resampler(
8786-
self,
8786+
cast("Series | DataFrame", self),
87878787
freq=rule,
87888788
label=label,
87898789
closed=closed,

pandas/core/resample.py

Lines changed: 57 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class Resampler(BaseGroupBy, PandasObject):
131131

132132
grouper: BinGrouper
133133
_timegrouper: TimeGrouper
134+
binner: DatetimeIndex | TimedeltaIndex | PeriodIndex # depends on subclass
134135
exclusions: frozenset[Hashable] = frozenset() # for SelectionMixin compat
135136

136137
# to the groupby descriptor
@@ -147,15 +148,15 @@ class Resampler(BaseGroupBy, PandasObject):
147148

148149
def __init__(
149150
self,
150-
obj: DataFrame | Series,
151-
groupby: TimeGrouper,
151+
obj: NDFrame,
152+
timegrouper: TimeGrouper,
152153
axis: Axis = 0,
153154
kind=None,
154155
*,
155156
group_keys: bool | lib.NoDefault = lib.no_default,
156157
selection=None,
157158
) -> None:
158-
self._timegrouper = groupby
159+
self._timegrouper = timegrouper
159160
self.keys = None
160161
self.sort = True
161162
# error: Incompatible types in assignment (expression has type "Union
@@ -466,7 +467,7 @@ def _groupby_and_aggregate(self, how, *args, **kwargs):
466467

467468
return self._wrap_result(result)
468469

469-
def _get_resampler_for_grouping(self, groupby, key=None):
470+
def _get_resampler_for_grouping(self, groupby: GroupBy, key):
470471
"""
471472
Return the correct class for resampling with groupby.
472473
"""
@@ -1158,18 +1159,29 @@ class _GroupByMixin(PandasObject):
11581159

11591160
_attributes: list[str] # in practice the same as Resampler._attributes
11601161
_selection: IndexLabel | None = None
1162+
_groupby: GroupBy
1163+
_timegrouper: TimeGrouper
11611164

1162-
def __init__(self, *, parent: Resampler, groupby=None, key=None, **kwargs) -> None:
1165+
def __init__(
1166+
self,
1167+
*,
1168+
parent: Resampler,
1169+
groupby: GroupBy,
1170+
key=None,
1171+
selection: IndexLabel | None = None,
1172+
) -> None:
11631173
# reached via ._gotitem and _get_resampler_for_grouping
11641174

1175+
assert isinstance(groupby, GroupBy), type(groupby)
1176+
11651177
# parent is always a Resampler, sometimes a _GroupByMixin
11661178
assert isinstance(parent, Resampler), type(parent)
11671179

11681180
# initialize our GroupByMixin object with
11691181
# the resampler attributes
11701182
for attr in self._attributes:
1171-
setattr(self, attr, kwargs.get(attr, getattr(parent, attr)))
1172-
self._selection = kwargs.get("selection")
1183+
setattr(self, attr, getattr(parent, attr))
1184+
self._selection = selection
11731185

11741186
self.binner = parent.binner
11751187
self.key = key
@@ -1185,7 +1197,7 @@ def _apply(self, f, *args, **kwargs):
11851197
"""
11861198

11871199
def func(x):
1188-
x = self._resampler_cls(x, groupby=self._timegrouper)
1200+
x = self._resampler_cls(x, timegrouper=self._timegrouper)
11891201

11901202
if isinstance(f, str):
11911203
return getattr(x, f)(**kwargs)
@@ -1217,10 +1229,6 @@ def _gotitem(self, key, ndim, subset=None):
12171229
# error: "GotItemMixin" has no attribute "obj"
12181230
subset = self.obj # type: ignore[attr-defined]
12191231

1220-
# we need to make a shallow copy of ourselves
1221-
# with the same groupby
1222-
kwargs = {attr: getattr(self, attr) for attr in self._attributes}
1223-
12241232
# Try to select from a DataFrame, falling back to a Series
12251233
try:
12261234
if isinstance(key, list) and self.key not in key and self.key is not None:
@@ -1239,7 +1247,6 @@ def _gotitem(self, key, ndim, subset=None):
12391247
groupby=groupby,
12401248
parent=cast(Resampler, self),
12411249
selection=selection,
1242-
**kwargs,
12431250
)
12441251
return new_rs
12451252

@@ -1515,9 +1522,7 @@ def _resampler_cls(self):
15151522
return TimedeltaIndexResampler
15161523

15171524

1518-
def get_resampler(
1519-
obj, kind=None, **kwds
1520-
) -> DatetimeIndexResampler | PeriodIndexResampler | TimedeltaIndexResampler:
1525+
def get_resampler(obj: Series | DataFrame, kind=None, **kwds) -> Resampler:
15211526
"""
15221527
Create a TimeGrouper and return our resampler.
15231528
"""
@@ -1529,8 +1534,15 @@ def get_resampler(
15291534

15301535

15311536
def get_resampler_for_grouping(
1532-
groupby, rule, how=None, fill_method=None, limit=None, kind=None, on=None, **kwargs
1533-
):
1537+
groupby: GroupBy,
1538+
rule,
1539+
how=None,
1540+
fill_method=None,
1541+
limit=None,
1542+
kind=None,
1543+
on=None,
1544+
**kwargs,
1545+
) -> Resampler:
15341546
"""
15351547
Return our appropriate resampler when grouping as well.
15361548
"""
@@ -1657,19 +1669,19 @@ def __init__(
16571669

16581670
super().__init__(freq=freq, axis=axis, **kwargs)
16591671

1660-
def _get_resampler(self, obj, kind=None):
1672+
def _get_resampler(self, obj: NDFrame, kind=None) -> Resampler:
16611673
"""
16621674
Return my resampler or raise if we have an invalid axis.
16631675
16641676
Parameters
16651677
----------
1666-
obj : input object
1678+
obj : Series or DataFrame
16671679
kind : string, optional
16681680
'period','timestamp','timedelta' are valid
16691681
16701682
Returns
16711683
-------
1672-
a Resampler
1684+
Resampler
16731685
16741686
Raises
16751687
------
@@ -1681,15 +1693,23 @@ def _get_resampler(self, obj, kind=None):
16811693
ax = self.ax
16821694
if isinstance(ax, DatetimeIndex):
16831695
return DatetimeIndexResampler(
1684-
obj, groupby=self, kind=kind, axis=self.axis, group_keys=self.group_keys
1696+
obj,
1697+
timegrouper=self,
1698+
kind=kind,
1699+
axis=self.axis,
1700+
group_keys=self.group_keys,
16851701
)
16861702
elif isinstance(ax, PeriodIndex) or kind == "period":
16871703
return PeriodIndexResampler(
1688-
obj, groupby=self, kind=kind, axis=self.axis, group_keys=self.group_keys
1704+
obj,
1705+
timegrouper=self,
1706+
kind=kind,
1707+
axis=self.axis,
1708+
group_keys=self.group_keys,
16891709
)
16901710
elif isinstance(ax, TimedeltaIndex):
16911711
return TimedeltaIndexResampler(
1692-
obj, groupby=self, axis=self.axis, group_keys=self.group_keys
1712+
obj, timegrouper=self, axis=self.axis, group_keys=self.group_keys
16931713
)
16941714

16951715
raise TypeError(
@@ -1698,10 +1718,12 @@ def _get_resampler(self, obj, kind=None):
16981718
f"but got an instance of '{type(ax).__name__}'"
16991719
)
17001720

1701-
def _get_grouper(self, obj, validate: bool = True):
1721+
def _get_grouper(
1722+
self, obj: NDFrameT, validate: bool = True
1723+
) -> tuple[BinGrouper, NDFrameT]:
17021724
# create the resampler and return our binner
17031725
r = self._get_resampler(obj)
1704-
return r.grouper, r.obj
1726+
return r.grouper, cast(NDFrameT, r.obj)
17051727

17061728
def _get_time_bins(self, ax: DatetimeIndex):
17071729
if not isinstance(ax, DatetimeIndex):
@@ -1766,19 +1788,21 @@ def _get_time_bins(self, ax: DatetimeIndex):
17661788

17671789
return binner, bins, labels
17681790

1769-
def _adjust_bin_edges(self, binner, ax_values):
1791+
def _adjust_bin_edges(
1792+
self, binner: DatetimeIndex, ax_values: npt.NDArray[np.int64]
1793+
) -> tuple[DatetimeIndex, npt.NDArray[np.int64]]:
17701794
# Some hacks for > daily data, see #1471, #1458, #1483
17711795

17721796
if self.freq != "D" and is_superperiod(self.freq, "D"):
17731797
if self.closed == "right":
17741798
# GH 21459, GH 9119: Adjust the bins relative to the wall time
1775-
bin_edges = binner.tz_localize(None)
1776-
bin_edges = (
1777-
bin_edges
1778-
+ Timedelta(days=1, unit=bin_edges.unit).as_unit(bin_edges.unit)
1779-
- Timedelta(1, unit=bin_edges.unit).as_unit(bin_edges.unit)
1799+
edges_dti = binner.tz_localize(None)
1800+
edges_dti = (
1801+
edges_dti
1802+
+ Timedelta(days=1, unit=edges_dti.unit).as_unit(edges_dti.unit)
1803+
- Timedelta(1, unit=edges_dti.unit).as_unit(edges_dti.unit)
17801804
)
1781-
bin_edges = bin_edges.tz_localize(binner.tz).asi8
1805+
bin_edges = edges_dti.tz_localize(binner.tz).asi8
17821806
else:
17831807
bin_edges = binner.asi8
17841808

0 commit comments

Comments
 (0)