Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable dynamic TCP receive window resizing #933

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
179 changes: 178 additions & 1 deletion src/socket/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,30 @@ impl Display for ConnectError {
#[cfg(feature = "std")]
impl std::error::Error for ConnectError {}

/// Error returned by set_*
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum ArgumentError {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the creation of a new error type that is used only in this API follows our API design principles.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't expect a new error type for this API, too. However, I didn't find an appropriate existent error type to add enum variants. Do you have any suggestion for which of the existent {Send, Recv, Listen, Connect}Error should I add variants?

InvalidArgs,
InvalidState,
InsufficientResource,
}

impl Display for crate::socket::tcp::ArgumentError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
crate::socket::tcp::ArgumentError::InvalidArgs => write!(f, "invalid arguments by RFC"),
crate::socket::tcp::ArgumentError::InvalidState => write!(f, "invalid state"),
crate::socket::tcp::ArgumentError::InsufficientResource => {
write!(f, "insufficient runtime resource")
}
}
}
}

#[cfg(feature = "std")]
impl std::error::Error for crate::socket::tcp::ArgumentError {}

/// Error returned by [`Socket::send`]
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
Expand Down Expand Up @@ -767,6 +791,41 @@ impl<'a> Socket<'a> {
}
}

/// Return the local receive window scaling factor defined in [RFC 1323].
///
/// The value will become constant after the connection is established.
/// It may be reset to 0 during the handshake if remote side does not support window scaling.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like that this API has a bidirectional data flow, where a value set by the consumer is read by the TCP/IP stack, but also then is updated by the TCP/IP stack. Is there any precedent (whether in smoltcp or in BSD TCP/IP) to have an option like this? I would there to be two API entry points, one to request an option, one to see if it was applied.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In google's gVisor userspace TCP/IP stack, they have similar options. I believe the issue is that, smoltcp has its default, and we want to override this default. That's why we have such a bidirectional data flow. Our TCP/IP stack reads the value (after connect/accept), and sets the value if no appropriate default value is set (before connect/accept).

As for the "two API entry points", I don't quite get that. Could you give some examples?

pub fn local_recv_win_scale(&self) -> u8 {
self.remote_win_shift
}

/// Set the local receive window scaling factor defined in [RFC 1323].
///
/// The value will become constant after the connection is established.
/// It may be reset to 0 during the handshake if remote side does not support window scaling.
///
/// # Errors
/// `Err(ArgumentError::InvalidArgs)` if the scale is greater than 14.
/// `Err(ArgumentError::InvalidState)` if the socket is not in the `Closed` or `Listen` state.
/// `Err(ArgumentError::InsufficientResource)` if the receive buffer is smaller than (1<<scale) bytes.
pub fn set_local_recv_win_scale(&mut self, scale: u8) -> Result<(), ArgumentError> {
if scale > 14 {
return Err(ArgumentError::InvalidArgs);
}

if self.rx_buffer.capacity() < (1 << scale) as usize {
return Err(ArgumentError::InsufficientResource);
}

match self.state {
State::Closed | State::Listen => {
self.remote_win_shift = scale;
Ok(())
}
_ => Err(ArgumentError::InvalidState),
}
}

/// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
///
/// See also the [set_hop_limit](#method.set_hop_limit) method
Expand Down Expand Up @@ -815,6 +874,7 @@ impl<'a> Socket<'a> {
fn reset(&mut self) {
let rx_cap_log2 =
mem::size_of::<usize>() * 8 - self.rx_buffer.capacity().leading_zeros() as usize;
let new_rx_win_shift = rx_cap_log2.saturating_sub(16) as u8;

self.state = State::Closed;
self.timer = Timer::new();
Expand All @@ -832,7 +892,10 @@ impl<'a> Socket<'a> {
self.remote_last_win = 0;
self.remote_win_len = 0;
self.remote_win_scale = None;
self.remote_win_shift = rx_cap_log2.saturating_sub(16) as u8;
// keep user-specified window scaling across connect()/listen()
if self.remote_win_shift < new_rx_win_shift {
self.remote_win_shift = new_rx_win_shift;
}
self.remote_mss = DEFAULT_MSS;
self.remote_last_ts = None;
self.ack_delay_timer = AckDelayTimer::Idle;
Expand Down Expand Up @@ -2280,6 +2343,7 @@ impl<'a> Socket<'a> {
} else if self.timer.should_close(cx.now()) {
// If we have spent enough time in the TIME-WAIT state, close the socket.
tcp_trace!("TIME-WAIT timer expired");
self.remote_win_shift = 0;
self.reset();
return Ok(());
} else {
Expand Down Expand Up @@ -2550,6 +2614,53 @@ impl<'a> Socket<'a> {
.unwrap_or(&PollAt::Ingress)
}
}

/// Replace the receive buffer with a new one.
///
/// The requirements for the new buffer are:
/// 1. The new buffer must be larger than the length of remaining data in the current buffer
/// 2. The new buffer must be multiple of (1 << self.remote_win_shift)
///
/// If the new buffer does not meet the requirements, the new buffer is returned as an error;
/// otherwise, the old buffer is returned as an Ok value.
///
/// See also the [local_recv_win_scale](struct.Socket.html#method.local_recv_win_scale) methods.
pub fn replace_recv_buffer<T: Into<SocketBuffer<'a>>>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a generic method that has no direct correspondence to TCP receive window resizing and as such I don't see why it should be a part of this PR at all.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method does enable the TCP receive window resizing. Smoltcp uses rx_buffer.capacity()-rx_buffer.len() to compute the current receive window. Without this method, it's impossible to adjust rx_buffer.capacity(). By replaced with a larger receive buffer, smoltcp knows it can advertise a larger receive window size safely without correctness issue.

&mut self,
new_buffer: T,
) -> Result<SocketBuffer<'a>, SocketBuffer<'a>> {
let mut replaced_buf = new_buffer.into();
/* Check if the new buffer is valid
* Requirements:
* 1. The new buffer must be larger than the length of remaining data in the current buffer
* 2. The new buffer must be multiple of (1 << self.remote_win_shift)
*/
if replaced_buf.capacity() < self.rx_buffer.len()
|| replaced_buf.capacity() % (1 << self.remote_win_shift) != 0
{
return Err(replaced_buf);
}
replaced_buf.clear();
self.rx_buffer.dequeue_many_with(|buf| {
let enqueued_len = replaced_buf.enqueue_slice(buf);
assert_eq!(enqueued_len, buf.len());
(enqueued_len, replaced_buf.get_allocated(0, enqueued_len))
});
if !self.rx_buffer.is_empty() {
// copy the wrapped around part
self.rx_buffer.dequeue_many_with(|buf| {
let enqueued_len = replaced_buf.enqueue_slice(buf);
assert_eq!(enqueued_len, buf.len());
(
enqueued_len,
replaced_buf.get_allocated(buf.len() - enqueued_len, enqueued_len),
)
});
}
assert_eq!(self.rx_buffer.len(), 0);
mem::swap(&mut self.rx_buffer, &mut replaced_buf);
Ok(replaced_buf)
}
}

impl<'a> fmt::Write for Socket<'a> {
Expand Down Expand Up @@ -7840,4 +7951,70 @@ mod test {
}]
);
}

// =========================================================================================//
// Tests for window scaling
// =========================================================================================//

fn socket_established_with_window_scaling() -> TestSocket {
let mut s = socket_established();
s.remote_win_shift = 10;
const BASE: usize = 1 << 10;
s.tx_buffer = SocketBuffer::new(vec![0u8; 64 * BASE]);
s.rx_buffer = SocketBuffer::new(vec![0u8; 64 * BASE]);
s
}

#[test]
fn test_too_large_window_scale() {
let mut socket = Socket::new(
SocketBuffer::new(vec![0; 8 * (1 << 15)]),
SocketBuffer::new(vec![0; 8 * (1 << 15)]),
);
assert!(socket.set_local_recv_win_scale(15).is_err())
}

#[test]
fn test_set_window_scale() {
let mut socket = Socket::new(
SocketBuffer::new(vec![0; 128]),
SocketBuffer::new(vec![0; 128]),
);
assert!(matches!(socket.state, State::Closed));
assert_eq!(socket.rx_buffer.capacity(), 128);
assert!(socket.set_local_recv_win_scale(6).is_ok());
assert!(socket.set_local_recv_win_scale(14).is_err());
assert_eq!(socket.local_recv_win_scale(), 6);
}

#[test]
fn test_set_scale_with_tcp_state() {
let mut socket = socket();
assert!(socket.set_local_recv_win_scale(1).is_ok());
let mut socket = socket_established();
assert!(socket.set_local_recv_win_scale(1).is_err());
let mut socket = socket_listen();
assert!(socket.set_local_recv_win_scale(1).is_ok());
let mut socket = socket_syn_received();
assert!(socket.set_local_recv_win_scale(1).is_err());
}

#[test]
fn test_resize_recv_buffer_invalid_size() {
let mut s = socket_established_with_window_scaling();
assert_eq!(s.rx_buffer.enqueue_slice(&[42; 31 * 1024]), 31 * 1024);
assert_eq!(s.rx_buffer.len(), 31 * 1024);
assert!(s
.replace_recv_buffer(SocketBuffer::new(vec![7u8; 32 * 1024 + 512]))
.is_err());
assert!(s
.replace_recv_buffer(SocketBuffer::new(vec![7u8; 16 * 1024]))
.is_err());
let old_buffer = s
.replace_recv_buffer(SocketBuffer::new(vec![7u8; 32 * 1024]))
.unwrap();
assert_eq!(old_buffer.capacity(), 64 * 1024);
assert_eq!(s.rx_buffer.len(), 31 * 1024);
assert_eq!(s.rx_buffer.capacity(), 32 * 1024);
}
}
Loading