Skip to content

Commit

Permalink
Async write support
Browse files Browse the repository at this point in the history
  • Loading branch information
fafhrd91 committed Sep 10, 2024
1 parent f279962 commit d02fd8e
Show file tree
Hide file tree
Showing 12 changed files with 243 additions and 196 deletions.
40 changes: 27 additions & 13 deletions ntex-async-std/src/io.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use std::{any, cell::RefCell, future::poll_fn, io, pin::Pin, task::Context, task::Poll};
use std::{
any, cell::RefCell, future::poll_fn, io, pin::Pin, task::ready, task::Context,
task::Poll,
};

use async_std::io::{Read as ARead, Write as AWrite};
use ntex_bytes::{Buf, BufMut, BytesVec};
use ntex_io::{types, Handle, IoStream, ReadContext, WriteContext};
use ntex_util::{future::lazy, ready};
use ntex_io::{types, Handle, IoStream, ReadContext, WriteContext, WriteContextBuf};

use crate::TcpStream;

Expand Down Expand Up @@ -51,11 +53,17 @@ struct Write(RefCell<TcpStream>);

impl ntex_io::AsyncWrite for Write {
#[inline]
async fn write(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result<()>) {
match lazy(|cx| flush_io(&mut self.0.borrow_mut().0, &mut buf, cx)).await {
Poll::Ready(res) => (buf, res),
Poll::Pending => (buf, Ok(())),
}
async fn write(&mut self, buf: &mut WriteContextBuf) -> io::Result<()> {
poll_fn(|cx| {
if let Some(mut b) = buf.take() {
let result = flush_io(&mut self.0.borrow_mut().0, &mut b, cx);
buf.set(b);
result
} else {
Poll::Ready(Ok(()))
}
})
.await
}

#[inline]
Expand Down Expand Up @@ -186,11 +194,17 @@ mod unixstream {

impl ntex_io::AsyncWrite for Write {
#[inline]
async fn write(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result<()>) {
match lazy(|cx| flush_io(&mut self.0.borrow_mut().0, &mut buf, cx)).await {
Poll::Ready(res) => (buf, res),
Poll::Pending => (buf, Ok(())),
}
async fn write(&mut self, buf: &mut WriteContextBuf) -> io::Result<()> {
poll_fn(|cx| {
if let Some(mut b) = buf.take() {
let result = flush_io(&mut self.0.borrow_mut().0, &mut b, cx);
buf.set(b);
result
} else {
Poll::Ready(Ok(()))
}
})
.await
}

#[inline]
Expand Down
4 changes: 4 additions & 0 deletions ntex-compio/CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changes

## [0.1.2] - 2024-09-10

* Use new io api

## [0.1.1] - 2024-09-05

* Tune write task
Expand Down
46 changes: 24 additions & 22 deletions ntex-compio/src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use compio::buf::{BufResult, IoBuf, IoBufMut, SetBufInit};
use compio::io::{AsyncRead, AsyncWrite};
use compio::net::TcpStream;
use ntex_bytes::{Buf, BufMut, BytesVec};
use ntex_io::{types, Handle, IoStream, ReadContext, WriteContext};
use ntex_io::{types, Handle, IoStream, ReadContext, WriteContext, WriteContextBuf};

impl IoStream for crate::TcpStream {
fn start(self, read: ReadContext, write: WriteContext) -> Option<Box<dyn Handle>> {
Expand Down Expand Up @@ -130,31 +130,33 @@ where
T: AsyncWrite,
{
#[inline]
async fn write(&mut self, buf: BytesVec) -> (BytesVec, io::Result<()>) {
let mut buf = CompioBuf(buf);
loop {
let BufResult(result, buf1) = self.0.write(buf).await;
buf = buf1;

return match result {
Ok(0) => (
buf.0,
Err(io::Error::new(
async fn write(&mut self, wbuf: &mut WriteContextBuf) -> io::Result<()> {
if let Some(b) = wbuf.take() {
let mut buf = CompioBuf(b);
loop {
let BufResult(result, buf1) = self.0.write(buf).await;
buf = buf1;

let result = match result {
Ok(0) => Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write frame to transport",
)),
),
Ok(size) => {
buf.0.advance(size);

if buf.0.is_empty() {
(buf.0, Ok(()))
} else {
continue;
Ok(size) => {
buf.0.advance(size);
if buf.0.is_empty() {
Ok(())
} else {
continue;
}
}
}
Err(e) => (buf.0, Err(e)),
};
Err(e) => Err(e),
};
wbuf.set(buf.0);
return result;
}
} else {
Ok(())
}
}

Expand Down
107 changes: 58 additions & 49 deletions ntex-glommio/src/io.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
use std::task::{Context, Poll};
use std::{any, future::poll_fn, future::Future, io, pin::Pin};
use std::{any, future::poll_fn, io, pin::Pin, task::ready, task::Context, task::Poll};

use futures_lite::future::FutureExt;
use futures_lite::io::{AsyncRead, AsyncWrite};
use ntex_bytes::{Buf, BufMut, BytesVec};
use ntex_io::{types, Handle, IoStream, ReadContext, WriteContext, WriteStatus};
use ntex_util::{ready, time::sleep, time::Sleep};
use ntex_io::{types, Handle, IoStream, ReadContext, WriteContext, WriteContextBuf};

use crate::net_impl::{TcpStream, UnixStream};

Expand Down Expand Up @@ -62,11 +59,59 @@ struct Write(TcpStream);

impl ntex_io::AsyncWrite for Write {
#[inline]
async fn write(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result<()>) {
match lazy(|cx| flush_io(&mut *self.0.borrow_mut(), &mut buf, cx)).await {
Poll::Ready(res) => (buf, res),
Poll::Pending => (buf, Ok(())),
}
async fn write(&mut self, buf: &mut WriteContextBuf) -> io::Result<()> {
poll_fn(|cx| {
if let Some(mut b) = buf.take() {
let result = flush_io(&mut *self.0 .0.borrow_mut(), &mut b, cx);
buf.set(b);
result
} else {
Poll::Ready(Ok(()))
}
})
.await
}

#[inline]
async fn flush(&mut self) -> io::Result<()> {
Ok(())
}

#[inline]
async fn shutdown(&mut self) -> io::Result<()> {
poll_fn(|cx| Pin::new(&mut *self.0 .0.borrow_mut()).poll_close(cx)).await
}
}

struct UnixRead(UnixStream);

impl ntex_io::AsyncRead for UnixRead {
async fn read(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result<usize>) {
// read data from socket
let result = poll_fn(|cx| {
let mut io = self.0 .0.borrow_mut();
poll_read_buf(Pin::new(&mut *io), cx, &mut buf)
})
.await;
(buf, result)
}
}

struct UnixWrite(UnixStream);

impl ntex_io::AsyncWrite for UnixWrite {
#[inline]
async fn write(&mut self, buf: &mut WriteContextBuf) -> io::Result<()> {
poll_fn(|cx| {
if let Some(mut b) = buf.take() {
let result = flush_io(&mut *self.0 .0.borrow_mut(), &mut b, cx);
buf.set(b);
result
} else {
Poll::Ready(Ok(()))
}
})
.await
}

#[inline]
Expand All @@ -76,7 +121,7 @@ impl ntex_io::AsyncWrite for Write {

#[inline]
async fn shutdown(&mut self) -> io::Result<()> {
poll_fn(|cx| Pin::new(&mut *self.0.borrow_mut()).poll_close(cx)).await
poll_fn(|cx| Pin::new(&mut *self.0 .0.borrow_mut()).poll_close(cx)).await
}
}

Expand Down Expand Up @@ -125,7 +170,7 @@ pub(super) fn flush_io<T: AsyncRead + AsyncWrite + Unpin>(
// log::trace!("flushed {} bytes", written);

// flush
return if written > 0 {
if written > 0 {
match Pin::new(&mut *io).poll_flush(cx) {
Poll::Ready(Ok(_)) => result,
Poll::Pending => Poll::Pending,
Expand All @@ -136,7 +181,7 @@ pub(super) fn flush_io<T: AsyncRead + AsyncWrite + Unpin>(
}
} else {
result
};
}
} else {
Poll::Ready(Ok(()))
}
Expand All @@ -158,39 +203,3 @@ pub fn poll_read_buf<T: AsyncRead>(

Poll::Ready(Ok(n))
}

struct UnixRead(UnixStream);

impl ntex_io::AsyncRead for UnixRead {
async fn read(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result<usize>) {
// read data from socket
let result = poll_fn(|cx| {
let mut io = self.0 .0.borrow_mut();
poll_read_buf(Pin::new(&mut *io), cx, &mut buf)
})
.await;
(buf, result)
}
}

struct UnixWrite(UnixStream);

impl ntex_io::AsyncWrite for UnixWrite {
#[inline]
async fn write(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result<()>) {
match lazy(|cx| flush_io(&mut *self.0.borrow_mut(), &mut buf, cx)).await {
Poll::Ready(res) => (buf, res),
Poll::Pending => (buf, Ok(())),
}
}

#[inline]
async fn flush(&mut self) -> io::Result<()> {
Ok(())
}

#[inline]
async fn shutdown(&mut self) -> io::Result<()> {
poll_fn(|cx| Pin::new(&mut *self.0.borrow_mut()).poll_close(cx)).await
}
}
2 changes: 1 addition & 1 deletion ntex-io/src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl Filter for Base {
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
let mut flags = self.0.flags();

if flags.contains(Flags::IO_STOPPED) {
if flags.is_stopped() {
Poll::Ready(WriteStatus::Terminate)
} else {
self.0 .0.write_task.register(cx.waker());
Expand Down
4 changes: 4 additions & 0 deletions ntex-io/src/flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ bitflags::bitflags! {
}

impl Flags {
pub(crate) fn is_stopped(&self) -> bool {
self.intersects(Flags::IO_STOPPED)
}

pub(crate) fn is_waiting_for_write(&self) -> bool {
self.intersects(Flags::BUF_W_MUST_FLUSH | Flags::BUF_W_BACKPRESSURE)
}
Expand Down
16 changes: 8 additions & 8 deletions ntex-io/src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ impl Io {
let inner = Rc::new(IoState {
filter: FilterPtr::null(),
pool: Cell::new(pool),
flags: Cell::new(Flags::empty()),
flags: Cell::new(Flags::WR_PAUSED),
error: Cell::new(None),
dispatch_task: LocalWaker::new(),
read_task: LocalWaker::new(),
Expand Down Expand Up @@ -421,7 +421,7 @@ impl<F> Io<F> {
let st = self.st();
let mut flags = st.flags.get();

if flags.contains(Flags::IO_STOPPED) {
if flags.is_stopped() {
Poll::Ready(self.error().map(Err).unwrap_or(Ok(None)))
} else {
st.dispatch_task.register(cx.waker());
Expand Down Expand Up @@ -531,7 +531,7 @@ impl<F> Io<F> {
} else {
let st = self.st();
let flags = st.flags.get();
if flags.contains(Flags::IO_STOPPED) {
if flags.is_stopped() {
Err(RecvError::PeerGone(self.error()))
} else if flags.contains(Flags::DSP_STOP) {
st.remove_flags(Flags::DSP_STOP);
Expand Down Expand Up @@ -568,7 +568,7 @@ impl<F> Io<F> {
pub fn poll_flush(&self, cx: &mut Context<'_>, full: bool) -> Poll<io::Result<()>> {
let flags = self.flags();

if flags.contains(Flags::IO_STOPPED) {
if flags.is_stopped() {
Poll::Ready(self.error().map(Err).unwrap_or(Ok(())))
} else {
let st = self.st();
Expand All @@ -595,7 +595,7 @@ impl<F> Io<F> {
let st = self.st();
let flags = st.flags.get();

if flags.intersects(Flags::IO_STOPPED) {
if flags.is_stopped() {
if let Some(err) = self.error() {
Poll::Ready(Err(err))
} else {
Expand Down Expand Up @@ -700,7 +700,7 @@ impl<F> Drop for Io<F> {
if st.filter.is_set() {
// filter is unsafe and must be dropped explicitly,
// and wont be dropped without special attention
if !st.flags.get().contains(Flags::IO_STOPPED) {
if !st.flags.get().is_stopped() {
log::trace!(
"{}: Io is dropped, force stopping io streams {:?}",
st.tag.get(),
Expand Down Expand Up @@ -884,7 +884,7 @@ pub struct OnDisconnect {

impl OnDisconnect {
pub(super) fn new(inner: Rc<IoState>) -> Self {
Self::new_inner(inner.flags.get().contains(Flags::IO_STOPPED), inner)
Self::new_inner(inner.flags.get().is_stopped(), inner)
}

fn new_inner(disconnected: bool, inner: Rc<IoState>) -> Self {
Expand All @@ -909,7 +909,7 @@ impl OnDisconnect {
#[inline]
/// Check if connection is disconnected
pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
if self.token == usize::MAX || self.inner.flags.get().contains(Flags::IO_STOPPED) {
if self.token == usize::MAX || self.inner.flags.get().is_stopped() {
Poll::Ready(())
} else if let Some(on_disconnect) = self.inner.on_disconnect.take() {
on_disconnect[self.token].register(cx.waker());
Expand Down
4 changes: 2 additions & 2 deletions ntex-io/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub use self::filter::{Base, Filter, Layer};
pub use self::framed::Framed;
pub use self::io::{Io, IoRef, OnDisconnect};
pub use self::seal::{IoBoxed, Sealed};
pub use self::tasks::{ReadContext, WriteContext};
pub use self::tasks::{ReadContext, WriteContext, WriteContextBuf};
pub use self::timer::TimerHandle;
pub use self::utils::{seal, Decoded};

Expand All @@ -45,7 +45,7 @@ pub trait AsyncRead {

#[doc(hidden)]
pub trait AsyncWrite {
async fn write(&mut self, buf: BytesVec) -> (BytesVec, sio::Result<()>);
async fn write(&mut self, buf: &mut WriteContextBuf) -> sio::Result<()>;

async fn flush(&mut self) -> sio::Result<()>;

Expand Down
Loading

0 comments on commit d02fd8e

Please sign in to comment.