Skip to content

Commit

Permalink
RUST-1420 Cache AWS credentials received from endpoints (#905)
Browse files Browse the repository at this point in the history
  • Loading branch information
isabelatkinson authored Jun 27, 2023
1 parent a018a87 commit 79fa2e1
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 16 deletions.
14 changes: 8 additions & 6 deletions .evergreen/MSRV-Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions .evergreen/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ functions:
working_dir: "src"
script: |
${PREPARE_SHELL}
ASYNC_RUNTIME=${ASYNC_RUNTIME} .evergreen/run-aws-tests.sh
ASYNC_RUNTIME=${ASYNC_RUNTIME} SKIP_CREDENTIAL_CACHING_TESTS=1 .evergreen/run-aws-tests.sh
"run aws auth test with assume role credentials":
- command: shell.exec
Expand Down Expand Up @@ -205,7 +205,7 @@ functions:
working_dir: "src"
script: |
${PREPARE_SHELL}
ASYNC_RUNTIME=${ASYNC_RUNTIME} .evergreen/run-aws-tests.sh
ASYNC_RUNTIME=${ASYNC_RUNTIME} SKIP_CREDENTIAL_CACHING_TESTS=1 .evergreen/run-aws-tests.sh
"run aws auth test with aws EC2 credentials":
- command: shell.exec
Expand Down Expand Up @@ -245,7 +245,7 @@ functions:
working_dir: "src"
script: |
${PREPARE_SHELL}
ASYNC_RUNTIME=${ASYNC_RUNTIME} PROJECT_DIRECTORY=${PROJECT_DIRECTORY} .evergreen/run-aws-tests.sh
ASYNC_RUNTIME=${ASYNC_RUNTIME} PROJECT_DIRECTORY=${PROJECT_DIRECTORY} SKIP_CREDENTIAL_CACHING_TESTS=1 .evergreen/run-aws-tests.sh
"run aws auth test with aws credentials and session token as environment variables":
- command: shell.exec
Expand All @@ -267,7 +267,7 @@ functions:
working_dir: "src"
script: |
${PREPARE_SHELL}
ASYNC_RUNTIME=${ASYNC_RUNTIME} .evergreen/run-aws-tests.sh
ASYNC_RUNTIME=${ASYNC_RUNTIME} SKIP_CREDENTIAL_CACHING_TESTS=1 .evergreen/run-aws-tests.sh
"run aws ECS auth test":
- command: shell.exec
Expand Down
2 changes: 1 addition & 1 deletion .evergreen/run-aws-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ set -o errexit

source ./.evergreen/configure-rust.sh

RUST_BACKTRACE=1 cargo test --features aws-auth auth_aws::auth_aws
RUST_BACKTRACE=1 cargo test --features aws-auth auth_aws
RUST_BACKTRACE=1 cargo test --features aws-auth lambda_examples::auth::test_handler
11 changes: 11 additions & 0 deletions src/bson_util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,17 @@ pub(crate) fn serialize_result_error_as_string<S: Serializer, T: Serialize>(
.serialize(serializer)
}

#[cfg(feature = "aws-auth")]
pub(crate) fn deserialize_datetime_option_from_double<'de, D>(
deserializer: D,
) -> std::result::Result<Option<bson::DateTime>, D::Error>
where
D: Deserializer<'de>,
{
let millis = f64::deserialize(deserializer)? * 1000.0;
Ok(Some(bson::DateTime::from_millis(millis as i64)))
}

