diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs index ceb5eda..01e4293 100644 --- a/protocol/src/lib.rs +++ b/protocol/src/lib.rs @@ -44,6 +44,10 @@ pub const NUM_DECOY_BYTES: usize = 1; pub const LENGTH_BYTES: usize = 3; /// Value for decoy flag. pub const DECOY_BYTE: u8 = 128; +/// Initial buffer for decoy and version packets in the handshake. +/// The buffer may have to be expanded if a party is sending large +/// decoy packets. +pub const INITIAL_HANDSHAKE_BUFFER_BYTES: usize = 4096; // Version content is always empty for the current version of the protocol. const VERSION_CONTENT: [u8; 0] = []; @@ -59,6 +63,10 @@ const GARBAGE_TERMINTOR_BYTES: usize = 16; pub enum Error { /// The message decoded is smaller than expected MessageLengthTooSmall, + /// Allocated memory is too small for packet, returns + /// total required bytes for the failed packet so the + /// caller can re-allocate and re-attempt. + BufferTooSmall { required_bytes: usize }, /// There is a mismatch in the encoding of a message IncompatableV1Message, /// The maximum amount of garbage bytes was exceeded in the handshake. @@ -82,6 +90,11 @@ impl fmt::Display for Error { write!(f, "Unable to generate secret materials {}", e) } Error::MessageLengthTooSmall => write!(f, "Message length too small allocation"), + Error::BufferTooSmall { required_bytes } => write!( + f, + "Buffer memory allocation too small, need at least {} bytes", + required_bytes + ), Error::IncompatableV1Message => write!(f, "Incompatable V1 message"), Error::MaxGarbageLength => { write!(f, "More than 4095 bytes of garbage in the handshake") @@ -106,6 +119,7 @@ impl std::error::Error for Error { Error::Cipher(e) => Some(e), Error::OutOfSync => None, Error::SecretExpansion => None, + Error::BufferTooSmall { required_bytes: _ } => None, } } } @@ -219,7 +233,7 @@ impl PacketReader { /// /// - `ciphertext` - The message from the peer. /// - `contents` - Mutable buffer to write plaintext. - /// - `aad` - Optional authentication for the peer, currently only used for the first round of messages. + /// - `aad` - Optional authentication for the peer. /// /// # Errors /// @@ -232,6 +246,12 @@ impl PacketReader { ) -> Result<(), Error> { let auth = aad.unwrap_or_default(); let (msg, tag) = ciphertext.split_at(ciphertext.len() - TAG_BYTES); + // Bounds check that the contents buffer is large enough. + if contents.len() < msg.len() { + return Err(Error::BufferTooSmall { + required_bytes: msg.len(), + }); + } contents[0..msg.len()].copy_from_slice(msg); self.packet_decoding_cipher.decrypt( auth, @@ -248,7 +268,7 @@ impl PacketReader { /// # Arguments /// /// - `ciphertext` - The message from the peer. - /// - `aad` - Optional authentication for the peer, currently only used for the first round of messages. + /// - `aad` - Optional authentication for the peer. /// /// # Errors /// @@ -564,8 +584,10 @@ pub struct Handshake<'a> { remote_garbage_terminator: Option<[u8; 16]>, /// Packet handler output. packet_handler: Option, - /// Stored state between authentication attempts, decrypted length for next packet. - authentication_packet_bytes: Option, + /// Decrypted length for next packet. Store state between authentication attempts to avoid resetting ciphers. + current_packet_length_bytes: Option, + /// Processesed message index. Store state between authentication attempts to avoid resetting ciphers. + current_message_index: usize, } impl<'a> Handshake<'a> { @@ -640,7 +662,8 @@ impl<'a> Handshake<'a> { garbage, remote_garbage_terminator: None, packet_handler: None, - authentication_packet_bytes: None, + current_packet_length_bytes: None, + current_message_index: 0, }) } @@ -705,74 +728,145 @@ impl<'a> Handshake<'a> { Ok(()) } + /// Authenticate the channel and manage the memory allocation required. + /// + /// This function wraps [`authenticate_garbage_and_version`] and handles buffer allocation. + /// + /// # Arguments + /// + /// * `buffer` - The input buffer + /// + /// # Errors + /// + /// Returns the same errors of [`authenticate_garbage_and_version`], except `BufferTooSmall` is managed internally. + #[cfg(feature = "alloc")] + pub fn authenticate_garbage_and_version_with_alloc( + &mut self, + buffer: &[u8], + ) -> Result<(), Error> { + let mut packet_buffer = vec![0u8; INITIAL_HANDSHAKE_BUFFER_BYTES]; + + loop { + match self.authenticate_garbage_and_version(buffer, &mut packet_buffer) { + Ok(()) => return Ok(()), + Err(Error::BufferTooSmall { required_bytes }) => { + packet_buffer.resize(required_bytes, 0); + } + Err(e) => return Err(e), + } + } + } + /// Authenticate the channel. /// /// Designed to be called multiple times until succesful in order to flush - /// garbage and decoy packets from channel. + /// garbage and decoy packets from channel. If a `BufferTooSmall ` is + /// returned, the buffer should be extended until `BufferTooSmall ` is + /// not returned. All other errors are fatal for the handshake and it should + /// be completely restarted. /// /// # Arguments /// - /// - `buffer` - Should contain all garbage, the garbage terminator, and the version packet received from peer. + /// - `buffer` - Should contain all garbage, the garbage terminator, any decoy packets, and finally the version packet received from peer. + /// - `packet_buffer` - Required memory allocation for decrypting decoy and version packets. /// /// # Error /// - /// - `MessageLengthTooSmall` - The buffer did not contain all required information and should be extended (e.g. read more off a socket) and authentication re-tried. + /// - `BufferTooSmall ` - The buffer did not contain all required information and should be extended (e.g. read more off a socket) and authentication re-tried. /// - `HandshakeOutOfOrder` - The handshake sequence is in a bad state and should be restarted. - /// - `MaxGarbageLength` - Buffer did not contain the garbage terminator and contains too much garbage, should not be retried. - pub fn authenticate_garbage_and_version(&mut self, buffer: &[u8]) -> Result<(), Error> { - // Find the end of the garbage + /// - `MaxGarbageLength` - Buffer did not contain the garbage terminator, should not be retried. + pub fn authenticate_garbage_and_version( + &mut self, + buffer: &[u8], + packet_buffer: &mut [u8], + ) -> Result<(), Error> { + // Find the end of the garbage. let (garbage, message) = split_garbage_and_message( buffer, self.remote_garbage_terminator .ok_or(Error::HandshakeOutOfOrder)?, )?; - // Quickly fail if the message doesn't even have enough bytes for a length packet. - if message.len() < LENGTH_BYTES { - return Err(Error::MessageLengthTooSmall); + // Flag to track if the version packet has been received to signal the end of the handshake. + let mut found_version_packet = false; + + // The first packet, even if it is a decoy packet, + // is used to authenticate the received garbage through + // the AAD. + if self.current_message_index == 0 { + found_version_packet = self.decrypt_packet(message, packet_buffer, Some(garbage))?; + } + + // If the first packet is a decoy, or if this is a follow up + // authentication attempt, the decoys need to be flushed and + // the version packet found. + // + // The version packet is essentially ignored in the current + // version of the protocol, but it does move the cipher + // states forward. It could be extended in the future. + while !found_version_packet { + found_version_packet = self.decrypt_packet(message, packet_buffer, None)?; } + Ok(()) + } + + /// Decrypt the next packet in the message buffer while + /// book keeping relevant lengths and indexes. This allows + /// the buffer to be re-processed without throwing off + /// the state of the ciphers. + /// + /// # Returns + /// + /// True if the decrypted packet is the version packet. + fn decrypt_packet( + &mut self, + message: &[u8], + packet_buffer: &mut [u8], + garbage: Option<&[u8]>, + ) -> Result { let packet_handler = self .packet_handler .as_mut() .ok_or(Error::HandshakeOutOfOrder)?; - // TODO: Drain decoy packets, will require some more state to be store between attempts, like a message index. - - // Grab the packet length from internal statem, else decrypt it and store incase of failure. - let packet_length = if self.authentication_packet_bytes.is_some() { - self.authentication_packet_bytes - .ok_or(Error::HandshakeOutOfOrder) - } else { + if self.current_packet_length_bytes.is_none() { + // Bounds check on the input buffer. + if message.len() < self.current_message_index + LENGTH_BYTES { + return Err(Error::MessageLengthTooSmall); + } let packet_length = packet_handler.decypt_len( - message[0..LENGTH_BYTES] + message[self.current_message_index..LENGTH_BYTES] .try_into() - .map_err(|_| Error::MessageLengthTooSmall)?, + .expect("Buffer slice must be exactly 3 bytes long"), ); - // Hang on to decrypted length incase next steps fail to avoid using the cipher again re-attempting authentication. - self.authentication_packet_bytes = Some(packet_length); - Ok(packet_length) - }?; - - // Fail if there is not enough bytes to parse the message. - if message.len() < LENGTH_BYTES + packet_length { - return Err(Error::MessageLengthTooSmall); + // Hang on to decrypted length incase follow up steps fail + // and another authentication attempt is required. Avoids + // throwing off the cipher state. + self.current_packet_length_bytes = Some(packet_length); } - // Authenticate received garbage and get version packet. - // Assuming no decoy packets so AAD is set on version packet. - // The version packet is ignored in this version of the protocol, but - // moves along state in the ciphers. + let packet_length = self + .current_packet_length_bytes + .ok_or(Error::HandshakeOutOfOrder)?; - // Version packets have 0 contents. - let mut version_packet = [0u8; NUM_DECOY_BYTES + TAG_BYTES]; + // Bounds check on input buffer. + if message.len() < self.current_message_index + LENGTH_BYTES + packet_length { + return Err(Error::MessageLengthTooSmall); + } packet_handler.packet_reader.decrypt_contents( - &message[LENGTH_BYTES..packet_length + LENGTH_BYTES], - &mut version_packet, - Some(garbage), + &message[self.current_message_index + LENGTH_BYTES + ..self.current_message_index + LENGTH_BYTES + packet_length], + packet_buffer, + garbage, )?; - Ok(()) + // Mark current decryption point in the buffer. + self.current_message_index = self.current_message_index + LENGTH_BYTES + packet_length + 1; + self.current_packet_length_bytes = None; + + // The version packet is currently just an empty packet. + Ok(packet_buffer[0] != DECOY_BYTE) } /// Complete the handshake and return the packet handler for further communication. @@ -1067,13 +1161,13 @@ mod tests { // The initiator verifies the second half of the responders message which // includes the garbage terminator and version packet. init_handshake - .authenticate_garbage_and_version(&resp_message[64..]) + .authenticate_garbage_and_version_with_alloc(&resp_message[64..]) .unwrap(); // The responder verifies the second message from the initiator which // includes the garbage terminator and version packet. resp_handshake - .authenticate_garbage_and_version(&init_message_2) + .authenticate_garbage_and_version_with_alloc(&init_message_2) .unwrap(); let mut alice = init_handshake.finalize().unwrap(); @@ -1122,10 +1216,10 @@ mod tests { .unwrap(); init_handshake - .authenticate_garbage_and_version(&resp_message[64..]) + .authenticate_garbage_and_version_with_alloc(&resp_message[64..]) .unwrap(); resp_handshake - .authenticate_garbage_and_version(&init_finalize_message) + .authenticate_garbage_and_version_with_alloc(&init_finalize_message) .unwrap(); let mut alice = init_handshake.finalize().unwrap(); diff --git a/protocol/tests/round_trips.rs b/protocol/tests/round_trips.rs index 7d4af2f..9cf60f8 100644 --- a/protocol/tests/round_trips.rs +++ b/protocol/tests/round_trips.rs @@ -31,10 +31,10 @@ fn hello_world_happy_path() { .unwrap(); init_handshake - .authenticate_garbage_and_version(&resp_message[64..]) + .authenticate_garbage_and_version_with_alloc(&resp_message[64..]) .unwrap(); resp_handshake - .authenticate_garbage_and_version(&init_finalize_message) + .authenticate_garbage_and_version_with_alloc(&init_finalize_message) .unwrap(); let mut alice = init_handshake.finalize().unwrap(); @@ -111,7 +111,7 @@ fn regtest_handshake() { let response = &mut max_response[..size]; dbg!("Authenticating the handshake"); handshake - .authenticate_garbage_and_version(response) + .authenticate_garbage_and_version_with_alloc(response) .unwrap(); dbg!("Finalizing the handshake"); let packet_handler = handshake.finalize().unwrap(); diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index b9bada1..3a90d5f 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -11,6 +11,8 @@ use tokio::net::{TcpListener, TcpStream}; configure_me::include_config!(); +const HANDSHAKE_BUFFER_BYTES: usize = 4096; + /// Validate and bootstrap proxy connection. async fn proxy_conn(client: TcpStream) -> Result<(), bip324_proxy::Error> { let remote_ip = bip324_proxy::peek_addr(&client) @@ -63,30 +65,33 @@ async fn proxy_conn(client: TcpStream) -> Result<(), bip324_proxy::Error> { .expect("send garbage and version"); // Keep pulling bytes from the buffer until the garbage is flushed. - // Capacity is arbitrary, could use some tuning. - let mut remote_garbage_and_version_buffer = BytesMut::with_capacity(4096); + let mut remote_garbage_and_version_buffer = BytesMut::with_capacity(HANDSHAKE_BUFFER_BYTES); loop { println!("Authenticating garbage and version packet..."); - let read = remote + + // Read from the remote, hopefully contains all garbage, decoy packets, and version packet. + // BytesMut is keeping track of its internal posistion, so this read should only ever + // extend the buffer on retries. Not overwrite it. The buffer will grow if required. + if let Err(e) = remote .read_buf(&mut remote_garbage_and_version_buffer) - .await; - match read { - Err(e) => panic!("unable to read garbage {}", e), - _ => { - let auth = - handshake.authenticate_garbage_and_version(&remote_garbage_and_version_buffer); - match auth { - Err(e) => match e { - // Read again if too small, other wise surface error. - bip324::Error::MessageLengthTooSmall => continue, - e => panic!("unable to authenticate garbage {}", e), - }, - _ => { - println!("Channel authenticated."); - break; - } - } + .await + { + panic!("unable to read garbage {}", e) + } + + // Attempt to authenticate the channel. + match handshake + .authenticate_garbage_and_version_with_alloc(&remote_garbage_and_version_buffer) + { + Ok(()) => { + println!("Channel authenticated."); + break; + } + Err(bip324::Error::MessageLengthTooSmall) => { + // Attempt to pull more from the buffer and retry. + continue; } + Err(e) => panic!("unable to authenticate garbage and version {}", e), } }