Skip to content

Add a SockAddr type #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 13, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
//! # Examples
//!
//! ```no_run
//! use std::net::SocketAddr;
//! use socket2::{Socket, Domain, Type};
//!
//! // create a TCP listener bound to two addresses
//! let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
//!
//! socket.bind(&"127.0.0.1:12345".parse().unwrap()).unwrap();
//! socket.bind(&"127.0.0.1:12346".parse().unwrap()).unwrap();
//! socket.bind(&"127.0.0.1:12345".parse::<SocketAddr>().unwrap().into()).unwrap();
//! socket.bind(&"127.0.0.1:12346".parse::<SocketAddr>().unwrap().into()).unwrap();
//! socket.listen(128).unwrap();
//!
//! let listener = socket.into_tcp_listener();
Expand All @@ -45,6 +46,10 @@

use utils::NetInt;

#[cfg(unix)] use libc::{sockaddr_storage, socklen_t};
#[cfg(windows)] use winapi::{SOCKADDR_STORAGE as sockaddr_storage, socklen_t};

mod sockaddr;
mod socket;
mod utils;

Expand All @@ -63,13 +68,14 @@ mod utils;
/// # Examples
///
/// ```no_run
/// use socket2::{Socket, Domain, Type};
/// use std::net::SocketAddr;
/// use socket2::{Socket, Domain, Type, SockAddr};
///
/// // create a TCP listener bound to two addresses
/// let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
///
/// socket.bind(&"127.0.0.1:12345".parse().unwrap()).unwrap();
/// socket.bind(&"127.0.0.1:12346".parse().unwrap()).unwrap();
/// socket.bind(&"127.0.0.1:12345".parse::<SocketAddr>().unwrap().into()).unwrap();
/// socket.bind(&"127.0.0.1:12346".parse::<SocketAddr>().unwrap().into()).unwrap();
/// socket.listen(128).unwrap();
///
/// let listener = socket.into_tcp_listener();
Expand All @@ -79,6 +85,15 @@ pub struct Socket {
inner: sys::Socket,
}

/// The address of a socket.
///
/// `SockAddr`s may be constructed directly to and from the standard library
/// `SocketAddr`, `SocketAddrV4`, and `SocketAddrV6` types.
pub struct SockAddr {
storage: sockaddr_storage,
len: socklen_t,
}

/// Specification of the communication domain for a socket.
///
/// This is a newtype wrapper around an integer which provides a nicer API in
Expand Down Expand Up @@ -111,5 +126,3 @@ pub struct Type(i32);
pub struct Protocol(i32);

fn hton<I: NetInt>(i: I) -> I { i.to_be() }

fn ntoh<I: NetInt>(i: I) -> I { I::from_be(i) }
139 changes: 139 additions & 0 deletions src/sockaddr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
use std::net::{SocketAddrV4, SocketAddrV6, SocketAddr};
use std::mem;
use std::ptr;
use std::fmt;

#[cfg(unix)]
use libc::{sockaddr, sockaddr_storage, sockaddr_in, sockaddr_in6, sa_family_t, socklen_t, AF_INET,
AF_INET6};
#[cfg(windows)]
use winapi::{SOCKADDR as sockaddr, SOCKADDR_STORAGE as sockaddr_storage,
SOCKADDR_IN as sockaddr_in, SOCKADDR_IN6 as sockaddr_in6,
ADDRESS_FAMILY as sa_family_t, socklen_t, AF_INET, AF_INET6};

use SockAddr;

impl fmt::Debug for SockAddr {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
let mut builder = fmt.debug_struct("SockAddr");
builder.field("family", &self.family());
if let Some(addr) = self.as_inet() {
builder.field("inet", &addr);
} else if let Some(addr) = self.as_inet6() {
builder.field("inet6", &addr);
}
builder.finish()
}
}

