Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 54 additions & 37 deletions quinn-udp/src/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -719,17 +719,42 @@ fn decode_recv(
len: usize,
) -> io::Result<RecvMeta> {
let name = unsafe { name.assume_init() };
let mut ecn_bits = 0;
let mut dst_ip = None;
let mut interface_index = None;
#[allow(unused_mut)] // only mutable on Linux
let mut stride = len;
let mut ctrl = ControlMetadata {
ecn_bits: 0,
dst_ip: None,
interface_index: None,
stride: len,
};

let cmsg_iter = unsafe { cmsg::Iter::new(hdr) };
for cmsg in cmsg_iter {
ctrl.decode(cmsg);
}

Ok(RecvMeta {
len,
stride: ctrl.stride,
addr: decode_socket_addr(&name)?,
ecn: EcnCodepoint::from_bits(ctrl.ecn_bits),
dst_ip: ctrl.dst_ip,
interface_index: ctrl.interface_index,
})
}

/// Metadata decoded from control messages
struct ControlMetadata {
ecn_bits: u8,
dst_ip: Option<IpAddr>,
interface_index: Option<u32>,
stride: usize,
}

impl ControlMetadata {
/// Decodes a control message and updates the metadata state
fn decode(&mut self, cmsg: &libc::cmsghdr) {
match (cmsg.cmsg_level, cmsg.cmsg_type) {
(libc::IPPROTO_IP, libc::IP_TOS) => unsafe {
ecn_bits = cmsg::decode::<u8, libc::cmsghdr>(cmsg);
self.ecn_bits = cmsg::decode::<u8, libc::cmsghdr>(cmsg);
},
// FreeBSD uses IP_RECVTOS here, and we can be liberal because cmsgs are opt-in.
#[cfg(not(any(
Expand All @@ -739,7 +764,7 @@ fn decode_recv(
solarish
)))]
(libc::IPPROTO_IP, libc::IP_RECVTOS) => unsafe {
ecn_bits = cmsg::decode::<u8, libc::cmsghdr>(cmsg);
self.ecn_bits = cmsg::decode::<u8, libc::cmsghdr>(cmsg);
},
(libc::IPPROTO_IPV6, libc::IPV6_TCLASS) => unsafe {
// Temporary hack around broken macos ABI. Remove once upstream fixes it.
Expand All @@ -748,73 +773,65 @@ fn decode_recv(
if cfg!(apple)
&& cmsg.cmsg_len as usize == libc::CMSG_LEN(mem::size_of::<u8>() as _) as usize
{
ecn_bits = cmsg::decode::<u8, libc::cmsghdr>(cmsg);
self.ecn_bits = cmsg::decode::<u8, libc::cmsghdr>(cmsg);
} else {
ecn_bits = cmsg::decode::<libc::c_int, libc::cmsghdr>(cmsg) as u8;
self.ecn_bits = cmsg::decode::<libc::c_int, libc::cmsghdr>(cmsg) as u8;
}
},
#[cfg(any(target_os = "linux", target_os = "android"))]
(libc::IPPROTO_IP, libc::IP_PKTINFO) => {
let pktinfo = unsafe { cmsg::decode::<libc::in_pktinfo, libc::cmsghdr>(cmsg) };
dst_ip = Some(IpAddr::V4(Ipv4Addr::from(
self.dst_ip = Some(IpAddr::V4(Ipv4Addr::from(
pktinfo.ipi_addr.s_addr.to_ne_bytes(),
)));
interface_index = Some(pktinfo.ipi_ifindex as u32);
self.interface_index = Some(pktinfo.ipi_ifindex as u32);
}
#[cfg(any(bsd, apple))]
(libc::IPPROTO_IP, libc::IP_RECVDSTADDR) => {
let in_addr = unsafe { cmsg::decode::<libc::in_addr, libc::cmsghdr>(cmsg) };
dst_ip = Some(IpAddr::V4(Ipv4Addr::from(in_addr.s_addr.to_ne_bytes())));
self.dst_ip = Some(IpAddr::V4(Ipv4Addr::from(in_addr.s_addr.to_ne_bytes())));
}
(libc::IPPROTO_IPV6, libc::IPV6_PKTINFO) => {
let pktinfo = unsafe { cmsg::decode::<libc::in6_pktinfo, libc::cmsghdr>(cmsg) };
dst_ip = Some(IpAddr::V6(Ipv6Addr::from(pktinfo.ipi6_addr.s6_addr)));
interface_index = Some(pktinfo.ipi6_ifindex as u32);
self.dst_ip = Some(IpAddr::V6(Ipv6Addr::from(pktinfo.ipi6_addr.s6_addr)));
self.interface_index = Some(pktinfo.ipi6_ifindex as u32);
}
#[cfg(any(target_os = "linux", target_os = "android"))]
(libc::SOL_UDP, gro::UDP_GRO) => unsafe {
stride = cmsg::decode::<libc::c_int, libc::cmsghdr>(cmsg) as usize;
self.stride = cmsg::decode::<libc::c_int, libc::cmsghdr>(cmsg) as usize;
},
_ => {}
}
}
}

let addr = match libc::c_int::from(name.ss_family) {
/// Decodes a `sockaddr_storage` into a `SocketAddr`
fn decode_socket_addr(name: &libc::sockaddr_storage) -> io::Result<SocketAddr> {
match libc::c_int::from(name.ss_family) {
libc::AF_INET => {
// Safety: if the ss_family field is AF_INET then storage must be a sockaddr_in.
let addr: &libc::sockaddr_in =
unsafe { &*(&name as *const _ as *const libc::sockaddr_in) };
SocketAddr::V4(SocketAddrV4::new(
unsafe { &*(name as *const _ as *const libc::sockaddr_in) };
Ok(SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::from(addr.sin_addr.s_addr.to_ne_bytes()),
u16::from_be(addr.sin_port),
))
)))
}
libc::AF_INET6 => {
// Safety: if the ss_family field is AF_INET6 then storage must be a sockaddr_in6.
let addr: &libc::sockaddr_in6 =
unsafe { &*(&name as *const _ as *const libc::sockaddr_in6) };
SocketAddr::V6(SocketAddrV6::new(
unsafe { &*(name as *const _ as *const libc::sockaddr_in6) };
Ok(SocketAddr::V6(SocketAddrV6::new(
Ipv6Addr::from(addr.sin6_addr.s6_addr),
u16::from_be(addr.sin6_port),
addr.sin6_flowinfo,
addr.sin6_scope_id,
))
}
f => {
return Err(io::Error::other(format!(
"expected AF_INET or AF_INET6, got {f} in decode_recv"
)));
)))
}
};

Ok(RecvMeta {
len,
stride,
addr,
ecn: EcnCodepoint::from_bits(ecn_bits),
dst_ip,
interface_index,
})
f => Err(io::Error::other(format!(
"expected AF_INET or AF_INET6, got {f}"
))),
}
}

#[cfg(not(apple_slow))]
Expand Down
Loading