diff --git a/ntex-async-std/src/io.rs b/ntex-async-std/src/io.rs index 6ef23b0f1..c64761881 100644 --- a/ntex-async-std/src/io.rs +++ b/ntex-async-std/src/io.rs @@ -1,9 +1,11 @@ -use std::{any, cell::RefCell, future::poll_fn, io, pin::Pin, task::Context, task::Poll}; +use std::{ + any, cell::RefCell, future::poll_fn, io, pin::Pin, task::ready, task::Context, + task::Poll, +}; use async_std::io::{Read as ARead, Write as AWrite}; use ntex_bytes::{Buf, BufMut, BytesVec}; use ntex_io::{types, Handle, IoStream, ReadContext, WriteContext}; -use ntex_util::{future::lazy, ready}; use crate::TcpStream; @@ -52,10 +54,8 @@ struct Write(RefCell); impl ntex_io::AsyncWrite for Write { #[inline] async fn write(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result<()>) { - match lazy(|cx| flush_io(&mut self.0.borrow_mut().0, &mut buf, cx)).await { - Poll::Ready(res) => (buf, res), - Poll::Pending => (buf, Ok(())), - } + let result = poll_fn(|cx| flush_io(&mut self.0.borrow_mut().0, &mut buf, cx)).await; + (buf, result) } #[inline] @@ -187,10 +187,9 @@ mod unixstream { impl ntex_io::AsyncWrite for Write { #[inline] async fn write(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result<()>) { - match lazy(|cx| flush_io(&mut self.0.borrow_mut().0, &mut buf, cx)).await { - Poll::Ready(res) => (buf, res), - Poll::Pending => (buf, Ok(())), - } + let result = + poll_fn(|cx| flush_io(&mut self.0.borrow_mut().0, &mut buf, cx)).await; + (buf, result) } #[inline] diff --git a/ntex-glommio/src/io.rs b/ntex-glommio/src/io.rs index 53166de7f..c39567ded 100644 --- a/ntex-glommio/src/io.rs +++ b/ntex-glommio/src/io.rs @@ -1,11 +1,8 @@ -use std::task::{Context, Poll}; -use std::{any, future::poll_fn, future::Future, io, pin::Pin}; +use std::{any, future::poll_fn, io, pin::Pin, task::ready, task::Context, task::Poll}; -use futures_lite::future::FutureExt; use futures_lite::io::{AsyncRead, AsyncWrite}; use ntex_bytes::{Buf, BufMut, BytesVec}; -use ntex_io::{types, Handle, IoStream, ReadContext, WriteContext, WriteStatus}; -use ntex_util::{ready, time::sleep, time::Sleep}; +use ntex_io::{types, Handle, IoStream, ReadContext, WriteContext}; use crate::net_impl::{TcpStream, UnixStream}; @@ -63,10 +60,9 @@ struct Write(TcpStream); impl ntex_io::AsyncWrite for Write { #[inline] async fn write(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result<()>) { - match lazy(|cx| flush_io(&mut *self.0.borrow_mut(), &mut buf, cx)).await { - Poll::Ready(res) => (buf, res), - Poll::Pending => (buf, Ok(())), - } + let result = + poll_fn(|cx| flush_io(&mut *self.0 .0.borrow_mut(), &mut buf, cx)).await; + (buf, result) } #[inline] @@ -76,7 +72,7 @@ impl ntex_io::AsyncWrite for Write { #[inline] async fn shutdown(&mut self) -> io::Result<()> { - poll_fn(|cx| Pin::new(&mut *self.0.borrow_mut()).poll_close(cx)).await + poll_fn(|cx| Pin::new(&mut *self.0 .0.borrow_mut()).poll_close(cx)).await } } @@ -125,7 +121,7 @@ pub(super) fn flush_io( // log::trace!("flushed {} bytes", written); // flush - return if written > 0 { + if written > 0 { match Pin::new(&mut *io).poll_flush(cx) { Poll::Ready(Ok(_)) => result, Poll::Pending => Poll::Pending, @@ -136,7 +132,7 @@ pub(super) fn flush_io( } } else { result - }; + } } else { Poll::Ready(Ok(())) } @@ -178,10 +174,9 @@ struct UnixWrite(UnixStream); impl ntex_io::AsyncWrite for UnixWrite { #[inline] async fn write(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result<()>) { - match lazy(|cx| flush_io(&mut *self.0.borrow_mut(), &mut buf, cx)).await { - Poll::Ready(res) => (buf, res), - Poll::Pending => (buf, Ok(())), - } + let result = + poll_fn(|cx| flush_io(&mut *self.0 .0.borrow_mut(), &mut buf, cx)).await; + (buf, result) } #[inline] @@ -191,6 +186,6 @@ impl ntex_io::AsyncWrite for UnixWrite { #[inline] async fn shutdown(&mut self) -> io::Result<()> { - poll_fn(|cx| Pin::new(&mut *self.0.borrow_mut()).poll_close(cx)).await + poll_fn(|cx| Pin::new(&mut *self.0 .0.borrow_mut()).poll_close(cx)).await } } diff --git a/ntex-io/src/filter.rs b/ntex-io/src/filter.rs index f74e057c2..b452429b6 100644 --- a/ntex-io/src/filter.rs +++ b/ntex-io/src/filter.rs @@ -95,7 +95,7 @@ impl Filter for Base { fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll { let mut flags = self.0.flags(); - if flags.contains(Flags::IO_STOPPED) { + if flags.is_stopped() { Poll::Ready(WriteStatus::Terminate) } else { self.0 .0.write_task.register(cx.waker()); diff --git a/ntex-io/src/flags.rs b/ntex-io/src/flags.rs index 82056ff54..029d891b7 100644 --- a/ntex-io/src/flags.rs +++ b/ntex-io/src/flags.rs @@ -36,6 +36,10 @@ bitflags::bitflags! { } impl Flags { + pub(crate) fn is_stopped(&self) -> bool { + self.intersects(Flags::IO_STOPPED) + } + pub(crate) fn is_waiting_for_write(&self) -> bool { self.intersects(Flags::BUF_W_MUST_FLUSH | Flags::BUF_W_BACKPRESSURE) } diff --git a/ntex-io/src/io.rs b/ntex-io/src/io.rs index ded121418..29c59dec7 100644 --- a/ntex-io/src/io.rs +++ b/ntex-io/src/io.rs @@ -421,7 +421,7 @@ impl Io { let st = self.st(); let mut flags = st.flags.get(); - if flags.contains(Flags::IO_STOPPED) { + if flags.is_stopped() { Poll::Ready(self.error().map(Err).unwrap_or(Ok(None))) } else { st.dispatch_task.register(cx.waker()); @@ -531,7 +531,7 @@ impl Io { } else { let st = self.st(); let flags = st.flags.get(); - if flags.contains(Flags::IO_STOPPED) { + if flags.is_stopped() { Err(RecvError::PeerGone(self.error())) } else if flags.contains(Flags::DSP_STOP) { st.remove_flags(Flags::DSP_STOP); @@ -568,7 +568,7 @@ impl Io { pub fn poll_flush(&self, cx: &mut Context<'_>, full: bool) -> Poll> { let flags = self.flags(); - if flags.contains(Flags::IO_STOPPED) { + if flags.is_stopped() { Poll::Ready(self.error().map(Err).unwrap_or(Ok(()))) } else { let st = self.st(); @@ -595,7 +595,7 @@ impl Io { let st = self.st(); let flags = st.flags.get(); - if flags.intersects(Flags::IO_STOPPED) { + if flags.is_stopped() { if let Some(err) = self.error() { Poll::Ready(Err(err)) } else { @@ -700,7 +700,7 @@ impl Drop for Io { if st.filter.is_set() { // filter is unsafe and must be dropped explicitly, // and wont be dropped without special attention - if !st.flags.get().contains(Flags::IO_STOPPED) { + if !st.flags.get().is_stopped() { log::trace!( "{}: Io is dropped, force stopping io streams {:?}", st.tag.get(), @@ -884,7 +884,7 @@ pub struct OnDisconnect { impl OnDisconnect { pub(super) fn new(inner: Rc) -> Self { - Self::new_inner(inner.flags.get().contains(Flags::IO_STOPPED), inner) + Self::new_inner(inner.flags.get().is_stopped(), inner) } fn new_inner(disconnected: bool, inner: Rc) -> Self { @@ -909,7 +909,7 @@ impl OnDisconnect { #[inline] /// Check if connection is disconnected pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<()> { - if self.token == usize::MAX || self.inner.flags.get().contains(Flags::IO_STOPPED) { + if self.token == usize::MAX || self.inner.flags.get().is_stopped() { Poll::Ready(()) } else if let Some(on_disconnect) = self.inner.on_disconnect.take() { on_disconnect[self.token].register(cx.waker()); diff --git a/ntex-io/src/tasks.rs b/ntex-io/src/tasks.rs index 281faaafe..4010aeace 100644 --- a/ntex-io/src/tasks.rs +++ b/ntex-io/src/tasks.rs @@ -187,6 +187,19 @@ impl WriteContext { self.0 .0.io_stopped(err); } + /// Check if io is closed + async fn when_stopped(&self) { + poll_fn(|cx| { + if self.0.flags().is_stopped() { + Poll::Ready(()) + } else { + self.0 .0.write_task.register(cx.waker()); + Poll::Pending + } + }) + .await + } + /// Handle write io operations pub async fn handle(&self, io: &mut T) where @@ -194,7 +207,7 @@ impl WriteContext { { let inner = &self.0 .0; let mut delay = None; - let mut buf = None; + let mut buf: Option = None; loop { // check readiness @@ -226,10 +239,10 @@ impl WriteContext { match result { WriteStatus::Ready => { // write io stream - let (buf_result, result) = if let Some(b) = buf.take() { - io.write(b).await + let result = if let Some(b) = buf.take() { + select(io.write(b), self.when_stopped()).await } else if let Some(b) = inner.buffer.get_write_destination() { - io.write(b).await + select(io.write(b), self.when_stopped()).await } else { // nothing to write, wait for wakeup if flags.is_waiting_for_write() { @@ -248,6 +261,10 @@ impl WriteContext { continue; }; + let (buf_result, result) = match result { + Either::Left(res) => res, + Either::Right(_) => return, + }; match result { Ok(_) => { diff --git a/ntex-io/src/testing.rs b/ntex-io/src/testing.rs index 496d143cc..a8120d4e9 100644 --- a/ntex-io/src/testing.rs +++ b/ntex-io/src/testing.rs @@ -202,6 +202,11 @@ impl IoTest { self.remote.lock().unwrap().borrow().waker.wake(); } + /// Get available data cap + pub fn get_remote_buffer_cap(&self) -> usize { + self.local.lock().unwrap().borrow_mut().buf_cap + } + /// Read any available data pub fn read_any(&self) -> Bytes { self.local.lock().unwrap().borrow_mut().buf.split().freeze() @@ -397,7 +402,6 @@ struct Write(Rc); impl crate::AsyncWrite for Write { async fn write(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result<()>) { let result = poll_fn(|cx| write_io(&self.0, &mut buf, cx)).await; - (buf, result) } diff --git a/ntex-tokio/src/io.rs b/ntex-tokio/src/io.rs index a894b9cf7..5011015fc 100644 --- a/ntex-tokio/src/io.rs +++ b/ntex-tokio/src/io.rs @@ -3,7 +3,7 @@ use std::{any, cell::RefCell, cmp, future::poll_fn, io, mem, pin::Pin, rc::Rc, r use ntex_bytes::{Buf, BufMut, BytesVec}; use ntex_io::{types, Filter, Handle, Io, IoBoxed, IoStream, ReadContext, WriteContext}; -use ntex_util::{future::lazy, ready, time::Millis}; +use ntex_util::{ready, time::Millis}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::net::TcpStream; @@ -59,10 +59,8 @@ struct Write(Rc>); impl ntex_io::AsyncWrite for Write { #[inline] async fn write(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result<()>) { - match lazy(|cx| flush_io(&mut *self.0.borrow_mut(), &mut buf, cx)).await { - Poll::Ready(res) => (buf, res), - Poll::Pending => (buf, Ok(())), - } + let result = poll_fn(|cx| flush_io(&mut *self.0.borrow_mut(), &mut buf, cx)).await; + (buf, result) } #[inline] @@ -258,10 +256,9 @@ mod unixstream { impl ntex_io::AsyncWrite for Write { #[inline] async fn write(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result<()>) { - match lazy(|cx| flush_io(&mut *self.0.borrow_mut(), &mut buf, cx)).await { - Poll::Ready(res) => (buf, res), - Poll::Pending => (buf, Ok(())), - } + let result = + poll_fn(|cx| flush_io(&mut *self.0.borrow_mut(), &mut buf, cx)).await; + (buf, result) } #[inline] diff --git a/ntex/src/http/h1/dispatcher.rs b/ntex/src/http/h1/dispatcher.rs index 244853b87..96116fa43 100644 --- a/ntex/src/http/h1/dispatcher.rs +++ b/ntex/src/http/h1/dispatcher.rs @@ -1119,7 +1119,7 @@ mod tests { client.remote_buffer_cap(65536); sleep(Millis(50)).await; - assert_eq!(state.with_write_buf(|buf| buf.len()).unwrap(), 93); + assert_eq!(client.get_remote_buffer_cap(), 0); assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending()); assert_eq!(num.load(Ordering::Relaxed), 65_536 * 2);