Skip to content

Commit

Permalink
tls: add DetectSni middleware (#3199)
Browse files Browse the repository at this point in the history
This change adds a new `DetectSni` middleware to be used in the outbound stack
in order to extract the SNI extension from the `ClientHello` of a TLS session and apply
routing decisions based on it.

In contrast to `DetectTls` this new middleware is concerned with just extracting the SNI
as opposed to terminating the TLS session.

Signed-off-by: Zahari Dichev <[email protected]>
  • Loading branch information
zaharidichev authored Sep 25, 2024
1 parent 25589d4 commit fa4c8dc
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 1 deletion.
107 changes: 107 additions & 0 deletions linkerd/tls/src/detect_sni.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
use crate::{
server::{detect_sni, DetectIo, Timeout},
ServerName,
};
use linkerd_error::Error;
use linkerd_io as io;
use linkerd_stack::{layer, ExtractParam, InsertParam, NewService, Service, ServiceExt};
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use thiserror::Error;
use tokio::time;
use tracing::debug;

#[derive(Clone, Debug, Error)]
#[error("SNI detection timed out")]
pub struct SniDetectionTimeoutError;

#[derive(Clone, Debug, Error)]
#[error("Could not find SNI")]
pub struct NoSniFoundError;

#[derive(Clone, Debug)]
pub struct NewDetectSni<P, N> {
params: P,
inner: N,
}

#[derive(Clone, Debug)]
pub struct DetectSni<T, P, N> {
target: T,
inner: N,
timeout: Timeout,
params: P,
}

impl<P, N> NewDetectSni<P, N> {
pub fn new(params: P, inner: N) -> Self {
Self { inner, params }
}

pub fn layer(params: P) -> impl layer::Layer<N, Service = Self> + Clone
where
P: Clone,
{
layer::mk(move |inner| Self::new(params.clone(), inner))
}
}

impl<T, P, N> NewService<T> for NewDetectSni<P, N>
where
P: ExtractParam<Timeout, T> + Clone,
N: Clone,
{
type Service = DetectSni<T, P, N>;

fn new_service(&self, target: T) -> Self::Service {
let timeout = self.params.extract_param(&target);
DetectSni {
target,
timeout,
inner: self.inner.clone(),
params: self.params.clone(),
}
}
}

impl<T, P, I, N, S> Service<I> for DetectSni<T, P, N>
where
T: Clone + Send + Sync + 'static,
P: InsertParam<ServerName, T> + Clone + Send + Sync + 'static,
P::Target: Send + 'static,
I: io::AsyncRead + io::Peek + io::AsyncWrite + Send + Sync + Unpin + 'static,
N: NewService<P::Target, Service = S> + Clone + Send + 'static,
S: Service<DetectIo<I>> + Send,
S::Error: Into<Error>,
S::Future: Send,
{
type Response = S::Response;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<S::Response, Error>> + Send + 'static>>;

#[inline]
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}

fn call(&mut self, io: I) -> Self::Future {
let target = self.target.clone();
let new_accept = self.inner.clone();
let params = self.params.clone();

// Detect the SNI from a ClientHello (or timeout).
let Timeout(timeout) = self.timeout;
let detect = time::timeout(timeout, detect_sni(io));
Box::pin(async move {
let (sni, io) = detect.await.map_err(|_| SniDetectionTimeoutError)??;
let sni = sni.ok_or(NoSniFoundError)?;

debug!("detected SNI: {:?}", sni);
let svc = new_accept.new_service(params.insert_param(sni, target));
svc.oneshot(io).await.map_err(Into::into)
})
}
}
1 change: 1 addition & 0 deletions linkerd/tls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#![forbid(unsafe_code)]

pub mod client;
pub mod detect_sni;
pub mod server;

pub use self::{
Expand Down
2 changes: 1 addition & 1 deletion linkerd/tls/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ where
}

/// Peek or buffer the provided stream to determine an SNI value.
async fn detect_sni<I>(mut io: I) -> io::Result<(Option<ServerName>, DetectIo<I>)>
pub(crate) async fn detect_sni<I>(mut io: I) -> io::Result<(Option<ServerName>, DetectIo<I>)>
where
I: io::Peek + io::AsyncRead + io::AsyncWrite + Send + Sync + Unpin,
{
Expand Down

0 comments on commit fa4c8dc

Please sign in to comment.