Skip to content

Commit 8b70567

Browse files
PF2100ruimamaral
andauthored
ENH: Add **kwargs to pivot_table to allow the specification of aggfunc keyword arguments #57884 (#58893)
Co-authored-by: Pedro Freitas <pedrogmfreitas@tecnico.ulisboa.pt> Co-authored-by: Rui Amaral <rui.miguel.amaral@tecnico.ulisboa.pt>
1 parent d6b7d5c commit 8b70567

File tree

4 files changed

+109
-16
lines changed

4 files changed

+109
-16
lines changed

doc/source/whatsnew/v3.0.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ Other enhancements
4242
- :meth:`DataFrame.corrwith` now accepts ``min_periods`` as optional arguments, as in :meth:`DataFrame.corr` and :meth:`Series.corr` (:issue:`9490`)
4343
- :meth:`DataFrame.cummin`, :meth:`DataFrame.cummax`, :meth:`DataFrame.cumprod` and :meth:`DataFrame.cumsum` methods now have a ``numeric_only`` parameter (:issue:`53072`)
4444
- :meth:`DataFrame.fillna` and :meth:`Series.fillna` can now accept ``value=None``; for non-object dtype the corresponding NA value will be used (:issue:`57723`)
45+
- :meth:`DataFrame.pivot_table` and :func:`pivot_table` now allow the passing of keyword arguments to ``aggfunc`` through ``**kwargs`` (:issue:`57884`)
4546
- :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`)
4647
- :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`)
4748
- Restore support for reading Stata 104-format and enable reading 103-format dta files (:issue:`58554`)

