14
14
15
15
16
16
@pytest .mark .parametrize ("case" , [0.5 , "xxx" ])
17
- @pytest .mark .parametrize (
18
- "method" , ["intersection" , "union" , "difference" , "symmetric_difference" ]
19
- )
17
+ @pytest .mark .parametrize ("method" , ["intersection" , "union" , "difference" , "symmetric_difference" ])
20
18
def test_set_ops_error_cases (idx , case , sort , method ):
21
19
# non-iterable input
22
20
msg = "Input must be Index or array-like"
@@ -299,9 +297,7 @@ def test_intersection(idx, sort):
299
297
assert result .equals (idx )
300
298
301
299
302
- @pytest .mark .parametrize (
303
- "method" , ["intersection" , "union" , "difference" , "symmetric_difference" ]
304
- )
300
+ @pytest .mark .parametrize ("method" , ["intersection" , "union" , "difference" , "symmetric_difference" ])
305
301
def test_setop_with_categorical (idx , sort , method ):
306
302
other = idx .to_flat_index ().astype ("category" )
307
303
res_names = [None ] * idx .nlevels
@@ -428,9 +424,7 @@ def test_union_multiindex_empty_rangeindex():
428
424
tm .assert_index_equal (mi , result_right , check_names = False )
429
425
430
426
431
- @pytest .mark .parametrize (
432
- "method" , ["union" , "intersection" , "difference" , "symmetric_difference" ]
433
- )
427
+ @pytest .mark .parametrize ("method" , ["union" , "intersection" , "difference" , "symmetric_difference" ])
434
428
def test_setops_disallow_true (method ):
435
429
idx1 = MultiIndex .from_product ([["a" , "b" ], [1 , 2 ]])
436
430
idx2 = MultiIndex .from_product ([["b" , "c" ], [1 , 2 ]])
@@ -442,12 +436,8 @@ def test_setops_disallow_true(method):
442
436
@pytest .mark .parametrize ("val" , [pd .NA , 100 ])
443
437
def test_difference_keep_ea_dtypes (any_numeric_ea_dtype , val ):
444
438
# GH#48606
445
- midx = MultiIndex .from_arrays (
446
- [Series ([1 , 2 ], dtype = any_numeric_ea_dtype ), [2 , 1 ]], names = ["a" , None ]
447
- )
448
- midx2 = MultiIndex .from_arrays (
449
- [Series ([1 , 2 , val ], dtype = any_numeric_ea_dtype ), [1 , 1 , 3 ]]
450
- )
439
+ midx = MultiIndex .from_arrays ([Series ([1 , 2 ], dtype = any_numeric_ea_dtype ), [2 , 1 ]], names = ["a" , None ])
440
+ midx2 = MultiIndex .from_arrays ([Series ([1 , 2 , val ], dtype = any_numeric_ea_dtype ), [1 , 1 , 3 ]])
451
441
result = midx .difference (midx2 )
452
442
expected = MultiIndex .from_arrays ([Series ([1 ], dtype = any_numeric_ea_dtype ), [2 ]])
453
443
tm .assert_index_equal (result , expected )
@@ -463,16 +453,10 @@ def test_difference_keep_ea_dtypes(any_numeric_ea_dtype, val):
463
453
@pytest .mark .parametrize ("val" , [pd .NA , 5 ])
464
454
def test_symmetric_difference_keeping_ea_dtype (any_numeric_ea_dtype , val ):
465
455
# GH#48607
466
- midx = MultiIndex .from_arrays (
467
- [Series ([1 , 2 ], dtype = any_numeric_ea_dtype ), [2 , 1 ]], names = ["a" , None ]
468
- )
469
- midx2 = MultiIndex .from_arrays (
470
- [Series ([1 , 2 , val ], dtype = any_numeric_ea_dtype ), [1 , 1 , 3 ]]
471
- )
456
+ midx = MultiIndex .from_arrays ([Series ([1 , 2 ], dtype = any_numeric_ea_dtype ), [2 , 1 ]], names = ["a" , None ])
457
+ midx2 = MultiIndex .from_arrays ([Series ([1 , 2 , val ], dtype = any_numeric_ea_dtype ), [1 , 1 , 3 ]])
472
458
result = midx .symmetric_difference (midx2 )
473
- expected = MultiIndex .from_arrays (
474
- [Series ([1 , 1 , val ], dtype = any_numeric_ea_dtype ), [1 , 2 , 3 ]]
475
- )
459
+ expected = MultiIndex .from_arrays ([Series ([1 , 1 , val ], dtype = any_numeric_ea_dtype ), [1 , 2 , 3 ]])
476
460
tm .assert_index_equal (result , expected )
477
461
478
462
@@ -566,9 +550,7 @@ def test_union_nan_got_duplicated(dtype, sort):
566
550
mi2 = MultiIndex .from_arrays ([pd .array ([1.0 , np .nan , 3.0 ], dtype = dtype ), [2 , 3 , 4 ]])
567
551
result = mi1 .union (mi2 , sort = sort )
568
552
if sort is None :
569
- expected = MultiIndex .from_arrays (
570
- [pd .array ([1.0 , 3.0 , np .nan ], dtype = dtype ), [2 , 4 , 3 ]]
571
- )
553
+ expected = MultiIndex .from_arrays ([pd .array ([1.0 , 3.0 , np .nan ], dtype = dtype ), [2 , 4 , 3 ]])
572
554
else :
573
555
expected = mi2
574
556
tm .assert_index_equal (result , expected )
@@ -584,13 +566,9 @@ def test_union_keep_ea_dtype(any_numeric_ea_dtype, val):
584
566
midx2 = MultiIndex .from_arrays ([arr2 , [2 , 1 ]])
585
567
result = midx .union (midx2 )
586
568
if val == 4 :
587
- expected = MultiIndex .from_arrays (
588
- [Series ([1 , 2 , 4 ], dtype = any_numeric_ea_dtype ), [1 , 2 , 1 ]]
589
- )
569
+ expected = MultiIndex .from_arrays ([Series ([1 , 2 , 4 ], dtype = any_numeric_ea_dtype ), [1 , 2 , 1 ]])
590
570
else :
591
- expected = MultiIndex .from_arrays (
592
- [Series ([1 , 2 ], dtype = any_numeric_ea_dtype ), [1 , 2 ]]
593
- )
571
+ expected = MultiIndex .from_arrays ([Series ([1 , 2 ], dtype = any_numeric_ea_dtype ), [1 , 2 ]])
594
572
tm .assert_index_equal (result , expected )
595
573
596
574
@@ -637,9 +615,7 @@ def test_union_duplicates(index, request):
637
615
# and loses type information. Result is then unsigned only when values are
638
616
# sufficiently large to require unsigned dtype. This happens only if other
639
617
# has dups or one of both have missing values
640
- expected = expected .set_levels (
641
- [expected .levels [0 ].astype (int ), expected .levels [1 ]]
642
- )
618
+ expected = expected .set_levels ([expected .levels [0 ].astype (int ), expected .levels [1 ]])
643
619
result = mi1 .union (mi2 )
644
620
tm .assert_index_equal (result , expected )
645
621
@@ -666,9 +642,7 @@ def test_union_keep_ea_dtype_with_na(any_numeric_ea_dtype):
666
642
midx = MultiIndex .from_arrays ([arr1 , [2 , 1 ]], names = ["a" , None ])
667
643
midx2 = MultiIndex .from_arrays ([arr2 , [1 , 2 ]])
668
644
result = midx .union (midx2 )
669
- expected = MultiIndex .from_arrays (
670
- [Series ([1 , 4 , pd .NA , pd .NA ], dtype = any_numeric_ea_dtype ), [1 , 2 , 1 , 2 ]]
671
- )
645
+ expected = MultiIndex .from_arrays ([Series ([1 , 4 , pd .NA , pd .NA ], dtype = any_numeric_ea_dtype ), [1 , 2 , 1 , 2 ]])
672
646
tm .assert_index_equal (result , expected )
673
647
674
648
@@ -692,15 +666,41 @@ def test_intersection_lexsort_depth(levels1, levels2, codes1, codes2, names):
692
666
assert mi_int ._lexsort_depth == 2
693
667
694
668
669
+ @pytest .mark .parametrize (
670
+ "a" ,
671
+ [pd .Categorical (["a" , "b" ], categories = ["a" , "b" ]), ["a" , "b" ]],
672
+ )
673
+ @pytest .mark .parametrize (
674
+ "b" ,
675
+ [pd .Categorical (["a" , "b" ], categories = ["b" , "a" ]), pd .Categorical (["a" , "b" ], categories = ["b" , "a" ])],
676
+ )
677
+ def test_intersection_with_non_lex_sorted_categories (a , b ):
678
+ # GH#49974
679
+ other = ["1" , "2" ]
680
+
681
+ df1 = pd .DataFrame ({"x" : a , "y" : other })
682
+ df2 = pd .DataFrame ({"x" : b , "y" : other })
683
+
684
+ expected = pd .MultiIndex .from_arrays ([a , other ], names = ["x" , "y" ])
685
+
686
+ res1 = pd .MultiIndex .from_frame (df1 ).intersection (pd .MultiIndex .from_frame (df2 .sort_values (["x" , "y" ])))
687
+ res2 = pd .MultiIndex .from_frame (df1 ).intersection (pd .MultiIndex .from_frame (df2 ))
688
+ res3 = pd .MultiIndex .from_frame (df1 .sort_values (["x" , "y" ])).intersection (pd .MultiIndex .from_frame (df2 ))
689
+ res4 = pd .MultiIndex .from_frame (df1 .sort_values (["x" , "y" ])).intersection (
690
+ pd .MultiIndex .from_frame (df2 .sort_values (["x" , "y" ]))
691
+ )
692
+
693
+ tm .assert_index_equal (res1 , expected )
694
+ tm .assert_index_equal (res2 , expected )
695
+ tm .assert_index_equal (res3 , expected )
696
+ tm .assert_index_equal (res4 , expected )
697
+
698
+
695
699
@pytest .mark .parametrize ("val" , [pd .NA , 100 ])
696
700
def test_intersection_keep_ea_dtypes (val , any_numeric_ea_dtype ):
697
701
# GH#48604
698
- midx = MultiIndex .from_arrays (
699
- [Series ([1 , 2 ], dtype = any_numeric_ea_dtype ), [2 , 1 ]], names = ["a" , None ]
700
- )
701
- midx2 = MultiIndex .from_arrays (
702
- [Series ([1 , 2 , val ], dtype = any_numeric_ea_dtype ), [1 , 1 , 3 ]]
703
- )
702
+ midx = MultiIndex .from_arrays ([Series ([1 , 2 ], dtype = any_numeric_ea_dtype ), [2 , 1 ]], names = ["a" , None ])
703
+ midx2 = MultiIndex .from_arrays ([Series ([1 , 2 , val ], dtype = any_numeric_ea_dtype ), [1 , 1 , 3 ]])
704
704
result = midx .intersection (midx2 )
705
705
expected = MultiIndex .from_arrays ([Series ([2 ], dtype = any_numeric_ea_dtype ), [1 ]])
706
706
tm .assert_index_equal (result , expected )
0 commit comments