Skip to content

Commit

Permalink
Adds a circuit breaker layer. (#5134)
Browse files Browse the repository at this point in the history
The piece that estimates whether the next request is likely to fail is extremely simplistic for the moment.
It simply counter the number of errors (not taking in account successes) that happened in a given time window.

The reason is that for the moment, we want to use it for persist requests when the WAL is full.
On airmail, the aggressive retry logic of the client was causing a massive grpc storm on the faulty indexer node,
taking all of its CPU and preventing it from getting out of that state.

In this case, the error estimation logic is very simple, a full WAL guarantees that no further persist request will be successful for a little while.
  • Loading branch information
fulmicoton authored Jul 5, 2024
1 parent b373552 commit 1fba1d1
Show file tree
Hide file tree
Showing 5 changed files with 429 additions and 9 deletions.
3 changes: 2 additions & 1 deletion quickwit/quickwit-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ siphasher = { workspace = true }
tempfile = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true }
tokio-metrics ={ workspace = true }
tokio-metrics = { workspace = true }
tokio-stream = { workspace = true }
tonic = { workspace = true }
tower = { workspace = true }
Expand All @@ -51,3 +51,4 @@ named_tasks = ["tokio/tracing"]
serde_json = { workspace = true }
tempfile = { workspace = true }
proptest = { workspace = true }
tokio = { workspace = true, features = ["test-util"] }
372 changes: 372 additions & 0 deletions quickwit/quickwit-common/src/tower/circuit_breaker.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,372 @@
// Copyright (C) 2024 Quickwit, Inc.
//
// Quickwit is offered under the AGPL v3.0 and as commercial software.
// For commercial licensing, contact us at [email protected].
//
// AGPL:
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::time::Duration;

use pin_project::pin_project;
use prometheus::IntCounter;
use tokio::time::Instant;
use tower::{Layer, Service};

/// The circuit breaker layer implements the [circuit breaker pattern](https://martinfowler.com/bliki/CircuitBreaker.html).
///
/// It counts the errors emitted by the inner service, and if the number of errors exceeds a certain
/// threshold within a certain time window, it will "open" the circuit.
///
/// Requests will then be rejected for a given timeout.
/// After this timeout, the circuit breaker ends up in a HalfOpen state. It will allow a single
/// request to pass through. Depending on the result of this request, the circuit breaker will
/// either close the circuit again or open it again.
///
/// Implementation detail:
///
/// A circuit breaker needs to have some logic to estimate the chances for the next request
/// to fail. In this implementation, we use a simple heuristic that does not take in account
/// successes. We simply count the number or errors which happened in the last window.
///
/// The circuit breaker does not attempt to measure accurately the error rate.
/// Instead, it counts errors, and check for the time window in which these errors occurred.
/// This approach is accurate enough, robust, very easy to code and avoids calling the
/// `Instant::now()` at every error in the open state.
#[derive(Debug, Clone)]
pub struct CircuitBreakerLayer<Evaluator> {
max_error_count_per_time_window: u32,
time_window: Duration,
timeout: Duration,
evaluator: Evaluator,
circuit_break_total: prometheus::IntCounter,
}

pub trait CircuitBreakerEvaluator: Clone {
type Response;
type Error;
fn is_circuit_breaker_error(&self, output: &Result<Self::Response, Self::Error>) -> bool;
fn make_circuit_breaker_output(&self) -> Self::Error;
fn make_layer(
self,
max_num_errors_per_secs: u32,
timeout: Duration,
circuit_break_total: prometheus::IntCounter,
) -> CircuitBreakerLayer<Self> {
CircuitBreakerLayer {
max_error_count_per_time_window: max_num_errors_per_secs,
time_window: Duration::from_secs(1),
timeout,
evaluator: self,
circuit_break_total,
}
}
}

