Skip to content

Commit

Permalink
feat: add msg_control access for UdpSocket::recvmsg
Browse files Browse the repository at this point in the history
  • Loading branch information
OneOfOne committed Apr 13, 2024
1 parent 6895908 commit 2820354
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 18 deletions.
25 changes: 19 additions & 6 deletions src/io/recvmsg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,23 @@ use std::{
{boxed::Box, io, net::SocketAddr},
};

pub(crate) struct RecvMsg<T> {
pub(crate) struct RecvMsg<T, U = Vec<u8>> {
#[allow(dead_code)]
fd: SharedFd,
pub(crate) buf: Vec<T>,
#[allow(dead_code)]
io_slices: Vec<IoSliceMut<'static>>,
pub(crate) socket_addr: Box<SockAddr>,
pub(crate) msg_control: Option<U>,
pub(crate) msghdr: Box<libc::msghdr>,
}

impl<T: BoundedBufMut> Op<RecvMsg<T>> {
pub(crate) fn recvmsg(fd: &SharedFd, mut bufs: Vec<T>) -> io::Result<Op<RecvMsg<T>>> {
impl<T: BoundedBufMut, U: BoundedBufMut> Op<RecvMsg<T, U>> {
pub(crate) fn recvmsg(
fd: &SharedFd,
mut bufs: Vec<T>,
mut msg_control: Option<U>,
) -> io::Result<Op<RecvMsg<T, U>>> {
use io_uring_ooo::{opcode, types};

let mut io_slices = Vec::with_capacity(bufs.len());
Expand All @@ -35,6 +40,10 @@ impl<T: BoundedBufMut> Op<RecvMsg<T>> {
msghdr.msg_iovlen = io_slices.len() as _;
msghdr.msg_name = socket_addr.as_ptr() as *mut libc::c_void;
msghdr.msg_namelen = socket_addr.len();
if let Some(msg_control) = &mut msg_control {
msghdr.msg_control = msg_control.stable_mut_ptr().cast();
msghdr.msg_controllen = msg_control.bytes_total();
}

CONTEXT.with(|x| {
x.handle().expect("Not in a runtime context").submit_op(
Expand All @@ -43,6 +52,7 @@ impl<T: BoundedBufMut> Op<RecvMsg<T>> {
buf: bufs,
io_slices,
socket_addr,
msg_control,
msghdr,
},
|recv_from| {
Expand All @@ -57,11 +67,12 @@ impl<T: BoundedBufMut> Op<RecvMsg<T>> {
}
}

impl<T> Completable for RecvMsg<T>
impl<T, U> Completable for RecvMsg<T, U>
where
T: BoundedBufMut,
U: BoundedBufMut,
{
type Output = BufResult<(usize, SocketAddr), Vec<T>>;
type Output = BufResult<(usize, SocketAddr, Option<U>), Vec<T>>;

fn complete(self, cqe: CqeResult) -> Self::Output {
// Convert the operation result to `usize`
Expand All @@ -71,6 +82,8 @@ where

let socket_addr = (*self.socket_addr).as_socket();

let msg_control = self.msg_control;

let res = res.map(|n| {
let socket_addr: SocketAddr = socket_addr.unwrap();

Expand All @@ -89,7 +102,7 @@ where
break;
}
}
(n, socket_addr)
(n, socket_addr, msg_control)
});

(res, bufs)
Expand Down
7 changes: 4 additions & 3 deletions src/io/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,12 @@ impl Socket {
op.await
}

pub(crate) async fn recvmsg<T: BoundedBufMut>(
pub(crate) async fn recvmsg<T: BoundedBufMut, U: BoundedBufMut>(
&self,
buf: Vec<T>,
) -> crate::BufResult<(usize, SocketAddr), Vec<T>> {
let op = Op::recvmsg(&self.fd, buf).unwrap();
msg_control: Option<U>,
) -> crate::BufResult<(usize, SocketAddr, Option<U>), Vec<T>> {
let op = Op::recvmsg(&self.fd, buf, msg_control).unwrap();
op.await
}

Expand Down
10 changes: 6 additions & 4 deletions src/net/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,14 @@ impl UdpSocket {

/// Receives a single datagram message on the socket, into multiple buffers
///
/// On success, returns the number of bytes read and the origin.
pub async fn recvmsg<T: BoundedBufMut>(
/// On success, returns the number of bytes read, the origin, and the msg_control, as modified
/// by the kernel.
pub async fn recvmsg<T: BoundedBufMut, U: BoundedBufMut>(
&self,
buf: Vec<T>,
) -> crate::BufResult<(usize, SocketAddr), Vec<T>> {
self.inner.recvmsg(buf).await
msg_control: Option<U>,
) -> crate::BufResult<(usize, SocketAddr, Option<U>), Vec<T>> {
self.inner.recvmsg(buf, msg_control).await
}

/// Reads a packet of data from the socket into the buffer.
Expand Down
9 changes: 4 additions & 5 deletions tests/fs_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,18 +318,17 @@ fn basic_fallocate() {
#[test]
fn iopoll_without_sqpoll() {
use std::os::unix::fs::OpenOptionsExt;
let mut builder = tokio_uring::builder();
builder.uring_builder(&tokio_uring::uring_builder().setup_iopoll());
let runtime = tokio_uring::Runtime::new(&builder).unwrap();
let mut builder = tokio_uring_ooo::builder();
builder.uring_builder(tokio_uring_ooo::uring_builder().setup_iopoll());
let runtime = tokio_uring_ooo::Runtime::new(&builder).unwrap();
let tmp = tempfile();
runtime.block_on(async {
let file = std::fs::OpenOptions::new()
.write(true)
.custom_flags(libc::O_DIRECT)
.open(tmp.path())
.unwrap();
let file = tokio_uring::fs::File::from_std(file);

let file = tokio_uring_ooo::fs::File::from_std(file);
let layout = std::alloc::Layout::from_size_align(512, 512).unwrap();
let buf = unsafe {
let raw = std::alloc::alloc(layout);
Expand Down

0 comments on commit 2820354

Please sign in to comment.