|
4 | 4 | import pandas as pd
|
5 | 5 | from pandas import DataFrame, Index, MultiIndex, Series, concat, merge
|
6 | 6 | import pandas._testing as tm
|
| 7 | +from pandas.errors import MergeError |
7 | 8 | from pandas.tests.reshape.merge.test_merge import NGROUPS, N, get_test_data
|
8 | 9 |
|
9 | 10 | a_ = np.array
|
@@ -803,3 +804,24 @@ def test_join_inner_multiindex_deterministic_order():
|
803 | 804 | index=MultiIndex.from_tuples([(2, 1, 4, 3)], names=("b", "a", "d", "c")),
|
804 | 805 | )
|
805 | 806 | tm.assert_frame_equal(result, expected)
|
| 807 | + |
| 808 | + |
| 809 | +@pytest.mark.parametrize( |
| 810 | + ("input_col", "output_cols"), [("b", ["a", "b"]), ("a", ["a_x", "a_y"])] |
| 811 | +) |
| 812 | +def test_join_cross(input_col, output_cols): |
| 813 | + # GH#5401 |
| 814 | + left = DataFrame({"a": [1, 3]}) |
| 815 | + right = DataFrame({input_col: [3, 4]}) |
| 816 | + result = left.join(right, how="cross", lsuffix="_x", rsuffix="_y") |
| 817 | + expected = DataFrame({output_cols[0]: [1, 1, 3, 3], output_cols[1]: [3, 4, 3, 4]}) |
| 818 | + tm.assert_frame_equal(result, expected) |
| 819 | + |
| 820 | + |
| 821 | +def test_join_cross_error_reporting(): |
| 822 | + # GH#5401 |
| 823 | + left = DataFrame({"a": [1, 3]}) |
| 824 | + right = DataFrame({"a": [3, 4]}) |
| 825 | + msg = "Can not pass any merge columns when using cross as merge method" |
| 826 | + with pytest.raises(MergeError, match=msg): |
| 827 | + left.join(right, how="cross", on="a") |
0 commit comments