From 28203542a36a9ae093af5ac4d94941c295230bfa Mon Sep 17 00:00:00 2001 From: OneOfOne Date: Sat, 13 Apr 2024 14:40:07 -0500 Subject: [PATCH] feat: add msg_control access for UdpSocket::recvmsg https://github.com/tokio-rs/tokio-uring/pull/275 --- src/io/recvmsg.rs | 25 +++++++++++++++++++------ src/io/socket.rs | 7 ++++--- src/net/udp.rs | 10 ++++++---- tests/fs_file.rs | 9 ++++----- 4 files changed, 33 insertions(+), 18 deletions(-) diff --git a/src/io/recvmsg.rs b/src/io/recvmsg.rs index 81185174..d26cf5d9 100644 --- a/src/io/recvmsg.rs +++ b/src/io/recvmsg.rs @@ -7,18 +7,23 @@ use std::{ {boxed::Box, io, net::SocketAddr}, }; -pub(crate) struct RecvMsg { +pub(crate) struct RecvMsg> { #[allow(dead_code)] fd: SharedFd, pub(crate) buf: Vec, #[allow(dead_code)] io_slices: Vec>, pub(crate) socket_addr: Box, + pub(crate) msg_control: Option, pub(crate) msghdr: Box, } -impl Op> { - pub(crate) fn recvmsg(fd: &SharedFd, mut bufs: Vec) -> io::Result>> { +impl Op> { + pub(crate) fn recvmsg( + fd: &SharedFd, + mut bufs: Vec, + mut msg_control: Option, + ) -> io::Result>> { use io_uring_ooo::{opcode, types}; let mut io_slices = Vec::with_capacity(bufs.len()); @@ -35,6 +40,10 @@ impl Op> { msghdr.msg_iovlen = io_slices.len() as _; msghdr.msg_name = socket_addr.as_ptr() as *mut libc::c_void; msghdr.msg_namelen = socket_addr.len(); + if let Some(msg_control) = &mut msg_control { + msghdr.msg_control = msg_control.stable_mut_ptr().cast(); + msghdr.msg_controllen = msg_control.bytes_total(); + } CONTEXT.with(|x| { x.handle().expect("Not in a runtime context").submit_op( @@ -43,6 +52,7 @@ impl Op> { buf: bufs, io_slices, socket_addr, + msg_control, msghdr, }, |recv_from| { @@ -57,11 +67,12 @@ impl Op> { } } -impl Completable for RecvMsg +impl Completable for RecvMsg where T: BoundedBufMut, + U: BoundedBufMut, { - type Output = BufResult<(usize, SocketAddr), Vec>; + type Output = BufResult<(usize, SocketAddr, Option), Vec>; fn complete(self, cqe: CqeResult) -> Self::Output { // Convert the operation result to `usize` @@ -71,6 +82,8 @@ where let socket_addr = (*self.socket_addr).as_socket(); + let msg_control = self.msg_control; + let res = res.map(|n| { let socket_addr: SocketAddr = socket_addr.unwrap(); @@ -89,7 +102,7 @@ where break; } } - (n, socket_addr) + (n, socket_addr, msg_control) }); (res, bufs) diff --git a/src/io/socket.rs b/src/io/socket.rs index a1720aae..335bc6f7 100644 --- a/src/io/socket.rs +++ b/src/io/socket.rs @@ -190,11 +190,12 @@ impl Socket { op.await } - pub(crate) async fn recvmsg( + pub(crate) async fn recvmsg( &self, buf: Vec, - ) -> crate::BufResult<(usize, SocketAddr), Vec> { - let op = Op::recvmsg(&self.fd, buf).unwrap(); + msg_control: Option, + ) -> crate::BufResult<(usize, SocketAddr, Option), Vec> { + let op = Op::recvmsg(&self.fd, buf, msg_control).unwrap(); op.await } diff --git a/src/net/udp.rs b/src/net/udp.rs index 6ccb3861..8ed3c672 100644 --- a/src/net/udp.rs +++ b/src/net/udp.rs @@ -305,12 +305,14 @@ impl UdpSocket { /// Receives a single datagram message on the socket, into multiple buffers /// - /// On success, returns the number of bytes read and the origin. - pub async fn recvmsg( + /// On success, returns the number of bytes read, the origin, and the msg_control, as modified + /// by the kernel. + pub async fn recvmsg( &self, buf: Vec, - ) -> crate::BufResult<(usize, SocketAddr), Vec> { - self.inner.recvmsg(buf).await + msg_control: Option, + ) -> crate::BufResult<(usize, SocketAddr, Option), Vec> { + self.inner.recvmsg(buf, msg_control).await } /// Reads a packet of data from the socket into the buffer. diff --git a/tests/fs_file.rs b/tests/fs_file.rs index 470531d1..b41e2fd4 100644 --- a/tests/fs_file.rs +++ b/tests/fs_file.rs @@ -318,9 +318,9 @@ fn basic_fallocate() { #[test] fn iopoll_without_sqpoll() { use std::os::unix::fs::OpenOptionsExt; - let mut builder = tokio_uring::builder(); - builder.uring_builder(&tokio_uring::uring_builder().setup_iopoll()); - let runtime = tokio_uring::Runtime::new(&builder).unwrap(); + let mut builder = tokio_uring_ooo::builder(); + builder.uring_builder(tokio_uring_ooo::uring_builder().setup_iopoll()); + let runtime = tokio_uring_ooo::Runtime::new(&builder).unwrap(); let tmp = tempfile(); runtime.block_on(async { let file = std::fs::OpenOptions::new() @@ -328,8 +328,7 @@ fn iopoll_without_sqpoll() { .custom_flags(libc::O_DIRECT) .open(tmp.path()) .unwrap(); - let file = tokio_uring::fs::File::from_std(file); - + let file = tokio_uring_ooo::fs::File::from_std(file); let layout = std::alloc::Layout::from_size_align(512, 512).unwrap(); let buf = unsafe { let raw = std::alloc::alloc(layout);