Skip to content

Commit

Permalink
Refactor authentication middleware to include project_id in request t…
Browse files Browse the repository at this point in the history
…racing

Change the main authentication middleware to be a top level middleware that
injects the authentication status into the request extensions. The per route
middleware can then just read the AuthenticationStatus and take decisions based
on it. This change allows us to capture the project id in the request tracing
which wasn't possible before because it get invoked just before the request handlers
(due to how middleware ordering works).


The main change in behavior here is that the middleware will be invoked regardless of
whether the route requires authentication or not. This is ok, unless the customer passes
malformed auth headers in which case the request would currently fail with `BadRequest`
while it would have succeeded before. Note, that we currently don't have any unauthenticated
endpoints, so no "real" change in behavior here.
  • Loading branch information
MohamedBassem committed Jul 24, 2023
1 parent 74544f5 commit 5bf581b
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 88 deletions.
208 changes: 137 additions & 71 deletions services/cronback-api-srv/auth_middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,51 @@ use axum::middleware::Next;
use axum::response::IntoResponse;
use lib::model::{ModelId, ValidShardedId};
use lib::types::ProjectId;
use tracing::error;

use crate::auth::SecretApiKey;
use crate::auth::{AuthError, SecretApiKey};
use crate::errors::ApiError;
use crate::AppState;

const ON_BEHALF_OF_HEADER_NAME: &str = "X-On-Behalf-Of";

enum AuthenticationStatus {
Unauthenticated,
Authenticated(ValidShardedId<ProjectId>),
Admin(Option<ValidShardedId<ProjectId>>),
}

impl AuthenticationStatus {
fn project_id(&self) -> Option<ValidShardedId<ProjectId>> {
match self {
| AuthenticationStatus::Authenticated(p) => Some(p.clone()),
| AuthenticationStatus::Admin(Some(p)) => Some(p.clone()),
| _ => None,
}
}
}

