Skip to content

Commit 5ad3c9a

Browse files
committed
Add back keepalives config handling
Also fix connection timeouts to be per-address
1 parent bbf3169 commit 5ad3c9a

File tree

2 files changed

+48
-8
lines changed

2 files changed

+48
-8
lines changed

tokio-postgres/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ pin-project-lite = "0.1"
4949
phf = "0.8"
5050
postgres-protocol = { version = "0.5.0", path = "../postgres-protocol" }
5151
postgres-types = { version = "0.1.2", path = "../postgres-types" }
52+
socket2 = "0.3"
5253
tokio = { version = "0.3", features = ["io-util"] }
5354
tokio-util = { version = "0.4", features = ["codec"] }
5455

tokio-postgres/src/connect_socket.rs

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,67 @@
11
use crate::config::Host;
22
use crate::{Error, Socket};
3+
use socket2::{Domain, Protocol, Type};
34
use std::future::Future;
45
use std::io;
6+
use std::net::SocketAddr;
7+
#[cfg(unix)]
8+
use std::os::unix::io::{FromRawFd, IntoRawFd};
9+
#[cfg(windows)]
10+
use std::os::windows::io::{FromRawSocket, IntoRawSocket};
511
use std::time::Duration;
6-
use tokio::net::TcpStream;
712
#[cfg(unix)]
813
use tokio::net::UnixStream;
14+
use tokio::net::{self, TcpSocket};
915
use tokio::time;
1016

1117
pub(crate) async fn connect_socket(
1218
host: &Host,
1319
port: u16,
1420
connect_timeout: Option<Duration>,
15-
_keepalives: bool,
16-
_keepalives_idle: Duration,
21+
keepalives: bool,
22+
keepalives_idle: Duration,
1723
) -> Result<Socket, Error> {
1824
match host {
1925
Host::Tcp(host) => {
20-
let socket =
21-
connect_with_timeout(TcpStream::connect((&**host, port)), connect_timeout).await?;
22-
socket.set_nodelay(true).map_err(Error::connect)?;
23-
// FIXME support keepalives?
26+
let addrs = net::lookup_host((&**host, port))
27+
.await
28+
.map_err(Error::connect)?;
29+
30+
let mut last_err = None;
31+
32+
for addr in addrs {
33+
let domain = match addr {
34+
SocketAddr::V4(_) => Domain::ipv4(),
35+
SocketAddr::V6(_) => Domain::ipv6(),
36+
};
37+
38+
let socket = socket2::Socket::new(domain, Type::stream(), Some(Protocol::tcp()))
39+
.map_err(Error::connect)?;
40+
socket.set_nonblocking(true).map_err(Error::connect)?;
41+
socket.set_nodelay(true).map_err(Error::connect)?;
42+
if keepalives {
43+
socket
44+
.set_keepalive(Some(keepalives_idle))
45+
.map_err(Error::connect)?;
46+
}
47+
48+
#[cfg(unix)]
49+
let socket = unsafe { TcpSocket::from_raw_fd(socket.into_raw_fd()) };
50+
#[cfg(windows)]
51+
let socket = unsafe { TcpSocket::from_raw_socket(socket.into_raw_socket()) };
52+
53+
match connect_with_timeout(socket.connect(addr), connect_timeout).await {
54+
Ok(socket) => return Ok(Socket::new_tcp(socket)),
55+
Err(e) => last_err = Some(e),
56+
}
57+
}
2458

25-
Ok(Socket::new_tcp(socket))
59+
Err(last_err.unwrap_or_else(|| {
60+
Error::connect(io::Error::new(
61+
io::ErrorKind::InvalidInput,
62+
"could not resolve any addresses",
63+
))
64+
}))
2665
}
2766
#[cfg(unix)]
2867
Host::Unix(path) => {

0 commit comments

Comments
 (0)