diff --git a/Cargo.lock b/Cargo.lock index 21e709129f..d27497b255 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1145,6 +1145,7 @@ dependencies = [ "linkerd-proxy-identity-client", "linkerd-proxy-resolve", "linkerd-proxy-server-policy", + "linkerd-proxy-spire-client", "linkerd-proxy-tap", "linkerd-proxy-tcp", "linkerd-proxy-transport", @@ -1583,6 +1584,7 @@ dependencies = [ "linkerd-tls-test-util", "linkerd-tracing", "pin-project", + "rcgen 0.11.3", "tokio", "tracing", ] @@ -1636,7 +1638,7 @@ version = "0.1.0" dependencies = [ "linkerd-error", "linkerd-identity", - "rcgen", + "rcgen 0.12.0", "tracing", "x509-parser", ] @@ -1903,6 +1905,29 @@ dependencies = [ "thiserror", ] +[[package]] +name = "linkerd-proxy-spire-client" +version = "0.1.0" +dependencies = [ + "futures", + "linkerd-error", + "linkerd-exp-backoff", + "linkerd-identity", + "linkerd-proxy-http", + "linkerd-stack", + "linkerd-tonic-watch", + "rcgen 0.11.3", + "simple_asn1", + "spiffe-proto", + "thiserror", + "tokio", + "tokio-test", + "tonic", + "tower", + "tracing", + "x509-parser", +] + [[package]] name = "linkerd-proxy-tap" version = "0.1.0" @@ -2766,6 +2791,18 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "977b1e897f9d764566891689e642653e5ed90c6895106acd005eb4c1d0203991" +[[package]] +name = "rcgen" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52c4f3084aa3bc7dfbba4eff4fab2a54db4324965d8872ab933565e6fbd83bc6" +dependencies = [ + "pem", + "ring 0.16.20", + "time", + "yasna", +] + [[package]] name = "rcgen" version = "0.12.0" @@ -3059,6 +3096,18 @@ dependencies = [ "libc", ] +[[package]] +name = "simple_asn1" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adc4e5204eb1910f40f9cfa375f6f05b68c3abac4b6fd879c8ff5e7ae8a0a085" +dependencies = [ + "num-bigint", + "num-traits", + "thiserror", + "time", +] + [[package]] name = "slab" version = "0.4.9" @@ -3094,6 +3143,17 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "spiffe-proto" +version = "0.1.0" +dependencies = [ + "bytes", + "prost", + "prost-types", + "tonic", + "tonic-build", +] + [[package]] name = "spin" version = "0.5.2" diff --git a/Cargo.toml b/Cargo.toml index f4f8c92a5f..870e4a9552 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,6 +50,7 @@ members = [ "linkerd/proxy/dns-resolve", "linkerd/proxy/http", "linkerd/proxy/identity-client", + "linkerd/proxy/spire-client", "linkerd/proxy/resolve", "linkerd/proxy/server-policy", "linkerd/proxy/tap", @@ -73,6 +74,7 @@ members = [ "linkerd/transport-metrics", "linkerd2-proxy", "opencensus-proto", + "spiffe-proto", "tools", ] diff --git a/linkerd/app/core/Cargo.toml b/linkerd/app/core/Cargo.toml index 2a7090ae0b..b11e5151c7 100644 --- a/linkerd/app/core/Cargo.toml +++ b/linkerd/app/core/Cargo.toml @@ -43,6 +43,7 @@ linkerd-proxy-client-policy = { path = "../../proxy/client-policy" } linkerd-proxy-dns-resolve = { path = "../../proxy/dns-resolve" } linkerd-proxy-http = { path = "../../proxy/http" } linkerd-proxy-identity-client = { path = "../../proxy/identity-client" } +linkerd-proxy-spire-client = { path = "../../proxy/spire-client" } linkerd-proxy-resolve = { path = "../../proxy/resolve" } linkerd-proxy-server-policy = { path = "../../proxy/server-policy" } linkerd-proxy-tap = { path = "../../proxy/tap" } diff --git a/linkerd/app/core/src/lib.rs b/linkerd/app/core/src/lib.rs index b7c1779699..668e04de9c 100644 --- a/linkerd/app/core/src/lib.rs +++ b/linkerd/app/core/src/lib.rs @@ -50,7 +50,10 @@ pub use linkerd_transport_header as transport_header; pub mod identity { pub use linkerd_identity::*; pub use linkerd_meshtls::*; - pub use linkerd_proxy_identity_client as client; + pub mod client { + pub use linkerd_proxy_identity_client as linkerd; + pub use linkerd_proxy_spire_client as spire; + } } pub const CANONICAL_DST_HEADER: &str = "l5d-dst-canonical"; diff --git a/linkerd/app/src/env.rs b/linkerd/app/src/env.rs index d9a21703dd..cd7ccc4327 100644 --- a/linkerd/app/src/env.rs +++ b/linkerd/app/src/env.rs @@ -793,7 +793,7 @@ pub fn parse_config(strings: &S) -> Result .unwrap_or(super::tap::Config::Disabled); let identity = { - let (addr, certify, params) = identity_config?; + let (addr, certify, tls) = identity_config?; // If the address doesn't have a server identity, then we're on localhost. let connect = if addr.addr.is_loopback() { inbound.proxy.connect.clone() @@ -805,9 +805,10 @@ pub fn parse_config(strings: &S) -> Result } else { outbound.http_request_queue.failfast_timeout }; - identity::Config { + identity::Config::Linkerd { certify, - control: ControlConfig { + tls, + client: ControlConfig { addr, connect, buffer: QueueConfig { @@ -815,7 +816,6 @@ pub fn parse_config(strings: &S) -> Result failfast_timeout, }, }, - params, } }; @@ -1215,7 +1215,14 @@ pub fn parse_control_addr( pub fn parse_identity_config( strings: &S, -) -> Result<(ControlAddr, identity::certify::Config, identity::TlsParams), EnvError> { +) -> Result< + ( + ControlAddr, + identity::client::linkerd::Config, + identity::TlsParams, + ), + EnvError, +> { let control = parse_control_addr(strings, ENV_IDENTITY_SVC_BASE); let ta = parse(strings, ENV_IDENTITY_TRUST_ANCHORS, |s| { if s.is_empty() { @@ -1225,7 +1232,7 @@ pub fn parse_identity_config( }); let dir = parse(strings, ENV_IDENTITY_DIR, |ref s| Ok(PathBuf::from(s))); let tok = parse(strings, ENV_IDENTITY_TOKEN_FILE, |ref s| { - identity::TokenSource::if_nonempty_file(s.to_string()).map_err(|e| { + identity::client::linkerd::TokenSource::if_nonempty_file(s.to_string()).map_err(|e| { error!("Could not read {ENV_IDENTITY_TOKEN_FILE}: {e}"); ParseError::InvalidTokenSource }) @@ -1253,17 +1260,19 @@ pub fn parse_identity_config( min_refresh, max_refresh, ) => { - let certify = identity::certify::Config { + let certify = identity::client::linkerd::Config { token, min_refresh: min_refresh.unwrap_or(DEFAULT_IDENTITY_MIN_REFRESH), max_refresh: max_refresh.unwrap_or(DEFAULT_IDENTITY_MAX_REFRESH), - documents: identity::certify::Documents::load(dir).map_err(|error| { - error!(%error, "Failed to read identity documents"); - EnvError::InvalidEnvVar - })?, + documents: identity::client::linkerd::certify::Documents::load(dir).map_err( + |error| { + error!(%error, "Failed to read identity documents"); + EnvError::InvalidEnvVar + }, + )?, }; let params = identity::TlsParams { - server_id: identity::Id::Dns(local_name.clone()), + id: identity::Id::Dns(local_name.clone()), server_name: local_name, trust_anchors_pem, }; diff --git a/linkerd/app/src/identity.rs b/linkerd/app/src/identity.rs index 7b231e54a1..6bab7dd0cd 100644 --- a/linkerd/app/src/identity.rs +++ b/linkerd/app/src/identity.rs @@ -1,11 +1,12 @@ -pub use linkerd_app_core::identity::{ - client::{certify, TokenSource}, - Id, -}; +use crate::spire; + +pub use linkerd_app_core::identity::{client, Id}; use linkerd_app_core::{ control, dns, exp_backoff::{ExponentialBackoff, ExponentialBackoffStream}, - identity::{client::Certify, creds, CertMetrics, Credentials, DerX509, Mode, WithCertMetrics}, + identity::{ + client::linkerd::Certify, creds, CertMetrics, Credentials, DerX509, Mode, WithCertMetrics, + }, metrics::{prom, ControlHttp as ClientMetrics}, Error, Result, }; @@ -13,22 +14,32 @@ use std::{future::Future, pin::Pin, time::SystemTime}; use tokio::sync::watch; use tracing::Instrument; +#[derive(Debug, thiserror::Error)] +#[error("linkerd identity requires a TLS Id and server name to be the same")] +pub struct TlsIdAndServerNameNotMatching(()); + #[derive(Clone, Debug)] -pub struct Config { - pub control: control::Config, - pub certify: certify::Config, - pub params: TlsParams, +#[allow(clippy::large_enum_variant)] +pub enum Config { + Linkerd { + client: control::Config, + certify: client::linkerd::Config, + tls: TlsParams, + }, + Spire { + client: spire::Config, + tls: TlsParams, + }, } #[derive(Clone, Debug)] pub struct TlsParams { - pub server_id: Id, + pub id: Id, pub server_name: dns::Name, pub trust_anchors_pem: String, } pub struct Identity { - addr: control::ControlAddr, receiver: creds::Receiver, ready: watch::Receiver, task: Task, @@ -55,47 +66,82 @@ impl Config { client_metrics: ClientMetrics, registry: &mut prom::Registry, ) -> Result { - let name = self.params.server_name.clone(); - let (store, receiver) = Mode::default().watch( - name.clone().into(), - name.clone(), - &self.params.trust_anchors_pem, - )?; - - let certify = Certify::from(self.certify); - - let addr = self.control.addr.clone(); - - let (tx, ready) = watch::channel(false); - - // Save to be spawned on an auxiliary runtime. - let task = Box::pin({ - let addr = addr.clone(); - let svc = self.control.build( - dns, - client_metrics, - registry.sub_registry_with_prefix("control_identity"), - receiver.new_client(), - ); - - let cert_metrics = - CertMetrics::register(registry.sub_registry_with_prefix("identity_cert")); - let cred = WithCertMetrics::new(cert_metrics, NotifyReady { store, tx }); - - certify - .run(name, cred, svc) - .instrument(tracing::debug_span!("identity", server.addr = %addr).or_current()) - }); - - Ok(Identity { - addr, - receiver, - ready, - task, + let cert_metrics = + CertMetrics::register(registry.sub_registry_with_prefix("identity_cert")); + + Ok(match self { + Self::Linkerd { + client, + certify, + tls, + } => { + // TODO: move this validation into env.rs + let name = match (&tls.id, &tls.server_name) { + (Id::Dns(id), sni) if id == sni => id.clone(), + (_id, _sni) => { + return Err(TlsIdAndServerNameNotMatching(()).into()); + } + }; + + let certify = Certify::from(certify); + let (store, receiver, ready) = watch(tls, cert_metrics)?; + + let task = { + let addr = client.addr.clone(); + let svc = client.build( + dns, + client_metrics, + registry.sub_registry_with_prefix("control_identity"), + receiver.new_client(), + ); + + Box::pin(certify.run(name, store, svc).instrument( + tracing::info_span!("identity", server.addr = %addr).or_current(), + )) + }; + Identity { + receiver, + ready, + task, + } + } + Self::Spire { client, tls } => { + let addr = client.socket_addr.clone(); + let spire = spire::client::Spire::new(tls.id.clone()); + + let (store, receiver, ready) = watch(tls, cert_metrics)?; + let task = + Box::pin(spire.run(store, spire::Client::from(client)).instrument( + tracing::info_span!("spire", server.addr = %addr).or_current(), + )); + + Identity { + receiver, + ready, + task, + } + } }) } } +fn watch( + tls: TlsParams, + metrics: CertMetrics, +) -> Result<( + WithCertMetrics, + creds::Receiver, + watch::Receiver, +)> { + let (tx, ready) = watch::channel(false); + let (store, receiver) = + Mode::default().watch(tls.id, tls.server_name, &tls.trust_anchors_pem)?; + let cred = WithCertMetrics::new(metrics, NotifyReady { store, tx }); + Ok((cred, receiver, ready)) +} + +// === impl NotifyReady === + impl Credentials for NotifyReady { fn set_certificate( &mut self, @@ -113,10 +159,6 @@ impl Credentials for NotifyReady { // === impl Identity === impl Identity { - pub fn addr(&self) -> control::ControlAddr { - self.addr.clone() - } - /// Returns a future that is satisfied once certificates have been provisioned. pub fn ready(&self) -> Pin + Send + 'static>> { let mut ready = self.ready.clone(); diff --git a/linkerd/app/src/lib.rs b/linkerd/app/src/lib.rs index 8ce98b8340..bc559263cf 100644 --- a/linkerd/app/src/lib.rs +++ b/linkerd/app/src/lib.rs @@ -9,6 +9,7 @@ pub mod env; pub mod identity; pub mod oc_collector; pub mod policy; +pub mod spire; pub mod tap; pub use self::metrics::Metrics; @@ -340,8 +341,8 @@ impl App { self.identity.receiver().server_name().clone() } - pub fn identity_addr(&self) -> ControlAddr { - self.identity.addr() + pub fn local_tls_id(&self) -> identity::Id { + self.identity.receiver().local_id().clone() } pub fn opencensus_addr(&self) -> Option<&ControlAddr> { @@ -389,7 +390,7 @@ impl App { // Kick off the identity so that the process can become ready. let local = identity.receiver(); - let local_name = local.server_name().clone(); + let local_id = local.local_id().clone(); let ready = identity.ready(); tokio::spawn( identity @@ -402,7 +403,7 @@ impl App { ready .map(move |()| { latch.release(); - info!(id = %local_name, "Certified identity"); + info!(id = %local_id, "Certified identity"); }) .instrument(info_span!("identity").or_current()), ); diff --git a/linkerd/app/src/spire.rs b/linkerd/app/src/spire.rs new file mode 100644 index 0000000000..f1f449c2e6 --- /dev/null +++ b/linkerd/app/src/spire.rs @@ -0,0 +1,68 @@ +use linkerd_app_core::{exp_backoff::ExponentialBackoff, Error}; +use std::sync::Arc; +use tokio::net::UnixStream; +use tokio::sync::watch; +use tonic::transport::{Endpoint, Uri}; + +pub use linkerd_app_core::identity::client::spire as client; + +const UNIX_PREFIX: &str = "unix:"; +const TONIC_DEFAULT_URI: &str = "http://[::]:50051"; + +#[derive(Clone, Debug)] +pub struct Config { + pub(crate) socket_addr: Arc, + pub(crate) backoff: ExponentialBackoff, +} + +// Connects to SPIRE workload API via Unix Domain Socket +pub struct Client { + config: Config, +} + +// === impl Client === + +impl From for Client { + fn from(config: Config) -> Self { + Self { config } + } +} + +impl tower::Service<()> for Client { + type Response = tonic::Response>; + type Error = Error; + type Future = futures::future::BoxFuture<'static, Result>; + + fn poll_ready( + &mut self, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: ()) -> Self::Future { + let socket = self.config.socket_addr.clone(); + let backoff = self.config.backoff; + Box::pin(async move { + // Strip the 'unix:' prefix for tonic compatibility. + let stripped_path = socket + .strip_prefix(UNIX_PREFIX) + .unwrap_or(socket.as_str()) + .to_string(); + + // We will ignore this uri because uds do not use it + // if your connector does use the uri it will be provided + // as the request to the `MakeConnection`. + let chan = Endpoint::try_from(TONIC_DEFAULT_URI)? + .connect_with_connector(tower::util::service_fn(move |_: Uri| { + UnixStream::connect(stripped_path.clone()) + })) + .await?; + + let api = client::Api::watch(chan, backoff); + let receiver = api.spawn_watch(()).await?; + + Ok(receiver) + }) + } +} diff --git a/linkerd/meshtls/Cargo.toml b/linkerd/meshtls/Cargo.toml index a087bd1841..41b7495324 100644 --- a/linkerd/meshtls/Cargo.toml +++ b/linkerd/meshtls/Cargo.toml @@ -29,6 +29,7 @@ linkerd-tls = { path = "../tls" } [dev-dependencies] tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] } tracing = "0.1" +rcgen = "0.11.3" linkerd-conditional = { path = "../conditional" } linkerd-proxy-transport = { path = "../proxy/transport" } diff --git a/linkerd/meshtls/tests/boring.rs b/linkerd/meshtls/tests/boring.rs index 839e3dc87a..0aecc52ae9 100644 --- a/linkerd/meshtls/tests/boring.rs +++ b/linkerd/meshtls/tests/boring.rs @@ -6,6 +6,11 @@ mod util; use linkerd_meshtls::Mode; +#[test] +fn fails_processing_cert_when_wrong_id_configured() { + util::fails_processing_cert_when_wrong_id_configured(Mode::Boring); +} + #[tokio::test(flavor = "current_thread")] async fn plaintext() { util::plaintext(Mode::Boring).await; diff --git a/linkerd/meshtls/tests/rustls.rs b/linkerd/meshtls/tests/rustls.rs index bd9eed32c7..ac8eff9177 100644 --- a/linkerd/meshtls/tests/rustls.rs +++ b/linkerd/meshtls/tests/rustls.rs @@ -6,6 +6,11 @@ mod util; use linkerd_meshtls::Mode; +#[test] +fn fails_processing_cert_when_wrong_id_configured() { + util::fails_processing_cert_when_wrong_id_configured(Mode::Rustls); +} + #[tokio::test(flavor = "current_thread")] async fn plaintext() { util::plaintext(Mode::Rustls).await; diff --git a/linkerd/meshtls/tests/util.rs b/linkerd/meshtls/tests/util.rs index 5380af1145..65fbecf4a0 100644 --- a/linkerd/meshtls/tests/util.rs +++ b/linkerd/meshtls/tests/util.rs @@ -3,8 +3,9 @@ use futures::prelude::*; use linkerd_conditional::Conditional; +use linkerd_dns_name::Name; use linkerd_error::Infallible; -use linkerd_identity::{Credentials, DerX509}; +use linkerd_identity::{Credentials, DerX509, Id}; use linkerd_io::{self as io, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use linkerd_meshtls as meshtls; use linkerd_proxy_transport::{ @@ -17,6 +18,8 @@ use linkerd_stack::{ }; use linkerd_tls as tls; use linkerd_tls_test_util as test_util; +use rcgen::{BasicConstraints, Certificate, CertificateParams, IsCa, SanType}; +use std::str::FromStr; use std::{ future::Future, net::SocketAddr, @@ -26,6 +29,44 @@ use std::{ use tokio::net::TcpStream; use tracing::Instrument; +fn generate_cert_with_name(subject_alt_names: Vec) -> (Vec, Vec, String) { + let mut root_params = CertificateParams::default(); + root_params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained); + let root_cert = Certificate::from_params(root_params).expect("should generate root"); + + let mut params = CertificateParams::default(); + params.subject_alt_names = subject_alt_names; + + let cert = Certificate::from_params(params).expect("should generate cert"); + + ( + cert.serialize_der_with_signer(&root_cert) + .expect("should serialize"), + cert.serialize_private_key_der(), + root_cert.serialize_pem().expect("should serialize"), + ) +} + +pub fn fails_processing_cert_when_wrong_id_configured(mode: meshtls::Mode) { + let server_name = Name::from_str("system.local").expect("should parse"); + let id = Id::Dns(server_name.clone()); + + let (cert, key, roots) = + generate_cert_with_name(vec![SanType::URI("spiffe://system/local".into())]); + let (mut store, _) = mode + .watch(id, server_name.clone(), &roots) + .expect("should construct"); + + let err = store + .set_certificate(DerX509(cert), vec![], key, SystemTime::now()) + .expect_err("should error"); + + assert_eq!( + "certificate does not match TLS identity", + format!("{}", err), + ); +} + pub async fn plaintext(mode: meshtls::Mode) { let (_foo, _, server_tls) = load(mode, &test_util::FOO_NS1); let (_bar, client_tls, _) = load(mode, &test_util::BAR_NS1); diff --git a/linkerd/proxy/identity-client/src/lib.rs b/linkerd/proxy/identity-client/src/lib.rs index b2dfce139a..c1c95ac1af 100644 --- a/linkerd/proxy/identity-client/src/lib.rs +++ b/linkerd/proxy/identity-client/src/lib.rs @@ -4,4 +4,7 @@ pub mod certify; mod token; -pub use self::{certify::Certify, token::TokenSource}; +pub use self::{ + certify::{Certify, Config}, + token::TokenSource, +}; diff --git a/linkerd/proxy/spire-client/Cargo.toml b/linkerd/proxy/spire-client/Cargo.toml new file mode 100644 index 0000000000..76c8c6fd41 --- /dev/null +++ b/linkerd/proxy/spire-client/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "linkerd-proxy-spire-client" +version = "0.1.0" +authors = ["Linkerd Developers "] +license = "Apache-2.0" +edition = "2021" +publish = false + +[dependencies] +futures = { version = "0.3", default-features = false } +linkerd-error = { path = "../../error" } +linkerd-proxy-http = { path = "../../proxy/http" } +linkerd-identity = { path = "../../identity" } +spiffe-proto = { path = "../../../spiffe-proto" } +linkerd-tonic-watch = { path = "../../tonic-watch" } +linkerd-exp-backoff = { path = "../../exp-backoff" } +linkerd-stack = { path = "../../stack" } +tokio = { version = "1", features = ["time", "sync"] } +tonic = "0.10" +tower = "0.4" +tracing = "0.1" +x509-parser = "0.15.1" +asn1 = { version = "0.6", package = "simple_asn1" } +thiserror = "1" + +[dev-dependencies] +rcgen = "0.11.3" +tokio-test = "0.4" diff --git a/linkerd/proxy/spire-client/src/api.rs b/linkerd/proxy/spire-client/src/api.rs new file mode 100644 index 0000000000..3c3ec16d4f --- /dev/null +++ b/linkerd/proxy/spire-client/src/api.rs @@ -0,0 +1,268 @@ +use futures::prelude::*; +use linkerd_error::{Error, Recover, Result}; +use linkerd_exp_backoff::{ExponentialBackoff, ExponentialBackoffStream}; +use linkerd_identity::DerX509; +use linkerd_identity::{Credentials, Id}; +use linkerd_proxy_http as http; +use linkerd_tonic_watch::StreamWatch; +use spiffe_proto::client::{ + self as api, spiffe_workload_api_client::SpiffeWorkloadApiClient as Client, +}; +use std::collections::HashMap; +use std::time::{Duration, UNIX_EPOCH}; +use tower::Service; +use tracing::error; + +const SPIFFE_HEADER_KEY: &str = "workload.spiffe.io"; +const SPIFFE_HEADER_VALUE: &str = "true"; + +#[derive(Debug, thiserror::Error)] +#[error("no matching SVID found")] +pub struct NoMatchingSVIDFound(()); + +#[derive(Clone)] +pub struct Svid { + pub(super) spiffe_id: Id, + leaf: DerX509, + private_key: Vec, + intermediates: Vec, +} + +#[derive(Clone)] +pub struct SvidUpdate { + svids: HashMap, +} + +#[derive(Clone, Debug)] +pub struct Api { + client: Client, +} + +#[derive(Clone)] +pub struct GrpcRecover(ExponentialBackoff); + +pub type Watch = StreamWatch>; + +// === impl Svid === + +impl SvidUpdate { + pub(super) fn new(svids: Vec) -> Self { + let mut svids_map = HashMap::default(); + for svid in svids.into_iter() { + svids_map.insert(svid.spiffe_id.clone(), svid); + } + + SvidUpdate { svids: svids_map } + } +} + +// === impl Svid === + +impl Svid { + #[cfg(test)] + pub(super) fn new( + spiffe_id: Id, + leaf: DerX509, + private_key: Vec, + intermediates: Vec, + ) -> Self { + Self { + spiffe_id, + leaf, + private_key, + intermediates, + } + } +} + +impl TryFrom for Svid { + // TODO: Use bundles from response to compare against + // what is provided at bootstrap time + + type Error = Error; + fn try_from(proto: api::X509svid) -> Result { + if proto.x509_svid_key.is_empty() { + return Err("empty private key".into()); + } + + let cert_der_blocks = asn1::from_der(&proto.x509_svid)?; + let (leaf, intermediates) = match cert_der_blocks.split_first() { + None => return Err("empty cert chain".into()), + Some((leaf_block, intermediates_block)) => { + let leaf = DerX509(asn1::to_der(leaf_block)?); + let mut intermediates = vec![]; + for block in intermediates_block.iter() { + let cert_der = asn1::to_der(block)?; + intermediates.push(DerX509(cert_der)); + } + (leaf, intermediates) + } + }; + + let spiffe_id = Id::parse_uri(&proto.spiffe_id)?; + + Ok(Svid { + spiffe_id, + leaf, + private_key: proto.x509_svid_key, + intermediates: intermediates.to_vec(), + }) + } +} + +// === impl Api === + +impl Api +where + S: tonic::client::GrpcService + Clone, + S::Error: Into, + S::ResponseBody: Default + http::HttpBody + Send + 'static, + ::Error: Into + Send, +{ + pub fn watch(client: S, backoff: ExponentialBackoff) -> Watch { + let client = Client::new(client); + StreamWatch::new(GrpcRecover(backoff), Self { client }) + } +} + +impl Service<()> for Api +where + S: tonic::client::GrpcService + Clone, + S: Clone + Send + Sync + 'static, + S::ResponseBody: Default + http::HttpBody + Send + 'static, + ::Error: Into + Send, + S::Future: Send + 'static, +{ + type Response = + tonic::Response>>; + type Error = tonic::Status; + type Future = futures::future::BoxFuture<'static, Result>; + + fn poll_ready( + &mut self, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + fn call(&mut self, _: ()) -> Self::Future { + let req = api::X509svidRequest {}; + let mut client = self.client.clone(); + Box::pin(async move { + let parsed_header = SPIFFE_HEADER_VALUE + .parse() + .map_err(|e| tonic::Status::internal(format!("Failed to parse header: {}", e)))?; + + let mut req = tonic::Request::new(req); + req.metadata_mut().insert(SPIFFE_HEADER_KEY, parsed_header); + + let rsp = client.fetch_x509svid(req).await?; + Ok(rsp.map(|svids| { + svids + .map_ok(move |s| { + let svids = s + .svids + .into_iter() + .filter_map(|proto| { + proto + .try_into() + .map_err(|err| error!("could not parse SVID: {}", err)) + .ok() + }) + .collect(); + + SvidUpdate::new(svids) + }) + .boxed() + })) + }) + } +} + +// === impl GrpcRecover === + +impl Recover for GrpcRecover { + type Backoff = ExponentialBackoffStream; + + fn recover(&self, status: tonic::Status) -> Result { + // Non retriable conditions described in: + // https://github.com/spiffe/spiffe/blob/a5b6456ff1bcdb6935f61ed7f83e8ee533a325a3/standards/SPIFFE_Workload_API.md#client-state-machine + if status.code() == tonic::Code::InvalidArgument { + return Err(status); + } + + tracing::warn!( + grpc.status = %status.code(), + grpc.message = status.message(), + "Unexpected SPIRE Workload API response; retrying with a backoff", + ); + Ok(self.0.stream()) + } +} + +pub(super) fn process_svid(credentials: &mut C, mut update: SvidUpdate, id: &Id) -> Result<()> +where + C: Credentials, +{ + if let Some(svid) = update.svids.remove(id) { + use x509_parser::prelude::*; + + let (_, parsed_cert) = X509Certificate::from_der(&svid.leaf.0)?; + let exp: u64 = parsed_cert.validity().not_after.timestamp().try_into()?; + let exp = UNIX_EPOCH + Duration::from_secs(exp); + + return credentials.set_certificate(svid.leaf, svid.intermediates, svid.private_key, exp); + } + + Err(NoMatchingSVIDFound(()).into()) +} + +#[cfg(test)] +mod tests { + use crate::api::Svid; + use rcgen::{Certificate, CertificateParams, SanType}; + use spiffe_proto::client as api; + + fn gen_svid_pb(id: String, subject_alt_names: Vec) -> api::X509svid { + let mut params = CertificateParams::default(); + params.subject_alt_names = subject_alt_names; + let cert = Certificate::from_params(params).expect("should generate cert"); + + api::X509svid { + spiffe_id: id, + x509_svid: cert.serialize_der().expect("should serialize"), + x509_svid_key: cert.serialize_private_key_der(), + bundle: Vec::default(), + } + } + + #[test] + fn can_parse_valid_proto() { + let id = "spiffe://some-domain/some-workload"; + let svid_pb = gen_svid_pb(id.into(), vec![SanType::URI(id.into())]); + assert!(Svid::try_from(svid_pb).is_ok()); + } + + #[test] + fn cannot_parse_non_spiffe_id() { + let id = "some-domain.some-workload"; + let svid_pb = gen_svid_pb(id.into(), vec![SanType::DnsName(id.into())]); + assert!(Svid::try_from(svid_pb).is_err()); + } + + #[test] + fn cannot_parse_empty_cert() { + let id = "spiffe://some-domain/some-workload"; + let mut svid_pb = gen_svid_pb(id.into(), vec![SanType::URI(id.into())]); + svid_pb.x509_svid = Vec::default(); + assert!(Svid::try_from(svid_pb).is_err()); + } + + #[test] + fn cannot_parse_empty_key() { + let id = "spiffe://some-domain/some-workload"; + let mut svid_pb = gen_svid_pb(id.into(), vec![SanType::URI(id.into())]); + svid_pb.x509_svid_key = Vec::default(); + assert!(Svid::try_from(svid_pb).is_err()); + } +} diff --git a/linkerd/proxy/spire-client/src/lib.rs b/linkerd/proxy/spire-client/src/lib.rs new file mode 100644 index 0000000000..4987209147 --- /dev/null +++ b/linkerd/proxy/spire-client/src/lib.rs @@ -0,0 +1,257 @@ +#![deny(rust_2018_idioms, clippy::disallowed_methods, clippy::disallowed_types)] +#![forbid(unsafe_code)] + +mod api; + +pub use api::{Api, SvidUpdate}; +use linkerd_error::Error; +use linkerd_identity::Credentials; +use linkerd_identity::Id; +use std::fmt::{Debug, Display}; +use tokio::sync::watch; +use tower::{util::ServiceExt, Service}; + +pub struct Spire { + id: Id, +} + +// === impl Spire === + +impl Spire { + pub fn new(id: Id) -> Self { + Self { id } + } + + pub async fn run(self, credentials: C, mut client: S) + where + C: Credentials, + S: Service<(), Response = tonic::Response>>, + S::Error: Into + Display + Debug, + { + let client = client.ready().await.expect("should be ready"); + let rsp = client + .call(()) + .await + .expect("spire client must gracefully handle errors"); + consume_updates(&self.id, rsp.into_inner(), credentials).await + } +} + +async fn consume_updates( + id: &Id, + mut updates: watch::Receiver, + mut credentials: C, +) where + C: Credentials, +{ + loop { + let svid_update = updates.borrow_and_update().clone(); + if let Err(error) = api::process_svid(&mut credentials, svid_update, id) { + tracing::error!(%error, "Error processing SVID update"); + } + if updates.changed().await.is_err() { + tracing::debug!("SVID watch closed; terminating"); + return; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::api::Svid; + use linkerd_error::Result; + use linkerd_identity::{Credentials, DerX509, Id}; + use rcgen::{Certificate, CertificateParams, SanType, SerialNumber}; + use std::time::SystemTime; + use tokio::sync::watch; + + fn gen_svid(id: Id, subject_alt_names: Vec, serial: SerialNumber) -> Svid { + let mut params = CertificateParams::default(); + params.subject_alt_names = subject_alt_names; + params.serial_number = Some(serial); + + Svid::new( + id, + DerX509( + Certificate::from_params(params) + .expect("should generate cert") + .serialize_der() + .expect("should serialize"), + ), + Vec::default(), + Vec::default(), + ) + } + + struct MockClient { + rx: watch::Receiver, + } + + impl MockClient { + fn new(init: SvidUpdate) -> (Self, watch::Sender) { + let (tx, rx) = watch::channel(init); + (Self { rx }, tx) + } + } + + impl tower::Service<()> for MockClient { + type Response = tonic::Response>; + type Error = Error; + // type Future = futures::future::BoxFuture<'static, Result>; + type Future = futures::future::BoxFuture<'static, Result>; + + fn poll_ready( + &mut self, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: ()) -> Self::Future { + let rsp = tonic::Response::new(self.rx.clone()); + Box::pin(futures::future::ready(Ok(rsp))) + } + } + + struct MockCredentials { + tx: watch::Sender>, + } + + impl MockCredentials { + fn new() -> (Self, watch::Receiver>) { + let (tx, rx) = watch::channel(None); + (Self { tx }, rx) + } + } + + impl Credentials for MockCredentials { + fn set_certificate( + &mut self, + leaf: DerX509, + _: Vec, + _: Vec, + _: SystemTime, + ) -> Result<()> { + let (_, cert) = x509_parser::parse_x509_certificate(&leaf.0).unwrap(); + let serial = SerialNumber::from_slice(&cert.serial.to_bytes_be()); + self.tx.send(Some(serial)).unwrap(); + Ok(()) + } + } + + #[tokio::test(flavor = "current_thread")] + async fn valid_updates() { + let spiffe_san = "spiffe://some-domain/some-workload"; + let spiffe_id = Id::parse_uri("spiffe://some-domain/some-workload").expect("should parse"); + + let (creds, mut creds_rx) = MockCredentials::new(); + + let spire = Spire::new(spiffe_id.clone()); + + let serial_1 = SerialNumber::from_slice("some-serial-1".as_bytes()); + let update_1 = SvidUpdate::new(vec![gen_svid( + spiffe_id.clone(), + vec![SanType::URI(spiffe_san.into())], + serial_1.clone(), + )]); + + let (client, svid_tx) = MockClient::new(update_1); + tokio::spawn(spire.run(creds, client)); + + creds_rx.changed().await.unwrap(); + assert!(*creds_rx.borrow_and_update() == Some(serial_1)); + + let serial_2 = SerialNumber::from_slice("some-serial-2".as_bytes()); + let update_2 = SvidUpdate::new(vec![gen_svid( + spiffe_id.clone(), + vec![SanType::URI(spiffe_san.into())], + serial_2.clone(), + )]); + + svid_tx.send(update_2).expect("should send"); + + creds_rx.changed().await.unwrap(); + assert!(*creds_rx.borrow_and_update() == Some(serial_2)); + } + + #[tokio::test(flavor = "current_thread")] + async fn invalid_update_empty_cert() { + let spiffe_san = "spiffe://some-domain/some-workload"; + let spiffe_id = Id::parse_uri("spiffe://some-domain/some-workload").expect("should parse"); + + let (creds, mut creds_rx) = MockCredentials::new(); + + let spire = Spire::new(spiffe_id.clone()); + + let serial_1 = SerialNumber::from_slice("some-serial-1".as_bytes()); + let update_1 = SvidUpdate::new(vec![gen_svid( + spiffe_id.clone(), + vec![SanType::URI(spiffe_san.into())], + serial_1.clone(), + )]); + + let (client, svid_tx) = MockClient::new(update_1); + tokio::spawn(spire.run(creds, client)); + + creds_rx.changed().await.unwrap(); + assert!(*creds_rx.borrow_and_update() == Some(serial_1.clone())); + + let invalid_svid = Svid::new( + spiffe_id.clone(), + DerX509(Vec::default()), + Vec::default(), + Vec::default(), + ); + + let mut update_sent = svid_tx.subscribe(); + let update_2 = SvidUpdate::new(vec![invalid_svid]); + svid_tx.send(update_2).expect("should send"); + + update_sent.changed().await.unwrap(); + + assert!(!creds_rx.has_changed().unwrap()); + assert!(*creds_rx.borrow_and_update() == Some(serial_1)); + } + + #[tokio::test(flavor = "current_thread")] + async fn invalid_valid_update_non_matching_id() { + let spiffe_san = "spiffe://some-domain/some-workload"; + let spiffe_san_wrong = "spiffe://some-domain/wrong"; + + let spiffe_id = Id::parse_uri("spiffe://some-domain/some-workload").expect("should parse"); + let spiffe_id_wrong = Id::parse_uri("spiffe://some-domain/wrong").expect("should parse"); + + let (creds, mut creds_rx) = MockCredentials::new(); + + let spire = Spire::new(spiffe_id.clone()); + + let serial_1 = SerialNumber::from_slice("some-serial-1".as_bytes()); + let update_1 = SvidUpdate::new(vec![gen_svid( + spiffe_id.clone(), + vec![SanType::URI(spiffe_san.into())], + serial_1.clone(), + )]); + + let (client, svid_tx) = MockClient::new(update_1); + tokio::spawn(spire.run(creds, client)); + + creds_rx.changed().await.unwrap(); + assert!(*creds_rx.borrow_and_update() == Some(serial_1.clone())); + + let serial_2 = SerialNumber::from_slice("some-serial-2".as_bytes()); + let mut update_sent = svid_tx.subscribe(); + let update_2 = SvidUpdate::new(vec![gen_svid( + spiffe_id_wrong, + vec![SanType::URI(spiffe_san_wrong.into())], + serial_2.clone(), + )]); + + svid_tx.send(update_2).expect("should send"); + + update_sent.changed().await.unwrap(); + + assert!(!creds_rx.has_changed().unwrap()); + assert!(*creds_rx.borrow_and_update() == Some(serial_1)); + } +} diff --git a/linkerd2-proxy/src/main.rs b/linkerd2-proxy/src/main.rs index ea3836829e..51c60b0b8d 100644 --- a/linkerd2-proxy/src/main.rs +++ b/linkerd2-proxy/src/main.rs @@ -80,14 +80,8 @@ fn main() { } // TODO distinguish ServerName and Identity. - info!("Local identity is {}", app.local_server_name()); - let addr = app.identity_addr(); - match addr.identity.value() { - None => info!("Identity verified via {}", addr.addr), - Some(tls) => { - info!("Identity verified via {} ({})", addr.addr, tls.server_id); - } - } + info!("SNI is {}", app.local_server_name()); + info!("Local identity is {}", app.local_tls_id()); let dst_addr = app.dst_addr(); match dst_addr.identity.value() { diff --git a/spiffe-proto/Cargo.toml b/spiffe-proto/Cargo.toml new file mode 100644 index 0000000000..9e1790e63c --- /dev/null +++ b/spiffe-proto/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "spiffe-proto" +version = "0.1.0" +authors = ["Linkerd Developers "] +license = "Apache-2.0" +edition = "2021" +publish = false + +[dependencies] +bytes = "1" +prost = "0.12" +prost-types = "0.12" + +[dependencies.tonic] +version = "0.10" +default-features = false +features = ["prost", "codegen"] + +[dev-dependencies.tonic-build] +version = "0.10" +default-features = false +features = ["prost"] + +[lib] +doctest = false diff --git a/spiffe-proto/spiffe/proto/workload.proto b/spiffe-proto/spiffe/proto/workload.proto new file mode 100644 index 0000000000..75a4b0922f --- /dev/null +++ b/spiffe-proto/spiffe/proto/workload.proto @@ -0,0 +1,49 @@ +syntax = "proto3"; + +package spiffe.workloadapi; + +service SpiffeWorkloadAPI { + // Fetch X.509-SVIDs for all SPIFFE identities the workload is entitled to, + // as well as related information like trust bundles and CRLs. As this + // information changes, subsequent messages will be streamed from the + // server. + rpc FetchX509SVID(X509SVIDRequest) returns (stream X509SVIDResponse); +} + +// The X509SVIDRequest message conveys parameters for requesting an X.509-SVID. +// There are currently no request parameters. +message X509SVIDRequest { } + +// The X509SVIDResponse message carries X.509-SVIDs and related information, +// including a set of global CRLs and a list of bundles the workload may use +// for federating with foreign trust domains. +message X509SVIDResponse { + // Required. A list of X509SVID messages, each of which includes a single + // X.509-SVID, its private key, and the bundle for the trust domain. + repeated X509SVID svids = 1; + + // Optional. ASN.1 DER encoded certificate revocation lists. + repeated bytes crl = 2; + + // Optional. CA certificate bundles belonging to foreign trust domains that + // the workload should trust, keyed by the SPIFFE ID of the foreign trust + // domain. Bundles are ASN.1 DER encoded. + map federated_bundles = 3; +} + +// The X509SVID message carries a single SVID and all associated information, +// including the X.509 bundle for the trust domain. +message X509SVID { + // Required. The SPIFFE ID of the SVID in this entry + string spiffe_id = 1; + + // Required. ASN.1 DER encoded certificate chain. MAY include + // intermediates, the leaf certificate (or SVID itself) MUST come first. + bytes x509_svid = 2; + + // Required. ASN.1 DER encoded PKCS#8 private key. MUST be unencrypted. + bytes x509_svid_key = 3; + + // Required. ASN.1 DER encoded X.509 bundle for the trust domain. + bytes bundle = 4; +} diff --git a/spiffe-proto/src/gen/spiffe.workloadapi.rs b/spiffe-proto/src/gen/spiffe.workloadapi.rs new file mode 100644 index 0000000000..b052c5308d --- /dev/null +++ b/spiffe-proto/src/gen/spiffe.workloadapi.rs @@ -0,0 +1,151 @@ +/// The X509SVIDRequest message conveys parameters for requesting an X.509-SVID. +/// There are currently no request parameters. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct X509svidRequest {} +/// The X509SVIDResponse message carries X.509-SVIDs and related information, +/// including a set of global CRLs and a list of bundles the workload may use +/// for federating with foreign trust domains. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct X509svidResponse { + /// Required. A list of X509SVID messages, each of which includes a single + /// X.509-SVID, its private key, and the bundle for the trust domain. + #[prost(message, repeated, tag = "1")] + pub svids: ::prost::alloc::vec::Vec, + /// Optional. ASN.1 DER encoded certificate revocation lists. + #[prost(bytes = "vec", repeated, tag = "2")] + pub crl: ::prost::alloc::vec::Vec<::prost::alloc::vec::Vec>, + /// Optional. CA certificate bundles belonging to foreign trust domains that + /// the workload should trust, keyed by the SPIFFE ID of the foreign trust + /// domain. Bundles are ASN.1 DER encoded. + #[prost(map = "string, bytes", tag = "3")] + pub federated_bundles: ::std::collections::HashMap< + ::prost::alloc::string::String, + ::prost::alloc::vec::Vec, + >, +} +/// The X509SVID message carries a single SVID and all associated information, +/// including the X.509 bundle for the trust domain. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct X509svid { + /// Required. The SPIFFE ID of the SVID in this entry + #[prost(string, tag = "1")] + pub spiffe_id: ::prost::alloc::string::String, + /// Required. ASN.1 DER encoded certificate chain. MAY include + /// intermediates, the leaf certificate (or SVID itself) MUST come first. + #[prost(bytes = "vec", tag = "2")] + pub x509_svid: ::prost::alloc::vec::Vec, + /// Required. ASN.1 DER encoded PKCS#8 private key. MUST be unencrypted. + #[prost(bytes = "vec", tag = "3")] + pub x509_svid_key: ::prost::alloc::vec::Vec, + /// Required. ASN.1 DER encoded X.509 bundle for the trust domain. + #[prost(bytes = "vec", tag = "4")] + pub bundle: ::prost::alloc::vec::Vec, +} +/// Generated client implementations. +pub mod spiffe_workload_api_client { + #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] + use tonic::codegen::*; + use tonic::codegen::http::Uri; + #[derive(Debug, Clone)] + pub struct SpiffeWorkloadApiClient { + inner: tonic::client::Grpc, + } + impl SpiffeWorkloadApiClient + where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: Body + Send + 'static, + ::Error: Into + Send, + { + pub fn new(inner: T) -> Self { + let inner = tonic::client::Grpc::new(inner); + Self { inner } + } + pub fn with_origin(inner: T, origin: Uri) -> Self { + let inner = tonic::client::Grpc::with_origin(inner, origin); + Self { inner } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> SpiffeWorkloadApiClient> + where + F: tonic::service::Interceptor, + T::ResponseBody: Default, + T: tonic::codegen::Service< + http::Request, + Response = http::Response< + >::ResponseBody, + >, + >, + , + >>::Error: Into + Send + Sync, + { + SpiffeWorkloadApiClient::new(InterceptedService::new(inner, interceptor)) + } + /// Compress requests with the given encoding. + /// + /// This requires the server to support it otherwise it might respond with an + /// error. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.send_compressed(encoding); + self + } + /// Enable decompressing responses. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.accept_compressed(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_decoding_message_size(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_encoding_message_size(limit); + self + } + /// Fetch X.509-SVIDs for all SPIFFE identities the workload is entitled to, + /// as well as related information like trust bundles and CRLs. As this + /// information changes, subsequent messages will be streamed from the + /// server. + pub async fn fetch_x509svid( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response>, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/SpiffeWorkloadAPI/FetchX509SVID", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("SpiffeWorkloadAPI", "FetchX509SVID")); + self.inner.server_streaming(req, path, codec).await + } + } +} diff --git a/spiffe-proto/src/lib.rs b/spiffe-proto/src/lib.rs new file mode 100644 index 0000000000..1223fc17f0 --- /dev/null +++ b/spiffe-proto/src/lib.rs @@ -0,0 +1,11 @@ +//! gRPC bindings for SPIFFE workload api. +//! +//! Vendored from . + +#![deny(rust_2018_idioms, clippy::disallowed_methods, clippy::disallowed_types)] +#![allow(clippy::derive_partial_eq_without_eq)] +#![forbid(unsafe_code)] + +pub mod client { + include!("gen/spiffe.workloadapi.rs"); +} diff --git a/spiffe-proto/tests/bootstrap.rs b/spiffe-proto/tests/bootstrap.rs new file mode 100644 index 0000000000..3aa90b3d3c --- /dev/null +++ b/spiffe-proto/tests/bootstrap.rs @@ -0,0 +1,48 @@ +//! A test that regenerates the Rust protobuf bindings. +//! +//! It can be run via: +//! +//! ```no_run +//! cargo test -p spiffe-proto --test=bootstrap +//! ``` + +/// Generates protobuf bindings into src/gen and fails if the generated files do +/// not match those that are already checked into git +#[test] +fn bootstrap() { + let out_dir = std::path::PathBuf::from(std::env!("CARGO_MANIFEST_DIR")) + .join("src") + .join("gen"); + generate(&out_dir); + if changed(&out_dir) { + panic!("protobuf interfaces do not match generated sources"); + } +} + +/// Generates protobuf bindings into the given directory +fn generate(out_dir: &std::path::Path) { + let iface_files = &["spiffe/proto/workload.proto"]; + if let Err(error) = tonic_build::configure() + .build_client(true) + .build_server(false) + .emit_rerun_if_changed(false) + .disable_package_emission() + .out_dir(out_dir) + .compile(iface_files, &["."]) + { + panic!("failed to compile protobuf: {error}") + } +} + +/// Returns true if the given path contains files that have changed since the +/// last Git commit +fn changed(path: &std::path::Path) -> bool { + let status = std::process::Command::new("git") + .arg("diff") + .arg("--exit-code") + .arg("--") + .arg(path) + .status() + .expect("failed to run git"); + !status.success() +}