|
1 | 1 | use crate::config::Host;
|
2 | 2 | use crate::{Error, Socket};
|
| 3 | +use socket2::{Domain, Protocol, Type}; |
3 | 4 | use std::future::Future;
|
4 | 5 | use std::io;
|
| 6 | +use std::net::SocketAddr; |
| 7 | +#[cfg(unix)] |
| 8 | +use std::os::unix::io::{FromRawFd, IntoRawFd}; |
| 9 | +#[cfg(windows)] |
| 10 | +use std::os::windows::io::{FromRawSocket, IntoRawSocket}; |
5 | 11 | use std::time::Duration;
|
6 |
| -use tokio::net::TcpStream; |
7 | 12 | #[cfg(unix)]
|
8 | 13 | use tokio::net::UnixStream;
|
| 14 | +use tokio::net::{self, TcpSocket}; |
9 | 15 | use tokio::time;
|
10 | 16 |
|
11 | 17 | pub(crate) async fn connect_socket(
|
12 | 18 | host: &Host,
|
13 | 19 | port: u16,
|
14 | 20 | connect_timeout: Option<Duration>,
|
15 |
| - _keepalives: bool, |
16 |
| - _keepalives_idle: Duration, |
| 21 | + keepalives: bool, |
| 22 | + keepalives_idle: Duration, |
17 | 23 | ) -> Result<Socket, Error> {
|
18 | 24 | match host {
|
19 | 25 | Host::Tcp(host) => {
|
20 |
| - let socket = |
21 |
| - connect_with_timeout(TcpStream::connect((&**host, port)), connect_timeout).await?; |
22 |
| - socket.set_nodelay(true).map_err(Error::connect)?; |
23 |
| - // FIXME support keepalives? |
| 26 | + let addrs = net::lookup_host((&**host, port)) |
| 27 | + .await |
| 28 | + .map_err(Error::connect)?; |
| 29 | + |
| 30 | + let mut last_err = None; |
| 31 | + |
| 32 | + for addr in addrs { |
| 33 | + let domain = match addr { |
| 34 | + SocketAddr::V4(_) => Domain::ipv4(), |
| 35 | + SocketAddr::V6(_) => Domain::ipv6(), |
| 36 | + }; |
| 37 | + |
| 38 | + let socket = socket2::Socket::new(domain, Type::stream(), Some(Protocol::tcp())) |
| 39 | + .map_err(Error::connect)?; |
| 40 | + socket.set_nonblocking(true).map_err(Error::connect)?; |
| 41 | + socket.set_nodelay(true).map_err(Error::connect)?; |
| 42 | + if keepalives { |
| 43 | + socket |
| 44 | + .set_keepalive(Some(keepalives_idle)) |
| 45 | + .map_err(Error::connect)?; |
| 46 | + } |
| 47 | + |
| 48 | + #[cfg(unix)] |
| 49 | + let socket = unsafe { TcpSocket::from_raw_fd(socket.into_raw_fd()) }; |
| 50 | + #[cfg(windows)] |
| 51 | + let socket = unsafe { TcpSocket::from_raw_socket(socket.into_raw_socket()) }; |
| 52 | + |
| 53 | + match connect_with_timeout(socket.connect(addr), connect_timeout).await { |
| 54 | + Ok(socket) => return Ok(Socket::new_tcp(socket)), |
| 55 | + Err(e) => last_err = Some(e), |
| 56 | + } |
| 57 | + } |
24 | 58 |
|
25 |
| - Ok(Socket::new_tcp(socket)) |
| 59 | + Err(last_err.unwrap_or_else(|| { |
| 60 | + Error::connect(io::Error::new( |
| 61 | + io::ErrorKind::InvalidInput, |
| 62 | + "could not resolve any addresses", |
| 63 | + )) |
| 64 | + })) |
26 | 65 | }
|
27 | 66 | #[cfg(unix)]
|
28 | 67 | Host::Unix(path) => {
|
|
0 commit comments