impl SockAddr {
/// Constructs a `SockAddr` from its raw components.
pub unsafe fn from_raw_parts(addr: *const sockaddr, len: socklen_t) -> SockAddr {
let mut storage = mem::uninitialized::<sockaddr_storage>();
ptr::copy_nonoverlapping(addr as *const _ as *const u8,
&mut storage as *mut _ as *mut u8,
len as usize);

SockAddr {
storage: storage,
len: len,
}
}

unsafe fn as_<T>(&self, family: sa_family_t) -> Option<T> {
if self.storage.ss_family != family {
return None;
}

Some(mem::transmute_copy(&self.storage))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One day this may bite us, but let's hope it's way far away!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was somewhat terrified when I saw that this was how the addr2raw function was implemented as well, but it works.

}

/// Returns this address as a `SocketAddrV4` if it is in the `AF_INET`
/// family.
pub fn as_inet(&self) -> Option<SocketAddrV4> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this could return &SocketAddrV4 instead? (a relatively large struct to copy as we go down)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That will lock us into guaranteeing that SocketAddrV4 has the same representation as sockaddr_in, but if that's okay then sure.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wait good point, let's not do that.

unsafe { self.as_(AF_INET as sa_family_t) }
}

/// Returns this address as a `SocketAddrV4` if it is in the `AF_INET6`
/// family.
pub fn as_inet6(&self) -> Option<SocketAddrV6> {
unsafe { self.as_(AF_INET6 as sa_family_t) }
}

/// Returns this address's family.
pub fn family(&self) -> sa_family_t {
self.storage.ss_family
}

/// Returns the size of this address in bytes.
pub fn len(&self) -> socklen_t {
self.len
}

/// Returns a raw pointer to the address.
pub fn as_ptr(&self) -> *const sockaddr {
&self.storage as *const _ as *const _
}
}

// SocketAddrV4 and SocketAddrV6 are just wrappers around sockaddr_in and sockaddr_in6

// check to make sure that the sizes at least match up
fn _size_checks(v4: SocketAddrV4, v6: SocketAddrV6) {
unsafe {
mem::transmute::<SocketAddrV4, sockaddr_in>(v4);
mem::transmute::<SocketAddrV6, sockaddr_in6>(v6);
}
}

impl From<SocketAddrV4> for SockAddr {
fn from(addr: SocketAddrV4) -> SockAddr {
unsafe {
SockAddr::from_raw_parts(&addr as *const _ as *const _,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this throw in some assert_eq! about the size of SocketAddrV4? (same for v6 below)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done via transmute to check statically.

mem::size_of::<SocketAddrV4>() as socklen_t)
}
}
}

impl From<SocketAddrV6> for SockAddr {
fn from(addr: SocketAddrV6) -> SockAddr {
unsafe {
SockAddr::from_raw_parts(&addr as *const _ as *const _,
mem::size_of::<SocketAddrV6>() as socklen_t)
}
}
}

impl From<SocketAddr> for SockAddr {
fn from(addr: SocketAddr) -> SockAddr {
match addr {
SocketAddr::V4(addr) => addr.into(),
SocketAddr::V6(addr) => addr.into(),
}
}
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn inet() {
let raw = "127.0.0.1:80".parse::<SocketAddrV4>().unwrap();
let addr = SockAddr::from(raw);
assert!(addr.as_inet6().is_none());
let addr = addr.as_inet().unwrap();
assert_eq!(raw, addr);
}

#[test]
fn inet6() {
let raw = "[2001:db8::ff00:42:8329]:80"
.parse::<SocketAddrV6>()
.unwrap();
let addr = SockAddr::from(raw);
assert!(addr.as_inet().is_none());
let addr = addr.as_inet6().unwrap();
assert_eq!(raw, addr);
}
}
28 changes: 15 additions & 13 deletions src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

use std::fmt;
use std::io::{self, Read, Write};
use std::net::{self, SocketAddr, Ipv4Addr, Ipv6Addr, Shutdown};
use std::net::{self, Ipv4Addr, Ipv6Addr, Shutdown};
use std::time::Duration;

