Skip to content

Commit

Permalink
Use shard from query plan during execution
Browse files Browse the repository at this point in the history
Co-authored-by: Wojciech Przytuła <[email protected]>
  • Loading branch information
Lorak-mmk and wprzytula committed Mar 12, 2024
1 parent ae069e8 commit 48aa642
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 79 deletions.
27 changes: 15 additions & 12 deletions scylla/src/transport/connection_pool.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#[cfg(feature = "cloud")]
use crate::cloud::set_ssl_config_for_scylla_cloud_host;

use crate::routing::{Shard, ShardCount, Sharder, Token};
use crate::routing::{Shard, ShardCount, Sharder};
use crate::transport::errors::QueryError;
use crate::transport::{
connection,
Expand All @@ -28,7 +28,7 @@ use std::time::Duration;

use tokio::sync::{broadcast, mpsc, Notify};
use tracing::instrument::WithSubscriber;
use tracing::{debug, trace, warn};
use tracing::{debug, error, trace, warn};

/// The target size of a per-node connection pool.
#[derive(Debug, Clone, Copy)]
Expand Down Expand Up @@ -235,22 +235,25 @@ impl NodeConnectionPool {
.unwrap_or(None)
}

pub(crate) fn connection_for_token(&self, token: Token) -> Result<Arc<Connection>, QueryError> {
trace!(token = token.value, "Selecting connection for token");
pub(crate) fn connection_for_shard(&self, shard: Shard) -> Result<Arc<Connection>, QueryError> {
trace!(shard = shard, "Selecting connection for shard");
self.with_connections(|pool_conns| match pool_conns {
PoolConnections::NotSharded(conns) => {
Self::choose_random_connection_from_slice(conns).unwrap()
}
PoolConnections::Sharded {
sharder,
connections,
sharder
} => {
let shard: u16 = sharder
.shard_of(token)
let shard = shard
.try_into()
.expect("Shard number doesn't fit in u16");
trace!(shard = shard, "Selecting connection for token");
Self::connection_for_shard(shard, sharder.nr_shards, connections.as_slice())
// It's safer to use 0 rather that panic here, as shards are returned by `LoadBalancingPolicy`
// now, which can be implemented by a user in an arbitrary way.
.unwrap_or_else(|_| {
error!("The provided shard number: {} does not fit u16! Using 0 as the shard number. Check your LoadBalancingPolicy implementation.", shard);
0
});
Self::connection_for_shard_helper(shard, sharder.nr_shards, connections.as_slice())
}
})
}
Expand All @@ -266,13 +269,13 @@ impl NodeConnectionPool {
connections,
} => {
let shard: u16 = rand::thread_rng().gen_range(0..sharder.nr_shards.get());
Self::connection_for_shard(shard, sharder.nr_shards, connections.as_slice())
Self::connection_for_shard_helper(shard, sharder.nr_shards, connections.as_slice())
}
})
}

