From 060ba8bf32aa81e0cd769b76a640128d60df7eea Mon Sep 17 00:00:00 2001 From: Mohamed Bassem Date: Mon, 31 Jul 2023 21:10:21 +0300 Subject: [PATCH] [Auth] Cover the auth middleware with tests --- cronback-services/src/api/auth.rs | 3 +- cronback-services/src/api/auth_middleware.rs | 291 +++++++++++++++++- .../src/api/handlers/admin/api_keys.rs | 6 +- cronback-services/src/api/mod.rs | 4 +- 4 files changed, 291 insertions(+), 13 deletions(-) diff --git a/cronback-services/src/api/auth.rs b/cronback-services/src/api/auth.rs index 8ae68f2..4283529 100644 --- a/cronback-services/src/api/auth.rs +++ b/cronback-services/src/api/auth.rs @@ -44,6 +44,7 @@ impl From for ApiError { } } +#[derive(Clone)] pub struct Authenticator { store: AuthStore, } @@ -190,7 +191,7 @@ impl FromStr for SecretApiKey { } impl SecretApiKey { - fn generate() -> Self { + pub fn generate() -> Self { Self { key_id: Uuid::new_v4().simple().to_string(), plain_secret: Uuid::new_v4().simple().to_string(), diff --git a/cronback-services/src/api/auth_middleware.rs b/cronback-services/src/api/auth_middleware.rs index a18cd9d..1206dd9 100644 --- a/cronback-services/src/api/auth_middleware.rs +++ b/cronback-services/src/api/auth_middleware.rs @@ -1,17 +1,34 @@ use std::sync::Arc; -use axum::extract::State; +use axum::extract::{FromRef, State}; use axum::http::{self, HeaderMap, HeaderValue, Request}; use axum::middleware::Next; use axum::response::IntoResponse; use lib::prelude::*; -use super::auth::{AuthError, SecretApiKey}; +use super::auth::{AuthError, Authenticator, SecretApiKey}; use super::errors::ApiError; use super::AppState; const ON_BEHALF_OF_HEADER_NAME: &str = "X-On-Behalf-Of"; +// Partial state from the main app state to facilitate writing tests for the +// middleware. +#[derive(Clone)] +pub struct AuthenticationState { + authenticator: Authenticator, + config: super::config::ApiSvcConfig, +} + +impl FromRef> for AuthenticationState { + fn from_ref(input: &Arc) -> Self { + Self { + authenticator: input.authenticator.clone(), + config: input.context.service_config(), + } + } +} + enum AuthenticationStatus { Unauthenticated, Authenticated(ValidShardedId), @@ -61,14 +78,14 @@ fn get_auth_key( } async fn get_auth_status( - state: &AppState, + state: &AuthenticationState, req: &Request, ) -> Result { let auth_key = get_auth_key(req.headers())?; let Some(auth_key) = auth_key else { return Ok(AuthenticationStatus::Unauthenticated); }; - let config = state.context.service_config(); + let config = &state.config; let admin_keys = &config.admin_api_keys; if admin_keys.contains(&auth_key) { let project: Option> = req @@ -98,7 +115,10 @@ async fn get_auth_status( return Ok(AuthenticationStatus::Unauthenticated); }; - let project = state.authenicator.authenticate(&user_provided_secret).await; + let project = state + .authenticator + .authenticate(&user_provided_secret) + .await; match project { | Ok(project_id) => Ok(AuthenticationStatus::Authenticated(project_id)), | Err(AuthError::AuthFailed(_)) => { @@ -178,11 +198,11 @@ pub async fn ensure_admin( /// of the other "ensure_*" middlewares in this module to enforce the expected /// AuthenticationStatus for a certain route. pub async fn authenticate( - State(state): State>, + State(state): State, mut req: Request, next: Next, ) -> Result { - let auth_status = get_auth_status(state.as_ref(), &req).await?; + let auth_status = get_auth_status(&state, &req).await?; let project_id = auth_status.project_id(); req.extensions_mut().insert(auth_status); @@ -200,3 +220,260 @@ pub async fn authenticate( Ok(resp) } + +#[cfg(test)] +mod tests { + + use std::collections::HashSet; + use std::fmt::Debug; + + use axum::routing::get; + use axum::{middleware, Router}; + use cronback_api_model::admin::CreateAPIkeyRequest; + use hyper::{Body, StatusCode}; + use tower::ServiceExt; + + use super::*; + use crate::api::auth_store::AuthStore; + use crate::api::config::ApiSvcConfig; + use crate::api::ApiService; + + async fn make_state() -> AuthenticationState { + let mut set = HashSet::new(); + set.insert("adminkey1".to_string()); + set.insert("adminkey2".to_string()); + + let config = ApiSvcConfig { + address: String::new(), + port: 123, + database_uri: String::new(), + admin_api_keys: set, + log_request_body: false, + log_response_body: false, + }; + + let db = ApiService::in_memory_database().await.unwrap(); + let auth_store = AuthStore::new(db); + let authenticator = Authenticator::new(auth_store); + + AuthenticationState { + authenticator, + config, + } + } + + struct TestInput { + app: Router, + auth_header: Option, + on_behalf_on_header: Option, + expected_status: StatusCode, + } + + impl Debug for TestInput { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TestInput") + .field("auth_header", &self.auth_header) + .field("on_behalf_on_header", &self.on_behalf_on_header) + .field("expected_status", &self.expected_status) + .finish() + } + } + + struct TestExpectations { + unauthenticated: StatusCode, + authenticated: StatusCode, + admin_no_project: StatusCode, + admin_with_project: StatusCode, + unknown_secret_key: StatusCode, + } + + async fn run_tests( + app: Router, + state: AuthenticationState, + expectations: TestExpectations, + ) -> anyhow::Result<()> { + // Define one project and generate a key for it. + let prj1 = ProjectId::generate(); + let key = state + .authenticator + .gen_key( + CreateAPIkeyRequest { + key_name: "test".to_string(), + metadata: Default::default(), + }, + &prj1, + ) + .await?; + + let inputs = vec![ + // Unauthenticated user + TestInput { + app: app.clone(), + auth_header: None, + on_behalf_on_header: None, + expected_status: expectations.unauthenticated, + }, + // Authenticated user + TestInput { + app: app.clone(), + auth_header: Some(format!("Bearer {}", key.unsafe_to_string())), + on_behalf_on_header: None, + expected_status: expectations.authenticated, + }, + // Admin without project + TestInput { + app: app.clone(), + auth_header: Some("Bearer adminkey1".to_string()), + on_behalf_on_header: None, + expected_status: expectations.admin_no_project, + }, + // Admin with project + TestInput { + app: app.clone(), + auth_header: Some("Bearer adminkey1".to_string()), + on_behalf_on_header: Some(prj1.to_string()), + expected_status: expectations.admin_with_project, + }, + // Unknown secret key + TestInput { + app: app.clone(), + auth_header: Some(format!( + "Bearer {}", + SecretApiKey::generate().unsafe_to_string() + )), + on_behalf_on_header: Some(prj1.to_string()), + expected_status: expectations.unknown_secret_key, + }, + // Malformed secret key should be treated as an unknown secret key + TestInput { + app: app.clone(), + auth_header: Some("Bearer wrong key".to_string()), + on_behalf_on_header: Some("wrong_project".to_string()), + expected_status: expectations.unknown_secret_key, + }, + // Malformed authorization header + TestInput { + app: app.clone(), + auth_header: Some(format!("Token {}", key.unsafe_to_string())), + on_behalf_on_header: Some(prj1.to_string()), + expected_status: StatusCode::BAD_REQUEST, + }, + // Malformed on-behalf-on project id + TestInput { + app: app.clone(), + auth_header: Some("Bearer adminkey1".to_string()), + on_behalf_on_header: Some("wrong_project".to_string()), + expected_status: StatusCode::BAD_REQUEST, + }, + ]; + + for input in inputs { + let input_str = format!("{:?}", input); + + let mut req = Request::builder(); + if let Some(v) = input.auth_header { + req = req.header("Authorization", v); + } + if let Some(v) = input.on_behalf_on_header { + req = req.header(ON_BEHALF_OF_HEADER_NAME, v); + } + + let resp = input + .app + .oneshot(req.uri("/").body(Body::empty()).unwrap()) + .await?; + + assert_eq!( + resp.status(), + input.expected_status, + "Input: {}", + input_str + ); + } + Ok(()) + } + + #[tokio::test] + async fn test_ensure_authenticated() -> anyhow::Result<()> { + let state = make_state().await; + + let app = Router::new() + .route("/", get(|| async { "Hello, World!" })) + .layer(middleware::from_fn(super::ensure_authenticated)) + .layer(middleware::from_fn_with_state( + state.clone(), + super::authenticate, + )); + + run_tests( + app, + state, + TestExpectations { + unauthenticated: StatusCode::UNAUTHORIZED, + authenticated: StatusCode::OK, + admin_no_project: StatusCode::BAD_REQUEST, + admin_with_project: StatusCode::OK, + unknown_secret_key: StatusCode::UNAUTHORIZED, + }, + ) + .await?; + + Ok(()) + } + + #[tokio::test] + async fn test_ensure_admin() -> anyhow::Result<()> { + let state = make_state().await; + + let app = Router::new() + .route("/", get(|| async { "Hello, World!" })) + .layer(middleware::from_fn(super::ensure_admin)) + .layer(middleware::from_fn_with_state( + state.clone(), + super::authenticate, + )); + + run_tests( + app, + state, + TestExpectations { + unauthenticated: StatusCode::UNAUTHORIZED, + authenticated: StatusCode::FORBIDDEN, + admin_no_project: StatusCode::OK, + admin_with_project: StatusCode::OK, + unknown_secret_key: StatusCode::UNAUTHORIZED, + }, + ) + .await?; + + Ok(()) + } + + #[tokio::test] + async fn test_ensure_admin_for_project() -> anyhow::Result<()> { + let state = make_state().await; + + let app = Router::new() + .route("/", get(|| async { "Hello, World!" })) + .layer(middleware::from_fn(super::ensure_admin_for_project)) + .layer(middleware::from_fn_with_state( + state.clone(), + super::authenticate, + )); + + run_tests( + app, + state, + TestExpectations { + unauthenticated: StatusCode::UNAUTHORIZED, + authenticated: StatusCode::FORBIDDEN, + admin_no_project: StatusCode::BAD_REQUEST, + admin_with_project: StatusCode::OK, + unknown_secret_key: StatusCode::UNAUTHORIZED, + }, + ) + .await?; + + Ok(()) + } +} diff --git a/cronback-services/src/api/handlers/admin/api_keys.rs b/cronback-services/src/api/handlers/admin/api_keys.rs index e3d7e67..0d9e7ce 100644 --- a/cronback-services/src/api/handlers/admin/api_keys.rs +++ b/cronback-services/src/api/handlers/admin/api_keys.rs @@ -24,7 +24,7 @@ pub(crate) async fn create( Extension(project): Extension>, ValidatedJson(req): ValidatedJson, ) -> Result, ApiError> { - let key = state.authenicator.gen_key(req, &project).await?; + let key = state.authenticator.gen_key(req, &project).await?; // This is the only legitimate place where this function should be used. let key_str = key.unsafe_to_string(); @@ -38,7 +38,7 @@ pub(crate) async fn list( Extension(project): Extension>, ) -> Result, ApiError> { let keys = state - .authenicator + .authenticator .list_keys(&project) .await .map_err(|e| AppStateError::DatabaseError(e.to_string()))? @@ -72,7 +72,7 @@ pub(crate) async fn revoke( Extension(project): Extension>, ) -> Result { let deleted = state - .authenicator + .authenticator .revoke_key(&id, &project) .await .map_err(|e| AppStateError::DatabaseError(e.to_string()))?; diff --git a/cronback-services/src/api/mod.rs b/cronback-services/src/api/mod.rs index 81ec1cf..ec4b88e 100644 --- a/cronback-services/src/api/mod.rs +++ b/cronback-services/src/api/mod.rs @@ -85,7 +85,7 @@ impl CronbackService for ApiService { let shared_state = Arc::new(AppState { context: context.clone(), - authenicator: Authenticator::new(AuthStore::new(db)), + authenticator: Authenticator::new(AuthStore::new(db)), scheduler_clients: Box::new(GrpcClientProvider::new( config.clone(), )), @@ -169,7 +169,7 @@ pub enum AppStateError { pub struct AppState { pub context: ServiceContext, - pub authenicator: Authenticator, + pub authenticator: Authenticator, pub scheduler_clients: Box>, pub dispatcher_clients: