Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: add auto reconnect to curp client #972

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
70 changes: 16 additions & 54 deletions crates/curp/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ mod tests;

#[cfg(madsim)]
use std::sync::atomic::AtomicU64;
use std::{collections::HashMap, fmt::Debug, ops::Deref, sync::Arc, time::Duration};
use std::{collections::HashMap, fmt::Debug, ops::Deref, sync::Arc};

use async_trait::async_trait;
use curp_external_api::cmd::Command;
Expand Down Expand Up @@ -163,7 +163,7 @@ impl Drop for ProposeIdGuard<'_> {
#[async_trait]
trait RepeatableClientApi: ClientApi {
/// Generate a unique propose id during the retry process.
fn gen_propose_id(&self) -> Result<ProposeIdGuard<'_>, Self::Error>;
async fn gen_propose_id(&self) -> Result<ProposeIdGuard<'_>, Self::Error>;

/// Send propose to the whole cluster, `use_fast_path` set to `false` to fallback into ordered
/// requests (event the requests are commutative).
Expand Down Expand Up @@ -392,51 +392,23 @@ impl ClientBuilder {
})
}

/// Wait for client id
async fn wait_for_client_id(&self, state: Arc<state::State>) -> Result<(), tonic::Status> {
/// Max retry count for waiting for a client ID
///
/// TODO: This retry count is set relatively high to avoid test cluster startup timeouts.
/// We should consider setting this to a more reasonable value.
const RETRY_COUNT: usize = 30;
/// The interval for each retry
const RETRY_INTERVAL: Duration = Duration::from_secs(1);

for _ in 0..RETRY_COUNT {
if state.client_id() != 0 {
return Ok(());
}
debug!("waiting for client_id");
tokio::time::sleep(RETRY_INTERVAL).await;
}

Err(tonic::Status::deadline_exceeded(
"timeout waiting for client id",
))
}

