diff --git a/CMakeLists.txt b/CMakeLists.txt index b451b4e8..3f2c5475 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,6 +13,7 @@ slime_option(USE_MACA "USE in MACA Platform" OFF) slime_option(BUILD_NVLINK "Build NVLINK" OFF) slime_option(BUILD_ASCEND_DIRECT "Build Ascend direct transport" OFF) +slime_option(BUILD_TCP "Build TCP transport" ON) # Slime options for custom python wrapper slime_option(BUILD_PYTHON "Build python wrapper" OFF) diff --git a/dlslime/csrc/CMakeLists.txt b/dlslime/csrc/CMakeLists.txt index 0947f3c1..045f74ad 100644 --- a/dlslime/csrc/CMakeLists.txt +++ b/dlslime/csrc/CMakeLists.txt @@ -27,6 +27,10 @@ if(BUILD_RDMA) target_link_libraries(dlslime INTERFACE _slime_rdma) endif() +if(BUILD_TCP) + target_link_libraries(dlslime INTERFACE _slime_tcp) +endif() + # rpc/ has no independent C++ consumers (the session API is all # pybind11). It is compiled straight into _slime_c.so by the python # subdirectory below. Keep the source files here as a marker so future diff --git a/dlslime/csrc/device/cuda/cuda_signal.h b/dlslime/csrc/device/cuda/cuda_signal.h index 0c690ab9..3b66bde9 100755 --- a/dlslime/csrc/device/cuda/cuda_signal.h +++ b/dlslime/csrc/device/cuda/cuda_signal.h @@ -8,7 +8,7 @@ #include "dlslime/csrc/device/signal.h" #include "dlslime/csrc/engine/rdma/rdma_env.h" #include "dlslime/csrc/logging.h" -#include "dlslime/csrc/pause.h" +#include "dlslime/csrc/common/pause.h" #include "nvtx_helper.h" namespace dlslime { diff --git a/dlslime/csrc/engine/CMakeLists.txt b/dlslime/csrc/engine/CMakeLists.txt index c03c88c7..c9bffdf5 100755 --- a/dlslime/csrc/engine/CMakeLists.txt +++ b/dlslime/csrc/engine/CMakeLists.txt @@ -34,3 +34,7 @@ endif() if (BUILD_RDMA) add_subdirectory(rdma) endif() + +if (BUILD_TCP) + add_subdirectory(tcp) +endif() diff --git a/dlslime/csrc/engine/tcp/CMakeLists.txt b/dlslime/csrc/engine/tcp/CMakeLists.txt new file mode 100644 index 00000000..ab69db49 --- /dev/null +++ b/dlslime/csrc/engine/tcp/CMakeLists.txt @@ -0,0 +1,47 @@ +# asio is header-only. Try find_package, fall back to manual detection. +find_package(asio QUIET) +if(NOT asio_FOUND) + if(EXISTS /usr/include/asio.hpp) + add_library(asio::asio INTERFACE IMPORTED) + target_include_directories(asio::asio INTERFACE /usr/include) + elseif(EXISTS /usr/include/boost/asio.hpp) + add_library(asio::asio INTERFACE IMPORTED) + target_include_directories(asio::asio INTERFACE /usr/include/boost) + target_compile_definitions(asio::asio INTERFACE ASIO_STANDALONE) + else() + message(FATAL_ERROR "asio not found. Install libasio-dev or boost.") + endif() +endif() + +add_library(_slime_tcp SHARED + tcp_memory_pool.cpp + tcp_connection_pool.cpp + tcp_context.cpp + tcp_session.cpp + tcp_endpoint.cpp +) + +target_compile_definitions(_slime_tcp PRIVATE ASIO_STANDALONE) + +if (USE_CUDA) + find_package(CUDAToolkit REQUIRED) + target_compile_definitions(_slime_tcp PRIVATE USE_CUDA) + target_include_directories(_slime_tcp PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) + target_link_libraries(_slime_tcp PUBLIC CUDA::cudart) +endif() + +target_link_libraries(_slime_tcp PUBLIC + asio::asio + _slime_device + _slime_engine +) + +set_target_properties(_slime_tcp PROPERTIES + BUILD_WITH_INSTALL_RPATH TRUE + INSTALL_RPATH "${ORIGIN}" +) + +install(TARGETS _slime_tcp + EXPORT dlslimeTargets + LIBRARY DESTINATION ${DLSLIME_INSTALL_PATH} +) diff --git a/dlslime/csrc/engine/tcp/build_and_test.sh b/dlslime/csrc/engine/tcp/build_and_test.sh new file mode 100755 index 00000000..d30e99c9 --- /dev/null +++ b/dlslime/csrc/engine/tcp/build_and_test.sh @@ -0,0 +1,65 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/../../../.." && pwd)" +BUILD_DIR="$REPO_ROOT/build_tcp" +MODE="${1:-all}" + +# Optional: USE_CUDA=ON ./build_and_test.sh all +USE_CUDA="${USE_CUDA:-OFF}" + +header() { echo; echo -e "\033[1;36m==>\033[m \033[1m$*\033[m"; } +ok() { echo -e " \033[1;32mOK\033[m $*"; } + +do_build() { + rm -rf ${BUILD_DIR} + local cuda_label="" + [[ "$USE_CUDA" == "ON" ]] && cuda_label=" + USE_CUDA=ON" + + header "Configuring (BUILD_TCP=ON, BUILD_RDMA=OFF${cuda_label})" + cmake -S "$REPO_ROOT" -B "$BUILD_DIR" -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DDLSLIME_INSTALL_PATH=dlslime \ + -DBUILD_PYTHON=ON \ + -DBUILD_RDMA=OFF \ + -DBUILD_TCP=ON \ + -DBUILD_NVLINK=OFF \ + -DBUILD_ASCEND_DIRECT=OFF \ + -DUSE_CUDA="$USE_CUDA" \ + -DSKBUILD_PROJECT_NAME=dlslime 2>&1 | tail -3 + ok "CMake configure" + + header "Building _slime_c" + cmake --build "$BUILD_DIR" --target _slime_c -j"$(nproc)" 2>&1 | tail -8 + ok "Build complete" + + cp "$BUILD_DIR/lib/"*.so "$REPO_ROOT/dlslime/" + ok "Copied .so files to dlslime/" +} + +do_test() { + header "Running TcpEndpoint tests" + export SLIME_LOG_LEVEL=1 + export LD_LIBRARY_PATH="$REPO_ROOT/dlslime" + export PYTHONPATH="$REPO_ROOT" + python3 "$SCRIPT_DIR/test_tcp_endpoint.py" 2>&1 | while IFS= read -r line; do + if [[ "$line" == *"PASSED"* ]]; then echo -e " \033[1;32m✓\033[m $line" + elif [[ "$line" == *"SKIP"* ]]; then echo -e " \033[1;33m⊘\033[m $line" + elif [[ "$line" == *"FAIL"* ]]; then echo -e " \033[1;91m✗\033[m $line" + else echo " $line" + fi + done + echo " tests done " +} + +case "$MODE" in + all) do_build; do_test ;; + build) do_build ;; + test) do_test ;; + clean) rm -rf "$BUILD_DIR" "$REPO_ROOT/dlslime/_slime_c"*.so "$REPO_ROOT/dlslime/lib_slime_"*.so + ok "Cleaned" ;; + *) echo "Usage: $0 {all|build|test|clean}" >&2 + echo " USE_CUDA=ON $0 all # build + test with CUDA" >&2 + exit 1 ;; +esac diff --git a/dlslime/csrc/engine/tcp/plan.md b/dlslime/csrc/engine/tcp/plan.md new file mode 100755 index 00000000..915a8240 --- /dev/null +++ b/dlslime/csrc/engine/tcp/plan.md @@ -0,0 +1,729 @@ +# DLSlime TcpEndpoint v3 Primitives 架构与实现计划 + +**分支**: `tcp-v3` | **基准**: `main` | **日期**: 2026-05-14 + +--- + +## 1. 架构设计 + +### 1.1 总体架构 + +``` +┌──────────────────────────────────────────────────────────────┐ +│ Python 调用者线程 │ +│ ep.async_send(chunk, timeout_ms=30000) → Future │ +│ ep.async_recv(chunk, timeout_ms=30000) → Future │ +│ ep.async_read(assign, timeout_ms=30000) → Future │ +│ ep.async_write(assign, timeout_ms=30000) → Future │ +│ │ │ +│ │ post lambda │ +│ ▼ │ +│ ┌──────────────────────┐ ┌─────────────────────────────┐ │ +│ │ asio::io_context │ │ TcpConnectionPool │ │ +│ │ (单后台线程) │◄───│ (host, port) → deque │ │ +│ │ │ │ IDLE / ACTIVE / RESERVED │ │ +│ │ async_write ────────┼───►│ 60s 空闲超时 │ │ +│ │ async_read ◄────────┼───►│ │ │ +│ │ async_accept ───────┼───►│ ServerSession │ │ +│ │ │ │ (readHeader→dispatch→ │ │ +│ │ │ │ readBody→readHeader 循环) │ │ +│ └──────────────────────┘ └─────────────────────────────┘ │ +└──────────────────────────────────────────────────────────────┘ +``` + +### 1.2 线程模型 + +| 角色 | 线程 | 职责 | +|------|------|------| +| io_context | 1 个 daemon 线程 | `io_ctx_.run()` — 所有 asio async I/O 回调 | +| 调用者 | N 个 Python 线程 | 调 async_* → 立即返回 Future;wait() 自旋阻塞 | +| accept | io_context | `async_accept` 回调链,每连接创建 ServerSession | + +### 1.3 asio 操作模型 + +``` +调用者线程 io_context 线程 +────────── ────────────── +async_send(chunk, 5000): + ├─ getConnection() [sync, fast] ┌─ async_write(header+payload) + ├─ SO_SNDTIMEO=5s │ → 归还连接 → signal op_state + ├─ asio::post(lambda) ──────────────► │ + └─ return Future ◄─── signal ────────┘ + +async_recv(chunk, 5000): + ├─ pending_recvs_.push(op_state) ┌─ ServerSession::dispatch(OP_SEND) + └─ return Future │ → pop pending_recvs_ + │ │ → memcpy → signal op_state + └── wait_for(5.0) ── timeout? ──┘ + +async_read(assign, 5000): + ├─ getConnection() [RESERVE] ┌─ async_write(OP_READ header) + ├─ asio::post(lambda) ──────────────► │ → async_read(response data) + └─ return Future ◄─── signal ────────┘ → 归还连接 → signal op_state + +async_write(assign, 5000): + ├─ getConnection() [sync, fast] ┌─ async_write(header+payload) + ├─ SO_SNDTIMEO=5s │ → 归还连接 → signal op_state + ├─ asio::post(lambda) ──────────────► │ + └─ return Future ◄─── signal ────────┘ +``` + +--- + +## 2. 线协议设计 + +### 2.1 SessionHeader (17 字节,对齐 Mooncake) + +``` +偏移 大小 字段 +0 8 size (payload 字节数, little-endian: htole64 / le64toh) +8 8 addr (远端 buffer 虚拟地址) +16 1 opcode (操作码) +───────────────── + 17 bytes total +``` + +### 2.2 为什么 3 个 opcode 支持 4 个原语? + +OP_SEND 同时承载 `async_send`(发起方主动 push 数据)和 `async_recv`(接收方 +被动等待)。recv 方不在线上发送任何操作码——它只是向本地 `pending_recvs_` 队列注册 +一个 buffer,然后对端 ServerSession 在收到 OP_SEND 时通过 `RecvMatcher` 回调 pop +队列前端、memcpy 数据并 signal op_state。 + +这与 Mooncake 的设计一致:ServerSession::dispatch(OP_SEND) 先分块读取 payload, +然后通过 recv_matcher_ 匹配本地注册的 recv buffer。不需要独立的 recv opcode—— +SEND 到达本身就隐含了"有一端在等待"的语义。 + +OP_READ 和 OP_WRITE 各需独立 opcode,因为服务端 dispatch 分支逻辑完全不同: +- OP_READ:读取本地内存后异步写回原始数据(无 header) +- OP_WRITE:读取 payload 后 memcpy 到 hdr.addr + +如果有 4 个 opcode(比如独立的 OP_RECV),反而增加冗余——OP_RECV 在语义上等于 +"我准备好接收了",但这已在连接建立时通过 endpoint_info 交换 MR 信息隐式表达, +不需要每个操作发一次。 + +| opcode | 值 | 线格式 | 远端 ServerSession 动作 | DLSlime 原语 | +|--------|-----|--------|------------------------|-------------| +| `OP_SEND` | 0x00 | header{sz, 0, 0x00} + payload | 读 payload → recv_matcher pop → memcpy → signal | **async_send** (发起) / **async_recv** (被动) | +| `OP_READ` | 0x01 | 仅 header{sz, addr, 0x01} | 从本地 addr 读 sz 字节 → async_write 原始数据发回 | **async_read** (调用者 pull) | +| `OP_WRITE` | 0x02 | header{sz, addr, 0x02} + payload | 读 payload → memcpy 到本地 addr | **async_write** (调用者 push) | + +### 2.3 四个原语在线上的完整流程 + +``` +async_send(chunk): + 调用者: getConnection → post to io_ctx → return Future + io_ctx: async_write(sock, [header{OP_SEND}|payload]) + → on_complete: returnConnection → signal op_state + 对端 ServerSession: async_read(header) → dispatch(OP_SEND) + → chunk_buf_.resize → readBody 分块读 payload → recv_matcher_() + → pop pending_recv → memcpy → signal recv op_state + +async_recv(chunk): + 调用者: pending_recvs_.push({buffer, op_state}) → return Future → wait_for(timeout) + (无 opcode 在线路上 — recv 是 SEND 的被动消费方) + +async_read(assign): + 调用者: getConnection(RESERVED) → post to io_ctx → return Future + io_ctx: async_write(sock, header{OP_READ, sz, remote_addr}) + → async_read(sock, user_buffer, sz) + → on_complete: returnConnection → signal op_state + 对端 ServerSession: async_read(header) → dispatch(OP_READ) + → async_write(sock, local[addr], sz) → readHeader 继续 + +async_write(assign): + 调用者: getConnection → post to io_ctx → return Future + io_ctx: async_write(sock, [header{OP_WRITE, sz, remote_addr}|payload]) + → on_complete: returnConnection → signal op_state + 对端 ServerSession: async_read(header) → dispatch(OP_WRITE) + → chunk_buf_.resize → readBody 分块读 payload → memcpy 到 addr +``` + +--- + +## 3. 接口设计 + +### 3.1 C++ TcpEndpoint 公共接口 + +```cpp +class TcpEndpoint : public std::enable_shared_from_this { +public: + // 默认超时 30 秒 + static constexpr int64_t kDefaultTimeoutMs = 30000; + + // ── 构造 ── + + // 【主构造】ip 绑定网卡地址 (默认 0.0.0.0), port=0 随机端口 + explicit TcpEndpoint(const std::string& ip = "0.0.0.0", uint16_t port = 0); + + // 【次构造】共享 TcpContext — 暂禁用 + // (涉及 context 所有权 / conn_pool 跨 endpoint 管理 / 析构顺序) + TcpEndpoint(TcpContext& ctx, uint16_t port = 0) = delete; + + // ── 连接 ── + json endpoint_info() const; // {host, port, mr_info} + void connect(const json& remote_info); + void shutdown(); + + // ── 内存 ── + int32_t register_memory_region(const std::string& name, + uintptr_t ptr, uintptr_t offset, size_t length); + int32_t register_remote_memory_region(const std::string& name, + const json& mr_info); + json mr_info() const; + + // ── 异步通信原语 (全部返回 Future, I/O 在 io_context 线程) ── + // + // timeout_ms 由调用者通过 future.wait_for() 控制实际操作时限; + // 方法签名的 timeout_ms 仅作为 op_state 的提示值传入。 + // recv 的超时完全由 future.wait_for() 控制, 不需要 timeout_ms 参数。 + + // 双边发送 + std::shared_ptr async_send( + const chunk_tuple_t& chunk, + int64_t timeout_ms = kDefaultTimeoutMs); + + // 双边接收 (超时通过 future.wait_for()) + std::shared_ptr async_recv( + const chunk_tuple_t& chunk); + + // 单边读 + std::shared_ptr async_read( + const std::vector& assign, + int64_t timeout_ms = kDefaultTimeoutMs); + + // 单边写 + std::shared_ptr async_write( + const std::vector& assign, + int64_t timeout_ms = kDefaultTimeoutMs); + + // ── 访问器 ── + void setId(int64_t id); + int64_t getId() const; + bool is_connected() const; +}; +``` + +### 3.2 C++ TcpFuture 接口 + +```cpp +class TcpFuture : public DeviceFuture { +public: + // 无限期阻塞等待 + int32_t wait() const override; + + // 限时等待: timeout_ms 毫秒, 成功返回 true 并写 *out + // 超时返回 false (操作仍在进行, 可重试) + bool wait_for(int64_t timeout_ms, int32_t* out) const; +}; + +class TcpSendFuture : public TcpFuture { }; +class TcpRecvFuture : public TcpFuture { }; +class TcpReadWriteFuture : public TcpFuture { }; +``` + +### 3.3 Python 接口 + +```python +from dlslime import TcpEndpoint, TcpMemoryPool + +pool = TcpMemoryPool() +buf = ctypes.create_string_buffer(4096) +h = pool.register_memory_region(ctypes.addressof(buf), 0, 4096, "buf") + +ep = TcpEndpoint(port=0) # 0 = 随机端口 +info = ep.endpoint_info() # {'host': '...', 'port': N, 'mr_info': {...}} + +ep.connect(peer_info) + +# ── 异步原语, 默认 30s 超时 ── +fut = ep.async_send((h, 0, 128)) # 30s 默认超时 +fut = ep.async_send((h, 0, 128), 5000) # 5s 超时 +status = fut.wait() # 阻塞直到完成, 返回 0=成功 + +fut = ep.async_recv((h, 0, 128)) # 超时通过 future 控制 +result = fut.wait_for(3.0) # 3 秒超时, 返回 int 或 None + +fut = ep.async_read([(local_h, remote_h, 0, 0, 128)]) +fut = ep.async_write([(local_h, remote_h, 0, 0, 128)]) + +ep.shutdown() +``` + +--- + +## 4. 通信原语设计详解 + +### 4.1 async_send(chunk, timeout_ms = 30000) + +**语义**: 将本地注册内存的数据异步发送到对端。对端必须已调用 `async_recv()` 注册接收缓冲区。 + +**调用者线程**: +1. `local_pool_->get_mr_fast(mr_key)` — resolve 本地 MR +2. `conn_pool_.getConnection(peer_host_, peer_port_)` — 获取或创建 TCP 连接 +3. `TcpOpState::create()` + `signal->reset_all()` — 创建完成信号 +4. 如果 `timeout_ms > 0`: `setsockopt(fd, SO_SNDTIMEO, timeout_ms)` +5. `asio::post(io_ctx_, lambda)` — 提交到 io_context +6. 立即返回 `TcpSendFuture(op_state)` + +**io_context 线程**: +1. `hdr_hton()` — 字节序转换 header +2. `asio::async_write(sock, [header_buf, payload_buf], callback)` — gather write +3. callback: + - 如果 `timeout_ms > 0`: 恢复 `SO_SNDTIMEO = 0` + - `op->completion_status = ec ? TCP_FAILED : TCP_SUCCESS` + - `conn_pool_.returnConnection(conn)` + - `op->signal->set_comm_done(0)` + +**超时行为**: socket 写超时 → write 失败 → completion_status = TCP_FAILED。调用者 `future.wait()` 得到 -1。 + +### 4.2 async_recv(chunk, timeout_ms = 30000) + +**语义**: 注册接收意图。当对端 `async_send()` 的数据到达时,io_context 线程自动匹配并 memcpy 到注册的 buffer。 + +**调用者线程**: +1. `local_pool_->get_mr_fast(mr_key)` — resolve 本地 MR +2. `TcpOpState::create()` + 设置 `user_buffer`, `user_length` +3. `pending_recvs_.push_back({op_state})` — FIFO 入队 +4. 立即返回 `TcpRecvFuture(op_state)` + +**io_context 线程** (ServerSession::dispatch, OP_SEND 分支): +1. `readBody()` — 分块读取 payload 到 `chunk_buf_` +2. `RecvSlot slot = recv_matcher_()` — pop FIFO 前端 +3. `memcpy(slot.buffer, chunk_buf_.data(), min(payload_len, slot.length))` +4. `slot.op_state->completion_status = TCP_SUCCESS` +5. `slot.op_state->signal->set_comm_done(0)` + +**超时行为**: 调用者使用 `future.wait_for(timeout_ms)` 限时等待。超时返回 None,但 recv 保留在队列中——后续到达的 SEND 仍会完成它(调用者可重试)。 + +### 4.3 async_read(assign, timeout_ms = 30000) + +**语义**: 从对端的注册内存异步读取数据。两步异步操作:发 OP_READ header → 收原始响应数据。 + +**调用者线程**: +1. resolve local + remote MRs +2. `conn_pool_.getConnection(peer_host_, peer_port_)` — RESERVE 连接 +3. `TcpOpState::create()` + 设置 `user_buffer`, `user_length` +4. `asio::post(io_ctx_, lambda)` — 提交到 io_context +5. 立即返回 `TcpReadWriteFuture(op_state)` + +**io_context 线程**: +1. `hdr_hton()` → `asio::async_write(sock, header_buf, callback_1)` +2. callback_1: 如果写失败 → signal TCP_FAILED + returnConnection +3. `asio::async_read(sock, user_buffer_buf, callback_2)` +4. callback_2: + - `op->completion_status = ec ? TCP_FAILED : TCP_SUCCESS` + - `conn_pool_.returnConnection(conn)` + - `op->signal->set_comm_done(0)` + +**对端 ServerSession** (OP_READ 分支): +1. 从 `hdr.addr` 读取 `hdr.size` 字节本地内存 +2. `asio::async_write(sock, raw_data, callback)` — 直接写回原始数据(无 header) +3. `readHeader()` — 继续监听下个请求 + +**超时行为**: `future.wait_for(timeout_ms)`。连接在整个读取期间被 RESERVED,超时后操作继续在后台运行。 + +### 4.4 async_write(assign, timeout_ms = 30000) + +**语义**: 将本地注册内存的数据异步写入对端注册内存。 + +与 `async_send` 相同的 post+async_write 模式,区别: +- header.opcode = OP_WRITE +- header.addr = remote_addr(对端目标 buffer 地址) +- 对端 ServerSession dispatch(OP_WRITE) → readBody → memcpy 到 `hdr.addr` + +**超时行为**: 同 async_send — SO_SNDTIMEO + future.wait_for()。 + +--- + +## 5. 连接池设计 + +### 5.1 状态机 + +``` + getConnection() + [不存在] ────────────────────────► [ACTIVE] (in_use=true) + │ + returnConnection() + │ + ▼ + [IDLE] (in_use=false, 在 deque 中) ──► 300s 无使用 → cleanupIdleConnections() → 关闭 + │ + │ getConnection() 命中 + ▼ + [ACTIVE] (in_use=true, 离开 deque) +``` + +### 5.2 接口 + +```cpp +class TcpConnectionPool { + // 获取 IDLE 连接或创建新 TCP 连接 + std::shared_ptr getConnection(host, port); + + // 归还连接到 IDLE 状态 (或关闭, 如果 socket 已断开) + void returnConnection(std::shared_ptr conn); + + // 淘汰超过 kIdleTimeout (300s) 的空闲连接 + void cleanupIdleConnections(); + + // 关闭所有连接 (shutdown 时调用) + void clear(); +}; +``` + +--- + +## 6. ServerSession 设计 + +### 6.1 生命周期 + +``` +acceptor.async_accept(socket) + → ServerSession(socket, local_pool, recv_matcher) + → session->start() + → readHeader() ──────────────────────────────────────┐ + → async_read(sock, 17B header) │ + → hdr_to_host() │ + → dispatch() │ + ├─ OP_SEND: chunk_buf_.resize → readBody() │ + │ → memcpy → recv_matcher_() → signal │ + ├─ OP_WRITE: chunk_buf_.resize → readBody() │ + │ → memcpy → hdr.addr │ + └─ OP_READ: async_write(sock, local[addr]) │ + → readHeader() ──────────────────────────────────┘ +``` + +### 6.2 RecvMatcher + +```cpp +// ServerSession 持有的回调, 由 TcpEndpoint 注入 +using RecvMatcher = std::function; + +// TcpEndpoint::make_recv_matcher(): +// 返回一个 lambda, 持有 weak_ptr +// 在 recv_mu_ 下 pop pending_recvs_ 队列前端 +// 返回 {buffer, length, op_state} +``` + +--- + +## 7. 文件结构 + +### 新建文件 + +``` +dlslime/csrc/engine/tcp/ +├── CMakeLists.txt # asio 依赖 + _slime_tcp 共享库 +├── tcp_header.h # 17B SessionHeader + 3 opcodes +├── tcp_memory_pool.h/.cpp # 纯簿记 (addr, offset, length) +├── tcp_context.h/.cpp # 共享 io_context + connection_pool + thread +├── tcp_session.h/.cpp # ServerSession (accept 端) + 分块 I/O +├── tcp_connection_pool.h/.cpp # (host, port) 连接池 +├── tcp_op_state.h # 操作状态 (signal + atomic status) +├── tcp_future.h # TcpFuture 层次 (header-only) +├── tcp_endpoint.h/.cpp # TcpEndpoint: async_send/recv/read/write +├── build_and_test.sh # 一键构建+测试 +└── test_tcp_endpoint.py # Python 端到端测试 (4 用例) +``` + +### 修改文件 + +| 文件 | 变更 | +|------|------| +| `CMakeLists.txt` | `slime_option(BUILD_TCP "Build TCP transport" ON)` | +| `dlslime/csrc/engine/CMakeLists.txt` | `if(BUILD_TCP) add_subdirectory(tcp) endif()` | +| `dlslime/csrc/CMakeLists.txt` | `if(BUILD_TCP) target_link_libraries(dlslime INTERFACE _slime_tcp) endif()` | +| `dlslime/csrc/python/CMakeLists.txt` | `if(BUILD_TCP) target_compile_definitions + list(APPEND ... _slime_tcp) endif()` | +| `dlslime/csrc/python/bind.cpp` | `#ifdef BUILD_TCP` — TcpEndpoint, TcpMemoryPool, TcpFuture pybind11 bindings | + +--- + +## 8. 超时机制总结 + +| 原语 | 超时位置 | 默认值 | 实现方式 | +|------|---------|--------|---------| +| async_send | socket write | 30000ms | `setsockopt(SO_SNDTIMEO)` + `future.wait_for()` | +| async_recv | 等待数据到达 | 30000ms | `future.wait_for(timeout_ms)` — 定时自旋轮询 signal | +| async_read | 等待远端响应 | 30000ms | `future.wait_for(timeout_ms)` — 定时自旋轮询 signal | +| async_write | socket write | 30000ms | `setsockopt(SO_SNDTIMEO)` + `future.wait_for()` | + +**wait_for 实现**: +```cpp +bool TcpFuture::wait_for(int64_t timeout_ms, int32_t* out) const { + auto deadline = steady_clock::now() + milliseconds(timeout_ms); + while (true) { + if (signal->get_comm_done_mask() matches expected_mask) { + *out = completion_status; return true; + } + if (steady_clock::now() >= deadline) { + // last check before declaring timeout + if (signal->get_comm_done_mask() matches expected_mask) { + *out = completion_status; return true; + } + return false; + } + machnet_pause(); // CPU relax + } +} +``` + +--- + +## 11. 实现步骤 + +| 阶段 | 文件 | 说明 | +|------|------|------| +| 1. 分支 | `git checkout -b tcp-v4` | 基于 v3 创建了新分支 | +| 2. 头文件 | tcp_header.h, tcp_op_state.h | 17B header + 3 opcodes + op state | +| 3. 内存池 | tcp_memory_pool.h/.cpp | 纯簿记, 无硬件注册 | +| 4. Future | tcp_future.h | header-only, wait + wait_for | +| 5. Context | tcp_context.h/.cpp | 共享 io_context + connection_pool + thread | +| 6. 连接池 | tcp_connection_pool.h/.cpp | get/return/cleanup/clear | +| 7. Session | tcp_session.h/.cpp | ServerSession async_read 回调链 | +| 8. 端点 | tcp_endpoint.h/.cpp | async_send/recv/read/write | +| 9. 构建 | CMakeLists 链 + bind.cpp | BUILD_TCP + pybind11 | +| 10. 测试 | test_tcp_endpoint.py | 5 用例 + timeout 测试 | +| 11. 脚本 | build_and_test.sh | 一键构建+测试 | +| 12. 提交 | git commit | 单 commit, 清晰消息 | + +--- + +## 9. send/recv 设计深度分析 + +### 核心矛盾:RDMA vs TCP 的 send/recv 语义差异 + +RDMA 的 send/recv 是**硬件匹配**的: +- 发送方 post Send WR → 硬件从本地 buffer 取数据 → 发到对端 RQ +- 接收方 post Recv WR → 硬件在 RQ 上预置 WQE (buffer地址 + 长度) +- 硬件按**FIFO 顺序**匹配:第 N 个到达的 SEND 消费第 N 个预置的 RECV +- 如果 SEND 到达时没有 RECV → RNR NAK (Receiver Not Ready) → 发送方重试 +- 如果 SEND 数据量 > RECV buffer → 截断或报错 + +TCP **没有硬件匹配**,所有匹配逻辑必须在软件中实现。这带来了三个核心问题: + +| 问题 | RDMA 方案 | TCP 需要解决 | +|------|---------|------------| +| 匹配: 哪个 SEND 对哪个 RECV? | 硬件 RQ FIFO | 软件队列或 tag 匹配 | +| 顺序: SEND 先到还是 RECV 先到? | 硬件 RNR 重试 | 缓冲或拒绝 | +| 大小: 发送量 > 接收 buffer? | 截断/报错 | 截断或分片 | + +### 三种匹配策略 + +#### 策略 A: FIFO 队列匹配(v3 plan 默认) + +``` +recv(chunk) → pending_recvs_.push_back({buffer, op_state}) +ServerSession dispatch(OP_SEND): + payload = readBody() + slot = recv_matcher_() // pop front + memcpy(slot.buffer, payload, min(len, slot.length)) + signal slot.op_state +``` + +**优点**: 实现简单,与 RDMA 语义一致,足够支持双端点 ping-pong 通信。 +**缺点**: 严格 FIFO——调用者无法指定"这个 recv 对应后面第 N 个 send"。多 slot 场景(如 SlimeRPC 的 slotted mailbox)无法用 FIFO 区分。 + +#### 策略 B: Tag 匹配(Gloo 风格) + +``` +wire: [header{OP_SEND, sz, tag}] + payload +recv(tag, buffer) → pending_recvs_[tag].push({buffer, op_state}) +ServerSession dispatch(OP_SEND): + payload = readBody() + slot = pending_recvs_[hdr.addr_as_tag].pop() + memcpy(slot.buffer, payload) +``` + +**优点**: 灵活,支持多路复用——一个 TCP 连接可以承载多个逻辑流(如 RPC slot)。 +**缺点**: header.addr 字段被复用为 tag(牺牲了 addr 的原始语义),协议复杂度增加。 + +#### 策略 C: Slot 预注册(Gloo Buffer 风格) + +``` +每个 Pair 预先创建 N 个 slot buffer: + pair.createSendBuffer(slot=0, ptr, size) + pair.createRecvBuffer(slot=1, ptr, size) +wire: [header{OP_SEND, sz, 0, slot}] + payload +ServerSession: 直接 lookup slot → memcpy +``` + +**优点**: 零队列开销,O(1) slot 查找,SlimeRPC 天然适配。 +**缺点**: 需要预注册 slot(与当前 DLSlime MR 模型不兼容),灵活度低。 + +### 推荐策略:分层渐进 + +``` +Phase 1 (v3) — FIFO 基础: + pending_recvs_ = deque<{buffer, op_state}> + wire: header{OP_SEND, sz, addr=0} + 匹配: 严格 FIFO + 足够: 双端点 ping-pong、简单 RPC + +Phase 2 — 缓冲早到 SEND: + early_sends_ = deque<{payload_data}> + 如果 dispatch(OP_SEND) 时 pending_recvs_ 为空: + → 缓存 payload 到 early_sends_(带大小上限) + → 下次 recv() 先检查 early_sends_ 再入队 + 避免数据丢失 + +Phase 3 — Tag 匹配 (如需要): + 扩展 header: 用 2 字节 reserved 字段承载 tag + pending_recvs_ = map> + 支持多路复用 +``` + +### send/recv 与 read/write 的本质区别 + +很多人混淆 send/recv 和 write/read: + +| | send/recv | write/read | +|---|---|---| +| 语义 | **双边**:双方都需要显式操作 | **单边**:一方发起,另一方无感知 | +| 数据方向 | send=push, recv=pull (被动) | write=push to remote addr, read=pull from remote addr | +| 远端参与 | recv 方必须预先注册 buffer | 远端 ServerSession 自动处理,无需注册 | +| 寻址方式 | **无地址**(匹配决定目标 buffer) | **有地址**(header.addr 指定远端 buffer) | +| RDMA 类比 | ibv_post_send / ibv_post_recv | ibv_post_send with RDMA_WRITE/RDMA_READ | + +核心洞察:**send/recv 的"地址"是隐式的——通过匹配关系决定; +write/read 的"地址"是显式的——header.addr 直接指向远端内存。** + +这就是为什么 v3 plan 中: +- OP_SEND: header.addr = 0(不使用),通过 FIFO 匹配目标 buffer +- OP_WRITE: header.addr = remote_addr(直接指定远端目标地址) +- OP_READ: header.addr = remote_addr(直接指定远端源地址) + +### v3 实现策略 + +v3 采用策略 A(FIFO),但为策略 C(slot)预留空间: + +```cpp +// 当前: deque — 简单 FIFO +std::deque pending_recvs_; + +// Phase 3 可演进为: map — tag 匹配 +// std::unordered_map> pending_recvs_; +// 同时扩展 header: 用 reserved 字段承载 tag + +void TcpEndpoint::async_recv(const chunk_tuple_t& chunk, + int64_t timeout_ms, void*) { + // resolve MR → op_state → push to FIFO + // Phase 3: push to pending_recvs_[tag] instead +} + +ServerSession::dispatch(OP_SEND): + readBody() → chunk_buf_ + RecvSlot slot = recv_matcher_() + if (slot.buffer == 0): + // Phase 2: buffer early send to early_sends_ + return + memcpy(slot.buffer, chunk_buf_, min(payload_len, slot.length)) + signal slot.op_state +``` + +**recv timeout 语义**(区别于 socket timeout): +- SO_RCVTIMEO 是 socket 级超时(读数据超时) +- `future.wait_for()` 是**注册后等待匹配**的超时 +- 超时后 recv 保留在队列中:后续 SEND 仍可完成它(调用者可重试 wait_for) + +## 12. 验证计划 + +```bash +# 构建 +./dlslime/csrc/engine/tcp/build_and_test.sh build + +# 测试 +./dlslime/csrc/engine/tcp/build_and_test.sh test + +# 全流程 +./dlslime/csrc/engine/tcp/build_and_test.sh +``` + +**测试用例**: +1. `test_async_send_recv` — A async_send → B async_recv, B async_send → A async_recv +2. `test_async_write_read` — A async_write → B buffer, A async_read → verify +3. `test_recv_timeout` — async_recv + wait_for(0.3s) → None (无对端发送) +4. `test_send_timeout` — async_send(timeout_ms=10000) 参数 +5. `test_default_timeout` — async_send() 无参数 → 使用 30000ms 默认值 + +## 10. TcpContext 设计 — 为同步通信和资源共享做准备 + +### 使用优先级 + +TcpContext 类始终存在,ctx_ 成员始终非空。但构造方式有两种优先级: + +| 优先级 | 构造 | 场景 | 占比 | +|--------|------|------|------| +| **主** | `TcpEndpoint(port)` | 单 endpoint, 内部自动 new TcpContext | ~90% | +| **次** | `TcpEndpoint(ctx, port)` | 多 endpoint 共享 io_context 线程 | ~10% | + +**默认路径**:调用者无需感知 TcpContext——每个 endpoint 构造时内部 `make_unique()`, +自动创建 io_context + 后台线程 + 连接池。代码最简洁。 + +**高级路径**:当 PeerAgent 连接 N 个 peer 时,可手动创建一个 TcpContext 并注入到 N 个 +TcpEndpoint,将 N 个线程合并为 1 个。TcpContext 也用于测试中精确控制 io_context 生命周期。 + +两种路径不互斥——同一进程可混合使用。TcpContext 类永不删除,ctx_ 成员永不删除。 + +### TcpContext 接口 + +```cpp +class TcpContext { +public: + TcpContext(); // 创建 io_context + 启动后台线程 + ~TcpContext(); // stop + join + clear pool + + asio::io_context& io_context() { return io_ctx_; } + TcpConnectionPool& conn_pool() { return conn_pool_; } + void shutdown(); + +private: + asio::io_context io_ctx_; + std::thread io_thread_; + TcpConnectionPool conn_pool_{io_ctx_}; + bool running_{true}; +}; +``` + +### TcpEndpoint 与 TcpContext 的关系 + +```cpp +class TcpEndpoint { + // 【主构造】自包含 — 内部创建 TcpContext + explicit TcpEndpoint(uint16_t port = 0) + : own_ctx_(std::make_unique()) // 自动创建 + , acceptor_(own_ctx_->io_context()) + , ... { + ctx_ = own_ctx_.get(); // ctx_ → 内部 context + } + + // 【次构造】共享 — 注入外部 TcpContext + TcpEndpoint(TcpContext& ctx, uint16_t port = 0) + : acceptor_(ctx.io_context()) + , ... { + ctx_ = &ctx; // ctx_ → 外部 context, own_ctx_ = nullptr + } + +private: + TcpContext* ctx_{nullptr}; // 始终非空 + std::unique_ptr own_ctx_; // 仅主构造时非空 + // ... +}; +``` + +### 为同步通信预留 + +有了共享 TcpContext,同步包装器可以不依赖单个 endpoint 的事件循环: + +```cpp +// 未来 sync_send: 调 async_send + 立刻 future.wait() +std::shared_ptr sync_send(TcpEndpoint& ep, + const chunk_tuple_t& chunk, + int64_t timeout_ms = 30000) { + auto fut = ep.async_send(chunk, timeout_ms); + fut->wait(); // 阻塞调用者线程直到 io_context 完成 + return fut; +} +``` + +同步版本只是 async + wait() 的语法糖,不需要独立的底层实现。 \ No newline at end of file diff --git a/dlslime/csrc/engine/tcp/plan_v4.md b/dlslime/csrc/engine/tcp/plan_v4.md new file mode 100644 index 00000000..10f9d25a --- /dev/null +++ b/dlslime/csrc/engine/tcp/plan_v4.md @@ -0,0 +1,315 @@ +# TcpEndpoint v4 — Future / OpState / Session / Primitive 关系重构 + +**状态**: 已实现并测试通过 (2026-05-18) + +## 已实现功能 + +- 4 个 async 原语基于 ClientSession + Future + OpState 模型 +- ClientSession 与 ServerSession 对称:start_write/start_read vs readBody/writeBody +- 多 assign 支持:迭代 vector,每个 assign 创建独立 ClientSession,共享 OpState +- CUDA 两端 staging:async_send/write/read + ServerSession readBody/writeBody +- send/recv 脱离 MemoryPool(裸指针模式),read/write 继续使用 MR 寻址 +- `register_memory_region(name, ptr, offset, length)` 接口对齐 RDMAEndpoint +- 编译开关:`USE_CUDA=ON ./build_and_test.sh all` 启用 CUDA 路径 +- 宽松截断 + exact_size 拒绝 + overflow 保护 + +## 当前状态(已过时,仅供参考) + +``` +async_send(chunk): + 取连接 → TcpOpState → asio::post(lambda) → return Future + lambda: async_write(header+payload) → signal op → return conn + +async_read(assign): + 取连接(RESERVE) → TcpOpState → asio::post(lambda) → return Future + lambda: async_write(header) → async_read(response) → signal op → return conn +``` + +问题: +1. I/O 生命周期散落在 lambda 捕获中,无显式状态机 +2. async_read 的 write_header → read_response 是两个回调嵌套 +3. ServerSession 有清晰的 `readHeader → dispatch → readBody/writeBody`, + 但客户端没有对应的 ClientSession +4. `assign_tuple_t` (local_mr, remote_mr, remote_off, local_off, length) 的解析 + 散落在 endpoint 方法中,与 I/O 执行耦合 + +## 接口对齐 + +与 RDMAEndpoint 保持一致(已去除 void* stream, writeWithImm/immRecv): + +```cpp +// TwoSide (对应 RDMA send/recv, 异步化) +std::shared_ptr async_send(const chunk_tuple_t& chunk, + int64_t timeout_ms = kDefaultTimeoutMs); +std::shared_ptr async_recv(const chunk_tuple_t& chunk); + +// OneSide (对应 RDMA read/write, 异步化) +std::shared_ptr async_read( + const std::vector& assign, + int64_t timeout_ms = kDefaultTimeoutMs); +std::shared_ptr async_write( + const std::vector& assign, + int64_t timeout_ms = kDefaultTimeoutMs); +``` + +- `async_` 前缀:TCP 全部为异步(I/O 在 io_context 线程),与 RDMA 的同步 Future 区分 +- `void* stream`:已删除(TCP 无 CUDA stream) +- `std::vector`:接口接受 vector,但 v4 不对多个 assign 做聚合。 + 每个原语调用 = 一个 ClientSession = 一个 Future。 + 多个 assign 的聚合留给上层(SlimeRPC)。 + +## 关键数据结构 + +### 两种 tuple,两种寻址模型 + +```cpp +// send/recv — 双边,只需要本地 buffer 信息 +using chunk_tuple_t = std::tuple; +// mr_handle offset length + +// read/write — 单边,指定本地+远端两个 buffer +using assign_tuple_t = std::tuple; +// local_mr remote_mr remote_off local_off length +``` + +`assign_tuple_t` 已经包含了完成一次单边操作所需的**所有**寻址信息: +- 远端地址 = remote_mr.addr + remote_off → `SessionHeader.addr` +- 本地地址 = local_mr.addr + local_off → 本地读写位置 +- 长度 = length → `SessionHeader.size` + +### 从 assign_tuple_t 到 SessionHeader 的映射(在 Primitive 中完成 MR 解析) + +```cpp +// async_write: assign_tuple_t → SessionHeader + local_src +const auto& a = assign[0]; +auto local = local_pool_->get_mr_fast(std::get<0>(a)); // local_mr handle +auto remote = remote_pool_->get_remote_mr_fast(std::get<1>(a)); // remote_mr handle + +uint64_t remote_addr = remote.addr + std::get<2>(a); // remote_off +uint64_t local_src = local.addr + std::get<3>(a); // local_off +size_t len = std::get<4>(a); // length + +SessionHeader hdr{len, remote_addr, OP_WRITE}; +// ClientSession 拿到的是解析后的 hdr + local_src,不接触 assign_tuple_t +``` + +## v4 目标:四者关系 + +``` +┌─────────────────────────────────────────────────────────┐ +│ Primitive (TcpEndpoint::async_xxx) │ +│ │ +│ 1. 解析 assign_tuple_t / chunk_tuple_t → MR 寻址 │ +│ 2. 构建 SessionHeader (wire format) │ +│ 3. 创建 OpState (completion signal) │ +│ 4. 获取连接 (from pool) │ +│ 5. 创建 ClientSession(sock, op, hdr, payload_src/dst) │ +│ 6. return Future(op) │ +└────────────┬────────────────────────────────────────────┘ + │ 创建 + ┌────────▼──────────┐ ┌──────────────────┐ + │ ClientSession │────────→│ TcpOpState │←──────┐ + │ (I/O 状态机) │ signal │ (完成信号) │ │ + │ shared_ptr 自管理 │ └────────┬─────────┘ │ + │ │ │ 被持有 │ + │ start_write() │ │ │ + │ start_read() │ ┌────────▼─────────┐ │ + │ on_done → 归还连接 │ │ TcpFuture │ │ + └────────────────────┘ │ (用户句柄) │───────┘ + │ wait()/wait_for()│ + └──────────────────┘ +``` + +### 关系矩阵 + +| 对象 | 生命周期 | 知道什么 | 不知道什么 | +|------|---------|---------|-----------| +| **Primitive** | 单次调用 | MR 寻址, hdr 构建, assign_tuple_t 解析 | 线协议细节, async I/O 回调链 | +| **OpState** | ≥ Future 生命周期 | completion_status, signal | I/O 如何完成, 谁在驱动 | +| **Future** | 调用者持有 | wait()/wait_for() | 线协议, socket, 连接池 | +| **ClientSession** | I/O 进行中 | hdr, socket, payload 指针 | MR handle, assign_tuple_t | +| **ServerSession** | 连接存续期间 | socket, recv_matcher | 连接池, Future, OpState | + +## ClientSession 设计 + +一个 ClientSession = 一次出站 I/O 操作的完整生命周期。 + +```cpp +class ClientSession : public std::enable_shared_from_this { +public: + using DoneCallback = std::function; + + ClientSession(asio::ip::tcp::socket sock, DoneCallback on_done); + + // write: header + payload (both async_write, gather) + void start_write(const SessionHeader& hdr, const void* payload); + + // read: write header → read response into dst + void start_read(const SessionHeader& hdr, void* dst); + +private: + asio::ip::tcp::socket socket_; + DoneCallback on_done_; + SessionHeader hdr_{}; + // chunk_buf_ 不需要 — write 直接用 payload 指针, read 直接用 dst 指针 +}; +``` + +关键设计决策: +- **ClientSession 不持有 OpState** — 它只报告 `ec`。由 Primitive 在 on_done 中 signal OpState +- **ClientSession 不持有 PooledConnection** — 它只持有 socket。由 Primitive 在 on_done 中归还连接 +- 这样 ClientSession 是纯粹的 I/O 状态机,不耦合 Future/OpState/Pool + +### 原语 → ClientSession 映射 + +``` +async_send(chunk): + ┌─ 解析 chunk_tuple_t → mr.addr + offset → src_ptr, length + ├─ hdr = {length, 0, OP_SEND} + ├─ op = TcpOpState::create() + ├─ conn = pool.getConnection() + ├─ session = make_shared(move(conn->socket), + │ [op, conn, &pool](ec) { + │ op->completion_status = ec ? FAILED : SUCCESS; + │ op->signal->set_comm_done(0); + │ pool.returnConnection(conn); + │ }); + ├─ session->start_write(hdr, src_ptr); + └─ return TcpSendFuture(op); + +async_write(assign): ← 同上, hdr.opcode = OP_WRITE, hdr.addr = remote_addr +async_read(assign): ← session->start_read(hdr, dst_ptr) + dst_ptr = local_mr.addr + local_off +async_recv(chunk): ← 无 ClientSession (注册到 pending_recvs_) +``` + +### std::vector 的多 assign 处理 + +RDMA 中多个 assign 可聚合为一个 WR chain(一次 `ibv_post_send`,一个 Future)。 +TCP 没有硬件聚合——每个 assign 对应一个独立的线消息(一个 header + payload)。 +但接口约定是一个 `std::vector` → 一个 Future。 + +处理方式:**迭代 vector,每个 assign 创建一个 ClientSession,共享一个 OpState**。 + +``` +async_write([assign_0, assign_1, assign_2]): + op = TcpOpState::create() + op->expected_mask = (1 << 3) - 1 // 3 个 assign, 等 3 个 session 完成 + + for i, a in enumerate(assign): + 解析 a → hdr + src_ptr + conn = pool.getConnection() // 复用同一连接 + session = ClientSession(sock, [op, conn, i, &pool](ec) { + if (!ec) op->signal->set_comm_done(i); // 设置第 i 位 + pool.returnConnection(conn); + }) + session->start_write(hdr, src_ptr) + + return TcpReadWriteFuture(op) // wait 等待 expected_mask 所有位就绪 +``` + +每个 assign → 一个 session → 一次 `async_write`(串行在线路上,同连接)。 +Future.wait() 自旋等待 `completion_mask` 达到 `expected_mask`。 + +**与单 assign 的统一**:单 assign 是 `expected_mask = 1` 的特例。 +ClientSession 不感知是单还是多——只负责一个 I/O 操作。 + +### 不再需要的 + +- `asio::post` — ClientSession 构造后直接在调用者线程调 start_xxx,asio async_write/async_read 已经在 io_context 上 +- `weak_ptr` — ClientSession 不持有 endpoint 引用 +- `pending_reads_` map — 不再需要按 request_id 匹配响应。async_read 创建的 ClientSession 在 start_read 的 on_done 中直接拿到结果 + +## 入站/出站对称 + +``` +ServerSession (入站, 持久) ClientSession (出站, 瞬态) +────────────────────────── ────────────────────────── +readHeader() ← socket start_write(hdr, payload) → socket +dispatch() start_read(hdr, dst) → socket + ├─ OP_SEND: async_read → signal write_header → callback + ├─ OP_WRITE: readBody → memcpy read_response → callback + └─ OP_READ: writeBody → done on_done → Primitive signal → 析构 +readHeader() ← 循环 +``` + +## 文件变更 + +| 文件 | 变更 | +|------|------| +| `tcp_session.h` | 新增 ClientSession 类 (约 35 行) | +| `tcp_session.cpp` | 新增 ClientSession 实现 (约 50 行): start_write, start_read | +| `tcp_endpoint.cpp` | async_send/write/read 从 ad-hoc lambda → ClientSession; 删除 pending_reads_ 相关逻辑; 删除 asio::post | +| `tcp_endpoint.h` | 删除 `pending_reads_`, `read_mu_`, `next_req_id_` (不再需要 request_id 匹配); 公开 API 不变 | + +## 不聚合的理由 + +`assign_tuple_t` 是一个单次 I/O 操作的完整描述——不是可拆分的子操作集合。 +每个 async_read/async_write 调用对应一个 ClientSession。 +多个 assign 的聚合留给上层(如 SlimeRPC channel 的多个 slot), +不在 TcpEndpoint 层处理。 + +## Timeout 设计 + +### 两层 timeout,不同归属 + +| 层 | 机制 | 归属 | 语义 | +|----|------|------|------| +| **Future 层** | `wait_for(ms)` 定时自旋轮询 signal | Future / 调用者 | "我等不了了,但操作还在后台跑" | +| **I/O 层** | `asio::steady_timer` + `socket.cancel()` | ClientSession | "真的取消这个 I/O" | + +### v4 实现 Future 层,v5 实现 I/O 层 + +**v4**: +- `timeout_ms` 参数保留在方法签名中,但仅作为 OpState 的提示值存储 +- 真正的超时由 `future.wait_for(seconds)` 控制——调用者决定等待多久 +- ClientSession 不感知 timeout——它总是跑完 I/O 链 + +```cpp +fut = ep.async_send((h, 0, 128), timeout_ms=5000); +// timeout_ms 存入 op_state, 但 async I/O 链不受影响 +status = fut.wait_for(3.0); // 调用者侧超时 — 3 秒后返回 None +// 3 秒后 ClientSession 可能还在写, 完成后仍会 signal op_state +// 只是没有人等这个 signal 了 +``` + +**v5**:加 `asio::steady_timer` 给 ClientSession +```cpp +void ClientSession::start_write(...) { + if (timeout_ms_ > 0) { + timer_.expires_after(ms(timeout_ms_)); + timer_.async_wait([this](ec) { if (!ec) socket_.cancel(); }); + } + asio::async_write(socket_, bufs, ...); +} +// timer 触发 → socket.cancel() → async_write 回调收到 operation_aborted +// → on_done(operation_aborted) → op->completion_status = TCP_TIMEOUT +``` + +### 为什么不把 timeout_ms 去掉 + +保留它的两个理由: +1. 接口与 RDMA 的 `send(chunk, stream)` 模式一致——都有一个"额外控制参数"的位置 +2. 它为 v5 的 timer 实现预留了参数位,届时只需改内部实现,不改变 API + +## 为什么不做 + +- **recv 无 ClientSession** — 无出站 I/O +- **不拆 WriteSession/ReadSession** — 差异小,合并为一个 ClientSession +- **不在 Future 中持有 Session** — Future 只 wait,通过 OpState 间接关联 +- **ClientSession 不持有 OpState** — 只报 ec,由 Primitive 的 on_done 统一 signal + +## 未来规划 + +### CUDA 锁页内存 + +当前 CUDA staging 使用 `new char[]`(可分页内存),D2H/H2D `cudaMemcpy` 走的是同步 device→host 拷贝,pageable memory 路径较慢。 + +后续改为 `cudaHostAlloc()` 分配锁页(pinned)内存,使 `cudaMemcpy` 能走 DMA 快速路径。同时可考虑 `cudaMemcpyAsync` + `cudaStream` 与 io_context 的异步重叠。 + +### async_recv exact_size 自适应 + +当前 `exact_size` 是 opt-in boolean 参数,默认 `false`(宽松截断)。未来改为默认自适应: +- 当 `send_size <= recv_size`:自动启用严格检查(exact match) +- 当 `send_size > recv_size`:自动宽松截断 +- 移除 `exact_size` 参数,行为由实际数据量驱动 diff --git a/dlslime/csrc/engine/tcp/tcp_connection_pool.cpp b/dlslime/csrc/engine/tcp/tcp_connection_pool.cpp new file mode 100755 index 00000000..2bbaa5bb --- /dev/null +++ b/dlslime/csrc/engine/tcp/tcp_connection_pool.cpp @@ -0,0 +1,137 @@ +#include "tcp_connection_pool.h" + +#include + +#include "dlslime/csrc/logging.h" + +namespace dlslime { +namespace tcp { + +using tcp = asio::ip::tcp; + +std::shared_ptr +TcpConnectionPool::getConnection(const std::string& host, uint16_t port) { + ConnKey key{host, port}; + + { + std::lock_guard lk(mu_); + auto it = pool_.find(key); + if (it != pool_.end()) { + for (auto& c : it->second) { + if (!c->in_use && c->socket.is_open()) { + c->in_use = true; + c->last_used = std::chrono::steady_clock::now(); + return c; + } + } + } + } + + tcp::resolver resolver(io_ctx_); + auto endpoints = resolver.resolve(host, std::to_string(port)); + tcp::socket sock(io_ctx_); + asio::error_code ec; + asio::connect(sock, endpoints, ec); + if (ec) { + SLIME_LOG_WARN("TcpConnectionPool: connect to ", host, ":", port, + " failed: ", ec.message()); + return nullptr; + } + sock.set_option(tcp::no_delay(true)); + + auto conn = std::make_shared(std::move(sock), host, port); + { + std::lock_guard lk(mu_); + // Remove idle connection + cleanupIdleConnections(false); + + auto& q = pool_[key]; + for (auto q_i = q.begin(); q_i != q.end();) { + auto& c = *q_i; + if (!c->in_use) { + if (c->socket.is_open()) { + c->in_use = true; + c->last_used = std::chrono::steady_clock::now(); + asio::error_code ign; + conn->socket.close(ign); + return c; + } else { + // Remove dead connection + q_i = q.erase(q_i); + continue; + } + } + q_i++; + } + q.push_back(conn); + } + return conn; +} + +void TcpConnectionPool::returnConnection( + std::shared_ptr conn) { + if (!conn) return; + ConnKey key{conn->host, conn->port}; + + std::lock_guard lk(mu_); + auto it = pool_.find(key); + if (it != pool_.end()) { + auto& q = it->second; + for (auto qi = q.begin(); qi != q.end(); ++qi) + if (*qi == conn) { + if (conn->socket.is_open()) { + conn->in_use = false; + conn->last_used = std::chrono::steady_clock::now(); + } else { + q.erase(qi); + } + break; + } + if (q.empty()) pool_.erase(it); + return; + } + + // Connection not found in pool (temporary), close it. + if (conn->socket.is_open()) { + asio::error_code ec; + conn->socket.close(ec); + if (ec) + SLIME_LOG_WARN("TcpConnectionPool: close temp conn ", conn->host, + ":", conn->port, " failed: ", ec.message()); + } + +} + +void TcpConnectionPool::cleanupIdleConnections(bool lock) { + auto now = std::chrono::steady_clock::now(); + if (lock) std::lock_guard lk(mu_); + for (auto it = pool_.begin(); it != pool_.end(); ) { + auto& q = it->second; + while (!q.empty()) { + auto& c = q.back(); + if (!c->in_use) { + auto idle = std::chrono::duration_cast( + now - c->last_used).count(); + if (idle > kIdleTimeout.count()) { + asio::error_code ign; + c->socket.close(ign); + q.pop_back(); + continue; + } + } + break; + } + if (q.empty()) it = pool_.erase(it); else ++it; + } +} + +void TcpConnectionPool::clear() { + std::lock_guard lk(mu_); + for (auto& [_, q] : pool_) + // force close + for (auto& c : q) { c->in_use = false; asio::error_code ign; c->socket.close(ign);} + pool_.clear(); +} + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/tcp_connection_pool.h b/dlslime/csrc/engine/tcp/tcp_connection_pool.h new file mode 100755 index 00000000..2565a72f --- /dev/null +++ b/dlslime/csrc/engine/tcp/tcp_connection_pool.h @@ -0,0 +1,68 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace dlslime { +namespace tcp { + +struct PooledConnection { + asio::ip::tcp::socket socket; + std::string host; + uint16_t port{0}; + std::chrono::steady_clock::time_point last_used; + bool in_use{true}; + + PooledConnection(asio::ip::tcp::socket s, std::string h, uint16_t p) + : socket(std::move(s)), host(std::move(h)), port(p), + last_used(std::chrono::steady_clock::now()) {} +}; + +// Keyed by (host, port). Thread-safe. +// States: IDLE (in deque, in_use=false) / ACTIVE (checked out) / RESERVED +class TcpConnectionPool { +public: + static constexpr std::chrono::seconds kIdleTimeout{300}; + + explicit TcpConnectionPool(asio::io_context& io_ctx) : io_ctx_(io_ctx) {} + + std::shared_ptr getConnection( + const std::string& host, uint16_t port); + + void returnConnection(std::shared_ptr conn); + + void cleanupIdleConnections(bool lock = true); + void clear(); + +private: + struct ConnKey { + std::string host; + uint16_t port; + bool operator==(const ConnKey& o) const { + return host == o.host && port == o.port; + } + }; + struct ConnKeyHash { + size_t operator()(const ConnKey& k) const { + return std::hash{}(k.host) + ^ (std::hash{}(k.port) << 1); + } + }; + + asio::io_context& io_ctx_; + std::mutex mu_; + std::unordered_map>, + ConnKeyHash> pool_; +}; + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/tcp_context.cpp b/dlslime/csrc/engine/tcp/tcp_context.cpp new file mode 100644 index 00000000..f669e9be --- /dev/null +++ b/dlslime/csrc/engine/tcp/tcp_context.cpp @@ -0,0 +1,27 @@ +#include "tcp_context.h" + +namespace dlslime { +namespace tcp { + +TcpContext::TcpContext() { + // Keep io_context alive even when there's no work posted yet. + auto work = asio::make_work_guard(io_ctx_); + io_thread_ = std::thread([this, w = std::move(work)]() { + io_ctx_.run(); + }); +} + +TcpContext::~TcpContext() { + shutdown(); +} + +void TcpContext::shutdown() { + if (!running_) return; + running_ = false; + io_ctx_.stop(); + if (io_thread_.joinable()) io_thread_.join(); + conn_pool_.clear(); +} + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/tcp_context.h b/dlslime/csrc/engine/tcp/tcp_context.h new file mode 100644 index 00000000..a3bd5185 --- /dev/null +++ b/dlslime/csrc/engine/tcp/tcp_context.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include + +#include +#include + +#include "tcp_connection_pool.h" + +namespace dlslime { +namespace tcp { + +// Process-level shared resource holder. +// Multiple TcpEndpoints can share one TcpContext to run on a single +// io_context thread, reducing thread count. +// +// For sync wrappers: sync_send() = async_send() + future.wait() +// — the io_context drives the I/O, the caller thread just blocks. +class TcpContext { +public: + TcpContext(); + ~TcpContext(); + + TcpContext(const TcpContext&) = delete; + TcpContext& operator=(const TcpContext&) = delete; + + asio::io_context& io_context() { return io_ctx_; } + TcpConnectionPool& conn_pool() { return conn_pool_; } + + void shutdown(); + +private: + asio::io_context io_ctx_; + std::thread io_thread_; + TcpConnectionPool conn_pool_{io_ctx_}; + bool running_{true}; +}; + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/tcp_endpoint.cpp b/dlslime/csrc/engine/tcp/tcp_endpoint.cpp new file mode 100755 index 00000000..f65649b4 --- /dev/null +++ b/dlslime/csrc/engine/tcp/tcp_endpoint.cpp @@ -0,0 +1,437 @@ +#include "tcp_endpoint.h" + +#include +#include + +#include +#include + +#include "dlslime/csrc/logging.h" + +#ifdef USE_CUDA +#include +#endif + +namespace dlslime { +namespace tcp { + +using tcp = asio::ip::tcp; + +// ── helpers ───────────────────────────────────────────── + +static void hdr_hton(SessionHeader& h) { + h.size = htole64(h.size); + h.addr = htole64(h.addr); +} + +#ifdef USE_CUDA +static bool is_cuda_memory(const void* addr) { + cudaPointerAttributes attr; + auto st = cudaPointerGetAttributes(&attr, addr); + return (st == cudaSuccess && attr.type == cudaMemoryTypeDevice); +} +#endif + +// ── RecvMatcher factory ──────────────────────────────── + +ServerSession::RecvMatcher TcpEndpoint::make_recv_matcher() { + std::weak_ptr weak = shared_from_this(); + return [weak]() -> RecvSlot { + auto self = weak.lock(); + if (!self) return {}; + std::lock_guard lk(self->recv_mu_); + if (self->pending_recvs_.empty()) return {}; + auto pr = std::move(self->pending_recvs_.front()); + self->pending_recvs_.pop_front(); + + RecvSlot slot{pr.op_state->user_buffer, pr.op_state->user_length, + pr.op_state, {}, pr.exact_size}; +#ifdef USE_CUDA + if (pr.cuda_dst) { + slot.buffer = reinterpret_cast(pr.staging_buf.get()); + slot.post_read = [buf = std::shared_ptr(std::move(pr.staging_buf)), + dst = pr.cuda_dst, len = pr.op_state->user_length]() { + auto cu_err = cudaMemcpy(reinterpret_cast(dst), buf.get(), + len, cudaMemcpyHostToDevice); + if (cu_err != cudaSuccess) + SLIME_LOG_ERROR("cudaMemcpy H2D (recv): ", cudaGetErrorString(cu_err)); + }; + } +#endif + return slot; + }; +} + +// ── Constructor ──────────────────────────────────────── + +TcpEndpoint::TcpEndpoint(const std::string& ip, uint16_t port) + : own_ctx_(std::make_unique()) + , acceptor_(own_ctx_->io_context()) + , local_pool_(std::make_shared()) + , remote_pool_(std::make_shared()) + , local_host_(ip) { + ctx_ = own_ctx_.get(); + local_port_ = port; + start_io(); +} + +TcpEndpoint::~TcpEndpoint() { + shutdown(); +} + +void TcpEndpoint::start_io() { + auto addr = asio::ip::make_address(local_host_); + auto ep = tcp::endpoint(addr, local_port_); + acceptor_.open(ep.protocol()); + acceptor_.set_option(tcp::acceptor::reuse_address(true)); + acceptor_.bind(ep); + acceptor_.listen(64); + + if (local_port_ == 0) { + asio::error_code ec; + local_port_ = acceptor_.local_endpoint(ec).port(); + } + + do_accept(); +} + +// ── do_accept ─────────────────────────────────────────── + +void TcpEndpoint::do_accept() { + if (!running_.load(std::memory_order_acquire)) return; + acceptor_.async_accept( + [this](asio::error_code ec, tcp::socket sock) { + if (ec) { + if (ec != asio::error::operation_aborted) + SLIME_LOG_WARN("TcpEndpoint accept: ", ec.message()); + return; + } + sock.set_option(tcp::no_delay(true)); + auto session = std::make_shared( + std::move(sock), local_pool_.get(), make_recv_matcher()); + session->start(); + do_accept(); + }); +} + +// ── endpoint_info / connect ───────────────────────────── + +json TcpEndpoint::endpoint_info() const { + return { + {"host", local_host_}, + {"port", local_port_}, + {"mr_info", local_pool_->mr_info()} + }; +} + +json TcpEndpoint::mr_info() const { + return local_pool_->mr_info(); +} + +void TcpEndpoint::connect(const json& remote_endpoint_info) { + auto host = remote_endpoint_info.value("host", ""); + auto port = static_cast(remote_endpoint_info.value("port", 0)); + + // Verify reachability before accepting the peer identity. + auto conn = ctx_->conn_pool().getConnection(host, port); + if (!conn) { + SLIME_LOG_WARN("TcpEndpoint::connect: cannot reach ", host, ":", port); + return; + } + + peer_host_ = host; + peer_port_ = port; + if (remote_endpoint_info.contains("mr_info")) { + for (const auto& [name, info] : remote_endpoint_info["mr_info"].items()) + remote_pool_->register_remote_memory_region(info, name); + } + connected_.store(true, std::memory_order_release); + ctx_->conn_pool().returnConnection(std::move(conn)); +} + +// ── memory registration ───────────────────────────────── + +int32_t TcpEndpoint::register_memory_region(const std::string& name, + uintptr_t ptr, uintptr_t offset, + size_t length) { + return local_pool_->register_memory_region(ptr + offset, length, name); +} + +int32_t TcpEndpoint::register_remote_memory_region(const std::string& name, + const json& mr_info) { + return remote_pool_->register_remote_memory_region(mr_info, name); +} + +// ── async_send ────────────────────────────────────────── +// chunk_tuple_t = (src_ptr, offset, length) — raw pointers, no MR lookup. + +std::shared_ptr +TcpEndpoint::async_send(const chunk_tuple_t& chunk, int64_t /*timeout_ms*/) { + uintptr_t src = std::get<0>(chunk) + std::get<1>(chunk); + size_t len = std::get<2>(chunk); + + auto conn = ctx_->conn_pool().getConnection(peer_host_, peer_port_); + auto op = TcpOpState::create(); + op->signal->reset_all(); + + if (!conn) { + op->completion_status.store(TCP_FAILED, std::memory_order_release); + op->signal->force_complete(); + return std::make_shared(op); + } + + SessionHeader hdr{len, 0, OP_SEND}; + auto& pool = ctx_->conn_pool(); + + auto* send_ptr = reinterpret_cast(src); + bool is_cuda = false; +#ifdef USE_CUDA + if (is_cuda_memory(send_ptr)) { + auto* buf = new char[len]; + auto cu_err = cudaMemcpy(buf, send_ptr, len, cudaMemcpyDeviceToHost); + if (cu_err != cudaSuccess) { + SLIME_LOG_ERROR("async_send cudaMemcpy D2H: ", cudaGetErrorString(cu_err)); + delete[] buf; + op->completion_status.store(TCP_FAILED, std::memory_order_release); + op->signal->force_complete(); + pool.returnConnection(conn); + return std::make_shared(op); + } + send_ptr = buf; + is_cuda = true; + } +#endif + + auto session = std::make_shared( + std::move(conn->socket), + [op, conn, &pool, send_ptr, is_cuda](asio::error_code ec) { + if (ec) + SLIME_LOG_WARN("async_send: ", ec.message()); + op->completion_status.store( + ec ? TCP_FAILED : TCP_SUCCESS, std::memory_order_release); + if (op->signal) op->signal->set_comm_done(0); + pool.returnConnection(conn); +#ifdef USE_CUDA + if (is_cuda) delete[] send_ptr; +#endif + }); + session->start_write(hdr, send_ptr); + + return std::make_shared(op); +} + +// ── async_recv ────────────────────────────────────────── +// chunk_tuple_t = (dst_ptr, offset, length) — raw pointers, no MR lookup. + +std::shared_ptr +TcpEndpoint::async_recv(const chunk_tuple_t& chunk, bool exact_size) { + auto op = TcpOpState::create(); + op->signal->reset_all(); + uintptr_t dst = std::get<0>(chunk) + std::get<1>(chunk); + size_t length = std::get<2>(chunk); + op->user_buffer = dst; + op->user_length = length; + + PendingRecv pr{op, nullptr, 0, exact_size}; +#ifdef USE_CUDA + if (is_cuda_memory(reinterpret_cast(dst))) { + auto* buf = new char[length]; + pr.staging_buf.reset(buf); + pr.cuda_dst = dst; + op->user_buffer = reinterpret_cast(buf); + } +#endif + + { + std::lock_guard lk(recv_mu_); + pending_recvs_.push_back(std::move(pr)); + } + + return std::make_shared(op); +} + +// ── async_read ────────────────────────────────────────── +// Each assign creates an independent ClientSession; all share one OpState. +// Future.wait() blocks until every session has signalled its bit. + +std::shared_ptr +TcpEndpoint::async_read(const std::vector& assign, + int64_t /*timeout_ms*/) { + if (assign.empty()) + throw std::runtime_error("TcpEndpoint::async_read: empty assignment"); + + size_t N = assign.size(); + auto op = TcpOpState::create(); + op->signal->reset_all(); + op->expected_mask = (N < 32) ? (1u << N) - 1 : 0xFFFFFFFFu; + op->completion_status.store(TCP_SUCCESS, std::memory_order_release); + op->completion_mask.store(0, std::memory_order_release); + + auto& pool = ctx_->conn_pool(); + + for (size_t i = 0; i < N; i++) { + const auto& a = assign[i]; + int32_t local_h = static_cast(std::get<0>(a)); + int32_t remote_h = static_cast(std::get<1>(a)); + uint64_t remote_off = std::get<2>(a); + uint64_t local_off = std::get<3>(a); + size_t length = std::get<4>(a); + + auto local = local_pool_->get_mr_fast(local_h); + auto remote = remote_pool_->get_remote_mr_fast(remote_h); + if (local.length == 0 || remote.length == 0) + throw std::runtime_error("TcpEndpoint::async_read: invalid MR handle"); + + uintptr_t local_dst = local.addr + local_off; + SessionHeader hdr{length, remote.addr + remote_off, OP_READ}; + + auto conn = pool.getConnection(peer_host_, peer_port_); + if (!conn) { + op->completion_status.store(TCP_FAILED, std::memory_order_release); + op->signal->set_comm_done(i); + continue; + } + + auto* read_dst = reinterpret_cast(local_dst); + bool is_cuda = false; +#ifdef USE_CUDA + if (is_cuda_memory(read_dst)) { + read_dst = new char[length]; + is_cuda = true; + } +#endif + + auto session = std::make_shared( + std::move(conn->socket), + [op, conn, i, &pool, read_dst, is_cuda, + real_dst = local_dst, len = length](asio::error_code ec) { + if (ec) { + SLIME_LOG_WARN("async_read session ", i, ": ", ec.message()); + op->completion_status.store(TCP_FAILED, std::memory_order_release); + } +#ifdef USE_CUDA + if (!ec && is_cuda) { + auto cu_err = cudaMemcpy(reinterpret_cast(real_dst), + read_dst, len, cudaMemcpyHostToDevice); + if (cu_err != cudaSuccess) { + SLIME_LOG_ERROR("async_read cudaMemcpy H2D: ", cudaGetErrorString(cu_err)); + op->completion_status.store(TCP_FAILED, std::memory_order_release); + } + } + if (is_cuda) delete[] read_dst; +#endif + if (op->signal) op->signal->set_comm_done(i); + pool.returnConnection(conn); + }); + session->start_read(hdr, read_dst); + } + + return std::make_shared(op); +} + +// ── async_write ───────────────────────────────────────── +// Each assign creates an independent ClientSession; all share one OpState. + +std::shared_ptr +TcpEndpoint::async_write(const std::vector& assign, + int64_t /*timeout_ms*/) { + if (assign.empty()) + throw std::runtime_error("TcpEndpoint::async_write: empty assignment"); + + size_t N = assign.size(); + auto op = TcpOpState::create(); + op->signal->reset_all(); + op->expected_mask = (N < 32) ? (1u << N) - 1 : 0xFFFFFFFFu; + op->completion_status.store(TCP_SUCCESS, std::memory_order_release); + op->completion_mask.store(0, std::memory_order_release); + + auto& pool = ctx_->conn_pool(); + + for (size_t i = 0; i < N; i++) { + const auto& a = assign[i]; + int32_t local_h = static_cast(std::get<0>(a)); + int32_t remote_h = static_cast(std::get<1>(a)); + uint64_t remote_off = std::get<2>(a); + uint64_t local_off = std::get<3>(a); + size_t length = std::get<4>(a); + + auto local = local_pool_->get_mr_fast(local_h); + auto remote = remote_pool_->get_remote_mr_fast(remote_h); + if (local.length == 0 || remote.length == 0) + throw std::runtime_error("TcpEndpoint::async_write: invalid MR handle"); + + uintptr_t src = local.addr + local_off; + SessionHeader hdr{length, remote.addr + remote_off, OP_WRITE}; + + auto conn = pool.getConnection(peer_host_, peer_port_); + if (!conn) { + op->completion_status.store(TCP_FAILED, std::memory_order_release); + op->signal->set_comm_done(i); + continue; + } + + auto* send_ptr = reinterpret_cast(src); + bool is_cuda = false; +#ifdef USE_CUDA + if (is_cuda_memory(send_ptr)) { + auto* buf = new char[length]; + auto cu_err = cudaMemcpy(buf, send_ptr, length, cudaMemcpyDeviceToHost); + if (cu_err != cudaSuccess) { + SLIME_LOG_ERROR("async_write cudaMemcpy D2H: ", cudaGetErrorString(cu_err)); + delete[] buf; + op->completion_status.store(TCP_FAILED, std::memory_order_release); + op->signal->force_complete(); + pool.returnConnection(conn); + return std::make_shared(op); + } + send_ptr = buf; + is_cuda = true; + } +#endif + + auto session = std::make_shared( + std::move(conn->socket), + [op, conn, i, &pool, send_ptr, is_cuda](asio::error_code ec) { + if (ec) { + SLIME_LOG_WARN("async_write session ", i, ": ", ec.message()); + op->completion_status.store(TCP_FAILED, std::memory_order_release); + } + if (op->signal) op->signal->set_comm_done(i); + pool.returnConnection(conn); +#ifdef USE_CUDA + if (is_cuda) delete[] send_ptr; +#endif + }); + session->start_write(hdr, send_ptr); + } + + return std::make_shared(op); +} + +// ── shutdown ──────────────────────────────────────────── + +void TcpEndpoint::shutdown() { + bool expected = true; + if (!running_.compare_exchange_strong(expected, false)) + return; + + connected_.store(false, std::memory_order_release); + acceptor_.close(); + + { + std::lock_guard lk(recv_mu_); + for (auto& pr : pending_recvs_) { + if (pr.op_state && pr.op_state->signal) { + pr.op_state->completion_status.store(TCP_CLOSED, std::memory_order_release); + pr.op_state->signal->force_complete(); + } + } + pending_recvs_.clear(); + } + + if (own_ctx_) + own_ctx_->shutdown(); +} + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/tcp_endpoint.h b/dlslime/csrc/engine/tcp/tcp_endpoint.h new file mode 100755 index 00000000..66f7a6f4 --- /dev/null +++ b/dlslime/csrc/engine/tcp/tcp_endpoint.h @@ -0,0 +1,112 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "dlslime/csrc/common/json.hpp" +#include "dlslime/csrc/engine/assignment.h" +#include "tcp_connection_pool.h" +#include "tcp_context.h" +#include "tcp_future.h" +#include "tcp_header.h" +#include "tcp_memory_pool.h" +#include "tcp_op_state.h" +#include "tcp_session.h" + +namespace dlslime { +namespace tcp { + +using json = nlohmann::json; + +class TcpEndpoint : public std::enable_shared_from_this { +public: + static constexpr int64_t kDefaultTimeoutMs = 30000; + + explicit TcpEndpoint(const std::string& ip = "0.0.0.0", uint16_t port = 0); + + TcpEndpoint(TcpContext& ctx, const std::string& ip = "0.0.0.0", uint16_t port = 0) = delete; + + ~TcpEndpoint(); + + TcpEndpoint(const TcpEndpoint&) = delete; + TcpEndpoint& operator=(const TcpEndpoint&) = delete; + + // ── Connection ────────────────────────────────────── + json endpoint_info() const; + void connect(const json& remote_endpoint_info); + void shutdown(); + + // ── Memory ────────────────────────────────────────── + int32_t register_memory_region(const std::string& name, + uintptr_t ptr, uintptr_t offset, size_t length); + int32_t register_remote_memory_region(const std::string& name, + const json& mr_info); + json mr_info() const; + + // ── Async I/O (all return Future immediately; I/O runs on io_context thread) ── + + std::shared_ptr async_send( + const chunk_tuple_t& chunk, + int64_t timeout_ms = kDefaultTimeoutMs); + + std::shared_ptr async_recv( + const chunk_tuple_t& chunk, + bool exact_size = false); + + std::shared_ptr async_read( + const std::vector& assign, + int64_t timeout_ms = kDefaultTimeoutMs); + + std::shared_ptr async_write( + const std::vector& assign, + int64_t timeout_ms = kDefaultTimeoutMs); + + // ── Accessors ─────────────────────────────────────── + void setId(int64_t id) { id_.store(id, std::memory_order_relaxed); } + int64_t getId() const { return id_.load(std::memory_order_relaxed); } + bool is_connected() const { return connected_.load(std::memory_order_acquire); } + +private: + void start_io(); + void do_accept(); + ServerSession::RecvMatcher make_recv_matcher(); + + // ── identity ──────────────────────────────────────── + std::atomic id_{-1}; + std::string local_host_{"0.0.0.0"}; + uint16_t local_port_{0}; + std::string peer_host_; + uint16_t peer_port_{0}; + std::atomic connected_{false}; + + // ── asio core ─────────────────────────────────────── + TcpContext* ctx_{nullptr}; + std::unique_ptr own_ctx_; + asio::ip::tcp::acceptor acceptor_; + std::atomic running_{true}; + + // ── memory ────────────────────────────────────────── + std::shared_ptr local_pool_; + std::shared_ptr remote_pool_; + + // ── recv matching ─────────────────────────────────── + struct PendingRecv { + std::shared_ptr op_state; + std::unique_ptr staging_buf; + uintptr_t cuda_dst{0}; + bool exact_size{false}; + }; + std::mutex recv_mu_; + std::deque pending_recvs_; +}; + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/tcp_future.h b/dlslime/csrc/engine/tcp/tcp_future.h new file mode 100644 index 00000000..dcf53a4e --- /dev/null +++ b/dlslime/csrc/engine/tcp/tcp_future.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include +#include + +#include "dlslime/csrc/common/pause.h" +#include "dlslime/csrc/device/device_future.h" +#include "tcp_op_state.h" + +namespace dlslime { +namespace tcp { + +class TcpFuture : public DeviceFuture { +public: + explicit TcpFuture(std::shared_ptr op) + : op_state_(std::move(op)) { + if (!op_state_) + throw std::runtime_error("TcpFuture: null op_state"); + } + + int32_t wait() const override { + if (op_state_->signal) + op_state_->signal->wait_comm_done_cpu(op_state_->expected_mask); + return op_state_->completion_status.load(std::memory_order_acquire); + } + + // Block up to timeout_ms milliseconds. Returns true iff completed; + // writes status to *out. On timeout the operation is still in-flight. + bool wait_for(int64_t timeout_ms, int32_t* out) const { + auto deadline = std::chrono::steady_clock::now() + + std::chrono::milliseconds(timeout_ms); + while (true) { + if (op_state_->signal) { + uint32_t m = op_state_->signal->get_comm_done_mask(); + if ((m & op_state_->expected_mask) == op_state_->expected_mask) { + if (out) *out = op_state_->completion_status.load( + std::memory_order_acquire); + return true; + } + } + if (std::chrono::steady_clock::now() >= deadline) { + if (op_state_->signal) { + uint32_t m = op_state_->signal->get_comm_done_mask(); + if ((m & op_state_->expected_mask) == op_state_->expected_mask) { + if (out) *out = op_state_->completion_status.load( + std::memory_order_acquire); + return true; + } + } + return false; + } + machnet_pause(); + } + } + +protected: + std::shared_ptr op_state_; +}; + +class TcpSendFuture : public TcpFuture { public: using TcpFuture::TcpFuture; }; +class TcpRecvFuture : public TcpFuture { public: using TcpFuture::TcpFuture; }; +class TcpReadWriteFuture : public TcpFuture { public: using TcpFuture::TcpFuture; }; + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/tcp_header.h b/dlslime/csrc/engine/tcp/tcp_header.h new file mode 100644 index 00000000..313187d6 --- /dev/null +++ b/dlslime/csrc/engine/tcp/tcp_header.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include + +namespace dlslime { +namespace tcp { + +// 17-byte wire header, referenced from Mooncake SessionHeader. +// offset size field +// 0 8 size payload byte count (htole64 / le64toh) +// 8 8 addr remote buffer virtual address +// 16 1 opcode SEND=0x00 READ=0x01 WRITE=0x02 + +#pragma pack(push, 1) +struct SessionHeader { + uint64_t size; + uint64_t addr; + uint8_t opcode; +}; +#pragma pack(pop) + +static_assert(sizeof(SessionHeader) == 17, "SessionHeader must be 17 bytes"); + +enum OpCode : uint8_t { + OP_SEND = 0x00, // header + payload → peer recv matches + OP_READ = 0x01, // header only → peer reads local memory → sends data back + OP_WRITE = 0x02, // header + payload → peer writes to local memory +}; + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/tcp_memory_pool.cpp b/dlslime/csrc/engine/tcp/tcp_memory_pool.cpp new file mode 100755 index 00000000..dc7d8ab3 --- /dev/null +++ b/dlslime/csrc/engine/tcp/tcp_memory_pool.cpp @@ -0,0 +1,142 @@ +#include "tcp_memory_pool.h" + +#include "dlslime/csrc/logging.h" + +namespace dlslime { +namespace tcp { + +// ── local MR ──────────────────────────────────────────── + +int32_t TcpMemoryPool::register_memory_region( + uintptr_t addr, size_t length, const std::string& name) { + + if (name.empty()) { + SLIME_LOG_WARN("TcpMemoryPool: empty name rejected"); + return -1; + } + if (name_to_handle_.find(name) != name_to_handle_.end()) { + SLIME_LOG_WARN("TcpMemoryPool: duplicate name '", name, "' rejected"); + return -1; + } + + auto pit = ptr_to_handle_.find(addr); + if (pit != ptr_to_handle_.end()) { + int32_t h = pit->second; + if (h >= 0 && static_cast(h) < handle_to_mr_.size() + && handle_to_mr_[h].addr == addr + && handle_to_mr_[h].length >= length) { + name_to_handle_[name] = h; + return h; + } + } + + int32_t h = static_cast(handle_to_mr_.size()); + handle_to_mr_.push_back({addr, length}); + ptr_to_handle_[addr] = h; + name_to_handle_[name] = h; + return h; +} + +int32_t TcpMemoryPool::unregister_memory_region(int32_t handle) { + if (handle < 0 || static_cast(handle) >= handle_to_mr_.size()) + return -1; + + auto& mr = handle_to_mr_[handle]; + ptr_to_handle_.erase(mr.addr); + + // Remove all name→handle entries pointing to this handle. + for (auto it = name_to_handle_.begin(); it != name_to_handle_.end(); ) { + if (it->second == handle) + it = name_to_handle_.erase(it); + else + ++it; + } + + mr = {}; + return 0; +} + +// ── remote MR ─────────────────────────────────────────── + +int32_t TcpMemoryPool::register_remote_memory_region( + const json& mr_info, std::optional name) { + + std::string mr_name = name.value_or(mr_info.value("name", "")); + + if (!mr_name.empty()) { + auto it = remote_name_to_handle_.find(mr_name); + if (it != remote_name_to_handle_.end()) { + int32_t h = it->second; + auto& rm = remote_handle_to_mr_[h]; + rm.addr = mr_info.value("addr", 0UL); + rm.length = mr_info.value("length", 0UL); + return h; + } + } + + int32_t h = static_cast(remote_handle_to_mr_.size()); + remote_handle_to_mr_.push_back({ + mr_info.value("addr", 0UL), + mr_info.value("length", 0UL) + }); + remote_handle_to_name_.push_back(mr_name); + if (!mr_name.empty()) remote_name_to_handle_[mr_name] = h; + return h; +} + +int32_t TcpMemoryPool::unregister_remote_memory_region(int32_t handle) { + if (handle < 0 || static_cast(handle) >= remote_handle_to_mr_.size()) + return -1; + auto& s = remote_handle_to_name_[handle]; + if (!s.empty()) remote_name_to_handle_.erase(s); + remote_handle_to_mr_[handle] = {}; + s.clear(); + return 0; +} + +// ── fast lookup ───────────────────────────────────────── + +TcpMr TcpMemoryPool::get_mr_fast(int32_t handle) const { + if (handle < 0 || static_cast(handle) >= handle_to_mr_.size()) + return {}; + return handle_to_mr_[handle]; +} + +TcpMr TcpMemoryPool::get_remote_mr_fast(int32_t handle) const { + if (handle < 0 || static_cast(handle) >= remote_handle_to_mr_.size()) + return {}; + return remote_handle_to_mr_[handle]; +} + +int32_t TcpMemoryPool::get_mr_handle(const std::string& name) const { + auto it = name_to_handle_.find(name); + return it != name_to_handle_.end() ? it->second : -1; +} + +int32_t TcpMemoryPool::get_remote_mr_handle(const std::string& name) const { + auto it = remote_name_to_handle_.find(name); + return it != remote_name_to_handle_.end() ? it->second : -1; +} + +// ── serialization ─────────────────────────────────────── + +json TcpMemoryPool::mr_info() const { + json j = json::object(); + for (const auto& [name, h] : name_to_handle_) + if (h >= 0 && static_cast(h) < handle_to_mr_.size() + && handle_to_mr_[h].length > 0) + j[name] = handle_to_mr_[h].json_info(name); + return j; +} + +json TcpMemoryPool::remote_mr_info() const { + json j = json::object(); + for (const auto& [name, h] : remote_name_to_handle_) + if (h >= 0 && static_cast(h) < remote_handle_to_mr_.size() + && remote_handle_to_mr_[h].length > 0) + j[name] = remote_handle_to_mr_[h].json_info(name); + return j; +} + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/tcp_memory_pool.h b/dlslime/csrc/engine/tcp/tcp_memory_pool.h new file mode 100755 index 00000000..c9061708 --- /dev/null +++ b/dlslime/csrc/engine/tcp/tcp_memory_pool.h @@ -0,0 +1,61 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "dlslime/csrc/common/json.hpp" + +namespace dlslime { +namespace tcp { + +using json = nlohmann::json; + +struct TcpMr { + uintptr_t addr{0}; + size_t length{0}; + + json json_info(const std::string& name) const { + return {{"name", name}, {"addr", addr}, {"length", length}}; + } +}; + +// Pure-bookkeeping pool. No hardware registration needed for TCP. +class TcpMemoryPool { +public: + TcpMemoryPool() = default; + + // name must be non-empty and unique; returns -1 on violation. + int32_t register_memory_region(uintptr_t addr, size_t length, + const std::string& name); + int32_t unregister_memory_region(int32_t handle); + + // remote MR — name is optional (may come from peer's mr_info). + int32_t register_remote_memory_region(const json& mr_info, + std::optional name = std::nullopt); + int32_t unregister_remote_memory_region(int32_t handle); + + TcpMr get_mr_fast(int32_t handle) const; + TcpMr get_remote_mr_fast(int32_t handle) const; + int32_t get_mr_handle(const std::string& name) const; + int32_t get_remote_mr_handle(const std::string& name) const; + + json mr_info() const; + json remote_mr_info() const; + +private: + // local MRs + std::unordered_map name_to_handle_; + std::unordered_map ptr_to_handle_; + std::vector handle_to_mr_; + + // remote MRs + std::unordered_map remote_name_to_handle_; + std::vector remote_handle_to_mr_; + std::vector remote_handle_to_name_; +}; + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/tcp_op_state.h b/dlslime/csrc/engine/tcp/tcp_op_state.h new file mode 100644 index 00000000..dbf89a2a --- /dev/null +++ b/dlslime/csrc/engine/tcp/tcp_op_state.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include +#include + +#include "dlslime/csrc/device/device_api.h" +#include "dlslime/csrc/device/signal.h" + +namespace dlslime { +namespace tcp { + +enum Status : int32_t { + TCP_SUCCESS = 0, + TCP_FAILED = -1, + TCP_TIMEOUT = -2, + TCP_CLOSED = -3, +}; + +// One per in-flight operation. The io_context thread (or caller for sync +// ops) signals completion via DeviceSignal; the future's wait() spins on +// wait_comm_done_cpu(). +struct TcpOpState { + std::shared_ptr signal; + uint32_t expected_mask{1}; + std::atomic completion_mask{0}; + std::atomic completion_status{TCP_SUCCESS}; + + uintptr_t user_buffer{0}; + size_t user_length{0}; + size_t bytes_copied{0}; + + static std::shared_ptr create() { + auto s = std::make_shared(); + s->signal = dlslime::device::createSignal(false); + return s; + } +}; + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/tcp_session.cpp b/dlslime/csrc/engine/tcp/tcp_session.cpp new file mode 100755 index 00000000..994a2a46 --- /dev/null +++ b/dlslime/csrc/engine/tcp/tcp_session.cpp @@ -0,0 +1,252 @@ +#include "tcp_session.h" + +#include +#include +#include + +#include +#include + +#include "dlslime/csrc/logging.h" + +#ifdef USE_CUDA +#include +#endif + +namespace dlslime { +namespace tcp { + +// ── helpers ───────────────────────────────────────────── + +static void hdr_to_net(SessionHeader& hdr) { + hdr.size = htole64(hdr.size); + hdr.addr = htole64(hdr.addr); +} + +static void hdr_to_host(SessionHeader& hdr) { + hdr.size = le64toh(hdr.size); + hdr.addr = le64toh(hdr.addr); +} + +static bool is_fatal(asio::error_code ec) { + return ec && ec != asio::error::eof; +} + +#ifdef USE_CUDA +static bool is_cuda_memory(const void* addr) { + cudaPointerAttributes attr; + auto st = cudaPointerGetAttributes(&attr, addr); + return (st == cudaSuccess && attr.type == cudaMemoryTypeDevice); +} +#endif + +// ── ServerSession ─────────────────────────────────────── + +ServerSession::ServerSession(asio::ip::tcp::socket socket, + TcpMemoryPool* local_pool, + RecvMatcher recv_matcher) + : socket_(std::move(socket)) + , local_pool_(local_pool) + , recv_matcher_(std::move(recv_matcher)) {} + +void ServerSession::start() { + readHeader(); +} + +void ServerSession::readHeader() { + auto self = shared_from_this(); + header_ = {}; + asio::async_read(socket_, asio::buffer(&header_, sizeof(header_)), + [this, self](asio::error_code ec, size_t /*n*/) { + if (ec) { + if (is_fatal(ec)) + SLIME_LOG_WARN("ServerSession::readHeader ", ec.message()); + return; + } + hdr_to_host(header_); + dispatch(); + }); +} + +void ServerSession::dispatch() { + switch (header_.opcode) { + + case OP_SEND: { + if (header_.size == 0) { readHeader(); return; } + auto slot = recv_matcher_(); + if (!slot.buffer || slot.length == 0) { + SLIME_LOG_WARN("ServerSession: OP_SEND with no pending recv"); + readHeader(); + return; + } + if (slot.exact_size && header_.size != slot.length) { + SLIME_LOG_WARN("ServerSession: size mismatch, send ", header_.size, + " != recv ", slot.length); + if (slot.op_state) { + slot.op_state->completion_status.store( + TCP_FAILED, std::memory_order_release); + if (slot.op_state->signal) + slot.op_state->signal->set_comm_done(0); + } + readHeader(); + return; + } + + // Always drain the full send payload from the wire. If recv buffer + // is smaller, read into a temp buffer then copy what fits. + size_t n_read = static_cast(header_.size); + size_t n_copy = std::min(n_read, slot.length); + auto* dst = reinterpret_cast(slot.buffer); + bool overflow = false; + + if (header_.size > slot.length) { + dst = new char[n_read]; + overflow = true; + } + + auto self = shared_from_this(); + asio::async_read(socket_, asio::buffer(dst, n_read), + [this, self, slot, n_copy, dst, overflow]( + asio::error_code ec, size_t /*rn*/) { + if (ec) { + if (is_fatal(ec)) + SLIME_LOG_WARN("ServerSession SEND read: ", ec.message()); + if (overflow) delete[] dst; + return; + } + if (overflow) { + std::memcpy(reinterpret_cast(slot.buffer), dst, n_copy); + delete[] dst; + } + if (slot.post_read) slot.post_read(); + if (slot.op_state) { + slot.op_state->bytes_copied = n_copy; + slot.op_state->completion_status.store( + TCP_SUCCESS, std::memory_order_release); + if (slot.op_state->signal) + slot.op_state->signal->set_comm_done(0); + } + readHeader(); + }); + break; + } + + case OP_WRITE: + if (header_.size == 0) { readHeader(); return; } + readBody(reinterpret_cast(header_.addr), header_.size); + break; + + case OP_READ: { + uintptr_t addr = static_cast(header_.addr); + size_t sz = static_cast(header_.size); + if (sz == 0) { readHeader(); return; } + writeBody(reinterpret_cast(addr), sz); + break; + } + + default: + SLIME_LOG_WARN("ServerSession: unknown opcode ", + static_cast(header_.opcode)); + readHeader(); + break; + } +} + +void ServerSession::readBody(void* dst, size_t len) { + auto* ptr = static_cast(dst); + bool is_cuda = false; +#ifdef USE_CUDA + if (is_cuda_memory(dst)) { + ptr = new char[len]; + is_cuda = true; + } +#endif + + auto self = shared_from_this(); + asio::async_read(socket_, asio::buffer(ptr, len), + [this, self, real_addr = reinterpret_cast(dst), + len, is_cuda, ptr](asio::error_code ec, size_t /*n*/) { + if (ec) { + if (is_fatal(ec)) + SLIME_LOG_WARN("ServerSession::readBody ", ec.message()); + if (is_cuda) delete[] ptr; + return; + } +#ifdef USE_CUDA + if (is_cuda) { + auto cu_err = cudaMemcpy(reinterpret_cast(real_addr), ptr, + len, cudaMemcpyHostToDevice); + if (cu_err != cudaSuccess) + SLIME_LOG_ERROR("readBody cudaMemcpy H2D: ", cudaGetErrorString(cu_err)); + delete[] ptr; + } +#endif + readHeader(); + }); +} + +void ServerSession::writeBody(const void* src, size_t len) { + auto* ptr = static_cast(src); + bool is_cuda = false; +#ifdef USE_CUDA + if (is_cuda_memory(src)) { + auto* buf = new char[len]; + auto cu_err = cudaMemcpy(buf, src, len, cudaMemcpyDeviceToHost); + if (cu_err != cudaSuccess) { + SLIME_LOG_ERROR("writeBody cudaMemcpy D2H: ", cudaGetErrorString(cu_err)); + delete[] buf; + ptr = static_cast(src); + } else { + ptr = buf; + is_cuda = true; + } + } +#endif + + auto self = shared_from_this(); + asio::async_write(socket_, asio::buffer(ptr, len), + [this, self, is_cuda, ptr](asio::error_code ec, size_t /*n*/) { + if (is_cuda) delete[] ptr; + if (ec && is_fatal(ec)) + SLIME_LOG_WARN("ServerSession::writeBody ", ec.message()); + readHeader(); + }); +} + +// ── ClientSession ─────────────────────────────────────── + +ClientSession::ClientSession(asio::ip::tcp::socket sock, DoneCallback on_done) + : socket_(std::move(sock)), on_done_(std::move(on_done)) {} + +void ClientSession::start_write(const SessionHeader& hdr, const void* payload) { + auto self = shared_from_this(); + SessionHeader net = hdr; + hdr_to_net(net); + std::array bufs = { + asio::buffer(&net, sizeof(net)), + asio::buffer(payload, hdr.size) + }; + asio::async_write(socket_, bufs, + [this, self](asio::error_code ec, size_t) { + if (on_done_) on_done_(ec); + }); +} + +void ClientSession::start_read(const SessionHeader& hdr, void* dst) { + auto self = shared_from_this(); + hdr_ = hdr; + SessionHeader net = hdr; + hdr_to_net(net); + asio::async_write(socket_, asio::buffer(&net, sizeof(net)), + [this, self, dst](asio::error_code ec, size_t) { + if (ec) { if (on_done_) on_done_(ec); return; } + asio::async_read(socket_, + asio::buffer(dst, hdr_.size), + [this, self](asio::error_code ec, size_t) { + if (on_done_) on_done_(ec); + }); + }); +} + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/tcp_session.h b/dlslime/csrc/engine/tcp/tcp_session.h new file mode 100644 index 00000000..2e70e7d8 --- /dev/null +++ b/dlslime/csrc/engine/tcp/tcp_session.h @@ -0,0 +1,75 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include "tcp_header.h" +#include "tcp_memory_pool.h" +#include "tcp_op_state.h" + +namespace dlslime { +namespace tcp { + +class TcpConnectionPool; + +struct RecvSlot { + uintptr_t buffer{0}; + size_t length{0}; + std::shared_ptr op_state; + std::function post_read; + bool exact_size{false}; // reject send size != recv size +}; + +// ── ServerSession: handles incoming requests on one persistent connection ── +// +// Lifecycle: start() → readHeader → dispatch → readBody/writeBody → readHeader ↻ +class ServerSession : public std::enable_shared_from_this { +public: + using RecvMatcher = std::function; + + ServerSession(asio::ip::tcp::socket socket, + TcpMemoryPool* local_pool, + RecvMatcher recv_matcher); + + void start(); + +private: + void readHeader(); + void dispatch(); + void readBody(void* dst, size_t len); + void writeBody(const void* src, size_t len); + + asio::ip::tcp::socket socket_; + TcpMemoryPool* local_pool_; + RecvMatcher recv_matcher_; + SessionHeader header_{}; +}; + +// ── ClientSession: drives one outbound I/O operation ───── +// +// Lifecycle: construct → start_write/start_read → on_done → self-destruct +// Does NOT own OpState or PooledConnection — only drives the I/O and reports ec. +class ClientSession : public std::enable_shared_from_this { +public: + using DoneCallback = std::function; + + ClientSession(asio::ip::tcp::socket sock, DoneCallback on_done); + + // Write header + payload to socket (gather async_write). + void start_write(const SessionHeader& hdr, const void* payload); + + // Write OP_READ header → read raw response into dst. + void start_read(const SessionHeader& hdr, void* dst); + +private: + asio::ip::tcp::socket socket_; + DoneCallback on_done_; + SessionHeader hdr_{}; +}; + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/csrc/engine/tcp/test_tcp_endpoint.py b/dlslime/csrc/engine/tcp/test_tcp_endpoint.py new file mode 100755 index 00000000..8bab9a8b --- /dev/null +++ b/dlslime/csrc/engine/tcp/test_tcp_endpoint.py @@ -0,0 +1,596 @@ +"""End-to-end test for TcpEndpoint v3 async primitives with timeout. + +Usage: + LD_LIBRARY_PATH=dlslime PYTHONPATH=. DLSLIME_LOG_LEVEL=0 python3 \ + dlslime/csrc/engine/tcp/test_tcp_endpoint.py +""" + +import ctypes +import os +import threading +import time + +from dlslime import TcpEndpoint, TcpMemoryPool + +# ── optional torch / CUDA support ──────────────────────── + +_HAS_TORCH = False +_HAS_CUDA = False + +try: + import torch + + _HAS_TORCH = True + _HAS_CUDA = torch.cuda.is_available() +except Exception: + pass + +_CUDA_FORCE_OFF = os.environ.get("DLSLIME_TCP_TEST_CUDA", "") in ("0", "false", "no") + + +def _torch_skip(): + return not _HAS_TORCH + + +def _cuda_skip(): + if _CUDA_FORCE_OFF: + return True + return not _HAS_CUDA + + +# ── test harness ───────────────────────────────────────── + +def _sync_run(name, fn_a, fn_b, timeout=120): + err = [] + b = threading.Barrier(2) + + def wrap(fn): + try: + b.wait(10) + fn() + except Exception as e: + err.append(e) + + ta = threading.Thread(target=wrap, args=(fn_a,), daemon=False) + tb = threading.Thread(target=wrap, args=(fn_b,), daemon=False) + ta.start() + tb.start() + ta.join(timeout) + tb.join(timeout) + if ta.is_alive() or tb.is_alive(): + raise RuntimeError(f"{name} FAIL!{timeout}s timeout!") + if len(err) > 0: + print(f"{name} FAIL {err}", flush=True) + return False + else: + print(f"{name} SUCC ", flush=True) + return True + + +# ── ctypes-based tests ─────────────────────────────────── + +def test_async_send_recv(): + buf_a = ctypes.create_string_buffer(128) + buf_b = ctypes.create_string_buffer(128) + + ep_a = TcpEndpoint(port=10001) + ep_b = TcpEndpoint(port=10002) + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + + def run_a(): + ep_a.connect(info_b) + ctypes.memmove(ctypes.addressof(buf_a), b"hello", 5) + time.sleep(5) + st = ep_a.async_send((ctypes.addressof(buf_a), 0, 5)).wait() + if st != 0: + raise RuntimeError(f"send: {st}") + st = ep_a.async_recv((ctypes.addressof(buf_a), 5, 5)).wait() + if st != 0: + raise RuntimeError(f"recv: {st}") + if bytes(buf_a[5:10]) != b"world": + raise RuntimeError(f"data: {bytes(buf_a[5:10])}") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + st = ep_b.async_recv((ctypes.addressof(buf_b), 0, 5)).wait() + if st != 0: + raise RuntimeError(f"recv: {st}") + if bytes(buf_b[:5]) != b"hello": + raise RuntimeError(f"data: {bytes(buf_b[:5])}") + ctypes.memmove(ctypes.addressof(buf_b), b"world", 5) + time.sleep(5) + st = ep_b.async_send((ctypes.addressof(buf_b), 0, 5)).wait() + if st != 0: + raise RuntimeError(f"send: {st}") + ep_b.shutdown() + + _sync_run("test_async_send_recv", run_a, run_b) + + +def test_async_send2recv(): + buf_a = ctypes.create_string_buffer(32) + buf_b = ctypes.create_string_buffer(32) + + ep_a = TcpEndpoint(port=10401) + ep_b = TcpEndpoint(port=10402) + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + + def run_a(): + ep_a.connect(info_b) + ctypes.memmove(ctypes.addressof(buf_a), b"one", 3) + time.sleep(5) + st = ep_a.async_send((ctypes.addressof(buf_a), 0, 3)).wait() + if st != 0: + raise RuntimeError(f"send: {st}") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + st = ep_b.async_recv((ctypes.addressof(buf_b), 0, 3)).wait() + if st != 0: + raise RuntimeError(f"recv: {st}") + if bytes(buf_b[:3]) != b"one": + raise RuntimeError(f"data: {bytes(buf_b[:3])}") + ep_b.shutdown() + + _sync_run("test_async_send_recv_one", run_a, run_b) + + +def test_async_write(): + buf_a = ctypes.create_string_buffer(256) + buf_b = ctypes.create_string_buffer(256) + addr_a = ctypes.addressof(buf_a) + + ep_a = TcpEndpoint(port=10003) + ep_b = TcpEndpoint(port=10004) + h_a = ep_a.register_memory_region("a", addr_a, 0, 256) + h_b = ep_b.register_memory_region("b", ctypes.addressof(buf_b), 0, 256) + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + h_br = ep_a.register_remote_memory_region("rb", info_b["mr_info"]["b"]) + + test_data = b"hello_from_a" + + def run_a(): + ep_a.connect(info_b) + ctypes.memmove(addr_a, test_data, len(test_data)) + st = ep_a.async_write([(h_a, h_br, 0, 0, len(test_data))]).wait() + if st != 0: + raise RuntimeError(f"write: {st}") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + for _ in range(50): + if bytes(buf_b[:len(test_data)]) == test_data: + break + time.sleep(0.5) + if bytes(buf_b[:len(test_data)]) != test_data: + raise RuntimeError(f"B write not received in {50 * 0.5}s") + ep_b.shutdown() + + _sync_run("test_async_write", run_a, run_b) + + +def test_async_read(): + buf_a = ctypes.create_string_buffer(256) + buf_b = ctypes.create_string_buffer(256) + addr_a = ctypes.addressof(buf_a) + + ep_a = TcpEndpoint(port=10005) + ep_b = TcpEndpoint(port=10006) + h_a = ep_a.register_memory_region("a", addr_a, 0, 256) + h_b = ep_b.register_memory_region("b", ctypes.addressof(buf_b), 0, 256) + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + h_br = ep_a.register_remote_memory_region("rb", info_b["mr_info"]["b"]) + + test_data = b"hello_from_b" + ctypes.memmove(ctypes.addressof(buf_b), test_data, 12) + + def run_a(): + ep_a.connect(info_b) + st = ep_a.async_read([(h_a, h_br, 0, 0, len(test_data))]).wait() + if st != 0: + raise RuntimeError(f"read: {st}") + if bytes(buf_a[:len(test_data)]) != test_data: + raise RuntimeError("read data mismatch") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + time.sleep(25) + ep_b.shutdown() + + _sync_run("test_async_read", run_a, run_b) + + +# ── skip test ── + + +def test_recv_timeout(): + buf_a = ctypes.create_string_buffer(32) + + ep_a = TcpEndpoint(port=10007) + ep_b = TcpEndpoint(port=10008) + + def run_b(): + ep_b.connect(ep_a.endpoint_info()) + time.sleep(1.0) + ep_b.shutdown() + + def run_a(): + ep_a.connect(ep_b.endpoint_info()) + fut = ep_a.async_recv((ctypes.addressof(buf_a), 0, 5)) + result = fut.wait_for(0.3) + if result is not None: + raise RuntimeError(f"expected None, got {result}") + ep_a.shutdown() + + _sync_run("test_recv_timeout", run_a, run_b) + + +def test_send_timeout_ms(): + buf_a = ctypes.create_string_buffer(64) + buf_b = ctypes.create_string_buffer(64) + + ep_a = TcpEndpoint(port=10009) + ep_b = TcpEndpoint(port=10010) + + def run_b(): + ep_b.connect(ep_a.endpoint_info()) + st = ep_b.async_recv((ctypes.addressof(buf_b), 0, 5)).wait() + if st != 0: + raise RuntimeError(f"recv: {st}") + ep_b.shutdown() + + def run_a(): + ep_a.connect(ep_b.endpoint_info()) + ctypes.memmove(ctypes.addressof(buf_a), b"world", 5) + st = ep_a.async_send((ctypes.addressof(buf_a), 0, 5), timeout_ms=10000).wait() + if st != 0: + raise RuntimeError(f"send: {st}") + ep_a.shutdown() + + _sync_run("test_send_timeout_ms", run_a, run_b) + + +def test_default_timeout(): + buf_a = ctypes.create_string_buffer(32) + buf_b = ctypes.create_string_buffer(32) + + ep_a = TcpEndpoint(port=10011) + ep_b = TcpEndpoint(port=10012) + + def run_b(): + ep_b.connect(ep_a.endpoint_info()) + st = ep_b.async_recv((ctypes.addressof(buf_b), 0, 5)).wait() + if st != 0: + raise RuntimeError(f"recv: {st}") + ep_b.shutdown() + + def run_a(): + ep_a.connect(ep_b.endpoint_info()) + ctypes.memmove(ctypes.addressof(buf_a), b"test!", 5) + st = ep_a.async_send((ctypes.addressof(buf_a), 0, 5)).wait() + if st != 0: + raise RuntimeError(f"send: {st}") + ep_a.shutdown() + + _sync_run("test_default_timeout", run_a, run_b) + + +def test_exact_size_mismatch(): + buf_a = ctypes.create_string_buffer(32) + buf_b = ctypes.create_string_buffer(32) + + ep_a = TcpEndpoint(port=10011) + ep_b = TcpEndpoint(port=10012) + + def run_b(): + ep_b.connect(ep_a.endpoint_info()) + st = ep_b.async_recv((ctypes.addressof(buf_b), 0, 4), exact_size=True).wait() + if st != -1: + raise RuntimeError(f"expected TCP_FAILED(-1), got {st}") + ep_b.shutdown() + + def run_a(): + ep_a.connect(ep_b.endpoint_info()) + ctypes.memmove(ctypes.addressof(buf_a), b"overflow", 8) + st = ep_a.async_send((ctypes.addressof(buf_a), 0, 8)).wait() + if st != 0: + raise RuntimeError(f"send: {st}") + ep_a.shutdown() + + _sync_run("test_exact_size_mismatch", run_a, run_b) + + +def test_overflow_truncate(): + buf_a = ctypes.create_string_buffer(64) + buf_b = ctypes.create_string_buffer(64) + + ep_a = TcpEndpoint(port=10013) + ep_b = TcpEndpoint(port=10014) + + def run_b(): + ep_b.connect(ep_a.endpoint_info()) + st = ep_b.async_recv((ctypes.addressof(buf_b), 0, 4)).wait() + if st != 0: + raise RuntimeError(f"recv1: {st}") + if bytes(buf_b[:4]) != b"LONG": + raise RuntimeError(f"truncated: {bytes(buf_b[:4])}") + st = ep_b.async_recv((ctypes.addressof(buf_b), 4, 5)).wait() + if st != 0: + raise RuntimeError(f"recv2: {st}") + if bytes(buf_b[4:9]) != b"HELLO": + raise RuntimeError(f"follow-up: {bytes(buf_b[4:9])}") + ep_b.shutdown() + + def run_a(): + ep_a.connect(ep_b.endpoint_info()) + ctypes.memmove(ctypes.addressof(buf_a), b"LONGDATA", 8) + st = ep_a.async_send((ctypes.addressof(buf_a), 0, 8)).wait() + if st != 0: + raise RuntimeError(f"send1: {st}") + ctypes.memmove(ctypes.addressof(buf_a), b"HELLO", 5) + st = ep_a.async_send((ctypes.addressof(buf_a), 0, 5)).wait() + if st != 0: + raise RuntimeError(f"send2: {st}") + ep_a.shutdown() + + _sync_run("test_overflow_truncate", run_a, run_b) + + +def test_mr_name_validation(): + ep = TcpEndpoint(port=0) + buf = ctypes.create_string_buffer(32) + + h = ep.register_memory_region("valid", ctypes.addressof(buf), 0, 32) + if h < 0: + raise RuntimeError(f"valid name: {h}") + + h = ep.register_memory_region("", ctypes.addressof(buf), 0, 32) + if h != -1: + raise RuntimeError(f"empty name should return -1, got {h}") + + h = ep.register_memory_region("valid", ctypes.addressof(buf), 0, 32) + if h != -1: + raise RuntimeError(f"duplicate name should return -1, got {h}") + + ep.shutdown() + + +def test_connect_unreachable(): + ep = TcpEndpoint(port=10015) + unreachable = {"host": "127.0.0.1", "port": 65535, "mr_info": {}} + ep.connect(unreachable) + if ep.is_connected(): + raise RuntimeError("should not be connected") + ep.shutdown() + + +# ── parameterized torch tests (device="cpu" or "cuda") ── + + +def _make_tensor(shape, device, dtype, **kw): + """Create a tensor on the given device. CPU tensor for recv on cuda path + uses ctypes buffer so data_ptr() gives host pointer (needed for cudaMemcpy).""" + return torch.randn(shape, dtype=dtype, + device=device if isinstance(device, torch.device) else torch.device(device), + **kw) + + +def test_torch_send_recv(device="cpu", dtype=torch.float32): + """Round-trip: A send full → B recv → B send slice → A recv.""" + SZ, SL = 32, 5 # elements + t_a = _make_tensor(SZ, device, dtype) + t_b = _make_tensor(SZ, device, dtype) + expected = t_a.clone() + n_bytes = SZ * 4 + sl_bytes = SL * 4 + + ep_a = TcpEndpoint(port=10101) + ep_b = TcpEndpoint(port=10102) + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + + def run_a(): + ep_a.connect(info_b) + time.sleep(5) + st = ep_a.async_send((t_a.data_ptr(), 0, n_bytes)).wait() + if st != 0: + raise RuntimeError(f"send: {st}") + st = ep_a.async_recv((t_a.data_ptr(), 10 * 4, sl_bytes)).wait() + if st != 0: + raise RuntimeError(f"recv: {st}") + if not torch.equal(t_a[10:15], t_b[20:25]): + raise RuntimeError("slice mismatch") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + st = ep_b.async_recv((t_b.data_ptr(), 0, n_bytes)).wait() + if st != 0: + raise RuntimeError(f"recv: {st}") + if not torch.equal(expected, t_b): + raise RuntimeError("full tensor mismatch") + time.sleep(5) + st = ep_b.async_send((t_b.data_ptr(), 20 * 4, sl_bytes)).wait() + if st != 0: + raise RuntimeError(f"send: {st}") + ep_b.shutdown() + + _sync_run(f"test_torch_send_recv_{device}", run_a, run_b, 120) + + +def test_torch_write(device="cpu", dtype=torch.float32): + """One-sided write: A async_write → B verifies data received.""" + SZ = 64 + t_a = _make_tensor(SZ, device, dtype) + t_b = _make_tensor(SZ, device, dtype) + expected = t_a.clone() + + n_bytes = SZ * 4 + + ep_a = TcpEndpoint(port=10103) + ep_b = TcpEndpoint(port=10104) + h_a = ep_a.register_memory_region("a", t_a.data_ptr(), 0, n_bytes) + h_b = ep_b.register_memory_region("b", t_b.data_ptr(), 0, n_bytes) + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + h_br = ep_a.register_remote_memory_region("rb", info_b["mr_info"]["b"]) + + def run_a(): + ep_a.connect(info_b) + st = ep_a.async_write([(h_a, h_br, 0, 0, n_bytes)]).wait() + if st != 0: + raise RuntimeError(f"write: {st}") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + for _ in range(40): + if torch.equal(expected, t_b): + break + time.sleep(0.5) + if not torch.equal(expected, t_b): + raise RuntimeError("write data not received") + ep_b.shutdown() + + _sync_run(f"test_torch_write_{device}", run_a, run_b) + + +def test_torch_read(device="cpu", dtype=torch.float32): + """One-sided read: B buffer pre-filled, A async_read and verifies.""" + dsize = 4 + SZ = 64 + t_a = _make_tensor(SZ, device, dtype) + t_b = _make_tensor(SZ, device, dtype) + expected = t_b.clone() + + n_bytes = SZ * dsize + + ep_a = TcpEndpoint(port=10105) + ep_b = TcpEndpoint(port=10106) + h_a = ep_a.register_memory_region("a", t_a.data_ptr(), 0, n_bytes) + h_b = ep_b.register_memory_region("b", t_b.data_ptr(), 0, n_bytes) + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + h_br = ep_a.register_remote_memory_region("rb", info_b["mr_info"]["b"]) + + def run_a(): + ep_a.connect(info_b) + st = ep_a.async_read([(h_a, h_br, 0, 0, n_bytes)]).wait() + if st != 0: + raise RuntimeError(f"read: {st}") + if not torch.equal(t_a, expected): + raise RuntimeError("read data mismatch") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + time.sleep(20) + ep_b.shutdown() + + _sync_run(f"test_torch_read_{device}", run_a, run_b) + + +def test_torch_write_batch(device="cpu", dtype=torch.float32, n_batch=4): + """One async_write with multiple assignments.""" + dsize = 4 + SZ = 64 + t_a_batch = [_make_tensor(SZ, device, dtype) for i in range(n_batch)] + t_b_batch = [_make_tensor(SZ, device, dtype) for i in range(n_batch)] + expected_batch = [i.clone() for i in t_a_batch] + + n_bytes = SZ * dsize + + ep_a = TcpEndpoint(port=10107) + ep_b = TcpEndpoint(port=10108) + h_a_batch = [ep_a.register_memory_region(f"a_{i}", t_a_batch[i].data_ptr(), 0, n_bytes) for i in range(n_batch)] + h_b_batch = [ep_b.register_memory_region(f"b_{i}", t_b_batch[i].data_ptr(), 0, n_bytes) for i in range(n_batch)] + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + h_br_batch = [ep_a.register_remote_memory_region(f"rb_{i}", info_b["mr_info"][f"b_{i}"]) for i in range(n_batch)] + + def run_a(): + ep_a.connect(info_b) + assigns = [(h_a_batch[i], h_br_batch[i], i * dsize, i * dsize, dsize) for i in range(n_batch)] + st = ep_a.async_write(assigns).wait() + if st != 0: + raise RuntimeError(f"write batch: {st}") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + time.sleep(3) + for i in range(n_batch): + time.sleep(2) + if not torch.equal(t_b_batch[i][i], expected_batch[i][i]): + raise RuntimeError(f"batch {i}: mismatch") + ep_b.shutdown() + + _sync_run(f"test_torch_write_batch_{device}", run_a, run_b) + + +def test_torch_read_batch(device="cpu", dtype=torch.float32, n_batch=4): + """One async_read with multiple assignments.""" + dsize = 4 + SZ = 64 + t_a_batch = [_make_tensor(SZ, device, dtype) for i in range(n_batch)] + t_b_batch = [_make_tensor(SZ, device, dtype) for i in range(n_batch)] + expected_batch = [i.clone() for i in t_b_batch] + + n_bytes = SZ * dsize + + ep_a = TcpEndpoint(port=10109) + ep_b = TcpEndpoint(port=10110) + h_a_batch = [ep_a.register_memory_region(f"a_{i}", t_a_batch[i].data_ptr(), 0, n_bytes) for i in range(n_batch)] + h_b_batch = [ep_b.register_memory_region(f"b_{i}", t_b_batch[i].data_ptr(), 0, n_bytes) for i in range(n_batch)] + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + h_br_batch = [ep_a.register_remote_memory_region(f"rb_{i}", info_b["mr_info"][f"b_{i}"]) for i in range(n_batch)] + + def run_a(): + ep_a.connect(info_b) + assigns = [(h_a_batch[i], h_br_batch[i], i * dsize, i * dsize, dsize) for i in range(n_batch)] + st = ep_a.async_read(assigns).wait() + if st != 0: + raise RuntimeError(f"read batch: {st}") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + time.sleep(3) + for i in range(n_batch): + time.sleep(2) + if not torch.equal(t_a_batch[i][i], expected_batch[i][i]): + raise RuntimeError(f"batch {i}: mismatch") + ep_b.shutdown() + + _sync_run(f"test_torch_read_batch_{device}", run_a, run_b) + + +# ── main ───────────────────────────────────────────────── + +if __name__ == "__main__": + test_async_send_recv() + test_async_send2recv() + test_async_write() + test_async_read() + + if not _torch_skip(): + device_list = ["cpu", "cuda"] + if _cuda_skip(): + print("No Cuda, Skip", flush = True) + device_list = ["cpu", ] + + for dev in device_list: + test_torch_send_recv(dev) + test_torch_write(dev) + test_torch_read(dev) + test_torch_write_batch(dev) + test_torch_read_batch(dev) diff --git a/dlslime/csrc/python/CMakeLists.txt b/dlslime/csrc/python/CMakeLists.txt index 389be03a..1584babc 100755 --- a/dlslime/csrc/python/CMakeLists.txt +++ b/dlslime/csrc/python/CMakeLists.txt @@ -67,6 +67,11 @@ if (BUILD_ASCEND_DIRECT) ) endif() +if (BUILD_TCP) + target_compile_definitions(_slime_c PRIVATE BUILD_TCP) + list(APPEND _slime_c_link_libraries _slime_tcp) +endif() + # Ops moved to NanoCCL - link to NanoCCL if needed # if (BUILD_INTRA_OPS OR BUILD_INTER_OPS) # if (BUILD_INTRA_OPS) diff --git a/dlslime/csrc/python/bind.cpp b/dlslime/csrc/python/bind.cpp index b5a56591..93eb641d 100644 --- a/dlslime/csrc/python/bind.cpp +++ b/dlslime/csrc/python/bind.cpp @@ -27,6 +27,12 @@ #include "dlslime/csrc/engine/ascend_direct/ascend_remote_memory_pool.h" #endif +#ifdef BUILD_TCP +#include "dlslime/csrc/engine/tcp/tcp_endpoint.h" +#include "dlslime/csrc/engine/tcp/tcp_future.h" +#include "dlslime/csrc/engine/tcp/tcp_memory_pool.h" +#endif + #include "dlslime/csrc/device/signal.h" #ifdef BUILD_RDMA @@ -89,6 +95,12 @@ namespace py = pybind11; #define BUILD_RPC_ENABLED false #endif +#ifdef BUILD_TCP +#define BUILD_TCP_ENABLED true +#else +#define BUILD_TCP_ENABLED false +#endif + // Ops moved to NanoCCL #define BUILD_INTRA_OPS_ENABLED false #define BUILD_INTER_OPS_ENABLED false @@ -102,6 +114,7 @@ PYBIND11_MODULE(_slime_c, m) EXPOSE_BUILD_FLAG(m, BUILD_INTRA_OPS); EXPOSE_BUILD_FLAG(m, BUILD_INTER_OPS); EXPOSE_BUILD_FLAG(m, BUILD_RPC); + EXPOSE_BUILD_FLAG(m, BUILD_TCP); m.def("discover_topology", &dlslime::topology::discoverTopology, @@ -512,6 +525,112 @@ PYBIND11_MODULE(_slime_c, m) py::call_guard()); #endif +#ifdef BUILD_TCP + // ========================================================================= + // TCP Transport + // ========================================================================= + py::class_>( + m, "SlimeTcpSendFuture") + .def("wait", &dlslime::tcp::TcpSendFuture::wait, + py::call_guard()) + .def("wait_for", + [](const dlslime::tcp::TcpSendFuture& self, double sec) -> py::object { + int32_t st = 0; + int64_t ms = static_cast(sec * 1000.0); + if (ms < 0) ms = 0; + if (self.wait_for(ms, &st)) return py::cast(st); + return py::none(); + }, + py::arg("timeout_seconds")); + + py::class_>( + m, "SlimeTcpRecvFuture") + .def("wait", &dlslime::tcp::TcpRecvFuture::wait, + py::call_guard()) + .def("wait_for", + [](const dlslime::tcp::TcpRecvFuture& self, double sec) -> py::object { + int32_t st = 0; + int64_t ms = static_cast(sec * 1000.0); + if (ms < 0) ms = 0; + if (self.wait_for(ms, &st)) return py::cast(st); + return py::none(); + }, + py::arg("timeout_seconds")); + + py::class_>( + m, "SlimeTcpReadWriteFuture") + .def("wait", &dlslime::tcp::TcpReadWriteFuture::wait, + py::call_guard()) + .def("wait_for", + [](const dlslime::tcp::TcpReadWriteFuture& self, double sec) -> py::object { + int32_t st = 0; + int64_t ms = static_cast(sec * 1000.0); + if (ms < 0) ms = 0; + if (self.wait_for(ms, &st)) return py::cast(st); + return py::none(); + }, + py::arg("timeout_seconds")); + + py::class_>( + m, "TcpMemoryPool") + .def(py::init<>()) + .def("register_memory_region", + &dlslime::tcp::TcpMemoryPool::register_memory_region, + py::arg("addr"), py::arg("length"), py::arg("name")) + .def("register_remote_memory_region", + [](dlslime::tcp::TcpMemoryPool& self, const json& mr_info, + py::object name_obj) { + std::optional name; + if (!name_obj.is_none()) name = name_obj.cast(); + return self.register_remote_memory_region(mr_info, name); + }, + py::arg("mr_info"), py::arg("name") = py::none()) + .def("get_mr_handle", &dlslime::tcp::TcpMemoryPool::get_mr_handle, + py::arg("name")) + .def("mr_info", &dlslime::tcp::TcpMemoryPool::mr_info); + + py::class_>( + m, "TcpEndpoint") + .def(py::init(), + py::arg("ip") = "0.0.0.0", py::arg("port") = 0) + .def("connect", &dlslime::tcp::TcpEndpoint::connect, + py::arg("remote_info"), py::call_guard()) + .def("endpoint_info", &dlslime::tcp::TcpEndpoint::endpoint_info) + .def("mr_info", &dlslime::tcp::TcpEndpoint::mr_info) + .def("shutdown", &dlslime::tcp::TcpEndpoint::shutdown, + py::call_guard()) + .def("is_connected", &dlslime::tcp::TcpEndpoint::is_connected) + .def("register_memory_region", + &dlslime::tcp::TcpEndpoint::register_memory_region, + py::arg("name"), py::arg("data_ptr"), py::arg("offset"), py::arg("length"), + py::call_guard()) + .def("register_remote_memory_region", + &dlslime::tcp::TcpEndpoint::register_remote_memory_region, + py::arg("name"), py::arg("mr_info"), py::call_guard()) + .def("async_send", + py::overload_cast( + &dlslime::tcp::TcpEndpoint::async_send), + py::arg("chunk"), + py::arg("timeout_ms") = dlslime::tcp::TcpEndpoint::kDefaultTimeoutMs, + py::call_guard()) + .def("async_recv", + &dlslime::tcp::TcpEndpoint::async_recv, + py::arg("chunk"), py::arg("exact_size") = false, + py::call_guard()) + .def("async_read", + py::overload_cast&, int64_t>( + &dlslime::tcp::TcpEndpoint::async_read), + py::arg("assign"), + py::arg("timeout_ms") = dlslime::tcp::TcpEndpoint::kDefaultTimeoutMs, + py::call_guard()) + .def("async_write", + py::overload_cast&, int64_t>( + &dlslime::tcp::TcpEndpoint::async_write), + py::arg("assign"), + py::arg("timeout_ms") = dlslime::tcp::TcpEndpoint::kDefaultTimeoutMs, + py::call_guard()); +#endif // BUILD_TCP + // Ops moved to NanoCCL - Python bindings should be in NanoCCL's Python module // #ifdef BUILD_INTRA_OPS // py::class_(m, "AllToAllIntraLLBuffer")