diff --git a/include/dpp/httpsclient.h b/include/dpp/httpsclient.h index 0090c68af4..eec9c7b257 100644 --- a/include/dpp/httpsclient.h +++ b/include/dpp/httpsclient.h @@ -32,7 +32,7 @@ namespace dpp { -static inline const std::string http_version = "DiscordBot (https://github.com/brainboxdotcc/DPP, " +static inline constexpr std::string http_version = "DiscordBot (https://github.com/brainboxdotcc/DPP, " + to_hex(DPP_VERSION_MAJOR, false) + "." + to_hex(DPP_VERSION_MINOR, false) + "." + to_hex(DPP_VERSION_PATCH, false) + ")"; diff --git a/include/dpp/wsclient.h b/include/dpp/wsclient.h index b1a1b1870d..18f483dcb6 100644 --- a/include/dpp/wsclient.h +++ b/include/dpp/wsclient.h @@ -151,11 +151,10 @@ class DPP_EXPORT websocket_client : public ssl_client { size_t fill_header(unsigned char* outbuf, size_t sendlength, ws_opcode opcode); /** - * @brief Handle ping and pong requests. - * @param ping True if this is a ping, false if it is a pong - * @param payload The ping payload, to be returned as-is for a ping + * @brief Handle ping requests. + * @param payload The ping payload, to be returned as-is for a pong */ - void handle_ping_pong(bool ping, const std::string &payload); + void handle_ping(const std::string &payload); protected: diff --git a/src/dpp/sslclient.cpp b/src/dpp/sslclient.cpp index ee7f4d7af0..d3a191e1c8 100644 --- a/src/dpp/sslclient.cpp +++ b/src/dpp/sslclient.cpp @@ -73,7 +73,7 @@ #include /* Maximum allowed time in milliseconds for socket read/write timeouts and connect() */ -#define SOCKET_OP_TIMEOUT 5000 +constexpr uint16_t SOCKET_OP_TIMEOUT{5000}; namespace dpp { @@ -123,10 +123,10 @@ thread_local std::unordered_map keepalives; * SSL_read in non-blocking mode will only read 16k at a time. There's no point in a bigger buffer as * it'd go unused. */ -#define DPP_BUFSIZE 16 * 1024 +constexpr uint32_t DPP_BUFSIZE{16 * 1024}; /* Represents a failed socket system call, e.g. connect() failure */ -const int ERROR_STATUS = -1; +constexpr int ERROR_STATUS{-1}; bool close_socket(dpp::socket sfd) { @@ -197,25 +197,26 @@ int connect_with_timeout(dpp::socket sockfd, const struct sockaddr *addr, sockle #endif if (rc == -1 && err != EWOULDBLOCK && err != EINPROGRESS) { throw connection_exception(err_connect_failure, strerror(errno)); - } else { - /* Set a deadline timestamp 'timeout' ms from now */ - double deadline = utility::time_f() + (timeout_ms / 1000.0); - do { - rc = -1; - if (utility::time_f() >= deadline) { - throw connection_exception(err_connection_timed_out, "Connection timed out"); - } - pollfd pfd = {}; - pfd.fd = sockfd; - pfd.events = POLLOUT; - int r = ::poll(&pfd, 1, 10); - if (r > 0 && pfd.revents & POLLOUT) { - rc = 0; - } else if (r != 0 || pfd.revents & POLLERR) { - throw connection_exception(err_connection_timed_out, strerror(errno)); - } - } while (rc == -1); } + + /* Set a deadline timestamp 'timeout' ms from now */ + double deadline = utility::time_f() + (timeout_ms / 1000.0); + + do { + if (utility::time_f() >= deadline) { + throw connection_exception(err_connection_timed_out, "Connection timed out"); + } + pollfd pfd = {}; + pfd.fd = sockfd; + pfd.events = POLLOUT; + const int r = ::poll(&pfd, 1, 10); + if (r > 0 && pfd.revents & POLLOUT) { + rc = 0; + } else if (r != 0 || pfd.revents & POLLERR) { + throw connection_exception(err_connection_timed_out, strerror(errno)); + } + } while (rc == -1); + if (!set_nonblocking(sockfd, false)) { throw connection_exception(err_nonblocking_failure, "Can't switch socket to blocking mode!"); } @@ -502,16 +503,17 @@ void ssl_client::read_loop() read_blocked_on_write = false; read_blocked = false; r = (int) ::recv(sfd, server_to_client_buffer, DPP_BUFSIZE, 0); + if (r <= 0) { /* error or EOF */ return; - } else { - buffer.append(server_to_client_buffer, r); - if (!this->handle_buffer(buffer)) { - return; - } - bytes_in += r; } + + buffer.append(server_to_client_buffer, r); + if (!this->handle_buffer(buffer)) { + return; + } + bytes_in += r; } else { do { read_blocked_on_write = false; diff --git a/src/dpp/wsclient.cpp b/src/dpp/wsclient.cpp index 91d1281065..e5e6eae41f 100644 --- a/src/dpp/wsclient.cpp +++ b/src/dpp/wsclient.cpp @@ -127,50 +127,57 @@ void websocket_client::write(const std::string &data) bool websocket_client::handle_buffer(std::string &buffer) { - switch (state) { - case HTTP_HEADERS: - if (buffer.find("\r\n\r\n") != std::string::npos) { - /* Got all headers, proceed to new state */ - - /* Get headers string */ - std::string headers = buffer.substr(0, buffer.find("\r\n\r\n")); - - /* Modify buffer, remove headers section */ - buffer.erase(0, buffer.find("\r\n\r\n") + 4); - - /* Process headers into map */ - std::vector h = utility::tokenize(headers); - if (h.size()) { - std::string status_line = h[0]; - h.erase(h.begin()); - /* HTTP/1.1 101 Switching Protocols */ - std::vector status = utility::tokenize(status_line, " "); - if (status.size() >= 3 && status[1] == "101") { - for(auto &hd : h) { - std::string::size_type sep = hd.find(": "); - if (sep != std::string::npos) { - std::string key = hd.substr(0, sep); - std::string value = hd.substr(sep + 2, hd.length()); - http_headers[key] = value; - } - } - - state = CONNECTED; - } else if (status.size() < 3) { - log(ll_warning, "Malformed HTTP response on websocket"); - return false; - } else if (status[1] != "200" && status[1] != "204") { - log(ll_warning, "Received unhandled code: " + status[1]); - return false; - } + if (state == HTTP_HEADERS) { + /* We can expect Discord to end all packets with this. + * If they don't, something is wrong and we should abort. + */ + if (buffer.find("\r\n\r\n") == std::string::npos) { + return false; + } + + /* Got all headers, proceed to new state */ + + /* Get headers string */ + std::string headers = buffer.substr(0, buffer.find("\r\n\r\n")); + + /* Modify buffer, remove headers section */ + buffer.erase(0, buffer.find("\r\n\r\n") + 4); + + /* Process headers into map */ + std::vector h = utility::tokenize(headers); + + /* No headers? Something aint right. */ + if (h.empty()) { + return false; + } + + std::string status_line = h[0]; + h.erase(h.begin()); + std::vector status = utility::tokenize(status_line, " "); + /* HTTP/1.1 101 Switching Protocols */ + if (status.size() >= 3 && status[1] == "101") { + for(auto &hd : h) { + std::string::size_type sep = hd.find(": "); + if (sep != std::string::npos) { + std::string key = hd.substr(0, sep); + std::string value = hd.substr(sep + 2, hd.length()); + http_headers[key] = value; } } - break; - case CONNECTED: - /* Process packets until we can't */ - while (this->parseheader(buffer)); - break; + + state = CONNECTED; + } else if (status.size() < 3) { + log(ll_warning, "Malformed HTTP response on websocket"); + return false; + } else if (status[1] != "200" && status[1] != "204") { + log(ll_warning, "Received unhandled code: " + status[1]); + return false; + } + } else if (state == CONNECTED) { + /* Process packets until we can't */ + while (this->parseheader(buffer)); } + return true; } @@ -184,88 +191,89 @@ bool websocket_client::parseheader(std::string &data) if (data.size() < 4) { /* Not enough data to form a frame yet */ return false; - } else { - unsigned char opcode = data[0]; - switch (opcode & ~WS_FINBIT) { - case OP_CONTINUATION: - case OP_TEXT: - case OP_BINARY: - case OP_PING: - case OP_PONG: { - unsigned char len1 = data[1]; - unsigned int payloadstartoffset = 2; - - if (len1 & WS_MASKBIT) { - len1 &= ~WS_MASKBIT; - payloadstartoffset += 2; - /* We don't handle masked data, because discord doesn't send it */ - return true; - } + } - /* 6 bit ("small") length frame */ - uint64_t len = len1; - - if (len1 == WS_PAYLOAD_LENGTH_MAGIC_LARGE) { - /* 24 bit ("large") length frame */ - if (data.length() < 8) { - /* We don't have a complete header yet */ - return false; - } - - unsigned char len2 = (unsigned char)data[2]; - unsigned char len3 = (unsigned char)data[3]; - len = (len2 << 8) | len3; - - payloadstartoffset += 2; - } else if (len1 == WS_PAYLOAD_LENGTH_MAGIC_HUGE) { - /* 64 bit ("huge") length frame */ - if (data.length() < 10) { - /* We don't have a complete header yet */ - return false; - } - len = 0; - for (int v = 2, shift = 56; v < 10; ++v, shift -= 8) { - unsigned char l = (unsigned char)data[v]; - len |= (uint64_t)(l & 0xff) << shift; - } - payloadstartoffset += 8; - } + unsigned char opcode = data[0]; + switch (opcode & ~WS_FINBIT) { + case OP_CONTINUATION: + case OP_TEXT: + case OP_BINARY: + case OP_PING: + case OP_PONG: { + unsigned char len1 = data[1]; + unsigned int payloadstartoffset = 2; + + if (len1 & WS_MASKBIT) { + len1 &= ~WS_MASKBIT; + payloadstartoffset += 2; + /* We don't handle masked data, because discord doesn't send it */ + return true; + } - if (data.length() < payloadstartoffset + len) { - /* We don't have a complete frame yet */ - return false; - } + /* 6 bit ("small") length frame */ + uint64_t len = len1; - if ((opcode & ~WS_FINBIT) == OP_PING || (opcode & ~WS_FINBIT) == OP_PONG) { - handle_ping_pong((opcode & ~WS_FINBIT) == OP_PING, data.substr(payloadstartoffset, len)); - } else { - /* Pass this frame to the deriving class */ - this->handle_frame(data.substr(payloadstartoffset, len)); + if (len1 == WS_PAYLOAD_LENGTH_MAGIC_LARGE) { + /* 24 bit ("large") length frame */ + if (data.length() < 8) { + /* We don't have a complete header yet */ + return false; } - /* Remove this frame from the input buffer */ - data.erase(data.begin(), data.begin() + payloadstartoffset + len); + unsigned char len2 = (unsigned char)data[2]; + unsigned char len3 = (unsigned char)data[3]; + len = (len2 << 8) | len3; - return true; + payloadstartoffset += 2; + } else if (len1 == WS_PAYLOAD_LENGTH_MAGIC_HUGE) { + /* 64 bit ("huge") length frame */ + if (data.length() < 10) { + /* We don't have a complete header yet */ + return false; + } + len = 0; + for (int v = 2, shift = 56; v < 10; ++v, shift -= 8) { + unsigned char l = (unsigned char)data[v]; + len |= (uint64_t)(l & 0xff) << shift; + } + payloadstartoffset += 8; } - break; - case OP_CLOSE: { - uint16_t error = data[2] & 0xff; - error <<= 8; - error |= (data[3] & 0xff); - this->error(error); + if (data.length() < payloadstartoffset + len) { + /* We don't have a complete frame yet */ return false; } - break; - default: { - this->error(0); - return false; + if ((opcode & ~WS_FINBIT) == OP_PING) { + handle_ping(data.substr(payloadstartoffset, len)); + } else if ((opcode & ~WS_FINBIT) != OP_PONG) { + /* Pass this frame to the deriving class */ + this->handle_frame(data.substr(payloadstartoffset, len)); } - break; + + /* Remove this frame from the input buffer */ + data.erase(data.begin(), data.begin() + payloadstartoffset + len); + + return true; + } + break; + + case OP_CLOSE: { + uint16_t error = data[2] & 0xff; + error <<= 8; + error |= (data[3] & 0xff); + this->error(error); + return false; + } + break; + + default: { + this->error(0); + return false; } + break; } + return false; } @@ -282,16 +290,14 @@ void websocket_client::one_second_timer() } } -void websocket_client::handle_ping_pong(bool ping, const std::string &payload) +void websocket_client::handle_ping(const std::string &payload) { - if (ping) { - /* For receiving pings we echo back their payload with the type OP_PONG */ - unsigned char out[MAXHEADERSIZE]; - size_t s = this->fill_header(out, payload.length(), OP_PONG); - std::string header((const char*)out, s); - ssl_client::write(header); - ssl_client::write(payload); - } + /* For receiving pings we echo back their payload with the type OP_PONG */ + unsigned char out[MAXHEADERSIZE]; + size_t s = this->fill_header(out, payload.length(), OP_PONG); + std::string header((const char*)out, s); + ssl_client::write(header); + ssl_client::write(payload); } void websocket_client::send_close_packet()