Skip to content

Commit 609c3b7

Browse files
authored
PERF: lexsort_indexer (MultiIndex / multi-column sorting) (#54835)
* lexsort_indexer perf * whatsnew * mypy * mypy * use generator * mypy
1 parent d60c7d2 commit 609c3b7

File tree

4 files changed

+37
-98
lines changed

4 files changed

+37
-98
lines changed

asv_bench/benchmarks/frame_methods.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -693,20 +693,24 @@ def time_frame_sort_values(self, ascending):
693693
self.df.sort_values(by="A", ascending=ascending)
694694

695695

696-
class SortIndexByColumns:
696+
class SortMultiKey:
697697
def setup(self):
698698
N = 10000
699699
K = 10
700-
self.df = DataFrame(
700+
self.df_by_columns = DataFrame(
701701
{
702702
"key1": tm.makeStringIndex(N).values.repeat(K),
703703
"key2": tm.makeStringIndex(N).values.repeat(K),
704704
"value": np.random.randn(N * K),
705705
}
706706
)
707+
self.df_by_index = self.df_by_columns.set_index(["key1", "key2"])
708+
709+
def time_sort_values(self):
710+
self.df_by_columns.sort_values(by=["key1", "key2"])
707711

708-
def time_frame_sort_values_by_columns(self):
709-
self.df.sort_values(by=["key1", "key2"])
712+
def time_sort_index(self):
713+
self.df_by_index.sort_index()
710714

711715

712716
class Quantile:

doc/source/whatsnew/v2.2.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ Deprecations
157157

158158
Performance improvements
159159
~~~~~~~~~~~~~~~~~~~~~~~~
160+
- Performance improvement in :meth:`DataFrame.sort_index` and :meth:`Series.sort_index` when indexed by a :class:`MultiIndex` (:issue:`54835`)
160161
- Performance improvement when indexing with more than 4 keys (:issue:`54550`)
161162
-
162163

pandas/core/indexes/multi.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2208,17 +2208,9 @@ def append(self, other):
22082208
def argsort(
22092209
self, *args, na_position: str = "last", **kwargs
22102210
) -> npt.NDArray[np.intp]:
2211-
if len(args) == 0 and len(kwargs) == 0:
2212-
# lexsort is significantly faster than self._values.argsort()
2213-
target = self._sort_levels_monotonic(raise_if_incomparable=True)
2214-
return lexsort_indexer(
2215-
# error: Argument 1 to "lexsort_indexer" has incompatible type
2216-
# "List[Categorical]"; expected "Union[List[Union[ExtensionArray,
2217-
# ndarray[Any, Any]]], List[Series]]"
2218-
target._get_codes_for_sorting(), # type: ignore[arg-type]
2219-
na_position=na_position,
2220-
)
2221-
return self._values.argsort(*args, **kwargs)
2211+
target = self._sort_levels_monotonic(raise_if_incomparable=True)
2212+
keys = [lev.codes for lev in target._get_codes_for_sorting()]
2213+
return lexsort_indexer(keys, na_position=na_position, codes_given=True)
22222214

22232215
@Appender(_index_shared_docs["repeat"] % _index_doc_kwargs)
22242216
def repeat(self, repeats: int, axis=None) -> MultiIndex:

pandas/core/sorting.py

Lines changed: 25 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,9 @@ def get_indexer_indexer(
9999
na_position=na_position,
100100
)
101101
elif isinstance(target, ABCMultiIndex):
102+
codes = [lev.codes for lev in target._get_codes_for_sorting()]
102103
indexer = lexsort_indexer(
103-
target.codes, orders=ascending, na_position=na_position, codes_given=True
104+
codes, orders=ascending, na_position=na_position, codes_given=True
104105
)
105106
else:
106107
# Check monotonic-ness before sort an index (GH 11080)
@@ -298,22 +299,8 @@ def decons_obs_group_ids(
298299
return [lab[indexer].astype(np.intp, subok=False, copy=True) for lab in labels]
299300

300301

301-
def indexer_from_factorized(
302-
labels, shape: Shape, compress: bool = True
303-
) -> npt.NDArray[np.intp]:
304-
ids = get_group_index(labels, shape, sort=True, xnull=False)
305-
306-
if not compress:
307-
ngroups = (ids.size and ids.max()) + 1
308-
else:
309-
ids, obs = compress_group_index(ids, sort=True)
310-
ngroups = len(obs)
311-
312-
return get_group_index_sorter(ids, ngroups)
313-
314-
315302
def lexsort_indexer(
316-
keys: list[ArrayLike] | list[Series],
303+
keys: Sequence[ArrayLike | Index | Series],
317304
orders=None,
318305
na_position: str = "last",
319306
key: Callable | None = None,
@@ -324,9 +311,9 @@ def lexsort_indexer(
324311
325312
Parameters
326313
----------
327-
keys : list[ArrayLike] | list[Series]
328-
Sequence of ndarrays to be sorted by the indexer
329-
list[Series] is only if key is not None.
314+
keys : Sequence[ArrayLike | Index | Series]
315+
Sequence of arrays to be sorted by the indexer
316+
Sequence[Series] is only if key is not None.
330317
orders : bool or list of booleans, optional
331318
Determines the sorting order for each element in keys. If a list,
332319
it must be the same length as keys. This determines whether the
@@ -346,83 +333,38 @@ def lexsort_indexer(
346333
"""
347334
from pandas.core.arrays import Categorical
348335

349-
labels = []
350-
shape = []
336+
if na_position not in ["last", "first"]:
337+
raise ValueError(f"invalid na_position: {na_position}")
338+
351339
if isinstance(orders, bool):
352340
orders = [orders] * len(keys)
353341
elif orders is None:
354342
orders = [True] * len(keys)
355343

356-
# error: Incompatible types in assignment (expression has type
357-
# "List[Union[ExtensionArray, ndarray[Any, Any], Index, Series]]", variable
358-
# has type "Union[List[Union[ExtensionArray, ndarray[Any, Any]]], List[Series]]")
359-
keys = [ensure_key_mapped(k, key) for k in keys] # type: ignore[assignment]
344+
labels = []
360345

361346
for k, order in zip(keys, orders):
362-
if na_position not in ["last", "first"]:
363-
raise ValueError(f"invalid na_position: {na_position}")
364-
347+
k = ensure_key_mapped(k, key)
365348
if codes_given:
366-
mask = k == -1
367-
codes = k.copy()
368-
n = len(codes)
369-
mask_n = n
370-
# error: Item "ExtensionArray" of "Union[Any, ExtensionArray,
371-
# ndarray[Any, Any]]" has no attribute "any"
372-
if mask.any(): # type: ignore[union-attr]
373-
n -= 1
374-
349+
codes = cast(np.ndarray, k)
350+
n = codes.max() + 1 if len(codes) else 0
375351
else:
376352
cat = Categorical(k, ordered=True)
353+
codes = cat.codes
377354
n = len(cat.categories)
378-
codes = cat.codes.copy()
379-
mask = cat.codes == -1
380-
if mask.any():
381-
mask_n = n + 1
382-
else:
383-
mask_n = n
384-
385-
if order: # ascending
386-
if na_position == "last":
387-
# error: Argument 1 to "where" has incompatible type "Union[Any,
388-
# ExtensionArray, ndarray[Any, Any]]"; expected
389-
# "Union[_SupportsArray[dtype[Any]],
390-
# _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float,
391-
# complex, str, bytes, _NestedSequence[Union[bool, int, float,
392-
# complex, str, bytes]]]"
393-
codes = np.where(mask, n, codes) # type: ignore[arg-type]
394-
elif na_position == "first":
395-
# error: Incompatible types in assignment (expression has type
396-
# "Union[Any, int, ndarray[Any, dtype[signedinteger[Any]]]]",
397-
# variable has type "Union[Series, ExtensionArray, ndarray[Any, Any]]")
398-
# error: Unsupported operand types for + ("ExtensionArray" and "int")
399-
codes += 1 # type: ignore[operator,assignment]
400-
else: # not order means descending
401-
if na_position == "last":
402-
# error: Unsupported operand types for - ("int" and "ExtensionArray")
403-
# error: Argument 1 to "where" has incompatible type "Union[Any,
404-
# ExtensionArray, ndarray[Any, Any]]"; expected
405-
# "Union[_SupportsArray[dtype[Any]],
406-
# _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float,
407-
# complex, str, bytes, _NestedSequence[Union[bool, int, float,
408-
# complex, str, bytes]]]"
409-
codes = np.where(
410-
mask, n, n - codes - 1 # type: ignore[operator,arg-type]
411-
)
412-
elif na_position == "first":
413-
# error: Unsupported operand types for - ("int" and "ExtensionArray")
414-
# error: Argument 1 to "where" has incompatible type "Union[Any,
415-
# ExtensionArray, ndarray[Any, Any]]"; expected
416-
# "Union[_SupportsArray[dtype[Any]],
417-
# _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float,
418-
# complex, str, bytes, _NestedSequence[Union[bool, int, float,
419-
# complex, str, bytes]]]"
420-
codes = np.where(mask, 0, n - codes) # type: ignore[operator,arg-type]
421-
422-
shape.append(mask_n)
355+
356+
mask = codes == -1
357+
358+
if na_position == "last" and mask.any():
359+
codes = np.where(mask, n, codes)
360+
361+
# not order means descending
362+
if not order:
363+
codes = np.where(mask, codes, n - codes - 1)
364+
423365
labels.append(codes)
424366

425-
return indexer_from_factorized(labels, tuple(shape))
367+
return np.lexsort(labels[::-1])
426368

427369

428370
def nargsort(

0 commit comments

Comments
 (0)