Skip to content

Commit 7e7986a

Browse files
authored
RUST-1994 Implement happy eyeballs for TCP connection (#1183)
1 parent 06eda2d commit 7e7986a

File tree

1 file changed

+65
-10
lines changed

1 file changed

+65
-10
lines changed

src/runtime/stream.rs

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use std::{
99
use tokio::{io::AsyncWrite, net::TcpStream};
1010

1111
use crate::{
12-
error::{ErrorKind, Result},
12+
error::{Error, ErrorKind, Result},
1313
options::ServerAddress,
1414
runtime,
1515
};
@@ -78,7 +78,12 @@ async fn tcp_try_connect(address: &SocketAddr) -> Result<TcpStream> {
7878
}
7979

8080
async fn tcp_connect(address: &ServerAddress) -> Result<TcpStream> {
81-
let mut socket_addrs: Vec<_> = runtime::resolve_address(address).await?.collect();
81+
// "Happy Eyeballs": try addresses in parallel, interleaving IPv6 and IPv4, preferring IPv6.
82+
// Based on the implementation in https://codeberg.org/KMK/happy-eyeballs.
83+
let (addrs_v6, addrs_v4): (Vec<_>, Vec<_>) = runtime::resolve_address(address)
84+
.await?
85+
.partition(|a| matches!(a, SocketAddr::V6(_)));
86+
let socket_addrs = interleave(addrs_v6, addrs_v4);
8287

8388
if socket_addrs.is_empty() {
8489
return Err(ErrorKind::DnsResolve {
@@ -87,19 +92,58 @@ async fn tcp_connect(address: &ServerAddress) -> Result<TcpStream> {
8792
.into());
8893
}
8994

90-
// After considering various approaches, we decided to do what other drivers do, namely try
91-
// each of the addresses in sequence with a preference for IPv4.
92-
socket_addrs.sort_by_key(|addr| if addr.is_ipv4() { 0 } else { 1 });
95+
fn handle_join(
96+
result: std::result::Result<Result<TcpStream>, tokio::task::JoinError>,
97+
) -> Result<TcpStream> {
98+
match result {
99+
Ok(r) => r,
100+
// JoinError indicates the task was cancelled or paniced, which should never happen
101+
// here.
102+
Err(e) => Err(Error::internal(format!("TCP connect task failure: {}", e))),
103+
}
104+
}
105+
106+
static CONNECTION_ATTEMPT_DELAY: Duration = Duration::from_millis(250);
93107

108+
// Race connections
109+
let mut attempts = tokio::task::JoinSet::new();
94110
let mut connect_error = None;
111+
'spawn: for a in socket_addrs {
112+
attempts.spawn(async move { tcp_try_connect(&a).await });
113+
let sleep = tokio::time::sleep(CONNECTION_ATTEMPT_DELAY);
114+
tokio::pin!(sleep); // required for select!
115+
while !attempts.is_empty() {
116+
tokio::select! {
117+
biased;
118+
connect_res = attempts.join_next() => {
119+
match connect_res.map(handle_join) {
120+
// The gating `while !attempts.is_empty()` should mean this never happens.
121+
None => return Err(Error::internal("empty TCP connect task set")),
122+
// A connection succeeded, return it. The JoinSet will cancel remaining tasks on drop.
123+
Some(Ok(cnx)) => return Ok(cnx),
124+
// A connection failed. Remember the error and wait for any other remaining attempts.
125+
Some(Err(e)) => {
126+
connect_error.get_or_insert(e);
127+
},
128+
}
129+
}
130+
// CONNECTION_ATTEMPT_DELAY expired, spawn a new connection attempt.
131+
_ = &mut sleep => continue 'spawn
132+
}
133+
}
134+
}
95135

96-
for address in &socket_addrs {
97-
connect_error = match tcp_try_connect(address).await {
98-
Ok(stream) => return Ok(stream),
99-
Err(err) => Some(err),
100-
};
136+
// No more address to try. Drain the attempts until one succeeds.
137+
while let Some(result) = attempts.join_next().await {
138+
match handle_join(result) {
139+
Ok(cnx) => return Ok(cnx),
140+
Err(e) => {
141+
connect_error.get_or_insert(e);
142+
}
143+
}
101144
}
102145

146+
// All attempts failed. Return the first error.
103147
Err(connect_error.unwrap_or_else(|| {
104148
ErrorKind::Internal {
105149
message: "connecting to all DNS results failed but no error reported".to_string(),
@@ -108,6 +152,17 @@ async fn tcp_connect(address: &ServerAddress) -> Result<TcpStream> {
108152
}))
109153
}
110154

155+
fn interleave<T>(left: Vec<T>, right: Vec<T>) -> Vec<T> {
156+
let mut out = Vec::with_capacity(left.len() + right.len());
157+
let (mut left, mut right) = (left.into_iter(), right.into_iter());
158+
while let Some(a) = left.next() {
159+
out.push(a);
160+
std::mem::swap(&mut left, &mut right);
161+
}
162+
out.extend(right);
163+
out
164+
}
165+
111166
impl tokio::io::AsyncRead for AsyncStream {
112167
fn poll_read(
113168
mut self: Pin<&mut Self>,

0 commit comments

Comments
 (0)