Skip to content

Commit b8ae9bb

Browse files
authored
BUG: masked mean unnecessarily overflowing (#48378)
1 parent ce143a2 commit b8ae9bb

File tree

5 files changed

+55
-19
lines changed

5 files changed

+55
-19
lines changed

doc/source/whatsnew/v1.6.0.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ Sparse
200200

201201
ExtensionArray
202202
^^^^^^^^^^^^^^
203-
-
203+
- Bug in :meth:`Series.mean` overflowing unnecessarily with nullable integers (:issue:`48378`)
204204
-
205205

206206
Styler

pandas/core/array_algos/masked_reductions.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pandas.core.nanops import check_below_min_count
1515

1616

17-
def _sumprod(
17+
def _reductions(
1818
func: Callable,
1919
values: np.ndarray,
2020
mask: npt.NDArray[np.bool_],
@@ -24,7 +24,7 @@ def _sumprod(
2424
axis: int | None = None,
2525
):
2626
"""
27-
Sum or product for 1D masked array.
27+
Sum, mean or product for 1D masked array.
2828
2929
Parameters
3030
----------
@@ -63,7 +63,7 @@ def sum(
6363
min_count: int = 0,
6464
axis: int | None = None,
6565
):
66-
return _sumprod(
66+
return _reductions(
6767
np.sum, values=values, mask=mask, skipna=skipna, min_count=min_count, axis=axis
6868
)
6969

@@ -76,7 +76,7 @@ def prod(
7676
min_count: int = 0,
7777
axis: int | None = None,
7878
):
79-
return _sumprod(
79+
return _reductions(
8080
np.prod, values=values, mask=mask, skipna=skipna, min_count=min_count, axis=axis
8181
)
8282

@@ -139,11 +139,13 @@ def max(
139139
return _minmax(np.max, values=values, mask=mask, skipna=skipna, axis=axis)
140140

141141

142-
# TODO: axis kwarg
143-
def mean(values: np.ndarray, mask: npt.NDArray[np.bool_], skipna: bool = True):
142+
def mean(
143+
values: np.ndarray,
144+
mask: npt.NDArray[np.bool_],
145+
*,
146+
skipna: bool = True,
147+
axis: int | None = None,
148+
):
144149
if not values.size or mask.all():
145150
return libmissing.NA
146-
_sum = _sumprod(np.sum, values=values, mask=mask, skipna=skipna)
147-
count = np.count_nonzero(~mask)
148-
mean_value = _sum / count
149-
return mean_value
151+
return _reductions(np.mean, values=values, mask=mask, skipna=skipna, axis=axis)

pandas/core/arrays/masked.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,17 +1036,12 @@ def _quantile(
10361036
# Reductions
10371037

10381038
def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
1039-
if name in {"any", "all", "min", "max", "sum", "prod"}:
1039+
if name in {"any", "all", "min", "max", "sum", "prod", "mean"}:
10401040
return getattr(self, name)(skipna=skipna, **kwargs)
10411041

10421042
data = self._data
10431043
mask = self._mask
10441044

1045-
if name in {"mean"}:
1046-
op = getattr(masked_reductions, name)
1047-
result = op(data, mask, skipna=skipna, **kwargs)
1048-
return result
1049-
10501045
# coerce to a nan-aware float if needed
10511046
# (we explicitly use NaN within reductions)
10521047
if self._hasna:
@@ -1107,6 +1102,18 @@ def prod(self, *, skipna=True, min_count=0, axis: int | None = 0, **kwargs):
11071102
"prod", result, skipna=skipna, axis=axis, **kwargs
11081103
)
11091104

1105+
def mean(self, *, skipna=True, axis: int | None = 0, **kwargs):
1106+
nv.validate_mean((), kwargs)
1107+
result = masked_reductions.mean(
1108+
self._data,
1109+
self._mask,
1110+
skipna=skipna,
1111+
axis=axis,
1112+
)
1113+
return self._wrap_reduction_result(
1114+
"mean", result, skipna=skipna, axis=axis, **kwargs
1115+
)
1116+
11101117
def min(self, *, skipna=True, axis: int | None = 0, **kwargs):
11111118
nv.validate_min((), kwargs)
11121119
return masked_reductions.min(

pandas/tests/extension/base/dim2.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@
66

77
from pandas._libs.missing import is_matching_na
88

9+
from pandas.core.dtypes.common import (
10+
is_bool_dtype,
11+
is_integer_dtype,
12+
)
13+
914
import pandas as pd
15+
import pandas._testing as tm
1016
from pandas.core.arrays.integer import INT_STR_TO_DTYPE
1117
from pandas.tests.extension.base.base import BaseExtensionTests
1218

@@ -191,7 +197,12 @@ def test_reductions_2d_axis0(self, data, method):
191197
kwargs["ddof"] = 0
192198

193199
try:
194-
result = getattr(arr2d, method)(axis=0, **kwargs)
200+
if method == "mean" and hasattr(data, "_mask"):
201+
# Empty slices produced by the mask cause RuntimeWarnings by numpy
202+
with tm.assert_produces_warning(RuntimeWarning, check_stacklevel=False):
203+
result = getattr(arr2d, method)(axis=0, **kwargs)
204+
else:
205+
result = getattr(arr2d, method)(axis=0, **kwargs)
195206
except Exception as err:
196207
try:
197208
getattr(data, method)()
@@ -212,7 +223,7 @@ def get_reduction_result_dtype(dtype):
212223
# i.e. dtype.kind == "u"
213224
return INT_STR_TO_DTYPE[np.dtype(np.uint).name]
214225

215-
if method in ["mean", "median", "sum", "prod"]:
226+
if method in ["median", "sum", "prod"]:
216227
# std and var are not dtype-preserving
217228
expected = data
218229
if method in ["sum", "prod"] and data.dtype.kind in "iub":
@@ -229,6 +240,10 @@ def get_reduction_result_dtype(dtype):
229240
self.assert_extension_array_equal(result, expected)
230241
elif method == "std":
231242
self.assert_extension_array_equal(result, data - data)
243+
elif method == "mean":
244+
if is_integer_dtype(data) or is_bool_dtype(data):
245+
data = data.astype("Float64")
246+
self.assert_extension_array_equal(result, data)
232247
# punt on method == "var"
233248

234249
@pytest.mark.parametrize("method", ["mean", "median", "var", "std", "sum", "prod"])

pandas/tests/reductions/test_reductions.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,18 @@ def test_sum_overflow_float(self, use_bottleneck, dtype):
775775
result = s.max(skipna=False)
776776
assert np.allclose(float(result), v[-1])
777777

778+
def test_mean_masked_overflow(self):
779+
# GH#48378
780+
val = 100_000_000_000_000_000
781+
n_elements = 100
782+
na = np.array([val] * n_elements)
783+
ser = Series([val] * n_elements, dtype="Int64")
784+
785+
result_numpy = np.mean(na)
786+
result_masked = ser.mean()
787+
assert result_masked - result_numpy == 0
788+
assert result_masked == 1e17
789+
778790
@pytest.mark.parametrize("dtype", ("m8[ns]", "m8[ns]", "M8[ns]", "M8[ns, UTC]"))
779791
@pytest.mark.parametrize("skipna", [True, False])
780792
def test_empty_timeseries_reductions_return_nat(self, dtype, skipna):

0 commit comments

Comments
 (0)