@@ -319,7 +319,7 @@ def nargsort(
319
319
return indexer
320
320
321
321
322
- def apply_key (index , key , level = None ):
322
+ def ensure_key_mapped_multiindex (index , key , level = None ):
323
323
"""
324
324
Returns a new MultiIndex in which key has been applied
325
325
to all levels specified in level (or all levels if level
@@ -386,22 +386,20 @@ def ensure_key_mapped(values, key: Optional[Callable], levels=None):
386
386
return values .copy ()
387
387
388
388
if isinstance (values , ABCMultiIndex ):
389
- return apply_key (values , key , level = levels )
390
- else :
391
- _class = type (values )
392
- result = key (values .copy ())
393
- if len (result ) != len (values ):
394
- raise ValueError (
395
- "User-provided `key` function much not change the shape of the array."
396
- )
397
-
398
- if not isinstance (result , _class ): # recover from type error
399
- try :
400
- result = _class (result )
401
- except TypeError :
402
- raise TypeError (
403
- "User-provided `key` function returned an invalid type."
404
- )
389
+ return ensure_key_mapped_multiindex (values , key , level = levels )
390
+
391
+ type_of_values = type (values )
392
+ result = key (values .copy ())
393
+ if len (result ) != len (values ):
394
+ raise ValueError (
395
+ "User-provided `key` function much not change the shape of the array."
396
+ )
397
+
398
+ if not isinstance (result , type_of_values ): # recover from type error
399
+ try :
400
+ result = type_of_values (result )
401
+ except TypeError :
402
+ raise TypeError ("User-provided `key` function returned an invalid type." )
405
403
406
404
return result
407
405
0 commit comments