Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make read method on packet reader no std #45

Merged
merged 4 commits into from
Apr 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 66 additions & 32 deletions protocol/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use secp256k1::{
const DECOY_BYTES: usize = 1;
/// Number of bytes for the authentication tag of a packet.
const TAG_BYTES: usize = 16;
/// Number of bytes for the length encoding of a packet.
/// Number of bytes for the length encoding prefix of a packet.
const LENGTH_BYTES: usize = 3;
/// Value for decoy flag.
const DECOY: u8 = 128;
Expand Down Expand Up @@ -136,6 +136,19 @@ pub struct ReceivedMessage {
pub message: Option<Vec<u8>>,
}

impl ReceivedMessage {
pub fn new(msg_bytes: &[u8]) -> Result<Self, Error> {
let header = msg_bytes.first().ok_or(Error::MessageLengthTooSmall)?;
if header.eq(&DECOY) {
Ok(ReceivedMessage { message: None })
} else {
Ok(ReceivedMessage {
message: Some(msg_bytes[1..].to_vec()),
})
}
}
}

#[derive(Clone, Debug)]
pub struct PacketReader {
length_decoding_cipher: FSChaCha20,
Expand All @@ -158,7 +171,6 @@ impl PacketReader {
/// The length to be read into the buffer next to receive the full message from the peer.
pub fn decypt_len(&mut self, len_bytes: [u8; 3]) -> usize {
let mut enc_content_len = [0u8; 3];
// TODO: should we just make len_butes mutable?
enc_content_len.copy_from_slice(&len_bytes);
self.length_decoding_cipher
.crypt(&mut enc_content_len)
Expand All @@ -176,40 +188,51 @@ impl PacketReader {
///
/// # Arguments
///
/// - `contents` - The message from the peer.
/// - `aad` - Optional authentication for the peer, currently only used for the first round of messages.
///
/// # Returns
///
/// The message from the peer.
/// - `ciphertext` - The message from the peer.
/// - `contents` - Mutable buffer to write plaintext.
/// - `aad` - Optional authentication for the peer, currently only used for the first round of messages.
///
/// # Errors
///
/// Fails if the packet was not decrypted or authenticated properly.
pub fn decrypt_contents(
&mut self,
contents: Vec<u8>,
aad: Option<Vec<u8>>,
) -> Result<ReceivedMessage, Error> {
ciphertext: &[u8],
contents: &mut [u8],
aad: Option<&[u8]>,
) -> Result<(), Error> {
let auth = aad.unwrap_or_default();
let mut contents = contents.clone();
let contents_len = contents.len();
let (ciphertext, tag) = contents.split_at_mut(contents_len - 16);
let (msg, tag) = ciphertext.split_at(ciphertext.len() - TAG_BYTES);
contents[0..msg.len()].copy_from_slice(msg);
self.packet_decoding_cipher.decrypt(
&auth,
ciphertext,
tag.try_into().expect("16 bytes"),
auth,
&mut contents[0..msg.len()],
tag.try_into().map_err(|_| Error::MessageLengthTooSmall)?,
)?;
let header = *ciphertext
.first()
.expect("All contents should include a header.");
if header.eq(&DECOY) {
return Ok(ReceivedMessage { message: None });
}
let message = ciphertext[1..].to_vec();
Ok(ReceivedMessage {
message: Some(message),
})

Ok(())
}

/// Decrypt the rest of the message from the peer, excluding the 3 length bytes. This method should only be called after
/// calling `decrypt_len` on the first three bytes of the buffer.
///
/// # Arguments
///
/// - `ciphertext` - The message from the peer.
/// - `aad` - Optional authentication for the peer, currently only used for the first round of messages.
///
/// # Errors
///
/// Fails if the packet was not decrypted or authenticated properly.
#[cfg(feature = "std")]
pub fn decrypt_contents_with_alloc(
&mut self,
ciphertext: &[u8],
aad: Option<&[u8]>,
) -> Result<Vec<u8>, Error> {
let mut contents = vec![0u8; ciphertext.len() - TAG_BYTES];
self.decrypt_contents(ciphertext, &mut contents, aad)?;
Ok(contents)
}
}

Expand Down Expand Up @@ -403,12 +426,19 @@ impl PacketHandler {
/// # Errors
///
/// Fails if the packet was not decrypted or authenticated properly.
#[cfg(feature = "std")]
pub fn decrypt_contents(
&mut self,
contents: Vec<u8>,
aad: Option<Vec<u8>>,
aad: Option<&[u8]>,
) -> Result<ReceivedMessage, Error> {
self.packet_reader.decrypt_contents(contents, aad)
let contents = self
.packet_reader
.decrypt_contents_with_alloc(&contents, aad)?;

let message = ReceivedMessage::new(&contents)?;

Ok(message)
}

/// Decrypt the one or more messages from bytes received by a V2 peer.
Expand Down Expand Up @@ -774,9 +804,13 @@ impl<'a> Handshake<'a> {
// Assuming no decoy packets so AAD is set on version packet.
// The version packet is ignored in this version of the protocol, but
// moves along state in the ciphers.
packet_handler.decrypt_contents(
message[LENGTH_BYTES..packet_length + LENGTH_BYTES].to_vec(),
Some(garbage.to_vec()),

// Version packets have 0 contents.
let mut version_packet = [0u8; DECOY_BYTES + TAG_BYTES];
packet_handler.packet_reader.decrypt_contents(
&message[LENGTH_BYTES..packet_length + LENGTH_BYTES],
&mut version_packet,
Some(garbage),
)?;

Ok(())
Expand Down
12 changes: 8 additions & 4 deletions proxy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
use std::fmt;
use std::net::SocketAddr;

use bip324::ReceivedMessage;
use bip324::{PacketReader, PacketWriter};
use bitcoin::consensus::Decodable;
use bitcoin::hashes::sha256d;
Expand Down Expand Up @@ -165,12 +166,15 @@ pub async fn read_v2<T: AsyncRead + Unpin>(
let packet_bytes_len = decrypter.decypt_len(length_bytes);
let mut packet_bytes = vec![0u8; packet_bytes_len];
input.read_exact(&mut packet_bytes).await?;
let contents = decrypter
.decrypt_contents(packet_bytes, None)
.expect("decrypt")
.message
let raw = decrypter
.decrypt_contents_with_alloc(&packet_bytes, None)
.expect("decrypt");

let contents = ReceivedMessage::new(&raw)
.expect("some bytes")
.message
.expect("not a decoy");

// If packet is using short or full ID.
let (cmd, cmd_index) = if contents.starts_with(&[0u8]) {
(to_ascii(contents[1..13].try_into().expect("12 bytes")), 13)
Expand Down
Loading