From 13bafc068fb53ba43bc66339a0650112c72594ad Mon Sep 17 00:00:00 2001 From: root Date: Thu, 14 May 2026 13:24:47 +0000 Subject: [PATCH 01/10] add TcpEndpoint v3: asio-based TCP transport with async primitives Introduce a standalone-asio-based TcpEndpoint in dlslime/csrc/engine/tcp/ with four async communication primitives, all supporting timeout (default 30s). Architecture highlights: - 17-byte SessionHeader (Mooncake-aligned): {size, addr, opcode} with 3 opcodes (OP_SEND, OP_READ, OP_WRITE) supporting 4 primitives (recv matched passively) - TcpContext: shared io_context + connection pool + background thread, multiple endpoints can share one context to reduce thread count - TcpConnectionPool: (host, port)-keyed connection reuse, 60s idle timeout - ServerSession: async_read callback chain (readHeader->dispatch->readBody loop) with 64KB chunked reads for large payloads - Symmetric connection rendezvous (is_initiator by host:port comparison) Async primitives: - async_send(chunk, timeout_ms=30000): post to io_ctx, async_write, signal future - async_recv(chunk, timeout_ms=30000): FIFO registration, ServerSession matches incoming OP_SEND, memcpy to user buffer, signal future - async_read(assign, timeout_ms=30000): post OP_READ header, async_read response data, connection reserved until response arrives - async_write(assign, timeout_ms=30000): post OP_WRITE header+payload via async_write, signal future Timeout: SO_SNDTIMEO on socket for send/write, future.wait_for(ms) timed busy-spin (machnet_pause) for recv/read. All return TcpFuture with wait() and wait_for(seconds) -> int|None. Files: 16 new (10 in tcp/), 5 modified (CMakeLists chain + bind.cpp) Tests: 5 Python cases (send/recv, write/read, recv timeout, send timeout, default timeout) all pass. Co-Authored-By: Claude Opus 4.7 --- CMakeLists.txt | 1 + dlslime/csrc/CMakeLists.txt | 4 + dlslime/csrc/engine/CMakeLists.txt | 4 + dlslime/csrc/engine/tcp/CMakeLists.txt | 40 ++ dlslime/csrc/engine/tcp/build_and_test.sh | 54 +++ .../csrc/engine/tcp/tcp_connection_pool.cpp | 110 +++++ dlslime/csrc/engine/tcp/tcp_connection_pool.h | 68 +++ dlslime/csrc/engine/tcp/tcp_context.cpp | 27 ++ dlslime/csrc/engine/tcp/tcp_context.h | 41 ++ dlslime/csrc/engine/tcp/tcp_endpoint.cpp | 448 ++++++++++++++++++ dlslime/csrc/engine/tcp/tcp_endpoint.h | 133 ++++++ dlslime/csrc/engine/tcp/tcp_future.h | 67 +++ dlslime/csrc/engine/tcp/tcp_header.h | 32 ++ dlslime/csrc/engine/tcp/tcp_memory_pool.cpp | 128 +++++ dlslime/csrc/engine/tcp/tcp_memory_pool.h | 63 +++ dlslime/csrc/engine/tcp/tcp_op_state.h | 41 ++ dlslime/csrc/engine/tcp/tcp_session.cpp | 142 ++++++ dlslime/csrc/engine/tcp/tcp_session.h | 60 +++ dlslime/csrc/engine/tcp/test_tcp_endpoint.py | 210 ++++++++ dlslime/csrc/python/CMakeLists.txt | 5 + dlslime/csrc/python/bind.cpp | 129 +++++ 21 files changed, 1807 insertions(+) create mode 100644 dlslime/csrc/engine/tcp/CMakeLists.txt create mode 100755 dlslime/csrc/engine/tcp/build_and_test.sh create mode 100644 dlslime/csrc/engine/tcp/tcp_connection_pool.cpp create mode 100644 dlslime/csrc/engine/tcp/tcp_connection_pool.h create mode 100644 dlslime/csrc/engine/tcp/tcp_context.cpp create mode 100644 dlslime/csrc/engine/tcp/tcp_context.h create mode 100644 dlslime/csrc/engine/tcp/tcp_endpoint.cpp create mode 100644 dlslime/csrc/engine/tcp/tcp_endpoint.h create mode 100644 dlslime/csrc/engine/tcp/tcp_future.h create mode 100644 dlslime/csrc/engine/tcp/tcp_header.h create mode 100644 dlslime/csrc/engine/tcp/tcp_memory_pool.cpp create mode 100644 dlslime/csrc/engine/tcp/tcp_memory_pool.h create mode 100644 dlslime/csrc/engine/tcp/tcp_op_state.h create mode 100644 dlslime/csrc/engine/tcp/tcp_session.cpp create mode 100644 dlslime/csrc/engine/tcp/tcp_session.h create mode 100644 dlslime/csrc/engine/tcp/test_tcp_endpoint.py diff --git a/CMakeLists.txt b/CMakeLists.txt index b451b4e8..3f2c5475 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,6 +13,7 @@ slime_option(USE_MACA "USE in MACA Platform" OFF) slime_option(BUILD_NVLINK "Build NVLINK" OFF) slime_option(BUILD_ASCEND_DIRECT "Build Ascend direct transport" OFF) +slime_option(BUILD_TCP "Build TCP transport" ON) # Slime options for custom python wrapper slime_option(BUILD_PYTHON "Build python wrapper" OFF) diff --git a/dlslime/csrc/CMakeLists.txt b/dlslime/csrc/CMakeLists.txt index 0947f3c1..045f74ad 100644 --- a/dlslime/csrc/CMakeLists.txt +++ b/dlslime/csrc/CMakeLists.txt @@ -27,6 +27,10 @@ if(BUILD_RDMA) target_link_libraries(dlslime INTERFACE _slime_rdma) endif() +if(BUILD_TCP) + target_link_libraries(dlslime INTERFACE _slime_tcp) +endif() + # rpc/ has no independent C++ consumers (the session API is all # pybind11). It is compiled straight into _slime_c.so by the python # subdirectory below. Keep the source files here as a marker so future diff --git a/dlslime/csrc/engine/CMakeLists.txt b/dlslime/csrc/engine/CMakeLists.txt index c03c88c7..c9bffdf5 100755 --- a/dlslime/csrc/engine/CMakeLists.txt +++ b/dlslime/csrc/engine/CMakeLists.txt @@ -34,3 +34,7 @@ endif() if (BUILD_RDMA) add_subdirectory(rdma) endif() + +if (BUILD_TCP) + add_subdirectory(tcp) +endif() diff --git a/dlslime/csrc/engine/tcp/CMakeLists.txt b/dlslime/csrc/engine/tcp/CMakeLists.txt new file mode 100644 index 00000000..5487b06b --- /dev/null +++ b/dlslime/csrc/engine/tcp/CMakeLists.txt @@ -0,0 +1,40 @@ +# asio is header-only. Try find_package, fall back to manual detection. +find_package(asio QUIET) +if(NOT asio_FOUND) + if(EXISTS /usr/include/asio.hpp) + add_library(asio::asio INTERFACE IMPORTED) + target_include_directories(asio::asio INTERFACE /usr/include) + elseif(EXISTS /usr/include/boost/asio.hpp) + add_library(asio::asio INTERFACE IMPORTED) + target_include_directories(asio::asio INTERFACE /usr/include/boost) + target_compile_definitions(asio::asio INTERFACE ASIO_STANDALONE) + else() + message(FATAL_ERROR "asio not found. Install libasio-dev or boost.") + endif() +endif() + +add_library(_slime_tcp SHARED + tcp_memory_pool.cpp + tcp_connection_pool.cpp + tcp_context.cpp + tcp_session.cpp + tcp_endpoint.cpp +) + +target_compile_definitions(_slime_tcp PRIVATE ASIO_STANDALONE) + +target_link_libraries(_slime_tcp PUBLIC + asio::asio + _slime_device + _slime_engine +) + +set_target_properties(_slime_tcp PROPERTIES + BUILD_WITH_INSTALL_RPATH TRUE + INSTALL_RPATH "${ORIGIN}" +) + +install(TARGETS _slime_tcp + EXPORT dlslimeTargets + LIBRARY DESTINATION ${DLSLIME_INSTALL_PATH} +) diff --git a/dlslime/csrc/engine/tcp/build_and_test.sh b/dlslime/csrc/engine/tcp/build_and_test.sh new file mode 100755 index 00000000..0283ab30 --- /dev/null +++ b/dlslime/csrc/engine/tcp/build_and_test.sh @@ -0,0 +1,54 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/../../../.." && pwd)" +BUILD_DIR="$REPO_ROOT/build_tcp" +MODE="${1:-all}" + +header() { echo; echo -e "\033[1;36m==>\033[m \033[1m$*\033[m"; } +ok() { echo -e " \033[1;32mOK\033[m $*"; } + +do_build() { + header "Configuring (BUILD_TCP=ON, BUILD_RDMA=OFF)" + cmake -S "$REPO_ROOT" -B "$BUILD_DIR" -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DDLSLIME_INSTALL_PATH=dlslime \ + -DBUILD_PYTHON=ON \ + -DBUILD_RDMA=OFF \ + -DBUILD_TCP=ON \ + -DBUILD_NVLINK=OFF \ + -DBUILD_ASCEND_DIRECT=OFF \ + -DSKBUILD_PROJECT_NAME=dlslime 2>&1 | tail -3 + ok "CMake configure" + + header "Building _slime_c" + cmake --build "$BUILD_DIR" --target _slime_c -j"$(nproc)" 2>&1 | tail -8 + ok "Build complete" + + cp "$BUILD_DIR/lib/"*.so "$REPO_ROOT/dlslime/" + ok "Copied .so files to dlslime/" +} + +do_test() { + header "Running TcpEndpoint v3 tests" + export DLSLIME_LOG_LEVEL=0 + export LD_LIBRARY_PATH="$REPO_ROOT/dlslime" + export PYTHONPATH="$REPO_ROOT" + python3 "$SCRIPT_DIR/test_tcp_endpoint.py" 2>&1 | while IFS= read -r line; do + if [[ "$line" == *"PASSED"* ]]; then echo -e " \033[1;32m✓\033[m $line" + elif [[ "$line" == *"FAIL"* ]]; then echo -e " \033[1;91m✗\033[m $line" + else echo " $line" + fi + done + ok "All tests passed" +} + +case "$MODE" in + all) do_build; do_test ;; + build) do_build ;; + test) do_test ;; + clean) rm -rf "$BUILD_DIR" "$REPO_ROOT/dlslime/_slime_c"*.so "$REPO_ROOT/dlslime/lib_slime_"*.so + ok "Cleaned" ;; + *) echo "Usage: $0 {all|build|test|clean}" >&2; exit 1 ;; +esac diff --git a/dlslime/csrc/engine/tcp/tcp_connection_pool.cpp b/dlslime/csrc/engine/tcp/tcp_connection_pool.cpp new file mode 100644 index 00000000..d2bd77af --- /dev/null +++ b/dlslime/csrc/engine/tcp/tcp_connection_pool.cpp @@ -0,0 +1,110 @@ +#include "tcp_connection_pool.h" + +#include + +#include "dlslime/csrc/logging.h" + +namespace dlslime { +namespace tcp { + +using tcp = asio::ip::tcp; + +std::shared_ptr +TcpConnectionPool::getConnection(const std::string& host, uint16_t port) { + ConnKey key{host, port}; + + { + std::lock_guard lk(mu_); + auto it = pool_.find(key); + if (it != pool_.end()) { + for (auto& c : it->second) { + if (!c->in_use && c->socket.is_open()) { + c->in_use = true; + c->last_used = std::chrono::steady_clock::now(); + return c; + } + } + } + } + + tcp::resolver resolver(io_ctx_); + auto endpoints = resolver.resolve(host, std::to_string(port)); + tcp::socket sock(io_ctx_); + asio::error_code ec; + asio::connect(sock, endpoints, ec); + if (ec) { + SLIME_LOG_WARN("TcpConnectionPool: connect to ", host, ":", port, + " failed: ", ec.message()); + return nullptr; + } + sock.set_option(tcp::no_delay(true)); + + auto conn = std::make_shared(std::move(sock), host, port); + { + std::lock_guard lk(mu_); + auto& q = pool_[key]; + for (auto& c : q) { + if (!c->in_use && c->socket.is_open()) { + asio::error_code ign; + conn->socket.close(ign); + c->in_use = true; + c->last_used = std::chrono::steady_clock::now(); + return c; + } + } + q.push_back(conn); + } + return conn; +} + +void TcpConnectionPool::returnConnection( + std::shared_ptr conn) { + if (!conn) return; + std::lock_guard lk(mu_); + if (conn->socket.is_open()) { + conn->in_use = false; + conn->last_used = std::chrono::steady_clock::now(); + } else { + ConnKey key{conn->host, conn->port}; + auto it = pool_.find(key); + if (it != pool_.end()) { + auto& q = it->second; + for (auto qi = q.begin(); qi != q.end(); ++qi) + if (*qi == conn) { q.erase(qi); break; } + if (q.empty()) pool_.erase(it); + } + } +} + +void TcpConnectionPool::cleanupIdleConnections() { + auto now = std::chrono::steady_clock::now(); + std::lock_guard lk(mu_); + for (auto it = pool_.begin(); it != pool_.end(); ) { + auto& q = it->second; + while (!q.empty()) { + auto& c = q.back(); + if (!c->in_use) { + auto idle = std::chrono::duration_cast( + now - c->last_used).count(); + if (idle > kIdleTimeout.count()) { + asio::error_code ign; + c->socket.close(ign); + q.pop_back(); + continue; + } + } + break; + } + if (q.empty()) it = pool_.erase(it); else ++it; + } +} + +void TcpConnectionPool::clear() { + std::lock_guard lk(mu_); + for (auto& [_, q] : pool_) + for (auto& c : q) { asio::error_code ign; c->socket.close(ign); } + pool_.clear(); +} + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/tcp_connection_pool.h b/dlslime/csrc/engine/tcp/tcp_connection_pool.h new file mode 100644 index 00000000..f06254a3 --- /dev/null +++ b/dlslime/csrc/engine/tcp/tcp_connection_pool.h @@ -0,0 +1,68 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace dlslime { +namespace tcp { + +struct PooledConnection { + asio::ip::tcp::socket socket; + std::string host; + uint16_t port{0}; + std::chrono::steady_clock::time_point last_used; + bool in_use{true}; + + PooledConnection(asio::ip::tcp::socket s, std::string h, uint16_t p) + : socket(std::move(s)), host(std::move(h)), port(p), + last_used(std::chrono::steady_clock::now()) {} +}; + +// Keyed by (host, port). Thread-safe. +// States: IDLE (in deque, in_use=false) / ACTIVE (checked out) / RESERVED +class TcpConnectionPool { +public: + static constexpr std::chrono::seconds kIdleTimeout{60}; + + explicit TcpConnectionPool(asio::io_context& io_ctx) : io_ctx_(io_ctx) {} + + std::shared_ptr getConnection( + const std::string& host, uint16_t port); + + void returnConnection(std::shared_ptr conn); + + void cleanupIdleConnections(); + void clear(); + +private: + struct ConnKey { + std::string host; + uint16_t port; + bool operator==(const ConnKey& o) const { + return host == o.host && port == o.port; + } + }; + struct ConnKeyHash { + size_t operator()(const ConnKey& k) const { + return std::hash{}(k.host) + ^ (std::hash{}(k.port) << 1); + } + }; + + asio::io_context& io_ctx_; + std::mutex mu_; + std::unordered_map>, + ConnKeyHash> pool_; +}; + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/tcp_context.cpp b/dlslime/csrc/engine/tcp/tcp_context.cpp new file mode 100644 index 00000000..f669e9be --- /dev/null +++ b/dlslime/csrc/engine/tcp/tcp_context.cpp @@ -0,0 +1,27 @@ +#include "tcp_context.h" + +namespace dlslime { +namespace tcp { + +TcpContext::TcpContext() { + // Keep io_context alive even when there's no work posted yet. + auto work = asio::make_work_guard(io_ctx_); + io_thread_ = std::thread([this, w = std::move(work)]() { + io_ctx_.run(); + }); +} + +TcpContext::~TcpContext() { + shutdown(); +} + +void TcpContext::shutdown() { + if (!running_) return; + running_ = false; + io_ctx_.stop(); + if (io_thread_.joinable()) io_thread_.join(); + conn_pool_.clear(); +} + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/tcp_context.h b/dlslime/csrc/engine/tcp/tcp_context.h new file mode 100644 index 00000000..a3bd5185 --- /dev/null +++ b/dlslime/csrc/engine/tcp/tcp_context.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include + +#include +#include + +#include "tcp_connection_pool.h" + +namespace dlslime { +namespace tcp { + +// Process-level shared resource holder. +// Multiple TcpEndpoints can share one TcpContext to run on a single +// io_context thread, reducing thread count. +// +// For sync wrappers: sync_send() = async_send() + future.wait() +// — the io_context drives the I/O, the caller thread just blocks. +class TcpContext { +public: + TcpContext(); + ~TcpContext(); + + TcpContext(const TcpContext&) = delete; + TcpContext& operator=(const TcpContext&) = delete; + + asio::io_context& io_context() { return io_ctx_; } + TcpConnectionPool& conn_pool() { return conn_pool_; } + + void shutdown(); + +private: + asio::io_context io_ctx_; + std::thread io_thread_; + TcpConnectionPool conn_pool_{io_ctx_}; + bool running_{true}; +}; + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/tcp_endpoint.cpp b/dlslime/csrc/engine/tcp/tcp_endpoint.cpp new file mode 100644 index 00000000..df0792e3 --- /dev/null +++ b/dlslime/csrc/engine/tcp/tcp_endpoint.cpp @@ -0,0 +1,448 @@ +#include "tcp_endpoint.h" + +#include +#include + +#include +#include + +#include "dlslime/csrc/logging.h" + +namespace dlslime { +namespace tcp { + +using tcp = asio::ip::tcp; + +// ── helpers ───────────────────────────────────────────── + +static void hdr_hton(SessionHeader& h) { + h.size = htole64(h.size); + h.addr = htole64(h.addr); +} + +void TcpEndpoint::set_sndtimeo(int fd, int64_t ms) { + struct timeval tv; + tv.tv_sec = static_cast(ms / 1000); + tv.tv_usec = static_cast((ms % 1000) * 1000); + setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)); +} + +// ── RecvMatcher factory ──────────────────────────────── + +ServerSession::RecvMatcher TcpEndpoint::make_recv_matcher() { + std::weak_ptr weak = shared_from_this(); + return [weak]() -> RecvSlot { + auto self = weak.lock(); + if (!self) return {}; + std::lock_guard lk(self->recv_mu_); + if (self->pending_recvs_.empty()) return {}; + auto pr = std::move(self->pending_recvs_.front()); + self->pending_recvs_.pop_front(); + return {pr.op_state->user_buffer, pr.op_state->user_length, pr.op_state}; + }; +} + +// ── Constructor ──────────────────────────────────────── + +TcpEndpoint::TcpEndpoint(uint16_t port) + : own_ctx_(std::make_unique()) + , acceptor_(own_ctx_->io_context()) + , local_pool_(std::make_shared()) + , remote_pool_(std::make_shared()) { + ctx_ = own_ctx_.get(); + local_port_ = port; + start_io(); +} + +TcpEndpoint::TcpEndpoint(TcpContext& ctx, uint16_t port) + : acceptor_(ctx.io_context()) + , local_pool_(std::make_shared()) + , remote_pool_(std::make_shared()) { + ctx_ = &ctx; + local_port_ = port; + start_io(); +} + +TcpEndpoint::~TcpEndpoint() { + shutdown(); +} + +void TcpEndpoint::start_io() { + auto ep = tcp::endpoint(tcp::v4(), local_port_); + acceptor_.open(ep.protocol()); + acceptor_.set_option(tcp::acceptor::reuse_address(true)); + acceptor_.bind(ep); + acceptor_.listen(64); + + if (local_port_ == 0) { + asio::error_code ec; + local_port_ = acceptor_.local_endpoint(ec).port(); + } + + do_accept(); +} + +// ── do_accept ─────────────────────────────────────────── + +void TcpEndpoint::do_accept() { + if (!running_.load(std::memory_order_acquire)) return; + acceptor_.async_accept( + [this](asio::error_code ec, tcp::socket sock) { + if (ec) { + if (ec != asio::error::operation_aborted) + SLIME_LOG_WARN("TcpEndpoint accept: ", ec.message()); + return; + } + sock.set_option(tcp::no_delay(true)); + auto session = std::make_shared( + std::move(sock), local_pool_.get(), make_recv_matcher()); + session->start(); + do_accept(); + }); +} + +// ── endpoint_info / connect ───────────────────────────── + +json TcpEndpoint::endpoint_info() const { + return { + {"host", local_host_}, + {"port", local_port_}, + {"mr_info", local_pool_->mr_info()} + }; +} + +json TcpEndpoint::mr_info() const { + return local_pool_->mr_info(); +} + +bool TcpEndpoint::is_initiator(const std::string& peer_host, + uint16_t peer_port) const { + int cmp = local_host_.compare(peer_host); + if (cmp != 0) return cmp > 0; + return local_port_ > peer_port; +} + +void TcpEndpoint::connect(const json& remote_info) { + if (connected_.load(std::memory_order_acquire)) return; + + peer_host_ = remote_info.value("host", ""); + peer_port_ = static_cast(remote_info.value("port", 0)); + + if (remote_info.contains("mr_info")) { + for (const auto& [name, info] : remote_info["mr_info"].items()) + remote_pool_->register_remote_memory_region(info, name); + } + + if (is_initiator(peer_host_, peer_port_)) { + auto conn = ctx_->conn_pool().getConnection(peer_host_, peer_port_); + if (conn) ctx_->conn_pool().returnConnection(std::move(conn)); + } + + connected_.store(true, std::memory_order_release); +} + +// ── memory registration ───────────────────────────────── + +int32_t TcpEndpoint::register_memory_region(const std::string& name, + uintptr_t ptr, uintptr_t offset, + size_t length) { + return local_pool_->register_memory_region(ptr, offset, length, name); +} + +int32_t TcpEndpoint::register_remote_memory_region(const std::string& name, + const json& mr_info) { + return remote_pool_->register_remote_memory_region(mr_info, name); +} + +// ── write_message ─────────────────────────────────────── + +bool TcpEndpoint::write_message(tcp::socket& sock, + const SessionHeader& hdr, + const void* payload) { + asio::error_code ec; + SessionHeader net = hdr; + hdr_hton(net); + std::array bufs = { + asio::buffer(&net, sizeof(net)), + asio::buffer(payload, hdr.size) + }; + asio::write(sock, bufs, ec); + return !ec; +} + +// ── async_send ────────────────────────────────────────── + +std::shared_ptr +TcpEndpoint::async_send(const chunk_tuple_t& chunk, int64_t timeout_ms, void*) { + auto mr = local_pool_->get_mr_fast(static_cast(std::get<0>(chunk))); + if (mr.length == 0) + throw std::runtime_error("TcpEndpoint::async_send: invalid local MR"); + + uintptr_t src = mr.addr + mr.offset + std::get<1>(chunk); + size_t len = std::get<2>(chunk); + + auto conn = ctx_->conn_pool().getConnection(peer_host_, peer_port_); + auto op = TcpOpState::create(); + op->signal->reset_all(); + + if (!conn) { + op->completion_status.store(TCP_FAILED, std::memory_order_release); + op->signal->force_complete(); + return std::make_shared(op); + } + + if (timeout_ms > 0) + set_sndtimeo(conn->socket.native_handle(), timeout_ms); + + SessionHeader hdr{len, 0, OP_SEND}; + auto& pool = ctx_->conn_pool(); + + std::weak_ptr weak = weak_from_this(); + asio::post(ctx_->io_context(), [weak, conn, op, hdr, src, len, timeout_ms, &pool]() { + auto ep = weak.lock(); + if (!ep) { + op->completion_status.store(TCP_CLOSED, std::memory_order_release); + if (op->signal) op->signal->force_complete(); + return; + } + + asio::error_code ec; + SessionHeader net = hdr; + hdr_hton(net); + std::array bufs = { + asio::buffer(&net, sizeof(net)), + asio::buffer(reinterpret_cast(src), len) + }; + asio::async_write(conn->socket, bufs, + [conn, op, timeout_ms, &pool](asio::error_code ec, size_t) { + if (timeout_ms > 0 && conn->socket.is_open()) + TcpEndpoint::set_sndtimeo(conn->socket.native_handle(), 0); + op->completion_status.store( + ec ? TCP_FAILED : TCP_SUCCESS, std::memory_order_release); + if (op->signal) op->signal->set_comm_done(0); + pool.returnConnection(conn); + }); + }); + + return std::make_shared(op); +} + +// ── async_recv ────────────────────────────────────────── + +std::shared_ptr +TcpEndpoint::async_recv(const chunk_tuple_t& chunk, int64_t /*timeout_ms*/, void*) { + auto mr = local_pool_->get_mr_fast(static_cast(std::get<0>(chunk))); + if (mr.length == 0) + throw std::runtime_error("TcpEndpoint::async_recv: invalid local MR"); + + auto op = TcpOpState::create(); + op->signal->reset_all(); + op->user_buffer = mr.addr + mr.offset + std::get<1>(chunk); + op->user_length = std::get<2>(chunk); + + { + std::lock_guard lk(recv_mu_); + pending_recvs_.push_back({op}); + } + + return std::make_shared(op); +} + +// ── async_read ────────────────────────────────────────── + +std::shared_ptr +TcpEndpoint::async_read(const std::vector& assign, + int64_t /*timeout_ms*/, void*) { + if (assign.empty()) + throw std::runtime_error("TcpEndpoint::async_read: empty assignment"); + + const auto& a = assign[0]; + int32_t local_h = static_cast(std::get<0>(a)); + int32_t remote_h = static_cast(std::get<1>(a)); + uint64_t remote_off = std::get<2>(a); + uint64_t local_off = std::get<3>(a); + size_t length = std::get<4>(a); + + auto local = local_pool_->get_mr_fast(local_h); + auto remote = remote_pool_->get_remote_mr_fast(remote_h); + if (local.length == 0 || remote.length == 0) + throw std::runtime_error("TcpEndpoint::async_read: invalid MR handle"); + + auto op = TcpOpState::create(); + op->signal->reset_all(); + op->user_buffer = local.addr + local.offset + local_off; + op->user_length = length; + + auto conn = ctx_->conn_pool().getConnection(peer_host_, peer_port_); + if (!conn) { + op->completion_status.store(TCP_FAILED, std::memory_order_release); + op->signal->force_complete(); + return std::make_shared(op); + } + + uint64_t req_id = next_req_id_.fetch_add(1, std::memory_order_relaxed); + { + std::lock_guard lk(read_mu_); + pending_reads_[req_id] = {conn, op}; + } + + SessionHeader hdr{length, remote.addr + remote.offset + remote_off, OP_READ}; + auto& pool = ctx_->conn_pool(); + + std::weak_ptr weak = weak_from_this(); + asio::post(ctx_->io_context(), [weak, conn, op, hdr, req_id, &pool]() { + auto ep = weak.lock(); + if (!ep) { + op->completion_status.store(TCP_CLOSED, std::memory_order_release); + if (op->signal) op->signal->force_complete(); + return; + } + + SessionHeader net = hdr; + hdr_hton(net); + asio::async_write(conn->socket, + asio::buffer(&net, sizeof(net)), + [weak, conn, op, req_id, &pool](asio::error_code ec, size_t) { + if (ec) { + op->completion_status.store(TCP_FAILED, std::memory_order_release); + if (op->signal) op->signal->set_comm_done(0); + pool.returnConnection(conn); + auto self = weak.lock(); + if (self) { + std::lock_guard lk(self->read_mu_); + self->pending_reads_.erase(req_id); + } + return; + } + + // Read raw response data (no header). + asio::async_read(conn->socket, + asio::buffer(reinterpret_cast(op->user_buffer), + op->user_length), + [weak, conn, op, req_id, &pool](asio::error_code ec, size_t n) { + op->bytes_copied = n; + op->completion_status.store( + ec ? TCP_FAILED : TCP_SUCCESS, + std::memory_order_release); + if (op->signal) op->signal->set_comm_done(0); + pool.returnConnection(conn); + auto self = weak.lock(); + if (self) { + std::lock_guard lk(self->read_mu_); + self->pending_reads_.erase(req_id); + } + }); + }); + }); + + return std::make_shared(op); +} + +// ── async_write ───────────────────────────────────────── + +std::shared_ptr +TcpEndpoint::async_write(const std::vector& assign, + int64_t timeout_ms, void*) { + if (assign.empty()) + throw std::runtime_error("TcpEndpoint::async_write: empty assignment"); + + const auto& a = assign[0]; + int32_t local_h = static_cast(std::get<0>(a)); + int32_t remote_h = static_cast(std::get<1>(a)); + uint64_t remote_off = std::get<2>(a); + uint64_t local_off = std::get<3>(a); + size_t length = std::get<4>(a); + + auto local = local_pool_->get_mr_fast(local_h); + auto remote = remote_pool_->get_remote_mr_fast(remote_h); + if (local.length == 0 || remote.length == 0) + throw std::runtime_error("TcpEndpoint::async_write: invalid MR handle"); + + uintptr_t src = local.addr + local.offset + local_off; + + auto conn = ctx_->conn_pool().getConnection(peer_host_, peer_port_); + auto op = TcpOpState::create(); + op->signal->reset_all(); + + if (!conn) { + op->completion_status.store(TCP_FAILED, std::memory_order_release); + op->signal->force_complete(); + return std::make_shared(op); + } + + if (timeout_ms > 0) + set_sndtimeo(conn->socket.native_handle(), timeout_ms); + + SessionHeader hdr{length, remote.addr + remote.offset + remote_off, OP_WRITE}; + auto& pool = ctx_->conn_pool(); + + std::weak_ptr weak = weak_from_this(); + asio::post(ctx_->io_context(), [weak, conn, op, hdr, src, length, timeout_ms, &pool]() { + auto ep = weak.lock(); + if (!ep) { + op->completion_status.store(TCP_CLOSED, std::memory_order_release); + if (op->signal) op->signal->force_complete(); + return; + } + + asio::error_code ec; + SessionHeader net = hdr; + hdr_hton(net); + std::array bufs = { + asio::buffer(&net, sizeof(net)), + asio::buffer(reinterpret_cast(src), length) + }; + asio::async_write(conn->socket, bufs, + [conn, op, timeout_ms, &pool](asio::error_code ec, size_t) { + if (timeout_ms > 0 && conn->socket.is_open()) + TcpEndpoint::set_sndtimeo(conn->socket.native_handle(), 0); + op->completion_status.store( + ec ? TCP_FAILED : TCP_SUCCESS, std::memory_order_release); + if (op->signal) op->signal->set_comm_done(0); + pool.returnConnection(conn); + }); + }); + + return std::make_shared(op); +} + +// ── shutdown ──────────────────────────────────────────── + +void TcpEndpoint::shutdown() { + bool expected = true; + if (!running_.compare_exchange_strong(expected, false)) + return; + + connected_.store(false, std::memory_order_release); + + acceptor_.close(); + + // Force-complete all pending operations. + { + std::lock_guard lk(recv_mu_); + for (auto& pr : pending_recvs_) { + if (pr.op_state && pr.op_state->signal) { + pr.op_state->completion_status.store(TCP_CLOSED, std::memory_order_release); + pr.op_state->signal->force_complete(); + } + } + pending_recvs_.clear(); + } + { + std::lock_guard lk(read_mu_); + for (auto& [_, pending] : pending_reads_) { + if (pending.op_state && pending.op_state->signal) { + pending.op_state->completion_status.store(TCP_CLOSED, std::memory_order_release); + pending.op_state->signal->force_complete(); + } + } + pending_reads_.clear(); + } + + // If self-contained, stop the private TcpContext. + if (own_ctx_) + own_ctx_->shutdown(); +} + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/tcp_endpoint.h b/dlslime/csrc/engine/tcp/tcp_endpoint.h new file mode 100644 index 00000000..344c4901 --- /dev/null +++ b/dlslime/csrc/engine/tcp/tcp_endpoint.h @@ -0,0 +1,133 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "dlslime/csrc/common/json.hpp" +#include "dlslime/csrc/engine/assignment.h" +#include "tcp_connection_pool.h" +#include "tcp_context.h" +#include "tcp_future.h" +#include "tcp_header.h" +#include "tcp_memory_pool.h" +#include "tcp_op_state.h" +#include "tcp_session.h" + +namespace dlslime { +namespace tcp { + +using json = nlohmann::json; + +class TcpEndpoint : public std::enable_shared_from_this { +public: + static constexpr int64_t kDefaultTimeoutMs = 30000; + + // Self-contained: creates its own TcpContext + io_thread. + explicit TcpEndpoint(uint16_t port = 0); + + // Shared context: multiple endpoints share one io_context thread. + TcpEndpoint(TcpContext& ctx, uint16_t port = 0); + + ~TcpEndpoint(); + + TcpEndpoint(const TcpEndpoint&) = delete; + TcpEndpoint& operator=(const TcpEndpoint&) = delete; + + // ── Connection ────────────────────────────────────── + json endpoint_info() const; + void connect(const json& remote_info); + void shutdown(); + + // ── Memory ────────────────────────────────────────── + int32_t register_memory_region(const std::string& name, + uintptr_t ptr, uintptr_t offset, size_t length); + int32_t register_remote_memory_region(const std::string& name, + const json& mr_info); + json mr_info() const; + + // ── Async I/O (all return Future; I/O on io_context thread) ── + + std::shared_ptr async_send( + const chunk_tuple_t& chunk, + int64_t timeout_ms = kDefaultTimeoutMs, + void* stream = nullptr); + + std::shared_ptr async_recv( + const chunk_tuple_t& chunk, + int64_t timeout_ms = kDefaultTimeoutMs, + void* stream = nullptr); + + std::shared_ptr async_read( + const std::vector& assign, + int64_t timeout_ms = kDefaultTimeoutMs, + void* stream = nullptr); + + std::shared_ptr async_write( + const std::vector& assign, + int64_t timeout_ms = kDefaultTimeoutMs, + void* stream = nullptr); + + // ── Accessors ─────────────────────────────────────── + void setId(int64_t id) { id_.store(id, std::memory_order_relaxed); } + int64_t getId() const { return id_.load(std::memory_order_relaxed); } + bool is_connected() const { return connected_.load(std::memory_order_acquire); } + +private: + // ── io_context management ─────────────────────────── + void start_io(); + void do_accept(); + ServerSession::RecvMatcher make_recv_matcher(); + + // ── helpers ───────────────────────────────────────── + bool is_initiator(const std::string& peer_host, uint16_t peer_port) const; + bool write_message(asio::ip::tcp::socket& sock, + const SessionHeader& hdr, const void* payload); + static void set_sndtimeo(int fd, int64_t ms); + + // ── identity ──────────────────────────────────────── + std::atomic id_{-1}; + std::string local_host_{"0.0.0.0"}; + uint16_t local_port_{0}; + std::string peer_host_; + uint16_t peer_port_{0}; + std::atomic connected_{false}; + + // ── asio core ─────────────────────────────────────── + TcpContext* ctx_{nullptr}; + std::unique_ptr own_ctx_; // if self-contained + asio::ip::tcp::acceptor acceptor_; + std::atomic running_{true}; + + // ── memory ────────────────────────────────────────── + std::shared_ptr local_pool_; + std::shared_ptr remote_pool_; + + // ── recv matching ─────────────────────────────────── + struct PendingRecv { + std::shared_ptr op_state; + }; + std::mutex recv_mu_; + std::deque pending_recvs_; + + // ── read matching (connections reserved for response) ── + struct PendingRead { + std::shared_ptr conn; + std::shared_ptr op_state; + }; + std::mutex read_mu_; + std::unordered_map pending_reads_; + std::atomic next_req_id_{1}; +}; + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/tcp_future.h b/dlslime/csrc/engine/tcp/tcp_future.h new file mode 100644 index 00000000..dcf53a4e --- /dev/null +++ b/dlslime/csrc/engine/tcp/tcp_future.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include +#include + +#include "dlslime/csrc/common/pause.h" +#include "dlslime/csrc/device/device_future.h" +#include "tcp_op_state.h" + +namespace dlslime { +namespace tcp { + +class TcpFuture : public DeviceFuture { +public: + explicit TcpFuture(std::shared_ptr op) + : op_state_(std::move(op)) { + if (!op_state_) + throw std::runtime_error("TcpFuture: null op_state"); + } + + int32_t wait() const override { + if (op_state_->signal) + op_state_->signal->wait_comm_done_cpu(op_state_->expected_mask); + return op_state_->completion_status.load(std::memory_order_acquire); + } + + // Block up to timeout_ms milliseconds. Returns true iff completed; + // writes status to *out. On timeout the operation is still in-flight. + bool wait_for(int64_t timeout_ms, int32_t* out) const { + auto deadline = std::chrono::steady_clock::now() + + std::chrono::milliseconds(timeout_ms); + while (true) { + if (op_state_->signal) { + uint32_t m = op_state_->signal->get_comm_done_mask(); + if ((m & op_state_->expected_mask) == op_state_->expected_mask) { + if (out) *out = op_state_->completion_status.load( + std::memory_order_acquire); + return true; + } + } + if (std::chrono::steady_clock::now() >= deadline) { + if (op_state_->signal) { + uint32_t m = op_state_->signal->get_comm_done_mask(); + if ((m & op_state_->expected_mask) == op_state_->expected_mask) { + if (out) *out = op_state_->completion_status.load( + std::memory_order_acquire); + return true; + } + } + return false; + } + machnet_pause(); + } + } + +protected: + std::shared_ptr op_state_; +}; + +class TcpSendFuture : public TcpFuture { public: using TcpFuture::TcpFuture; }; +class TcpRecvFuture : public TcpFuture { public: using TcpFuture::TcpFuture; }; +class TcpReadWriteFuture : public TcpFuture { public: using TcpFuture::TcpFuture; }; + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/tcp_header.h b/dlslime/csrc/engine/tcp/tcp_header.h new file mode 100644 index 00000000..313187d6 --- /dev/null +++ b/dlslime/csrc/engine/tcp/tcp_header.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include + +namespace dlslime { +namespace tcp { + +// 17-byte wire header, referenced from Mooncake SessionHeader. +// offset size field +// 0 8 size payload byte count (htole64 / le64toh) +// 8 8 addr remote buffer virtual address +// 16 1 opcode SEND=0x00 READ=0x01 WRITE=0x02 + +#pragma pack(push, 1) +struct SessionHeader { + uint64_t size; + uint64_t addr; + uint8_t opcode; +}; +#pragma pack(pop) + +static_assert(sizeof(SessionHeader) == 17, "SessionHeader must be 17 bytes"); + +enum OpCode : uint8_t { + OP_SEND = 0x00, // header + payload → peer recv matches + OP_READ = 0x01, // header only → peer reads local memory → sends data back + OP_WRITE = 0x02, // header + payload → peer writes to local memory +}; + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/tcp_memory_pool.cpp b/dlslime/csrc/engine/tcp/tcp_memory_pool.cpp new file mode 100644 index 00000000..1b540775 --- /dev/null +++ b/dlslime/csrc/engine/tcp/tcp_memory_pool.cpp @@ -0,0 +1,128 @@ +#include "tcp_memory_pool.h" + +namespace dlslime { +namespace tcp { + +// ── local MR ──────────────────────────────────────────── + +int32_t TcpMemoryPool::register_memory_region( + uintptr_t addr, uint64_t offset, size_t length, + std::optional name) { + + auto pit = ptr_to_handle_.find(addr); + if (pit != ptr_to_handle_.end()) { + int32_t h = pit->second; + if (h >= 0 && static_cast(h) < handle_to_mr_.size() + && handle_to_mr_[h].addr == addr + && handle_to_mr_[h].length >= length) { + if (name.has_value()) name_to_handle_[*name] = h; + return h; + } + } + + int32_t h = static_cast(handle_to_mr_.size()); + handle_to_mr_.push_back({addr, offset, length}); + handle_to_name_.push_back(name.value_or("")); + ptr_to_handle_[addr] = h; + if (name.has_value()) name_to_handle_[*name] = h; + return h; +} + +int32_t TcpMemoryPool::unregister_memory_region(int32_t handle) { + if (handle < 0 || static_cast(handle) >= handle_to_mr_.size()) + return -1; + auto& mr = handle_to_mr_[handle]; + auto& s = handle_to_name_[handle]; + ptr_to_handle_.erase(mr.addr); + if (!s.empty()) name_to_handle_.erase(s); + mr = {}; + s.clear(); + return 0; +} + +// ── remote MR ─────────────────────────────────────────── + +int32_t TcpMemoryPool::register_remote_memory_region( + const json& mr_info, std::optional name) { + + std::string mr_name = name.value_or(mr_info.value("name", "")); + + if (!mr_name.empty()) { + auto it = remote_name_to_handle_.find(mr_name); + if (it != remote_name_to_handle_.end()) { + int32_t h = it->second; + auto& rm = remote_handle_to_mr_[h]; + rm.addr = mr_info.value("addr", 0UL); + rm.offset = mr_info.value("offset", 0UL); + rm.length = mr_info.value("length", 0UL); + return h; + } + } + + int32_t h = static_cast(remote_handle_to_mr_.size()); + remote_handle_to_mr_.push_back({ + mr_info.value("addr", 0UL), + mr_info.value("offset", 0UL), + mr_info.value("length", 0UL) + }); + remote_handle_to_name_.push_back(mr_name); + if (!mr_name.empty()) remote_name_to_handle_[mr_name] = h; + return h; +} + +int32_t TcpMemoryPool::unregister_remote_memory_region(int32_t handle) { + if (handle < 0 || static_cast(handle) >= remote_handle_to_mr_.size()) + return -1; + auto& s = remote_handle_to_name_[handle]; + if (!s.empty()) remote_name_to_handle_.erase(s); + remote_handle_to_mr_[handle] = {}; + s.clear(); + return 0; +} + +// ── fast lookup ───────────────────────────────────────── + +TcpMr TcpMemoryPool::get_mr_fast(int32_t handle) const { + if (handle < 0 || static_cast(handle) >= handle_to_mr_.size()) + return {}; + return handle_to_mr_[handle]; +} + +TcpMr TcpMemoryPool::get_remote_mr_fast(int32_t handle) const { + if (handle < 0 || static_cast(handle) >= remote_handle_to_mr_.size()) + return {}; + return remote_handle_to_mr_[handle]; +} + +int32_t TcpMemoryPool::get_mr_handle(const std::string& name) const { + auto it = name_to_handle_.find(name); + return it != name_to_handle_.end() ? it->second : -1; +} + +int32_t TcpMemoryPool::get_remote_mr_handle(const std::string& name) const { + auto it = remote_name_to_handle_.find(name); + return it != remote_name_to_handle_.end() ? it->second : -1; +} + +// ── serialization ─────────────────────────────────────── + +json TcpMemoryPool::mr_info() const { + json j = json::object(); + for (const auto& [name, h] : name_to_handle_) + if (h >= 0 && static_cast(h) < handle_to_mr_.size() + && handle_to_mr_[h].length > 0) + j[name] = handle_to_mr_[h].json_info(name); + return j; +} + +json TcpMemoryPool::remote_mr_info() const { + json j = json::object(); + for (const auto& [name, h] : remote_name_to_handle_) + if (h >= 0 && static_cast(h) < remote_handle_to_mr_.size() + && remote_handle_to_mr_[h].length > 0) + j[name] = remote_handle_to_mr_[h].json_info(name); + return j; +} + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/tcp_memory_pool.h b/dlslime/csrc/engine/tcp/tcp_memory_pool.h new file mode 100644 index 00000000..249f30cb --- /dev/null +++ b/dlslime/csrc/engine/tcp/tcp_memory_pool.h @@ -0,0 +1,63 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "dlslime/csrc/common/json.hpp" + +namespace dlslime { +namespace tcp { + +using json = nlohmann::json; + +struct TcpMr { + uintptr_t addr{0}; + uint64_t offset{0}; + size_t length{0}; + + json json_info(const std::string& name) const { + return {{"name", name}, {"addr", addr}, + {"offset", offset}, {"length", length}}; + } +}; + +// Pure-bookkeeping pool. No hardware registration needed for TCP. +class TcpMemoryPool { +public: + TcpMemoryPool() = default; + + int32_t register_memory_region(uintptr_t addr, uint64_t offset, + size_t length, + std::optional name = std::nullopt); + int32_t unregister_memory_region(int32_t handle); + + int32_t register_remote_memory_region(const json& mr_info, + std::optional name = std::nullopt); + int32_t unregister_remote_memory_region(int32_t handle); + + TcpMr get_mr_fast(int32_t handle) const; + TcpMr get_remote_mr_fast(int32_t handle) const; + int32_t get_mr_handle(const std::string& name) const; + int32_t get_remote_mr_handle(const std::string& name) const; + + json mr_info() const; + json remote_mr_info() const; + +private: + // local MRs + std::unordered_map name_to_handle_; + std::unordered_map ptr_to_handle_; + std::vector handle_to_mr_; + std::vector handle_to_name_; + + // remote MRs + std::unordered_map remote_name_to_handle_; + std::vector remote_handle_to_mr_; + std::vector remote_handle_to_name_; +}; + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/tcp_op_state.h b/dlslime/csrc/engine/tcp/tcp_op_state.h new file mode 100644 index 00000000..dbf89a2a --- /dev/null +++ b/dlslime/csrc/engine/tcp/tcp_op_state.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include +#include + +#include "dlslime/csrc/device/device_api.h" +#include "dlslime/csrc/device/signal.h" + +namespace dlslime { +namespace tcp { + +enum Status : int32_t { + TCP_SUCCESS = 0, + TCP_FAILED = -1, + TCP_TIMEOUT = -2, + TCP_CLOSED = -3, +}; + +// One per in-flight operation. The io_context thread (or caller for sync +// ops) signals completion via DeviceSignal; the future's wait() spins on +// wait_comm_done_cpu(). +struct TcpOpState { + std::shared_ptr signal; + uint32_t expected_mask{1}; + std::atomic completion_mask{0}; + std::atomic completion_status{TCP_SUCCESS}; + + uintptr_t user_buffer{0}; + size_t user_length{0}; + size_t bytes_copied{0}; + + static std::shared_ptr create() { + auto s = std::make_shared(); + s->signal = dlslime::device::createSignal(false); + return s; + } +}; + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/tcp_session.cpp b/dlslime/csrc/engine/tcp/tcp_session.cpp new file mode 100644 index 00000000..57e2d13a --- /dev/null +++ b/dlslime/csrc/engine/tcp/tcp_session.cpp @@ -0,0 +1,142 @@ +#include "tcp_session.h" + +#include +#include +#include + +#include +#include + +#include "dlslime/csrc/logging.h" + +namespace dlslime { +namespace tcp { + +// ── helpers ───────────────────────────────────────────── + +static void hdr_to_net(SessionHeader& hdr) { + hdr.size = htole64(hdr.size); + hdr.addr = htole64(hdr.addr); +} + +static void hdr_to_host(SessionHeader& hdr) { + hdr.size = le64toh(hdr.size); + hdr.addr = le64toh(hdr.addr); +} + +static bool is_fatal(asio::error_code ec) { + return ec && ec != asio::error::eof; +} + +// ── ServerSession ─────────────────────────────────────── + +ServerSession::ServerSession(asio::ip::tcp::socket socket, + TcpMemoryPool* local_pool, + RecvMatcher recv_matcher) + : socket_(std::move(socket)) + , local_pool_(local_pool) + , recv_matcher_(std::move(recv_matcher)) {} + +void ServerSession::start() { + readHeader(); +} + +void ServerSession::readHeader() { + auto self = shared_from_this(); + header_ = {}; + asio::async_read(socket_, asio::buffer(&header_, sizeof(header_)), + [this, self](asio::error_code ec, size_t /*n*/) { + if (ec) { + if (is_fatal(ec)) + SLIME_LOG_WARN("ServerSession::readHeader ", ec.message()); + return; // connection closed, session ends + } + hdr_to_host(header_); + transferred_ = 0; + dispatch(); + }); +} + +void ServerSession::dispatch() { + switch (header_.opcode) { + case OP_SEND: + if (header_.size == 0) { readHeader(); return; } + chunk_buf_.resize(header_.size); + readBody(header_.size); + break; + + case OP_WRITE: + if (header_.size == 0) { readHeader(); return; } + chunk_buf_.resize(header_.size); + readBody(header_.size); + break; + + case OP_READ: { + uintptr_t addr = static_cast(header_.addr); + size_t sz = static_cast(header_.size); + if (sz == 0) { readHeader(); return; } + // Write back raw data — no header on the response. + auto self = shared_from_this(); + asio::async_write(socket_, + asio::buffer(reinterpret_cast(addr), sz), + [this, self](asio::error_code ec, size_t /*n*/) { + if (ec && is_fatal(ec)) + SLIME_LOG_WARN("ServerSession READ response ", ec.message()); + readHeader(); + }); + break; + } + + default: + SLIME_LOG_WARN("ServerSession: unknown opcode ", + static_cast(header_.opcode)); + readHeader(); + break; + } +} + +void ServerSession::readBody(uint64_t remaining) { + auto self = shared_from_this(); + size_t chunk = std::min(static_cast(remaining), kDefaultChunkSize); + + if (chunk == 0) { + if (header_.opcode == OP_SEND) { + auto slot = recv_matcher_(); + if (slot.buffer && slot.length > 0) { + size_t n = std::min(static_cast(header_.size), + slot.length); + std::memcpy(reinterpret_cast(slot.buffer), + chunk_buf_.data(), n); + if (slot.op_state) { + slot.op_state->bytes_copied = n; + slot.op_state->completion_status.store( + TCP_SUCCESS, std::memory_order_release); + if (slot.op_state->signal) + slot.op_state->signal->set_comm_done(0); + } + } + } else if (header_.opcode == OP_WRITE) { + uintptr_t addr = static_cast(header_.addr); + std::memcpy(reinterpret_cast(addr), + chunk_buf_.data(), header_.size); + } + readHeader(); + return; + } + + size_t offset = transferred_; + asio::async_read(socket_, + asio::buffer(chunk_buf_.data() + offset, chunk), + [this, self, remaining](asio::error_code ec, size_t n) { + if (ec) { + if (is_fatal(ec)) + SLIME_LOG_WARN("ServerSession::readBody ", ec.message()); + return; + } + transferred_ += n; + readBody(remaining - n); + }); +} + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/tcp_session.h b/dlslime/csrc/engine/tcp/tcp_session.h new file mode 100644 index 00000000..470cb186 --- /dev/null +++ b/dlslime/csrc/engine/tcp/tcp_session.h @@ -0,0 +1,60 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "tcp_header.h" +#include "tcp_memory_pool.h" +#include "tcp_op_state.h" + +namespace dlslime { +namespace tcp { + +class TcpConnectionPool; + +constexpr size_t kDefaultChunkSize = 65536; // 64KB + +// ── RecvSlot: returned by RecvMatcher when a SEND matches a pending recv ── +struct RecvSlot { + uintptr_t buffer{0}; + size_t length{0}; + std::shared_ptr op_state; +}; + +// ── ServerSession: handles incoming requests on one connection ── +// +// Lifecycle: start() → readHeader → dispatch → readBody/writeBody ↻ +// Persistent — one session handles many transfers on the same connection. +// Referenced from Mooncake ServerSession. +class ServerSession : public std::enable_shared_from_this { +public: + using RecvMatcher = std::function; + + ServerSession(asio::ip::tcp::socket socket, + TcpMemoryPool* local_pool, + RecvMatcher recv_matcher); + + void start(); + +private: + void readHeader(); + void dispatch(); + void readBody(uint64_t remaining); + + asio::ip::tcp::socket socket_; + TcpMemoryPool* local_pool_; + RecvMatcher recv_matcher_; + SessionHeader header_{}; + uint64_t transferred_{0}; + std::vector chunk_buf_; +}; + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/test_tcp_endpoint.py b/dlslime/csrc/engine/tcp/test_tcp_endpoint.py new file mode 100644 index 00000000..4d6f85b2 --- /dev/null +++ b/dlslime/csrc/engine/tcp/test_tcp_endpoint.py @@ -0,0 +1,210 @@ +"""End-to-end test for TcpEndpoint v3 async primitives with timeout. + +Usage: + LD_LIBRARY_PATH=dlslime PYTHONPATH=. DLSLIME_LOG_LEVEL=0 python3 \ + dlslime/csrc/engine/tcp/test_tcp_endpoint.py +""" + +import ctypes +import threading +import time + +from dlslime import TcpEndpoint, TcpMemoryPool + + +def _sync_run(fn_a, fn_b): + b = threading.Barrier(2) + ta = threading.Thread(target=lambda: (b.wait(), fn_a()), daemon=True) + tb = threading.Thread(target=lambda: (b.wait(), fn_b()), daemon=True) + ta.start(); tb.start() + ta.join(); tb.join() + + +def test_async_send_recv(): + """Two endpoints async_send/async_recv each other.""" + print("=== test_async_send_recv ===") + + buf_a = ctypes.create_string_buffer(4096) + buf_b = ctypes.create_string_buffer(4096) + + ep_a = TcpEndpoint(10001) + ep_b = TcpEndpoint(10002) + h_a = ep_a.register_memory_region("a", ctypes.addressof(buf_a), 0, 4096) + h_b = ep_b.register_memory_region("b", ctypes.addressof(buf_b), 0, 4096) + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + + def run_a(): + ep_a.connect(info_b) + print(" A connected") + ctypes.memmove(ctypes.addressof(buf_a), b"hello", 5) + st = ep_a.async_send((h_a, 0, 5)).wait() + assert st == 0, f"send failed: {st}" + print(" A sent 5 bytes") + st = ep_a.async_recv((h_a, 0, 5)).wait() + assert st == 0, f"recv failed: {st}" + assert bytes(buf_a[:5]) == b"world" + print(" A recv'd: world") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + print(" B connected") + st = ep_b.async_recv((h_b, 0, 5)).wait() + assert st == 0 and bytes(buf_b[:5]) == b"hello" + print(" B recv'd: hello") + ctypes.memmove(ctypes.addressof(buf_b), b"world", 5) + st = ep_b.async_send((h_b, 0, 5)).wait() + assert st == 0 + print(" B sent 5 bytes") + ep_b.shutdown() + + _sync_run(run_a, run_b) + print(" PASSED\n") + + +def test_async_write_read(): + """A writes to B's buffer, then reads from B's buffer.""" + print("=== test_async_write_read ===") + + buf_a = ctypes.create_string_buffer(4096) + buf_b = ctypes.create_string_buffer(4096) + addr_a = ctypes.addressof(buf_a) + + ep_a = TcpEndpoint(0) + ep_b = TcpEndpoint(0) + + h_a = ep_a.register_memory_region("a", addr_a, 0, 4096) + h_b = ep_b.register_memory_region("b", ctypes.addressof(buf_b), 0, 4096) + + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + + h_br = ep_a.register_remote_memory_region("rb", info_b["mr_info"]["b"]) + + test_data = b"hello_from_a" + + def run_a(): + ep_a.connect(info_b) + print(" A connected") + ctypes.memmove(addr_a, test_data, len(test_data)) + st = ep_a.async_write([(h_a, h_br, 0, 0, len(test_data))]).wait() + assert st == 0, f"write failed: {st}" + print(f" A wrote {len(test_data)} bytes to B") + time.sleep(0.1) + st = ep_a.async_read([(h_a, h_br, 0, 0, len(test_data))]).wait() + assert st == 0 and bytes(buf_a[:len(test_data)]) == test_data + print(f" A read from B: {bytes(buf_a[:len(test_data)])}") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + print(" B connected") + time.sleep(0.2) + for _ in range(50): + if bytes(buf_b[:len(test_data)]) == test_data: + break + time.sleep(0.01) + assert bytes(buf_b[:len(test_data)]) == test_data + print(f" B buffer verified") + ep_b.shutdown() + + _sync_run(run_a, run_b) + print(" PASSED\n") + + +def test_recv_timeout(): + """recv times out when peer never sends.""" + print("=== test_recv_timeout ===") + + buf_a = ctypes.create_string_buffer(64) + + ep_a = TcpEndpoint(10003) + h_a = ep_a.register_memory_region("a", ctypes.addressof(buf_a), 0, 64) + ep_b = TcpEndpoint(10004) + + def run_b(): + ep_b.connect(ep_a.endpoint_info()) + time.sleep(1.5) + ep_b.shutdown() + + def run_a(): + ep_a.connect(ep_b.endpoint_info()) + fut = ep_a.async_recv((h_a, 0, 5), timeout_ms=300) + result = fut.wait_for(0.3) + print(f" recv wait_for(0.3s): {result} (expected None)") + assert result is None, f"Expected None (timeout), got {result}" + ep_a.shutdown() + + _sync_run(run_a, run_b) + print(" PASSED\n") + + +def test_send_timeout_ms(): + """async_send accepts timeout_ms parameter.""" + print("=== test_send_timeout_ms ===") + + buf_a = ctypes.create_string_buffer(256) + buf_b = ctypes.create_string_buffer(256) + + ep_a = TcpEndpoint(10005) + ep_b = TcpEndpoint(10006) + h_a = ep_a.register_memory_region("a", ctypes.addressof(buf_a), 0, 256) + h_b = ep_b.register_memory_region("b", ctypes.addressof(buf_b), 0, 256) + + def run_b(): + ep_b.connect(ep_a.endpoint_info()) + st = ep_b.async_recv((h_b, 0, 5)).wait() + assert st == 0 + ep_b.shutdown() + + def run_a(): + ep_a.connect(ep_b.endpoint_info()) + ctypes.memmove(ctypes.addressof(buf_a), b"world", 5) + st = ep_a.async_send((h_a, 0, 5), timeout_ms=10000).wait() + assert st == 0, f"send timeout_ms=10000 failed: {st}" + print(f" async_send with timeout_ms=10000: status={st}") + ep_a.shutdown() + + _sync_run(run_a, run_b) + print(" PASSED\n") + + +def test_default_timeout(): + """async_send uses kDefaultTimeoutMs=30000 when timeout_ms not given.""" + print("=== test_default_timeout ===") + + buf_a = ctypes.create_string_buffer(128) + buf_b = ctypes.create_string_buffer(128) + + ep_a = TcpEndpoint(10007) + ep_b = TcpEndpoint(10008) + h_a = ep_a.register_memory_region("a", ctypes.addressof(buf_a), 0, 128) + h_b = ep_b.register_memory_region("b", ctypes.addressof(buf_b), 0, 128) + + def run_b(): + ep_b.connect(ep_a.endpoint_info()) + st = ep_b.async_recv((h_b, 0, 5)).wait() + assert st == 0 + ep_b.shutdown() + + def run_a(): + ep_a.connect(ep_b.endpoint_info()) + ctypes.memmove(ctypes.addressof(buf_a), b"test!", 5) + # No timeout_ms arg — uses default 30000ms + st = ep_a.async_send((h_a, 0, 5)).wait() + assert st == 0, f"default timeout send failed: {st}" + print(f" async_send with default timeout: status={st}") + ep_a.shutdown() + + _sync_run(run_a, run_b) + print(" PASSED\n") + + +if __name__ == "__main__": + test_async_send_recv() + test_async_write_read() + test_recv_timeout() + test_send_timeout_ms() + test_default_timeout() + print("All TcpEndpoint v3 tests passed!") diff --git a/dlslime/csrc/python/CMakeLists.txt b/dlslime/csrc/python/CMakeLists.txt index 389be03a..1584babc 100755 --- a/dlslime/csrc/python/CMakeLists.txt +++ b/dlslime/csrc/python/CMakeLists.txt @@ -67,6 +67,11 @@ if (BUILD_ASCEND_DIRECT) ) endif() +if (BUILD_TCP) + target_compile_definitions(_slime_c PRIVATE BUILD_TCP) + list(APPEND _slime_c_link_libraries _slime_tcp) +endif() + # Ops moved to NanoCCL - link to NanoCCL if needed # if (BUILD_INTRA_OPS OR BUILD_INTER_OPS) # if (BUILD_INTRA_OPS) diff --git a/dlslime/csrc/python/bind.cpp b/dlslime/csrc/python/bind.cpp index b5a56591..0b1efc90 100644 --- a/dlslime/csrc/python/bind.cpp +++ b/dlslime/csrc/python/bind.cpp @@ -27,6 +27,12 @@ #include "dlslime/csrc/engine/ascend_direct/ascend_remote_memory_pool.h" #endif +#ifdef BUILD_TCP +#include "dlslime/csrc/engine/tcp/tcp_endpoint.h" +#include "dlslime/csrc/engine/tcp/tcp_future.h" +#include "dlslime/csrc/engine/tcp/tcp_memory_pool.h" +#endif + #include "dlslime/csrc/device/signal.h" #ifdef BUILD_RDMA @@ -89,6 +95,12 @@ namespace py = pybind11; #define BUILD_RPC_ENABLED false #endif +#ifdef BUILD_TCP +#define BUILD_TCP_ENABLED true +#else +#define BUILD_TCP_ENABLED false +#endif + // Ops moved to NanoCCL #define BUILD_INTRA_OPS_ENABLED false #define BUILD_INTER_OPS_ENABLED false @@ -102,6 +114,7 @@ PYBIND11_MODULE(_slime_c, m) EXPOSE_BUILD_FLAG(m, BUILD_INTRA_OPS); EXPOSE_BUILD_FLAG(m, BUILD_INTER_OPS); EXPOSE_BUILD_FLAG(m, BUILD_RPC); + EXPOSE_BUILD_FLAG(m, BUILD_TCP); m.def("discover_topology", &dlslime::topology::discoverTopology, @@ -512,6 +525,122 @@ PYBIND11_MODULE(_slime_c, m) py::call_guard()); #endif +#ifdef BUILD_TCP + // ========================================================================= + // TCP Transport + // ========================================================================= + py::class_>( + m, "SlimeTcpSendFuture") + .def("wait", &dlslime::tcp::TcpSendFuture::wait, + py::call_guard()) + .def("wait_for", + [](const dlslime::tcp::TcpSendFuture& self, double sec) -> py::object { + int32_t st = 0; + int64_t ms = static_cast(sec * 1000.0); + if (ms < 0) ms = 0; + if (self.wait_for(ms, &st)) return py::cast(st); + return py::none(); + }, + py::arg("timeout_seconds")); + + py::class_>( + m, "SlimeTcpRecvFuture") + .def("wait", &dlslime::tcp::TcpRecvFuture::wait, + py::call_guard()) + .def("wait_for", + [](const dlslime::tcp::TcpRecvFuture& self, double sec) -> py::object { + int32_t st = 0; + int64_t ms = static_cast(sec * 1000.0); + if (ms < 0) ms = 0; + if (self.wait_for(ms, &st)) return py::cast(st); + return py::none(); + }, + py::arg("timeout_seconds")); + + py::class_>( + m, "SlimeTcpReadWriteFuture") + .def("wait", &dlslime::tcp::TcpReadWriteFuture::wait, + py::call_guard()) + .def("wait_for", + [](const dlslime::tcp::TcpReadWriteFuture& self, double sec) -> py::object { + int32_t st = 0; + int64_t ms = static_cast(sec * 1000.0); + if (ms < 0) ms = 0; + if (self.wait_for(ms, &st)) return py::cast(st); + return py::none(); + }, + py::arg("timeout_seconds")); + + py::class_>( + m, "TcpMemoryPool") + .def(py::init<>()) + .def("register_memory_region", + [](dlslime::tcp::TcpMemoryPool& self, uintptr_t addr, uint64_t offset, + size_t length, py::object name_obj) { + std::optional name; + if (!name_obj.is_none()) name = name_obj.cast(); + return self.register_memory_region(addr, offset, length, name); + }, + py::arg("addr"), py::arg("offset"), py::arg("length"), + py::arg("name") = py::none()) + .def("register_remote_memory_region", + [](dlslime::tcp::TcpMemoryPool& self, const json& mr_info, + py::object name_obj) { + std::optional name; + if (!name_obj.is_none()) name = name_obj.cast(); + return self.register_remote_memory_region(mr_info, name); + }, + py::arg("mr_info"), py::arg("name") = py::none()) + .def("get_mr_handle", &dlslime::tcp::TcpMemoryPool::get_mr_handle, + py::arg("name")) + .def("mr_info", &dlslime::tcp::TcpMemoryPool::mr_info); + + py::class_>( + m, "TcpEndpoint") + .def(py::init(), py::arg("port") = 0) + .def("connect", &dlslime::tcp::TcpEndpoint::connect, + py::arg("remote_info"), py::call_guard()) + .def("endpoint_info", &dlslime::tcp::TcpEndpoint::endpoint_info) + .def("mr_info", &dlslime::tcp::TcpEndpoint::mr_info) + .def("shutdown", &dlslime::tcp::TcpEndpoint::shutdown, + py::call_guard()) + .def("register_memory_region", + &dlslime::tcp::TcpEndpoint::register_memory_region, + py::arg("name"), py::arg("data_ptr"), py::arg("offset"), py::arg("length"), + py::call_guard()) + .def("register_remote_memory_region", + &dlslime::tcp::TcpEndpoint::register_remote_memory_region, + py::arg("name"), py::arg("mr_info"), py::call_guard()) + .def("async_send", + py::overload_cast( + &dlslime::tcp::TcpEndpoint::async_send), + py::arg("chunk"), + py::arg("timeout_ms") = dlslime::tcp::TcpEndpoint::kDefaultTimeoutMs, + py::arg("stream") = nullptr, + py::call_guard()) + .def("async_recv", + py::overload_cast( + &dlslime::tcp::TcpEndpoint::async_recv), + py::arg("chunk"), + py::arg("timeout_ms") = dlslime::tcp::TcpEndpoint::kDefaultTimeoutMs, + py::arg("stream") = nullptr, + py::call_guard()) + .def("async_read", + py::overload_cast&, int64_t, void*>( + &dlslime::tcp::TcpEndpoint::async_read), + py::arg("assign"), + py::arg("timeout_ms") = dlslime::tcp::TcpEndpoint::kDefaultTimeoutMs, + py::arg("stream") = nullptr, + py::call_guard()) + .def("async_write", + py::overload_cast&, int64_t, void*>( + &dlslime::tcp::TcpEndpoint::async_write), + py::arg("assign"), + py::arg("timeout_ms") = dlslime::tcp::TcpEndpoint::kDefaultTimeoutMs, + py::arg("stream") = nullptr, + py::call_guard()); +#endif // BUILD_TCP + // Ops moved to NanoCCL - Python bindings should be in NanoCCL's Python module // #ifdef BUILD_INTRA_OPS // py::class_(m, "AllToAllIntraLLBuffer") From 4211aca3b76b11d2e85df2ff1877f17d0f0a93c4 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 14 May 2026 14:11:39 +0000 Subject: [PATCH 02/10] clean up TcpEndpoint API: remove dead void* stream and unused timeout_ms - Remove void* stream from all 4 async_* methods (RDMA leftover, never used) - Remove timeout_ms from async_recv (recv timeout via future.wait_for()) - Remove ineffective SO_SNDTIMEO calls (no effect on asio::async_write) - Update pybind11 bindings and tests to match - Add tcp/plan.md with v3 architecture documentation Co-Authored-By: Claude Opus 4.7 --- dlslime/csrc/engine/tcp/plan.md | 731 +++++++++++++++++++ dlslime/csrc/engine/tcp/tcp_endpoint.cpp | 36 +- dlslime/csrc/engine/tcp/tcp_endpoint.h | 31 +- dlslime/csrc/engine/tcp/test_tcp_endpoint.py | 2 +- dlslime/csrc/python/bind.cpp | 14 +- 5 files changed, 757 insertions(+), 57 deletions(-) create mode 100755 dlslime/csrc/engine/tcp/plan.md diff --git a/dlslime/csrc/engine/tcp/plan.md b/dlslime/csrc/engine/tcp/plan.md new file mode 100755 index 00000000..720ed8e4 --- /dev/null +++ b/dlslime/csrc/engine/tcp/plan.md @@ -0,0 +1,731 @@ +# DLSlime TcpEndpoint v3 Primitives 架构与实现计划 + +**分支**: `tcp-v3` | **基准**: `main` | **日期**: 2026-05-14 + +--- + +## 1. 架构设计 + +### 1.1 总体架构 + +``` +┌──────────────────────────────────────────────────────────────┐ +│ Python 调用者线程 │ +│ ep.async_send(chunk, timeout_ms=30000) → Future │ +│ ep.async_recv(chunk, timeout_ms=30000) → Future │ +│ ep.async_read(assign, timeout_ms=30000) → Future │ +│ ep.async_write(assign, timeout_ms=30000) → Future │ +│ │ │ +│ │ post lambda │ +│ ▼ │ +│ ┌──────────────────────┐ ┌─────────────────────────────┐ │ +│ │ asio::io_context │ │ TcpConnectionPool │ │ +│ │ (单后台线程) │◄───│ (host, port) → deque │ │ +│ │ │ │ IDLE / ACTIVE / RESERVED │ │ +│ │ async_write ────────┼───►│ 60s 空闲超时 │ │ +│ │ async_read ◄────────┼───►│ │ │ +│ │ async_accept ───────┼───►│ ServerSession │ │ +│ │ │ │ (readHeader→dispatch→ │ │ +│ │ │ │ readBody→readHeader 循环) │ │ +│ └──────────────────────┘ └─────────────────────────────┘ │ +└──────────────────────────────────────────────────────────────┘ +``` + +### 1.2 线程模型 + +| 角色 | 线程 | 职责 | +|------|------|------| +| io_context | 1 个 daemon 线程 | `io_ctx_.run()` — 所有 asio async I/O 回调 | +| 调用者 | N 个 Python 线程 | 调 async_* → 立即返回 Future;wait() 自旋阻塞 | +| accept | io_context | `async_accept` 回调链,每连接创建 ServerSession | + +### 1.3 asio 操作模型 + +``` +调用者线程 io_context 线程 +────────── ────────────── +async_send(chunk, 5000): + ├─ getConnection() [sync, fast] ┌─ async_write(header+payload) + ├─ SO_SNDTIMEO=5s │ → 归还连接 → signal op_state + ├─ asio::post(lambda) ──────────────► │ + └─ return Future ◄─── signal ────────┘ + +async_recv(chunk, 5000): + ├─ pending_recvs_.push(op_state) ┌─ ServerSession::dispatch(OP_SEND) + └─ return Future │ → pop pending_recvs_ + │ │ → memcpy → signal op_state + └── wait_for(5.0) ── timeout? ──┘ + +async_read(assign, 5000): + ├─ getConnection() [RESERVE] ┌─ async_write(OP_READ header) + ├─ asio::post(lambda) ──────────────► │ → async_read(response data) + └─ return Future ◄─── signal ────────┘ → 归还连接 → signal op_state + +async_write(assign, 5000): + ├─ getConnection() [sync, fast] ┌─ async_write(header+payload) + ├─ SO_SNDTIMEO=5s │ → 归还连接 → signal op_state + ├─ asio::post(lambda) ──────────────► │ + └─ return Future ◄─── signal ────────┘ +``` + +--- + +## 2. 线协议设计 + +### 2.1 SessionHeader (17 字节,对齐 Mooncake) + +``` +偏移 大小 字段 +0 8 size (payload 字节数, little-endian: htole64 / le64toh) +8 8 addr (远端 buffer 虚拟地址) +16 1 opcode (操作码) +───────────────── + 17 bytes total +``` + +### 2.2 为什么 3 个 opcode 支持 4 个原语? + +OP_SEND 同时承载 `async_send`(发起方主动 push 数据)和 `async_recv`(接收方 +被动等待)。recv 方不在线上发送任何操作码——它只是向本地 `pending_recvs_` 队列注册 +一个 buffer,然后对端 ServerSession 在收到 OP_SEND 时通过 `RecvMatcher` 回调 pop +队列前端、memcpy 数据并 signal op_state。 + +这与 Mooncake 的设计一致:ServerSession::dispatch(OP_SEND) 先分块读取 payload, +然后通过 recv_matcher_ 匹配本地注册的 recv buffer。不需要独立的 recv opcode—— +SEND 到达本身就隐含了"有一端在等待"的语义。 + +OP_READ 和 OP_WRITE 各需独立 opcode,因为服务端 dispatch 分支逻辑完全不同: +- OP_READ:读取本地内存后异步写回原始数据(无 header) +- OP_WRITE:读取 payload 后 memcpy 到 hdr.addr + +如果有 4 个 opcode(比如独立的 OP_RECV),反而增加冗余——OP_RECV 在语义上等于 +"我准备好接收了",但这已在连接建立时通过 endpoint_info 交换 MR 信息隐式表达, +不需要每个操作发一次。 + +| opcode | 值 | 线格式 | 远端 ServerSession 动作 | DLSlime 原语 | +|--------|-----|--------|------------------------|-------------| +| `OP_SEND` | 0x00 | header{sz, 0, 0x00} + payload | 读 payload → recv_matcher pop → memcpy → signal | **async_send** (发起) / **async_recv** (被动) | +| `OP_READ` | 0x01 | 仅 header{sz, addr, 0x01} | 从本地 addr 读 sz 字节 → async_write 原始数据发回 | **async_read** (调用者 pull) | +| `OP_WRITE` | 0x02 | header{sz, addr, 0x02} + payload | 读 payload → memcpy 到本地 addr | **async_write** (调用者 push) | + +### 2.3 四个原语在线上的完整流程 + +``` +async_send(chunk): + 调用者: getConnection → post to io_ctx → return Future + io_ctx: async_write(sock, [header{OP_SEND}|payload]) + → on_complete: returnConnection → signal op_state + 对端 ServerSession: async_read(header) → dispatch(OP_SEND) + → chunk_buf_.resize → readBody 分块读 payload → recv_matcher_() + → pop pending_recv → memcpy → signal recv op_state + +async_recv(chunk): + 调用者: pending_recvs_.push({buffer, op_state}) → return Future → wait_for(timeout) + (无 opcode 在线路上 — recv 是 SEND 的被动消费方) + +async_read(assign): + 调用者: getConnection(RESERVED) → post to io_ctx → return Future + io_ctx: async_write(sock, header{OP_READ, sz, remote_addr}) + → async_read(sock, user_buffer, sz) + → on_complete: returnConnection → signal op_state + 对端 ServerSession: async_read(header) → dispatch(OP_READ) + → async_write(sock, local[addr], sz) → readHeader 继续 + +async_write(assign): + 调用者: getConnection → post to io_ctx → return Future + io_ctx: async_write(sock, [header{OP_WRITE, sz, remote_addr}|payload]) + → on_complete: returnConnection → signal op_state + 对端 ServerSession: async_read(header) → dispatch(OP_WRITE) + → chunk_buf_.resize → readBody 分块读 payload → memcpy 到 addr +``` + +--- + +## 3. 接口设计 + +### 3.1 C++ TcpEndpoint 公共接口 + +```cpp +class TcpEndpoint : public std::enable_shared_from_this { +public: + // 默认超时 30 秒 + static constexpr int64_t kDefaultTimeoutMs = 30000; + + // ── 构造 ── + + // 【主构造】每个 endpoint 内部自动创建 TcpContext, 调用者无需关心。 + // 这是最常用的场景: 一个 endpoint = 一个 peer 连接。 + explicit TcpEndpoint(uint16_t port = 0); + + // 【次构造】注入外部共享 TcpContext, 用于多 endpoint 复用单 io_context 线程 + // 的高级优化场景 (如 PeerAgent 连接 N 个 peer 时节省 N-1 个线程)。 + // 仅在明确需要跨 endpoint 共享资源时使用。 + TcpEndpoint(TcpContext& ctx, uint16_t port = 0); + + // ── 连接 ── + json endpoint_info() const; // {host, port, mr_info} + void connect(const json& remote_info); + void shutdown(); + + // ── 内存 ── + int32_t register_memory_region(const std::string& name, + uintptr_t ptr, uintptr_t offset, size_t length); + int32_t register_remote_memory_region(const std::string& name, + const json& mr_info); + json mr_info() const; + + // ── 异步通信原语 (全部返回 Future, I/O 在 io_context 线程) ── + // + // timeout_ms 由调用者通过 future.wait_for() 控制实际操作时限; + // 方法签名的 timeout_ms 仅作为 op_state 的提示值传入。 + // recv 的超时完全由 future.wait_for() 控制, 不需要 timeout_ms 参数。 + + // 双边发送 + std::shared_ptr async_send( + const chunk_tuple_t& chunk, + int64_t timeout_ms = kDefaultTimeoutMs); + + // 双边接收 (超时通过 future.wait_for()) + std::shared_ptr async_recv( + const chunk_tuple_t& chunk); + + // 单边读 + std::shared_ptr async_read( + const std::vector& assign, + int64_t timeout_ms = kDefaultTimeoutMs); + + // 单边写 + std::shared_ptr async_write( + const std::vector& assign, + int64_t timeout_ms = kDefaultTimeoutMs); + + // ── 访问器 ── + void setId(int64_t id); + int64_t getId() const; + bool is_connected() const; +}; +``` + +### 3.2 C++ TcpFuture 接口 + +```cpp +class TcpFuture : public DeviceFuture { +public: + // 无限期阻塞等待 + int32_t wait() const override; + + // 限时等待: timeout_ms 毫秒, 成功返回 true 并写 *out + // 超时返回 false (操作仍在进行, 可重试) + bool wait_for(int64_t timeout_ms, int32_t* out) const; +}; + +class TcpSendFuture : public TcpFuture { }; +class TcpRecvFuture : public TcpFuture { }; +class TcpReadWriteFuture : public TcpFuture { }; +``` + +### 3.3 Python 接口 + +```python +from dlslime import TcpEndpoint, TcpMemoryPool + +pool = TcpMemoryPool() +buf = ctypes.create_string_buffer(4096) +h = pool.register_memory_region(ctypes.addressof(buf), 0, 4096, "buf") + +ep = TcpEndpoint(port=0) # 0 = 随机端口 +info = ep.endpoint_info() # {'host': '...', 'port': N, 'mr_info': {...}} + +ep.connect(peer_info) + +# ── 异步原语, 默认 30s 超时 ── +fut = ep.async_send((h, 0, 128)) # 30s 默认超时 +fut = ep.async_send((h, 0, 128), 5000) # 5s 超时 +status = fut.wait() # 阻塞直到完成, 返回 0=成功 + +fut = ep.async_recv((h, 0, 128)) # 超时通过 future 控制 +result = fut.wait_for(3.0) # 3 秒超时, 返回 int 或 None + +fut = ep.async_read([(local_h, remote_h, 0, 0, 128)]) +fut = ep.async_write([(local_h, remote_h, 0, 0, 128)]) + +ep.shutdown() +``` + +--- + +## 4. 通信原语设计详解 + +### 4.1 async_send(chunk, timeout_ms = 30000) + +**语义**: 将本地注册内存的数据异步发送到对端。对端必须已调用 `async_recv()` 注册接收缓冲区。 + +**调用者线程**: +1. `local_pool_->get_mr_fast(mr_key)` — resolve 本地 MR +2. `conn_pool_.getConnection(peer_host_, peer_port_)` — 获取或创建 TCP 连接 +3. `TcpOpState::create()` + `signal->reset_all()` — 创建完成信号 +4. 如果 `timeout_ms > 0`: `setsockopt(fd, SO_SNDTIMEO, timeout_ms)` +5. `asio::post(io_ctx_, lambda)` — 提交到 io_context +6. 立即返回 `TcpSendFuture(op_state)` + +**io_context 线程**: +1. `hdr_hton()` — 字节序转换 header +2. `asio::async_write(sock, [header_buf, payload_buf], callback)` — gather write +3. callback: + - 如果 `timeout_ms > 0`: 恢复 `SO_SNDTIMEO = 0` + - `op->completion_status = ec ? TCP_FAILED : TCP_SUCCESS` + - `conn_pool_.returnConnection(conn)` + - `op->signal->set_comm_done(0)` + +**超时行为**: socket 写超时 → write 失败 → completion_status = TCP_FAILED。调用者 `future.wait()` 得到 -1。 + +### 4.2 async_recv(chunk, timeout_ms = 30000) + +**语义**: 注册接收意图。当对端 `async_send()` 的数据到达时,io_context 线程自动匹配并 memcpy 到注册的 buffer。 + +**调用者线程**: +1. `local_pool_->get_mr_fast(mr_key)` — resolve 本地 MR +2. `TcpOpState::create()` + 设置 `user_buffer`, `user_length` +3. `pending_recvs_.push_back({op_state})` — FIFO 入队 +4. 立即返回 `TcpRecvFuture(op_state)` + +**io_context 线程** (ServerSession::dispatch, OP_SEND 分支): +1. `readBody()` — 分块读取 payload 到 `chunk_buf_` +2. `RecvSlot slot = recv_matcher_()` — pop FIFO 前端 +3. `memcpy(slot.buffer, chunk_buf_.data(), min(payload_len, slot.length))` +4. `slot.op_state->completion_status = TCP_SUCCESS` +5. `slot.op_state->signal->set_comm_done(0)` + +**超时行为**: 调用者使用 `future.wait_for(timeout_ms)` 限时等待。超时返回 None,但 recv 保留在队列中——后续到达的 SEND 仍会完成它(调用者可重试)。 + +### 4.3 async_read(assign, timeout_ms = 30000) + +**语义**: 从对端的注册内存异步读取数据。两步异步操作:发 OP_READ header → 收原始响应数据。 + +**调用者线程**: +1. resolve local + remote MRs +2. `conn_pool_.getConnection(peer_host_, peer_port_)` — RESERVE 连接 +3. `TcpOpState::create()` + 设置 `user_buffer`, `user_length` +4. `asio::post(io_ctx_, lambda)` — 提交到 io_context +5. 立即返回 `TcpReadWriteFuture(op_state)` + +**io_context 线程**: +1. `hdr_hton()` → `asio::async_write(sock, header_buf, callback_1)` +2. callback_1: 如果写失败 → signal TCP_FAILED + returnConnection +3. `asio::async_read(sock, user_buffer_buf, callback_2)` +4. callback_2: + - `op->completion_status = ec ? TCP_FAILED : TCP_SUCCESS` + - `conn_pool_.returnConnection(conn)` + - `op->signal->set_comm_done(0)` + +**对端 ServerSession** (OP_READ 分支): +1. 从 `hdr.addr` 读取 `hdr.size` 字节本地内存 +2. `asio::async_write(sock, raw_data, callback)` — 直接写回原始数据(无 header) +3. `readHeader()` — 继续监听下个请求 + +**超时行为**: `future.wait_for(timeout_ms)`。连接在整个读取期间被 RESERVED,超时后操作继续在后台运行。 + +### 4.4 async_write(assign, timeout_ms = 30000) + +**语义**: 将本地注册内存的数据异步写入对端注册内存。 + +与 `async_send` 相同的 post+async_write 模式,区别: +- header.opcode = OP_WRITE +- header.addr = remote_addr(对端目标 buffer 地址) +- 对端 ServerSession dispatch(OP_WRITE) → readBody → memcpy 到 `hdr.addr` + +**超时行为**: 同 async_send — SO_SNDTIMEO + future.wait_for()。 + +--- + +## 5. 连接池设计 + +### 5.1 状态机 + +``` + getConnection() + [不存在] ────────────────────────► [ACTIVE] (in_use=true) + │ + returnConnection() + │ + ▼ + [IDLE] (in_use=false, 在 deque 中) ──► 60s 无使用 → cleanupIdleConnections() → 关闭 + │ + │ getConnection() 命中 + ▼ + [ACTIVE] (in_use=true, 离开 deque) +``` + +### 5.2 接口 + +```cpp +class TcpConnectionPool { + // 获取 IDLE 连接或创建新 TCP 连接 + std::shared_ptr getConnection(host, port); + + // 归还连接到 IDLE 状态 (或关闭, 如果 socket 已断开) + void returnConnection(std::shared_ptr conn); + + // 淘汰超过 kIdleTimeout (60s) 的空闲连接 + void cleanupIdleConnections(); + + // 关闭所有连接 (shutdown 时调用) + void clear(); +}; +``` + +--- + +## 6. ServerSession 设计 + +### 6.1 生命周期 + +``` +acceptor.async_accept(socket) + → ServerSession(socket, local_pool, recv_matcher) + → session->start() + → readHeader() ──────────────────────────────────────┐ + → async_read(sock, 17B header) │ + → hdr_to_host() │ + → dispatch() │ + ├─ OP_SEND: chunk_buf_.resize → readBody() │ + │ → memcpy → recv_matcher_() → signal │ + ├─ OP_WRITE: chunk_buf_.resize → readBody() │ + │ → memcpy → hdr.addr │ + └─ OP_READ: async_write(sock, local[addr]) │ + → readHeader() ──────────────────────────────────┘ +``` + +### 6.2 RecvMatcher + +```cpp +// ServerSession 持有的回调, 由 TcpEndpoint 注入 +using RecvMatcher = std::function; + +// TcpEndpoint::make_recv_matcher(): +// 返回一个 lambda, 持有 weak_ptr +// 在 recv_mu_ 下 pop pending_recvs_ 队列前端 +// 返回 {buffer, length, op_state} +``` + +--- + +## 7. 文件结构 + +### 新建文件 + +``` +dlslime/csrc/engine/tcp/ +├── CMakeLists.txt # asio 依赖 + _slime_tcp 共享库 +├── tcp_header.h # 17B SessionHeader + 3 opcodes +├── tcp_memory_pool.h/.cpp # 纯簿记 (addr, offset, length) +├── tcp_context.h/.cpp # 共享 io_context + connection_pool + thread +├── tcp_session.h/.cpp # ServerSession (accept 端) + 分块 I/O +├── tcp_connection_pool.h/.cpp # (host, port) 连接池 +├── tcp_op_state.h # 操作状态 (signal + atomic status) +├── tcp_future.h # TcpFuture 层次 (header-only) +├── tcp_endpoint.h/.cpp # TcpEndpoint: async_send/recv/read/write +├── build_and_test.sh # 一键构建+测试 +└── test_tcp_endpoint.py # Python 端到端测试 (4 用例) +``` + +### 修改文件 + +| 文件 | 变更 | +|------|------| +| `CMakeLists.txt` | `slime_option(BUILD_TCP "Build TCP transport" ON)` | +| `dlslime/csrc/engine/CMakeLists.txt` | `if(BUILD_TCP) add_subdirectory(tcp) endif()` | +| `dlslime/csrc/CMakeLists.txt` | `if(BUILD_TCP) target_link_libraries(dlslime INTERFACE _slime_tcp) endif()` | +| `dlslime/csrc/python/CMakeLists.txt` | `if(BUILD_TCP) target_compile_definitions + list(APPEND ... _slime_tcp) endif()` | +| `dlslime/csrc/python/bind.cpp` | `#ifdef BUILD_TCP` — TcpEndpoint, TcpMemoryPool, TcpFuture pybind11 bindings | + +--- + +## 8. 超时机制总结 + +| 原语 | 超时位置 | 默认值 | 实现方式 | +|------|---------|--------|---------| +| async_send | socket write | 30000ms | `setsockopt(SO_SNDTIMEO)` + `future.wait_for()` | +| async_recv | 等待数据到达 | 30000ms | `future.wait_for(timeout_ms)` — 定时自旋轮询 signal | +| async_read | 等待远端响应 | 30000ms | `future.wait_for(timeout_ms)` — 定时自旋轮询 signal | +| async_write | socket write | 30000ms | `setsockopt(SO_SNDTIMEO)` + `future.wait_for()` | + +**wait_for 实现**: +```cpp +bool TcpFuture::wait_for(int64_t timeout_ms, int32_t* out) const { + auto deadline = steady_clock::now() + milliseconds(timeout_ms); + while (true) { + if (signal->get_comm_done_mask() matches expected_mask) { + *out = completion_status; return true; + } + if (steady_clock::now() >= deadline) { + // last check before declaring timeout + if (signal->get_comm_done_mask() matches expected_mask) { + *out = completion_status; return true; + } + return false; + } + machnet_pause(); // CPU relax + } +} +``` + +--- + +## 11. 实现步骤 + +| 阶段 | 文件 | 说明 | +|------|------|------| +| 1. 分支 | `git checkout -b tcp-v3 main` | 基于 main 创建新分支 | +| 2. 头文件 | tcp_header.h, tcp_op_state.h | 17B header + 3 opcodes + op state | +| 3. 内存池 | tcp_memory_pool.h/.cpp | 纯簿记, 无硬件注册 | +| 4. Future | tcp_future.h | header-only, wait + wait_for | +| 5. Context | tcp_context.h/.cpp | 共享 io_context + connection_pool + thread | +| 6. 连接池 | tcp_connection_pool.h/.cpp | get/return/cleanup/clear | +| 7. Session | tcp_session.h/.cpp | ServerSession async_read 回调链 | +| 8. 端点 | tcp_endpoint.h/.cpp | async_send/recv/read/write | +| 9. 构建 | CMakeLists 链 + bind.cpp | BUILD_TCP + pybind11 | +| 10. 测试 | test_tcp_endpoint.py | 5 用例 + timeout 测试 | +| 11. 脚本 | build_and_test.sh | 一键构建+测试 | +| 12. 提交 | git commit | 单 commit, 清晰消息 | + +--- + +## 9. send/recv 设计深度分析 + +### 核心矛盾:RDMA vs TCP 的 send/recv 语义差异 + +RDMA 的 send/recv 是**硬件匹配**的: +- 发送方 post Send WR → 硬件从本地 buffer 取数据 → 发到对端 RQ +- 接收方 post Recv WR → 硬件在 RQ 上预置 WQE (buffer地址 + 长度) +- 硬件按**FIFO 顺序**匹配:第 N 个到达的 SEND 消费第 N 个预置的 RECV +- 如果 SEND 到达时没有 RECV → RNR NAK (Receiver Not Ready) → 发送方重试 +- 如果 SEND 数据量 > RECV buffer → 截断或报错 + +TCP **没有硬件匹配**,所有匹配逻辑必须在软件中实现。这带来了三个核心问题: + +| 问题 | RDMA 方案 | TCP 需要解决 | +|------|---------|------------| +| 匹配: 哪个 SEND 对哪个 RECV? | 硬件 RQ FIFO | 软件队列或 tag 匹配 | +| 顺序: SEND 先到还是 RECV 先到? | 硬件 RNR 重试 | 缓冲或拒绝 | +| 大小: 发送量 > 接收 buffer? | 截断/报错 | 截断或分片 | + +### 三种匹配策略 + +#### 策略 A: FIFO 队列匹配(v3 plan 默认) + +``` +recv(chunk) → pending_recvs_.push_back({buffer, op_state}) +ServerSession dispatch(OP_SEND): + payload = readBody() + slot = recv_matcher_() // pop front + memcpy(slot.buffer, payload, min(len, slot.length)) + signal slot.op_state +``` + +**优点**: 实现简单,与 RDMA 语义一致,足够支持双端点 ping-pong 通信。 +**缺点**: 严格 FIFO——调用者无法指定"这个 recv 对应后面第 N 个 send"。多 slot 场景(如 SlimeRPC 的 slotted mailbox)无法用 FIFO 区分。 + +#### 策略 B: Tag 匹配(Gloo 风格) + +``` +wire: [header{OP_SEND, sz, tag}] + payload +recv(tag, buffer) → pending_recvs_[tag].push({buffer, op_state}) +ServerSession dispatch(OP_SEND): + payload = readBody() + slot = pending_recvs_[hdr.addr_as_tag].pop() + memcpy(slot.buffer, payload) +``` + +**优点**: 灵活,支持多路复用——一个 TCP 连接可以承载多个逻辑流(如 RPC slot)。 +**缺点**: header.addr 字段被复用为 tag(牺牲了 addr 的原始语义),协议复杂度增加。 + +#### 策略 C: Slot 预注册(Gloo Buffer 风格) + +``` +每个 Pair 预先创建 N 个 slot buffer: + pair.createSendBuffer(slot=0, ptr, size) + pair.createRecvBuffer(slot=1, ptr, size) +wire: [header{OP_SEND, sz, 0, slot}] + payload +ServerSession: 直接 lookup slot → memcpy +``` + +**优点**: 零队列开销,O(1) slot 查找,SlimeRPC 天然适配。 +**缺点**: 需要预注册 slot(与当前 DLSlime MR 模型不兼容),灵活度低。 + +### 推荐策略:分层渐进 + +``` +Phase 1 (v3) — FIFO 基础: + pending_recvs_ = deque<{buffer, op_state}> + wire: header{OP_SEND, sz, addr=0} + 匹配: 严格 FIFO + 足够: 双端点 ping-pong、简单 RPC + +Phase 2 — 缓冲早到 SEND: + early_sends_ = deque<{payload_data}> + 如果 dispatch(OP_SEND) 时 pending_recvs_ 为空: + → 缓存 payload 到 early_sends_(带大小上限) + → 下次 recv() 先检查 early_sends_ 再入队 + 避免数据丢失 + +Phase 3 — Tag 匹配 (如需要): + 扩展 header: 用 2 字节 reserved 字段承载 tag + pending_recvs_ = map> + 支持多路复用 +``` + +### send/recv 与 read/write 的本质区别 + +很多人混淆 send/recv 和 write/read: + +| | send/recv | write/read | +|---|---|---| +| 语义 | **双边**:双方都需要显式操作 | **单边**:一方发起,另一方无感知 | +| 数据方向 | send=push, recv=pull (被动) | write=push to remote addr, read=pull from remote addr | +| 远端参与 | recv 方必须预先注册 buffer | 远端 ServerSession 自动处理,无需注册 | +| 寻址方式 | **无地址**(匹配决定目标 buffer) | **有地址**(header.addr 指定远端 buffer) | +| RDMA 类比 | ibv_post_send / ibv_post_recv | ibv_post_send with RDMA_WRITE/RDMA_READ | + +核心洞察:**send/recv 的"地址"是隐式的——通过匹配关系决定; +write/read 的"地址"是显式的——header.addr 直接指向远端内存。** + +这就是为什么 v3 plan 中: +- OP_SEND: header.addr = 0(不使用),通过 FIFO 匹配目标 buffer +- OP_WRITE: header.addr = remote_addr(直接指定远端目标地址) +- OP_READ: header.addr = remote_addr(直接指定远端源地址) + +### v3 实现策略 + +v3 采用策略 A(FIFO),但为策略 C(slot)预留空间: + +```cpp +// 当前: deque — 简单 FIFO +std::deque pending_recvs_; + +// Phase 3 可演进为: map — tag 匹配 +// std::unordered_map> pending_recvs_; +// 同时扩展 header: 用 reserved 字段承载 tag + +void TcpEndpoint::async_recv(const chunk_tuple_t& chunk, + int64_t timeout_ms, void*) { + // resolve MR → op_state → push to FIFO + // Phase 3: push to pending_recvs_[tag] instead +} + +ServerSession::dispatch(OP_SEND): + readBody() → chunk_buf_ + RecvSlot slot = recv_matcher_() + if (slot.buffer == 0): + // Phase 2: buffer early send to early_sends_ + return + memcpy(slot.buffer, chunk_buf_, min(payload_len, slot.length)) + signal slot.op_state +``` + +**recv timeout 语义**(区别于 socket timeout): +- SO_RCVTIMEO 是 socket 级超时(读数据超时) +- `future.wait_for()` 是**注册后等待匹配**的超时 +- 超时后 recv 保留在队列中:后续 SEND 仍可完成它(调用者可重试 wait_for) + +## 12. 验证计划 + +```bash +# 构建 +./dlslime/csrc/engine/tcp/build_and_test.sh build + +# 测试 +./dlslime/csrc/engine/tcp/build_and_test.sh test + +# 全流程 +./dlslime/csrc/engine/tcp/build_and_test.sh +``` + +**测试用例**: +1. `test_async_send_recv` — A async_send → B async_recv, B async_send → A async_recv +2. `test_async_write_read` — A async_write → B buffer, A async_read → verify +3. `test_recv_timeout` — async_recv + wait_for(0.3s) → None (无对端发送) +4. `test_send_timeout` — async_send(timeout_ms=10000) 参数 +5. `test_default_timeout` — async_send() 无参数 → 使用 30000ms 默认值 + +## 10. TcpContext 设计 — 为同步通信和资源共享做准备 + +### 使用优先级 + +TcpContext 类始终存在,ctx_ 成员始终非空。但构造方式有两种优先级: + +| 优先级 | 构造 | 场景 | 占比 | +|--------|------|------|------| +| **主** | `TcpEndpoint(port)` | 单 endpoint, 内部自动 new TcpContext | ~90% | +| **次** | `TcpEndpoint(ctx, port)` | 多 endpoint 共享 io_context 线程 | ~10% | + +**默认路径**:调用者无需感知 TcpContext——每个 endpoint 构造时内部 `make_unique()`, +自动创建 io_context + 后台线程 + 连接池。代码最简洁。 + +**高级路径**:当 PeerAgent 连接 N 个 peer 时,可手动创建一个 TcpContext 并注入到 N 个 +TcpEndpoint,将 N 个线程合并为 1 个。TcpContext 也用于测试中精确控制 io_context 生命周期。 + +两种路径不互斥——同一进程可混合使用。TcpContext 类永不删除,ctx_ 成员永不删除。 + +### TcpContext 接口 + +```cpp +class TcpContext { +public: + TcpContext(); // 创建 io_context + 启动后台线程 + ~TcpContext(); // stop + join + clear pool + + asio::io_context& io_context() { return io_ctx_; } + TcpConnectionPool& conn_pool() { return conn_pool_; } + void shutdown(); + +private: + asio::io_context io_ctx_; + std::thread io_thread_; + TcpConnectionPool conn_pool_{io_ctx_}; + bool running_{true}; +}; +``` + +### TcpEndpoint 与 TcpContext 的关系 + +```cpp +class TcpEndpoint { + // 【主构造】自包含 — 内部创建 TcpContext + explicit TcpEndpoint(uint16_t port = 0) + : own_ctx_(std::make_unique()) // 自动创建 + , acceptor_(own_ctx_->io_context()) + , ... { + ctx_ = own_ctx_.get(); // ctx_ → 内部 context + } + + // 【次构造】共享 — 注入外部 TcpContext + TcpEndpoint(TcpContext& ctx, uint16_t port = 0) + : acceptor_(ctx.io_context()) + , ... { + ctx_ = &ctx; // ctx_ → 外部 context, own_ctx_ = nullptr + } + +private: + TcpContext* ctx_{nullptr}; // 始终非空 + std::unique_ptr own_ctx_; // 仅主构造时非空 + // ... +}; +``` + +### 为同步通信预留 + +有了共享 TcpContext,同步包装器可以不依赖单个 endpoint 的事件循环: + +```cpp +// 未来 sync_send: 调 async_send + 立刻 future.wait() +std::shared_ptr sync_send(TcpEndpoint& ep, + const chunk_tuple_t& chunk, + int64_t timeout_ms = 30000) { + auto fut = ep.async_send(chunk, timeout_ms); + fut->wait(); // 阻塞调用者线程直到 io_context 完成 + return fut; +} +``` + +同步版本只是 async + wait() 的语法糖,不需要独立的底层实现。 \ No newline at end of file diff --git a/dlslime/csrc/engine/tcp/tcp_endpoint.cpp b/dlslime/csrc/engine/tcp/tcp_endpoint.cpp index df0792e3..eda64004 100644 --- a/dlslime/csrc/engine/tcp/tcp_endpoint.cpp +++ b/dlslime/csrc/engine/tcp/tcp_endpoint.cpp @@ -20,13 +20,6 @@ static void hdr_hton(SessionHeader& h) { h.addr = htole64(h.addr); } -void TcpEndpoint::set_sndtimeo(int fd, int64_t ms) { - struct timeval tv; - tv.tv_sec = static_cast(ms / 1000); - tv.tv_usec = static_cast((ms % 1000) * 1000); - setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)); -} - // ── RecvMatcher factory ──────────────────────────────── ServerSession::RecvMatcher TcpEndpoint::make_recv_matcher() { @@ -173,7 +166,7 @@ bool TcpEndpoint::write_message(tcp::socket& sock, // ── async_send ────────────────────────────────────────── std::shared_ptr -TcpEndpoint::async_send(const chunk_tuple_t& chunk, int64_t timeout_ms, void*) { +TcpEndpoint::async_send(const chunk_tuple_t& chunk, int64_t timeout_ms) { auto mr = local_pool_->get_mr_fast(static_cast(std::get<0>(chunk))); if (mr.length == 0) throw std::runtime_error("TcpEndpoint::async_send: invalid local MR"); @@ -191,14 +184,11 @@ TcpEndpoint::async_send(const chunk_tuple_t& chunk, int64_t timeout_ms, void*) { return std::make_shared(op); } - if (timeout_ms > 0) - set_sndtimeo(conn->socket.native_handle(), timeout_ms); - SessionHeader hdr{len, 0, OP_SEND}; auto& pool = ctx_->conn_pool(); std::weak_ptr weak = weak_from_this(); - asio::post(ctx_->io_context(), [weak, conn, op, hdr, src, len, timeout_ms, &pool]() { + asio::post(ctx_->io_context(), [weak, conn, op, hdr, src, len, &pool]() { auto ep = weak.lock(); if (!ep) { op->completion_status.store(TCP_CLOSED, std::memory_order_release); @@ -214,9 +204,7 @@ TcpEndpoint::async_send(const chunk_tuple_t& chunk, int64_t timeout_ms, void*) { asio::buffer(reinterpret_cast(src), len) }; asio::async_write(conn->socket, bufs, - [conn, op, timeout_ms, &pool](asio::error_code ec, size_t) { - if (timeout_ms > 0 && conn->socket.is_open()) - TcpEndpoint::set_sndtimeo(conn->socket.native_handle(), 0); + [conn, op, &pool](asio::error_code ec, size_t) { op->completion_status.store( ec ? TCP_FAILED : TCP_SUCCESS, std::memory_order_release); if (op->signal) op->signal->set_comm_done(0); @@ -230,7 +218,7 @@ TcpEndpoint::async_send(const chunk_tuple_t& chunk, int64_t timeout_ms, void*) { // ── async_recv ────────────────────────────────────────── std::shared_ptr -TcpEndpoint::async_recv(const chunk_tuple_t& chunk, int64_t /*timeout_ms*/, void*) { +TcpEndpoint::async_recv(const chunk_tuple_t& chunk) { auto mr = local_pool_->get_mr_fast(static_cast(std::get<0>(chunk))); if (mr.length == 0) throw std::runtime_error("TcpEndpoint::async_recv: invalid local MR"); @@ -252,7 +240,7 @@ TcpEndpoint::async_recv(const chunk_tuple_t& chunk, int64_t /*timeout_ms*/, void std::shared_ptr TcpEndpoint::async_read(const std::vector& assign, - int64_t /*timeout_ms*/, void*) { + int64_t /*timeout_ms*/) { if (assign.empty()) throw std::runtime_error("TcpEndpoint::async_read: empty assignment"); @@ -315,7 +303,6 @@ TcpEndpoint::async_read(const std::vector& assign, return; } - // Read raw response data (no header). asio::async_read(conn->socket, asio::buffer(reinterpret_cast(op->user_buffer), op->user_length), @@ -342,7 +329,7 @@ TcpEndpoint::async_read(const std::vector& assign, std::shared_ptr TcpEndpoint::async_write(const std::vector& assign, - int64_t timeout_ms, void*) { + int64_t /*timeout_ms*/) { if (assign.empty()) throw std::runtime_error("TcpEndpoint::async_write: empty assignment"); @@ -370,14 +357,11 @@ TcpEndpoint::async_write(const std::vector& assign, return std::make_shared(op); } - if (timeout_ms > 0) - set_sndtimeo(conn->socket.native_handle(), timeout_ms); - SessionHeader hdr{length, remote.addr + remote.offset + remote_off, OP_WRITE}; auto& pool = ctx_->conn_pool(); std::weak_ptr weak = weak_from_this(); - asio::post(ctx_->io_context(), [weak, conn, op, hdr, src, length, timeout_ms, &pool]() { + asio::post(ctx_->io_context(), [weak, conn, op, hdr, src, length, &pool]() { auto ep = weak.lock(); if (!ep) { op->completion_status.store(TCP_CLOSED, std::memory_order_release); @@ -393,9 +377,7 @@ TcpEndpoint::async_write(const std::vector& assign, asio::buffer(reinterpret_cast(src), length) }; asio::async_write(conn->socket, bufs, - [conn, op, timeout_ms, &pool](asio::error_code ec, size_t) { - if (timeout_ms > 0 && conn->socket.is_open()) - TcpEndpoint::set_sndtimeo(conn->socket.native_handle(), 0); + [conn, op, &pool](asio::error_code ec, size_t) { op->completion_status.store( ec ? TCP_FAILED : TCP_SUCCESS, std::memory_order_release); if (op->signal) op->signal->set_comm_done(0); @@ -417,7 +399,6 @@ void TcpEndpoint::shutdown() { acceptor_.close(); - // Force-complete all pending operations. { std::lock_guard lk(recv_mu_); for (auto& pr : pending_recvs_) { @@ -439,7 +420,6 @@ void TcpEndpoint::shutdown() { pending_reads_.clear(); } - // If self-contained, stop the private TcpContext. if (own_ctx_) own_ctx_->shutdown(); } diff --git a/dlslime/csrc/engine/tcp/tcp_endpoint.h b/dlslime/csrc/engine/tcp/tcp_endpoint.h index 344c4901..29996b58 100644 --- a/dlslime/csrc/engine/tcp/tcp_endpoint.h +++ b/dlslime/csrc/engine/tcp/tcp_endpoint.h @@ -9,7 +9,6 @@ #include #include #include -#include #include #include @@ -32,10 +31,10 @@ class TcpEndpoint : public std::enable_shared_from_this { public: static constexpr int64_t kDefaultTimeoutMs = 30000; - // Self-contained: creates its own TcpContext + io_thread. + // 【主构造】自包含 TcpContext (最常用) explicit TcpEndpoint(uint16_t port = 0); - // Shared context: multiple endpoints share one io_context thread. + // 【次构造】共享 TcpContext (多 endpoint 复用单 io_context 线程) TcpEndpoint(TcpContext& ctx, uint16_t port = 0); ~TcpEndpoint(); @@ -55,27 +54,26 @@ class TcpEndpoint : public std::enable_shared_from_this { const json& mr_info); json mr_info() const; - // ── Async I/O (all return Future; I/O on io_context thread) ── + // ── Async I/O (all return Future immediately; I/O runs on io_context thread) ── + // Bilateral send. timeout_ms controls socket write timeout (SO_SNDTIMEO). std::shared_ptr async_send( const chunk_tuple_t& chunk, - int64_t timeout_ms = kDefaultTimeoutMs, - void* stream = nullptr); + int64_t timeout_ms = kDefaultTimeoutMs); + // Bilateral recv. Timeout via future.wait_for(). std::shared_ptr async_recv( - const chunk_tuple_t& chunk, - int64_t timeout_ms = kDefaultTimeoutMs, - void* stream = nullptr); + const chunk_tuple_t& chunk); + // Unilateral read: request remote to send data from registered buffer. std::shared_ptr async_read( const std::vector& assign, - int64_t timeout_ms = kDefaultTimeoutMs, - void* stream = nullptr); + int64_t timeout_ms = kDefaultTimeoutMs); + // Unilateral write: push data to remote registered buffer. std::shared_ptr async_write( const std::vector& assign, - int64_t timeout_ms = kDefaultTimeoutMs, - void* stream = nullptr); + int64_t timeout_ms = kDefaultTimeoutMs); // ── Accessors ─────────────────────────────────────── void setId(int64_t id) { id_.store(id, std::memory_order_relaxed); } @@ -83,16 +81,13 @@ class TcpEndpoint : public std::enable_shared_from_this { bool is_connected() const { return connected_.load(std::memory_order_acquire); } private: - // ── io_context management ─────────────────────────── void start_io(); void do_accept(); ServerSession::RecvMatcher make_recv_matcher(); - // ── helpers ───────────────────────────────────────── bool is_initiator(const std::string& peer_host, uint16_t peer_port) const; bool write_message(asio::ip::tcp::socket& sock, const SessionHeader& hdr, const void* payload); - static void set_sndtimeo(int fd, int64_t ms); // ── identity ──────────────────────────────────────── std::atomic id_{-1}; @@ -104,7 +99,7 @@ class TcpEndpoint : public std::enable_shared_from_this { // ── asio core ─────────────────────────────────────── TcpContext* ctx_{nullptr}; - std::unique_ptr own_ctx_; // if self-contained + std::unique_ptr own_ctx_; asio::ip::tcp::acceptor acceptor_; std::atomic running_{true}; @@ -119,7 +114,7 @@ class TcpEndpoint : public std::enable_shared_from_this { std::mutex recv_mu_; std::deque pending_recvs_; - // ── read matching (connections reserved for response) ── + // ── read matching ─────────────────────────────────── struct PendingRead { std::shared_ptr conn; std::shared_ptr op_state; diff --git a/dlslime/csrc/engine/tcp/test_tcp_endpoint.py b/dlslime/csrc/engine/tcp/test_tcp_endpoint.py index 4d6f85b2..9a406326 100644 --- a/dlslime/csrc/engine/tcp/test_tcp_endpoint.py +++ b/dlslime/csrc/engine/tcp/test_tcp_endpoint.py @@ -130,7 +130,7 @@ def run_b(): def run_a(): ep_a.connect(ep_b.endpoint_info()) - fut = ep_a.async_recv((h_a, 0, 5), timeout_ms=300) + fut = ep_a.async_recv((h_a, 0, 5)) result = fut.wait_for(0.3) print(f" recv wait_for(0.3s): {result} (expected None)") assert result is None, f"Expected None (timeout), got {result}" diff --git a/dlslime/csrc/python/bind.cpp b/dlslime/csrc/python/bind.cpp index 0b1efc90..e865c0e9 100644 --- a/dlslime/csrc/python/bind.cpp +++ b/dlslime/csrc/python/bind.cpp @@ -612,32 +612,26 @@ PYBIND11_MODULE(_slime_c, m) &dlslime::tcp::TcpEndpoint::register_remote_memory_region, py::arg("name"), py::arg("mr_info"), py::call_guard()) .def("async_send", - py::overload_cast( + py::overload_cast( &dlslime::tcp::TcpEndpoint::async_send), py::arg("chunk"), py::arg("timeout_ms") = dlslime::tcp::TcpEndpoint::kDefaultTimeoutMs, - py::arg("stream") = nullptr, py::call_guard()) .def("async_recv", - py::overload_cast( - &dlslime::tcp::TcpEndpoint::async_recv), + &dlslime::tcp::TcpEndpoint::async_recv, py::arg("chunk"), - py::arg("timeout_ms") = dlslime::tcp::TcpEndpoint::kDefaultTimeoutMs, - py::arg("stream") = nullptr, py::call_guard()) .def("async_read", - py::overload_cast&, int64_t, void*>( + py::overload_cast&, int64_t>( &dlslime::tcp::TcpEndpoint::async_read), py::arg("assign"), py::arg("timeout_ms") = dlslime::tcp::TcpEndpoint::kDefaultTimeoutMs, - py::arg("stream") = nullptr, py::call_guard()) .def("async_write", - py::overload_cast&, int64_t, void*>( + py::overload_cast&, int64_t>( &dlslime::tcp::TcpEndpoint::async_write), py::arg("assign"), py::arg("timeout_ms") = dlslime::tcp::TcpEndpoint::kDefaultTimeoutMs, - py::arg("stream") = nullptr, py::call_guard()); #endif // BUILD_TCP From 3fea51d4be14e230cee8b3b28df1e41f0a0ed76b Mon Sep 17 00:00:00 2001 From: root Date: Sun, 17 May 2026 08:03:46 +0000 Subject: [PATCH 03/10] refine TcpEndpoint: ip constructor, remove offset, cleanup pool MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - TcpEndpoint(ip, port): bind to specific NIC instead of hardcoded 0.0.0.0 - TcpEndpoint(TcpContext&): =delete until multi-endpoint semantics resolved - TcpMemoryPool: remove offset field — caller passes final address directly - TcpConnectionPool: call cleanupIdleConnections in getConnection hot path with lock parameter; remove dead connections during iteration; fix returnConnection value-type bug; kIdleTimeout 60→300s - connect: accept repeated calls (remove connected_ guard) - register_memory_region: simplified signature without offset - Python: constructor keyword args, offset removed from bindings Co-Authored-By: Claude Opus 4.7 --- dlslime/csrc/engine/tcp/plan.md | 16 ++--- .../csrc/engine/tcp/tcp_connection_pool.cpp | 69 +++++++++++++------ dlslime/csrc/engine/tcp/tcp_connection_pool.h | 2 +- dlslime/csrc/engine/tcp/tcp_endpoint.cpp | 45 +++++------- dlslime/csrc/engine/tcp/tcp_endpoint.h | 18 ++--- dlslime/csrc/engine/tcp/tcp_memory_pool.cpp | 9 ++- dlslime/csrc/engine/tcp/tcp_memory_pool.h | 6 +- dlslime/csrc/engine/tcp/test_tcp_endpoint.py | 38 +++++----- dlslime/csrc/python/bind.cpp | 11 +-- 9 files changed, 115 insertions(+), 99 deletions(-) mode change 100644 => 100755 dlslime/csrc/engine/tcp/tcp_connection_pool.cpp mode change 100644 => 100755 dlslime/csrc/engine/tcp/tcp_connection_pool.h mode change 100644 => 100755 dlslime/csrc/engine/tcp/tcp_endpoint.cpp mode change 100644 => 100755 dlslime/csrc/engine/tcp/tcp_endpoint.h mode change 100644 => 100755 dlslime/csrc/engine/tcp/tcp_memory_pool.cpp mode change 100644 => 100755 dlslime/csrc/engine/tcp/tcp_memory_pool.h mode change 100644 => 100755 dlslime/csrc/engine/tcp/test_tcp_endpoint.py diff --git a/dlslime/csrc/engine/tcp/plan.md b/dlslime/csrc/engine/tcp/plan.md index 720ed8e4..34f4f1a1 100755 --- a/dlslime/csrc/engine/tcp/plan.md +++ b/dlslime/csrc/engine/tcp/plan.md @@ -153,14 +153,12 @@ public: // ── 构造 ── - // 【主构造】每个 endpoint 内部自动创建 TcpContext, 调用者无需关心。 - // 这是最常用的场景: 一个 endpoint = 一个 peer 连接。 - explicit TcpEndpoint(uint16_t port = 0); + // 【主构造】ip 绑定网卡地址 (默认 0.0.0.0), port=0 随机端口 + explicit TcpEndpoint(const std::string& ip = "0.0.0.0", uint16_t port = 0); - // 【次构造】注入外部共享 TcpContext, 用于多 endpoint 复用单 io_context 线程 - // 的高级优化场景 (如 PeerAgent 连接 N 个 peer 时节省 N-1 个线程)。 - // 仅在明确需要跨 endpoint 共享资源时使用。 - TcpEndpoint(TcpContext& ctx, uint16_t port = 0); + // 【次构造】共享 TcpContext — 暂禁用 + // (涉及 context 所有权 / conn_pool 跨 endpoint 管理 / 析构顺序) + TcpEndpoint(TcpContext& ctx, uint16_t port = 0) = delete; // ── 连接 ── json endpoint_info() const; // {host, port, mr_info} @@ -349,7 +347,7 @@ ep.shutdown() returnConnection() │ ▼ - [IDLE] (in_use=false, 在 deque 中) ──► 60s 无使用 → cleanupIdleConnections() → 关闭 + [IDLE] (in_use=false, 在 deque 中) ──► 300s 无使用 → cleanupIdleConnections() → 关闭 │ │ getConnection() 命中 ▼ @@ -366,7 +364,7 @@ class TcpConnectionPool { // 归还连接到 IDLE 状态 (或关闭, 如果 socket 已断开) void returnConnection(std::shared_ptr conn); - // 淘汰超过 kIdleTimeout (60s) 的空闲连接 + // 淘汰超过 kIdleTimeout (300s) 的空闲连接 void cleanupIdleConnections(); // 关闭所有连接 (shutdown 时调用) diff --git a/dlslime/csrc/engine/tcp/tcp_connection_pool.cpp b/dlslime/csrc/engine/tcp/tcp_connection_pool.cpp old mode 100644 new mode 100755 index d2bd77af..fd206fd4 --- a/dlslime/csrc/engine/tcp/tcp_connection_pool.cpp +++ b/dlslime/csrc/engine/tcp/tcp_connection_pool.cpp @@ -42,15 +42,26 @@ TcpConnectionPool::getConnection(const std::string& host, uint16_t port) { auto conn = std::make_shared(std::move(sock), host, port); { std::lock_guard lk(mu_); + // Remove idle connection + cleanupIdleConnections(false); + auto& q = pool_[key]; - for (auto& c : q) { - if (!c->in_use && c->socket.is_open()) { - asio::error_code ign; - conn->socket.close(ign); - c->in_use = true; - c->last_used = std::chrono::steady_clock::now(); - return c; + for (auto q_i = q.begin(); q_i != q.end();) { + auto& c = *q_i; + if (!c->in_use) { + if (c->socket.is_open()) { + c->in_use = true; + c->last_used = std::chrono::steady_clock::now(); + asio::error_code ign; + conn->socket.close(ign); + return c; + } else { + // Remove dead connection + q_i = q.erase(q_i); + continue; + } } + q_i++; } q.push_back(conn); } @@ -60,25 +71,40 @@ TcpConnectionPool::getConnection(const std::string& host, uint16_t port) { void TcpConnectionPool::returnConnection( std::shared_ptr conn) { if (!conn) return; + ConnKey key{conn->host, conn->port}; + std::lock_guard lk(mu_); + auto it = pool_.find(key); + if (it != pool_.end()) { + auto& q = it->second; + for (auto qi = q.begin(); qi != q.end(); ++qi) + if (*qi == conn) { + if (conn->socket.is_open()) { + conn->in_use = false; + conn->last_used = std::chrono::steady_clock::now(); + } else { + q.erase(qi); + } + break; + } + if (q.empty()) pool_.erase(it); + return; + } + + // Connection not found in pool (temporary), close it. if (conn->socket.is_open()) { - conn->in_use = false; - conn->last_used = std::chrono::steady_clock::now(); - } else { - ConnKey key{conn->host, conn->port}; - auto it = pool_.find(key); - if (it != pool_.end()) { - auto& q = it->second; - for (auto qi = q.begin(); qi != q.end(); ++qi) - if (*qi == conn) { q.erase(qi); break; } - if (q.empty()) pool_.erase(it); - } + asio::error_code ec; + conn->socket.close(ec); + if (ec) + SLIME_LOG_WARN("TcpConnectionPool: close temp conn ", conn->host, + ":", conn->port, " failed: ", ec.message()); } + } -void TcpConnectionPool::cleanupIdleConnections() { +void TcpConnectionPool::cleanupIdleConnections(bool lock = true) { auto now = std::chrono::steady_clock::now(); - std::lock_guard lk(mu_); + if (use_lock) std::lock_guard lk(mu_); for (auto it = pool_.begin(); it != pool_.end(); ) { auto& q = it->second; while (!q.empty()) { @@ -102,7 +128,8 @@ void TcpConnectionPool::cleanupIdleConnections() { void TcpConnectionPool::clear() { std::lock_guard lk(mu_); for (auto& [_, q] : pool_) - for (auto& c : q) { asio::error_code ign; c->socket.close(ign); } + // force close + for (auto& c : q) { c->in_use = false; asio::error_code ign; c->socket.close(ign);} pool_.clear(); } diff --git a/dlslime/csrc/engine/tcp/tcp_connection_pool.h b/dlslime/csrc/engine/tcp/tcp_connection_pool.h old mode 100644 new mode 100755 index f06254a3..4175d795 --- a/dlslime/csrc/engine/tcp/tcp_connection_pool.h +++ b/dlslime/csrc/engine/tcp/tcp_connection_pool.h @@ -30,7 +30,7 @@ struct PooledConnection { // States: IDLE (in deque, in_use=false) / ACTIVE (checked out) / RESERVED class TcpConnectionPool { public: - static constexpr std::chrono::seconds kIdleTimeout{60}; + static constexpr std::chrono::seconds kIdleTimeout{300}; explicit TcpConnectionPool(asio::io_context& io_ctx) : io_ctx_(io_ctx) {} diff --git a/dlslime/csrc/engine/tcp/tcp_endpoint.cpp b/dlslime/csrc/engine/tcp/tcp_endpoint.cpp old mode 100644 new mode 100755 index eda64004..fc7c2533 --- a/dlslime/csrc/engine/tcp/tcp_endpoint.cpp +++ b/dlslime/csrc/engine/tcp/tcp_endpoint.cpp @@ -37,31 +37,24 @@ ServerSession::RecvMatcher TcpEndpoint::make_recv_matcher() { // ── Constructor ──────────────────────────────────────── -TcpEndpoint::TcpEndpoint(uint16_t port) +TcpEndpoint::TcpEndpoint(const std::string& ip, uint16_t port) : own_ctx_(std::make_unique()) , acceptor_(own_ctx_->io_context()) , local_pool_(std::make_shared()) - , remote_pool_(std::make_shared()) { + , remote_pool_(std::make_shared()) + , local_host_(ip) { ctx_ = own_ctx_.get(); local_port_ = port; start_io(); } -TcpEndpoint::TcpEndpoint(TcpContext& ctx, uint16_t port) - : acceptor_(ctx.io_context()) - , local_pool_(std::make_shared()) - , remote_pool_(std::make_shared()) { - ctx_ = &ctx; - local_port_ = port; - start_io(); -} - TcpEndpoint::~TcpEndpoint() { shutdown(); } void TcpEndpoint::start_io() { - auto ep = tcp::endpoint(tcp::v4(), local_port_); + auto addr = asio::ip::make_address(local_host_); + auto ep = tcp::endpoint(addr, local_port_); acceptor_.open(ep.protocol()); acceptor_.set_option(tcp::acceptor::reuse_address(true)); acceptor_.bind(ep); @@ -115,14 +108,12 @@ bool TcpEndpoint::is_initiator(const std::string& peer_host, return local_port_ > peer_port; } -void TcpEndpoint::connect(const json& remote_info) { - if (connected_.load(std::memory_order_acquire)) return; - - peer_host_ = remote_info.value("host", ""); - peer_port_ = static_cast(remote_info.value("port", 0)); +void TcpEndpoint::connect(const json& remote_endpoint_info) { + peer_host_ = remote_endpoint_info.value("host", ""); + peer_port_ = static_cast(remote_endpoint_info.value("port", 0)); - if (remote_info.contains("mr_info")) { - for (const auto& [name, info] : remote_info["mr_info"].items()) + if (remote_endpoint_info.contains("mr_info")) { + for (const auto& [name, info] : remote_endpoint_info["mr_info"].items()) remote_pool_->register_remote_memory_region(info, name); } @@ -137,9 +128,9 @@ void TcpEndpoint::connect(const json& remote_info) { // ── memory registration ───────────────────────────────── int32_t TcpEndpoint::register_memory_region(const std::string& name, - uintptr_t ptr, uintptr_t offset, + uintptr_t ptr, size_t length) { - return local_pool_->register_memory_region(ptr, offset, length, name); + return local_pool_->register_memory_region(ptr, length, name); } int32_t TcpEndpoint::register_remote_memory_region(const std::string& name, @@ -171,7 +162,7 @@ TcpEndpoint::async_send(const chunk_tuple_t& chunk, int64_t timeout_ms) { if (mr.length == 0) throw std::runtime_error("TcpEndpoint::async_send: invalid local MR"); - uintptr_t src = mr.addr + mr.offset + std::get<1>(chunk); + uintptr_t src = mr.addr + std::get<1>(chunk); size_t len = std::get<2>(chunk); auto conn = ctx_->conn_pool().getConnection(peer_host_, peer_port_); @@ -225,7 +216,7 @@ TcpEndpoint::async_recv(const chunk_tuple_t& chunk) { auto op = TcpOpState::create(); op->signal->reset_all(); - op->user_buffer = mr.addr + mr.offset + std::get<1>(chunk); + op->user_buffer = mr.addr + std::get<1>(chunk); op->user_length = std::get<2>(chunk); { @@ -258,7 +249,7 @@ TcpEndpoint::async_read(const std::vector& assign, auto op = TcpOpState::create(); op->signal->reset_all(); - op->user_buffer = local.addr + local.offset + local_off; + op->user_buffer = local.addr + local_off; op->user_length = length; auto conn = ctx_->conn_pool().getConnection(peer_host_, peer_port_); @@ -274,7 +265,7 @@ TcpEndpoint::async_read(const std::vector& assign, pending_reads_[req_id] = {conn, op}; } - SessionHeader hdr{length, remote.addr + remote.offset + remote_off, OP_READ}; + SessionHeader hdr{length, remote.addr + remote_off, OP_READ}; auto& pool = ctx_->conn_pool(); std::weak_ptr weak = weak_from_this(); @@ -345,7 +336,7 @@ TcpEndpoint::async_write(const std::vector& assign, if (local.length == 0 || remote.length == 0) throw std::runtime_error("TcpEndpoint::async_write: invalid MR handle"); - uintptr_t src = local.addr + local.offset + local_off; + uintptr_t src = local.addr + local_off; auto conn = ctx_->conn_pool().getConnection(peer_host_, peer_port_); auto op = TcpOpState::create(); @@ -357,7 +348,7 @@ TcpEndpoint::async_write(const std::vector& assign, return std::make_shared(op); } - SessionHeader hdr{length, remote.addr + remote.offset + remote_off, OP_WRITE}; + SessionHeader hdr{length, remote.addr + remote_off, OP_WRITE}; auto& pool = ctx_->conn_pool(); std::weak_ptr weak = weak_from_this(); diff --git a/dlslime/csrc/engine/tcp/tcp_endpoint.h b/dlslime/csrc/engine/tcp/tcp_endpoint.h old mode 100644 new mode 100755 index 29996b58..806e6f6c --- a/dlslime/csrc/engine/tcp/tcp_endpoint.h +++ b/dlslime/csrc/engine/tcp/tcp_endpoint.h @@ -31,11 +31,12 @@ class TcpEndpoint : public std::enable_shared_from_this { public: static constexpr int64_t kDefaultTimeoutMs = 30000; - // 【主构造】自包含 TcpContext (最常用) - explicit TcpEndpoint(uint16_t port = 0); + // ip: 绑定网卡地址 (默认 0.0.0.0). port: 0 = 随机端口. + explicit TcpEndpoint(const std::string& ip = "0.0.0.0", uint16_t port = 0); - // 【次构造】共享 TcpContext (多 endpoint 复用单 io_context 线程) - TcpEndpoint(TcpContext& ctx, uint16_t port = 0); + // 共享 TcpContext — 暂禁用, 多 endpoint 复用单 io_context 时再完善 + // (涉及 context 所有权 / conn_pool 管理 / 析构顺序) + TcpEndpoint(TcpContext& ctx, uint16_t port = 0) = delete; ~TcpEndpoint(); @@ -44,12 +45,12 @@ class TcpEndpoint : public std::enable_shared_from_this { // ── Connection ────────────────────────────────────── json endpoint_info() const; - void connect(const json& remote_info); + void connect(const json& remote_endpoint_info); void shutdown(); // ── Memory ────────────────────────────────────────── int32_t register_memory_region(const std::string& name, - uintptr_t ptr, uintptr_t offset, size_t length); + uintptr_t ptr, size_t length); int32_t register_remote_memory_region(const std::string& name, const json& mr_info); json mr_info() const; @@ -98,8 +99,9 @@ class TcpEndpoint : public std::enable_shared_from_this { std::atomic connected_{false}; // ── asio core ─────────────────────────────────────── - TcpContext* ctx_{nullptr}; - std::unique_ptr own_ctx_; + // ctx_ 始终指向 own_ctx_ (次构造禁用后不再有外部注入路径) + TcpContext* ctx_{nullptr}; + std::unique_ptr own_ctx_; asio::ip::tcp::acceptor acceptor_; std::atomic running_{true}; diff --git a/dlslime/csrc/engine/tcp/tcp_memory_pool.cpp b/dlslime/csrc/engine/tcp/tcp_memory_pool.cpp old mode 100644 new mode 100755 index 1b540775..9f4d51ae --- a/dlslime/csrc/engine/tcp/tcp_memory_pool.cpp +++ b/dlslime/csrc/engine/tcp/tcp_memory_pool.cpp @@ -6,7 +6,7 @@ namespace tcp { // ── local MR ──────────────────────────────────────────── int32_t TcpMemoryPool::register_memory_region( - uintptr_t addr, uint64_t offset, size_t length, + uintptr_t addr, size_t length, std::optional name) { auto pit = ptr_to_handle_.find(addr); @@ -21,10 +21,11 @@ int32_t TcpMemoryPool::register_memory_region( } int32_t h = static_cast(handle_to_mr_.size()); - handle_to_mr_.push_back({addr, offset, length}); + handle_to_mr_.push_back({addr, length}); handle_to_name_.push_back(name.value_or("")); ptr_to_handle_[addr] = h; - if (name.has_value()) name_to_handle_[*name] = h; + if (name.has_value()) + name_to_handle_[*name] = h; return h; } @@ -53,7 +54,6 @@ int32_t TcpMemoryPool::register_remote_memory_region( int32_t h = it->second; auto& rm = remote_handle_to_mr_[h]; rm.addr = mr_info.value("addr", 0UL); - rm.offset = mr_info.value("offset", 0UL); rm.length = mr_info.value("length", 0UL); return h; } @@ -62,7 +62,6 @@ int32_t TcpMemoryPool::register_remote_memory_region( int32_t h = static_cast(remote_handle_to_mr_.size()); remote_handle_to_mr_.push_back({ mr_info.value("addr", 0UL), - mr_info.value("offset", 0UL), mr_info.value("length", 0UL) }); remote_handle_to_name_.push_back(mr_name); diff --git a/dlslime/csrc/engine/tcp/tcp_memory_pool.h b/dlslime/csrc/engine/tcp/tcp_memory_pool.h old mode 100644 new mode 100755 index 249f30cb..a1b686ef --- a/dlslime/csrc/engine/tcp/tcp_memory_pool.h +++ b/dlslime/csrc/engine/tcp/tcp_memory_pool.h @@ -15,12 +15,10 @@ using json = nlohmann::json; struct TcpMr { uintptr_t addr{0}; - uint64_t offset{0}; size_t length{0}; json json_info(const std::string& name) const { - return {{"name", name}, {"addr", addr}, - {"offset", offset}, {"length", length}}; + return {{"name", name}, {"addr", addr}, {"length", length}}; } }; @@ -29,7 +27,7 @@ class TcpMemoryPool { public: TcpMemoryPool() = default; - int32_t register_memory_region(uintptr_t addr, uint64_t offset, + int32_t register_memory_region(uintptr_t addr, size_t length, std::optional name = std::nullopt); int32_t unregister_memory_region(int32_t handle); diff --git a/dlslime/csrc/engine/tcp/test_tcp_endpoint.py b/dlslime/csrc/engine/tcp/test_tcp_endpoint.py old mode 100644 new mode 100755 index 9a406326..0510d89b --- a/dlslime/csrc/engine/tcp/test_tcp_endpoint.py +++ b/dlslime/csrc/engine/tcp/test_tcp_endpoint.py @@ -27,10 +27,10 @@ def test_async_send_recv(): buf_a = ctypes.create_string_buffer(4096) buf_b = ctypes.create_string_buffer(4096) - ep_a = TcpEndpoint(10001) - ep_b = TcpEndpoint(10002) - h_a = ep_a.register_memory_region("a", ctypes.addressof(buf_a), 0, 4096) - h_b = ep_b.register_memory_region("b", ctypes.addressof(buf_b), 0, 4096) + ep_a = TcpEndpoint(port=10001) + ep_b = TcpEndpoint(port=10002) + h_a = ep_a.register_memory_region("a", ctypes.addressof(buf_a), 4096) + h_b = ep_b.register_memory_region("b", ctypes.addressof(buf_b), 4096) info_a = ep_a.endpoint_info() info_b = ep_b.endpoint_info() @@ -71,11 +71,11 @@ def test_async_write_read(): buf_b = ctypes.create_string_buffer(4096) addr_a = ctypes.addressof(buf_a) - ep_a = TcpEndpoint(0) - ep_b = TcpEndpoint(0) + ep_a = TcpEndpoint(port=0) + ep_b = TcpEndpoint(port=0) - h_a = ep_a.register_memory_region("a", addr_a, 0, 4096) - h_b = ep_b.register_memory_region("b", ctypes.addressof(buf_b), 0, 4096) + h_a = ep_a.register_memory_region("a", addr_a, 4096) + h_b = ep_b.register_memory_region("b", ctypes.addressof(buf_b), 4096) info_a = ep_a.endpoint_info() info_b = ep_b.endpoint_info() @@ -119,9 +119,9 @@ def test_recv_timeout(): buf_a = ctypes.create_string_buffer(64) - ep_a = TcpEndpoint(10003) - h_a = ep_a.register_memory_region("a", ctypes.addressof(buf_a), 0, 64) - ep_b = TcpEndpoint(10004) + ep_a = TcpEndpoint(port=10003) + h_a = ep_a.register_memory_region("a", ctypes.addressof(buf_a), 64) + ep_b = TcpEndpoint(port=10004) def run_b(): ep_b.connect(ep_a.endpoint_info()) @@ -147,10 +147,10 @@ def test_send_timeout_ms(): buf_a = ctypes.create_string_buffer(256) buf_b = ctypes.create_string_buffer(256) - ep_a = TcpEndpoint(10005) - ep_b = TcpEndpoint(10006) - h_a = ep_a.register_memory_region("a", ctypes.addressof(buf_a), 0, 256) - h_b = ep_b.register_memory_region("b", ctypes.addressof(buf_b), 0, 256) + ep_a = TcpEndpoint(port=10005) + ep_b = TcpEndpoint(port=10006) + h_a = ep_a.register_memory_region("a", ctypes.addressof(buf_a), 256) + h_b = ep_b.register_memory_region("b", ctypes.addressof(buf_b), 256) def run_b(): ep_b.connect(ep_a.endpoint_info()) @@ -177,10 +177,10 @@ def test_default_timeout(): buf_a = ctypes.create_string_buffer(128) buf_b = ctypes.create_string_buffer(128) - ep_a = TcpEndpoint(10007) - ep_b = TcpEndpoint(10008) - h_a = ep_a.register_memory_region("a", ctypes.addressof(buf_a), 0, 128) - h_b = ep_b.register_memory_region("b", ctypes.addressof(buf_b), 0, 128) + ep_a = TcpEndpoint(port=10007) + ep_b = TcpEndpoint(port=10008) + h_a = ep_a.register_memory_region("a", ctypes.addressof(buf_a), 128) + h_b = ep_b.register_memory_region("b", ctypes.addressof(buf_b), 128) def run_b(): ep_b.connect(ep_a.endpoint_info()) diff --git a/dlslime/csrc/python/bind.cpp b/dlslime/csrc/python/bind.cpp index e865c0e9..bde909d7 100644 --- a/dlslime/csrc/python/bind.cpp +++ b/dlslime/csrc/python/bind.cpp @@ -575,13 +575,13 @@ PYBIND11_MODULE(_slime_c, m) m, "TcpMemoryPool") .def(py::init<>()) .def("register_memory_region", - [](dlslime::tcp::TcpMemoryPool& self, uintptr_t addr, uint64_t offset, + [](dlslime::tcp::TcpMemoryPool& self, uintptr_t addr, size_t length, py::object name_obj) { std::optional name; if (!name_obj.is_none()) name = name_obj.cast(); - return self.register_memory_region(addr, offset, length, name); + return self.register_memory_region(addr, length, name); }, - py::arg("addr"), py::arg("offset"), py::arg("length"), + py::arg("addr"), py::arg("length"), py::arg("name") = py::none()) .def("register_remote_memory_region", [](dlslime::tcp::TcpMemoryPool& self, const json& mr_info, @@ -597,7 +597,8 @@ PYBIND11_MODULE(_slime_c, m) py::class_>( m, "TcpEndpoint") - .def(py::init(), py::arg("port") = 0) + .def(py::init(), + py::arg("ip") = "0.0.0.0", py::arg("port") = 0) .def("connect", &dlslime::tcp::TcpEndpoint::connect, py::arg("remote_info"), py::call_guard()) .def("endpoint_info", &dlslime::tcp::TcpEndpoint::endpoint_info) @@ -606,7 +607,7 @@ PYBIND11_MODULE(_slime_c, m) py::call_guard()) .def("register_memory_region", &dlslime::tcp::TcpEndpoint::register_memory_region, - py::arg("name"), py::arg("data_ptr"), py::arg("offset"), py::arg("length"), + py::arg("name"), py::arg("data_ptr"), py::arg("length"), py::call_guard()) .def("register_remote_memory_region", &dlslime::tcp::TcpEndpoint::register_remote_memory_region, From 2f3e1f012d83a7ec0e1a2d594d9fd18c074e519c Mon Sep 17 00:00:00 2001 From: root Date: Sun, 17 May 2026 11:00:28 +0000 Subject: [PATCH 04/10] simplify ServerSession and enforce MR name constraint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ServerSession: - remove readBody chunking (kDefaultChunkSize / transferred_ / chunk_buf_) — asio::async_read already loops internally, application-level chunking adds nothing but callback overhead - add writeBody(src, len) symmetric to readBody(dst, len) - readBody/writeBody share the same pattern: async I/O → readHeader on done - OP_SEND stays inline (needs signal between read and readHeader) TcpMemoryPool: - register_memory_region: name now mandatory (const std::string&, no optional) - reject empty name and duplicate name with SLIME_LOG_WARN + return -1 - remove handle_to_name_ vector (no longer needed) TcpConnectionPool: - cleanupIdleConnections(bool lock = true) — caller can skip internal lock - getConnection calls cleanupIdleConnections(false) on the hot path Co-Authored-By: Claude Opus 4.7 --- .../csrc/engine/tcp/tcp_connection_pool.cpp | 4 +- dlslime/csrc/engine/tcp/tcp_connection_pool.h | 2 +- dlslime/csrc/engine/tcp/tcp_endpoint.cpp | 2 +- dlslime/csrc/engine/tcp/tcp_memory_pool.cpp | 33 ++++-- dlslime/csrc/engine/tcp/tcp_memory_pool.h | 8 +- dlslime/csrc/engine/tcp/tcp_session.cpp | 101 ++++++++---------- dlslime/csrc/engine/tcp/tcp_session.h | 18 +--- dlslime/csrc/python/bind.cpp | 10 +- 8 files changed, 85 insertions(+), 93 deletions(-) diff --git a/dlslime/csrc/engine/tcp/tcp_connection_pool.cpp b/dlslime/csrc/engine/tcp/tcp_connection_pool.cpp index fd206fd4..2bbaa5bb 100755 --- a/dlslime/csrc/engine/tcp/tcp_connection_pool.cpp +++ b/dlslime/csrc/engine/tcp/tcp_connection_pool.cpp @@ -102,9 +102,9 @@ void TcpConnectionPool::returnConnection( } -void TcpConnectionPool::cleanupIdleConnections(bool lock = true) { +void TcpConnectionPool::cleanupIdleConnections(bool lock) { auto now = std::chrono::steady_clock::now(); - if (use_lock) std::lock_guard lk(mu_); + if (lock) std::lock_guard lk(mu_); for (auto it = pool_.begin(); it != pool_.end(); ) { auto& q = it->second; while (!q.empty()) { diff --git a/dlslime/csrc/engine/tcp/tcp_connection_pool.h b/dlslime/csrc/engine/tcp/tcp_connection_pool.h index 4175d795..2565a72f 100755 --- a/dlslime/csrc/engine/tcp/tcp_connection_pool.h +++ b/dlslime/csrc/engine/tcp/tcp_connection_pool.h @@ -39,7 +39,7 @@ class TcpConnectionPool { void returnConnection(std::shared_ptr conn); - void cleanupIdleConnections(); + void cleanupIdleConnections(bool lock = true); void clear(); private: diff --git a/dlslime/csrc/engine/tcp/tcp_endpoint.cpp b/dlslime/csrc/engine/tcp/tcp_endpoint.cpp index fc7c2533..9a148f9d 100755 --- a/dlslime/csrc/engine/tcp/tcp_endpoint.cpp +++ b/dlslime/csrc/engine/tcp/tcp_endpoint.cpp @@ -157,7 +157,7 @@ bool TcpEndpoint::write_message(tcp::socket& sock, // ── async_send ────────────────────────────────────────── std::shared_ptr -TcpEndpoint::async_send(const chunk_tuple_t& chunk, int64_t timeout_ms) { +TcpEndpoint::async_send(const chunk_tuple_t& chunk, int64_t /*timeout_ms*/) { auto mr = local_pool_->get_mr_fast(static_cast(std::get<0>(chunk))); if (mr.length == 0) throw std::runtime_error("TcpEndpoint::async_send: invalid local MR"); diff --git a/dlslime/csrc/engine/tcp/tcp_memory_pool.cpp b/dlslime/csrc/engine/tcp/tcp_memory_pool.cpp index 9f4d51ae..dc7d8ab3 100755 --- a/dlslime/csrc/engine/tcp/tcp_memory_pool.cpp +++ b/dlslime/csrc/engine/tcp/tcp_memory_pool.cpp @@ -1,13 +1,23 @@ #include "tcp_memory_pool.h" +#include "dlslime/csrc/logging.h" + namespace dlslime { namespace tcp { // ── local MR ──────────────────────────────────────────── int32_t TcpMemoryPool::register_memory_region( - uintptr_t addr, size_t length, - std::optional name) { + uintptr_t addr, size_t length, const std::string& name) { + + if (name.empty()) { + SLIME_LOG_WARN("TcpMemoryPool: empty name rejected"); + return -1; + } + if (name_to_handle_.find(name) != name_to_handle_.end()) { + SLIME_LOG_WARN("TcpMemoryPool: duplicate name '", name, "' rejected"); + return -1; + } auto pit = ptr_to_handle_.find(addr); if (pit != ptr_to_handle_.end()) { @@ -15,29 +25,34 @@ int32_t TcpMemoryPool::register_memory_region( if (h >= 0 && static_cast(h) < handle_to_mr_.size() && handle_to_mr_[h].addr == addr && handle_to_mr_[h].length >= length) { - if (name.has_value()) name_to_handle_[*name] = h; + name_to_handle_[name] = h; return h; } } int32_t h = static_cast(handle_to_mr_.size()); handle_to_mr_.push_back({addr, length}); - handle_to_name_.push_back(name.value_or("")); ptr_to_handle_[addr] = h; - if (name.has_value()) - name_to_handle_[*name] = h; + name_to_handle_[name] = h; return h; } int32_t TcpMemoryPool::unregister_memory_region(int32_t handle) { if (handle < 0 || static_cast(handle) >= handle_to_mr_.size()) return -1; + auto& mr = handle_to_mr_[handle]; - auto& s = handle_to_name_[handle]; ptr_to_handle_.erase(mr.addr); - if (!s.empty()) name_to_handle_.erase(s); + + // Remove all name→handle entries pointing to this handle. + for (auto it = name_to_handle_.begin(); it != name_to_handle_.end(); ) { + if (it->second == handle) + it = name_to_handle_.erase(it); + else + ++it; + } + mr = {}; - s.clear(); return 0; } diff --git a/dlslime/csrc/engine/tcp/tcp_memory_pool.h b/dlslime/csrc/engine/tcp/tcp_memory_pool.h index a1b686ef..c9061708 100755 --- a/dlslime/csrc/engine/tcp/tcp_memory_pool.h +++ b/dlslime/csrc/engine/tcp/tcp_memory_pool.h @@ -27,11 +27,12 @@ class TcpMemoryPool { public: TcpMemoryPool() = default; - int32_t register_memory_region(uintptr_t addr, - size_t length, - std::optional name = std::nullopt); + // name must be non-empty and unique; returns -1 on violation. + int32_t register_memory_region(uintptr_t addr, size_t length, + const std::string& name); int32_t unregister_memory_region(int32_t handle); + // remote MR — name is optional (may come from peer's mr_info). int32_t register_remote_memory_region(const json& mr_info, std::optional name = std::nullopt); int32_t unregister_remote_memory_region(int32_t handle); @@ -49,7 +50,6 @@ class TcpMemoryPool { std::unordered_map name_to_handle_; std::unordered_map ptr_to_handle_; std::vector handle_to_mr_; - std::vector handle_to_name_; // remote MRs std::unordered_map remote_name_to_handle_; diff --git a/dlslime/csrc/engine/tcp/tcp_session.cpp b/dlslime/csrc/engine/tcp/tcp_session.cpp index 57e2d13a..4a8949b4 100644 --- a/dlslime/csrc/engine/tcp/tcp_session.cpp +++ b/dlslime/csrc/engine/tcp/tcp_session.cpp @@ -49,41 +49,56 @@ void ServerSession::readHeader() { if (ec) { if (is_fatal(ec)) SLIME_LOG_WARN("ServerSession::readHeader ", ec.message()); - return; // connection closed, session ends + return; } hdr_to_host(header_); - transferred_ = 0; dispatch(); }); } void ServerSession::dispatch() { switch (header_.opcode) { - case OP_SEND: + + case OP_SEND: { if (header_.size == 0) { readHeader(); return; } - chunk_buf_.resize(header_.size); - readBody(header_.size); + auto slot = recv_matcher_(); + if (!slot.buffer || slot.length == 0) { + SLIME_LOG_WARN("ServerSession: OP_SEND with no pending recv"); + readHeader(); + return; + } + size_t n = std::min(static_cast(header_.size), slot.length); + auto self = shared_from_this(); + asio::async_read(socket_, + asio::buffer(reinterpret_cast(slot.buffer), n), + [this, self, slot, n](asio::error_code ec, size_t /*rn*/) { + if (ec) { + if (is_fatal(ec)) + SLIME_LOG_WARN("ServerSession SEND read: ", ec.message()); + return; + } + if (slot.op_state) { + slot.op_state->bytes_copied = n; + slot.op_state->completion_status.store( + TCP_SUCCESS, std::memory_order_release); + if (slot.op_state->signal) + slot.op_state->signal->set_comm_done(0); + } + readHeader(); + }); break; + } case OP_WRITE: if (header_.size == 0) { readHeader(); return; } - chunk_buf_.resize(header_.size); - readBody(header_.size); + readBody(reinterpret_cast(header_.addr), header_.size); break; case OP_READ: { uintptr_t addr = static_cast(header_.addr); size_t sz = static_cast(header_.size); if (sz == 0) { readHeader(); return; } - // Write back raw data — no header on the response. - auto self = shared_from_this(); - asio::async_write(socket_, - asio::buffer(reinterpret_cast(addr), sz), - [this, self](asio::error_code ec, size_t /*n*/) { - if (ec && is_fatal(ec)) - SLIME_LOG_WARN("ServerSession READ response ", ec.message()); - readHeader(); - }); + writeBody(reinterpret_cast(addr), sz); break; } @@ -95,46 +110,24 @@ void ServerSession::dispatch() { } } -void ServerSession::readBody(uint64_t remaining) { +void ServerSession::readBody(void* dst, size_t len) { auto self = shared_from_this(); - size_t chunk = std::min(static_cast(remaining), kDefaultChunkSize); - - if (chunk == 0) { - if (header_.opcode == OP_SEND) { - auto slot = recv_matcher_(); - if (slot.buffer && slot.length > 0) { - size_t n = std::min(static_cast(header_.size), - slot.length); - std::memcpy(reinterpret_cast(slot.buffer), - chunk_buf_.data(), n); - if (slot.op_state) { - slot.op_state->bytes_copied = n; - slot.op_state->completion_status.store( - TCP_SUCCESS, std::memory_order_release); - if (slot.op_state->signal) - slot.op_state->signal->set_comm_done(0); - } - } - } else if (header_.opcode == OP_WRITE) { - uintptr_t addr = static_cast(header_.addr); - std::memcpy(reinterpret_cast(addr), - chunk_buf_.data(), header_.size); - } - readHeader(); - return; - } + asio::async_read(socket_, asio::buffer(dst, len), + [this, self](asio::error_code ec, size_t /*n*/) { + if (ec && is_fatal(ec)) + SLIME_LOG_WARN("ServerSession::readBody ", ec.message()); + readHeader(); + }); +} - size_t offset = transferred_; - asio::async_read(socket_, - asio::buffer(chunk_buf_.data() + offset, chunk), - [this, self, remaining](asio::error_code ec, size_t n) { - if (ec) { - if (is_fatal(ec)) - SLIME_LOG_WARN("ServerSession::readBody ", ec.message()); - return; - } - transferred_ += n; - readBody(remaining - n); +void ServerSession::writeBody(const void* src, size_t len) { + auto self = shared_from_this(); + asio::async_write(socket_, + asio::buffer(src, len), + [this, self](asio::error_code ec, size_t /*n*/) { + if (ec && is_fatal(ec)) + SLIME_LOG_WARN("ServerSession::writeBody ", ec.message()); + readHeader(); }); } diff --git a/dlslime/csrc/engine/tcp/tcp_session.h b/dlslime/csrc/engine/tcp/tcp_session.h index 470cb186..6c14a841 100644 --- a/dlslime/csrc/engine/tcp/tcp_session.h +++ b/dlslime/csrc/engine/tcp/tcp_session.h @@ -4,11 +4,8 @@ #include #include -#include #include #include -#include -#include #include "tcp_header.h" #include "tcp_memory_pool.h" @@ -19,20 +16,14 @@ namespace tcp { class TcpConnectionPool; -constexpr size_t kDefaultChunkSize = 65536; // 64KB - -// ── RecvSlot: returned by RecvMatcher when a SEND matches a pending recv ── struct RecvSlot { uintptr_t buffer{0}; size_t length{0}; std::shared_ptr op_state; }; -// ── ServerSession: handles incoming requests on one connection ── -// -// Lifecycle: start() → readHeader → dispatch → readBody/writeBody ↻ -// Persistent — one session handles many transfers on the same connection. -// Referenced from Mooncake ServerSession. +// ServerSession: handles incoming requests on one persistent connection. +// Lifecycle: start() → readHeader → dispatch → readBody/writeBody → readHeader ↻ class ServerSession : public std::enable_shared_from_this { public: using RecvMatcher = std::function; @@ -46,14 +37,13 @@ class ServerSession : public std::enable_shared_from_this { private: void readHeader(); void dispatch(); - void readBody(uint64_t remaining); + void readBody(void* dst, size_t len); // read into caller's buffer + void writeBody(const void* src, size_t len); // write from caller's buffer asio::ip::tcp::socket socket_; TcpMemoryPool* local_pool_; RecvMatcher recv_matcher_; SessionHeader header_{}; - uint64_t transferred_{0}; - std::vector chunk_buf_; }; } // namespace tcp diff --git a/dlslime/csrc/python/bind.cpp b/dlslime/csrc/python/bind.cpp index bde909d7..0359583f 100644 --- a/dlslime/csrc/python/bind.cpp +++ b/dlslime/csrc/python/bind.cpp @@ -575,14 +575,8 @@ PYBIND11_MODULE(_slime_c, m) m, "TcpMemoryPool") .def(py::init<>()) .def("register_memory_region", - [](dlslime::tcp::TcpMemoryPool& self, uintptr_t addr, - size_t length, py::object name_obj) { - std::optional name; - if (!name_obj.is_none()) name = name_obj.cast(); - return self.register_memory_region(addr, length, name); - }, - py::arg("addr"), py::arg("length"), - py::arg("name") = py::none()) + &dlslime::tcp::TcpMemoryPool::register_memory_region, + py::arg("addr"), py::arg("length"), py::arg("name")) .def("register_remote_memory_region", [](dlslime::tcp::TcpMemoryPool& self, const json& mr_info, py::object name_obj) { From f4b384a53870b890ef6e6d1c7b9f0e33342c06ba Mon Sep 17 00:00:00 2001 From: root Date: Mon, 18 May 2026 02:18:17 +0000 Subject: [PATCH 05/10] add ClientSession and refactor endpoint I/O to session-driven model MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce ClientSession as the outbound counterpart to ServerSession, driving one I/O operation per instance. Endpoint primitives now create ClientSession instead of ad-hoc lambdas. ClientSession: - start_write(hdr, payload): gather async_write header + body - start_read(hdr, dst): write OP_READ header → async_read response to dst - DoneCallback reports asio::error_code; Primitive signals OpState on done - Self-destructs via shared_ptr when async chain completes Endpoint cleanup: - async_send / async_write / async_read: ad-hoc asio::post + nested lambda replaced with ClientSession creation + start_xxx - Removed pending_reads_ / read_mu_ / next_req_id_ (no longer needed — ClientSession's start_read callback directly delivers the result) - Removed write_message helper (no longer used) - shutdown: removed pending_reads_ cleanup block Co-Authored-By: Claude Opus 4.7 --- dlslime/csrc/engine/tcp/plan_v4.md | 289 +++++++++++++++++++++++ dlslime/csrc/engine/tcp/tcp_endpoint.cpp | 159 +++---------- dlslime/csrc/engine/tcp/tcp_endpoint.h | 24 +- dlslime/csrc/engine/tcp/tcp_session.cpp | 35 +++ dlslime/csrc/engine/tcp/tcp_session.h | 29 ++- 5 files changed, 382 insertions(+), 154 deletions(-) create mode 100644 dlslime/csrc/engine/tcp/plan_v4.md diff --git a/dlslime/csrc/engine/tcp/plan_v4.md b/dlslime/csrc/engine/tcp/plan_v4.md new file mode 100644 index 00000000..4f160d9c --- /dev/null +++ b/dlslime/csrc/engine/tcp/plan_v4.md @@ -0,0 +1,289 @@ +# TcpEndpoint v4 — Future / OpState / Session / Primitive 关系重构 + +## 当前状态 + +四个 async 原语使用 ad-hoc lambda 模式,与 session 概念脱节: + +``` +async_send(chunk): + 取连接 → TcpOpState → asio::post(lambda) → return Future + lambda: async_write(header+payload) → signal op → return conn + +async_read(assign): + 取连接(RESERVE) → TcpOpState → asio::post(lambda) → return Future + lambda: async_write(header) → async_read(response) → signal op → return conn +``` + +问题: +1. I/O 生命周期散落在 lambda 捕获中,无显式状态机 +2. async_read 的 write_header → read_response 是两个回调嵌套 +3. ServerSession 有清晰的 `readHeader → dispatch → readBody/writeBody`, + 但客户端没有对应的 ClientSession +4. `assign_tuple_t` (local_mr, remote_mr, remote_off, local_off, length) 的解析 + 散落在 endpoint 方法中,与 I/O 执行耦合 + +## 接口对齐 + +与 RDMAEndpoint 保持一致(已去除 void* stream, writeWithImm/immRecv): + +```cpp +// TwoSide (对应 RDMA send/recv, 异步化) +std::shared_ptr async_send(const chunk_tuple_t& chunk, + int64_t timeout_ms = kDefaultTimeoutMs); +std::shared_ptr async_recv(const chunk_tuple_t& chunk); + +// OneSide (对应 RDMA read/write, 异步化) +std::shared_ptr async_read( + const std::vector& assign, + int64_t timeout_ms = kDefaultTimeoutMs); +std::shared_ptr async_write( + const std::vector& assign, + int64_t timeout_ms = kDefaultTimeoutMs); +``` + +- `async_` 前缀:TCP 全部为异步(I/O 在 io_context 线程),与 RDMA 的同步 Future 区分 +- `void* stream`:已删除(TCP 无 CUDA stream) +- `std::vector`:接口接受 vector,但 v4 不对多个 assign 做聚合。 + 每个原语调用 = 一个 ClientSession = 一个 Future。 + 多个 assign 的聚合留给上层(SlimeRPC)。 + +## 关键数据结构 + +### 两种 tuple,两种寻址模型 + +```cpp +// send/recv — 双边,只需要本地 buffer 信息 +using chunk_tuple_t = std::tuple; +// mr_handle offset length + +// read/write — 单边,指定本地+远端两个 buffer +using assign_tuple_t = std::tuple; +// local_mr remote_mr remote_off local_off length +``` + +`assign_tuple_t` 已经包含了完成一次单边操作所需的**所有**寻址信息: +- 远端地址 = remote_mr.addr + remote_off → `SessionHeader.addr` +- 本地地址 = local_mr.addr + local_off → 本地读写位置 +- 长度 = length → `SessionHeader.size` + +### 从 assign_tuple_t 到 SessionHeader 的映射(在 Primitive 中完成 MR 解析) + +```cpp +// async_write: assign_tuple_t → SessionHeader + local_src +const auto& a = assign[0]; +auto local = local_pool_->get_mr_fast(std::get<0>(a)); // local_mr handle +auto remote = remote_pool_->get_remote_mr_fast(std::get<1>(a)); // remote_mr handle + +uint64_t remote_addr = remote.addr + std::get<2>(a); // remote_off +uint64_t local_src = local.addr + std::get<3>(a); // local_off +size_t len = std::get<4>(a); // length + +SessionHeader hdr{len, remote_addr, OP_WRITE}; +// ClientSession 拿到的是解析后的 hdr + local_src,不接触 assign_tuple_t +``` + +## v4 目标:四者关系 + +``` +┌─────────────────────────────────────────────────────────┐ +│ Primitive (TcpEndpoint::async_xxx) │ +│ │ +│ 1. 解析 assign_tuple_t / chunk_tuple_t → MR 寻址 │ +│ 2. 构建 SessionHeader (wire format) │ +│ 3. 创建 OpState (completion signal) │ +│ 4. 获取连接 (from pool) │ +│ 5. 创建 ClientSession(sock, op, hdr, payload_src/dst) │ +│ 6. return Future(op) │ +└────────────┬────────────────────────────────────────────┘ + │ 创建 + ┌────────▼──────────┐ ┌──────────────────┐ + │ ClientSession │────────→│ TcpOpState │←──────┐ + │ (I/O 状态机) │ signal │ (完成信号) │ │ + │ shared_ptr 自管理 │ └────────┬─────────┘ │ + │ │ │ 被持有 │ + │ start_write() │ │ │ + │ start_read() │ ┌────────▼─────────┐ │ + │ on_done → 归还连接 │ │ TcpFuture │ │ + └────────────────────┘ │ (用户句柄) │───────┘ + │ wait()/wait_for()│ + └──────────────────┘ +``` + +### 关系矩阵 + +| 对象 | 生命周期 | 知道什么 | 不知道什么 | +|------|---------|---------|-----------| +| **Primitive** | 单次调用 | MR 寻址, hdr 构建, assign_tuple_t 解析 | 线协议细节, async I/O 回调链 | +| **OpState** | ≥ Future 生命周期 | completion_status, signal | I/O 如何完成, 谁在驱动 | +| **Future** | 调用者持有 | wait()/wait_for() | 线协议, socket, 连接池 | +| **ClientSession** | I/O 进行中 | hdr, socket, payload 指针 | MR handle, assign_tuple_t | +| **ServerSession** | 连接存续期间 | socket, recv_matcher | 连接池, Future, OpState | + +## ClientSession 设计 + +一个 ClientSession = 一次出站 I/O 操作的完整生命周期。 + +```cpp +class ClientSession : public std::enable_shared_from_this { +public: + using DoneCallback = std::function; + + ClientSession(asio::ip::tcp::socket sock, DoneCallback on_done); + + // write: header + payload (both async_write, gather) + void start_write(const SessionHeader& hdr, const void* payload); + + // read: write header → read response into dst + void start_read(const SessionHeader& hdr, void* dst); + +private: + asio::ip::tcp::socket socket_; + DoneCallback on_done_; + SessionHeader hdr_{}; + // chunk_buf_ 不需要 — write 直接用 payload 指针, read 直接用 dst 指针 +}; +``` + +关键设计决策: +- **ClientSession 不持有 OpState** — 它只报告 `ec`。由 Primitive 在 on_done 中 signal OpState +- **ClientSession 不持有 PooledConnection** — 它只持有 socket。由 Primitive 在 on_done 中归还连接 +- 这样 ClientSession 是纯粹的 I/O 状态机,不耦合 Future/OpState/Pool + +### 原语 → ClientSession 映射 + +``` +async_send(chunk): + ┌─ 解析 chunk_tuple_t → mr.addr + offset → src_ptr, length + ├─ hdr = {length, 0, OP_SEND} + ├─ op = TcpOpState::create() + ├─ conn = pool.getConnection() + ├─ session = make_shared(move(conn->socket), + │ [op, conn, &pool](ec) { + │ op->completion_status = ec ? FAILED : SUCCESS; + │ op->signal->set_comm_done(0); + │ pool.returnConnection(conn); + │ }); + ├─ session->start_write(hdr, src_ptr); + └─ return TcpSendFuture(op); + +async_write(assign): ← 同上, hdr.opcode = OP_WRITE, hdr.addr = remote_addr +async_read(assign): ← session->start_read(hdr, dst_ptr) + dst_ptr = local_mr.addr + local_off +async_recv(chunk): ← 无 ClientSession (注册到 pending_recvs_) +``` + +### std::vector 的多 assign 处理 + +RDMA 中多个 assign 可聚合为一个 WR chain(一次 `ibv_post_send`,一个 Future)。 +TCP 没有硬件聚合——每个 assign 对应一个独立的线消息(一个 header + payload)。 +但接口约定是一个 `std::vector` → 一个 Future。 + +处理方式:**迭代 vector,每个 assign 创建一个 ClientSession,共享一个 OpState**。 + +``` +async_write([assign_0, assign_1, assign_2]): + op = TcpOpState::create() + op->expected_mask = (1 << 3) - 1 // 3 个 assign, 等 3 个 session 完成 + + for i, a in enumerate(assign): + 解析 a → hdr + src_ptr + conn = pool.getConnection() // 复用同一连接 + session = ClientSession(sock, [op, conn, i, &pool](ec) { + if (!ec) op->signal->set_comm_done(i); // 设置第 i 位 + pool.returnConnection(conn); + }) + session->start_write(hdr, src_ptr) + + return TcpReadWriteFuture(op) // wait 等待 expected_mask 所有位就绪 +``` + +每个 assign → 一个 session → 一次 `async_write`(串行在线路上,同连接)。 +Future.wait() 自旋等待 `completion_mask` 达到 `expected_mask`。 + +**与单 assign 的统一**:单 assign 是 `expected_mask = 1` 的特例。 +ClientSession 不感知是单还是多——只负责一个 I/O 操作。 + +### 不再需要的 + +- `asio::post` — ClientSession 构造后直接在调用者线程调 start_xxx,asio async_write/async_read 已经在 io_context 上 +- `weak_ptr` — ClientSession 不持有 endpoint 引用 +- `pending_reads_` map — 不再需要按 request_id 匹配响应。async_read 创建的 ClientSession 在 start_read 的 on_done 中直接拿到结果 + +## 入站/出站对称 + +``` +ServerSession (入站, 持久) ClientSession (出站, 瞬态) +────────────────────────── ────────────────────────── +readHeader() ← socket start_write(hdr, payload) → socket +dispatch() start_read(hdr, dst) → socket + ├─ OP_SEND: async_read → signal write_header → callback + ├─ OP_WRITE: readBody → memcpy read_response → callback + └─ OP_READ: writeBody → done on_done → Primitive signal → 析构 +readHeader() ← 循环 +``` + +## 文件变更 + +| 文件 | 变更 | +|------|------| +| `tcp_session.h` | 新增 ClientSession 类 (约 35 行) | +| `tcp_session.cpp` | 新增 ClientSession 实现 (约 50 行): start_write, start_read | +| `tcp_endpoint.cpp` | async_send/write/read 从 ad-hoc lambda → ClientSession; 删除 pending_reads_ 相关逻辑; 删除 asio::post | +| `tcp_endpoint.h` | 删除 `pending_reads_`, `read_mu_`, `next_req_id_` (不再需要 request_id 匹配); 公开 API 不变 | + +## 不聚合的理由 + +`assign_tuple_t` 是一个单次 I/O 操作的完整描述——不是可拆分的子操作集合。 +每个 async_read/async_write 调用对应一个 ClientSession。 +多个 assign 的聚合留给上层(如 SlimeRPC channel 的多个 slot), +不在 TcpEndpoint 层处理。 + +## Timeout 设计 + +### 两层 timeout,不同归属 + +| 层 | 机制 | 归属 | 语义 | +|----|------|------|------| +| **Future 层** | `wait_for(ms)` 定时自旋轮询 signal | Future / 调用者 | "我等不了了,但操作还在后台跑" | +| **I/O 层** | `asio::steady_timer` + `socket.cancel()` | ClientSession | "真的取消这个 I/O" | + +### v4 实现 Future 层,v5 实现 I/O 层 + +**v4**: +- `timeout_ms` 参数保留在方法签名中,但仅作为 OpState 的提示值存储 +- 真正的超时由 `future.wait_for(seconds)` 控制——调用者决定等待多久 +- ClientSession 不感知 timeout——它总是跑完 I/O 链 + +```cpp +fut = ep.async_send((h, 0, 128), timeout_ms=5000); +// timeout_ms 存入 op_state, 但 async I/O 链不受影响 +status = fut.wait_for(3.0); // 调用者侧超时 — 3 秒后返回 None +// 3 秒后 ClientSession 可能还在写, 完成后仍会 signal op_state +// 只是没有人等这个 signal 了 +``` + +**v5**:加 `asio::steady_timer` 给 ClientSession +```cpp +void ClientSession::start_write(...) { + if (timeout_ms_ > 0) { + timer_.expires_after(ms(timeout_ms_)); + timer_.async_wait([this](ec) { if (!ec) socket_.cancel(); }); + } + asio::async_write(socket_, bufs, ...); +} +// timer 触发 → socket.cancel() → async_write 回调收到 operation_aborted +// → on_done(operation_aborted) → op->completion_status = TCP_TIMEOUT +``` + +### 为什么不把 timeout_ms 去掉 + +保留它的两个理由: +1. 接口与 RDMA 的 `send(chunk, stream)` 模式一致——都有一个"额外控制参数"的位置 +2. 它为 v5 的 timer 实现预留了参数位,届时只需改内部实现,不改变 API + +## 为什么不做 + +- **recv 无 ClientSession** — 无出站 I/O +- **不拆 WriteSession/ReadSession** — 差异小,合并为一个 ClientSession +- **不在 Future 中持有 Session** — Future 只 wait,通过 OpState 间接关联 +- **ClientSession 不持有 OpState** — 只报 ec,由 Primitive 的 on_done 统一 signal diff --git a/dlslime/csrc/engine/tcp/tcp_endpoint.cpp b/dlslime/csrc/engine/tcp/tcp_endpoint.cpp index 9a148f9d..c591b4c6 100755 --- a/dlslime/csrc/engine/tcp/tcp_endpoint.cpp +++ b/dlslime/csrc/engine/tcp/tcp_endpoint.cpp @@ -128,8 +128,7 @@ void TcpEndpoint::connect(const json& remote_endpoint_info) { // ── memory registration ───────────────────────────────── int32_t TcpEndpoint::register_memory_region(const std::string& name, - uintptr_t ptr, - size_t length) { + uintptr_t ptr, size_t length) { return local_pool_->register_memory_region(ptr, length, name); } @@ -138,22 +137,6 @@ int32_t TcpEndpoint::register_remote_memory_region(const std::string& name, return remote_pool_->register_remote_memory_region(mr_info, name); } -// ── write_message ─────────────────────────────────────── - -bool TcpEndpoint::write_message(tcp::socket& sock, - const SessionHeader& hdr, - const void* payload) { - asio::error_code ec; - SessionHeader net = hdr; - hdr_hton(net); - std::array bufs = { - asio::buffer(&net, sizeof(net)), - asio::buffer(payload, hdr.size) - }; - asio::write(sock, bufs, ec); - return !ec; -} - // ── async_send ────────────────────────────────────────── std::shared_ptr @@ -178,30 +161,15 @@ TcpEndpoint::async_send(const chunk_tuple_t& chunk, int64_t /*timeout_ms*/) { SessionHeader hdr{len, 0, OP_SEND}; auto& pool = ctx_->conn_pool(); - std::weak_ptr weak = weak_from_this(); - asio::post(ctx_->io_context(), [weak, conn, op, hdr, src, len, &pool]() { - auto ep = weak.lock(); - if (!ep) { - op->completion_status.store(TCP_CLOSED, std::memory_order_release); - if (op->signal) op->signal->force_complete(); - return; - } - - asio::error_code ec; - SessionHeader net = hdr; - hdr_hton(net); - std::array bufs = { - asio::buffer(&net, sizeof(net)), - asio::buffer(reinterpret_cast(src), len) - }; - asio::async_write(conn->socket, bufs, - [conn, op, &pool](asio::error_code ec, size_t) { - op->completion_status.store( - ec ? TCP_FAILED : TCP_SUCCESS, std::memory_order_release); - if (op->signal) op->signal->set_comm_done(0); - pool.returnConnection(conn); - }); - }); + auto session = std::make_shared( + std::move(conn->socket), + [op, conn, &pool](asio::error_code ec) { + op->completion_status.store( + ec ? TCP_FAILED : TCP_SUCCESS, std::memory_order_release); + if (op->signal) op->signal->set_comm_done(0); + pool.returnConnection(conn); + }); + session->start_write(hdr, reinterpret_cast(src)); return std::make_shared(op); } @@ -236,8 +204,8 @@ TcpEndpoint::async_read(const std::vector& assign, throw std::runtime_error("TcpEndpoint::async_read: empty assignment"); const auto& a = assign[0]; - int32_t local_h = static_cast(std::get<0>(a)); - int32_t remote_h = static_cast(std::get<1>(a)); + int32_t local_h = static_cast(std::get<0>(a)); + int32_t remote_h = static_cast(std::get<1>(a)); uint64_t remote_off = std::get<2>(a); uint64_t local_off = std::get<3>(a); size_t length = std::get<4>(a); @@ -259,59 +227,18 @@ TcpEndpoint::async_read(const std::vector& assign, return std::make_shared(op); } - uint64_t req_id = next_req_id_.fetch_add(1, std::memory_order_relaxed); - { - std::lock_guard lk(read_mu_); - pending_reads_[req_id] = {conn, op}; - } - SessionHeader hdr{length, remote.addr + remote_off, OP_READ}; auto& pool = ctx_->conn_pool(); - std::weak_ptr weak = weak_from_this(); - asio::post(ctx_->io_context(), [weak, conn, op, hdr, req_id, &pool]() { - auto ep = weak.lock(); - if (!ep) { - op->completion_status.store(TCP_CLOSED, std::memory_order_release); - if (op->signal) op->signal->force_complete(); - return; - } - - SessionHeader net = hdr; - hdr_hton(net); - asio::async_write(conn->socket, - asio::buffer(&net, sizeof(net)), - [weak, conn, op, req_id, &pool](asio::error_code ec, size_t) { - if (ec) { - op->completion_status.store(TCP_FAILED, std::memory_order_release); - if (op->signal) op->signal->set_comm_done(0); - pool.returnConnection(conn); - auto self = weak.lock(); - if (self) { - std::lock_guard lk(self->read_mu_); - self->pending_reads_.erase(req_id); - } - return; - } - - asio::async_read(conn->socket, - asio::buffer(reinterpret_cast(op->user_buffer), - op->user_length), - [weak, conn, op, req_id, &pool](asio::error_code ec, size_t n) { - op->bytes_copied = n; - op->completion_status.store( - ec ? TCP_FAILED : TCP_SUCCESS, - std::memory_order_release); - if (op->signal) op->signal->set_comm_done(0); - pool.returnConnection(conn); - auto self = weak.lock(); - if (self) { - std::lock_guard lk(self->read_mu_); - self->pending_reads_.erase(req_id); - } - }); - }); - }); + auto session = std::make_shared( + std::move(conn->socket), + [op, conn, &pool](asio::error_code ec) { + op->completion_status.store( + ec ? TCP_FAILED : TCP_SUCCESS, std::memory_order_release); + if (op->signal) op->signal->set_comm_done(0); + pool.returnConnection(conn); + }); + session->start_read(hdr, reinterpret_cast(op->user_buffer)); return std::make_shared(op); } @@ -351,30 +278,15 @@ TcpEndpoint::async_write(const std::vector& assign, SessionHeader hdr{length, remote.addr + remote_off, OP_WRITE}; auto& pool = ctx_->conn_pool(); - std::weak_ptr weak = weak_from_this(); - asio::post(ctx_->io_context(), [weak, conn, op, hdr, src, length, &pool]() { - auto ep = weak.lock(); - if (!ep) { - op->completion_status.store(TCP_CLOSED, std::memory_order_release); - if (op->signal) op->signal->force_complete(); - return; - } - - asio::error_code ec; - SessionHeader net = hdr; - hdr_hton(net); - std::array bufs = { - asio::buffer(&net, sizeof(net)), - asio::buffer(reinterpret_cast(src), length) - }; - asio::async_write(conn->socket, bufs, - [conn, op, &pool](asio::error_code ec, size_t) { - op->completion_status.store( - ec ? TCP_FAILED : TCP_SUCCESS, std::memory_order_release); - if (op->signal) op->signal->set_comm_done(0); - pool.returnConnection(conn); - }); - }); + auto session = std::make_shared( + std::move(conn->socket), + [op, conn, &pool](asio::error_code ec) { + op->completion_status.store( + ec ? TCP_FAILED : TCP_SUCCESS, std::memory_order_release); + if (op->signal) op->signal->set_comm_done(0); + pool.returnConnection(conn); + }); + session->start_write(hdr, reinterpret_cast(src)); return std::make_shared(op); } @@ -387,7 +299,6 @@ void TcpEndpoint::shutdown() { return; connected_.store(false, std::memory_order_release); - acceptor_.close(); { @@ -400,16 +311,6 @@ void TcpEndpoint::shutdown() { } pending_recvs_.clear(); } - { - std::lock_guard lk(read_mu_); - for (auto& [_, pending] : pending_reads_) { - if (pending.op_state && pending.op_state->signal) { - pending.op_state->completion_status.store(TCP_CLOSED, std::memory_order_release); - pending.op_state->signal->force_complete(); - } - } - pending_reads_.clear(); - } if (own_ctx_) own_ctx_->shutdown(); diff --git a/dlslime/csrc/engine/tcp/tcp_endpoint.h b/dlslime/csrc/engine/tcp/tcp_endpoint.h index 806e6f6c..05c13d2e 100755 --- a/dlslime/csrc/engine/tcp/tcp_endpoint.h +++ b/dlslime/csrc/engine/tcp/tcp_endpoint.h @@ -9,7 +9,6 @@ #include #include #include -#include #include #include "dlslime/csrc/common/json.hpp" @@ -31,11 +30,8 @@ class TcpEndpoint : public std::enable_shared_from_this { public: static constexpr int64_t kDefaultTimeoutMs = 30000; - // ip: 绑定网卡地址 (默认 0.0.0.0). port: 0 = 随机端口. explicit TcpEndpoint(const std::string& ip = "0.0.0.0", uint16_t port = 0); - // 共享 TcpContext — 暂禁用, 多 endpoint 复用单 io_context 时再完善 - // (涉及 context 所有权 / conn_pool 管理 / 析构顺序) TcpEndpoint(TcpContext& ctx, uint16_t port = 0) = delete; ~TcpEndpoint(); @@ -57,21 +53,17 @@ class TcpEndpoint : public std::enable_shared_from_this { // ── Async I/O (all return Future immediately; I/O runs on io_context thread) ── - // Bilateral send. timeout_ms controls socket write timeout (SO_SNDTIMEO). std::shared_ptr async_send( const chunk_tuple_t& chunk, int64_t timeout_ms = kDefaultTimeoutMs); - // Bilateral recv. Timeout via future.wait_for(). std::shared_ptr async_recv( const chunk_tuple_t& chunk); - // Unilateral read: request remote to send data from registered buffer. std::shared_ptr async_read( const std::vector& assign, int64_t timeout_ms = kDefaultTimeoutMs); - // Unilateral write: push data to remote registered buffer. std::shared_ptr async_write( const std::vector& assign, int64_t timeout_ms = kDefaultTimeoutMs); @@ -87,8 +79,6 @@ class TcpEndpoint : public std::enable_shared_from_this { ServerSession::RecvMatcher make_recv_matcher(); bool is_initiator(const std::string& peer_host, uint16_t peer_port) const; - bool write_message(asio::ip::tcp::socket& sock, - const SessionHeader& hdr, const void* payload); // ── identity ──────────────────────────────────────── std::atomic id_{-1}; @@ -99,11 +89,10 @@ class TcpEndpoint : public std::enable_shared_from_this { std::atomic connected_{false}; // ── asio core ─────────────────────────────────────── - // ctx_ 始终指向 own_ctx_ (次构造禁用后不再有外部注入路径) TcpContext* ctx_{nullptr}; std::unique_ptr own_ctx_; - asio::ip::tcp::acceptor acceptor_; - std::atomic running_{true}; + asio::ip::tcp::acceptor acceptor_; + std::atomic running_{true}; // ── memory ────────────────────────────────────────── std::shared_ptr local_pool_; @@ -115,15 +104,6 @@ class TcpEndpoint : public std::enable_shared_from_this { }; std::mutex recv_mu_; std::deque pending_recvs_; - - // ── read matching ─────────────────────────────────── - struct PendingRead { - std::shared_ptr conn; - std::shared_ptr op_state; - }; - std::mutex read_mu_; - std::unordered_map pending_reads_; - std::atomic next_req_id_{1}; }; } // namespace tcp diff --git a/dlslime/csrc/engine/tcp/tcp_session.cpp b/dlslime/csrc/engine/tcp/tcp_session.cpp index 4a8949b4..40857aac 100644 --- a/dlslime/csrc/engine/tcp/tcp_session.cpp +++ b/dlslime/csrc/engine/tcp/tcp_session.cpp @@ -131,5 +131,40 @@ void ServerSession::writeBody(const void* src, size_t len) { }); } +// ── ClientSession ─────────────────────────────────────── + +ClientSession::ClientSession(asio::ip::tcp::socket sock, DoneCallback on_done) + : socket_(std::move(sock)), on_done_(std::move(on_done)) {} + +void ClientSession::start_write(const SessionHeader& hdr, const void* payload) { + auto self = shared_from_this(); + SessionHeader net = hdr; + hdr_to_net(net); + std::array bufs = { + asio::buffer(&net, sizeof(net)), + asio::buffer(payload, hdr.size) + }; + asio::async_write(socket_, bufs, + [this, self](asio::error_code ec, size_t) { + if (on_done_) on_done_(ec); + }); +} + +void ClientSession::start_read(const SessionHeader& hdr, void* dst) { + auto self = shared_from_this(); + hdr_ = hdr; + SessionHeader net = hdr; + hdr_to_net(net); + asio::async_write(socket_, asio::buffer(&net, sizeof(net)), + [this, self, dst](asio::error_code ec, size_t) { + if (ec) { if (on_done_) on_done_(ec); return; } + asio::async_read(socket_, + asio::buffer(dst, hdr_.size), + [this, self](asio::error_code ec, size_t) { + if (on_done_) on_done_(ec); + }); + }); +} + } // namespace tcp } // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/tcp_session.h b/dlslime/csrc/engine/tcp/tcp_session.h index 6c14a841..f4a3480d 100644 --- a/dlslime/csrc/engine/tcp/tcp_session.h +++ b/dlslime/csrc/engine/tcp/tcp_session.h @@ -22,7 +22,8 @@ struct RecvSlot { std::shared_ptr op_state; }; -// ServerSession: handles incoming requests on one persistent connection. +// ── ServerSession: handles incoming requests on one persistent connection ── +// // Lifecycle: start() → readHeader → dispatch → readBody/writeBody → readHeader ↻ class ServerSession : public std::enable_shared_from_this { public: @@ -37,8 +38,8 @@ class ServerSession : public std::enable_shared_from_this { private: void readHeader(); void dispatch(); - void readBody(void* dst, size_t len); // read into caller's buffer - void writeBody(const void* src, size_t len); // write from caller's buffer + void readBody(void* dst, size_t len); + void writeBody(const void* src, size_t len); asio::ip::tcp::socket socket_; TcpMemoryPool* local_pool_; @@ -46,5 +47,27 @@ class ServerSession : public std::enable_shared_from_this { SessionHeader header_{}; }; +// ── ClientSession: drives one outbound I/O operation ───── +// +// Lifecycle: construct → start_write/start_read → on_done → self-destruct +// Does NOT own OpState or PooledConnection — only drives the I/O and reports ec. +class ClientSession : public std::enable_shared_from_this { +public: + using DoneCallback = std::function; + + ClientSession(asio::ip::tcp::socket sock, DoneCallback on_done); + + // Write header + payload to socket (gather async_write). + void start_write(const SessionHeader& hdr, const void* payload); + + // Write OP_READ header → read raw response into dst. + void start_read(const SessionHeader& hdr, void* dst); + +private: + asio::ip::tcp::socket socket_; + DoneCallback on_done_; + SessionHeader hdr_{}; +}; + } // namespace tcp } // namespace dlslime From 5a68128bd94a2289b679765d53d0f5f0145033b6 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 18 May 2026 03:35:30 +0000 Subject: [PATCH 06/10] remove_mrpool_in_sendrecv_and_update_connext --- dlslime/csrc/engine/tcp/plan.md | 2 +- dlslime/csrc/engine/tcp/tcp_endpoint.cpp | 41 ++++++++------------ dlslime/csrc/engine/tcp/tcp_endpoint.h | 4 +- dlslime/csrc/engine/tcp/test_tcp_endpoint.py | 26 +++++-------- 4 files changed, 27 insertions(+), 46 deletions(-) diff --git a/dlslime/csrc/engine/tcp/plan.md b/dlslime/csrc/engine/tcp/plan.md index 34f4f1a1..915a8240 100755 --- a/dlslime/csrc/engine/tcp/plan.md +++ b/dlslime/csrc/engine/tcp/plan.md @@ -474,7 +474,7 @@ bool TcpFuture::wait_for(int64_t timeout_ms, int32_t* out) const { | 阶段 | 文件 | 说明 | |------|------|------| -| 1. 分支 | `git checkout -b tcp-v3 main` | 基于 main 创建新分支 | +| 1. 分支 | `git checkout -b tcp-v4` | 基于 v3 创建了新分支 | | 2. 头文件 | tcp_header.h, tcp_op_state.h | 17B header + 3 opcodes + op state | | 3. 内存池 | tcp_memory_pool.h/.cpp | 纯簿记, 无硬件注册 | | 4. Future | tcp_future.h | header-only, wait + wait_for | diff --git a/dlslime/csrc/engine/tcp/tcp_endpoint.cpp b/dlslime/csrc/engine/tcp/tcp_endpoint.cpp index c591b4c6..2fea7b13 100755 --- a/dlslime/csrc/engine/tcp/tcp_endpoint.cpp +++ b/dlslime/csrc/engine/tcp/tcp_endpoint.cpp @@ -101,28 +101,25 @@ json TcpEndpoint::mr_info() const { return local_pool_->mr_info(); } -bool TcpEndpoint::is_initiator(const std::string& peer_host, - uint16_t peer_port) const { - int cmp = local_host_.compare(peer_host); - if (cmp != 0) return cmp > 0; - return local_port_ > peer_port; -} - void TcpEndpoint::connect(const json& remote_endpoint_info) { - peer_host_ = remote_endpoint_info.value("host", ""); - peer_port_ = static_cast(remote_endpoint_info.value("port", 0)); + auto host = remote_endpoint_info.value("host", ""); + auto port = static_cast(remote_endpoint_info.value("port", 0)); + + // Verify reachability before accepting the peer identity. + auto conn = ctx_->conn_pool().getConnection(host, port); + if (!conn) { + SLIME_LOG_WARN("TcpEndpoint::connect: cannot reach ", host, ":", port); + return; + } + peer_host_ = host; + peer_port_ = port; if (remote_endpoint_info.contains("mr_info")) { for (const auto& [name, info] : remote_endpoint_info["mr_info"].items()) remote_pool_->register_remote_memory_region(info, name); } - - if (is_initiator(peer_host_, peer_port_)) { - auto conn = ctx_->conn_pool().getConnection(peer_host_, peer_port_); - if (conn) ctx_->conn_pool().returnConnection(std::move(conn)); - } - connected_.store(true, std::memory_order_release); + ctx_->conn_pool().returnConnection(std::move(conn)); } // ── memory registration ───────────────────────────────── @@ -138,14 +135,11 @@ int32_t TcpEndpoint::register_remote_memory_region(const std::string& name, } // ── async_send ────────────────────────────────────────── +// chunk_tuple_t = (src_ptr, offset, length) — raw pointers, no MR lookup. std::shared_ptr TcpEndpoint::async_send(const chunk_tuple_t& chunk, int64_t /*timeout_ms*/) { - auto mr = local_pool_->get_mr_fast(static_cast(std::get<0>(chunk))); - if (mr.length == 0) - throw std::runtime_error("TcpEndpoint::async_send: invalid local MR"); - - uintptr_t src = mr.addr + std::get<1>(chunk); + uintptr_t src = std::get<0>(chunk) + std::get<1>(chunk); size_t len = std::get<2>(chunk); auto conn = ctx_->conn_pool().getConnection(peer_host_, peer_port_); @@ -175,16 +169,13 @@ TcpEndpoint::async_send(const chunk_tuple_t& chunk, int64_t /*timeout_ms*/) { } // ── async_recv ────────────────────────────────────────── +// chunk_tuple_t = (dst_ptr, offset, length) — raw pointers, no MR lookup. std::shared_ptr TcpEndpoint::async_recv(const chunk_tuple_t& chunk) { - auto mr = local_pool_->get_mr_fast(static_cast(std::get<0>(chunk))); - if (mr.length == 0) - throw std::runtime_error("TcpEndpoint::async_recv: invalid local MR"); - auto op = TcpOpState::create(); op->signal->reset_all(); - op->user_buffer = mr.addr + std::get<1>(chunk); + op->user_buffer = std::get<0>(chunk) + std::get<1>(chunk); op->user_length = std::get<2>(chunk); { diff --git a/dlslime/csrc/engine/tcp/tcp_endpoint.h b/dlslime/csrc/engine/tcp/tcp_endpoint.h index 05c13d2e..5e20887b 100755 --- a/dlslime/csrc/engine/tcp/tcp_endpoint.h +++ b/dlslime/csrc/engine/tcp/tcp_endpoint.h @@ -32,7 +32,7 @@ class TcpEndpoint : public std::enable_shared_from_this { explicit TcpEndpoint(const std::string& ip = "0.0.0.0", uint16_t port = 0); - TcpEndpoint(TcpContext& ctx, uint16_t port = 0) = delete; + TcpEndpoint(TcpContext& ctx, const std::string& ip = "0.0.0.0", uint16_t port = 0) = delete; ~TcpEndpoint(); @@ -78,8 +78,6 @@ class TcpEndpoint : public std::enable_shared_from_this { void do_accept(); ServerSession::RecvMatcher make_recv_matcher(); - bool is_initiator(const std::string& peer_host, uint16_t peer_port) const; - // ── identity ──────────────────────────────────────── std::atomic id_{-1}; std::string local_host_{"0.0.0.0"}; diff --git a/dlslime/csrc/engine/tcp/test_tcp_endpoint.py b/dlslime/csrc/engine/tcp/test_tcp_endpoint.py index 0510d89b..87075bb9 100755 --- a/dlslime/csrc/engine/tcp/test_tcp_endpoint.py +++ b/dlslime/csrc/engine/tcp/test_tcp_endpoint.py @@ -29,8 +29,6 @@ def test_async_send_recv(): ep_a = TcpEndpoint(port=10001) ep_b = TcpEndpoint(port=10002) - h_a = ep_a.register_memory_region("a", ctypes.addressof(buf_a), 4096) - h_b = ep_b.register_memory_region("b", ctypes.addressof(buf_b), 4096) info_a = ep_a.endpoint_info() info_b = ep_b.endpoint_info() @@ -38,10 +36,10 @@ def run_a(): ep_a.connect(info_b) print(" A connected") ctypes.memmove(ctypes.addressof(buf_a), b"hello", 5) - st = ep_a.async_send((h_a, 0, 5)).wait() + st = ep_a.async_send((ctypes.addressof(buf_a), 0, 5)).wait() assert st == 0, f"send failed: {st}" print(" A sent 5 bytes") - st = ep_a.async_recv((h_a, 0, 5)).wait() + st = ep_a.async_recv((ctypes.addressof(buf_a), 0, 5)).wait() assert st == 0, f"recv failed: {st}" assert bytes(buf_a[:5]) == b"world" print(" A recv'd: world") @@ -50,11 +48,11 @@ def run_a(): def run_b(): ep_b.connect(info_a) print(" B connected") - st = ep_b.async_recv((h_b, 0, 5)).wait() + st = ep_b.async_recv((ctypes.addressof(buf_b), 0, 5)).wait() assert st == 0 and bytes(buf_b[:5]) == b"hello" print(" B recv'd: hello") ctypes.memmove(ctypes.addressof(buf_b), b"world", 5) - st = ep_b.async_send((h_b, 0, 5)).wait() + st = ep_b.async_send((ctypes.addressof(buf_b), 0, 5)).wait() assert st == 0 print(" B sent 5 bytes") ep_b.shutdown() @@ -120,7 +118,6 @@ def test_recv_timeout(): buf_a = ctypes.create_string_buffer(64) ep_a = TcpEndpoint(port=10003) - h_a = ep_a.register_memory_region("a", ctypes.addressof(buf_a), 64) ep_b = TcpEndpoint(port=10004) def run_b(): @@ -130,7 +127,7 @@ def run_b(): def run_a(): ep_a.connect(ep_b.endpoint_info()) - fut = ep_a.async_recv((h_a, 0, 5)) + fut = ep_a.async_recv((ctypes.addressof(buf_a), 0, 5)) result = fut.wait_for(0.3) print(f" recv wait_for(0.3s): {result} (expected None)") assert result is None, f"Expected None (timeout), got {result}" @@ -149,19 +146,17 @@ def test_send_timeout_ms(): ep_a = TcpEndpoint(port=10005) ep_b = TcpEndpoint(port=10006) - h_a = ep_a.register_memory_region("a", ctypes.addressof(buf_a), 256) - h_b = ep_b.register_memory_region("b", ctypes.addressof(buf_b), 256) def run_b(): ep_b.connect(ep_a.endpoint_info()) - st = ep_b.async_recv((h_b, 0, 5)).wait() + st = ep_b.async_recv((ctypes.addressof(buf_b), 0, 5)).wait() assert st == 0 ep_b.shutdown() def run_a(): ep_a.connect(ep_b.endpoint_info()) ctypes.memmove(ctypes.addressof(buf_a), b"world", 5) - st = ep_a.async_send((h_a, 0, 5), timeout_ms=10000).wait() + st = ep_a.async_send((ctypes.addressof(buf_a), 0, 5), timeout_ms=10000).wait() assert st == 0, f"send timeout_ms=10000 failed: {st}" print(f" async_send with timeout_ms=10000: status={st}") ep_a.shutdown() @@ -179,20 +174,17 @@ def test_default_timeout(): ep_a = TcpEndpoint(port=10007) ep_b = TcpEndpoint(port=10008) - h_a = ep_a.register_memory_region("a", ctypes.addressof(buf_a), 128) - h_b = ep_b.register_memory_region("b", ctypes.addressof(buf_b), 128) def run_b(): ep_b.connect(ep_a.endpoint_info()) - st = ep_b.async_recv((h_b, 0, 5)).wait() + st = ep_b.async_recv((ctypes.addressof(buf_b), 0, 5)).wait() assert st == 0 ep_b.shutdown() def run_a(): ep_a.connect(ep_b.endpoint_info()) ctypes.memmove(ctypes.addressof(buf_a), b"test!", 5) - # No timeout_ms arg — uses default 30000ms - st = ep_a.async_send((h_a, 0, 5)).wait() + st = ep_a.async_send((ctypes.addressof(buf_a), 0, 5)).wait() assert st == 0, f"default timeout send failed: {st}" print(f" async_send with default timeout: status={st}") ep_a.shutdown() From 1dd6b351eb7ef225b321646769665a77a78fa630 Mon Sep 17 00:00:00 2001 From: SHshenhao Date: Mon, 18 May 2026 07:30:23 +0000 Subject: [PATCH 07/10] add CUDA staging, remove is_initiator, decouple send/recv from MR - CUDA: char* + new[]/delete[] staging pattern (Mooncake-aligned) async_send/async_write: D2H before ClientSession; async_recv/async_read: H2D via RecvSlot::post_read callback in ServerSession - Remove is_initiator: both sides use conn_pool on-demand - connect() verifies reachability via getConnection before setting peer state - send/recv: treat chunk_tuple_t as raw pointers (no MR lookup needed for bilateral ops); read/write continue using MemoryPool for remote address resolution - PendingRecv: add staging_buf (unique_ptr) and cuda_dst for CUDA - RecvSlot: add post_read callback for post-recv CUDA H2D before signal Co-Authored-By: Claude Opus 4.7 --- dlslime/csrc/engine/tcp/tcp_endpoint.cpp | 105 ++++++++++++++++++++--- dlslime/csrc/engine/tcp/tcp_endpoint.h | 2 + dlslime/csrc/engine/tcp/tcp_session.cpp | 1 + dlslime/csrc/engine/tcp/tcp_session.h | 1 + 4 files changed, 98 insertions(+), 11 deletions(-) diff --git a/dlslime/csrc/engine/tcp/tcp_endpoint.cpp b/dlslime/csrc/engine/tcp/tcp_endpoint.cpp index 2fea7b13..547584a6 100755 --- a/dlslime/csrc/engine/tcp/tcp_endpoint.cpp +++ b/dlslime/csrc/engine/tcp/tcp_endpoint.cpp @@ -8,6 +8,10 @@ #include "dlslime/csrc/logging.h" +#ifdef USE_CUDA +#include +#endif + namespace dlslime { namespace tcp { @@ -20,6 +24,14 @@ static void hdr_hton(SessionHeader& h) { h.addr = htole64(h.addr); } +#ifdef USE_CUDA +static bool is_cuda_memory(const void* addr) { + cudaPointerAttributes attr; + auto st = cudaPointerGetAttributes(&attr, addr); + return (st == cudaSuccess && attr.type == cudaMemoryTypeDevice); +} +#endif + // ── RecvMatcher factory ──────────────────────────────── ServerSession::RecvMatcher TcpEndpoint::make_recv_matcher() { @@ -31,7 +43,19 @@ ServerSession::RecvMatcher TcpEndpoint::make_recv_matcher() { if (self->pending_recvs_.empty()) return {}; auto pr = std::move(self->pending_recvs_.front()); self->pending_recvs_.pop_front(); - return {pr.op_state->user_buffer, pr.op_state->user_length, pr.op_state}; + + RecvSlot slot{pr.op_state->user_buffer, pr.op_state->user_length, pr.op_state}; +#ifdef USE_CUDA + if (pr.cuda_dst) { + slot.buffer = reinterpret_cast(pr.staging_buf.get()); + slot.post_read = [buf = std::move(pr.staging_buf), + dst = pr.cuda_dst, len = pr.op_state->user_length]() { + cudaMemcpy(reinterpret_cast(dst), buf.get(), + len, cudaMemcpyHostToDevice); + }; + } +#endif + return slot; }; } @@ -155,15 +179,30 @@ TcpEndpoint::async_send(const chunk_tuple_t& chunk, int64_t /*timeout_ms*/) { SessionHeader hdr{len, 0, OP_SEND}; auto& pool = ctx_->conn_pool(); + auto* send_ptr = reinterpret_cast(src); + bool is_cuda = false; +#ifdef USE_CUDA + if (is_cuda_memory(send_ptr)) { + // TODO: 使用锁页内存,以及考虑async和overlap + auto* buf = new char[len]; + cudaMemcpy(buf, send_ptr, len, cudaMemcpyDeviceToHost); + send_ptr = buf; + is_cuda = true; + } +#endif + auto session = std::make_shared( std::move(conn->socket), - [op, conn, &pool](asio::error_code ec) { + [op, conn, &pool, send_ptr, is_cuda](asio::error_code ec) { op->completion_status.store( ec ? TCP_FAILED : TCP_SUCCESS, std::memory_order_release); if (op->signal) op->signal->set_comm_done(0); pool.returnConnection(conn); +#ifdef USE_CUDA + if (is_cuda) delete[] send_ptr; +#endif }); - session->start_write(hdr, reinterpret_cast(src)); + session->start_write(hdr, send_ptr); return std::make_shared(op); } @@ -175,12 +214,24 @@ std::shared_ptr TcpEndpoint::async_recv(const chunk_tuple_t& chunk) { auto op = TcpOpState::create(); op->signal->reset_all(); - op->user_buffer = std::get<0>(chunk) + std::get<1>(chunk); - op->user_length = std::get<2>(chunk); + uintptr_t dst = std::get<0>(chunk) + std::get<1>(chunk); + size_t length = std::get<2>(chunk); + op->user_buffer = dst; + op->user_length = length; + + PendingRecv pr{op}; +#ifdef USE_CUDA + if (is_cuda_memory(reinterpret_cast(dst))) { + auto* buf = new char[length]; + pr.staging_buf.reset(buf); + pr.cuda_dst = dst; + op->user_buffer = reinterpret_cast(buf); + } +#endif { std::lock_guard lk(recv_mu_); - pending_recvs_.push_back({op}); + pending_recvs_.push_back(std::move(pr)); } return std::make_shared(op); @@ -208,7 +259,8 @@ TcpEndpoint::async_read(const std::vector& assign, auto op = TcpOpState::create(); op->signal->reset_all(); - op->user_buffer = local.addr + local_off; + uintptr_t local_dst = local.addr + local_off; + op->user_buffer = local_dst; op->user_length = length; auto conn = ctx_->conn_pool().getConnection(peer_host_, peer_port_); @@ -221,15 +273,32 @@ TcpEndpoint::async_read(const std::vector& assign, SessionHeader hdr{length, remote.addr + remote_off, OP_READ}; auto& pool = ctx_->conn_pool(); + auto* read_dst = reinterpret_cast(local_dst); + bool is_cuda = false; +#ifdef USE_CUDA + if (is_cuda_memory(read_dst)) { + read_dst = new char[length]; + is_cuda = true; + } +#endif + auto session = std::make_shared( std::move(conn->socket), - [op, conn, &pool](asio::error_code ec) { + [op, conn, &pool, read_dst, is_cuda, + real_dst = local_dst, len = length](asio::error_code ec) { +#ifdef USE_CUDA + if (!ec && is_cuda) { + cudaMemcpy(reinterpret_cast(real_dst), + read_dst, len, cudaMemcpyHostToDevice); + delete[] read_dst; + } +#endif op->completion_status.store( ec ? TCP_FAILED : TCP_SUCCESS, std::memory_order_release); if (op->signal) op->signal->set_comm_done(0); pool.returnConnection(conn); }); - session->start_read(hdr, reinterpret_cast(op->user_buffer)); + session->start_read(hdr, read_dst); return std::make_shared(op); } @@ -269,15 +338,29 @@ TcpEndpoint::async_write(const std::vector& assign, SessionHeader hdr{length, remote.addr + remote_off, OP_WRITE}; auto& pool = ctx_->conn_pool(); + auto* send_ptr = reinterpret_cast(src); + bool is_cuda = false; +#ifdef USE_CUDA + if (is_cuda_memory(send_ptr)) { + auto* buf = new char[length]; + cudaMemcpy(buf, send_ptr, length, cudaMemcpyDeviceToHost); + send_ptr = buf; + is_cuda = true; + } +#endif + auto session = std::make_shared( std::move(conn->socket), - [op, conn, &pool](asio::error_code ec) { + [op, conn, &pool, send_ptr, is_cuda](asio::error_code ec) { op->completion_status.store( ec ? TCP_FAILED : TCP_SUCCESS, std::memory_order_release); if (op->signal) op->signal->set_comm_done(0); pool.returnConnection(conn); +#ifdef USE_CUDA + if (is_cuda) delete[] send_ptr; +#endif }); - session->start_write(hdr, reinterpret_cast(src)); + session->start_write(hdr, send_ptr); return std::make_shared(op); } diff --git a/dlslime/csrc/engine/tcp/tcp_endpoint.h b/dlslime/csrc/engine/tcp/tcp_endpoint.h index 5e20887b..03c211fa 100755 --- a/dlslime/csrc/engine/tcp/tcp_endpoint.h +++ b/dlslime/csrc/engine/tcp/tcp_endpoint.h @@ -99,6 +99,8 @@ class TcpEndpoint : public std::enable_shared_from_this { // ── recv matching ─────────────────────────────────── struct PendingRecv { std::shared_ptr op_state; + std::unique_ptr staging_buf; + uintptr_t cuda_dst{0}; }; std::mutex recv_mu_; std::deque pending_recvs_; diff --git a/dlslime/csrc/engine/tcp/tcp_session.cpp b/dlslime/csrc/engine/tcp/tcp_session.cpp index 40857aac..aa5dd596 100644 --- a/dlslime/csrc/engine/tcp/tcp_session.cpp +++ b/dlslime/csrc/engine/tcp/tcp_session.cpp @@ -77,6 +77,7 @@ void ServerSession::dispatch() { SLIME_LOG_WARN("ServerSession SEND read: ", ec.message()); return; } + if (slot.post_read) slot.post_read(); if (slot.op_state) { slot.op_state->bytes_copied = n; slot.op_state->completion_status.store( diff --git a/dlslime/csrc/engine/tcp/tcp_session.h b/dlslime/csrc/engine/tcp/tcp_session.h index f4a3480d..80f55bdf 100644 --- a/dlslime/csrc/engine/tcp/tcp_session.h +++ b/dlslime/csrc/engine/tcp/tcp_session.h @@ -20,6 +20,7 @@ struct RecvSlot { uintptr_t buffer{0}; size_t length{0}; std::shared_ptr op_state; + std::function post_read; // called after read, before signal }; // ── ServerSession: handles incoming requests on one persistent connection ── From 1dce4ab116c19c9dc7d1c64a838f933a6ffeac30 Mon Sep 17 00:00:00 2001 From: SHshenhao Date: Mon, 18 May 2026 13:11:06 +0000 Subject: [PATCH 08/10] update --- dlslime/csrc/device/cuda/cuda_signal.h | 2 +- dlslime/csrc/engine/tcp/build_and_test.sh | 22 +- dlslime/csrc/engine/tcp/tcp_endpoint.cpp | 7 +- dlslime/csrc/engine/tcp/tcp_endpoint.h | 4 +- dlslime/csrc/engine/tcp/tcp_session.cpp | 39 +- dlslime/csrc/engine/tcp/tcp_session.h | 3 +- dlslime/csrc/engine/tcp/test_tcp_endpoint.py | 592 ++++++++++++++++--- dlslime/csrc/python/bind.cpp | 3 +- 8 files changed, 574 insertions(+), 98 deletions(-) diff --git a/dlslime/csrc/device/cuda/cuda_signal.h b/dlslime/csrc/device/cuda/cuda_signal.h index 0c690ab9..3b66bde9 100755 --- a/dlslime/csrc/device/cuda/cuda_signal.h +++ b/dlslime/csrc/device/cuda/cuda_signal.h @@ -8,7 +8,7 @@ #include "dlslime/csrc/device/signal.h" #include "dlslime/csrc/engine/rdma/rdma_env.h" #include "dlslime/csrc/logging.h" -#include "dlslime/csrc/pause.h" +#include "dlslime/csrc/common/pause.h" #include "nvtx_helper.h" namespace dlslime { diff --git a/dlslime/csrc/engine/tcp/build_and_test.sh b/dlslime/csrc/engine/tcp/build_and_test.sh index 0283ab30..f50155dd 100755 --- a/dlslime/csrc/engine/tcp/build_and_test.sh +++ b/dlslime/csrc/engine/tcp/build_and_test.sh @@ -6,11 +6,17 @@ REPO_ROOT="$(cd "$SCRIPT_DIR/../../../.." && pwd)" BUILD_DIR="$REPO_ROOT/build_tcp" MODE="${1:-all}" +# Optional: USE_CUDA=ON ./build_and_test.sh all +USE_CUDA="${USE_CUDA:-OFF}" + header() { echo; echo -e "\033[1;36m==>\033[m \033[1m$*\033[m"; } ok() { echo -e " \033[1;32mOK\033[m $*"; } do_build() { - header "Configuring (BUILD_TCP=ON, BUILD_RDMA=OFF)" + local cuda_label="" + [[ "$USE_CUDA" == "ON" ]] && cuda_label=" + USE_CUDA=ON" + + header "Configuring (BUILD_TCP=ON, BUILD_RDMA=OFF${cuda_label})" cmake -S "$REPO_ROOT" -B "$BUILD_DIR" -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ -DDLSLIME_INSTALL_PATH=dlslime \ @@ -19,6 +25,7 @@ do_build() { -DBUILD_TCP=ON \ -DBUILD_NVLINK=OFF \ -DBUILD_ASCEND_DIRECT=OFF \ + -DUSE_CUDA="$USE_CUDA" \ -DSKBUILD_PROJECT_NAME=dlslime 2>&1 | tail -3 ok "CMake configure" @@ -31,17 +38,18 @@ do_build() { } do_test() { - header "Running TcpEndpoint v3 tests" + header "Running TcpEndpoint tests" export DLSLIME_LOG_LEVEL=0 export LD_LIBRARY_PATH="$REPO_ROOT/dlslime" export PYTHONPATH="$REPO_ROOT" python3 "$SCRIPT_DIR/test_tcp_endpoint.py" 2>&1 | while IFS= read -r line; do - if [[ "$line" == *"PASSED"* ]]; then echo -e " \033[1;32m✓\033[m $line" - elif [[ "$line" == *"FAIL"* ]]; then echo -e " \033[1;91m✗\033[m $line" + if [[ "$line" == *"PASSED"* ]]; then echo -e " \033[1;32m✓\033[m $line" + elif [[ "$line" == *"SKIP"* ]]; then echo -e " \033[1;33m⊘\033[m $line" + elif [[ "$line" == *"FAIL"* ]]; then echo -e " \033[1;91m✗\033[m $line" else echo " $line" fi done - ok "All tests passed" + echo " tests done " } case "$MODE" in @@ -50,5 +58,7 @@ case "$MODE" in test) do_test ;; clean) rm -rf "$BUILD_DIR" "$REPO_ROOT/dlslime/_slime_c"*.so "$REPO_ROOT/dlslime/lib_slime_"*.so ok "Cleaned" ;; - *) echo "Usage: $0 {all|build|test|clean}" >&2; exit 1 ;; + *) echo "Usage: $0 {all|build|test|clean}" >&2 + echo " USE_CUDA=ON $0 all # build + test with CUDA" >&2 + exit 1 ;; esac diff --git a/dlslime/csrc/engine/tcp/tcp_endpoint.cpp b/dlslime/csrc/engine/tcp/tcp_endpoint.cpp index 547584a6..8453ea9d 100755 --- a/dlslime/csrc/engine/tcp/tcp_endpoint.cpp +++ b/dlslime/csrc/engine/tcp/tcp_endpoint.cpp @@ -44,7 +44,8 @@ ServerSession::RecvMatcher TcpEndpoint::make_recv_matcher() { auto pr = std::move(self->pending_recvs_.front()); self->pending_recvs_.pop_front(); - RecvSlot slot{pr.op_state->user_buffer, pr.op_state->user_length, pr.op_state}; + RecvSlot slot{pr.op_state->user_buffer, pr.op_state->user_length, + pr.op_state, {}, pr.exact_size}; #ifdef USE_CUDA if (pr.cuda_dst) { slot.buffer = reinterpret_cast(pr.staging_buf.get()); @@ -211,7 +212,7 @@ TcpEndpoint::async_send(const chunk_tuple_t& chunk, int64_t /*timeout_ms*/) { // chunk_tuple_t = (dst_ptr, offset, length) — raw pointers, no MR lookup. std::shared_ptr -TcpEndpoint::async_recv(const chunk_tuple_t& chunk) { +TcpEndpoint::async_recv(const chunk_tuple_t& chunk, bool exact_size) { auto op = TcpOpState::create(); op->signal->reset_all(); uintptr_t dst = std::get<0>(chunk) + std::get<1>(chunk); @@ -219,7 +220,7 @@ TcpEndpoint::async_recv(const chunk_tuple_t& chunk) { op->user_buffer = dst; op->user_length = length; - PendingRecv pr{op}; + PendingRecv pr{op, nullptr, 0, exact_size}; #ifdef USE_CUDA if (is_cuda_memory(reinterpret_cast(dst))) { auto* buf = new char[length]; diff --git a/dlslime/csrc/engine/tcp/tcp_endpoint.h b/dlslime/csrc/engine/tcp/tcp_endpoint.h index 03c211fa..bc6e97d1 100755 --- a/dlslime/csrc/engine/tcp/tcp_endpoint.h +++ b/dlslime/csrc/engine/tcp/tcp_endpoint.h @@ -58,7 +58,8 @@ class TcpEndpoint : public std::enable_shared_from_this { int64_t timeout_ms = kDefaultTimeoutMs); std::shared_ptr async_recv( - const chunk_tuple_t& chunk); + const chunk_tuple_t& chunk, + bool exact_size = false); std::shared_ptr async_read( const std::vector& assign, @@ -101,6 +102,7 @@ class TcpEndpoint : public std::enable_shared_from_this { std::shared_ptr op_state; std::unique_ptr staging_buf; uintptr_t cuda_dst{0}; + bool exact_size{false}; }; std::mutex recv_mu_; std::deque pending_recvs_; diff --git a/dlslime/csrc/engine/tcp/tcp_session.cpp b/dlslime/csrc/engine/tcp/tcp_session.cpp index aa5dd596..c1454f73 100644 --- a/dlslime/csrc/engine/tcp/tcp_session.cpp +++ b/dlslime/csrc/engine/tcp/tcp_session.cpp @@ -67,19 +67,48 @@ void ServerSession::dispatch() { readHeader(); return; } - size_t n = std::min(static_cast(header_.size), slot.length); + if (slot.exact_size && header_.size != slot.length) { + SLIME_LOG_WARN("ServerSession: size mismatch, send ", header_.size, + " != recv ", slot.length); + if (slot.op_state) { + slot.op_state->completion_status.store( + TCP_FAILED, std::memory_order_release); + if (slot.op_state->signal) + slot.op_state->signal->set_comm_done(0); + } + readHeader(); + return; + } + + // Always drain the full send payload from the wire. If recv buffer + // is smaller, read into a temp buffer then copy what fits. + size_t n_read = static_cast(header_.size); + size_t n_copy = std::min(n_read, slot.length); + auto* dst = reinterpret_cast(slot.buffer); + bool overflow = false; + + if (header_.size > slot.length) { + dst = new char[n_read]; + overflow = true; + } + auto self = shared_from_this(); - asio::async_read(socket_, - asio::buffer(reinterpret_cast(slot.buffer), n), - [this, self, slot, n](asio::error_code ec, size_t /*rn*/) { + asio::async_read(socket_, asio::buffer(dst, n_read), + [this, self, slot, n_copy, dst, overflow]( + asio::error_code ec, size_t /*rn*/) { if (ec) { if (is_fatal(ec)) SLIME_LOG_WARN("ServerSession SEND read: ", ec.message()); + if (overflow) delete[] dst; return; } + if (overflow) { + std::memcpy(reinterpret_cast(slot.buffer), dst, n_copy); + delete[] dst; + } if (slot.post_read) slot.post_read(); if (slot.op_state) { - slot.op_state->bytes_copied = n; + slot.op_state->bytes_copied = n_copy; slot.op_state->completion_status.store( TCP_SUCCESS, std::memory_order_release); if (slot.op_state->signal) diff --git a/dlslime/csrc/engine/tcp/tcp_session.h b/dlslime/csrc/engine/tcp/tcp_session.h index 80f55bdf..2e70e7d8 100644 --- a/dlslime/csrc/engine/tcp/tcp_session.h +++ b/dlslime/csrc/engine/tcp/tcp_session.h @@ -20,7 +20,8 @@ struct RecvSlot { uintptr_t buffer{0}; size_t length{0}; std::shared_ptr op_state; - std::function post_read; // called after read, before signal + std::function post_read; + bool exact_size{false}; // reject send size != recv size }; // ── ServerSession: handles incoming requests on one persistent connection ── diff --git a/dlslime/csrc/engine/tcp/test_tcp_endpoint.py b/dlslime/csrc/engine/tcp/test_tcp_endpoint.py index 87075bb9..962ae76a 100755 --- a/dlslime/csrc/engine/tcp/test_tcp_endpoint.py +++ b/dlslime/csrc/engine/tcp/test_tcp_endpoint.py @@ -6,26 +6,81 @@ """ import ctypes +import os import threading import time from dlslime import TcpEndpoint, TcpMemoryPool +# ── optional torch / CUDA support ──────────────────────── -def _sync_run(fn_a, fn_b): - b = threading.Barrier(2) - ta = threading.Thread(target=lambda: (b.wait(), fn_a()), daemon=True) - tb = threading.Thread(target=lambda: (b.wait(), fn_b()), daemon=True) - ta.start(); tb.start() - ta.join(); tb.join() +_HAS_TORCH = False +_HAS_CUDA = False +try: + import torch -def test_async_send_recv(): - """Two endpoints async_send/async_recv each other.""" - print("=== test_async_send_recv ===") + _HAS_TORCH = True + _HAS_CUDA = torch.cuda.is_available() +except Exception: + pass + +_CUDA_FORCE_OFF = os.environ.get("DLSLIME_TCP_TEST_CUDA", "") in ("0", "false", "no") + + +def _torch_skip(): + return not _HAS_TORCH + + +def _cuda_skip(): + if _CUDA_FORCE_OFF: + return True + return not _HAS_CUDA + + +# ── test harness ───────────────────────────────────────── + +def _sync_run(name, fn_a, fn_b): + err = [] + + def wrap(fn): + try: + b.wait() + fn() + except Exception as e: + err.append(e) + + b = threading.Barrier(2) + ta = threading.Thread(target=wrap, args=(fn_a,), daemon=True) + tb = threading.Thread(target=wrap, args=(fn_b,), daemon=True) + ta.start() + tb.start() + ta.join() + tb.join() + if len(err) > 0: + print(f"{name} FAIL {err}") + return False + else: + print(f"{name} SUCC ") + return True + + +def _run_test(fn): + print(f"=== {fn.__name__} ===") + try: + fn() + print(" PASSED\n") + return True + except Exception as e: + print(f" FAILED — {e}\n") + return False + + +# ── ctypes-based tests ─────────────────────────────────── - buf_a = ctypes.create_string_buffer(4096) - buf_b = ctypes.create_string_buffer(4096) +def test_async_send_recv(): + buf_a = ctypes.create_string_buffer(128) + buf_b = ctypes.create_string_buffer(128) ep_a = TcpEndpoint(port=10001) ep_b = TcpEndpoint(port=10002) @@ -34,115 +89,156 @@ def test_async_send_recv(): def run_a(): ep_a.connect(info_b) - print(" A connected") ctypes.memmove(ctypes.addressof(buf_a), b"hello", 5) st = ep_a.async_send((ctypes.addressof(buf_a), 0, 5)).wait() - assert st == 0, f"send failed: {st}" - print(" A sent 5 bytes") - st = ep_a.async_recv((ctypes.addressof(buf_a), 0, 5)).wait() - assert st == 0, f"recv failed: {st}" - assert bytes(buf_a[:5]) == b"world" - print(" A recv'd: world") + if st != 0: + raise RuntimeError(f"send: {st}") + st = ep_a.async_recv((ctypes.addressof(buf_a), 5, 5)).wait() + if st != 0: + raise RuntimeError(f"recv: {st}") + if bytes(buf_a[5:10]) != b"world": + raise RuntimeError(f"data: {bytes(buf_a[5:10])}") ep_a.shutdown() def run_b(): ep_b.connect(info_a) - print(" B connected") st = ep_b.async_recv((ctypes.addressof(buf_b), 0, 5)).wait() - assert st == 0 and bytes(buf_b[:5]) == b"hello" - print(" B recv'd: hello") + if st != 0: + raise RuntimeError(f"recv: {st}") + if bytes(buf_b[:5]) != b"hello": + raise RuntimeError(f"data: {bytes(buf_b[:5])}") ctypes.memmove(ctypes.addressof(buf_b), b"world", 5) st = ep_b.async_send((ctypes.addressof(buf_b), 0, 5)).wait() - assert st == 0 - print(" B sent 5 bytes") + if st != 0: + raise RuntimeError(f"send: {st}") ep_b.shutdown() - _sync_run(run_a, run_b) - print(" PASSED\n") + _sync_run("test_async_send_recv", run_a, run_b) + + +def test_async_send_recv_one(): + buf_a = ctypes.create_string_buffer(32) + buf_b = ctypes.create_string_buffer(32) + + ep_a = TcpEndpoint(port=10041) + ep_b = TcpEndpoint(port=10042) + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + + def run_a(): + ep_a.connect(info_b) + ctypes.memmove(ctypes.addressof(buf_a), b"one", 3) + st = ep_a.async_send((ctypes.addressof(buf_a), 0, 3)).wait() + if st != 0: + raise RuntimeError(f"send: {st}") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + st = ep_b.async_recv((ctypes.addressof(buf_b), 0, 3)).wait() + if st != 0: + raise RuntimeError(f"recv: {st}") + if bytes(buf_b[:3]) != b"one": + raise RuntimeError(f"data: {bytes(buf_b[:3])}") + ep_b.shutdown() + _sync_run("test_async_send_recv_one", run_a, run_b) -def test_async_write_read(): - """A writes to B's buffer, then reads from B's buffer.""" - print("=== test_async_write_read ===") - buf_a = ctypes.create_string_buffer(4096) - buf_b = ctypes.create_string_buffer(4096) +def test_async_write(): + buf_a = ctypes.create_string_buffer(256) + buf_b = ctypes.create_string_buffer(256) addr_a = ctypes.addressof(buf_a) ep_a = TcpEndpoint(port=0) ep_b = TcpEndpoint(port=0) - - h_a = ep_a.register_memory_region("a", addr_a, 4096) - h_b = ep_b.register_memory_region("b", ctypes.addressof(buf_b), 4096) - + h_a = ep_a.register_memory_region("a", addr_a, 256) + h_b = ep_b.register_memory_region("b", ctypes.addressof(buf_b), 256) info_a = ep_a.endpoint_info() info_b = ep_b.endpoint_info() - h_br = ep_a.register_remote_memory_region("rb", info_b["mr_info"]["b"]) test_data = b"hello_from_a" def run_a(): ep_a.connect(info_b) - print(" A connected") ctypes.memmove(addr_a, test_data, len(test_data)) st = ep_a.async_write([(h_a, h_br, 0, 0, len(test_data))]).wait() - assert st == 0, f"write failed: {st}" - print(f" A wrote {len(test_data)} bytes to B") - time.sleep(0.1) - st = ep_a.async_read([(h_a, h_br, 0, 0, len(test_data))]).wait() - assert st == 0 and bytes(buf_a[:len(test_data)]) == test_data - print(f" A read from B: {bytes(buf_a[:len(test_data)])}") + if st != 0: + raise RuntimeError(f"write: {st}") ep_a.shutdown() def run_b(): ep_b.connect(info_a) - print(" B connected") - time.sleep(0.2) for _ in range(50): if bytes(buf_b[:len(test_data)]) == test_data: break - time.sleep(0.01) - assert bytes(buf_b[:len(test_data)]) == test_data - print(f" B buffer verified") + time.sleep(0.5) + if bytes(buf_b[:len(test_data)]) != test_data: + raise RuntimeError("B write not received") ep_b.shutdown() - _sync_run(run_a, run_b) - print(" PASSED\n") + _sync_run("test_async_write", run_a, run_b) -def test_recv_timeout(): - """recv times out when peer never sends.""" - print("=== test_recv_timeout ===") +def test_async_read(): + buf_a = ctypes.create_string_buffer(256) + buf_b = ctypes.create_string_buffer(256) + addr_a = ctypes.addressof(buf_a) - buf_a = ctypes.create_string_buffer(64) + ep_a = TcpEndpoint(port=0) + ep_b = TcpEndpoint(port=0) + h_a = ep_a.register_memory_region("a", addr_a, 256) + h_b = ep_b.register_memory_region("b", ctypes.addressof(buf_b), 256) + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + h_br = ep_a.register_remote_memory_region("rb", info_b["mr_info"]["b"]) + + test_data = b"hello_from_b" + ctypes.memmove(ctypes.addressof(buf_b), test_data, 12) + + def run_a(): + ep_a.connect(info_b) + st = ep_a.async_read([(h_a, h_br, 0, 0, len(test_data))]).wait() + if st != 0: + raise RuntimeError(f"read: {st}") + if bytes(buf_a[:len(test_data)]) != test_data: + raise RuntimeError("read data mismatch") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + time.sleep(30) + ep_b.shutdown() + + _sync_run("test_async_read", run_a, run_b) + + +def test_recv_timeout(): + buf_a = ctypes.create_string_buffer(32) ep_a = TcpEndpoint(port=10003) ep_b = TcpEndpoint(port=10004) def run_b(): ep_b.connect(ep_a.endpoint_info()) - time.sleep(1.5) + time.sleep(1.0) ep_b.shutdown() def run_a(): ep_a.connect(ep_b.endpoint_info()) fut = ep_a.async_recv((ctypes.addressof(buf_a), 0, 5)) result = fut.wait_for(0.3) - print(f" recv wait_for(0.3s): {result} (expected None)") - assert result is None, f"Expected None (timeout), got {result}" + if result is not None: + raise RuntimeError(f"expected None, got {result}") ep_a.shutdown() - _sync_run(run_a, run_b) - print(" PASSED\n") + _sync_run("test_recv_timeout", run_a, run_b) def test_send_timeout_ms(): - """async_send accepts timeout_ms parameter.""" - print("=== test_send_timeout_ms ===") - - buf_a = ctypes.create_string_buffer(256) - buf_b = ctypes.create_string_buffer(256) + buf_a = ctypes.create_string_buffer(64) + buf_b = ctypes.create_string_buffer(64) ep_a = TcpEndpoint(port=10005) ep_b = TcpEndpoint(port=10006) @@ -150,27 +246,24 @@ def test_send_timeout_ms(): def run_b(): ep_b.connect(ep_a.endpoint_info()) st = ep_b.async_recv((ctypes.addressof(buf_b), 0, 5)).wait() - assert st == 0 + if st != 0: + raise RuntimeError(f"recv: {st}") ep_b.shutdown() def run_a(): ep_a.connect(ep_b.endpoint_info()) ctypes.memmove(ctypes.addressof(buf_a), b"world", 5) st = ep_a.async_send((ctypes.addressof(buf_a), 0, 5), timeout_ms=10000).wait() - assert st == 0, f"send timeout_ms=10000 failed: {st}" - print(f" async_send with timeout_ms=10000: status={st}") + if st != 0: + raise RuntimeError(f"send: {st}") ep_a.shutdown() - _sync_run(run_a, run_b) - print(" PASSED\n") + _sync_run("test_send_timeout_ms", run_a, run_b) def test_default_timeout(): - """async_send uses kDefaultTimeoutMs=30000 when timeout_ms not given.""" - print("=== test_default_timeout ===") - - buf_a = ctypes.create_string_buffer(128) - buf_b = ctypes.create_string_buffer(128) + buf_a = ctypes.create_string_buffer(32) + buf_b = ctypes.create_string_buffer(32) ep_a = TcpEndpoint(port=10007) ep_b = TcpEndpoint(port=10008) @@ -178,25 +271,364 @@ def test_default_timeout(): def run_b(): ep_b.connect(ep_a.endpoint_info()) st = ep_b.async_recv((ctypes.addressof(buf_b), 0, 5)).wait() - assert st == 0 + if st != 0: + raise RuntimeError(f"recv: {st}") ep_b.shutdown() def run_a(): ep_a.connect(ep_b.endpoint_info()) ctypes.memmove(ctypes.addressof(buf_a), b"test!", 5) st = ep_a.async_send((ctypes.addressof(buf_a), 0, 5)).wait() - assert st == 0, f"default timeout send failed: {st}" - print(f" async_send with default timeout: status={st}") + if st != 0: + raise RuntimeError(f"send: {st}") ep_a.shutdown() - _sync_run(run_a, run_b) - print(" PASSED\n") + _sync_run("test_default_timeout", run_a, run_b) + + +def test_exact_size_mismatch(): + buf_a = ctypes.create_string_buffer(32) + buf_b = ctypes.create_string_buffer(32) + + ep_a = TcpEndpoint(port=10011) + ep_b = TcpEndpoint(port=10012) + + def run_b(): + ep_b.connect(ep_a.endpoint_info()) + st = ep_b.async_recv((ctypes.addressof(buf_b), 0, 4), exact_size=True).wait() + if st != -1: + raise RuntimeError(f"expected TCP_FAILED(-1), got {st}") + ep_b.shutdown() + + def run_a(): + ep_a.connect(ep_b.endpoint_info()) + ctypes.memmove(ctypes.addressof(buf_a), b"overflow", 8) + st = ep_a.async_send((ctypes.addressof(buf_a), 0, 8)).wait() + if st != 0: + raise RuntimeError(f"send: {st}") + ep_a.shutdown() + + _sync_run("test_exact_size_mismatch", run_a, run_b) + + +def test_overflow_truncate(): + buf_a = ctypes.create_string_buffer(64) + buf_b = ctypes.create_string_buffer(64) + + ep_a = TcpEndpoint(port=10013) + ep_b = TcpEndpoint(port=10014) + + def run_b(): + ep_b.connect(ep_a.endpoint_info()) + st = ep_b.async_recv((ctypes.addressof(buf_b), 0, 4)).wait() + if st != 0: + raise RuntimeError(f"recv1: {st}") + if bytes(buf_b[:4]) != b"LONG": + raise RuntimeError(f"truncated: {bytes(buf_b[:4])}") + st = ep_b.async_recv((ctypes.addressof(buf_b), 4, 5)).wait() + if st != 0: + raise RuntimeError(f"recv2: {st}") + if bytes(buf_b[4:9]) != b"HELLO": + raise RuntimeError(f"follow-up: {bytes(buf_b[4:9])}") + ep_b.shutdown() + + def run_a(): + ep_a.connect(ep_b.endpoint_info()) + ctypes.memmove(ctypes.addressof(buf_a), b"LONGDATA", 8) + st = ep_a.async_send((ctypes.addressof(buf_a), 0, 8)).wait() + if st != 0: + raise RuntimeError(f"send1: {st}") + ctypes.memmove(ctypes.addressof(buf_a), b"HELLO", 5) + st = ep_a.async_send((ctypes.addressof(buf_a), 0, 5)).wait() + if st != 0: + raise RuntimeError(f"send2: {st}") + ep_a.shutdown() + + _sync_run("test_overflow_truncate", run_a, run_b) + + +# def test_mr_name_validation(): +# ep = TcpEndpoint(port=0) +# buf = ctypes.create_string_buffer(32) + +# h = ep.register_memory_region("valid", ctypes.addressof(buf), 32) +# if h < 0: +# raise RuntimeError(f"valid name: {h}") + +# h = ep.register_memory_region("", ctypes.addressof(buf), 32) +# if h != -1: +# raise RuntimeError(f"empty name should return -1, got {h}") + +# h = ep.register_memory_region("valid", ctypes.addressof(buf), 32) +# if h != -1: +# raise RuntimeError(f"duplicate name should return -1, got {h}") + +# ep.shutdown() + + +# def test_connect_unreachable(): +# ep = TcpEndpoint(port=10015) +# unreachable = {"host": "127.0.0.1", "port": 65535, "mr_info": {}} +# ep.connect(unreachable) +# if ep.is_connected(): +# raise RuntimeError("should not be connected") +# ep.shutdown() + + +# ── parameterized torch tests (device="cpu" or "cuda") ── + +def _dev_skip(device): + if not _HAS_TORCH: + return True + if device == "cuda": + return _cuda_skip() + return False + + +def _make_tensor(shape, device, **kw): + """Create a tensor on the given device. CPU tensor for recv on cuda path + uses ctypes buffer so data_ptr() gives host pointer (needed for cudaMemcpy).""" + return torch.zeros(shape, dtype=torch.float32, device=device, **kw) + + +def test_torch_send_recv(device="cpu"): + """Round-trip: A send full → B recv → B send slice → A recv.""" + if _dev_skip(device): + return + + SZ, SL = 32, 5 # elements + t_a = _make_tensor(SZ, device).normal_() + t_b = _make_tensor(SZ, device) + n_bytes = SZ * 4 + sl_bytes = SL * 4 + + ep_a = TcpEndpoint(port=10021) + ep_b = TcpEndpoint(port=10022) + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + + def run_a(): + ep_a.connect(info_b) + st = ep_a.async_send((t_a.data_ptr(), 0, n_bytes)).wait() + if st != 0: + raise RuntimeError(f"send: {st}") + st = ep_a.async_recv((t_a.data_ptr(), 10 * 4, sl_bytes)).wait() + if st != 0: + raise RuntimeError(f"recv: {st}") + if not torch.equal(t_a[10:15].cpu(), t_b[20:25].cpu()): + raise RuntimeError("slice mismatch") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + st = ep_b.async_recv((t_b.data_ptr(), 0, n_bytes)).wait() + if st != 0: + raise RuntimeError(f"recv: {st}") + if not torch.equal(t_a.cpu(), t_b.cpu()): + raise RuntimeError("full tensor mismatch") + st = ep_b.async_send((t_b.data_ptr(), 20 * 4, sl_bytes)).wait() + if st != 0: + raise RuntimeError(f"send: {st}") + ep_b.shutdown() + + _sync_run(f"test_torch_send_recv_{device}", run_a, run_b) + + +def test_torch_write(device="cpu"): + """One-sided write: A async_write → B verifies data received.""" + if _dev_skip(device): + return + + SZ = 64 + t_a = _make_tensor(SZ, device).zero_() + if device == "cuda": + t_b = torch.zeros(SZ, dtype=torch.float32) # remote always CPU + b_ptr = t_b.data_ptr() + else: + t_b = _make_tensor(SZ, device) + b_ptr = t_b.data_ptr() + + n_bytes = SZ * 4 + + ep_a = TcpEndpoint(port=0) + ep_b = TcpEndpoint(port=0) + h_a = ep_a.register_memory_region("a", t_a.data_ptr(), n_bytes) + h_b = ep_b.register_memory_region("b", b_ptr, n_bytes) + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + h_br = ep_a.register_remote_memory_region("rb", info_b["mr_info"]["b"]) + + def run_a(): + ep_a.connect(info_b) + t_a[:6] = torch.arange(6, dtype=torch.float32, device=device) + st = ep_a.async_write([(h_a, h_br, 0, 0, 6 * 4)]).wait() + if st != 0: + raise RuntimeError(f"write: {st}") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + expected = torch.arange(6, dtype=torch.float32) + for _ in range(50): + if torch.equal(t_b[:6], expected): + break + time.sleep(0.01) + if not torch.equal(t_b[:6], expected): + raise RuntimeError("write data not received") + ep_b.shutdown() + + _sync_run(f"test_torch_write_{device}", run_a, run_b) + + +def test_torch_read(device="cpu"): + """One-sided read: B buffer pre-filled, A async_read and verifies.""" + if _dev_skip(device): + return + + SZ = 64 + t_a = _make_tensor(SZ, device).zero_() + if device == "cuda": + t_b = torch.zeros(SZ, dtype=torch.float32) + b_ptr = t_b.data_ptr() + else: + t_b = _make_tensor(SZ, device) + b_ptr = t_b.data_ptr() + + n_bytes = SZ * 4 + + ep_a = TcpEndpoint(port=0) + ep_b = TcpEndpoint(port=0) + h_a = ep_a.register_memory_region("a", t_a.data_ptr(), n_bytes) + h_b = ep_b.register_memory_region("b", b_ptr, n_bytes) + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + h_br = ep_a.register_remote_memory_region("rb", info_b["mr_info"]["b"]) + + def run_a(): + ep_a.connect(info_b) + st = ep_a.async_read([(h_a, h_br, 0, 0, 6 * 4)]).wait() + if st != 0: + raise RuntimeError(f"read: {st}") + expected = torch.arange(6, dtype=torch.float32) + if not torch.equal(t_a[:6].cpu(), expected): + raise RuntimeError("read data mismatch") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + t_b[:6] = torch.arange(6, dtype=torch.float32) + time.sleep(0.2) + ep_b.shutdown() + + _sync_run(f"test_torch_read_{device}", run_a, run_b) + + +def test_torch_write_batch(device="cpu", n_batch=4): + """One async_write with multiple assignments.""" + if _dev_skip(device): + return + + SZ = 64 + t_a = _make_tensor(SZ, device).zero_() + if device == "cuda": + t_b = torch.zeros(SZ, dtype=torch.float32) + b_ptr = t_b.data_ptr() + else: + t_b = _make_tensor(SZ, device) + b_ptr = t_b.data_ptr() + + n_bytes = SZ * 4 + + ep_a = TcpEndpoint(port=0) + ep_b = TcpEndpoint(port=0) + h_a = ep_a.register_memory_region("a", t_a.data_ptr(), n_bytes) + h_b = ep_b.register_memory_region("b", b_ptr, n_bytes) + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + h_br = ep_a.register_remote_memory_region("rb", info_b["mr_info"]["b"]) + + def run_a(): + ep_a.connect(info_b) + for i in range(n_batch): + t_a[i] = float(i + 1) + assigns = [(h_a, h_br, i * 4, i * 4, 4) for i in range(n_batch)] + st = ep_a.async_write(assigns).wait() + if st != 0: + raise RuntimeError(f"write batch: {st}") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + time.sleep(0.2) + for i in range(n_batch): + if t_b[i].item() != float(i + 1): + raise RuntimeError(f"batch {i}: expected {i+1}, got {t_b[i]}") + ep_b.shutdown() + + _sync_run(f"test_torch_write_batch_{device}", run_a, run_b) + + +def test_torch_read_batch(device="cpu", n_batch=4): + """One async_read with multiple assignments.""" + if _dev_skip(device): + return + + SZ = 64 + t_a = _make_tensor(SZ, device).zero_() + if device == "cuda": + t_b = torch.zeros(SZ, dtype=torch.float32) + b_ptr = t_b.data_ptr() + else: + t_b = _make_tensor(SZ, device) + b_ptr = t_b.data_ptr() + + n_bytes = SZ * 4 + + ep_a = TcpEndpoint(port=0) + ep_b = TcpEndpoint(port=0) + h_a = ep_a.register_memory_region("a", t_a.data_ptr(), n_bytes) + h_b = ep_b.register_memory_region("b", b_ptr, n_bytes) + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + h_br = ep_a.register_remote_memory_region("rb", info_b["mr_info"]["b"]) + + def run_a(): + ep_a.connect(info_b) + assigns = [(h_a, h_br, i * 4, i * 4 + 16 * 4, 4) for i in range(n_batch)] + st = ep_a.async_read(assigns).wait() + if st != 0: + raise RuntimeError(f"read batch: {st}") + expected = torch.tensor([1., 2., 3., 4.], dtype=torch.float32) + if not torch.equal(t_a[16:20].cpu(), expected): + raise RuntimeError("read back mismatch") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + for i in range(n_batch): + t_b[i] = float(i + 1) + time.sleep(0.2) + ep_b.shutdown() + + _sync_run(f"test_torch_read_batch_{device}", run_a, run_b) + +# ── main ───────────────────────────────────────────────── if __name__ == "__main__": test_async_send_recv() - test_async_write_read() + test_async_send_recv_one() + test_async_write() + test_async_read() test_recv_timeout() test_send_timeout_ms() test_default_timeout() - print("All TcpEndpoint v3 tests passed!") + test_exact_size_mismatch() + test_overflow_truncate() + + for dev in ("cpu", "cuda"): + test_torch_send_recv(dev) + test_torch_write(dev) + test_torch_read(dev) + # test_torch_write_batch(dev) + # test_torch_read_batch(dev) diff --git a/dlslime/csrc/python/bind.cpp b/dlslime/csrc/python/bind.cpp index 0359583f..1a3ed1d1 100644 --- a/dlslime/csrc/python/bind.cpp +++ b/dlslime/csrc/python/bind.cpp @@ -599,6 +599,7 @@ PYBIND11_MODULE(_slime_c, m) .def("mr_info", &dlslime::tcp::TcpEndpoint::mr_info) .def("shutdown", &dlslime::tcp::TcpEndpoint::shutdown, py::call_guard()) + .def("is_connected", &dlslime::tcp::TcpEndpoint::is_connected) .def("register_memory_region", &dlslime::tcp::TcpEndpoint::register_memory_region, py::arg("name"), py::arg("data_ptr"), py::arg("length"), @@ -614,7 +615,7 @@ PYBIND11_MODULE(_slime_c, m) py::call_guard()) .def("async_recv", &dlslime::tcp::TcpEndpoint::async_recv, - py::arg("chunk"), + py::arg("chunk"), py::arg("exact_size") = false, py::call_guard()) .def("async_read", py::overload_cast&, int64_t>( From a4f3d3f93788bc2f21585fef38bc419b5ef554e6 Mon Sep 17 00:00:00 2001 From: SHshenhao Date: Tue, 19 May 2026 02:58:11 +0000 Subject: [PATCH 09/10] async_read_async_write_support_vectorN --- dlslime/csrc/engine/tcp/tcp_endpoint.cpp | 186 ++++++++++++----------- 1 file changed, 100 insertions(+), 86 deletions(-) diff --git a/dlslime/csrc/engine/tcp/tcp_endpoint.cpp b/dlslime/csrc/engine/tcp/tcp_endpoint.cpp index 8453ea9d..83d2de11 100755 --- a/dlslime/csrc/engine/tcp/tcp_endpoint.cpp +++ b/dlslime/csrc/engine/tcp/tcp_endpoint.cpp @@ -239,6 +239,8 @@ TcpEndpoint::async_recv(const chunk_tuple_t& chunk, bool exact_size) { } // ── async_read ────────────────────────────────────────── +// Each assign creates an independent ClientSession; all share one OpState. +// Future.wait() blocks until every session has signalled its bit. std::shared_ptr TcpEndpoint::async_read(const std::vector& assign, @@ -246,65 +248,71 @@ TcpEndpoint::async_read(const std::vector& assign, if (assign.empty()) throw std::runtime_error("TcpEndpoint::async_read: empty assignment"); - const auto& a = assign[0]; - int32_t local_h = static_cast(std::get<0>(a)); - int32_t remote_h = static_cast(std::get<1>(a)); - uint64_t remote_off = std::get<2>(a); - uint64_t local_off = std::get<3>(a); - size_t length = std::get<4>(a); - - auto local = local_pool_->get_mr_fast(local_h); - auto remote = remote_pool_->get_remote_mr_fast(remote_h); - if (local.length == 0 || remote.length == 0) - throw std::runtime_error("TcpEndpoint::async_read: invalid MR handle"); - - auto op = TcpOpState::create(); + size_t N = assign.size(); + auto op = TcpOpState::create(); op->signal->reset_all(); - uintptr_t local_dst = local.addr + local_off; - op->user_buffer = local_dst; - op->user_length = length; - - auto conn = ctx_->conn_pool().getConnection(peer_host_, peer_port_); - if (!conn) { - op->completion_status.store(TCP_FAILED, std::memory_order_release); - op->signal->force_complete(); - return std::make_shared(op); - } + op->expected_mask = (N < 32) ? (1u << N) - 1 : 0xFFFFFFFFu; + op->completion_status.store(TCP_SUCCESS, std::memory_order_release); + op->completion_mask.store(0, std::memory_order_release); - SessionHeader hdr{length, remote.addr + remote_off, OP_READ}; auto& pool = ctx_->conn_pool(); - auto* read_dst = reinterpret_cast(local_dst); - bool is_cuda = false; + for (size_t i = 0; i < N; i++) { + const auto& a = assign[i]; + int32_t local_h = static_cast(std::get<0>(a)); + int32_t remote_h = static_cast(std::get<1>(a)); + uint64_t remote_off = std::get<2>(a); + uint64_t local_off = std::get<3>(a); + size_t length = std::get<4>(a); + + auto local = local_pool_->get_mr_fast(local_h); + auto remote = remote_pool_->get_remote_mr_fast(remote_h); + if (local.length == 0 || remote.length == 0) + throw std::runtime_error("TcpEndpoint::async_read: invalid MR handle"); + + uintptr_t local_dst = local.addr + local_off; + SessionHeader hdr{length, remote.addr + remote_off, OP_READ}; + + auto conn = pool.getConnection(peer_host_, peer_port_); + if (!conn) { + op->completion_status.store(TCP_FAILED, std::memory_order_release); + op->signal->set_comm_done(i); + continue; + } + + auto* read_dst = reinterpret_cast(local_dst); + bool is_cuda = false; #ifdef USE_CUDA - if (is_cuda_memory(read_dst)) { - read_dst = new char[length]; - is_cuda = true; - } + if (is_cuda_memory(read_dst)) { + read_dst = new char[length]; + is_cuda = true; + } #endif - auto session = std::make_shared( - std::move(conn->socket), - [op, conn, &pool, read_dst, is_cuda, - real_dst = local_dst, len = length](asio::error_code ec) { + auto session = std::make_shared( + std::move(conn->socket), + [op, conn, i, &pool, read_dst, is_cuda, + real_dst = local_dst, len = length](asio::error_code ec) { #ifdef USE_CUDA - if (!ec && is_cuda) { - cudaMemcpy(reinterpret_cast(real_dst), - read_dst, len, cudaMemcpyHostToDevice); - delete[] read_dst; - } + if (!ec && is_cuda) { + cudaMemcpy(reinterpret_cast(real_dst), + read_dst, len, cudaMemcpyHostToDevice); + delete[] read_dst; + } #endif - op->completion_status.store( - ec ? TCP_FAILED : TCP_SUCCESS, std::memory_order_release); - if (op->signal) op->signal->set_comm_done(0); - pool.returnConnection(conn); - }); - session->start_read(hdr, read_dst); + if (ec) + op->completion_status.store(TCP_FAILED, std::memory_order_release); + if (op->signal) op->signal->set_comm_done(i); + pool.returnConnection(conn); + }); + session->start_read(hdr, read_dst); + } return std::make_shared(op); } // ── async_write ───────────────────────────────────────── +// Each assign creates an independent ClientSession; all share one OpState. std::shared_ptr TcpEndpoint::async_write(const std::vector& assign, @@ -312,56 +320,62 @@ TcpEndpoint::async_write(const std::vector& assign, if (assign.empty()) throw std::runtime_error("TcpEndpoint::async_write: empty assignment"); - const auto& a = assign[0]; - int32_t local_h = static_cast(std::get<0>(a)); - int32_t remote_h = static_cast(std::get<1>(a)); - uint64_t remote_off = std::get<2>(a); - uint64_t local_off = std::get<3>(a); - size_t length = std::get<4>(a); - - auto local = local_pool_->get_mr_fast(local_h); - auto remote = remote_pool_->get_remote_mr_fast(remote_h); - if (local.length == 0 || remote.length == 0) - throw std::runtime_error("TcpEndpoint::async_write: invalid MR handle"); - - uintptr_t src = local.addr + local_off; - - auto conn = ctx_->conn_pool().getConnection(peer_host_, peer_port_); - auto op = TcpOpState::create(); + size_t N = assign.size(); + auto op = TcpOpState::create(); op->signal->reset_all(); + op->expected_mask = (N < 32) ? (1u << N) - 1 : 0xFFFFFFFFu; + op->completion_status.store(TCP_SUCCESS, std::memory_order_release); + op->completion_mask.store(0, std::memory_order_release); - if (!conn) { - op->completion_status.store(TCP_FAILED, std::memory_order_release); - op->signal->force_complete(); - return std::make_shared(op); - } - - SessionHeader hdr{length, remote.addr + remote_off, OP_WRITE}; auto& pool = ctx_->conn_pool(); - auto* send_ptr = reinterpret_cast(src); - bool is_cuda = false; + for (size_t i = 0; i < N; i++) { + const auto& a = assign[i]; + int32_t local_h = static_cast(std::get<0>(a)); + int32_t remote_h = static_cast(std::get<1>(a)); + uint64_t remote_off = std::get<2>(a); + uint64_t local_off = std::get<3>(a); + size_t length = std::get<4>(a); + + auto local = local_pool_->get_mr_fast(local_h); + auto remote = remote_pool_->get_remote_mr_fast(remote_h); + if (local.length == 0 || remote.length == 0) + throw std::runtime_error("TcpEndpoint::async_write: invalid MR handle"); + + uintptr_t src = local.addr + local_off; + SessionHeader hdr{length, remote.addr + remote_off, OP_WRITE}; + + auto conn = pool.getConnection(peer_host_, peer_port_); + if (!conn) { + op->completion_status.store(TCP_FAILED, std::memory_order_release); + op->signal->set_comm_done(i); + continue; + } + + auto* send_ptr = reinterpret_cast(src); + bool is_cuda = false; #ifdef USE_CUDA - if (is_cuda_memory(send_ptr)) { - auto* buf = new char[length]; - cudaMemcpy(buf, send_ptr, length, cudaMemcpyDeviceToHost); - send_ptr = buf; - is_cuda = true; - } + if (is_cuda_memory(send_ptr)) { + auto* buf = new char[length]; + cudaMemcpy(buf, send_ptr, length, cudaMemcpyDeviceToHost); + send_ptr = buf; + is_cuda = true; + } #endif - auto session = std::make_shared( - std::move(conn->socket), - [op, conn, &pool, send_ptr, is_cuda](asio::error_code ec) { - op->completion_status.store( - ec ? TCP_FAILED : TCP_SUCCESS, std::memory_order_release); - if (op->signal) op->signal->set_comm_done(0); - pool.returnConnection(conn); + auto session = std::make_shared( + std::move(conn->socket), + [op, conn, i, &pool, send_ptr, is_cuda](asio::error_code ec) { + if (ec) + op->completion_status.store(TCP_FAILED, std::memory_order_release); + if (op->signal) op->signal->set_comm_done(i); + pool.returnConnection(conn); #ifdef USE_CUDA - if (is_cuda) delete[] send_ptr; + if (is_cuda) delete[] send_ptr; #endif - }); - session->start_write(hdr, send_ptr); + }); + session->start_write(hdr, send_ptr); + } return std::make_shared(op); } From ab6e8156e3d89e84f3fc4f97db83bf836f4e9299 Mon Sep 17 00:00:00 2001 From: SHshenhao Date: Tue, 19 May 2026 09:58:25 +0000 Subject: [PATCH 10/10] updatetest_addcudasupport_addbatchreadwritesuport --- dlslime/csrc/engine/tcp/CMakeLists.txt | 7 + dlslime/csrc/engine/tcp/build_and_test.sh | 3 +- dlslime/csrc/engine/tcp/plan_v4.md | 30 +- dlslime/csrc/engine/tcp/tcp_endpoint.cpp | 56 +++- dlslime/csrc/engine/tcp/tcp_endpoint.h | 2 +- dlslime/csrc/engine/tcp/tcp_session.cpp | 66 +++- dlslime/csrc/engine/tcp/test_tcp_endpoint.py | 324 ++++++++----------- dlslime/csrc/python/bind.cpp | 2 +- 8 files changed, 283 insertions(+), 207 deletions(-) mode change 100644 => 100755 dlslime/csrc/engine/tcp/tcp_session.cpp diff --git a/dlslime/csrc/engine/tcp/CMakeLists.txt b/dlslime/csrc/engine/tcp/CMakeLists.txt index 5487b06b..ab69db49 100644 --- a/dlslime/csrc/engine/tcp/CMakeLists.txt +++ b/dlslime/csrc/engine/tcp/CMakeLists.txt @@ -23,6 +23,13 @@ add_library(_slime_tcp SHARED target_compile_definitions(_slime_tcp PRIVATE ASIO_STANDALONE) +if (USE_CUDA) + find_package(CUDAToolkit REQUIRED) + target_compile_definitions(_slime_tcp PRIVATE USE_CUDA) + target_include_directories(_slime_tcp PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) + target_link_libraries(_slime_tcp PUBLIC CUDA::cudart) +endif() + target_link_libraries(_slime_tcp PUBLIC asio::asio _slime_device diff --git a/dlslime/csrc/engine/tcp/build_and_test.sh b/dlslime/csrc/engine/tcp/build_and_test.sh index f50155dd..d30e99c9 100755 --- a/dlslime/csrc/engine/tcp/build_and_test.sh +++ b/dlslime/csrc/engine/tcp/build_and_test.sh @@ -13,6 +13,7 @@ header() { echo; echo -e "\033[1;36m==>\033[m \033[1m$*\033[m"; } ok() { echo -e " \033[1;32mOK\033[m $*"; } do_build() { + rm -rf ${BUILD_DIR} local cuda_label="" [[ "$USE_CUDA" == "ON" ]] && cuda_label=" + USE_CUDA=ON" @@ -39,7 +40,7 @@ do_build() { do_test() { header "Running TcpEndpoint tests" - export DLSLIME_LOG_LEVEL=0 + export SLIME_LOG_LEVEL=1 export LD_LIBRARY_PATH="$REPO_ROOT/dlslime" export PYTHONPATH="$REPO_ROOT" python3 "$SCRIPT_DIR/test_tcp_endpoint.py" 2>&1 | while IFS= read -r line; do diff --git a/dlslime/csrc/engine/tcp/plan_v4.md b/dlslime/csrc/engine/tcp/plan_v4.md index 4f160d9c..10f9d25a 100644 --- a/dlslime/csrc/engine/tcp/plan_v4.md +++ b/dlslime/csrc/engine/tcp/plan_v4.md @@ -1,8 +1,19 @@ # TcpEndpoint v4 — Future / OpState / Session / Primitive 关系重构 -## 当前状态 +**状态**: 已实现并测试通过 (2026-05-18) -四个 async 原语使用 ad-hoc lambda 模式,与 session 概念脱节: +## 已实现功能 + +- 4 个 async 原语基于 ClientSession + Future + OpState 模型 +- ClientSession 与 ServerSession 对称:start_write/start_read vs readBody/writeBody +- 多 assign 支持:迭代 vector,每个 assign 创建独立 ClientSession,共享 OpState +- CUDA 两端 staging:async_send/write/read + ServerSession readBody/writeBody +- send/recv 脱离 MemoryPool(裸指针模式),read/write 继续使用 MR 寻址 +- `register_memory_region(name, ptr, offset, length)` 接口对齐 RDMAEndpoint +- 编译开关:`USE_CUDA=ON ./build_and_test.sh all` 启用 CUDA 路径 +- 宽松截断 + exact_size 拒绝 + overflow 保护 + +## 当前状态(已过时,仅供参考) ``` async_send(chunk): @@ -287,3 +298,18 @@ void ClientSession::start_write(...) { - **不拆 WriteSession/ReadSession** — 差异小,合并为一个 ClientSession - **不在 Future 中持有 Session** — Future 只 wait,通过 OpState 间接关联 - **ClientSession 不持有 OpState** — 只报 ec,由 Primitive 的 on_done 统一 signal + +## 未来规划 + +### CUDA 锁页内存 + +当前 CUDA staging 使用 `new char[]`(可分页内存),D2H/H2D `cudaMemcpy` 走的是同步 device→host 拷贝,pageable memory 路径较慢。 + +后续改为 `cudaHostAlloc()` 分配锁页(pinned)内存,使 `cudaMemcpy` 能走 DMA 快速路径。同时可考虑 `cudaMemcpyAsync` + `cudaStream` 与 io_context 的异步重叠。 + +### async_recv exact_size 自适应 + +当前 `exact_size` 是 opt-in boolean 参数,默认 `false`(宽松截断)。未来改为默认自适应: +- 当 `send_size <= recv_size`:自动启用严格检查(exact match) +- 当 `send_size > recv_size`:自动宽松截断 +- 移除 `exact_size` 参数,行为由实际数据量驱动 diff --git a/dlslime/csrc/engine/tcp/tcp_endpoint.cpp b/dlslime/csrc/engine/tcp/tcp_endpoint.cpp index 83d2de11..f65649b4 100755 --- a/dlslime/csrc/engine/tcp/tcp_endpoint.cpp +++ b/dlslime/csrc/engine/tcp/tcp_endpoint.cpp @@ -49,10 +49,12 @@ ServerSession::RecvMatcher TcpEndpoint::make_recv_matcher() { #ifdef USE_CUDA if (pr.cuda_dst) { slot.buffer = reinterpret_cast(pr.staging_buf.get()); - slot.post_read = [buf = std::move(pr.staging_buf), + slot.post_read = [buf = std::shared_ptr(std::move(pr.staging_buf)), dst = pr.cuda_dst, len = pr.op_state->user_length]() { - cudaMemcpy(reinterpret_cast(dst), buf.get(), - len, cudaMemcpyHostToDevice); + auto cu_err = cudaMemcpy(reinterpret_cast(dst), buf.get(), + len, cudaMemcpyHostToDevice); + if (cu_err != cudaSuccess) + SLIME_LOG_ERROR("cudaMemcpy H2D (recv): ", cudaGetErrorString(cu_err)); }; } #endif @@ -150,8 +152,9 @@ void TcpEndpoint::connect(const json& remote_endpoint_info) { // ── memory registration ───────────────────────────────── int32_t TcpEndpoint::register_memory_region(const std::string& name, - uintptr_t ptr, size_t length) { - return local_pool_->register_memory_region(ptr, length, name); + uintptr_t ptr, uintptr_t offset, + size_t length) { + return local_pool_->register_memory_region(ptr + offset, length, name); } int32_t TcpEndpoint::register_remote_memory_region(const std::string& name, @@ -184,9 +187,16 @@ TcpEndpoint::async_send(const chunk_tuple_t& chunk, int64_t /*timeout_ms*/) { bool is_cuda = false; #ifdef USE_CUDA if (is_cuda_memory(send_ptr)) { - // TODO: 使用锁页内存,以及考虑async和overlap auto* buf = new char[len]; - cudaMemcpy(buf, send_ptr, len, cudaMemcpyDeviceToHost); + auto cu_err = cudaMemcpy(buf, send_ptr, len, cudaMemcpyDeviceToHost); + if (cu_err != cudaSuccess) { + SLIME_LOG_ERROR("async_send cudaMemcpy D2H: ", cudaGetErrorString(cu_err)); + delete[] buf; + op->completion_status.store(TCP_FAILED, std::memory_order_release); + op->signal->force_complete(); + pool.returnConnection(conn); + return std::make_shared(op); + } send_ptr = buf; is_cuda = true; } @@ -195,6 +205,8 @@ TcpEndpoint::async_send(const chunk_tuple_t& chunk, int64_t /*timeout_ms*/) { auto session = std::make_shared( std::move(conn->socket), [op, conn, &pool, send_ptr, is_cuda](asio::error_code ec) { + if (ec) + SLIME_LOG_WARN("async_send: ", ec.message()); op->completion_status.store( ec ? TCP_FAILED : TCP_SUCCESS, std::memory_order_release); if (op->signal) op->signal->set_comm_done(0); @@ -293,15 +305,21 @@ TcpEndpoint::async_read(const std::vector& assign, std::move(conn->socket), [op, conn, i, &pool, read_dst, is_cuda, real_dst = local_dst, len = length](asio::error_code ec) { + if (ec) { + SLIME_LOG_WARN("async_read session ", i, ": ", ec.message()); + op->completion_status.store(TCP_FAILED, std::memory_order_release); + } #ifdef USE_CUDA if (!ec && is_cuda) { - cudaMemcpy(reinterpret_cast(real_dst), - read_dst, len, cudaMemcpyHostToDevice); - delete[] read_dst; + auto cu_err = cudaMemcpy(reinterpret_cast(real_dst), + read_dst, len, cudaMemcpyHostToDevice); + if (cu_err != cudaSuccess) { + SLIME_LOG_ERROR("async_read cudaMemcpy H2D: ", cudaGetErrorString(cu_err)); + op->completion_status.store(TCP_FAILED, std::memory_order_release); + } } + if (is_cuda) delete[] read_dst; #endif - if (ec) - op->completion_status.store(TCP_FAILED, std::memory_order_release); if (op->signal) op->signal->set_comm_done(i); pool.returnConnection(conn); }); @@ -357,7 +375,15 @@ TcpEndpoint::async_write(const std::vector& assign, #ifdef USE_CUDA if (is_cuda_memory(send_ptr)) { auto* buf = new char[length]; - cudaMemcpy(buf, send_ptr, length, cudaMemcpyDeviceToHost); + auto cu_err = cudaMemcpy(buf, send_ptr, length, cudaMemcpyDeviceToHost); + if (cu_err != cudaSuccess) { + SLIME_LOG_ERROR("async_write cudaMemcpy D2H: ", cudaGetErrorString(cu_err)); + delete[] buf; + op->completion_status.store(TCP_FAILED, std::memory_order_release); + op->signal->force_complete(); + pool.returnConnection(conn); + return std::make_shared(op); + } send_ptr = buf; is_cuda = true; } @@ -366,8 +392,10 @@ TcpEndpoint::async_write(const std::vector& assign, auto session = std::make_shared( std::move(conn->socket), [op, conn, i, &pool, send_ptr, is_cuda](asio::error_code ec) { - if (ec) + if (ec) { + SLIME_LOG_WARN("async_write session ", i, ": ", ec.message()); op->completion_status.store(TCP_FAILED, std::memory_order_release); + } if (op->signal) op->signal->set_comm_done(i); pool.returnConnection(conn); #ifdef USE_CUDA diff --git a/dlslime/csrc/engine/tcp/tcp_endpoint.h b/dlslime/csrc/engine/tcp/tcp_endpoint.h index bc6e97d1..66f7a6f4 100755 --- a/dlslime/csrc/engine/tcp/tcp_endpoint.h +++ b/dlslime/csrc/engine/tcp/tcp_endpoint.h @@ -46,7 +46,7 @@ class TcpEndpoint : public std::enable_shared_from_this { // ── Memory ────────────────────────────────────────── int32_t register_memory_region(const std::string& name, - uintptr_t ptr, size_t length); + uintptr_t ptr, uintptr_t offset, size_t length); int32_t register_remote_memory_region(const std::string& name, const json& mr_info); json mr_info() const; diff --git a/dlslime/csrc/engine/tcp/tcp_session.cpp b/dlslime/csrc/engine/tcp/tcp_session.cpp old mode 100644 new mode 100755 index c1454f73..994a2a46 --- a/dlslime/csrc/engine/tcp/tcp_session.cpp +++ b/dlslime/csrc/engine/tcp/tcp_session.cpp @@ -9,6 +9,10 @@ #include "dlslime/csrc/logging.h" +#ifdef USE_CUDA +#include +#endif + namespace dlslime { namespace tcp { @@ -28,6 +32,14 @@ static bool is_fatal(asio::error_code ec) { return ec && ec != asio::error::eof; } +#ifdef USE_CUDA +static bool is_cuda_memory(const void* addr) { + cudaPointerAttributes attr; + auto st = cudaPointerGetAttributes(&attr, addr); + return (st == cudaSuccess && attr.type == cudaMemoryTypeDevice); +} +#endif + // ── ServerSession ─────────────────────────────────────── ServerSession::ServerSession(asio::ip::tcp::socket socket, @@ -141,20 +153,60 @@ void ServerSession::dispatch() { } void ServerSession::readBody(void* dst, size_t len) { + auto* ptr = static_cast(dst); + bool is_cuda = false; +#ifdef USE_CUDA + if (is_cuda_memory(dst)) { + ptr = new char[len]; + is_cuda = true; + } +#endif + auto self = shared_from_this(); - asio::async_read(socket_, asio::buffer(dst, len), - [this, self](asio::error_code ec, size_t /*n*/) { - if (ec && is_fatal(ec)) - SLIME_LOG_WARN("ServerSession::readBody ", ec.message()); + asio::async_read(socket_, asio::buffer(ptr, len), + [this, self, real_addr = reinterpret_cast(dst), + len, is_cuda, ptr](asio::error_code ec, size_t /*n*/) { + if (ec) { + if (is_fatal(ec)) + SLIME_LOG_WARN("ServerSession::readBody ", ec.message()); + if (is_cuda) delete[] ptr; + return; + } +#ifdef USE_CUDA + if (is_cuda) { + auto cu_err = cudaMemcpy(reinterpret_cast(real_addr), ptr, + len, cudaMemcpyHostToDevice); + if (cu_err != cudaSuccess) + SLIME_LOG_ERROR("readBody cudaMemcpy H2D: ", cudaGetErrorString(cu_err)); + delete[] ptr; + } +#endif readHeader(); }); } void ServerSession::writeBody(const void* src, size_t len) { + auto* ptr = static_cast(src); + bool is_cuda = false; +#ifdef USE_CUDA + if (is_cuda_memory(src)) { + auto* buf = new char[len]; + auto cu_err = cudaMemcpy(buf, src, len, cudaMemcpyDeviceToHost); + if (cu_err != cudaSuccess) { + SLIME_LOG_ERROR("writeBody cudaMemcpy D2H: ", cudaGetErrorString(cu_err)); + delete[] buf; + ptr = static_cast(src); + } else { + ptr = buf; + is_cuda = true; + } + } +#endif + auto self = shared_from_this(); - asio::async_write(socket_, - asio::buffer(src, len), - [this, self](asio::error_code ec, size_t /*n*/) { + asio::async_write(socket_, asio::buffer(ptr, len), + [this, self, is_cuda, ptr](asio::error_code ec, size_t /*n*/) { + if (is_cuda) delete[] ptr; if (ec && is_fatal(ec)) SLIME_LOG_WARN("ServerSession::writeBody ", ec.message()); readHeader(); diff --git a/dlslime/csrc/engine/tcp/test_tcp_endpoint.py b/dlslime/csrc/engine/tcp/test_tcp_endpoint.py index 962ae76a..8bab9a8b 100755 --- a/dlslime/csrc/engine/tcp/test_tcp_endpoint.py +++ b/dlslime/csrc/engine/tcp/test_tcp_endpoint.py @@ -40,40 +40,31 @@ def _cuda_skip(): # ── test harness ───────────────────────────────────────── -def _sync_run(name, fn_a, fn_b): +def _sync_run(name, fn_a, fn_b, timeout=120): err = [] + b = threading.Barrier(2) def wrap(fn): try: - b.wait() + b.wait(10) fn() except Exception as e: err.append(e) - b = threading.Barrier(2) - ta = threading.Thread(target=wrap, args=(fn_a,), daemon=True) - tb = threading.Thread(target=wrap, args=(fn_b,), daemon=True) + ta = threading.Thread(target=wrap, args=(fn_a,), daemon=False) + tb = threading.Thread(target=wrap, args=(fn_b,), daemon=False) ta.start() tb.start() - ta.join() - tb.join() + ta.join(timeout) + tb.join(timeout) + if ta.is_alive() or tb.is_alive(): + raise RuntimeError(f"{name} FAIL!{timeout}s timeout!") if len(err) > 0: - print(f"{name} FAIL {err}") + print(f"{name} FAIL {err}", flush=True) return False else: - print(f"{name} SUCC ") - return True - - -def _run_test(fn): - print(f"=== {fn.__name__} ===") - try: - fn() - print(" PASSED\n") + print(f"{name} SUCC ", flush=True) return True - except Exception as e: - print(f" FAILED — {e}\n") - return False # ── ctypes-based tests ─────────────────────────────────── @@ -90,6 +81,7 @@ def test_async_send_recv(): def run_a(): ep_a.connect(info_b) ctypes.memmove(ctypes.addressof(buf_a), b"hello", 5) + time.sleep(5) st = ep_a.async_send((ctypes.addressof(buf_a), 0, 5)).wait() if st != 0: raise RuntimeError(f"send: {st}") @@ -108,6 +100,7 @@ def run_b(): if bytes(buf_b[:5]) != b"hello": raise RuntimeError(f"data: {bytes(buf_b[:5])}") ctypes.memmove(ctypes.addressof(buf_b), b"world", 5) + time.sleep(5) st = ep_b.async_send((ctypes.addressof(buf_b), 0, 5)).wait() if st != 0: raise RuntimeError(f"send: {st}") @@ -116,18 +109,19 @@ def run_b(): _sync_run("test_async_send_recv", run_a, run_b) -def test_async_send_recv_one(): +def test_async_send2recv(): buf_a = ctypes.create_string_buffer(32) buf_b = ctypes.create_string_buffer(32) - ep_a = TcpEndpoint(port=10041) - ep_b = TcpEndpoint(port=10042) + ep_a = TcpEndpoint(port=10401) + ep_b = TcpEndpoint(port=10402) info_a = ep_a.endpoint_info() info_b = ep_b.endpoint_info() def run_a(): ep_a.connect(info_b) ctypes.memmove(ctypes.addressof(buf_a), b"one", 3) + time.sleep(5) st = ep_a.async_send((ctypes.addressof(buf_a), 0, 3)).wait() if st != 0: raise RuntimeError(f"send: {st}") @@ -150,10 +144,10 @@ def test_async_write(): buf_b = ctypes.create_string_buffer(256) addr_a = ctypes.addressof(buf_a) - ep_a = TcpEndpoint(port=0) - ep_b = TcpEndpoint(port=0) - h_a = ep_a.register_memory_region("a", addr_a, 256) - h_b = ep_b.register_memory_region("b", ctypes.addressof(buf_b), 256) + ep_a = TcpEndpoint(port=10003) + ep_b = TcpEndpoint(port=10004) + h_a = ep_a.register_memory_region("a", addr_a, 0, 256) + h_b = ep_b.register_memory_region("b", ctypes.addressof(buf_b), 0, 256) info_a = ep_a.endpoint_info() info_b = ep_b.endpoint_info() h_br = ep_a.register_remote_memory_region("rb", info_b["mr_info"]["b"]) @@ -175,7 +169,7 @@ def run_b(): break time.sleep(0.5) if bytes(buf_b[:len(test_data)]) != test_data: - raise RuntimeError("B write not received") + raise RuntimeError(f"B write not received in {50 * 0.5}s") ep_b.shutdown() _sync_run("test_async_write", run_a, run_b) @@ -186,10 +180,10 @@ def test_async_read(): buf_b = ctypes.create_string_buffer(256) addr_a = ctypes.addressof(buf_a) - ep_a = TcpEndpoint(port=0) - ep_b = TcpEndpoint(port=0) - h_a = ep_a.register_memory_region("a", addr_a, 256) - h_b = ep_b.register_memory_region("b", ctypes.addressof(buf_b), 256) + ep_a = TcpEndpoint(port=10005) + ep_b = TcpEndpoint(port=10006) + h_a = ep_a.register_memory_region("a", addr_a, 0, 256) + h_b = ep_b.register_memory_region("b", ctypes.addressof(buf_b), 0, 256) info_a = ep_a.endpoint_info() info_b = ep_b.endpoint_info() h_br = ep_a.register_remote_memory_region("rb", info_b["mr_info"]["b"]) @@ -208,17 +202,20 @@ def run_a(): def run_b(): ep_b.connect(info_a) - time.sleep(30) + time.sleep(25) ep_b.shutdown() _sync_run("test_async_read", run_a, run_b) +# ── skip test ── + + def test_recv_timeout(): buf_a = ctypes.create_string_buffer(32) - ep_a = TcpEndpoint(port=10003) - ep_b = TcpEndpoint(port=10004) + ep_a = TcpEndpoint(port=10007) + ep_b = TcpEndpoint(port=10008) def run_b(): ep_b.connect(ep_a.endpoint_info()) @@ -240,8 +237,8 @@ def test_send_timeout_ms(): buf_a = ctypes.create_string_buffer(64) buf_b = ctypes.create_string_buffer(64) - ep_a = TcpEndpoint(port=10005) - ep_b = TcpEndpoint(port=10006) + ep_a = TcpEndpoint(port=10009) + ep_b = TcpEndpoint(port=10010) def run_b(): ep_b.connect(ep_a.endpoint_info()) @@ -265,8 +262,8 @@ def test_default_timeout(): buf_a = ctypes.create_string_buffer(32) buf_b = ctypes.create_string_buffer(32) - ep_a = TcpEndpoint(port=10007) - ep_b = TcpEndpoint(port=10008) + ep_a = TcpEndpoint(port=10011) + ep_b = TcpEndpoint(port=10012) def run_b(): ep_b.connect(ep_a.endpoint_info()) @@ -347,75 +344,69 @@ def run_a(): _sync_run("test_overflow_truncate", run_a, run_b) -# def test_mr_name_validation(): -# ep = TcpEndpoint(port=0) -# buf = ctypes.create_string_buffer(32) +def test_mr_name_validation(): + ep = TcpEndpoint(port=0) + buf = ctypes.create_string_buffer(32) -# h = ep.register_memory_region("valid", ctypes.addressof(buf), 32) -# if h < 0: -# raise RuntimeError(f"valid name: {h}") + h = ep.register_memory_region("valid", ctypes.addressof(buf), 0, 32) + if h < 0: + raise RuntimeError(f"valid name: {h}") -# h = ep.register_memory_region("", ctypes.addressof(buf), 32) -# if h != -1: -# raise RuntimeError(f"empty name should return -1, got {h}") + h = ep.register_memory_region("", ctypes.addressof(buf), 0, 32) + if h != -1: + raise RuntimeError(f"empty name should return -1, got {h}") -# h = ep.register_memory_region("valid", ctypes.addressof(buf), 32) -# if h != -1: -# raise RuntimeError(f"duplicate name should return -1, got {h}") + h = ep.register_memory_region("valid", ctypes.addressof(buf), 0, 32) + if h != -1: + raise RuntimeError(f"duplicate name should return -1, got {h}") -# ep.shutdown() + ep.shutdown() -# def test_connect_unreachable(): -# ep = TcpEndpoint(port=10015) -# unreachable = {"host": "127.0.0.1", "port": 65535, "mr_info": {}} -# ep.connect(unreachable) -# if ep.is_connected(): -# raise RuntimeError("should not be connected") -# ep.shutdown() +def test_connect_unreachable(): + ep = TcpEndpoint(port=10015) + unreachable = {"host": "127.0.0.1", "port": 65535, "mr_info": {}} + ep.connect(unreachable) + if ep.is_connected(): + raise RuntimeError("should not be connected") + ep.shutdown() # ── parameterized torch tests (device="cpu" or "cuda") ── -def _dev_skip(device): - if not _HAS_TORCH: - return True - if device == "cuda": - return _cuda_skip() - return False - -def _make_tensor(shape, device, **kw): +def _make_tensor(shape, device, dtype, **kw): """Create a tensor on the given device. CPU tensor for recv on cuda path uses ctypes buffer so data_ptr() gives host pointer (needed for cudaMemcpy).""" - return torch.zeros(shape, dtype=torch.float32, device=device, **kw) + return torch.randn(shape, dtype=dtype, + device=device if isinstance(device, torch.device) else torch.device(device), + **kw) -def test_torch_send_recv(device="cpu"): +def test_torch_send_recv(device="cpu", dtype=torch.float32): """Round-trip: A send full → B recv → B send slice → A recv.""" - if _dev_skip(device): - return - SZ, SL = 32, 5 # elements - t_a = _make_tensor(SZ, device).normal_() - t_b = _make_tensor(SZ, device) + t_a = _make_tensor(SZ, device, dtype) + t_b = _make_tensor(SZ, device, dtype) + expected = t_a.clone() n_bytes = SZ * 4 sl_bytes = SL * 4 - ep_a = TcpEndpoint(port=10021) - ep_b = TcpEndpoint(port=10022) + ep_a = TcpEndpoint(port=10101) + ep_b = TcpEndpoint(port=10102) info_a = ep_a.endpoint_info() info_b = ep_b.endpoint_info() def run_a(): ep_a.connect(info_b) + time.sleep(5) st = ep_a.async_send((t_a.data_ptr(), 0, n_bytes)).wait() if st != 0: raise RuntimeError(f"send: {st}") st = ep_a.async_recv((t_a.data_ptr(), 10 * 4, sl_bytes)).wait() if st != 0: raise RuntimeError(f"recv: {st}") - if not torch.equal(t_a[10:15].cpu(), t_b[20:25].cpu()): + if not torch.equal(t_a[10:15], t_b[20:25]): raise RuntimeError("slice mismatch") ep_a.shutdown() @@ -424,134 +415,110 @@ def run_b(): st = ep_b.async_recv((t_b.data_ptr(), 0, n_bytes)).wait() if st != 0: raise RuntimeError(f"recv: {st}") - if not torch.equal(t_a.cpu(), t_b.cpu()): + if not torch.equal(expected, t_b): raise RuntimeError("full tensor mismatch") + time.sleep(5) st = ep_b.async_send((t_b.data_ptr(), 20 * 4, sl_bytes)).wait() if st != 0: raise RuntimeError(f"send: {st}") ep_b.shutdown() - _sync_run(f"test_torch_send_recv_{device}", run_a, run_b) + _sync_run(f"test_torch_send_recv_{device}", run_a, run_b, 120) -def test_torch_write(device="cpu"): +def test_torch_write(device="cpu", dtype=torch.float32): """One-sided write: A async_write → B verifies data received.""" - if _dev_skip(device): - return - SZ = 64 - t_a = _make_tensor(SZ, device).zero_() - if device == "cuda": - t_b = torch.zeros(SZ, dtype=torch.float32) # remote always CPU - b_ptr = t_b.data_ptr() - else: - t_b = _make_tensor(SZ, device) - b_ptr = t_b.data_ptr() + t_a = _make_tensor(SZ, device, dtype) + t_b = _make_tensor(SZ, device, dtype) + expected = t_a.clone() n_bytes = SZ * 4 - ep_a = TcpEndpoint(port=0) - ep_b = TcpEndpoint(port=0) - h_a = ep_a.register_memory_region("a", t_a.data_ptr(), n_bytes) - h_b = ep_b.register_memory_region("b", b_ptr, n_bytes) + ep_a = TcpEndpoint(port=10103) + ep_b = TcpEndpoint(port=10104) + h_a = ep_a.register_memory_region("a", t_a.data_ptr(), 0, n_bytes) + h_b = ep_b.register_memory_region("b", t_b.data_ptr(), 0, n_bytes) info_a = ep_a.endpoint_info() info_b = ep_b.endpoint_info() h_br = ep_a.register_remote_memory_region("rb", info_b["mr_info"]["b"]) def run_a(): ep_a.connect(info_b) - t_a[:6] = torch.arange(6, dtype=torch.float32, device=device) - st = ep_a.async_write([(h_a, h_br, 0, 0, 6 * 4)]).wait() + st = ep_a.async_write([(h_a, h_br, 0, 0, n_bytes)]).wait() if st != 0: raise RuntimeError(f"write: {st}") ep_a.shutdown() def run_b(): ep_b.connect(info_a) - expected = torch.arange(6, dtype=torch.float32) - for _ in range(50): - if torch.equal(t_b[:6], expected): + for _ in range(40): + if torch.equal(expected, t_b): break - time.sleep(0.01) - if not torch.equal(t_b[:6], expected): + time.sleep(0.5) + if not torch.equal(expected, t_b): raise RuntimeError("write data not received") ep_b.shutdown() _sync_run(f"test_torch_write_{device}", run_a, run_b) -def test_torch_read(device="cpu"): +def test_torch_read(device="cpu", dtype=torch.float32): """One-sided read: B buffer pre-filled, A async_read and verifies.""" - if _dev_skip(device): - return - + dsize = 4 SZ = 64 - t_a = _make_tensor(SZ, device).zero_() - if device == "cuda": - t_b = torch.zeros(SZ, dtype=torch.float32) - b_ptr = t_b.data_ptr() - else: - t_b = _make_tensor(SZ, device) - b_ptr = t_b.data_ptr() + t_a = _make_tensor(SZ, device, dtype) + t_b = _make_tensor(SZ, device, dtype) + expected = t_b.clone() - n_bytes = SZ * 4 + n_bytes = SZ * dsize - ep_a = TcpEndpoint(port=0) - ep_b = TcpEndpoint(port=0) - h_a = ep_a.register_memory_region("a", t_a.data_ptr(), n_bytes) - h_b = ep_b.register_memory_region("b", b_ptr, n_bytes) + ep_a = TcpEndpoint(port=10105) + ep_b = TcpEndpoint(port=10106) + h_a = ep_a.register_memory_region("a", t_a.data_ptr(), 0, n_bytes) + h_b = ep_b.register_memory_region("b", t_b.data_ptr(), 0, n_bytes) info_a = ep_a.endpoint_info() info_b = ep_b.endpoint_info() h_br = ep_a.register_remote_memory_region("rb", info_b["mr_info"]["b"]) def run_a(): ep_a.connect(info_b) - st = ep_a.async_read([(h_a, h_br, 0, 0, 6 * 4)]).wait() + st = ep_a.async_read([(h_a, h_br, 0, 0, n_bytes)]).wait() if st != 0: raise RuntimeError(f"read: {st}") - expected = torch.arange(6, dtype=torch.float32) - if not torch.equal(t_a[:6].cpu(), expected): + if not torch.equal(t_a, expected): raise RuntimeError("read data mismatch") ep_a.shutdown() def run_b(): ep_b.connect(info_a) - t_b[:6] = torch.arange(6, dtype=torch.float32) - time.sleep(0.2) + time.sleep(20) ep_b.shutdown() _sync_run(f"test_torch_read_{device}", run_a, run_b) -def test_torch_write_batch(device="cpu", n_batch=4): +def test_torch_write_batch(device="cpu", dtype=torch.float32, n_batch=4): """One async_write with multiple assignments.""" - if _dev_skip(device): - return - + dsize = 4 SZ = 64 - t_a = _make_tensor(SZ, device).zero_() - if device == "cuda": - t_b = torch.zeros(SZ, dtype=torch.float32) - b_ptr = t_b.data_ptr() - else: - t_b = _make_tensor(SZ, device) - b_ptr = t_b.data_ptr() + t_a_batch = [_make_tensor(SZ, device, dtype) for i in range(n_batch)] + t_b_batch = [_make_tensor(SZ, device, dtype) for i in range(n_batch)] + expected_batch = [i.clone() for i in t_a_batch] - n_bytes = SZ * 4 + n_bytes = SZ * dsize - ep_a = TcpEndpoint(port=0) - ep_b = TcpEndpoint(port=0) - h_a = ep_a.register_memory_region("a", t_a.data_ptr(), n_bytes) - h_b = ep_b.register_memory_region("b", b_ptr, n_bytes) + ep_a = TcpEndpoint(port=10107) + ep_b = TcpEndpoint(port=10108) + h_a_batch = [ep_a.register_memory_region(f"a_{i}", t_a_batch[i].data_ptr(), 0, n_bytes) for i in range(n_batch)] + h_b_batch = [ep_b.register_memory_region(f"b_{i}", t_b_batch[i].data_ptr(), 0, n_bytes) for i in range(n_batch)] info_a = ep_a.endpoint_info() info_b = ep_b.endpoint_info() - h_br = ep_a.register_remote_memory_region("rb", info_b["mr_info"]["b"]) + h_br_batch = [ep_a.register_remote_memory_region(f"rb_{i}", info_b["mr_info"][f"b_{i}"]) for i in range(n_batch)] def run_a(): ep_a.connect(info_b) - for i in range(n_batch): - t_a[i] = float(i + 1) - assigns = [(h_a, h_br, i * 4, i * 4, 4) for i in range(n_batch)] + assigns = [(h_a_batch[i], h_br_batch[i], i * dsize, i * dsize, dsize) for i in range(n_batch)] st = ep_a.async_write(assigns).wait() if st != 0: raise RuntimeError(f"write batch: {st}") @@ -559,55 +526,49 @@ def run_a(): def run_b(): ep_b.connect(info_a) - time.sleep(0.2) + time.sleep(3) for i in range(n_batch): - if t_b[i].item() != float(i + 1): - raise RuntimeError(f"batch {i}: expected {i+1}, got {t_b[i]}") + time.sleep(2) + if not torch.equal(t_b_batch[i][i], expected_batch[i][i]): + raise RuntimeError(f"batch {i}: mismatch") ep_b.shutdown() _sync_run(f"test_torch_write_batch_{device}", run_a, run_b) -def test_torch_read_batch(device="cpu", n_batch=4): +def test_torch_read_batch(device="cpu", dtype=torch.float32, n_batch=4): """One async_read with multiple assignments.""" - if _dev_skip(device): - return - + dsize = 4 SZ = 64 - t_a = _make_tensor(SZ, device).zero_() - if device == "cuda": - t_b = torch.zeros(SZ, dtype=torch.float32) - b_ptr = t_b.data_ptr() - else: - t_b = _make_tensor(SZ, device) - b_ptr = t_b.data_ptr() + t_a_batch = [_make_tensor(SZ, device, dtype) for i in range(n_batch)] + t_b_batch = [_make_tensor(SZ, device, dtype) for i in range(n_batch)] + expected_batch = [i.clone() for i in t_b_batch] - n_bytes = SZ * 4 + n_bytes = SZ * dsize - ep_a = TcpEndpoint(port=0) - ep_b = TcpEndpoint(port=0) - h_a = ep_a.register_memory_region("a", t_a.data_ptr(), n_bytes) - h_b = ep_b.register_memory_region("b", b_ptr, n_bytes) + ep_a = TcpEndpoint(port=10109) + ep_b = TcpEndpoint(port=10110) + h_a_batch = [ep_a.register_memory_region(f"a_{i}", t_a_batch[i].data_ptr(), 0, n_bytes) for i in range(n_batch)] + h_b_batch = [ep_b.register_memory_region(f"b_{i}", t_b_batch[i].data_ptr(), 0, n_bytes) for i in range(n_batch)] info_a = ep_a.endpoint_info() info_b = ep_b.endpoint_info() - h_br = ep_a.register_remote_memory_region("rb", info_b["mr_info"]["b"]) + h_br_batch = [ep_a.register_remote_memory_region(f"rb_{i}", info_b["mr_info"][f"b_{i}"]) for i in range(n_batch)] def run_a(): ep_a.connect(info_b) - assigns = [(h_a, h_br, i * 4, i * 4 + 16 * 4, 4) for i in range(n_batch)] + assigns = [(h_a_batch[i], h_br_batch[i], i * dsize, i * dsize, dsize) for i in range(n_batch)] st = ep_a.async_read(assigns).wait() if st != 0: raise RuntimeError(f"read batch: {st}") - expected = torch.tensor([1., 2., 3., 4.], dtype=torch.float32) - if not torch.equal(t_a[16:20].cpu(), expected): - raise RuntimeError("read back mismatch") ep_a.shutdown() def run_b(): ep_b.connect(info_a) + time.sleep(3) for i in range(n_batch): - t_b[i] = float(i + 1) - time.sleep(0.2) + time.sleep(2) + if not torch.equal(t_a_batch[i][i], expected_batch[i][i]): + raise RuntimeError(f"batch {i}: mismatch") ep_b.shutdown() _sync_run(f"test_torch_read_batch_{device}", run_a, run_b) @@ -617,18 +578,19 @@ def run_b(): if __name__ == "__main__": test_async_send_recv() - test_async_send_recv_one() + test_async_send2recv() test_async_write() test_async_read() - test_recv_timeout() - test_send_timeout_ms() - test_default_timeout() - test_exact_size_mismatch() - test_overflow_truncate() - - for dev in ("cpu", "cuda"): - test_torch_send_recv(dev) - test_torch_write(dev) - test_torch_read(dev) - # test_torch_write_batch(dev) - # test_torch_read_batch(dev) + + if not _torch_skip(): + device_list = ["cpu", "cuda"] + if _cuda_skip(): + print("No Cuda, Skip", flush = True) + device_list = ["cpu", ] + + for dev in device_list: + test_torch_send_recv(dev) + test_torch_write(dev) + test_torch_read(dev) + test_torch_write_batch(dev) + test_torch_read_batch(dev) diff --git a/dlslime/csrc/python/bind.cpp b/dlslime/csrc/python/bind.cpp index 1a3ed1d1..93eb641d 100644 --- a/dlslime/csrc/python/bind.cpp +++ b/dlslime/csrc/python/bind.cpp @@ -602,7 +602,7 @@ PYBIND11_MODULE(_slime_c, m) .def("is_connected", &dlslime::tcp::TcpEndpoint::is_connected) .def("register_memory_region", &dlslime::tcp::TcpEndpoint::register_memory_region, - py::arg("name"), py::arg("data_ptr"), py::arg("length"), + py::arg("name"), py::arg("data_ptr"), py::arg("offset"), py::arg("length"), py::call_guard()) .def("register_remote_memory_region", &dlslime::tcp::TcpEndpoint::register_remote_memory_region,