From 568df1cbe99b6dd65a641ee14a066b49db3f58d4 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 11 Sep 2024 00:34:18 +0500 Subject: [PATCH] Refactor filter shutdown --- ntex-io/src/filter.rs | 29 ++-------- ntex-io/src/flags.rs | 2 - ntex-io/src/ioref.rs | 6 -- ntex-io/src/lib.rs | 7 +-- ntex-io/src/tasks.rs | 131 ++++++++++++++++++------------------------ 5 files changed, 65 insertions(+), 110 deletions(-) diff --git a/ntex-io/src/filter.rs b/ntex-io/src/filter.rs index b452429b6..78ed7520a 100644 --- a/ntex-io/src/filter.rs +++ b/ntex-io/src/filter.rs @@ -93,26 +93,16 @@ impl Filter for Base { #[inline] fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll { - let mut flags = self.0.flags(); + let flags = self.0.flags(); if flags.is_stopped() { Poll::Ready(WriteStatus::Terminate) } else { self.0 .0.write_task.register(cx.waker()); - if flags.intersects(Flags::IO_STOPPING) { - Poll::Ready(WriteStatus::Shutdown( - self.0 .0.disconnect_timeout.get().into(), - )) - } else if flags.contains(Flags::IO_STOPPING_FILTERS) - && !flags.contains(Flags::IO_FILTERS_TIMEOUT) - { - flags.insert(Flags::IO_FILTERS_TIMEOUT); - self.0.set_flags(flags); - Poll::Ready(WriteStatus::Timeout( - self.0 .0.disconnect_timeout.get().into(), - )) - } else if flags.intersects(Flags::WR_PAUSED) { + if flags.contains(Flags::IO_STOPPING) { + Poll::Ready(WriteStatus::Shutdown) + } else if flags.contains(Flags::WR_PAUSED) { Poll::Pending } else { Poll::Ready(WriteStatus::Ready) @@ -242,20 +232,13 @@ where Poll::Pending => Poll::Pending, Poll::Ready(WriteStatus::Ready) => res2, Poll::Ready(WriteStatus::Terminate) => Poll::Ready(WriteStatus::Terminate), - Poll::Ready(WriteStatus::Shutdown(t)) => { + Poll::Ready(WriteStatus::Shutdown) => { if res2 == Poll::Ready(WriteStatus::Terminate) { Poll::Ready(WriteStatus::Terminate) } else { - Poll::Ready(WriteStatus::Shutdown(t)) + Poll::Ready(WriteStatus::Shutdown) } } - Poll::Ready(WriteStatus::Timeout(t)) => match res2 { - Poll::Ready(WriteStatus::Terminate) => Poll::Ready(WriteStatus::Terminate), - Poll::Ready(WriteStatus::Shutdown(t)) => { - Poll::Ready(WriteStatus::Shutdown(t)) - } - _ => Poll::Ready(WriteStatus::Timeout(t)), - }, } } } diff --git a/ntex-io/src/flags.rs b/ntex-io/src/flags.rs index 029d891b7..bc9b5aacf 100644 --- a/ntex-io/src/flags.rs +++ b/ntex-io/src/flags.rs @@ -7,8 +7,6 @@ bitflags::bitflags! { const IO_STOPPING = 0b0000_0000_0000_0010; /// shuting down filters const IO_STOPPING_FILTERS = 0b0000_0000_0000_0100; - /// initiate filters shutdown timeout in write task - const IO_FILTERS_TIMEOUT = 0b0000_0000_0000_1000; /// pause io read const RD_PAUSED = 0b0000_0000_0001_0000; diff --git a/ntex-io/src/ioref.rs b/ntex-io/src/ioref.rs index 02dc33766..c905a8cfb 100644 --- a/ntex-io/src/ioref.rs +++ b/ntex-io/src/ioref.rs @@ -14,12 +14,6 @@ impl IoRef { self.0.flags.get() } - #[inline] - /// Set flags - pub(crate) fn set_flags(&self, flags: Flags) { - self.0.flags.set(flags) - } - #[inline] /// Get current filter pub(crate) fn filter(&self) -> &dyn Filter { diff --git a/ntex-io/src/lib.rs b/ntex-io/src/lib.rs index ee3ede964..fb7a8867f 100644 --- a/ntex-io/src/lib.rs +++ b/ntex-io/src/lib.rs @@ -23,7 +23,6 @@ mod utils; use ntex_bytes::BytesVec; use ntex_codec::{Decoder, Encoder}; -use ntex_util::time::Millis; pub use self::buf::{ReadBuf, WriteBuf}; pub use self::dispatcher::{Dispatcher, DispatcherConfig}; @@ -64,10 +63,8 @@ pub enum ReadStatus { pub enum WriteStatus { /// Write task is clear to proceed with write operation Ready, - /// Initiate timeout for normal write operations, shutdown connection after timeout - Timeout(Millis), - /// Initiate graceful io shutdown operation with timeout - Shutdown(Millis), + /// Initiate graceful io shutdown operation + Shutdown, /// Immediately terminate connection Terminate, } diff --git a/ntex-io/src/tasks.rs b/ntex-io/src/tasks.rs index b6ff33c21..497e1f6c6 100644 --- a/ntex-io/src/tasks.rs +++ b/ntex-io/src/tasks.rs @@ -1,17 +1,22 @@ -use std::{future::poll_fn, io, task::Poll}; +use std::{cell::Cell, fmt, future::poll_fn, io, task::Context, task::Poll}; use ntex_bytes::{BufMut, BytesVec}; -use ntex_util::{future::select, future::Either, time::sleep}; +use ntex_util::{future::lazy, future::select, future::Either, time::sleep, time::Sleep}; use crate::{AsyncRead, AsyncWrite, Flags, IoRef, ReadStatus, WriteStatus}; -#[derive(Debug)] /// Context for io read task -pub struct ReadContext(IoRef); +pub struct ReadContext(IoRef, Cell>); + +impl fmt::Debug for ReadContext { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ReadContext").field("io", &self.0).finish() + } +} impl ReadContext { pub(crate) fn new(io: &IoRef) -> Self { - Self(io.clone()) + Self(io.clone(), Cell::new(None)) } #[inline] @@ -30,7 +35,7 @@ impl ReadContext { } else { self.0 .0.read_task.register(cx.waker()); if flags.contains(Flags::IO_STOPPING_FILTERS) { - shutdown_filters(&self.0); + self.shutdown_filters(cx); } Poll::Pending } @@ -149,7 +154,7 @@ impl ReadContext { } Ok(_) => { if inner.flags.get().contains(Flags::IO_STOPPING_FILTERS) { - shutdown_filters(&self.0); + lazy(|cx| self.shutdown_filters(cx)).await; } } Err(err) => { @@ -160,6 +165,48 @@ impl ReadContext { } } } + + fn shutdown_filters(&self, cx: &mut Context<'_>) { + let st = &self.0 .0; + let filter = self.0.filter(); + + match filter.shutdown(&self.0, &st.buffer, 0) { + Ok(Poll::Ready(())) => { + st.dispatch_task.wake(); + st.insert_flags(Flags::IO_STOPPING); + } + Ok(Poll::Pending) => { + let flags = st.flags.get(); + + // check read buffer, if buffer is not consumed it is unlikely + // that filter will properly complete shutdown + if flags.contains(Flags::RD_PAUSED) + || flags.contains(Flags::BUF_R_FULL | Flags::BUF_R_READY) + { + st.dispatch_task.wake(); + st.insert_flags(Flags::IO_STOPPING); + } else { + // filter shutdown timeout + let timeout = self + .1 + .take() + .unwrap_or_else(|| sleep(st.disconnect_timeout.get())); + if timeout.poll_elapsed(cx).is_ready() { + st.dispatch_task.wake(); + st.insert_flags(Flags::IO_STOPPING); + } else { + self.1.set(Some(timeout)); + } + } + } + Err(err) => { + st.io_stopped(Some(err)); + } + } + if let Err(err) = filter.process_write_buf(&self.0, &st.buffer, 0) { + st.io_stopped(Some(err)); + } + } } #[derive(Debug)] @@ -212,41 +259,13 @@ impl WriteContext { where T: AsyncWrite, { - let inner = &self.0 .0; - let mut delay = None; let mut buf = WriteContextBuf { io: self.0.clone(), buf: None, }; loop { - // check readiness - let result = if let Some(ref mut sleep) = delay { - let result = match select(sleep, self.ready()).await { - Either::Left(_) => { - self.close(Some(io::Error::new( - io::ErrorKind::TimedOut, - "Operation timedout", - ))); - return; - } - Either::Right(res) => res, - }; - delay = None; - result - } else { - self.ready().await - }; - - // running - let mut flags = inner.flags.get(); - if flags.contains(Flags::WR_PAUSED) { - flags.remove(Flags::WR_PAUSED); - inner.flags.set(flags); - } - - // handle write - match result { + match self.ready().await { WriteStatus::Ready => { // write io stream match select(io.write(&mut buf), self.when_stopped()).await { @@ -255,12 +274,7 @@ impl WriteContext { Either::Right(_) => return, } } - WriteStatus::Timeout(time) => { - log::trace!("{}: Initiate timeout delay for {:?}", self.tag(), time); - delay = Some(sleep(time)); - continue; - } - WriteStatus::Shutdown(time) => { + WriteStatus::Shutdown => { log::trace!("{}: Write task is instructed to shutdown", self.tag()); let fut = async { @@ -270,7 +284,7 @@ impl WriteContext { io.shutdown().await?; Ok(()) }; - match select(sleep(time), fut).await { + match select(sleep(self.0 .0.disconnect_timeout.get()), fut).await { Either::Left(_) => self.close(None), Either::Right(res) => self.close(res.err()), } @@ -328,34 +342,3 @@ impl WriteContextBuf { } } } - -fn shutdown_filters(io: &IoRef) { - let st = &io.0; - let flags = st.flags.get(); - - if !flags.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING) { - let filter = io.filter(); - match filter.shutdown(io, &st.buffer, 0) { - Ok(Poll::Ready(())) => { - st.dispatch_task.wake(); - st.insert_flags(Flags::IO_STOPPING); - } - Ok(Poll::Pending) => { - // check read buffer, if buffer is not consumed it is unlikely - // that filter will properly complete shutdown - if flags.contains(Flags::RD_PAUSED) - || flags.contains(Flags::BUF_R_FULL | Flags::BUF_R_READY) - { - st.dispatch_task.wake(); - st.insert_flags(Flags::IO_STOPPING); - } - } - Err(err) => { - st.io_stopped(Some(err)); - } - } - if let Err(err) = filter.process_write_buf(io, &st.buffer, 0) { - st.io_stopped(Some(err)); - } - } -}