diff --git a/scylla/src/transport/connection_pool.rs b/scylla/src/transport/connection_pool.rs index f26ea36ac..11c2d671e 100644 --- a/scylla/src/transport/connection_pool.rs +++ b/scylla/src/transport/connection_pool.rs @@ -1,7 +1,7 @@ #[cfg(feature = "cloud")] use crate::cloud::set_ssl_config_for_scylla_cloud_host; -use crate::routing::{Shard, ShardCount, Sharder, Token}; +use crate::routing::{Shard, ShardCount, Sharder}; use crate::transport::errors::QueryError; use crate::transport::{ connection, @@ -28,7 +28,7 @@ use std::time::Duration; use tokio::sync::{broadcast, mpsc, Notify}; use tracing::instrument::WithSubscriber; -use tracing::{debug, trace, warn}; +use tracing::{debug, error, trace, warn}; /// The target size of a per-node connection pool. #[derive(Debug, Clone, Copy)] @@ -235,22 +235,25 @@ impl NodeConnectionPool { .unwrap_or(None) } - pub(crate) fn connection_for_token(&self, token: Token) -> Result, QueryError> { - trace!(token = token.value, "Selecting connection for token"); + pub(crate) fn connection_for_shard(&self, shard: Shard) -> Result, QueryError> { + trace!(shard = shard, "Selecting connection for shard"); self.with_connections(|pool_conns| match pool_conns { PoolConnections::NotSharded(conns) => { Self::choose_random_connection_from_slice(conns).unwrap() } PoolConnections::Sharded { - sharder, connections, + sharder } => { - let shard: u16 = sharder - .shard_of(token) + let shard = shard .try_into() - .expect("Shard number doesn't fit in u16"); - trace!(shard = shard, "Selecting connection for token"); - Self::connection_for_shard(shard, sharder.nr_shards, connections.as_slice()) + // It's safer to use 0 rather that panic here, as shards are returned by `LoadBalancingPolicy` + // now, which can be implemented by a user in an arbitrary way. + .unwrap_or_else(|_| { + error!("The provided shard number: {} does not fit u16! Using 0 as the shard number. Check your LoadBalancingPolicy implementation.", shard); + 0 + }); + Self::connection_for_shard_helper(shard, sharder.nr_shards, connections.as_slice()) } }) } @@ -266,13 +269,13 @@ impl NodeConnectionPool { connections, } => { let shard: u16 = rand::thread_rng().gen_range(0..sharder.nr_shards.get()); - Self::connection_for_shard(shard, sharder.nr_shards, connections.as_slice()) + Self::connection_for_shard_helper(shard, sharder.nr_shards, connections.as_slice()) } }) } // Tries to get a connection to given shard, if it's broken returns any working connection - fn connection_for_shard( + fn connection_for_shard_helper( shard: u16, nr_shards: ShardCount, shard_conns: &[Vec>], diff --git a/scylla/src/transport/iterator.rs b/scylla/src/transport/iterator.rs index 070c07a78..f277f5b64 100644 --- a/scylla/src/transport/iterator.rs +++ b/scylla/src/transport/iterator.rs @@ -35,7 +35,7 @@ use crate::transport::connection::{Connection, NonErrorQueryResponse, QueryRespo use crate::transport::load_balancing::{self, RoutingInfo}; use crate::transport::metrics::Metrics; use crate::transport::retry_policy::{QueryInfo, RetryDecision, RetrySession}; -use crate::transport::{Node, NodeRef}; +use crate::transport::NodeRef; use tracing::{trace, trace_span, warn, Instrument}; use uuid::Uuid; @@ -160,8 +160,6 @@ impl RowIterator { let worker_task = async move { let query_ref = &query; - let choose_connection = |node: Arc| async move { node.random_connection().await }; - let page_query = |connection: Arc, consistency: Consistency, paging_state: Option| { @@ -187,7 +185,6 @@ impl RowIterator { let worker = RowIteratorWorker { sender: sender.into(), - choose_connection, page_query, statement_info: routing_info, query_is_idempotent: query.config.is_idempotent, @@ -259,13 +256,6 @@ impl RowIterator { is_confirmed_lwt: config.prepared.is_confirmed_lwt(), }; - let choose_connection = |node: Arc| async move { - match token { - Some(token) => node.connection_for_token(token).await, - None => node.random_connection().await, - } - }; - let page_query = |connection: Arc, consistency: Consistency, paging_state: Option| async move { @@ -311,7 +301,6 @@ impl RowIterator { let worker = RowIteratorWorker { sender: sender.into(), - choose_connection, page_query, statement_info, query_is_idempotent: config.prepared.config.is_idempotent, @@ -496,13 +485,9 @@ type PageSendAttemptedProof = SendAttemptedProof { +struct RowIteratorWorker<'a, QueryFunc, SpanCreatorFunc> { sender: ProvingSender>, - // Closure used to choose a connection from a node - // AsyncFn(Arc) -> Result, QueryError> - choose_connection: ConnFunc, - // Closure used to perform a single page query // AsyncFn(Arc, Option) -> Result page_query: QueryFunc, @@ -524,11 +509,8 @@ struct RowIteratorWorker<'a, ConnFunc, QueryFunc, SpanCreatorFunc> { span_creator: SpanCreatorFunc, } -impl - RowIteratorWorker<'_, ConnFunc, QueryFunc, SpanCreator> +impl RowIteratorWorker<'_, QueryFunc, SpanCreator> where - ConnFunc: Fn(Arc) -> ConnFut, - ConnFut: Future, QueryError>>, QueryFunc: Fn(Arc, Consistency, Option) -> QueryFut, QueryFut: Future>, SpanCreator: Fn() -> RequestSpan, @@ -546,12 +528,13 @@ where self.log_query_start(); - 'nodes_in_plan: for (node, _shard) in query_plan { + 'nodes_in_plan: for (node, shard) in query_plan { let span = trace_span!(parent: &self.parent_span, "Executing query", node = %node.address); // For each node in the plan choose a connection to use // This connection will be reused for same node retries to preserve paging cache on the shard - let connection: Arc = match (self.choose_connection)(node.clone()) + let connection: Arc = match node + .connection_for_shard(shard) .instrument(span.clone()) .await { diff --git a/scylla/src/transport/node.rs b/scylla/src/transport/node.rs index 97b267946..07c34e130 100644 --- a/scylla/src/transport/node.rs +++ b/scylla/src/transport/node.rs @@ -3,7 +3,7 @@ use tracing::warn; use uuid::Uuid; /// Node represents a cluster node along with it's data and connections -use crate::routing::{Sharder, Token}; +use crate::routing::{Shard, Sharder}; use crate::transport::connection::Connection; use crate::transport::connection::VerifiedKeyspaceName; use crate::transport::connection_pool::{NodeConnectionPool, PoolConfig}; @@ -152,18 +152,13 @@ impl Node { self.pool.as_ref()?.sharder() } - /// Get connection which should be used to connect using given token - /// If this connection is broken get any random connection to this Node - pub(crate) async fn connection_for_token( + /// Get a connection targetting the given shard + /// If such connection is broken, get any random connection to this `Node` + pub(crate) async fn connection_for_shard( &self, - token: Token, + shard: Shard, ) -> Result, QueryError> { - self.get_pool()?.connection_for_token(token) - } - - /// Get random connection - pub(crate) async fn random_connection(&self) -> Result, QueryError> { - self.get_pool()?.random_connection() + self.get_pool()?.connection_for_shard(shard) } pub fn is_down(&self) -> bool { diff --git a/scylla/src/transport/session.rs b/scylla/src/transport/session.rs index 090bb9ef8..721b2af1d 100644 --- a/scylla/src/transport/session.rs +++ b/scylla/src/transport/session.rs @@ -655,7 +655,6 @@ impl Session { statement_info, &query.config, execution_profile, - |node: Arc| async move { node.random_connection().await }, |connection: Arc, consistency: Consistency, execution_profile: &ExecutionProfileInner| { @@ -1024,12 +1023,6 @@ impl Session { statement_info, &prepared.config, execution_profile, - |node: Arc| async move { - match token { - Some(token) => node.connection_for_token(token).await, - None => node.random_connection().await, - } - }, |connection: Arc, consistency: Consistency, execution_profile: &ExecutionProfileInner| { @@ -1236,14 +1229,6 @@ impl Session { statement_info, &batch.config, execution_profile, - |node: Arc| async move { - match first_value_token { - Some(first_value_token) => { - node.connection_for_token(first_value_token).await - } - None => node.random_connection().await, - } - }, |connection: Arc, consistency: Consistency, execution_profile: &ExecutionProfileInner| { @@ -1507,28 +1492,23 @@ impl Session { } // This method allows to easily run a query using load balancing, retry policy etc. - // Requires some information about the query and two closures - // First closure is used to choose a connection - // - query will use node.random_connection() - // - execute will use node.connection_for_token() - // The second closure is used to do the query itself on a connection + // Requires some information about the query and a closure. + // The closure is used to do the query itself on a connection. // - query will use connection.query() // - execute will use connection.execute() // If this query closure fails with some errors retry policy is used to perform retries // On success this query's result is returned // I tried to make this closures take a reference instead of an Arc but failed // maybe once async closures get stabilized this can be fixed - async fn run_query<'a, ConnFut, QueryFut, ResT>( + async fn run_query<'a, QueryFut, ResT>( &'a self, statement_info: RoutingInfo<'a>, statement_config: &'a StatementConfig, execution_profile: Arc, - choose_connection: impl Fn(Arc) -> ConnFut, do_query: impl Fn(Arc, Consistency, &ExecutionProfileInner) -> QueryFut, request_span: &'a RequestSpan, ) -> Result, QueryError> where - ConnFut: Future, QueryError>>, QueryFut: Future>, ResT: AllowedRunQueryResTType, { @@ -1602,7 +1582,6 @@ impl Session { self.execute_query( &shared_query_plan, - &choose_connection, &do_query, &execution_profile, ExecuteQueryContext { @@ -1638,7 +1617,6 @@ impl Session { }); self.execute_query( query_plan, - &choose_connection, &do_query, &execution_profile, ExecuteQueryContext { @@ -1684,16 +1662,14 @@ impl Session { result } - async fn execute_query<'a, ConnFut, QueryFut, ResT>( + async fn execute_query<'a, QueryFut, ResT>( &'a self, query_plan: impl Iterator, Shard)>, - choose_connection: impl Fn(Arc) -> ConnFut, do_query: impl Fn(Arc, Consistency, &ExecutionProfileInner) -> QueryFut, execution_profile: &ExecutionProfileInner, mut context: ExecuteQueryContext<'a>, ) -> Option, QueryError>> where - ConnFut: Future, QueryError>>, QueryFut: Future>, ResT: AllowedRunQueryResTType, { @@ -1702,14 +1678,11 @@ impl Session { .consistency_set_on_statement .unwrap_or(execution_profile.consistency); - 'nodes_in_plan: for (node, _shard) in query_plan { + 'nodes_in_plan: for (node, shard) in query_plan { let span = trace_span!("Executing query", node = %node.address); 'same_node_retries: loop { trace!(parent: &span, "Execution started"); - let connection: Arc = match choose_connection(node.clone()) - .instrument(span.clone()) - .await - { + let connection = match node.connection_for_shard(shard).await { Ok(connection) => connection, Err(e) => { trace!(