Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: made TLS certificate optional for the debug page #198

Merged
merged 4 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions node/actors/executor/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
//! Library files for the executor. We have it separate from the binary so that we can use these files in the tools crate.
use crate::io::Dispatcher;
use anyhow::Context as _;
use network::http;
pub use network::{gossip::attestation, RpcConfig};
use std::{
collections::{HashMap, HashSet},
Expand Down Expand Up @@ -63,7 +62,7 @@ pub struct Config {

/// Http debug page configuration.
/// If None, debug page is disabled
pub debug_page: Option<http::DebugPageConfig>,
pub debug_page: Option<network::debug_page::Config>,

/// How often to poll the database looking for the batch commitment.
pub batch_poll_interval: time::Duration,
Expand Down Expand Up @@ -144,12 +143,12 @@ impl Executor {
net.register_metrics();
s.spawn(async { runner.run(ctx).await.context("Network stopped") });

if let Some(debug_config) = self.config.debug_page {
if let Some(cfg) = self.config.debug_page {
s.spawn(async {
http::DebugPageServer::new(debug_config, net)
network::debug_page::Server::new(cfg, net)
.run(ctx)
.await
.context("Http Server stopped")
.context("Debug page server stopped")
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,56 +24,131 @@ use tokio_rustls::{
pki_types::{CertificateDer, PrivateKeyDer},
ServerConfig,
},
server::TlsStream,
TlsAcceptor,
};
use zksync_concurrency::{ctx, scope};
use zksync_consensus_crypto::TextFmt as _;
use zksync_consensus_utils::debug_page;

const STYLE: &str = include_str!("style.css");

/// TLS certificate chain with a private key.
#[derive(Debug, PartialEq)]
pub struct TlsConfig {
/// TLS certificate chain.
pub cert_chain: Vec<CertificateDer<'static>>,
/// Private key for the leaf cert.
pub private_key: PrivateKeyDer<'static>,
}

/// Credentials.
#[derive(PartialEq, Clone)]
pub struct Credentials {
/// User for debug page
pub user: String,
/// Password for debug page
/// TODO: it should be treated as a secret: zeroize, etc.
pub password: String,
}

impl Credentials {
fn parse(value: String) -> anyhow::Result<Self> {
let [user, password] = value
.split(':')
.collect::<Vec<_>>()
.try_into()
.ok()
.context("bad format")?;
Ok(Self {
user: user.to_string(),
password: password.to_string(),
})
}
}

impl std::fmt::Debug for Credentials {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Credentials").finish_non_exhaustive()
}
}

/// Http debug page configuration.
#[derive(Debug, PartialEq)]
pub struct DebugPageConfig {
pub struct Config {
/// Public Http address to listen incoming http requests.
pub addr: SocketAddr,
/// Debug page credentials.
pub credentials: Option<debug_page::Credentials>,
/// Cert file path
pub certs: Vec<CertificateDer<'static>>,
/// Key file path
pub private_key: PrivateKeyDer<'static>,
pub credentials: Option<Credentials>,
/// TLS certificate to terminate the connections with.
pub tls: Option<TlsConfig>,
}

/// Http Server for debug page.
pub struct DebugPageServer {
config: DebugPageConfig,
pub struct Server {
config: Config,
network: Arc<Network>,
}

impl DebugPageServer {
#[async_trait::async_trait]
trait Listener: 'static + Send {
type Stream: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin;
async fn accept(&mut self) -> anyhow::Result<Self::Stream>;
}

#[async_trait::async_trait]
impl Listener for TcpListener {
type Stream = tokio::net::TcpStream;
async fn accept(&mut self) -> anyhow::Result<Self::Stream> {
Ok(TcpListener::accept(self).await?.0)
}
}

#[async_trait::async_trait]
impl Listener for TlsListener<TcpListener, TlsAcceptor> {
type Stream = TlsStream<tokio::net::TcpStream>;
async fn accept(&mut self) -> anyhow::Result<Self::Stream> {
Ok(TlsListener::accept(self).await?.0)
}
}

impl Server {
/// Creates a new Server
pub fn new(config: DebugPageConfig, network: Arc<Network>) -> DebugPageServer {
DebugPageServer { config, network }
pub fn new(config: Config, network: Arc<Network>) -> Self {
Self { config, network }
}

/// Runs the Server.
pub async fn run(&self, ctx: &ctx::Ctx) -> anyhow::Result<()> {
let listener = TcpListener::bind(self.config.addr)
.await
.context("TcpListener::bind()")?;
if let Some(tls) = &self.config.tls {
let cfg = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(tls.cert_chain.clone(), tls.private_key.clone_key())
.context("with_single_cert()")?;
self.run_with_listener(ctx, TlsListener::new(Arc::new(cfg).into(), listener))
.await
} else {
self.run_with_listener(ctx, listener).await
}
}

async fn run_with_listener<L: Listener>(
&self,
ctx: &ctx::Ctx,
mut listener: L,
) -> anyhow::Result<()> {
// Start a watcher to shut down the server whenever ctx gets cancelled
let graceful = hyper_util::server::graceful::GracefulShutdown::new();

scope::run!(ctx, |ctx, s| async {
let mut listener = TlsListener::new(
self.tls_acceptor(),
TcpListener::bind(self.config.addr).await?,
);

let http = http1::Builder::new();

// Start a loop to accept incoming connections
while let Ok(res) = ctx.wait(listener.accept()).await {
match res {
Ok((stream, _)) => {
Ok(stream) => {
let io = TokioIo::new(stream);
let conn = http.serve_connection(io, service_fn(|req| self.handle(req)));
// watch this connection
Expand All @@ -86,10 +161,6 @@ impl DebugPageServer {
});
}
Err(err) => {
if let Some(remote_addr) = err.peer_addr() {
tracing::error!("[client {remote_addr}] ");
}

tracing::error!("Error accepting connection: {}", err);
continue;
}
Expand All @@ -106,46 +177,41 @@ impl DebugPageServer {
request: Request<hyper::body::Incoming>,
) -> anyhow::Result<Response<Full<Bytes>>> {
let mut response = Response::new(Full::default());
match self.basic_authentication(request.headers()) {
Ok(_) => *response.body_mut() = self.serve(request),
Err(e) => {
*response.status_mut() = StatusCode::UNAUTHORIZED;
*response.body_mut() = Full::new(Bytes::from(e.to_string()));
let header_value = HeaderValue::from_str(r#"Basic realm="debug""#).unwrap();
response
.headers_mut()
.insert(header::WWW_AUTHENTICATE, header_value);
}
if let Err(err) = self.authenticate(request.headers()) {
*response.status_mut() = StatusCode::UNAUTHORIZED;
*response.body_mut() = Full::new(Bytes::from(err.to_string()));
let header_value = HeaderValue::from_str(r#"Basic realm="debug""#).unwrap();
response
.headers_mut()
.insert(header::WWW_AUTHENTICATE, header_value);
}
*response.body_mut() = self.serve(request);
Ok(response)
}

fn basic_authentication(&self, headers: &HeaderMap) -> anyhow::Result<()> {
self.config
.credentials
.clone()
.map_or(Ok(()), |credentials| {
// The header value, if present, must be a valid UTF8 string
let header_value = headers
.get("Authorization")
.context("The 'Authorization' header was missing")?
.to_str()
.context("The 'Authorization' header was not a valid UTF8 string.")?;
let base64encoded_segment = header_value
.strip_prefix("Basic ")
.context("The authorization scheme was not 'Basic'.")?;
let decoded_bytes = base64::engine::general_purpose::STANDARD
.decode(base64encoded_segment)
.context("Failed to base64-decode 'Basic' credentials.")?;
let incoming_credentials = debug_page::Credentials::try_from(
String::from_utf8(decoded_bytes)
.context("The decoded credential string is not valid UTF8.")?,
)?;
if credentials != incoming_credentials {
anyhow::bail!("Invalid password.")
}
Ok(())
})
fn authenticate(&self, headers: &HeaderMap) -> anyhow::Result<()> {
let Some(want) = self.config.credentials.as_ref() else {
return Ok(());
};

// The header value, if present, must be a valid UTF8 string
let header_value = headers
.get("Authorization")
.context("The 'Authorization' header was missing")?
.to_str()
.context("The 'Authorization' header was not a valid UTF8 string.")?;
let base64encoded_segment = header_value
.strip_prefix("Basic ")
.context("Unsupported authorization scheme.")?;
let decoded_bytes = base64::engine::general_purpose::STANDARD
.decode(base64encoded_segment)
.context("Failed to base64-decode 'Basic' credentials.")?;
let got = Credentials::parse(
String::from_utf8(decoded_bytes)
.context("The decoded credential string is not valid UTF8.")?,
)?;
anyhow::ensure!(want == &got, "Invalid credentials.");
Ok(())
}

fn serve(&self, _request: Request<hyper::body::Incoming>) -> Full<Bytes> {
Expand Down Expand Up @@ -262,16 +328,4 @@ impl DebugPageServer {
format!("{}...{}", &key[..10], &key[len - 11..len])
})
}

fn tls_acceptor(&self) -> TlsAcceptor {
let cert_der = self.config.certs.clone();
let key_der = self.config.private_key.clone_key();
Arc::new(
ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(cert_der, key_der)
.unwrap(),
)
.into()
}
}
2 changes: 1 addition & 1 deletion node/actors/network/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ use zksync_consensus_utils::pipe::ActorPipe;

mod config;
pub mod consensus;
pub mod debug_page;
mod frame;
pub mod gossip;
pub mod http;
pub mod io;
mod metrics;
mod mux;
Expand Down
29 changes: 12 additions & 17 deletions node/libs/concurrency/src/net/tcp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,23 @@ pub type Listener = tokio::net::TcpListener;

/// Accepts an INBOUND listener connection.
pub async fn accept(ctx: &ctx::Ctx, this: &mut Listener) -> ctx::OrCanceled<io::Result<Stream>> {
Ok(ctx.wait(this.accept()).await?.map(|(stream, _)| {
// We are the only owner of the correctly opened
// socket at this point so `set_nodelay` should
// always succeed.
stream.set_nodelay(true).unwrap();
stream
}))
ctx.wait(async {
let stream = this.accept().await?.0;
stream.set_nodelay(true)?;
Ok(stream)
})
.await
}

/// Opens a TCP connection to a remote host.
pub async fn connect(
ctx: &ctx::Ctx,
addr: std::net::SocketAddr,
) -> ctx::OrCanceled<io::Result<Stream>> {
Ok(ctx
.wait(tokio::net::TcpStream::connect(addr))
.await?
.map(|stream| {
// We are the only owner of the correctly opened
// socket at this point so `set_nodelay` should
// always succeed.
stream.set_nodelay(true).unwrap();
stream
}))
ctx.wait(async {
let stream = tokio::net::TcpStream::connect(addr).await?;
stream.set_nodelay(true)?;
Ok(stream)
})
.await
}
2 changes: 1 addition & 1 deletion node/libs/protobuf/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub mod build;
pub mod proto;
mod proto_fmt;
pub mod repr;
pub use repr::{read_required_repr, ProtoRepr};
pub use repr::{read_optional_repr, read_required_repr, ProtoRepr};
pub mod serde;
mod std_conv;
pub mod testonly;
Expand Down
5 changes: 5 additions & 0 deletions node/libs/protobuf/src/repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ pub fn read_required_repr<P: ProtoRepr>(field: &Option<P>) -> anyhow::Result<P::
field.as_ref().context("missing field")?.read()
}

/// Parses an optional proto field.
pub fn read_optional_repr<P: ProtoRepr>(field: &Option<P>) -> anyhow::Result<Option<P::Type>> {
field.as_ref().map(ProtoRepr::read).transpose()
}

/// Encodes a proto message.
/// Currently it outputs a canonical encoding, but `decode` accepts
/// non-canonical encoding as well.
Expand Down
Loading
Loading