Skip to content

Commit

Permalink
Move TCP metrics to network crate
Browse files Browse the repository at this point in the history
  • Loading branch information
slowli committed Sep 28, 2023
1 parent 6acfd28 commit e5685d3
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 184 deletions.
94 changes: 92 additions & 2 deletions node/actors/network/src/metrics.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,98 @@
//! General-purpose network metrics.

use crate::state::State;
use std::sync::Weak;
use vise::{Collector, Gauge, Metrics};
use concurrency::{io, metrics::GaugeGuard, net};
use std::{
pin::Pin,
sync::Weak,
task::{ready, Context, Poll},
};
use vise::{Collector, Counter, EncodeLabelSet, EncodeLabelValue, Family, Gauge, Metrics, Unit};

/// Metered TCP stream.
#[pin_project::pin_project]
pub(crate) struct MeteredStream {
#[pin]
stream: net::tcp::Stream,
_active: GaugeGuard,
}

impl MeteredStream {
/// Creates a new stream with the specified `direction`.
pub(crate) fn new(stream: net::tcp::Stream, direction: Direction) -> Self {
TCP_METRICS.established[&direction].inc();
Self {
stream,
_active: GaugeGuard::from(TCP_METRICS.active[&direction].clone()),
}
}
}

impl io::AsyncRead for MeteredStream {
#[inline(always)]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.project();
let before = buf.remaining();
let res = this.stream.poll_read(cx, buf);
let after = buf.remaining();
TCP_METRICS.received.inc_by((before - after) as u64);
res
}
}

impl io::AsyncWrite for MeteredStream {
#[inline(always)]
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
let this = self.project();
let res = ready!(this.stream.poll_write(cx, buf))?;
TCP_METRICS.sent.inc_by(res as u64);
Poll::Ready(Ok(res))
}

#[inline(always)]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.project().stream.poll_flush(cx)
}

#[inline(always)]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.project().stream.poll_shutdown(cx)
}
}

/// Direction of a TCP connection.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, EncodeLabelSet, EncodeLabelValue)]
#[metrics(label = "direction", rename_all = "snake_case")]
pub(crate) enum Direction {
/// Inbound connection.
Inbound,
/// Outbound connection.
Outbound,
}

/// Metrics reported for TCP connections.
#[derive(Debug, Metrics)]
#[metrics(prefix = "concurrency_net_tcp")]
struct TcpMetrics {
/// Total bytes sent over all TCP connections.
#[metrics(unit = Unit::Bytes)]
sent: Counter,
/// Total bytes received over all TCP connections.
#[metrics(unit = Unit::Bytes)]
received: Counter,
/// TCP connections established since the process started.
established: Family<Direction, Counter>,
/// Number of currently active TCP connections.
active: Family<Direction, Gauge>,
}

/// TCP metrics instance.
#[vise::register]
static TCP_METRICS: vise::Global<TcpMetrics> = vise::Global::new();

/// General-purpose network metrics exposed via a collector.
#[derive(Debug, Metrics)]
Expand Down
154 changes: 81 additions & 73 deletions node/actors/network/src/noise/stream.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
//! `tokio::io` stream using Noise encryption.
use super::bytes;
use crate::metrics::MeteredStream;
use concurrency::{
ctx, io,
io::{AsyncRead as _, AsyncWrite as _},
net,
};
use crypto::{sha256::Sha256, ByteFmt};
use std::{
Expand Down Expand Up @@ -32,65 +32,6 @@ fn params() -> snow::params::NoiseParams {
}
}

impl Stream {
/// Performs a server-side noise handshake and returns the encrypted stream.
pub(crate) async fn server_handshake(
ctx: &ctx::Ctx,
s: net::tcp::Stream,
) -> anyhow::Result<Stream> {
Self::handshake(ctx, s, snow::Builder::new(params()).build_responder()?).await
}

/// Performs a client-side noise handshake and returns the encrypted stream.
pub(crate) async fn client_handshake(
ctx: &ctx::Ctx,
s: net::tcp::Stream,
) -> anyhow::Result<Stream> {
Self::handshake(ctx, s, snow::Builder::new(params()).build_initiator()?).await
}

/// Performs the noise handshake given the HandshakeState.
async fn handshake(
ctx: &ctx::Ctx,
mut stream: net::tcp::Stream,
mut hs: snow::HandshakeState,
) -> anyhow::Result<Stream> {
let mut buf = vec![0; 65536];
let mut payload = vec![];
loop {
if hs.is_handshake_finished() {
return Ok(Stream {
id: ByteFmt::decode(hs.get_handshake_hash()).unwrap(),
inner: stream,
noise: hs.into_transport_mode()?,
read_buf: Box::default(),
write_buf: Box::default(),
});
}
if hs.is_my_turn() {
let n = hs.write_message(&payload, &mut buf)?;
// TODO(gprusak): writing/reading length field and the frame content could be
// done in a single syscall.
io::write_all(ctx, &mut stream, &u16::to_le_bytes(n as u16)).await??;
io::write_all(ctx, &mut stream, &buf[..n]).await??;
io::flush(ctx, &mut stream).await??;
} else {
let mut msg_size = [0u8, 2];
io::read_exact(ctx, &mut stream, &mut msg_size).await??;
let n = u16::from_le_bytes(msg_size) as usize;
io::read_exact(ctx, &mut stream, &mut buf[..n]).await??;
hs.read_message(&buf[..n], &mut payload)?;
}
}
}

/// Returns the noise session id.
/// See `Stream::id`.
pub(crate) fn id(&self) -> Sha256 {
self.id
}
}

