Skip to content

Commit

Permalink
Update buffer service (#452)
Browse files Browse the repository at this point in the history
  • Loading branch information
fafhrd91 authored Nov 3, 2024
1 parent a301471 commit c303d02
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 48 deletions.
24 changes: 11 additions & 13 deletions ntex-service/src/and_then.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ mod tests {

use crate::{chain, chain_factory, fn_factory, Service, ServiceCtx};

#[derive(Clone)]
#[derive(Debug, Clone)]
struct Srv1(Rc<Cell<usize>>, Rc<Cell<usize>>);

impl Service<&'static str> for Srv1 {
Expand Down Expand Up @@ -123,7 +123,7 @@ mod tests {
}
}

#[derive(Clone)]
#[derive(Debug, Clone)]
struct Srv2(Rc<Cell<usize>>, Rc<Cell<usize>>);

impl Service<&'static str> for Srv2 {
Expand Down Expand Up @@ -157,12 +157,10 @@ mod tests {
async fn test_ready() {
let cnt = Rc::new(Cell::new(0));
let cnt_sht = Rc::new(Cell::new(0));
let srv = Box::new(
chain(Srv1(cnt.clone(), cnt_sht.clone()))
.and_then(Srv2(cnt.clone(), cnt_sht.clone()))
.clone(),
)
.into_pipeline();
let srv = chain(Box::new(Srv1(cnt.clone(), cnt_sht.clone())))
.clone()
.and_then(crate::boxed::service(Srv2(cnt.clone(), cnt_sht.clone())))
.into_pipeline();
let res = srv.ready().await;
assert_eq!(res, Ok(()));
assert_eq!(cnt.get(), 2);
Expand All @@ -176,6 +174,8 @@ mod tests {

srv.shutdown().await;
assert_eq!(cnt_sht.get(), 2);

assert!(format!("{:?}", srv).contains("AndThen"));
}

#[ntex::test]
Expand All @@ -194,11 +194,9 @@ mod tests {
#[ntex::test]
async fn test_call() {
let cnt = Rc::new(Cell::new(0));
let srv = Box::new(
chain(Srv1(cnt.clone(), Rc::new(Cell::new(0))))
.and_then(Srv2(cnt, Rc::new(Cell::new(0)))),
)
.into_pipeline();
let srv = chain(Box::new(Srv1(cnt.clone(), Rc::new(Cell::new(0)))))
.and_then(Srv2(cnt, Rc::new(Cell::new(0))))
.into_pipeline();
let res = srv.call("srv1").await;
assert!(res.is_ok());
assert_eq!(res.unwrap(), ("srv1", "srv2"));
Expand Down
2 changes: 2 additions & 0 deletions ntex-util/src/future/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ where
poll_fn(|cx| Pin::new(&mut *stream).poll_next(cx)).await
}

#[doc(hidden)]
#[deprecated]
/// A future that completes after the given item has been fully processed
/// into the sink, including flushing.
pub async fn sink_write<S, I>(sink: &mut S, item: I) -> Result<(), S::Error>
Expand Down
108 changes: 76 additions & 32 deletions ntex-util/src/services/buffer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Service that buffers incomming requests.
use std::cell::{Cell, RefCell};
use std::task::{ready, Poll};
use std::task::{ready, Poll, Waker};
use std::{collections::VecDeque, fmt, future::poll_fn, marker::PhantomData};

use ntex_service::{Middleware, Pipeline, PipelineBinding, Service, ServiceCtx};
Expand Down Expand Up @@ -70,11 +70,13 @@ where
fn create(&self, service: S) -> Self::Service {
BufferService {
service: Pipeline::new(service).bind(),
service_pending: Cell::new(true),
size: self.buf_size,
ready: Cell::new(false),
buf: RefCell::new(VecDeque::with_capacity(self.buf_size)),
next_call: RefCell::default(),
cancel_on_shutdown: self.cancel_on_shutdown,
readiness: Cell::new(None),
_t: PhantomData,
}
}
Expand Down Expand Up @@ -111,10 +113,12 @@ impl<E: std::fmt::Display + std::fmt::Debug> std::error::Error for BufferService
pub struct BufferService<R, S: Service<R>> {
size: usize,
ready: Cell<bool>,
service_pending: Cell<bool>,
service: PipelineBinding<S, R>,
buf: RefCell<VecDeque<oneshot::Sender<oneshot::Sender<()>>>>,
next_call: RefCell<Option<oneshot::Receiver<()>>>,
cancel_on_shutdown: bool,
readiness: Cell<Option<Waker>>,
_t: PhantomData<R>,
}

Expand All @@ -127,10 +131,12 @@ where
Self {
size,
service: Pipeline::new(service).bind(),
service_pending: Cell::new(true),
ready: Cell::new(false),
buf: RefCell::new(VecDeque::with_capacity(size)),
next_call: RefCell::default(),
cancel_on_shutdown: false,
readiness: Cell::new(None),
_t: PhantomData,
}
}
Expand All @@ -152,9 +158,11 @@ where
size: self.size,
ready: Cell::new(false),
service: self.service.clone(),
service_pending: Cell::new(false),
buf: RefCell::new(VecDeque::with_capacity(self.size)),
next_call: RefCell::default(),
cancel_on_shutdown: self.cancel_on_shutdown,
readiness: Cell::new(None),
_t: PhantomData,
}
}
Expand All @@ -170,6 +178,7 @@ where
.field("cancel_on_shutdown", &self.cancel_on_shutdown)
.field("ready", &self.ready)
.field("service", &self.service)
.field("service_pending", &self.service_pending)
.field("buf", &self.buf)
.field("next_call", &self.next_call)
.finish()
Expand All @@ -185,58 +194,79 @@ where
type Error = BufferServiceError<S::Error>;

async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
// hold advancement until the last released task either makes a call or is dropped
let next_call = self.next_call.borrow_mut().take();
if let Some(next_call) = next_call {
let _ = next_call.recv().await;
}

poll_fn(|cx| {
let mut buffer = self.buf.borrow_mut();
let mut next_call = self.next_call.borrow_mut();
if let Some(next_call) = &*next_call {
// hold advancement until the last released task either makes a call or is dropped
let _ = ready!(next_call.poll_recv(cx));
}
next_call.take();

// handle inner service readiness
if self.service.poll_ready(cx)?.is_pending() {
if buffer.len() < self.size {
// buffer next request
self.ready.set(false);
return Poll::Ready(Ok(()));
self.service_pending.set(false);
Poll::Ready(Ok(()))
} else {
log::trace!("Buffer limit exceeded");
return Poll::Pending;
// service is not ready
self.service_pending.set(true);
let _ = self.readiness.take().map(|w| w.wake());
Poll::Pending
}
}
} else {
self.service_pending.set(false);

while let Some(sender) = buffer.pop_front() {
let (next_call_tx, next_call_rx) = oneshot::channel();
if sender.send(next_call_tx).is_err()
|| next_call_rx.poll_recv(cx).is_ready()
{
// the task is gone
continue;
while let Some(sender) = buffer.pop_front() {
let (next_call_tx, next_call_rx) = oneshot::channel();
if sender.send(next_call_tx).is_err()
|| next_call_rx.poll_recv(cx).is_ready()
{
// the task is gone
continue;
}
self.next_call.borrow_mut().replace(next_call_rx);
self.ready.set(false);
return Poll::Ready(Ok(()));
}
next_call.replace(next_call_rx);
self.ready.set(false);
return Poll::Ready(Ok(()));
}

self.ready.set(true);
Poll::Ready(Ok(()))
self.ready.set(true);
Poll::Ready(Ok(()))
}
})
.await
}

