Skip to content

Commit 3a80166

Browse files
authored
BUG: support median function for custom BaseIndexer rolling windows (#33626)
1 parent f16179a commit 3a80166

File tree

4 files changed

+24
-4
lines changed

4 files changed

+24
-4
lines changed

doc/source/whatsnew/v1.1.0.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ Other API changes
175175
- :meth:`Groupby.groups` now returns an abbreviated representation when called on large dataframes (:issue:`1135`)
176176
- ``loc`` lookups with an object-dtype :class:`Index` and an integer key will now raise ``KeyError`` instead of ``TypeError`` when key is missing (:issue:`31905`)
177177
- Using a :func:`pandas.api.indexers.BaseIndexer` with ``skew``, ``cov``, ``corr`` will now raise a ``NotImplementedError`` (:issue:`32865`)
178-
- Using a :func:`pandas.api.indexers.BaseIndexer` with ``count``, ``min``, ``max`` will now return correct results for any monotonic :func:`pandas.api.indexers.BaseIndexer` descendant (:issue:`32865`)
178+
- Using a :func:`pandas.api.indexers.BaseIndexer` with ``count``, ``min``, ``max``, ``median`` will now return correct results for any monotonic :func:`pandas.api.indexers.BaseIndexer` descendant (:issue:`32865`)
179179
- Added a :func:`pandas.api.indexers.FixedForwardWindowIndexer` class to support forward-looking windows during ``rolling`` operations.
180180
-
181181

pandas/_libs/window/aggregations.pyx

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -843,7 +843,8 @@ def roll_kurt_variable(ndarray[float64_t] values, ndarray[int64_t] start,
843843

844844

845845
def roll_median_c(ndarray[float64_t] values, ndarray[int64_t] start,
846-
ndarray[int64_t] end, int64_t minp, int64_t win):
846+
ndarray[int64_t] end, int64_t minp, int64_t win=0):
847+
# GH 32865. win argument kept for compatibility
847848
cdef:
848849
float64_t val, res, prev
849850
bint err = False
@@ -858,7 +859,7 @@ def roll_median_c(ndarray[float64_t] values, ndarray[int64_t] start,
858859
# actual skiplist ops outweigh any window computation costs
859860
output = np.empty(N, dtype=float)
860861

861-
if win == 0 or (end - start).max() == 0:
862+
if (end - start).max() == 0:
862863
output[:] = NaN
863864
return output
864865
win = (end - start).max()

pandas/core/window/rolling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1431,7 +1431,8 @@ def mean(self, *args, **kwargs):
14311431

14321432
def median(self, **kwargs):
14331433
window_func = self._get_roll_func("roll_median_c")
1434-
window_func = partial(window_func, win=self._get_window())
1434+
# GH 32865. Move max window size calculation to
1435+
# the median function implementation
14351436
return self._apply(window_func, center=self.center, name="median", **kwargs)
14361437

14371438
def std(self, ddof=1, *args, **kwargs):

pandas/tests/window/test_base_indexer.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,12 @@ def get_window_bounds(self, num_values, min_periods, center, closed):
141141
],
142142
{"ddof": 1},
143143
),
144+
(
145+
"median",
146+
np.median,
147+
[1.0, 2.0, 3.0, 4.0, 6.0, 7.0, 7.0, 8.0, 8.5, np.nan],
148+
{},
149+
),
144150
],
145151
)
146152
def test_rolling_forward_window(constructor, func, np_func, expected, np_kwargs):
@@ -162,7 +168,19 @@ def test_rolling_forward_window(constructor, func, np_func, expected, np_kwargs)
162168

163169
rolling = constructor(values).rolling(window=indexer, min_periods=2)
164170
result = getattr(rolling, func)()
171+
172+
# Check that the function output matches the explicitly provided array
165173
expected = constructor(expected)
166174
tm.assert_equal(result, expected)
175+
176+
# Check that the rolling function output matches applying an alternative
177+
# function to the rolling window object
167178
expected2 = constructor(rolling.apply(lambda x: np_func(x, **np_kwargs)))
168179
tm.assert_equal(result, expected2)
180+
181+
# Check that the function output matches applying an alternative function
182+
# if min_periods isn't specified
183+
rolling3 = constructor(values).rolling(window=indexer)
184+
result3 = getattr(rolling3, func)()
185+
expected3 = constructor(rolling3.apply(lambda x: np_func(x, **np_kwargs)))
186+
tm.assert_equal(result3, expected3)

0 commit comments

Comments
 (0)