Skip to content

Commit b014e57

Browse files
committed
Refactor TcpStream::connect into resolving loop and TcpStream::connect_to
1 parent d89b384 commit b014e57

File tree

1 file changed

+73
-52
lines changed

1 file changed

+73
-52
lines changed

src/net/tcp/stream.rs

Lines changed: 73 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -78,61 +78,10 @@ impl TcpStream {
7878
/// # Ok(()) }) }
7979
/// ```
8080
pub async fn connect<A: ToSocketAddrs>(addrs: A) -> io::Result<TcpStream> {
81-
enum State {
82-
Waiting(TcpStream),
83-
Error(io::Error),
84-
Done,
85-
}
86-
8781
let mut last_err = None;
8882

8983
for addr in addrs.to_socket_addrs()? {
90-
let mut state = {
91-
match mio::net::TcpStream::connect(&addr) {
92-
Ok(mio_stream) => {
93-
#[cfg(unix)]
94-
let stream = TcpStream {
95-
raw_fd: mio_stream.as_raw_fd(),
96-
io_handle: IoHandle::new(mio_stream),
97-
};
98-
99-
#[cfg(windows)]
100-
let stream = TcpStream {
101-
// raw_socket: mio_stream.as_raw_socket(),
102-
io_handle: IoHandle::new(mio_stream),
103-
};
104-
105-
State::Waiting(stream)
106-
}
107-
Err(err) => State::Error(err),
108-
}
109-
};
110-
111-
let res = future::poll_fn(|cx| {
112-
match mem::replace(&mut state, State::Done) {
113-
State::Waiting(stream) => {
114-
// Once we've connected, wait for the stream to be writable as that's when
115-
// the actual connection has been initiated. Once we're writable we check
116-
// for `take_socket_error` to see if the connect actually hit an error or
117-
// not.
118-
//
119-
// If all that succeeded then we ship everything on up.
120-
if let Poll::Pending = stream.io_handle.poll_writable(cx)? {
121-
state = State::Waiting(stream);
122-
return Poll::Pending;
123-
}
124-
125-
if let Some(err) = stream.io_handle.get_ref().take_error()? {
126-
return Poll::Ready(Err(err));
127-
}
128-
129-
Poll::Ready(Ok(stream))
130-
}
131-
State::Error(err) => Poll::Ready(Err(err)),
132-
State::Done => panic!("`TcpStream::connect()` future polled after completion"),
133-
}
134-
})
135-
.await;
84+
let res = Self::connect_to(addr).await;
13685

13786
match res {
13887
Ok(stream) => return Ok(stream),
@@ -148,6 +97,78 @@ impl TcpStream {
14897
}))
14998
}
15099

100+
/// Creates a new TCP stream connected to the specified address.
101+
///
102+
/// This method will create a new TCP socket and attempt to connect it to the `addr`
103+
/// provided. The returned future will be resolved once the stream has successfully
104+
/// connected, or it will return an error if one occurs.
105+
///
106+
/// # Examples
107+
///
108+
/// ```no_run
109+
/// # fn main() -> std::io::Result<()> { async_std::task::block_on(async {
110+
/// #
111+
/// use async_std::net::TcpStream;
112+
///
113+
/// let addr = "127.0.0.1".parse().unwrap();
114+
/// let stream = TcpStream::connect_to(addr).await?;
115+
/// #
116+
/// # Ok(()) }) }
117+
/// ```
118+
pub async fn connect_to(addr: SocketAddr) -> io::Result<TcpStream> {
119+
let stream = mio::net::TcpStream::connect(&addr).map(|mio_stream| {
120+
#[cfg(unix)]
121+
let stream = TcpStream {
122+
raw_fd: mio_stream.as_raw_fd(),
123+
io_handle: IoHandle::new(mio_stream),
124+
};
125+
126+
#[cfg(windows)]
127+
let stream = TcpStream {
128+
// raw_socket: mio_stream.as_raw_socket(),
129+
io_handle: IoHandle::new(mio_stream),
130+
};
131+
132+
stream
133+
});
134+
135+
enum State {
136+
Waiting(TcpStream),
137+
Error(io::Error),
138+
Done,
139+
}
140+
let mut state = match stream {
141+
// TODO replace with .map_or_else(State::Error, State::Waiting)
142+
Ok(stream) => State::Waiting(stream),
143+
Err(err) => State::Error(err),
144+
};
145+
future::poll_fn(|cx| {
146+
match mem::replace(&mut state, State::Done) {
147+
State::Waiting(stream) => {
148+
// Once we've connected, wait for the stream to be writable as that's when
149+
// the actual connection has been initiated. Once we're writable we check
150+
// for `take_socket_error` to see if the connect actually hit an error or
151+
// not.
152+
//
153+
// If all that succeeded then we ship everything on up.
154+
if let Poll::Pending = stream.io_handle.poll_writable(cx)? {
155+
state = State::Waiting(stream);
156+
return Poll::Pending;
157+
}
158+
159+
if let Some(err) = stream.io_handle.get_ref().take_error()? {
160+
return Poll::Ready(Err(err));
161+
}
162+
163+
Poll::Ready(Ok(stream))
164+
}
165+
State::Error(err) => Poll::Ready(Err(err)),
166+
State::Done => panic!("`TcpStream::connect_to()` future polled after completion"),
167+
}
168+
})
169+
.await
170+
}
171+
151172
/// Returns the local address that this stream is connected to.
152173
///
153174
/// ## Examples

0 commit comments

Comments
 (0)