impl<S, Evaluator: CircuitBreakerEvaluator> Layer<S> for CircuitBreakerLayer<Evaluator> {
type Service = CircuitBreaker<S, Evaluator>;

fn layer(&self, service: S) -> CircuitBreaker<S, Evaluator> {
let time_window = Duration::from_millis(self.time_window.as_millis() as u64);
let timeout = Duration::from_millis(self.timeout.as_millis() as u64);
CircuitBreaker {
underlying: service,
circuit_breaker_inner: Arc::new(Mutex::new(CircuitBreakerInner {
max_error_count_per_time_window: self.max_error_count_per_time_window,
time_window,
timeout,
state: CircuitBreakerState::Closed(ClosedState {
error_counter: 0u32,
error_window_end: Instant::now() + time_window,
}),
evaluator: self.evaluator.clone(),
circuit_break_total: self.circuit_break_total.clone(),
})),
}
}
}

struct CircuitBreakerInner<Evaluator> {
max_error_count_per_time_window: u32,
time_window: Duration,
timeout: Duration,
evaluator: Evaluator,
state: CircuitBreakerState,
circuit_break_total: IntCounter,
}

impl<Evaluator> CircuitBreakerInner<Evaluator> {
fn get_state(&mut self) -> CircuitBreakerState {
let new_state = match self.state {
CircuitBreakerState::Open { until } => {
let now = Instant::now();
if now < until {
CircuitBreakerState::Open { until }
} else {
CircuitBreakerState::HalfOpen
}
}
other => other,
};
self.state = new_state;
new_state
}

fn receive_error(&mut self) {
match self.state {
CircuitBreakerState::HalfOpen => {
self.circuit_break_total.inc();
self.state = CircuitBreakerState::Open {
until: Instant::now() + self.timeout,
}
}
CircuitBreakerState::Open { .. } => {}
CircuitBreakerState::Closed(ClosedState {
error_counter,
error_window_end,
}) => {
if error_counter < self.max_error_count_per_time_window {
self.state = CircuitBreakerState::Closed(ClosedState {
error_counter: error_counter + 1,
error_window_end,
});
return;
}
let now = Instant::now();
if now < error_window_end {
self.circuit_break_total.inc();
self.state = CircuitBreakerState::Open {
until: now + self.timeout,
};
} else {
self.state = CircuitBreakerState::Closed(ClosedState {
error_counter: 0u32,
error_window_end: now + self.time_window,
});
}
}
}
}

fn receive_success(&mut self) {
match self.state {
CircuitBreakerState::HalfOpen | CircuitBreakerState::Open { .. } => {
self.state = CircuitBreakerState::Closed(ClosedState {
error_counter: 0u32,
error_window_end: Instant::now() + self.time_window,
});
}
CircuitBreakerState::Closed { .. } => {
// We could actually take that as a signal.
}
}
}
}

#[derive(Clone)]
pub struct CircuitBreaker<S, Evaluator> {
underlying: S,
circuit_breaker_inner: Arc<Mutex<CircuitBreakerInner<Evaluator>>>,
}

impl<S, Evaluator> std::fmt::Debug for CircuitBreaker<S, Evaluator> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("CircuitBreaker").finish()
}
}

#[derive(Debug, Clone, Copy)]
enum CircuitBreakerState {
Open { until: Instant },
HalfOpen,
Closed(ClosedState),
}

#[derive(Debug, Clone, Copy)]
struct ClosedState {
error_counter: u32,
error_window_end: Instant,
}

impl<S, R, Evaluator> Service<R> for CircuitBreaker<S, Evaluator>
where
S: Service<R>,
Evaluator: CircuitBreakerEvaluator<Response = S::Response, Error = S::Error>,
{
type Response = S::Response;
type Error = S::Error;
type Future = CircuitBreakerFuture<S::Future, Evaluator>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let mut inner = self.circuit_breaker_inner.lock().unwrap();
let state = inner.get_state();
match state {
CircuitBreakerState::Closed { .. } | CircuitBreakerState::HalfOpen => {
self.underlying.poll_ready(cx)
}
CircuitBreakerState::Open { .. } => {
let circuit_break_error = inner.evaluator.make_circuit_breaker_output();
Poll::Ready(Err(circuit_break_error))
}
}
}

