diff --git a/include/quic/endpoint.hpp b/include/quic/endpoint.hpp index b9d54bd3..39351a8a 100644 --- a/include/quic/endpoint.hpp +++ b/include/quic/endpoint.hpp @@ -42,6 +42,13 @@ namespace oxen::quic // called when a connection closes or times out before the handshake completes using connection_closed_callback = std::function; + // called after we are done reading currently-available UDP packets to allow batch processing of + // incoming data; first we fire off any available callbacks triggered for incoming packets, then + // just before we go back to potentially block waiting for more packets, we fire this to let the + // application know that there might not be more callbacks immediately arriving and so it should + // process what it has. + using packet_batch_callback = std::function; + class Endpoint : std::enable_shared_from_this { private: @@ -50,6 +57,7 @@ namespace oxen::quic void handle_ep_opt(opt::inbound_alpns alpns); void handle_ep_opt(opt::handshake_timeout timeout); void handle_ep_opt(dgram_data_callback dgram_cb); + void handle_ep_opt(packet_batch_callback batch_cb); void handle_ep_opt(connection_established_callback conn_established_cb); void handle_ep_opt(connection_closed_callback conn_closed_cb); @@ -238,6 +246,8 @@ namespace oxen::quic bool _packet_splitting{false}; Splitting _policy{Splitting::NONE}; + packet_batch_callback _packet_batcher{nullptr}; + std::shared_ptr outbound_ctx; std::shared_ptr inbound_ctx; diff --git a/include/quic/udp.hpp b/include/quic/udp.hpp index a3631fec..1c3ac194 100644 --- a/include/quic/udp.hpp +++ b/include/quic/udp.hpp @@ -38,6 +38,7 @@ namespace oxen::quic ; using receive_callback_t = std::function; + using receive_batch_callback_t = std::function; UDPSocket() = delete; @@ -45,10 +46,19 @@ namespace oxen::quic /// binding to an any address (or any port) you can retrieve the realized address via /// address() after construction. /// - /// When packets are received they will be fed into the given callback. + /// When packets are received they will be fed into the given `on_receive` callback. + /// + /// The optional `on_receive_batch` callback will be invoked after processing a batch of + /// incoming packets but before returning to polling the socket for additional incoming + /// packets. This is meant to allow the caller to bundle incoming packets into batches + /// without introducing delays: each time one or more packets are read from the socket there + /// will be a sequence of `on_receive(...)` calls for each packet, followed by an + /// `on_receive_batch()` call immediately before the socket returns to waiting for + /// additional packets. Thus a caller can use the `on_receive` callback to collect packets + /// and the `on_receive_batch` callback to process the collected packets all at once. /// /// ev_loop must outlive this object. - UDPSocket(event_base* ev_loop, const Address& addr, receive_callback_t cb); + UDPSocket(event_base* ev_loop, const Address& addr, receive_callback_t on_receive, receive_batch_callback_t on_receive_batch = nullptr); /// Non-copyable and non-moveable UDPSocket(const UDPSocket& s) = delete; @@ -103,6 +113,8 @@ namespace oxen::quic event_ptr rev_ = nullptr; receive_callback_t receive_callback_; + receive_batch_callback_t receive_callback_batch_; + bool pending_receive_batch_ = false; event_ptr wev_ = nullptr; std::vector> writeable_callbacks_; }; diff --git a/src/endpoint.cpp b/src/endpoint.cpp index 8050dfd3..8b6733a3 100644 --- a/src/endpoint.cpp +++ b/src/endpoint.cpp @@ -56,6 +56,12 @@ namespace oxen::quic dgram_recv_cb = std::move(func); } + void Endpoint::handle_ep_opt(packet_batch_callback func) + { + log::trace(log_cat, "Endpoint given packet batch callback"); + _packet_batcher = std::move(func); + } + void Endpoint::handle_ep_opt(connection_established_callback conn_established_cb) { log::trace(log_cat, "Endpoint given connection established callback"); @@ -71,8 +77,11 @@ namespace oxen::quic void Endpoint::_init_internals() { log::debug(log_cat, "Starting new UDP socket on {}", _local); - socket = - std::make_unique(get_loop().get(), _local, [this](const auto& packet) { handle_packet(packet); }); + socket = std::make_unique( + get_loop().get(), + _local, + [this](const auto& packet) { handle_packet(packet); }, + [this] { if (_packet_batcher) _packet_batcher(); }); _local = socket->address(); diff --git a/src/udp.cpp b/src/udp.cpp index 7213c29f..d61436e7 100644 --- a/src/udp.cpp +++ b/src/udp.cpp @@ -85,8 +85,12 @@ namespace oxen::quic } #endif - UDPSocket::UDPSocket(event_base* ev_loop, const Address& addr, receive_callback_t on_receive) : - ev_{ev_loop}, receive_callback_{std::move(on_receive)} + UDPSocket::UDPSocket( + event_base* ev_loop, + const Address& addr, + receive_callback_t on_receive, + receive_batch_callback_t on_receive_batch) : + ev_{ev_loop}, receive_callback_{std::move(on_receive)}, receive_callback_batch_{std::move(on_receive_batch)} { assert(ev_); @@ -125,7 +129,16 @@ namespace oxen::quic ev_, sock_, EV_READ | EV_PERSIST, - [](evutil_socket_t, short, void* self) { static_cast(self)->receive(); }, + [](evutil_socket_t, short, void* self_) { + auto& self = *static_cast(self_); + self.receive(); + if (self.pending_receive_batch_) + { + self.pending_receive_batch_ = false; + if (self.receive_callback_batch_) + self.receive_callback_batch_(); + } + }, this)); event_add(rev_.get(), nullptr); @@ -190,6 +203,7 @@ namespace oxen::quic return; } + pending_receive_batch_ = true; receive_callback_(Packet{bound_, payload, hdr}); } diff --git a/tests/007-datagrams.cpp b/tests/007-datagrams.cpp index 75be1610..538eb0f2 100644 --- a/tests/007-datagrams.cpp +++ b/tests/007-datagrams.cpp @@ -333,7 +333,7 @@ namespace oxen::quic::test good_msg += v++; for (int i = 0; i < n; ++i) - conn_interface->send_datagram(std::basic_string_view{good_msg}); + conn_interface->send_datagram(good_msg); for (auto& f : data_futures) REQUIRE(f.get()); @@ -638,4 +638,78 @@ namespace oxen::quic::test }; #endif }; + + TEST_CASE("007 - Datagram support: packet batch triggers", "[007][datagrams][packet-batch]") + { + auto client_established = callback_waiter{[](connection_interface&) {}}; + + Network test_net{}; + + std::basic_string big_msg{}; + + for (int v = 0; big_msg.size() < 1000; v++) + big_msg += static_cast(v % 256); + + std::atomic batch_counter{0}; + std::atomic data_counter{0}; + + std::promise got_first, acked_first; + std::promise got_all_n_batches; + + dgram_data_callback recv_dgram_cb = [&](dgram_interface&, bstring value) { + auto count = ++data_counter; + CHECK(value == big_msg); + if (count == 1) + { + // We get one datagram, then stall the quic thread so that the test can fire + // multiple packets that we should then receive in one go. + got_first.set_value(); + REQUIRE(acked_first.get_future().wait_for(1s) == std::future_status::ready); + } + else if (count == 31) + { + got_all_n_batches.set_value(batch_counter); + } + }; + + auto batch_notifier = [&] { batch_counter++; }; + + opt::local_addr server_local{}; + opt::local_addr client_local{}; + + auto server_tls = GNUTLSCreds::make("./serverkey.pem"s, "./servercert.pem"s, "./clientcert.pem"s); + auto client_tls = GNUTLSCreds::make("./clientkey.pem"s, "./clientcert.pem"s, "./servercert.pem"s); + + auto server_endpoint = test_net.endpoint(server_local, recv_dgram_cb, batch_notifier, opt::enable_datagrams{}); + REQUIRE_NOTHROW(server_endpoint->listen(server_tls)); + + opt::remote_addr client_remote{"127.0.0.1"s, server_endpoint->local().port()}; + + auto client = test_net.endpoint(client_local, client_established, opt::enable_datagrams{}); + auto conn = client->connect(client_remote, client_tls); + + REQUIRE(client_established.wait()); + + // Start off with *one* datagram; the first one the server receives will stall the + // server until we signal it via the acked_first promise, during which we'll send a + // bunch more that ought to be processed in a single batch. + conn->send_datagram(big_msg); + + REQUIRE(got_first.get_future().wait_for(1s) == std::future_status::ready); + + int batches_before_flood = batch_counter; + + for (int i = 0; i < 30; i++) + conn->send_datagram(big_msg); + + acked_first.set_value(); + + auto f = got_all_n_batches.get_future(); + REQUIRE(f.wait_for(1s) == std::future_status::ready); + auto batch_counter_before_final = f.get(); + REQUIRE(data_counter == 31); + REQUIRE(batch_counter_before_final > batches_before_flood); + REQUIRE(batch_counter == batch_counter_before_final + 1); + }; + } // namespace oxen::quic::test