Skip to content

Commit 752ae2c

Browse files
committed
change the lexsort logic, first by mask/~mask then by values to ensure
to keep nans at front or at the end
1 parent c19970b commit 752ae2c

File tree

4 files changed

+91
-51
lines changed

4 files changed

+91
-51
lines changed

pandas/_libs/algos.pyx

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,17 @@ class Infinity(object):
6868
__le__ = lambda self, other: isinstance(other, Infinity)
6969
__eq__ = lambda self, other: isinstance(other, Infinity)
7070
__ne__ = lambda self, other: not isinstance(other, Infinity)
71-
__gt__ = lambda self, other: not isinstance(other, Infinity)
72-
__ge__ = lambda self, other: True
71+
__gt__ = lambda self, other: (not isinstance(other, Infinity) and
72+
not lib.checknull(other))
73+
__ge__ = lambda self, other: not lib.checknull(other)
7374

7475

7576
class NegInfinity(object):
7677
""" provide a negative Infinity comparision method for ranking """
7778

78-
__lt__ = lambda self, other: not isinstance(other, NegInfinity)
79-
__le__ = lambda self, other: True
79+
__lt__ = lambda self, other: (not isinstance(other, NegInfinity) and
80+
not lib.checknull(other))
81+
__le__ = lambda self, other: not lib.checknull(other)
8082
__eq__ = lambda self, other: isinstance(other, NegInfinity)
8183
__ne__ = lambda self, other: not isinstance(other, NegInfinity)
8284
__gt__ = lambda self, other: False

