diff --git a/Cargo.toml b/Cargo.toml index 4f9adfd..943c710 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -100,9 +100,11 @@ aws-sigv4 = "1.2.0" criterion = { version = "0.5", features = ["async_tokio", "html_reports"] } dotenv = "0.15" env_logger = "0.11" +macro_rules_attribute = "0.2.0" once_cell = "1" pretty_assertions = "1.3" reqwest = { version = "0.12", features = ["blocking", "json"] } temp-env = "0.3" tempfile = "3.8" +test-case = "3.3.1" tokio = { version = "1", features = ["full"] } diff --git a/src/aws/v4.rs b/src/aws/v4.rs index 5718152..97c7928 100644 --- a/src/aws/v4.rs +++ b/src/aws/v4.rs @@ -390,6 +390,7 @@ mod tests { use aws_sigv4::http_request::SigningSettings; use aws_sigv4::sign::v4; use http::header; + use macro_rules_attribute::apply; use reqwest::Client; use super::super::AwsDefaultLoader; @@ -520,17 +521,18 @@ mod tests { req } - fn test_cases() -> &'static [fn() -> http::Request<&'static str>] { - &[ - test_get_request, - test_get_request_with_sse, - test_get_request_with_query, - test_get_request_virtual_host, - test_get_request_with_query_virtual_host, - test_put_request, - test_put_request_virtual_host, - test_put_request_with_body_digest, - ] + macro_rules! test_cases { + ($($tt:tt)*) => { + #[test_case::test_case(test_get_request)] + #[test_case::test_case(test_get_request_with_sse)] + #[test_case::test_case(test_get_request_with_query)] + #[test_case::test_case(test_get_request_virtual_host)] + #[test_case::test_case(test_get_request_with_query_virtual_host)] + #[test_case::test_case(test_put_request)] + #[test_case::test_case(test_put_request_virtual_host)] + #[test_case::test_case(test_put_request_with_body_digest)] + $($tt)* + }; } fn compare_request(name: &str, l: &http::Request<&str>, r: &http::Request<&str>) { @@ -568,326 +570,324 @@ mod tests { assert_eq!(format_query(l), format_query(r), "{name} query mismatch"); } + #[apply(test_cases)] #[tokio::test] - async fn test_calculate() -> Result<()> { + async fn test_calculate(req_fn: fn() -> http::Request<&'static str>) -> Result<()> { let _ = env_logger::builder().is_test(true).try_init(); - for req_fn in test_cases() { - let mut req = req_fn(); - let name = format!( - "{} {} {:?}", - req.method(), - req.uri().path(), - req.uri().query(), - ); - let now = now(); - - let mut ss = SigningSettings::default(); - ss.percent_encoding_mode = PercentEncodingMode::Double; - ss.payload_checksum_kind = PayloadChecksumKind::XAmzSha256; - let id = Credentials::new( - "access_key_id", - "secret_access_key", - None, - None, - "hardcoded-credentials", - ) - .into(); - let sp = v4::SigningParams::builder() - .identity(&id) - .region("test") - .name("s3") - .time(SystemTime::from(now)) - .settings(ss) - .build() - .expect("signing params must be valid"); - - let mut body = SignableBody::UnsignedPayload; - if req.headers().get(X_AMZ_CONTENT_SHA_256).is_some() { - body = SignableBody::Bytes(req.body().as_bytes()); - } + let mut req = req_fn(); + let name = format!( + "{} {} {:?}", + req.method(), + req.uri().path(), + req.uri().query(), + ); + let now = now(); + + let mut ss = SigningSettings::default(); + ss.percent_encoding_mode = PercentEncodingMode::Double; + ss.payload_checksum_kind = PayloadChecksumKind::XAmzSha256; + let id = Credentials::new( + "access_key_id", + "secret_access_key", + None, + None, + "hardcoded-credentials", + ) + .into(); + let sp = v4::SigningParams::builder() + .identity(&id) + .region("test") + .name("s3") + .time(SystemTime::from(now)) + .settings(ss) + .build() + .expect("signing params must be valid"); + + let mut body = SignableBody::UnsignedPayload; + if req.headers().get(X_AMZ_CONTENT_SHA_256).is_some() { + body = SignableBody::Bytes(req.body().as_bytes()); + } - let output = aws_sigv4::http_request::sign( - SignableRequest::new( - req.method().as_str(), - req.uri().to_string(), - req.headers() - .iter() - .map(|(k, v)| (k.as_str(), std::str::from_utf8(v.as_bytes()).unwrap())), - body, - ) - .unwrap(), - &sp.into(), + let output = aws_sigv4::http_request::sign( + SignableRequest::new( + req.method().as_str(), + req.uri().to_string(), + req.headers() + .iter() + .map(|(k, v)| (k.as_str(), std::str::from_utf8(v.as_bytes()).unwrap())), + body, ) - .expect("signing must succeed"); - let (aws_sig, _) = output.into_parts(); - aws_sig.apply_to_request_http1x(&mut req); - let expected_req = req; - - let mut req = req_fn(); - - let loader = AwsDefaultLoader::new( - Client::new(), - AwsConfig { - access_key_id: Some("access_key_id".to_string()), - secret_access_key: Some("secret_access_key".to_string()), - ..Default::default() - }, - ); - let cred = loader.load().await?.unwrap(); + .unwrap(), + &sp.into(), + ) + .expect("signing must succeed"); + let (aws_sig, _) = output.into_parts(); + aws_sig.apply_to_request_http1x(&mut req); + let expected_req = req; + + let mut req = req_fn(); + + let loader = AwsDefaultLoader::new( + Client::new(), + AwsConfig { + access_key_id: Some("access_key_id".to_string()), + secret_access_key: Some("secret_access_key".to_string()), + ..Default::default() + }, + ); + let cred = loader.load().await?.unwrap(); - let signer = Signer::new("s3", "test").time(now); - signer.sign(&mut req, &cred).expect("must apply success"); + let signer = Signer::new("s3", "test").time(now); + signer.sign(&mut req, &cred).expect("must apply success"); - let actual_req = req; + let actual_req = req; - compare_request(&name, &expected_req, &actual_req); - } + compare_request(&name, &expected_req, &actual_req); Ok(()) } + #[apply(test_cases)] #[tokio::test] - async fn test_calculate_in_query() -> Result<()> { + async fn test_calculate_in_query(req_fn: fn() -> http::Request<&'static str>) -> Result<()> { let _ = env_logger::builder().is_test(true).try_init(); - for req_fn in test_cases() { - let mut req = req_fn(); - let name = format!( - "{} {} {:?}", - req.method(), - req.uri().path(), - req.uri().query(), - ); - let now = now(); - - let mut ss = SigningSettings::default(); - ss.percent_encoding_mode = PercentEncodingMode::Double; - ss.payload_checksum_kind = PayloadChecksumKind::XAmzSha256; - ss.signature_location = SignatureLocation::QueryParams; - ss.expires_in = Some(std::time::Duration::from_secs(3600)); - let id = Credentials::new( - "access_key_id", - "secret_access_key", - None, - None, - "hardcoded-credentials", - ) - .into(); - let sp = v4::SigningParams::builder() - .identity(&id) - .region("test") - .name("s3") - .time(SystemTime::from(now)) - .settings(ss) - .build() - .expect("signing params must be valid"); - - let mut body = SignableBody::UnsignedPayload; - if req.headers().get(X_AMZ_CONTENT_SHA_256).is_some() { - body = SignableBody::Bytes(req.body().as_bytes()); - } + let mut req = req_fn(); + let name = format!( + "{} {} {:?}", + req.method(), + req.uri().path(), + req.uri().query(), + ); + let now = now(); + + let mut ss = SigningSettings::default(); + ss.percent_encoding_mode = PercentEncodingMode::Double; + ss.payload_checksum_kind = PayloadChecksumKind::XAmzSha256; + ss.signature_location = SignatureLocation::QueryParams; + ss.expires_in = Some(std::time::Duration::from_secs(3600)); + let id = Credentials::new( + "access_key_id", + "secret_access_key", + None, + None, + "hardcoded-credentials", + ) + .into(); + let sp = v4::SigningParams::builder() + .identity(&id) + .region("test") + .name("s3") + .time(SystemTime::from(now)) + .settings(ss) + .build() + .expect("signing params must be valid"); + + let mut body = SignableBody::UnsignedPayload; + if req.headers().get(X_AMZ_CONTENT_SHA_256).is_some() { + body = SignableBody::Bytes(req.body().as_bytes()); + } - let output = aws_sigv4::http_request::sign( - SignableRequest::new( - req.method().as_str(), - req.uri().to_string(), - req.headers() - .iter() - .map(|(k, v)| (k.as_str(), std::str::from_utf8(v.as_bytes()).unwrap())), - body, - ) - .unwrap(), - &sp.into(), + let output = aws_sigv4::http_request::sign( + SignableRequest::new( + req.method().as_str(), + req.uri().to_string(), + req.headers() + .iter() + .map(|(k, v)| (k.as_str(), std::str::from_utf8(v.as_bytes()).unwrap())), + body, ) - .expect("signing must succeed"); - let (aws_sig, _) = output.into_parts(); - aws_sig.apply_to_request_http1x(&mut req); - let expected_req = req; - - let mut req = req_fn(); - - let loader = AwsDefaultLoader::new( - Client::new(), - AwsConfig { - access_key_id: Some("access_key_id".to_string()), - secret_access_key: Some("secret_access_key".to_string()), - ..Default::default() - }, - ); - let cred = loader.load().await?.unwrap(); + .unwrap(), + &sp.into(), + ) + .expect("signing must succeed"); + let (aws_sig, _) = output.into_parts(); + aws_sig.apply_to_request_http1x(&mut req); + let expected_req = req; + + let mut req = req_fn(); + + let loader = AwsDefaultLoader::new( + Client::new(), + AwsConfig { + access_key_id: Some("access_key_id".to_string()), + secret_access_key: Some("secret_access_key".to_string()), + ..Default::default() + }, + ); + let cred = loader.load().await?.unwrap(); - let signer = Signer::new("s3", "test").time(now); + let signer = Signer::new("s3", "test").time(now); - signer.sign_query(&mut req, Duration::from_secs(3600), &cred)?; - let actual_req = req; + signer.sign_query(&mut req, Duration::from_secs(3600), &cred)?; + let actual_req = req; - compare_request(&name, &expected_req, &actual_req); - } + compare_request(&name, &expected_req, &actual_req); Ok(()) } + #[apply(test_cases)] #[tokio::test] - async fn test_calculate_with_token() -> Result<()> { + async fn test_calculate_with_token(req_fn: fn() -> http::Request<&'static str>) -> Result<()> { let _ = env_logger::builder().is_test(true).try_init(); - for req_fn in test_cases() { - let mut req = req_fn(); - let name = format!( - "{} {} {:?}", - req.method(), - req.uri().path(), - req.uri().query(), - ); - let now = now(); - - let mut ss = SigningSettings::default(); - ss.percent_encoding_mode = PercentEncodingMode::Double; - ss.payload_checksum_kind = PayloadChecksumKind::XAmzSha256; - let id = Credentials::new( - "access_key_id", - "secret_access_key", - Some("security_token".to_string()), - None, - "hardcoded-credentials", - ) - .into(); - let sp = v4::SigningParams::builder() - .identity(&id) - .region("test") - .name("s3") - .time(SystemTime::from(now)) - .settings(ss) - .build() - .expect("signing params must be valid"); - - let mut body = SignableBody::UnsignedPayload; - if req.headers().get(X_AMZ_CONTENT_SHA_256).is_some() { - body = SignableBody::Bytes(req.body().as_bytes()); - } + let mut req = req_fn(); + let name = format!( + "{} {} {:?}", + req.method(), + req.uri().path(), + req.uri().query(), + ); + let now = now(); + + let mut ss = SigningSettings::default(); + ss.percent_encoding_mode = PercentEncodingMode::Double; + ss.payload_checksum_kind = PayloadChecksumKind::XAmzSha256; + let id = Credentials::new( + "access_key_id", + "secret_access_key", + Some("security_token".to_string()), + None, + "hardcoded-credentials", + ) + .into(); + let sp = v4::SigningParams::builder() + .identity(&id) + .region("test") + .name("s3") + .time(SystemTime::from(now)) + .settings(ss) + .build() + .expect("signing params must be valid"); + + let mut body = SignableBody::UnsignedPayload; + if req.headers().get(X_AMZ_CONTENT_SHA_256).is_some() { + body = SignableBody::Bytes(req.body().as_bytes()); + } - let output = aws_sigv4::http_request::sign( - SignableRequest::new( - req.method().as_str(), - req.uri().to_string(), - req.headers() - .iter() - .map(|(k, v)| (k.as_str(), std::str::from_utf8(v.as_bytes()).unwrap())), - body, - ) - .unwrap(), - &sp.into(), + let output = aws_sigv4::http_request::sign( + SignableRequest::new( + req.method().as_str(), + req.uri().to_string(), + req.headers() + .iter() + .map(|(k, v)| (k.as_str(), std::str::from_utf8(v.as_bytes()).unwrap())), + body, ) - .expect("signing must succeed"); - let (aws_sig, _) = output.into_parts(); - aws_sig.apply_to_request_http1x(&mut req); - let expected_req = req; - - let mut req = req_fn(); - - let loader = AwsDefaultLoader::new( - Client::new(), - AwsConfig { - access_key_id: Some("access_key_id".to_string()), - secret_access_key: Some("secret_access_key".to_string()), - session_token: Some("security_token".to_string()), - ..Default::default() - }, - ); - let cred = loader.load().await?.unwrap(); + .unwrap(), + &sp.into(), + ) + .expect("signing must succeed"); + let (aws_sig, _) = output.into_parts(); + aws_sig.apply_to_request_http1x(&mut req); + let expected_req = req; + + let mut req = req_fn(); + + let loader = AwsDefaultLoader::new( + Client::new(), + AwsConfig { + access_key_id: Some("access_key_id".to_string()), + secret_access_key: Some("secret_access_key".to_string()), + session_token: Some("security_token".to_string()), + ..Default::default() + }, + ); + let cred = loader.load().await?.unwrap(); - let signer = Signer::new("s3", "test").time(now); + let signer = Signer::new("s3", "test").time(now); - signer.sign(&mut req, &cred).expect("must apply success"); - let actual_req = req; + signer.sign(&mut req, &cred).expect("must apply success"); + let actual_req = req; - compare_request(&name, &expected_req, &actual_req); - } + compare_request(&name, &expected_req, &actual_req); Ok(()) } + #[apply(test_cases)] #[tokio::test] - async fn test_calculate_with_token_in_query() -> Result<()> { + async fn test_calculate_with_token_in_query( + req_fn: fn() -> http::Request<&'static str>, + ) -> Result<()> { let _ = env_logger::builder().is_test(true).try_init(); - for req_fn in test_cases() { - let mut req = req_fn(); - let name = format!( - "{} {} {:?}", - req.method(), - req.uri().path(), - req.uri().query(), - ); - let now = now(); - - let mut ss = SigningSettings::default(); - ss.percent_encoding_mode = PercentEncodingMode::Double; - ss.payload_checksum_kind = PayloadChecksumKind::XAmzSha256; - ss.signature_location = SignatureLocation::QueryParams; - ss.expires_in = Some(std::time::Duration::from_secs(3600)); - let id = Credentials::new( - "access_key_id", - "secret_access_key", - Some("security_token".to_string()), - None, - "hardcoded-credentials", - ) - .into(); - let sp = v4::SigningParams::builder() - .identity(&id) - .region("test") - // .security_token("security_token") - .name("s3") - .time(SystemTime::from(now)) - .settings(ss) - .build() - .expect("signing params must be valid"); - - let mut body = SignableBody::UnsignedPayload; - if req.headers().get(X_AMZ_CONTENT_SHA_256).is_some() { - body = SignableBody::Bytes(req.body().as_bytes()); - } + let mut req = req_fn(); + let name = format!( + "{} {} {:?}", + req.method(), + req.uri().path(), + req.uri().query(), + ); + let now = now(); + + let mut ss = SigningSettings::default(); + ss.percent_encoding_mode = PercentEncodingMode::Double; + ss.payload_checksum_kind = PayloadChecksumKind::XAmzSha256; + ss.signature_location = SignatureLocation::QueryParams; + ss.expires_in = Some(std::time::Duration::from_secs(3600)); + let id = Credentials::new( + "access_key_id", + "secret_access_key", + Some("security_token".to_string()), + None, + "hardcoded-credentials", + ) + .into(); + let sp = v4::SigningParams::builder() + .identity(&id) + .region("test") + // .security_token("security_token") + .name("s3") + .time(SystemTime::from(now)) + .settings(ss) + .build() + .expect("signing params must be valid"); + + let mut body = SignableBody::UnsignedPayload; + if req.headers().get(X_AMZ_CONTENT_SHA_256).is_some() { + body = SignableBody::Bytes(req.body().as_bytes()); + } - let output = aws_sigv4::http_request::sign( - SignableRequest::new( - req.method().as_str(), - req.uri().to_string(), - req.headers() - .iter() - .map(|(k, v)| (k.as_str(), std::str::from_utf8(v.as_bytes()).unwrap())), - body, - ) - .unwrap(), - &sp.into(), + let output = aws_sigv4::http_request::sign( + SignableRequest::new( + req.method().as_str(), + req.uri().to_string(), + req.headers() + .iter() + .map(|(k, v)| (k.as_str(), std::str::from_utf8(v.as_bytes()).unwrap())), + body, ) - .expect("signing must succeed"); - let (aws_sig, _) = output.into_parts(); - aws_sig.apply_to_request_http1x(&mut req); - let expected_req = req; - - let mut req = req_fn(); - - let loader = AwsDefaultLoader::new( - Client::new(), - AwsConfig { - access_key_id: Some("access_key_id".to_string()), - secret_access_key: Some("secret_access_key".to_string()), - session_token: Some("security_token".to_string()), - ..Default::default() - }, - ); - let cred = loader.load().await?.unwrap(); + .unwrap(), + &sp.into(), + ) + .expect("signing must succeed"); + let (aws_sig, _) = output.into_parts(); + aws_sig.apply_to_request_http1x(&mut req); + let expected_req = req; + + let mut req = req_fn(); + + let loader = AwsDefaultLoader::new( + Client::new(), + AwsConfig { + access_key_id: Some("access_key_id".to_string()), + secret_access_key: Some("secret_access_key".to_string()), + session_token: Some("security_token".to_string()), + ..Default::default() + }, + ); + let cred = loader.load().await?.unwrap(); - let signer = Signer::new("s3", "test").time(now); - signer - .sign_query(&mut req, Duration::from_secs(3600), &cred) - .expect("must apply success"); - let actual_req = req; + let signer = Signer::new("s3", "test").time(now); + signer + .sign_query(&mut req, Duration::from_secs(3600), &cred) + .expect("must apply success"); + let actual_req = req; - compare_request(&name, &expected_req, &actual_req); - } + compare_request(&name, &expected_req, &actual_req); Ok(()) }