Skip to content

Commit b4c3dfd

Browse files
committed
Working avg tiebreak with nan handling
1 parent 5b295a2 commit b4c3dfd

File tree

4 files changed

+34
-17
lines changed

4 files changed

+34
-17
lines changed

pandas/_libs/algos.pxd

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,11 @@ cdef inline Py_ssize_t swap(numeric *a, numeric *b) nogil:
1111
a[0] = b[0]
1212
b[0] = t
1313
return 0
14+
15+
cdef:
16+
int TIEBREAK_AVERAGE = 0
17+
int TIEBREAK_MIN = 1
18+
int TIEBREAK_MAX = 2
19+
int TIEBREAK_FIRST = 3
20+
int TIEBREAK_FIRST_DESCENDING = 4
21+
int TIEBREAK_DENSE = 5

pandas/_libs/groupby.pyx

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@ from numpy cimport (ndarray,
1313
int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t,
1414
uint32_t, uint64_t, float32_t, float64_t)
1515

16+
from libc.math cimport isnan
1617
from libc.stdlib cimport malloc, free
1718

1819
from util cimport numeric, get_nat
19-
from algos cimport swap
20-
from algos import take_2d_axis1_float64_float64, groupsort_indexer
20+
from algos cimport (swap, TIEBREAK_AVERAGE, TIEBREAK_MIN, TIEBREAK_MAX,
21+
TIEBREAK_FIRST, TIEBREAK_DENSE)
22+
from algos import take_2d_axis1_float64_float64, groupsort_indexer, tiebreakers
2123

2224
cdef int64_t iNaT = get_nat()
2325

pandas/_libs/groupby_helper.pxi.in

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,6 @@ def group_nth_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
444444
else:
445445
out[i, j] = resx[i, j]
446446

447-
448447
@cython.boundscheck(False)
449448
@cython.wraparound(False)
450449
def group_rank_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
@@ -455,27 +454,35 @@ def group_rank_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
455454
Only transforms on axis=0
456455
"""
457456
cdef:
457+
int tiebreak
458458
Py_ssize_t i, j, N, K
459-
int64_t lab, idx, counter=1
459+
int64_t val_start=0, grp_start=0, dups=0, sum_ranks=0
460460
ndarray[int64_t] _as
461461

462+
tiebreak = tiebreakers[kwargs['ties_method']]
462463
N, K = (<object> values).shape
463464

464465
_as = np.lexsort((values[:, 0], labels))
465466

466467
with nogil:
467468
for i in range(N):
468-
idx = _as[i]
469-
lab = labels[idx]
470-
if i > 0 and lab == labels[_as[i-1]]:
471-
counter += 1
472-
else:
473-
counter = 1
474-
if lab < 0:
475-
continue
476-
477-
for j in range(K):
478-
out[idx, j] = counter
469+
dups += 1
470+
sum_ranks += i - grp_start + 1
471+
472+
if tiebreak == TIEBREAK_AVERAGE:
473+
for j in range(i - dups + 1, i + 1):
474+
out[_as[j], 0] = sum_ranks / dups
475+
476+
if (i == N - 1 or (
477+
(values[_as[i], 0] != values[_as[i+1], 0]) and not
478+
(isnan(values[_as[i], 0]) and
479+
isnan(values[_as[i+1], 0])
480+
))):
481+
dups = sum_ranks = 0
482+
val_start = i
483+
484+
if i == 0 or labels[_as[i]] != labels[_as[i-1]]:
485+
grp_start = i
479486

480487
{{endfor}}
481488

pandas/core/groupby.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1770,10 +1770,10 @@ def cumcount(self, ascending=True):
17701770

17711771
@Substitution(name='groupby')
17721772
@Appender(_doc_template)
1773-
def rank(self, ties_method='average', ascending=True, na_option='keep',
1773+
def rank(self, method='average', ascending=True, na_option='keep',
17741774
pct=False, axis=0):
17751775
"""Rank within each group"""
1776-
return self._cython_transform('rank', ties_method=ties_method,
1776+
return self._cython_transform('rank', ties_method=method,
17771777
ascending=ascending, na_option=na_option,
17781778
pct=pct, axis=axis)
17791779

0 commit comments

Comments
 (0)