#[cfg(unix)]
Expand All @@ -19,7 +19,7 @@ use libc as c;
use winapi as c;

use sys;
use {Socket, Protocol, Domain, Type};
use {Socket, Protocol, Domain, Type, SockAddr};

impl Socket {
/// Creates a new socket ready to be configured.
Expand Down Expand Up @@ -58,7 +58,7 @@ impl Socket {
///
/// An error will be returned if `listen` or `connect` has already been
/// called on this builder.
pub fn connect(&self, addr: &SocketAddr) -> io::Result<()> {
pub fn connect(&self, addr: &SockAddr) -> io::Result<()> {
self.inner.connect(addr)
}

Expand All @@ -81,15 +81,15 @@ impl Socket {
///
/// If the connection request times out, it may still be processing in the
/// background - a second call to `connect` or `connect_timeout` may fail.
pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
pub fn connect_timeout(&self, addr: &SockAddr, timeout: Duration) -> io::Result<()> {
self.inner.connect_timeout(addr, timeout)
}

/// Binds this socket to the specified address.
///
/// This function directly corresponds to the bind(2) function on Windows
/// and Unix.
pub fn bind(&self, addr: &SocketAddr) -> io::Result<()> {
pub fn bind(&self, addr: &SockAddr) -> io::Result<()> {
self.inner.bind(addr)
}

Expand All @@ -110,19 +110,19 @@ impl Socket {
/// This function will block the calling thread until a new connection is
/// established. When established, the corresponding `Socket` and the
/// remote peer's address will be returned.
pub fn accept(&self) -> io::Result<(Socket, SocketAddr)> {
pub fn accept(&self) -> io::Result<(Socket, SockAddr)> {
self.inner.accept().map(|(socket, addr)| {
(Socket { inner: socket }, addr)
})
}

/// Returns the socket address of the local half of this TCP connection.
pub fn local_addr(&self) -> io::Result<SocketAddr> {
pub fn local_addr(&self) -> io::Result<SockAddr> {
self.inner.local_addr()
}

/// Returns the socket address of the remote peer of this TCP connection.
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
pub fn peer_addr(&self) -> io::Result<SockAddr> {
self.inner.peer_addr()
}

Expand Down Expand Up @@ -184,7 +184,7 @@ impl Socket {

/// Receives data from the socket. On success, returns the number of bytes
/// read and the address from whence the data came.
pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)> {
self.inner.recv_from(buf)
}

Expand All @@ -195,7 +195,7 @@ impl Socket {
///
/// On success, returns the number of bytes peeked and the address from
/// whence the data came.
pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)> {
self.inner.peek_from(buf)
}

Expand All @@ -214,7 +214,7 @@ impl Socket {
///
/// This is typically used on UDP or datagram-oriented sockets. On success
/// returns the number of bytes that were sent.
pub fn send_to(&self, buf: &[u8], addr: &SocketAddr) -> io::Result<usize> {
pub fn send_to(&self, buf: &[u8], addr: &SockAddr) -> io::Result<usize> {
self.inner.send_to(buf, addr)
}

Expand Down Expand Up @@ -693,12 +693,14 @@ impl From<Protocol> for i32 {

#[cfg(test)]
mod test {
use std::net::SocketAddr;

use super::*;

#[test]
fn connect_timeout_unrouteable() {
// this IP is unroutable, so connections should always time out
let addr: SocketAddr = "10.255.255.1:80".parse().unwrap();
let addr = "10.255.255.1:80".parse::<SocketAddr>().unwrap().into();

let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
match socket.connect_timeout(&addr, Duration::from_millis(250)) {
Expand All @@ -711,7 +713,7 @@ mod test {
#[test]
fn connect_timeout_valid() {
let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
socket.bind(&"127.0.0.1:0".parse().unwrap()).unwrap();
socket.bind(&"127.0.0.1:0".parse::<SocketAddr>().unwrap().into()).unwrap();
socket.listen(128).unwrap();

let addr = socket.local_addr().unwrap();
Expand Down
Loading