Skip to content

Commit b0f0440

Browse files
author
harisbal
committed
ENH: Multi-level merge on multi-indexes
Allow for merging on multiple levels of multi-indexes
1 parent d50b162 commit b0f0440

File tree

2 files changed

+243
-27
lines changed

2 files changed

+243
-27
lines changed

pandas/core/indexes/base.py

Lines changed: 68 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3028,27 +3028,84 @@ def join(self, other, how='left', level=None, return_indexers=False,
30283028

30293029
def _join_multi(self, other, how, return_indexers=True):
30303030
from .multi import MultiIndex
3031-
self_is_mi = isinstance(self, MultiIndex)
3032-
other_is_mi = isinstance(other, MultiIndex)
30333031

3032+
def _complete_join():
3033+
new_lvls = join_index.levels
3034+
new_lbls = join_index.labels
3035+
new_nms = join_index.names
3036+
3037+
for n in not_overlap:
3038+
if n in self_names:
3039+
idx = lidx
3040+
lvls = self.levels[self_names.index(n)].values
3041+
lbls = self.labels[self_names.index(n)]
3042+
else:
3043+
idx = ridx
3044+
lvls = other.levels[other_names.index(n)].values
3045+
lbls = other.labels[other_names.index(n)]
3046+
3047+
new_lvls = new_lvls.__add__([lvls])
3048+
new_nms = new_nms.__add__([n])
3049+
3050+
# Return the label on match else -1
3051+
l = [lbls[i] if i!=-1 else -1 for i in idx]
3052+
new_lbls = new_lbls.__add__([l])
3053+
3054+
return new_lvls, new_lbls, new_nms
3055+
30343056
# figure out join names
30353057
self_names = [n for n in self.names if n is not None]
30363058
other_names = [n for n in other.names if n is not None]
30373059
overlap = list(set(self_names) & set(other_names))
30383060

3061+
# Drop the non matching levels
3062+
ldrop_levels = [l for l in self_names if l not in overlap]
3063+
rdrop_levels = [l for l in other_names if l not in overlap]
3064+
3065+
self_is_mi = isinstance(self, MultiIndex)
3066+
other_is_mi = isinstance(other, MultiIndex)
3067+
30393068
# need at least 1 in common, but not more than 1
30403069
if not len(overlap):
3041-
raise ValueError("cannot join with no level specified and no "
3042-
"overlapping names")
3043-
if len(overlap) > 1:
3044-
raise NotImplementedError("merging with more than one level "
3045-
"overlap on a multi-index is not "
3046-
"implemented")
3047-
jl = overlap[0]
3070+
raise ValueError("cannot join with no overlapping index names")
3071+
3072+
if self_is_mi and other_is_mi:
3073+
self_tmp = self.droplevel(ldrop_levels)
3074+
other_tmp = other.droplevel(rdrop_levels)
3075+
3076+
if not (other_tmp.is_unique and self_tmp.is_unique):
3077+
raise TypeError(" The index resulting from the overlapping "
3078+
"levels is not unique")
3079+
3080+
join_index, lidx, ridx = self_tmp.join(other_tmp, how=how,
3081+
return_indexers=True)
3082+
3083+
# Append to the returned Index the non-overlapping levels
3084+
not_overlap = ldrop_levels + rdrop_levels
3085+
3086+
if how == 'left':
3087+
join_index = self
3088+
elif how == 'right':
3089+
join_index = other
3090+
else:
3091+
join_index = join_index
3092+
3093+
if how == 'outer':
3094+
new_levels, new_labels, new_names = _complete_join()
3095+
else:
3096+
new_levels = join_index.levels
3097+
new_labels = join_index.labels
3098+
new_names = join_index.names
3099+
3100+
join_index = MultiIndex(levels=new_levels, labels=new_labels,
3101+
names=new_names, verify_integrity=False)
3102+
3103+
return join_index, lidx, ridx
30483104

3049-
# make the indices into mi's that match
3050-
if not (self_is_mi and other_is_mi):
3105+
else:
3106+
jl = overlap[0]
30513107

3108+
# make the indices into mi's that match
30523109
flip_order = False
30533110
if self_is_mi:
30543111
self, other = other, self
@@ -3065,10 +3122,6 @@ def _join_multi(self, other, how, return_indexers=True):
30653122
return result[0], result[2], result[1]
30663123
return result
30673124

3068-
# 2 multi-indexes
3069-
raise NotImplementedError("merging with both multi-indexes is not "
3070-
"implemented")
3071-
30723125
def _join_non_unique(self, other, how='left', return_indexers=False):
30733126
from pandas.core.reshape.merge import _get_join_indexers
30743127

pandas/tests/reshape/test_merge.py

Lines changed: 175 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,14 +1136,14 @@ def test_join_multi_levels(self):
11361136

11371137
def f():
11381138
household.join(portfolio, how='inner')
1139-
pytest.raises(ValueError, f)
1139+
self.assertRaises(ValueError, f)
11401140

11411141
portfolio2 = portfolio.copy()
11421142
portfolio2.index.set_names(['household_id', 'foo'])
11431143

11441144
def f():
11451145
portfolio2.join(portfolio, how='inner')
1146-
pytest.raises(ValueError, f)
1146+
self.assertRaises(ValueError, f)
11471147

11481148
def test_join_multi_levels2(self):
11491149

@@ -1182,11 +1182,7 @@ def test_join_multi_levels2(self):
11821182
.set_index(["household_id", "asset_id", "t"])
11831183
.reindex(columns=['share', 'log_return']))
11841184

