Skip to content

Commit ae37428

Browse files
committed
fix lexsort not working on object dtype for numpy version earlier than 1.11.0
add comment for using argsort with order argument instead of lexsort
1 parent 489f5ea commit ae37428

File tree

1 file changed

+23
-19
lines changed

1 file changed

+23
-19
lines changed

pandas/_libs/algos_rank_helper.pxi.in

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -44,26 +44,24 @@ def rank_1d_{{dtype}}(object in_arr, ties_method='average', ascending=True,
4444

4545
{{if dtype == 'object'}}
4646
ndarray sorted_data, values
47-
ndarray[np.uint8_t, cast=True] sorted_namask
48-
{{elif dtype != 'uint64'}}
49-
ndarray[{{ctype}}] sorted_data, values
50-
ndarray[np.uint8_t, cast=True] sorted_namask
5147
{{else}}
5248
ndarray[{{ctype}}] sorted_data, values
5349
{{endif}}
5450

5551
ndarray[float64_t] ranks
5652
ndarray[int64_t] argsorted
53+
ndarray[np.uint8_t, cast=True] sorted_mask
5754

5855
{{if dtype == 'uint64'}}
5956
{{ctype}} val
6057
{{else}}
61-
{{ctype}} val, nan_value, isnan
58+
{{ctype}} val, nan_value
6259
{{endif}}
6360

6461
float64_t sum_ranks = 0
6562
int tiebreak = 0
6663
bint keep_na = 0
64+
bint isnan
6765
float count = 0.0
6866
tiebreak = tiebreakers[ties_method]
6967

@@ -95,14 +93,24 @@ def rank_1d_{{dtype}}(object in_arr, ties_method='average', ascending=True,
9593
{{endif}}
9694

9795
np.putmask(values, mask, nan_value)
96+
{{else}}
97+
mask = np.zeros(shape=len(values), dtype=bool)
9898
{{endif}}
9999

100100
n = len(values)
101101
ranks = np.empty(n, dtype='f8')
102102

103103
{{if dtype == 'object'}}
104104
try:
105-
_as = values.argsort()
105+
# lexsort on object array will raise TypeError for numpy version
106+
# earlier than 1.11.0
107+
_dt = [('values', 'object'), ('mask', 'bool')]
108+
_values = np.asarray(list(zip(mask, values)), dtype=_dt)
109+
if ascending ^ (na_option == 'top'):
110+
_values = np.asarray(list(zip(values, mask)), dtype=_dt)
111+
else:
112+
_values = np.asarray(list(zip(values, (~mask))), dtype=_dt)
113+
_as = np.argsort(_values, kind='mergesort', order=('mask', 'values'))
106114
except TypeError:
107115
if not retry:
108116
raise
@@ -116,40 +124,37 @@ def rank_1d_{{dtype}}(object in_arr, ties_method='average', ascending=True,
116124
{{else}}
117125
if tiebreak == TIEBREAK_FIRST:
118126
# need to use a stable sort here
119-
_as = values.argsort(kind='mergesort')
127+
_as = np.lexsort(keys=(mask, values))
120128
if not ascending:
121129
tiebreak = TIEBREAK_FIRST_DESCENDING
122130
else:
123-
_as = values.argsort()
131+
_as = np.lexsort(keys=(mask, values))
124132
{{endif}}
125133

126134
if not ascending:
127135
_as = _as[::-1]
128136

129137
sorted_data = values.take(_as)
130138
# need to distinguish between pos/neg nan and real nan when keep_na is true
131-
{{if dtype != 'uint64'}}
132-
sorted_namask = mask.take(_as)
133-
sorted_namask = sorted_namask.astype(np.bool)
134-
{{endif}}
139+
sorted_mask = mask.take(_as)
135140
argsorted = _as.astype('i8')
136141

137142
{{if dtype == 'object'}}
138143
for i in range(n):
139144
sum_ranks += i + 1
140145
dups += 1
141-
isnan = sorted_namask[i]
146+
isnan = sorted_mask[i]
142147
val = util.get_value_at(sorted_data, i)
143148

144149
if isnan and keep_na:
145150
ranks[argsorted[i]] = nan
146-
sum_ranks = dups = 0
147151
continue
148152

149153
count += 1.0
150154

151155
if (i == n - 1 or
152-
are_diff(util.get_value_at(sorted_data, i + 1), val)):
156+
are_diff(util.get_value_at(sorted_data, i + 1), val) or
157+
sorted_mask[i + 1]):
153158
if tiebreak == TIEBREAK_AVERAGE:
154159
for j in range(i - dups + 1, i + 1):
155160
ranks[argsorted[j]] = sum_ranks / dups
@@ -178,16 +183,15 @@ def rank_1d_{{dtype}}(object in_arr, ties_method='average', ascending=True,
178183
val = sorted_data[i]
179184

180185
{{if dtype != 'uint64'}}
181-
isnan = sorted_namask[i]
182-
if isnan and keep_na:
186+
isnan = sorted_mask[i]
187+
if isnan and keep_na:
183188
ranks[argsorted[i]] = nan
184-
sum_ranks = dups = 0
185189
continue
186190
{{endif}}
187191

188192
count += 1.0
189193

190-
if i == n - 1 or sorted_data[i + 1] != val:
194+
if i == n - 1 or sorted_data[i + 1] != val or sorted_mask[i + 1]:
191195
if tiebreak == TIEBREAK_AVERAGE:
192196
for j in range(i - dups + 1, i + 1):
193197
ranks[argsorted[j]] = sum_ranks / dups

0 commit comments

Comments
 (0)