Skip to content

Commit

Permalink
add further info to LocalProtocolViolation variants
Browse files Browse the repository at this point in the history
  • Loading branch information
Fraser999 authored and marc-casperlabs committed Mar 14, 2024
1 parent c64366a commit 4c71055
Showing 1 changed file with 77 additions and 22 deletions.
99 changes: 77 additions & 22 deletions src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,11 @@ impl Channel {
self.outgoing_requests.len() < self.config.request_limit as usize
}

/// Returns the configured request limit for this channel.
pub fn request_limit(&self) -> u16 {
self.config.request_limit
}

/// Creates a new request, bypassing all client-side checks.
///
/// Low-level function that does nothing but create a syntactically correct request and track
Expand Down Expand Up @@ -474,7 +479,7 @@ impl Display for CompletedRead {
}
}

/// The caller of the this crate has violated the protocol.
/// The caller of this crate has violated the protocol.
///
/// A correct implementation of a client should never encounter this, thus simply unwrapping every
/// instance of this as part of a `Result<_, LocalProtocolViolation>` is usually a valid choice.
Expand All @@ -487,18 +492,26 @@ pub enum LocalProtocolViolation {
///
/// Wait for additional requests to be cancelled or answered. Calling
/// [`JulietProtocol::allowed_to_send_request()`] beforehand is recommended.
#[error("sending would exceed request limit")]
WouldExceedRequestLimit,
#[error("sending would exceed request limit of {limit}")]
WouldExceedRequestLimit {
/// The configured limit for requests on the channel.
limit: u16,
},
/// The channel given does not exist.
///
/// The given [`ChannelId`] exceeds `N` of [`JulietProtocol<N>`].
#[error("invalid channel")]
InvalidChannel(ChannelId),
#[error("channel {channel} not a member of configured {channel_count} channels")]
InvalidChannel {
/// The provided channel ID.
channel: ChannelId,
/// The configured number of channels.
channel_count: usize,
},
/// The given payload exceeds the configured limit.
///
/// See [`ChannelConfiguration::with_max_request_payload_size()`] and
/// [`ChannelConfiguration::with_max_response_payload_size()`] for details.
#[error("payload exceeds configured limit")]
#[error("payload length of {payload_length} bytes exceeds configured limit of {limit}")]
PayloadExceedsLimit {
/// The payload length in bytes.
payload_length: usize,
Expand All @@ -508,8 +521,16 @@ pub enum LocalProtocolViolation {
/// The given error payload exceeds a single frame.
///
/// Error payloads may not span multiple frames, shorten the payload or increase frame size.
#[error("error payload would be multi-frame")]
ErrorPayloadIsMultiFrame,
#[error(
"error payload of {payload_length} bytes exceeds a single frame with configured max size \
of {max_frame_size})"
)]
ErrorPayloadIsMultiFrame {
/// The payload length in bytes.
payload_length: usize,
/// The configured maximum frame size in bytes.
max_frame_size: u32,
},
}

macro_rules! log_frame {
Expand Down Expand Up @@ -539,7 +560,10 @@ impl<const N: usize> JulietProtocol<N> {
#[inline(always)]
const fn lookup_channel(&self, channel: ChannelId) -> Result<&Channel, LocalProtocolViolation> {
if channel.0 as usize >= N {
Err(LocalProtocolViolation::InvalidChannel(channel))
Err(LocalProtocolViolation::InvalidChannel {
channel,
channel_count: N,
})
} else {
Ok(&self.channels[channel.0 as usize])
}
Expand All @@ -554,7 +578,10 @@ impl<const N: usize> JulietProtocol<N> {
channel: ChannelId,
) -> Result<&mut Channel, LocalProtocolViolation> {
if channel.0 as usize >= N {
Err(LocalProtocolViolation::InvalidChannel(channel))
Err(LocalProtocolViolation::InvalidChannel {
channel,
channel_count: N,
})
} else {
Ok(&mut self.channels[channel.0 as usize])
}
Expand Down Expand Up @@ -608,7 +635,9 @@ impl<const N: usize> JulietProtocol<N> {
}

if !chan.allowed_to_send_request() {
return Err(LocalProtocolViolation::WouldExceedRequestLimit);
return Err(LocalProtocolViolation::WouldExceedRequestLimit {
limit: chan.request_limit(),
});
}

Ok(chan.create_unchecked_request(channel, payload))
Expand Down Expand Up @@ -723,11 +752,15 @@ impl<const N: usize> JulietProtocol<N> {
id: Id,
payload: Bytes,
) -> Result<OutgoingMessage, LocalProtocolViolation> {
let header = Header::new_error(header::ErrorKind::Other, channel, id);
let header = Header::new_error(ErrorKind::Other, channel, id);

let payload_length = payload.len();
let msg = OutgoingMessage::new(header, Some(payload));
if msg.is_multi_frame(self.max_frame_size) {
Err(LocalProtocolViolation::ErrorPayloadIsMultiFrame)
Err(LocalProtocolViolation::ErrorPayloadIsMultiFrame {
payload_length,
max_frame_size: self.max_frame_size.0,
})
} else {
Ok(msg)
}
Expand Down Expand Up @@ -1264,7 +1297,8 @@ mod tests {

#[test]
fn test_channel_lookups_work() {
let mut protocol: JulietProtocol<3> = ProtocolBuilder::new().build();
const CHANNEL_COUNT: usize = 3;
let mut protocol: JulietProtocol<CHANNEL_COUNT> = ProtocolBuilder::new().build();

// We mark channels by inserting an ID into them, that way we can ensure we're not getting
// back the same channel every time.
Expand All @@ -1285,15 +1319,24 @@ mod tests {
.insert(Id::new(102));
assert!(matches!(
protocol.lookup_channel_mut(ChannelId(3)),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(3)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(3),
channel_count: CHANNEL_COUNT
})
));
assert!(matches!(
protocol.lookup_channel_mut(ChannelId(4)),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(4)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(4),
channel_count: CHANNEL_COUNT
})
));
assert!(matches!(
protocol.lookup_channel_mut(ChannelId(255)),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(255)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(255),
channel_count: CHANNEL_COUNT
})
));

// Now look up the channels and ensure they contain the right values
Expand All @@ -1320,15 +1363,24 @@ mod tests {
);
assert!(matches!(
protocol.lookup_channel(ChannelId(3)),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(3)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(3),
channel_count: CHANNEL_COUNT
})
));
assert!(matches!(
protocol.lookup_channel(ChannelId(4)),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(4)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(4),
channel_count: CHANNEL_COUNT
})
));
assert!(matches!(
protocol.lookup_channel(ChannelId(255)),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(255)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(255),
channel_count: CHANNEL_COUNT
})
));
}

Expand Down Expand Up @@ -1453,7 +1505,10 @@ mod tests {
// Try an invalid channel, should result in an error.
assert!(matches!(
protocol.create_request(ChannelId::new(2), payload.get()),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(2)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(2),
channel_count: 2
})
));

assert!(protocol
Expand All @@ -1465,7 +1520,7 @@ mod tests {

assert!(matches!(
protocol.create_request(channel, payload.get()),
Err(LocalProtocolViolation::WouldExceedRequestLimit)
Err(LocalProtocolViolation::WouldExceedRequestLimit { limit: 1 })
));
}
}
Expand Down

0 comments on commit 4c71055

Please sign in to comment.