pandas/_libs/algos_rank_helper.pxi.in

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ dtypes = [('object', 'object', 'Infinity()', 'NegInfinity()'),
2727
{{if dtype == 'object'}}
2828

2929

30-
def rank_1d_{{dtype}}(object in_arr, bint retry=1, ties_method='average',
30+
def rank_1d_{{dtype}}(object in_arr, ties_method='average',
3131
ascending=True, na_option='keep', pct=False):
3232
{{else}}
3333

@@ -40,7 +40,7 @@ def rank_1d_{{dtype}}(object in_arr, ties_method='average', ascending=True,
4040
"""
4141

4242
cdef:
43-
Py_ssize_t i, j, n, dups = 0, total_tie_count = 0
43+
Py_ssize_t i, j, n, dups = 0, total_tie_count = 0, non_na_idx = 0
4444

4545
{{if dtype == 'object'}}
4646
ndarray sorted_data, values
@@ -78,12 +78,6 @@ def rank_1d_{{dtype}}(object in_arr, ties_method='average', ascending=True,
7878

7979
keep_na = na_option == 'keep'
8080

81-
{{if dtype != 'uint64'}}
82-
if ascending ^ (na_option == 'top'):
83-
nan_value = {{pos_nan_value}}
84-
else:
85-
nan_value = {{neg_nan_value}}
86-
8781
{{if dtype == 'object'}}
8882
mask = lib.isnaobj(values)
8983
{{elif dtype == 'float64'}}
@@ -92,49 +86,52 @@ def rank_1d_{{dtype}}(object in_arr, ties_method='average', ascending=True,
9286
mask = values == iNaT
9387
{{endif}}
9488

89+
# doulbe sort first by mask and then by values to ensure nan values are
90+
# either at the beginning or the end. mask/(~mask) controls padding at
91+
# tail or the head
92+
{{if dtype != 'uint64'}}
93+
if ascending ^ (na_option == 'top'):
94+
nan_value = {{pos_nan_value}}
95+
order = (values, mask)
96+
else:
97+
nan_value = {{neg_nan_value}}
98+
order = (values, ~mask)
9599
np.putmask(values, mask, nan_value)
96100
{{else}}
97101
mask = np.zeros(shape=len(values), dtype=bool)
102+
order = (values, mask)
98103
{{endif}}
99104

100105
n = len(values)
101106
ranks = np.empty(n, dtype='f8')
102107

103108
{{if dtype == 'object'}}
109+
104110
try:
105-
# lexsort on object array will raise TypeError for numpy version
106-
# earlier than 1.11.0
107-
_as = np.lexsort(keys=(mask, values))
111+
_as = np.lexsort(keys=order)
108112
except TypeError:
109-
try:
110-
_dt = [('values', 'O'), ('mask', '?')]
111-
_values = np.asarray(list(zip(values, mask)), dtype=_dt)
112-
_as = np.argsort(_values, kind='mergesort', order=('values', 'mask'))
113-
except TypeError:
114-
if not retry:
115-
raise
116-
valid_locs = (~mask).nonzero()[0]
117-
ranks.put(valid_locs, rank_1d_object(values.take(valid_locs), 0,
118-
ties_method=ties_method,
119-
ascending=ascending))
120-
np.putmask(ranks, mask, np.nan)
121-
return ranks
113+
# lexsort on object array will raise TypeError for numpy version
114+
# earlier than 1.11.0. Use argsort with order argument instead.
115+
_dt = [('values', 'O'), ('mask', '?')]
116+
_values = np.asarray(list(zip(order[0], order[1])), dtype=_dt)
117+
_as = np.argsort(_values, kind='mergesort', order=('mask', 'values'))
122118
{{else}}
123119
if tiebreak == TIEBREAK_FIRST:
124120
# need to use a stable sort here
125-
_as = np.lexsort(keys=(mask, values))
121+
_as = np.lexsort(keys=order)
126122
if not ascending:
127123
tiebreak = TIEBREAK_FIRST_DESCENDING
128124
else:
129-
_as = np.lexsort(keys=(mask, values))
125+
_as = np.lexsort(keys=order)
130126
{{endif}}
131127

132128
if not ascending:
133129
_as = _as[::-1]
134130

135131
sorted_data = values.take(_as)
136-
# need to distinguish between pos/neg nan and real nan when keep_na is true
137132
sorted_mask = mask.take(_as)
133+
_indices = order[1].take(_as).nonzero()[0]
134+
non_na_idx = _indices[0] if len(_indices) > 0 else -1
138135
argsorted = _as.astype('i8')
139136

140137
{{if dtype == 'object'}}
@@ -147,12 +144,11 @@ def rank_1d_{{dtype}}(object in_arr, ties_method='average', ascending=True,
147144
if isnan and keep_na:
148145
ranks[argsorted[i]] = nan
149146
continue
150-
151147
count += 1.0
152148

153149
if (i == n - 1 or
154150
are_diff(util.get_value_at(sorted_data, i + 1), val) or
155-
sorted_mask[i + 1]):
151+
i == non_na_idx - 1):
156152
if tiebreak == TIEBREAK_AVERAGE:
157153
for j in range(i - dups + 1, i + 1):
158154
ranks[argsorted[j]] = sum_ranks / dups
@@ -177,7 +173,6 @@ def rank_1d_{{dtype}}(object in_arr, ties_method='average', ascending=True,
177173
for i in range(n):
178174
sum_ranks += i + 1
179175
dups += 1
180-
181176
val = sorted_data[i]
182177

183178
{{if dtype != 'uint64'}}
@@ -189,7 +184,8 @@ def rank_1d_{{dtype}}(object in_arr, ties_method='average', ascending=True,
189184

190185
count += 1.0
191186

192-
if i == n - 1 or sorted_data[i + 1] != val or sorted_mask[i + 1]:
187+
if (i == n - 1 or sorted_data[i + 1] != val or
188+
i == non_na_idx - 1):
193189
if tiebreak == TIEBREAK_AVERAGE:
194190
for j in range(i - dups + 1, i + 1):
195191
ranks[argsorted[j]] = sum_ranks / dups

pandas/tests/series/test_rank.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,13 +224,18 @@ def test_rank_inf(self, contents, dtype):
224224
'int64': iNaT,
225225
'object': None
226226
}
227+
# Insert nans at random positions if underlying dtype has missing
228+
# value. Then adjust the expected order by adding nans accordingly
229+
# This is for testing whether rank calculation is affected
230+
# when values are interwined with nan values.
227231
values = np.array(contents, dtype=dtype)
228-
nan_indices = np.random.choice(range(len(values)), 3)
229232
exp_order = np.array(range(len(values)), dtype='float64') + 1.0
230233
if dtype in dtype_na_map:
231234
na_value = dtype_na_map[dtype]
235+
nan_indices = np.random.choice(range(len(values)), 5)
232236
values = np.insert(values, nan_indices, na_value)
233237
exp_order = np.insert(exp_order, nan_indices, np.nan)
238+
# shuffle the testing array and expected results in the same way
234239
random_order = np.random.permutation(len(values))
235240
iseries = Series(values[random_order])
236241
exp = Series(exp_order[random_order], dtype='float64')
@@ -254,6 +259,39 @@ def _check(s, expected, method='average'):
254259
series = s if dtype is None else s.astype(dtype)
255260
_check(series, results[method], method=method)
256261

262+
def test_rank_tie_methods_on_infs_nans(self):
263+
dtypes = [('object', None, Infinity(), NegInfinity()),
264+
('float64', np.nan, np.inf, -np.inf)]
265+
chunk = 3
266+
disabled = set([('object', 'first')])
267+
268+
def _check(s, expected, method='average', na_option='keep'):
269+
result = s.rank(method=method, na_option=na_option)
270+
tm.assert_series_equal(result, Series(expected, dtype='float64'))
271+
272+
exp_ranks = {
273+
'average': ([2, 2, 2], [5, 5, 5], [8, 8, 8]),
274+
'min': ([1, 1, 1], [4, 4, 4], [7, 7, 7]),
275+
'max': ([3, 3, 3], [6, 6, 6], [9, 9, 9]),
276+
'first': ([1, 2, 3], [4, 5, 6], [7, 8, 9]),
277+
'dense': ([1, 1, 1], [2, 2, 2], [3, 3, 3])
278+
}
279+
na_options = ('top', 'bottom', 'keep')
280+
for dtype, na_value, pos_inf, neg_inf in dtypes:
281+
in_arr = [neg_inf] * chunk + [na_value] * chunk + [pos_inf] * chunk
282+
iseries = Series(in_arr, dtype=dtype)
283+
for method, na_opt in product(exp_ranks.keys(), na_options):
284+
ranks = exp_ranks[method]
285+
if (dtype, method) in disabled:
286+
continue
287+
if na_opt == 'top':
288+
order = ranks[1] + ranks[0] + ranks[2]
289+
elif na_opt == 'bottom':
290+
order = ranks[0] + ranks[2] + ranks[1]
291+
else:
292+
order = ranks[0] + [np.nan] * chunk + ranks[1]
293+
_check(iseries, order, method, na_opt)
294+
257295
def test_rank_methods_series(self):
258296
pytest.importorskip('scipy.stats.special')
259297
rankdata = pytest.importorskip('scipy.stats.rankdata')

pandas/tests/test_algos.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,44 +1271,48 @@ def test_infinity_sort():
12711271
NegInf = libalgos.NegInfinity()
12721272

12731273
ref_nums = [NegInf, float("-inf"), -1e100, 0, 1e100, float("inf"), Inf]
1274-
ref_objects = [NegInf, 'A', 'Ac', 'B', 'C', 'D', 'a', 'z', Inf]
12751274

12761275
assert all(Inf >= x for x in ref_nums)
12771276
assert all(Inf > x or x is Inf for x in ref_nums)
12781277
assert Inf >= Inf and Inf == Inf
1279-
assert Inf >= libalgos.Infinity() and Inf == libalgos.Infinity()
12801278
assert not Inf < Inf and not Inf > Inf
1281-
assert not Inf < libalgos.Infinity() and not Inf > libalgos.Infinity()
12821279
assert libalgos.Infinity() == libalgos.Infinity()
12831280
assert not libalgos.Infinity() != libalgos.Infinity()
12841281

12851282
assert all(NegInf <= x for x in ref_nums)
12861283
assert all(NegInf < x or x is NegInf for x in ref_nums)
12871284
assert NegInf <= NegInf and NegInf == NegInf
1288-
assert (NegInf <= libalgos.NegInfinity() and
1289-
NegInf == libalgos.NegInfinity())
12901285
assert not NegInf < NegInf and not NegInf > NegInf
1291-
assert (not NegInf < libalgos.NegInfinity() and not
1292-
NegInf > libalgos.NegInfinity())
12931286
assert libalgos.NegInfinity() == libalgos.NegInfinity()
12941287
assert not libalgos.NegInfinity() != libalgos.NegInfinity()
12951288

1296-
assert Inf > np.nan and Inf >= np.nan and Inf != np.nan
1297-
assert not Inf < np.nan and not Inf <= np.nan
1298-
assert NegInf < np.nan and NegInf <= np.nan and NegInf != np.nan
1299-
assert (not NegInf > np.nan and not NegInf >= np.nan and not
1300-
NegInf == np.nan)
1301-
13021289
for perm in permutations(ref_nums):
13031290
assert sorted(perm) == ref_nums
1304-
for perm in permutations(ref_objects):
1305-
assert sorted(perm) == ref_objects
13061291

13071292
# smoke tests
13081293
np.array([libalgos.Infinity()] * 32).argsort()
13091294
np.array([libalgos.NegInfinity()] * 32).argsort()
13101295

13111296

1297+
def test_infinity_against_nan():
1298+
Inf = libalgos.Infinity()
1299+
NegInf = libalgos.NegInfinity()
1300+
1301+
assert not Inf > np.nan
1302+
assert not Inf >= np.nan
1303+
assert not Inf < np.nan
1304+
assert not Inf <= np.nan
1305+
assert not Inf == np.nan
1306+
assert Inf != np.nan
1307+
1308+
assert not NegInf > np.nan
1309+
assert not NegInf >= np.nan
1310+
assert not NegInf < np.nan
1311+
assert not NegInf <= np.nan
1312+
assert not NegInf == np.nan
1313+
assert NegInf != np.nan
1314+
1315+
13121316
def test_ensure_platform_int():
13131317
arr = np.arange(100, dtype=np.intp)
13141318

0 commit comments

Comments
 (0)