Skip to content

Commit

Permalink
feat: move decryption to courier thread where it's actually needed
Browse files Browse the repository at this point in the history
  • Loading branch information
Neko-Life committed Oct 20, 2024
1 parent 45fc7ca commit c1a2f0d
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 119 deletions.
187 changes: 155 additions & 32 deletions src/dpp/voice/enabled/courier_loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
************************************************************************************/

#include <string_view>
#include <utility>
#include <dpp/exception.h>
#include <dpp/isa_detection.h>
#include <dpp/discordvoiceclient.h>
Expand All @@ -32,11 +33,32 @@

namespace dpp {

struct iter_bench {
std::chrono::time_point<std::chrono::steady_clock> start;
std::string n;

iter_bench(std::string _n) : n(std::move(_n)) {

Check notice on line 40 in src/dpp/voice/enabled/courier_loop.cpp

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

src/dpp/voice/enabled/courier_loop.cpp#L40

Struct 'iter_bench' has a constructor with 1 argument that is not explicit.
start = std::chrono::steady_clock::now();

Check warning on line 41 in src/dpp/voice/enabled/courier_loop.cpp

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

src/dpp/voice/enabled/courier_loop.cpp#L41

Variable 'start' is assigned in constructor body. Consider performing initialization in initialization list.
std::cout << n << "START: " << std::chrono::duration_cast<std::chrono::milliseconds>(start.time_since_epoch()).count() << "\n";
}

~iter_bench() {
auto end = std::chrono::steady_clock::now();
std::cout << n << "END: " << std::chrono::duration_cast<std::chrono::milliseconds>(end.time_since_epoch()).count() << "\n";
std::cout << n << "BENCH: " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "\n";
}
};

void discord_voice_client::voice_courier_loop(discord_voice_client& client, courier_shared_state_t& shared_state) {
utility::set_thread_name(std::string("vcourier/") + std::to_string(client.server_id));
size_t iter = 0;
while (true) {
std::cout << "voice_courier_loop ITER: " << ++iter << "\n";

std::this_thread::sleep_for(std::chrono::milliseconds{client.iteration_interval});

iter_bench b("voice_courier_loop ITER_");

struct flush_data_t {
snowflake user_id;
rtp_seq_t min_seq;
Expand All @@ -51,6 +73,9 @@ void discord_voice_client::voice_courier_loop(discord_voice_client& client, cour
* release the lock as soon as possible.
*/
{
std::cout << "shared_state.mtx LOCK\n";
iter_bench c("voice_courier_loop FLUSH_DATA_");

std::unique_lock lk(shared_state.mtx);

/* mitigates vector resizing while holding the mutex */
Expand All @@ -60,43 +85,53 @@ void discord_voice_client::voice_courier_loop(discord_voice_client& client, cour
for (auto &[user_id, parking_lot]: shared_state.parked_voice_payloads) {
has_payload_to_deliver = has_payload_to_deliver || !parking_lot.parked_payloads.empty();
flush_data.push_back(flush_data_t{user_id,
parking_lot.range.min_seq,
std::move(parking_lot.parked_payloads),
parking_lot.range.min_seq,
std::move(parking_lot.parked_payloads),
/* Quickly check if we already have a decoder and only take the pending ctls if so. */
parking_lot.decoder ? std::move(parking_lot.pending_decoder_ctls)
: decltype(parking_lot.pending_decoder_ctls){},
parking_lot.decoder});
parking_lot.decoder ? std::move(parking_lot.pending_decoder_ctls)
: decltype(parking_lot.pending_decoder_ctls){},
parking_lot.decoder});
parking_lot.range.min_seq = parking_lot.range.max_seq + 1;
parking_lot.range.min_timestamp = parking_lot.range.max_timestamp + 1;
}

if (!has_payload_to_deliver) {
if (shared_state.terminating) {
/* We have delivered all data to handlers. Terminate now. */
std::cout << "shared_state.mtx RELEASE\n";
break;
}

shared_state.signal_iteration.wait(lk, [&shared_state](){
/*
* Actually check the state we're looking for instead of waking up
* everytime read_ready was called.
*/
for (auto &[user_id, parking_lot]: shared_state.parked_voice_payloads) {
if (parking_lot.parked_payloads.empty()) {
continue;
}

return true;
}
return false;
});
shared_state.signal_iteration.wait(lk, [&shared_state](){
std::cout << "shared_state.mtx LOCKED\n";
if (shared_state.terminating) {
return true;
}

/*
* Actually check the state we're looking for instead of waking up
* everytime read_ready was called.
*/
for (auto &[user_id, parking_lot]: shared_state.parked_voice_payloads) {
if (parking_lot.parked_payloads.empty()) {
continue;
}

std::cout << "shared_state.mtx RELEASE\n";
return true;
}
std::cout << "shared_state.mtx RELEASE\n";
return false;
});
std::cout << "shared_state.mtx RELEASE\n";

/*
* More data came or about to terminate, or just a spurious wake.
* We need to collect the payloads again to determine what to do next.
*/
continue;
}
std::cout << "shared_state.mtx RELEASE\n";
}

if (client.creator->on_voice_receive.empty() && client.creator->on_voice_receive_combined.empty()) {
Expand All @@ -116,52 +151,140 @@ void discord_voice_client::voice_courier_loop(discord_voice_client& client, cour
int max_samples = 0;
int samples = 0;

opus_int16 flush_data_pcm[23040];
for (auto &d: flush_data) {
iter_bench c("voice_courier_loop FLUSH_LOOP_");
if (!d.decoder) {
continue;
}
for (const auto &decoder_ctl: d.pending_decoder_ctls) {
decoder_ctl(*d.decoder);
}

for (rtp_seq_t seq = d.min_seq; !d.parked_payloads.empty(); ++seq) {
opus_int16 pcm[23040];
iter_bench e("voice_courier_loop SEQ_LOOP_");
std::cout << "TOP_SEQ: " << d.parked_payloads.top().seq << "\nCURRENT_SEQ: "<< seq << "\n";

if (d.parked_payloads.top().seq != seq) {
/*
* Lost a packet with sequence number "seq",
* But Opus decoder might be able to guess something.
*/
if (int samples = opus_decode(d.decoder.get(), nullptr, 0, pcm, 5760, 0);
if (int samples = opus_decode(d.decoder.get(), nullptr, 0, flush_data_pcm, 5760, 0);
samples >= 0) {
/*
* Since this sample comes from a lost packet,
* we can only pretend there is an event, without any raw payload byte.
*/
voice_receive_t vr(nullptr, "", &client, d.user_id, reinterpret_cast<uint8_t *>(pcm),
samples * opus_channel_count * sizeof(opus_int16));
voice_receive_t vr(nullptr, "", &client, d.user_id, reinterpret_cast<uint8_t *>(flush_data_pcm),
samples * opus_channel_count * sizeof(opus_int16));

park_count = audio_mix(client, *client.mixer, pcm_mix, pcm, park_count, samples, max_samples);
park_count = audio_mix(client, *client.mixer, pcm_mix, flush_data_pcm, park_count, samples, max_samples);
client.creator->on_voice_receive.call(vr);
}
} else {
voice_receive_t &vr = *d.parked_payloads.top().vr;
if (vr.audio_data.size() > 0x7FFFFFFF) {

/* We do decryption here to avoid blocking ssl_client from receiving more audio data */
constexpr size_t header_size = 12;

uint8_t *buffer = vr.audio_data.data();
int packet_size = vr.audio_data.size();

constexpr size_t nonce_size = sizeof(uint32_t);
/* Nonce is 4 byte at the end of payload with zero padding */
uint8_t nonce[24] = { 0 };
std::memcpy(nonce, buffer + packet_size - nonce_size, nonce_size);

/* Get the number of CSRC in header */
const size_t csrc_count = buffer[0] & 0b0000'1111;
/* Skip to the encrypted voice data */
const ptrdiff_t offset_to_data = header_size + sizeof(uint32_t) * csrc_count;
size_t total_header_len = offset_to_data;

uint8_t* ciphertext = buffer + offset_to_data;
size_t ciphertext_len = packet_size - offset_to_data - nonce_size;

size_t ext_len = 0;
if ([[maybe_unused]] const bool uses_extension = (buffer[0] >> 4) & 0b0001) {
/**
* Get the RTP Extensions size, we only get the size here because
* the extension itself is encrypted along with the opus packet
*/
{
uint16_t ext_len_in_words;
memcpy(&ext_len_in_words, &ciphertext[2], sizeof(uint16_t));
ext_len_in_words = ntohs(ext_len_in_words);
ext_len = sizeof(uint32_t) * ext_len_in_words;
}
constexpr size_t ext_header_len = sizeof(uint16_t) * 2;
ciphertext += ext_header_len;
ciphertext_len -= ext_header_len;
total_header_len += ext_header_len;
}

uint8_t decrypted[65535] = { 0 };
unsigned long long opus_packet_len = 0;
if (ssl_crypto_aead_xchacha20poly1305_ietf_decrypt(
decrypted, &opus_packet_len,
nullptr,
ciphertext, ciphertext_len,
buffer,
/**
* Additional Data:
* The whole header (including csrc list) +
* 4 byte extension header (magic 0xBEDE + 16-bit denoting extension length)
*/
total_header_len,
nonce, vr.voice_client->secret_key.data()) != 0) {
/* Invalid Discord RTP payload. */
return;
}

uint8_t *opus_packet = decrypted;
if (ext_len > 0) {
/* Skip previously encrypted RTP Header Extension */
opus_packet += ext_len;
opus_packet_len -= ext_len;
}

/**
* If DAVE is enabled, use the user's ratchet to decrypt the OPUS audio data
*/
std::vector<uint8_t> frame;
if (vr.voice_client->is_end_to_end_encrypted()) {
auto decryptor = vr.voice_client->mls_state->decryptors.find(vr.user_id);
if (decryptor != vr.voice_client->mls_state->decryptors.end()) {
frame.resize(decryptor->second->get_max_plaintext_byte_size(dave::media_type::media_audio, opus_packet_len));
size_t enc_len = decryptor->second->decrypt(
dave::media_type::media_audio,
dave::make_array_view<const uint8_t>(opus_packet, opus_packet_len),
dave::make_array_view(frame)
);
if (enc_len > 0) {
opus_packet = frame.data();
opus_packet_len = enc_len;
}
}
}

if (opus_packet_len > 0x7FFFFFFF) {
throw dpp::length_exception(err_massive_audio, "audio_data > 2GB! This should never happen!");
}
if (samples = opus_decode(d.decoder.get(), vr.audio_data.data(),
static_cast<opus_int32>(vr.audio_data.size() & 0x7FFFFFFF), pcm, 5760, 0);
if (samples = opus_decode(d.decoder.get(), opus_packet,
static_cast<opus_int32>(opus_packet_len & 0x7FFFFFFF), flush_data_pcm, 5760, 0);
samples >= 0) {
vr.reassign(&client, d.user_id, reinterpret_cast<uint8_t *>(pcm),
samples * opus_channel_count * sizeof(opus_int16));
vr.reassign(&client, d.user_id, reinterpret_cast<uint8_t *>(flush_data_pcm),
samples * opus_channel_count * sizeof(opus_int16));
client.end_gain = 1.0f / client.moving_average;
park_count = audio_mix(client, *client.mixer, pcm_mix, pcm, park_count, samples, max_samples);
park_count = audio_mix(client, *client.mixer, pcm_mix, flush_data_pcm, park_count, samples, max_samples);
client.creator->on_voice_receive.call(vr);
}

d.parked_payloads.pop();
}
}
}

/* If combined receive is bound, dispatch it */
if (park_count) {

Expand All @@ -178,7 +301,7 @@ void discord_voice_client::voice_courier_loop(discord_voice_client& client, cour
}

voice_receive_t vr(nullptr, "", &client, 0, reinterpret_cast<uint8_t *>(pcm_downsample),
max_samples * opus_channel_count * sizeof(opus_int16));
max_samples * opus_channel_count * sizeof(opus_int16));

client.creator->on_voice_receive_combined.call(vr);
}
Expand Down
Loading

0 comments on commit c1a2f0d

Please sign in to comment.