diff --git a/scylla-cql/src/frame/mod.rs b/scylla-cql/src/frame/mod.rs index cec55d140f..715ba43984 100644 --- a/scylla-cql/src/frame/mod.rs +++ b/scylla-cql/src/frame/mod.rs @@ -48,15 +48,21 @@ pub enum Compression { Snappy, } -impl Display for Compression { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl Compression { + pub fn as_str(&self) -> &'static str { match self { - Compression::Lz4 => f.write_str("lz4"), - Compression::Snappy => f.write_str("snappy"), + Compression::Lz4 => "lz4", + Compression::Snappy => "snappy", } } } +impl Display for Compression { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + pub struct SerializedRequest { data: Vec, } diff --git a/scylla-cql/src/frame/protocol_features.rs b/scylla-cql/src/frame/protocol_features.rs index a1687485c3..fd142ac647 100644 --- a/scylla-cql/src/frame/protocol_features.rs +++ b/scylla-cql/src/frame/protocol_features.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::collections::HashMap; const RATE_LIMIT_ERROR_EXTENSION: &str = "SCYLLA_RATE_LIMIT_ERROR"; @@ -51,19 +52,19 @@ impl ProtocolFeatures { .find_map(|v| v.as_str().strip_prefix(key)?.strip_prefix('=')) } - pub fn add_startup_options(&self, options: &mut HashMap) { + pub fn add_startup_options(&self, options: &mut HashMap, Cow<'_, str>>) { if self.rate_limit_error.is_some() { - options.insert(RATE_LIMIT_ERROR_EXTENSION.to_string(), String::new()); + options.insert(Cow::Borrowed(RATE_LIMIT_ERROR_EXTENSION), Cow::Borrowed("")); } if let Some(mask) = self.lwt_optimization_meta_bit_mask { options.insert( - SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION.to_string(), - format!("{}={}", LWT_OPTIMIZATION_META_BIT_MASK_KEY, mask), + Cow::Borrowed(SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION), + Cow::Owned(format!("{}={}", LWT_OPTIMIZATION_META_BIT_MASK_KEY, mask)), ); } if self.tablets_v1_supported { - options.insert(TABLETS_ROUTING_V1_KEY.to_string(), String::new()); + options.insert(Cow::Borrowed(TABLETS_ROUTING_V1_KEY), Cow::Borrowed("")); } } diff --git a/scylla-cql/src/frame/request/options.rs b/scylla-cql/src/frame/request/options.rs index 5efdada0c6..6ea6517ce6 100644 --- a/scylla-cql/src/frame/request/options.rs +++ b/scylla-cql/src/frame/request/options.rs @@ -11,3 +11,20 @@ impl SerializableRequest for Options { Ok(()) } } + +/* Key names for options in SUPPORTED/STARTUP */ +pub const SCYLLA_SHARD_AWARE_PORT: &str = "SCYLLA_SHARD_AWARE_PORT"; +pub const SCYLLA_SHARD_AWARE_PORT_SSL: &str = "SCYLLA_SHARD_AWARE_PORT_SSL"; + +pub const COMPRESSION: &str = "COMPRESSION"; +pub const CQL_VERSION: &str = "CQL_VERSION"; +pub const DRIVER_NAME: &str = "DRIVER_NAME"; +pub const DRIVER_VERSION: &str = "DRIVER_VERSION"; +pub const APPLICATION_NAME: &str = "APPLICATION_NAME"; +pub const APPLICATION_VERSION: &str = "APPLICATION_VERSION"; +pub const CLIENT_ID: &str = "CLIENT_ID"; + +/* Value names for options in SUPPORTED/STARTUP */ +pub const DEFAULT_CQL_PROTOCOL_VERSION: &str = "4.0.0"; +pub const DEFAULT_DRIVER_NAME: &str = "ScyllaDB Rust Driver"; +pub const DEFAULT_DRIVER_VERSION: &str = env!("CARGO_PKG_VERSION"); diff --git a/scylla-cql/src/frame/request/startup.rs b/scylla-cql/src/frame/request/startup.rs index a1d41df5c4..6759d0cfce 100644 --- a/scylla-cql/src/frame/request/startup.rs +++ b/scylla-cql/src/frame/request/startup.rs @@ -1,17 +1,17 @@ use crate::frame::frame_errors::ParseError; -use std::collections::HashMap; +use std::{borrow::Cow, collections::HashMap}; use crate::{ frame::request::{RequestOpcode, SerializableRequest}, frame::types, }; -pub struct Startup { - pub options: HashMap, +pub struct Startup<'a> { + pub options: HashMap, Cow<'a, str>>, } -impl SerializableRequest for Startup { +impl SerializableRequest for Startup<'_> { const OPCODE: RequestOpcode = RequestOpcode::Startup; fn serialize(&self, buf: &mut Vec) -> Result<(), ParseError> { diff --git a/scylla-cql/src/frame/types.rs b/scylla-cql/src/frame/types.rs index 77497cd37a..0cc791be31 100644 --- a/scylla-cql/src/frame/types.rs +++ b/scylla-cql/src/frame/types.rs @@ -428,14 +428,14 @@ pub fn read_string_map( } pub fn write_string_map( - v: &HashMap, + v: &HashMap, impl AsRef>, buf: &mut impl BufMut, ) -> Result<(), std::num::TryFromIntError> { let len = v.len(); write_short_length(len, buf)?; for (key, val) in v.iter() { - write_string(key, buf)?; - write_string(val, buf)?; + write_string(key.as_ref(), buf)?; + write_string(val.as_ref(), buf)?; } Ok(()) } diff --git a/scylla/src/transport/connection.rs b/scylla/src/transport/connection.rs index cdc6e730ea..52795c30d3 100644 --- a/scylla/src/transport/connection.rs +++ b/scylla/src/transport/connection.rs @@ -1,7 +1,7 @@ use bytes::Bytes; use futures::{future::RemoteHandle, FutureExt}; use scylla_cql::errors::TranslationError; -use scylla_cql::frame::request::options::Options; +use scylla_cql::frame::request::options::{self, Options}; use scylla_cql::frame::response::result::{ResultMetadata, TableSpec}; use scylla_cql::frame::response::Error; use scylla_cql::frame::types::SerialConsistency; @@ -350,6 +350,188 @@ mod ssl_config { } } +/// Driver and application self-identifying information, +/// to be sent in STARTUP message. +#[derive(Debug, Clone, Default)] +pub struct SelfIdentity<'id> { + // Custom driver identity can be set if a custom driver build is running, + // or an entirely different driver is operating on top of Rust driver + // (e.g. cpp-rust-driver). + custom_driver_name: Option>, + custom_driver_version: Option>, + + // ### Q: Where do APPLICATION_NAME, APPLICATION_VERSION and CLIENT_ID come from? + // - there are no columns in system.clients dedicated to those attributes, + // - APPLICATION_NAME / APPLICATION_VERSION are not present in Scylla's source code at all, + // - only 2 results in Cassandra source is some example in docs: + // https://github.com/apache/cassandra/blob/d3cbf9c1f72057d2a5da9df8ed567d20cd272931/doc/modules/cassandra/pages/managing/operating/virtualtables.adoc?plain=1#L218. + // APPLICATION_NAME and APPLICATION_VERSION appears in client_options which + // is an arbitrary dict where client can send any keys. + // - driver variables are mentioned in protocol v5 + // (https://github.com/apache/cassandra/blob/d3cbf9c1f72057d2a5da9df8ed567d20cd272931/doc/native_protocol_v5.spec#L480), + // application variables are not. + // + // ### A: + // The following options are not exposed anywhere in Scylla tables. + // They come directly from CPP driver, and they are supported in Cassandra + // + // See https://github.com/scylladb/cpp-driver/blob/fa0f27069a625057984d1fa58f434ea99b86c83f/include/cassandra.h#L2916. + // As we want to support as big subset of its API as possible in cpp-rust-driver, I decided to expose API for setting + // those particular key-value pairs, similarly to what cpp-driver does, and not an API to set arbitrary key-value pairs. + // + // Allowing users to set arbitrary options could break the driver by overwriting options that bear special meaning, + // e.g. the shard-aware port. Therefore, I'm against such liberal API. OTOH, we need to expose APPLICATION_NAME, + // APPLICATION_VERSION and CLIENT_ID for cpp-rust-driver. + + // Application identity can be set to distinguish different applications + // connected to the same cluster. + application_name: Option>, + application_version: Option>, + + // A (unique) client ID can be set to distinguish different instances + // of the same application connected to the same cluster. + client_id: Option>, +} + +impl<'id> SelfIdentity<'id> { + pub fn new() -> Self { + Self::default() + } + + /// Advertises a custom driver name, which can be used if a custom driver build is running, + /// or an entirely different driver is operating on top of Rust driver + /// (e.g. cpp-rust-driver). + pub fn set_custom_driver_name(&mut self, custom_driver_name: impl Into>) { + self.custom_driver_name = Some(custom_driver_name.into()); + } + + /// Advertises a custom driver name. See [Self::set_custom_driver_name] for use cases. + pub fn with_custom_driver_name(mut self, custom_driver_name: impl Into>) -> Self { + self.custom_driver_name = Some(custom_driver_name.into()); + self + } + + /// Custom driver name to be advertised. See [Self::set_custom_driver_name] for use cases. + pub fn get_custom_driver_name(&self) -> Option<&str> { + self.custom_driver_name.as_deref() + } + + /// Advertises a custom driver version. See [Self::set_custom_driver_name] for use cases. + pub fn set_custom_driver_version(&mut self, custom_driver_version: impl Into>) { + self.custom_driver_version = Some(custom_driver_version.into()); + } + + /// Advertises a custom driver version. See [Self::set_custom_driver_name] for use cases. + pub fn with_custom_driver_version( + mut self, + custom_driver_version: impl Into>, + ) -> Self { + self.custom_driver_version = Some(custom_driver_version.into()); + self + } + + /// Custom driver version to be advertised. See [Self::set_custom_driver_version] for use cases. + pub fn get_custom_driver_version(&self) -> Option<&str> { + self.custom_driver_version.as_deref() + } + + /// Advertises an application name, which can be used to distinguish different applications + /// connected to the same cluster. + pub fn set_application_name(&mut self, application_name: impl Into>) { + self.application_name = Some(application_name.into()); + } + + /// Advertises an application name. See [Self::set_application_name] for use cases. + pub fn with_application_name(mut self, application_name: impl Into>) -> Self { + self.application_name = Some(application_name.into()); + self + } + + /// Application name to be advertised. See [Self::set_application_name] for use cases. + pub fn get_application_name(&self) -> Option<&str> { + self.application_name.as_deref() + } + + /// Advertises an application version. See [Self::set_application_name] for use cases. + pub fn set_application_version(&mut self, application_version: impl Into>) { + self.application_version = Some(application_version.into()); + } + + /// Advertises an application version. See [Self::set_application_name] for use cases. + pub fn with_application_version( + mut self, + application_version: impl Into>, + ) -> Self { + self.application_version = Some(application_version.into()); + self + } + + /// Application version to be advertised. See [Self::set_application_version] for use cases. + pub fn get_application_version(&self) -> Option<&str> { + self.application_version.as_deref() + } + + /// Advertises a client ID, which can be set to distinguish different instances + /// of the same application connected to the same cluster. + pub fn set_client_id(&mut self, client_id: impl Into>) { + self.client_id = Some(client_id.into()); + } + + /// Advertises a client ID. See [Self::set_client_id] for use cases. + pub fn with_client_id(mut self, client_id: impl Into>) -> Self { + self.client_id = Some(client_id.into()); + self + } + + /// Client ID to be advertised. See [Self::set_client_id] for use cases. + pub fn get_client_id(&self) -> Option<&str> { + self.client_id.as_deref() + } +} + +impl<'id: 'map, 'map> SelfIdentity<'id> { + fn add_startup_options(&'id self, options: &'map mut HashMap, Cow<'id, str>>) { + /* Driver identity. */ + let driver_name = self + .custom_driver_name + .as_deref() + .unwrap_or(options::DEFAULT_DRIVER_NAME); + options.insert( + Cow::Borrowed(options::DRIVER_NAME), + Cow::Borrowed(driver_name), + ); + + let driver_version = self + .custom_driver_version + .as_deref() + .unwrap_or(options::DEFAULT_DRIVER_VERSION); + options.insert( + Cow::Borrowed(options::DRIVER_VERSION), + Cow::Borrowed(driver_version), + ); + + /* Application identity. */ + if let Some(application_name) = self.application_name.as_deref() { + options.insert( + Cow::Borrowed(options::APPLICATION_NAME), + Cow::Borrowed(application_name), + ); + } + + if let Some(application_version) = self.application_version.as_deref() { + options.insert( + Cow::Borrowed(options::APPLICATION_VERSION), + Cow::Borrowed(application_version), + ); + } + + /* Client identity. */ + if let Some(client_id) = self.client_id.as_deref() { + options.insert(Cow::Borrowed(options::CLIENT_ID), Cow::Borrowed(client_id)); + } + } +} + #[derive(Clone)] pub(crate) struct ConnectionConfig { pub(crate) compression: Option, @@ -370,6 +552,8 @@ pub(crate) struct ConnectionConfig { pub(crate) keepalive_interval: Option, pub(crate) keepalive_timeout: Option, pub(crate) tablet_sender: Option, RawTablet)>>, + + pub(crate) identity: SelfIdentity<'static>, } impl Default for ConnectionConfig { @@ -394,6 +578,8 @@ impl Default for ConnectionConfig { keepalive_timeout: None, tablet_sender: None, + + identity: SelfIdentity::default(), } } } @@ -419,6 +605,8 @@ pub(crate) type ErrorReceiver = tokio::sync::oneshot::Receiver; impl Connection { // Returns new connection and ErrorReceiver which can be used to wait for a fatal error + /// Opens a connection and makes it ready to send/receive CQL frames on it, + /// but does not yet send any frames (no OPTIONS/STARTUP handshake nor REGISTER requests). pub(crate) async fn new( addr: SocketAddr, source_port: Option, @@ -440,52 +628,7 @@ impl Connection { stream.set_nodelay(config.tcp_nodelay)?; if let Some(tcp_keepalive_interval) = config.tcp_keepalive_interval { - // It may be surprising why we call `with_time()` with `tcp_keepalive_interval` - // and `with_interval() with some other value. This is due to inconsistent naming: - // our interval means time after connection becomes idle until keepalives - // begin to be sent (they call it "time"), and their interval is time between - // sending keepalives. - // We insist on our naming due to other drivers following the same convention. - let mut tcp_keepalive = TcpKeepalive::new().with_time(tcp_keepalive_interval); - - // These cfg values are taken from socket2 library, which uses the same constraints. - #[cfg(any( - target_os = "android", - target_os = "dragonfly", - target_os = "freebsd", - target_os = "fuchsia", - target_os = "illumos", - target_os = "ios", - target_os = "linux", - target_os = "macos", - target_os = "netbsd", - target_os = "tvos", - target_os = "watchos", - target_os = "windows", - ))] - { - tcp_keepalive = tcp_keepalive.with_interval(Duration::from_secs(1)); - } - - #[cfg(any( - target_os = "android", - target_os = "dragonfly", - target_os = "freebsd", - target_os = "fuchsia", - target_os = "illumos", - target_os = "ios", - target_os = "linux", - target_os = "macos", - target_os = "netbsd", - target_os = "tvos", - target_os = "watchos", - ))] - { - tcp_keepalive = tcp_keepalive.with_retries(10); - } - - let sf = SockRef::from(&stream); - sf.set_tcp_keepalive(&tcp_keepalive)?; + Self::setup_tcp_keepalive(&stream, tcp_keepalive_interval)?; } // TODO: What should be the size of the channel? @@ -522,9 +665,61 @@ impl Connection { Ok((connection, error_receiver)) } + fn setup_tcp_keepalive( + stream: &TcpStream, + tcp_keepalive_interval: Duration, + ) -> std::io::Result<()> { + // It may be surprising why we call `with_time()` with `tcp_keepalive_interval` + // and `with_interval() with some other value. This is due to inconsistent naming: + // our interval means time after connection becomes idle until keepalives + // begin to be sent (they call it "time"), and their interval is time between + // sending keepalives. + // We insist on our naming due to other drivers following the same convention. + let mut tcp_keepalive = TcpKeepalive::new().with_time(tcp_keepalive_interval); + + // These cfg values are taken from socket2 library, which uses the same constraints. + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "tvos", + target_os = "watchos", + target_os = "windows", + ))] + { + tcp_keepalive = tcp_keepalive.with_interval(Duration::from_secs(1)); + } + + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "tvos", + target_os = "watchos", + ))] + { + tcp_keepalive = tcp_keepalive.with_retries(10); + } + + let sf = SockRef::from(&stream); + sf.set_tcp_keepalive(&tcp_keepalive) + } + pub(crate) async fn startup( &self, - options: HashMap, + options: HashMap, Cow<'_, str>>, ) -> Result { Ok(self .send_request(&request::Startup { options }, false, false, None) @@ -1507,38 +1702,31 @@ async fn maybe_translated_addr( } } +/// Opens a connection and performs its setup on CQL level: +/// - performs OPTIONS/STARTUP handshake (chooses desired connections options); +/// - registers for all event types using REGISTER request (if this is control connection). +/// +/// At the beginning, translates node's address, if it is subject to address translation. pub(crate) async fn open_connection( endpoint: UntranslatedEndpoint, source_port: Option, config: &ConnectionConfig, ) -> Result<(Connection, ErrorReceiver), QueryError> { + /* Translate the address, if applicable. */ let addr = maybe_translated_addr(endpoint, config.address_translator.as_deref()).await?; - open_named_connection( - addr, - source_port, - config, - Some("scylla-rust-driver".to_string()), - option_env!("CARGO_PKG_VERSION").map(|v| v.to_string()), - ) - .await -} -pub(crate) async fn open_named_connection( - addr: SocketAddr, - source_port: Option, - config: &ConnectionConfig, - driver_name: Option, - driver_version: Option, -) -> Result<(Connection, ErrorReceiver), QueryError> { - // TODO: shouldn't all this logic be in Connection::new? + /* Setup connection on TCP level and prepare for sending/receiving CQL frames. */ let (mut connection, error_receiver) = Connection::new(addr, source_port, config.clone()).await?; + /* Perform OPTIONS/SUPPORTED/STARTUP handshake. */ + + // Get OPTIONS SUPPORTED by the cluster. let options_result = connection.get_options().await?; let shard_aware_port_key = match config.is_ssl() { - true => "SCYLLA_SHARD_AWARE_PORT_SSL", - false => "SCYLLA_SHARD_AWARE_PORT", + true => options::SCYLLA_SHARD_AWARE_PORT_SSL, + false => options::SCYLLA_SHARD_AWARE_PORT, }; let mut supported = match options_result { @@ -1551,21 +1739,24 @@ pub(crate) async fn open_named_connection( } }; + // If this is ScyllaDB that we connected to, we received sharding information. let shard_info = ShardInfo::try_from(&supported.options).ok(); - let supported_compression = supported.options.remove("COMPRESSION").unwrap_or_default(); + let supported_compression = supported + .options + .remove(options::COMPRESSION) + .unwrap_or_default(); let shard_aware_port = supported .options .remove(shard_aware_port_key) .unwrap_or_default() - .into_iter() - .next() + .first() .and_then(|p| p.parse::().ok()); + // Parse nonstandard protocol extensions. let protocol_features = ProtocolFeatures::parse_from_supported(&supported.options); - let mut options = HashMap::new(); - protocol_features.add_startup_options(&mut options); - + // At the beginning, Connection assumes no sharding and no protocol extensions; + // now that we know them, let's turn them on in the driver. let features = ConnectionFeatures { shard_info, shard_aware_port, @@ -1573,24 +1764,40 @@ pub(crate) async fn open_named_connection( }; connection.set_features(features); - options.insert("CQL_VERSION".to_string(), "4.0.0".to_string()); // FIXME: hardcoded values - if let Some(name) = driver_name { - options.insert("DRIVER_NAME".to_string(), name); - } - if let Some(version) = driver_version { - options.insert("DRIVER_VERSION".to_string(), version); - } + /* Prepare options that the driver opts-in in STARTUP frame. */ + let mut options = HashMap::new(); + protocol_features.add_startup_options(&mut options); + + // The only CQL protocol version supported by the driver. + options.insert( + Cow::Borrowed(options::CQL_VERSION), + Cow::Borrowed(options::DEFAULT_CQL_PROTOCOL_VERSION), + ); + + // Application & driver's identity. + config.identity.add_startup_options(&mut options); + + // Optional compression. if let Some(compression) = &config.compression { - let compression_str = compression.to_string(); - if supported_compression.iter().any(|c| c == &compression_str) { + let compression_str = compression.as_str(); + if supported_compression.iter().any(|c| c == compression_str) { // Compression is reported to be supported by the server, // request it from the server - options.insert("COMPRESSION".to_string(), compression.to_string()); + options.insert( + Cow::Borrowed(options::COMPRESSION), + Cow::Borrowed(compression_str), + ); } else { // Fall back to no compression + tracing::warn!( + "Requested compression <{}> is not supported by the cluster. Falling back to no compression", + compression_str + ); connection.config.compression = None; } } + + /* Send the STARTUP frame with all the requested options. */ let result = connection.startup(options).await?; match result { Response::Ready => {} @@ -1605,6 +1812,7 @@ pub(crate) async fn open_named_connection( } } + /* If this is a control connection, REGISTER to receive all event types. */ if connection.config.event_sender.is_some() { let all_event_types = vec![ EventType::TopologyChange, diff --git a/scylla/src/transport/mod.rs b/scylla/src/transport/mod.rs index a33943645d..620c1fafb7 100644 --- a/scylla/src/transport/mod.rs +++ b/scylla/src/transport/mod.rs @@ -19,6 +19,7 @@ pub mod speculative_execution; pub mod topology; pub use crate::frame::{Authenticator, Compression}; +pub use connection::SelfIdentity; pub use execution_profile::ExecutionProfile; pub use scylla_cql::errors; diff --git a/scylla/src/transport/session.rs b/scylla/src/transport/session.rs index c86d326336..80a23dace8 100644 --- a/scylla/src/transport/session.rs +++ b/scylla/src/transport/session.rs @@ -46,7 +46,7 @@ use super::node::KnownNode; use super::partitioner::PartitionerName; use super::query_result::MaybeFirstRowTypedError; use super::topology::UntranslatedPeer; -use super::NodeRef; +use super::{NodeRef, SelfIdentity}; use crate::cql_to_rust::FromRow; use crate::frame::response::cql_to_rust::FromRowError; use crate::frame::response::result; @@ -289,6 +289,10 @@ pub struct SessionConfig { /// for e.g: if they do not want unexpected traffic /// or they expect the topology to change frequently. pub cluster_metadata_refresh_interval: Duration, + + /// Driver and application self-identifying information, + /// to be sent to server in STARTUP message. + pub identity: SelfIdentity<'static>, } impl SessionConfig { @@ -335,6 +339,7 @@ impl SessionConfig { tracing_info_fetch_interval: Duration::from_millis(3), tracing_info_fetch_consistency: Consistency::One, cluster_metadata_refresh_interval: Duration::from_secs(60), + identity: SelfIdentity::default(), } } @@ -515,6 +520,7 @@ impl Session { keepalive_interval: config.keepalive_interval, keepalive_timeout: config.keepalive_timeout, tablet_sender: Some(tablet_sender), + identity: config.identity, }; let pool_config = PoolConfig { diff --git a/scylla/src/transport/session_builder.rs b/scylla/src/transport/session_builder.rs index e697bb6154..998803793f 100644 --- a/scylla/src/transport/session_builder.rs +++ b/scylla/src/transport/session_builder.rs @@ -1,5 +1,6 @@ //! SessionBuilder provides an easy way to create new Sessions +use super::connection::SelfIdentity; use super::errors::NewSessionError; use super::execution_profile::ExecutionProfileHandle; use super::session::{AddressTranslator, Session, SessionConfig}; @@ -942,6 +943,38 @@ impl GenericSessionBuilder { self.config.cluster_metadata_refresh_interval = interval; self } + + /// Set the custom identity of the driver/application/instance, + /// to be sent as options in STARTUP message. + /// + /// By default driver name and version are sent; + /// application name and version and client id are not sent. + /// + /// # Example + /// ``` + /// # use scylla::{Session, SessionBuilder}; + /// # use scylla::transport::SelfIdentity; + /// # async fn example() -> Result<(), Box> { + /// let (app_major, app_minor, app_patch) = (2, 1, 3); + /// let app_version = format!("{app_major}.{app_minor}.{app_patch}"); + /// + /// let session: Session = SessionBuilder::new() + /// .known_node("127.0.0.1:9042") + /// .custom_identity( + /// SelfIdentity::new() + /// .with_custom_driver_version("0.13.0-custom_build_17") + /// .with_application_name("my-app") + /// .with_application_version(app_version) + /// ) + /// .build() + /// .await?; + /// # Ok(()) + /// # } + /// ``` + pub fn custom_identity(mut self, identity: SelfIdentity<'static>) -> Self { + self.config.identity = identity; + self + } } /// Creates a [`SessionBuilder`] with default configuration, same as [`SessionBuilder::new`] diff --git a/scylla/tests/integration/main.rs b/scylla/tests/integration/main.rs index 06a2ab429a..ef190f1237 100644 --- a/scylla/tests/integration/main.rs +++ b/scylla/tests/integration/main.rs @@ -4,6 +4,7 @@ mod hygiene; mod lwt_optimisation; mod new_session; mod retries; +mod self_identity; mod shards; mod silent_prepare_query; mod skip_metadata_optimization; diff --git a/scylla/tests/integration/self_identity.rs b/scylla/tests/integration/self_identity.rs new file mode 100644 index 0000000000..cba46f7171 --- /dev/null +++ b/scylla/tests/integration/self_identity.rs @@ -0,0 +1,109 @@ +use crate::utils::{setup_tracing, test_with_3_node_cluster}; +use scylla::{Session, SessionBuilder}; +use scylla_cql::frame::request::options; +use scylla_cql::frame::types; +use std::sync::Arc; +use tokio::sync::mpsc; + +use scylla::transport::SelfIdentity; +use scylla_proxy::{ + Condition, ProxyError, Reaction, RequestOpcode, RequestReaction, RequestRule, ShardAwareness, + WorkerError, +}; + +#[tokio::test] +#[ntest::timeout(20000)] +#[cfg(not(scylla_cloud_tests))] +async fn self_identity_is_set_properly_in_startup_message() { + setup_tracing(); + + let application_name = "test_self_identity"; + let application_version = "42.2137.0"; + let client_id = "blue18"; + let custom_driver_name = "ScyllaDB Rust Driver - test run"; + let custom_driver_version = "2137.42.0"; + + let default_self_identity = SelfIdentity::new(); + + let full_self_identity = SelfIdentity::new() + .with_application_name(application_name) + .with_application_version(application_version) + .with_client_id(client_id) + .with_custom_driver_name(custom_driver_name) + .with_custom_driver_version(custom_driver_version); + + test_given_self_identity(default_self_identity).await; + test_given_self_identity(full_self_identity).await; +} + +async fn test_given_self_identity(self_identity: SelfIdentity<'static>) { + let res = test_with_3_node_cluster( + ShardAwareness::QueryNode, + |proxy_uris, translation_map, mut running_proxy| async move { + // We set up proxy, so that it informs us (via startup_rx) about driver's Startup message contents. + + let (startup_tx, mut startup_rx) = mpsc::unbounded_channel(); + + running_proxy.running_nodes[0].change_request_rules(Some(vec![RequestRule( + Condition::RequestOpcode(RequestOpcode::Startup), + RequestReaction::noop().with_feedback_when_performed(startup_tx), + )])); + + // DB preparation phase + let _session: Session = SessionBuilder::new() + .known_node(proxy_uris[0].as_str()) + .address_translator(Arc::new(translation_map)) + .custom_identity(self_identity.clone()) + .build() + .await + .unwrap(); + + let (startup_frame, _shard) = startup_rx.recv().await.unwrap(); + let startup_options = types::read_string_map(&mut &*startup_frame.body).unwrap(); + + for (option_key, facultative_option) in [ + ( + options::APPLICATION_NAME, + self_identity.get_application_name(), + ), + ( + options::APPLICATION_VERSION, + self_identity.get_application_version(), + ), + (options::CLIENT_ID, self_identity.get_client_id()), + ] { + assert_eq!( + startup_options.get(option_key).map(String::as_str), + facultative_option + ); + } + + for (option_key, default_mandatory_option, custom_mandatory_option) in [ + ( + options::DRIVER_NAME, + options::DEFAULT_DRIVER_NAME, + self_identity.get_custom_driver_name(), + ), + ( + options::DRIVER_VERSION, + options::DEFAULT_DRIVER_VERSION, + self_identity.get_custom_driver_version(), + ), + ] { + assert_eq!( + startup_options.get(option_key).map(String::as_str), + Some(custom_mandatory_option.unwrap_or(default_mandatory_option)) + ); + } + + running_proxy + }, + ) + .await; + + match res { + Ok(()) => (), + Err(ProxyError::Worker(WorkerError::DriverDisconnected(_))) => (), + Err(err) => panic!("{}", err), + } +}