// Tries to get a connection to given shard, if it's broken returns any working connection
fn connection_for_shard(
fn connection_for_shard_helper(
shard: u16,
nr_shards: ShardCount,
shard_conns: &[Vec<Arc<Connection>>],
Expand Down
29 changes: 6 additions & 23 deletions scylla/src/transport/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use crate::transport::connection::{Connection, NonErrorQueryResponse, QueryRespo
use crate::transport::load_balancing::{self, RoutingInfo};
use crate::transport::metrics::Metrics;
use crate::transport::retry_policy::{QueryInfo, RetryDecision, RetrySession};
use crate::transport::{Node, NodeRef};
use crate::transport::NodeRef;
use tracing::{trace, trace_span, warn, Instrument};
use uuid::Uuid;

Expand Down Expand Up @@ -160,8 +160,6 @@ impl RowIterator {
let worker_task = async move {
let query_ref = &query;

let choose_connection = |node: Arc<Node>| async move { node.random_connection().await };

let page_query = |connection: Arc<Connection>,
consistency: Consistency,
paging_state: Option<Bytes>| {
Expand All @@ -187,7 +185,6 @@ impl RowIterator {

let worker = RowIteratorWorker {
sender: sender.into(),
choose_connection,
page_query,
statement_info: routing_info,
query_is_idempotent: query.config.is_idempotent,
Expand Down Expand Up @@ -259,13 +256,6 @@ impl RowIterator {
is_confirmed_lwt: config.prepared.is_confirmed_lwt(),
};

let choose_connection = |node: Arc<Node>| async move {
match token {
Some(token) => node.connection_for_token(token).await,
None => node.random_connection().await,
}
};

let page_query = |connection: Arc<Connection>,
consistency: Consistency,
paging_state: Option<Bytes>| async move {
Expand Down Expand Up @@ -311,7 +301,6 @@ impl RowIterator {

let worker = RowIteratorWorker {
sender: sender.into(),
choose_connection,
page_query,
statement_info,
query_is_idempotent: config.prepared.config.is_idempotent,
Expand Down Expand Up @@ -496,13 +485,9 @@ type PageSendAttemptedProof = SendAttemptedProof<Result<ReceivedPage, QueryError

// RowIteratorWorker works in the background to fetch pages
// RowIterator receives them through a channel
struct RowIteratorWorker<'a, ConnFunc, QueryFunc, SpanCreatorFunc> {
struct RowIteratorWorker<'a, QueryFunc, SpanCreatorFunc> {
sender: ProvingSender<Result<ReceivedPage, QueryError>>,

// Closure used to choose a connection from a node
// AsyncFn(Arc<Node>) -> Result<Arc<Connection>, QueryError>
choose_connection: ConnFunc,

// Closure used to perform a single page query
// AsyncFn(Arc<Connection>, Option<Bytes>) -> Result<QueryResponse, QueryError>
page_query: QueryFunc,
Expand All @@ -524,11 +509,8 @@ struct RowIteratorWorker<'a, ConnFunc, QueryFunc, SpanCreatorFunc> {
span_creator: SpanCreatorFunc,
}

impl<ConnFunc, ConnFut, QueryFunc, QueryFut, SpanCreator>
RowIteratorWorker<'_, ConnFunc, QueryFunc, SpanCreator>
impl<QueryFunc, QueryFut, SpanCreator> RowIteratorWorker<'_, QueryFunc, SpanCreator>
where
ConnFunc: Fn(Arc<Node>) -> ConnFut,
ConnFut: Future<Output = Result<Arc<Connection>, QueryError>>,
QueryFunc: Fn(Arc<Connection>, Consistency, Option<Bytes>) -> QueryFut,
QueryFut: Future<Output = Result<QueryResponse, QueryError>>,
SpanCreator: Fn() -> RequestSpan,
Expand All @@ -546,12 +528,13 @@ where

self.log_query_start();

'nodes_in_plan: for (node, _shard) in query_plan {
'nodes_in_plan: for (node, shard) in query_plan {
let span =
trace_span!(parent: &self.parent_span, "Executing query", node = %node.address);
// For each node in the plan choose a connection to use
// This connection will be reused for same node retries to preserve paging cache on the shard
let connection: Arc<Connection> = match (self.choose_connection)(node.clone())
let connection: Arc<Connection> = match node
.connection_for_shard(shard)
.instrument(span.clone())
.await
{
Expand Down
17 changes: 6 additions & 11 deletions scylla/src/transport/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use tracing::warn;
use uuid::Uuid;

/// Node represents a cluster node along with it's data and connections
use crate::routing::{Sharder, Token};
use crate::routing::{Shard, Sharder};
use crate::transport::connection::Connection;
use crate::transport::connection::VerifiedKeyspaceName;
use crate::transport::connection_pool::{NodeConnectionPool, PoolConfig};
Expand Down Expand Up @@ -152,18 +152,13 @@ impl Node {
self.pool.as_ref()?.sharder()
}

/// Get connection which should be used to connect using given token
/// If this connection is broken get any random connection to this Node
pub(crate) async fn connection_for_token(
/// Get a connection targetting the given shard
/// If such connection is broken, get any random connection to this `Node`
pub(crate) async fn connection_for_shard(
&self,
token: Token,
shard: Shard,
) -> Result<Arc<Connection>, QueryError> {
self.get_pool()?.connection_for_token(token)
}

/// Get random connection
pub(crate) async fn random_connection(&self) -> Result<Arc<Connection>, QueryError> {
self.get_pool()?.random_connection()
self.get_pool()?.connection_for_shard(shard)
}

pub fn is_down(&self) -> bool {
Expand Down
39 changes: 6 additions & 33 deletions scylla/src/transport/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,6 @@ impl Session {
statement_info,
&query.config,
execution_profile,
|node: Arc<Node>| async move { node.random_connection().await },
|connection: Arc<Connection>,
consistency: Consistency,
execution_profile: &ExecutionProfileInner| {
Expand Down Expand Up @@ -1024,12 +1023,6 @@ impl Session {
statement_info,
&prepared.config,
execution_profile,
|node: Arc<Node>| async move {
match token {
Some(token) => node.connection_for_token(token).await,
None => node.random_connection().await,
}
},
|connection: Arc<Connection>,
consistency: Consistency,
execution_profile: &ExecutionProfileInner| {
Expand Down Expand Up @@ -1236,14 +1229,6 @@ impl Session {
statement_info,
&batch.config,
execution_profile,
|node: Arc<Node>| async move {
match first_value_token {
Some(first_value_token) => {
node.connection_for_token(first_value_token).await
}
None => node.random_connection().await,
}
},
|connection: Arc<Connection>,
consistency: Consistency,
execution_profile: &ExecutionProfileInner| {
Expand Down Expand Up @@ -1507,28 +1492,23 @@ impl Session {
}

// This method allows to easily run a query using load balancing, retry policy etc.
// Requires some information about the query and two closures
// First closure is used to choose a connection
// - query will use node.random_connection()
// - execute will use node.connection_for_token()
// The second closure is used to do the query itself on a connection
// Requires some information about the query and a closure.
// The closure is used to do the query itself on a connection.
// - query will use connection.query()
// - execute will use connection.execute()
// If this query closure fails with some errors retry policy is used to perform retries
// On success this query's result is returned
// I tried to make this closures take a reference instead of an Arc but failed
// maybe once async closures get stabilized this can be fixed
async fn run_query<'a, ConnFut, QueryFut, ResT>(
async fn run_query<'a, QueryFut, ResT>(
&'a self,
statement_info: RoutingInfo<'a>,
statement_config: &'a StatementConfig,
execution_profile: Arc<ExecutionProfileInner>,
choose_connection: impl Fn(Arc<Node>) -> ConnFut,
do_query: impl Fn(Arc<Connection>, Consistency, &ExecutionProfileInner) -> QueryFut,
request_span: &'a RequestSpan,
) -> Result<RunQueryResult<ResT>, QueryError>
where
ConnFut: Future<Output = Result<Arc<Connection>, QueryError>>,
QueryFut: Future<Output = Result<ResT, QueryError>>,
ResT: AllowedRunQueryResTType,
{
Expand Down Expand Up @@ -1602,7 +1582,6 @@ impl Session {

self.execute_query(
&shared_query_plan,
&choose_connection,
&do_query,
&execution_profile,
ExecuteQueryContext {
Expand Down Expand Up @@ -1638,7 +1617,6 @@ impl Session {
});
self.execute_query(
query_plan,
&choose_connection,
&do_query,
&execution_profile,
ExecuteQueryContext {
Expand Down Expand Up @@ -1684,16 +1662,14 @@ impl Session {
result
}

async fn execute_query<'a, ConnFut, QueryFut, ResT>(
async fn execute_query<'a, QueryFut, ResT>(
&'a self,
query_plan: impl Iterator<Item = (NodeRef<'a>, Shard)>,
choose_connection: impl Fn(Arc<Node>) -> ConnFut,
do_query: impl Fn(Arc<Connection>, Consistency, &ExecutionProfileInner) -> QueryFut,
execution_profile: &ExecutionProfileInner,
mut context: ExecuteQueryContext<'a>,
) -> Option<Result<RunQueryResult<ResT>, QueryError>>
where
ConnFut: Future<Output = Result<Arc<Connection>, QueryError>>,
QueryFut: Future<Output = Result<ResT, QueryError>>,
ResT: AllowedRunQueryResTType,
{
Expand All @@ -1702,14 +1678,11 @@ impl Session {
.consistency_set_on_statement
.unwrap_or(execution_profile.consistency);

'nodes_in_plan: for (node, _shard) in query_plan {
'nodes_in_plan: for (node, shard) in query_plan {
let span = trace_span!("Executing query", node = %node.address);
'same_node_retries: loop {
trace!(parent: &span, "Execution started");
let connection: Arc<Connection> = match choose_connection(node.clone())
.instrument(span.clone())
.await
{
let connection = match node.connection_for_shard(shard).await {
Ok(connection) => connection,
Err(e) => {
trace!(
Expand Down

0 comments on commit 48aa642

Please sign in to comment.