@@ -459,13 +459,18 @@ def test_sort_values_invalid_na_position(
459
459
460
460
@pytest .mark .filterwarnings (r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning" )
461
461
@pytest .mark .parametrize ("na_position" , ["first" , "last" ])
462
+ @pytest .mark .parametrize ("box_in_series" , [False , True ])
462
463
@pytest .mark .xfail (
463
- reason = "Sorting fails due to heterogeneous types in index (int vs str)"
464
+ reason = "Sorting fails due to heterogeneous types in index (int vs str)" ,
465
+ strict = False ,
464
466
)
465
- def test_sort_values_with_missing (index_with_missing , na_position , request ):
467
+ def test_sort_values_with_missing (index_with_missing , na_position , request , box_in_series ):
466
468
# GH 35584. Test that sort_values works with missing values,
467
469
# sort non-missing and place missing according to na_position
468
470
471
+ if box_in_series :
472
+ index_with_missing = pd .Series (index_with_missing )
473
+
469
474
non_na_values = [x for x in index_with_missing if pd .notna (x )]
470
475
if len ({type (x ) for x in non_na_values }) > 1 :
471
476
index_with_missing = index_with_missing .map (str )
@@ -478,18 +483,30 @@ def test_sort_values_with_missing(index_with_missing, na_position, request):
478
483
)
479
484
480
485
missing_count = np .sum (index_with_missing .isna ())
481
- not_na_vals = index_with_missing [index_with_missing .notna ()].values
486
+
487
+ if isinstance (index_with_missing , pd .Series ):
488
+ not_na_vals = index_with_missing [index_with_missing .notna ()].values
489
+ else :
490
+ not_na_vals = index_with_missing [index_with_missing .notna ()].values
491
+
492
+
482
493
sorted_values = np .sort (not_na_vals )
483
494
if na_position == "first" :
484
495
sorted_values = np .concatenate ([[None ] * missing_count , sorted_values ])
485
496
else :
486
497
sorted_values = np .concatenate ([sorted_values , [None ] * missing_count ])
487
498
488
499
# Explicitly pass dtype needed for Index backed by EA e.g. IntegerArray
489
- expected = type (index_with_missing )(sorted_values , dtype = index_with_missing .dtype )
500
+ if isinstance (index_with_missing , pd .Series ):
501
+ expected = pd .Series (sorted_values , dtype = index_with_missing .dtype )
502
+ else :
503
+ expected = type (index_with_missing )(sorted_values , dtype = index_with_missing .dtype )
490
504
491
505
result = index_with_missing .sort_values (na_position = na_position )
492
- tm .assert_index_equal (result , expected )
506
+ if isinstance (index_with_missing , pd .Series ):
507
+ tm .assert_series_equal (result , expected )
508
+ else :
509
+ tm .assert_index_equal (result , expected )
493
510
494
511
495
512
def test_sort_values_natsort_key ():
0 commit comments