Skip to content

Commit

Permalink
refactor: readability changes to sslclient and wsclient
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaskowicz1 committed Jul 23, 2024
1 parent ced36fd commit 1aa25dc
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 150 deletions.
2 changes: 1 addition & 1 deletion include/dpp/httpsclient.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

namespace dpp {

static inline const std::string http_version = "DiscordBot (https://github.com/brainboxdotcc/DPP, "
static inline constexpr std::string http_version = "DiscordBot (https://github.com/brainboxdotcc/DPP, "
+ to_hex(DPP_VERSION_MAJOR, false) + "."
+ to_hex(DPP_VERSION_MINOR, false) + "."
+ to_hex(DPP_VERSION_PATCH, false) + ")";
Expand Down
7 changes: 3 additions & 4 deletions include/dpp/wsclient.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,10 @@ class DPP_EXPORT websocket_client : public ssl_client {
size_t fill_header(unsigned char* outbuf, size_t sendlength, ws_opcode opcode);

/**
* @brief Handle ping and pong requests.
* @param ping True if this is a ping, false if it is a pong
* @param payload The ping payload, to be returned as-is for a ping
* @brief Handle ping requests.
* @param payload The ping payload, to be returned as-is for a pong
*/
void handle_ping_pong(bool ping, const std::string &payload);
void handle_ping(const std::string &payload);

protected:

Expand Down
56 changes: 29 additions & 27 deletions src/dpp/sslclient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
#include <dpp/dns.h>

/* Maximum allowed time in milliseconds for socket read/write timeouts and connect() */
#define SOCKET_OP_TIMEOUT 5000
constexpr uint16_t SOCKET_OP_TIMEOUT{5000};

namespace dpp {

Expand Down Expand Up @@ -123,10 +123,10 @@ thread_local std::unordered_map<std::string, keepalive_cache_t> keepalives;
* SSL_read in non-blocking mode will only read 16k at a time. There's no point in a bigger buffer as
* it'd go unused.
*/
#define DPP_BUFSIZE 16 * 1024
constexpr uint32_t DPP_BUFSIZE{16 * 1024};

/* Represents a failed socket system call, e.g. connect() failure */
const int ERROR_STATUS = -1;
constexpr int ERROR_STATUS{-1};

bool close_socket(dpp::socket sfd)
{
Expand Down Expand Up @@ -197,25 +197,26 @@ int connect_with_timeout(dpp::socket sockfd, const struct sockaddr *addr, sockle
#endif
if (rc == -1 && err != EWOULDBLOCK && err != EINPROGRESS) {
throw connection_exception(err_connect_failure, strerror(errno));
} else {
/* Set a deadline timestamp 'timeout' ms from now */
double deadline = utility::time_f() + (timeout_ms / 1000.0);
do {
rc = -1;
if (utility::time_f() >= deadline) {
throw connection_exception(err_connection_timed_out, "Connection timed out");
}
pollfd pfd = {};
pfd.fd = sockfd;
pfd.events = POLLOUT;
int r = ::poll(&pfd, 1, 10);
if (r > 0 && pfd.revents & POLLOUT) {
rc = 0;
} else if (r != 0 || pfd.revents & POLLERR) {
throw connection_exception(err_connection_timed_out, strerror(errno));
}
} while (rc == -1);
}

/* Set a deadline timestamp 'timeout' ms from now */
double deadline = utility::time_f() + (timeout_ms / 1000.0);

do {
if (utility::time_f() >= deadline) {
throw connection_exception(err_connection_timed_out, "Connection timed out");
}
pollfd pfd = {};
pfd.fd = sockfd;
pfd.events = POLLOUT;
const int r = ::poll(&pfd, 1, 10);
if (r > 0 && pfd.revents & POLLOUT) {
rc = 0;
} else if (r != 0 || pfd.revents & POLLERR) {
throw connection_exception(err_connection_timed_out, strerror(errno));
}
} while (rc == -1);

if (!set_nonblocking(sockfd, false)) {
throw connection_exception(err_nonblocking_failure, "Can't switch socket to blocking mode!");
}
Expand Down Expand Up @@ -502,16 +503,17 @@ void ssl_client::read_loop()
read_blocked_on_write = false;
read_blocked = false;
r = (int) ::recv(sfd, server_to_client_buffer, DPP_BUFSIZE, 0);

if (r <= 0) {
/* error or EOF */
return;
} else {
buffer.append(server_to_client_buffer, r);
if (!this->handle_buffer(buffer)) {
return;
}
bytes_in += r;
}

buffer.append(server_to_client_buffer, r);
if (!this->handle_buffer(buffer)) {
return;
}
bytes_in += r;
} else {
do {
read_blocked_on_write = false;
Expand Down
242 changes: 124 additions & 118 deletions src/dpp/wsclient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,50 +127,57 @@ void websocket_client::write(const std::string &data)

bool websocket_client::handle_buffer(std::string &buffer)
{
switch (state) {
case HTTP_HEADERS:
if (buffer.find("\r\n\r\n") != std::string::npos) {
/* Got all headers, proceed to new state */

/* Get headers string */
std::string headers = buffer.substr(0, buffer.find("\r\n\r\n"));

/* Modify buffer, remove headers section */
buffer.erase(0, buffer.find("\r\n\r\n") + 4);

/* Process headers into map */
std::vector<std::string> h = utility::tokenize(headers);
if (h.size()) {
std::string status_line = h[0];
h.erase(h.begin());
/* HTTP/1.1 101 Switching Protocols */
std::vector<std::string> status = utility::tokenize(status_line, " ");
if (status.size() >= 3 && status[1] == "101") {
for(auto &hd : h) {
std::string::size_type sep = hd.find(": ");
if (sep != std::string::npos) {
std::string key = hd.substr(0, sep);
std::string value = hd.substr(sep + 2, hd.length());
http_headers[key] = value;
}
}

state = CONNECTED;
} else if (status.size() < 3) {
log(ll_warning, "Malformed HTTP response on websocket");
return false;
} else if (status[1] != "200" && status[1] != "204") {
log(ll_warning, "Received unhandled code: " + status[1]);
return false;
}
if (state == HTTP_HEADERS) {
/* We can expect Discord to end all packets with this.
* If they don't, something is wrong and we should abort.
*/
if (buffer.find("\r\n\r\n") == std::string::npos) {
return false;
}

/* Got all headers, proceed to new state */

/* Get headers string */
std::string headers = buffer.substr(0, buffer.find("\r\n\r\n"));

/* Modify buffer, remove headers section */
buffer.erase(0, buffer.find("\r\n\r\n") + 4);

/* Process headers into map */
std::vector<std::string> h = utility::tokenize(headers);

/* No headers? Something aint right. */
if (h.empty()) {
return false;
}

std::string status_line = h[0];
h.erase(h.begin());
std::vector<std::string> status = utility::tokenize(status_line, " ");
/* HTTP/1.1 101 Switching Protocols */
if (status.size() >= 3 && status[1] == "101") {
for(auto &hd : h) {
std::string::size_type sep = hd.find(": ");
if (sep != std::string::npos) {
std::string key = hd.substr(0, sep);
std::string value = hd.substr(sep + 2, hd.length());
http_headers[key] = value;
}
}
break;
case CONNECTED:
/* Process packets until we can't */
while (this->parseheader(buffer));
break;

state = CONNECTED;
} else if (status.size() < 3) {
log(ll_warning, "Malformed HTTP response on websocket");
return false;
} else if (status[1] != "200" && status[1] != "204") {
log(ll_warning, "Received unhandled code: " + status[1]);
return false;
}
} else if (state == CONNECTED) {
/* Process packets until we can't */
while (this->parseheader(buffer));
}

return true;
}

Expand All @@ -184,88 +191,89 @@ bool websocket_client::parseheader(std::string &data)
if (data.size() < 4) {
/* Not enough data to form a frame yet */
return false;
} else {
unsigned char opcode = data[0];
switch (opcode & ~WS_FINBIT) {
case OP_CONTINUATION:
case OP_TEXT:
case OP_BINARY:
case OP_PING:
case OP_PONG: {
unsigned char len1 = data[1];
unsigned int payloadstartoffset = 2;

if (len1 & WS_MASKBIT) {
len1 &= ~WS_MASKBIT;
payloadstartoffset += 2;
/* We don't handle masked data, because discord doesn't send it */
return true;
}
}

/* 6 bit ("small") length frame */
uint64_t len = len1;

if (len1 == WS_PAYLOAD_LENGTH_MAGIC_LARGE) {
/* 24 bit ("large") length frame */
if (data.length() < 8) {
/* We don't have a complete header yet */
return false;
}

unsigned char len2 = (unsigned char)data[2];
unsigned char len3 = (unsigned char)data[3];
len = (len2 << 8) | len3;

payloadstartoffset += 2;
} else if (len1 == WS_PAYLOAD_LENGTH_MAGIC_HUGE) {
/* 64 bit ("huge") length frame */
if (data.length() < 10) {
/* We don't have a complete header yet */
return false;
}
len = 0;
for (int v = 2, shift = 56; v < 10; ++v, shift -= 8) {
unsigned char l = (unsigned char)data[v];
len |= (uint64_t)(l & 0xff) << shift;
}
payloadstartoffset += 8;
}
unsigned char opcode = data[0];
switch (opcode & ~WS_FINBIT) {
case OP_CONTINUATION:
case OP_TEXT:
case OP_BINARY:
case OP_PING:
case OP_PONG: {
unsigned char len1 = data[1];
unsigned int payloadstartoffset = 2;

if (len1 & WS_MASKBIT) {
len1 &= ~WS_MASKBIT;
payloadstartoffset += 2;
/* We don't handle masked data, because discord doesn't send it */
return true;
}

if (data.length() < payloadstartoffset + len) {
/* We don't have a complete frame yet */
return false;
}
/* 6 bit ("small") length frame */
uint64_t len = len1;

if ((opcode & ~WS_FINBIT) == OP_PING || (opcode & ~WS_FINBIT) == OP_PONG) {
handle_ping_pong((opcode & ~WS_FINBIT) == OP_PING, data.substr(payloadstartoffset, len));
} else {
/* Pass this frame to the deriving class */
this->handle_frame(data.substr(payloadstartoffset, len));
if (len1 == WS_PAYLOAD_LENGTH_MAGIC_LARGE) {
/* 24 bit ("large") length frame */
if (data.length() < 8) {
/* We don't have a complete header yet */
return false;
}

/* Remove this frame from the input buffer */
data.erase(data.begin(), data.begin() + payloadstartoffset + len);
unsigned char len2 = (unsigned char)data[2];
unsigned char len3 = (unsigned char)data[3];
len = (len2 << 8) | len3;

return true;
payloadstartoffset += 2;
} else if (len1 == WS_PAYLOAD_LENGTH_MAGIC_HUGE) {
/* 64 bit ("huge") length frame */
if (data.length() < 10) {
/* We don't have a complete header yet */
return false;
}
len = 0;
for (int v = 2, shift = 56; v < 10; ++v, shift -= 8) {
unsigned char l = (unsigned char)data[v];
len |= (uint64_t)(l & 0xff) << shift;
}
payloadstartoffset += 8;
}
break;

case OP_CLOSE: {
uint16_t error = data[2] & 0xff;
error <<= 8;
error |= (data[3] & 0xff);
this->error(error);
if (data.length() < payloadstartoffset + len) {
/* We don't have a complete frame yet */
return false;
}
break;

default: {
this->error(0);
return false;
if ((opcode & ~WS_FINBIT) == OP_PING) {
handle_ping(data.substr(payloadstartoffset, len));
} else if ((opcode & ~WS_FINBIT) != OP_PONG) {
/* Pass this frame to the deriving class */
this->handle_frame(data.substr(payloadstartoffset, len));
}
break;

/* Remove this frame from the input buffer */
data.erase(data.begin(), data.begin() + payloadstartoffset + len);

return true;
}
break;

case OP_CLOSE: {
uint16_t error = data[2] & 0xff;
error <<= 8;
error |= (data[3] & 0xff);
this->error(error);
return false;
}
break;

default: {
this->error(0);
return false;
}
break;
}

return false;
}

Expand All @@ -282,16 +290,14 @@ void websocket_client::one_second_timer()
}
}

void websocket_client::handle_ping_pong(bool ping, const std::string &payload)
void websocket_client::handle_ping(const std::string &payload)
{
if (ping) {
/* For receiving pings we echo back their payload with the type OP_PONG */
unsigned char out[MAXHEADERSIZE];
size_t s = this->fill_header(out, payload.length(), OP_PONG);
std::string header((const char*)out, s);
ssl_client::write(header);
ssl_client::write(payload);
}
/* For receiving pings we echo back their payload with the type OP_PONG */
unsigned char out[MAXHEADERSIZE];
size_t s = this->fill_header(out, payload.length(), OP_PONG);
std::string header((const char*)out, s);
ssl_client::write(header);
ssl_client::write(payload);
}

void websocket_client::send_close_packet()
Expand Down

0 comments on commit 1aa25dc

Please sign in to comment.