@@ -61,6 +61,7 @@ def pivot_table(
61
61
margins_name : Hashable = "All" ,
62
62
observed : bool = True ,
63
63
sort : bool = True ,
64
+ ** kwargs ,
64
65
) -> DataFrame :
65
66
"""
66
67
Create a spreadsheet-style pivot table as a DataFrame.
@@ -119,6 +120,11 @@ def pivot_table(
119
120
120
121
.. versionadded:: 1.3.0
121
122
123
+ **kwargs : dict
124
+ Optional keyword arguments to pass to ``aggfunc``.
125
+
126
+ .. versionadded:: 3.0.0
127
+
122
128
Returns
123
129
-------
124
130
DataFrame
@@ -246,6 +252,7 @@ def pivot_table(
246
252
margins_name = margins_name ,
247
253
observed = observed ,
248
254
sort = sort ,
255
+ kwargs = kwargs ,
249
256
)
250
257
pieces .append (_table )
251
258
keys .append (getattr (func , "__name__" , func ))
@@ -265,6 +272,7 @@ def pivot_table(
265
272
margins_name ,
266
273
observed ,
267
274
sort ,
275
+ kwargs ,
268
276
)
269
277
return table .__finalize__ (data , method = "pivot_table" )
270
278
@@ -281,6 +289,7 @@ def __internal_pivot_table(
281
289
margins_name : Hashable ,
282
290
observed : bool ,
283
291
sort : bool ,
292
+ kwargs ,
284
293
) -> DataFrame :
285
294
"""
286
295
Helper of :func:`pandas.pivot_table` for any non-list ``aggfunc``.
@@ -323,7 +332,7 @@ def __internal_pivot_table(
323
332
values = list (values )
324
333
325
334
grouped = data .groupby (keys , observed = observed , sort = sort , dropna = dropna )
326
- agged = grouped .agg (aggfunc )
335
+ agged = grouped .agg (aggfunc , ** kwargs )
327
336
328
337
if dropna and isinstance (agged , ABCDataFrame ) and len (agged .columns ):
329
338
agged = agged .dropna (how = "all" )
@@ -378,6 +387,7 @@ def __internal_pivot_table(
378
387
rows = index ,
379
388
cols = columns ,
380
389
aggfunc = aggfunc ,
390
+ kwargs = kwargs ,
381
391
observed = dropna ,
382
392
margins_name = margins_name ,
383
393
fill_value = fill_value ,
@@ -403,6 +413,7 @@ def _add_margins(
403
413
rows ,
404
414
cols ,
405
415
aggfunc ,
416
+ kwargs ,
406
417
observed : bool ,
407
418
margins_name : Hashable = "All" ,
408
419
fill_value = None ,
@@ -415,7 +426,7 @@ def _add_margins(
415
426
if margins_name in table .index .get_level_values (level ):
416
427
raise ValueError (msg )
417
428
418
- grand_margin = _compute_grand_margin (data , values , aggfunc , margins_name )
429
+ grand_margin = _compute_grand_margin (data , values , aggfunc , kwargs , margins_name )
419
430
420
431
if table .ndim == 2 :
421
432
# i.e. DataFrame
@@ -436,7 +447,15 @@ def _add_margins(
436
447
437
448
elif values :
438
449
marginal_result_set = _generate_marginal_results (
439
- table , data , values , rows , cols , aggfunc , observed , margins_name
450
+ table ,
451
+ data ,
452
+ values ,
453
+ rows ,
454
+ cols ,
455
+ aggfunc ,
456
+ kwargs ,
457
+ observed ,
458
+ margins_name ,
440
459
)
441
460
if not isinstance (marginal_result_set , tuple ):
442
461
return marginal_result_set
@@ -445,7 +464,7 @@ def _add_margins(
445
464
# no values, and table is a DataFrame
446
465
assert isinstance (table , ABCDataFrame )
447
466
marginal_result_set = _generate_marginal_results_without_values (
448
- table , data , rows , cols , aggfunc , observed , margins_name
467
+ table , data , rows , cols , aggfunc , kwargs , observed , margins_name
449
468
)
450
469
if not isinstance (marginal_result_set , tuple ):
451
470
return marginal_result_set
@@ -482,26 +501,26 @@ def _add_margins(
482
501
483
502
484
503
def _compute_grand_margin (
485
- data : DataFrame , values , aggfunc , margins_name : Hashable = "All"
504
+ data : DataFrame , values , aggfunc , kwargs , margins_name : Hashable = "All"
486
505
):
487
506
if values :
488
507
grand_margin = {}
489
508
for k , v in data [values ].items ():
490
509
try :
491
510
if isinstance (aggfunc , str ):
492
- grand_margin [k ] = getattr (v , aggfunc )()
511
+ grand_margin [k ] = getattr (v , aggfunc )(** kwargs )
493
512
elif isinstance (aggfunc , dict ):
494
513
if isinstance (aggfunc [k ], str ):
495
- grand_margin [k ] = getattr (v , aggfunc [k ])()
514
+ grand_margin [k ] = getattr (v , aggfunc [k ])(** kwargs )
496
515
else :
497
- grand_margin [k ] = aggfunc [k ](v )
516
+ grand_margin [k ] = aggfunc [k ](v , ** kwargs )
498
517
else :
499
- grand_margin [k ] = aggfunc (v )
518
+ grand_margin [k ] = aggfunc (v , ** kwargs )
500
519
except TypeError :
501
520
pass
502
521
return grand_margin
503
522
else :
504
- return {margins_name : aggfunc (data .index )}
523
+ return {margins_name : aggfunc (data .index , ** kwargs )}
505
524
506
525
507
526
def _generate_marginal_results (
@@ -511,6 +530,7 @@ def _generate_marginal_results(
511
530
rows ,
512
531
cols ,
513
532
aggfunc ,
533
+ kwargs ,
514
534
observed : bool ,
515
535
margins_name : Hashable = "All" ,
516
536
):
@@ -524,7 +544,11 @@ def _all_key(key):
524
544
return (key , margins_name ) + ("" ,) * (len (cols ) - 1 )
525
545
526
546
if len (rows ) > 0 :
527
- margin = data [rows + values ].groupby (rows , observed = observed ).agg (aggfunc )
547
+ margin = (
548
+ data [rows + values ]
549
+ .groupby (rows , observed = observed )
550
+ .agg (aggfunc , ** kwargs )
551
+ )
528
552
cat_axis = 1
529
553
530
554
for key , piece in table .T .groupby (level = 0 , observed = observed ):
@@ -549,7 +573,7 @@ def _all_key(key):
549
573
table_pieces .append (piece )
550
574
# GH31016 this is to calculate margin for each group, and assign
551
575
# corresponded key as index
552
- transformed_piece = DataFrame (piece .apply (aggfunc )).T
576
+ transformed_piece = DataFrame (piece .apply (aggfunc , ** kwargs )).T
553
577
if isinstance (piece .index , MultiIndex ):
554
578
# We are adding an empty level
555
579
transformed_piece .index = MultiIndex .from_tuples (
@@ -579,7 +603,9 @@ def _all_key(key):
579
603
margin_keys = table .columns
580
604
581
605
if len (cols ) > 0 :
582
- row_margin = data [cols + values ].groupby (cols , observed = observed ).agg (aggfunc )
606
+ row_margin = (
607
+ data [cols + values ].groupby (cols , observed = observed ).agg (aggfunc , ** kwargs )
608
+ )
583
609
row_margin = row_margin .stack ()
584
610
585
611
# GH#26568. Use names instead of indices in case of numeric names
@@ -598,6 +624,7 @@ def _generate_marginal_results_without_values(
598
624
rows ,
599
625
cols ,
600
626
aggfunc ,
627
+ kwargs ,
601
628
observed : bool ,
602
629
margins_name : Hashable = "All" ,
603
630
):
@@ -612,14 +639,16 @@ def _all_key():
612
639
return (margins_name ,) + ("" ,) * (len (cols ) - 1 )
613
640
614
641
if len (rows ) > 0 :
615
- margin = data .groupby (rows , observed = observed )[rows ].apply (aggfunc )
642
+ margin = data .groupby (rows , observed = observed )[rows ].apply (
643
+ aggfunc , ** kwargs
644
+ )
616
645
all_key = _all_key ()
617
646
table [all_key ] = margin
618
647
result = table
619
648
margin_keys .append (all_key )
620
649
621
650
else :
622
- margin = data .groupby (level = 0 , observed = observed ).apply (aggfunc )
651
+ margin = data .groupby (level = 0 , observed = observed ).apply (aggfunc , ** kwargs )
623
652
all_key = _all_key ()
624
653
table [all_key ] = margin
625
654
result = table
@@ -630,7 +659,9 @@ def _all_key():
630
659
margin_keys = table .columns
631
660
632
661
if len (cols ):
633
- row_margin = data .groupby (cols , observed = observed )[cols ].apply (aggfunc )
662
+ row_margin = data .groupby (cols , observed = observed )[cols ].apply (
663
+ aggfunc , ** kwargs
664
+ )
634
665
else :
635
666
row_margin = Series (np .nan , index = result .columns )
636
667
0 commit comments