Skip to content

Commit a87f1f9

Browse files
committed
API: Sum / Prod of all-NA and empty
Changes the sum of empty and all-NA to be 0. Changes the prod of empty and all-NA to be 1.
1 parent 15f6cdb commit a87f1f9

File tree

6 files changed

+178
-51
lines changed

6 files changed

+178
-51
lines changed

pandas/core/generic.py

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7310,7 +7310,8 @@ def _add_numeric_operations(cls):
73107310
@Substitution(outname='mad',
73117311
desc="Return the mean absolute deviation of the values "
73127312
"for the requested axis",
7313-
name1=name, name2=name2, axis_descr=axis_descr)
7313+
name1=name, name2=name2, axis_descr=axis_descr,
7314+
empty_is_na='')
73147315
@Appender(_num_doc)
73157316
def mad(self, axis=None, skipna=None, level=None):
73167317
if skipna is None:
@@ -7351,7 +7352,7 @@ def mad(self, axis=None, skipna=None, level=None):
73517352
@Substitution(outname='compounded',
73527353
desc="Return the compound percentage of the values for "
73537354
"the requested axis", name1=name, name2=name2,
7354-
axis_descr=axis_descr)
7355+
axis_descr=axis_descr, empty_is_na='')
73557356
@Appender(_num_doc)
73567357
def compound(self, axis=None, skipna=None, level=None):
73577358
if skipna is None:
@@ -7375,10 +7376,11 @@ def compound(self, axis=None, skipna=None, level=None):
73757376
lambda y, axis: np.maximum.accumulate(y, axis), "max",
73767377
-np.inf, np.nan)
73777378