/// Parses the AUTHORIZATION header to extract the user provided secret key.
fn get_auth_key(
header_map: &HeaderMap<HeaderValue>,
) -> Result<String, ApiError> {
) -> Result<Option<String>, ApiError> {
let auth_header = header_map
.get(http::header::AUTHORIZATION)
.and_then(|header| header.to_str().ok());

let auth_header = if let Some(auth_header) = auth_header {
auth_header
} else {
return Err(ApiError::Unauthorized);
return Ok(None);
};

if auth_header.is_empty() {
return Err(ApiError::Unauthorized);
return Ok(None);
}

match auth_header.split_once(' ') {
| Some((name, content)) if name == "Bearer" => Ok(content.to_string()),
| Some((name, content)) if name == "Bearer" => {
Ok(Some(content.to_string()))
}
| _ => {
Err(ApiError::BadRequest(
"Authentication header is malformed, please use \
Expand All @@ -41,96 +61,142 @@ fn get_auth_key(
}
}

/// Ensures that the caller is authenticated with an admin key AND acting on
/// behalf of a project. The `ProjectId` is then injected in the request
/// extensions.
pub async fn admin_only_auth_for_project<B>(
State(state): State<Arc<AppState>>,
mut req: Request<B>,
next: Next<B>,
) -> Result<impl IntoResponse, ApiError> {
async fn get_auth_status<B>(
state: &AppState,
req: &Request<B>,
) -> Result<AuthenticationStatus, ApiError> {
let auth_key = get_auth_key(req.headers())?;
let Some(auth_key) = auth_key else {
return Ok(AuthenticationStatus::Unauthenticated);
};
let admin_keys = &state.config.api.admin_api_keys;
if admin_keys.contains(&auth_key) {
let project = extract_project_from_request(&req)?;
req.extensions_mut().insert(project.clone());
Ok(next.run(req).await)
} else {
Err(ApiError::Forbidden)
let project: Option<ValidShardedId<ProjectId>> = req
.headers()
.get(ON_BEHALF_OF_HEADER_NAME)
.map(HeaderValue::to_str)
.transpose()
.map_err(|_| {
ApiError::BadRequest(format!(
"{ON_BEHALF_OF_HEADER_NAME} header is not a valid UTF-8 \
string"
))
})?
.map(|p| ProjectId::from(p.to_owned()).validated())
.transpose()
.map_err(|_| {
ApiError::BadRequest(format!(
"Invalid project id in {ON_BEHALF_OF_HEADER_NAME} header"
))
})?;

return Ok(AuthenticationStatus::Admin(project));
}

let Ok(user_provided_secret) = auth_key.to_string().parse::<SecretApiKey>()
else {
return Ok(AuthenticationStatus::Unauthenticated);
};

let project = state.authenicator.authenticate(&user_provided_secret).await;
match project {
| Ok(project_id) => Ok(AuthenticationStatus::Authenticated(project_id)),
| Err(AuthError::AuthFailed(_)) => {
Ok(AuthenticationStatus::Unauthenticated)
}
| Err(e) => {
tracing::error!("{}", e);
Err(ApiError::ServiceUnavailable)
}
}
}

/// Ensures that the caller is authenticated with an admin key. No project is
/// required. Handlers using this middleware shouldn't rely on a `ProjectId`
/// being set in the request extensions.
pub async fn admin_only_auth<B>(
State(state): State<Arc<AppState>>,
/// Ensures that the caller is authenticated with a project id.
pub async fn ensure_authenticated<B>(
req: Request<B>,
next: Next<B>,
) -> Result<impl IntoResponse, ApiError> {
let auth_key = get_auth_key(req.headers())?;
let admin_keys = &state.config.api.admin_api_keys;
if admin_keys.contains(&auth_key) {
Ok(next.run(req).await)
} else {
Err(ApiError::Forbidden)
let auth = req.extensions().get::<AuthenticationStatus>().expect(
"All endpoints should have passed by the authentication middleware",
);
match auth {
| AuthenticationStatus::Admin(Some(_))
| AuthenticationStatus::Admin(None) => {
Err(ApiError::BadRequest(
"Super privilege header(s) missing!".to_owned(),
))
}
| AuthenticationStatus::Authenticated(_) => Ok(next.run(req).await),
| AuthenticationStatus::Unauthenticated => Err(ApiError::Unauthorized),
}
}

fn extract_project_from_request<B>(
req: &Request<B>,
) -> Result<ValidShardedId<ProjectId>, ApiError> {
// This is an admin user which is acting on behalf of some project.
const ON_BEHALF_OF_HEADER_NAME: &str = "X-On-Behalf-Of";
if let Some(project) = req.headers().get(ON_BEHALF_OF_HEADER_NAME) {
let project = project.to_str().map_err(|_| {
ApiError::BadRequest(format!(
"{ON_BEHALF_OF_HEADER_NAME} header is not a valid UTF-8 string"
/// Ensures that the caller is authenticated with an admin key AND acting on
/// behalf of a project.
pub async fn ensure_admin_for_project<B>(
req: Request<B>,
next: Next<B>,
) -> Result<impl IntoResponse, ApiError> {
let auth = req.extensions().get::<AuthenticationStatus>().expect(
"All endpoints should have passed by the authentication middleware",
);

match auth {
| AuthenticationStatus::Admin(Some(_)) => Ok(next.run(req).await),
| AuthenticationStatus::Admin(None) => {
Err(ApiError::BadRequest(
"Super privilege header(s) missing!".to_owned(),
))
})?;
let validated_project = ProjectId::from(project.to_owned())
.validated()
.map_err(|_| {
ApiError::BadRequest(format!(
"Invalid project id in {ON_BEHALF_OF_HEADER_NAME} header"
))
});
return validated_project;
}
| AuthenticationStatus::Authenticated(_) => Err(ApiError::Forbidden),
| AuthenticationStatus::Unauthenticated => Err(ApiError::Unauthorized),
}
}

error!("Admin user didn't set {} header", ON_BEHALF_OF_HEADER_NAME);
/// Ensures that the caller is authenticated with an admin key. No project is
/// required. Handlers using this middleware shouldn't rely on a `ProjectId`
/// being set in the request extensions.
pub async fn ensure_admin<B>(
req: Request<B>,
next: Next<B>,
) -> Result<impl IntoResponse, ApiError> {
let auth = req.extensions().get::<AuthenticationStatus>().expect(
"All endpoints should have passed by the authentication middleware",
);

Err(ApiError::BadRequest(
"Super privilege header(s) missing!".to_owned(),
))
match auth {
| AuthenticationStatus::Admin(_) => Ok(next.run(req).await),
| AuthenticationStatus::Authenticated(_) => Err(ApiError::Forbidden),
| AuthenticationStatus::Unauthenticated => Err(ApiError::Unauthorized),
}
}

pub async fn auth<B>(
/// Parses the request headers to extract authentication information. The
/// AuthenticationStatus is then injected in the request/response extensions
/// along with the authenticated ProjectId if found. This middleware only fails
/// if the user passes malformed authentication headers. It's the responsibility
/// of the other "ensure_*" middlewares in this module to enforce the expected
/// AuthenticationStatus for a certain route.
pub async fn authenticate<B>(
State(state): State<Arc<AppState>>,
mut req: Request<B>,
next: Next<B>,
) -> Result<impl IntoResponse, ApiError> {
let auth_key = get_auth_key(req.headers())?;
let admin_keys = &state.config.api.admin_api_keys;
if admin_keys.contains(&auth_key) {
let project = extract_project_from_request(&req)?;
req.extensions_mut().insert(project.clone());
return Ok(next.run(req).await);
}
let auth_status = get_auth_status(state.as_ref(), &req).await?;

let Ok(user_provided_secret) = auth_key.to_string().parse::<SecretApiKey>()
else {
return Err(ApiError::Unauthorized);
};
let project_id = auth_status.project_id();
req.extensions_mut().insert(auth_status);

let project = state
.authenicator
.authenticate(&user_provided_secret)
.await?;
if let Some(project_id) = &project_id {
req.extensions_mut().insert(project_id.clone());
}

req.extensions_mut().insert(project.clone());
let mut resp = next.run(req).await;
// Inject project_id in the response extensions as well.
resp.extensions_mut().insert(project);

if let Some(project_id) = &project_id {
// Inject project_id in the response extensions as well.
resp.extensions_mut().insert(project_id.clone());
}

Ok(resp)
}
12 changes: 3 additions & 9 deletions services/cronback-api-srv/handlers/admin/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::sync::Arc;

use axum::{middleware, Router};

use crate::auth_middleware::{admin_only_auth, admin_only_auth_for_project};
use crate::auth_middleware::{ensure_admin, ensure_admin_for_project};
use crate::AppState;

pub(crate) fn routes(shared_state: Arc<AppState>) -> Router {
Expand All @@ -17,10 +17,7 @@ pub(crate) fn routes(shared_state: Arc<AppState>) -> Router {
.route("/", axum::routing::get(api_keys::list))
.route("/:id", axum::routing::delete(api_keys::revoke))
.with_state(Arc::clone(&shared_state))
.route_layer(middleware::from_fn_with_state(
Arc::clone(&shared_state),
admin_only_auth_for_project,
)),
.route_layer(middleware::from_fn(ensure_admin_for_project)),
)
.nest(
"/projects",
Expand All @@ -29,9 +26,6 @@ pub(crate) fn routes(shared_state: Arc<AppState>) -> Router {
.route("/:id/disable", axum::routing::post(projects::disable))
.route("/:id/enable", axum::routing::post(projects::enable))
.with_state(Arc::clone(&shared_state))
.route_layer(middleware::from_fn_with_state(
Arc::clone(&shared_state),
admin_only_auth,
)),
.route_layer(middleware::from_fn(ensure_admin)),
)
}
10 changes: 3 additions & 7 deletions services/cronback-api-srv/handlers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::Arc;

use axum::{middleware, Router};

use crate::auth_middleware::auth as auth_middleware;
use crate::auth_middleware::ensure_authenticated;
use crate::AppState;

pub(crate) mod admin;
Expand All @@ -13,11 +13,7 @@ pub(crate) fn routes(shared_state: Arc<AppState>) -> Router {
.nest("/admin", admin::routes(Arc::clone(&shared_state)))
.nest(
"/triggers",
triggers::routes(Arc::clone(&shared_state)).route_layer(
middleware::from_fn_with_state(
Arc::clone(&shared_state),
auth_middleware,
),
),
triggers::routes(Arc::clone(&shared_state))
.route_layer(middleware::from_fn(ensure_authenticated)),
)
}
4 changes: 4 additions & 0 deletions services/cronback-api-srv/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ pub async fn start_api_server(
TraceLayer::new_for_http()
.make_span_with(ApiMakeSpan::new(service_name)),
)
.layer(middleware::from_fn_with_state(
Arc::clone(&shared_state),
auth_middleware::authenticate,
))
.route_layer(middleware::from_fn(inject_request_id))
.route_layer(middleware::from_fn(track_metrics))
.fallback(fallback);
Expand Down
9 changes: 8 additions & 1 deletion services/cronback-api-srv/logging.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use hyper::header::USER_AGENT;
use lib::config::Config;
use lib::types::RequestId;
use lib::prelude::ValidShardedId;
use lib::types::{ProjectId, RequestId};
use tower_http::trace::MakeSpan;
use tracing::{error_span, info};

Expand Down Expand Up @@ -41,6 +42,11 @@ impl<B> MakeSpan<B> for ApiMakeSpan {
.get::<ConnectInfo<SocketAddr>>()
.map(|a| a.ip().to_string());

let project_id = request
.extensions()
.get::<ValidShardedId<ProjectId>>()
.map(|p| p.to_string());

error_span!(
target: "request_response_tracing_metadata",
"http_request",
Expand All @@ -52,6 +58,7 @@ impl<B> MakeSpan<B> for ApiMakeSpan {
version = ?request.version(),
user_agent = ?user_agent,
ip = %ip.unwrap_or_default(),
project_id = %project_id.unwrap_or_default(),
)
}
}
Expand Down

0 comments on commit 5bf581b

Please sign in to comment.