From 6f6da9aad538f62cdf5ded3b9f8244aef35fda9d Mon Sep 17 00:00:00 2001 From: Tobias Pitters Date: Sun, 28 Nov 2021 12:00:07 +0100 Subject: [PATCH 1/2] fix kahans summation for the inf case --- pandas/_libs/groupby.pyx | 31 ++++++++++++++++++++------- pandas/_libs/window/aggregations.pyx | 8 +++++++ pandas/tests/groupby/test_function.py | 28 ++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 8 deletions(-) diff --git a/pandas/_libs/groupby.pyx b/pandas/_libs/groupby.pyx index 078cb8e02e824..bdf702b436f7e 100644 --- a/pandas/_libs/groupby.pyx +++ b/pandas/_libs/groupby.pyx @@ -26,7 +26,7 @@ from numpy cimport ( uint32_t, uint64_t, ) -from numpy.math cimport NAN +from numpy.math cimport NAN, isinf cnp.import_array() @@ -51,7 +51,14 @@ from pandas._libs.missing cimport checknull cdef int64_t NPY_NAT = get_nat() _int64_max = np.iinfo(np.int64).max -cdef float64_t NaN = np.NaN +cdef: + float32_t MINfloat32 = np.NINF + float64_t MINfloat64 = np.NINF + + float32_t MAXfloat32 = np.inf + float64_t MAXfloat64 = np.inf + + float64_t NaN = np.NaN cdef enum InterpolationEnumType: INTERPOLATION_LINEAR, @@ -251,13 +258,18 @@ def group_cumsum(numeric_t[:, ::1] out, # For floats, use Kahan summation to reduce floating-point # error (https://en.wikipedia.org/wiki/Kahan_summation_algorithm) - if numeric_t == float32_t or numeric_t == float64_t: + if numeric_t is float32_t or numeric_t is float64_t: if val == val: - y = val - compensation[lab, j] - t = accum[lab, j] + y - compensation[lab, j] = t - accum[lab, j] - y - accum[lab, j] = t - out[i, j] = t + # if val or accum are inf/-inf don't use kahan + if isinf(val) or isinf(accum[lab, j]): + accum[lab, j] += val + out[i, j] = accum[lab, j] + else: + y = val - compensation[lab, j] + t = accum[lab, j] + y + compensation[lab, j] = t - accum[lab, j] - y + accum[lab, j] = t + out[i, j] = t else: out[i, j] = NaN if not skipna: @@ -556,6 +568,9 @@ def group_add(add_t[:, ::1] out, for j in range(K): val = values[i, j] + if (val == MAXfloat64) or (val == MINfloat64): + sumx[lab, j] = val + break # not nan if val == val: nobs[lab, j] += 1 diff --git a/pandas/_libs/window/aggregations.pyx b/pandas/_libs/window/aggregations.pyx index 98201a6f58499..236022cb04fb8 100644 --- a/pandas/_libs/window/aggregations.pyx +++ b/pandas/_libs/window/aggregations.pyx @@ -100,6 +100,10 @@ cdef inline void add_sum(float64_t val, int64_t *nobs, float64_t *sum_x, t = sum_x[0] + y compensation[0] = t - sum_x[0] - y sum_x[0] = t + if (val == MINfloat64) or (val == MAXfloat64): + sum_x[0] = val + nobs[0] = nobs[0] + 1 + compensation[0] = 0 cdef inline void remove_sum(float64_t val, int64_t *nobs, float64_t *sum_x, @@ -116,6 +120,10 @@ cdef inline void remove_sum(float64_t val, int64_t *nobs, float64_t *sum_x, t = sum_x[0] + y compensation[0] = t - sum_x[0] - y sum_x[0] = t + if (val == MINfloat64) or (val == MAXfloat64): + sum_x[0] = val + nobs[0] = nobs[0] - 1 + compensation[0] = 0 def roll_sum(const float64_t[:] values, ndarray[int64_t] start, diff --git a/pandas/tests/groupby/test_function.py b/pandas/tests/groupby/test_function.py index c462db526b36d..eb6574981db07 100644 --- a/pandas/tests/groupby/test_function.py +++ b/pandas/tests/groupby/test_function.py @@ -1162,3 +1162,31 @@ def test_mean_on_timedelta(): pd.to_timedelta([4, 5]), name="time", index=Index(["A", "B"], name="cat") ) tm.assert_series_equal(result, expected) + + +def test_sum_with_nan_inf(): + df = DataFrame( + {"a": ["hello", "hello", "world", "world"], "b": [np.inf, 10, np.nan, 10]} + ) + gb = df.groupby("a") + result = gb.sum() + expected = DataFrame( + [np.inf, 10], index=Index(["hello", "world"], name="a"), columns=["b"] + ) + tm.assert_frame_equal(result, expected) + + +def test_cumsum_inf(): + ser = Series([np.inf, 1, 1]) + + result = ser.groupby([1, 1, 1]).cumsum() + expected = Series([np.inf, np.inf, np.inf]) + tm.assert_series_equal(result, expected) + + +def test_cumsum_ninf_inf(): + ser = Series([np.inf, 1, 1, -np.inf, 1]) + + result = ser.groupby([1, 1, 1, 1, 1]).cumsum() + expected = Series([np.inf, np.inf, np.inf, np.nan, np.nan]) + tm.assert_series_equal(result, expected) From c0f32ac7f085071513ec55118a3ac35a72d679d1 Mon Sep 17 00:00:00 2001 From: Tobias Pitters Date: Sun, 28 Nov 2021 12:14:13 +0100 Subject: [PATCH 2/2] fix formatting error --- pandas/_libs/groupby.pyx | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pandas/_libs/groupby.pyx b/pandas/_libs/groupby.pyx index bdf702b436f7e..efdef60f8ae5f 100644 --- a/pandas/_libs/groupby.pyx +++ b/pandas/_libs/groupby.pyx @@ -26,7 +26,10 @@ from numpy cimport ( uint32_t, uint64_t, ) -from numpy.math cimport NAN, isinf +from numpy.math cimport ( + NAN, + isinf, +) cnp.import_array()