diff --git a/interceptor/src/nack/responder/mod.rs b/interceptor/src/nack/responder/mod.rs index 48bbfbf85..7498985fb 100644 --- a/interceptor/src/nack/responder/mod.rs +++ b/interceptor/src/nack/responder/mod.rs @@ -17,12 +17,14 @@ use std::collections::HashMap; use std::future::Future; use std::pin::Pin; use std::sync::Arc; +use std::time::Duration; use tokio::sync::Mutex; /// GeneratorBuilder can be used to configure Responder Interceptor #[derive(Default)] pub struct ResponderBuilder { log2_size: Option, + max_packet_age: Option, } impl ResponderBuilder { @@ -32,6 +34,15 @@ impl ResponderBuilder { self.log2_size = Some(log2_size); self } + + /// with_max_packet_age sets the max age of packets that will be resent. + /// + /// When a resend is requested, packets that were first sent more than `max_packet_age` ago + /// will not be resent. + pub fn with_max_packet_age(mut self, max_packet_age: Duration) -> ResponderBuilder { + self.max_packet_age = Some(max_packet_age); + self + } } impl InterceptorBuilder for ResponderBuilder { @@ -43,6 +54,7 @@ impl InterceptorBuilder for ResponderBuilder { } else { 13 // 8192 = 1 << 13 }, + max_packet_age: self.max_packet_age, streams: Arc::new(Mutex::new(HashMap::new())), }), })) @@ -51,6 +63,7 @@ impl InterceptorBuilder for ResponderBuilder { pub struct ResponderInternal { log2_size: u8, + max_packet_age: Option, streams: Arc>>>, } @@ -58,6 +71,7 @@ impl ResponderInternal { async fn resend_packets( streams: Arc>>>, nack: TransportLayerNack, + max_packet_age: Option, ) { let stream = { let m = streams.lock().await; @@ -73,10 +87,19 @@ impl ResponderInternal { n.range(Box::new( move |seq: u16| -> Pin + Send + 'static>> { let stream3 = Arc::clone(&stream2); + Box::pin(async move { if let Some(p) = stream3.get(seq).await { + let should_send = max_packet_age + .map(|max_age| p.age() < max_age) + .unwrap_or(true); + + if !should_send { + return true; + } + let a = Attributes::new(); - if let Err(err) = stream3.next_rtp_writer.write(&p, &a).await { + if let Err(err) = stream3.next_rtp_writer.write(&p.packet, &a).await { log::warn!("failed resending nacked packet: {}", err); } } @@ -92,6 +115,7 @@ impl ResponderInternal { pub struct ResponderRtcpReader { parent_rtcp_reader: Arc, + max_packet_age: Option, internal: Arc, } @@ -106,8 +130,9 @@ impl RTCPReader for ResponderRtcpReader { if let Some(nack) = p.as_any().downcast_ref::() { let nack = nack.clone(); let streams = Arc::clone(&self.internal.streams); + let max_packet_age = self.max_packet_age; tokio::spawn(async move { - ResponderInternal::resend_packets(streams, nack).await; + ResponderInternal::resend_packets(streams, nack, max_packet_age).await; }); } } @@ -138,6 +163,7 @@ impl Interceptor for Responder { ) -> Arc { Arc::new(ResponderRtcpReader { internal: Arc::clone(&self.internal), + max_packet_age: self.internal.max_packet_age, parent_rtcp_reader: reader, }) as Arc } diff --git a/interceptor/src/nack/responder/responder_stream.rs b/interceptor/src/nack/responder/responder_stream.rs index 86e022162..f036c1a4a 100644 --- a/interceptor/src/nack/responder/responder_stream.rs +++ b/interceptor/src/nack/responder/responder_stream.rs @@ -4,10 +4,11 @@ use crate::{Attributes, RTPWriter}; use async_trait::async_trait; use std::sync::Arc; +use std::time::{Duration, Instant}; use tokio::sync::Mutex; struct ResponderStreamInternal { - packets: Vec>, + packets: Vec>, size: u16, last_added: u16, started: bool, @@ -26,7 +27,7 @@ impl ResponderStreamInternal { fn add(&mut self, packet: &rtp::packet::Packet) { let seq = packet.header.sequence_number; if !self.started { - self.packets[(seq % self.size) as usize] = Some(packet.clone()); + self.packets[(seq % self.size) as usize] = Some(packet.clone().into()); self.last_added = seq; self.started = true; return; @@ -43,11 +44,11 @@ impl ResponderStreamInternal { } } - self.packets[(seq % self.size) as usize] = Some(packet.clone()); + self.packets[(seq % self.size) as usize] = Some(packet.clone().into()); self.last_added = seq; } - fn get(&self, seq: u16) -> Option<&rtp::packet::Packet> { + fn get(&self, seq: u16) -> Option<&SentPacket> { let diff = self.last_added.wrapping_sub(seq); if diff >= UINT16SIZE_HALF { return None; @@ -79,7 +80,7 @@ impl ResponderStream { internal.add(pkt); } - pub(super) async fn get(&self, seq: u16) -> Option { + pub(super) async fn get(&self, seq: u16) -> Option { let internal = self.internal.lock().await; internal.get(seq).cloned() } @@ -96,6 +97,28 @@ impl RTPWriter for ResponderStream { } } +#[derive(Clone)] +/// A packet that has been sent, or at least been queued to send. +pub struct SentPacket { + pub(super) packet: rtp::packet::Packet, + sent_at: Instant, +} + +impl SentPacket { + pub(super) fn age(&self) -> Duration { + self.sent_at.elapsed() + } +} + +impl From for SentPacket { + fn from(packet: rtp::packet::Packet) -> Self { + Self { + packet, + sent_at: Instant::now(), + } + } +} + #[cfg(test)] mod test { use super::*; @@ -127,9 +150,9 @@ mod test { let seq = start.wrapping_add(*n); if let Some(packet) = sb.get(seq) { assert_eq!( - packet.header.sequence_number, seq, + packet.packet.header.sequence_number, seq, "packet for {} returned with incorrect SequenceNumber: {}", - seq, packet.header.sequence_number + seq, packet.packet.header.sequence_number ); } else { assert!(false, "packet not found: {}", seq); @@ -144,7 +167,7 @@ mod test { assert!( false, "packet found for {}: {}", - seq, packet.header.sequence_number + seq, packet.packet.header.sequence_number ); } }