-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
Use ea interface to calculate accumulator functions for datetimelike #50297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
79f9cd4
6fd82e3
4aa5019
d9653c0
4ffa497
8305cf3
9487b86
eb55b3c
47c3730
ed46798
d1fa417
f72cc65
177284a
589218d
93215ae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
""" | ||
datetimelke_accumulations.py is for accumulations of datetimelike extension arrays | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Callable | ||
|
||
import numpy as np | ||
|
||
from pandas._libs import iNaT | ||
|
||
from pandas.core.dtypes.missing import isna | ||
|
||
|
||
def _cum_func( | ||
func: Callable, | ||
values: np.ndarray, | ||
*, | ||
skipna: bool = True, | ||
): | ||
""" | ||
Accumulations for 1D datetimelike arrays. | ||
|
||
Parameters | ||
---------- | ||
func : np.cumsum, np.maximum.accumulate, np.minimum.accumulate | ||
values : np.ndarray | ||
Numpy array with the values (can be of any dtype that support the | ||
operation). Values is changed is modified inplace. | ||
skipna : bool, default True | ||
Whether to skip NA. | ||
""" | ||
try: | ||
fill_value = { | ||
np.maximum.accumulate: np.iinfo(np.int64).min, | ||
np.cumsum: 0, | ||
mroeschke marked this conversation as resolved.
Show resolved
Hide resolved
|
||
np.minimum.accumulate: np.iinfo(np.int64).max, | ||
}[func] | ||
except KeyError: | ||
raise ValueError(f"No accumulation for {func} implemented on BaseMaskedArray") | ||
|
||
mask = isna(values) | ||
y = values.view("i8") | ||
y[mask] = fill_value | ||
|
||
if not skipna: | ||
mask = np.maximum.accumulate(mask) | ||
|
||
result = func(y) | ||
result[mask] = iNaT | ||
|
||
if values.dtype.kind in ["m", "M"]: | ||
return result.view(values.dtype.base) | ||
return result | ||
|
||
|
||
def cumsum(values: np.ndarray, *, skipna: bool = True) -> np.ndarray: | ||
return _cum_func(np.cumsum, values, skipna=skipna) | ||
|
||
|
||
def cummin(values: np.ndarray, *, skipna: bool = True): | ||
return _cum_func(np.minimum.accumulate, values, skipna=skipna) | ||
|
||
|
||
def cummax(values: np.ndarray, *, skipna: bool = True): | ||
return _cum_func(np.maximum.accumulate, values, skipna=skipna) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1712,52 +1712,7 @@ def na_accum_func(values: ArrayLike, accum_func, *, skipna: bool) -> ArrayLike: | |
}[accum_func] | ||
|
||
# We will be applying this function to block values | ||
if values.dtype.kind in ["m", "M"]: | ||
# GH#30460, GH#29058 | ||
# numpy 1.18 started sorting NaTs at the end instead of beginning, | ||
# so we need to work around to maintain backwards-consistency. | ||
orig_dtype = values.dtype | ||
|
||
# We need to define mask before masking NaTs | ||
mask = isna(values) | ||
|
||
y = values.view("i8") | ||
# Note: the accum_func comparison fails as an "is" comparison | ||
changed = accum_func == np.minimum.accumulate | ||
|
||
try: | ||
if changed: | ||
y[mask] = lib.i8max | ||
|
||
result = accum_func(y, axis=0) | ||
finally: | ||
if changed: | ||
# restore NaT elements | ||
y[mask] = iNaT | ||
|
||
if skipna: | ||
result[mask] = iNaT | ||
elif accum_func == np.minimum.accumulate: | ||
# Restore NaTs that we masked previously | ||
nz = (~np.asarray(mask)).nonzero()[0] | ||
if len(nz): | ||
# everything up to the first non-na entry stays NaT | ||
result[: nz[0]] = iNaT | ||
|
||
if isinstance(values.dtype, np.dtype): | ||
result = result.view(orig_dtype) | ||
else: | ||
# DatetimeArray/TimedeltaArray | ||
# TODO: have this case go through a DTA method? | ||
# For DatetimeTZDtype, view result as M8[ns] | ||
npdtype = orig_dtype if isinstance(orig_dtype, np.dtype) else "M8[ns]" | ||
# Item "type" of "Union[Type[ExtensionArray], Type[ndarray[Any, Any]]]" | ||
# has no attribute "_simple_new" | ||
result = type(values)._simple_new( # type: ignore[union-attr] | ||
result.view(npdtype), dtype=orig_dtype | ||
) | ||
|
||
elif skipna and not issubclass(values.dtype.type, (np.integer, np.bool_)): | ||
if skipna and not issubclass(values.dtype.type, (np.integer, np.bool_)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe a comment/check that "mM" cases should not get here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (OK for follow-up, i can add this into my next "assorted" branch) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added an assert |
||
vals = values.copy() | ||
mask = isna(vals) | ||
vals[mask] = mask_a | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import pytest | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you call this test_cumulative.py to match tests/series/ and tests/frame/ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
import pandas._testing as tm | ||
from pandas.core.arrays import DatetimeArray | ||
|
||
|
||
class TestAccumulator: | ||
def test_accumulators_freq(self): | ||
# GH#50297 | ||
arr = DatetimeArray._from_sequence_not_strict( | ||
[ | ||
"2000-01-01", | ||
"2000-01-02", | ||
"2000-01-03", | ||
], | ||
freq="D", | ||
) | ||
result = arr._accumulate("cummin") | ||
expected = DatetimeArray._from_sequence_not_strict( | ||
["2000-01-01"] * 3, freq=None | ||
) | ||
tm.assert_datetime_array_equal(result, expected) | ||
|
||
result = arr._accumulate("cummax") | ||
expected = DatetimeArray._from_sequence_not_strict( | ||
[ | ||
"2000-01-01", | ||
"2000-01-02", | ||
"2000-01-03", | ||
], | ||
freq=None, | ||
) | ||
tm.assert_datetime_array_equal(result, expected) | ||
|
||
@pytest.mark.parametrize("func", ["cumsum", "cumprod"]) | ||
def test_accumulators_disallowed(self, func): | ||
# GH#50297 | ||
arr = DatetimeArray._from_sequence_not_strict( | ||
[ | ||
"2000-01-01", | ||
"2000-01-02", | ||
], | ||
freq="D", | ||
) | ||
with pytest.raises(TypeError, match=f"Accumulation {func}"): | ||
arr._accumulate(func) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import pytest | ||
|
||
import pandas._testing as tm | ||
from pandas.core.arrays import TimedeltaArray | ||
|
||
|
||
class TestAccumulator: | ||
def test_accumulators_disallowed(self): | ||
# GH#50297 | ||
arr = TimedeltaArray._from_sequence_not_strict(["1D", "2D"]) | ||
with pytest.raises(TypeError, match="cumprod not supported"): | ||
arr._accumulate("cumprod") | ||
|
||
def test_cumsum(self): | ||
# GH#50297 | ||
arr = TimedeltaArray._from_sequence_not_strict(["1D", "2D"]) | ||
result = arr._accumulate("cumsum") | ||
expected = TimedeltaArray._from_sequence_not_strict(["1D", "3D"]) | ||
tm.assert_timedelta_array_equal(result, expected) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,12 +70,12 @@ def test_cummin_cummax(self, datetime_series, method): | |
[ | ||
"cummax", | ||
False, | ||
["NaT", "2 days", "2 days", "2 days", "2 days", "3 days"], | ||
["NaT", "NaT", "NaT", "NaT", "NaT", "NaT"], | ||
], | ||
[ | ||
"cummin", | ||
False, | ||
["NaT", "2 days", "2 days", "1 days", "1 days", "1 days"], | ||
["NaT", "NaT", "NaT", "NaT", "NaT", "NaT"], | ||
], | ||
], | ||
) | ||
|
@@ -91,6 +91,26 @@ def test_cummin_cummax_datetimelike(self, ts, method, skipna, exp_tdi): | |
result = getattr(ser, method)(skipna=skipna) | ||
tm.assert_series_equal(expected, result) | ||
|
||
@pytest.mark.parametrize( | ||
"func, exp", | ||
[ | ||
("cummin", pd.Period("2012-1-1", freq="D")), | ||
("cummax", pd.Period("2012-1-2", freq="D")), | ||
], | ||
) | ||
def test_cummin_cummax_period(self, func, exp): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for checking this |
||
# GH#28385 | ||
ser = pd.Series( | ||
[pd.Period("2012-1-1", freq="D"), pd.NaT, pd.Period("2012-1-2", freq="D")] | ||
) | ||
result = getattr(ser, func)(skipna=False) | ||
expected = pd.Series([pd.Period("2012-1-1", freq="D"), pd.NaT, pd.NaT]) | ||
tm.assert_series_equal(result, expected) | ||
|
||
result = getattr(ser, func)(skipna=True) | ||
expected = pd.Series([pd.Period("2012-1-1", freq="D"), pd.NaT, exp]) | ||
tm.assert_series_equal(result, expected) | ||
|
||
@pytest.mark.parametrize( | ||
"arg", | ||
[ | ||
|
Uh oh!
There was an error while loading. Please reload this page.