diff --git a/src/cmap/establish.rs b/src/cmap/establish.rs index fb5291836..ea6e13d07 100644 --- a/src/cmap/establish.rs +++ b/src/cmap/establish.rs @@ -30,12 +30,17 @@ pub(crate) struct ConnectionEstablisher { tls_config: Option, connect_timeout: Duration, + + #[cfg(test)] + test_patch_reply: Option)>, } pub(crate) struct EstablisherOptions { handshake_options: HandshakerOptions, tls_options: Option, connect_timeout: Option, + #[cfg(test)] + pub(crate) test_patch_reply: Option)>, } impl EstablisherOptions { @@ -55,6 +60,8 @@ impl EstablisherOptions { }, tls_options: opts.tls_options(), connect_timeout: opts.connect_timeout, + #[cfg(test)] + test_patch_reply: None, } } } @@ -80,6 +87,8 @@ impl ConnectionEstablisher { handshaker, tls_config, connect_timeout, + #[cfg(test)] + test_patch_reply: options.test_patch_reply, }) } @@ -92,7 +101,7 @@ impl ConnectionEstablisher { } /// Establishes a connection. - pub(super) async fn establish_connection( + pub(crate) async fn establish_connection( &self, pending_connection: PendingConnection, credential: Option<&Credential>, @@ -106,7 +115,13 @@ impl ConnectionEstablisher { .map_err(|e| EstablishError::pre_hello(e, pool_gen.clone()))?; let mut connection = Connection::new_pooled(pending_connection, stream); - let handshake_result = self.handshaker.handshake(&mut connection, credential).await; + #[allow(unused_mut)] + let mut handshake_result = self.handshaker.handshake(&mut connection, credential).await; + #[cfg(test)] + if let Some(patch) = self.test_patch_reply { + patch(&mut handshake_result); + } + let handshake_result = handshake_result; // If the handshake response had a `serviceId` field, this is a connection to a load // balancer and must derive its generation from the service_generations map. diff --git a/src/test/spec.rs b/src/test/spec.rs index d4f8f0309..bb201dda6 100644 --- a/src/test/spec.rs +++ b/src/test/spec.rs @@ -8,6 +8,7 @@ mod connection_stepdown; mod crud; mod faas; mod gridfs; +mod handshake; mod index_management; #[cfg(feature = "dns-resolver")] mod initial_dns_seedlist_discovery; diff --git a/src/test/spec/handshake.rs b/src/test/spec/handshake.rs new file mode 100644 index 000000000..43212ad5b --- /dev/null +++ b/src/test/spec/handshake.rs @@ -0,0 +1,40 @@ +use std::time::Instant; + +use bson::oid::ObjectId; + +use crate::{ + cmap::{ + conn::PendingConnection, + establish::{ConnectionEstablisher, EstablisherOptions}, + }, + event::cmap::CmapEventEmitter, + test::get_client_options, +}; + +// Prose test 1: Test that the driver accepts an arbitrary auth mechanism +#[tokio::test] +async fn arbitrary_auth_mechanism() { + let client_options = get_client_options().await; + let mut options = EstablisherOptions::from_client_options(client_options); + options.test_patch_reply = Some(|reply| { + reply + .as_mut() + .unwrap() + .command_response + .sasl_supported_mechs + .get_or_insert_with(Vec::new) + .push("ArBiTrArY!".to_string()); + }); + let establisher = ConnectionEstablisher::new(options).unwrap(); + let pending = PendingConnection { + id: 0, + address: client_options.hosts[0].clone(), + generation: crate::cmap::PoolGeneration::normal(), + event_emitter: CmapEventEmitter::new(None, ObjectId::new()), + time_created: Instant::now(), + }; + establisher + .establish_connection(pending, None) + .await + .unwrap(); +} diff --git a/src/test/spec/unified_runner/test_runner.rs b/src/test/spec/unified_runner/test_runner.rs index de7b16020..1e71a89dc 100644 --- a/src/test/spec/unified_runner/test_runner.rs +++ b/src/test/spec/unified_runner/test_runner.rs @@ -746,7 +746,7 @@ impl TestRunner { fn fill_kms_placeholders( kms_provider_map: HashMap, ) -> crate::test::csfle::KmsProviderList { - use crate::{bson::doc, test::csfle::ALL_KMS_PROVIDERS}; + use crate::test::csfle::ALL_KMS_PROVIDERS; let placeholder = doc! { "$$placeholder": 1 }; let all_kms_providers = ALL_KMS_PROVIDERS.clone();