async fn not_ready(&self) {
let fut = poll_fn(|cx| {
if self.service_pending.get() {
Poll::Ready(())
} else {
self.readiness.set(Some(cx.waker().clone()));
Poll::Pending
}
});

crate::future::select(fut, self.service.get_ref().not_ready()).await;
}

async fn shutdown(&self) {
// hold advancement until the last released task either makes a call or is dropped
let next_call = self.next_call.borrow_mut().take();
if let Some(next_call) = next_call {
let _ = next_call.recv().await;
}

poll_fn(|cx| {
let mut buffer = self.buf.borrow_mut();
if self.cancel_on_shutdown {
buffer.clear();
} else if !buffer.is_empty() {
let mut next_call = self.next_call.borrow_mut();
if let Some(next_call) = &*next_call {
// hold advancement until the last released task either makes a call or is dropped
let _ = ready!(next_call.poll_recv(cx));
}
next_call.take();
}

if !buffer.is_empty() {
if ready!(self.service.poll_ready(cx)).is_err() {
log::error!(
"Buffered inner service failed while buffer flushing on shutdown"
Expand All @@ -252,7 +282,7 @@ where
// the task is gone
continue;
}
next_call.replace(next_call_rx);
self.next_call.borrow_mut().replace(next_call_rx);
if buffer.is_empty() {
break;
}
Expand Down Expand Up @@ -299,9 +329,10 @@ mod tests {
use crate::future::lazy;
use crate::task::LocalWaker;

#[derive(Clone)]
#[derive(Debug, Clone)]
struct TestService(Rc<Inner>);

#[derive(Debug)]
struct Inner {
ready: Cell<bool>,
waker: LocalWaker,
Expand Down Expand Up @@ -342,6 +373,7 @@ mod tests {
let srv =
Pipeline::new(BufferService::new(2, TestService(inner.clone())).clone()).bind();
assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
assert_eq!(lazy(|cx| srv.poll_not_ready(cx)).await, Poll::Pending);

let srv1 = srv.clone();
ntex::rt::spawn(async move {
Expand All @@ -350,6 +382,7 @@ mod tests {
crate::time::sleep(Duration::from_millis(25)).await;
assert_eq!(inner.count.get(), 0);
assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
assert_eq!(lazy(|cx| srv.poll_not_ready(cx)).await, Poll::Pending);

let srv1 = srv.clone();
ntex::rt::spawn(async move {
Expand All @@ -358,17 +391,20 @@ mod tests {
crate::time::sleep(Duration::from_millis(25)).await;
assert_eq!(inner.count.get(), 0);
assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
assert_eq!(lazy(|cx| srv.poll_not_ready(cx)).await, Poll::Ready(()));

inner.ready.set(true);
inner.waker.wake();
assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
assert_eq!(lazy(|cx| srv.poll_not_ready(cx)).await, Poll::Pending);

crate::time::sleep(Duration::from_millis(25)).await;
assert_eq!(inner.count.get(), 1);

inner.ready.set(true);
inner.waker.wake();
assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
assert_eq!(lazy(|cx| srv.poll_not_ready(cx)).await, Poll::Pending);

crate::time::sleep(Duration::from_millis(25)).await;
assert_eq!(inner.count.get(), 2);
Expand All @@ -381,10 +417,18 @@ mod tests {

let srv = Pipeline::new(BufferService::new(2, TestService(inner.clone()))).bind();
assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
assert_eq!(lazy(|cx| srv.poll_not_ready(cx)).await, Poll::Pending);

let _ = srv.call(()).await;
assert_eq!(inner.count.get(), 1);
assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
assert_eq!(lazy(|cx| srv.poll_not_ready(cx)).await, Poll::Pending);
assert!(lazy(|cx| srv.poll_shutdown(cx)).await.is_ready());

let err = BufferServiceError::from("test");
assert!(format!("{}", err).contains("test"));
assert!(format!("{:?}", srv).contains("BufferService"));
assert!(format!("{:?}", Buffer::<TestService>::default()).contains("Buffer"));
}

#[ntex_macros::rt_test2]
Expand Down
6 changes: 3 additions & 3 deletions ntex-util/src/services/variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ mod tests {

use super::*;

#[derive(Clone)]
#[derive(Debug, Clone)]
struct Srv1;

impl Service<()> for Srv1 {
Expand All @@ -275,7 +275,7 @@ mod tests {
}
}

#[derive(Clone)]
#[derive(Debug, Clone)]
struct Srv2;

impl Service<()> for Srv2 {
Expand Down Expand Up @@ -303,9 +303,9 @@ mod tests {
.clone()
.v3(fn_factory(|| async { Ok::<_, ()>(Srv2) }))
.clone();
assert!(format!("{:?}", factory).contains("Variant"));

let service = factory.pipeline(&()).await.unwrap().clone();
assert!(format!("{:?}", service).contains("Variant"));

let mut f = pin::pin!(service.not_ready());
let _ = poll_fn(|cx| {
Expand Down

0 comments on commit c303d02

Please sign in to comment.