diff --git a/include/ur_client_library/comm/socket_t.h b/include/ur_client_library/comm/socket_t.h index 79d2358c..8a7d61f7 100644 --- a/include/ur_client_library/comm/socket_t.h +++ b/include/ur_client_library/comm/socket_t.h @@ -16,6 +16,7 @@ #pragma once +#include #ifdef _WIN32 # define NOMINMAX @@ -33,6 +34,7 @@ typedef SOCKET socket_t; typedef SSIZE_T ssize_t; +typedef int socklen_t; static inline int ur_setsockopt(socket_t s, int level, int optname, const void* optval, unsigned int optlen) { @@ -64,3 +66,29 @@ typedef int socket_t; # define ur_close close #endif // _WIN32 + +#ifndef MSG_NOSIGNAL +# define MSG_NOSIGNAL 0 +#endif + +/*! + * \brief Get the last socket error as an std::error_code + * + * On Windows, this will use WSAGetLastError and the system category, while on other platforms it + * will use errno and the generic category. + * + * \return The last socket error + */ +inline std::error_code getLastSocketErrorCode() +{ +#ifdef _WIN32 + return std::error_code(WSAGetLastError(), std::system_category()); +#else + return std::error_code(errno, std::generic_category()); +#endif +} + +inline std::system_error makeSocketError(const std::string& message) +{ + return std::system_error(getLastSocketErrorCode(), message); +} diff --git a/include/ur_client_library/comm/tcp_server.h b/include/ur_client_library/comm/tcp_server.h index 8f212883..465e7907 100644 --- a/include/ur_client_library/comm/tcp_server.h +++ b/include/ur_client_library/comm/tcp_server.h @@ -34,6 +34,7 @@ #include #include #include +#include #include #include @@ -79,9 +80,13 @@ class TCPServer * * \param func Function handling the event information. The file descriptor created by the * connection event will be passed to the function. + * + * \note: The connection callback will be triggered with the socket being accepted. Hence, it + * is possible to send data from the connection callback directly. */ void setConnectCallback(std::function func) { + std::lock_guard lk(callback_mutex_); new_connection_callback_ = func; } @@ -90,9 +95,13 @@ class TCPServer * * \param func Function handling the event information. The file descriptor created by the * connection event will be passed to the function. + * + * \note: The socket will already be closed when the disconnect callback is triggered, thus + * trying to interact with the socket from the disconnect callback will fail. */ void setDisconnectCallback(std::function func) { + std::lock_guard lk(callback_mutex_); disconnect_callback_ = func; } @@ -104,6 +113,7 @@ class TCPServer */ void setMessageCallback(std::function func) { + std::lock_guard lk(message_mutex_); message_callback_ = func; } @@ -116,9 +126,11 @@ class TCPServer void start(); /*! - * \brief Shut down the event listener thread. After calling this, no events will be handled - * anymore, but the socket will remain open and bound to the port. Call start() in order to - * restart event handling. + * \brief Shutdown the server and close all client connections. + * + * \note: This should not be called from within any of the registered callback functions, as + * it will cause a deadlock. If you want to shutdown the server from a callback, you can e.g. + * start a new thread that calls shutdown() from there. */ void shutdown(); @@ -135,6 +147,21 @@ class TCPServer */ bool write(const socket_t fd, const uint8_t* buf, const size_t buf_len, size_t& written); + /*! + * \brief Writes to a filedescriptor without verifying that it is a client or even a valid + * filedescriptor. It is the caller's responsibility to ensure that the filedescriptor is valid + * and belongs to a client. + * + * \param[in] fd File descriptor belonging to the client the data should be sent to. The file + * descriptor will be given from the connection callback. + * \param[in] buf Buffer of bytes to write + * \param[in] buf_len Number of bytes in the buffer + * \param[out] written Number of bytes actually written + * + * \returns True on success, false otherwise + */ + bool writeUnchecked(const socket_t fd, const uint8_t* buf, const size_t buf_len, size_t& written); + /*! * \brief Get the maximum number of clients allowed to connect to this server * @@ -182,7 +209,7 @@ class TCPServer void handleDisconnect(const socket_t fd); //! read data from socket - void readData(const socket_t fd); + bool readData(const socket_t fd); //! Event handler. Blocks until activity on any client or connection attempt void spin(); @@ -190,7 +217,7 @@ class TCPServer //! Runs spin() as long as keep_running_ is set to true. void worker(); - std::atomic keep_running_; + std::atomic keep_running_{ false }; std::thread worker_thread_; std::atomic listen_fd_; @@ -202,6 +229,10 @@ class TCPServer uint32_t max_clients_allowed_; std::vector client_fds_; + std::mutex clients_mutex_; + std::mutex message_mutex_; + std::mutex listen_fd_mutex_; + std::mutex callback_mutex_; static const int INPUT_BUFFER_SIZE = 4096; char input_buffer_[INPUT_BUFFER_SIZE]; diff --git a/include/ur_client_library/log.h b/include/ur_client_library/log.h index ed71d397..87de90e3 100644 --- a/include/ur_client_library/log.h +++ b/include/ur_client_library/log.h @@ -88,5 +88,4 @@ void setLogLevel(LogLevel level); * \param fmt Format string */ void log(const char* file, int line, LogLevel level, const char* fmt, ...); - } // namespace urcl diff --git a/src/comm/tcp_server.cpp b/src/comm/tcp_server.cpp index 22505b6d..90f5d65b 100644 --- a/src/comm/tcp_server.cpp +++ b/src/comm/tcp_server.cpp @@ -29,13 +29,13 @@ #include #include +#include #include #include #include +#include "ur_client_library/comm/socket_t.h" #include -#include -#include namespace urcl { @@ -62,10 +62,10 @@ TCPServer::~TCPServer() void TCPServer::init() { - socket_t err = (listen_fd_ = socket(AF_INET, SOCK_STREAM, 0)); - if (err < 0) + listen_fd_ = socket(AF_INET, SOCK_STREAM, 0); + if (listen_fd_ == INVALID_SOCKET) { - throw std::system_error(std::error_code(errno, std::generic_category()), "Failed to create socket endpoint"); + throw makeSocketError("Failed to create socket endpoint"); } int flag = 1; #ifndef _WIN32 @@ -81,12 +81,29 @@ void TCPServer::init() void TCPServer::shutdown() { + std::unique_lock listen_lk(listen_fd_mutex_, std::try_to_lock); + if (listen_fd_ == INVALID_SOCKET) + { + URCL_LOG_INFO("Listen FD already closed by another thread. Nothing to do here."); + return; + } + if (!listen_lk.owns_lock()) + { + URCL_LOG_WARN("Could not acquire lock for listen FD when shutting down. Is there another thread shutting the " + "server down already? Waiting for lock to be released."); + listen_lk.lock(); + if (listen_fd_ == INVALID_SOCKET) + { + URCL_LOG_INFO("Listen FD already closed by another thread. Nothing to do here."); + return; + } + } keep_running_ = false; socket_t shutdown_socket = ::socket(AF_INET, SOCK_STREAM, 0); if (shutdown_socket == INVALID_SOCKET) { - throw std::system_error(std::error_code(errno, std::generic_category()), "Unable to create shutdown socket."); + throw makeSocketError("Unable to create shutdown socket."); } #ifdef _WIN32 @@ -115,17 +132,22 @@ void TCPServer::shutdown() URCL_LOG_DEBUG("Worker thread joined."); } + std::lock_guard lk(clients_mutex_); for (const auto& client_fd : client_fds_) { ur_close(client_fd); } + // This will effectively deactivate the disconnection handler. + client_fds_.clear(); ur_close(shutdown_socket); ur_close(listen_fd_); + listen_fd_ = INVALID_SOCKET; } void TCPServer::bind(const size_t max_num_tries, const std::chrono::milliseconds reconnection_time) { struct sockaddr_in server_addr; + memset(&server_addr, 0, sizeof(server_addr)); server_addr.sin_family = AF_INET; // INADDR_ANY is a special constant that signalizes "ANY IFACE", @@ -138,8 +160,9 @@ void TCPServer::bind(const size_t max_num_tries, const std::chrono::milliseconds err = ::bind(listen_fd_, (struct sockaddr*)&server_addr, sizeof(server_addr)); if (err == -1) { + auto error_code = getLastSocketErrorCode(); std::ostringstream ss; - ss << "Failed to bind socket for port " << port_ << " to address. Reason: " << strerror(errno); + ss << "Failed to bind socket for port " << port_ << " to address. Reason: " << error_code.message(); if (connection_counter++ < max_num_tries || max_num_tries == 0) { @@ -150,7 +173,7 @@ void TCPServer::bind(const size_t max_num_tries, const std::chrono::milliseconds } else { - throw std::system_error(std::error_code(errno, std::generic_category()), ss.str()); + throw std::system_error(error_code, ss.str()); } } } while (err == -1 && (connection_counter <= max_num_tries || max_num_tries == 0)); @@ -168,13 +191,14 @@ void TCPServer::startListen() { std::ostringstream ss; ss << "Failed to start listen on port " << port_; - throw std::system_error(std::error_code(errno, std::generic_category()), ss.str()); + throw makeSocketError(ss.str()); } struct sockaddr_in sin; socklen_t len = sizeof(sin); if (getsockname(listen_fd_, (struct sockaddr*)&sin, &len) == -1) { - URCL_LOG_ERROR("getsockname() failed to get port number for listening socket: %s", strerror(errno)); + URCL_LOG_ERROR("getsockname() failed to get port number for listening socket: %s", + getLastSocketErrorCode().message().c_str()); } else @@ -191,30 +215,54 @@ void TCPServer::handleConnect() socket_t client_fd = accept(listen_fd_, (struct sockaddr*)&client_addr, &addrlen); if (client_fd == INVALID_SOCKET) { - std::ostringstream ss; - ss << "Failed to accept connection request on port " << port_; - throw std::system_error(std::error_code(errno, std::generic_category()), ss.str()); + URCL_LOG_ERROR("Failed to accept connection request on port %d. Reason: %s", port_, + getLastSocketErrorCode().message().c_str()); + return; + } + +#ifdef _WIN32 + bool set_size_exceeded = client_fds_.size() >= FD_SETSIZE - 1; // -1 because listen_fd_ also occupies one + // slot in masterfds_ +#else + bool set_size_exceeded = client_fd >= FD_SETSIZE; // On Unix-like systems, the client FD itself must be less than + // FD_SETSIZE, otherwise it cannot be added to the fd_set. +#endif + + if (set_size_exceeded) + { + URCL_LOG_ERROR("Accepted client FD %d exceeds FD_SETSIZE (%d). Closing connection.", (int)client_fd, FD_SETSIZE); + ur_close(client_fd); + return; } - if (client_fds_.size() < max_clients_allowed_ || max_clients_allowed_ == 0) + bool accepted = false; + { - client_fds_.push_back(client_fd); - FD_SET(client_fd, &masterfds_); - if (client_fd > maxfd_) + std::lock_guard lk(clients_mutex_); + if (client_fds_.size() < max_clients_allowed_ || max_clients_allowed_ == 0) { - maxfd_ = client_fd; + client_fds_.push_back(client_fd); + FD_SET(client_fd, &masterfds_); + if (client_fd > maxfd_) + { + maxfd_ = client_fd; + } + accepted = true; } - if (new_connection_callback_) + else { - new_connection_callback_(client_fd); + URCL_LOG_WARN("Connection attempt on port %d while maximum number of clients (%d) is already connected. Closing " + "connection.", + port_, max_clients_allowed_); + ur_close(client_fd); } } - else { - URCL_LOG_WARN("Connection attempt on port %d while maximum number of clients (%d) is already connected. Closing " - "connection.", - port_, max_clients_allowed_); - ur_close(client_fd); + std::lock_guard lk(callback_mutex_); + if (new_connection_callback_ && accepted) + { + new_connection_callback_(client_fd); + } } } @@ -222,8 +270,12 @@ void TCPServer::spin() { tempfds_ = masterfds_; + timeval timeout; + timeout.tv_sec = 1; + timeout.tv_usec = 0; + // blocks until activity on any socket from tempfds - int sel = select(static_cast(maxfd_ + 1), &tempfds_, NULL, NULL, NULL); + int sel = select(static_cast(maxfd_ + 1), &tempfds_, NULL, NULL, &timeout); if (sel < 0) { URCL_LOG_ERROR("select() failed. Shutting down socket event handler."); @@ -231,56 +283,91 @@ void TCPServer::spin() return; } - if (!keep_running_) + if (!keep_running_ || sel == 0) { return; } - // Check which fd has an activity - for (socket_t i = 0; i <= maxfd_; i++) + if (FD_ISSET(listen_fd_, &tempfds_)) + { + URCL_LOG_DEBUG("Activity on listen FD %d", (int)listen_fd_); + handleConnect(); + } + + std::vector disconnected_clients; + std::vector client_fds_with_activity; + { - if (FD_ISSET(i, &tempfds_)) + std::lock_guard lk(clients_mutex_); + for (const auto& client_fd : client_fds_) { - URCL_LOG_DEBUG("Activity on FD %d", i); - if (listen_fd_ == i) + if (FD_ISSET(client_fd, &tempfds_)) { - // Activity on the listen_fd means we have a new connection - handleConnect(); - } - else - { - readData(i); + URCL_LOG_DEBUG("Activity on client FD %d", (int)client_fd); + client_fds_with_activity.push_back(client_fd); } } } + // We handle client activity outside the clients_mutex_ lock to avoid holding it during potentially slow I/O and + // message callbacks. + // The clients_mutex_ lock is only needed to protect the client_fds_ vector, but once we have copied the FDs with + // activity to a separate vector, we can safely handle them without holding the lock. + for (const auto& client_fd : client_fds_with_activity) + { + if (!readData(client_fd)) + { + disconnected_clients.push_back(client_fd); + } + } + for (const auto& client_fd : disconnected_clients) + { + handleDisconnect(client_fd); + } } void TCPServer::handleDisconnect(const socket_t fd) { - URCL_LOG_DEBUG("%d disconnected.", fd); - ur_close(fd); - if (disconnect_callback_) + URCL_LOG_INFO("%d disconnected.", fd); { - disconnect_callback_(fd); + std::lock_guard lk(clients_mutex_); + ur_close(fd); + FD_CLR(fd, &masterfds_); + + for (size_t i = 0; i < client_fds_.size(); ++i) + { + if (client_fds_[i] == fd) + { + client_fds_.erase(client_fds_.begin() + i); + break; + } + } + + maxfd_ = listen_fd_; + for (const auto& client_fd : client_fds_) + { + if (client_fd > maxfd_) + { + maxfd_ = client_fd; + } + } } - FD_CLR(fd, &masterfds_); - for (size_t i = 0; i < client_fds_.size(); ++i) { - if (client_fds_[i] == fd) + std::lock_guard lk(callback_mutex_); + if (disconnect_callback_ && keep_running_) { - client_fds_.erase(client_fds_.begin() + i); - break; + disconnect_callback_(fd); } } } -void TCPServer::readData(const socket_t fd) +bool TCPServer::readData(const socket_t fd) { memset(input_buffer_, 0, INPUT_BUFFER_SIZE); // clear input buffer int nbytesrecv = recv(fd, input_buffer_, INPUT_BUFFER_SIZE, 0); if (nbytesrecv > 0) { + std::lock_guard lk(message_mutex_); if (message_callback_) { message_callback_(fd, input_buffer_, nbytesrecv); @@ -290,7 +377,14 @@ void TCPServer::readData(const socket_t fd) { if (nbytesrecv < 0) { - if (errno == ECONNRESET) // if connection gets reset by client, we want to suppress this output + auto check_err = []() { +#ifdef _WIN32 + return WSAGetLastError() == WSAECONNRESET; +#else + return errno == ECONNRESET; +#endif + }; + if (check_err()) // if connection gets reset by client, we want to suppress this output { URCL_LOG_DEBUG("client from FD %d sent a connection reset package.", fd); } @@ -303,8 +397,9 @@ void TCPServer::readData(const socket_t fd) { // normal disconnect } - handleDisconnect(fd); + return false; } + return true; } void TCPServer::worker() @@ -326,13 +421,36 @@ void TCPServer::start() bool TCPServer::write(const socket_t fd, const uint8_t* buf, const size_t buf_len, size_t& written) { written = 0; + if (fd == INVALID_SOCKET) + { + URCL_LOG_ERROR("Invalid socket provided for writing."); + return false; + } + { + std::lock_guard lk(clients_mutex_); + if (std::find(client_fds_.begin(), client_fds_.end(), fd) == client_fds_.end()) + { + URCL_LOG_ERROR("Trying to write to FD %d, but this client is not connected.", fd); + return false; + } + } + + // We don't use a lock around the send call here, since writing on a closed socket would raise + // an error anyway, and the client FD is only removed from client_fds_ after the socket is + // closed. Thus, even if the client gets disconnected right after the check, the send call will + // just fail and return false, which is the expected behavior. + return writeUnchecked(fd, buf, buf_len, written); +} +bool TCPServer::writeUnchecked(const socket_t fd, const uint8_t* buf, const size_t buf_len, size_t& written) +{ size_t remaining = buf_len; // handle partial sends while (written < buf_len) { - ssize_t sent = ::send(fd, reinterpret_cast(buf + written), static_cast(remaining), 0); + ssize_t sent = + ::send(fd, reinterpret_cast(buf + written), static_cast(remaining), MSG_NOSIGNAL); if (sent <= 0) { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 6a470915..99db15ba 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -137,6 +137,9 @@ gtest_add_tests(TARGET rtde_parser_tests ) add_executable(tcp_server_tests test_tcp_server.cpp) +if (MSVC) + target_compile_options(tcp_server_tests PRIVATE /Zc:lambda) +endif() target_link_libraries(tcp_server_tests PRIVATE ur_client_library::urcl GTest::gtest_main) gtest_add_tests(TARGET tcp_server_tests ) diff --git a/tests/fake_rtde_server.cpp b/tests/fake_rtde_server.cpp index 72e23a63..5985b92e 100644 --- a/tests/fake_rtde_server.cpp +++ b/tests/fake_rtde_server.cpp @@ -53,7 +53,7 @@ void RTDEServer::messageCallback(const socket_t filedescriptor, char* buffer, in size += serializer.serialize(buffer + size, accepted); size_t written; - server_.write(filedescriptor, buffer, size, written); + server_.writeUnchecked(filedescriptor, buffer, size, written); break; } case rtde_interface::PackageType::RTDE_GET_URCONTROL_VERSION: @@ -70,7 +70,7 @@ void RTDEServer::messageCallback(const socket_t filedescriptor, char* buffer, in size += serializer.serialize(buffer + size, version); // build size_t written; - server_.write(filedescriptor, buffer, size, written); + server_.writeUnchecked(filedescriptor, buffer, size, written); break; } case rtde_interface::PackageType::RTDE_CONTROL_PACKAGE_SETUP_OUTPUTS: @@ -98,7 +98,7 @@ void RTDEServer::messageCallback(const socket_t filedescriptor, char* buffer, in // only important, that no field is "NOT_FOUND". size_t written; - server_.write(filedescriptor, buffer, size, written); + server_.writeUnchecked(filedescriptor, buffer, size, written); URCL_LOG_INFO("Output recipe set"); break; } @@ -124,7 +124,7 @@ void RTDEServer::messageCallback(const socket_t filedescriptor, char* buffer, in // only important, that no field is "NOT_FOUND". size_t written; - server_.write(filedescriptor, buffer, size, written); + server_.writeUnchecked(filedescriptor, buffer, size, written); URCL_LOG_INFO("Input recipe set with %zu variables.", input_recipe_.size()); break; @@ -140,7 +140,7 @@ void RTDEServer::messageCallback(const socket_t filedescriptor, char* buffer, in size += serializer.serialize(buffer + size, accepted); size_t written; - server_.write(filedescriptor, buffer, size, written); + server_.writeUnchecked(filedescriptor, buffer, size, written); startSendingDataPackages(); break; } @@ -155,7 +155,7 @@ void RTDEServer::messageCallback(const socket_t filedescriptor, char* buffer, in size += serializer.serialize(buffer + size, accepted); size_t written; - server_.write(filedescriptor, buffer, size, written); + server_.writeUnchecked(filedescriptor, buffer, size, written); stopSendingDataPackages(); break; } diff --git a/tests/test_tcp_server.cpp b/tests/test_tcp_server.cpp index 56113141..5e28e032 100644 --- a/tests/test_tcp_server.cpp +++ b/tests/test_tcp_server.cpp @@ -29,9 +29,13 @@ // -- END LICENSE BLOCK ------------------------------------------------ #include +#include +#include #include #include #include +#include +#include #include #include @@ -210,21 +214,29 @@ TEST_F(TCPServerTest, callback_functions) EXPECT_TRUE(waitForDisconnectionCallback()); } -TEST_F(TCPServerTest, unlimited_clients_allowed) +TEST_F(TCPServerTest, many_clients_allowed) { comm::TCPServer server(port_); - server.setMessageCallback(std::bind(&TCPServerTest_unlimited_clients_allowed_Test::messageCallback, this, + server.setMessageCallback(std::bind(&TCPServerTest_many_clients_allowed_Test::messageCallback, this, std::placeholders::_1, std::placeholders::_2)); server.setConnectCallback( - std::bind(&TCPServerTest_unlimited_clients_allowed_Test::connectionCallback, this, std::placeholders::_1)); + std::bind(&TCPServerTest_many_clients_allowed_Test::connectionCallback, this, std::placeholders::_1)); server.setDisconnectCallback( - std::bind(&TCPServerTest_unlimited_clients_allowed_Test::disconnectionCallback, this, std::placeholders::_1)); + std::bind(&TCPServerTest_many_clients_allowed_Test::disconnectionCallback, this, std::placeholders::_1)); server.start(); +#ifdef _WIN32 + // Windows has a maximum of 64 sockets per process, so we can only test with 63 clients since the server also uses one + // socket. + constexpr int num_clients = 63; +#else + constexpr int num_clients = 100; +#endif + // Test that a large number of clients can connect to the server std::vector> clients; std::unique_ptr client; - for (unsigned int i = 0; i < 100; ++i) + for (unsigned int i = 0; i < num_clients; ++i) { clients.push_back(std::make_unique(port_)); ASSERT_TRUE(waitForConnectionCallback()); @@ -374,6 +386,244 @@ TEST_F(TCPServerTest, check_shutting_down_server_while_listening) EXPECT_EQ(client.getState(), comm::SocketState::Disconnected); } +TEST_F(TCPServerTest, double_shutdown) +{ + comm::TCPServer server(port_); + server.setConnectCallback( + std::bind(&TCPServerTest_double_shutdown_Test::connectionCallback, this, std::placeholders::_1)); + server.setDisconnectCallback( + std::bind(&TCPServerTest_double_shutdown_Test::disconnectionCallback, this, std::placeholders::_1)); + server.start(); + + Client client(port_); + EXPECT_TRUE(waitForConnectionCallback()); + + EXPECT_NO_THROW(server.shutdown()); + EXPECT_NO_THROW(server.shutdown()); +} + +TEST_F(TCPServerTest, concurrent_writes_same_client) +{ + comm::TCPServer server(0); + server.setConnectCallback([this](const socket_t fd) { connectionCallback(fd); }); + server.setDisconnectCallback([this](const socket_t fd) { disconnectionCallback(fd); }); + server.start(); + + Client client(server.getPort()); + ASSERT_TRUE(waitForConnectionCallback()); + const socket_t fd = client_fd_; + + const std::string message = "test data\n"; + const auto* data = reinterpret_cast(message.c_str()); + const size_t len = message.size(); + + constexpr int num_threads = 10; + constexpr int writes_per_thread = 100; + std::atomic success_count{ 0 }; + std::vector writers; + + for (int i = 0; i < num_threads; ++i) + { + writers.emplace_back([&server, fd, data, len, &success_count]() { + for (int j = 0; j < writes_per_thread; ++j) + { + size_t written; + if (server.write(fd, data, len, written)) + { + ++success_count; + } + } + }); + } + + for (auto& t : writers) + { + t.join(); + } + + EXPECT_EQ(success_count.load(), num_threads * writes_per_thread); +} + +TEST_F(TCPServerTest, write_during_client_disconnect) +{ + comm::TCPServer server(0); + server.setConnectCallback([this](const socket_t fd) { connectionCallback(fd); }); + server.setDisconnectCallback([this](const socket_t fd) { disconnectionCallback(fd); }); + server.start(); + + Client client(server.getPort()); + ASSERT_TRUE(waitForConnectionCallback()); + const socket_t fd = client_fd_; + + const std::string message = "test data\n"; + const auto* data = reinterpret_cast(message.c_str()); + const size_t len = message.size(); + + std::atomic stop{ false }; + + std::thread writer([&server, fd, data, len, &stop]() { + while (!stop.load()) + { + size_t written; + server.write(fd, data, len, written); + } + }); + + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + client.close(); + ASSERT_TRUE(waitForDisconnectionCallback()); + + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + stop.store(true); + writer.join(); +} + +TEST_F(TCPServerTest, rapid_connect_disconnect_with_concurrent_writes) +{ + comm::TCPServer server(0); + + std::mutex fds_mutex; + std::vector connected_fds; + + server.setConnectCallback([&](const socket_t fd) { + std::lock_guard lk(fds_mutex); + connected_fds.push_back(fd); + }); + server.setDisconnectCallback([&](const socket_t fd) { + std::lock_guard lk(fds_mutex); + connected_fds.erase(std::remove(connected_fds.begin(), connected_fds.end(), fd), connected_fds.end()); + }); + server.start(); + + const std::string message = "test data\n"; + const auto* data = reinterpret_cast(message.c_str()); + const size_t len = message.size(); + + std::atomic stop{ false }; + + std::thread writer([&]() { + while (!stop.load()) + { + std::vector snapshot; + { + std::lock_guard lk(fds_mutex); + snapshot = connected_fds; + } + for (const auto& fd : snapshot) + { + size_t written; + server.write(fd, data, len, written); + } + std::this_thread::sleep_for(std::chrono::microseconds(100)); + } + }); + + constexpr int num_iterations = 50; + for (int i = 0; i < num_iterations; ++i) + { + Client client(server.getPort()); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + + stop.store(true); + writer.join(); +} + +TEST_F(TCPServerTest, concurrent_writes_multiple_clients) +{ + comm::TCPServer server(0); + + std::mutex fds_mutex; + std::vector connected_fds; + std::condition_variable all_connected_cv; + constexpr size_t num_clients = 10; + + server.setConnectCallback([&](const socket_t fd) { + std::lock_guard lk(fds_mutex); + connected_fds.push_back(fd); + if (connected_fds.size() == num_clients) + { + all_connected_cv.notify_all(); + } + }); + server.start(); + + std::vector> clients; + for (size_t i = 0; i < num_clients; ++i) + { + clients.push_back(std::make_unique(server.getPort())); + } + + { + std::unique_lock lk(fds_mutex); + ASSERT_TRUE( + all_connected_cv.wait_for(lk, std::chrono::seconds(5), [&]() { return connected_fds.size() == num_clients; })); + } + + const std::string message = "test data\n"; + const auto* data = reinterpret_cast(message.c_str()); + const size_t len = message.size(); + + constexpr int writes_per_thread = 100; + std::atomic total_successes{ 0 }; + + std::vector writers; + { + std::lock_guard lk(fds_mutex); + for (const auto& fd : connected_fds) + { + writers.emplace_back([&server, fd, data, len, &total_successes]() { + for (int j = 0; j < writes_per_thread; ++j) + { + size_t written; + if (server.write(fd, data, len, written)) + { + ++total_successes; + } + } + }); + } + } + + for (auto& t : writers) + { + t.join(); + } + + EXPECT_EQ(total_successes.load(), static_cast(num_clients) * writes_per_thread); +} + +TEST_F(TCPServerTest, shutdown_during_active_writes) +{ + comm::TCPServer server(0); + server.setConnectCallback([this](const socket_t fd) { connectionCallback(fd); }); + server.setDisconnectCallback([this](const socket_t fd) { disconnectionCallback(fd); }); + server.start(); + + Client client(server.getPort()); + ASSERT_TRUE(waitForConnectionCallback()); + const socket_t fd = client_fd_; + + const std::string message = "test data\n"; + const auto* data = reinterpret_cast(message.c_str()); + const size_t len = message.size(); + + std::atomic stop{ false }; + + std::thread writer([&server, fd, data, len, &stop]() { + while (!stop.load()) + { + size_t written; + server.write(fd, data, len, written); + } + }); + + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + server.shutdown(); + stop.store(true); + writer.join(); +} + int main(int argc, char* argv[]) { ::testing::InitGoogleTest(&argc, argv);