From 1d529fab3c07dab925558941a65047847863853b Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 11 Sep 2024 18:18:45 +0500 Subject: [PATCH] Refactor async io support (#417) --- ntex-async-std/Cargo.toml | 4 +- ntex-async-std/src/io.rs | 611 +++++++----------------------- ntex-compio/CHANGES.md | 4 + ntex-compio/Cargo.toml | 4 +- ntex-compio/src/io.rs | 180 +++------ ntex-glommio/Cargo.toml | 4 +- ntex-glommio/src/io.rs | 631 +++++++------------------------ ntex-io/CHANGES.md | 4 + ntex-io/Cargo.toml | 2 +- ntex-io/src/buf.rs | 32 +- ntex-io/src/filter.rs | 31 +- ntex-io/src/flags.rs | 6 +- ntex-io/src/io.rs | 16 +- ntex-io/src/ioref.rs | 10 - ntex-io/src/lib.rs | 25 +- ntex-io/src/tasks.rs | 526 +++++++++++--------------- ntex-io/src/testing.rs | 347 +++++------------ ntex-net/Cargo.toml | 12 +- ntex-tokio/CHANGES.md | 4 + ntex-tokio/Cargo.toml | 4 +- ntex-tokio/src/io.rs | 763 +++++++++----------------------------- ntex-tokio/src/lib.rs | 2 - ntex-tokio/src/signals.rs | 138 ------- ntex/Cargo.toml | 2 +- 24 files changed, 863 insertions(+), 2499 deletions(-) delete mode 100644 ntex-tokio/src/signals.rs diff --git a/ntex-async-std/Cargo.toml b/ntex-async-std/Cargo.toml index 2cf5111c2..21d9b93bb 100644 --- a/ntex-async-std/Cargo.toml +++ b/ntex-async-std/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-async-std" -version = "0.5.0" +version = "0.5.1" authors = ["ntex contributors "] description = "async-std intergration for ntex framework" keywords = ["network", "framework", "async", "futures"] @@ -17,7 +17,7 @@ path = "src/lib.rs" [dependencies] ntex-bytes = "0.1" -ntex-io = "2.0" +ntex-io = "2.5" ntex-util = "2.0" log = "0.4" async-std = { version = "1", features = ["unstable"] } diff --git a/ntex-async-std/src/io.rs b/ntex-async-std/src/io.rs index 8c9a2aa7c..7180aeae5 100644 --- a/ntex-async-std/src/io.rs +++ b/ntex-async-std/src/io.rs @@ -1,18 +1,24 @@ -use std::{any, cell::RefCell, future::Future, io, pin::Pin, task::Context, task::Poll}; +use std::{ + any, cell::RefCell, future::poll_fn, io, pin::Pin, task::ready, task::Context, + task::Poll, +}; -use async_std::io::{Read, Write}; +use async_std::io::{Read as ARead, Write as AWrite}; use ntex_bytes::{Buf, BufMut, BytesVec}; -use ntex_io::{ - types, Handle, IoStream, ReadContext, ReadStatus, WriteContext, WriteStatus, -}; -use ntex_util::{ready, time::sleep, time::Sleep}; +use ntex_io::{types, Handle, IoStream, ReadContext, WriteContext, WriteContextBuf}; use crate::TcpStream; impl IoStream for TcpStream { fn start(self, read: ReadContext, write: WriteContext) -> Option> { - async_std::task::spawn_local(ReadTask::new(self.clone(), read)); - async_std::task::spawn_local(WriteTask::new(self.clone(), write)); + let mut rio = Read(RefCell::new(self.clone())); + async_std::task::spawn_local(async move { + read.handle(&mut rio).await; + }); + let mut wio = Write(RefCell::new(self.clone())); + async_std::task::spawn_local(async move { + write.handle(&mut wio).await; + }); Some(Box::new(self)) } } @@ -29,296 +35,111 @@ impl Handle for TcpStream { } /// Read io task -struct ReadTask { - io: RefCell, - state: ReadContext, -} - -impl ReadTask { - /// Create new read io task - fn new(io: TcpStream, state: ReadContext) -> Self { - Self { - state, - io: RefCell::new(io), - } +struct Read(RefCell); + +impl ntex_io::AsyncRead for Read { + async fn read(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result) { + // read data from socket + let result = poll_fn(|cx| { + let mut io = self.0.borrow_mut(); + poll_read_buf(Pin::new(&mut io.0), cx, &mut buf) + }) + .await; + (buf, result) } } -impl Future for ReadTask { - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_ref(); +struct Write(RefCell); - match ready!(this.state.poll_ready(cx)) { - ReadStatus::Ready => { - this.state.with_buf(|buf, hw, lw| { - // read data from socket - let mut io = self.io.borrow_mut(); - loop { - // make sure we've got room - let remaining = buf.remaining_mut(); - if remaining < lw { - buf.reserve(hw - remaining); - } - - return match poll_read_buf(Pin::new(&mut io.0), cx, buf) { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(n)) => { - if n == 0 { - log::trace!("async-std stream is disconnected"); - Poll::Ready(Ok(())) - } else if buf.len() < hw { - continue; - } else { - Poll::Pending - } - } - Poll::Ready(Err(err)) => { - log::trace!("async-std read task failed on io {:?}", err); - Poll::Ready(Err(err)) - } - }; - } - }) - } - ReadStatus::Terminate => { - log::trace!("read task is instructed to shutdown"); - Poll::Ready(()) +impl ntex_io::AsyncWrite for Write { + #[inline] + async fn write(&mut self, buf: &mut WriteContextBuf) -> io::Result<()> { + poll_fn(|cx| { + if let Some(mut b) = buf.take() { + let result = flush_io(&mut self.0.borrow_mut().0, &mut b, cx); + buf.set(b); + result + } else { + Poll::Ready(Ok(())) } - } + }) + .await } -} - -#[derive(Debug)] -enum IoWriteState { - Processing(Option), - Shutdown(Sleep, Shutdown), -} - -#[derive(Debug)] -enum Shutdown { - None, - Stopping(u16), -} -/// Write io task -struct WriteTask { - st: IoWriteState, - io: TcpStream, - state: WriteContext, -} - -impl WriteTask { - /// Create new write io task - fn new(io: TcpStream, state: WriteContext) -> Self { - Self { - io, - state, - st: IoWriteState::Processing(None), - } + #[inline] + async fn flush(&mut self) -> io::Result<()> { + Ok(()) } -} - -impl Future for WriteTask { - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_mut().get_mut(); - match this.st { - IoWriteState::Processing(ref mut delay) => { - match this.state.poll_ready(cx) { - Poll::Ready(WriteStatus::Ready) => { - if let Some(delay) = delay { - if delay.poll_elapsed(cx).is_ready() { - this.state.close(Some(io::Error::new( - io::ErrorKind::TimedOut, - "Operation timedout", - ))); - return Poll::Ready(()); - } - } - - // flush io stream - let io = &mut this.io.0; - match ready!(this.state.with_buf(|buf| flush_io(io, buf, cx))) { - Ok(()) => Poll::Pending, - Err(e) => { - this.state.close(Some(e)); - Poll::Ready(()) - } - } - } - Poll::Ready(WriteStatus::Timeout(time)) => { - log::trace!("initiate timeout delay for {:?}", time); - if delay.is_none() { - *delay = Some(sleep(time)); - } - self.poll(cx) - } - Poll::Ready(WriteStatus::Shutdown(time)) => { - log::trace!("write task is instructed to shutdown"); - - let timeout = if let Some(delay) = delay.take() { - delay - } else { - sleep(time) - }; - - this.st = IoWriteState::Shutdown(timeout, Shutdown::None); - self.poll(cx) - } - Poll::Ready(WriteStatus::Terminate) => { - log::trace!("write task is instructed to terminate"); - - let _ = Pin::new(&mut this.io.0).poll_close(cx); - this.state.close(None); - Poll::Ready(()) - } - Poll::Pending => Poll::Pending, - } - } - IoWriteState::Shutdown(ref mut delay, ref mut st) => { - // close WRITE side and wait for disconnect on read side. - // use disconnect timeout, otherwise it could hang forever. - loop { - match st { - Shutdown::None => { - // flush write buffer - let io = &mut this.io.0; - match this.state.with_buf(|buf| flush_io(io, buf, cx)) { - Poll::Ready(Ok(())) => { - if let Err(e) = - this.io.0.shutdown(std::net::Shutdown::Write) - { - this.state.close(Some(e)); - return Poll::Ready(()); - } - *st = Shutdown::Stopping(0); - continue; - } - Poll::Ready(Err(err)) => { - log::trace!( - "write task is closed with err during flush, {:?}", - err - ); - this.state.close(Some(err)); - return Poll::Ready(()); - } - Poll::Pending => (), - } - } - Shutdown::Stopping(ref mut count) => { - // read until 0 or err - let mut buf = [0u8; 512]; - let io = &mut this.io; - loop { - match Pin::new(&mut io.0).poll_read(cx, &mut buf) { - Poll::Ready(Err(e)) => { - log::trace!("write task is stopped"); - this.state.close(Some(e)); - return Poll::Ready(()); - } - Poll::Ready(Ok(0)) => { - log::trace!("async-std socket is disconnected"); - this.state.close(None); - return Poll::Ready(()); - } - Poll::Ready(Ok(n)) => { - *count += n as u16; - if *count > 4096 { - log::trace!( - "write task is stopped, too much input" - ); - this.state.close(None); - return Poll::Ready(()); - } - } - Poll::Pending => break, - } - } - } - } - - // disconnect timeout - if delay.poll_elapsed(cx).is_pending() { - return Poll::Pending; - } - log::trace!("write task is stopped after delay"); - this.state.close(None); - let _ = Pin::new(&mut this.io.0).poll_close(cx); - return Poll::Ready(()); - } - } - } + #[inline] + async fn shutdown(&mut self) -> io::Result<()> { + self.0.borrow().0.shutdown(std::net::Shutdown::Both) } } /// Flush write buffer to underlying I/O stream. -pub(super) fn flush_io( +pub(super) fn flush_io( io: &mut T, - buf: &mut Option, + buf: &mut BytesVec, cx: &mut Context<'_>, ) -> Poll> { - if let Some(buf) = buf { - let len = buf.len(); - - if len != 0 { - // log::trace!("flushing framed transport: {:?}", buf.len()); - - let mut written = 0; - let result = loop { - break match Pin::new(&mut *io).poll_write(cx, &buf[written..]) { - Poll::Ready(Ok(n)) => { - if n == 0 { - log::trace!("Disconnected during flush, written {}", written); - Poll::Ready(Err(io::Error::new( - io::ErrorKind::WriteZero, - "failed to write frame to transport", - ))) + let len = buf.len(); + + if len != 0 { + // log::trace!("flushing framed transport: {:?}", buf.len()); + + let mut written = 0; + let result = loop { + break match Pin::new(&mut *io).poll_write(cx, &buf[written..]) { + Poll::Ready(Ok(n)) => { + if n == 0 { + log::trace!("Disconnected during flush, written {}", written); + Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "failed to write frame to transport", + ))) + } else { + written += n; + if written == len { + buf.clear(); + Poll::Ready(Ok(())) } else { - written += n; - if written == len { - buf.clear(); - Poll::Ready(Ok(())) - } else { - continue; - } + continue; } } - Poll::Pending => { - // remove written data - buf.advance(written); - Poll::Pending - } - Poll::Ready(Err(e)) => { - log::trace!("Error during flush: {}", e); - Poll::Ready(Err(e)) - } - }; - }; - // log::trace!("flushed {} bytes", written); - - // flush - return if written > 0 { - match Pin::new(&mut *io).poll_flush(cx) { - Poll::Ready(Ok(_)) => result, - Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => { - log::trace!("error during flush: {}", e); - Poll::Ready(Err(e)) - } } - } else { - result + Poll::Pending => { + // remove written data + buf.advance(written); + Poll::Pending + } + Poll::Ready(Err(e)) => { + log::trace!("Error during flush: {}", e); + Poll::Ready(Err(e)) + } }; + }; + // log::trace!("flushed {} bytes", written); + + // flush + if written > 0 { + match Pin::new(&mut *io).poll_flush(cx) { + Poll::Ready(Ok(_)) => result, + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => { + log::trace!("error during flush: {}", e); + Poll::Ready(Err(e)) + } + } + } else { + result } + } else { + Poll::Ready(Ok(())) } - Poll::Ready(Ok(())) } -pub fn poll_read_buf( +pub fn poll_read_buf( io: Pin<&mut T>, cx: &mut Context<'_>, buf: &mut BytesVec, @@ -342,226 +163,58 @@ mod unixstream { impl IoStream for UnixStream { fn start(self, read: ReadContext, write: WriteContext) -> Option> { - async_std::task::spawn_local(ReadTask::new(self.clone(), read)); - async_std::task::spawn_local(WriteTask::new(self, write)); + let mut rio = Read(RefCell::new(self.clone())); + async_std::task::spawn_local(async move { + read.handle(&mut rio).await; + }); + let mut wio = Write(RefCell::new(self)); + async_std::task::spawn_local(async move { + write.handle(&mut wio).await; + }); None } } /// Read io task - struct ReadTask { - io: RefCell, - state: ReadContext, - } - - impl ReadTask { - /// Create new read io task - fn new(io: UnixStream, state: ReadContext) -> Self { - Self { - state, - io: RefCell::new(io), - } + struct Read(RefCell); + + impl ntex_io::AsyncRead for Read { + async fn read(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result) { + // read data from socket + let result = poll_fn(|cx| { + let mut io = self.0.borrow_mut(); + poll_read_buf(Pin::new(&mut io.0), cx, &mut buf) + }) + .await; + (buf, result) } } - impl Future for ReadTask { - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_ref(); - - this.state.with_buf(|buf, hw, lw| { - match ready!(this.state.poll_ready(cx)) { - ReadStatus::Ready => { - // read data from socket - let mut io = this.io.borrow_mut(); - loop { - // make sure we've got room - let remaining = buf.remaining_mut(); - if remaining < lw { - buf.reserve(hw - remaining); - } - - return match poll_read_buf(Pin::new(&mut io.0), cx, buf) { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(n)) => { - if n == 0 { - log::trace!("async-std stream is disconnected"); - Poll::Ready(Ok(())) - } else if buf.len() < hw { - continue; - } else { - Poll::Pending - } - } - Poll::Ready(Err(err)) => { - log::trace!("read task failed on io {:?}", err); - Poll::Ready(Err(err)) - } - }; - } - } - ReadStatus::Terminate => { - log::trace!("read task is instructed to shutdown"); - Poll::Ready(Ok(())) - } + struct Write(RefCell); + + impl ntex_io::AsyncWrite for Write { + #[inline] + async fn write(&mut self, buf: &mut WriteContextBuf) -> io::Result<()> { + poll_fn(|cx| { + if let Some(mut b) = buf.take() { + let result = flush_io(&mut self.0.borrow_mut().0, &mut b, cx); + buf.set(b); + result + } else { + Poll::Ready(Ok(())) } }) + .await } - } - - /// Write io task - struct WriteTask { - st: IoWriteState, - io: UnixStream, - state: WriteContext, - } - impl WriteTask { - /// Create new write io task - fn new(io: UnixStream, state: WriteContext) -> Self { - Self { - io, - state, - st: IoWriteState::Processing(None), - } + #[inline] + async fn flush(&mut self) -> io::Result<()> { + Ok(()) } - } - - impl Future for WriteTask { - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_mut().get_mut(); - - match this.st { - IoWriteState::Processing(ref mut delay) => { - match this.state.poll_ready(cx) { - Poll::Ready(WriteStatus::Ready) => { - if let Some(delay) = delay { - if delay.poll_elapsed(cx).is_ready() { - this.state.close(Some(io::Error::new( - io::ErrorKind::TimedOut, - "Operation timedout", - ))); - return Poll::Ready(()); - } - } - - // flush io stream - let io = &mut this.io.0; - match ready!(this.state.with_buf(|buf| flush_io(io, buf, cx))) { - Ok(()) => Poll::Pending, - Err(e) => { - this.state.close(Some(e)); - Poll::Ready(()) - } - } - } - Poll::Ready(WriteStatus::Timeout(time)) => { - log::trace!("initiate timeout delay for {:?}", time); - if delay.is_none() { - *delay = Some(sleep(time)); - } - self.poll(cx) - } - Poll::Ready(WriteStatus::Shutdown(time)) => { - log::trace!("write task is instructed to shutdown"); - - let timeout = if let Some(delay) = delay.take() { - delay - } else { - sleep(time) - }; - this.st = IoWriteState::Shutdown(timeout, Shutdown::None); - self.poll(cx) - } - Poll::Ready(WriteStatus::Terminate) => { - log::trace!("write task is instructed to terminate"); - - let _ = Pin::new(&mut this.io.0).poll_close(cx); - this.state.close(None); - Poll::Ready(()) - } - Poll::Pending => Poll::Pending, - } - } - IoWriteState::Shutdown(ref mut delay, ref mut st) => { - // close WRITE side and wait for disconnect on read side. - // use disconnect timeout, otherwise it could hang forever. - loop { - match st { - Shutdown::None => { - // flush write buffer - let io = &mut this.io.0; - match this.state.with_buf(|buf| flush_io(io, buf, cx)) { - Poll::Ready(Ok(())) => { - if let Err(e) = - this.io.0.shutdown(std::net::Shutdown::Write) - { - this.state.close(Some(e)); - return Poll::Ready(()); - } - *st = Shutdown::Stopping(0); - continue; - } - Poll::Ready(Err(err)) => { - log::trace!( - "write task is closed with err during flush, {:?}", - err - ); - this.state.close(Some(err)); - return Poll::Ready(()); - } - Poll::Pending => (), - } - } - Shutdown::Stopping(ref mut count) => { - // read until 0 or err - let mut buf = [0u8; 512]; - let io = &mut this.io; - loop { - match Pin::new(&mut io.0).poll_read(cx, &mut buf) { - Poll::Ready(Err(e)) => { - log::trace!("write task is stopped"); - this.state.close(Some(e)); - return Poll::Ready(()); - } - Poll::Ready(Ok(0)) => { - log::trace!( - "async-std unix socket is disconnected" - ); - this.state.close(None); - return Poll::Ready(()); - } - Poll::Ready(Ok(n)) => { - *count += n as u16; - if *count > 4096 { - log::trace!( - "write task is stopped, too much input" - ); - this.state.close(None); - return Poll::Ready(()); - } - } - Poll::Pending => break, - } - } - } - } - - // disconnect timeout - if delay.poll_elapsed(cx).is_pending() { - return Poll::Pending; - } - log::trace!("write task is stopped after delay"); - this.state.close(None); - let _ = Pin::new(&mut this.io.0).poll_close(cx); - return Poll::Ready(()); - } - } - } + #[inline] + async fn shutdown(&mut self) -> io::Result<()> { + self.0.borrow().0.shutdown(std::net::Shutdown::Both) } } } diff --git a/ntex-compio/CHANGES.md b/ntex-compio/CHANGES.md index 46374448e..e4b381bcb 100644 --- a/ntex-compio/CHANGES.md +++ b/ntex-compio/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [0.1.2] - 2024-09-11 + +* Use new io api + ## [0.1.1] - 2024-09-05 * Tune write task diff --git a/ntex-compio/Cargo.toml b/ntex-compio/Cargo.toml index 533311a7f..1b6360705 100644 --- a/ntex-compio/Cargo.toml +++ b/ntex-compio/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-compio" -version = "0.1.1" +version = "0.1.2" authors = ["ntex contributors "] description = "compio runtime intergration for ntex framework" keywords = ["network", "framework", "async", "futures"] @@ -18,7 +18,7 @@ path = "src/lib.rs" [dependencies] ntex-bytes = "0.1" -ntex-io = "2.3" +ntex-io = "2.5" ntex-util = "2" log = "0.4" compio-net = "0.4.1" diff --git a/ntex-compio/src/io.rs b/ntex-compio/src/io.rs index aaabbe7db..c27e67595 100644 --- a/ntex-compio/src/io.rs +++ b/ntex-compio/src/io.rs @@ -4,17 +4,13 @@ use compio::buf::{BufResult, IoBuf, IoBufMut, SetBufInit}; use compio::io::{AsyncRead, AsyncWrite}; use compio::net::TcpStream; use ntex_bytes::{Buf, BufMut, BytesVec}; -use ntex_io::{ - types, Handle, IoStream, ReadContext, ReadStatus, WriteContext, WriteStatus, -}; -use ntex_util::{future::select, future::Either, time::sleep}; +use ntex_io::{types, Handle, IoStream, ReadContext, WriteContext, WriteContextBuf}; impl IoStream for crate::TcpStream { fn start(self, read: ReadContext, write: WriteContext) -> Option> { - let mut io = self.0.clone(); + let io = self.0.clone(); compio::runtime::spawn(async move { - run(&mut io, &read, write).await; - + run(io.clone(), &read, write).await; match io.close().await { Ok(_) => log::debug!("{} Stream is closed", read.tag()), Err(e) => log::error!("{} Stream is closed, {:?}", read.tag(), e), @@ -29,11 +25,9 @@ impl IoStream for crate::TcpStream { #[cfg(unix)] impl IoStream for crate::UnixStream { fn start(self, read: ReadContext, write: WriteContext) -> Option> { - let mut io = self.0; compio::runtime::spawn(async move { - run(&mut io, &read, write).await; - - match io.close().await { + run(self.0.clone(), &read, write).await; + match self.0.close().await { Ok(_) => log::debug!("{} Unix stream is closed", read.tag()), Err(e) => log::error!("{} Unix stream is closed, {:?}", read.tag(), e), } @@ -89,17 +83,18 @@ impl SetBufInit for CompioBuf { } async fn run( - io: &mut T, + io: T, read: &ReadContext, write: WriteContext, ) { - let mut wr_io = io.clone(); + let mut wr_io = WriteIo(io.clone()); let wr_task = compio::runtime::spawn(async move { - write_task(&mut wr_io, &write).await; + write.handle(&mut wr_io).await; log::debug!("{} Write task is stopped", write.tag()); }); + let mut io = ReadIo(io); - read_task(io, read).await; + read.handle(&mut io).await; log::debug!("{} Read task is stopped", read.tag()); if !wr_task.is_finished() { @@ -107,142 +102,63 @@ async fn run( } } -/// Read io task -async fn read_task(io: &mut T, state: &ReadContext) { - loop { - match state.ready().await { - ReadStatus::Ready => { - let result = state - .with_buf_async(|buf| async { - let BufResult(result, buf) = - match select(io.read(CompioBuf(buf)), state.wait_for_close()) - .await - { - Either::Left(res) => res, - Either::Right(_) => return (Default::default(), Ok(1)), - }; - - match result { - Ok(n) => { - if n == 0 { - log::trace!( - "{}: Tcp stream is disconnected", - state.tag() - ); - } - (buf.0, Ok(n)) - } - Err(err) => { - log::trace!( - "{}: Read task failed on io {:?}", - state.tag(), - err - ); - (buf.0, Err(err)) - } - } - }) - .await; +struct ReadIo(T); - if result.is_ready() { - break; - } - } - ReadStatus::Terminate => { - log::trace!("{}: Read task is instructed to shutdown", state.tag()); - break; - } - } +impl ntex_io::AsyncRead for ReadIo +where + T: AsyncRead, +{ + #[inline] + async fn read(&mut self, buf: BytesVec) -> (BytesVec, io::Result) { + let BufResult(result, buf) = self.0.read(CompioBuf(buf)).await; + (buf.0, result) } } -/// Write io task -async fn write_task(mut io: T, state: &WriteContext) { - let mut delay = None; - - loop { - let result = if let Some(ref mut sleep) = delay { - let result = match select(sleep, state.ready()).await { - Either::Left(_) => { - state.close(Some(io::Error::new( - io::ErrorKind::TimedOut, - "Operation timedout", - ))); - return; - } - Either::Right(res) => res, - }; - delay = None; - result - } else { - state.ready().await - }; - - match result { - WriteStatus::Ready => { - // write io stream - match write(&mut io, state).await { - Ok(()) => continue, - Err(e) => { - state.close(Some(e)); - } - } - } - WriteStatus::Timeout(time) => { - log::trace!("{}: Initiate timeout delay for {:?}", state.tag(), time); - delay = Some(sleep(time)); - continue; - } - WriteStatus::Shutdown(time) => { - log::trace!("{}: Write task is instructed to shutdown", state.tag()); - - let fut = async { - write(&mut io, state).await?; - io.flush().await?; - io.shutdown().await?; - Ok(()) - }; - match select(sleep(time), fut).await { - Either::Left(_) => state.close(None), - Either::Right(res) => state.close(res.err()), - } - } - WriteStatus::Terminate => { - log::trace!("{}: Write task is instructed to terminate", state.tag()); - state.close(io.shutdown().await.err()); - } - } - break; - } -} +struct WriteIo(T); -// write to io stream -async fn write(io: &mut T, state: &WriteContext) -> io::Result<()> { - state - .with_buf_async(|buf| async { - let mut buf = CompioBuf(buf); +impl ntex_io::AsyncWrite for WriteIo +where + T: AsyncWrite, +{ + #[inline] + async fn write(&mut self, wbuf: &mut WriteContextBuf) -> io::Result<()> { + if let Some(b) = wbuf.take() { + let mut buf = CompioBuf(b); loop { - let BufResult(result, buf1) = io.write(buf).await; + let BufResult(result, buf1) = self.0.write(buf).await; buf = buf1; - return match result { + let result = match result { Ok(0) => Err(io::Error::new( io::ErrorKind::WriteZero, "failed to write frame to transport", )), Ok(size) => { - if buf.0.len() == size { - // return io.flush().await; - state.memory_pool().release_write_buf(buf.0); + buf.0.advance(size); + if buf.0.is_empty() { Ok(()) } else { - buf.0.advance(size); continue; } } Err(e) => Err(e), }; + wbuf.set(buf.0); + return result; } - }) - .await + } else { + Ok(()) + } + } + + #[inline] + async fn flush(&mut self) -> io::Result<()> { + self.0.flush().await + } + + #[inline] + async fn shutdown(&mut self) -> io::Result<()> { + self.0.shutdown().await + } } diff --git a/ntex-glommio/Cargo.toml b/ntex-glommio/Cargo.toml index edcd7febf..cbb9b5bf9 100644 --- a/ntex-glommio/Cargo.toml +++ b/ntex-glommio/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-glommio" -version = "0.5.0" +version = "0.5.1" authors = ["ntex contributors "] description = "glommio intergration for ntex framework" keywords = ["network", "framework", "async", "futures"] @@ -17,7 +17,7 @@ path = "src/lib.rs" [dependencies] ntex-bytes = "0.1" -ntex-io = "2.0" +ntex-io = "2.5" ntex-util = "2.0" futures-lite = "2.2" log = "0.4" diff --git a/ntex-glommio/src/io.rs b/ntex-glommio/src/io.rs index 7d60c0de3..09fc0616d 100644 --- a/ntex-glommio/src/io.rs +++ b/ntex-glommio/src/io.rs @@ -1,28 +1,30 @@ -use std::task::{Context, Poll}; -use std::{any, future::Future, io, pin::Pin}; +use std::{any, future::poll_fn, io, pin::Pin, task::ready, task::Context, task::Poll}; -use futures_lite::future::FutureExt; use futures_lite::io::{AsyncRead, AsyncWrite}; use ntex_bytes::{Buf, BufMut, BytesVec}; -use ntex_io::{ - types, Handle, IoStream, ReadContext, ReadStatus, WriteContext, WriteStatus, -}; -use ntex_util::{ready, time::sleep, time::Sleep}; +use ntex_io::{types, Handle, IoStream, ReadContext, WriteContext, WriteContextBuf}; use crate::net_impl::{TcpStream, UnixStream}; impl IoStream for TcpStream { fn start(self, read: ReadContext, write: WriteContext) -> Option> { - glommio::spawn_local(ReadTask::new(self.clone(), read)).detach(); - glommio::spawn_local(WriteTask::new(self.clone(), write)).detach(); + let mut rio = Read(self.clone()); + glommio::spawn_local(async move { read.handle(&mut rio).await }).detach(); + let mut wio = Write(self.clone()); + glommio::spawn_local(async move { write.handle(&mut wio).await }).detach(); Some(Box::new(self)) } } impl IoStream for UnixStream { fn start(self, read: ReadContext, write: WriteContext) -> Option> { - glommio::spawn_local(UnixReadTask::new(self.clone(), read)).detach(); - glommio::spawn_local(UnixWriteTask::new(self, write)).detach(); + let mut rio = UnixRead(self.clone()); + glommio::spawn_local(async move { + read.handle(&mut rio).await; + }) + .detach(); + let mut wio = UnixWrite(self); + glommio::spawn_local(async move { write.handle(&mut wio).await }).detach(); None } } @@ -39,306 +41,150 @@ impl Handle for TcpStream { } /// Read io task -struct ReadTask { - io: TcpStream, - state: ReadContext, -} - -impl ReadTask { - /// Create new read io task - fn new(io: TcpStream, state: ReadContext) -> Self { - Self { io, state } +struct Read(TcpStream); + +impl ntex_io::AsyncRead for Read { + async fn read(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result) { + // read data from socket + let result = poll_fn(|cx| { + let mut io = self.0 .0.borrow_mut(); + poll_read_buf(Pin::new(&mut *io), cx, &mut buf) + }) + .await; + (buf, result) } } -impl Future for ReadTask { - type Output = (); +struct Write(TcpStream); - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_mut(); - - this.state.with_buf(|buf, hw, lw| { - match ready!(this.state.poll_ready(cx)) { - ReadStatus::Ready => { - // read data from socket - loop { - // make sure we've got room - let remaining = buf.remaining_mut(); - if remaining < lw { - buf.reserve(hw - remaining); - } - - return match poll_read_buf( - Pin::new(&mut *this.io.0.borrow_mut()), - cx, - buf, - ) { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(n)) => { - if n == 0 { - log::trace!("glommio stream is disconnected"); - Poll::Ready(Ok(())) - } else if buf.len() < hw { - continue; - } else { - Poll::Pending - } - } - Poll::Ready(Err(err)) => { - log::trace!("read task failed on io {:?}", err); - Poll::Ready(Err(err)) - } - }; - } - } - ReadStatus::Terminate => { - log::trace!("read task is instructed to shutdown"); - Poll::Ready(Ok(())) - } +impl ntex_io::AsyncWrite for Write { + #[inline] + async fn write(&mut self, buf: &mut WriteContextBuf) -> io::Result<()> { + poll_fn(|cx| { + if let Some(mut b) = buf.take() { + let result = flush_io(&mut *self.0 .0.borrow_mut(), &mut b, cx); + buf.set(b); + result + } else { + Poll::Ready(Ok(())) } }) + .await } -} -enum IoWriteState { - Processing(Option), - Shutdown(Sleep, Shutdown), -} + #[inline] + async fn flush(&mut self) -> io::Result<()> { + Ok(()) + } -enum Shutdown { - Flush, - Close(Pin>>>), - Stopping(u16), + #[inline] + async fn shutdown(&mut self) -> io::Result<()> { + poll_fn(|cx| Pin::new(&mut *self.0 .0.borrow_mut()).poll_close(cx)).await + } } -/// Write io task -struct WriteTask { - st: IoWriteState, - io: TcpStream, - state: WriteContext, -} +struct UnixRead(UnixStream); -impl WriteTask { - /// Create new write io task - fn new(io: TcpStream, state: WriteContext) -> Self { - Self { - io, - state, - st: IoWriteState::Processing(None), - } +impl ntex_io::AsyncRead for UnixRead { + async fn read(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result) { + // read data from socket + let result = poll_fn(|cx| { + let mut io = self.0 .0.borrow_mut(); + poll_read_buf(Pin::new(&mut *io), cx, &mut buf) + }) + .await; + (buf, result) } } -impl Future for WriteTask { - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_mut().get_mut(); - - match this.st { - IoWriteState::Processing(ref mut delay) => { - match this.state.poll_ready(cx) { - Poll::Ready(WriteStatus::Ready) => { - if let Some(delay) = delay { - if delay.poll_elapsed(cx).is_ready() { - this.state.close(Some(io::Error::new( - io::ErrorKind::TimedOut, - "Operation timedout", - ))); - return Poll::Ready(()); - } - } - - // flush io stream - match ready!(this.state.with_buf(|buf| flush_io( - &mut *this.io.0.borrow_mut(), - buf, - cx - ))) { - Ok(()) => Poll::Pending, - Err(e) => { - this.state.close(Some(e)); - Poll::Ready(()) - } - } - } - Poll::Ready(WriteStatus::Timeout(time)) => { - log::trace!("initiate timeout delay for {:?}", time); - if delay.is_none() { - *delay = Some(sleep(time)); - } - self.poll(cx) - } - Poll::Ready(WriteStatus::Shutdown(time)) => { - log::trace!("write task is instructed to shutdown"); - - let timeout = if let Some(delay) = delay.take() { - delay - } else { - sleep(time) - }; - - this.st = IoWriteState::Shutdown(timeout, Shutdown::Flush); - self.poll(cx) - } - Poll::Ready(WriteStatus::Terminate) => { - log::trace!("write task is instructed to terminate"); +struct UnixWrite(UnixStream); - let _ = Pin::new(&mut *this.io.0.borrow_mut()).poll_close(cx); - this.state.close(None); - Poll::Ready(()) - } - Poll::Pending => Poll::Pending, - } +impl ntex_io::AsyncWrite for UnixWrite { + #[inline] + async fn write(&mut self, buf: &mut WriteContextBuf) -> io::Result<()> { + poll_fn(|cx| { + if let Some(mut b) = buf.take() { + let result = flush_io(&mut *self.0 .0.borrow_mut(), &mut b, cx); + buf.set(b); + result + } else { + Poll::Ready(Ok(())) } - IoWriteState::Shutdown(ref mut delay, ref mut st) => { - // close WRITE side and wait for disconnect on read side. - // use disconnect timeout, otherwise it could hang forever. - loop { - match st { - Shutdown::Flush => { - // flush write buffer - let mut io = this.io.0.borrow_mut(); - match this.state.with_buf(|buf| flush_io(&mut *io, buf, cx)) { - Poll::Ready(Ok(())) => { - let io = this.io.clone(); - #[allow(clippy::await_holding_refcell_ref)] - let fut = Box::pin(async move { - io.0.borrow() - .shutdown(std::net::Shutdown::Write) - .await - }); - *st = Shutdown::Close(fut); - continue; - } - Poll::Ready(Err(err)) => { - log::trace!( - "write task is closed with err during flush, {:?}", - err - ); - this.state.close(Some(err)); - return Poll::Ready(()); - } - Poll::Pending => (), - } - } - Shutdown::Close(ref mut fut) => { - if ready!(fut.poll(cx)).is_err() { - this.state.close(None); - return Poll::Ready(()); - } - *st = Shutdown::Stopping(0); - continue; - } - Shutdown::Stopping(ref mut count) => { - // read until 0 or err - let mut buf = [0u8; 512]; - let io = &mut this.io; - loop { - match Pin::new(&mut *io.0.borrow_mut()) - .poll_read(cx, &mut buf) - { - Poll::Ready(Err(e)) => { - log::trace!("write task is stopped"); - this.state.close(Some(e)); - return Poll::Ready(()); - } - Poll::Ready(Ok(0)) => { - log::trace!("glommio socket is disconnected"); - this.state.close(None); - return Poll::Ready(()); - } - Poll::Ready(Ok(n)) => { - *count += n as u16; - if *count > 4096 { - log::trace!( - "write task is stopped, too much input" - ); - this.state.close(None); - return Poll::Ready(()); - } - } - Poll::Pending => break, - } - } - } - } + }) + .await + } - // disconnect timeout - if delay.poll_elapsed(cx).is_pending() { - return Poll::Pending; - } - log::trace!("write task is stopped after delay"); - this.state.close(None); - let _ = Pin::new(&mut *this.io.0.borrow_mut()).poll_close(cx); - return Poll::Ready(()); - } - } - } + #[inline] + async fn flush(&mut self) -> io::Result<()> { + Ok(()) + } + + #[inline] + async fn shutdown(&mut self) -> io::Result<()> { + poll_fn(|cx| Pin::new(&mut *self.0 .0.borrow_mut()).poll_close(cx)).await } } /// Flush write buffer to underlying I/O stream. pub(super) fn flush_io( io: &mut T, - buf: &mut Option, + buf: &mut BytesVec, cx: &mut Context<'_>, ) -> Poll> { - if let Some(buf) = buf { - let len = buf.len(); - - if len != 0 { - // log::trace!("flushing framed transport: {:?}", buf.len()); - - let mut written = 0; - let result = loop { - break match Pin::new(&mut *io).poll_write(cx, &buf[written..]) { - Poll::Ready(Ok(n)) => { - if n == 0 { - log::trace!("Disconnected during flush, written {}", written); - Poll::Ready(Err(io::Error::new( - io::ErrorKind::WriteZero, - "failed to write frame to transport", - ))) + let len = buf.len(); + + if len != 0 { + // log::trace!("flushing framed transport: {:?}", buf.len()); + + let mut written = 0; + let result = loop { + break match Pin::new(&mut *io).poll_write(cx, &buf[written..]) { + Poll::Ready(Ok(n)) => { + if n == 0 { + log::trace!("Disconnected during flush, written {}", written); + Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "failed to write frame to transport", + ))) + } else { + written += n; + if written == len { + buf.clear(); + Poll::Ready(Ok(())) } else { - written += n; - if written == len { - buf.clear(); - Poll::Ready(Ok(())) - } else { - continue; - } + continue; } } - Poll::Pending => { - // remove written data - buf.advance(written); - Poll::Pending - } - Poll::Ready(Err(e)) => { - log::trace!("Error during flush: {}", e); - Poll::Ready(Err(e)) - } - }; - }; - log::trace!("flushed {} bytes", written); - - // flush - return if written > 0 { - match Pin::new(&mut *io).poll_flush(cx) { - Poll::Ready(Ok(_)) => result, - Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => { - log::trace!("error during flush: {}", e); - Poll::Ready(Err(e)) - } } - } else { - result + Poll::Pending => { + // remove written data + buf.advance(written); + Poll::Pending + } + Poll::Ready(Err(e)) => { + log::trace!("Error during flush: {}", e); + Poll::Ready(Err(e)) + } }; + }; + // log::trace!("flushed {} bytes", written); + + // flush + if written > 0 { + match Pin::new(&mut *io).poll_flush(cx) { + Poll::Ready(Ok(_)) => result, + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => { + log::trace!("error during flush: {}", e); + Poll::Ready(Err(e)) + } + } + } else { + result } + } else { + Poll::Ready(Ok(())) } - Poll::Ready(Ok(())) } pub fn poll_read_buf( @@ -357,232 +203,3 @@ pub fn poll_read_buf( Poll::Ready(Ok(n)) } - -/// Read io task -struct UnixReadTask { - io: UnixStream, - state: ReadContext, -} - -impl UnixReadTask { - /// Create new read io task - fn new(io: UnixStream, state: ReadContext) -> Self { - Self { io, state } - } -} - -impl Future for UnixReadTask { - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_mut(); - - this.state.with_buf(|buf, hw, lw| { - match ready!(this.state.poll_ready(cx)) { - ReadStatus::Ready => { - // read data from socket - loop { - // make sure we've got room - let remaining = buf.remaining_mut(); - if remaining < lw { - buf.reserve(hw - remaining); - } - - return match poll_read_buf( - Pin::new(&mut *this.io.0.borrow_mut()), - cx, - buf, - ) { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(n)) => { - if n == 0 { - log::trace!("glommio stream is disconnected"); - Poll::Ready(Ok(())) - } else if buf.len() < hw { - continue; - } else { - Poll::Pending - } - } - Poll::Ready(Err(err)) => { - log::trace!("read task failed on io {:?}", err); - Poll::Ready(Err(err)) - } - }; - } - } - ReadStatus::Terminate => { - log::trace!("read task is instructed to shutdown"); - Poll::Ready(Ok(())) - } - } - }) - } -} - -/// Write io task -struct UnixWriteTask { - st: IoWriteState, - io: UnixStream, - state: WriteContext, -} - -impl UnixWriteTask { - /// Create new write io task - fn new(io: UnixStream, state: WriteContext) -> Self { - Self { - io, - state, - st: IoWriteState::Processing(None), - } - } -} - -impl Future for UnixWriteTask { - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_mut().get_mut(); - - match this.st { - IoWriteState::Processing(ref mut delay) => { - match this.state.poll_ready(cx) { - Poll::Ready(WriteStatus::Ready) => { - if let Some(delay) = delay { - if delay.poll_elapsed(cx).is_ready() { - this.state.close(Some(io::Error::new( - io::ErrorKind::TimedOut, - "Operation timedout", - ))); - return Poll::Ready(()); - } - } - - // flush io stream - match ready!(this.state.with_buf(|buf| flush_io( - &mut *this.io.0.borrow_mut(), - buf, - cx - ))) { - Ok(()) => Poll::Pending, - Err(e) => { - this.state.close(Some(e)); - Poll::Ready(()) - } - } - } - Poll::Ready(WriteStatus::Timeout(time)) => { - log::trace!("initiate timeout delay for {:?}", time); - if delay.is_none() { - *delay = Some(sleep(time)); - } - self.poll(cx) - } - Poll::Ready(WriteStatus::Shutdown(time)) => { - log::trace!("write task is instructed to shutdown"); - - let timeout = if let Some(delay) = delay.take() { - delay - } else { - sleep(time) - }; - - this.st = IoWriteState::Shutdown(timeout, Shutdown::Flush); - self.poll(cx) - } - Poll::Ready(WriteStatus::Terminate) => { - log::trace!("write task is instructed to terminate"); - - let _ = Pin::new(&mut *this.io.0.borrow_mut()).poll_close(cx); - this.state.close(None); - Poll::Ready(()) - } - Poll::Pending => Poll::Pending, - } - } - IoWriteState::Shutdown(ref mut delay, ref mut st) => { - // close WRITE side and wait for disconnect on read side. - // use disconnect timeout, otherwise it could hang forever. - loop { - match st { - Shutdown::Flush => { - // flush write buffer - let mut io = this.io.0.borrow_mut(); - match this.state.with_buf(|buf| flush_io(&mut *io, buf, cx)) { - Poll::Ready(Ok(())) => { - let io = this.io.clone(); - #[allow(clippy::await_holding_refcell_ref)] - let fut = Box::pin(async move { - io.0.borrow() - .shutdown(std::net::Shutdown::Write) - .await - }); - *st = Shutdown::Close(fut); - continue; - } - Poll::Ready(Err(err)) => { - log::trace!( - "write task is closed with err during flush, {:?}", - err - ); - this.state.close(Some(err)); - return Poll::Ready(()); - } - Poll::Pending => (), - } - } - Shutdown::Close(ref mut fut) => { - if ready!(fut.poll(cx)).is_err() { - this.state.close(None); - return Poll::Ready(()); - } - *st = Shutdown::Stopping(0); - continue; - } - Shutdown::Stopping(ref mut count) => { - // read until 0 or err - let mut buf = [0u8; 512]; - let io = &mut this.io; - loop { - match Pin::new(&mut *io.0.borrow_mut()) - .poll_read(cx, &mut buf) - { - Poll::Ready(Err(e)) => { - log::trace!("write task is stopped"); - this.state.close(Some(e)); - return Poll::Ready(()); - } - Poll::Ready(Ok(0)) => { - log::trace!("glommio unix socket is disconnected"); - this.state.close(None); - return Poll::Ready(()); - } - Poll::Ready(Ok(n)) => { - *count += n as u16; - if *count > 4096 { - log::trace!( - "write task is stopped, too much input" - ); - this.state.close(None); - return Poll::Ready(()); - } - } - Poll::Pending => break, - } - } - } - } - - // disconnect timeout - if delay.poll_elapsed(cx).is_pending() { - return Poll::Pending; - } - log::trace!("write task is stopped after delay"); - this.state.close(None); - let _ = Pin::new(&mut *this.io.0.borrow_mut()).poll_close(cx); - return Poll::Ready(()); - } - } - } - } -} diff --git a/ntex-io/CHANGES.md b/ntex-io/CHANGES.md index 4f296504c..9c32c3aca 100644 --- a/ntex-io/CHANGES.md +++ b/ntex-io/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [2.5.0] - 2024-09-10 + +* Refactor async io support + ## [2.3.1] - 2024-09-05 * Tune async io tasks support diff --git a/ntex-io/Cargo.toml b/ntex-io/Cargo.toml index 798cd3396..5c088194a 100644 --- a/ntex-io/Cargo.toml +++ b/ntex-io/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-io" -version = "2.4.0" +version = "2.5.0" authors = ["ntex contributors "] description = "Utilities for encoding and decoding frames" keywords = ["network", "framework", "async", "futures"] diff --git a/ntex-io/src/buf.rs b/ntex-io/src/buf.rs index 478442efb..7d4624f05 100644 --- a/ntex-io/src/buf.rs +++ b/ntex-io/src/buf.rs @@ -152,27 +152,6 @@ impl Stack { } } - pub(crate) fn with_read_source(&self, io: &IoRef, f: F) -> R - where - F: FnOnce(&mut BytesVec) -> R, - { - let item = self.get_last_level(); - let mut rb = item.0.take(); - if rb.is_none() { - rb = Some(io.memory_pool().get_read_buf()); - } - - let result = f(rb.as_mut().unwrap()); - if let Some(b) = rb { - if b.is_empty() { - io.memory_pool().release_read_buf(b); - } else { - item.0.set(Some(b)); - } - } - result - } - pub(crate) fn with_read_destination(&self, io: &IoRef, f: F) -> R where F: FnOnce(&mut BytesVec) -> R, @@ -226,6 +205,17 @@ impl Stack { self.get_last_level().1.take() } + pub(crate) fn set_write_destination(&self, buf: BytesVec) -> Option { + let b = self.get_last_level().1.take(); + if b.is_some() { + self.get_last_level().1.set(b); + Some(buf) + } else { + self.get_last_level().1.set(Some(buf)); + None + } + } + pub(crate) fn with_write_destination(&self, io: &IoRef, f: F) -> R where F: FnOnce(&mut Option) -> R, diff --git a/ntex-io/src/filter.rs b/ntex-io/src/filter.rs index f74e057c2..78ed7520a 100644 --- a/ntex-io/src/filter.rs +++ b/ntex-io/src/filter.rs @@ -93,26 +93,16 @@ impl Filter for Base { #[inline] fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll { - let mut flags = self.0.flags(); + let flags = self.0.flags(); - if flags.contains(Flags::IO_STOPPED) { + if flags.is_stopped() { Poll::Ready(WriteStatus::Terminate) } else { self.0 .0.write_task.register(cx.waker()); - if flags.intersects(Flags::IO_STOPPING) { - Poll::Ready(WriteStatus::Shutdown( - self.0 .0.disconnect_timeout.get().into(), - )) - } else if flags.contains(Flags::IO_STOPPING_FILTERS) - && !flags.contains(Flags::IO_FILTERS_TIMEOUT) - { - flags.insert(Flags::IO_FILTERS_TIMEOUT); - self.0.set_flags(flags); - Poll::Ready(WriteStatus::Timeout( - self.0 .0.disconnect_timeout.get().into(), - )) - } else if flags.intersects(Flags::WR_PAUSED) { + if flags.contains(Flags::IO_STOPPING) { + Poll::Ready(WriteStatus::Shutdown) + } else if flags.contains(Flags::WR_PAUSED) { Poll::Pending } else { Poll::Ready(WriteStatus::Ready) @@ -242,20 +232,13 @@ where Poll::Pending => Poll::Pending, Poll::Ready(WriteStatus::Ready) => res2, Poll::Ready(WriteStatus::Terminate) => Poll::Ready(WriteStatus::Terminate), - Poll::Ready(WriteStatus::Shutdown(t)) => { + Poll::Ready(WriteStatus::Shutdown) => { if res2 == Poll::Ready(WriteStatus::Terminate) { Poll::Ready(WriteStatus::Terminate) } else { - Poll::Ready(WriteStatus::Shutdown(t)) + Poll::Ready(WriteStatus::Shutdown) } } - Poll::Ready(WriteStatus::Timeout(t)) => match res2 { - Poll::Ready(WriteStatus::Terminate) => Poll::Ready(WriteStatus::Terminate), - Poll::Ready(WriteStatus::Shutdown(t)) => { - Poll::Ready(WriteStatus::Shutdown(t)) - } - _ => Poll::Ready(WriteStatus::Timeout(t)), - }, } } } diff --git a/ntex-io/src/flags.rs b/ntex-io/src/flags.rs index 82056ff54..bc9b5aacf 100644 --- a/ntex-io/src/flags.rs +++ b/ntex-io/src/flags.rs @@ -7,8 +7,6 @@ bitflags::bitflags! { const IO_STOPPING = 0b0000_0000_0000_0010; /// shuting down filters const IO_STOPPING_FILTERS = 0b0000_0000_0000_0100; - /// initiate filters shutdown timeout in write task - const IO_FILTERS_TIMEOUT = 0b0000_0000_0000_1000; /// pause io read const RD_PAUSED = 0b0000_0000_0001_0000; @@ -36,6 +34,10 @@ bitflags::bitflags! { } impl Flags { + pub(crate) fn is_stopped(&self) -> bool { + self.intersects(Flags::IO_STOPPED) + } + pub(crate) fn is_waiting_for_write(&self) -> bool { self.intersects(Flags::BUF_W_MUST_FLUSH | Flags::BUF_W_BACKPRESSURE) } diff --git a/ntex-io/src/io.rs b/ntex-io/src/io.rs index ded121418..8bedd1df7 100644 --- a/ntex-io/src/io.rs +++ b/ntex-io/src/io.rs @@ -165,7 +165,7 @@ impl Io { let inner = Rc::new(IoState { filter: FilterPtr::null(), pool: Cell::new(pool), - flags: Cell::new(Flags::empty()), + flags: Cell::new(Flags::WR_PAUSED), error: Cell::new(None), dispatch_task: LocalWaker::new(), read_task: LocalWaker::new(), @@ -421,7 +421,7 @@ impl Io { let st = self.st(); let mut flags = st.flags.get(); - if flags.contains(Flags::IO_STOPPED) { + if flags.is_stopped() { Poll::Ready(self.error().map(Err).unwrap_or(Ok(None))) } else { st.dispatch_task.register(cx.waker()); @@ -531,7 +531,7 @@ impl Io { } else { let st = self.st(); let flags = st.flags.get(); - if flags.contains(Flags::IO_STOPPED) { + if flags.is_stopped() { Err(RecvError::PeerGone(self.error())) } else if flags.contains(Flags::DSP_STOP) { st.remove_flags(Flags::DSP_STOP); @@ -568,7 +568,7 @@ impl Io { pub fn poll_flush(&self, cx: &mut Context<'_>, full: bool) -> Poll> { let flags = self.flags(); - if flags.contains(Flags::IO_STOPPED) { + if flags.is_stopped() { Poll::Ready(self.error().map(Err).unwrap_or(Ok(()))) } else { let st = self.st(); @@ -595,7 +595,7 @@ impl Io { let st = self.st(); let flags = st.flags.get(); - if flags.intersects(Flags::IO_STOPPED) { + if flags.is_stopped() { if let Some(err) = self.error() { Poll::Ready(Err(err)) } else { @@ -700,7 +700,7 @@ impl Drop for Io { if st.filter.is_set() { // filter is unsafe and must be dropped explicitly, // and wont be dropped without special attention - if !st.flags.get().contains(Flags::IO_STOPPED) { + if !st.flags.get().is_stopped() { log::trace!( "{}: Io is dropped, force stopping io streams {:?}", st.tag.get(), @@ -884,7 +884,7 @@ pub struct OnDisconnect { impl OnDisconnect { pub(super) fn new(inner: Rc) -> Self { - Self::new_inner(inner.flags.get().contains(Flags::IO_STOPPED), inner) + Self::new_inner(inner.flags.get().is_stopped(), inner) } fn new_inner(disconnected: bool, inner: Rc) -> Self { @@ -909,7 +909,7 @@ impl OnDisconnect { #[inline] /// Check if connection is disconnected pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<()> { - if self.token == usize::MAX || self.inner.flags.get().contains(Flags::IO_STOPPED) { + if self.token == usize::MAX || self.inner.flags.get().is_stopped() { Poll::Ready(()) } else if let Some(on_disconnect) = self.inner.on_disconnect.take() { on_disconnect[self.token].register(cx.waker()); diff --git a/ntex-io/src/ioref.rs b/ntex-io/src/ioref.rs index 340c03f54..c905a8cfb 100644 --- a/ntex-io/src/ioref.rs +++ b/ntex-io/src/ioref.rs @@ -14,12 +14,6 @@ impl IoRef { self.0.flags.get() } - #[inline] - /// Set flags - pub(crate) fn set_flags(&self, flags: Flags) { - self.0.flags.set(flags) - } - #[inline] /// Get current filter pub(crate) fn filter(&self) -> &dyn Filter { @@ -41,10 +35,6 @@ impl IoRef { .intersects(Flags::IO_STOPPING | Flags::IO_STOPPED) } - pub(crate) fn is_io_closed(&self) -> bool { - self.0.flags.get().intersects(Flags::IO_STOPPED) - } - #[inline] /// Check if write back-pressure is enabled pub fn is_wr_backpressure(&self) -> bool { diff --git a/ntex-io/src/lib.rs b/ntex-io/src/lib.rs index 7c034ce75..fb7a8867f 100644 --- a/ntex-io/src/lib.rs +++ b/ntex-io/src/lib.rs @@ -1,5 +1,6 @@ //! Utilities for abstructing io streams #![deny(rust_2018_idioms, unreachable_pub, missing_debug_implementations)] +#![allow(async_fn_in_trait)] use std::{ any::Any, any::TypeId, fmt, io as sio, io::Error as IoError, task::Context, task::Poll, @@ -20,8 +21,8 @@ mod tasks; mod timer; mod utils; +use ntex_bytes::BytesVec; use ntex_codec::{Decoder, Encoder}; -use ntex_util::time::Millis; pub use self::buf::{ReadBuf, WriteBuf}; pub use self::dispatcher::{Dispatcher, DispatcherConfig}; @@ -29,13 +30,27 @@ pub use self::filter::{Base, Filter, Layer}; pub use self::framed::Framed; pub use self::io::{Io, IoRef, OnDisconnect}; pub use self::seal::{IoBoxed, Sealed}; -pub use self::tasks::{ReadContext, WriteContext}; +pub use self::tasks::{ReadContext, WriteContext, WriteContextBuf}; pub use self::timer::TimerHandle; pub use self::utils::{seal, Decoded}; #[doc(hidden)] pub use self::flags::Flags; +#[doc(hidden)] +pub trait AsyncRead { + async fn read(&mut self, buf: BytesVec) -> (BytesVec, sio::Result); +} + +#[doc(hidden)] +pub trait AsyncWrite { + async fn write(&mut self, buf: &mut WriteContextBuf) -> sio::Result<()>; + + async fn flush(&mut self) -> sio::Result<()>; + + async fn shutdown(&mut self) -> sio::Result<()>; +} + /// Status for read task #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] pub enum ReadStatus { @@ -48,10 +63,8 @@ pub enum ReadStatus { pub enum WriteStatus { /// Write task is clear to proceed with write operation Ready, - /// Initiate timeout for normal write operations, shutdown connection after timeout - Timeout(Millis), - /// Initiate graceful io shutdown operation with timeout - Shutdown(Millis), + /// Initiate graceful io shutdown operation + Shutdown, /// Immediately terminate connection Terminate, } diff --git a/ntex-io/src/tasks.rs b/ntex-io/src/tasks.rs index e8a14c2eb..497e1f6c6 100644 --- a/ntex-io/src/tasks.rs +++ b/ntex-io/src/tasks.rs @@ -1,16 +1,22 @@ -use std::{future::poll_fn, future::Future, io, task::Context, task::Poll}; +use std::{cell::Cell, fmt, future::poll_fn, io, task::Context, task::Poll}; -use ntex_bytes::{BufMut, BytesVec, PoolRef}; +use ntex_bytes::{BufMut, BytesVec}; +use ntex_util::{future::lazy, future::select, future::Either, time::sleep, time::Sleep}; -use crate::{Flags, IoRef, ReadStatus, WriteStatus}; +use crate::{AsyncRead, AsyncWrite, Flags, IoRef, ReadStatus, WriteStatus}; -#[derive(Debug)] /// Context for io read task -pub struct ReadContext(IoRef); +pub struct ReadContext(IoRef, Cell>); + +impl fmt::Debug for ReadContext { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ReadContext").field("io", &self.0).finish() + } +} impl ReadContext { pub(crate) fn new(io: &IoRef) -> Self { - Self(io.clone()) + Self(io.clone(), Cell::new(None)) } #[inline] @@ -19,15 +25,8 @@ impl ReadContext { self.0.tag() } - #[inline] - /// Check readiness for read operations - pub async fn ready(&self) -> ReadStatus { - poll_fn(|cx| self.0.filter().poll_read_ready(cx)).await - } - - #[inline] /// Wait when io get closed or preparing for close - pub async fn wait_for_close(&self) { + async fn wait_for_close(&self) { poll_fn(|cx| { let flags = self.0.flags(); @@ -36,7 +35,7 @@ impl ReadContext { } else { self.0 .0.read_task.register(cx.waker()); if flags.contains(Flags::IO_STOPPING_FILTERS) { - shutdown_filters(&self.0); + self.shutdown_filters(cx); } Poll::Pending } @@ -44,222 +43,169 @@ impl ReadContext { .await } - #[inline] - /// Check readiness for read operations - pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll { - self.0.filter().poll_read_ready(cx) - } - - /// Get read buffer - pub fn with_buf(&self, f: F) -> Poll<()> + /// Handle read io operations + pub async fn handle(&self, io: &mut T) where - F: FnOnce(&mut BytesVec, usize, usize) -> Poll>, + T: AsyncRead, { let inner = &self.0 .0; - let (hw, lw) = self.0.memory_pool().read_params().unpack(); - let (result, nbytes, total) = inner.buffer.with_read_source(&self.0, |buf| { + + loop { + let result = poll_fn(|cx| self.0.filter().poll_read_ready(cx)).await; + if result == ReadStatus::Terminate { + log::trace!("{}: Read task is instructed to shutdown", self.tag()); + break; + } + + let mut buf = if inner.flags.get().is_read_buf_ready() { + // read buffer is still not read by dispatcher + // we cannot touch it + inner.pool.get().get_read_buf() + } else { + inner + .buffer + .get_read_source() + .unwrap_or_else(|| inner.pool.get().get_read_buf()) + }; + + // make sure we've got room + let (hw, lw) = self.0.memory_pool().read_params().unpack(); + let remaining = buf.remaining_mut(); + if remaining <= lw { + buf.reserve(hw - remaining); + } let total = buf.len(); // call provided callback - let result = f(buf, hw, lw); + let (buf, result) = match select(io.read(buf), self.wait_for_close()).await { + Either::Left(res) => res, + Either::Right(_) => { + log::trace!("{}: Read io is closed, stop read task", self.tag()); + break; + } + }; + + // handle incoming data let total2 = buf.len(); let nbytes = if total2 > total { total2 - total } else { 0 }; - (result, nbytes, total2) - }); - - // handle buffer changes - if nbytes > 0 { - let filter = self.0.filter(); - let _ = filter - .process_read_buf(&self.0, &inner.buffer, 0, nbytes) - .and_then(|status| { - if status.nbytes > 0 { - // dest buffer has new data, wake up dispatcher - if inner.buffer.read_destination_size() >= hw { - log::trace!( + let total = total2; + + if let Some(mut first_buf) = inner.buffer.get_read_source() { + first_buf.extend_from_slice(&buf); + inner.buffer.set_read_source(&self.0, first_buf); + } else { + inner.buffer.set_read_source(&self.0, buf); + } + + // handle buffer changes + if nbytes > 0 { + let filter = self.0.filter(); + let res = match filter.process_read_buf(&self.0, &inner.buffer, 0, nbytes) { + Ok(status) => { + if status.nbytes > 0 { + // check read back-pressure + if hw < inner.buffer.read_destination_size() { + log::trace!( "{}: Io read buffer is too large {}, enable read back-pressure", self.0.tag(), total ); - inner.insert_flags(Flags::BUF_R_READY | Flags::BUF_R_FULL); - } else { - inner.insert_flags(Flags::BUF_R_READY); - - if nbytes >= hw { - // read task is paused because of read back-pressure - // but there is no new data in top most read buffer - // so we need to wake up read task to read more data - // otherwise read task would sleep forever - inner.read_task.wake(); + inner.insert_flags(Flags::BUF_R_READY | Flags::BUF_R_FULL); + } else { + inner.insert_flags(Flags::BUF_R_READY); } - } - log::trace!( - "{}: New {} bytes available, wakeup dispatcher", - self.0.tag(), - nbytes - ); - inner.dispatch_task.wake(); - } else { - if nbytes >= hw { - // read task is paused because of read back-pressure - // but there is no new data in top most read buffer - // so we need to wake up read task to read more data - // otherwise read task would sleep forever - inner.read_task.wake(); - } - if inner.flags.get().contains(Flags::RD_NOTIFY) { + log::trace!( + "{}: New {} bytes available, wakeup dispatcher", + self.0.tag(), + nbytes + ); + // dest buffer has new data, wake up dispatcher + inner.dispatch_task.wake(); + } else if inner.flags.get().contains(Flags::RD_NOTIFY) { // in case of "notify" we must wake up dispatch task // if we read any data from source inner.dispatch_task.wake(); } - } - // while reading, filter wrote some data - // in that case filters need to process write buffers - // and potentialy wake write task - if status.need_write { - filter.process_write_buf(&self.0, &inner.buffer, 0) - } else { - Ok(()) + // while reading, filter wrote some data + // in that case filters need to process write buffers + // and potentialy wake write task + if status.need_write { + filter.process_write_buf(&self.0, &inner.buffer, 0) + } else { + Ok(()) + } } - }) - .map_err(|err| { + Err(err) => Err(err), + }; + + if let Err(err) = res { inner.dispatch_task.wake(); inner.io_stopped(Some(err)); inner.insert_flags(Flags::BUF_R_READY); - }); - } - - match result { - Poll::Ready(Ok(())) => { - inner.io_stopped(None); - Poll::Ready(()) - } - Poll::Ready(Err(e)) => { - inner.io_stopped(Some(e)); - Poll::Ready(()) + } } - Poll::Pending => { - if inner.flags.get().contains(Flags::IO_STOPPING_FILTERS) { - shutdown_filters(&self.0); + + match result { + Ok(0) => { + log::trace!("{}: Tcp stream is disconnected", self.tag()); + inner.io_stopped(None); + break; + } + Ok(_) => { + if inner.flags.get().contains(Flags::IO_STOPPING_FILTERS) { + lazy(|cx| self.shutdown_filters(cx)).await; + } + } + Err(err) => { + log::trace!("{}: Read task failed on io {:?}", self.tag(), err); + inner.io_stopped(Some(err)); + break; } - Poll::Pending } } } - /// Get read buffer (async) - pub async fn with_buf_async(&self, f: F) -> Poll<()> - where - F: FnOnce(BytesVec) -> R, - R: Future)>, - { - let inner = &self.0 .0; - - // // we already pushed new data to read buffer, - // // we have to wait for dispatcher to read data from buffer - // if inner.flags.get().is_read_buf_ready() { - // ntex_util::task::yield_to().await; - // } - - let mut buf = if inner.flags.get().is_read_buf_ready() { - // read buffer is still not read by dispatcher - // we cannot touch it - inner.pool.get().get_read_buf() - } else { - inner - .buffer - .get_read_source() - .unwrap_or_else(|| inner.pool.get().get_read_buf()) - }; - - // make sure we've got room - let (hw, lw) = self.0.memory_pool().read_params().unpack(); - let remaining = buf.remaining_mut(); - if remaining <= lw { - buf.reserve(hw - remaining); - } - let total = buf.len(); - - // call provided callback - let (buf, result) = f(buf).await; - let total2 = buf.len(); - let nbytes = if total2 > total { total2 - total } else { 0 }; - let total = total2; - - if let Some(mut first_buf) = inner.buffer.get_read_source() { - first_buf.extend_from_slice(&buf); - inner.buffer.set_read_source(&self.0, first_buf); - } else { - inner.buffer.set_read_source(&self.0, buf); - } + fn shutdown_filters(&self, cx: &mut Context<'_>) { + let st = &self.0 .0; + let filter = self.0.filter(); - // handle buffer changes - if nbytes > 0 { - let filter = self.0.filter(); - let res = match filter.process_read_buf(&self.0, &inner.buffer, 0, nbytes) { - Ok(status) => { - if status.nbytes > 0 { - // check read back-pressure - if hw < inner.buffer.read_destination_size() { - log::trace!( - "{}: Io read buffer is too large {}, enable read back-pressure", - self.0.tag(), - total - ); - inner.insert_flags(Flags::BUF_R_READY | Flags::BUF_R_FULL); - } else { - inner.insert_flags(Flags::BUF_R_READY); - } - log::trace!( - "{}: New {} bytes available, wakeup dispatcher", - self.0.tag(), - nbytes - ); - // dest buffer has new data, wake up dispatcher - inner.dispatch_task.wake(); - } else if inner.flags.get().contains(Flags::RD_NOTIFY) { - // in case of "notify" we must wake up dispatch task - // if we read any data from source - inner.dispatch_task.wake(); - } - - // while reading, filter wrote some data - // in that case filters need to process write buffers - // and potentialy wake write task - if status.need_write { - filter.process_write_buf(&self.0, &inner.buffer, 0) - } else { - Ok(()) - } - } - Err(err) => Err(err), - }; - - if let Err(err) = res { - inner.dispatch_task.wake(); - inner.io_stopped(Some(err)); - inner.insert_flags(Flags::BUF_R_READY); + match filter.shutdown(&self.0, &st.buffer, 0) { + Ok(Poll::Ready(())) => { + st.dispatch_task.wake(); + st.insert_flags(Flags::IO_STOPPING); } - } + Ok(Poll::Pending) => { + let flags = st.flags.get(); - match result { - Ok(n) => { - if n == 0 { - inner.io_stopped(None); - Poll::Ready(()) + // check read buffer, if buffer is not consumed it is unlikely + // that filter will properly complete shutdown + if flags.contains(Flags::RD_PAUSED) + || flags.contains(Flags::BUF_R_FULL | Flags::BUF_R_READY) + { + st.dispatch_task.wake(); + st.insert_flags(Flags::IO_STOPPING); } else { - if inner.flags.get().contains(Flags::IO_STOPPING_FILTERS) { - shutdown_filters(&self.0); + // filter shutdown timeout + let timeout = self + .1 + .take() + .unwrap_or_else(|| sleep(st.disconnect_timeout.get())); + if timeout.poll_elapsed(cx).is_ready() { + st.dispatch_task.wake(); + st.insert_flags(Flags::IO_STOPPING); + } else { + self.1.set(Some(timeout)); } - Poll::Pending } } - Err(e) => { - inner.io_stopped(Some(e)); - Poll::Ready(()) + Err(err) => { + st.io_stopped(Some(err)); } } + if let Err(err) = filter.process_write_buf(&self.0, &st.buffer, 0) { + st.io_stopped(Some(err)); + } } } @@ -267,6 +213,13 @@ impl ReadContext { /// Context for io write task pub struct WriteContext(IoRef); +#[derive(Debug)] +/// Context buf for io write task +pub struct WriteContextBuf { + io: IoRef, + buf: Option, +} + impl WriteContext { pub(crate) fn new(io: &IoRef) -> Self { Self(io.clone()) @@ -278,104 +231,92 @@ impl WriteContext { self.0.tag() } - #[inline] - /// Return memory pool for this context - pub fn memory_pool(&self) -> PoolRef { - self.0.memory_pool() - } - - #[inline] /// Check readiness for write operations - pub async fn ready(&self) -> WriteStatus { + async fn ready(&self) -> WriteStatus { poll_fn(|cx| self.0.filter().poll_write_ready(cx)).await } - #[inline] - /// Check readiness for write operations - pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll { - self.0.filter().poll_write_ready(cx) + /// Indicate that write io task is stopped + fn close(&self, err: Option) { + self.0 .0.io_stopped(err); } - #[inline] /// Check if io is closed - pub fn poll_close(&self, cx: &mut Context<'_>) -> Poll<()> { - if self.0.is_io_closed() { - Poll::Ready(()) - } else { - self.0 .0.write_task.register(cx.waker()); - Poll::Pending - } + async fn when_stopped(&self) { + poll_fn(|cx| { + if self.0.flags().is_stopped() { + Poll::Ready(()) + } else { + self.0 .0.write_task.register(cx.waker()); + Poll::Pending + } + }) + .await } - /// Get write buffer - pub fn with_buf(&self, f: F) -> Poll> + /// Handle write io operations + pub async fn handle(&self, io: &mut T) where - F: FnOnce(&mut Option) -> Poll>, + T: AsyncWrite, { - let inner = &self.0 .0; - - // call provided callback - let (result, len) = inner.buffer.with_write_destination(&self.0, |buf| { - let result = f(buf); - (result, buf.as_ref().map(|b| b.len()).unwrap_or(0)) - }); + let mut buf = WriteContextBuf { + io: self.0.clone(), + buf: None, + }; - // if write buffer is smaller than high watermark value, turn off back-pressure - let mut flags = inner.flags.get(); - if len == 0 { - if flags.is_waiting_for_write() { - flags.waiting_for_write_is_done(); - inner.dispatch_task.wake(); + loop { + match self.ready().await { + WriteStatus::Ready => { + // write io stream + match select(io.write(&mut buf), self.when_stopped()).await { + Either::Left(Ok(_)) => continue, + Either::Left(Err(e)) => self.close(Some(e)), + Either::Right(_) => return, + } + } + WriteStatus::Shutdown => { + log::trace!("{}: Write task is instructed to shutdown", self.tag()); + + let fut = async { + // write io stream + io.write(&mut buf).await?; + io.flush().await?; + io.shutdown().await?; + Ok(()) + }; + match select(sleep(self.0 .0.disconnect_timeout.get()), fut).await { + Either::Left(_) => self.close(None), + Either::Right(res) => self.close(res.err()), + } + } + WriteStatus::Terminate => { + log::trace!("{}: Write task is instructed to terminate", self.tag()); + self.close(io.shutdown().await.err()); + } } - } else if flags.contains(Flags::BUF_W_BACKPRESSURE) - && len < inner.pool.get().write_params_high() << 1 - { - flags.remove(Flags::BUF_W_BACKPRESSURE); - inner.dispatch_task.wake(); + return; } - - match result { - Poll::Pending => flags.remove(Flags::WR_PAUSED), - Poll::Ready(Ok(())) => flags.insert(Flags::WR_PAUSED), - Poll::Ready(Err(_)) => {} - } - - inner.flags.set(flags); - result } +} - /// Get write buffer (async) - pub async fn with_buf_async(&self, f: F) -> io::Result<()> - where - F: FnOnce(BytesVec) -> R, - R: Future>, - { - let inner = &self.0 .0; - - // running - let mut flags = inner.flags.get(); - if flags.contains(Flags::WR_PAUSED) { - flags.remove(Flags::WR_PAUSED); - inner.flags.set(flags); +impl WriteContextBuf { + pub fn set(&mut self, mut buf: BytesVec) { + if buf.is_empty() { + self.io.memory_pool().release_write_buf(buf); + } else if let Some(b) = self.buf.take() { + buf.extend_from_slice(&b); + self.io.memory_pool().release_write_buf(b); + self.buf = Some(buf); + } else if let Some(b) = self.io.0.buffer.set_write_destination(buf) { + // write buffer is already set + self.buf = Some(b); } - // buffer - let buf = inner.buffer.get_write_destination(); - - // call provided callback - let result = if let Some(buf) = buf { - if !buf.is_empty() { - f(buf).await - } else { - Ok(()) - } - } else { - Ok(()) - }; - // if write buffer is smaller than high watermark value, turn off back-pressure + let inner = &self.io.0; + let len = self.buf.as_ref().map(|b| b.len()).unwrap_or_default() + + inner.buffer.write_destination_size(); let mut flags = inner.flags.get(); - let len = inner.buffer.write_destination_size(); if len == 0 { if flags.is_waiting_for_write() { @@ -391,44 +332,13 @@ impl WriteContext { inner.flags.set(flags); inner.dispatch_task.wake(); } - - result } - #[inline] - /// Indicate that write io task is stopped - pub fn close(&self, err: Option) { - self.0 .0.io_stopped(err); - } -} - -fn shutdown_filters(io: &IoRef) { - let st = &io.0; - let flags = st.flags.get(); - - if !flags.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING) { - let filter = io.filter(); - match filter.shutdown(io, &st.buffer, 0) { - Ok(Poll::Ready(())) => { - st.dispatch_task.wake(); - st.insert_flags(Flags::IO_STOPPING); - } - Ok(Poll::Pending) => { - // check read buffer, if buffer is not consumed it is unlikely - // that filter will properly complete shutdown - if flags.contains(Flags::RD_PAUSED) - || flags.contains(Flags::BUF_R_FULL | Flags::BUF_R_READY) - { - st.dispatch_task.wake(); - st.insert_flags(Flags::IO_STOPPING); - } - } - Err(err) => { - st.io_stopped(Some(err)); - } - } - if let Err(err) = filter.process_write_buf(io, &st.buffer, 0) { - st.io_stopped(Some(err)); + pub fn take(&mut self) -> Option { + if let Some(buf) = self.buf.take() { + Some(buf) + } else { + self.io.0.buffer.get_write_destination() } } } diff --git a/ntex-io/src/testing.rs b/ntex-io/src/testing.rs index 406c42c6f..63c3f7593 100644 --- a/ntex-io/src/testing.rs +++ b/ntex-io/src/testing.rs @@ -1,14 +1,13 @@ //! utilities and helpers for testing #![allow(clippy::let_underscore_future)] -use std::future::{poll_fn, Future}; use std::sync::{Arc, Mutex}; -use std::task::{ready, Context, Poll, Waker}; -use std::{any, cell::RefCell, cmp, fmt, io, mem, net, pin::Pin, rc::Rc}; +use std::task::{Context, Poll, Waker}; +use std::{any, cell::RefCell, cmp, fmt, future::poll_fn, io, mem, net, rc::Rc}; use ntex_bytes::{Buf, BufMut, Bytes, BytesVec}; -use ntex_util::time::{sleep, Millis, Sleep}; +use ntex_util::time::{sleep, Millis}; -use crate::{types, Handle, IoStream, ReadContext, ReadStatus, WriteContext, WriteStatus}; +use crate::{types, Handle, IoStream, ReadContext, WriteContext, WriteContextBuf}; #[derive(Default)] struct AtomicWaker(Arc>>>); @@ -356,14 +355,14 @@ impl IoStream for IoTest { fn start(self, read: ReadContext, write: WriteContext) -> Option> { let io = Rc::new(self); - let _ = ntex_util::spawn(ReadTask { - io: io.clone(), - state: read, + let mut rio = Read(io.clone()); + let _ = ntex_util::spawn(async move { + read.handle(&mut rio).await; }); - let _ = ntex_util::spawn(WriteTask { - io: io.clone(), - state: write, - st: IoWriteState::Processing(None), + + let mut wio = Write(io.clone()); + let _ = ntex_util::spawn(async move { + write.handle(&mut wio).await; }); Some(Box::new(io)) @@ -382,271 +381,97 @@ impl Handle for Rc { } /// Read io task -struct ReadTask { - io: Rc, - state: ReadContext, -} - -impl Future for ReadTask { - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_ref(); +struct Read(Rc); - this.state.with_buf(|buf, hw, lw| { - match this.state.poll_ready(cx) { - Poll::Ready(ReadStatus::Terminate) => { - log::trace!("read task is instructed to terminate"); - Poll::Ready(Ok(())) - } - Poll::Ready(ReadStatus::Ready) => { - let io = &this.io; - - // read data from socket - let mut new_bytes = 0; - loop { - // make sure we've got room - let remaining = buf.remaining_mut(); - if remaining < lw { - buf.reserve(hw - remaining); - } - match io.poll_read_buf(cx, buf) { - Poll::Pending => { - log::trace!( - "no more data in io stream, read: {:?}", - new_bytes - ); - break; - } - Poll::Ready(Ok(n)) => { - if n == 0 { - log::trace!("io stream is disconnected"); - return Poll::Ready(Ok(())); - } else { - new_bytes += n; - if buf.len() >= hw { - log::trace!( - "high water mark pause reading, read: {:?}", - new_bytes - ); - break; - } - } - } - Poll::Ready(Err(err)) => { - log::trace!("read task failed on io {:?}", err); - return Poll::Ready(Err(err)); - } - } - } +impl crate::AsyncRead for Read { + async fn read(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result) { + // read data from socket + let result = poll_fn(|cx| self.0.poll_read_buf(cx, &mut buf)).await; + (buf, result) + } +} - Poll::Pending - } - Poll::Pending => Poll::Pending, +/// Write +struct Write(Rc); + +impl crate::AsyncWrite for Write { + async fn write(&mut self, buf: &mut WriteContextBuf) -> io::Result<()> { + poll_fn(|cx| { + if let Some(mut b) = buf.take() { + let result = write_io(&self.0, &mut b, cx); + buf.set(b); + result + } else { + Poll::Ready(Ok(())) } }) + .await } -} - -#[derive(Debug)] -enum IoWriteState { - Processing(Option), - Shutdown(Option, Shutdown), -} - -#[derive(Debug)] -enum Shutdown { - None, - Flushed, - Stopping, -} -/// Write io task -struct WriteTask { - st: IoWriteState, - io: Rc, - state: WriteContext, -} - -impl Future for WriteTask { - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_mut().get_mut(); - - match this.st { - IoWriteState::Processing(ref mut delay) => { - match this.state.poll_ready(cx) { - Poll::Ready(WriteStatus::Ready) => { - // flush framed instance - match ready!(flush_io(&this.io, &this.state, cx)) { - Ok(()) => Poll::Pending, - Err(e) => { - this.state.close(Some(e)); - Poll::Ready(()) - } - } - } - Poll::Ready(WriteStatus::Timeout(time)) => { - if delay.is_none() { - *delay = Some(sleep(time)); - } - self.poll(cx) - } - Poll::Ready(WriteStatus::Shutdown(time)) => { - log::trace!("write task is instructed to shutdown"); - - let timeout = if let Some(delay) = delay.take() { - delay - } else { - sleep(time) - }; - - this.st = IoWriteState::Shutdown(Some(timeout), Shutdown::None); - self.poll(cx) - } - Poll::Ready(WriteStatus::Terminate) => { - log::trace!("write task is instructed to terminate"); - // shutdown WRITE side - this.io - .local - .lock() - .unwrap() - .borrow_mut() - .flags - .insert(IoTestFlags::CLOSED); - this.state.close(None); - Poll::Ready(()) - } - Poll::Pending => Poll::Pending, - } - } - IoWriteState::Shutdown(ref mut delay, ref mut st) => { - // close WRITE side and wait for disconnect on read side. - // use disconnect timeout, otherwise it could hang forever. - loop { - match st { - Shutdown::None => { - // flush write buffer - match flush_io(&this.io, &this.state, cx) { - Poll::Ready(Ok(())) => { - *st = Shutdown::Flushed; - continue; - } - Poll::Ready(Err(err)) => { - log::trace!( - "write task is closed with err during flush {:?}", - err - ); - this.state.close(Some(err)); - return Poll::Ready(()); - } - Poll::Pending => (), - } - } - Shutdown::Flushed => { - // shutdown WRITE side - this.io - .local - .lock() - .unwrap() - .borrow_mut() - .flags - .insert(IoTestFlags::CLOSED); - *st = Shutdown::Stopping; - continue; - } - Shutdown::Stopping => { - // read until 0 or err - let io = &this.io; - loop { - let mut buf = BytesVec::new(); - match io.poll_read_buf(cx, &mut buf) { - Poll::Ready(Err(e)) => { - this.state.close(Some(e)); - log::trace!("write task is stopped"); - return Poll::Ready(()); - } - Poll::Ready(Ok(0)) => { - this.state.close(None); - log::trace!("write task is stopped"); - return Poll::Ready(()); - } - Poll::Pending => break, - _ => (), - } - } - } - } + async fn flush(&mut self) -> io::Result<()> { + Ok(()) + } - // disconnect timeout - if let Some(ref delay) = delay { - if delay.poll_elapsed(cx).is_pending() { - return Poll::Pending; - } - } - log::trace!("write task is stopped after delay"); - this.state.close(None); - return Poll::Ready(()); - } - } - } + async fn shutdown(&mut self) -> io::Result<()> { + // shutdown WRITE side + self.0 + .local + .lock() + .unwrap() + .borrow_mut() + .flags + .insert(IoTestFlags::CLOSED); + Ok(()) } } /// Flush write buffer to underlying I/O stream. -pub(super) fn flush_io( +pub(super) fn write_io( io: &IoTest, - state: &WriteContext, + buf: &mut BytesVec, cx: &mut Context<'_>, ) -> Poll> { - state.with_buf(|buf| { - if let Some(buf) = buf { - let len = buf.len(); - - if len != 0 { - log::trace!("flushing framed transport: {}", len); - - let mut written = 0; - let result = loop { - break match io.poll_write_buf(cx, &buf[written..]) { - Poll::Ready(Ok(n)) => { - if n == 0 { - log::trace!( - "disconnected during flush, written {}", - written - ); - Poll::Ready(Err(io::Error::new( - io::ErrorKind::WriteZero, - "failed to write frame to transport", - ))) - } else { - written += n; - if written == len { - buf.clear(); - Poll::Ready(Ok(())) - } else { - continue; - } - } - } - Poll::Pending => { - // remove written data - buf.advance(written); - Poll::Pending - } - Poll::Ready(Err(e)) => { - log::trace!("error during flush: {}", e); - Poll::Ready(Err(e)) + let len = buf.len(); + + if len != 0 { + log::trace!("flushing framed transport: {}", len); + + let mut written = 0; + let result = loop { + break match io.poll_write_buf(cx, &buf[written..]) { + Poll::Ready(Ok(n)) => { + if n == 0 { + log::trace!("disconnected during flush, written {}", written); + Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "failed to write frame to transport", + ))) + } else { + written += n; + if written == len { + buf.clear(); + Poll::Ready(Ok(())) + } else { + continue; } - }; - }; - log::trace!("flushed {} bytes", written); - return result; - } - } + } + } + Poll::Pending => { + // remove written data + buf.advance(written); + Poll::Pending + } + Poll::Ready(Err(e)) => { + log::trace!("error during flush: {}", e); + Poll::Ready(Err(e)) + } + }; + }; + log::trace!("flushed {} bytes", written); + result + } else { Poll::Ready(Ok(())) - }) + } } #[cfg(test)] diff --git a/ntex-net/Cargo.toml b/ntex-net/Cargo.toml index c53b555eb..4f3e51172 100644 --- a/ntex-net/Cargo.toml +++ b/ntex-net/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-net" -version = "2.1.0" +version = "2.2.0" authors = ["ntex contributors "] description = "ntexwork utils for ntex framework" keywords = ["network", "framework", "async", "futures"] @@ -34,14 +34,14 @@ async-std = ["ntex-rt/async-std", "ntex-async-std"] ntex-service = "3" ntex-bytes = "0.1" ntex-http = "0.1" -ntex-io = "2.4" +ntex-io = "2.5" ntex-rt = "0.4.14" ntex-util = "2" -ntex-tokio = { version = "0.5.1", optional = true } -ntex-compio = { version = "0.1", optional = true } -ntex-glommio = { version = "0.5", optional = true } -ntex-async-std = { version = "0.5", optional = true } +ntex-tokio = { version = "0.5.2", optional = true } +ntex-compio = { version = "0.1.2", optional = true } +ntex-glommio = { version = "0.5.1", optional = true } +ntex-async-std = { version = "0.5.1", optional = true } log = "0.4" thiserror = "1" diff --git a/ntex-tokio/CHANGES.md b/ntex-tokio/CHANGES.md index ceaa64a87..08b453b71 100644 --- a/ntex-tokio/CHANGES.md +++ b/ntex-tokio/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [0.5.2] - 2024-09-11 + +* Use new io api + ## [0.5.1] - 2024-09-06 * Stop write task if io is closed diff --git a/ntex-tokio/Cargo.toml b/ntex-tokio/Cargo.toml index b741cff97..06dc53817 100644 --- a/ntex-tokio/Cargo.toml +++ b/ntex-tokio/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-tokio" -version = "0.5.1" +version = "0.5.2" authors = ["ntex contributors "] description = "tokio intergration for ntex framework" keywords = ["network", "framework", "async", "futures"] @@ -17,7 +17,7 @@ path = "src/lib.rs" [dependencies] ntex-bytes = "0.1" -ntex-io = "2.4" +ntex-io = "2.5" ntex-util = "2" log = "0.4" tokio = { version = "1", default-features = false, features = ["rt", "net", "sync", "signal"] } diff --git a/ntex-tokio/src/io.rs b/ntex-tokio/src/io.rs index 06f1656ad..6c2a1f539 100644 --- a/ntex-tokio/src/io.rs +++ b/ntex-tokio/src/io.rs @@ -1,12 +1,12 @@ use std::task::{Context, Poll}; -use std::{any, cell::RefCell, cmp, future::Future, io, mem, pin::Pin, rc::Rc, rc::Weak}; +use std::{any, cell::RefCell, cmp, future::poll_fn, io, mem, pin::Pin, rc::Rc, rc::Weak}; use ntex_bytes::{Buf, BufMut, BytesVec}; use ntex_io::{ - types, Filter, Handle, Io, IoBoxed, IoStream, ReadContext, ReadStatus, WriteContext, - WriteStatus, + types, Filter, Handle, Io, IoBoxed, IoStream, ReadContext, WriteContext, + WriteContextBuf, }; -use ntex_util::{ready, time::sleep, time::Millis, time::Sleep}; +use ntex_util::{ready, time::Millis}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::net::TcpStream; @@ -14,8 +14,14 @@ impl IoStream for crate::TcpStream { fn start(self, read: ReadContext, write: WriteContext) -> Option> { let io = Rc::new(RefCell::new(self.0)); - tokio::task::spawn_local(ReadTask::new(io.clone(), read)); - tokio::task::spawn_local(WriteTask::new(io.clone(), write)); + let mut rio = Read(io.clone()); + tokio::task::spawn_local(async move { + read.handle(&mut rio).await; + }); + let mut wio = Write(io.clone()); + tokio::task::spawn_local(async move { + write.handle(&mut wio).await; + }); Some(Box::new(HandleWrapper(io))) } } @@ -36,345 +42,149 @@ impl Handle for HandleWrapper { } /// Read io task -struct ReadTask { - io: Rc>, - state: ReadContext, -} - -impl ReadTask { - /// Create new read io task - fn new(io: Rc>, state: ReadContext) -> Self { - Self { io, state } - } -} +struct Read(Rc>); -impl Future for ReadTask { - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_ref(); - - match ready!(this.state.poll_ready(cx)) { - ReadStatus::Ready => { - this.state.with_buf(|buf, hw, lw| { - // read data from socket - let mut io = this.io.borrow_mut(); - loop { - // make sure we've got room - let remaining = buf.remaining_mut(); - if remaining < lw { - buf.reserve(hw - remaining); +impl ntex_io::AsyncRead for Read { + #[inline] + async fn read(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result) { + // read data from socket + let result = poll_fn(|cx| { + let mut n = 0; + let mut io = self.0.borrow_mut(); + loop { + return match poll_read_buf(Pin::new(&mut *io), cx, &mut buf)? { + Poll::Pending => { + if n > 0 { + Poll::Ready(Ok(n)) + } else { + Poll::Pending } - return match poll_read_buf(Pin::new(&mut *io), cx, buf) { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(n)) => { - if n == 0 { - log::trace!( - "{}: Tcp stream is disconnected", - this.state.tag() - ); - Poll::Ready(Ok(())) - } else if buf.len() < hw { - continue; - } else { - Poll::Pending - } - } - Poll::Ready(Err(err)) => { - log::trace!( - "{}: Read task failed on io {:?}", - this.state.tag(), - err - ); - Poll::Ready(Err(err)) - } - }; } - }) - } - ReadStatus::Terminate => { - log::trace!("{}: Read task is instructed to shutdown", this.state.tag()); - Poll::Ready(()) + Poll::Ready(size) => { + n += size; + if n > 0 && buf.remaining_mut() > 0 { + continue; + } + Poll::Ready(Ok(n)) + } + }; } - } + }) + .await; + + (buf, result) } } -#[derive(Debug)] -enum IoWriteState { - Processing(Option), - Shutdown(Sleep, Shutdown), -} +struct Write(Rc>); -#[derive(Debug)] -enum Shutdown { - None, - Flushed, - Stopping(u16), -} +impl ntex_io::AsyncWrite for Write { + #[inline] + async fn write(&mut self, buf: &mut WriteContextBuf) -> io::Result<()> { + poll_fn(|cx| { + if let Some(mut b) = buf.take() { + let result = flush_io(&mut *self.0.borrow_mut(), &mut b, cx); + buf.set(b); + result + } else { + Poll::Ready(Ok(())) + } + }) + .await + } -/// Write io task -struct WriteTask { - st: IoWriteState, - io: Rc>, - state: WriteContext, -} + #[inline] + async fn flush(&mut self) -> io::Result<()> { + Ok(()) + } -impl WriteTask { - /// Create new write io task - fn new(io: Rc>, state: WriteContext) -> Self { - Self { - io, - state, - st: IoWriteState::Processing(None), - } + #[inline] + async fn shutdown(&mut self) -> io::Result<()> { + poll_fn(|cx| Pin::new(&mut *self.0.borrow_mut()).poll_shutdown(cx)).await } } -impl Future for WriteTask { - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_mut().get_mut(); - - if this.state.poll_close(cx).is_ready() { - return Poll::Ready(()); +pub fn poll_read_buf( + io: Pin<&mut T>, + cx: &mut Context<'_>, + buf: &mut BytesVec, +) -> Poll> { + let n = { + let dst = + unsafe { &mut *(buf.chunk_mut() as *mut _ as *mut [mem::MaybeUninit]) }; + let mut buf = ReadBuf::uninit(dst); + let ptr = buf.filled().as_ptr(); + if io.poll_read(cx, &mut buf)?.is_pending() { + return Poll::Pending; } - match this.st { - IoWriteState::Processing(ref mut delay) => { - match ready!(this.state.poll_ready(cx)) { - WriteStatus::Ready => { - if let Some(delay) = delay { - if delay.poll_elapsed(cx).is_ready() { - this.state.close(Some(io::Error::new( - io::ErrorKind::TimedOut, - "Operation timedout", - ))); - return Poll::Ready(()); - } - } - - // flush io stream - match ready!(this.state.with_buf(|buf| flush_io( - &mut *this.io.borrow_mut(), - buf, - cx, - &this.state - ))) { - Ok(()) => Poll::Pending, - Err(e) => { - this.state.close(Some(e)); - Poll::Ready(()) - } - } - } - WriteStatus::Timeout(time) => { - log::trace!( - "{}: Initiate timeout delay for {:?}", - this.state.tag(), - time - ); - if delay.is_none() { - *delay = Some(sleep(time)); - } - self.poll(cx) - } - WriteStatus::Shutdown(time) => { - log::trace!( - "{}: Write task is instructed to shutdown", - this.state.tag() - ); - - let timeout = if let Some(delay) = delay.take() { - delay - } else { - sleep(time) - }; - - this.st = IoWriteState::Shutdown(timeout, Shutdown::None); - self.poll(cx) - } - WriteStatus::Terminate => { - log::trace!( - "{}: Write task is instructed to terminate", - this.state.tag() - ); - - if !matches!( - this.io.borrow().linger(), - Ok(Some(std::time::Duration::ZERO)) - ) { - // call shutdown to prevent flushing data on terminated Io. when - // linger is set to zero, closing will reset the connection, so - // shutdown is not neccessary. - let _ = Pin::new(&mut *this.io.borrow_mut()).poll_shutdown(cx); - } - this.state.close(None); - Poll::Ready(()) - } - } - } - IoWriteState::Shutdown(ref mut delay, ref mut st) => { - // close WRITE side and wait for disconnect on read side. - // use disconnect timeout, otherwise it could hang forever. - loop { - if this.state.poll_close(cx).is_ready() { - return Poll::Ready(()); - } - match st { - Shutdown::None => { - // flush write buffer - let mut io = this.io.borrow_mut(); - match this - .state - .with_buf(|buf| flush_io(&mut *io, buf, cx, &this.state)) - { - Poll::Ready(Ok(())) => { - *st = Shutdown::Flushed; - continue; - } - Poll::Ready(Err(err)) => { - log::trace!( - "{}: Write task is closed with err during flush, {:?}", this.state.tag(), - err - ); - this.state.close(Some(err)); - return Poll::Ready(()); - } - Poll::Pending => (), - } - } - Shutdown::Flushed => { - // shutdown WRITE side - match Pin::new(&mut *this.io.borrow_mut()).poll_shutdown(cx) { - Poll::Ready(Ok(_)) => { - *st = Shutdown::Stopping(0); - continue; - } - Poll::Ready(Err(e)) => { - log::trace!( - "{}: Write task is closed with err during shutdown", - this.state.tag() - ); - this.state.close(Some(e)); - return Poll::Ready(()); - } - _ => (), - } - } - Shutdown::Stopping(ref mut count) => { - // read until 0 or err - let mut buf = [0u8; 512]; - loop { - let mut read_buf = ReadBuf::new(&mut buf); - match Pin::new(&mut *this.io.borrow_mut()) - .poll_read(cx, &mut read_buf) - { - Poll::Ready(Err(_)) | Poll::Ready(Ok(_)) - if read_buf.filled().is_empty() => - { - this.state.close(None); - log::trace!( - "{}: Tokio write task is stopped", - this.state.tag() - ); - return Poll::Ready(()); - } - Poll::Pending => { - *count += read_buf.filled().len() as u16; - if *count > 4096 { - log::trace!("{}: Tokio write task is stopped, too much input", this.state.tag()); - this.state.close(None); - return Poll::Ready(()); - } - break; - } - _ => (), - } - } - } - } + // Ensure the pointer does not change from under us + assert_eq!(ptr, buf.filled().as_ptr()); + buf.filled().len() + }; - // disconnect timeout - if delay.poll_elapsed(cx).is_pending() { - return Poll::Pending; - } - log::trace!("{}: Write task is stopped after delay", this.state.tag()); - this.state.close(None); - return Poll::Ready(()); - } - } - } + // Safety: This is guaranteed to be the number of initialized (and read) + // bytes due to the invariants provided by `ReadBuf::filled`. + unsafe { + buf.advance_mut(n); } + + Poll::Ready(Ok(n)) } /// Flush write buffer to underlying I/O stream. pub(super) fn flush_io( io: &mut T, - buf: &mut Option, + buf: &mut BytesVec, cx: &mut Context<'_>, - st: &WriteContext, ) -> Poll> { - if let Some(buf) = buf { - let len = buf.len(); - - if len != 0 { - // log::trace!("{}: Flushing framed transport: {:?}", st.tag(), buf.len()); - - let mut written = 0; - let result = loop { - break match Pin::new(&mut *io).poll_write(cx, &buf[written..]) { - Poll::Ready(Ok(n)) => { - if n == 0 { - log::trace!( - "{}: Disconnected during flush, written {}", - st.tag(), - written - ); - Poll::Ready(Err(io::Error::new( - io::ErrorKind::WriteZero, - "failed to write frame to transport", - ))) + let len = buf.len(); + + if len != 0 { + // log::trace!("{}: Flushing framed transport: {:?}", st.tag(), buf.len()); + + let mut written = 0; + let result = loop { + break match Pin::new(&mut *io).poll_write(cx, &buf[written..]) { + Poll::Ready(Ok(n)) => { + if n == 0 { + Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "failed to write frame to transport", + ))) + } else { + written += n; + if written == len { + buf.clear(); + Poll::Ready(Ok(())) } else { - written += n; - if written == len { - buf.clear(); - Poll::Ready(Ok(())) - } else { - continue; - } + continue; } } - Poll::Pending => { - // remove written data - buf.advance(written); - Poll::Pending - } - Poll::Ready(Err(e)) => { - log::trace!("{}: Error during flush: {}", st.tag(), e); - Poll::Ready(Err(e)) - } - }; - }; - // log::trace!("{}: flushed {} bytes", st.tag(), written); - - // flush - return if written > 0 { - match Pin::new(&mut *io).poll_flush(cx) { - Poll::Ready(Ok(_)) => result, - Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => { - log::trace!("{}: Error during flush: {}", st.tag(), e); - Poll::Ready(Err(e)) - } } - } else { - result + Poll::Pending => { + // remove written data + buf.advance(written); + Poll::Pending + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), }; + }; + // log::trace!("{}: flushed {} bytes", st.tag(), written); + + // flush + if written > 0 { + match Pin::new(&mut *io).poll_flush(cx) { + Poll::Ready(Ok(_)) => result, + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + } + } else { + result } + } else { + Poll::Ready(Ok(())) } - Poll::Ready(Ok(())) } pub struct TokioIoBoxed(IoBoxed); @@ -472,294 +282,77 @@ mod unixstream { fn start(self, read: ReadContext, write: WriteContext) -> Option> { let io = Rc::new(RefCell::new(self.0)); - tokio::task::spawn_local(ReadTask::new(io.clone(), read)); - tokio::task::spawn_local(WriteTask::new(io, write)); + let mut rio = Read(io.clone()); + tokio::task::spawn_local(async move { + read.handle(&mut rio).await; + }); + let mut wio = Write(io.clone()); + tokio::task::spawn_local(async move { + write.handle(&mut wio).await; + }); None } } - /// Read io task - struct ReadTask { - io: Rc>, - state: ReadContext, - } - - impl ReadTask { - /// Create new read io task - fn new(io: Rc>, state: ReadContext) -> Self { - Self { io, state } - } - } + struct Read(Rc>); - impl Future for ReadTask { - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_ref(); - - this.state.with_buf(|buf, hw, lw| { - match ready!(this.state.poll_ready(cx)) { - ReadStatus::Ready => { - // read data from socket - let mut io = this.io.borrow_mut(); - loop { - // make sure we've got room - let remaining = buf.remaining_mut(); - if remaining < lw { - buf.reserve(hw - remaining); + impl ntex_io::AsyncRead for Read { + #[inline] + async fn read(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result) { + // read data from socket + let result = poll_fn(|cx| { + let mut n = 0; + let mut io = self.0.borrow_mut(); + loop { + return match poll_read_buf(Pin::new(&mut *io), cx, &mut buf)? { + Poll::Pending => { + if n > 0 { + Poll::Ready(Ok(n)) + } else { + Poll::Pending } - - return match poll_read_buf(Pin::new(&mut *io), cx, buf) { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(n)) => { - if n == 0 { - log::trace!( - "{}: Tokio unix stream is disconnected", - this.state.tag() - ); - Poll::Ready(Ok(())) - } else if buf.len() < hw { - continue; - } else { - Poll::Pending - } - } - Poll::Ready(Err(err)) => { - log::trace!( - "{}: Unix stream read task failed {:?}", - this.state.tag(), - err - ); - Poll::Ready(Err(err)) - } - }; } - } - ReadStatus::Terminate => { - log::trace!( - "{}: Read task is instructed to shutdown", - this.state.tag() - ); - Poll::Ready(Ok(())) - } + Poll::Ready(size) => { + n += size; + if n > 0 && buf.remaining_mut() > 0 { + continue; + } + Poll::Ready(Ok(n)) + } + }; } }) - } - } - - /// Write io task - struct WriteTask { - st: IoWriteState, - io: Rc>, - state: WriteContext, - } + .await; - impl WriteTask { - /// Create new write io task - fn new(io: Rc>, state: WriteContext) -> Self { - Self { - io, - state, - st: IoWriteState::Processing(None), - } + (buf, result) } } - impl Future for WriteTask { - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_mut().get_mut(); - - if this.state.poll_close(cx).is_ready() { - return Poll::Ready(()); - } - - match this.st { - IoWriteState::Processing(ref mut delay) => { - match this.state.poll_ready(cx) { - Poll::Ready(WriteStatus::Ready) => { - if let Some(delay) = delay { - if delay.poll_elapsed(cx).is_ready() { - this.state.close(Some(io::Error::new( - io::ErrorKind::TimedOut, - "Operation timedout", - ))); - return Poll::Ready(()); - } - } - - // flush io stream - match ready!(this.state.with_buf(|buf| flush_io( - &mut *this.io.borrow_mut(), - buf, - cx, - &this.state - ))) { - Ok(()) => Poll::Pending, - Err(e) => { - this.state.close(Some(e)); - Poll::Ready(()) - } - } - } - Poll::Ready(WriteStatus::Timeout(time)) => { - if delay.is_none() { - *delay = Some(sleep(time)); - } - self.poll(cx) - } - Poll::Ready(WriteStatus::Shutdown(time)) => { - log::trace!( - "{}: Write task is instructed to shutdown", - this.state.tag() - ); - - let timeout = if let Some(delay) = delay.take() { - delay - } else { - sleep(time) - }; - - this.st = IoWriteState::Shutdown(timeout, Shutdown::None); - self.poll(cx) - } - Poll::Ready(WriteStatus::Terminate) => { - log::trace!( - "{}: Write task is instructed to terminate", - this.state.tag() - ); - - let _ = Pin::new(&mut *this.io.borrow_mut()).poll_shutdown(cx); - this.state.close(None); - Poll::Ready(()) - } - Poll::Pending => Poll::Pending, - } + struct Write(Rc>); + + impl ntex_io::AsyncWrite for Write { + #[inline] + async fn write(&mut self, buf: &mut WriteContextBuf) -> io::Result<()> { + poll_fn(|cx| { + if let Some(mut b) = buf.take() { + let result = flush_io(&mut *self.0.borrow_mut(), &mut b, cx); + buf.set(b); + result + } else { + Poll::Ready(Ok(())) } - IoWriteState::Shutdown(ref mut delay, ref mut st) => { - // close WRITE side and wait for disconnect on read side. - // use disconnect timeout, otherwise it could hang forever. - loop { - if this.state.poll_close(cx).is_ready() { - return Poll::Ready(()); - } - match st { - Shutdown::None => { - // flush write buffer - let mut io = this.io.borrow_mut(); - match this.state.with_buf(|buf| { - flush_io(&mut *io, buf, cx, &this.state) - }) { - Poll::Ready(Ok(())) => { - *st = Shutdown::Flushed; - continue; - } - Poll::Ready(Err(err)) => { - log::trace!( - "{}: Write task is closed with err during flush, {:?}", this.state.tag(), - err - ); - this.state.close(Some(err)); - return Poll::Ready(()); - } - Poll::Pending => (), - } - } - Shutdown::Flushed => { - // shutdown WRITE side - match Pin::new(&mut *this.io.borrow_mut()).poll_shutdown(cx) - { - Poll::Ready(Ok(_)) => { - *st = Shutdown::Stopping(0); - continue; - } - Poll::Ready(Err(e)) => { - log::trace!( - "{}: Write task is closed with err during shutdown", this.state.tag() - ); - this.state.close(Some(e)); - return Poll::Ready(()); - } - _ => (), - } - } - Shutdown::Stopping(ref mut count) => { - // read until 0 or err - let mut buf = [0u8; 512]; - loop { - let mut read_buf = ReadBuf::new(&mut buf); - match Pin::new(&mut *this.io.borrow_mut()) - .poll_read(cx, &mut read_buf) - { - Poll::Ready(Err(_)) | Poll::Ready(Ok(_)) - if read_buf.filled().is_empty() => - { - this.state.close(None); - log::trace!( - "{}: Write task is stopped", - this.state.tag() - ); - return Poll::Ready(()); - } - Poll::Pending => { - *count += read_buf.filled().len() as u16; - if *count > 4096 { - log::trace!( - "{}: Write task is stopped, too much input", this.state.tag() - ); - this.state.close(None); - return Poll::Ready(()); - } - break; - } - _ => (), - } - } - } - } - - // disconnect timeout - if delay.poll_elapsed(cx).is_pending() { - return Poll::Pending; - } - log::trace!( - "{}: Write task is stopped after delay", - this.state.tag() - ); - this.state.close(None); - return Poll::Ready(()); - } - } - } + }) + .await } - } -} -pub fn poll_read_buf( - io: Pin<&mut T>, - cx: &mut Context<'_>, - buf: &mut BytesVec, -) -> Poll> { - let n = { - let dst = - unsafe { &mut *(buf.chunk_mut() as *mut _ as *mut [mem::MaybeUninit]) }; - let mut buf = ReadBuf::uninit(dst); - let ptr = buf.filled().as_ptr(); - if io.poll_read(cx, &mut buf)?.is_pending() { - return Poll::Pending; + #[inline] + async fn flush(&mut self) -> io::Result<()> { + Ok(()) } - // Ensure the pointer does not change from under us - assert_eq!(ptr, buf.filled().as_ptr()); - buf.filled().len() - }; - - // Safety: This is guaranteed to be the number of initialized (and read) - // bytes due to the invariants provided by `ReadBuf::filled`. - unsafe { - buf.advance_mut(n); + #[inline] + async fn shutdown(&mut self) -> io::Result<()> { + poll_fn(|cx| Pin::new(&mut *self.0.borrow_mut()).poll_shutdown(cx)).await + } } - - Poll::Ready(Ok(n)) } diff --git a/ntex-tokio/src/lib.rs b/ntex-tokio/src/lib.rs index 79d538e2b..8916e200a 100644 --- a/ntex-tokio/src/lib.rs +++ b/ntex-tokio/src/lib.rs @@ -4,10 +4,8 @@ use ntex_bytes::PoolRef; use ntex_io::Io; mod io; -mod signals; pub use self::io::{SocketOptions, TokioIoBoxed}; -pub use self::signals::{signal, Signal}; struct TcpStream(tokio::net::TcpStream); diff --git a/ntex-tokio/src/signals.rs b/ntex-tokio/src/signals.rs deleted file mode 100644 index 87ff229da..000000000 --- a/ntex-tokio/src/signals.rs +++ /dev/null @@ -1,138 +0,0 @@ -use std::{ - cell::RefCell, future::Future, mem, pin::Pin, rc::Rc, task::Context, task::Poll, -}; - -use tokio::sync::oneshot; -use tokio::task::spawn_local; - -thread_local! { - static SRUN: RefCell = const { RefCell::new(false) }; - static SHANDLERS: Rc>>> = Default::default(); -} - -/// Different types of process signals -#[derive(PartialEq, Eq, Clone, Copy, Debug)] -pub enum Signal { - /// SIGHUP - Hup, - /// SIGINT - Int, - /// SIGTERM - Term, - /// SIGQUIT - Quit, -} - -/// Register signal handler. -/// -/// Signals are handled by oneshots, you have to re-register -/// after each signal. -pub fn signal() -> Option> { - if !SRUN.with(|v| *v.borrow()) { - spawn_local(Signals::new()); - } - SHANDLERS.with(|handlers| { - let (tx, rx) = oneshot::channel(); - handlers.borrow_mut().push(tx); - Some(rx) - }) -} - -struct Signals { - #[cfg(not(unix))] - signal: Pin>>>, - #[cfg(unix)] - signals: Vec<( - Signal, - tokio::signal::unix::Signal, - tokio::signal::unix::SignalKind, - )>, -} - -impl Signals { - fn new() -> Signals { - SRUN.with(|h| *h.borrow_mut() = true); - - #[cfg(not(unix))] - { - Signals { - signal: Box::pin(tokio::signal::ctrl_c()), - } - } - - #[cfg(unix)] - { - use tokio::signal::unix; - - let sig_map = [ - (unix::SignalKind::interrupt(), Signal::Int), - (unix::SignalKind::hangup(), Signal::Hup), - (unix::SignalKind::terminate(), Signal::Term), - (unix::SignalKind::quit(), Signal::Quit), - ]; - - let mut signals = Vec::new(); - for (kind, sig) in sig_map.iter() { - match unix::signal(*kind) { - Ok(stream) => signals.push((*sig, stream, *kind)), - Err(e) => log::error!( - "Cannot initialize stream handler for {:?} err: {}", - sig, - e - ), - } - } - - Signals { signals } - } - } -} - -impl Drop for Signals { - fn drop(&mut self) { - SRUN.with(|h| *h.borrow_mut() = false); - } -} - -impl Future for Signals { - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - #[cfg(not(unix))] - { - if self.signal.as_mut().poll(cx).is_ready() { - let handlers = SHANDLERS.with(|h| mem::take(&mut *h.borrow_mut())); - for sender in handlers { - let _ = sender.send(Signal::Int); - } - } - Poll::Pending - } - #[cfg(unix)] - { - for (sig, stream, kind) in self.signals.iter_mut() { - loop { - if Pin::new(&mut *stream).poll_recv(cx).is_ready() { - let handlers = SHANDLERS.with(|h| mem::take(&mut *h.borrow_mut())); - for sender in handlers { - let _ = sender.send(*sig); - } - match tokio::signal::unix::signal(*kind) { - Ok(s) => { - *stream = s; - continue; - } - Err(e) => log::error!( - "Cannot initialize stream handler for {:?} err: {}", - sig, - e - ), - } - } - break; - } - } - Poll::Pending - } - } -} diff --git a/ntex/Cargo.toml b/ntex/Cargo.toml index 46e88c430..28881625b 100644 --- a/ntex/Cargo.toml +++ b/ntex/Cargo.toml @@ -71,7 +71,7 @@ ntex-bytes = "0.1.27" ntex-server = "2.3" ntex-h2 = "1.1" ntex-rt = "0.4.15" -ntex-io = "2.4" +ntex-io = "2.5" ntex-net = "2.1" ntex-tls = "2.1"