1185-
def f():
1186-
household.join(log_return, how='inner')
1187-
pytest.raises(NotImplementedError, f)
1188-
1189-
# this is the equivalency
1185+
# this is equivalency the
11901186
result = (merge(household.reset_index(), log_return.reset_index(),
11911187
on=['asset_id'], how='inner')
11921188
.set_index(['household_id', 'asset_id', 't']))
@@ -1195,7 +1191,7 @@ def f():
11951191
expected = (
11961192
DataFrame(dict(
11971193
household_id=[1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4],
1198-
asset_id=["nl0000301109", "nl0000289783", "gb00b03mlx29",
1194+
asset_id=["nl0000301109", "nl0000301109", "gb00b03mlx29",
11991195
"gb00b03mlx29", "gb00b03mlx29",
12001196
"gb00b03mlx29", "gb00b03mlx29", "gb00b03mlx29",
12011197
"lu0197800237", "lu0197800237",
@@ -1208,12 +1204,179 @@ def f():
12081204
.09604978, -.06524096, .03532373,
12091205
.03025441, .036997, None, None]
12101206
))
1211-
.set_index(["household_id", "asset_id", "t"]))
1207+
.set_index(["household_id", "asset_id", "t"])
1208+
.reindex(columns=['share', 'log_return']))
12121209

1213-
def f():
1214-
household.join(log_return, how='outer')
1215-
pytest.raises(NotImplementedError, f)
1210+
result = (merge(household.reset_index(), log_return.reset_index(),
1211+
on=['asset_id'], how='outer')
1212+
.set_index(['household_id', 'asset_id', 't']))
12161213

