Skip to content

Commit 6b0011e

Browse files
committed
add multi index with categories test
1 parent 13f758c commit 6b0011e

File tree

1 file changed

+45
-45
lines changed

1 file changed

+45
-45
lines changed

pandas/tests/indexes/multi/test_setops.py

Lines changed: 45 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414

1515

1616
@pytest.mark.parametrize("case", [0.5, "xxx"])
17-
@pytest.mark.parametrize(
18-
"method", ["intersection", "union", "difference", "symmetric_difference"]
19-
)
17+
@pytest.mark.parametrize("method", ["intersection", "union", "difference", "symmetric_difference"])
2018
def test_set_ops_error_cases(idx, case, sort, method):
2119
# non-iterable input
2220
msg = "Input must be Index or array-like"
@@ -299,9 +297,7 @@ def test_intersection(idx, sort):
299297
assert result.equals(idx)
300298

301299

302-
@pytest.mark.parametrize(
303-
"method", ["intersection", "union", "difference", "symmetric_difference"]
304-
)
300+
@pytest.mark.parametrize("method", ["intersection", "union", "difference", "symmetric_difference"])
305301
def test_setop_with_categorical(idx, sort, method):
306302
other = idx.to_flat_index().astype("category")
307303
res_names = [None] * idx.nlevels
@@ -428,9 +424,7 @@ def test_union_multiindex_empty_rangeindex():
428424
tm.assert_index_equal(mi, result_right, check_names=False)
429425

430426

431-
@pytest.mark.parametrize(
432-
"method", ["union", "intersection", "difference", "symmetric_difference"]
433-
)
427+
@pytest.mark.parametrize("method", ["union", "intersection", "difference", "symmetric_difference"])
434428
def test_setops_disallow_true(method):
435429
idx1 = MultiIndex.from_product([["a", "b"], [1, 2]])
436430
idx2 = MultiIndex.from_product([["b", "c"], [1, 2]])
@@ -442,12 +436,8 @@ def test_setops_disallow_true(method):
442436
@pytest.mark.parametrize("val", [pd.NA, 100])
443437
def test_difference_keep_ea_dtypes(any_numeric_ea_dtype, val):
444438
# GH#48606
445-
midx = MultiIndex.from_arrays(
446-
[Series([1, 2], dtype=any_numeric_ea_dtype), [2, 1]], names=["a", None]
447-
)
448-
midx2 = MultiIndex.from_arrays(
449-
[Series([1, 2, val], dtype=any_numeric_ea_dtype), [1, 1, 3]]
450-
)
439+
midx = MultiIndex.from_arrays([Series([1, 2], dtype=any_numeric_ea_dtype), [2, 1]], names=["a", None])
440+
midx2 = MultiIndex.from_arrays([Series([1, 2, val], dtype=any_numeric_ea_dtype), [1, 1, 3]])
451441
result = midx.difference(midx2)
452442
expected = MultiIndex.from_arrays([Series([1], dtype=any_numeric_ea_dtype), [2]])
453443
tm.assert_index_equal(result, expected)
@@ -463,16 +453,10 @@ def test_difference_keep_ea_dtypes(any_numeric_ea_dtype, val):
463453
@pytest.mark.parametrize("val", [pd.NA, 5])
464454
def test_symmetric_difference_keeping_ea_dtype(any_numeric_ea_dtype, val):
465455
# GH#48607
466-
midx = MultiIndex.from_arrays(
467-
[Series([1, 2], dtype=any_numeric_ea_dtype), [2, 1]], names=["a", None]
468-
)
469-
midx2 = MultiIndex.from_arrays(
470-
[Series([1, 2, val], dtype=any_numeric_ea_dtype), [1, 1, 3]]
471-
)
456+
midx = MultiIndex.from_arrays([Series([1, 2], dtype=any_numeric_ea_dtype), [2, 1]], names=["a", None])
457+
midx2 = MultiIndex.from_arrays([Series([1, 2, val], dtype=any_numeric_ea_dtype), [1, 1, 3]])
472458
result = midx.symmetric_difference(midx2)
473-
expected = MultiIndex.from_arrays(
474-
[Series([1, 1, val], dtype=any_numeric_ea_dtype), [1, 2, 3]]
475-
)
459+
expected = MultiIndex.from_arrays([Series([1, 1, val], dtype=any_numeric_ea_dtype), [1, 2, 3]])
476460
tm.assert_index_equal(result, expected)
477461

478462