fn call(&mut self, request: R) -> Self::Future {
CircuitBreakerFuture {
underlying_fut: self.underlying.call(request),
circuit_breaker_inner: self.circuit_breaker_inner.clone(),
}
}
}

#[pin_project]
pub struct CircuitBreakerFuture<F, Evaluator> {
#[pin]
underlying_fut: F,
circuit_breaker_inner: Arc<Mutex<CircuitBreakerInner<Evaluator>>>,
}

impl<Response, Error, F, Evaluator> Future for CircuitBreakerFuture<F, Evaluator>
where
F: Future<Output = Result<Response, Error>>,
Evaluator: CircuitBreakerEvaluator<Response = Response, Error = Error>,
{
type Output = F::Output;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let circuit_breaker_inner = self.circuit_breaker_inner.clone();
let poll_res = self.project().underlying_fut.poll(cx);
match poll_res {
Poll::Pending => Poll::Pending,
Poll::Ready(result) => {
let mut circuit_breaker_inner_lock = circuit_breaker_inner.lock().unwrap();
let is_circuit_breaker_error = circuit_breaker_inner_lock
.evaluator
.is_circuit_breaker_error(&result);
if is_circuit_breaker_error {
circuit_breaker_inner_lock.receive_error();
} else {
circuit_breaker_inner_lock.receive_success();
}
Poll::Ready(result)
}
}
}
}

#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicBool, Ordering};

use tower::{ServiceBuilder, ServiceExt};

use super::*;

#[derive(Debug)]
enum TestError {
CircuitBreak,
ServiceError,
}

#[derive(Debug, Clone, Copy)]
struct TestCircuitBreakerEvaluator;

impl CircuitBreakerEvaluator for TestCircuitBreakerEvaluator {
type Response = ();
type Error = TestError;

fn is_circuit_breaker_error(&self, output: &Result<Self::Response, Self::Error>) -> bool {
output.is_err()
}

fn make_circuit_breaker_output(&self) -> TestError {
TestError::CircuitBreak
}
}

#[tokio::test]
async fn test_circuit_breaker() {
tokio::time::pause();
let test_switch: Arc<AtomicBool> = Arc::new(AtomicBool::new(true));

const TIMEOUT: Duration = Duration::from_millis(500);

let int_counter: prometheus::IntCounter =
IntCounter::new("circuit_break_total_test", "test circuit breaker counter").unwrap();
let mut service = ServiceBuilder::new()
.layer(TestCircuitBreakerEvaluator.make_layer(10, TIMEOUT, int_counter))
.service_fn(|_| async {
if test_switch.load(Ordering::Relaxed) {
Ok(())
} else {
Err(TestError::ServiceError)
}
});

service.ready().await.unwrap().call(()).await.unwrap();

for _ in 0..1_000 {
service.ready().await.unwrap().call(()).await.unwrap();
}

test_switch.store(false, Ordering::Relaxed);

let mut service_error_count = 0;
let mut circuit_break_count = 0;
for _ in 0..1_000 {
match service.ready().await {
Ok(service) => {
service.call(()).await.unwrap_err();
service_error_count += 1;
}
Err(_circuit_breaker_error) => {
circuit_break_count += 1;
}
}
}

assert_eq!(service_error_count + circuit_break_count, 1_000);
assert_eq!(service_error_count, 11);

tokio::time::advance(TIMEOUT).await;

// The test request at half open fails.
for _ in 0..1_000 {
match service.ready().await {
Ok(service) => {
service.call(()).await.unwrap_err();
service_error_count += 1;
}
Err(_circuit_breaker_error) => {
circuit_break_count += 1;
}
}
}

assert_eq!(service_error_count + circuit_break_count, 2_000);
assert_eq!(service_error_count, 12);

test_switch.store(true, Ordering::Relaxed);
tokio::time::advance(TIMEOUT).await;

// The test request at half open succeeds.
for _ in 0..1_000 {
service.ready().await.unwrap().call(()).await.unwrap();
}
}
}
Loading

0 comments on commit 1fba1d1

Please sign in to comment.