1214+
assert_frame_equal(result, expected)
1215+
1216+
def test_join_multi_levels3(self):
1217+
# Multi-index join tests
1218+
# Self join
1219+
matrix = (
1220+
pd.DataFrame(
1221+
dict(Origin=[1, 1, 2, 2, 3],
1222+
Destination=[1, 2, 1, 3, 1],
1223+
Period=['AM','PM','IP','AM','OP'],
1224+
TripPurp=['hbw', 'nhb', 'hbo', 'nhb', 'hbw'],
1225+
Trips=[1987, 3647, 2470, 4296, 4444]),
1226+
columns=['Origin', 'Destination', 'Period',
1227+
'TripPurp', 'Trips'])
1228+
.set_index(['Origin', 'Destination', 'Period', 'TripPurp']))
1229+
1230+
distances = (
1231+
pd.DataFrame(
1232+
dict(Origin= [1, 1, 2, 2, 3, 3, 5],
1233+
Destination=[1, 2, 1, 2, 1, 2, 6],
1234+
Period=['AM','PM','IP','AM','OP','IP', 'AM'],
1235+
LinkType=['a', 'a', 'c', 'b', 'a', 'b', 'a'],
1236+
Distance=[100, 80, 90, 80, 75, 35, 55]),
1237+
columns=['Origin', 'Destination', 'Period',
1238+
'LinkType', 'Distance'])
1239+
.set_index(['Origin', 'Destination','Period', 'LinkType']))
1240+
1241+
expected = (
1242+
pd.DataFrame(
1243+
dict(Origin=[1, 1, 2, 2, 3],
1244+
Destination=[1, 2, 1, 3, 1],
1245+
Period=['AM','PM','IP','AM','OP'],
1246+
TripPurp=['hbw', 'nhb', 'hbo', 'nhb', 'hbw'],
1247+
Trips=[1987, 3647, 2470, 4296, 4444],
1248+
Trips_joined=[1987, 3647, 2470, 4296, 4444]),
1249+
columns=['Origin', 'Destination', 'Period',
1250+
'TripPurp', 'Trips', 'Trips_joined'])
1251+
.set_index(['Origin', 'Destination', 'Period', 'TripPurp']))
1252+
1253+
result = matrix.join(matrix, how='inner', rsuffix='_joined')
1254+
assert_frame_equal(result, expected)
1255+
1256+
#Left join
1257+
expected = (
1258+
pd.DataFrame(
1259+
dict(Origin= [1, 1, 2, 2, 3],
1260+
Destination=[1, 2, 1, 3, 1],
1261+
Period=['AM','PM','IP', 'AM', 'OP'],
1262+
TripPurp=['hbw', 'nhb', 'hbo', 'nhb', 'hbw'],
1263+
Trips=[1987, 3647, 2470, 4296, 4444],
1264+
Distance=[100, 80, 90, np.nan, 75]),
1265+
columns=['Origin', 'Destination', 'Period', 'TripPurp',
1266+
'Trips', 'Distance'])
1267+
.set_index(['Origin', 'Destination', 'Period', 'TripPurp']))
1268+
1269+
result = matrix.join(distances, how='left')
1270+
assert_frame_equal(result, expected)
1271+
1272+
#Right join
1273+
expected = (
1274+
pd.DataFrame(
1275+
dict(Origin= [1, 1, 2, 2, 3, 3, 5],
1276+
Destination=[1, 2, 1, 2, 1, 2, 6],
1277+
Period=['AM','PM','IP','AM','OP','IP', 'AM'],
1278+
LinkType=['a', 'a', 'c', 'b', 'a', 'b', 'a'],
1279+
Trips=[1987, 3647, 2470, np.nan, 4444, np.nan, np.nan],
1280+
Distance=[100, 80, 90, 80, 75, 35, 55]),
1281+
columns=['Origin', 'Destination', 'Period',
1282+
'LinkType', 'Trips', 'Distance'])
1283+
.set_index(['Origin', 'Destination','Period', 'LinkType']))
1284+
1285+
result = matrix.join(distances, how='right')
1286+
assert_frame_equal(result, expected)
1287+
1288+
#Inner join
1289+
expected = (
1290+
pd.DataFrame(
1291+
dict(Origin= [1, 1, 2, 3],
1292+
Destination=[1, 2, 1, 1],
1293+
Period=['AM','PM','IP', 'OP'],
1294+
Trips=[1987, 3647, 2470, 4444],
1295+
Distance=[100, 80, 90, 75]),
1296+
columns=['Origin', 'Destination', 'Period', 'Trips', 'Distance'])
1297+
.set_index(['Origin', 'Destination', 'Period']))
1298+
1299+
result = matrix.join(distances, how='inner')
1300+
assert_frame_equal(result, expected)
1301+
1302+
#Outer join
1303+
expected = (
1304+
pd.DataFrame(
1305+
dict(Origin= [1, 1, 2, 2, 2, 3, 3, 5],
1306+
Destination=[1, 2, 1, 2, 3, 1, 2, 6],
1307+
Period=['AM','PM','IP', 'AM', 'AM', 'OP', 'IP', 'AM'],
1308+
TripPurp=['hbw', 'nhb', 'hbo', np.nan, 'nhb',
1309+
'hbw', np.nan, np.nan],
1310+
LinkType=['a', 'a', 'c', 'b', np.nan, 'a', 'b', 'a'],
1311+
Trips=[1987, 3647, 2470, np.nan, 4296, 4444, np.nan, np.nan],
1312+
Distance=[100, 80, 90, 80, np.nan, 75, 35, 55]),
1313+
columns=['Origin', 'Destination', 'Period', 'TripPurp', 'LinkType',
1314+
'Trips', 'Distance'])
1315+
.set_index(['Origin', 'Destination', 'Period', 'TripPurp', 'LinkType']))
1316+
1317+
1318+
result = matrix.join(distances, how='outer')
1319+
assert_frame_equal(result, expected)
1320+
1321+
#Non-unique resulting index
1322+
distances2 = (
1323+
pd.DataFrame(
1324+
dict(Origin= [1, 1, 2],
1325+
Destination=[1, 1, 1],
1326+
Period=['AM','AM', 'PM'],
1327+
LinkType=['a', 'b', 'a'],
1328+
Distance=[100, 110, 120]),
1329+
columns=['Origin', 'Destination', 'Period',
1330+
'LinkType', 'Distance'])
1331+
.set_index(['Origin', 'Destination','Period', 'LinkType']))
1332+
1333+
def f():
1334+
matrix.join(distances2, how='left')
1335+
self.assertRaises(TypeError, f)
1336+
1337+
#No-overlapping level names
1338+
distances2 = (
1339+
pd.DataFrame(
1340+
dict(Orig= [1, 1, 2, 2, 3, 3, 5],
1341+
Dest=[1, 2, 1, 2, 1, 2, 6],
1342+
Per=['AM','PM','IP','AM','OP','IP', 'AM'],
1343+
LinkTyp=['a', 'a', 'c', 'b', 'a', 'b', 'a'],
1344+
Dist=[100, 80, 90, 80, 75, 35, 55]),
1345+
columns=['Orig', 'Dest', 'Per',
1346+
'LinkTyp', 'Dist'])
1347+
.set_index(['Orig', 'Dest','Per', 'LinkTyp']))
1348+
1349+
def f():
1350+
matrix.join(distances2, how='left')
1351+
self.assertRaises(ValueError, f)
1352+
1353+
# Empty Level
1354+
distances2 = (
1355+
pd.DataFrame(
1356+
dict(Origin=[1, 1, 2, 2, 3, 3, 5],
1357+
Destination=[1, 2, 1, 2, 1, 2, 6],
1358+
Period=[np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,np.nan],
1359+
LinkType=['a', 'a', 'c', 'b', 'a', 'b', 'a'],
1360+
Distance=[100, 80, 90, 80, 75, 35, 55]),
1361+
columns=['Origin', 'Destination', 'Period',
1362+
'LinkType', 'Distance'])
1363+
.set_index(['Origin', 'Destination','Period', 'LinkType']))
1364+
1365+
1366+
expected = (
1367+
pd.DataFrame(
1368+
dict(Origin=[1, 1, 2, 2, 3],
1369+
Destination=[1, 2, 1, 3, 1],
1370+
Period=['AM','PM','IP','AM','OP'],
1371+
TripPurp=['hbw', 'nhb', 'hbo', 'nhb', 'hbw'],
1372+
Trips=[1987, 3647, 2470, 4296, 4444],
1373+
Distance=[np.nan, np.nan, np.nan, np.nan, np.nan]),
1374+
columns=['Origin', 'Destination', 'Period',
1375+
'TripPurp', 'Trips', 'Distance'])
1376+
.set_index(['Origin', 'Destination', 'Period', 'TripPurp']))
1377+
1378+
result = matrix.join(distances2, how='left')
1379+
assert_frame_equal(result, expected)
12171380

12181381
@pytest.fixture
12191382
def df():

0 commit comments

Comments
 (0)