/// Build the client
///
/// # Errors
///
/// Return `tonic::transport::Error` for connection failure.
#[inline]
pub async fn build<C: Command>(
pub fn build<C: Command>(
&self,
) -> Result<impl ClientApi<Error = tonic::Status, Cmd = C> + Send + Sync + 'static, tonic::Status>
{
let state = Arc::new(
self.init_state_builder()
.build()
.await
.map_err(|e| tonic::Status::internal(e.to_string()))?,
);
let state = Arc::new(self.init_state_builder().build());
let client = Retry::new(
Unary::new(Arc::clone(&state), self.init_unary_config()),
self.init_retry_config(),
Some(self.spawn_bg_tasks(Arc::clone(&state))),
);
self.wait_for_client_id(state).await?;

Ok(client)
}

Expand All @@ -447,31 +419,23 @@ impl ClientBuilder {
///
/// Return `tonic::transport::Error` for connection failure.
#[inline]
pub async fn build_with_client_id<C: Command>(
#[must_use]
pub fn build_with_client_id<C: Command>(
&self,
) -> Result<
(
impl ClientApi<Error = tonic::Status, Cmd = C> + Send + Sync + 'static,
Arc<AtomicU64>,
),
tonic::Status,
> {
let state = Arc::new(
self.init_state_builder()
.build()
.await
.map_err(|e| tonic::Status::internal(e.to_string()))?,
);
) -> (
impl ClientApi<Error = tonic::Status, Cmd = C> + Send + Sync + 'static,
Arc<AtomicU64>,
) {
let state = Arc::new(self.init_state_builder().build());

let client = Retry::new(
Unary::new(Arc::clone(&state), self.init_unary_config()),
self.init_retry_config(),
Some(self.spawn_bg_tasks(Arc::clone(&state))),
);
let client_id = state.clone_client_id();
self.wait_for_client_id(state).await?;

Ok((client, client_id))
(client, client_id)
}
}

Expand All @@ -482,22 +446,20 @@ impl<P: Protocol> ClientBuilderWithBypass<P> {
///
/// Return `tonic::transport::Error` for connection failure.
#[inline]
pub async fn build<C: Command>(
pub fn build<C: Command>(
self,
) -> Result<impl ClientApi<Error = tonic::Status, Cmd = C>, tonic::Status> {
let state = self
.inner
.init_state_builder()
.build_bypassed::<P>(self.local_server_id, self.local_server)
.await
.map_err(|e| tonic::Status::internal(e.to_string()))?;
.build_bypassed::<P>(self.local_server_id, self.local_server);
let state = Arc::new(state);
let client = Retry::new(
Unary::new(Arc::clone(&state), self.inner.init_unary_config()),
self.inner.init_retry_config(),
Some(self.inner.spawn_bg_tasks(Arc::clone(&state))),
);
self.inner.wait_for_client_id(state).await?;

Ok(client)
}
}
8 changes: 4 additions & 4 deletions crates/curp/src/client/retry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ where
use_fast_path: bool,
) -> Result<ProposeResponse<Self::Cmd>, tonic::Status> {
self.retry::<_, _>(|client| async move {
let propose_id = self.inner.gen_propose_id()?;
let propose_id = self.inner.gen_propose_id().await?;
RepeatableClientApi::propose(client, *propose_id, cmd, token, use_fast_path).await
})
.await
Expand All @@ -245,7 +245,7 @@ where
self.retry::<_, _>(|client| {
let changes_c = changes.clone();
async move {
let propose_id = self.inner.gen_propose_id()?;
let propose_id = self.inner.gen_propose_id().await?;
RepeatableClientApi::propose_conf_change(client, *propose_id, changes_c).await
}
})
Expand All @@ -255,7 +255,7 @@ where
/// Send propose to shutdown cluster
async fn propose_shutdown(&self) -> Result<(), tonic::Status> {
self.retry::<_, _>(|client| async move {
let propose_id = self.inner.gen_propose_id()?;
let propose_id = self.inner.gen_propose_id().await?;
RepeatableClientApi::propose_shutdown(client, *propose_id).await
})
.await
Expand All @@ -272,7 +272,7 @@ where
let name_c = node_name.clone();
let node_client_urls_c = node_client_urls.clone();
async move {
let propose_id = self.inner.gen_propose_id()?;
let propose_id = self.inner.gen_propose_id().await?;
RepeatableClientApi::propose_publish(
client,
*propose_id,
Expand Down
55 changes: 38 additions & 17 deletions crates/curp/src/client/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ impl State {
tls_config,
is_raw_curp: true,
},
client_id: Arc::new(AtomicU64::new(0)),
// Sets the client id to non-zero to avoid waiting for client id in tests
client_id: Arc::new(AtomicU64::new(1)),
})
}

Expand Down Expand Up @@ -146,8 +147,8 @@ impl State {
};
let resp = rand_conn
.fetch_cluster(FetchClusterRequest::default(), REFRESH_TIMEOUT)
.await?;
self.check_and_update(&resp.into_inner()).await?;
.await;
self.check_and_update(&resp?.into_inner()).await?;
Ok(())
}

Expand Down Expand Up @@ -327,7 +328,7 @@ impl State {
.remove(&diff)
.unwrap_or_else(|| unreachable!("{diff} must in new member addrs"));
debug!("client connects to a new server({diff}), address({addrs:?})");
let new_conn = rpc::connect(diff, addrs, self.immutable.tls_config.clone()).await?;
let new_conn = rpc::connect(diff, addrs, self.immutable.tls_config.clone());
let _ig = e.insert(new_conn);
} else {
debug!("client removes old server({diff})");
Expand All @@ -347,6 +348,30 @@ impl State {

Ok(())
}

/// Wait for client id
pub(super) async fn wait_for_client_id(&self) -> Result<u64, tonic::Status> {
/// Max retry count for waiting for a client ID
///
/// TODO: This retry count is set relatively high to avoid test cluster startup timeouts.
/// We should consider setting this to a more reasonable value.
const RETRY_COUNT: usize = 30;
/// The interval for each retry
const RETRY_INTERVAL: Duration = Duration::from_secs(1);

for _ in 0..RETRY_COUNT {
let client_id = self.client_id();
if client_id != 0 {
return Ok(client_id);
}
debug!("waiting for client_id");
tokio::time::sleep(RETRY_INTERVAL).await;
}

Err(tonic::Status::deadline_exceeded(
"timeout waiting for client id",
))
}
}

/// Builder for state
Expand Down Expand Up @@ -395,24 +420,22 @@ impl StateBuilder {
}

/// Build the state with local server
pub(super) async fn build_bypassed<P: Protocol>(
pub(super) fn build_bypassed<P: Protocol>(
mut self,
local_server_id: ServerId,
local_server: P,
) -> Result<State, tonic::transport::Error> {
) -> State {
debug!("client bypassed server({local_server_id})");

let _ig = self.all_members.remove(&local_server_id);
let mut connects: HashMap<_, _> =
rpc::connects(self.all_members.clone(), self.tls_config.as_ref())
.await?
.collect();
rpc::connects(self.all_members.clone(), self.tls_config.as_ref()).collect();
let __ig = connects.insert(
local_server_id,
Arc::new(BypassedConnect::new(local_server_id, local_server)),
);

Ok(State {
State {
mutable: RwLock::new(StateMut {
leader: self.leader_state.map(|state| state.0),
term: self.leader_state.map_or(0, |state| state.1),
Expand All @@ -426,16 +449,14 @@ impl StateBuilder {
is_raw_curp: self.is_raw_curp,
},
client_id: Arc::new(AtomicU64::new(0)),
})
}
}

/// Build the state
pub(super) async fn build(self) -> Result<State, tonic::transport::Error> {
pub(super) fn build(self) -> State {
let connects: HashMap<_, _> =
rpc::connects(self.all_members.clone(), self.tls_config.as_ref())
.await?
.collect();
Ok(State {
rpc::connects(self.all_members.clone(), self.tls_config.as_ref()).collect();
State {
mutable: RwLock::new(StateMut {
leader: self.leader_state.map(|state| state.0),
term: self.leader_state.map_or(0, |state| state.1),
Expand All @@ -449,6 +470,6 @@ impl StateBuilder {
is_raw_curp: self.is_raw_curp,
},
client_id: Arc::new(AtomicU64::new(0)),
})
}
}
}
4 changes: 2 additions & 2 deletions crates/curp/src/client/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ async fn test_stream_client_keep_alive_works() {
Box::pin(async move {
client_id
.compare_exchange(
0,
1,
10,
std::sync::atomic::Ordering::Relaxed,
std::sync::atomic::Ordering::Relaxed,
Expand All @@ -773,7 +773,7 @@ async fn test_stream_client_keep_alive_on_redirect() {
Box::pin(async move {
client_id
.compare_exchange(
0,
1,
10,
std::sync::atomic::Ordering::Relaxed,
std::sync::atomic::Ordering::Relaxed,
Expand Down
15 changes: 9 additions & 6 deletions crates/curp/src/client/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ impl<C: Command> ClientApi for Unary<C> {
token: Option<&String>,
use_fast_path: bool,
) -> Result<ProposeResponse<C>, CurpError> {
let propose_id = self.gen_propose_id()?;
let propose_id = self.gen_propose_id().await?;
RepeatableClientApi::propose(self, *propose_id, cmd, token, use_fast_path).await
}

Expand All @@ -209,13 +209,13 @@ impl<C: Command> ClientApi for Unary<C> {
&self,
changes: Vec<ConfChange>,
) -> Result<Vec<Member>, CurpError> {
let propose_id = self.gen_propose_id()?;
let propose_id = self.gen_propose_id().await?;
RepeatableClientApi::propose_conf_change(self, *propose_id, changes).await
}

/// Send propose to shutdown cluster
async fn propose_shutdown(&self) -> Result<(), CurpError> {
let propose_id = self.gen_propose_id()?;
let propose_id = self.gen_propose_id().await?;
RepeatableClientApi::propose_shutdown(self, *propose_id).await
}

Expand All @@ -226,7 +226,7 @@ impl<C: Command> ClientApi for Unary<C> {
node_name: String,
node_client_urls: Vec<String>,
) -> Result<(), Self::Error> {
let propose_id = self.gen_propose_id()?;
let propose_id = self.gen_propose_id().await?;
RepeatableClientApi::propose_publish(
self,
*propose_id,
Expand Down Expand Up @@ -372,8 +372,11 @@ impl<C: Command> ClientApi for Unary<C> {
#[async_trait]
impl<C: Command> RepeatableClientApi for Unary<C> {
/// Generate a unique propose id during the retry process.
fn gen_propose_id(&self) -> Result<ProposeIdGuard<'_>, Self::Error> {
let client_id = self.state.client_id();
async fn gen_propose_id(&self) -> Result<ProposeIdGuard<'_>, Self::Error> {
let mut client_id = self.state.client_id();
if client_id == 0 {
client_id = self.state.wait_for_client_id().await?;
};
let seq_num = self.new_seq_num();
Ok(ProposeIdGuard::new(
&self.tracker,
Expand Down
2 changes: 0 additions & 2 deletions crates/curp/src/members.rs
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,6 @@ pub async fn get_cluster_info_from_remote(
let peers = init_cluster_info.peers_addrs();
let self_client_urls = init_cluster_info.self_client_urls();
let connects = rpc::connects(peers, tls_config)
.await
.ok()?
.map(|pair| pair.1)
.collect_vec();
let mut futs = connects
Expand Down
Loading
Loading