Skip to content

Commit

Permalink
Suggestions for #3160 (#3272)
Browse files Browse the repository at this point in the history
* tls: Avoid InsertParam parameter.

We don't actually use InsertParam all that much--only in the TLS server (which
is obviously why it was included here). This change removes the InsertParam in
favor of using a tuple, generally reducing boilerplate.

It turns out that the TLS stack already has a map_target to handle turning the
tuple-target into a Tls type, so it shouldn't be needed.

* tls: Remove ExtractParam from detect_sni

Similarly, we don't actually care about extracting a timeout from the target.
Using an ExtractParam causes needless boilerplate.

This change updates the stack module to simply take a timeout parameter at
construction time.

* tls: Make the detect_tls module private

We now only need to export the NewDetectSni type. The module reexport is not
necessary.

* tls: Reorganize NewDetectRequiredSni under the server module

Because the DetectTls and DetectSni types are so similar -- and implemented in
the context of a server inspecting a provided connection (and not a client
establishing a TLS connection), this change reorganizes the module:

* The DetectSni types are renamed to DetectRequiredSni to better reflect their
  purpose and difference from the DetectTls type.
* The detect_sni module is renamed and moved to server::required_sni. This
  module is private and the relevant types are reexported from the server
  module.
  • Loading branch information
olix0r authored Oct 10, 2024
1 parent ff9b5c1 commit 6cb613d
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 160 deletions.
29 changes: 4 additions & 25 deletions linkerd/app/outbound/src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use linkerd_app_core::{
core::Resolve,
},
svc,
tls::{self, detect_sni::NewDetectSni, server::Timeout, ServerName},
tls::{NewDetectRequiredSni, ServerName},
transport::addrs::*,
Error,
};
Expand All @@ -25,9 +25,6 @@ struct Tls<T> {
parent: T,
}

#[derive(Clone)]
struct DetectParams(Timeout);

pub fn spawn_routes<T>(
mut route_rx: watch::Receiver<T>,
init: Routes,
Expand Down Expand Up @@ -97,13 +94,13 @@ impl<C> Outbound<C> {
.push_tls_concrete(resolve)
.push_tls_logical()
.map_stack(|config, _rt, stk| {
let detect_timeout = Timeout(config.proxy.detect_protocol_timeout);

stk.push_new_idle_cached(config.discovery_idle_timeout)
// Use a dedicated target type to configure parameters for
// the TLS stack. It also helps narrow the cache key.
.push_map_target(|(sni, parent): (ServerName, T)| Tls { sni, parent })
.push(NewDetectSni::layer(DetectParams(detect_timeout)))
.push(NewDetectRequiredSni::layer(
config.proxy.detect_protocol_timeout,
))
.arc_new_clone_tcp()
})
}
Expand All @@ -126,24 +123,6 @@ where
}
}

// === impl DetectParams ===

impl<T> svc::ExtractParam<tls::server::Timeout, T> for DetectParams {
#[inline]
fn extract_param(&self, _: &T) -> tls::server::Timeout {
self.0
}
}

impl<T> svc::InsertParam<ServerName, T> for DetectParams {
type Target = (ServerName, T);

#[inline]
fn insert_param(&self, sni: ServerName, target: T) -> Self::Target {
(sni, target)
}
}

// === impl TlsMetrics ===

