diff --git a/src/protocol.rs b/src/protocol.rs index 43324ff..759ddc1 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -305,6 +305,11 @@ impl Channel { self.outgoing_requests.len() < self.config.request_limit as usize } + /// Returns the configured request limit for this channel. + pub fn request_limit(&self) -> u16 { + self.config.request_limit + } + /// Creates a new request, bypassing all client-side checks. /// /// Low-level function that does nothing but create a syntactically correct request and track @@ -474,7 +479,7 @@ impl Display for CompletedRead { } } -/// The caller of the this crate has violated the protocol. +/// The caller of this crate has violated the protocol. /// /// A correct implementation of a client should never encounter this, thus simply unwrapping every /// instance of this as part of a `Result<_, LocalProtocolViolation>` is usually a valid choice. @@ -487,18 +492,26 @@ pub enum LocalProtocolViolation { /// /// Wait for additional requests to be cancelled or answered. Calling /// [`JulietProtocol::allowed_to_send_request()`] beforehand is recommended. - #[error("sending would exceed request limit")] - WouldExceedRequestLimit, + #[error("sending would exceed request limit of {limit}")] + WouldExceedRequestLimit { + /// The configured limit for requests on the channel. + limit: u16, + }, /// The channel given does not exist. /// /// The given [`ChannelId`] exceeds `N` of [`JulietProtocol`]. - #[error("invalid channel")] - InvalidChannel(ChannelId), + #[error("channel {channel} not a member of configured {channel_count} channels")] + InvalidChannel { + /// The provided channel ID. + channel: ChannelId, + /// The configured number of channels. + channel_count: usize, + }, /// The given payload exceeds the configured limit. /// /// See [`ChannelConfiguration::with_max_request_payload_size()`] and /// [`ChannelConfiguration::with_max_response_payload_size()`] for details. - #[error("payload exceeds configured limit")] + #[error("payload length of {payload_length} bytes exceeds configured limit of {limit}")] PayloadExceedsLimit { /// The payload length in bytes. payload_length: usize, @@ -508,8 +521,16 @@ pub enum LocalProtocolViolation { /// The given error payload exceeds a single frame. /// /// Error payloads may not span multiple frames, shorten the payload or increase frame size. - #[error("error payload would be multi-frame")] - ErrorPayloadIsMultiFrame, + #[error( + "error payload of {payload_length} bytes exceeds a single frame with configured max size \ + of {max_frame_size})" + )] + ErrorPayloadIsMultiFrame { + /// The payload length in bytes. + payload_length: usize, + /// The configured maximum frame size in bytes. + max_frame_size: u32, + }, } macro_rules! log_frame { @@ -539,7 +560,10 @@ impl JulietProtocol { #[inline(always)] const fn lookup_channel(&self, channel: ChannelId) -> Result<&Channel, LocalProtocolViolation> { if channel.0 as usize >= N { - Err(LocalProtocolViolation::InvalidChannel(channel)) + Err(LocalProtocolViolation::InvalidChannel { + channel, + channel_count: N, + }) } else { Ok(&self.channels[channel.0 as usize]) } @@ -554,7 +578,10 @@ impl JulietProtocol { channel: ChannelId, ) -> Result<&mut Channel, LocalProtocolViolation> { if channel.0 as usize >= N { - Err(LocalProtocolViolation::InvalidChannel(channel)) + Err(LocalProtocolViolation::InvalidChannel { + channel, + channel_count: N, + }) } else { Ok(&mut self.channels[channel.0 as usize]) } @@ -608,7 +635,9 @@ impl JulietProtocol { } if !chan.allowed_to_send_request() { - return Err(LocalProtocolViolation::WouldExceedRequestLimit); + return Err(LocalProtocolViolation::WouldExceedRequestLimit { + limit: chan.request_limit(), + }); } Ok(chan.create_unchecked_request(channel, payload)) @@ -723,11 +752,15 @@ impl JulietProtocol { id: Id, payload: Bytes, ) -> Result { - let header = Header::new_error(header::ErrorKind::Other, channel, id); + let header = Header::new_error(ErrorKind::Other, channel, id); + let payload_length = payload.len(); let msg = OutgoingMessage::new(header, Some(payload)); if msg.is_multi_frame(self.max_frame_size) { - Err(LocalProtocolViolation::ErrorPayloadIsMultiFrame) + Err(LocalProtocolViolation::ErrorPayloadIsMultiFrame { + payload_length, + max_frame_size: self.max_frame_size.0, + }) } else { Ok(msg) } @@ -1264,7 +1297,8 @@ mod tests { #[test] fn test_channel_lookups_work() { - let mut protocol: JulietProtocol<3> = ProtocolBuilder::new().build(); + const CHANNEL_COUNT: usize = 3; + let mut protocol: JulietProtocol = ProtocolBuilder::new().build(); // We mark channels by inserting an ID into them, that way we can ensure we're not getting // back the same channel every time. @@ -1285,15 +1319,24 @@ mod tests { .insert(Id::new(102)); assert!(matches!( protocol.lookup_channel_mut(ChannelId(3)), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(3))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(3), + channel_count: CHANNEL_COUNT + }) )); assert!(matches!( protocol.lookup_channel_mut(ChannelId(4)), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(4))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(4), + channel_count: CHANNEL_COUNT + }) )); assert!(matches!( protocol.lookup_channel_mut(ChannelId(255)), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(255))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(255), + channel_count: CHANNEL_COUNT + }) )); // Now look up the channels and ensure they contain the right values @@ -1320,15 +1363,24 @@ mod tests { ); assert!(matches!( protocol.lookup_channel(ChannelId(3)), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(3))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(3), + channel_count: CHANNEL_COUNT + }) )); assert!(matches!( protocol.lookup_channel(ChannelId(4)), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(4))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(4), + channel_count: CHANNEL_COUNT + }) )); assert!(matches!( protocol.lookup_channel(ChannelId(255)), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(255))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(255), + channel_count: CHANNEL_COUNT + }) )); } @@ -1453,7 +1505,10 @@ mod tests { // Try an invalid channel, should result in an error. assert!(matches!( protocol.create_request(ChannelId::new(2), payload.get()), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(2))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(2), + channel_count: 2 + }) )); assert!(protocol @@ -1465,7 +1520,7 @@ mod tests { assert!(matches!( protocol.create_request(channel, payload.get()), - Err(LocalProtocolViolation::WouldExceedRequestLimit) + Err(LocalProtocolViolation::WouldExceedRequestLimit { limit: 1 }) )); } }