From db15070769a34d9507045318f0dc5f32931fcb72 Mon Sep 17 00:00:00 2001 From: Luke Manley Date: Mon, 5 Sep 2022 20:58:00 -0400 Subject: [PATCH 1/3] perf improvement in MultiIndex.argsort and MultiIndex.sort_values --- asv_bench/benchmarks/multiindex_object.py | 10 ++++++++++ pandas/core/indexes/multi.py | 4 ++++ pandas/tests/indexes/multi/test_setops.py | 5 ----- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/asv_bench/benchmarks/multiindex_object.py b/asv_bench/benchmarks/multiindex_object.py index ca1d71f50bd5f..56034a8c4029c 100644 --- a/asv_bench/benchmarks/multiindex_object.py +++ b/asv_bench/benchmarks/multiindex_object.py @@ -176,6 +176,16 @@ def time_sortlevel_one(self): self.mi.sortlevel(1) +class SortValues: + def setup(self): + a = np.tile(np.arange(100), 1000) + b = np.tile(np.arange(1000), 100) + self.mi = MultiIndex.from_arrays([a, b]) + + def time_sort_values(self): + self.mi.sort_values() + + class Values: def setup_cache(self): diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index 6bf00bc5bc366..0ebd1a3d718d8 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -2225,6 +2225,10 @@ def append(self, other): return Index._with_infer(new_tuples) def argsort(self, *args, **kwargs) -> npt.NDArray[np.intp]: + if len(args) == 0 and len(kwargs) == 0: + # np.lexsort is significantly faster than self._values.argsort() + values = [self._get_level_values(i) for i in reversed(range(self.nlevels))] + return np.lexsort(values) return self._values.argsort(*args, **kwargs) @Appender(_index_shared_docs["repeat"] % _index_doc_kwargs) diff --git a/pandas/tests/indexes/multi/test_setops.py b/pandas/tests/indexes/multi/test_setops.py index d29abc1bc6a9f..e15809a751b73 100644 --- a/pandas/tests/indexes/multi/test_setops.py +++ b/pandas/tests/indexes/multi/test_setops.py @@ -530,11 +530,6 @@ def test_union_duplicates(index, request): if index.empty or isinstance(index, (IntervalIndex, CategoricalIndex)): # No duplicates in empty indexes return - if index.dtype.kind == "c": - mark = pytest.mark.xfail( - reason="sort_values() call raises bc complex objects are not comparable" - ) - request.node.add_marker(mark) values = index.unique().values.tolist() mi1 = MultiIndex.from_arrays([values, [1] * len(values)]) From 8366ab45a9750300c98ddba0ef0b393519eeabed Mon Sep 17 00:00:00 2001 From: Luke Manley Date: Mon, 5 Sep 2022 21:08:48 -0400 Subject: [PATCH 2/3] whatsnew --- doc/source/whatsnew/v1.6.0.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/source/whatsnew/v1.6.0.rst b/doc/source/whatsnew/v1.6.0.rst index 848e87f0bc029..eebdaa0d6a322 100644 --- a/doc/source/whatsnew/v1.6.0.rst +++ b/doc/source/whatsnew/v1.6.0.rst @@ -100,6 +100,7 @@ Deprecations Performance improvements ~~~~~~~~~~~~~~~~~~~~~~~~ +- Performance improvement in :meth:`MultiIndex.argsort` and :meth:`MultiIndex.sort_values` (:issue:`48406`) - Performance improvement in :meth:`.GroupBy.mean` and :meth:`.GroupBy.var` for extension array dtypes (:issue:`37493`) - Performance improvement for :meth:`MultiIndex.unique` (:issue:`48335`) - From b2fea756b8919907994d51d025cb97d2b8f7ffae Mon Sep 17 00:00:00 2001 From: Luke Manley Date: Tue, 6 Sep 2022 18:05:49 -0400 Subject: [PATCH 3/3] add asv for Int64 --- asv_bench/benchmarks/multiindex_object.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/asv_bench/benchmarks/multiindex_object.py b/asv_bench/benchmarks/multiindex_object.py index 56034a8c4029c..89d74e784b64c 100644 --- a/asv_bench/benchmarks/multiindex_object.py +++ b/asv_bench/benchmarks/multiindex_object.py @@ -8,6 +8,7 @@ MultiIndex, RangeIndex, Series, + array, date_range, ) @@ -177,12 +178,16 @@ def time_sortlevel_one(self): class SortValues: - def setup(self): - a = np.tile(np.arange(100), 1000) - b = np.tile(np.arange(1000), 100) + + params = ["int64", "Int64"] + param_names = ["dtype"] + + def setup(self, dtype): + a = array(np.tile(np.arange(100), 1000), dtype=dtype) + b = array(np.tile(np.arange(1000), 100), dtype=dtype) self.mi = MultiIndex.from_arrays([a, b]) - def time_sort_values(self): + def time_sort_values(self, dtype): self.mi.sort_values()