Skip to content

Commit 2236346

Browse files
authored
BUG: MultiIndex.putmask losing ea dtype (#49847)
* BUG: MultiIndex.putmask losing ea dtype * Fix typing * Add asv * Simplify and add whatsnew
1 parent b7708f0 commit 2236346

File tree

6 files changed

+101
-3
lines changed

6 files changed

+101
-3
lines changed

asv_bench/benchmarks/multiindex_object.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,4 +379,26 @@ def time_isin_large(self, dtype):
379379
self.midx.isin(self.values_large)
380380

381381

382+
class Putmask:
383+
def setup(self):
384+
N = 10**5
385+
level1 = range(1_000)
386+
387+
level2 = date_range(start="1/1/2000", periods=N // 1000)
388+
self.midx = MultiIndex.from_product([level1, level2])
389+
390+
level1 = range(1_000, 2_000)
391+
self.midx_values = MultiIndex.from_product([level1, level2])
392+
393+
level2 = date_range(start="1/1/2010", periods=N // 1000)
394+
self.midx_values_different = MultiIndex.from_product([level1, level2])
395+
self.mask = np.array([True, False] * (N // 2))
396+
397+
def time_putmask(self):
398+
self.midx.putmask(self.mask, self.midx_values)
399+
400+
def time_putmask_all_different(self):
401+
self.midx.putmask(self.mask, self.midx_values_different)
402+
403+
382404
from .pandas_vb_common import setup # noqa: F401 isort:skip

doc/source/whatsnew/v2.0.0.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,7 @@ Performance improvements
588588
- Performance improvement in :class:`MultiIndex` set operations with sort=None (:issue:`49010`)
589589
- Performance improvement in :meth:`.DataFrameGroupBy.mean`, :meth:`.SeriesGroupBy.mean`, :meth:`.DataFrameGroupBy.var`, and :meth:`.SeriesGroupBy.var` for extension array dtypes (:issue:`37493`)
590590
- Performance improvement in :meth:`MultiIndex.isin` when ``level=None`` (:issue:`48622`, :issue:`49577`)
591+
- Performance improvement in :meth:`MultiIndex.putmask` (:issue:`49830`)
591592
- Performance improvement in :meth:`Index.union` and :meth:`MultiIndex.union` when index contains duplicates (:issue:`48900`)
592593
- Performance improvement in :meth:`Series.fillna` for pyarrow-backed dtypes (:issue:`49722`)
593594
- Performance improvement for :meth:`Series.value_counts` with nullable dtype (:issue:`48338`)
@@ -703,6 +704,7 @@ MultiIndex
703704
- Bug in :meth:`MultiIndex.union` not sorting when sort=None and index contains missing values (:issue:`49010`)
704705
- Bug in :meth:`MultiIndex.append` not checking names for equality (:issue:`48288`)
705706
- Bug in :meth:`MultiIndex.symmetric_difference` losing extension array (:issue:`48607`)
707+
- Bug in :meth:`MultiIndex.putmask` losing extension array (:issue:`49830`)
706708
- Bug in :meth:`MultiIndex.value_counts` returning a :class:`Series` indexed by flat index of tuples instead of a :class:`MultiIndex` (:issue:`49558`)
707709
-
708710

pandas/core/array_algos/putmask.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
"""
44
from __future__ import annotations
55

6-
from typing import Any
6+
from typing import (
7+
TYPE_CHECKING,
8+
Any,
9+
)
710

811
import numpy as np
912

@@ -19,6 +22,9 @@
1922

2023
from pandas.core.arrays import ExtensionArray
2124

25+
if TYPE_CHECKING:
26+
from pandas import MultiIndex
27+
2228

2329
def putmask_inplace(values: ArrayLike, mask: npt.NDArray[np.bool_], value: Any) -> None:
2430
"""
@@ -96,7 +102,7 @@ def putmask_without_repeat(
96102

97103

98104
def validate_putmask(
99-
values: ArrayLike, mask: np.ndarray
105+
values: ArrayLike | MultiIndex, mask: np.ndarray
100106
) -> tuple[npt.NDArray[np.bool_], bool]:
101107
"""
102108
Validate mask and check if this putmask operation is a no-op.

pandas/core/indexes/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5144,7 +5144,6 @@ def _concat(self, to_concat: list[Index], name: Hashable) -> Index:
51445144

51455145
return Index._with_infer(result, name=name)
51465146

5147-
@final
51485147
def putmask(self, mask, value) -> Index:
51495148
"""
51505149
Return a new Index of the values set with the mask.

pandas/core/indexes/multi.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
)
8080

8181
import pandas.core.algorithms as algos
82+
from pandas.core.array_algos.putmask import validate_putmask
8283
from pandas.core.arrays import Categorical
8384
from pandas.core.arrays.categorical import factorize_from_iterables
8485
import pandas.core.common as com
@@ -3660,6 +3661,46 @@ def _validate_fill_value(self, item):
36603661
raise ValueError("Item must have length equal to number of levels.")
36613662
return item
36623663

3664+
def putmask(self, mask, value: MultiIndex) -> MultiIndex:
3665+
"""
3666+
Return a new MultiIndex of the values set with the mask.
3667+
3668+
Parameters
3669+
----------
3670+
mask : array like
3671+
value : MultiIndex
3672+
Must either be the same length as self or length one
3673+
3674+
Returns
3675+
-------
3676+
MultiIndex
3677+
"""
3678+
mask, noop = validate_putmask(self, mask)
3679+
if noop:
3680+
return self.copy()
3681+
3682+
if len(mask) == len(value):
3683+
subset = value[mask].remove_unused_levels()
3684+
else:
3685+
subset = value.remove_unused_levels()
3686+
3687+
new_levels = []
3688+
new_codes = []
3689+
3690+
for i, (value_level, level, level_codes) in enumerate(
3691+
zip(subset.levels, self.levels, self.codes)
3692+
):
3693+
new_level = level.union(value_level, sort=False)
3694+
value_codes = new_level.get_indexer_for(subset.get_level_values(i))
3695+
new_code = ensure_int64(level_codes)
3696+
new_code[mask] = value_codes
3697+
new_levels.append(new_level)
3698+
new_codes.append(new_code)
3699+
3700+
return MultiIndex(
3701+
levels=new_levels, codes=new_codes, names=self.names, verify_integrity=False
3702+
)
3703+
36633704
def insert(self, loc: int, item) -> MultiIndex:
36643705
"""
36653706
Make new MultiIndex inserting new item at location

pandas/tests/indexes/multi/test_indexing.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,34 @@ def test_putmask_multiindex_other(self):
162162
expected = MultiIndex.from_tuples([right[0], right[1], left[2]])
163163
tm.assert_index_equal(result, expected)
164164

165+
def test_putmask_keep_dtype(self, any_numeric_ea_dtype):
166+
# GH#49830
167+
midx = MultiIndex.from_arrays(
168+
[pd.Series([1, 2, 3], dtype=any_numeric_ea_dtype), [10, 11, 12]]
169+
)
170+
midx2 = MultiIndex.from_arrays(
171+
[pd.Series([5, 6, 7], dtype=any_numeric_ea_dtype), [-1, -2, -3]]
172+
)
173+
result = midx.putmask([True, False, False], midx2)
174+
expected = MultiIndex.from_arrays(
175+
[pd.Series([5, 2, 3], dtype=any_numeric_ea_dtype), [-1, 11, 12]]
176+
)
177+
tm.assert_index_equal(result, expected)
178+
179+
def test_putmask_keep_dtype_shorter_value(self, any_numeric_ea_dtype):
180+
# GH#49830
181+
midx = MultiIndex.from_arrays(
182+
[pd.Series([1, 2, 3], dtype=any_numeric_ea_dtype), [10, 11, 12]]
183+
)
184+
midx2 = MultiIndex.from_arrays(
185+
[pd.Series([5], dtype=any_numeric_ea_dtype), [-1]]
186+
)
187+
result = midx.putmask([True, False, False], midx2)
188+
expected = MultiIndex.from_arrays(
189+
[pd.Series([5, 2, 3], dtype=any_numeric_ea_dtype), [-1, 11, 12]]
190+
)
191+
tm.assert_index_equal(result, expected)
192+
165193

166194
class TestGetIndexer:
167195
def test_get_indexer(self):

0 commit comments

Comments
 (0)