@@ -566,9 +550,7 @@ def test_union_nan_got_duplicated(dtype, sort):
566550
mi2 = MultiIndex.from_arrays([pd.array([1.0, np.nan, 3.0], dtype=dtype), [2, 3, 4]])
567551
result = mi1.union(mi2, sort=sort)
568552
if sort is None:
569-
expected = MultiIndex.from_arrays(
570-
[pd.array([1.0, 3.0, np.nan], dtype=dtype), [2, 4, 3]]
571-
)
553+
expected = MultiIndex.from_arrays([pd.array([1.0, 3.0, np.nan], dtype=dtype), [2, 4, 3]])
572554
else:
573555
expected = mi2
574556
tm.assert_index_equal(result, expected)
@@ -584,13 +566,9 @@ def test_union_keep_ea_dtype(any_numeric_ea_dtype, val):
584566
midx2 = MultiIndex.from_arrays([arr2, [2, 1]])
585567
result = midx.union(midx2)
586568
if val == 4:
587-
expected = MultiIndex.from_arrays(
588-
[Series([1, 2, 4], dtype=any_numeric_ea_dtype), [1, 2, 1]]
589-
)
569+
expected = MultiIndex.from_arrays([Series([1, 2, 4], dtype=any_numeric_ea_dtype), [1, 2, 1]])
590570
else:
591-
expected = MultiIndex.from_arrays(
592-
[Series([1, 2], dtype=any_numeric_ea_dtype), [1, 2]]
593-
)
571+
expected = MultiIndex.from_arrays([Series([1, 2], dtype=any_numeric_ea_dtype), [1, 2]])
594572
tm.assert_index_equal(result, expected)
595573

596574

@@ -637,9 +615,7 @@ def test_union_duplicates(index, request):
637615
# and loses type information. Result is then unsigned only when values are
638616
# sufficiently large to require unsigned dtype. This happens only if other
639617
# has dups or one of both have missing values
640-
expected = expected.set_levels(
641-
[expected.levels[0].astype(int), expected.levels[1]]
642-
)
618+
expected = expected.set_levels([expected.levels[0].astype(int), expected.levels[1]])
643619
result = mi1.union(mi2)
644620
tm.assert_index_equal(result, expected)
645621

@@ -666,9 +642,7 @@ def test_union_keep_ea_dtype_with_na(any_numeric_ea_dtype):
666642
midx = MultiIndex.from_arrays([arr1, [2, 1]], names=["a", None])
667643
midx2 = MultiIndex.from_arrays([arr2, [1, 2]])
668644
result = midx.union(midx2)
669-
expected = MultiIndex.from_arrays(
670-
[Series([1, 4, pd.NA, pd.NA], dtype=any_numeric_ea_dtype), [1, 2, 1, 2]]
671-
)
645+
expected = MultiIndex.from_arrays([Series([1, 4, pd.NA, pd.NA], dtype=any_numeric_ea_dtype), [1, 2, 1, 2]])
672646
tm.assert_index_equal(result, expected)
673647

674648

@@ -692,15 +666,41 @@ def test_intersection_lexsort_depth(levels1, levels2, codes1, codes2, names):
692666
assert mi_int._lexsort_depth == 2
693667

694668

669+
@pytest.mark.parametrize(
670+
"a",
671+
[pd.Categorical(["a", "b"], categories=["a", "b"]), ["a", "b"]],
672+
)
673+
@pytest.mark.parametrize(
674+
"b",
675+
[pd.Categorical(["a", "b"], categories=["b", "a"]), pd.Categorical(["a", "b"], categories=["b", "a"])],
676+
)
677+
def test_intersection_with_non_lex_sorted_categories(a, b):
678+
# GH#49974
679+
other = ["1", "2"]
680+
681+
df1 = pd.DataFrame({"x": a, "y": other})
682+
df2 = pd.DataFrame({"x": b, "y": other})
683+
684+
expected = pd.MultiIndex.from_arrays([a, other], names=["x", "y"])
685+
686+
res1 = pd.MultiIndex.from_frame(df1).intersection(pd.MultiIndex.from_frame(df2.sort_values(["x", "y"])))
687+
res2 = pd.MultiIndex.from_frame(df1).intersection(pd.MultiIndex.from_frame(df2))
688+
res3 = pd.MultiIndex.from_frame(df1.sort_values(["x", "y"])).intersection(pd.MultiIndex.from_frame(df2))
689+
res4 = pd.MultiIndex.from_frame(df1.sort_values(["x", "y"])).intersection(
690+
pd.MultiIndex.from_frame(df2.sort_values(["x", "y"]))
691+
)
692+
693+
tm.assert_index_equal(res1, expected)
694+
tm.assert_index_equal(res2, expected)
695+
tm.assert_index_equal(res3, expected)
696+
tm.assert_index_equal(res4, expected)
697+
698+
695699
@pytest.mark.parametrize("val", [pd.NA, 100])
696700
def test_intersection_keep_ea_dtypes(val, any_numeric_ea_dtype):
697701
# GH#48604
698-
midx = MultiIndex.from_arrays(
699-
[Series([1, 2], dtype=any_numeric_ea_dtype), [2, 1]], names=["a", None]
700-
)
701-
midx2 = MultiIndex.from_arrays(
702-
[Series([1, 2, val], dtype=any_numeric_ea_dtype), [1, 1, 3]]
703-
)
702+
midx = MultiIndex.from_arrays([Series([1, 2], dtype=any_numeric_ea_dtype), [2, 1]], names=["a", None])
703+
midx2 = MultiIndex.from_arrays([Series([1, 2, val], dtype=any_numeric_ea_dtype), [1, 1, 3]])
704704
result = midx.intersection(midx2)
705705
expected = MultiIndex.from_arrays([Series([2], dtype=any_numeric_ea_dtype), [1]])
706706
tm.assert_index_equal(result, expected)

0 commit comments

Comments
 (0)