Skip to content

Commit a718128

Browse files
committed
Allow index.map() to accept series and dictionary inputs in addition to functional inputs
1 parent b895968 commit a718128

File tree

3 files changed

+62
-3
lines changed

3 files changed

+62
-3
lines changed

pandas/indexes/base.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from pandas.compat.numpy import function as nv
1515
from pandas import compat
1616

17-
1817
from pandas.types.generic import ABCSeries, ABCMultiIndex, ABCPeriodIndex
1918
from pandas.types.missing import isnull, array_equivalent
2019
from pandas.types.common import (_ensure_int64,
@@ -2438,7 +2437,7 @@ def map(self, mapper):
24382437
24392438
Parameters
24402439
----------
2441-
mapper : callable
2440+
mapper : function, dict, or Series
24422441
Function to be applied.
24432442
24442443
Returns
@@ -2450,7 +2449,15 @@ def map(self, mapper):
24502449
24512450
"""
24522451
from .multi import MultiIndex
2453-
mapped_values = self._arrmap(self.values, mapper)
2452+
2453+
if isinstance(mapper, ABCSeries):
2454+
indexer = mapper.index.get_indexer(self._values)
2455+
mapped_values = algos.take_1d(mapper.values, indexer)
2456+
else:
2457+
if isinstance(mapper, dict):
2458+
mapper = mapper.get
2459+
mapped_values = self._arrmap(self._values, mapper)
2460+
24542461
attributes = self._get_attributes_dict()
24552462
if mapped_values.size and isinstance(mapped_values[0], tuple):
24562463
return MultiIndex.from_tuples(mapped_values,

pandas/tests/indexes/test_base.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,52 @@ def test_map_tseries_indices_return_index(self):
809809
exp = Index(range(24), name='hourly')
810810
tm.assert_index_equal(exp, date_index.map(lambda x: x.hour))
811811

812+
def test_map_with_series_all_indices(self):
813+
expected = Index(['foo', 'bar', 'baz'])
814+
mapper = Series(expected.values, index=[0, 1, 2])
815+
self.assert_index_equal(tm.makeIntIndex(3).map(mapper), expected)
816+
817+
# GH 12766
818+
# special = []
819+
special = ['catIndex']
820+
821+
for name in special:
822+
orig_values = ['a', 'B', 1, 'a']
823+
new_values = ['one', 2, 3.0, 'one']
824+
cur_index = CategoricalIndex(orig_values, name='XXX')
825+
mapper = pd.Series(new_values[:-1], index=orig_values[:-1])
826+
expected = CategoricalIndex(new_values, name='XXX')
827+
output = cur_index.map(mapper)
828+
self.assert_numpy_array_equal(expected.values.get_values(), output.values.get_values())
829+
self.assert_equal(expected.name, output.name)
830+
831+
832+
for name in list(set(self.indices.keys()) - set(special)):
833+
cur_index = self.indices[name]
834+
expected = Index(np.arange(len(cur_index), 0, -1))
835+
mapper = pd.Series(expected.values, index=cur_index)
836+
print(name)
837+
output = cur_index.map(mapper)
838+
self.assert_index_equal(expected, cur_index.map(mapper))
839+
840+
def test_map_with_categorical_series(self):
841+
# GH 12756
842+
a = Index([1, 2, 3, 4])
843+
b = Series(["even", "odd", "even", "odd"], dtype="category")
844+
c = Series(["even", "odd", "even", "odd"])
845+
846+
exp = CategoricalIndex(["odd", "even", "odd", np.nan])
847+
self.assert_index_equal(a.map(b), exp)
848+
exp = Index(["odd", "even", "odd", np.nan])
849+
self.assert_index_equal(a.map(c), exp)
850+
851+
def test_map_with_series_missing_values(self):
852+
# GH 12756
853+
expected = Index([2., np.nan, 'foo'])
854+
mapper = Series(['foo', 2., 'baz'], index=[0, 2, -1])
855+
output = Index([2, 1, 0]).map(mapper)
856+
self.assert_index_equal(output, expected)
857+
812858
def test_append_multiple(self):
813859
index = Index(['a', 'b', 'c', 'd', 'e', 'f'])
814860

pandas/tests/indexes/test_category.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,12 @@ def f(x):
233233
ordered=False)
234234
tm.assert_index_equal(result, exp)
235235

236+
result = ci.map(pd.Series([10, 20, 30], index=['A', 'B', 'C']))
237+
tm.assert_index_equal(result, exp)
238+
239+
result = ci.map({'A': 10, 'B': 20, 'C': 30})
240+
tm.assert_index_equal(result, exp)
241+
236242
def test_where(self):
237243
i = self.create_index()
238244
result = i.where(notnull(i))

0 commit comments

Comments
 (0)