From 79fa2e1fa5d746c5791a12c5b6fa7ca5fd93f09f Mon Sep 17 00:00:00 2001 From: Isabel Atkinson Date: Tue, 27 Jun 2023 15:50:10 -0600 Subject: [PATCH] RUST-1420 Cache AWS credentials received from endpoints (#905) --- .evergreen/MSRV-Cargo.lock | 14 +++-- .evergreen/config.yml | 8 +-- .evergreen/run-aws-tests.sh | 2 +- src/bson_util/mod.rs | 11 ++++ src/client/auth/aws.rs | 100 ++++++++++++++++++++++++++++- src/test/auth_aws.rs | 122 +++++++++++++++++++++++++++++++++++- src/test/mod.rs | 1 + 7 files changed, 242 insertions(+), 16 deletions(-) diff --git a/.evergreen/MSRV-Cargo.lock b/.evergreen/MSRV-Cargo.lock index 46481024d..f1ffbeb78 100644 --- a/.evergreen/MSRV-Cargo.lock +++ b/.evergreen/MSRV-Cargo.lock @@ -269,15 +269,17 @@ dependencies = [ [[package]] name = "bson" -version = "2.3.0" -source = "git+https://github.com/mongodb/bson-rust?branch=main#0612667e1344f9aabc3592a2cee02a96ac1b76bc" +version = "2.6.0" +source = "git+https://github.com/mongodb/bson-rust?branch=main#b243db19b74745fcf52eacd9e4d34077186186d5" dependencies = [ "ahash", - "base64", + "base64 0.13.1", + "bitvec", "chrono", "hex", - "indexmap", - "lazy_static", + "indexmap 1.9.3", + "js-sys", + "once_cell 1.13.0", "rand", "serde", "serde_bytes", @@ -285,7 +287,7 @@ dependencies = [ "serde_with", "time", "uuid 0.8.2", - "uuid 1.1.2", + "uuid 1.4.0", ] [[package]] diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 712a69c63..325b67e09 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -167,7 +167,7 @@ functions: working_dir: "src" script: | ${PREPARE_SHELL} - ASYNC_RUNTIME=${ASYNC_RUNTIME} .evergreen/run-aws-tests.sh + ASYNC_RUNTIME=${ASYNC_RUNTIME} SKIP_CREDENTIAL_CACHING_TESTS=1 .evergreen/run-aws-tests.sh "run aws auth test with assume role credentials": - command: shell.exec @@ -205,7 +205,7 @@ functions: working_dir: "src" script: | ${PREPARE_SHELL} - ASYNC_RUNTIME=${ASYNC_RUNTIME} .evergreen/run-aws-tests.sh + ASYNC_RUNTIME=${ASYNC_RUNTIME} SKIP_CREDENTIAL_CACHING_TESTS=1 .evergreen/run-aws-tests.sh "run aws auth test with aws EC2 credentials": - command: shell.exec @@ -245,7 +245,7 @@ functions: working_dir: "src" script: | ${PREPARE_SHELL} - ASYNC_RUNTIME=${ASYNC_RUNTIME} PROJECT_DIRECTORY=${PROJECT_DIRECTORY} .evergreen/run-aws-tests.sh + ASYNC_RUNTIME=${ASYNC_RUNTIME} PROJECT_DIRECTORY=${PROJECT_DIRECTORY} SKIP_CREDENTIAL_CACHING_TESTS=1 .evergreen/run-aws-tests.sh "run aws auth test with aws credentials and session token as environment variables": - command: shell.exec @@ -267,7 +267,7 @@ functions: working_dir: "src" script: | ${PREPARE_SHELL} - ASYNC_RUNTIME=${ASYNC_RUNTIME} .evergreen/run-aws-tests.sh + ASYNC_RUNTIME=${ASYNC_RUNTIME} SKIP_CREDENTIAL_CACHING_TESTS=1 .evergreen/run-aws-tests.sh "run aws ECS auth test": - command: shell.exec diff --git a/.evergreen/run-aws-tests.sh b/.evergreen/run-aws-tests.sh index 22a24546f..0d8d03618 100755 --- a/.evergreen/run-aws-tests.sh +++ b/.evergreen/run-aws-tests.sh @@ -44,5 +44,5 @@ set -o errexit source ./.evergreen/configure-rust.sh -RUST_BACKTRACE=1 cargo test --features aws-auth auth_aws::auth_aws +RUST_BACKTRACE=1 cargo test --features aws-auth auth_aws RUST_BACKTRACE=1 cargo test --features aws-auth lambda_examples::auth::test_handler diff --git a/src/bson_util/mod.rs b/src/bson_util/mod.rs index a92969c78..2937bdd8e 100644 --- a/src/bson_util/mod.rs +++ b/src/bson_util/mod.rs @@ -235,6 +235,17 @@ pub(crate) fn serialize_result_error_as_string( .serialize(serializer) } +#[cfg(feature = "aws-auth")] +pub(crate) fn deserialize_datetime_option_from_double<'de, D>( + deserializer: D, +) -> std::result::Result, D::Error> +where + D: Deserializer<'de>, +{ + let millis = f64::deserialize(deserializer)? * 1000.0; + Ok(Some(bson::DateTime::from_millis(millis as i64))) +} + #[cfg(test)] mod test { use crate::bson_util::num_decimal_digits; diff --git a/src/client/auth/aws.rs b/src/client/auth/aws.rs index c4730bf6a..ab49817ef 100644 --- a/src/client/auth/aws.rs +++ b/src/client/auth/aws.rs @@ -1,13 +1,16 @@ -use std::{fs::File, io::Read}; +use std::{fs::File, io::Read, time::Duration}; use chrono::{offset::Utc, DateTime}; use hmac::Hmac; +use lazy_static::lazy_static; use rand::distributions::{Alphanumeric, DistString}; use serde::Deserialize; use sha2::{Digest, Sha256}; +use tokio::sync::Mutex; use crate::{ bson::{doc, rawdoc, spec::BinarySubtype, Binary, Bson, Document}, + bson_util::deserialize_datetime_option_from_double, client::{ auth::{ self, @@ -27,12 +30,31 @@ const AWS_EC2_IP: &str = "169.254.169.254"; const AWS_LONG_DATE_FMT: &str = "%Y%m%dT%H%M%SZ"; const MECH_NAME: &str = "MONGODB-AWS"; +lazy_static! { + static ref CACHED_CREDENTIAL: Mutex> = Mutex::new(None); +} + /// Performs MONGODB-AWS authentication for a given stream. pub(super) async fn authenticate_stream( conn: &mut Connection, credential: &Credential, server_api: Option<&ServerApi>, http_client: &HttpClient, +) -> Result<()> { + match authenticate_stream_inner(conn, credential, server_api, http_client).await { + Ok(()) => Ok(()), + Err(error) => { + *CACHED_CREDENTIAL.lock().await = None; + Err(error) + } + } +} + +async fn authenticate_stream_inner( + conn: &mut Connection, + credential: &Credential, + server_api: Option<&ServerApi>, + http_client: &HttpClient, ) -> Result<()> { let source = match credential.source.as_deref() { Some("$external") | None => "$external", @@ -68,7 +90,23 @@ pub(super) async fn authenticate_stream( let server_first = ServerFirst::parse(server_first_response.auth_response_body(MECH_NAME)?)?; server_first.validate(&nonce)?; - let aws_credential = AwsCredential::get(credential, http_client).await?; + let aws_credential = { + // Limit scope of this variable to avoid holding onto the lock for the duration of + // authenticate_stream. + let cached_credential = CACHED_CREDENTIAL.lock().await; + match *cached_credential { + Some(ref aws_credential) if !aws_credential.is_expired() => aws_credential.clone(), + _ => { + // From the spec: the driver MUST not place a lock on making a request. + drop(cached_credential); + let aws_credential = AwsCredential::get(credential, http_client).await?; + if aws_credential.expiration.is_some() { + *CACHED_CREDENTIAL.lock().await = Some(aws_credential.clone()); + } + aws_credential + } + } + }; let date = Utc::now(); @@ -117,7 +155,7 @@ pub(super) async fn authenticate_stream( } /// Contains the credentials for MONGODB-AWS authentication. -#[derive(Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize)] #[serde(rename_all = "PascalCase")] pub(crate) struct AwsCredential { access_key_id: String, @@ -126,6 +164,9 @@ pub(crate) struct AwsCredential { #[serde(alias = "Token")] session_token: Option, + + #[serde(default, deserialize_with = "deserialize_datetime_option_from_double")] + expiration: Option, } impl AwsCredential { @@ -157,6 +198,7 @@ impl AwsCredential { access_key_id: access_key, secret_access_key: secret_key, session_token, + expiration: None, }); } @@ -419,6 +461,16 @@ impl AwsCredential { pub(crate) fn session_token(&self) -> Option<&str> { self.session_token.as_deref() } + + fn is_expired(&self) -> bool { + match self.expiration { + Some(expiration) => { + expiration.saturating_duration_since(bson::DateTime::now()) + < Duration::from_secs(5 * 60) + } + None => true, + } + } } /// The response from the server to the `saslStart` command in a MONGODB-AWS authentication attempt. @@ -496,3 +548,45 @@ impl ServerFirst { } } } + +#[cfg(test)] +pub(crate) mod test_utils { + use super::{AwsCredential, CACHED_CREDENTIAL}; + + pub(crate) async fn cached_credential() -> Option { + CACHED_CREDENTIAL.lock().await.clone() + } + + pub(crate) async fn clear_cached_credential() { + *CACHED_CREDENTIAL.lock().await = None; + } + + pub(crate) async fn poison_cached_credential() { + CACHED_CREDENTIAL + .lock() + .await + .as_mut() + .unwrap() + .access_key_id = "bad".into(); + } + + pub(crate) async fn cached_access_key_id() -> String { + cached_credential().await.unwrap().access_key_id + } + + pub(crate) async fn cached_secret_access_key() -> String { + cached_credential().await.unwrap().secret_access_key + } + + pub(crate) async fn cached_session_token() -> Option { + cached_credential().await.unwrap().session_token + } + + pub(crate) async fn cached_expiration() -> bson::DateTime { + cached_credential().await.unwrap().expiration.unwrap() + } + + pub(crate) async fn set_cached_expiration(expiration: bson::DateTime) { + CACHED_CREDENTIAL.lock().await.as_mut().unwrap().expiration = Some(expiration); + } +} diff --git a/src/test/auth_aws.rs b/src/test/auth_aws.rs index 95b52ac39..bed583baa 100644 --- a/src/test/auth_aws.rs +++ b/src/test/auth_aws.rs @@ -1,5 +1,8 @@ -use bson::Document; -use tokio::sync::RwLockReadGuard; +use std::env::{remove_var, set_var, var}; + +use tokio::sync::{RwLockReadGuard, RwLockWriteGuard}; + +use crate::{bson::Document, client::auth::aws::test_utils::*, test::DEFAULT_URI, Client}; use super::{TestClient, LOCK}; @@ -13,3 +16,118 @@ async fn auth_aws() { coll.find_one(None, None).await.unwrap(); } + +// The TestClient performs operations upon creation that trigger authentication, so the credential +// caching tests use a regular client instead to avoid that noise. +async fn get_client() -> Client { + Client::with_uri_str(DEFAULT_URI.clone()).await.unwrap() +} + +#[cfg_attr(feature = "tokio-runtime", tokio::test)] +#[cfg_attr(feature = "async-std-runtime", async_std::test)] +async fn credential_caching() { + // This test should only be run when authenticating using AWS endpoints. + if var("SKIP_CREDENTIAL_CACHING_TESTS").is_ok() { + return; + } + + let _guard: RwLockWriteGuard<()> = LOCK.run_exclusively().await; + + clear_cached_credential().await; + + let client = get_client().await; + let coll = client.database("aws").collection::("somecoll"); + coll.find_one(None, None).await.unwrap(); + assert!(cached_credential().await.is_some()); + + let now = bson::DateTime::now(); + set_cached_expiration(now).await; + + let client = get_client().await; + let coll = client.database("aws").collection::("somecoll"); + coll.find_one(None, None).await.unwrap(); + assert!(cached_credential().await.is_some()); + assert!(cached_expiration().await > now); + + poison_cached_credential().await; + + let client = get_client().await; + let coll = client.database("aws").collection::("somecoll"); + match coll.find_one(None, None).await { + Ok(_) => panic!( + "find one should have failed with authentication error due to poisoned cached \ + credential" + ), + Err(error) => assert!(error.is_auth_error()), + } + assert!(cached_credential().await.is_none()); + + coll.find_one(None, None).await.unwrap(); + assert!(cached_credential().await.is_some()); +} + +#[cfg_attr(feature = "tokio-runtime", tokio::test)] +#[cfg_attr(feature = "async-std-runtime", async_std::test)] +async fn credential_caching_environment_vars() { + // This test should only be run when authenticating using AWS endpoints. + if var("SKIP_CREDENTIAL_CACHING_TESTS").is_ok() { + return; + } + + let _guard: RwLockWriteGuard<()> = LOCK.run_exclusively().await; + + clear_cached_credential().await; + + let client = get_client().await; + let coll = client.database("aws").collection::("somecoll"); + coll.find_one(None, None).await.unwrap(); + assert!(cached_credential().await.is_some()); + + set_var("AWS_ACCESS_KEY_ID", cached_access_key_id().await); + set_var("AWS_SECRET_ACCESS_KEY", cached_secret_access_key().await); + if let Some(session_token) = cached_session_token().await { + set_var("AWS_SESSION_TOKEN", session_token); + } + clear_cached_credential().await; + + let client = get_client().await; + let coll = client.database("aws").collection::("somecoll"); + coll.find_one(None, None).await.unwrap(); + assert!(cached_credential().await.is_none()); + + set_var("AWS_ACCESS_KEY_ID", "bad"); + set_var("AWS_SECRET_ACCESS_KEY", "bad"); + set_var("AWS_SESSION_TOKEN", "bad"); + + let client = get_client().await; + let coll = client.database("aws").collection::("somecoll"); + match coll.find_one(None, None).await { + Ok(_) => panic!( + "find one should have failed with authentication error due to poisoned environment \ + variables" + ), + Err(error) => assert!(error.is_auth_error()), + } + + remove_var("AWS_ACCESS_KEY_ID"); + remove_var("AWS_SECRET_ACCESS_KEY"); + remove_var("AWS_SESSION_TOKEN"); + clear_cached_credential().await; + + let client = get_client().await; + let coll = client.database("aws").collection::("somecoll"); + coll.find_one(None, None).await.unwrap(); + assert!(cached_credential().await.is_some()); + + set_var("AWS_ACCESS_KEY_ID", "bad"); + set_var("AWS_SECRET_ACCESS_KEY", "bad"); + set_var("AWS_SESSION_TOKEN", "bad"); + + let client = get_client().await; + let coll = client.database("aws").collection::("somecoll"); + coll.find_one(None, None).await.unwrap(); + + remove_var("AWS_ACCESS_KEY_ID"); + remove_var("AWS_SECRET_ACCESS_KEY"); + remove_var("AWS_SESSION_TOKEN"); +} diff --git a/src/test/mod.rs b/src/test/mod.rs index 4049b4171..b17a1b156 100644 --- a/src/test/mod.rs +++ b/src/test/mod.rs @@ -1,6 +1,7 @@ #[cfg(all(not(feature = "sync"), not(feature = "tokio-sync")))] mod atlas_connectivity; mod atlas_planned_maintenance_testing; +#[cfg(feature = "aws-auth")] mod auth_aws; mod change_stream; mod client;