Skip to content

Commit

Permalink
RUST-1846 Test that saslSupportedMechs can contain arbitrary strings (
Browse files Browse the repository at this point in the history
  • Loading branch information
abr-egn authored Jun 25, 2024
1 parent 08e2923 commit 45736f2
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 3 deletions.
19 changes: 17 additions & 2 deletions src/cmap/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,17 @@ pub(crate) struct ConnectionEstablisher {
tls_config: Option<TlsConfig>,

connect_timeout: Duration,

#[cfg(test)]
test_patch_reply: Option<fn(&mut Result<HelloReply>)>,
}

pub(crate) struct EstablisherOptions {
handshake_options: HandshakerOptions,
tls_options: Option<TlsOptions>,
connect_timeout: Option<Duration>,
#[cfg(test)]
pub(crate) test_patch_reply: Option<fn(&mut Result<HelloReply>)>,
}

impl EstablisherOptions {
Expand All @@ -55,6 +60,8 @@ impl EstablisherOptions {
},
tls_options: opts.tls_options(),
connect_timeout: opts.connect_timeout,
#[cfg(test)]
test_patch_reply: None,
}
}
}
Expand All @@ -80,6 +87,8 @@ impl ConnectionEstablisher {
handshaker,
tls_config,
connect_timeout,
#[cfg(test)]
test_patch_reply: options.test_patch_reply,
})
}

Expand All @@ -92,7 +101,7 @@ impl ConnectionEstablisher {
}

/// Establishes a connection.
pub(super) async fn establish_connection(
pub(crate) async fn establish_connection(
&self,
pending_connection: PendingConnection,
credential: Option<&Credential>,
Expand All @@ -106,7 +115,13 @@ impl ConnectionEstablisher {
.map_err(|e| EstablishError::pre_hello(e, pool_gen.clone()))?;

let mut connection = Connection::new_pooled(pending_connection, stream);
let handshake_result = self.handshaker.handshake(&mut connection, credential).await;
#[allow(unused_mut)]
let mut handshake_result = self.handshaker.handshake(&mut connection, credential).await;
#[cfg(test)]
if let Some(patch) = self.test_patch_reply {
patch(&mut handshake_result);
}
let handshake_result = handshake_result;

// If the handshake response had a `serviceId` field, this is a connection to a load
// balancer and must derive its generation from the service_generations map.
Expand Down
1 change: 1 addition & 0 deletions src/test/spec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mod connection_stepdown;
mod crud;
mod faas;
mod gridfs;
mod handshake;
mod index_management;
#[cfg(feature = "dns-resolver")]
mod initial_dns_seedlist_discovery;
Expand Down
40 changes: 40 additions & 0 deletions src/test/spec/handshake.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use std::time::Instant;

use bson::oid::ObjectId;

use crate::{
cmap::{
conn::PendingConnection,
establish::{ConnectionEstablisher, EstablisherOptions},
},
event::cmap::CmapEventEmitter,
test::get_client_options,
};

// Prose test 1: Test that the driver accepts an arbitrary auth mechanism
#[tokio::test]
async fn arbitrary_auth_mechanism() {
let client_options = get_client_options().await;
let mut options = EstablisherOptions::from_client_options(client_options);
options.test_patch_reply = Some(|reply| {
reply
.as_mut()
.unwrap()
.command_response
.sasl_supported_mechs
.get_or_insert_with(Vec::new)
.push("ArBiTrArY!".to_string());
});
let establisher = ConnectionEstablisher::new(options).unwrap();
let pending = PendingConnection {
id: 0,
address: client_options.hosts[0].clone(),
generation: crate::cmap::PoolGeneration::normal(),
event_emitter: CmapEventEmitter::new(None, ObjectId::new()),
time_created: Instant::now(),
};
establisher
.establish_connection(pending, None)
.await
.unwrap();
}
2 changes: 1 addition & 1 deletion src/test/spec/unified_runner/test_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ impl TestRunner {
fn fill_kms_placeholders(
kms_provider_map: HashMap<mongocrypt::ctx::KmsProvider, Document>,
) -> crate::test::csfle::KmsProviderList {
use crate::{bson::doc, test::csfle::ALL_KMS_PROVIDERS};
use crate::test::csfle::ALL_KMS_PROVIDERS;

let placeholder = doc! { "$$placeholder": 1 };
let all_kms_providers = ALL_KMS_PROVIDERS.clone();
Expand Down

0 comments on commit 45736f2

Please sign in to comment.