Skip to content

Commit

Permalink
Update mTLS providers to have both a Name and an Id
Browse files Browse the repository at this point in the history
  • Loading branch information
olix0r committed Nov 7, 2023
1 parent 8c35fe7 commit 897d53a
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 10 deletions.
2 changes: 1 addition & 1 deletion linkerd/app/gateway/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ impl Gateway {
.push_map_target(Target::discard_parent)
// Add headers to prevent loops.
.push(NewHttpGateway::layer(
self.inbound.identity().server_name().clone().into(),
self.inbound.identity().local_id().clone(),
))
.push_on_service(svc::LoadShed::layer())
.lift_new()
Expand Down
1 change: 1 addition & 0 deletions linkerd/app/src/identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ impl Config {
pub fn build(self, dns: dns::Resolver, client_metrics: ClientMetrics) -> Result<Identity> {
let name = self.documents.server_name.clone();
let (store, receiver) = Mode::default().watch(
name.clone().into(),
name.clone(),
&self.documents.trust_anchors_pem,
&self.documents.key_pkcs8,
Expand Down
8 changes: 5 additions & 3 deletions linkerd/meshtls/boring/src/creds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ use boring::{
};
use linkerd_dns_name as dns;
use linkerd_error::Result;
use linkerd_identity as id;
use std::sync::Arc;
use tokio::sync::watch;

pub fn watch(
identity: dns::Name,
local_id: id::Id,
server_name: dns::Name,
roots_pem: &str,
key_pkcs8: &[u8],
csr: &[u8],
Expand All @@ -25,8 +27,8 @@ pub fn watch(
};

let (tx, rx) = watch::channel(Creds::from(creds.clone()));
let rx = Receiver::new(identity.clone(), rx);
let store = Store::new(creds, csr, identity, tx);
let rx = Receiver::new(local_id, server_name.clone(), rx);
let store = Store::new(creds, csr, server_name, tx);

Ok((store, rx))
}
Expand Down
11 changes: 9 additions & 2 deletions linkerd/meshtls/boring/src/creds/receiver.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
use super::CredsRx;
use crate::{NewClient, Server};
use linkerd_dns_name as dns;
use linkerd_identity as id;

#[derive(Clone)]
pub struct Receiver {
id: id::Id,
name: dns::Name,
rx: CredsRx,
}

impl Receiver {
pub(crate) fn new(name: dns::Name, rx: CredsRx) -> Self {
Self { name, rx }
pub(crate) fn new(id: id::Id, name: dns::Name, rx: CredsRx) -> Self {
Self { id, name, rx }
}

/// Returns the local identity.
pub fn local_id(&self) -> &id::Id {
&self.id
}

/// Returns the mTLS Server Name.
pub fn server_name(&self) -> &dns::Name {
&self.name
}
Expand Down
5 changes: 4 additions & 1 deletion linkerd/meshtls/rustls/src/creds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod store;
pub use self::{receiver::Receiver, store::Store};
use linkerd_dns_name as dns;
use linkerd_error::Result;
use linkerd_identity as id;
use ring::{error::KeyRejected, signature::EcdsaKeyPair};
use std::sync::Arc;
use thiserror::Error;
Expand All @@ -20,6 +21,7 @@ pub struct InvalidKey(KeyRejected);
pub struct InvalidTrustRoots(());

pub fn watch(
local_id: id::Id,
server_name: dns::Name,
roots_pem: &str,
key_pkcs8: &[u8],
Expand Down Expand Up @@ -80,7 +82,7 @@ pub fn watch(
watch::channel(store::server_config(roots.clone(), empty_resolver))
};

let rx = Receiver::new(server_name.clone(), client_rx, server_rx);
let rx = Receiver::new(local_id, server_name.clone(), client_rx, server_rx);
let store = Store::new(
roots,
server_cert_verifier,
Expand All @@ -97,6 +99,7 @@ pub fn watch(
#[cfg(feature = "test-util")]
pub fn for_test(ent: &linkerd_tls_test_util::Entity) -> (Store, Receiver) {
watch(
ent.name.parse().expect("id must be valid"),
ent.name.parse().expect("name must be valid"),
std::str::from_utf8(ent.trust_anchors).expect("roots must be PEM"),
ent.key,
Expand Down
11 changes: 11 additions & 0 deletions linkerd/meshtls/rustls/src/creds/receiver.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use crate::{NewClient, Server};
use linkerd_dns_name as dns;
use linkerd_identity::Id;
use std::sync::Arc;
use tokio::sync::watch;
use tokio_rustls::rustls;

/// Receives TLS config updates to build `NewClient` and `Server` types.
#[derive(Clone)]
pub struct Receiver {
id: Id,
name: dns::Name,
client_rx: watch::Receiver<Arc<rustls::ClientConfig>>,
server_rx: watch::Receiver<Arc<rustls::ServerConfig>>,
Expand All @@ -16,18 +18,25 @@ pub struct Receiver {

impl Receiver {
pub(super) fn new(
id: Id,
name: dns::Name,
client_rx: watch::Receiver<Arc<rustls::ClientConfig>>,
server_rx: watch::Receiver<Arc<rustls::ServerConfig>>,
) -> Self {
Self {
id,
name,
client_rx,
server_rx,
}
}

/// Returns the local server name (i.e. used in mTLS).
pub fn local_id(&self) -> &Id {
&self.id
}

/// Returns the local server name (i.e. used for SNI).
pub fn server_name(&self) -> &dns::Name {
&self.name
}
Expand Down Expand Up @@ -86,6 +95,7 @@ mod tests {
let (_, client_rx) = watch::channel(Arc::new(empty_client_config()));
let receiver = Receiver {
name: "example".parse().unwrap(),
id: "example".parse().unwrap(),
server_rx,
client_rx,
};
Expand All @@ -108,6 +118,7 @@ mod tests {
let (server_tx, server_rx) = watch::channel(init_config.clone());
let (_, client_rx) = watch::channel(Arc::new(empty_client_config()));
let receiver = Receiver {
id: "example".parse().unwrap(),
name: "example".parse().unwrap(),
server_rx,
client_rx,
Expand Down
1 change: 1 addition & 0 deletions linkerd/meshtls/rustls/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::time::Duration;
fn load(ent: &Entity) -> crate::creds::Store {
let roots_pem = std::str::from_utf8(ent.trust_anchors).expect("valid PEM");
let (store, _) = crate::creds::watch(
ent.name.parse().unwrap(),
ent.name.parse().unwrap(),
roots_pem,
ent.key,
Expand Down
14 changes: 13 additions & 1 deletion linkerd/meshtls/src/creds.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{NewClient, Server};
use linkerd_dns_name as dns;
use linkerd_error::Result;
use linkerd_identity::{Credentials, DerX509};
use linkerd_identity::{Credentials, DerX509, Id};

#[cfg(feature = "boring")]
pub use crate::boring;
Expand Down Expand Up @@ -80,6 +80,18 @@ impl From<rustls::creds::Receiver> for Receiver {
}

impl Receiver {
pub fn local_id(&self) -> &Id {
match self {
#[cfg(feature = "boring")]
Self::Boring(receiver) => receiver.local_id(),

#[cfg(feature = "rustls")]
Self::Rustls(receiver) => receiver.local_id(),
#[cfg(not(feature = "__has_any_tls_impls"))]
_ => crate::no_tls!(),
}
}

pub fn server_name(&self) -> &dns::Name {
match self {
#[cfg(feature = "boring")]
Expand Down
6 changes: 4 additions & 2 deletions linkerd/meshtls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub use self::{
};
use linkerd_dns_name as dns;
use linkerd_error::{Error, Result};
use linkerd_identity as id;
use std::str::FromStr;

#[cfg(feature = "boring")]
Expand Down Expand Up @@ -82,6 +83,7 @@ impl Default for Mode {
impl Mode {
pub fn watch(
self,
local_id: id::Id,
server_name: dns::Name,
roots_pem: &str,
key_pkcs8: &[u8],
Expand All @@ -91,7 +93,7 @@ impl Mode {
#[cfg(feature = "boring")]
Self::Boring => {
let (store, receiver) =
boring::creds::watch(server_name, roots_pem, key_pkcs8, csr)?;
boring::creds::watch(local_id, server_name, roots_pem, key_pkcs8, csr)?;
Ok((
creds::Store::Boring(store),
creds::Receiver::Boring(receiver),
Expand All @@ -101,7 +103,7 @@ impl Mode {
#[cfg(feature = "rustls")]
Self::Rustls => {
let (store, receiver) =
rustls::creds::watch(server_name, roots_pem, key_pkcs8, csr)?;
rustls::creds::watch(local_id, server_name, roots_pem, key_pkcs8, csr)?;
Ok((
creds::Store::Rustls(store),
creds::Receiver::Rustls(receiver),
Expand Down
1 change: 1 addition & 0 deletions linkerd/meshtls/tests/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ fn load(
let roots_pem = std::str::from_utf8(ent.trust_anchors).expect("valid PEM");
let (mut store, rx) = mode
.watch(
ent.name.parse().unwrap(),
ent.name.parse().unwrap(),
roots_pem,
ent.key,
Expand Down

0 comments on commit 897d53a

Please sign in to comment.