// Constants from the Noise spec.

/// Maximal size of the encrypted frame that Noise may output.
Expand Down Expand Up @@ -130,16 +71,14 @@ impl Default for Buffer {

/// Encrypted stream.
/// It implements tokio::io::AsyncRead/AsyncWrite.
#[pin_project::pin_project(project=StreamProject)]
pub(crate) struct Stream {
#[pin_project::pin_project(project = StreamProject)]
pub(crate) struct Stream<S = MeteredStream> {
/// Hash of the handshake messages.
/// Uniquely identifies the noise session.
id: Sha256,
/// Underlying TCP stream.
/// TODO(gprusak): we can generalize noise::Stream to wrap an arbitrary
/// stream if needed.
#[pin]
inner: net::tcp::Stream,
inner: S,
/// Noise protocol state, used to encrypt/decrypt frames.
noise: snow::TransportState,
/// Buffers used for the read half of the stream.
Expand All @@ -148,12 +87,66 @@ pub(crate) struct Stream {
write_buf: Box<Buffer>,
}

impl Stream {
impl<S> Stream<S>
where
S: io::AsyncRead + io::AsyncWrite + Unpin,
{
/// Performs a server-side noise handshake and returns the encrypted stream.
pub(crate) async fn server_handshake(ctx: &ctx::Ctx, stream: S) -> anyhow::Result<Self> {
Self::handshake(ctx, stream, snow::Builder::new(params()).build_responder()?).await
}

/// Performs a client-side noise handshake and returns the encrypted stream.
pub(crate) async fn client_handshake(ctx: &ctx::Ctx, stream: S) -> anyhow::Result<Self> {
Self::handshake(ctx, stream, snow::Builder::new(params()).build_initiator()?).await
}

/// Performs the noise handshake given the HandshakeState.
async fn handshake(
ctx: &ctx::Ctx,
mut stream: S,
mut hs: snow::HandshakeState,
) -> anyhow::Result<Self> {
let mut buf = vec![0; 65536];
let mut payload = vec![];
loop {
if hs.is_handshake_finished() {
return Ok(Self {
id: ByteFmt::decode(hs.get_handshake_hash()).unwrap(),
inner: stream,
noise: hs.into_transport_mode()?,
read_buf: Box::default(),
write_buf: Box::default(),
});
}
if hs.is_my_turn() {
let n = hs.write_message(&payload, &mut buf)?;
// TODO(gprusak): writing/reading length field and the frame content could be
// done in a single syscall.
io::write_all(ctx, &mut stream, &u16::to_le_bytes(n as u16)).await??;
io::write_all(ctx, &mut stream, &buf[..n]).await??;
io::flush(ctx, &mut stream).await??;
} else {
let mut msg_size = [0u8, 2];
io::read_exact(ctx, &mut stream, &mut msg_size).await??;
let n = u16::from_le_bytes(msg_size) as usize;
io::read_exact(ctx, &mut stream, &mut buf[..n]).await??;
hs.read_message(&buf[..n], &mut payload)?;
}
}
}

/// Returns the noise session id.
/// See `Stream::id`.
pub(crate) fn id(&self) -> Sha256 {
self.id
}

/// Wait until a frame is fully loaded.
/// Returns the size of the frame.
/// Returns None in case EOF is reached before the frame is loaded.
fn poll_read_frame(
this: &mut StreamProject<'_>,
this: &mut StreamProject<'_, S>,
cx: &mut Context<'_>,
) -> Poll<io::Result<Option<usize>>> {
// Fetch frame until complete.
Expand All @@ -179,7 +172,7 @@ impl Stream {

/// Wait until payload is nonempty.
fn poll_read_payload(
this: &mut StreamProject<'_>,
this: &mut StreamProject<'_, S>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
if this.read_buf.payload.len() > 0 {
Expand All @@ -203,7 +196,10 @@ impl Stream {
}
}

impl io::AsyncRead for Stream {
impl<S> io::AsyncRead for Stream<S>
where
S: io::AsyncRead + io::AsyncWrite + Unpin,
{
/// From tokio::io::AsyncRead:
/// * The amount of data read can be determined by the increase
/// in the length of the slice returned by ReadBuf::filled.
Expand All @@ -227,9 +223,15 @@ impl io::AsyncRead for Stream {
}
}

impl Stream {
impl<S> Stream<S>
where
S: io::AsyncRead + io::AsyncWrite + Unpin,
{
/// poll_flush_frame will either flush this.write_buf.frame, or return an error.
fn poll_flush_frame(this: &mut StreamProject, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
fn poll_flush_frame(
this: &mut StreamProject<'_, S>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
while this.write_buf.frame.len() > 0 {
let n =
ready!(Pin::new(&mut this.inner).poll_write(cx, this.write_buf.frame.as_slice()))?;
Expand All @@ -242,7 +244,10 @@ impl Stream {
}

/// poll_flush_payload will either flush this.write_buf.payload, or return an error.
fn poll_flush_payload(this: &mut StreamProject, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
fn poll_flush_payload(
this: &mut StreamProject<'_, S>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
if this.write_buf.payload.len() == 0 {
return Poll::Ready(Ok(()));
}
Expand All @@ -266,7 +271,10 @@ impl Stream {
}
}

impl io::AsyncWrite for Stream {
impl<S> io::AsyncWrite for Stream<S>
where
S: io::AsyncRead + io::AsyncWrite + Unpin,
{
/// from futures::io::AsyncWrite:
/// * poll_write must try to make progress by flushing if needed to become writable
/// from std::io::Write:
Expand Down
19 changes: 14 additions & 5 deletions node/actors/network/src/noise/testonly.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
use crate::noise;
use crate::{metrics, noise};
use concurrency::{ctx, net, scope};

pub(crate) async fn pipe(ctx: &ctx::Ctx) -> (noise::Stream, noise::Stream) {
scope::run!(ctx, |ctx, s| async {
let (s1, s2) = net::tcp::testonly::pipe(ctx).await;
let s1 = s.spawn(async { noise::Stream::client_handshake(ctx, s1).await });
let s2 = s.spawn(async { noise::Stream::server_handshake(ctx, s2).await });
Ok((s1.join(ctx).await?, s2.join(ctx).await?))
let (outbound_stream, inbound_stream) = net::tcp::testonly::pipe(ctx).await;
let outbound_stream =
metrics::MeteredStream::new(outbound_stream, metrics::Direction::Outbound);
let inbound_stream =
metrics::MeteredStream::new(inbound_stream, metrics::Direction::Inbound);
let outbound_task =
s.spawn(async { noise::Stream::client_handshake(ctx, outbound_stream).await });
let inbound_task =
s.spawn(async { noise::Stream::server_handshake(ctx, inbound_stream).await });
Ok((
outbound_task.join(ctx).await?,
inbound_task.join(ctx).await?,
))
})
.await
.unwrap()
Expand Down
8 changes: 5 additions & 3 deletions node/actors/network/src/preface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//!
//! Hence, the preface protocol is used to enable encryption
//! and multiplex between mutliple endpoints available on the same TCP port.
use crate::{frame, noise};
use crate::{frame, metrics, noise};
use concurrency::{ctx, net, time};
use schema::{proto::network::preface as proto, required, ProtoFmt};

Expand Down Expand Up @@ -79,7 +79,8 @@ pub(crate) async fn connect(
endpoint: Endpoint,
) -> anyhow::Result<noise::Stream> {
let ctx = &ctx.with_timeout(TIMEOUT);
let mut stream = net::tcp::connect(ctx, addr).await??;
let stream = net::tcp::connect(ctx, addr).await??;
let mut stream = metrics::MeteredStream::new(stream, metrics::Direction::Outbound);
frame::send_proto(ctx, &mut stream, &Encryption::NoiseNN).await?;
let mut stream = noise::Stream::client_handshake(ctx, stream).await?;
frame::send_proto(ctx, &mut stream, &endpoint).await?;
Expand All @@ -89,8 +90,9 @@ pub(crate) async fn connect(
/// Performs a server-side preface protocol.
pub(crate) async fn accept(
ctx: &ctx::Ctx,
mut stream: net::tcp::Stream,
stream: net::tcp::Stream,
) -> anyhow::Result<(noise::Stream, Endpoint)> {
let mut stream = metrics::MeteredStream::new(stream, metrics::Direction::Inbound);
let ctx = &ctx.with_timeout(TIMEOUT);
let _: Encryption = frame::recv_proto(ctx, &mut stream).await?;
let mut stream = noise::Stream::server_handshake(ctx, stream).await?;
Expand Down
Loading

0 comments on commit e5685d3

Please sign in to comment.