@@ -9,7 +9,7 @@ use std::{
9
9
use tokio:: { io:: AsyncWrite , net:: TcpStream } ;
10
10
11
11
use crate :: {
12
- error:: { ErrorKind , Result } ,
12
+ error:: { Error , ErrorKind , Result } ,
13
13
options:: ServerAddress ,
14
14
runtime,
15
15
} ;
@@ -78,7 +78,12 @@ async fn tcp_try_connect(address: &SocketAddr) -> Result<TcpStream> {
78
78
}
79
79
80
80
async fn tcp_connect ( address : & ServerAddress ) -> Result < TcpStream > {
81
- let mut socket_addrs: Vec < _ > = runtime:: resolve_address ( address) . await ?. collect ( ) ;
81
+ // "Happy Eyeballs": try addresses in parallel, interleaving IPv6 and IPv4, preferring IPv6.
82
+ // Based on the implementation in https://codeberg.org/KMK/happy-eyeballs.
83
+ let ( addrs_v6, addrs_v4) : ( Vec < _ > , Vec < _ > ) = runtime:: resolve_address ( address)
84
+ . await ?
85
+ . partition ( |a| matches ! ( a, SocketAddr :: V6 ( _) ) ) ;
86
+ let socket_addrs = interleave ( addrs_v6, addrs_v4) ;
82
87
83
88
if socket_addrs. is_empty ( ) {
84
89
return Err ( ErrorKind :: DnsResolve {
@@ -87,19 +92,58 @@ async fn tcp_connect(address: &ServerAddress) -> Result<TcpStream> {
87
92
. into ( ) ) ;
88
93
}
89
94
90
- // After considering various approaches, we decided to do what other drivers do, namely try
91
- // each of the addresses in sequence with a preference for IPv4.
92
- socket_addrs. sort_by_key ( |addr| if addr. is_ipv4 ( ) { 0 } else { 1 } ) ;
95
+ fn handle_join (
96
+ result : std:: result:: Result < Result < TcpStream > , tokio:: task:: JoinError > ,
97
+ ) -> Result < TcpStream > {
98
+ match result {
99
+ Ok ( r) => r,
100
+ // JoinError indicates the task was cancelled or paniced, which should never happen
101
+ // here.
102
+ Err ( e) => Err ( Error :: internal ( format ! ( "TCP connect task failure: {}" , e) ) ) ,
103
+ }
104
+ }
105
+
106
+ static CONNECTION_ATTEMPT_DELAY : Duration = Duration :: from_millis ( 250 ) ;
93
107
108
+ // Race connections
109
+ let mut attempts = tokio:: task:: JoinSet :: new ( ) ;
94
110
let mut connect_error = None ;
111
+ ' spawn: for a in socket_addrs {
112
+ attempts. spawn ( async move { tcp_try_connect ( & a) . await } ) ;
113
+ let sleep = tokio:: time:: sleep ( CONNECTION_ATTEMPT_DELAY ) ;
114
+ tokio:: pin!( sleep) ; // required for select!
115
+ while !attempts. is_empty ( ) {
116
+ tokio:: select! {
117
+ biased;
118
+ connect_res = attempts. join_next( ) => {
119
+ match connect_res. map( handle_join) {
120
+ // The gating `while !attempts.is_empty()` should mean this never happens.
121
+ None => return Err ( Error :: internal( "empty TCP connect task set" ) ) ,
122
+ // A connection succeeded, return it. The JoinSet will cancel remaining tasks on drop.
123
+ Some ( Ok ( cnx) ) => return Ok ( cnx) ,
124
+ // A connection failed. Remember the error and wait for any other remaining attempts.
125
+ Some ( Err ( e) ) => {
126
+ connect_error. get_or_insert( e) ;
127
+ } ,
128
+ }
129
+ }
130
+ // CONNECTION_ATTEMPT_DELAY expired, spawn a new connection attempt.
131
+ _ = & mut sleep => continue ' spawn
132
+ }
133
+ }
134
+ }
95
135
96
- for address in & socket_addrs {
97
- connect_error = match tcp_try_connect ( address) . await {
98
- Ok ( stream) => return Ok ( stream) ,
99
- Err ( err) => Some ( err) ,
100
- } ;
136
+ // No more address to try. Drain the attempts until one succeeds.
137
+ while let Some ( result) = attempts. join_next ( ) . await {
138
+ match handle_join ( result) {
139
+ Ok ( cnx) => return Ok ( cnx) ,
140
+ Err ( e) => {
141
+ connect_error. get_or_insert ( e) ;
142
+ }
143
+ }
101
144
}
102
145
146
+ // All attempts failed. Return the first error.
103
147
Err ( connect_error. unwrap_or_else ( || {
104
148
ErrorKind :: Internal {
105
149
message : "connecting to all DNS results failed but no error reported" . to_string ( ) ,
@@ -108,6 +152,17 @@ async fn tcp_connect(address: &ServerAddress) -> Result<TcpStream> {
108
152
} ) )
109
153
}
110
154
155
+ fn interleave < T > ( left : Vec < T > , right : Vec < T > ) -> Vec < T > {
156
+ let mut out = Vec :: with_capacity ( left. len ( ) + right. len ( ) ) ;
157
+ let ( mut left, mut right) = ( left. into_iter ( ) , right. into_iter ( ) ) ;
158
+ while let Some ( a) = left. next ( ) {
159
+ out. push ( a) ;
160
+ std:: mem:: swap ( & mut left, & mut right) ;
161
+ }
162
+ out. extend ( right) ;
163
+ out
164
+ }
165
+
111
166
impl tokio:: io:: AsyncRead for AsyncStream {
112
167
fn poll_read (
113
168
mut self : Pin < & mut Self > ,
0 commit comments