diff --git a/Cargo.toml b/Cargo.toml index 9357a67..5241c0c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ hex = "0.4.3" itertools = "0.12.0" nom = "7.1.3" reqwest = { version = "0.11.23", default-features = false } -sec = { version = "1.0.0", features = ["deserialize"] } +sec = { version = "1.0.0", features = [ "deserialize", "serialize" ] } serde = { version = "1.0.193", features = [ "derive" ] } serde_json = "1.0.108" sha2 = "0.10.8" diff --git a/src/container_orchestrator.rs b/src/container_orchestrator.rs index 336a3b8..fbbe177 100644 --- a/src/container_orchestrator.rs +++ b/src/container_orchestrator.rs @@ -46,7 +46,7 @@ pub(crate) struct ContainerOrchestrator { pub(crate) struct PublishedContainer { host_addr: SocketAddr, manifest_reference: ManifestReference, - config: RuntimeConfig, + config: Arc, } impl PublishedContainer { @@ -57,12 +57,16 @@ impl PublishedContainer { pub(crate) fn host_addr(&self) -> SocketAddr { self.host_addr } + + pub(crate) fn config(&self) -> &Arc { + &self.config + } } #[derive(Clone, Debug, Default, Deserialize, Serialize)] pub(crate) struct RuntimeConfig { #[serde(default)] - http_access: HashMap, + pub(crate) http_access: Option>>, } impl IntoResponse for RuntimeConfig { @@ -115,7 +119,12 @@ impl ContainerOrchestrator { self.configs_dir .join(location.repository()) .join(location.image()) - .join(manifest_reference.reference().to_string()) + .join( + manifest_reference + .reference() + .to_string() + .trim_start_matches(':'), + ) } pub(crate) async fn load_config( @@ -196,7 +205,7 @@ impl ContainerOrchestrator { return Ok(None); }; - let config = self.load_config(&manifest_reference).await?; + let config = Arc::new(self.load_config(&manifest_reference).await?); Ok(Some(PublishedContainer { host_addr: port_mapping diff --git a/src/registry/auth.rs b/src/registry/auth.rs index 8f7887a..ff677af 100644 --- a/src/registry/auth.rs +++ b/src/registry/auth.rs @@ -1,4 +1,4 @@ -use std::{str, sync::Arc}; +use std::{collections::HashMap, str, sync::Arc}; use axum::{ async_trait, @@ -96,6 +96,30 @@ impl AuthProvider for bool { } } +#[async_trait] +impl AuthProvider for HashMap> { + async fn check_credentials( + &self, + UnverifiedCredentials { + username: unverified_username, + password: unverified_password, + }: &UnverifiedCredentials, + ) -> bool { + if let Some(correct_password) = self.get(unverified_username) { + // TODO: Use constant-time compare. Maybe add to `sec`? + if correct_password == unverified_password { + return true; + } + } + + false + } + + async fn has_access_to(&self, _username: &str, _namespace: &str, _image: &str) -> bool { + true + } +} + #[async_trait] impl AuthProvider for Box where diff --git a/src/reverse_proxy.rs b/src/reverse_proxy.rs index fff8ce5..b3a00ad 100644 --- a/src/reverse_proxy.rs +++ b/src/reverse_proxy.rs @@ -75,7 +75,10 @@ impl PartialEq for Domain { #[derive(Debug)] enum Destination { - ReverseProxied(Uri), + ReverseProxied { + uri: Uri, + config: Arc, + }, Internal(Uri), NotFound, } @@ -129,9 +132,10 @@ impl RoutingTable { Authority::from_str(&pc.host_addr().to_string()) .expect("SocketAddr should never fail to convert to Authority"), ); - return Destination::ReverseProxied( - Uri::from_parts(parts).expect("should not have invalidated Uri"), - ); + return Destination::ReverseProxied { + uri: Uri::from_parts(parts).expect("should not have invalidated Uri"), + config: pc.config().clone(), + }; } // Matching a domain did not succeed, let's try with a path. @@ -161,7 +165,10 @@ impl RoutingTable { parts.authority = Some(Authority::from_str(&container_addr.to_string()).unwrap()); parts.path_and_query = Some(PathAndQuery::from_str(&dest_path_and_query).unwrap()); - return Destination::ReverseProxied(Uri::from_parts(parts).unwrap()); + return Destination::ReverseProxied { + uri: Uri::from_parts(parts).unwrap(), + config: pc.config().clone(), + }; } } @@ -271,7 +278,7 @@ fn split_path_base_url(uri: &Uri) -> Option<(ImageLocation, String)> { async fn route_request( State(rp): State>, - request: Request, + mut request: Request, ) -> Result { let dest_uri = { let routing_table = rp.routing_table.read().await; @@ -279,9 +286,26 @@ async fn route_request( }; match dest_uri { - Destination::ReverseProxied(dest) => { + Destination::ReverseProxied { uri: dest, config } => { trace!(%dest, "reverse proxying"); + // First, check if http authentication is enabled. + if let Some(ref http_access) = config.http_access { + let creds = request + .extract_parts::() + .await + .map_err(AppError::AuthFailure)?; + + if !http_access.check_credentials(&creds).await { + return Response::builder() + .status(StatusCode::FORBIDDEN) + .body(Body::empty()) + .map_err(|_| { + AppError::AssertionFailed("should not fail to build response") + }); + } + } + // Note: `reqwest` and `axum` currently use different versions of `http` let method = request.method().to_string().parse().map_err(|_| { AppError::AssertionFailed("method http version mismatch workaround failed") @@ -367,7 +391,7 @@ async fn route_request( Ok(config.into_response()) } Method::PUT => { - let raw = dbg!(opt_body.ok_or(AppError::InvalidPayload)?); + let raw = opt_body.ok_or(AppError::InvalidPayload)?; let new_config: RuntimeConfig = toml::from_str(&raw).map_err(|_| AppError::InvalidPayload)?; let stored = orchestrator