diff --git a/include/quic/endpoint.hpp b/include/quic/endpoint.hpp index b9d54bd3..9e1fb7df 100644 --- a/include/quic/endpoint.hpp +++ b/include/quic/endpoint.hpp @@ -42,6 +42,13 @@ namespace oxen::quic // called when a connection closes or times out before the handshake completes using connection_closed_callback = std::function; + // Called after we are done reading currently-available UDP packets to allow batch processing of + // incoming data (or any other just-before-potentially-blocking needed handling). First we fire + // off any available callbacks triggered for incoming packets, then just before we go back to + // potentially block waiting for more packets, we fire this to let the application know that + // there might not be more callbacks immediately arriving and so it should process what it has. + using post_receive_callback = std::function; + class Endpoint : std::enable_shared_from_this { private: @@ -50,6 +57,7 @@ namespace oxen::quic void handle_ep_opt(opt::inbound_alpns alpns); void handle_ep_opt(opt::handshake_timeout timeout); void handle_ep_opt(dgram_data_callback dgram_cb); + void handle_ep_opt(post_receive_callback post_recv_cb); void handle_ep_opt(connection_established_callback conn_established_cb); void handle_ep_opt(connection_closed_callback conn_closed_cb); @@ -238,6 +246,8 @@ namespace oxen::quic bool _packet_splitting{false}; Splitting _policy{Splitting::NONE}; + post_receive_callback _post_receive{nullptr}; + std::shared_ptr outbound_ctx; std::shared_ptr inbound_ctx; diff --git a/include/quic/udp.hpp b/include/quic/udp.hpp index a3631fec..f5dcc76c 100644 --- a/include/quic/udp.hpp +++ b/include/quic/udp.hpp @@ -38,6 +38,7 @@ namespace oxen::quic ; using receive_callback_t = std::function; + using post_receive_callback_t = std::function; UDPSocket() = delete; @@ -45,10 +46,23 @@ namespace oxen::quic /// binding to an any address (or any port) you can retrieve the realized address via /// address() after construction. /// - /// When packets are received they will be fed into the given callback. + /// When packets are received they will be fed into the given `on_receive` callback. + /// + /// The optional `post_receive` callback will be invoked after processing available incoming + /// packets but before returning to polling the socket for additional incoming packets. + /// This is meant to allow the caller to bundle incoming packets into batches without + /// introducing delays: each time one or more packets are read from the socket there will be + /// a sequence of `on_receive(...)` calls for each packet, followed by a `post_receive()` + /// call immediately before the socket returns to waiting for additional packets. Thus a + /// caller can use the `on_receive` callback to collect packets and the `post_receive` + /// callback to process the collected packets all at once. /// /// ev_loop must outlive this object. - UDPSocket(event_base* ev_loop, const Address& addr, receive_callback_t cb); + UDPSocket( + event_base* ev_loop, + const Address& addr, + receive_callback_t on_receive, + post_receive_callback_t post_receive = nullptr); /// Non-copyable and non-moveable UDPSocket(const UDPSocket& s) = delete; @@ -103,6 +117,8 @@ namespace oxen::quic event_ptr rev_ = nullptr; receive_callback_t receive_callback_; + post_receive_callback_t post_receive_; + bool have_received_ = false; event_ptr wev_ = nullptr; std::vector> writeable_callbacks_; }; diff --git a/src/endpoint.cpp b/src/endpoint.cpp index 8050dfd3..b7e916e0 100644 --- a/src/endpoint.cpp +++ b/src/endpoint.cpp @@ -56,6 +56,12 @@ namespace oxen::quic dgram_recv_cb = std::move(func); } + void Endpoint::handle_ep_opt(post_receive_callback func) + { + log::trace(log_cat, "Endpoint given post-receive callback"); + _post_receive = std::move(func); + } + void Endpoint::handle_ep_opt(connection_established_callback conn_established_cb) { log::trace(log_cat, "Endpoint given connection established callback"); @@ -71,8 +77,14 @@ namespace oxen::quic void Endpoint::_init_internals() { log::debug(log_cat, "Starting new UDP socket on {}", _local); - socket = - std::make_unique(get_loop().get(), _local, [this](const auto& packet) { handle_packet(packet); }); + socket = std::make_unique( + get_loop().get(), + _local, + [this](const auto& packet) { handle_packet(packet); }, + [this] { + if (_post_receive) + _post_receive(); + }); _local = socket->address(); diff --git a/src/udp.cpp b/src/udp.cpp index 7213c29f..ae53621d 100644 --- a/src/udp.cpp +++ b/src/udp.cpp @@ -85,8 +85,9 @@ namespace oxen::quic } #endif - UDPSocket::UDPSocket(event_base* ev_loop, const Address& addr, receive_callback_t on_receive) : - ev_{ev_loop}, receive_callback_{std::move(on_receive)} + UDPSocket::UDPSocket( + event_base* ev_loop, const Address& addr, receive_callback_t on_receive, post_receive_callback_t post_receive) : + ev_{ev_loop}, receive_callback_{std::move(on_receive)}, post_receive_{std::move(post_receive)} { assert(ev_); @@ -125,7 +126,16 @@ namespace oxen::quic ev_, sock_, EV_READ | EV_PERSIST, - [](evutil_socket_t, short, void* self) { static_cast(self)->receive(); }, + [](evutil_socket_t, short, void* self_) { + auto& self = *static_cast(self_); + self.receive(); + if (self.have_received_) + { + self.have_received_ = false; + if (self.post_receive_) + self.post_receive_(); + } + }, this)); event_add(rev_.get(), nullptr); @@ -190,6 +200,7 @@ namespace oxen::quic return; } + have_received_ = true; receive_callback_(Packet{bound_, payload, hdr}); } diff --git a/tests/007-datagrams.cpp b/tests/007-datagrams.cpp index 112809e0..16375e5b 100644 --- a/tests/007-datagrams.cpp +++ b/tests/007-datagrams.cpp @@ -333,7 +333,7 @@ namespace oxen::quic::test good_msg += v++; for (int i = 0; i < n; ++i) - conn_interface->send_datagram(std::basic_string_view{good_msg}); + conn_interface->send_datagram(good_msg); for (auto& f : data_futures) REQUIRE(f.get()); @@ -638,4 +638,83 @@ namespace oxen::quic::test }; #endif }; + + TEST_CASE("007 - Datagram support: packet post-receive triggers", "[007][datagrams][packet-post-receive]") + { + auto client_established = callback_waiter{[](connection_interface&) {}}; + + Network test_net{}; + + std::basic_string big_msg{}; + + for (int v = 0; big_msg.size() < 1000; v++) + big_msg += static_cast(v % 256); + + std::atomic recv_counter{0}; + std::atomic data_counter{0}; + + std::promise got_first, acked_first; + std::promise got_all_n_recvs; + + dgram_data_callback recv_dgram_cb = [&](dgram_interface&, bstring value) { + auto count = ++data_counter; + CHECK(value == big_msg); + if (count == 1) + { + // We get one datagram, then stall the quic thread so that the test can fire + // multiple packets that we should then receive in one go. + got_first.set_value(); + REQUIRE(acked_first.get_future().wait_for(1s) == std::future_status::ready); + } + else if (count == 31) + { + got_all_n_recvs.set_value(recv_counter); + } + }; + + auto recv_notifier = [&] { recv_counter++; }; + + opt::local_addr server_local{}; + opt::local_addr client_local{}; + + auto server_tls = GNUTLSCreds::make("./serverkey.pem"s, "./servercert.pem"s, "./clientcert.pem"s); + auto client_tls = GNUTLSCreds::make("./clientkey.pem"s, "./clientcert.pem"s, "./servercert.pem"s); + + auto server_endpoint = test_net.endpoint(server_local, recv_dgram_cb, recv_notifier, opt::enable_datagrams{}); + REQUIRE_NOTHROW(server_endpoint->listen(server_tls)); + + opt::remote_addr client_remote{"127.0.0.1"s, server_endpoint->local().port()}; + + auto client = test_net.endpoint(client_local, client_established, opt::enable_datagrams{}); + auto conn = client->connect(client_remote, client_tls); + + REQUIRE(client_established.wait()); + + // Start off with *one* datagram; the first one the server receives will stall the + // server until we signal it via the acked_first promise, during which we'll send a + // bunch more that ought to be processed in a single batch. + conn->send_datagram(big_msg); + + REQUIRE(got_first.get_future().wait_for(1s) == std::future_status::ready); + + int batches_before_flood = recv_counter; + + for (int i = 0; i < 30; i++) + conn->send_datagram(big_msg); + + acked_first.set_value(); + + auto f = got_all_n_recvs.get_future(); + REQUIRE(f.wait_for(1s) == std::future_status::ready); + auto recv_counter_before_final = f.get(); + REQUIRE(data_counter == 31); + REQUIRE(recv_counter_before_final > batches_before_flood); + // There should be a recv callback fired *immediately* after the data callback that + // fulfilled the above proimise, so a miniscule wait here should guarantee that it has been + // set. + std::this_thread::sleep_for(1ms); + auto final_recv_counter = recv_counter.load(); + REQUIRE(final_recv_counter > recv_counter_before_final); + }; + } // namespace oxen::quic::test