diff --git a/services/cronback-api-srv/auth_middleware.rs b/services/cronback-api-srv/auth_middleware.rs index 00d9dec..25070bc 100644 --- a/services/cronback-api-srv/auth_middleware.rs +++ b/services/cronback-api-srv/auth_middleware.rs @@ -6,15 +6,33 @@ use axum::middleware::Next; use axum::response::IntoResponse; use lib::model::{ModelId, ValidShardedId}; use lib::types::ProjectId; -use tracing::error; -use crate::auth::SecretApiKey; +use crate::auth::{AuthError, SecretApiKey}; use crate::errors::ApiError; use crate::AppState; +const ON_BEHALF_OF_HEADER_NAME: &str = "X-On-Behalf-Of"; + +enum AuthenticationStatus { + Unauthenticated, + Authenticated(ValidShardedId), + Admin(Option>), +} + +impl AuthenticationStatus { + fn project_id(&self) -> Option> { + match self { + | AuthenticationStatus::Authenticated(p) => Some(p.clone()), + | AuthenticationStatus::Admin(Some(p)) => Some(p.clone()), + | _ => None, + } + } +} + +/// Parses the AUTHORIZATION header to extract the user provided secret key. fn get_auth_key( header_map: &HeaderMap, -) -> Result { +) -> Result, ApiError> { let auth_header = header_map .get(http::header::AUTHORIZATION) .and_then(|header| header.to_str().ok()); @@ -22,15 +40,17 @@ fn get_auth_key( let auth_header = if let Some(auth_header) = auth_header { auth_header } else { - return Err(ApiError::Unauthorized); + return Ok(None); }; if auth_header.is_empty() { - return Err(ApiError::Unauthorized); + return Ok(None); } match auth_header.split_once(' ') { - | Some((name, content)) if name == "Bearer" => Ok(content.to_string()), + | Some((name, content)) if name == "Bearer" => { + Ok(Some(content.to_string())) + } | _ => { Err(ApiError::BadRequest( "Authentication header is malformed, please use \ @@ -41,96 +61,142 @@ fn get_auth_key( } } -/// Ensures that the caller is authenticated with an admin key AND acting on -/// behalf of a project. The `ProjectId` is then injected in the request -/// extensions. -pub async fn admin_only_auth_for_project( - State(state): State>, - mut req: Request, - next: Next, -) -> Result { +async fn get_auth_status( + state: &AppState, + req: &Request, +) -> Result { let auth_key = get_auth_key(req.headers())?; + let Some(auth_key) = auth_key else { + return Ok(AuthenticationStatus::Unauthenticated); + }; let admin_keys = &state.config.api.admin_api_keys; if admin_keys.contains(&auth_key) { - let project = extract_project_from_request(&req)?; - req.extensions_mut().insert(project.clone()); - Ok(next.run(req).await) - } else { - Err(ApiError::Forbidden) + let project: Option> = req + .headers() + .get(ON_BEHALF_OF_HEADER_NAME) + .map(HeaderValue::to_str) + .transpose() + .map_err(|_| { + ApiError::BadRequest(format!( + "{ON_BEHALF_OF_HEADER_NAME} header is not a valid UTF-8 \ + string" + )) + })? + .map(|p| ProjectId::from(p.to_owned()).validated()) + .transpose() + .map_err(|_| { + ApiError::BadRequest(format!( + "Invalid project id in {ON_BEHALF_OF_HEADER_NAME} header" + )) + })?; + + return Ok(AuthenticationStatus::Admin(project)); + } + + let Ok(user_provided_secret) = auth_key.to_string().parse::() + else { + return Ok(AuthenticationStatus::Unauthenticated); + }; + + let project = state.authenicator.authenticate(&user_provided_secret).await; + match project { + | Ok(project_id) => Ok(AuthenticationStatus::Authenticated(project_id)), + | Err(AuthError::AuthFailed(_)) => { + Ok(AuthenticationStatus::Unauthenticated) + } + | Err(e) => { + tracing::error!("{}", e); + Err(ApiError::ServiceUnavailable) + } } } -/// Ensures that the caller is authenticated with an admin key. No project is -/// required. Handlers using this middleware shouldn't rely on a `ProjectId` -/// being set in the request extensions. -pub async fn admin_only_auth( - State(state): State>, +/// Ensures that the caller is authenticated with a project id. +pub async fn ensure_authenticated( req: Request, next: Next, ) -> Result { - let auth_key = get_auth_key(req.headers())?; - let admin_keys = &state.config.api.admin_api_keys; - if admin_keys.contains(&auth_key) { - Ok(next.run(req).await) - } else { - Err(ApiError::Forbidden) + let auth = req.extensions().get::().expect( + "All endpoints should have passed by the authentication middleware", + ); + match auth { + | AuthenticationStatus::Admin(Some(_)) + | AuthenticationStatus::Admin(None) => { + Err(ApiError::BadRequest( + "Super privilege header(s) missing!".to_owned(), + )) + } + | AuthenticationStatus::Authenticated(_) => Ok(next.run(req).await), + | AuthenticationStatus::Unauthenticated => Err(ApiError::Unauthorized), } } -fn extract_project_from_request( - req: &Request, -) -> Result, ApiError> { - // This is an admin user which is acting on behalf of some project. - const ON_BEHALF_OF_HEADER_NAME: &str = "X-On-Behalf-Of"; - if let Some(project) = req.headers().get(ON_BEHALF_OF_HEADER_NAME) { - let project = project.to_str().map_err(|_| { - ApiError::BadRequest(format!( - "{ON_BEHALF_OF_HEADER_NAME} header is not a valid UTF-8 string" +/// Ensures that the caller is authenticated with an admin key AND acting on +/// behalf of a project. +pub async fn ensure_admin_for_project( + req: Request, + next: Next, +) -> Result { + let auth = req.extensions().get::().expect( + "All endpoints should have passed by the authentication middleware", + ); + + match auth { + | AuthenticationStatus::Admin(Some(_)) => Ok(next.run(req).await), + | AuthenticationStatus::Admin(None) => { + Err(ApiError::BadRequest( + "Super privilege header(s) missing!".to_owned(), )) - })?; - let validated_project = ProjectId::from(project.to_owned()) - .validated() - .map_err(|_| { - ApiError::BadRequest(format!( - "Invalid project id in {ON_BEHALF_OF_HEADER_NAME} header" - )) - }); - return validated_project; + } + | AuthenticationStatus::Authenticated(_) => Err(ApiError::Forbidden), + | AuthenticationStatus::Unauthenticated => Err(ApiError::Unauthorized), } +} - error!("Admin user didn't set {} header", ON_BEHALF_OF_HEADER_NAME); +/// Ensures that the caller is authenticated with an admin key. No project is +/// required. Handlers using this middleware shouldn't rely on a `ProjectId` +/// being set in the request extensions. +pub async fn ensure_admin( + req: Request, + next: Next, +) -> Result { + let auth = req.extensions().get::().expect( + "All endpoints should have passed by the authentication middleware", + ); - Err(ApiError::BadRequest( - "Super privilege header(s) missing!".to_owned(), - )) + match auth { + | AuthenticationStatus::Admin(_) => Ok(next.run(req).await), + | AuthenticationStatus::Authenticated(_) => Err(ApiError::Forbidden), + | AuthenticationStatus::Unauthenticated => Err(ApiError::Unauthorized), + } } -pub async fn auth( +/// Parses the request headers to extract authentication information. The +/// AuthenticationStatus is then injected in the request/response extensions +/// along with the authenticated ProjectId if found. This middleware only fails +/// if the user passes malformed authentication headers. It's the responsibility +/// of the other "ensure_*" middlewares in this module to enforce the expected +/// AuthenticationStatus for a certain route. +pub async fn authenticate( State(state): State>, mut req: Request, next: Next, ) -> Result { - let auth_key = get_auth_key(req.headers())?; - let admin_keys = &state.config.api.admin_api_keys; - if admin_keys.contains(&auth_key) { - let project = extract_project_from_request(&req)?; - req.extensions_mut().insert(project.clone()); - return Ok(next.run(req).await); - } + let auth_status = get_auth_status(state.as_ref(), &req).await?; - let Ok(user_provided_secret) = auth_key.to_string().parse::() - else { - return Err(ApiError::Unauthorized); - }; + let project_id = auth_status.project_id(); + req.extensions_mut().insert(auth_status); - let project = state - .authenicator - .authenticate(&user_provided_secret) - .await?; + if let Some(project_id) = &project_id { + req.extensions_mut().insert(project_id.clone()); + } - req.extensions_mut().insert(project.clone()); let mut resp = next.run(req).await; - // Inject project_id in the response extensions as well. - resp.extensions_mut().insert(project); + + if let Some(project_id) = &project_id { + // Inject project_id in the response extensions as well. + resp.extensions_mut().insert(project_id.clone()); + } + Ok(resp) } diff --git a/services/cronback-api-srv/handlers/admin/mod.rs b/services/cronback-api-srv/handlers/admin/mod.rs index 5871cc4..529dab1 100644 --- a/services/cronback-api-srv/handlers/admin/mod.rs +++ b/services/cronback-api-srv/handlers/admin/mod.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use axum::{middleware, Router}; -use crate::auth_middleware::{admin_only_auth, admin_only_auth_for_project}; +use crate::auth_middleware::{ensure_admin, ensure_admin_for_project}; use crate::AppState; pub(crate) fn routes(shared_state: Arc) -> Router { @@ -17,10 +17,7 @@ pub(crate) fn routes(shared_state: Arc) -> Router { .route("/", axum::routing::get(api_keys::list)) .route("/:id", axum::routing::delete(api_keys::revoke)) .with_state(Arc::clone(&shared_state)) - .route_layer(middleware::from_fn_with_state( - Arc::clone(&shared_state), - admin_only_auth_for_project, - )), + .route_layer(middleware::from_fn(ensure_admin_for_project)), ) .nest( "/projects", @@ -29,9 +26,6 @@ pub(crate) fn routes(shared_state: Arc) -> Router { .route("/:id/disable", axum::routing::post(projects::disable)) .route("/:id/enable", axum::routing::post(projects::enable)) .with_state(Arc::clone(&shared_state)) - .route_layer(middleware::from_fn_with_state( - Arc::clone(&shared_state), - admin_only_auth, - )), + .route_layer(middleware::from_fn(ensure_admin)), ) } diff --git a/services/cronback-api-srv/handlers/mod.rs b/services/cronback-api-srv/handlers/mod.rs index 54a5e59..93e824e 100644 --- a/services/cronback-api-srv/handlers/mod.rs +++ b/services/cronback-api-srv/handlers/mod.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use axum::{middleware, Router}; -use crate::auth_middleware::auth as auth_middleware; +use crate::auth_middleware::ensure_authenticated; use crate::AppState; pub(crate) mod admin; @@ -13,11 +13,7 @@ pub(crate) fn routes(shared_state: Arc) -> Router { .nest("/admin", admin::routes(Arc::clone(&shared_state))) .nest( "/triggers", - triggers::routes(Arc::clone(&shared_state)).route_layer( - middleware::from_fn_with_state( - Arc::clone(&shared_state), - auth_middleware, - ), - ), + triggers::routes(Arc::clone(&shared_state)) + .route_layer(middleware::from_fn(ensure_authenticated)), ) } diff --git a/services/cronback-api-srv/lib.rs b/services/cronback-api-srv/lib.rs index 8053ef7..7ee226a 100644 --- a/services/cronback-api-srv/lib.rs +++ b/services/cronback-api-srv/lib.rs @@ -109,6 +109,10 @@ pub async fn start_api_server( TraceLayer::new_for_http() .make_span_with(ApiMakeSpan::new(service_name)), ) + .layer(middleware::from_fn_with_state( + Arc::clone(&shared_state), + auth_middleware::authenticate, + )) .route_layer(middleware::from_fn(inject_request_id)) .route_layer(middleware::from_fn(track_metrics)) .fallback(fallback); diff --git a/services/cronback-api-srv/logging.rs b/services/cronback-api-srv/logging.rs index e4491f4..9e7c75e 100644 --- a/services/cronback-api-srv/logging.rs +++ b/services/cronback-api-srv/logging.rs @@ -9,7 +9,8 @@ use axum::middleware::Next; use axum::response::{IntoResponse, Response}; use hyper::header::USER_AGENT; use lib::config::Config; -use lib::types::RequestId; +use lib::prelude::ValidShardedId; +use lib::types::{ProjectId, RequestId}; use tower_http::trace::MakeSpan; use tracing::{error_span, info}; @@ -41,6 +42,11 @@ impl MakeSpan for ApiMakeSpan { .get::>() .map(|a| a.ip().to_string()); + let project_id = request + .extensions() + .get::>() + .map(|p| p.to_string()); + error_span!( target: "request_response_tracing_metadata", "http_request", @@ -52,6 +58,7 @@ impl MakeSpan for ApiMakeSpan { version = ?request.version(), user_agent = ?user_agent, ip = %ip.unwrap_or_default(), + project_id = %project_id.unwrap_or_default(), ) } }