From 22c3a9cf87da3257fd1d7497a52434fe18f275b0 Mon Sep 17 00:00:00 2001 From: Zahari Dichev Date: Fri, 13 Sep 2024 10:42:58 +0000 Subject: [PATCH] add independent DetectSni middleware Signed-off-by: Zahari Dichev --- linkerd/tls/src/detect_sni.rs | 107 ++++++++++++++++++++++++++++++++++ linkerd/tls/src/lib.rs | 1 + linkerd/tls/src/server.rs | 2 +- 3 files changed, 109 insertions(+), 1 deletion(-) create mode 100644 linkerd/tls/src/detect_sni.rs diff --git a/linkerd/tls/src/detect_sni.rs b/linkerd/tls/src/detect_sni.rs new file mode 100644 index 0000000000..45747d63ed --- /dev/null +++ b/linkerd/tls/src/detect_sni.rs @@ -0,0 +1,107 @@ +use crate::{ + server::{detect_sni, DetectIo, Timeout}, + ServerName, +}; +use linkerd_error::Error; +use linkerd_io as io; +use linkerd_stack::{layer, ExtractParam, InsertParam, NewService, Service, ServiceExt}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use thiserror::Error; +use tokio::time; +use tracing::debug; + +#[derive(Clone, Debug, Error)] +#[error("SNI detection timed out")] +pub struct SniDetectionTimeoutError; + +#[derive(Clone, Debug, Error)] +#[error("Could not find SNI")] +pub struct NoSniFoundError; + +#[derive(Clone, Debug)] +pub struct NewDetectSni { + params: P, + inner: N, +} + +#[derive(Clone, Debug)] +pub struct DetectSni { + target: T, + inner: N, + timeout: Timeout, + params: P, +} + +impl NewDetectSni { + pub fn new(params: P, inner: N) -> Self { + Self { inner, params } + } + + pub fn layer(params: P) -> impl layer::Layer + Clone + where + P: Clone, + { + layer::mk(move |inner| Self::new(params.clone(), inner)) + } +} + +impl NewService for NewDetectSni +where + P: ExtractParam + Clone, + N: Clone, +{ + type Service = DetectSni; + + fn new_service(&self, target: T) -> Self::Service { + let timeout = self.params.extract_param(&target); + DetectSni { + target, + timeout, + inner: self.inner.clone(), + params: self.params.clone(), + } + } +} + +impl Service for DetectSni +where + T: Clone + Send + Sync + 'static, + P: InsertParam + Clone + Send + Sync + 'static, + P::Target: Send + 'static, + I: io::AsyncRead + io::Peek + io::AsyncWrite + Send + Sync + Unpin + 'static, + N: NewService + Clone + Send + 'static, + S: Service> + Send, + S::Error: Into, + S::Future: Send, +{ + type Response = S::Response; + type Error = Error; + type Future = Pin> + Send + 'static>>; + + #[inline] + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, io: I) -> Self::Future { + let target = self.target.clone(); + let new_accept = self.inner.clone(); + let params = self.params.clone(); + + // Detect the SNI from a ClientHello (or timeout). + let Timeout(timeout) = self.timeout; + let detect = time::timeout(timeout, detect_sni(io)); + Box::pin(async move { + let (sni, io) = detect.await.map_err(|_| SniDetectionTimeoutError)??; + let sni = sni.ok_or(NoSniFoundError)?; + + debug!("detected SNI: {:?}", sni); + let svc = new_accept.new_service(params.insert_param(sni, target)); + svc.oneshot(io).await.map_err(Into::into) + }) + } +} diff --git a/linkerd/tls/src/lib.rs b/linkerd/tls/src/lib.rs index 0e54d86442..0a281e2b36 100755 --- a/linkerd/tls/src/lib.rs +++ b/linkerd/tls/src/lib.rs @@ -2,6 +2,7 @@ #![forbid(unsafe_code)] pub mod client; +pub mod detect_sni; pub mod server; pub use self::{ diff --git a/linkerd/tls/src/server.rs b/linkerd/tls/src/server.rs index 04862401f9..1c85c92ee6 100644 --- a/linkerd/tls/src/server.rs +++ b/linkerd/tls/src/server.rs @@ -192,7 +192,7 @@ where } /// Peek or buffer the provided stream to determine an SNI value. -async fn detect_sni(mut io: I) -> io::Result<(Option, DetectIo)> +pub(crate) async fn detect_sni(mut io: I) -> io::Result<(Option, DetectIo)> where I: io::Peek + io::AsyncRead + io::AsyncWrite + Send + Sync + Unpin, {