impl TlsMetrics {
Expand Down
22 changes: 0 additions & 22 deletions linkerd/app/outbound/src/tls/logical/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use crate::test_util::*;
use linkerd_app_core::{
io,
svc::{self, NewService},
tls,
transport::addrs::*,
Result,
};
Expand Down Expand Up @@ -42,9 +41,6 @@ struct ConnectTcp {
srvs: Arc<Mutex<HashMap<SocketAddr, MockServer>>>,
}

#[derive(Clone)]
struct DetectParams;

// === impl MockServer ===

impl MockServer {
Expand Down Expand Up @@ -119,24 +115,6 @@ impl<T: svc::Param<Remote<ServerAddr>>> svc::Service<T> for ConnectTcp {
}
}

// === impl DetectParams ===

impl<T> svc::ExtractParam<tls::server::Timeout, T> for DetectParams {
#[inline]
fn extract_param(&self, _: &T) -> tls::server::Timeout {
tls::server::Timeout(Duration::from_secs(1))
}
}

impl<T> svc::InsertParam<tls::ServerName, T> for DetectParams {
type Target = (tls::ServerName, T);

#[inline]
fn insert_param(&self, sni: tls::ServerName, target: T) -> Self::Target {
(sni, target)
}
}

fn spawn_io(
client_hello: Vec<u8>,
) -> (
Expand Down
4 changes: 2 additions & 2 deletions linkerd/app/outbound/src/tls/logical/tests/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::*;
use crate::tls::Tls;
use linkerd_app_core::{
svc::ServiceExt,
tls::{NewDetectSni, ServerName},
tls::{NewDetectRequiredSni, ServerName},
trace, NameAddr,
};
use linkerd_proxy_client_policy as client_policy;
Expand Down Expand Up @@ -35,7 +35,7 @@ async fn routes() {
.map_stack(|config, _rt, stk| {
stk.push_new_idle_cached(config.discovery_idle_timeout)
.push_map_target(|(sni, parent): (ServerName, _)| Tls { sni, parent })
.push(NewDetectSni::layer(DetectParams))
.push(NewDetectRequiredSni::layer(Duration::from_secs(1)))
.arc_new_clone_tcp()
})
.into_inner();
Expand Down
107 changes: 0 additions & 107 deletions linkerd/tls/src/detect_sni.rs

This file was deleted.

7 changes: 4 additions & 3 deletions linkerd/tls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
#![forbid(unsafe_code)]

pub mod client;
pub mod detect_sni;
pub mod server;

pub use self::{
client::{Client, ClientTls, ConditionalClientTls, ConnectMeta, NoClientTls, ServerId},
detect_sni::NewDetectSni,
server::{ClientId, ConditionalServerTls, NewDetectTls, NoServerTls, ServerTls},
server::{
ClientId, ConditionalServerTls, NewDetectRequiredSni, NewDetectTls, NoServerTls,
NoSniFoundError, ServerTls, SniDetectionTimeoutError,
},
};

use linkerd_dns_name as dns;
Expand Down
6 changes: 5 additions & 1 deletion linkerd/tls/src/server.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod client_hello;
mod required_sni;

use crate::{NegotiatedProtocol, ServerName};
use bytes::BytesMut;
Expand All @@ -18,6 +19,8 @@ use thiserror::Error;
use tokio::time::{self, Duration};
use tracing::{debug, trace, warn};

pub use self::required_sni::{NewDetectRequiredSni, NoSniFoundError, SniDetectionTimeoutError};

/// Describes the authenticated identity of a remote client.
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub struct ClientId(pub id::Id);
Expand Down Expand Up @@ -65,6 +68,7 @@ pub struct NewDetectTls<L, P, N> {
_local_identity: std::marker::PhantomData<fn() -> L>,
}

/// A param type used to indicate the timeout after which detection should fail.
#[derive(Copy, Clone, Debug)]
pub struct Timeout(pub Duration);

Expand Down Expand Up @@ -192,7 +196,7 @@ where
}

/// Peek or buffer the provided stream to determine an SNI value.
pub(crate) async fn detect_sni<I>(mut io: I) -> io::Result<(Option<ServerName>, DetectIo<I>)>
async fn detect_sni<I>(mut io: I) -> io::Result<(Option<ServerName>, DetectIo<I>)>
where
I: io::Peek + io::AsyncRead + io::AsyncWrite + Send + Sync + Unpin,
{
Expand Down
116 changes: 116 additions & 0 deletions linkerd/tls/src/server/required_sni.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use crate::{
server::{detect_sni, DetectIo},
ServerName,
};
use linkerd_error::Error;
use linkerd_io as io;
use linkerd_stack::{layer, NewService, Service, ServiceExt};
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use thiserror::Error;
use tokio::time;
use tracing::debug;

#[derive(Clone, Debug, Error)]
#[error("SNI detection timed out")]
pub struct SniDetectionTimeoutError;

#[derive(Clone, Debug, Error)]
#[error("Could not find SNI")]
pub struct NoSniFoundError;

/// A NewService that instruments an inner stack with knowledge of the
/// connection's TLS ServerName (i.e. from an SNI header).
///
/// This differs from the parent module's NewDetectTls in a a few ways:
///
/// - It requires that all connections have an SNI.
/// - It assumes that these connections may not be terminated locally, so there
/// is no concept of a local server name.
/// - There are no special affordances for mutually authenticated TLS, so we
/// make no attempt to detect the client's identity.
/// - The detection timeout is fixed and cannot vary per target (for
/// convenience, to reduce needless boilerplate).
#[derive(Clone, Debug)]
pub struct NewDetectRequiredSni<N> {
inner: N,
timeout: time::Duration,
}

#[derive(Clone, Debug)]
pub struct DetectRequiredSni<T, N> {
target: T,
inner: N,
timeout: time::Duration,
}

impl<N> NewDetectRequiredSni<N> {
fn new(timeout: time::Duration, inner: N) -> Self {
Self { inner, timeout }
}

pub fn layer(timeout: time::Duration) -> impl layer::Layer<N, Service = Self> + Clone {
layer::mk(move |inner| Self::new(timeout, inner))
}
}

impl<T, N> NewService<T> for NewDetectRequiredSni<N>
where
N: Clone,
{
type Service = DetectRequiredSni<T, N>;

fn new_service(&self, target: T) -> Self::Service {
DetectRequiredSni::new(self.timeout, target, self.inner.clone())
}
}

// === impl DetectSni ===

impl<T, N> DetectRequiredSni<T, N> {
fn new(timeout: time::Duration, target: T, inner: N) -> Self {
Self {
target,
inner,
timeout,
}
}
}

impl<T, I, N, S> Service<I> for DetectRequiredSni<T, N>
where
T: Clone + Send + Sync + 'static,
I: io::AsyncRead + io::Peek + io::AsyncWrite + Send + Sync + Unpin + 'static,
N: NewService<(ServerName, T), Service = S> + Clone + Send + 'static,
S: Service<DetectIo<I>> + Send,
S::Error: Into<Error>,
S::Future: Send,
{
type Response = S::Response;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<S::Response, Error>> + Send + 'static>>;

#[inline]
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}

fn call(&mut self, io: I) -> Self::Future {
let target = self.target.clone();
let new_accept = self.inner.clone();

// Detect the SNI from a ClientHello (or timeout).
let detect = time::timeout(self.timeout, detect_sni(io));
Box::pin(async move {
let (res, io) = detect.await.map_err(|_| SniDetectionTimeoutError)??;
let sni = res.ok_or(NoSniFoundError)?;
debug!(?sni, "Detected TLS");

let svc = new_accept.new_service((sni, target));
svc.oneshot(io).await.map_err(Into::into)
})
}
}

0 comments on commit 6cb613d

Please sign in to comment.