Skip to content

Commit 67ca634

Browse files
committed
Allowed kwargs to pass through to Cython func
1 parent a8d9b00 commit 67ca634

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

pandas/_libs/groupby_helper.pxi.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def group_nth_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
450450
def group_rank_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
451451
ndarray[{{c_type}}, ndim=2] values,
452452
ndarray[int64_t] labels,
453-
bint is_datetimelike):
453+
bint is_datetimelike, **kwargs):
454454
"""
455455
Only transforms on axis=0
456456
"""

pandas/core/groupby.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -994,15 +994,15 @@ def _transform_should_cast(self, func_nm):
994994
return (self.size().fillna(0) > 0).any() and (func_nm not in
995995
_cython_cast_blacklist)
996996

997-
def _cython_transform(self, how, numeric_only=True):
997+
def _cython_transform(self, how, numeric_only=True, **kwargs):
998998
output = collections.OrderedDict()
999999
for name, obj in self._iterate_slices():
10001000
is_numeric = is_numeric_dtype(obj.dtype)
10011001
if numeric_only and not is_numeric:
10021002
continue
10031003

10041004
try:
1005-
result, names = self.grouper.transform(obj.values, how)
1005+
result, names = self.grouper.transform(obj.values, how, **kwargs)
10061006
except NotImplementedError:
10071007
continue
10081008
except AssertionError as e:
@@ -1770,9 +1770,12 @@ def cumcount(self, ascending=True):
17701770

17711771
@Substitution(name='groupby')
17721772
@Appender(_doc_template)
1773-
def rank(self, axis=0, *args, **kwargs):
1773+
def rank(self, ties_method='average', ascending=True, na_option='keep',
1774+
pct=False, axis=0):
17741775
"""Rank within each group"""
1775-
return self._cython_transform('rank', **kwargs)
1776+
return self._cython_transform('rank', ties_method=ties_method,
1777+
ascending=ascending, na_option=na_option,
1778+
pct=pct, axis=axis)
17761779

17771780
@Substitution(name='groupby')
17781781
@Appender(_doc_template)
@@ -2249,7 +2252,8 @@ def wrapper(*args, **kwargs):
22492252
(how, dtype_str))
22502253
return func, dtype_str
22512254

2252-
def _cython_operation(self, kind, values, how, axis, min_count=-1):
2255+
def _cython_operation(self, kind, values, how, axis, min_count=-1,
2256+
**kwargs):
22532257
assert kind in ['transform', 'aggregate']
22542258

22552259
# can we do this operation with our cython functions
@@ -2341,7 +2345,8 @@ def _cython_operation(self, kind, values, how, axis, min_count=-1):
23412345

23422346
# TODO: min_count
23432347
result = self._transform(
2344-
result, values, labels, func, is_numeric, is_datetimelike)
2348+
result, values, labels, func, is_numeric, is_datetimelike,
2349+
**kwargs)
23452350

23462351
if is_integer_dtype(result):
23472352
mask = result == iNaT
@@ -2380,8 +2385,8 @@ def aggregate(self, values, how, axis=0, min_count=-1):
23802385
return self._cython_operation('aggregate', values, how, axis,
23812386
min_count=min_count)
23822387

2383-
def transform(self, values, how, axis=0):
2384-
return self._cython_operation('transform', values, how, axis)
2388+
def transform(self, values, how, axis=0, **kwargs):
2389+
return self._cython_operation('transform', values, how, axis, **kwargs)
23852390

23862391
def _aggregate(self, result, counts, values, comp_ids, agg_func,
23872392
is_numeric, is_datetimelike, min_count=-1):
@@ -2401,7 +2406,7 @@ def _aggregate(self, result, counts, values, comp_ids, agg_func,
24012406
return result
24022407

24032408
def _transform(self, result, values, comp_ids, transform_func,
2404-
is_numeric, is_datetimelike):
2409+
is_numeric, is_datetimelike, **kwargs):
24052410

24062411
comp_ids, _, ngroups = self.group_info
24072412
if values.ndim > 3:
@@ -2415,7 +2420,7 @@ def _transform(self, result, values, comp_ids, transform_func,
24152420
transform_func(result[:, :, i], values,
24162421
comp_ids, is_datetimelike)
24172422
else:
2418-
transform_func(result, values, comp_ids, is_datetimelike)
2423+
transform_func(result, values, comp_ids, is_datetimelike, **kwargs)
24192424

24202425
return result
24212426

0 commit comments

Comments
 (0)