Source code

Revision control

Copy as Markdown

Other Tools

//! Extensions and types for the standard networking primitives.
//!
//! This module contains a number of extension traits for the types in
//! `std::net` for Windows-specific functionality.
use std::cmp;
use std::io;
use std::mem;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
use std::net::{SocketAddr, TcpListener, TcpStream, UdpSocket};
use std::os::windows::prelude::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use winapi::ctypes::*;
use winapi::shared::guiddef::*;
use winapi::shared::in6addr::{in6_addr_u, IN6_ADDR};
use winapi::shared::inaddr::{in_addr_S_un, IN_ADDR};
use winapi::shared::minwindef::*;
use winapi::shared::minwindef::{FALSE, TRUE};
use winapi::shared::ntdef::*;
use winapi::shared::ws2def::SOL_SOCKET;
use winapi::shared::ws2def::*;
use winapi::shared::ws2ipdef::*;
use winapi::um::minwinbase::*;
use winapi::um::winsock2::*;
/// A type to represent a buffer in which a socket address will be stored.
///
/// This type is used with the `recv_from_overlapped` function on the
/// `UdpSocketExt` trait to provide space for the overlapped I/O operation to
/// fill in the address upon completion.
#[derive(Clone, Copy)]
pub struct SocketAddrBuf {
buf: SOCKADDR_STORAGE,
len: c_int,
}
/// A type to represent a buffer in which an accepted socket's address will be
/// stored.
///
/// This type is used with the `accept_overlapped` method on the
/// `TcpListenerExt` trait to provide space for the overlapped I/O operation to
/// fill in the socket addresses upon completion.
#[repr(C)]
pub struct AcceptAddrsBuf {
// For AcceptEx we've got the restriction that the addresses passed in that
// buffer need to be at least 16 bytes more than the maximum address length
// for the protocol in question, so add some extra here and there
local: SOCKADDR_STORAGE,
_pad1: [u8; 16],
remote: SOCKADDR_STORAGE,
_pad2: [u8; 16],
}
/// The parsed return value of `AcceptAddrsBuf`.
pub struct AcceptAddrs<'a> {
local: LPSOCKADDR,
local_len: c_int,
remote: LPSOCKADDR,
remote_len: c_int,
_data: &'a AcceptAddrsBuf,
}
struct WsaExtension {
guid: GUID,
val: AtomicUsize,
}
/// Additional methods for the `TcpStream` type in the standard library.
pub trait TcpStreamExt {
/// Execute an overlapped read I/O operation on this TCP stream.
///
/// This function will issue an overlapped I/O read (via `WSARecv`) on this
/// socket. The provided buffer will be filled in when the operation
/// completes and the given `OVERLAPPED` instance is used to track the
/// overlapped operation.
///
/// If the operation succeeds, `Ok(Some(n))` is returned indicating how
/// many bytes were read. If the operation returns an error indicating that
/// the I/O is currently pending, `Ok(None)` is returned. Otherwise, the
/// error associated with the operation is returned and no overlapped
/// operation is enqueued.
///
/// The number of bytes read will be returned as part of the completion
/// notification when the I/O finishes.
///
/// # Unsafety
///
/// This function is unsafe because the kernel requires that the `buf` and
/// `overlapped` pointers are valid until the end of the I/O operation. The
/// kernel also requires that `overlapped` is unique for this I/O operation
/// and is not in use for any other I/O.
///
/// To safely use this function callers must ensure that these two input
/// pointers are valid until the I/O operation is completed, typically via
/// completion ports and waiting to receive the completion notification on
/// the port.
unsafe fn read_overlapped(
&self,
buf: &mut [u8],
overlapped: *mut OVERLAPPED,
) -> io::Result<Option<usize>>;
/// Execute an overlapped write I/O operation on this TCP stream.
///
/// This function will issue an overlapped I/O write (via `WSASend`) on this
/// socket. The provided buffer will be written when the operation completes
/// and the given `OVERLAPPED` instance is used to track the overlapped
/// operation.
///
/// If the operation succeeds, `Ok(Some(n))` is returned where `n` is the
/// number of bytes that were written. If the operation returns an error
/// indicating that the I/O is currently pending, `Ok(None)` is returned.
/// Otherwise, the error associated with the operation is returned and no
/// overlapped operation is enqueued.
///
/// The number of bytes written will be returned as part of the completion
/// notification when the I/O finishes.
///
/// # Unsafety
///
/// This function is unsafe because the kernel requires that the `buf` and
/// `overlapped` pointers are valid until the end of the I/O operation. The
/// kernel also requires that `overlapped` is unique for this I/O operation
/// and is not in use for any other I/O.
///
/// To safely use this function callers must ensure that these two input
/// pointers are valid until the I/O operation is completed, typically via
/// completion ports and waiting to receive the completion notification on
/// the port.
unsafe fn write_overlapped(
&self,
buf: &[u8],
overlapped: *mut OVERLAPPED,
) -> io::Result<Option<usize>>;
/// Attempt to consume the internal socket in this builder by executing an
/// overlapped connect operation.
///
/// This function will issue a connect operation to the address specified on
/// the underlying socket, flagging it as an overlapped operation which will
/// complete asynchronously. If successful this function will return the
/// corresponding TCP stream.
///
/// The `buf` argument provided is an initial buffer of data that should be
/// sent after the connection is initiated. It's acceptable to
/// pass an empty slice here.
///
/// This function will also return whether the connect immediately
/// succeeded or not. If `None` is returned then the I/O operation is still
/// pending and will complete at a later date, and if `Some(bytes)` is
/// returned then that many bytes were transferred.
///
/// Note that to succeed this requires that the underlying socket has
/// previously been bound via a call to `bind` to a local address.
///
/// # Unsafety
///
/// This function is unsafe because the kernel requires that the
/// `overlapped` and `buf` pointers to be valid until the end of the I/O
/// operation. The kernel also requires that `overlapped` is unique for
/// this I/O operation and is not in use for any other I/O.
///
/// To safely use this function callers must ensure that this pointer is
/// valid until the I/O operation is completed, typically via completion
/// ports and waiting to receive the completion notification on the port.
unsafe fn connect_overlapped(
&self,
addr: &SocketAddr,
buf: &[u8],
overlapped: *mut OVERLAPPED,
) -> io::Result<Option<usize>>;
/// Once a `connect_overlapped` has finished, this function needs to be
/// called to finish the connect operation.
///
/// Currently this just calls `setsockopt` with `SO_UPDATE_CONNECT_CONTEXT`
/// to ensure that further functions like `getpeername` and `getsockname`
/// work correctly.
fn connect_complete(&self) -> io::Result<()>;
/// Calls the `GetOverlappedResult` function to get the result of an
/// overlapped operation for this handle.
///
/// This function takes the `OVERLAPPED` argument which must have been used
/// to initiate an overlapped I/O operation, and returns either the
/// successful number of bytes transferred during the operation or an error
/// if one occurred, along with the results of the `lpFlags` parameter of
/// the relevant operation, if applicable.
///
/// # Unsafety
///
/// This function is unsafe as `overlapped` must have previously been used
/// to execute an operation for this handle, and it must also be a valid
/// pointer to an `OVERLAPPED` instance.
///
/// # Panics
///
/// This function will panic
unsafe fn result(&self, overlapped: *mut OVERLAPPED) -> io::Result<(usize, u32)>;
}
/// Additional methods for the `UdpSocket` type in the standard library.
pub trait UdpSocketExt {
/// Execute an overlapped receive I/O operation on this UDP socket.
///
/// This function will issue an overlapped I/O read (via `WSARecvFrom`) on
/// this socket. The provided buffer will be filled in when the operation
/// completes, the source from where the data came from will be written to
/// `addr`, and the given `OVERLAPPED` instance is used to track the
/// overlapped operation.
///
/// If the operation succeeds, `Ok(Some(n))` is returned where `n` is the
/// number of bytes that were read. If the operation returns an error
/// indicating that the I/O is currently pending, `Ok(None)` is returned.
/// Otherwise, the error associated with the operation is returned and no
/// overlapped operation is enqueued.
///
/// The number of bytes read will be returned as part of the completion
/// notification when the I/O finishes.
///
/// # Unsafety
///
/// This function is unsafe because the kernel requires that the `buf`,
/// `addr`, and `overlapped` pointers are valid until the end of the I/O
/// operation. The kernel also requires that `overlapped` is unique for this
/// I/O operation and is not in use for any other I/O.
///
/// To safely use this function callers must ensure that these two input
/// pointers are valid until the I/O operation is completed, typically via
/// completion ports and waiting to receive the completion notification on
/// the port.
unsafe fn recv_from_overlapped(
&self,
buf: &mut [u8],
addr: *mut SocketAddrBuf,
overlapped: *mut OVERLAPPED,
) -> io::Result<Option<usize>>;
/// Execute an overlapped receive I/O operation on this UDP socket.
///
/// This function will issue an overlapped I/O read (via `WSARecv`) on
/// this socket. The provided buffer will be filled in when the operation
/// completes, the source from where the data came from will be written to
/// `addr`, and the given `OVERLAPPED` instance is used to track the
/// overlapped operation.
///
/// If the operation succeeds, `Ok(Some(n))` is returned where `n` is the
/// number of bytes that were read. If the operation returns an error
/// indicating that the I/O is currently pending, `Ok(None)` is returned.
/// Otherwise, the error associated with the operation is returned and no
/// overlapped operation is enqueued.
///
/// The number of bytes read will be returned as part of the completion
/// notification when the I/O finishes.
///
/// # Unsafety
///
/// This function is unsafe because the kernel requires that the `buf`,
/// and `overlapped` pointers are valid until the end of the I/O
/// operation. The kernel also requires that `overlapped` is unique for this
/// I/O operation and is not in use for any other I/O.
///
/// To safely use this function callers must ensure that these two input
/// pointers are valid until the I/O operation is completed, typically via
/// completion ports and waiting to receive the completion notification on
/// the port.
unsafe fn recv_overlapped(
&self,
buf: &mut [u8],
overlapped: *mut OVERLAPPED,
) -> io::Result<Option<usize>>;
/// Execute an overlapped send I/O operation on this UDP socket.
///
/// This function will issue an overlapped I/O write (via `WSASendTo`) on
/// this socket to the address specified by `addr`. The provided buffer will
/// be written when the operation completes and the given `OVERLAPPED`
/// instance is used to track the overlapped operation.
///
/// If the operation succeeds, `Ok(Some(n0)` is returned where `n` byte
/// were written. If the operation returns an error indicating that the I/O
/// is currently pending, `Ok(None)` is returned. Otherwise, the error
/// associated with the operation is returned and no overlapped operation
/// is enqueued.
///
/// The number of bytes written will be returned as part of the completion
/// notification when the I/O finishes.
///
/// # Unsafety
///
/// This function is unsafe because the kernel requires that the `buf` and
/// `overlapped` pointers are valid until the end of the I/O operation. The
/// kernel also requires that `overlapped` is unique for this I/O operation
/// and is not in use for any other I/O.
///
/// To safely use this function callers must ensure that these two input
/// pointers are valid until the I/O operation is completed, typically via
/// completion ports and waiting to receive the completion notification on
/// the port.
unsafe fn send_to_overlapped(
&self,
buf: &[u8],
addr: &SocketAddr,
overlapped: *mut OVERLAPPED,
) -> io::Result<Option<usize>>;
/// Execute an overlapped send I/O operation on this UDP socket.
///
/// This function will issue an overlapped I/O write (via `WSASend`) on
/// this socket to the address it was previously connected to. The provided
/// buffer will be written when the operation completes and the given `OVERLAPPED`
/// instance is used to track the overlapped operation.
///
/// If the operation succeeds, `Ok(Some(n0)` is returned where `n` byte
/// were written. If the operation returns an error indicating that the I/O
/// is currently pending, `Ok(None)` is returned. Otherwise, the error
/// associated with the operation is returned and no overlapped operation
/// is enqueued.
///
/// The number of bytes written will be returned as part of the completion
/// notification when the I/O finishes.
///
/// # Unsafety
///
/// This function is unsafe because the kernel requires that the `buf` and
/// `overlapped` pointers are valid until the end of the I/O operation. The
/// kernel also requires that `overlapped` is unique for this I/O operation
/// and is not in use for any other I/O.
///
/// To safely use this function callers must ensure that these two input
/// pointers are valid until the I/O operation is completed, typically via
/// completion ports and waiting to receive the completion notification on
/// the port.
unsafe fn send_overlapped(
&self,
buf: &[u8],
overlapped: *mut OVERLAPPED,
) -> io::Result<Option<usize>>;
/// Calls the `GetOverlappedResult` function to get the result of an
/// overlapped operation for this handle.
///
/// This function takes the `OVERLAPPED` argument which must have been used
/// to initiate an overlapped I/O operation, and returns either the
/// successful number of bytes transferred during the operation or an error
/// if one occurred, along with the results of the `lpFlags` parameter of
/// the relevant operation, if applicable.
///
/// # Unsafety
///
/// This function is unsafe as `overlapped` must have previously been used
/// to execute an operation for this handle, and it must also be a valid
/// pointer to an `OVERLAPPED` instance.
///
/// # Panics
///
/// This function will panic
unsafe fn result(&self, overlapped: *mut OVERLAPPED) -> io::Result<(usize, u32)>;
}
/// Additional methods for the `TcpListener` type in the standard library.
pub trait TcpListenerExt {
/// Perform an accept operation on this listener, accepting a connection in
/// an overlapped fashion.
///
/// This function will issue an I/O request to accept an incoming connection
/// with the specified overlapped instance. The `socket` provided must be a
/// configured but not bound or connected socket, and if successful this
/// will consume the internal socket of the builder to return a TCP stream.
///
/// The `addrs` buffer provided will be filled in with the local and remote
/// addresses of the connection upon completion.
///
/// If the accept succeeds immediately, `Ok(true)` is returned. If
/// the connect indicates that the I/O is currently pending, `Ok(false)` is
/// returned. Otherwise, the error associated with the operation is
/// returned and no overlapped operation is enqueued.
///
/// # Unsafety
///
/// This function is unsafe because the kernel requires that the
/// `addrs` and `overlapped` pointers are valid until the end of the I/O
/// operation. The kernel also requires that `overlapped` is unique for this
/// I/O operation and is not in use for any other I/O.
///
/// To safely use this function callers must ensure that the pointers are
/// valid until the I/O operation is completed, typically via completion
/// ports and waiting to receive the completion notification on the port.
unsafe fn accept_overlapped(
&self,
socket: &TcpStream,
addrs: &mut AcceptAddrsBuf,
overlapped: *mut OVERLAPPED,
) -> io::Result<bool>;
/// Once an `accept_overlapped` has finished, this function needs to be
/// called to finish the accept operation.
///
/// Currently this just calls `setsockopt` with `SO_UPDATE_ACCEPT_CONTEXT`
/// to ensure that further functions like `getpeername` and `getsockname`
/// work correctly.
fn accept_complete(&self, socket: &TcpStream) -> io::Result<()>;
/// Calls the `GetOverlappedResult` function to get the result of an
/// overlapped operation for this handle.
///
/// This function takes the `OVERLAPPED` argument which must have been used
/// to initiate an overlapped I/O operation, and returns either the
/// successful number of bytes transferred during the operation or an error
/// if one occurred, along with the results of the `lpFlags` parameter of
/// the relevant operation, if applicable.
///
/// # Unsafety
///
/// This function is unsafe as `overlapped` must have previously been used
/// to execute an operation for this handle, and it must also be a valid
/// pointer to an `OVERLAPPED` instance.
///
/// # Panics
///
/// This function will panic
unsafe fn result(&self, overlapped: *mut OVERLAPPED) -> io::Result<(usize, u32)>;
}
#[doc(hidden)]
trait NetInt {
fn from_be(i: Self) -> Self;
fn to_be(&self) -> Self;
}
macro_rules! doit {
($($t:ident)*) => ($(impl NetInt for $t {
fn from_be(i: Self) -> Self { <$t>::from_be(i) }
fn to_be(&self) -> Self { <$t>::to_be(*self) }
})*)
}
doit! { i8 i16 i32 i64 isize u8 u16 u32 u64 usize }
// fn hton<I: NetInt>(i: I) -> I { i.to_be() }
fn ntoh<I: NetInt>(i: I) -> I {
I::from_be(i)
}
fn last_err() -> io::Result<Option<usize>> {
let err = unsafe { WSAGetLastError() };
if err == WSA_IO_PENDING as i32 {
Ok(None)
} else {
Err(io::Error::from_raw_os_error(err))
}
}
fn cvt(i: c_int, size: DWORD) -> io::Result<Option<usize>> {
if i == SOCKET_ERROR {
last_err()
} else {
Ok(Some(size as usize))
}
}
/// A type with the same memory layout as `SOCKADDR`. Used in converting Rust level
/// SocketAddr* types into their system representation. The benefit of this specific
/// type over using `SOCKADDR_STORAGE` is that this type is exactly as large as it
/// needs to be and not a lot larger. And it can be initialized cleaner from Rust.
#[repr(C)]
pub(crate) union SocketAddrCRepr {
v4: SOCKADDR_IN,
v6: SOCKADDR_IN6_LH,
}
impl SocketAddrCRepr {
pub(crate) fn as_ptr(&self) -> *const SOCKADDR {
self as *const _ as *const SOCKADDR
}
}
fn socket_addr_to_ptrs(addr: &SocketAddr) -> (SocketAddrCRepr, c_int) {
match *addr {
SocketAddr::V4(ref a) => {
let sin_addr = unsafe {
let mut s_un = mem::zeroed::<in_addr_S_un>();
*s_un.S_addr_mut() = u32::from_ne_bytes(a.ip().octets());
IN_ADDR { S_un: s_un }
};
let sockaddr_in = SOCKADDR_IN {
sin_family: AF_INET as ADDRESS_FAMILY,
sin_port: a.port().to_be(),
sin_addr,
sin_zero: [0; 8],
};
let sockaddr = SocketAddrCRepr { v4: sockaddr_in };
(sockaddr, mem::size_of::<SOCKADDR_IN>() as c_int)
}
SocketAddr::V6(ref a) => {
let sin6_addr = unsafe {
let mut u = mem::zeroed::<in6_addr_u>();
*u.Byte_mut() = a.ip().octets();
IN6_ADDR { u }
};
let u = unsafe {
let mut u = mem::zeroed::<SOCKADDR_IN6_LH_u>();
*u.sin6_scope_id_mut() = a.scope_id();
u
};
let sockaddr_in6 = SOCKADDR_IN6_LH {
sin6_family: AF_INET6 as ADDRESS_FAMILY,
sin6_port: a.port().to_be(),
sin6_addr,
sin6_flowinfo: a.flowinfo(),
u,
};
let sockaddr = SocketAddrCRepr { v6: sockaddr_in6 };
(sockaddr, mem::size_of::<SOCKADDR_IN6_LH>() as c_int)
}
}
}
unsafe fn ptrs_to_socket_addr(ptr: *const SOCKADDR, len: c_int) -> Option<SocketAddr> {
if (len as usize) < mem::size_of::<c_int>() {
return None;
}
match (*ptr).sa_family as i32 {
AF_INET if len as usize >= mem::size_of::<SOCKADDR_IN>() => {
let b = &*(ptr as *const SOCKADDR_IN);
let ip = ntoh(*b.sin_addr.S_un.S_addr());
let ip = Ipv4Addr::new(
(ip >> 24) as u8,
(ip >> 16) as u8,
(ip >> 8) as u8,
(ip >> 0) as u8,
);
Some(SocketAddr::V4(SocketAddrV4::new(ip, ntoh(b.sin_port))))
}
AF_INET6 if len as usize >= mem::size_of::<SOCKADDR_IN6_LH>() => {
let b = &*(ptr as *const SOCKADDR_IN6_LH);
let arr = b.sin6_addr.u.Byte();
let ip = Ipv6Addr::new(
((arr[0] as u16) << 8) | (arr[1] as u16),
((arr[2] as u16) << 8) | (arr[3] as u16),
((arr[4] as u16) << 8) | (arr[5] as u16),
((arr[6] as u16) << 8) | (arr[7] as u16),
((arr[8] as u16) << 8) | (arr[9] as u16),
((arr[10] as u16) << 8) | (arr[11] as u16),
((arr[12] as u16) << 8) | (arr[13] as u16),
((arr[14] as u16) << 8) | (arr[15] as u16),
);
let addr = SocketAddrV6::new(
ip,
ntoh(b.sin6_port),
ntoh(b.sin6_flowinfo),
ntoh(*b.u.sin6_scope_id()),
);
Some(SocketAddr::V6(addr))
}
_ => None,
}
}
unsafe fn slice2buf(slice: &[u8]) -> WSABUF {
WSABUF {
len: cmp::min(slice.len(), <u_long>::max_value() as usize) as u_long,
buf: slice.as_ptr() as *mut _,
}
}
unsafe fn result(socket: SOCKET, overlapped: *mut OVERLAPPED) -> io::Result<(usize, u32)> {
let mut transferred = 0;
let mut flags = 0;
let r = WSAGetOverlappedResult(socket, overlapped, &mut transferred, FALSE, &mut flags);
if r == 0 {
Err(io::Error::last_os_error())
} else {
Ok((transferred as usize, flags))
}
}
impl TcpStreamExt for TcpStream {
unsafe fn read_overlapped(
&self,
buf: &mut [u8],
overlapped: *mut OVERLAPPED,
) -> io::Result<Option<usize>> {
let mut buf = slice2buf(buf);
let mut flags = 0;
let mut bytes_read: DWORD = 0;
let r = WSARecv(
self.as_raw_socket() as SOCKET,
&mut buf,
1,
&mut bytes_read,
&mut flags,
overlapped,
None,
);
cvt(r, bytes_read)
}
unsafe fn write_overlapped(
&self,
buf: &[u8],
overlapped: *mut OVERLAPPED,
) -> io::Result<Option<usize>> {
let mut buf = slice2buf(buf);
let mut bytes_written = 0;
// Note here that we capture the number of bytes written. The
// documentation on MSDN, however, states:
//
// > Use NULL for this parameter if the lpOverlapped parameter is not
// > NULL to avoid potentially erroneous results. This parameter can be
// > NULL only if the lpOverlapped parameter is not NULL.
//
// If we're not passing a null overlapped pointer here, then why are we
// then capturing the number of bytes! Well so it turns out that this is
// clearly faster to learn the bytes here rather than later calling
// `WSAGetOverlappedResult`, and in practice almost all implementations
// use this anyway [1].
//
// As a result we use this to and report back the result.
//
let r = WSASend(
self.as_raw_socket() as SOCKET,
&mut buf,
1,
&mut bytes_written,
0,
overlapped,
None,
);
cvt(r, bytes_written)
}
unsafe fn connect_overlapped(
&self,
addr: &SocketAddr,
buf: &[u8],
overlapped: *mut OVERLAPPED,
) -> io::Result<Option<usize>> {
connect_overlapped(self.as_raw_socket() as SOCKET, addr, buf, overlapped)
}
fn connect_complete(&self) -> io::Result<()> {
const SO_UPDATE_CONNECT_CONTEXT: c_int = 0x7010;
let result = unsafe {
setsockopt(
self.as_raw_socket() as SOCKET,
SOL_SOCKET,
SO_UPDATE_CONNECT_CONTEXT,
0 as *const _,
0,
)
};
if result == 0 {
Ok(())
} else {
Err(io::Error::last_os_error())
}
}
unsafe fn result(&self, overlapped: *mut OVERLAPPED) -> io::Result<(usize, u32)> {
result(self.as_raw_socket() as SOCKET, overlapped)
}
}
unsafe fn connect_overlapped(
socket: SOCKET,
addr: &SocketAddr,
buf: &[u8],
overlapped: *mut OVERLAPPED,
) -> io::Result<Option<usize>> {
static CONNECTEX: WsaExtension = WsaExtension {
guid: GUID {
Data1: 0x25a207b9,
Data2: 0xddf3,
Data3: 0x4660,
Data4: [0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e],
},
val: AtomicUsize::new(0),
};
type ConnectEx = unsafe extern "system" fn(
SOCKET,
*const SOCKADDR,
c_int,
PVOID,
DWORD,
LPDWORD,
LPOVERLAPPED,
) -> BOOL;
let ptr = CONNECTEX.get(socket)?;
assert!(ptr != 0);
let connect_ex = mem::transmute::<_, ConnectEx>(ptr);
let (addr_buf, addr_len) = socket_addr_to_ptrs(addr);
let mut bytes_sent: DWORD = 0;
let r = connect_ex(
socket,
addr_buf.as_ptr(),
addr_len,
buf.as_ptr() as *mut _,
buf.len() as u32,
&mut bytes_sent,
overlapped,
);
if r == TRUE {
Ok(Some(bytes_sent as usize))
} else {
last_err()
}
}
impl UdpSocketExt for UdpSocket {
unsafe fn recv_from_overlapped(
&self,
buf: &mut [u8],
addr: *mut SocketAddrBuf,
overlapped: *mut OVERLAPPED,
) -> io::Result<Option<usize>> {
let mut buf = slice2buf(buf);
let mut flags = 0;
let mut received_bytes: DWORD = 0;
let r = WSARecvFrom(
self.as_raw_socket() as SOCKET,
&mut buf,
1,
&mut received_bytes,
&mut flags,
&mut (*addr).buf as *mut _ as *mut _,
&mut (*addr).len,
overlapped,
None,
);
cvt(r, received_bytes)
}
unsafe fn recv_overlapped(
&self,
buf: &mut [u8],
overlapped: *mut OVERLAPPED,
) -> io::Result<Option<usize>> {
let mut buf = slice2buf(buf);
let mut flags = 0;
let mut received_bytes: DWORD = 0;
let r = WSARecv(
self.as_raw_socket() as SOCKET,
&mut buf,
1,
&mut received_bytes,
&mut flags,
overlapped,
None,
);
cvt(r, received_bytes)
}
unsafe fn send_to_overlapped(
&self,
buf: &[u8],
addr: &SocketAddr,
overlapped: *mut OVERLAPPED,
) -> io::Result<Option<usize>> {
let (addr_buf, addr_len) = socket_addr_to_ptrs(addr);
let mut buf = slice2buf(buf);
let mut sent_bytes = 0;
let r = WSASendTo(
self.as_raw_socket() as SOCKET,
&mut buf,
1,
&mut sent_bytes,
0,
addr_buf.as_ptr() as *const _,
addr_len,
overlapped,
None,
);
cvt(r, sent_bytes)
}
unsafe fn send_overlapped(
&self,
buf: &[u8],
overlapped: *mut OVERLAPPED,
) -> io::Result<Option<usize>> {
let mut buf = slice2buf(buf);
let mut sent_bytes = 0;
let r = WSASend(
self.as_raw_socket() as SOCKET,
&mut buf,
1,
&mut sent_bytes,
0,
overlapped,
None,
);
cvt(r, sent_bytes)
}
unsafe fn result(&self, overlapped: *mut OVERLAPPED) -> io::Result<(usize, u32)> {
result(self.as_raw_socket() as SOCKET, overlapped)
}
}
impl TcpListenerExt for TcpListener {
unsafe fn accept_overlapped(
&self,
socket: &TcpStream,
addrs: &mut AcceptAddrsBuf,
overlapped: *mut OVERLAPPED,
) -> io::Result<bool> {
static ACCEPTEX: WsaExtension = WsaExtension {
guid: GUID {
Data1: 0xb5367df1,
Data2: 0xcbac,
Data3: 0x11cf,
Data4: [0x95, 0xca, 0x00, 0x80, 0x5f, 0x48, 0xa1, 0x92],
},
val: AtomicUsize::new(0),
};
type AcceptEx = unsafe extern "system" fn(
SOCKET,
SOCKET,
PVOID,
DWORD,
DWORD,
DWORD,
LPDWORD,
LPOVERLAPPED,
) -> BOOL;
let ptr = ACCEPTEX.get(self.as_raw_socket() as SOCKET)?;
assert!(ptr != 0);
let accept_ex = mem::transmute::<_, AcceptEx>(ptr);
let mut bytes = 0;
let (a, b, c, d) = (*addrs).args();
let r = accept_ex(
self.as_raw_socket() as SOCKET,
socket.as_raw_socket() as SOCKET,
a,
b,
c,
d,
&mut bytes,
overlapped,
);
let succeeded = if r == TRUE {
true
} else {
last_err()?;
false
};
Ok(succeeded)
}
fn accept_complete(&self, socket: &TcpStream) -> io::Result<()> {
const SO_UPDATE_ACCEPT_CONTEXT: c_int = 0x700B;
let me = self.as_raw_socket();
let result = unsafe {
setsockopt(
socket.as_raw_socket() as SOCKET,
SOL_SOCKET,
SO_UPDATE_ACCEPT_CONTEXT,
&me as *const _ as *const _,
mem::size_of_val(&me) as c_int,
)
};
if result == 0 {
Ok(())
} else {
Err(io::Error::last_os_error())
}
}
unsafe fn result(&self, overlapped: *mut OVERLAPPED) -> io::Result<(usize, u32)> {
result(self.as_raw_socket() as SOCKET, overlapped)
}
}
impl SocketAddrBuf {
/// Creates a new blank socket address buffer.
///
/// This should be used before a call to `recv_from_overlapped` overlapped
/// to create an instance to pass down.
pub fn new() -> SocketAddrBuf {
SocketAddrBuf {
buf: unsafe { mem::zeroed() },
len: mem::size_of::<SOCKADDR_STORAGE>() as c_int,
}
}
/// Parses this buffer to return a standard socket address.
///
/// This function should be called after the buffer has been filled in with
/// a call to `recv_from_overlapped` being completed. It will interpret the
/// address filled in and return the standard socket address type.
///
/// If an error is encountered then `None` is returned.
pub fn to_socket_addr(&self) -> Option<SocketAddr> {
unsafe { ptrs_to_socket_addr(&self.buf as *const _ as *const _, self.len) }
}
}
static GETACCEPTEXSOCKADDRS: WsaExtension = WsaExtension {
guid: GUID {
Data1: 0xb5367df2,
Data2: 0xcbac,
Data3: 0x11cf,
Data4: [0x95, 0xca, 0x00, 0x80, 0x5f, 0x48, 0xa1, 0x92],
},
val: AtomicUsize::new(0),
};
type GetAcceptExSockaddrs = unsafe extern "system" fn(
PVOID,
DWORD,
DWORD,
DWORD,
*mut LPSOCKADDR,
LPINT,
*mut LPSOCKADDR,
LPINT,
);
impl AcceptAddrsBuf {
/// Creates a new blank buffer ready to be passed to a call to
/// `accept_overlapped`.
pub fn new() -> AcceptAddrsBuf {
unsafe { mem::zeroed() }
}
/// Parses the data contained in this address buffer, returning the parsed
/// result if successful.
///
/// This function can be called after a call to `accept_overlapped` has
/// succeeded to parse out the data that was written in.
pub fn parse(&self, socket: &TcpListener) -> io::Result<AcceptAddrs> {
let mut ret = AcceptAddrs {
local: 0 as *mut _,
local_len: 0,
remote: 0 as *mut _,
remote_len: 0,
_data: self,
};
let ptr = GETACCEPTEXSOCKADDRS.get(socket.as_raw_socket() as SOCKET)?;
assert!(ptr != 0);
unsafe {
let get_sockaddrs = mem::transmute::<_, GetAcceptExSockaddrs>(ptr);
let (a, b, c, d) = self.args();
get_sockaddrs(
a,
b,
c,
d,
&mut ret.local,
&mut ret.local_len,
&mut ret.remote,
&mut ret.remote_len,
);
Ok(ret)
}
}
fn args(&self) -> (PVOID, DWORD, DWORD, DWORD) {
let remote_offset = unsafe { &(*(0 as *const AcceptAddrsBuf)).remote as *const _ as usize };
(
self as *const _ as *mut _,
0,
remote_offset as DWORD,
(mem::size_of_val(self) - remote_offset) as DWORD,
)
}
}
impl<'a> AcceptAddrs<'a> {
/// Returns the local socket address contained in this buffer.
pub fn local(&self) -> Option<SocketAddr> {
unsafe { ptrs_to_socket_addr(self.local, self.local_len) }
}
/// Returns the remote socket address contained in this buffer.
pub fn remote(&self) -> Option<SocketAddr> {
unsafe { ptrs_to_socket_addr(self.remote, self.remote_len) }
}
}
impl WsaExtension {
fn get(&self, socket: SOCKET) -> io::Result<usize> {
let prev = self.val.load(Ordering::SeqCst);
if prev != 0 && !cfg!(debug_assertions) {
return Ok(prev);
}
let mut ret = 0 as usize;
let mut bytes = 0;
let r = unsafe {
WSAIoctl(
socket,
SIO_GET_EXTENSION_FUNCTION_POINTER,
&self.guid as *const _ as *mut _,
mem::size_of_val(&self.guid) as DWORD,
&mut ret as *mut _ as *mut _,
mem::size_of_val(&ret) as DWORD,
&mut bytes,
0 as *mut _,
None,
)
};
cvt(r, 0).map(|_| {
debug_assert_eq!(bytes as usize, mem::size_of_val(&ret));
debug_assert!(prev == 0 || prev == ret);
self.val.store(ret, Ordering::SeqCst);
ret
})
}
}
#[cfg(test)]
mod tests {
use std::io::prelude::*;
use std::net::{
IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV6, TcpListener, TcpStream, UdpSocket,
};
use std::slice;
use std::thread;
use socket2::{Domain, Socket, Type};
use crate::iocp::CompletionPort;
use crate::net::{AcceptAddrsBuf, TcpListenerExt};
use crate::net::{SocketAddrBuf, TcpStreamExt, UdpSocketExt};
use crate::Overlapped;
fn each_ip(f: &mut dyn FnMut(SocketAddr)) {
f(t!("127.0.0.1:0".parse()));
f(t!("[::1]:0".parse()));
}
#[test]
fn tcp_read() {
each_ip(&mut |addr| {
let l = t!(TcpListener::bind(addr));
let addr = t!(l.local_addr());
let t = thread::spawn(move || {
let mut a = t!(l.accept()).0;
t!(a.write_all(&[1, 2, 3]));
});
let cp = t!(CompletionPort::new(1));
let s = t!(TcpStream::connect(addr));
t!(cp.add_socket(1, &s));
let mut b = [0; 10];
let a = Overlapped::zero();
unsafe {
t!(s.read_overlapped(&mut b, a.raw()));
}
let status = t!(cp.get(None));
assert_eq!(status.bytes_transferred(), 3);
assert_eq!(status.token(), 1);
assert_eq!(status.overlapped(), a.raw());
assert_eq!(&b[0..3], &[1, 2, 3]);
t!(t.join());
})
}
#[test]
fn tcp_write() {
each_ip(&mut |addr| {
let l = t!(TcpListener::bind(addr));
let addr = t!(l.local_addr());
let t = thread::spawn(move || {
let mut a = t!(l.accept()).0;
let mut b = [0; 10];
let n = t!(a.read(&mut b));
assert_eq!(n, 3);
assert_eq!(&b[0..3], &[1, 2, 3]);
});
let cp = t!(CompletionPort::new(1));
let s = t!(TcpStream::connect(addr));
t!(cp.add_socket(1, &s));
let b = [1, 2, 3];
let a = Overlapped::zero();
unsafe {
t!(s.write_overlapped(&b, a.raw()));
}
let status = t!(cp.get(None));
assert_eq!(status.bytes_transferred(), 3);
assert_eq!(status.token(), 1);
assert_eq!(status.overlapped(), a.raw());
t!(t.join());
})
}
#[test]
fn tcp_connect() {
each_ip(&mut |addr_template| {
let l = t!(TcpListener::bind(addr_template));
let addr = t!(l.local_addr());
let t = thread::spawn(move || {
t!(l.accept());
});
let cp = t!(CompletionPort::new(1));
let domain = Domain::for_address(addr);
let socket = t!(Socket::new(domain, Type::STREAM, None));
t!(socket.bind(&addr_template.into()));
let socket = TcpStream::from(socket);
t!(cp.add_socket(1, &socket));
let a = Overlapped::zero();
unsafe {
t!(socket.connect_overlapped(&addr, &[], a.raw()));
}
let status = t!(cp.get(None));
assert_eq!(status.bytes_transferred(), 0);
assert_eq!(status.token(), 1);
assert_eq!(status.overlapped(), a.raw());
t!(socket.connect_complete());
t!(t.join());
})
}
#[test]
fn udp_recv_from() {
each_ip(&mut |addr| {
let a = t!(UdpSocket::bind(addr));
let b = t!(UdpSocket::bind(addr));
let a_addr = t!(a.local_addr());
let b_addr = t!(b.local_addr());
let t = thread::spawn(move || {
t!(a.send_to(&[1, 2, 3], b_addr));
});
let cp = t!(CompletionPort::new(1));
t!(cp.add_socket(1, &b));
let mut buf = [0; 10];
let a = Overlapped::zero();
let mut addr = SocketAddrBuf::new();
unsafe {
t!(b.recv_from_overlapped(&mut buf, &mut addr, a.raw()));
}
let status = t!(cp.get(None));
assert_eq!(status.bytes_transferred(), 3);
assert_eq!(status.token(), 1);
assert_eq!(status.overlapped(), a.raw());
assert_eq!(&buf[..3], &[1, 2, 3]);
assert_eq!(addr.to_socket_addr(), Some(a_addr));
t!(t.join());
})
}
#[test]
fn udp_recv() {
each_ip(&mut |addr| {
let a = t!(UdpSocket::bind(addr));
let b = t!(UdpSocket::bind(addr));
let a_addr = t!(a.local_addr());
let b_addr = t!(b.local_addr());
assert!(b.connect(a_addr).is_ok());
assert!(a.connect(b_addr).is_ok());
let t = thread::spawn(move || {
t!(a.send_to(&[1, 2, 3], b_addr));
});
let cp = t!(CompletionPort::new(1));
t!(cp.add_socket(1, &b));
let mut buf = [0; 10];
let a = Overlapped::zero();
unsafe {
t!(b.recv_overlapped(&mut buf, a.raw()));
}
let status = t!(cp.get(None));
assert_eq!(status.bytes_transferred(), 3);
assert_eq!(status.token(), 1);
assert_eq!(status.overlapped(), a.raw());
assert_eq!(&buf[..3], &[1, 2, 3]);
t!(t.join());
})
}
#[test]
fn udp_send_to() {
each_ip(&mut |addr| {
let a = t!(UdpSocket::bind(addr));
let b = t!(UdpSocket::bind(addr));
let a_addr = t!(a.local_addr());
let b_addr = t!(b.local_addr());
let t = thread::spawn(move || {
let mut b = [0; 100];
let (n, addr) = t!(a.recv_from(&mut b));
assert_eq!(n, 3);
assert_eq!(addr, b_addr);
assert_eq!(&b[..3], &[1, 2, 3]);
});
let cp = t!(CompletionPort::new(1));
t!(cp.add_socket(1, &b));
let a = Overlapped::zero();
unsafe {
t!(b.send_to_overlapped(&[1, 2, 3], &a_addr, a.raw()));
}
let status = t!(cp.get(None));
assert_eq!(status.bytes_transferred(), 3);
assert_eq!(status.token(), 1);
assert_eq!(status.overlapped(), a.raw());
t!(t.join());
})
}
#[test]
fn udp_send() {
each_ip(&mut |addr| {
let a = t!(UdpSocket::bind(addr));
let b = t!(UdpSocket::bind(addr));
let a_addr = t!(a.local_addr());
let b_addr = t!(b.local_addr());
assert!(b.connect(a_addr).is_ok());
assert!(a.connect(b_addr).is_ok());
let t = thread::spawn(move || {
let mut b = [0; 100];
let (n, addr) = t!(a.recv_from(&mut b));
assert_eq!(n, 3);
assert_eq!(addr, b_addr);
assert_eq!(&b[..3], &[1, 2, 3]);
});
let cp = t!(CompletionPort::new(1));
t!(cp.add_socket(1, &b));
let a = Overlapped::zero();
unsafe {
t!(b.send_overlapped(&[1, 2, 3], a.raw()));
}
let status = t!(cp.get(None));
assert_eq!(status.bytes_transferred(), 3);
assert_eq!(status.token(), 1);
assert_eq!(status.overlapped(), a.raw());
t!(t.join());
})
}
#[test]
fn tcp_accept() {
each_ip(&mut |addr_template| {
let l = t!(TcpListener::bind(addr_template));
let addr = t!(l.local_addr());
let t = thread::spawn(move || {
let socket = t!(TcpStream::connect(addr));
(socket.local_addr().unwrap(), socket.peer_addr().unwrap())
});
let cp = t!(CompletionPort::new(1));
let domain = Domain::for_address(addr);
let socket = TcpStream::from(t!(Socket::new(domain, Type::STREAM, None)));
t!(cp.add_socket(1, &l));
let a = Overlapped::zero();
let mut addrs = AcceptAddrsBuf::new();
unsafe {
t!(l.accept_overlapped(&socket, &mut addrs, a.raw()));
}
let status = t!(cp.get(None));
assert_eq!(status.bytes_transferred(), 0);
assert_eq!(status.token(), 1);
assert_eq!(status.overlapped(), a.raw());
t!(l.accept_complete(&socket));
let (remote, local) = t!(t.join());
let addrs = addrs.parse(&l).unwrap();
assert_eq!(addrs.local(), Some(local));
assert_eq!(addrs.remote(), Some(remote));
})
}
#[test]
fn sockaddr_convert_4() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(3, 4, 5, 6)), 0xabcd);
let (raw_addr, addr_len) = super::socket_addr_to_ptrs(&addr);
assert_eq!(addr_len, 16);
let addr_bytes =
unsafe { slice::from_raw_parts(raw_addr.as_ptr() as *const u8, addr_len as usize) };
assert_eq!(
addr_bytes,
&[2, 0, 0xab, 0xcd, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0]
);
}
#[test]
fn sockaddr_convert_v6() {
let port = 0xabcd;
let flowinfo = 0x12345678;
let scope_id = 0x87654321;
let addr = SocketAddr::V6(SocketAddrV6::new(
Ipv6Addr::new(
0x0102, 0x0304, 0x0506, 0x0708, 0x090a, 0x0b0c, 0x0d0e, 0x0f10,
),
port,
flowinfo,
scope_id,
));
let (raw_addr, addr_len) = super::socket_addr_to_ptrs(&addr);
assert_eq!(addr_len, 28);
let addr_bytes =
unsafe { slice::from_raw_parts(raw_addr.as_ptr() as *const u8, addr_len as usize) };
assert_eq!(
addr_bytes,
&[
23, 0, // AF_INET6
0xab, 0xcd, // Port
0x78, 0x56, 0x34, 0x12, // flowinfo
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e,
0x0f, 0x10, // IP
0x21, 0x43, 0x65, 0x87, // scope_id
]
);
}
}