pandas/core/frame.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9275,6 +9275,11 @@ def pivot(
92759275
92769276
.. versionadded:: 1.3.0
92779277
9278+
**kwargs : dict
9279+
Optional keyword arguments to pass to ``aggfunc``.
9280+
9281+
.. versionadded:: 3.0.0
9282+
92789283
Returns
92799284
-------
92809285
DataFrame
@@ -9382,6 +9387,7 @@ def pivot_table(
93829387
margins_name: Level = "All",
93839388
observed: bool = True,
93849389
sort: bool = True,
9390+
**kwargs,
93859391
) -> DataFrame:
93869392
from pandas.core.reshape.pivot import pivot_table
93879393

@@ -9397,6 +9403,7 @@ def pivot_table(
93979403
margins_name=margins_name,
93989404
observed=observed,
93999405
sort=sort,
9406+
**kwargs,
94009407
)
94019408

94029409
def stack(

pandas/core/reshape/pivot.py

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def pivot_table(
6161
margins_name: Hashable = "All",
6262
observed: bool = True,
6363
sort: bool = True,
64+
**kwargs,
6465
) -> DataFrame:
6566
"""
6667
Create a spreadsheet-style pivot table as a DataFrame.
@@ -119,6 +120,11 @@ def pivot_table(
119120
120121
.. versionadded:: 1.3.0
121122
123+
**kwargs : dict
124+
Optional keyword arguments to pass to ``aggfunc``.
125+
126+
.. versionadded:: 3.0.0
127+
122128
Returns
123129
-------
124130
DataFrame
@@ -246,6 +252,7 @@ def pivot_table(
246252
margins_name=margins_name,
247253
observed=observed,
248254
sort=sort,
255+
kwargs=kwargs,
249256
)
250257
pieces.append(_table)
251258
keys.append(getattr(func, "__name__", func))
@@ -265,6 +272,7 @@ def pivot_table(
265272
margins_name,
266273
observed,
267274
sort,
275+
kwargs,
268276
)
269277
return table.__finalize__(data, method="pivot_table")
270278

@@ -281,6 +289,7 @@ def __internal_pivot_table(
281289
margins_name: Hashable,
282290
observed: bool,
283291
sort: bool,
292+
kwargs,
284293
) -> DataFrame:
285294
"""
286295
Helper of :func:`pandas.pivot_table` for any non-list ``aggfunc``.
@@ -323,7 +332,7 @@ def __internal_pivot_table(
323332
values = list(values)
324333

325334
grouped = data.groupby(keys, observed=observed, sort=sort, dropna=dropna)
326-
agged = grouped.agg(aggfunc)
335+
agged = grouped.agg(aggfunc, **kwargs)
327336

328337
if dropna and isinstance(agged, ABCDataFrame) and len(agged.columns):
329338
agged = agged.dropna(how="all")
@@ -378,6 +387,7 @@ def __internal_pivot_table(
378387
rows=index,
379388
cols=columns,
380389
aggfunc=aggfunc,
390+
kwargs=kwargs,
381391
observed=dropna,
382392
margins_name=margins_name,
383393
fill_value=fill_value,
@@ -403,6 +413,7 @@ def _add_margins(
403413
rows,
404414
cols,
405415
aggfunc,
416+
kwargs,
406417
observed: bool,
407418
margins_name: Hashable = "All",
408419
fill_value=None,
@@ -415,7 +426,7 @@ def _add_margins(
415426
if margins_name in table.index.get_level_values(level):
416427
raise ValueError(msg)
417428

418-
grand_margin = _compute_grand_margin(data, values, aggfunc, margins_name)
429+
grand_margin = _compute_grand_margin(data, values, aggfunc, kwargs, margins_name)
419430

420431
if table.ndim == 2:
421432
# i.e. DataFrame
@@ -436,7 +447,15 @@ def _add_margins(
436447

437448
elif values:
438449
marginal_result_set = _generate_marginal_results(
439-
table, data, values, rows, cols, aggfunc, observed, margins_name
450+
table,
451+
data,
452+
values,
453+
rows,
454+
cols,
455+
aggfunc,
456+
kwargs,
457+
observed,
458+
margins_name,
440459
)
441460
if not isinstance(marginal_result_set, tuple):
442461
return marginal_result_set
@@ -445,7 +464,7 @@ def _add_margins(
445464
# no values, and table is a DataFrame
446465
assert isinstance(table, ABCDataFrame)
447466
marginal_result_set = _generate_marginal_results_without_values(
448-
table, data, rows, cols, aggfunc, observed, margins_name
467+
table, data, rows, cols, aggfunc, kwargs, observed, margins_name
449468
)
450469
if not isinstance(marginal_result_set, tuple):
451470
return marginal_result_set
@@ -482,26 +501,26 @@ def _add_margins(
482501

483502

484503
def _compute_grand_margin(
485-
data: DataFrame, values, aggfunc, margins_name: Hashable = "All"
504+
data: DataFrame, values, aggfunc, kwargs, margins_name: Hashable = "All"
486505
):
487506
if values:
488507
grand_margin = {}
489508
for k, v in data[values].items():
490509
try:
491510
if isinstance(aggfunc, str):
492-
grand_margin[k] = getattr(v, aggfunc)()
511+
grand_margin[k] = getattr(v, aggfunc)(**kwargs)
493512
elif isinstance(aggfunc, dict):
494513
if isinstance(aggfunc[k], str):
495-
grand_margin[k] = getattr(v, aggfunc[k])()
514+
grand_margin[k] = getattr(v, aggfunc[k])(**kwargs)
496515
else:
497-
grand_margin[k] = aggfunc[k](v)
516+
grand_margin[k] = aggfunc[k](v, **kwargs)
498517
else:
499-
grand_margin[k] = aggfunc(v)
518+
grand_margin[k] = aggfunc(v, **kwargs)
500519
except TypeError:
501520
pass
502521
return grand_margin
503522
else:
504-
return {margins_name: aggfunc(data.index)}
523+
return {margins_name: aggfunc(data.index, **kwargs)}
505524

506525

507526
def _generate_marginal_results(
@@ -511,6 +530,7 @@ def _generate_marginal_results(
511530
rows,
512531
cols,
513532
aggfunc,
533+
kwargs,
514534
observed: bool,
515535
margins_name: Hashable = "All",
516536
):
@@ -524,7 +544,11 @@ def _all_key(key):
524544
return (key, margins_name) + ("",) * (len(cols) - 1)
525545

526546
if len(rows) > 0:
527-
margin = data[rows + values].groupby(rows, observed=observed).agg(aggfunc)
547+
margin = (
548+
data[rows + values]
549+
.groupby(rows, observed=observed)
550+
.agg(aggfunc, **kwargs)
551+
)
528552
cat_axis = 1
529553

530554
for key, piece in table.T.groupby(level=0, observed=observed):
@@ -549,7 +573,7 @@ def _all_key(key):
549573
table_pieces.append(piece)
550574
# GH31016 this is to calculate margin for each group, and assign
551575
# corresponded key as index
552-
transformed_piece = DataFrame(piece.apply(aggfunc)).T
576+
transformed_piece = DataFrame(piece.apply(aggfunc, **kwargs)).T
553577
if isinstance(piece.index, MultiIndex):
554578
# We are adding an empty level
555579
transformed_piece.index = MultiIndex.from_tuples(
@@ -579,7 +603,9 @@ def _all_key(key):
579603
margin_keys = table.columns
580604

581605
if len(cols) > 0:
582-
row_margin = data[cols + values].groupby(cols, observed=observed).agg(aggfunc)
606+
row_margin = (
607+
data[cols + values].groupby(cols, observed=observed).agg(aggfunc, **kwargs)
608+
)
583609
row_margin = row_margin.stack()
584610

585611
# GH#26568. Use names instead of indices in case of numeric names
@@ -598,6 +624,7 @@ def _generate_marginal_results_without_values(
598624
rows,
599625
cols,
600626
aggfunc,
627+
kwargs,
601628
observed: bool,
602629
margins_name: Hashable = "All",
603630
):
@@ -612,14 +639,16 @@ def _all_key():
612639
return (margins_name,) + ("",) * (len(cols) - 1)
613640

614641
if len(rows) > 0:
615-
margin = data.groupby(rows, observed=observed)[rows].apply(aggfunc)
642+
margin = data.groupby(rows, observed=observed)[rows].apply(
643+
aggfunc, **kwargs
644+
)
616645
all_key = _all_key()
617646
table[all_key] = margin
618647
result = table
619648
margin_keys.append(all_key)
620649

621650
else:
622-
margin = data.groupby(level=0, observed=observed).apply(aggfunc)
651+
margin = data.groupby(level=0, observed=observed).apply(aggfunc, **kwargs)
623652
all_key = _all_key()
624653
table[all_key] = margin
625654
result = table
@@ -630,7 +659,9 @@ def _all_key():
630659
margin_keys = table.columns
631660

632661
if len(cols):
633-
row_margin = data.groupby(cols, observed=observed)[cols].apply(aggfunc)
662+
row_margin = data.groupby(cols, observed=observed)[cols].apply(
663+
aggfunc, **kwargs
664+
)
634665
else:
635666
row_margin = Series(np.nan, index=result.columns)
636667

pandas/tests/reshape/test_pivot.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2058,6 +2058,60 @@ def test_pivot_string_as_func(self):
20582058
).rename_axis("A")
20592059
tm.assert_frame_equal(result, expected)
20602060

2061+
@pytest.mark.parametrize("kwargs", [{"a": 2}, {"a": 2, "b": 3}, {"b": 3, "a": 2}])
2062+
def test_pivot_table_kwargs(self, kwargs):
2063+
# GH#57884
2064+
def f(x, a, b=3):
2065+
return x.sum() * a + b
2066+
2067+
def g(x):
2068+
return f(x, **kwargs)
2069+
2070+
df = DataFrame(
2071+
{
2072+
"A": ["good", "bad", "good", "bad", "good"],
2073+
"B": ["one", "two", "one", "three", "two"],
2074+
"X": [2, 5, 4, 20, 10],
2075+
}
2076+
)
2077+
result = pivot_table(
2078+
df, index="A", columns="B", values="X", aggfunc=f, **kwargs
2079+
)
2080+
expected = pivot_table(df, index="A", columns="B", values="X", aggfunc=g)
2081+
tm.assert_frame_equal(result, expected)
2082+
2083+
@pytest.mark.parametrize(
2084+
"kwargs", [{}, {"b": 10}, {"a": 3}, {"a": 3, "b": 10}, {"b": 10, "a": 3}]
2085+
)
2086+
def test_pivot_table_kwargs_margin(self, data, kwargs):
2087+
# GH#57884
2088+
def f(x, a=5, b=7):
2089+
return (x.sum() + b) * a
2090+
2091+
def g(x):
2092+
return f(x, **kwargs)
2093+
2094+
result = data.pivot_table(
2095+
values="D",
2096+
index=["A", "B"],
2097+
columns="C",
2098+
aggfunc=f,
2099+
margins=True,
2100+
fill_value=0,
2101+
**kwargs,
2102+
)
2103+
2104+
expected = data.pivot_table(
2105+
values="D",
2106+
index=["A", "B"],
2107+
columns="C",
2108+
aggfunc=g,
2109+
margins=True,
2110+
fill_value=0,
2111+
)
2112+
2113+
tm.assert_frame_equal(result, expected)
2114+
20612115
@pytest.mark.parametrize(
20622116
"f, f_numpy",
20632117
[

0 commit comments

Comments
 (0)