diff --git a/src/libnative/io/c_win32.rs b/src/libnative/io/c_win32.rs index 8d75a6739146d..dbbb39b3b7b52 100644 --- a/src/libnative/io/c_win32.rs +++ b/src/libnative/io/c_win32.rs @@ -50,9 +50,9 @@ extern "system" { pub fn ioctlsocket(s: libc::SOCKET, cmd: libc::c_long, argp: *mut libc::c_ulong) -> libc::c_int; pub fn select(nfds: libc::c_int, - readfds: *mut fd_set, - writefds: *mut fd_set, - exceptfds: *mut fd_set, + readfds: *fd_set, + writefds: *fd_set, + exceptfds: *fd_set, timeout: *libc::timeval) -> libc::c_int; pub fn getsockopt(sockfd: libc::SOCKET, level: libc::c_int, diff --git a/src/libnative/io/net.rs b/src/libnative/io/net.rs index be597761b1a8f..93ec23e32ad42 100644 --- a/src/libnative/io/net.rs +++ b/src/libnative/io/net.rs @@ -13,6 +13,7 @@ use std::cast; use std::io::net::ip; use std::io; use std::mem; +use std::os; use std::ptr; use std::rt::rtio; use std::sync::arc::UnsafeArc; @@ -144,6 +145,21 @@ fn last_error() -> io::IoError { super::last_error() } +fn ms_to_timeval(ms: u64) -> libc::timeval { + libc::timeval { + tv_sec: (ms / 1000) as libc::time_t, + tv_usec: ((ms % 1000) * 1000) as libc::suseconds_t, + } +} + +fn timeout(desc: &'static str) -> io::IoError { + io::IoError { + kind: io::TimedOut, + desc: desc, + detail: None, + } +} + #[cfg(windows)] unsafe fn close(sock: sock_t) { let _ = libc::closesocket(sock); } #[cfg(unix)] unsafe fn close(sock: sock_t) { let _ = libc::close(sock); } @@ -271,8 +287,7 @@ impl TcpStream { fn connect_timeout(fd: sock_t, addrp: *libc::sockaddr, len: libc::socklen_t, - timeout: u64) -> IoResult<()> { - use std::os; + timeout_ms: u64) -> IoResult<()> { #[cfg(unix)] use INPROGRESS = libc::EINPROGRESS; #[cfg(windows)] use INPROGRESS = libc::WSAEINPROGRESS; #[cfg(unix)] use WOULDBLOCK = libc::EWOULDBLOCK; @@ -289,12 +304,8 @@ impl TcpStream { os::errno() as int == WOULDBLOCK as int => { let mut set: c::fd_set = unsafe { mem::init() }; c::fd_set(&mut set, fd); - match await(fd, &mut set, timeout) { - 0 => Err(io::IoError { - kind: io::TimedOut, - desc: "connection timed out", - detail: None, - }), + match await(fd, &mut set, timeout_ms) { + 0 => Err(timeout("connection timed out")), -1 => Err(last_error()), _ => { let err: libc::c_int = try!( @@ -338,22 +349,14 @@ impl TcpStream { // Recalculate the timeout each iteration (it is generally // undefined what the value of the 'tv' is after select // returns EINTR). - let timeout = timeout - (::io::timer::now() - start); - let tv = libc::timeval { - tv_sec: (timeout / 1000) as libc::time_t, - tv_usec: ((timeout % 1000) * 1000) as libc::suseconds_t, - }; - c::select(fd + 1, ptr::null(), set as *mut _ as *_, - ptr::null(), &tv) + let tv = ms_to_timeval(timeout - (::io::timer::now() - start)); + c::select(fd + 1, ptr::null(), &*set, ptr::null(), &tv) }) } #[cfg(windows)] fn await(_fd: sock_t, set: &mut c::fd_set, timeout: u64) -> libc::c_int { - let tv = libc::timeval { - tv_sec: (timeout / 1000) as libc::time_t, - tv_usec: ((timeout % 1000) * 1000) as libc::suseconds_t, - }; - unsafe { c::select(1, ptr::mut_null(), set, ptr::mut_null(), &tv) } + let tv = ms_to_timeval(timeout); + unsafe { c::select(1, ptr::null(), &*set, ptr::null(), &tv) } } } @@ -467,7 +470,7 @@ impl Drop for Inner { //////////////////////////////////////////////////////////////////////////////// pub struct TcpListener { - inner: UnsafeArc, + inner: Inner, } impl TcpListener { @@ -477,7 +480,7 @@ impl TcpListener { let (addr, len) = addr_to_sockaddr(addr); let addrp = &addr as *libc::sockaddr_storage; let inner = Inner { fd: fd }; - let ret = TcpListener { inner: UnsafeArc::new(inner) }; + let ret = TcpListener { inner: inner }; // On platforms with Berkeley-derived sockets, this allows // to quickly rebind a socket, without needing to wait for // the OS to clean up the previous one. @@ -498,15 +501,12 @@ impl TcpListener { } } - pub fn fd(&self) -> sock_t { - // This is just a read-only arc so the unsafety is fine - unsafe { (*self.inner.get()).fd } - } + pub fn fd(&self) -> sock_t { self.inner.fd } pub fn native_listen(self, backlog: int) -> IoResult { match unsafe { libc::listen(self.fd(), backlog as libc::c_int) } { -1 => Err(last_error()), - _ => Ok(TcpAcceptor { listener: self }) + _ => Ok(TcpAcceptor { listener: self, deadline: 0 }) } } } @@ -525,12 +525,16 @@ impl rtio::RtioSocket for TcpListener { pub struct TcpAcceptor { listener: TcpListener, + deadline: u64, } impl TcpAcceptor { pub fn fd(&self) -> sock_t { self.listener.fd() } pub fn native_accept(&mut self) -> IoResult { + if self.deadline != 0 { + try!(self.accept_deadline()); + } unsafe { let mut storage: libc::sockaddr_storage = mem::init(); let storagep = &mut storage as *mut libc::sockaddr_storage; @@ -546,6 +550,25 @@ impl TcpAcceptor { } } } + + fn accept_deadline(&mut self) -> IoResult<()> { + let mut set: c::fd_set = unsafe { mem::init() }; + c::fd_set(&mut set, self.fd()); + + match retry(|| { + // If we're past the deadline, then pass a 0 timeout to select() so + // we can poll the status of the socket. + let now = ::io::timer::now(); + let ms = if self.deadline > now {0} else {self.deadline - now}; + let tv = ms_to_timeval(ms); + let n = if cfg!(windows) {1} else {self.fd() as libc::c_int + 1}; + unsafe { c::select(n, &set, ptr::null(), ptr::null(), &tv) } + }) { + -1 => Err(last_error()), + 0 => Err(timeout("accept timed out")), + _ => return Ok(()), + } + } } impl rtio::RtioSocket for TcpAcceptor { @@ -561,6 +584,12 @@ impl rtio::RtioTcpAcceptor for TcpAcceptor { fn accept_simultaneously(&mut self) -> IoResult<()> { Ok(()) } fn dont_accept_simultaneously(&mut self) -> IoResult<()> { Ok(()) } + fn set_timeout(&mut self, timeout: Option) { + self.deadline = match timeout { + None => 0, + Some(t) => ::io::timer::now() + t, + }; + } } //////////////////////////////////////////////////////////////////////////////// diff --git a/src/libnative/io/timer_win32.rs b/src/libnative/io/timer_win32.rs index a15898feb92b7..588ec367d8176 100644 --- a/src/libnative/io/timer_win32.rs +++ b/src/libnative/io/timer_win32.rs @@ -89,6 +89,17 @@ fn helper(input: libc::HANDLE, messages: Receiver) { } } +// returns the current time (in milliseconds) +pub fn now() -> u64 { + let mut ticks_per_s = 0; + assert_eq!(unsafe { libc::QueryPerformanceFrequency(&mut ticks_per_s) }, 1); + let ticks_per_s = if ticks_per_s == 0 {1} else {ticks_per_s}; + let mut ticks = 0; + assert_eq!(unsafe { libc::QueryPerformanceCounter(&mut ticks) }, 1); + + return (ticks as u64 * 1000) / (ticks_per_s as u64); +} + impl Timer { pub fn new() -> IoResult { timer_helper::boot(helper); diff --git a/src/librustuv/net.rs b/src/librustuv/net.rs index 69d978b24334f..f8df9263be1db 100644 --- a/src/librustuv/net.rs +++ b/src/librustuv/net.rs @@ -174,6 +174,9 @@ pub struct TcpListener { pub struct TcpAcceptor { listener: ~TcpListener, + timer: Option, + timeout_tx: Option>, + timeout_rx: Option>, } // TCP watchers (clients/streams) @@ -459,7 +462,12 @@ impl rtio::RtioSocket for TcpListener { impl rtio::RtioTcpListener for TcpListener { fn listen(~self) -> Result<~rtio::RtioTcpAcceptor:Send, IoError> { // create the acceptor object from ourselves - let mut acceptor = ~TcpAcceptor { listener: self }; + let mut acceptor = ~TcpAcceptor { + listener: self, + timer: None, + timeout_tx: None, + timeout_rx: None, + }; let _m = acceptor.fire_homing_missile(); // FIXME: the 128 backlog should be configurable @@ -509,7 +517,37 @@ impl rtio::RtioSocket for TcpAcceptor { impl rtio::RtioTcpAcceptor for TcpAcceptor { fn accept(&mut self) -> Result<~rtio::RtioTcpStream:Send, IoError> { - self.listener.incoming.recv() + match self.timeout_rx { + None => self.listener.incoming.recv(), + Some(ref rx) => { + use std::comm::Select; + + // Poll the incoming channel first (don't rely on the order of + // select just yet). If someone's pending then we should return + // them immediately. + match self.listener.incoming.try_recv() { + Ok(data) => return data, + Err(..) => {} + } + + // Use select to figure out which channel gets ready first. We + // do some custom handling of select to ensure that we never + // actually drain the timeout channel (we'll keep seeing the + // timeout message in the future). + let s = Select::new(); + let mut timeout = s.handle(rx); + let mut data = s.handle(&self.listener.incoming); + unsafe { + timeout.add(); + data.add(); + } + if s.wait() == timeout.id() { + Err(uv_error_to_io_error(UvError(uvll::ECANCELED))) + } else { + self.listener.incoming.recv() + } + } + } } fn accept_simultaneously(&mut self) -> Result<(), IoError> { @@ -525,6 +563,52 @@ impl rtio::RtioTcpAcceptor for TcpAcceptor { uvll::uv_tcp_simultaneous_accepts(self.listener.handle, 0) }) } + + fn set_timeout(&mut self, ms: Option) { + // First, if the timeout is none, clear any previous timeout by dropping + // the timer and transmission channels + let ms = match ms { + None => { + return drop((self.timer.take(), + self.timeout_tx.take(), + self.timeout_rx.take())) + } + Some(ms) => ms, + }; + + // If we have a timeout, lazily initialize the timer which will be used + // to fire when the timeout runs out. + if self.timer.is_none() { + let _m = self.fire_homing_missile(); + let loop_ = Loop::wrap(unsafe { + uvll::get_loop_for_uv_handle(self.listener.handle) + }); + let mut timer = TimerWatcher::new_home(&loop_, self.home().clone()); + unsafe { + timer.set_data(self as *mut _ as *TcpAcceptor); + } + self.timer = Some(timer); + } + + // Once we've got a timer, stop any previous timeout, reset it for the + // current one, and install some new channels to send/receive data on + let timer = self.timer.get_mut_ref(); + timer.stop(); + timer.start(timer_cb, ms, 0); + let (tx, rx) = channel(); + self.timeout_tx = Some(tx); + self.timeout_rx = Some(rx); + + extern fn timer_cb(timer: *uvll::uv_timer_t, status: c_int) { + assert_eq!(status, 0); + let acceptor: &mut TcpAcceptor = unsafe { + &mut *(uvll::get_data_for_uv_handle(timer) as *mut TcpAcceptor) + }; + // This send can never fail because if this timer is active then the + // receiving channel is guaranteed to be alive + acceptor.timeout_tx.get_ref().send(()); + } + } } //////////////////////////////////////////////////////////////////////////////// diff --git a/src/librustuv/timer.rs b/src/librustuv/timer.rs index 3710d97827f28..65ab32c6965a7 100644 --- a/src/librustuv/timer.rs +++ b/src/librustuv/timer.rs @@ -14,7 +14,7 @@ use std::rt::rtio::RtioTimer; use std::rt::task::BlockedTask; use homing::{HomeHandle, HomingIO}; -use super::{UvHandle, ForbidUnwind, ForbidSwitch, wait_until_woken_after}; +use super::{UvHandle, ForbidUnwind, ForbidSwitch, wait_until_woken_after, Loop}; use uvio::UvIoFactory; use uvll; @@ -34,18 +34,21 @@ pub enum NextAction { impl TimerWatcher { pub fn new(io: &mut UvIoFactory) -> ~TimerWatcher { + let handle = io.make_handle(); + let me = ~TimerWatcher::new_home(&io.loop_, handle); + me.install() + } + + pub fn new_home(loop_: &Loop, home: HomeHandle) -> TimerWatcher { let handle = UvHandle::alloc(None::, uvll::UV_TIMER); - assert_eq!(unsafe { - uvll::uv_timer_init(io.uv_loop(), handle) - }, 0); - let me = ~TimerWatcher { + assert_eq!(unsafe { uvll::uv_timer_init(loop_.handle, handle) }, 0); + TimerWatcher { handle: handle, action: None, blocker: None, - home: io.make_handle(), + home: home, id: 0, - }; - return me.install(); + } } pub fn start(&mut self, f: uvll::uv_timer_cb, msecs: u64, period: u64) { diff --git a/src/libstd/io/net/tcp.rs b/src/libstd/io/net/tcp.rs index 4f1e6bd741817..0619c89aac1c4 100644 --- a/src/libstd/io/net/tcp.rs +++ b/src/libstd/io/net/tcp.rs @@ -22,7 +22,7 @@ use io::IoResult; use io::net::ip::SocketAddr; use io::{Reader, Writer, Listener, Acceptor}; use kinds::Send; -use option::{None, Some}; +use option::{None, Some, Option}; use rt::rtio::{IoFactory, LocalIo, RtioSocket, RtioTcpListener}; use rt::rtio::{RtioTcpAcceptor, RtioTcpStream}; @@ -184,6 +184,56 @@ pub struct TcpAcceptor { obj: ~RtioTcpAcceptor:Send } +impl TcpAcceptor { + /// Prevents blocking on all future accepts after `ms` milliseconds have + /// elapsed. + /// + /// This function is used to set a deadline after which this acceptor will + /// time out accepting any connections. The argument is the relative + /// distance, in milliseconds, to a point in the future after which all + /// accepts will fail. + /// + /// If the argument specified is `None`, then any previously registered + /// timeout is cleared. + /// + /// A timeout of `0` can be used to "poll" this acceptor to see if it has + /// any pending connections. All pending connections will be accepted, + /// regardless of whether the timeout has expired or not (the accept will + /// not block in this case). + /// + /// # Example + /// + /// ```no_run + /// # #![allow(experimental)] + /// use std::io::net::tcp::TcpListener; + /// use std::io::net::ip::{SocketAddr, Ipv4Addr}; + /// use std::io::{Listener, Acceptor, TimedOut}; + /// + /// let addr = SocketAddr { ip: Ipv4Addr(127, 0, 0, 1), port: 8482 }; + /// let mut a = TcpListener::bind(addr).listen().unwrap(); + /// + /// // After 100ms have passed, all accepts will fail + /// a.set_timeout(Some(100)); + /// + /// match a.accept() { + /// Ok(..) => println!("accepted a socket"), + /// Err(ref e) if e.kind == TimedOut => { println!("timed out!"); } + /// Err(e) => println!("err: {}", e), + /// } + /// + /// // Reset the timeout and try again + /// a.set_timeout(Some(100)); + /// let socket = a.accept(); + /// + /// // Clear the timeout and block indefinitely waiting for a connection + /// a.set_timeout(None); + /// let socket = a.accept(); + /// ``` + #[experimental = "the type of the argument and name of this function are \ + subject to change"] + pub fn set_timeout(&mut self, ms: Option) { self.obj.set_timeout(ms); } +} + impl Acceptor for TcpAcceptor { fn accept(&mut self) -> IoResult { self.obj.accept().map(TcpStream::new) @@ -191,6 +241,7 @@ impl Acceptor for TcpAcceptor { } #[cfg(test)] +#[allow(experimental)] mod test { use super::*; use io::net::ip::SocketAddr; @@ -749,4 +800,37 @@ mod test { assert!(s.write([1]).is_err()); assert_eq!(s.read_to_end(), Ok(vec!(1))); }) + + iotest!(fn accept_timeout() { + let addr = next_test_ip4(); + let mut a = TcpListener::bind(addr).unwrap().listen().unwrap(); + + a.set_timeout(Some(10)); + + // Make sure we time out once and future invocations also time out + let err = a.accept().err().unwrap(); + assert_eq!(err.kind, TimedOut); + let err = a.accept().err().unwrap(); + assert_eq!(err.kind, TimedOut); + + // Also make sure that even though the timeout is expired that we will + // continue to receive any pending connections. + let l = TcpStream::connect(addr).unwrap(); + for i in range(0, 1001) { + match a.accept() { + Ok(..) => break, + Err(ref e) if e.kind == TimedOut => {} + Err(e) => fail!("error: {}", e), + } + if i == 1000 { fail!("should have a pending connection") } + } + drop(l); + + // Unset the timeout and make sure that this always blocks. + a.set_timeout(None); + spawn(proc() { + drop(TcpStream::connect(addr)); + }); + a.accept().unwrap(); + }) } diff --git a/src/libstd/rt/rtio.rs b/src/libstd/rt/rtio.rs index 0f3fc9c21ced0..5dd148346695d 100644 --- a/src/libstd/rt/rtio.rs +++ b/src/libstd/rt/rtio.rs @@ -200,6 +200,7 @@ pub trait RtioTcpAcceptor : RtioSocket { fn accept(&mut self) -> IoResult<~RtioTcpStream:Send>; fn accept_simultaneously(&mut self) -> IoResult<()>; fn dont_accept_simultaneously(&mut self) -> IoResult<()>; + fn set_timeout(&mut self, timeout: Option); } pub trait RtioTcpStream : RtioSocket {