Skip to content

Commit f9f71cc

Browse files
committed
Fix poor worst case performance of set intersection (and union, somewhat) on asymmetrically sized sets and extend unit tests slightly beyond that
1 parent 01c6ea2 commit f9f71cc

File tree

1 file changed

+56
-6
lines changed
  • src/libstd/collections/hash

1 file changed

+56
-6
lines changed

src/libstd/collections/hash/set.rs

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -420,9 +420,16 @@ impl<T, S> HashSet<T, S>
420420
/// ```
421421
#[stable(feature = "rust1", since = "1.0.0")]
422422
pub fn intersection<'a>(&'a self, other: &'a HashSet<T, S>) -> Intersection<'a, T, S> {
423-
Intersection {
424-
iter: self.iter(),
425-
other,
423+
if self.len() <= other.len() {
424+
Intersection {
425+
iter: self.iter(),
426+
other,
427+
}
428+
} else {
429+
Intersection {
430+
iter: other.iter(),
431+
other: self,
432+
}
426433
}
427434
}
428435

@@ -446,7 +453,15 @@ impl<T, S> HashSet<T, S>
446453
/// ```
447454
#[stable(feature = "rust1", since = "1.0.0")]
448455
pub fn union<'a>(&'a self, other: &'a HashSet<T, S>) -> Union<'a, T, S> {
449-
Union { iter: self.iter().chain(other.difference(self)) }
456+
if self.len() <= other.len() {
457+
Union {
458+
iter: self.iter().chain(other.difference(self)),
459+
}
460+
} else {
461+
Union {
462+
iter: other.iter().chain(self.difference(other)),
463+
}
464+
}
450465
}
451466

452467
/// Returns the number of elements in the set.
@@ -1504,6 +1519,8 @@ mod test_set {
15041519
fn test_intersection() {
15051520
let mut a = HashSet::new();
15061521
let mut b = HashSet::new();
1522+
assert!(a.intersection(&b).next().is_none());
1523+
assert!(b.intersection(&a).next().is_none());
15071524

15081525
assert!(a.insert(11));
15091526
assert!(a.insert(1));
@@ -1528,6 +1545,22 @@ mod test_set {
15281545
i += 1
15291546
}
15301547
assert_eq!(i, expected.len());
1548+
1549+
assert!(a.insert(9)); // make a bigger than b
1550+
1551+
i = 0;
1552+
for x in a.intersection(&b) {
1553+
assert!(expected.contains(x));
1554+
i += 1
1555+
}
1556+
assert_eq!(i, expected.len());
1557+
1558+
i = 0;
1559+
for x in b.intersection(&a) {
1560+
assert!(expected.contains(x));
1561+
i += 1
1562+
}
1563+
assert_eq!(i, expected.len());
15311564
}
15321565

15331566
#[test]
@@ -1583,11 +1616,11 @@ mod test_set {
15831616
fn test_union() {
15841617
let mut a = HashSet::new();
15851618
let mut b = HashSet::new();
1619+
assert!(a.union(&b).next().is_none());
1620+
assert!(b.union(&a).next().is_none());
15861621

15871622
assert!(a.insert(1));
15881623
assert!(a.insert(3));
1589-
assert!(a.insert(5));
1590-
assert!(a.insert(9));
15911624
assert!(a.insert(11));
15921625
assert!(a.insert(16));
15931626
assert!(a.insert(19));
@@ -1607,6 +1640,23 @@ mod test_set {
16071640
i += 1
16081641
}
16091642
assert_eq!(i, expected.len());
1643+
1644+
assert!(a.insert(9)); // make a bigger than b
1645+
assert!(a.insert(5));
1646+
1647+
i = 0;
1648+
for x in a.union(&b) {
1649+
assert!(expected.contains(x));
1650+
i += 1
1651+
}
1652+
assert_eq!(i, expected.len());
1653+
1654+
i = 0;
1655+
for x in b.union(&a) {
1656+
assert!(expected.contains(x));
1657+
i += 1
1658+
}
1659+
assert_eq!(i, expected.len());
16101660
}
16111661

16121662
#[test]

0 commit comments

Comments
 (0)