Skip to content

Commit

Permalink
Socket processing improvements.
Browse files Browse the repository at this point in the history
1. Add SimpleSocketSender friend class for SimpleSocket for avoiding concurrent state modifications during racing receive and send operations.
2. Don't log warning if socket shutdown initiated by client.
  • Loading branch information
mirasrael committed Jan 2, 2024
1 parent 6c204c7 commit d200db2
Show file tree
Hide file tree
Showing 11 changed files with 202 additions and 153 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ void ByteBufferAsyncProcessor::ThreadProc()
return;
}

while (data.empty() && queue.empty() || interrupt_balance != 0)
while ((data.empty() && queue.empty()) || interrupt_balance != 0)
{
if (state >= StateKind::Stopping)
{
Expand Down
29 changes: 18 additions & 11 deletions rd-cpp/src/rd_framework_cpp/src/main/wire/SocketWire.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <SimpleSocket.h>
#include <ActiveSocket.h>
#include <PassiveSocket.h>
#include <SimpleSocketSender.h>

#include <utility>
#include <thread>
Expand Down Expand Up @@ -81,16 +82,16 @@ bool SocketWire::Base::send0(Buffer::ByteArray const& msg, sequence_number_t seq
send_package_header.write_integral(seqn);

RD_ASSERT_THROW_MSG(
socket_provider->Send(send_package_header.data(), send_package_header.get_position()) == PACKAGE_HEADER_LENGTH,
socket_sender->Send(send_package_header.data(), send_package_header.get_position()) == PACKAGE_HEADER_LENGTH,
this->id +
": failed to send header over the network"
", reason: " +
socket_provider->DescribeError())
socket_sender->DescribeError())

RD_ASSERT_THROW_MSG(socket_provider->Send(msg.data(), msglen) == msglen, this->id +
RD_ASSERT_THROW_MSG(socket_sender->Send(msg.data(), msglen) == msglen, this->id +
": failed to send package over the network"
", reason: " +
socket_provider->DescribeError());
socket_sender->DescribeError());
logger->info("{}: were sent {} bytes", this->id, msglen);
// RD_ASSERT_MSG(socketProvider->Flush(), "{}: failed to flush");
return true;
Expand Down Expand Up @@ -126,6 +127,7 @@ void SocketWire::Base::set_socket_provider(std::shared_ptr<CActiveSocket> new_so
{
std::lock_guard<decltype(socket_send_lock)> guard(socket_send_lock);
socket_provider = std::move(new_socket);
socket_sender = std::make_unique<CSimpleSocketSender>(socket_provider);
socket_send_var.notify_all();
}
{
Expand All @@ -136,8 +138,8 @@ void SocketWire::Base::set_socket_provider(std::shared_ptr<CActiveSocket> new_so
}
}

auto heartbeat = LifetimeDefinition::use([this](Lifetime heartbeatLifetime) {
const auto heartbeat = start_heartbeat(heartbeatLifetime).share();
const auto heartbeat = LifetimeDefinition::use([this](Lifetime heartbeatLifetime) {
const auto heartbeat = start_heartbeat(std::move(heartbeatLifetime)).share();

async_send_buffer.resume();

Expand All @@ -159,6 +161,11 @@ void SocketWire::Base::set_socket_provider(std::shared_ptr<CActiveSocket> new_so
{
logger->debug("{}: socket was already shut down", this->id);
}
else if (socket_provider->GetSocketError() == CSimpleSocket::SocketNotconnected)
{
logger->debug("{}: socket not connected (shutdown likely was initiated by client)");
socket_provider->Close();
}
else if (!socket_provider->Shutdown(CSimpleSocket::Both))
{
// double close?
Expand Down Expand Up @@ -393,14 +400,14 @@ void SocketWire::Base::ping() const
ping_pkg_header.write_integral(counterpart_timestamp);
{
std::lock_guard<decltype(socket_send_lock)> guard(socket_send_lock);
int32_t sent = socket_provider->Send(ping_pkg_header.data(), ping_pkg_header.get_position());
if (sent == 0 && !socket_provider->IsSocketValid())
int32_t sent = socket_sender->Send(ping_pkg_header.data(), ping_pkg_header.get_position());
if (sent == 0 && !socket_sender->IsSocketValid())
{
logger->debug("{}: failed to send ping over the network, reason: socket was shut down for sending", this->id);
return;
}
RD_ASSERT_THROW_MSG(sent == PACKAGE_HEADER_LENGTH,
fmt::format("{}: failed to send ping over the network, reason: {}", this->id, socket_provider->DescribeError()))
fmt::format("{}: failed to send ping over the network, reason: {}", this->id, socket_sender->DescribeError()))
}

++current_timestamp;
Expand All @@ -421,11 +428,11 @@ bool SocketWire::Base::send_ack(sequence_number_t seqn) const
ack_buffer.write_integral(seqn);
{
std::lock_guard<decltype(socket_send_lock)> guard(socket_send_lock);
RD_ASSERT_THROW_MSG(socket_provider->Send(ack_buffer.data(), ack_buffer.get_position()) == PACKAGE_HEADER_LENGTH,
RD_ASSERT_THROW_MSG(socket_sender->Send(ack_buffer.data(), ack_buffer.get_position()) == PACKAGE_HEADER_LENGTH,
this->id +
": failed to send ack over the network"
", reason: " +
socket_provider->DescribeError())
socket_sender->DescribeError())
}
return true;
}
Expand Down
3 changes: 3 additions & 0 deletions rd-cpp/src/rd_framework_cpp/src/main/wire/SocketWire.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
class CSimpleSocket;
class CActiveSocket;
class CPassiveSocket;
class CSimpleSocketSender;

namespace rd
{
Expand All @@ -37,6 +38,8 @@ class RD_FRAMEWORK_API SocketWire
std::string id;
IScheduler* scheduler = nullptr;
std::shared_ptr<CSimpleSocket> socket_provider;
// we do use separate sender for socket_provider to avoid concurrent state modifications during contesting receive and send operations
std::unique_ptr<CSimpleSocketSender> socket_sender;

std::shared_ptr<CActiveSocket> socket;

Expand Down
2 changes: 2 additions & 0 deletions rd-cpp/thirdparty/clsocket/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ SET(CLSOCKET_HEADERS
src/PassiveSocket.h
src/SimpleSocket.h
src/StatTimer.h
src/SimpleSocketSender.h
)

SET(CLSOCKET_SOURCES
src/SimpleSocket.cpp
src/ActiveSocket.cpp
src/PassiveSocket.cpp
src/SimpleSocketSender.cpp
)

# mark headers as headers...
Expand Down
15 changes: 3 additions & 12 deletions rd-cpp/thirdparty/clsocket/src/ActiveSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ bool CActiveSocket::ConnectTCP(const char *pAddr, uint16_t nPort)
// Connect to address "xxx.xxx.xxx.xxx" (IPv4) address only.
//
//------------------------------------------------------------------
m_timer.Initialize();
m_timer.SetStartTime();
CStatTimerCookie timer_cookie(timer);

if (connect(m_socket, (struct sockaddr*)&m_stServerSockaddr, sizeof(m_stServerSockaddr)) ==
CSimpleSocket::SocketError)
Expand Down Expand Up @@ -121,8 +120,6 @@ bool CActiveSocket::ConnectTCP(const char *pAddr, uint16_t nPort)
bRetVal = true;
}

m_timer.SetEndTime();

return bRetVal;
}

Expand Down Expand Up @@ -170,8 +167,7 @@ bool CActiveSocket::ConnectUDP(const char *pAddr, uint16_t nPort)
// Connect to address "xxx.xxx.xxx.xxx" (IPv4) address only.
//
//------------------------------------------------------------------
m_timer.Initialize();
m_timer.SetStartTime();
CStatTimerCookie timer_cookie(timer);

if (connect(m_socket, (struct sockaddr*)&m_stServerSockaddr, sizeof(m_stServerSockaddr)) != CSimpleSocket::SocketError)
{
Expand All @@ -180,8 +176,6 @@ bool CActiveSocket::ConnectUDP(const char *pAddr, uint16_t nPort)

TranslateSocketError();

m_timer.SetEndTime();

return bRetVal;
}

Expand Down Expand Up @@ -228,8 +222,7 @@ bool CActiveSocket::ConnectRAW(const char *pAddr, uint16_t nPort)
// Connect to address "xxx.xxx.xxx.xxx" (IPv4) address only.
//
//------------------------------------------------------------------
m_timer.Initialize();
m_timer.SetStartTime();
CStatTimerCookie timer_cookie(timer);

if (connect(m_socket, (struct sockaddr*)&m_stServerSockaddr, sizeof(m_stServerSockaddr)) != CSimpleSocket::SocketError)
{
Expand All @@ -238,8 +231,6 @@ bool CActiveSocket::ConnectRAW(const char *pAddr, uint16_t nPort)

TranslateSocketError();

m_timer.SetEndTime();

return bRetVal;
}

Expand Down
92 changes: 41 additions & 51 deletions rd-cpp/thirdparty/clsocket/src/PassiveSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,25 +78,22 @@ bool CPassiveSocket::BindMulticast(const char *pInterface, const char *pGroup, u
//--------------------------------------------------------------------------
// Bind to the specified port
//--------------------------------------------------------------------------
if (bind(m_socket, (struct sockaddr *) &m_stMulticastGroup, sizeof(m_stMulticastGroup)) == 0) {
//----------------------------------------------------------------------
// Join the multicast group
//----------------------------------------------------------------------
m_stMulticastRequest.imr_multiaddr.s_addr = inet_addr(pGroup);
m_stMulticastRequest.imr_interface.s_addr = m_stMulticastGroup.sin_addr.s_addr;

if (SETSOCKOPT(m_socket, IPPROTO_IP, IP_ADD_MEMBERSHIP,
(void *) &m_stMulticastRequest,
sizeof(m_stMulticastRequest)) == CSimpleSocket::SocketSuccess) {
bRetVal = true;
}

m_timer.SetEndTime();
}

m_timer.Initialize();
m_timer.SetStartTime();

{
CStatTimerCookie timer_cookie(timer);
if (bind(m_socket, (struct sockaddr *) &m_stMulticastGroup, sizeof(m_stMulticastGroup)) == 0) {
//----------------------------------------------------------------------
// Join the multicast group
//----------------------------------------------------------------------
m_stMulticastRequest.imr_multiaddr.s_addr = inet_addr(pGroup);
m_stMulticastRequest.imr_interface.s_addr = m_stMulticastGroup.sin_addr.s_addr;

if (SETSOCKOPT(m_socket, IPPROTO_IP, IP_ADD_MEMBERSHIP,
(void *) &m_stMulticastRequest,
sizeof(m_stMulticastRequest)) == CSimpleSocket::SocketSuccess) {
bRetVal = true;
}
}
}

//--------------------------------------------------------------------------
// If there was a new_socket error then close the new_socket to clean out the
Expand Down Expand Up @@ -152,29 +149,28 @@ bool CPassiveSocket::Listen(const char *pAddr, uint16_t nPort, int32_t nConnecti
}
}

m_timer.Initialize();
m_timer.SetStartTime();

//--------------------------------------------------------------------------
// Bind to the specified port
//--------------------------------------------------------------------------
if (bind(m_socket, (struct sockaddr *) &m_stServerSockaddr, sizeof(m_stServerSockaddr)) !=
CSimpleSocket::SocketError) {
socklen_t namelen = sizeof(m_stServerSockaddr);
if (getsockname(m_socket, (struct sockaddr *) &m_stServerSockaddr, &namelen) != CSimpleSocket::SocketError) {
if (m_nSocketType == CSimpleSocket::SocketTypeTcp) {
if (listen(m_socket, nConnectionBacklog) != CSimpleSocket::SocketError) {
bRetVal = true;
}
} else {
bRetVal = true;
}
} else {
bRetVal = false;
}
}

m_timer.SetEndTime();
{
CStatTimerCookie timer_cookie(timer);

//--------------------------------------------------------------------------
// Bind to the specified port
//--------------------------------------------------------------------------
if (bind(m_socket, (struct sockaddr *) &m_stServerSockaddr, sizeof(m_stServerSockaddr)) !=
CSimpleSocket::SocketError) {
socklen_t namelen = sizeof(m_stServerSockaddr);
if (getsockname(m_socket, (struct sockaddr *) &m_stServerSockaddr, &namelen) != CSimpleSocket::SocketError) {
if (m_nSocketType == CSimpleSocket::SocketTypeTcp) {
if (listen(m_socket, nConnectionBacklog) != CSimpleSocket::SocketError) {
bRetVal = true;
}
} else {
bRetVal = true;
}
} else {
bRetVal = false;
}
}
}

//--------------------------------------------------------------------------
// If there was a new_socket error then close the new_socket to clean out the
Expand Down Expand Up @@ -213,10 +209,9 @@ CActiveSocket *CPassiveSocket::Accept() {
// Wait for incoming connection.
//--------------------------------------------------------------------------
if (pClientSocket != NULL) {
CSocketError socketErrno = SocketSuccess;
CSocketError socketErrno;

m_timer.Initialize();
m_timer.SetStartTime();
CStatTimerCookie timer_cookie(timer);

nClientSockLen = sizeof(m_stClientSockaddr);

Expand Down Expand Up @@ -246,8 +241,6 @@ CActiveSocket *CPassiveSocket::Accept() {

} while (socketErrno == CSimpleSocket::SocketInterrupted);

m_timer.SetEndTime();

if (socketErrno != CSimpleSocket::SocketSuccess) {
delete pClientSocket;
pClientSocket = NULL;
Expand All @@ -271,14 +264,11 @@ int32_t CPassiveSocket::Send(const uint8_t *pBuf, size_t bytesToSend) {
case CSimpleSocket::SocketTypeUdp: {
if (IsSocketValid()) {
if ((bytesToSend > 0) && (pBuf != NULL)) {
m_timer.Initialize();
m_timer.SetStartTime();
CStatTimerCookie timer_cookie(timer);

m_nBytesSent = static_cast<int32_t>(SENDTO(m_socket, pBuf, bytesToSend, 0,
reinterpret_cast<const sockaddr*>(&m_stClientSockaddr), sizeof(m_stClientSockaddr)));

m_timer.SetEndTime();

if (m_nBytesSent == CSimpleSocket::SocketError) {
TranslateSocketError();
}
Expand Down
Loading

0 comments on commit d200db2

Please sign in to comment.