#[cfg(test)]
mod test {
use crate::bson_util::num_decimal_digits;
Expand Down
100 changes: 97 additions & 3 deletions src/client/auth/aws.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
use std::{fs::File, io::Read};
use std::{fs::File, io::Read, time::Duration};

use chrono::{offset::Utc, DateTime};
use hmac::Hmac;
use lazy_static::lazy_static;
use rand::distributions::{Alphanumeric, DistString};
use serde::Deserialize;
use sha2::{Digest, Sha256};
use tokio::sync::Mutex;

use crate::{
bson::{doc, rawdoc, spec::BinarySubtype, Binary, Bson, Document},
bson_util::deserialize_datetime_option_from_double,
client::{
auth::{
self,
Expand All @@ -27,12 +30,31 @@ const AWS_EC2_IP: &str = "169.254.169.254";
const AWS_LONG_DATE_FMT: &str = "%Y%m%dT%H%M%SZ";
const MECH_NAME: &str = "MONGODB-AWS";

lazy_static! {
static ref CACHED_CREDENTIAL: Mutex<Option<AwsCredential>> = Mutex::new(None);
}

/// Performs MONGODB-AWS authentication for a given stream.
pub(super) async fn authenticate_stream(
conn: &mut Connection,
credential: &Credential,
server_api: Option<&ServerApi>,
http_client: &HttpClient,
) -> Result<()> {
match authenticate_stream_inner(conn, credential, server_api, http_client).await {
Ok(()) => Ok(()),
Err(error) => {
*CACHED_CREDENTIAL.lock().await = None;
Err(error)
}
}
}

async fn authenticate_stream_inner(
conn: &mut Connection,
credential: &Credential,
server_api: Option<&ServerApi>,
http_client: &HttpClient,
) -> Result<()> {
let source = match credential.source.as_deref() {
Some("$external") | None => "$external",
Expand Down Expand Up @@ -68,7 +90,23 @@ pub(super) async fn authenticate_stream(
let server_first = ServerFirst::parse(server_first_response.auth_response_body(MECH_NAME)?)?;
server_first.validate(&nonce)?;

let aws_credential = AwsCredential::get(credential, http_client).await?;
let aws_credential = {
// Limit scope of this variable to avoid holding onto the lock for the duration of
// authenticate_stream.
let cached_credential = CACHED_CREDENTIAL.lock().await;
match *cached_credential {
Some(ref aws_credential) if !aws_credential.is_expired() => aws_credential.clone(),
_ => {
// From the spec: the driver MUST not place a lock on making a request.
drop(cached_credential);
let aws_credential = AwsCredential::get(credential, http_client).await?;
if aws_credential.expiration.is_some() {
*CACHED_CREDENTIAL.lock().await = Some(aws_credential.clone());
}
aws_credential
}
}
};

let date = Utc::now();

Expand Down Expand Up @@ -117,7 +155,7 @@ pub(super) async fn authenticate_stream(
}

/// Contains the credentials for MONGODB-AWS authentication.
#[derive(Debug, Deserialize)]
#[derive(Clone, Debug, Deserialize)]
#[serde(rename_all = "PascalCase")]
pub(crate) struct AwsCredential {
access_key_id: String,
Expand All @@ -126,6 +164,9 @@ pub(crate) struct AwsCredential {

#[serde(alias = "Token")]
session_token: Option<String>,

#[serde(default, deserialize_with = "deserialize_datetime_option_from_double")]
expiration: Option<bson::DateTime>,
}

impl AwsCredential {
Expand Down Expand Up @@ -157,6 +198,7 @@ impl AwsCredential {
access_key_id: access_key,
secret_access_key: secret_key,
session_token,
expiration: None,
});
}

Expand Down Expand Up @@ -419,6 +461,16 @@ impl AwsCredential {
pub(crate) fn session_token(&self) -> Option<&str> {
self.session_token.as_deref()
}

fn is_expired(&self) -> bool {
match self.expiration {
Some(expiration) => {
expiration.saturating_duration_since(bson::DateTime::now())
< Duration::from_secs(5 * 60)
}
None => true,
}
}
}

/// The response from the server to the `saslStart` command in a MONGODB-AWS authentication attempt.
Expand Down Expand Up @@ -496,3 +548,45 @@ impl ServerFirst {
}
}
}

#[cfg(test)]
pub(crate) mod test_utils {
use super::{AwsCredential, CACHED_CREDENTIAL};

pub(crate) async fn cached_credential() -> Option<AwsCredential> {
CACHED_CREDENTIAL.lock().await.clone()
}

pub(crate) async fn clear_cached_credential() {
*CACHED_CREDENTIAL.lock().await = None;
}

pub(crate) async fn poison_cached_credential() {
CACHED_CREDENTIAL
.lock()
.await
.as_mut()
.unwrap()
.access_key_id = "bad".into();
}

pub(crate) async fn cached_access_key_id() -> String {
cached_credential().await.unwrap().access_key_id
}

pub(crate) async fn cached_secret_access_key() -> String {
cached_credential().await.unwrap().secret_access_key
}

pub(crate) async fn cached_session_token() -> Option<String> {
cached_credential().await.unwrap().session_token
}

pub(crate) async fn cached_expiration() -> bson::DateTime {
cached_credential().await.unwrap().expiration.unwrap()
}

pub(crate) async fn set_cached_expiration(expiration: bson::DateTime) {
CACHED_CREDENTIAL.lock().await.as_mut().unwrap().expiration = Some(expiration);
}
}
122 changes: 120 additions & 2 deletions src/test/auth_aws.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use bson::Document;
use tokio::sync::RwLockReadGuard;
use std::env::{remove_var, set_var, var};

use tokio::sync::{RwLockReadGuard, RwLockWriteGuard};

use crate::{bson::Document, client::auth::aws::test_utils::*, test::DEFAULT_URI, Client};

use super::{TestClient, LOCK};

Expand All @@ -13,3 +16,118 @@ async fn auth_aws() {

coll.find_one(None, None).await.unwrap();
}

// The TestClient performs operations upon creation that trigger authentication, so the credential
// caching tests use a regular client instead to avoid that noise.
async fn get_client() -> Client {
Client::with_uri_str(DEFAULT_URI.clone()).await.unwrap()
}

#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn credential_caching() {
// This test should only be run when authenticating using AWS endpoints.
if var("SKIP_CREDENTIAL_CACHING_TESTS").is_ok() {
return;
}

let _guard: RwLockWriteGuard<()> = LOCK.run_exclusively().await;

clear_cached_credential().await;

let client = get_client().await;
let coll = client.database("aws").collection::<Document>("somecoll");
coll.find_one(None, None).await.unwrap();
assert!(cached_credential().await.is_some());

let now = bson::DateTime::now();
set_cached_expiration(now).await;

let client = get_client().await;
let coll = client.database("aws").collection::<Document>("somecoll");
coll.find_one(None, None).await.unwrap();
assert!(cached_credential().await.is_some());
assert!(cached_expiration().await > now);

poison_cached_credential().await;

let client = get_client().await;
let coll = client.database("aws").collection::<Document>("somecoll");
match coll.find_one(None, None).await {
Ok(_) => panic!(
"find one should have failed with authentication error due to poisoned cached \
credential"
),
Err(error) => assert!(error.is_auth_error()),
}
assert!(cached_credential().await.is_none());

coll.find_one(None, None).await.unwrap();
assert!(cached_credential().await.is_some());
}

#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn credential_caching_environment_vars() {
// This test should only be run when authenticating using AWS endpoints.
if var("SKIP_CREDENTIAL_CACHING_TESTS").is_ok() {
return;
}

let _guard: RwLockWriteGuard<()> = LOCK.run_exclusively().await;

clear_cached_credential().await;

let client = get_client().await;
let coll = client.database("aws").collection::<Document>("somecoll");
coll.find_one(None, None).await.unwrap();
assert!(cached_credential().await.is_some());

set_var("AWS_ACCESS_KEY_ID", cached_access_key_id().await);
set_var("AWS_SECRET_ACCESS_KEY", cached_secret_access_key().await);
if let Some(session_token) = cached_session_token().await {
set_var("AWS_SESSION_TOKEN", session_token);
}
clear_cached_credential().await;

let client = get_client().await;
let coll = client.database("aws").collection::<Document>("somecoll");
coll.find_one(None, None).await.unwrap();
assert!(cached_credential().await.is_none());

set_var("AWS_ACCESS_KEY_ID", "bad");
set_var("AWS_SECRET_ACCESS_KEY", "bad");
set_var("AWS_SESSION_TOKEN", "bad");

let client = get_client().await;
let coll = client.database("aws").collection::<Document>("somecoll");
match coll.find_one(None, None).await {
Ok(_) => panic!(
"find one should have failed with authentication error due to poisoned environment \
variables"
),
Err(error) => assert!(error.is_auth_error()),
}

remove_var("AWS_ACCESS_KEY_ID");
remove_var("AWS_SECRET_ACCESS_KEY");
remove_var("AWS_SESSION_TOKEN");
clear_cached_credential().await;

let client = get_client().await;
let coll = client.database("aws").collection::<Document>("somecoll");
coll.find_one(None, None).await.unwrap();
assert!(cached_credential().await.is_some());

set_var("AWS_ACCESS_KEY_ID", "bad");
set_var("AWS_SECRET_ACCESS_KEY", "bad");
set_var("AWS_SESSION_TOKEN", "bad");

let client = get_client().await;
let coll = client.database("aws").collection::<Document>("somecoll");
coll.find_one(None, None).await.unwrap();

remove_var("AWS_ACCESS_KEY_ID");
remove_var("AWS_SECRET_ACCESS_KEY");
remove_var("AWS_SESSION_TOKEN");
}
1 change: 1 addition & 0 deletions src/test/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#[cfg(all(not(feature = "sync"), not(feature = "tokio-sync")))]
mod atlas_connectivity;
mod atlas_planned_maintenance_testing;
#[cfg(feature = "aws-auth")]
mod auth_aws;
mod change_stream;
mod client;
Expand Down

0 comments on commit 79fa2e1

Please sign in to comment.