7378-
cls.sum = _make_stat_function(
7379+
cls.sum = _make_empty_stat_function(
73797380
cls, 'sum', name, name2, axis_descr,
73807381
'Return the sum of the values for the requested axis',
7381-
nanops.nansum)
7382+
nanops.nansum,
7383+
empty_is_na=False)
73827384
cls.mean = _make_stat_function(
73837385
cls, 'mean', name, name2, axis_descr,
73847386
'Return the mean of the values for the requested axis',
@@ -7394,10 +7396,11 @@ def compound(self, axis=None, skipna=None, level=None):
73947396
"by N-1\n",
73957397
nanops.nankurt)
73967398
cls.kurtosis = cls.kurt
7397-
cls.prod = _make_stat_function(
7399+
cls.prod = _make_empty_stat_function(
73987400
cls, 'prod', name, name2, axis_descr,
73997401
'Return the product of the values for the requested axis',
7400-
nanops.nanprod)
7402+
nanops.nanprod,
7403+
empty_is_na=False)
74017404
cls.product = cls.prod
74027405
cls.median = _make_stat_function(
74037406
cls, 'median', name, name2, axis_descr,
@@ -7520,14 +7523,14 @@ def _doc_parms(cls):
75207523
----------
75217524
axis : %(axis_descr)s
75227525
skipna : boolean, default True
7523-
Exclude NA/null values. If an entire row/column is NA or empty, the result
7524-
will be NA
7526+
Exclude NA/null values before computing the result.
75257527
level : int or level name, default None
75267528
If the axis is a MultiIndex (hierarchical), count along a
75277529
particular level, collapsing into a %(name1)s
75287530
numeric_only : boolean, default None
75297531
Include only float, int, boolean columns. If None, will attempt to use
7530-
everything, then use only numeric data. Not implemented for Series.
7532+
everything, then use only numeric data. Not implemented for
7533+
Series.%(empty_is_na)s
75317534
75327535
Returns
75337536
-------
@@ -7584,7 +7587,7 @@ def _doc_parms(cls):
75847587
axis : %(axis_descr)s
75857588
skipna : boolean, default True
75867589
Exclude NA/null values. If an entire row/column is NA, the result
7587-
will be NA
7590+
will be NA.
75887591
75897592
Returns
75907593
-------
@@ -7598,16 +7601,45 @@ def _doc_parms(cls):
75987601
75997602
"""
76007603

7604+
_empty_is_na_doc = """
7605+
empty_is_na : bool, default False
7606+
The result of operating on an empty array should be NA. The default
7607+
behavior is for the sum of an empty array to be 0, and the product
7608+
of an empty array to be 1.
7609+
7610+
When ``skipna=True``, "empty" refers to whether or not the array
7611+
is empty after removing NAs. So operating on an all-NA array with
7612+
``skipna=True`` will be NA when ``empty_is_na`` is True.
7613+
"""
7614+
7615+
7616+
def _make_empty_stat_function(cls, name, name1, name2, axis_descr, desc, f,
7617+
empty_is_na=False):
7618+
@Substitution(outname=name, desc=desc, name1=name1, name2=name2,
7619+
axis_descr=axis_descr, empty_is_na=_empty_is_na_doc)
7620+
@Appender(_num_doc)
7621+
def stat_func(self, axis=None, skipna=True, level=None, numeric_only=None,
7622+
empty_is_na=empty_is_na, **kwargs):
7623+
nv.validate_stat_func(tuple(), kwargs, fname=name)
7624+
if axis is None:
7625+
axis = self._stat_axis_number
7626+
if level is not None:
7627+
return self._agg_by_level(name, axis=axis, level=level,
7628+
skipna=skipna, empty_is_na=empty_is_na)
7629+
return self._reduce(f, name, axis=axis, skipna=skipna,
7630+
numeric_only=numeric_only,
7631+
empty_is_na=empty_is_na)
7632+
7633+
return set_function_name(stat_func, name, cls)
7634+
76017635

76027636
def _make_stat_function(cls, name, name1, name2, axis_descr, desc, f):
76037637
@Substitution(outname=name, desc=desc, name1=name1, name2=name2,
7604-
axis_descr=axis_descr)
7638+
axis_descr=axis_descr, empty_is_na='')
76057639
@Appender(_num_doc)
7606-
def stat_func(self, axis=None, skipna=None, level=None, numeric_only=None,
7640+
def stat_func(self, axis=None, skipna=True, level=None, numeric_only=None,
76077641
**kwargs):
76087642
nv.validate_stat_func(tuple(), kwargs, fname=name)
7609-
if skipna is None:
7610-
skipna = True
76117643
if axis is None:
76127644
axis = self._stat_axis_number
76137645
if level is not None:

pandas/core/nanops.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ def f(values, axis=None, skipna=True, **kwds):
107107
if k not in kwds:
108108
kwds[k] = v
109109
try:
110-
if values.size == 0:
110+
# TODO: NaT
111+
if values.size == 0 and kwds.get('empty_is_na'):
111112

112113
# we either return np.nan or pd.NaT
113114
if is_numeric_dtype(values):
@@ -155,6 +156,7 @@ def _bn_ok_dtype(dt, name):
155156
# Bottleneck chokes on datetime64
156157
if (not is_object_dtype(dt) and not is_datetime_or_timedelta_dtype(dt)):
157158

159+
# TODO: handle this overflow
158160
# GH 15507
159161
# bottleneck does not properly upcast during the sum
160162
# so can overflow
@@ -163,6 +165,9 @@ def _bn_ok_dtype(dt, name):
163165
# further we also want to preserve NaN when all elements
164166
# are NaN, unlinke bottleneck/numpy which consider this
165167
# to be 0
168+
169+
# https://github.com/kwgoodman/bottleneck/issues/180
170+
# No upcast for boolean -> int
166171
if name in ['nansum', 'nanprod']:
167172
return False
168173

@@ -303,22 +308,21 @@ def nanall(values, axis=None, skipna=True):
303308

304309

305310
@disallow('M8')
306-
@bottleneck_switch()
307-
def nansum(values, axis=None, skipna=True):
311+
@bottleneck_switch(empty_is_na=False)
312+
def nansum(values, axis=None, skipna=True, empty_is_na=False):
308313
values, mask, dtype, dtype_max = _get_values(values, skipna, 0)
309314
dtype_sum = dtype_max
310315
if is_float_dtype(dtype):
311316
dtype_sum = dtype
312317
elif is_timedelta64_dtype(dtype):
313318
dtype_sum = np.float64
314319
the_sum = values.sum(axis, dtype=dtype_sum)
315-
the_sum = _maybe_null_out(the_sum, axis, mask)
320+
the_sum = _maybe_null_out(the_sum, axis, mask, empty_is_na)
316321

317322
return _wrap_results(the_sum, dtype)
318323

319324

320325
@disallow('M8')
321-
@bottleneck_switch()
322326
def nanmean(values, axis=None, skipna=True):
323327
values, mask, dtype, dtype_max = _get_values(values, skipna, 0)
324328

@@ -641,13 +645,15 @@ def nankurt(values, axis=None, skipna=True):
641645

642646

643647
@disallow('M8', 'm8')
644-
def nanprod(values, axis=None, skipna=True):
648+
@bottleneck_switch(empty_is_na=False)
649+
def nanprod(values, axis=None, skipna=True, empty_is_na=False):
645650
mask = isna(values)
646651
if skipna and not is_any_int_dtype(values):
647652
values = values.copy()
648653
values[mask] = 1
649654
result = values.prod(axis)
650-
return _maybe_null_out(result, axis, mask)
655+
656+
return _maybe_null_out(result, axis, mask, empty_is_na, unit=1.0)
651657

652658

653659
def _maybe_arg_null_out(result, axis, mask, skipna):
@@ -683,9 +689,13 @@ def _get_counts(mask, axis, dtype=float):
683689
return np.array(count, dtype=dtype)
684690

685691

686-
def _maybe_null_out(result, axis, mask):
692+
def _maybe_null_out(result, axis, mask, empty_is_na=True, unit=0.0):
687693
if axis is not None and getattr(result, 'ndim', False):
688694
null_mask = (mask.shape[axis] - mask.sum(axis)) == 0
695+
696+
if not empty_is_na:
697+
null_mask[result == unit] = False
698+
689699
if np.any(null_mask):
690700
if is_numeric_dtype(result):
691701
if np.iscomplexobj(result):
@@ -698,7 +708,7 @@ def _maybe_null_out(result, axis, mask):
698708
result[null_mask] = None
699709
elif result is not tslib.NaT:
700710
null_mask = mask.size - mask.sum()
701-
if null_mask == 0:
711+
if null_mask == 0.0 and empty_is_na:
702712
result = np.nan
703713

704714
return result

pandas/tests/frame/test_analytics.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -478,10 +478,11 @@ def test_nunique(self):
478478
Series({0: 1, 1: 3, 2: 2}))
479479

480480
def test_sum(self):
481-
self._check_stat_op('sum', np.sum, has_numeric_only=True)
481+
self._check_stat_op('sum', np.nansum, has_numeric_only=True,
482+
no_skipna_alternative=np.sum)
482483

483484
# mixed types (with upcasting happening)
484-
self._check_stat_op('sum', np.sum,
485+
self._check_stat_op('sum', np.nansum,
485486
frame=self.mixed_float.astype('float32'),
486487
has_numeric_only=True, check_dtype=False,
487488
check_less_precise=True)
@@ -753,7 +754,8 @@ def alt(x):
753754

754755
def _check_stat_op(self, name, alternative, frame=None, has_skipna=True,
755756
has_numeric_only=False, check_dtype=True,
756-
check_dates=False, check_less_precise=False):
757+
check_dates=False, check_less_precise=False,
758+
no_skipna_alternative=None):
757759
if frame is None:
758760
frame = self.frame
759761
# set some NAs
@@ -774,14 +776,20 @@ def _check_stat_op(self, name, alternative, frame=None, has_skipna=True,
774776
assert len(result)
775777

776778
if has_skipna:
777-
def skipna_wrapper(x):
778-
nona = x.dropna()
779-
if len(nona) == 0:
780-
return np.nan
781-
return alternative(nona)
779+
alt = no_skipna_alternative or alternative # e.g. sum / nansum
780+
781+
if no_skipna_alternative:
782+
def skipna_wrapper(x):
783+
return alternative(x.values)
784+
else:
785+
def skipna_wrapper(x):
786+
nona = x.dropna()
787+
if len(nona) == 0:
788+
return np.nan
789+
return alt(nona)
782790

783791
def wrapper(x):
784-
return alternative(x.values)
792+
return alt(x.values)
785793

786794
result0 = f(axis=0, skipna=False)
787795
result1 = f(axis=1, skipna=False)
@@ -793,7 +801,7 @@ def wrapper(x):
793801
check_dtype=False,
794802
check_less_precise=check_less_precise)
795803
else:
796-
skipna_wrapper = alternative
804+
skipna_wrapper =alternative
797805
wrapper = alternative
798806

799807
result0 = f(axis=0)
@@ -834,6 +842,12 @@ def wrapper(x):
834842
r0 = getattr(all_na, name)(axis=0)
835843
r1 = getattr(all_na, name)(axis=1)
836844
if name in ['sum', 'prod']:
845+
tm.assert_numpy_array_equal(r0.values, np.zeros_like(r0))
846+
tm.assert_numpy_array_equal(r1.values, np.zeros_like(r1))
847+
848+
if name in ['sum', 'prod']:
849+
r0 = getattr(all_na, name)(axis=0, skipna=False)
850+
r1 = getattr(all_na, name)(axis=1, skipna=False)
837851
assert np.isnan(r0).all()
838852
assert np.isnan(r1).all()
839853

pandas/tests/series/test_analytics.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,40 +33,46 @@ class TestSeriesAnalytics(TestData):
3333
@pytest.mark.parametrize("method", ["sum", "prod"])
3434
def test_empty(self, method, use_bottleneck):
3535

36+
if method == "sum":
37+
unit = 0
38+
else:
39+
unit = 1
3640
with pd.option_context("use_bottleneck", use_bottleneck):
37-
# GH 9422
38-
# treat all missing as NaN
41+
# GH 9422 / 18678
42+
# treat all missing as 0
3943
s = Series([])
4044
result = getattr(s, method)()
41-
assert isna(result)
45+
assert result == unit
4246

4347
result = getattr(s, method)(skipna=True)
44-
assert isna(result)
48+
assert result == unit
4549

4650
s = Series([np.nan])
4751
result = getattr(s, method)()
48-
assert isna(result)
52+
assert result == unit
4953

5054
result = getattr(s, method)(skipna=True)
51-
assert isna(result)
55+
assert result == unit
5256

5357
s = Series([np.nan, 1])
5458
result = getattr(s, method)()
55-
assert result == 1.0
59+
assert result == 1
5660

5761
s = Series([np.nan, 1])
5862
result = getattr(s, method)(skipna=True)
5963
assert result == 1.0
6064

6165
# GH #844 (changed in 9422)
6266
df = DataFrame(np.empty((10, 0)))
63-
assert (df.sum(1).isnull()).all()
67+
result = df.sum(1)
68+
expected = pd.Series(0, index=df.index, dtype='float64')
69+
tm.assert_series_equal(result, expected)
6470

6571
@pytest.mark.parametrize(
66-
"method", ['sum', 'mean', 'median', 'std', 'var'])
72+
"method", ['mean', 'median', 'std', 'var'])
6773
def test_ops_consistency_on_empty(self, method):
6874

69-
# GH 7869
75+
# GH 7869 / 18678
7076
# consistency on empty
7177

7278
# float
@@ -77,6 +83,19 @@ def test_ops_consistency_on_empty(self, method):
7783
result = getattr(Series(dtype='m8[ns]'), method)()
7884
assert result is pd.NaT
7985

86+
@pytest.mark.parametrize('method, unit', [
87+
('sum', 0),
88+
('prod', 1),
89+
])
90+
def test_ops_consistency_on_empty_sum_prod(self, method, unit):
91+
# GH 18678
92+
result = getattr(Series(dtype=float), method)()
93+
assert result == unit
94+
95+
if method == 'sum':
96+
result = getattr(Series(dtype='m8[ns]'), method)()
97+
assert result == pd.Timedelta(0)
98+
8099
def test_nansum_buglet(self):
81100
s = Series([1.0, np.nan], index=[0, 1])
82101
result = np.nansum(s)
@@ -111,7 +130,7 @@ def test_sum_overflow(self, use_bottleneck):
111130
assert np.allclose(float(result), v[-1])
112131

113132
def test_sum(self):
114-
self._check_stat_op('sum', np.sum, check_allna=True)
133+
self._check_stat_op('sum', np.nansum, check_allna=False)
115134

116135
def test_sum_inf(self):
117136
s = Series(np.random.randn(10))

pandas/tests/series/test_quantile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_quantile(self):
3838

3939
# GH7661
4040
result = Series([np.timedelta64('NaT')]).sum()
41-
assert result is pd.NaT
41+
assert result == pd.Timedelta(0)
4242

4343
msg = 'percentiles should all be in the interval \\[0, 1\\]'
4444
for invalid in [-1, 2, [0.5, -1], [0.5, 2]]:

0 commit comments

Comments
 (0)