Skip to content

Commit ada3f71

Browse files
committed
BUG: Fix pivot_table margins to include NaN groups when dropna=False
1 parent cfe54bd commit ada3f71

File tree

3 files changed

+47
-7
lines changed

3 files changed

+47
-7
lines changed

doc/source/whatsnew/v3.0.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,7 @@ Reshaping
844844
- Bug in :meth:`DataFrame.merge` when merging two :class:`DataFrame` on ``intc`` or ``uintc`` types on Windows (:issue:`60091`, :issue:`58713`)
845845
- Bug in :meth:`DataFrame.pivot_table` incorrectly subaggregating results when called without an ``index`` argument (:issue:`58722`)
846846
- Bug in :meth:`DataFrame.pivot_table` incorrectly ignoring the ``values`` argument when also supplied to the ``index`` or ``columns`` parameters (:issue:`57876`, :issue:`61292`)
847+
- Bug in :meth:`DataFrame.pivot_table` where ``margins=True`` did not correctly include groups with ``NaN`` values in the index or columns when ``dropna=False`` was explicitly passed. (:issue:`61509`)
847848
- Bug in :meth:`DataFrame.stack` with the new implementation where ``ValueError`` is raised when ``level=[]`` (:issue:`60740`)
848849
- Bug in :meth:`DataFrame.unstack` producing incorrect results when manipulating empty :class:`DataFrame` with an :class:`ExtentionDtype` (:issue:`59123`)
849850
- Bug in :meth:`concat` where concatenating DataFrame and Series with ``ignore_index = True`` drops the series name (:issue:`60723`, :issue:`56257`)

pandas/core/reshape/pivot.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@ def __internal_pivot_table(
396396
observed=dropna,
397397
margins_name=margins_name,
398398
fill_value=fill_value,
399+
dropna=dropna,
399400
)
400401

401402
# discard the top level
@@ -422,6 +423,7 @@ def _add_margins(
422423
observed: bool,
423424
margins_name: Hashable = "All",
424425
fill_value=None,
426+
dropna: bool = True,
425427
):
426428
if not isinstance(margins_name, str):
427429
raise ValueError("margins_name argument must be a string")
@@ -461,6 +463,7 @@ def _add_margins(
461463
kwargs,
462464
observed,
463465
margins_name,
466+
dropna,
464467
)
465468
if not isinstance(marginal_result_set, tuple):
466469
return marginal_result_set
@@ -469,7 +472,7 @@ def _add_margins(
469472
# no values, and table is a DataFrame
470473
assert isinstance(table, ABCDataFrame)
471474
marginal_result_set = _generate_marginal_results_without_values(
472-
table, data, rows, cols, aggfunc, kwargs, observed, margins_name
475+
table, data, rows, cols, aggfunc, kwargs, observed, margins_name, dropna
473476
)
474477
if not isinstance(marginal_result_set, tuple):
475478
return marginal_result_set
@@ -538,6 +541,7 @@ def _generate_marginal_results(
538541
kwargs,
539542
observed: bool,
540543
margins_name: Hashable = "All",
544+
dropna: bool = True,
541545
):
542546
margin_keys: list | Index
543547
if len(cols) > 0:
@@ -551,7 +555,7 @@ def _all_key(key):
551555
if len(rows) > 0:
552556
margin = (
553557
data[rows + values]
554-
.groupby(rows, observed=observed)
558+
.groupby(rows, observed=observed, dropna=dropna)
555559
.agg(aggfunc, **kwargs)
556560
)
557561
cat_axis = 1
@@ -567,7 +571,7 @@ def _all_key(key):
567571
else:
568572
margin = (
569573
data[cols[:1] + values]
570-
.groupby(cols[:1], observed=observed)
574+
.groupby(cols[:1], observed=observed, dropna=dropna)
571575
.agg(aggfunc, **kwargs)
572576
.T
573577
)
@@ -610,7 +614,9 @@ def _all_key(key):
610614

611615
if len(cols) > 0:
612616
row_margin = (
613-
data[cols + values].groupby(cols, observed=observed).agg(aggfunc, **kwargs)
617+
data[cols + values]
618+
.groupby(cols, observed=observed, dropna=dropna)
619+
.agg(aggfunc, **kwargs)
614620
)
615621
row_margin = row_margin.stack()
616622

@@ -633,6 +639,7 @@ def _generate_marginal_results_without_values(
633639
kwargs,
634640
observed: bool,
635641
margins_name: Hashable = "All",
642+
dropna: bool = True,
636643
):
637644
margin_keys: list | Index
638645
if len(cols) > 0:
@@ -645,7 +652,7 @@ def _all_key():
645652
return (margins_name,) + ("",) * (len(cols) - 1)
646653

647654
if len(rows) > 0:
648-
margin = data.groupby(rows, observed=observed)[rows].apply(
655+
margin = data.groupby(rows, observed=observed, dropna=dropna)[rows].apply(
649656
aggfunc, **kwargs
650657
)
651658
all_key = _all_key()
@@ -654,7 +661,9 @@ def _all_key():
654661
margin_keys.append(all_key)
655662

656663
else:
657-
margin = data.groupby(level=0, observed=observed).apply(aggfunc, **kwargs)
664+
margin = data.groupby(level=0, observed=observed, dropna=dropna).apply(
665+
aggfunc, **kwargs
666+
)
658667
all_key = _all_key()
659668
table[all_key] = margin
660669
result = table
@@ -665,7 +674,7 @@ def _all_key():
665674
margin_keys = table.columns
666675

667676
if len(cols):
668-
row_margin = data.groupby(cols, observed=observed)[cols].apply(
677+
row_margin = data.groupby(cols, observed=observed, dropna=dropna)[cols].apply(
669678
aggfunc, **kwargs
670679
)
671680
else:

pandas/tests/reshape/test_pivot.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2594,6 +2594,36 @@ def test_pivot_table_values_as_two_params(
25942594
expected = DataFrame(data=e_data, index=e_index, columns=e_cols)
25952595
tm.assert_frame_equal(result, expected)
25962596

2597+
def test_pivot_table_margins_include_nan_groups(self):
2598+
# GH#61509
2599+
df = DataFrame(
2600+
{
2601+
"i": [1, 2, 3],
2602+
"g1": ["a", "b", "b"],
2603+
"g2": ["x", None, None],
2604+
}
2605+
)
2606+
2607+
result = df.pivot_table(
2608+
index="g1",
2609+
columns="g2",
2610+
values="i",
2611+
aggfunc="count",
2612+
dropna=False,
2613+
margins=True,
2614+
)
2615+
2616+
expected = DataFrame(
2617+
{
2618+
"x": {"a": 1.0, "b": np.nan, "All": 1.0},
2619+
np.nan: {"a": np.nan, "b": 2.0, "All": 2.0},
2620+
"All": {"a": 1.0, "b": 2.0, "All": 3.0},
2621+
}
2622+
)
2623+
expected.index.name = "g1"
2624+
expected.columns.name = "g2"
2625+
tm.assert_frame_equal(result, expected, check_dtype=False)
2626+
25972627

25982628
class TestPivot:
25992629
def test_pivot(self):

0 commit comments

Comments
 (0)