From 9811b3445471e8f3cec4f713852d27eaf6adf323 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 14:59:17 -0700 Subject: [PATCH 01/39] EP PyTorch: NCCL EP backend + autograd ops + tests, route zero_copy via cfg.zero_copy Signed-off-by: Phuong Nguyen --- build_tools/pytorch.py | 10 + examples/pytorch/ep/bench/ep_bench.py | 398 ++++++++++ examples/pytorch/ep/bench/run_ep_bench.sh | 72 ++ .../pytorch/ep/bench/run_nccl_ep_bench.sh | 62 ++ examples/pytorch/ep/ep_moe.py | 228 ++++++ examples/pytorch/ep/run_test_ep.sh | 37 + tests/pytorch/distributed/run_ep.py | 338 ++++++++ tests/pytorch/distributed/run_test_ep.sh | 55 ++ tests/pytorch/distributed/test_ep.py | 31 + transformer_engine/pytorch/csrc/extensions.h | 37 + .../pytorch/csrc/extensions/ep.cpp | 331 ++++++++ .../pytorch/csrc/extensions/pybind.cpp | 4 + transformer_engine/pytorch/ep.py | 734 ++++++++++++++++++ 13 files changed, 2337 insertions(+) create mode 100644 examples/pytorch/ep/bench/ep_bench.py create mode 100755 examples/pytorch/ep/bench/run_ep_bench.sh create mode 100755 examples/pytorch/ep/bench/run_nccl_ep_bench.sh create mode 100644 examples/pytorch/ep/ep_moe.py create mode 100755 examples/pytorch/ep/run_test_ep.sh create mode 100644 tests/pytorch/distributed/run_ep.py create mode 100755 tests/pytorch/distributed/run_test_ep.sh create mode 100644 tests/pytorch/distributed/test_ep.py create mode 100644 transformer_engine/pytorch/csrc/extensions/ep.cpp create mode 100644 transformer_engine/pytorch/ep.py diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index e2e6d09c29..ca54e72434 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -77,6 +77,16 @@ def setup_pytorch_extension( setup_mpi_flags(include_dirs, cxx_flags) + # Mirror the NCCL EP gate from setup.py / common CMake. When disabled, the + # ep.cpp source no-ops at the #ifdef boundary; without the define it would + # produce undefined references to nvte_ep_*. + if bool(int(os.getenv("NVTE_BUILD_WITH_NCCL_EP", "1"))): + cxx_flags.append("-DNVTE_WITH_NCCL_EP") + # PyTorch's symm-mem headers gate the NCCL_HAS_SYMMEM_* feature macros on + # USE_NCCL. The EP extension shares the symm-mem NCCL comm with torch, so + # it needs those macros visible. + cxx_flags.append("-DUSE_NCCL") + library_dirs = [] libraries = [] if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))): diff --git a/examples/pytorch/ep/bench/ep_bench.py b/examples/pytorch/ep/bench/ep_bench.py new file mode 100644 index 0000000000..86217b7f91 --- /dev/null +++ b/examples/pytorch/ep/bench/ep_bench.py @@ -0,0 +1,398 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""PyTorch EP perf bench: raw and autograd dispatch/combine on a single EP group. + +One process per GPU; launched via run_ep_bench.sh (torchrun). + +Stages (each timed in its own loop): + - dispatch_raw: _ep_dispatch_raw (no autograd, no prepare) + - ep_dispatch_fwd: ep_dispatch forward only + - ep_dispatch_fwd_bwd: ep_dispatch + backward on 0.5 * ||recv||^2 + - combine_raw: _ep_combine_raw (no autograd) + - ep_combine_fwd: ep_combine forward only + - ep_combine_fwd_bwd: ep_combine + backward + +ep_prepare runs once outside the timed loops. --kineto DIR dumps a Chrome +trace plus a per-kernel summary on rank 0. +""" + +import argparse +import gc +import os +import sys +import time +from contextlib import nullcontext + +import numpy as np +import torch +import torch.distributed as dist + +from transformer_engine.pytorch.ep import ( + EpBuffer, + EpHandle, + ep_bootstrap, + ep_combine, + ep_dispatch, + ep_finalize, + ep_prepare, + _ep_combine_raw, + _ep_dispatch_raw, +) + + +def _parse_args(): + p = argparse.ArgumentParser(description="TE-PyTorch EP perf bench") + p.add_argument("--tokens-per-rank", type=int, default=8192) + p.add_argument("--hidden", type=int, default=7168) + p.add_argument("--top-k", type=int, default=8) + p.add_argument("--num-experts", type=int, default=256) + p.add_argument("--warmup", type=int, default=2) + p.add_argument("--iters", type=int, default=10) + p.add_argument( + "--max-num-sms", + type=int, + default=0, + help="Max SMs for dispatch/combine/preprocess kernels (0 = auto).", + ) + p.add_argument( + "--kineto", + default=None, + help="If set, dump a Kineto Chrome trace + per-kernel summary into this dir (rank 0).", + ) + p.add_argument( + "--cuda-graph", + action="store_true", + default=False, + help=( + "Capture each stage into a CUDA graph and time replay() instead of the eager call. " + "Raw + fwd-only stages use torch.cuda.graph; fwd+bwd stages use " + "torch.cuda.make_graphed_callables to capture forward and backward together." + ), + ) + p.add_argument( + "--mode-label", + default=None, + help="Optional suffix for NVTX range names (e.g. 'fused' / 'unfused').", + ) + return p.parse_args() + + +def _nvtx_funcs(): + """Return push/pop helpers using torch.cuda.nvtx if available.""" + try: + push = torch.cuda.nvtx.range_push + pop = torch.cuda.nvtx.range_pop + return push, pop + except AttributeError: + return lambda _name: None, lambda: None + + +def _device_sm() -> int: + major, minor = torch.cuda.get_device_capability() + return major * 10 + minor + + +def _make_inputs(rank, world_size, T, H, K, E, device): + """Round-robin identity routing + uniform top-k weights.""" + topk_idx = np.empty((T, K), dtype=np.int64) + for t in range(T): + for k in range(K): + topk_idx[t, k] = ((rank * T + t) * K + k) % E + rng = np.random.default_rng(seed=42 + rank) + tokens_np = (rng.standard_normal((T, H), dtype=np.float32) * 0.5).astype(np.float32) + return ( + torch.from_numpy(topk_idx).to(device), + torch.from_numpy(tokens_np).to(device=device, dtype=torch.bfloat16), + torch.full((T, K), 1.0 / K, dtype=torch.float32, device=device), + ) + + +def _time_stage_us(name, fn, iters, nvtx_suffix, push, pop): + """Time fn for iters iterations after one untimed warmup; returns mean us.""" + # Run iters+1 times; drop the first (autotune outlier) and frame NVTX from iter 1. + total_ns = 0 + counted = 0 + for i in range(iters + 1): + if i == 1: + push(f"{name}{nvtx_suffix}") + torch.cuda.synchronize() + t0 = time.perf_counter_ns() + fn() + torch.cuda.synchronize() + dt = time.perf_counter_ns() - t0 + if i == 0: + continue + total_ns += dt + counted += 1 + pop() + return total_ns / 1e3 / counted + + +def main(): + args = _parse_args() + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", rank))) + device = torch.device("cuda", torch.cuda.current_device()) + + if _device_sm() < 90: + if rank == 0: + print(f"[ep_bench] SKIPPED: EP requires SM>=90 (got SM{_device_sm()})") + dist.destroy_process_group() + return + if world_size < 4: + if rank == 0: + print(f"[ep_bench] SKIPPED: EP requires >=4 ranks (got {world_size})") + dist.destroy_process_group() + return + + ep_size = world_size + E = args.num_experts + assert E % ep_size == 0, f"num_experts ({E}) must be divisible by ep_size ({ep_size})" + num_local_experts = E // ep_size + T = args.tokens_per_rank + H = args.hidden + K = args.top_k + # Conservative cap: every token could land on every local expert. + recv_pr = world_size * T * K // 2 + if rank == 0: + print( + f"[ep_bench] world={world_size} ep={ep_size} T={T} H={H} K={K} " + f"E={E} (local={num_local_experts}) recv_pr={recv_pr}" + + (f" mode={args.mode_label}" if args.mode_label else ""), + flush=True, + ) + + ep_group = dist.new_group(ranks=list(range(world_size)), backend="nccl") + ep_bootstrap( + ep_group, + num_experts=E, + max_tokens_per_rank=T, + recv_capacity_per_rank=recv_pr, + hidden_dim=H, + max_num_sms=args.max_num_sms, + ) + + topk_idx, tokens_hbm, topk_w_hbm = _make_inputs(rank, world_size, T, H, K, E, device) + + handle = EpHandle( + top_k=K, + max_tokens_per_rank=T, + recv_capacity_per_rank=recv_pr, + hidden_dim=H, + num_local_experts=num_local_experts, + ) + buffer = EpBuffer(handle) + + tokens = tokens_hbm + topk_w = topk_w_hbm + recv_tokens = torch.empty(recv_pr, H, dtype=torch.bfloat16, device=device) + recv_w = torch.empty(recv_pr, dtype=torch.float32, device=device) + + # -- Prepare once outside the timed loops ------------------------------ + ep_prepare(handle, topk_idx) + torch.cuda.synchronize() + + # Pre-dispatch a steady recv_tokens / recv_w so combine stages have valid input. + _ep_dispatch_raw(handle, topk_idx, tokens, topk_w, recv_tokens, recv_w) + torch.cuda.synchronize() + # fp-equivalent stand-in for an MLP output. + expert_out = recv_tokens.clone() + + nvtx_suffix = f"[{args.mode_label}]" if args.mode_label else "" + push, pop = _nvtx_funcs() + + # -- Stage closures ---------------------------------------------------- + # Persistent fwd+bwd inputs (make_graphed_callables needs stable storage). + tokens_p = tokens.detach().clone().requires_grad_(True) + eo_p = recv_tokens.detach().clone().requires_grad_(True) + + # Stand-in callables; the cuda-graph branch below swaps in graphed versions. + fwd_bwd_dispatch_fn = lambda x: ep_dispatch(handle, buffer, x, topk_idx, topk_w)[ # noqa: E731 + 0 + ] + fwd_bwd_combine_fn = lambda eo: ep_combine(handle, buffer, eo) # noqa: E731 + + def _dispatch_raw(): + _ep_dispatch_raw(handle, topk_idx, tokens, topk_w, recv_tokens, recv_w) + + def _combine_raw(): + out_buf = torch.empty(T, H, dtype=torch.bfloat16, device=device) + _ep_combine_raw(handle, expert_out, out_buf) + + def _ep_dispatch_fwd(): + ep_dispatch(handle, buffer, tokens.detach(), topk_idx, topk_w) + + def _ep_dispatch_fwd_bwd(): + tokens_p.grad = None + r = fwd_bwd_dispatch_fn(tokens_p) + (0.5 * (r * r).sum(dtype=torch.float32)).backward() + + def _ep_combine_fwd(): + ep_combine(handle, buffer, recv_tokens) + + def _ep_combine_fwd_bwd(): + eo_p.grad = None + out = fwd_bwd_combine_fn(eo_p) + (0.5 * (out * out).sum(dtype=torch.float32)).backward() + + stages = [ + ("dispatch_raw", _dispatch_raw, True), + ("ep_dispatch_fwd", _ep_dispatch_fwd, True), + ("ep_dispatch_fwd_bwd", _ep_dispatch_fwd_bwd, False), + ("combine_raw", _combine_raw, True), + ("ep_combine_fwd", _ep_combine_fwd, True), + ("ep_combine_fwd_bwd", _ep_combine_fwd_bwd, False), + ] + # Third tuple element: True = direct torch.cuda.graph capture; False = use + # make_graphed_callables (autograd-aware) instead. + + # -- Warmup ----------------------------------------------------------- + for _ in range(args.warmup): + for _name, fn, _capt in stages: + fn() + torch.cuda.synchronize() + + # -- Optional CUDA-graph capture -------------------------------------- + # Capture each capturable stage on a side stream and time .replay() + # instead of the eager call. Outputs allocated inside the + # autograd.Function's forward go through the per-capture private pool + # so addresses stay stable across replays. + captured_runners = {} + if args.cuda_graph: + # Graph fwd+bwd of the autograd-wrapped ops via make_graphed_callables. + class _DispatchMod(torch.nn.Module): + def forward(self, x): + return ep_dispatch(handle, buffer, x, topk_idx, topk_w)[0] + + class _CombineMod(torch.nn.Module): + def forward(self, eo): + return ep_combine(handle, buffer, eo) + + disp_mod = _DispatchMod().cuda() + comb_mod = _CombineMod().cuda() + g_disp, g_comb = torch.cuda.make_graphed_callables( + (disp_mod, comb_mod), + ((tokens_p,), (eo_p,)), + ) + fwd_bwd_dispatch_fn = g_disp + fwd_bwd_combine_fn = g_comb + + # Direct torch.cuda.graph capture for raw + fwd-only stages. + side = torch.cuda.Stream() + side.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(side): + for name, fn, direct_capturable in stages: + if not direct_capturable: + continue + fn() # prime the allocator for stable replay addresses + torch.cuda.synchronize() + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + fn() + captured_runners[name] = g + torch.cuda.current_stream().wait_stream(side) + torch.cuda.synchronize() + + # -- Optional Kineto profiling ---------------------------------------- + kineto_ctx = nullcontext() + if args.kineto and rank == 0: + os.makedirs(args.kineto, exist_ok=True) + kineto_ctx = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=False, + with_stack=False, + ) + + # -- Timed loops ------------------------------------------------------ + results = {} + with kineto_ctx as prof: + for name, fn, _ in stages: + runner = fn + if name in captured_runners: + # Time replay() instead of the eager call. + graph = captured_runners[name] + runner = graph.replay + results[name] = _time_stage_us(name, runner, args.iters, nvtx_suffix, push, pop) + + if rank == 0: + label = f" [{args.mode_label}]" if args.mode_label else "" + print("", flush=True) + print(f"| stage | mean wall (us){label} |", flush=True) + print("|----------------------|---------------:|", flush=True) + for name in ( + "dispatch_raw", + "ep_dispatch_fwd", + "ep_dispatch_fwd_bwd", + "combine_raw", + "ep_combine_fwd", + "ep_combine_fwd_bwd", + ): + print(f"| {name:20s} | {results[name]:14.1f} |", flush=True) + print( + "| (dispatch fwd-raw) |" + f" {results['ep_dispatch_fwd'] - results['dispatch_raw']:14.1f} |", + flush=True, + ) + print( + "| (dispatch bwd-fwd) |" + f" {results['ep_dispatch_fwd_bwd'] - results['ep_dispatch_fwd']:14.1f} |", + flush=True, + ) + print( + "| (combine fwd-raw) |" + f" {results['ep_combine_fwd'] - results['combine_raw']:14.1f} |", + flush=True, + ) + print( + "| (combine bwd-fwd) |" + f" {results['ep_combine_fwd_bwd'] - results['ep_combine_fwd']:14.1f} |", + flush=True, + ) + print("", flush=True) + + if args.kineto and rank == 0 and prof is not None: + trace_path = os.path.join(args.kineto, "ep_bench_trace.json") + prof.export_chrome_trace(trace_path) + print(f"[ep_bench] kineto trace: {trace_path}", flush=True) + print( + prof.key_averages().table(sort_by="cuda_time_total", row_limit=30), + flush=True, + ) + kern_csv = os.path.join(args.kineto, "ep_bench_kernels.csv") + with open(kern_csv, "w") as f: + f.write("name,cuda_time_us,cpu_time_us,count\n") + for evt in prof.key_averages(): + if evt.device_time_total == 0 and evt.cpu_time_total == 0: + continue + f.write(f"{evt.key},{evt.device_time_total},{evt.cpu_time_total},{evt.count}\n") + print(f"[ep_bench] per-kernel CSV: {kern_csv}", flush=True) + + # Captured CUDA graphs (when --cuda-graph) hold references to NCCL EP + # handles and per-pool streams; drop them and sync before ep_finalize, + # otherwise the post-finalize dist.barrier can deadlock against pending + # graph state. + torch.cuda.synchronize() + if args.cuda_graph: + fwd_bwd_dispatch_fn = None + fwd_bwd_combine_fn = None + captured_runners.clear() + del g_disp, g_comb, disp_mod, comb_mod + del tokens_p, eo_p, buffer, handle, recv_tokens, recv_w, tokens, topk_w, expert_out + gc.collect() + torch.cuda.synchronize() + # Release NCCL EP's borrowed comm before torch destroys it. + ep_finalize() + dist.barrier() + dist.destroy_process_group() + sys.stdout.flush() + sys.stderr.flush() + + +if __name__ == "__main__": + main() diff --git a/examples/pytorch/ep/bench/run_ep_bench.sh b/examples/pytorch/ep/bench/run_ep_bench.sh new file mode 100755 index 0000000000..fefecd7fa9 --- /dev/null +++ b/examples/pytorch/ep/bench/run_ep_bench.sh @@ -0,0 +1,72 @@ +#!/usr/bin/env bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +# +# Launcher for examples/pytorch/ep/bench/ep_bench.py. +# Examples: +# bash run_ep_bench.sh # plain run, stdout only +# bash run_ep_bench.sh --cuda-graph # capture + replay each stage as a CUDA graph +# bash run_ep_bench.sh --kineto # Chrome trace + per-kernel CSV (rank 0) +# bash run_ep_bench.sh --nsys # nsys profile on rank 0 -> results/pyt_nsys.nsys-rep + +set -uo pipefail + +NSYS=0; KINETO=0; CGRAPH=0 +for a in "$@"; do + case "$a" in + --nsys) NSYS=1 ;; + --kineto) KINETO=1 ;; + --cuda-graph) CGRAPH=1 ;; + *) echo "unknown arg: $a" >&2; exit 2 ;; + esac +done +if [ "${NSYS}" -eq 1 ] && [ "${KINETO}" -eq 1 ]; then + echo "--nsys and --kineto both attach CUPTI; pick one." >&2; exit 2 +fi + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TE_REPO_ROOT="$(cd "${SCRIPT_DIR}/../../../.." && pwd)" +RESULTS="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS}" +export PYTHONPATH="${TE_REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +DETECTED_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l) +NUM_GPUS="${NUM_GPUS:-${DETECTED_GPUS}}" +if [ "${NUM_GPUS}" -lt 4 ]; then + echo "EP bench requires >=4 GPUs (found ${NUM_GPUS}); SKIPPING."; exit 0 +fi +if [ "${NUM_GPUS}" -gt 8 ]; then NUM_GPUS=8; fi + +: "${TIMEOUT_S:=1800}" +: "${NCCL_EP_JIT_CACHE_DIR:=${TMPDIR:-/tmp}/nccl_ep_jit_cache_$(id -u)}" +export NCCL_EP_JIT_CACHE_DIR +mkdir -p "${NCCL_EP_JIT_CACHE_DIR}" + +EXTRA_ARGS=() +TAG="pyt" +[ "${CGRAPH}" -eq 1 ] && EXTRA_ARGS+=(--cuda-graph) && TAG="${TAG}_cg" +if [ "${KINETO}" -eq 1 ]; then + EXTRA_ARGS+=(--kineto "${RESULTS}/kineto_${TAG}") +fi + +EP_BENCH_EXTRA_FLAGS="${EP_BENCH_EXTRA_FLAGS:-}" +LAUNCH=(torchrun --standalone --nnodes=1 --nproc-per-node="${NUM_GPUS}" + "${SCRIPT_DIR}/ep_bench.py" "${EXTRA_ARGS[@]}" ${EP_BENCH_EXTRA_FLAGS}) + +if [ "${NSYS}" -eq 1 ]; then + NSYS_CMD=(nsys profile + --output "${RESULTS}/pyt_${TAG}_nsys" + --force-overwrite=true + --trace=cuda,nvtx + --gpu-metrics-devices=none + --cuda-um-cpu-page-faults=false + --cuda-um-gpu-page-faults=false) + echo "[run_ep_bench] launching with nsys (results/${TAG}_nsys.nsys-rep)" + timeout --foreground --signal=TERM "${TIMEOUT_S}" "${NSYS_CMD[@]}" "${LAUNCH[@]}" + RC=$? +else + timeout --foreground --signal=TERM "${TIMEOUT_S}" "${LAUNCH[@]}" + RC=$? +fi +exit $RC diff --git a/examples/pytorch/ep/bench/run_nccl_ep_bench.sh b/examples/pytorch/ep/bench/run_nccl_ep_bench.sh new file mode 100755 index 0000000000..8f6da04a00 --- /dev/null +++ b/examples/pytorch/ep/bench/run_nccl_ep_bench.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +# +# Launcher for the native NCCL EP ``ep_bench`` (baseline for PyTorch comparison). +# Usage: +# bash run_nccl_ep_bench.sh # plain run, stdout only +# bash run_nccl_ep_bench.sh --nsys # nsys → results/nccl_ep_nsys.nsys-rep + +set -uo pipefail + +NSYS=0 +for a in "$@"; do + case "$a" in + --nsys) NSYS=1 ;; + *) echo "unknown arg: $a" >&2; exit 2 ;; + esac +done + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TE_REPO_ROOT="$(cd "${SCRIPT_DIR}/../../../.." && pwd)" +RESULTS="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS}" + +BIN="${TE_REPO_ROOT}/3rdparty/nccl/build/test/nccl_ep/ep_bench" +LIB="${TE_REPO_ROOT}/3rdparty/nccl/build/lib" +[ -x "${BIN}" ] || { echo "ep_bench not built at ${BIN}" >&2; exit 2; } + +NUM_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l) +if [ "${NUM_GPUS}" -lt 4 ]; then + echo "NCCL EP bench requires >=4 GPUs (found ${NUM_GPUS}); SKIPPING."; exit 0 +fi +if [ "${NUM_GPUS}" -gt 8 ]; then NUM_GPUS=8; fi + +if [ "${NSYS}" -eq 1 ]; then + ITERS=10 +else + ITERS=50 +fi +ARGS=(--algorithm ht --layout em --tokens 2048 --hidden 7168 --top-k 8 + --experts 256 --warmup 5 --iters "${ITERS}") +[ "${NSYS}" -eq 1 ] && ARGS+=(--profile) # enables NVTX ranges + cudaProfilerStart/Stop + +CMD=(/usr/local/mpi/bin/mpirun --allow-run-as-root --oversubscribe -np "${NUM_GPUS}" + -x LD_LIBRARY_PATH="${LIB}:${LD_LIBRARY_PATH:-}" + "${BIN}" "${ARGS[@]}") + +if [ "${NSYS}" -eq 1 ]; then + CMD=(nsys profile + --output "${RESULTS}/nccl_ep_nsys" + --force-overwrite=true + --capture-range=cudaProfilerApi + --capture-range-end=stop + --trace=cuda,nvtx,osrt + "${CMD[@]}") +fi + +[ "${NSYS}" -eq 1 ] && SUFFIX="_nsys" || SUFFIX="" +LOG="${RESULTS}/stdout_nccl_ep${SUFFIX}.txt" +"${CMD[@]}" 2>&1 | tee "${LOG}" +echo "Done. Log: ${LOG}" diff --git a/examples/pytorch/ep/ep_moe.py b/examples/pytorch/ep/ep_moe.py new file mode 100644 index 0000000000..934d88d8c7 --- /dev/null +++ b/examples/pytorch/ep/ep_moe.py @@ -0,0 +1,228 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""End-to-end MoE example: dispatch -> batched expert linear -> combine, fwd + bwd. + +One process per GPU; launched via run_test_ep.sh (torchrun). +""" + +import argparse +import os +import sys + +import numpy as np +import torch +import torch.distributed as dist + +from transformer_engine.pytorch.ep import ( + EpHandle, + EpBuffer, + ep_scope, + ep_dispatch, + ep_combine, +) + + +def _parse_args(): + p = argparse.ArgumentParser(description="TE-PyTorch EP MoE example (fwd + bwd)") + p.add_argument("--num-tokens", type=int, default=8, help="Per-rank token count.") + p.add_argument("--top-k", type=int, default=2) + p.add_argument("--hidden", type=int, default=32) + p.add_argument("--hidden-out", type=int, default=32) + p.add_argument("--num-experts", type=int, default=None) + p.add_argument("--check", action="store_true", default=True) + p.add_argument( + "--benchmark", + action="store_true", + help="Time fwd over HBM buffers.", + ) + p.add_argument("--benchmark-iters", type=int, default=20) + p.add_argument("--benchmark-warmup", type=int, default=5) + return p.parse_args() + + +def _make_routing(rank, T, K, E, num_local_experts): + """Deterministic routing: topk_idx[t, k] = (rank*NLE + t*K + k) % E.""" + topk_idx = np.empty((T, K), dtype=np.int64) + for t in range(T): + for k in range(K): + topk_idx[t, k] = (rank * num_local_experts + t * K + k) % E + return topk_idx + + +def _batched_expert_linear(recv_tokens, kernels, num_local_experts): + """Per-expert linear via bmm; ``recv_pr // num_local_experts`` slots per expert.""" + recv_pr, _H = recv_tokens.shape + H_out = kernels.shape[-1] + slots_per_expert = recv_pr // num_local_experts + grouped = recv_tokens.view(num_local_experts, slots_per_expert, recv_tokens.shape[-1]) + out = torch.bmm(grouped, kernels.to(grouped.dtype)) + return out.view(recv_pr, H_out) + + +def _reference_moe(tokens, topk_idx, topk_w, kernels): + T, K = topk_idx.shape + H_out = kernels.shape[-1] + out = np.zeros((T, H_out), dtype=np.float32) + for t in range(T): + tok = tokens[t].astype(np.float32) + for k in range(K): + e = int(topk_idx[t, k]) + out[t] += float(topk_w[t, k]) * (tok @ kernels[e].astype(np.float32)) + return out + + +def _reference_grad(tokens, topk_idx, topk_w, kernels): + T, K = topk_idx.shape + H = tokens.shape[-1] + ref_out = _reference_moe(tokens, topk_idx, topk_w, kernels) + grad = np.zeros((T, H), dtype=np.float32) + for t in range(T): + mixed = np.zeros_like(kernels[0]) + for k in range(K): + mixed = mixed + float(topk_w[t, k]) * kernels[int(topk_idx[t, k])] + grad[t] = ref_out[t] @ mixed.T + return ref_out, grad + + +def main(): + args = _parse_args() + + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", rank))) + device = torch.device("cuda", torch.cuda.current_device()) + + major, minor = torch.cuda.get_device_capability() + if major * 10 + minor < 90: + if rank == 0: + print(f"[ep_moe] SKIPPED: EP requires SM>=90 (got SM{major}{minor})") + dist.destroy_process_group() + return + + if world_size < 4: + if rank == 0: + print(f"[ep_moe] SKIPPED: EP requires >= 4 ranks (got {world_size})") + dist.destroy_process_group() + return + + ep_size = world_size + num_experts = args.num_experts if args.num_experts is not None else world_size + assert num_experts % ep_size == 0 + num_local_experts = num_experts // ep_size + T = args.num_tokens + recv_pr = ep_size * T * args.top_k + + ep_group = dist.new_group(ranks=list(range(world_size)), backend="nccl") + with ep_scope( + ep_group, + num_experts=num_experts, + max_tokens_per_rank=T, + recv_capacity_per_rank=recv_pr, + hidden_dim=args.hidden, + ): + _run_layer(args, rank, world_size, ep_size, num_experts, num_local_experts, T, recv_pr, device) + dist.destroy_process_group() + + +def _run_layer(args, rank, world_size, ep_size, num_experts, num_local_experts, T, recv_pr, device): + rng = np.random.default_rng(seed=42 + rank) + tokens_np = (rng.standard_normal((T, args.hidden), dtype=np.float32) * 0.5).astype(np.float32) + topk_idx_np = _make_routing(rank, T, args.top_k, num_experts, num_local_experts) + w_np = np.full((T, args.top_k), 1.0 / args.top_k, dtype=np.float32) + # Same seed across ranks -> identical kernel array everywhere. + kr = np.random.default_rng(seed=42) + kernels_np = ( + kr.standard_normal((num_experts, args.hidden, args.hidden_out), dtype=np.float32) + * (1.0 / np.sqrt(args.hidden)) + ).astype(np.float32) + + tokens = ( + torch.from_numpy(tokens_np).to(device=device, dtype=torch.bfloat16).requires_grad_(True) + ) + topk_idx = torch.from_numpy(topk_idx_np).to(device) + topk_w = torch.from_numpy(w_np).to(device) + kernels_local = torch.from_numpy( + kernels_np[rank * num_local_experts : (rank + 1) * num_local_experts] + ).to(device=device, dtype=torch.bfloat16) + + handle = EpHandle( + top_k=args.top_k, + max_tokens_per_rank=T, + recv_capacity_per_rank=recv_pr, + hidden_dim=args.hidden, + num_local_experts=num_local_experts, + ) + buffer = EpBuffer(handle) + + recv_t, recv_w_out, _tc = ep_dispatch(handle, buffer, tokens, topk_idx, topk_w) + expert_out = _batched_expert_linear(recv_t, kernels_local, num_local_experts) + # Apply per-slot topk weighting before combine. + expert_out = expert_out * recv_w_out.unsqueeze(-1).to(expert_out.dtype) + out = ep_combine(handle, buffer, expert_out) + + loss = 0.5 * (out.float() ** 2).sum() + loss.backward() + torch.cuda.synchronize() + + if rank == 0: + print( + f"[ep_moe] loss={loss.item():.4f} grad_tokens.shape={tuple(tokens.grad.shape)} " + f"ep={ep_size} num_experts={num_experts} recv_pr={recv_pr}" + ) + + if args.benchmark: + # Time dispatch + expert + combine over HBM buffers. + import time + + torch.cuda.synchronize() + dist.barrier() + for _ in range(args.benchmark_warmup): + rt, rw, _tc = ep_dispatch(handle, buffer, tokens.detach(), topk_idx, topk_w) + eo = _batched_expert_linear(rt, kernels_local, num_local_experts) + eo = eo * rw.unsqueeze(-1).to(eo.dtype) + ep_combine(handle, buffer, eo) + torch.cuda.synchronize() + dist.barrier() + t0 = time.perf_counter() + for _ in range(args.benchmark_iters): + rt, rw, _tc = ep_dispatch(handle, buffer, tokens.detach(), topk_idx, topk_w) + eo = _batched_expert_linear(rt, kernels_local, num_local_experts) + eo = eo * rw.unsqueeze(-1).to(eo.dtype) + ep_combine(handle, buffer, eo) + torch.cuda.synchronize() + dt_ms = (time.perf_counter() - t0) * 1000.0 / args.benchmark_iters + if rank == 0: + print( + f"[ep_moe --benchmark] HBM: {dt_ms:.3f} ms/iter " + f"(iters={args.benchmark_iters})" + ) + + if args.check: + # All-gather inputs/outputs/grads for a global reference comparison. + global_tokens = [torch.empty_like(tokens) for _ in range(world_size)] + global_topk_idx = [torch.empty_like(topk_idx) for _ in range(world_size)] + global_topk_w = [torch.empty_like(topk_w) for _ in range(world_size)] + global_out = [torch.empty_like(out) for _ in range(world_size)] + global_grad = [torch.empty_like(tokens.grad) for _ in range(world_size)] + dist.all_gather(global_tokens, tokens.detach()) + dist.all_gather(global_topk_idx, topk_idx) + dist.all_gather(global_topk_w, topk_w) + dist.all_gather(global_out, out.detach()) + dist.all_gather(global_grad, tokens.grad) + if rank == 0: + all_tokens = torch.cat(global_tokens).float().cpu().numpy() + all_idx = torch.cat(global_topk_idx).cpu().numpy() + all_w = torch.cat(global_topk_w).cpu().numpy() + all_out = torch.cat(global_out).float().cpu().numpy() + all_grad = torch.cat(global_grad).float().cpu().numpy() + ref_out, ref_grad = _reference_grad(all_tokens, all_idx, all_w, kernels_np) + np.testing.assert_allclose(all_out, ref_out, rtol=5e-2, atol=5e-2) + np.testing.assert_allclose(all_grad, ref_grad, rtol=5e-2, atol=5e-2) + print(f"[ep_moe] --check PASSED (ref_out.sum()={float(ref_out.sum()):.4f})") + + +if __name__ == "__main__": + main() + sys.exit(0) diff --git a/examples/pytorch/ep/run_test_ep.sh b/examples/pytorch/ep/run_test_ep.sh new file mode 100755 index 0000000000..13b41f4cb2 --- /dev/null +++ b/examples/pytorch/ep/run_test_ep.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -uo pipefail + +DETECTED_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l) +NUM_GPUS="${NUM_GPUS:-${DETECTED_GPUS}}" +if [ "${NUM_GPUS}" -lt 4 ]; then + echo "EP requires >= 4 GPUs (found ${NUM_GPUS}); SKIPPING." + exit 0 +fi +if [ "${NUM_GPUS}" -gt 8 ]; then NUM_GPUS=8; fi + +: ${TE_PATH:=/opt/transformerengine} +: ${TEST_TIMEOUT_S:=120} + +SCRIPT="${TE_PATH}/examples/pytorch/ep/ep_moe.py" +export PYTHONPATH="${TE_PATH}${PYTHONPATH:+:${PYTHONPATH}}" + +# Stage JIT cubins on tmpfs for fast iteration. +: ${NCCL_EP_JIT_CACHE_DIR:="${TMPDIR:-/tmp}/nccl_ep_jit_cache_$(id -u)"} +export NCCL_EP_JIT_CACHE_DIR +mkdir -p "$NCCL_EP_JIT_CACHE_DIR" + +echo "*** Executing ep_moe.py across ${NUM_GPUS} GPUs (timeout=${TEST_TIMEOUT_S}s) ***" +timeout --foreground --signal=KILL "${TEST_TIMEOUT_S}" \ + torchrun --standalone --nnodes=1 --nproc-per-node="${NUM_GPUS}" \ + "${SCRIPT}" --check 2>&1 | tee stdout_ep_moe.txt +RC=${PIPESTATUS[0]} + +RET=0 +if [ "${RC}" -ne 0 ]; then RET=1; fi +if grep -qE "(^|]:)FAILED|(^|]:)Traceback" stdout_ep_moe.txt; then RET=1; fi +rm -f stdout_ep_moe.txt +exit $RET diff --git a/tests/pytorch/distributed/run_ep.py b/tests/pytorch/distributed/run_ep.py new file mode 100644 index 0000000000..7f74a454aa --- /dev/null +++ b/tests/pytorch/distributed/run_ep.py @@ -0,0 +1,338 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Multi-process PyTorch EP tests, launched via torchrun (one process per GPU).""" + +import os +import sys +import unittest + +import numpy as np +import torch +import torch.distributed as dist + +from transformer_engine.pytorch.ep import ( + EpHandle, + EpBuffer, + ep_bootstrap, + ep_finalize, + ep_prepare, + ep_dispatch, + ep_combine, + _ep_combine_raw, + _ep_dispatch_raw, +) + +# Must come after the transformer_engine import so libtransformer_engine.so is loaded. +import transformer_engine_torch as tex # noqa: F401 + + +NUM_LOCAL_EXPERTS = 2 +HIDDEN_DIM = 32 +TOP_K = 2 +TOKENS_PER_RANK = 4 + + +def _device_sm() -> int: + major, minor = torch.cuda.get_device_capability() + return major * 10 + minor + + +def _build_ep_group(): + """EP group spanning all ranks of the default PG.""" + world_pg = dist.distributed_c10d._get_default_group() + ranks = list(range(world_pg.size())) + return dist.new_group(ranks=ranks, backend="nccl") + + +def _make_identity_inputs(rank, ep_size, device="cuda"): + """Per-rank identity routing + uniform weights so combine matches tokens.""" + T = TOKENS_PER_RANK + E = ep_size * NUM_LOCAL_EXPERTS + topk_idx = np.empty((T, TOP_K), dtype=np.int64) + base = rank * T + for t in range(T): + for k in range(TOP_K): + topk_idx[t, k] = ((base + t) * TOP_K + k) % E + tokens_np = np.linspace( + 0.1 + rank * 0.01, 0.9 + rank * 0.01, T * HIDDEN_DIM, dtype=np.float32 + ).reshape(T, HIDDEN_DIM) + topk_weights = np.full((T, TOP_K), 1.0 / TOP_K, dtype=np.float32) + return ( + torch.from_numpy(topk_idx).to(device), + torch.from_numpy(tokens_np).to(device=device, dtype=torch.bfloat16), + torch.from_numpy(topk_weights).to(device), + ) + + +class _Cfg: + rank: int + world_size: int + ep_size: int + num_experts: int + recv_capacity_per_rank: int + device: torch.device + + +def _make_cfg() -> _Cfg: + cfg = _Cfg() + cfg.rank = dist.get_rank() + cfg.world_size = dist.get_world_size() + cfg.ep_size = cfg.world_size + cfg.num_experts = NUM_LOCAL_EXPERTS * cfg.ep_size + T = TOKENS_PER_RANK + active = min(cfg.num_experts, T * cfg.ep_size * TOP_K) + overconc = cfg.num_experts // active + cfg.recv_capacity_per_rank = NUM_LOCAL_EXPERTS * max(T * cfg.ep_size * TOP_K, 16) * overconc * 2 + cfg.device = torch.device("cuda", torch.cuda.current_device()) + return cfg + + +class TestEP(unittest.TestCase): + cfg: _Cfg + ep_group: dist.ProcessGroup + + @classmethod + def setUpClass(cls): + if _device_sm() < 90: + raise unittest.SkipTest(f"NCCL EP requires SM>=90 (got SM{_device_sm()})") + cls.cfg = _make_cfg() + cls.ep_group = _build_ep_group() + ep_bootstrap( + cls.ep_group, + num_experts=cls.cfg.num_experts, + max_tokens_per_rank=TOKENS_PER_RANK, + recv_capacity_per_rank=cls.cfg.recv_capacity_per_rank, + hidden_dim=HIDDEN_DIM, + zero_copy=True, + ) + + def _make_handle(self, alignment=0, top_k=TOP_K): + return EpHandle( + top_k=top_k, + max_tokens_per_rank=TOKENS_PER_RANK, + recv_capacity_per_rank=self.cfg.recv_capacity_per_rank, + hidden_dim=HIDDEN_DIM, + num_local_experts=NUM_LOCAL_EXPERTS, + alignment=alignment, + ) + + def _make_buffers(self, dtype=torch.bfloat16): + """Allocate raw recv buffers + token_counts for the primitive tests.""" + rc = self.cfg.recv_capacity_per_rank + return ( + torch.empty(rc, HIDDEN_DIM, dtype=dtype, device=self.cfg.device), + torch.empty(rc, dtype=torch.float32, device=self.cfg.device), + torch.empty(NUM_LOCAL_EXPERTS, dtype=torch.int32, device=self.cfg.device), + ) + + def _make_ep_buffer(self, handle): + return EpBuffer(handle) + + @staticmethod + def _weighted(recv_tokens, recv_w): + """fp32 per-slot weighting + cast back; matches the upstream combine input.""" + mask = (recv_w != 0).to(torch.float32).unsqueeze(-1) + return (recv_tokens.float() * recv_w.unsqueeze(-1).float() * mask).to(recv_tokens.dtype) + + def _moe_step(self, handle, buffer, topk_idx, tokens, w): + recv_t, recv_w_out, _tc = ep_dispatch(handle, buffer, tokens, topk_idx, w) + eo = self._weighted(recv_t, recv_w_out) + return ep_combine(handle, buffer, eo) + + # Prepare + + def test_primitive_prepare(self): + handle = self._make_handle() + topk_idx, _toks, _w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + token_counts = ep_prepare(handle, topk_idx) + torch.cuda.synchronize() + self.assertEqual(token_counts.shape, (NUM_LOCAL_EXPERTS,)) + local = int(token_counts.sum().item()) + total = torch.tensor([local], dtype=torch.int64, device=self.cfg.device) + dist.all_reduce(total, op=dist.ReduceOp.SUM, group=self.ep_group) + self.assertEqual(int(total.item()), self.cfg.world_size * TOKENS_PER_RANK * TOP_K) + + # Identity round-trip via raw primitives + + def test_primitive_dispatch_combine_identity(self): + handle = self._make_handle() + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + recv_tokens, recv_w, _ = self._make_buffers() + ep_prepare(handle, topk_idx) + _ep_dispatch_raw(handle, topk_idx, tokens, w, recv_tokens, recv_w) + result = torch.empty_like(tokens) + _ep_combine_raw(handle, self._weighted(recv_tokens, recv_w), result) + torch.cuda.synchronize() + torch.testing.assert_close(result.float(), tokens.float(), atol=5e-2, rtol=5e-2) + + # Autograd + + def test_dispatch_fwd_bwd(self): + """0.5*||recv_tokens||^2 ; grad_tokens equals TOP_K * tokens.""" + handle = self._make_handle() + buffer = self._make_ep_buffer(handle) + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + tokens_p = tokens.detach().clone().requires_grad_(True) + recv_t, _recv_w, _tc = ep_dispatch(handle, buffer, tokens_p, topk_idx, w) + loss = 0.5 * (recv_t.float() ** 2).sum() + loss.backward() + torch.cuda.synchronize() + torch.testing.assert_close( + tokens_p.grad.float(), tokens.float() * float(TOP_K), atol=5e-2, rtol=5e-2 + ) + + def test_combine_fwd_bwd(self): + """Full dispatch + combine fwd+bwd; identity inputs round-trip.""" + handle = self._make_handle() + buffer = self._make_ep_buffer(handle) + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + tokens_p = tokens.detach().clone().requires_grad_(True) + out = self._moe_step(handle, buffer, topk_idx, tokens_p, w) + loss = 0.5 * (out.float() ** 2).sum() + loss.backward() + torch.cuda.synchronize() + torch.testing.assert_close(out.float(), tokens.float(), atol=5e-2, rtol=5e-2) + torch.testing.assert_close(tokens_p.grad.float(), tokens.float(), atol=5e-2, rtol=5e-2) + + # Multi-iter stability + + def test_dispatch_fwd_bwd_multiple_iterations(self): + """5 fwd+bwd iters on the same EpHandle + EpBuffer must be bit-stable.""" + handle = self._make_handle() + buffer = self._make_ep_buffer(handle) + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + + def one_step(): + tokens_p = tokens.detach().clone().requires_grad_(True) + out = self._moe_step(handle, buffer, topk_idx, tokens_p, w) + loss = 0.5 * (out.float() ** 2).sum() + loss.backward() + return out.detach().clone(), tokens_p.grad.detach().clone() + + out_ref, grad_ref = one_step() + torch.cuda.synchronize() + for _ in range(4): + out_i, grad_i = one_step() + torch.cuda.synchronize() + torch.testing.assert_close(out_i, out_ref, atol=0, rtol=0) + torch.testing.assert_close(grad_i, grad_ref, atol=0, rtol=0) + + # CUDA graph + + def test_cuda_graph_capture(self): + """Capture raw dispatch+combine into a CUDA graph; replay must be bit-stable.""" + handle = self._make_handle() + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + recv_tokens, recv_w, _ = self._make_buffers() + result = torch.empty_like(tokens) + + def step(): + ep_prepare(handle, topk_idx) + _ep_dispatch_raw(handle, topk_idx, tokens, w, recv_tokens, recv_w) + _ep_combine_raw(handle, self._weighted(recv_tokens, recv_w), result) + + for _ in range(3): + step() + torch.cuda.synchronize() + + # Routing is fixed per layer; prepare runs once before capture. + ep_prepare(handle, topk_idx) + torch.cuda.synchronize() + + graph = torch.cuda.CUDAGraph() + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + with torch.cuda.graph(graph): + _ep_dispatch_raw(handle, topk_idx, tokens, w, recv_tokens, recv_w) + _ep_combine_raw(handle, self._weighted(recv_tokens, recv_w), result) + torch.cuda.current_stream().wait_stream(s) + torch.cuda.synchronize() + + ref = result.clone() + for _ in range(5): + graph.replay() + torch.cuda.synchronize() + torch.testing.assert_close(result.float(), ref.float(), atol=0, rtol=0) + + # PP-1F1B handle isolation + + def test_pp_1f1b_two_handles(self): + """PP-1F1B interleave (F0 F1 B0 F2 B1 B2) over 3 per-microbatch handles.""" + T, H = TOKENS_PER_RANK, HIDDEN_DIM + idx, _toks, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + scales = (0.13, 0.41, 0.77) + handles, buffers, tokens, tokens_p = [], [], [], [] + for s in scales: + h = self._make_handle() + handles.append(h) + buffers.append(self._make_ep_buffer(h)) + t = torch.full( + (T, H), s + self.cfg.rank * 0.01, dtype=torch.bfloat16, device=self.cfg.device + ) + tokens.append(t) + tokens_p.append(t.detach().clone().requires_grad_(True)) + + recv = [None, None, None] + + def fwd(k): + recv[k], _, _ = ep_dispatch(handles[k], buffers[k], tokens_p[k], idx, w) + + def bwd(k): + (0.5 * (recv[k].float() ** 2).sum()).backward() + recv[k] = None + + fwd(0) + fwd(1) + bwd(0) + fwd(2) + bwd(1) + bwd(2) + torch.cuda.synchronize() + for k in range(3): + torch.testing.assert_close( + tokens_p[k].grad.float(), + tokens[k].float() * float(TOP_K), + atol=5e-2, + rtol=5e-2, + ) + + # Input validation + + def test_topk_int32_raises_clear_error(self): + handle = self._make_handle() + topk_idx_int32 = torch.zeros( + TOKENS_PER_RANK, TOP_K, dtype=torch.int32, device=self.cfg.device + ) + with self.assertRaises(RuntimeError) as cm: + ep_prepare(handle, topk_idx_int32) + msg = str(cm.exception) + self.assertIn("topk_idx", msg) + self.assertIn(".long()", msg) + + +def _init_distributed(): + dist.init_process_group(backend="nccl") + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + try: + from torch.distributed import _symmetric_memory as _symm_mem + + _symm_mem.set_backend("NCCL") + except (ImportError, RuntimeError): + pass + + +if __name__ == "__main__": + _init_distributed() + loader = unittest.TestLoader() + name_filter = os.environ.get("NVTE_EP_TEST_FILTER") + if name_filter: + loader.testMethodPrefix = name_filter + suite = loader.loadTestsFromTestCase(TestEP) + runner = unittest.TextTestRunner(stream=sys.stdout, verbosity=2) + result = runner.run(suite) + dist.barrier() + ep_finalize() + dist.destroy_process_group() + sys.exit(0 if result.wasSuccessful() else 1) diff --git a/tests/pytorch/distributed/run_test_ep.sh b/tests/pytorch/distributed/run_test_ep.sh new file mode 100755 index 0000000000..92d63cff7e --- /dev/null +++ b/tests/pytorch/distributed/run_test_ep.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +# +# Launcher for tests/pytorch/distributed/run_ep.py. Auto-detects GPU count. +# Short timeout by default to surface hangs early. + +set -uo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TE_REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +export PYTHONPATH="${TE_REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +DETECTED_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l) +if [ "${DETECTED_GPUS}" -lt 4 ]; then + echo "EP requires >= 4 GPUs (found ${DETECTED_GPUS}); SKIPPING." + exit 0 +fi +NUM_RANKS="${NVTE_TEST_EP_NUM_RANKS:-${DETECTED_GPUS}}" +if [ "${NUM_RANKS}" -gt 8 ]; then NUM_RANKS=8; fi + +# Short timeout to detect hangs early. +TEST_TIMEOUT_S="${TEST_TIMEOUT_S:-120}" + +# Stage NCCL EP JIT cubins on tmpfs to keep iteration fast. +: ${NCCL_EP_JIT_CACHE_DIR:="${TMPDIR:-/tmp}/nccl_ep_jit_cache_$(id -u)"} +export NCCL_EP_JIT_CACHE_DIR +mkdir -p "$NCCL_EP_JIT_CACHE_DIR" + +SCRIPT="${SCRIPT_DIR}/run_ep.py" +echo "=== Running ${SCRIPT} on ${NUM_RANKS} GPUs (timeout=${TEST_TIMEOUT_S}s) ===" + +# setsid + kill-after so SIGKILL takes down the whole process group, not just torchrun. +setsid timeout --foreground --kill-after=10 --signal=TERM "${TEST_TIMEOUT_S}" \ + torchrun --standalone --nnodes=1 --nproc-per-node="${NUM_RANKS}" \ + "${SCRIPT}" 2>&1 | tee stdout_ep.txt +RC=${PIPESTATUS[0]} +pkill -9 -f "tests/pytorch/distributed/run_ep.py" 2>/dev/null || true + +RET=0 +if [ "${RC}" -ne 0 ]; then + echo "torchrun exited with ${RC}" + RET=1 +fi +# Match unittest failure markers and unhandled Python tracebacks; torchrun +# prefixes per-rank stderr with "[rankN]:" so don't anchor at column 0. +if grep -qE "(^|]:)FAILED|(^|]:)Traceback" stdout_ep.txt; then RET=1; fi +if ! grep -qE "Ran [0-9]+ test|^OK$" stdout_ep.txt; then + echo "ERROR: no test summary — likely hang or early crash" + RET=1 +fi + +if [ -z "${KEEP_EP_LOGS:-}" ]; then rm -f stdout_ep.txt; fi +exit $RET diff --git a/tests/pytorch/distributed/test_ep.py b/tests/pytorch/distributed/test_ep.py new file mode 100644 index 0000000000..81eef9a3c1 --- /dev/null +++ b/tests/pytorch/distributed/test_ep.py @@ -0,0 +1,31 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Pytest driver — spawns run_ep.py under torchrun and asserts the suite passed.""" + +import os +import subprocess +from pathlib import Path + +import pytest +import torch + +TEST_ROOT = Path(__file__).parent.resolve() +WORKER = TEST_ROOT / "run_ep.py" +LAUNCHER = TEST_ROOT / "run_test_ep.sh" + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="EP requires >= 4 GPUs") +def test_multi_process_ep(): + """Launch the EP unit-test suite across all visible GPUs. + + Short timeout so a hang on any rank surfaces fast rather than burning CI time. + """ + timeout_s = int(os.environ.get("NVTE_TEST_EP_TIMEOUT_S", "180")) + proc = subprocess.run( + ["bash", str(LAUNCHER)], + env={**os.environ, "KEEP_EP_LOGS": "1", "TEST_TIMEOUT_S": str(timeout_s)}, + timeout=timeout_s + 30, + check=False, + ) + assert proc.returncode == 0, f"EP test suite failed (rc={proc.returncode})" diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 13d872392d..c4fe055933 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -9,6 +9,7 @@ #include +#include #include #include #include @@ -648,6 +649,42 @@ void inplace_multi_tensor_swizzle_scales_for_gemm_unchecked(std::vector +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "transformer_engine/comm_window.h" + +#ifdef NCCL_HAS_SYMMEM_SUPPORT +#include +#endif + +#include "../common.h" +#include "../extensions.h" +#include "transformer_engine/gemm.h" + +namespace transformer_engine::pytorch { + +namespace { + +// EP process group name, captured at ep_initialize. Used by the symm-mem +// window resolver below to look up SymmetricMemory for payload tensors. +// Empty until ep_initialize. +std::string g_ep_group_name; // NOLINT(runtime/string) + +// True while the EP backend holds a borrowed reference to torch's NCCL comm. +bool g_ep_initialized = false; + +// Zero-copy IO toggle. Placeholder for the symm-mem fast path; per-step ops +// always pass kNoWindow in this release regardless of the toggle. Wired up +// so the switch is a one-line change when the backend lands the fast path. +// Atomic so the Python-side toggle is safe against concurrent +// ep_dispatch/combine (which release the GIL). +std::atomic g_zero_copy_enabled{false}; + +// Per-step ops always pass kNoWindow in this release; the symm-mem IO path is +// planned for a near-future release. +constexpr NVTECommWindow kNoWindow = {nullptr, 0}; + +// Resolve ``t`` to an NCCL symm-mem window for the zero-copy one-sided path. +// Returns ``kNoWindow`` when symm-mem support isn't compiled in, zero-copy is +// disabled, no group is set, or ``t`` isn't symm-mem-backed. Currently unused +// at per-step call sites (they hardcode kNoWindow); kept so flipping +// ``g_zero_copy_enabled`` is the only change needed once the backend's +// symm-mem IO path is exposed. +[[maybe_unused]] NVTECommWindow maybe_make_window(const at::Tensor& t) { +#ifdef NCCL_HAS_SYMMEM_SUPPORT + if (!g_zero_copy_enabled.load(std::memory_order_relaxed)) return kNoWindow; + if (g_ep_group_name.empty()) return kNoWindow; + auto sm = c10d::symmetric_memory::rendezvous(t, g_ep_group_name); + if (sm == nullptr) return kNoWindow; + auto* nccl_sm = dynamic_cast(sm.get()); + NVTE_CHECK(nccl_sm != nullptr, + "Symm-mem backend mismatch: expected NCCLSymmetricMemory. Set the backend to " + "\"NCCL\" before allocating EP payload buffers."); + return NVTECommWindow{static_cast(nccl_sm->get_window()), + static_cast(nccl_sm->get_offset())}; +#else + (void)t; + return kNoWindow; +#endif +} + +// The backend only accepts int64 topk_idx. The PyTorch wrapper enforces this +// at the boundary so the per-step ops don't need an upcast workspace. +void check_topk_idx_int64(at::Tensor topk_idx) { + NVTE_CHECK(topk_idx.is_contiguous(), "topk_idx must be contiguous"); + NVTE_CHECK(topk_idx.scalar_type() == at::kLong, + "topk_idx must be int64; got dtype=", c10::toString(topk_idx.scalar_type()), + ". Cast with topk_idx.long() before calling."); +} + +using Shape = std::vector; + +} // namespace + +bool ep_get_zero_copy() { return g_zero_copy_enabled.load(std::memory_order_relaxed); } + +// ── Bootstrap ──────────────────────────────────────────────────────────────── +// Borrows torch's NCCL host comm (from ``ProcessGroupNCCL._comm_ptr()``). +// ``group_name`` is captured for the symm-mem window resolver. + +void ep_initialize(uintptr_t comm_ptr, const std::string& group_name, int64_t num_experts, + int64_t max_tokens_per_rank, int64_t max_recv_tokens_per_rank, + int64_t hidden_dim, int64_t max_num_sms, + pybind11::object max_token_dtype, bool zero_copy) { + NVTE_CHECK(!group_name.empty(), "group_name must be non-empty (used for symm-mem lookup)"); + NVTE_CHECK(comm_ptr != 0, "comm_ptr must be non-null (torch NCCL host comm pointer)"); + NVTE_CHECK(!g_ep_initialized, "ep_initialize called twice without ep_finalize"); + + auto ep_comm = reinterpret_cast(comm_ptr); + int ep_size = 0; + NVTE_CHECK(ncclCommCount(ep_comm, &ep_size) == ncclSuccess, "ncclCommCount failed"); + auto torch_dtype = max_token_dtype.cast(); + NVTEEpGroupConfig cfg{ + /*ep_size=*/ep_size, + /*num_experts=*/static_cast(num_experts), + /*max_tokens_per_rank=*/static_cast(max_tokens_per_rank), + /*max_recv_tokens_per_rank=*/static_cast(max_recv_tokens_per_rank), + /*hidden_dim=*/static_cast(hidden_dim), + /*max_num_sms=*/static_cast(max_num_sms), + /*max_token_dtype=*/static_cast(GetTransformerEngineDType(torch_dtype)), + /*zero_copy=*/zero_copy ? 1 : 0, + }; + nvte_ep_initialize(static_cast(ep_comm), cfg); + g_zero_copy_enabled.store(zero_copy, std::memory_order_relaxed); + g_ep_initialized = true; + g_ep_group_name = group_name; +} + +void ep_finalize() { + if (!g_ep_initialized) return; + // The borrowed comm is owned by torch's symm-mem layer; don't destroy it. + nvte_ep_shutdown(); + g_ep_initialized = false; + g_ep_group_name.clear(); + g_zero_copy_enabled.store(false, std::memory_order_relaxed); +} + +namespace { + +NVTEEpLayerConfig make_layer_cfg(int64_t top_k, int64_t dispatch_output_per_expert_alignment) { + return NVTEEpLayerConfig{ + /*top_k=*/static_cast(top_k), + /*dispatch_output_per_expert_alignment=*/ + static_cast(dispatch_output_per_expert_alignment), + }; +} + +} // namespace + +int64_t ep_handle_mem_size(int64_t top_k, int64_t dispatch_output_per_expert_alignment) { + return static_cast( + nvte_ep_handle_mem_size(make_layer_cfg(top_k, dispatch_output_per_expert_alignment))); +} + +// ── Per-step ops ───────────────────────────────────────────────────────────── + +void ep_prepare(at::Tensor handle_mem, at::Tensor topk_idx, at::Tensor token_counts, int64_t top_k, + int64_t dispatch_output_per_expert_alignment) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + NVTE_CHECK(topk_idx.dim() >= 2, "topk_idx must be at least 2D [..., top_k]"); + check_topk_idx_int64(topk_idx); + const size_t T_flat = topk_idx.numel() / topk_idx.size(-1); + const size_t topk_n = static_cast(topk_idx.size(-1)); + + auto topk_idx_te = + makeTransformerEngineTensor(topk_idx.data_ptr(), Shape{T_flat, topk_n}, DType::kInt64); + auto token_counts_te = makeTransformerEngineTensor( + token_counts.data_ptr(), Shape{static_cast(token_counts.numel())}, DType::kInt32); + auto handle_mem_te = makeTransformerEngineTensor( + handle_mem.data_ptr(), Shape{static_cast(handle_mem.numel())}, DType::kByte); + + nvte_ep_prepare(handle_mem_te.data(), topk_idx_te.data(), token_counts_te.data(), + make_layer_cfg(top_k, dispatch_output_per_expert_alignment), stream); +} + +void ep_dispatch(at::Tensor handle_mem, at::Tensor topk_idx, at::Tensor tokens, + at::Tensor topk_weights, at::Tensor recv_tokens, at::Tensor recv_topk_weights) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + NVTE_CHECK(tokens.dim() >= 2, "tokens must be at least 2D [..., H]"); + NVTE_CHECK(topk_idx.dim() >= 2, "topk_idx must be at least 2D [..., top_k]"); + NVTE_CHECK(topk_weights.dim() >= 2, "topk_weights must be at least 2D [..., top_k]"); + NVTE_CHECK(recv_tokens.dim() >= 2, "recv_tokens must be at least 2D [..., recv_pr, H]"); + check_topk_idx_int64(topk_idx); + + const size_t H = static_cast(tokens.size(-1)); + const size_t T_flat = tokens.numel() / H; + const size_t topk_n = static_cast(topk_idx.size(-1)); + const size_t recv_pr = recv_tokens.numel() / H; + + NVTE_CHECK(static_cast(topk_weights.size(-1)) == topk_n, + "topk_weights last dim must equal topk_idx last dim"); + NVTE_CHECK(static_cast(recv_topk_weights.numel()) == recv_pr, + "recv_topk_weights total size must equal recv_tokens recv_pr"); + NVTE_CHECK(recv_tokens.scalar_type() == tokens.scalar_type(), "recv_tokens dtype (", + c10::toString(recv_tokens.scalar_type()), ") must match tokens dtype (", + c10::toString(tokens.scalar_type()), ")"); + + auto tok_dtype = GetTransformerEngineDType(tokens.scalar_type()); + auto handle_mem_te = makeTransformerEngineTensor( + handle_mem.data_ptr(), Shape{static_cast(handle_mem.numel())}, DType::kByte); + auto topk_idx_te = + makeTransformerEngineTensor(topk_idx.data_ptr(), Shape{T_flat, topk_n}, DType::kInt64); + auto tokens_te = makeTransformerEngineTensor(tokens.data_ptr(), Shape{T_flat, H}, tok_dtype); + auto topk_w_te = + makeTransformerEngineTensor(topk_weights.data_ptr(), Shape{T_flat, topk_n}, DType::kFloat32); + auto recv_tokens_te = + makeTransformerEngineTensor(recv_tokens.data_ptr(), Shape{recv_pr, H}, tok_dtype); + auto recv_topk_w_te = + makeTransformerEngineTensor(recv_topk_weights.data_ptr(), Shape{recv_pr}, DType::kFloat32); + + // top_k / alignment are carried by the cached layer_cfg seeded at ep_prepare; + // per-step ops look them up by handle_mem pointer in the backend. + nvte_ep_dispatch(handle_mem_te.data(), topk_idx_te.data(), tokens_te.data(), kNoWindow, + topk_w_te.data(), kNoWindow, recv_tokens_te.data(), kNoWindow, + recv_topk_w_te.data(), kNoWindow, stream); +} + +void ep_combine(at::Tensor handle_mem, at::Tensor expert_out, at::Tensor result) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + NVTE_CHECK(expert_out.dim() >= 2, "expert_out must be at least 2D [..., recv_pr, H]"); + NVTE_CHECK(result.dim() >= 2, "result must be at least 2D [..., H]"); + + const size_t H = static_cast(expert_out.size(-1)); + const size_t recv_pr = expert_out.numel() / H; + const size_t T_flat = result.numel() / H; + NVTE_CHECK(static_cast(result.size(-1)) == H, + "result hidden dim must equal expert_out hidden dim"); + NVTE_CHECK(result.scalar_type() == expert_out.scalar_type(), "result dtype (", + c10::toString(result.scalar_type()), ") must match expert_out dtype (", + c10::toString(expert_out.scalar_type()), ")"); + + auto eo_dtype = GetTransformerEngineDType(expert_out.scalar_type()); + auto handle_mem_te = makeTransformerEngineTensor( + handle_mem.data_ptr(), Shape{static_cast(handle_mem.numel())}, DType::kByte); + auto expert_out_te = + makeTransformerEngineTensor(expert_out.data_ptr(), Shape{recv_pr, H}, eo_dtype); + auto result_te = makeTransformerEngineTensor(result.data_ptr(), Shape{T_flat, H}, eo_dtype); + + nvte_ep_combine(handle_mem_te.data(), expert_out_te.data(), kNoWindow, result_te.data(), stream); +} + +void ep_dispatch_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor g_recv_topk_weights, + at::Tensor grad_tokens, at::Tensor grad_topk_weights) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + NVTE_CHECK(grad.dim() >= 2, "grad must be at least 2D [..., recv_pr, H]"); + NVTE_CHECK(grad_tokens.dim() >= 2, "grad_tokens must be at least 2D [..., H]"); + NVTE_CHECK(grad_topk_weights.dim() >= 2, "grad_topk_weights must be at least 2D [..., top_k]"); + + const size_t H = static_cast(grad.size(-1)); + const size_t recv_pr = grad.numel() / H; + const size_t T_flat = grad_tokens.numel() / H; + const size_t topk_n = static_cast(grad_topk_weights.size(-1)); + NVTE_CHECK(static_cast(g_recv_topk_weights.numel()) == recv_pr, + "g_recv_topk_weights total size must equal grad recv_pr"); + NVTE_CHECK(static_cast(grad_tokens.size(-1)) == H, + "grad_tokens hidden dim must equal grad H"); + NVTE_CHECK(static_cast(grad_topk_weights.numel()) == T_flat * topk_n, + "grad_topk_weights numel (", grad_topk_weights.numel(), + ") must equal T_flat * top_k (", T_flat * topk_n, ")"); + NVTE_CHECK(grad_tokens.scalar_type() == grad.scalar_type(), "grad_tokens dtype (", + c10::toString(grad_tokens.scalar_type()), ") must match grad dtype (", + c10::toString(grad.scalar_type()), ")"); + + auto g_dtype = GetTransformerEngineDType(grad.scalar_type()); + auto handle_mem_te = makeTransformerEngineTensor( + handle_mem.data_ptr(), Shape{static_cast(handle_mem.numel())}, DType::kByte); + auto grad_te = makeTransformerEngineTensor(grad.data_ptr(), Shape{recv_pr, H}, g_dtype); + auto g_recv_w_te = + makeTransformerEngineTensor(g_recv_topk_weights.data_ptr(), Shape{recv_pr}, DType::kFloat32); + auto grad_tokens_te = + makeTransformerEngineTensor(grad_tokens.data_ptr(), Shape{T_flat, H}, g_dtype); + auto grad_topk_w_te = makeTransformerEngineTensor(grad_topk_weights.data_ptr(), + Shape{T_flat, topk_n}, DType::kFloat32); + + nvte_ep_dispatch_bwd(handle_mem_te.data(), grad_te.data(), kNoWindow, g_recv_w_te.data(), + kNoWindow, grad_tokens_te.data(), grad_topk_w_te.data(), stream); +} + +void ep_combine_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor grad_expert_out) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + NVTE_CHECK(grad.dim() >= 2, "grad must be at least 2D [..., H]"); + NVTE_CHECK(grad_expert_out.dim() >= 2, "grad_expert_out must be at least 2D [..., recv_pr, H]"); + + const size_t H = static_cast(grad.size(-1)); + const size_t T_flat = grad.numel() / H; + const size_t recv_pr = grad_expert_out.numel() / H; + NVTE_CHECK(static_cast(grad_expert_out.size(-1)) == H, + "grad_expert_out hidden dim must match grad H"); + NVTE_CHECK(grad_expert_out.scalar_type() == grad.scalar_type(), "grad_expert_out dtype (", + c10::toString(grad_expert_out.scalar_type()), ") must match grad dtype (", + c10::toString(grad.scalar_type()), ")"); + + auto g_dtype = GetTransformerEngineDType(grad.scalar_type()); + auto handle_mem_te = makeTransformerEngineTensor( + handle_mem.data_ptr(), Shape{static_cast(handle_mem.numel())}, DType::kByte); + auto grad_te = makeTransformerEngineTensor(grad.data_ptr(), Shape{T_flat, H}, g_dtype); + auto grad_eo_te = + makeTransformerEngineTensor(grad_expert_out.data_ptr(), Shape{recv_pr, H}, g_dtype); + + nvte_ep_combine_bwd(handle_mem_te.data(), grad_te.data(), kNoWindow, grad_eo_te.data(), + kNoWindow, stream); +} + +void register_ep_bindings(pybind11::module_& m) { + namespace py = pybind11; + m.def("ep_initialize", &ep_initialize, + "Initialize the EP backend; borrows torch's NCCL comm pointed to by ``comm_ptr``.", + py::arg("comm_ptr"), py::arg("group_name"), py::arg("num_experts"), + py::arg("max_tokens_per_rank"), py::arg("max_recv_tokens_per_rank"), py::arg("hidden_dim"), + py::arg("max_num_sms") = 0, py::arg("max_token_dtype"), py::arg("zero_copy") = false, + py::call_guard()); + m.def("ep_finalize", &ep_finalize, "Tear down the EP backend. Idempotent.", + py::call_guard()); + m.def("ep_get_zero_copy", &ep_get_zero_copy, "Return the current EP zero-copy toggle state."); + m.def("ep_handle_mem_size", &ep_handle_mem_size, + "Return the handle_mem byte size for the given layer config.", py::arg("top_k"), + py::arg("dispatch_output_per_expert_alignment") = 0); + m.def("ep_prepare", &ep_prepare, "EP prepare", py::call_guard()); + m.def("ep_dispatch", &ep_dispatch, "EP dispatch", py::call_guard()); + m.def("ep_combine", &ep_combine, "EP combine", py::call_guard()); + m.def("ep_dispatch_bwd", &ep_dispatch_bwd, "EP dispatch backward", + py::call_guard()); + m.def("ep_combine_bwd", &ep_combine_bwd, "EP combine backward", + py::call_guard()); +} + +} // namespace transformer_engine::pytorch + +#endif // NVTE_WITH_NCCL_EP diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d1890872c0..9394f85cb3 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -292,6 +292,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "DSquaredReLU + DBias + Quantize", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); +#ifdef NVTE_WITH_NCCL_EP + transformer_engine::pytorch::register_ep_bindings(m); +#endif // NVTE_WITH_NCCL_EP + // Permutation functions m.def("moe_permute_fwd", transformer_engine::pytorch::moe_permute_fwd, "MOE permute FWD", py::call_guard()); diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py new file mode 100644 index 0000000000..bc4b3bb5d1 --- /dev/null +++ b/transformer_engine/pytorch/ep.py @@ -0,0 +1,734 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""PyTorch Expert Parallelism (EP) API.""" + +from __future__ import annotations + +import atexit +from contextlib import contextmanager +from typing import Iterator, Optional + +import torch +import torch.distributed as dist + +import transformer_engine_torch as tex + + +__all__ = [ + "EpHandle", + "EpBuffer", + "ep_bootstrap", + "ep_finalize", + "ep_scope", + "ep_dispatch", + "ep_combine", + "symm_mem_alloc", +] + + +# Symmetric-memory buffer allocator +# +# Used for the symm-mem zero-copy IO path. Set ``ep_bootstrap(zero_copy=True)`` +# to opt in; the C++ backend then operates the EP group in zero-copy mode. + + +def symm_mem_alloc( + shape, + dtype: torch.dtype, + ep_group: dist.ProcessGroup, + device: Optional[torch.device] = None, +) -> torch.Tensor: + """Allocate and rendezvous a symm-mem buffer on ep_group. Collective on ep_group.""" + if device is None: + device = torch.device("cuda", torch.cuda.current_device()) + try: + from torch.distributed import _symmetric_memory as _symm_mem + except ImportError as e: + raise RuntimeError( + "torch.distributed._symmetric_memory is unavailable; symm_mem_alloc " + "requires PyTorch built with NCCL symm-mem support." + ) from e + if _symm_mem.get_backend(device) != "NCCL": + _symm_mem.set_backend("NCCL") + t = _symm_mem.empty(*shape, dtype=dtype, device=device) + _symm_mem.rendezvous(t, group=ep_group.group_name) + return t + + +# Bootstrap + + +# NCCL EP requires NCCL >= 2.30.4 (matches the C++ backend's runtime check). +_MIN_NCCL_VERSION = (2, 30, 4) + + +def _check_nccl_runtime_version() -> None: + """Raise with a clear message if the loaded libnccl is too old for NCCL EP.""" + import ctypes + + try: + lib = ctypes.CDLL("libnccl.so.2", mode=ctypes.RTLD_GLOBAL) + v = ctypes.c_int(0) + if lib.ncclGetVersion(ctypes.byref(v)) != 0: + import warnings + + warnings.warn("ncclGetVersion failed; skipping NCCL EP version check.") + return + except OSError: # libnccl not findable; let the C++ side error + return + n = v.value + # NCCL packs as (major*10000 + minor*100 + patch) up to ~2.x; newer + # builds use the same scheme. Decode defensively. + major, minor, patch = n // 10000, (n // 100) % 100, n % 100 + if (major, minor, patch) < _MIN_NCCL_VERSION: + min_str = ".".join(str(x) for x in _MIN_NCCL_VERSION) + raise RuntimeError( + f"NCCL EP requires NCCL >= {min_str}, found {major}.{minor}.{patch} at runtime. " + "Set LD_LIBRARY_PATH to a newer libnccl.so before launching." + ) + + +_BOOTSTRAPPED = False +_ATEXIT_REGISTERED = False + + +def _atexit_finalize() -> None: + """Best-effort teardown at interpreter shutdown; swallows errors.""" + global _BOOTSTRAPPED + if _BOOTSTRAPPED: + try: + tex.ep_finalize() + except Exception: + import traceback + + traceback.print_exc() + finally: + _BOOTSTRAPPED = False + + +def ep_bootstrap( + ep_group: dist.ProcessGroup, + num_experts: int, + max_tokens_per_rank: int, + recv_capacity_per_rank: int, + hidden_dim: int, + max_num_sms: int = 0, + zero_copy: bool = False, + max_token_dtype: torch.dtype = torch.bfloat16, +) -> None: + """Initialize EP by borrowing ep_group's NCCL comm. Call once per process. + + max_token_dtype sets the widest token dtype this EP group will dispatch; + it sizes NCCL EP staging buffers. + + ``zero_copy`` opts the EP group into the symm-mem zero-copy IO path; pass + ``True`` only when payload tensors are allocated via ``symm_mem_alloc``. + Defaults to ``False``. + """ + global _BOOTSTRAPPED, _ATEXIT_REGISTERED + if _BOOTSTRAPPED: + raise RuntimeError("ep_bootstrap was already called in this process") + if ep_group.size() < 2: + raise ValueError(f"ep_bootstrap requires ep_group.size() >= 2 (got {ep_group.size()}).") + _check_nccl_runtime_version() + + # Materialize the PG's NCCL comm before borrowing its raw handle. + dist.barrier(group=ep_group, device_ids=[torch.cuda.current_device()]) + comm_ptr = ep_group._get_backend(torch.device("cuda"))._comm_ptr() + + tex.ep_initialize( + int(comm_ptr), + str(ep_group.group_name), + int(num_experts), + int(max_tokens_per_rank), + int(recv_capacity_per_rank), + int(hidden_dim), + int(max_num_sms), + max_token_dtype, + bool(zero_copy), + ) + _BOOTSTRAPPED = True + if not _ATEXIT_REGISTERED: + atexit.register(_atexit_finalize) + _ATEXIT_REGISTERED = True + + +def ep_finalize() -> None: + """Explicit EP teardown; optional and idempotent. An atexit handler covers + normal shutdown; call this only before ``dist.destroy_process_group()``, + since the borrowed NCCL comm is invalid once the PG is destroyed. + + Propagates errors from the C++ teardown; use ``_atexit_finalize`` for the + best-effort interpreter-shutdown path. + """ + global _BOOTSTRAPPED + if not _BOOTSTRAPPED: + return + try: + tex.ep_finalize() + finally: + _BOOTSTRAPPED = False + + +@contextmanager +def ep_scope( + ep_group: dist.ProcessGroup, + num_experts: int, + max_tokens_per_rank: int, + recv_capacity_per_rank: int, + hidden_dim: int, + max_num_sms: int = 0, + zero_copy: bool = False, + max_token_dtype: torch.dtype = torch.bfloat16, +) -> Iterator[None]: + """Context manager: ``ep_bootstrap`` on enter, ``ep_finalize`` on exit. + + Use when you tear down the EP process group yourself, so the borrowed NCCL + comm is released before ``dist.destroy_process_group()``. + """ + ep_bootstrap( + ep_group, + num_experts=num_experts, + max_tokens_per_rank=max_tokens_per_rank, + recv_capacity_per_rank=recv_capacity_per_rank, + hidden_dim=hidden_dim, + max_num_sms=max_num_sms, + zero_copy=zero_copy, + max_token_dtype=max_token_dtype, + ) + try: + yield + finally: + ep_finalize() + + +# Handle + + +class EpHandle: + """Routing context for one EP layer. Construct one per concurrently-live + microbatch (e.g. one per in-flight PP-1F1B step). + """ + + __slots__ = ( + "handle_mem", + "top_k", + "alignment", + "max_tokens_per_rank", + "recv_capacity_per_rank", + "hidden_dim", + "num_local_experts", + "payload_dtype", + "device", + ) + + def __init__( + self, + top_k: int, + max_tokens_per_rank: int, + recv_capacity_per_rank: int, + hidden_dim: int, + num_local_experts: int, + alignment: int = 0, + device: Optional[torch.device] = None, + payload_dtype: torch.dtype = torch.bfloat16, + ) -> None: + if device is None: + device = torch.device("cuda", torch.cuda.current_device()) + alignment = int(alignment) + if alignment > 1 and (alignment & (alignment - 1)) != 0: + raise ValueError( + f"alignment must be 0, 1, or a power of two (got {alignment})." + ) + self.top_k = int(top_k) + self.alignment = alignment + self.max_tokens_per_rank = int(max_tokens_per_rank) + self.recv_capacity_per_rank = int(recv_capacity_per_rank) + self.hidden_dim = int(hidden_dim) + self.num_local_experts = int(num_local_experts) + self.payload_dtype = payload_dtype + self.device = device + size_bytes = tex.ep_handle_mem_size(self.top_k, self.alignment) + self.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) + + +# Buffer + + +class EpBuffer: + """Persistent payload and scratch buffers for one EP layer. + + All slots are plain HBM in TE 2.17 (the symm-mem IO fast path is planned + for a near-future release). + + Use one EpBuffer per concurrently-in-flight call on the layer (one per + PP-1F1B microbatch); sharing between an outstanding fwd and a later call + overwrites tensors the earlier bwd still reads. Call record_stream from + streams other than the allocation stream. + """ + + __slots__ = ( + "recv_tokens", + "combine_in", + "recv_topk_weights", + "token_counts", + "grad_tokens", + "grad_topk_weights", + ) + + def __init__( + self, + handle: EpHandle, + ep_group: Optional[dist.ProcessGroup] = None, + *, + device: Optional[torch.device] = None, + ) -> None: + """Allocate the persistent EP slots. + + Cross-rank slots are symm-mem-backed when ``ep_bootstrap`` was called + with ``zero_copy=True`` (requires ``ep_group``); otherwise plain HBM. + """ + if device is None: + device = torch.device("cuda", torch.cuda.current_device()) + recv_shape = (handle.recv_capacity_per_rank, handle.hidden_dim) + send_shape = (handle.max_tokens_per_rank, handle.hidden_dim) + zero_copy = bool(tex.ep_get_zero_copy()) + if zero_copy: + if ep_group is None: + raise ValueError("EpBuffer requires ep_group when ep_bootstrap(zero_copy=True).") + self.recv_tokens = symm_mem_alloc( + recv_shape, handle.payload_dtype, ep_group, device=device + ) + self.combine_in = symm_mem_alloc( + recv_shape, handle.payload_dtype, ep_group, device=device + ) + self.recv_topk_weights = symm_mem_alloc( + (handle.recv_capacity_per_rank,), torch.float32, ep_group, device=device + ) + self.grad_tokens = symm_mem_alloc( + send_shape, handle.payload_dtype, ep_group, device=device + ) + else: + self.recv_tokens = torch.empty(recv_shape, dtype=handle.payload_dtype, device=device) + self.combine_in = torch.empty(recv_shape, dtype=handle.payload_dtype, device=device) + self.recv_topk_weights = torch.empty( + handle.recv_capacity_per_rank, dtype=torch.float32, device=device + ) + self.grad_tokens = torch.empty(send_shape, dtype=handle.payload_dtype, device=device) + # Per-rank scratch; never cross-rank, plain HBM regardless of mode. + self.token_counts = torch.empty(handle.num_local_experts, dtype=torch.int32, device=device) + self.grad_topk_weights = torch.empty( + (handle.max_tokens_per_rank, handle.top_k), dtype=torch.float32, device=device + ) + + @classmethod + def from_external( + cls, + handle: EpHandle, + *, + recv_tokens: torch.Tensor, + combine_in: torch.Tensor, + recv_topk_weights: Optional[torch.Tensor] = None, + grad_tokens: Optional[torch.Tensor] = None, + token_counts: Optional[torch.Tensor] = None, + grad_topk_weights: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + ) -> "EpBuffer": + """Construct from caller-allocated buffers. + + Useful for sharing a pre-allocated pool across layers/microbatches, and + for plugging in symm-mem-backed tensors once the zero-copy IO fast path + ships in a near-future release. Caller-supplied slots are validated + against the expected shape and dtype; ``None`` slots get a fresh HBM + allocation. + """ + if device is None: + device = torch.device("cuda", torch.cuda.current_device()) + recv_shape = (handle.recv_capacity_per_rank, handle.hidden_dim) + send_shape = (handle.max_tokens_per_rank, handle.hidden_dim) + topk_shape = (handle.max_tokens_per_rank, handle.top_k) + recv_w_shape = (handle.recv_capacity_per_rank,) + counts_shape = (handle.num_local_experts,) + + def _check(t: torch.Tensor, name: str, shape: tuple, dtype: torch.dtype) -> torch.Tensor: + if tuple(t.shape) != shape: + raise ValueError(f"{name} shape {tuple(t.shape)} != expected {shape}") + if t.dtype != dtype: + raise ValueError(f"{name} dtype {t.dtype} != expected {dtype}") + return t + + inst = cls.__new__(cls) + inst.recv_tokens = _check(recv_tokens, "recv_tokens", recv_shape, handle.payload_dtype) + inst.combine_in = _check(combine_in, "combine_in", recv_shape, handle.payload_dtype) + inst.recv_topk_weights = ( + _check(recv_topk_weights, "recv_topk_weights", recv_w_shape, torch.float32) + if recv_topk_weights is not None + else torch.empty(recv_w_shape, dtype=torch.float32, device=device) + ) + inst.grad_tokens = ( + _check(grad_tokens, "grad_tokens", send_shape, handle.payload_dtype) + if grad_tokens is not None + else torch.empty(send_shape, dtype=handle.payload_dtype, device=device) + ) + inst.token_counts = ( + _check(token_counts, "token_counts", counts_shape, torch.int32) + if token_counts is not None + else torch.empty(counts_shape, dtype=torch.int32, device=device) + ) + inst.grad_topk_weights = ( + _check(grad_topk_weights, "grad_topk_weights", topk_shape, torch.float32) + if grad_topk_weights is not None + else torch.empty(topk_shape, dtype=torch.float32, device=device) + ) + return inst + + def record_stream(self, stream: torch.cuda.Stream) -> None: + """Record stream as a user of all owned tensors so the caching allocator + defers reclaim until stream has caught up.""" + for t in ( + self.recv_tokens, + self.combine_in, + self.recv_topk_weights, + self.grad_tokens, + self.token_counts, + self.grad_topk_weights, + ): + t.record_stream(stream) + + +# torch.library custom ops (so they don't graph-break under torch.compile) + +_LIB = "transformer_engine_ep" + + +@torch.library.custom_op( + f"{_LIB}::prepare", + mutates_args=("handle_mem", "token_counts"), + device_types="cuda", +) +def _prepare_op( + handle_mem: torch.Tensor, + top_k: int, + topk_idx: torch.Tensor, + token_counts: torch.Tensor, + alignment: int, +) -> None: + tex.ep_prepare(handle_mem, topk_idx, token_counts, top_k, alignment) + + +@_prepare_op.register_fake +def _(*args, **kw): + return None + + +@torch.library.custom_op( + f"{_LIB}::dispatch", + mutates_args=("recv_tokens", "recv_topk_weights"), + device_types="cuda", +) +def _dispatch_op( + handle_mem: torch.Tensor, + topk_idx: torch.Tensor, + tokens: torch.Tensor, + topk_weights: torch.Tensor, + recv_tokens: torch.Tensor, + recv_topk_weights: torch.Tensor, +) -> None: + tex.ep_dispatch(handle_mem, topk_idx, tokens, topk_weights, recv_tokens, recv_topk_weights) + + +@_dispatch_op.register_fake +def _(*args, **kw): + return None + + +@torch.library.custom_op( + f"{_LIB}::combine", + mutates_args=("result",), + device_types="cuda", +) +def _combine_op( + handle_mem: torch.Tensor, + expert_out: torch.Tensor, + result: torch.Tensor, +) -> None: + tex.ep_combine(handle_mem, expert_out, result) + + +@_combine_op.register_fake +def _(*args, **kw): + return None + + +@torch.library.custom_op( + f"{_LIB}::dispatch_bwd", + mutates_args=("grad_tokens", "grad_topk_weights"), + device_types="cuda", +) +def _dispatch_bwd_op( + handle_mem: torch.Tensor, + grad: torch.Tensor, + g_recv_topk_weights: torch.Tensor, + grad_tokens: torch.Tensor, + grad_topk_weights: torch.Tensor, +) -> None: + tex.ep_dispatch_bwd(handle_mem, grad, g_recv_topk_weights, grad_tokens, grad_topk_weights) + + +@_dispatch_bwd_op.register_fake +def _(*args, **kw): + return None + + +@torch.library.custom_op( + f"{_LIB}::combine_bwd", + mutates_args=("grad_expert_out",), + device_types="cuda", +) +def _combine_bwd_op( + handle_mem: torch.Tensor, + grad: torch.Tensor, + grad_expert_out: torch.Tensor, +) -> None: + tex.ep_combine_bwd(handle_mem, grad, grad_expert_out) + + +@_combine_bwd_op.register_fake +def _(*args, **kw): + return None + + +# Non-autograd primitives + + +def ep_prepare(handle: EpHandle, topk_idx: torch.Tensor) -> torch.Tensor: + """AllGather the routing map; fills handle.handle_mem and returns token_counts + (int32, shape [num_local_experts]). topk_idx must be int64. + """ + token_counts = torch.empty(handle.num_local_experts, dtype=torch.int32, device=handle.device) + torch.ops.transformer_engine_ep.prepare( + handle.handle_mem, handle.top_k, topk_idx, token_counts, handle.alignment + ) + return token_counts + + +def _ep_dispatch_raw( + handle: EpHandle, + topk_idx: torch.Tensor, + tokens: torch.Tensor, + topk_weights: torch.Tensor, + recv_tokens: torch.Tensor, + recv_topk_weights: torch.Tensor, +) -> None: + """Raw dispatch; no autograd, no prepare. Caller must run ep_prepare first.""" + tex.ep_dispatch( + handle.handle_mem, topk_idx, tokens, topk_weights, recv_tokens, recv_topk_weights + ) + + +def _ep_combine_raw(handle: EpHandle, expert_out: torch.Tensor, result: torch.Tensor) -> None: + """Raw combine; no autograd. Caller pre-weights expert_out.""" + tex.ep_combine(handle.handle_mem, expert_out, result) + + +# autograd.Function wrappers + + +class _EpDispatch(torch.autograd.Function): + """Autograd-aware prepare + dispatch. Fwd/bwd share handle_mem and the + EpBuffer slots; do not re-run ep_prepare between them and do not share + EpBuffer with another in-flight call (see EpBuffer). + """ + + @staticmethod + def forward( # type: ignore[override] + ctx, + handle_mem: torch.Tensor, + top_k: int, + alignment: int, + recv_tokens: torch.Tensor, + recv_topk_weights: torch.Tensor, + token_counts: torch.Tensor, + grad_tokens_buf: torch.Tensor, + grad_topk_weights_buf: torch.Tensor, + topk_idx: torch.Tensor, + tokens: torch.Tensor, + topk_weights: torch.Tensor, + ): + torch.ops.transformer_engine_ep.prepare( + handle_mem, top_k, topk_idx, token_counts, alignment + ) + torch.ops.transformer_engine_ep.dispatch( + handle_mem, + topk_idx, + tokens, + topk_weights, + recv_tokens, + recv_topk_weights, + ) + ctx.handle_mem = handle_mem + ctx.grad_tokens_buf = grad_tokens_buf + ctx.grad_topk_weights_buf = grad_topk_weights_buf + ctx.tokens_shape = tokens.shape + ctx.tokens_dtype = tokens.dtype + ctx.topk_weights_shape = topk_weights.shape + ctx.topk_weights_dtype = topk_weights.dtype + ctx.tokens_T_flat = tokens.numel() // tokens.shape[-1] + ctx.topk_T_flat = topk_weights.numel() // topk_weights.shape[-1] + ctx.recv_capacity = recv_tokens.shape[0] + ctx.hidden_dim = tokens.shape[-1] + ctx.mark_non_differentiable(token_counts) + # Detach so the long-lived buffers aren't tracked as differentiable outputs; + # autograd re-attaches grad_fn pointing back at this Function. + return recv_tokens.detach(), recv_topk_weights.detach(), token_counts + + @staticmethod + def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: ignore[override] + device = ctx.handle_mem.device + if g_recv_tokens is None: + g_recv_tokens = torch.zeros( + ctx.recv_capacity, ctx.hidden_dim, dtype=ctx.tokens_dtype, device=device + ) + if g_recv_topk_weights is None: + g_recv_topk_weights = torch.zeros(ctx.recv_capacity, dtype=torch.float32, device=device) + if not g_recv_tokens.is_contiguous(): + g_recv_tokens = g_recv_tokens.contiguous() + if not g_recv_topk_weights.is_contiguous(): + g_recv_topk_weights = g_recv_topk_weights.contiguous() + # Narrow the persistent slots to this call's flattened leading dim. + grad_tokens = ctx.grad_tokens_buf.narrow(0, 0, ctx.tokens_T_flat) + grad_topk_weights = ctx.grad_topk_weights_buf.narrow(0, 0, ctx.topk_T_flat) + torch.ops.transformer_engine_ep.dispatch_bwd( + ctx.handle_mem, + g_recv_tokens, + g_recv_topk_weights, + grad_tokens, + grad_topk_weights, + ) + # Reshape back to the original input shape so autograd's grad slot matches. + grad_tokens_out = grad_tokens.view(ctx.tokens_shape) + grad_topk_weights_out = grad_topk_weights.view(ctx.topk_weights_shape) + return ( + None, # handle_mem + None, # top_k + None, # alignment + None, # recv_tokens + None, # recv_topk_weights + None, # token_counts + None, # grad_tokens_buf + None, # grad_topk_weights_buf + None, # topk_idx + grad_tokens_out, + grad_topk_weights_out, + ) + + +class _EpCombine(torch.autograd.Function): + """Autograd-aware combine. combine_in is reused as grad_combine_in in bwd; + fwd/bwd share handle_mem so don't re-run ep_prepare between them. Caller + must pre-apply the topk weighting to expert_out. + """ + + @staticmethod + def forward( # type: ignore[override] + ctx, + handle_mem: torch.Tensor, + combine_in: torch.Tensor, + num_local_tokens: int, + hidden_dim: int, + expert_out: torch.Tensor, + ): + device = expert_out.device + # Stage expert_out into the persistent combine_in slot (symm-mem-backed + # in the zero-copy path); its storage is reused as grad_combine_in in bwd. + combine_in.copy_(expert_out) + result = torch.empty(num_local_tokens, hidden_dim, dtype=expert_out.dtype, device=device) + torch.ops.transformer_engine_ep.combine(handle_mem, combine_in, result) + ctx.handle_mem = handle_mem + ctx.combine_in = combine_in # reused as grad_combine_in in bwd + return result + + @staticmethod + def backward(ctx, g_result): # type: ignore[override] + grad_combine_in = ctx.combine_in + if not g_result.is_contiguous(): + g_result = g_result.contiguous() + torch.ops.transformer_engine_ep.combine_bwd(ctx.handle_mem, g_result, grad_combine_in) + return ( + None, # handle_mem + None, # combine_in + None, # num_local_tokens + None, # hidden_dim + grad_combine_in, + ) + + +# Public high-level wrappers + + +# FP8 dispatch is not yet supported by the common backend. +_FP8_DTYPES = (torch.float8_e4m3fn, torch.float8_e5m2) + + +def _reject_fp8(*tensors: torch.Tensor) -> None: + for t in tensors: + if t.dtype in _FP8_DTYPES: + raise NotImplementedError( + f"FP8 dispatch/combine not supported (got dtype={t.dtype}); " + "quantize outside the EP boundary." + ) + + +def ep_dispatch( + handle: EpHandle, + buffer: EpBuffer, + tokens: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, +): + """Run prepare + dispatch with autograd. topk_idx must be int64. + + Returns (recv_tokens, recv_topk_weights, token_counts); views into buffer's + persistent slots; consume them before the next ep_dispatch on the same + buffer or they get overwritten. token_counts is non-differentiable. + """ + _reject_fp8(tokens, buffer.recv_tokens) + return _EpDispatch.apply( + handle.handle_mem, + handle.top_k, + handle.alignment, + buffer.recv_tokens, + buffer.recv_topk_weights, + buffer.token_counts, + buffer.grad_tokens, + buffer.grad_topk_weights, + topk_idx, + tokens, + topk_weights, + ) + + +def ep_combine( + handle: EpHandle, + buffer: EpBuffer, + expert_out: torch.Tensor, + *, + num_local_tokens: Optional[int] = None, +): + """Combine expert outputs back to the source rank, with autograd. The + caller must pre-apply the topk weighting to expert_out. + + Result shape is (num_local_tokens, handle.hidden_dim); defaults to + handle.max_tokens_per_rank rows. + """ + _reject_fp8(expert_out, buffer.combine_in) + if num_local_tokens is None: + num_local_tokens = handle.max_tokens_per_rank + return _EpCombine.apply( + handle.handle_mem, + buffer.combine_in, + num_local_tokens, + handle.hidden_dim, + expert_out, + ) From d10f0ceb2ea0a46ae695253767494ead6736ba9f Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 16:28:36 -0700 Subject: [PATCH 02/39] EP PyTorch: wire maybe_make_window into per-step ops for zero_copy Signed-off-by: Phuong Nguyen --- .../pytorch/csrc/extensions/ep.cpp | 41 ++++++++++++------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/ep.cpp b/transformer_engine/pytorch/csrc/extensions/ep.cpp index cdb0b4239a..018b5addc0 100644 --- a/transformer_engine/pytorch/csrc/extensions/ep.cpp +++ b/transformer_engine/pytorch/csrc/extensions/ep.cpp @@ -59,15 +59,18 @@ constexpr NVTECommWindow kNoWindow = {nullptr, 0}; // Resolve ``t`` to an NCCL symm-mem window for the zero-copy one-sided path. // Returns ``kNoWindow`` when symm-mem support isn't compiled in, zero-copy is -// disabled, no group is set, or ``t`` isn't symm-mem-backed. Currently unused -// at per-step call sites (they hardcode kNoWindow); kept so flipping -// ``g_zero_copy_enabled`` is the only change needed once the backend's -// symm-mem IO path is exposed. -[[maybe_unused]] NVTECommWindow maybe_make_window(const at::Tensor& t) { +// disabled, no group is set, or ``t`` isn't symm-mem-backed; callers pass the +// resulting window unconditionally to the backend. +NVTECommWindow maybe_make_window(const at::Tensor& t) { #ifdef NCCL_HAS_SYMMEM_SUPPORT if (!g_zero_copy_enabled.load(std::memory_order_relaxed)) return kNoWindow; if (g_ep_group_name.empty()) return kNoWindow; - auto sm = c10d::symmetric_memory::rendezvous(t, g_ep_group_name); + c10::intrusive_ptr sm; + try { + sm = c10d::symmetric_memory::rendezvous(t, g_ep_group_name); + } catch (const std::exception&) { + return kNoWindow; // Tensor not symm-mem-backed; fall back to staged copy. + } if (sm == nullptr) return kNoWindow; auto* nccl_sm = dynamic_cast(sm.get()); NVTE_CHECK(nccl_sm != nullptr, @@ -212,9 +215,13 @@ void ep_dispatch(at::Tensor handle_mem, at::Tensor topk_idx, at::Tensor tokens, // top_k / alignment are carried by the cached layer_cfg seeded at ep_prepare; // per-step ops look them up by handle_mem pointer in the backend. - nvte_ep_dispatch(handle_mem_te.data(), topk_idx_te.data(), tokens_te.data(), kNoWindow, - topk_w_te.data(), kNoWindow, recv_tokens_te.data(), kNoWindow, - recv_topk_w_te.data(), kNoWindow, stream); + NVTECommWindow tokens_win = maybe_make_window(tokens); + NVTECommWindow topk_w_win = maybe_make_window(topk_weights); + NVTECommWindow recv_tokens_win = maybe_make_window(recv_tokens); + NVTECommWindow recv_topk_w_win = maybe_make_window(recv_topk_weights); + nvte_ep_dispatch(handle_mem_te.data(), topk_idx_te.data(), tokens_te.data(), tokens_win, + topk_w_te.data(), topk_w_win, recv_tokens_te.data(), recv_tokens_win, + recv_topk_w_te.data(), recv_topk_w_win, stream); } void ep_combine(at::Tensor handle_mem, at::Tensor expert_out, at::Tensor result) { @@ -238,7 +245,9 @@ void ep_combine(at::Tensor handle_mem, at::Tensor expert_out, at::Tensor result) makeTransformerEngineTensor(expert_out.data_ptr(), Shape{recv_pr, H}, eo_dtype); auto result_te = makeTransformerEngineTensor(result.data_ptr(), Shape{T_flat, H}, eo_dtype); - nvte_ep_combine(handle_mem_te.data(), expert_out_te.data(), kNoWindow, result_te.data(), stream); + NVTECommWindow expert_out_win = maybe_make_window(expert_out); + nvte_ep_combine(handle_mem_te.data(), expert_out_te.data(), expert_out_win, result_te.data(), + stream); } void ep_dispatch_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor g_recv_topk_weights, @@ -274,8 +283,10 @@ void ep_dispatch_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor g_recv_t auto grad_topk_w_te = makeTransformerEngineTensor(grad_topk_weights.data_ptr(), Shape{T_flat, topk_n}, DType::kFloat32); - nvte_ep_dispatch_bwd(handle_mem_te.data(), grad_te.data(), kNoWindow, g_recv_w_te.data(), - kNoWindow, grad_tokens_te.data(), grad_topk_w_te.data(), stream); + NVTECommWindow grad_win = maybe_make_window(grad); + NVTECommWindow g_recv_w_win = maybe_make_window(g_recv_topk_weights); + nvte_ep_dispatch_bwd(handle_mem_te.data(), grad_te.data(), grad_win, g_recv_w_te.data(), + g_recv_w_win, grad_tokens_te.data(), grad_topk_w_te.data(), stream); } void ep_combine_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor grad_expert_out) { @@ -299,8 +310,10 @@ void ep_combine_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor grad_expe auto grad_eo_te = makeTransformerEngineTensor(grad_expert_out.data_ptr(), Shape{recv_pr, H}, g_dtype); - nvte_ep_combine_bwd(handle_mem_te.data(), grad_te.data(), kNoWindow, grad_eo_te.data(), - kNoWindow, stream); + NVTECommWindow grad_win = maybe_make_window(grad); + NVTECommWindow grad_eo_win = maybe_make_window(grad_expert_out); + nvte_ep_combine_bwd(handle_mem_te.data(), grad_te.data(), grad_win, grad_eo_te.data(), + grad_eo_win, stream); } void register_ep_bindings(pybind11::module_& m) { From 3509e91da21897512c7565ba0d18be7a099ee12a Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 16:28:36 -0700 Subject: [PATCH 03/39] EP PyTorch: merge EpHandle into EpBuffer; ep_dispatch/ep_combine take a single buffer Signed-off-by: Phuong Nguyen --- examples/pytorch/ep/bench/ep_bench.py | 27 ++-- examples/pytorch/ep/ep_moe.py | 17 +-- tests/pytorch/distributed/run_ep.py | 86 ++++++----- transformer_engine/pytorch/ep.py | 202 +++++++++++++------------- 4 files changed, 159 insertions(+), 173 deletions(-) diff --git a/examples/pytorch/ep/bench/ep_bench.py b/examples/pytorch/ep/bench/ep_bench.py index 86217b7f91..973703b710 100644 --- a/examples/pytorch/ep/bench/ep_bench.py +++ b/examples/pytorch/ep/bench/ep_bench.py @@ -30,7 +30,6 @@ from transformer_engine.pytorch.ep import ( EpBuffer, - EpHandle, ep_bootstrap, ep_combine, ep_dispatch, @@ -177,14 +176,14 @@ def main(): topk_idx, tokens_hbm, topk_w_hbm = _make_inputs(rank, world_size, T, H, K, E, device) - handle = EpHandle( + buffer = EpBuffer( top_k=K, max_tokens_per_rank=T, recv_capacity_per_rank=recv_pr, hidden_dim=H, num_local_experts=num_local_experts, + ep_group=ep_group, ) - buffer = EpBuffer(handle) tokens = tokens_hbm topk_w = topk_w_hbm @@ -192,11 +191,11 @@ def main(): recv_w = torch.empty(recv_pr, dtype=torch.float32, device=device) # -- Prepare once outside the timed loops ------------------------------ - ep_prepare(handle, topk_idx) + ep_prepare(buffer, topk_idx) torch.cuda.synchronize() # Pre-dispatch a steady recv_tokens / recv_w so combine stages have valid input. - _ep_dispatch_raw(handle, topk_idx, tokens, topk_w, recv_tokens, recv_w) + _ep_dispatch_raw(buffer, topk_idx, tokens, topk_w, recv_tokens, recv_w) torch.cuda.synchronize() # fp-equivalent stand-in for an MLP output. expert_out = recv_tokens.clone() @@ -210,20 +209,20 @@ def main(): eo_p = recv_tokens.detach().clone().requires_grad_(True) # Stand-in callables; the cuda-graph branch below swaps in graphed versions. - fwd_bwd_dispatch_fn = lambda x: ep_dispatch(handle, buffer, x, topk_idx, topk_w)[ # noqa: E731 + fwd_bwd_dispatch_fn = lambda x: ep_dispatch(buffer, x, topk_idx, topk_w)[ # noqa: E731 0 ] - fwd_bwd_combine_fn = lambda eo: ep_combine(handle, buffer, eo) # noqa: E731 + fwd_bwd_combine_fn = lambda eo: ep_combine(buffer, eo) # noqa: E731 def _dispatch_raw(): - _ep_dispatch_raw(handle, topk_idx, tokens, topk_w, recv_tokens, recv_w) + _ep_dispatch_raw(buffer, topk_idx, tokens, topk_w, recv_tokens, recv_w) def _combine_raw(): out_buf = torch.empty(T, H, dtype=torch.bfloat16, device=device) - _ep_combine_raw(handle, expert_out, out_buf) + _ep_combine_raw(buffer, expert_out, out_buf) def _ep_dispatch_fwd(): - ep_dispatch(handle, buffer, tokens.detach(), topk_idx, topk_w) + ep_dispatch(buffer, tokens.detach(), topk_idx, topk_w) def _ep_dispatch_fwd_bwd(): tokens_p.grad = None @@ -231,7 +230,7 @@ def _ep_dispatch_fwd_bwd(): (0.5 * (r * r).sum(dtype=torch.float32)).backward() def _ep_combine_fwd(): - ep_combine(handle, buffer, recv_tokens) + ep_combine(buffer, recv_tokens) def _ep_combine_fwd_bwd(): eo_p.grad = None @@ -265,11 +264,11 @@ def _ep_combine_fwd_bwd(): # Graph fwd+bwd of the autograd-wrapped ops via make_graphed_callables. class _DispatchMod(torch.nn.Module): def forward(self, x): - return ep_dispatch(handle, buffer, x, topk_idx, topk_w)[0] + return ep_dispatch(buffer, x, topk_idx, topk_w)[0] class _CombineMod(torch.nn.Module): def forward(self, eo): - return ep_combine(handle, buffer, eo) + return ep_combine(buffer, eo) disp_mod = _DispatchMod().cuda() comb_mod = _CombineMod().cuda() @@ -383,7 +382,7 @@ def forward(self, eo): fwd_bwd_combine_fn = None captured_runners.clear() del g_disp, g_comb, disp_mod, comb_mod - del tokens_p, eo_p, buffer, handle, recv_tokens, recv_w, tokens, topk_w, expert_out + del tokens_p, eo_p, buffer, recv_tokens, recv_w, tokens, topk_w, expert_out gc.collect() torch.cuda.synchronize() # Release NCCL EP's borrowed comm before torch destroys it. diff --git a/examples/pytorch/ep/ep_moe.py b/examples/pytorch/ep/ep_moe.py index 934d88d8c7..70bf678f3c 100644 --- a/examples/pytorch/ep/ep_moe.py +++ b/examples/pytorch/ep/ep_moe.py @@ -15,7 +15,6 @@ import torch.distributed as dist from transformer_engine.pytorch.ep import ( - EpHandle, EpBuffer, ep_scope, ep_dispatch, @@ -147,20 +146,20 @@ def _run_layer(args, rank, world_size, ep_size, num_experts, num_local_experts, kernels_np[rank * num_local_experts : (rank + 1) * num_local_experts] ).to(device=device, dtype=torch.bfloat16) - handle = EpHandle( + buffer = EpBuffer( top_k=args.top_k, max_tokens_per_rank=T, recv_capacity_per_rank=recv_pr, hidden_dim=args.hidden, num_local_experts=num_local_experts, + ep_group=ep_group, ) - buffer = EpBuffer(handle) - recv_t, recv_w_out, _tc = ep_dispatch(handle, buffer, tokens, topk_idx, topk_w) + recv_t, recv_w_out, _tc = ep_dispatch(buffer, tokens, topk_idx, topk_w) expert_out = _batched_expert_linear(recv_t, kernels_local, num_local_experts) # Apply per-slot topk weighting before combine. expert_out = expert_out * recv_w_out.unsqueeze(-1).to(expert_out.dtype) - out = ep_combine(handle, buffer, expert_out) + out = ep_combine(buffer, expert_out) loss = 0.5 * (out.float() ** 2).sum() loss.backward() @@ -179,18 +178,18 @@ def _run_layer(args, rank, world_size, ep_size, num_experts, num_local_experts, torch.cuda.synchronize() dist.barrier() for _ in range(args.benchmark_warmup): - rt, rw, _tc = ep_dispatch(handle, buffer, tokens.detach(), topk_idx, topk_w) + rt, rw, _tc = ep_dispatch(buffer, tokens.detach(), topk_idx, topk_w) eo = _batched_expert_linear(rt, kernels_local, num_local_experts) eo = eo * rw.unsqueeze(-1).to(eo.dtype) - ep_combine(handle, buffer, eo) + ep_combine(buffer, eo) torch.cuda.synchronize() dist.barrier() t0 = time.perf_counter() for _ in range(args.benchmark_iters): - rt, rw, _tc = ep_dispatch(handle, buffer, tokens.detach(), topk_idx, topk_w) + rt, rw, _tc = ep_dispatch(buffer, tokens.detach(), topk_idx, topk_w) eo = _batched_expert_linear(rt, kernels_local, num_local_experts) eo = eo * rw.unsqueeze(-1).to(eo.dtype) - ep_combine(handle, buffer, eo) + ep_combine(buffer, eo) torch.cuda.synchronize() dt_ms = (time.perf_counter() - t0) * 1000.0 / args.benchmark_iters if rank == 0: diff --git a/tests/pytorch/distributed/run_ep.py b/tests/pytorch/distributed/run_ep.py index 7f74a454aa..3b343466b9 100644 --- a/tests/pytorch/distributed/run_ep.py +++ b/tests/pytorch/distributed/run_ep.py @@ -12,17 +12,20 @@ import torch.distributed as dist from transformer_engine.pytorch.ep import ( - EpHandle, EpBuffer, ep_bootstrap, ep_finalize, ep_prepare, ep_dispatch, ep_combine, + symm_mem_alloc, _ep_combine_raw, _ep_dispatch_raw, ) + +ZERO_COPY = False + # Must come after the transformer_engine import so libtransformer_engine.so is loaded. import transformer_engine_torch as tex # noqa: F401 @@ -104,21 +107,22 @@ def setUpClass(cls): max_tokens_per_rank=TOKENS_PER_RANK, recv_capacity_per_rank=cls.cfg.recv_capacity_per_rank, hidden_dim=HIDDEN_DIM, - zero_copy=True, + zero_copy=ZERO_COPY, ) - def _make_handle(self, alignment=0, top_k=TOP_K): - return EpHandle( + def _make_buffer(self, alignment=0, top_k=TOP_K): + return EpBuffer( top_k=top_k, max_tokens_per_rank=TOKENS_PER_RANK, recv_capacity_per_rank=self.cfg.recv_capacity_per_rank, hidden_dim=HIDDEN_DIM, num_local_experts=NUM_LOCAL_EXPERTS, alignment=alignment, + ep_group=self.ep_group, ) - def _make_buffers(self, dtype=torch.bfloat16): - """Allocate raw recv buffers + token_counts for the primitive tests.""" + def _make_raw_recv(self, dtype=torch.bfloat16): + """Raw recv tensors + token_counts for the primitive tests.""" rc = self.cfg.recv_capacity_per_rank return ( torch.empty(rc, HIDDEN_DIM, dtype=dtype, device=self.cfg.device), @@ -126,26 +130,23 @@ def _make_buffers(self, dtype=torch.bfloat16): torch.empty(NUM_LOCAL_EXPERTS, dtype=torch.int32, device=self.cfg.device), ) - def _make_ep_buffer(self, handle): - return EpBuffer(handle) - @staticmethod def _weighted(recv_tokens, recv_w): """fp32 per-slot weighting + cast back; matches the upstream combine input.""" mask = (recv_w != 0).to(torch.float32).unsqueeze(-1) return (recv_tokens.float() * recv_w.unsqueeze(-1).float() * mask).to(recv_tokens.dtype) - def _moe_step(self, handle, buffer, topk_idx, tokens, w): - recv_t, recv_w_out, _tc = ep_dispatch(handle, buffer, tokens, topk_idx, w) + def _moe_step(self, buffer, topk_idx, tokens, w): + recv_t, recv_w_out, _tc = ep_dispatch(buffer, tokens, topk_idx, w) eo = self._weighted(recv_t, recv_w_out) - return ep_combine(handle, buffer, eo) + return ep_combine(buffer, eo) # Prepare def test_primitive_prepare(self): - handle = self._make_handle() + buf = self._make_buffer() topk_idx, _toks, _w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) - token_counts = ep_prepare(handle, topk_idx) + token_counts = ep_prepare(buf, topk_idx) torch.cuda.synchronize() self.assertEqual(token_counts.shape, (NUM_LOCAL_EXPERTS,)) local = int(token_counts.sum().item()) @@ -156,13 +157,13 @@ def test_primitive_prepare(self): # Identity round-trip via raw primitives def test_primitive_dispatch_combine_identity(self): - handle = self._make_handle() + buf = self._make_buffer() topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) - recv_tokens, recv_w, _ = self._make_buffers() - ep_prepare(handle, topk_idx) - _ep_dispatch_raw(handle, topk_idx, tokens, w, recv_tokens, recv_w) + recv_tokens, recv_w, _ = self._make_raw_recv() + ep_prepare(buf, topk_idx) + _ep_dispatch_raw(buf, topk_idx, tokens, w, recv_tokens, recv_w) result = torch.empty_like(tokens) - _ep_combine_raw(handle, self._weighted(recv_tokens, recv_w), result) + _ep_combine_raw(buf, self._weighted(recv_tokens, recv_w), result) torch.cuda.synchronize() torch.testing.assert_close(result.float(), tokens.float(), atol=5e-2, rtol=5e-2) @@ -170,11 +171,10 @@ def test_primitive_dispatch_combine_identity(self): def test_dispatch_fwd_bwd(self): """0.5*||recv_tokens||^2 ; grad_tokens equals TOP_K * tokens.""" - handle = self._make_handle() - buffer = self._make_ep_buffer(handle) + buf = self._make_buffer() topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) tokens_p = tokens.detach().clone().requires_grad_(True) - recv_t, _recv_w, _tc = ep_dispatch(handle, buffer, tokens_p, topk_idx, w) + recv_t, _recv_w, _tc = ep_dispatch(buf, tokens_p, topk_idx, w) loss = 0.5 * (recv_t.float() ** 2).sum() loss.backward() torch.cuda.synchronize() @@ -184,11 +184,10 @@ def test_dispatch_fwd_bwd(self): def test_combine_fwd_bwd(self): """Full dispatch + combine fwd+bwd; identity inputs round-trip.""" - handle = self._make_handle() - buffer = self._make_ep_buffer(handle) + buf = self._make_buffer() topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) tokens_p = tokens.detach().clone().requires_grad_(True) - out = self._moe_step(handle, buffer, topk_idx, tokens_p, w) + out = self._moe_step(buf, topk_idx, tokens_p, w) loss = 0.5 * (out.float() ** 2).sum() loss.backward() torch.cuda.synchronize() @@ -198,14 +197,13 @@ def test_combine_fwd_bwd(self): # Multi-iter stability def test_dispatch_fwd_bwd_multiple_iterations(self): - """5 fwd+bwd iters on the same EpHandle + EpBuffer must be bit-stable.""" - handle = self._make_handle() - buffer = self._make_ep_buffer(handle) + """5 fwd+bwd iters on the same EpBuffer must be bit-stable.""" + buf = self._make_buffer() topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) def one_step(): tokens_p = tokens.detach().clone().requires_grad_(True) - out = self._moe_step(handle, buffer, topk_idx, tokens_p, w) + out = self._moe_step(buf, topk_idx, tokens_p, w) loss = 0.5 * (out.float() ** 2).sum() loss.backward() return out.detach().clone(), tokens_p.grad.detach().clone() @@ -222,22 +220,22 @@ def one_step(): def test_cuda_graph_capture(self): """Capture raw dispatch+combine into a CUDA graph; replay must be bit-stable.""" - handle = self._make_handle() + buf = self._make_buffer() topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) - recv_tokens, recv_w, _ = self._make_buffers() + recv_tokens, recv_w, _ = self._make_raw_recv() result = torch.empty_like(tokens) def step(): - ep_prepare(handle, topk_idx) - _ep_dispatch_raw(handle, topk_idx, tokens, w, recv_tokens, recv_w) - _ep_combine_raw(handle, self._weighted(recv_tokens, recv_w), result) + ep_prepare(buf, topk_idx) + _ep_dispatch_raw(buf, topk_idx, tokens, w, recv_tokens, recv_w) + _ep_combine_raw(buf, self._weighted(recv_tokens, recv_w), result) for _ in range(3): step() torch.cuda.synchronize() # Routing is fixed per layer; prepare runs once before capture. - ep_prepare(handle, topk_idx) + ep_prepare(buf, topk_idx) torch.cuda.synchronize() graph = torch.cuda.CUDAGraph() @@ -245,8 +243,8 @@ def step(): s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): with torch.cuda.graph(graph): - _ep_dispatch_raw(handle, topk_idx, tokens, w, recv_tokens, recv_w) - _ep_combine_raw(handle, self._weighted(recv_tokens, recv_w), result) + _ep_dispatch_raw(buf, topk_idx, tokens, w, recv_tokens, recv_w) + _ep_combine_raw(buf, self._weighted(recv_tokens, recv_w), result) torch.cuda.current_stream().wait_stream(s) torch.cuda.synchronize() @@ -259,15 +257,13 @@ def step(): # PP-1F1B handle isolation def test_pp_1f1b_two_handles(self): - """PP-1F1B interleave (F0 F1 B0 F2 B1 B2) over 3 per-microbatch handles.""" + """PP-1F1B interleave (F0 F1 B0 F2 B1 B2) over 3 per-microbatch buffers.""" T, H = TOKENS_PER_RANK, HIDDEN_DIM idx, _toks, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) scales = (0.13, 0.41, 0.77) - handles, buffers, tokens, tokens_p = [], [], [], [] + buffers, tokens, tokens_p = [], [], [] for s in scales: - h = self._make_handle() - handles.append(h) - buffers.append(self._make_ep_buffer(h)) + buffers.append(self._make_buffer()) t = torch.full( (T, H), s + self.cfg.rank * 0.01, dtype=torch.bfloat16, device=self.cfg.device ) @@ -277,7 +273,7 @@ def test_pp_1f1b_two_handles(self): recv = [None, None, None] def fwd(k): - recv[k], _, _ = ep_dispatch(handles[k], buffers[k], tokens_p[k], idx, w) + recv[k], _, _ = ep_dispatch(buffers[k], tokens_p[k], idx, w) def bwd(k): (0.5 * (recv[k].float() ** 2).sum()).backward() @@ -301,12 +297,12 @@ def bwd(k): # Input validation def test_topk_int32_raises_clear_error(self): - handle = self._make_handle() + buf = self._make_buffer() topk_idx_int32 = torch.zeros( TOKENS_PER_RANK, TOP_K, dtype=torch.int32, device=self.cfg.device ) with self.assertRaises(RuntimeError) as cm: - ep_prepare(handle, topk_idx_int32) + ep_prepare(buf, topk_idx_int32) msg = str(cm.exception) self.assertIn("topk_idx", msg) self.assertIn(".long()", msg) diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index bc4b3bb5d1..9e66d882b4 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -16,7 +16,6 @@ __all__ = [ - "EpHandle", "EpBuffer", "ep_bootstrap", "ep_finalize", @@ -203,24 +202,43 @@ def ep_scope( ep_finalize() -# Handle +# Buffer -class EpHandle: - """Routing context for one EP layer. Construct one per concurrently-live - microbatch (e.g. one per in-flight PP-1F1B step). +class EpBuffer: + """Per-microbatch EP layer state: routing handle + persistent payload slots. + + Owns the per-call ``handle_mem`` routing scratch and the payload buffers + consumed by :func:`ep_dispatch` / :func:`ep_combine`. Allocate one + EpBuffer per concurrently-in-flight call on the layer (one per PP-1F1B + microbatch); sharing across overlapping calls clobbers tensors the + earlier bwd still reads. Call ``record_stream`` from streams other than + the allocation stream. + + Cross-rank payload slots are symm-mem-backed when ``ep_bootstrap`` was + called with ``zero_copy=True`` (requires ``ep_group``); otherwise plain + HBM. """ __slots__ = ( + # routing "handle_mem", "top_k", "alignment", + # layer config "max_tokens_per_rank", "recv_capacity_per_rank", "hidden_dim", "num_local_experts", "payload_dtype", "device", + # payload slots + "recv_tokens", + "combine_in", + "recv_topk_weights", + "token_counts", + "grad_tokens", + "grad_topk_weights", ) def __init__( @@ -231,16 +249,15 @@ def __init__( hidden_dim: int, num_local_experts: int, alignment: int = 0, - device: Optional[torch.device] = None, + ep_group: Optional[dist.ProcessGroup] = None, payload_dtype: torch.dtype = torch.bfloat16, + device: Optional[torch.device] = None, ) -> None: if device is None: device = torch.device("cuda", torch.cuda.current_device()) alignment = int(alignment) if alignment > 1 and (alignment & (alignment - 1)) != 0: - raise ValueError( - f"alignment must be 0, 1, or a power of two (got {alignment})." - ) + raise ValueError(f"alignment must be 0, 1, or a power of two (got {alignment}).") self.top_k = int(top_k) self.alignment = alignment self.max_tokens_per_rank = int(max_tokens_per_rank) @@ -249,83 +266,43 @@ def __init__( self.num_local_experts = int(num_local_experts) self.payload_dtype = payload_dtype self.device = device + size_bytes = tex.ep_handle_mem_size(self.top_k, self.alignment) self.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) - -# Buffer - - -class EpBuffer: - """Persistent payload and scratch buffers for one EP layer. - - All slots are plain HBM in TE 2.17 (the symm-mem IO fast path is planned - for a near-future release). - - Use one EpBuffer per concurrently-in-flight call on the layer (one per - PP-1F1B microbatch); sharing between an outstanding fwd and a later call - overwrites tensors the earlier bwd still reads. Call record_stream from - streams other than the allocation stream. - """ - - __slots__ = ( - "recv_tokens", - "combine_in", - "recv_topk_weights", - "token_counts", - "grad_tokens", - "grad_topk_weights", - ) - - def __init__( - self, - handle: EpHandle, - ep_group: Optional[dist.ProcessGroup] = None, - *, - device: Optional[torch.device] = None, - ) -> None: - """Allocate the persistent EP slots. - - Cross-rank slots are symm-mem-backed when ``ep_bootstrap`` was called - with ``zero_copy=True`` (requires ``ep_group``); otherwise plain HBM. - """ - if device is None: - device = torch.device("cuda", torch.cuda.current_device()) - recv_shape = (handle.recv_capacity_per_rank, handle.hidden_dim) - send_shape = (handle.max_tokens_per_rank, handle.hidden_dim) + recv_shape = (self.recv_capacity_per_rank, self.hidden_dim) + send_shape = (self.max_tokens_per_rank, self.hidden_dim) zero_copy = bool(tex.ep_get_zero_copy()) if zero_copy: if ep_group is None: raise ValueError("EpBuffer requires ep_group when ep_bootstrap(zero_copy=True).") - self.recv_tokens = symm_mem_alloc( - recv_shape, handle.payload_dtype, ep_group, device=device - ) - self.combine_in = symm_mem_alloc( - recv_shape, handle.payload_dtype, ep_group, device=device - ) + self.recv_tokens = symm_mem_alloc(recv_shape, payload_dtype, ep_group, device=device) + self.combine_in = symm_mem_alloc(recv_shape, payload_dtype, ep_group, device=device) self.recv_topk_weights = symm_mem_alloc( - (handle.recv_capacity_per_rank,), torch.float32, ep_group, device=device - ) - self.grad_tokens = symm_mem_alloc( - send_shape, handle.payload_dtype, ep_group, device=device + (self.recv_capacity_per_rank,), torch.float32, ep_group, device=device ) + self.grad_tokens = symm_mem_alloc(send_shape, payload_dtype, ep_group, device=device) else: - self.recv_tokens = torch.empty(recv_shape, dtype=handle.payload_dtype, device=device) - self.combine_in = torch.empty(recv_shape, dtype=handle.payload_dtype, device=device) + self.recv_tokens = torch.empty(recv_shape, dtype=payload_dtype, device=device) + self.combine_in = torch.empty(recv_shape, dtype=payload_dtype, device=device) self.recv_topk_weights = torch.empty( - handle.recv_capacity_per_rank, dtype=torch.float32, device=device + self.recv_capacity_per_rank, dtype=torch.float32, device=device ) - self.grad_tokens = torch.empty(send_shape, dtype=handle.payload_dtype, device=device) + self.grad_tokens = torch.empty(send_shape, dtype=payload_dtype, device=device) # Per-rank scratch; never cross-rank, plain HBM regardless of mode. - self.token_counts = torch.empty(handle.num_local_experts, dtype=torch.int32, device=device) + self.token_counts = torch.empty(self.num_local_experts, dtype=torch.int32, device=device) self.grad_topk_weights = torch.empty( - (handle.max_tokens_per_rank, handle.top_k), dtype=torch.float32, device=device + (self.max_tokens_per_rank, self.top_k), dtype=torch.float32, device=device ) @classmethod def from_external( cls, - handle: EpHandle, + top_k: int, + max_tokens_per_rank: int, + recv_capacity_per_rank: int, + hidden_dim: int, + num_local_experts: int, *, recv_tokens: torch.Tensor, combine_in: torch.Tensor, @@ -333,23 +310,27 @@ def from_external( grad_tokens: Optional[torch.Tensor] = None, token_counts: Optional[torch.Tensor] = None, grad_topk_weights: Optional[torch.Tensor] = None, + alignment: int = 0, + payload_dtype: torch.dtype = torch.bfloat16, device: Optional[torch.device] = None, ) -> "EpBuffer": - """Construct from caller-allocated buffers. + """Construct from caller-allocated payload buffers. - Useful for sharing a pre-allocated pool across layers/microbatches, and - for plugging in symm-mem-backed tensors once the zero-copy IO fast path - ships in a near-future release. Caller-supplied slots are validated - against the expected shape and dtype; ``None`` slots get a fresh HBM - allocation. + Useful for sharing a pre-allocated pool across layers/microbatches and + for plugging in symm-mem-backed tensors. Caller-supplied slots are + validated against the expected shape and dtype; ``None`` slots get a + fresh HBM allocation. ``handle_mem`` is always allocated fresh. """ if device is None: device = torch.device("cuda", torch.cuda.current_device()) - recv_shape = (handle.recv_capacity_per_rank, handle.hidden_dim) - send_shape = (handle.max_tokens_per_rank, handle.hidden_dim) - topk_shape = (handle.max_tokens_per_rank, handle.top_k) - recv_w_shape = (handle.recv_capacity_per_rank,) - counts_shape = (handle.num_local_experts,) + alignment = int(alignment) + if alignment > 1 and (alignment & (alignment - 1)) != 0: + raise ValueError(f"alignment must be 0, 1, or a power of two (got {alignment}).") + recv_shape = (recv_capacity_per_rank, hidden_dim) + send_shape = (max_tokens_per_rank, hidden_dim) + topk_shape = (max_tokens_per_rank, top_k) + recv_w_shape = (recv_capacity_per_rank,) + counts_shape = (num_local_experts,) def _check(t: torch.Tensor, name: str, shape: tuple, dtype: torch.dtype) -> torch.Tensor: if tuple(t.shape) != shape: @@ -359,17 +340,29 @@ def _check(t: torch.Tensor, name: str, shape: tuple, dtype: torch.dtype) -> torc return t inst = cls.__new__(cls) - inst.recv_tokens = _check(recv_tokens, "recv_tokens", recv_shape, handle.payload_dtype) - inst.combine_in = _check(combine_in, "combine_in", recv_shape, handle.payload_dtype) + inst.top_k = int(top_k) + inst.alignment = alignment + inst.max_tokens_per_rank = int(max_tokens_per_rank) + inst.recv_capacity_per_rank = int(recv_capacity_per_rank) + inst.hidden_dim = int(hidden_dim) + inst.num_local_experts = int(num_local_experts) + inst.payload_dtype = payload_dtype + inst.device = device + + size_bytes = tex.ep_handle_mem_size(inst.top_k, inst.alignment) + inst.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) + + inst.recv_tokens = _check(recv_tokens, "recv_tokens", recv_shape, payload_dtype) + inst.combine_in = _check(combine_in, "combine_in", recv_shape, payload_dtype) inst.recv_topk_weights = ( _check(recv_topk_weights, "recv_topk_weights", recv_w_shape, torch.float32) if recv_topk_weights is not None else torch.empty(recv_w_shape, dtype=torch.float32, device=device) ) inst.grad_tokens = ( - _check(grad_tokens, "grad_tokens", send_shape, handle.payload_dtype) + _check(grad_tokens, "grad_tokens", send_shape, payload_dtype) if grad_tokens is not None - else torch.empty(send_shape, dtype=handle.payload_dtype, device=device) + else torch.empty(send_shape, dtype=payload_dtype, device=device) ) inst.token_counts = ( _check(token_counts, "token_counts", counts_shape, torch.int32) @@ -387,6 +380,7 @@ def record_stream(self, stream: torch.cuda.Stream) -> None: """Record stream as a user of all owned tensors so the caching allocator defers reclaim until stream has caught up.""" for t in ( + self.handle_mem, self.recv_tokens, self.combine_in, self.recv_topk_weights, @@ -502,19 +496,19 @@ def _(*args, **kw): # Non-autograd primitives -def ep_prepare(handle: EpHandle, topk_idx: torch.Tensor) -> torch.Tensor: - """AllGather the routing map; fills handle.handle_mem and returns token_counts - (int32, shape [num_local_experts]). topk_idx must be int64. +def ep_prepare(buffer: "EpBuffer", topk_idx: torch.Tensor) -> torch.Tensor: + """AllGather the routing map; fills ``buffer.handle_mem`` and returns + ``buffer.token_counts`` (int32, shape [num_local_experts]). topk_idx must + be int64. """ - token_counts = torch.empty(handle.num_local_experts, dtype=torch.int32, device=handle.device) torch.ops.transformer_engine_ep.prepare( - handle.handle_mem, handle.top_k, topk_idx, token_counts, handle.alignment + buffer.handle_mem, buffer.top_k, topk_idx, buffer.token_counts, buffer.alignment ) - return token_counts + return buffer.token_counts def _ep_dispatch_raw( - handle: EpHandle, + buffer: "EpBuffer", topk_idx: torch.Tensor, tokens: torch.Tensor, topk_weights: torch.Tensor, @@ -523,13 +517,13 @@ def _ep_dispatch_raw( ) -> None: """Raw dispatch; no autograd, no prepare. Caller must run ep_prepare first.""" tex.ep_dispatch( - handle.handle_mem, topk_idx, tokens, topk_weights, recv_tokens, recv_topk_weights + buffer.handle_mem, topk_idx, tokens, topk_weights, recv_tokens, recv_topk_weights ) -def _ep_combine_raw(handle: EpHandle, expert_out: torch.Tensor, result: torch.Tensor) -> None: +def _ep_combine_raw(buffer: "EpBuffer", expert_out: torch.Tensor, result: torch.Tensor) -> None: """Raw combine; no autograd. Caller pre-weights expert_out.""" - tex.ep_combine(handle.handle_mem, expert_out, result) + tex.ep_combine(buffer.handle_mem, expert_out, result) # autograd.Function wrappers @@ -681,7 +675,6 @@ def _reject_fp8(*tensors: torch.Tensor) -> None: def ep_dispatch( - handle: EpHandle, buffer: EpBuffer, tokens: torch.Tensor, topk_idx: torch.Tensor, @@ -689,15 +682,15 @@ def ep_dispatch( ): """Run prepare + dispatch with autograd. topk_idx must be int64. - Returns (recv_tokens, recv_topk_weights, token_counts); views into buffer's - persistent slots; consume them before the next ep_dispatch on the same - buffer or they get overwritten. token_counts is non-differentiable. + Returns (recv_tokens, recv_topk_weights, token_counts); views into the + buffer's persistent slots — consume them before the next ep_dispatch on + the same buffer or they get overwritten. token_counts is non-differentiable. """ _reject_fp8(tokens, buffer.recv_tokens) return _EpDispatch.apply( - handle.handle_mem, - handle.top_k, - handle.alignment, + buffer.handle_mem, + buffer.top_k, + buffer.alignment, buffer.recv_tokens, buffer.recv_topk_weights, buffer.token_counts, @@ -710,7 +703,6 @@ def ep_dispatch( def ep_combine( - handle: EpHandle, buffer: EpBuffer, expert_out: torch.Tensor, *, @@ -719,16 +711,16 @@ def ep_combine( """Combine expert outputs back to the source rank, with autograd. The caller must pre-apply the topk weighting to expert_out. - Result shape is (num_local_tokens, handle.hidden_dim); defaults to - handle.max_tokens_per_rank rows. + Result shape is (num_local_tokens, buffer.hidden_dim); defaults to + buffer.max_tokens_per_rank rows. """ _reject_fp8(expert_out, buffer.combine_in) if num_local_tokens is None: - num_local_tokens = handle.max_tokens_per_rank + num_local_tokens = buffer.max_tokens_per_rank return _EpCombine.apply( - handle.handle_mem, + buffer.handle_mem, buffer.combine_in, num_local_tokens, - handle.hidden_dim, + buffer.hidden_dim, expert_out, ) From 8409e2b9500fea8aa01d38e8f5be9258675cbef4 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 17:39:37 -0700 Subject: [PATCH 04/39] EP PyTorch example: drop stale ep_group kwarg from EpBuffer call Signed-off-by: Phuong Nguyen --- examples/pytorch/ep/ep_moe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/pytorch/ep/ep_moe.py b/examples/pytorch/ep/ep_moe.py index 70bf678f3c..ca7469951d 100644 --- a/examples/pytorch/ep/ep_moe.py +++ b/examples/pytorch/ep/ep_moe.py @@ -152,7 +152,6 @@ def _run_layer(args, rank, world_size, ep_size, num_experts, num_local_experts, recv_capacity_per_rank=recv_pr, hidden_dim=args.hidden, num_local_experts=num_local_experts, - ep_group=ep_group, ) recv_t, recv_w_out, _tc = ep_dispatch(buffer, tokens, topk_idx, topk_w) From b8480f7a4205c7f9c64f3b9619b461dd93c873a3 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 17:45:14 -0700 Subject: [PATCH 05/39] EP PyTorch: drop ep_scope; ep_finalize is optional with atexit fallback Signed-off-by: Phuong Nguyen --- examples/pytorch/ep/ep_moe.py | 14 ++++++---- transformer_engine/pytorch/ep.py | 46 +++++--------------------------- 2 files changed, 15 insertions(+), 45 deletions(-) diff --git a/examples/pytorch/ep/ep_moe.py b/examples/pytorch/ep/ep_moe.py index ca7469951d..24e271f2f8 100644 --- a/examples/pytorch/ep/ep_moe.py +++ b/examples/pytorch/ep/ep_moe.py @@ -16,9 +16,10 @@ from transformer_engine.pytorch.ep import ( EpBuffer, - ep_scope, - ep_dispatch, + ep_bootstrap, ep_combine, + ep_dispatch, + ep_finalize, ) @@ -114,15 +115,18 @@ def main(): recv_pr = ep_size * T * args.top_k ep_group = dist.new_group(ranks=list(range(world_size)), backend="nccl") - with ep_scope( + ep_bootstrap( ep_group, num_experts=num_experts, max_tokens_per_rank=T, recv_capacity_per_rank=recv_pr, hidden_dim=args.hidden, - ): + ) + try: _run_layer(args, rank, world_size, ep_size, num_experts, num_local_experts, T, recv_pr, device) - dist.destroy_process_group() + finally: + ep_finalize() + dist.destroy_process_group() def _run_layer(args, rank, world_size, ep_size, num_experts, num_local_experts, T, recv_pr, device): diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 9e66d882b4..b3612401bd 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -6,8 +6,7 @@ from __future__ import annotations import atexit -from contextlib import contextmanager -from typing import Iterator, Optional +from typing import Optional import torch import torch.distributed as dist @@ -19,7 +18,6 @@ "EpBuffer", "ep_bootstrap", "ep_finalize", - "ep_scope", "ep_dispatch", "ep_combine", "symm_mem_alloc", @@ -154,12 +152,12 @@ def ep_bootstrap( def ep_finalize() -> None: - """Explicit EP teardown; optional and idempotent. An atexit handler covers - normal shutdown; call this only before ``dist.destroy_process_group()``, - since the borrowed NCCL comm is invalid once the PG is destroyed. + """Optional explicit EP teardown; idempotent. - Propagates errors from the C++ teardown; use ``_atexit_finalize`` for the - best-effort interpreter-shutdown path. + An atexit handler covers normal interpreter shutdown, so most users do not + need to call this. Call it explicitly only before + ``dist.destroy_process_group()``, since the borrowed NCCL comm becomes + invalid once the PG is destroyed. """ global _BOOTSTRAPPED if not _BOOTSTRAPPED: @@ -170,38 +168,6 @@ def ep_finalize() -> None: _BOOTSTRAPPED = False -@contextmanager -def ep_scope( - ep_group: dist.ProcessGroup, - num_experts: int, - max_tokens_per_rank: int, - recv_capacity_per_rank: int, - hidden_dim: int, - max_num_sms: int = 0, - zero_copy: bool = False, - max_token_dtype: torch.dtype = torch.bfloat16, -) -> Iterator[None]: - """Context manager: ``ep_bootstrap`` on enter, ``ep_finalize`` on exit. - - Use when you tear down the EP process group yourself, so the borrowed NCCL - comm is released before ``dist.destroy_process_group()``. - """ - ep_bootstrap( - ep_group, - num_experts=num_experts, - max_tokens_per_rank=max_tokens_per_rank, - recv_capacity_per_rank=recv_capacity_per_rank, - hidden_dim=hidden_dim, - max_num_sms=max_num_sms, - zero_copy=zero_copy, - max_token_dtype=max_token_dtype, - ) - try: - yield - finally: - ep_finalize() - - # Buffer From 87a1ba32973e5890489d904f3b9d21f695b80d27 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 17:57:59 -0700 Subject: [PATCH 06/39] EP PyTorch: restrict payload dtype to bf16; refresh stale ep.cpp comments; drop unused ep_group kwargs Signed-off-by: Phuong Nguyen --- examples/pytorch/ep/bench/ep_bench.py | 1 - tests/pytorch/distributed/run_ep.py | 1 - .../pytorch/csrc/extensions/ep.cpp | 11 ++++------ transformer_engine/pytorch/ep.py | 21 +++++++------------ 4 files changed, 12 insertions(+), 22 deletions(-) diff --git a/examples/pytorch/ep/bench/ep_bench.py b/examples/pytorch/ep/bench/ep_bench.py index 973703b710..82a5429f9d 100644 --- a/examples/pytorch/ep/bench/ep_bench.py +++ b/examples/pytorch/ep/bench/ep_bench.py @@ -182,7 +182,6 @@ def main(): recv_capacity_per_rank=recv_pr, hidden_dim=H, num_local_experts=num_local_experts, - ep_group=ep_group, ) tokens = tokens_hbm diff --git a/tests/pytorch/distributed/run_ep.py b/tests/pytorch/distributed/run_ep.py index 3b343466b9..0e29ca7eae 100644 --- a/tests/pytorch/distributed/run_ep.py +++ b/tests/pytorch/distributed/run_ep.py @@ -118,7 +118,6 @@ def _make_buffer(self, alignment=0, top_k=TOP_K): hidden_dim=HIDDEN_DIM, num_local_experts=NUM_LOCAL_EXPERTS, alignment=alignment, - ep_group=self.ep_group, ) def _make_raw_recv(self, dtype=torch.bfloat16): diff --git a/transformer_engine/pytorch/csrc/extensions/ep.cpp b/transformer_engine/pytorch/csrc/extensions/ep.cpp index 018b5addc0..46484b8b67 100644 --- a/transformer_engine/pytorch/csrc/extensions/ep.cpp +++ b/transformer_engine/pytorch/csrc/extensions/ep.cpp @@ -46,15 +46,12 @@ std::string g_ep_group_name; // NOLINT(runtime/string) // True while the EP backend holds a borrowed reference to torch's NCCL comm. bool g_ep_initialized = false; -// Zero-copy IO toggle. Placeholder for the symm-mem fast path; per-step ops -// always pass kNoWindow in this release regardless of the toggle. Wired up -// so the switch is a one-line change when the backend lands the fast path. -// Atomic so the Python-side toggle is safe against concurrent -// ep_dispatch/combine (which release the GIL). +// Zero-copy IO toggle captured at ep_initialize. Atomic so the Python-side +// toggle is safe against concurrent ep_dispatch/combine (which release the GIL). std::atomic g_zero_copy_enabled{false}; -// Per-step ops always pass kNoWindow in this release; the symm-mem IO path is -// planned for a near-future release. +// Sentinel returned by maybe_make_window when zero-copy is off or the tensor +// is not symm-mem-backed; the backend treats it as "no window, use staged copy". constexpr NVTECommWindow kNoWindow = {nullptr, 0}; // Resolve ``t`` to an NCCL symm-mem window for the zero-copy one-sided path. diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index b3612401bd..34cbe93a29 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -627,17 +627,12 @@ def backward(ctx, g_result): # type: ignore[override] # Public high-level wrappers -# FP8 dispatch is not yet supported by the common backend. -_FP8_DTYPES = (torch.float8_e4m3fn, torch.float8_e5m2) - - -def _reject_fp8(*tensors: torch.Tensor) -> None: - for t in tensors: - if t.dtype in _FP8_DTYPES: - raise NotImplementedError( - f"FP8 dispatch/combine not supported (got dtype={t.dtype}); " - "quantize outside the EP boundary." - ) +# NCCL EP currently only supports bfloat16 payload tensors. +def _require_bf16(name: str, t: torch.Tensor) -> None: + if t.dtype is not torch.bfloat16: + raise NotImplementedError( + f"NCCL EP currently supports only bfloat16 payloads; got {name}.dtype={t.dtype}." + ) def ep_dispatch( @@ -652,7 +647,7 @@ def ep_dispatch( buffer's persistent slots — consume them before the next ep_dispatch on the same buffer or they get overwritten. token_counts is non-differentiable. """ - _reject_fp8(tokens, buffer.recv_tokens) + _require_bf16("tokens", tokens) return _EpDispatch.apply( buffer.handle_mem, buffer.top_k, @@ -680,7 +675,7 @@ def ep_combine( Result shape is (num_local_tokens, buffer.hidden_dim); defaults to buffer.max_tokens_per_rank rows. """ - _reject_fp8(expert_out, buffer.combine_in) + _require_bf16("expert_out", expert_out) if num_local_tokens is None: num_local_tokens = buffer.max_tokens_per_rank return _EpCombine.apply( From 5444954f9069edbe520b9c2082a81e37820ea441 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 18:26:42 -0700 Subject: [PATCH 07/39] EP PyTorch: clear pylint warnings in ep.py (broad-except suppression, _-prefixed stub args, autograd docstrings) Signed-off-by: Phuong Nguyen --- transformer_engine/pytorch/ep.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 34cbe93a29..d5368e6e47 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -96,7 +96,7 @@ def _atexit_finalize() -> None: if _BOOTSTRAPPED: try: tex.ep_finalize() - except Exception: + except Exception: # pylint: disable=broad-exception-caught import traceback traceback.print_exc() @@ -378,7 +378,7 @@ def _prepare_op( @_prepare_op.register_fake -def _(*args, **kw): +def _(*_args, **_kw): return None @@ -399,7 +399,7 @@ def _dispatch_op( @_dispatch_op.register_fake -def _(*args, **kw): +def _(*_args, **_kw): return None @@ -417,7 +417,7 @@ def _combine_op( @_combine_op.register_fake -def _(*args, **kw): +def _(*_args, **_kw): return None @@ -437,7 +437,7 @@ def _dispatch_bwd_op( @_dispatch_bwd_op.register_fake -def _(*args, **kw): +def _(*_args, **_kw): return None @@ -455,7 +455,7 @@ def _combine_bwd_op( @_combine_bwd_op.register_fake -def _(*args, **kw): +def _(*_args, **_kw): return None @@ -516,6 +516,7 @@ def forward( # type: ignore[override] tokens: torch.Tensor, topk_weights: torch.Tensor, ): + """Prepare + dispatch; stashes buffer slots + shapes for the bwd pass.""" torch.ops.transformer_engine_ep.prepare( handle_mem, top_k, topk_idx, token_counts, alignment ) @@ -545,6 +546,7 @@ def forward( # type: ignore[override] @staticmethod def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: ignore[override] + """Dispatch backward into the persistent grad_tokens/grad_topk_weights slots.""" device = ctx.handle_mem.device if g_recv_tokens is None: g_recv_tokens = torch.zeros( @@ -599,6 +601,7 @@ def forward( # type: ignore[override] hidden_dim: int, expert_out: torch.Tensor, ): + """Combine expert outputs; reuses combine_in as the grad slot for bwd.""" device = expert_out.device # Stage expert_out into the persistent combine_in slot (symm-mem-backed # in the zero-copy path); its storage is reused as grad_combine_in in bwd. @@ -611,6 +614,7 @@ def forward( # type: ignore[override] @staticmethod def backward(ctx, g_result): # type: ignore[override] + """Combine backward into combine_in storage; returned as grad of expert_out.""" grad_combine_in = ctx.combine_in if not g_result.is_contiguous(): g_result = g_result.contiguous() From d5a6c09992cbc0f63146feeeb78969fc7d3fb66e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Jun 2026 01:29:03 +0000 Subject: [PATCH 08/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/ep/bench/ep_bench.py | 4 +--- examples/pytorch/ep/ep_moe.py | 9 ++++----- transformer_engine/pytorch/csrc/extensions/ep.cpp | 4 ++-- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/examples/pytorch/ep/bench/ep_bench.py b/examples/pytorch/ep/bench/ep_bench.py index 82a5429f9d..81f5b83883 100644 --- a/examples/pytorch/ep/bench/ep_bench.py +++ b/examples/pytorch/ep/bench/ep_bench.py @@ -208,9 +208,7 @@ def main(): eo_p = recv_tokens.detach().clone().requires_grad_(True) # Stand-in callables; the cuda-graph branch below swaps in graphed versions. - fwd_bwd_dispatch_fn = lambda x: ep_dispatch(buffer, x, topk_idx, topk_w)[ # noqa: E731 - 0 - ] + fwd_bwd_dispatch_fn = lambda x: ep_dispatch(buffer, x, topk_idx, topk_w)[0] # noqa: E731 fwd_bwd_combine_fn = lambda eo: ep_combine(buffer, eo) # noqa: E731 def _dispatch_raw(): diff --git a/examples/pytorch/ep/ep_moe.py b/examples/pytorch/ep/ep_moe.py index 24e271f2f8..f72912301b 100644 --- a/examples/pytorch/ep/ep_moe.py +++ b/examples/pytorch/ep/ep_moe.py @@ -123,7 +123,9 @@ def main(): hidden_dim=args.hidden, ) try: - _run_layer(args, rank, world_size, ep_size, num_experts, num_local_experts, T, recv_pr, device) + _run_layer( + args, rank, world_size, ep_size, num_experts, num_local_experts, T, recv_pr, device + ) finally: ep_finalize() dist.destroy_process_group() @@ -196,10 +198,7 @@ def _run_layer(args, rank, world_size, ep_size, num_experts, num_local_experts, torch.cuda.synchronize() dt_ms = (time.perf_counter() - t0) * 1000.0 / args.benchmark_iters if rank == 0: - print( - f"[ep_moe --benchmark] HBM: {dt_ms:.3f} ms/iter " - f"(iters={args.benchmark_iters})" - ) + print(f"[ep_moe --benchmark] HBM: {dt_ms:.3f} ms/iter (iters={args.benchmark_iters})") if args.check: # All-gather inputs/outputs/grads for a global reference comparison. diff --git a/transformer_engine/pytorch/csrc/extensions/ep.cpp b/transformer_engine/pytorch/csrc/extensions/ep.cpp index 46484b8b67..8624605edd 100644 --- a/transformer_engine/pytorch/csrc/extensions/ep.cpp +++ b/transformer_engine/pytorch/csrc/extensions/ep.cpp @@ -102,8 +102,8 @@ bool ep_get_zero_copy() { return g_zero_copy_enabled.load(std::memory_order_rela void ep_initialize(uintptr_t comm_ptr, const std::string& group_name, int64_t num_experts, int64_t max_tokens_per_rank, int64_t max_recv_tokens_per_rank, - int64_t hidden_dim, int64_t max_num_sms, - pybind11::object max_token_dtype, bool zero_copy) { + int64_t hidden_dim, int64_t max_num_sms, pybind11::object max_token_dtype, + bool zero_copy) { NVTE_CHECK(!group_name.empty(), "group_name must be non-empty (used for symm-mem lookup)"); NVTE_CHECK(comm_ptr != 0, "comm_ptr must be non-null (torch NCCL host comm pointer)"); NVTE_CHECK(!g_ep_initialized, "ep_initialize called twice without ep_finalize"); From 2e9a89eb24ae838f1b4da046ffd94b3a1f7a083b Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 20:28:12 -0700 Subject: [PATCH 09/39] EP PyTorch: skip combine_in staging copy in non-zero-copy mode Signed-off-by: Phuong Nguyen --- transformer_engine/pytorch/ep.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index d5368e6e47..64a9a69df7 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -205,6 +205,7 @@ class EpBuffer: "token_counts", "grad_tokens", "grad_topk_weights", + "zero_copy", ) def __init__( @@ -239,6 +240,7 @@ def __init__( recv_shape = (self.recv_capacity_per_rank, self.hidden_dim) send_shape = (self.max_tokens_per_rank, self.hidden_dim) zero_copy = bool(tex.ep_get_zero_copy()) + self.zero_copy = zero_copy if zero_copy: if ep_group is None: raise ValueError("EpBuffer requires ep_group when ep_bootstrap(zero_copy=True).") @@ -314,6 +316,7 @@ def _check(t: torch.Tensor, name: str, shape: tuple, dtype: torch.dtype) -> torc inst.num_local_experts = int(num_local_experts) inst.payload_dtype = payload_dtype inst.device = device + inst.zero_copy = bool(tex.ep_get_zero_copy()) size_bytes = tex.ep_handle_mem_size(inst.top_k, inst.alignment) inst.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) @@ -599,17 +602,23 @@ def forward( # type: ignore[override] combine_in: torch.Tensor, num_local_tokens: int, hidden_dim: int, + zero_copy: bool, expert_out: torch.Tensor, ): - """Combine expert outputs; reuses combine_in as the grad slot for bwd.""" + """Combine expert outputs; combine_in storage is reused as the grad slot in bwd.""" device = expert_out.device - # Stage expert_out into the persistent combine_in slot (symm-mem-backed - # in the zero-copy path); its storage is reused as grad_combine_in in bwd. - combine_in.copy_(expert_out) + # Zero-copy mode: peers read from combine_in via symm-mem, so it must + # hold the expert outputs; stage expert_out into it unless aliased. + # Otherwise the kernel reads expert_out directly, no copy needed. + if zero_copy and combine_in.data_ptr() != expert_out.data_ptr(): + combine_in.copy_(expert_out) + kernel_in = combine_in + else: + kernel_in = expert_out result = torch.empty(num_local_tokens, hidden_dim, dtype=expert_out.dtype, device=device) - torch.ops.transformer_engine_ep.combine(handle_mem, combine_in, result) + torch.ops.transformer_engine_ep.combine(handle_mem, kernel_in, result) ctx.handle_mem = handle_mem - ctx.combine_in = combine_in # reused as grad_combine_in in bwd + ctx.combine_in = combine_in # used as grad_combine_in in bwd return result @staticmethod @@ -624,6 +633,7 @@ def backward(ctx, g_result): # type: ignore[override] None, # combine_in None, # num_local_tokens None, # hidden_dim + None, # zero_copy grad_combine_in, ) @@ -687,5 +697,6 @@ def ep_combine( buffer.combine_in, num_local_tokens, buffer.hidden_dim, + buffer.zero_copy, expert_out, ) From e9cdaed6af0aa7e5a2893ed626805871b6ae7b5f Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 20:46:55 -0700 Subject: [PATCH 10/39] EP PyTorch: gate symm-mem slot allocation on zero-copy; require expert_out to alias combine_in in zero-copy Signed-off-by: Phuong Nguyen --- transformer_engine/pytorch/ep.py | 151 ++++++++++++++++++++----------- 1 file changed, 99 insertions(+), 52 deletions(-) diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 64a9a69df7..264d4e51f9 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -241,6 +241,8 @@ def __init__( send_shape = (self.max_tokens_per_rank, self.hidden_dim) zero_copy = bool(tex.ep_get_zero_copy()) self.zero_copy = zero_copy + # Cross-rank slots are pre-allocated as symm-mem only when zero-copy + # is on; non-zero-copy mode allocates plain HBM per call in the ops. if zero_copy: if ep_group is None: raise ValueError("EpBuffer requires ep_group when ep_bootstrap(zero_copy=True).") @@ -251,12 +253,10 @@ def __init__( ) self.grad_tokens = symm_mem_alloc(send_shape, payload_dtype, ep_group, device=device) else: - self.recv_tokens = torch.empty(recv_shape, dtype=payload_dtype, device=device) - self.combine_in = torch.empty(recv_shape, dtype=payload_dtype, device=device) - self.recv_topk_weights = torch.empty( - self.recv_capacity_per_rank, dtype=torch.float32, device=device - ) - self.grad_tokens = torch.empty(send_shape, dtype=payload_dtype, device=device) + self.recv_tokens = None + self.combine_in = None + self.recv_topk_weights = None + self.grad_tokens = None # Per-rank scratch; never cross-rank, plain HBM regardless of mode. self.token_counts = torch.empty(self.num_local_experts, dtype=torch.int32, device=device) self.grad_topk_weights = torch.empty( @@ -272,8 +272,8 @@ def from_external( hidden_dim: int, num_local_experts: int, *, - recv_tokens: torch.Tensor, - combine_in: torch.Tensor, + recv_tokens: Optional[torch.Tensor] = None, + combine_in: Optional[torch.Tensor] = None, recv_topk_weights: Optional[torch.Tensor] = None, grad_tokens: Optional[torch.Tensor] = None, token_counts: Optional[torch.Tensor] = None, @@ -284,10 +284,10 @@ def from_external( ) -> "EpBuffer": """Construct from caller-allocated payload buffers. - Useful for sharing a pre-allocated pool across layers/microbatches and - for plugging in symm-mem-backed tensors. Caller-supplied slots are - validated against the expected shape and dtype; ``None`` slots get a - fresh HBM allocation. ``handle_mem`` is always allocated fresh. + In zero-copy mode recv_tokens, combine_in, recv_topk_weights, and + grad_tokens must be supplied and symm-mem-backed; in non-zero-copy + mode they default to None and the ops allocate per call. handle_mem + is always allocated fresh. """ if device is None: device = torch.device("cuda", torch.cuda.current_device()) @@ -321,18 +321,28 @@ def _check(t: torch.Tensor, name: str, shape: tuple, dtype: torch.dtype) -> torc size_bytes = tex.ep_handle_mem_size(inst.top_k, inst.alignment) inst.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) - inst.recv_tokens = _check(recv_tokens, "recv_tokens", recv_shape, payload_dtype) - inst.combine_in = _check(combine_in, "combine_in", recv_shape, payload_dtype) - inst.recv_topk_weights = ( - _check(recv_topk_weights, "recv_topk_weights", recv_w_shape, torch.float32) - if recv_topk_weights is not None - else torch.empty(recv_w_shape, dtype=torch.float32, device=device) - ) - inst.grad_tokens = ( - _check(grad_tokens, "grad_tokens", send_shape, payload_dtype) - if grad_tokens is not None - else torch.empty(send_shape, dtype=payload_dtype, device=device) - ) + if inst.zero_copy: + if ( + recv_tokens is None + or combine_in is None + or recv_topk_weights is None + or grad_tokens is None + ): + raise ValueError( + "EpBuffer.from_external requires recv_tokens, combine_in, recv_topk_weights, " + "and grad_tokens (all symm-mem-backed) when zero-copy is enabled." + ) + inst.recv_tokens = _check(recv_tokens, "recv_tokens", recv_shape, payload_dtype) + inst.combine_in = _check(combine_in, "combine_in", recv_shape, payload_dtype) + inst.recv_topk_weights = _check( + recv_topk_weights, "recv_topk_weights", recv_w_shape, torch.float32 + ) + inst.grad_tokens = _check(grad_tokens, "grad_tokens", send_shape, payload_dtype) + else: + inst.recv_tokens = None + inst.combine_in = None + inst.recv_topk_weights = None + inst.grad_tokens = None inst.token_counts = ( _check(token_counts, "token_counts", counts_shape, torch.int32) if token_counts is not None @@ -357,7 +367,8 @@ def record_stream(self, stream: torch.cuda.Stream) -> None: self.token_counts, self.grad_topk_weights, ): - t.record_stream(stream) + if t is not None: + t.record_stream(stream) # torch.library custom ops (so they don't graph-break under torch.compile) @@ -549,7 +560,8 @@ def forward( # type: ignore[override] @staticmethod def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: ignore[override] - """Dispatch backward into the persistent grad_tokens/grad_topk_weights slots.""" + """Dispatch backward; grad_tokens uses the buffer's symm-mem slot in + zero-copy mode or a fresh HBM tensor otherwise.""" device = ctx.handle_mem.device if g_recv_tokens is None: g_recv_tokens = torch.zeros( @@ -561,8 +573,13 @@ def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: g_recv_tokens = g_recv_tokens.contiguous() if not g_recv_topk_weights.is_contiguous(): g_recv_topk_weights = g_recv_topk_weights.contiguous() - # Narrow the persistent slots to this call's flattened leading dim. - grad_tokens = ctx.grad_tokens_buf.narrow(0, 0, ctx.tokens_T_flat) + if ctx.grad_tokens_buf is not None: + # Zero-copy: narrow the persistent symm-mem slot to this call's leading dim. + grad_tokens = ctx.grad_tokens_buf.narrow(0, 0, ctx.tokens_T_flat) + else: + grad_tokens = torch.empty( + ctx.tokens_T_flat, ctx.hidden_dim, dtype=ctx.tokens_dtype, device=device + ) grad_topk_weights = ctx.grad_topk_weights_buf.narrow(0, 0, ctx.topk_T_flat) torch.ops.transformer_engine_ep.dispatch_bwd( ctx.handle_mem, @@ -590,43 +607,59 @@ def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: class _EpCombine(torch.autograd.Function): - """Autograd-aware combine. combine_in is reused as grad_combine_in in bwd; - fwd/bwd share handle_mem so don't re-run ep_prepare between them. Caller - must pre-apply the topk weighting to expert_out. + """Autograd-aware combine. Zero-copy mode requires expert_out to alias + buffer.combine_in (no implicit staging), and that storage is reused as + the grad slot in bwd. Non-zero-copy mode reads expert_out directly and + allocates the bwd grad slot fresh. Caller pre-applies topk weighting. """ @staticmethod def forward( # type: ignore[override] ctx, handle_mem: torch.Tensor, - combine_in: torch.Tensor, + combine_in: Optional[torch.Tensor], num_local_tokens: int, hidden_dim: int, zero_copy: bool, expert_out: torch.Tensor, ): - """Combine expert outputs; combine_in storage is reused as the grad slot in bwd.""" + """Combine fwd; zero-copy requires expert_out to alias combine_in.""" + if zero_copy: + if combine_in is None: + raise RuntimeError( + "ep_combine: zero-copy mode requires buffer.combine_in to be allocated." + ) + if combine_in.data_ptr() != expert_out.data_ptr(): + raise RuntimeError( + "ep_combine: zero-copy mode requires expert_out to alias " + "buffer.combine_in (write expert outputs directly into that slot; " + "no implicit copy)." + ) device = expert_out.device - # Zero-copy mode: peers read from combine_in via symm-mem, so it must - # hold the expert outputs; stage expert_out into it unless aliased. - # Otherwise the kernel reads expert_out directly, no copy needed. - if zero_copy and combine_in.data_ptr() != expert_out.data_ptr(): - combine_in.copy_(expert_out) - kernel_in = combine_in - else: - kernel_in = expert_out result = torch.empty(num_local_tokens, hidden_dim, dtype=expert_out.dtype, device=device) - torch.ops.transformer_engine_ep.combine(handle_mem, kernel_in, result) + torch.ops.transformer_engine_ep.combine(handle_mem, expert_out, result) ctx.handle_mem = handle_mem - ctx.combine_in = combine_in # used as grad_combine_in in bwd + ctx.combine_in = combine_in # None in non-zero-copy; reused as grad slot otherwise + ctx.zero_copy = zero_copy + ctx.recv_capacity = expert_out.shape[0] + ctx.hidden_dim = expert_out.shape[-1] + ctx.expert_out_dtype = expert_out.dtype return result @staticmethod def backward(ctx, g_result): # type: ignore[override] - """Combine backward into combine_in storage; returned as grad of expert_out.""" - grad_combine_in = ctx.combine_in + """Combine bwd; writes into combine_in in zero-copy or a fresh slot otherwise.""" if not g_result.is_contiguous(): g_result = g_result.contiguous() + if ctx.zero_copy: + grad_combine_in = ctx.combine_in + else: + grad_combine_in = torch.empty( + ctx.recv_capacity, + ctx.hidden_dim, + dtype=ctx.expert_out_dtype, + device=ctx.handle_mem.device, + ) torch.ops.transformer_engine_ep.combine_bwd(ctx.handle_mem, g_result, grad_combine_in) return ( None, # handle_mem @@ -657,17 +690,30 @@ def ep_dispatch( ): """Run prepare + dispatch with autograd. topk_idx must be int64. - Returns (recv_tokens, recv_topk_weights, token_counts); views into the - buffer's persistent slots — consume them before the next ep_dispatch on - the same buffer or they get overwritten. token_counts is non-differentiable. + Returns (recv_tokens, recv_topk_weights, token_counts). In zero-copy mode + recv_tokens / recv_topk_weights alias the buffer's persistent symm-mem + slots; otherwise they are freshly allocated. token_counts is non-diff. """ _require_bf16("tokens", tokens) + if buffer.zero_copy: + recv_tokens = buffer.recv_tokens + recv_topk_weights = buffer.recv_topk_weights + else: + recv_tokens = torch.empty( + buffer.recv_capacity_per_rank, + buffer.hidden_dim, + dtype=buffer.payload_dtype, + device=buffer.device, + ) + recv_topk_weights = torch.empty( + buffer.recv_capacity_per_rank, dtype=torch.float32, device=buffer.device + ) return _EpDispatch.apply( buffer.handle_mem, buffer.top_k, buffer.alignment, - buffer.recv_tokens, - buffer.recv_topk_weights, + recv_tokens, + recv_topk_weights, buffer.token_counts, buffer.grad_tokens, buffer.grad_topk_weights, @@ -683,8 +729,9 @@ def ep_combine( *, num_local_tokens: Optional[int] = None, ): - """Combine expert outputs back to the source rank, with autograd. The - caller must pre-apply the topk weighting to expert_out. + """Combine expert outputs back to the source rank, with autograd. Caller + pre-applies topk weighting. Zero-copy mode requires expert_out to alias + buffer.combine_in (write expert outputs into that slot directly). Result shape is (num_local_tokens, buffer.hidden_dim); defaults to buffer.max_tokens_per_rank rows. From 5f29d90c4ac2e358eedb3f5f03a136f6157865a3 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 21:39:55 -0700 Subject: [PATCH 11/39] EP PyTorch: rename buffer slots to dispatch_/combine_symm_buf; drop grad_tokens/grad_topk_weights; alias-check bwd grads in zero-copy Signed-off-by: Phuong Nguyen --- transformer_engine/pytorch/ep.py | 211 ++++++++++++++++--------------- 1 file changed, 108 insertions(+), 103 deletions(-) diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 264d4e51f9..ec99e1250a 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -198,13 +198,15 @@ class EpBuffer: "num_local_experts", "payload_dtype", "device", - # payload slots - "recv_tokens", - "combine_in", - "recv_topk_weights", + # Symm-mem slots (zero-copy only). Each is reused across fwd and bwd: + # dispatch_symm_buf: fwd out (recv_tokens) / bwd in (g_recv_tokens) + # dispatch_w_symm_buf: fwd out (recv_topk_w) / bwd in (g_recv_topk_w) + # combine_symm_buf: fwd in (expert_out) / bwd out (g_expert_out) + "dispatch_symm_buf", + "dispatch_w_symm_buf", + "combine_symm_buf", + # Per-rank scratch (always HBM). "token_counts", - "grad_tokens", - "grad_topk_weights", "zero_copy", ) @@ -238,30 +240,26 @@ def __init__( self.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) recv_shape = (self.recv_capacity_per_rank, self.hidden_dim) - send_shape = (self.max_tokens_per_rank, self.hidden_dim) zero_copy = bool(tex.ep_get_zero_copy()) self.zero_copy = zero_copy - # Cross-rank slots are pre-allocated as symm-mem only when zero-copy - # is on; non-zero-copy mode allocates plain HBM per call in the ops. if zero_copy: if ep_group is None: raise ValueError("EpBuffer requires ep_group when ep_bootstrap(zero_copy=True).") - self.recv_tokens = symm_mem_alloc(recv_shape, payload_dtype, ep_group, device=device) - self.combine_in = symm_mem_alloc(recv_shape, payload_dtype, ep_group, device=device) - self.recv_topk_weights = symm_mem_alloc( + self.dispatch_symm_buf = symm_mem_alloc( + recv_shape, payload_dtype, ep_group, device=device + ) + self.dispatch_w_symm_buf = symm_mem_alloc( (self.recv_capacity_per_rank,), torch.float32, ep_group, device=device ) - self.grad_tokens = symm_mem_alloc(send_shape, payload_dtype, ep_group, device=device) + self.combine_symm_buf = symm_mem_alloc( + recv_shape, payload_dtype, ep_group, device=device + ) else: - self.recv_tokens = None - self.combine_in = None - self.recv_topk_weights = None - self.grad_tokens = None - # Per-rank scratch; never cross-rank, plain HBM regardless of mode. + self.dispatch_symm_buf = None + self.dispatch_w_symm_buf = None + self.combine_symm_buf = None + # token_counts is local-only routing scratch; always plain HBM. self.token_counts = torch.empty(self.num_local_experts, dtype=torch.int32, device=device) - self.grad_topk_weights = torch.empty( - (self.max_tokens_per_rank, self.top_k), dtype=torch.float32, device=device - ) @classmethod def from_external( @@ -272,22 +270,20 @@ def from_external( hidden_dim: int, num_local_experts: int, *, - recv_tokens: Optional[torch.Tensor] = None, - combine_in: Optional[torch.Tensor] = None, - recv_topk_weights: Optional[torch.Tensor] = None, - grad_tokens: Optional[torch.Tensor] = None, + dispatch_symm_buf: Optional[torch.Tensor] = None, + dispatch_w_symm_buf: Optional[torch.Tensor] = None, + combine_symm_buf: Optional[torch.Tensor] = None, token_counts: Optional[torch.Tensor] = None, - grad_topk_weights: Optional[torch.Tensor] = None, alignment: int = 0, payload_dtype: torch.dtype = torch.bfloat16, device: Optional[torch.device] = None, ) -> "EpBuffer": - """Construct from caller-allocated payload buffers. + """Construct from caller-allocated buffers. - In zero-copy mode recv_tokens, combine_in, recv_topk_weights, and - grad_tokens must be supplied and symm-mem-backed; in non-zero-copy - mode they default to None and the ops allocate per call. handle_mem - is always allocated fresh. + In zero-copy mode dispatch_symm_buf, dispatch_w_symm_buf, and + combine_symm_buf must all be supplied and symm-mem-backed; in + non-zero-copy mode they must all be None (ops allocate per call). + handle_mem is always allocated fresh. """ if device is None: device = torch.device("cuda", torch.cuda.current_device()) @@ -295,8 +291,6 @@ def from_external( if alignment > 1 and (alignment & (alignment - 1)) != 0: raise ValueError(f"alignment must be 0, 1, or a power of two (got {alignment}).") recv_shape = (recv_capacity_per_rank, hidden_dim) - send_shape = (max_tokens_per_rank, hidden_dim) - topk_shape = (max_tokens_per_rank, top_k) recv_w_shape = (recv_capacity_per_rank,) counts_shape = (num_local_experts,) @@ -323,36 +317,41 @@ def _check(t: torch.Tensor, name: str, shape: tuple, dtype: torch.dtype) -> torc if inst.zero_copy: if ( - recv_tokens is None - or combine_in is None - or recv_topk_weights is None - or grad_tokens is None + dispatch_symm_buf is None + or dispatch_w_symm_buf is None + or combine_symm_buf is None ): raise ValueError( - "EpBuffer.from_external requires recv_tokens, combine_in, recv_topk_weights, " - "and grad_tokens (all symm-mem-backed) when zero-copy is enabled." + "EpBuffer.from_external: zero-copy mode requires dispatch_symm_buf, " + "dispatch_w_symm_buf, and combine_symm_buf (all symm-mem-backed)." ) - inst.recv_tokens = _check(recv_tokens, "recv_tokens", recv_shape, payload_dtype) - inst.combine_in = _check(combine_in, "combine_in", recv_shape, payload_dtype) - inst.recv_topk_weights = _check( - recv_topk_weights, "recv_topk_weights", recv_w_shape, torch.float32 + inst.dispatch_symm_buf = _check( + dispatch_symm_buf, "dispatch_symm_buf", recv_shape, payload_dtype + ) + inst.dispatch_w_symm_buf = _check( + dispatch_w_symm_buf, "dispatch_w_symm_buf", recv_w_shape, torch.float32 + ) + inst.combine_symm_buf = _check( + combine_symm_buf, "combine_symm_buf", recv_shape, payload_dtype ) - inst.grad_tokens = _check(grad_tokens, "grad_tokens", send_shape, payload_dtype) else: - inst.recv_tokens = None - inst.combine_in = None - inst.recv_topk_weights = None - inst.grad_tokens = None + if ( + dispatch_symm_buf is not None + or dispatch_w_symm_buf is not None + or combine_symm_buf is not None + ): + raise ValueError( + "EpBuffer.from_external: dispatch_symm_buf / dispatch_w_symm_buf / " + "combine_symm_buf are only used in zero-copy mode." + ) + inst.dispatch_symm_buf = None + inst.dispatch_w_symm_buf = None + inst.combine_symm_buf = None inst.token_counts = ( _check(token_counts, "token_counts", counts_shape, torch.int32) if token_counts is not None else torch.empty(counts_shape, dtype=torch.int32, device=device) ) - inst.grad_topk_weights = ( - _check(grad_topk_weights, "grad_topk_weights", topk_shape, torch.float32) - if grad_topk_weights is not None - else torch.empty(topk_shape, dtype=torch.float32, device=device) - ) return inst def record_stream(self, stream: torch.cuda.Stream) -> None: @@ -360,12 +359,10 @@ def record_stream(self, stream: torch.cuda.Stream) -> None: defers reclaim until stream has caught up.""" for t in ( self.handle_mem, - self.recv_tokens, - self.combine_in, - self.recv_topk_weights, - self.grad_tokens, + self.dispatch_symm_buf, + self.dispatch_w_symm_buf, + self.combine_symm_buf, self.token_counts, - self.grad_topk_weights, ): if t is not None: t.record_stream(stream) @@ -510,9 +507,10 @@ def _ep_combine_raw(buffer: "EpBuffer", expert_out: torch.Tensor, result: torch. class _EpDispatch(torch.autograd.Function): - """Autograd-aware prepare + dispatch. Fwd/bwd share handle_mem and the - EpBuffer slots; do not re-run ep_prepare between them and do not share - EpBuffer with another in-flight call (see EpBuffer). + """Autograd-aware prepare + dispatch. Fwd produces recv_tokens (alias of + dispatch_symm_buf in zero-copy, fresh otherwise). Zero-copy bwd requires + the incoming grads to alias dispatch_symm_buf / dispatch_w_symm_buf + (no implicit staging). Fwd/bwd share handle_mem; do not re-run ep_prepare. """ @staticmethod @@ -521,16 +519,15 @@ def forward( # type: ignore[override] handle_mem: torch.Tensor, top_k: int, alignment: int, + zero_copy: bool, recv_tokens: torch.Tensor, recv_topk_weights: torch.Tensor, token_counts: torch.Tensor, - grad_tokens_buf: torch.Tensor, - grad_topk_weights_buf: torch.Tensor, topk_idx: torch.Tensor, tokens: torch.Tensor, topk_weights: torch.Tensor, ): - """Prepare + dispatch; stashes buffer slots + shapes for the bwd pass.""" + """Prepare + dispatch; saves shapes for the bwd pass.""" torch.ops.transformer_engine_ep.prepare( handle_mem, top_k, topk_idx, token_counts, alignment ) @@ -543,14 +540,18 @@ def forward( # type: ignore[override] recv_topk_weights, ) ctx.handle_mem = handle_mem - ctx.grad_tokens_buf = grad_tokens_buf - ctx.grad_topk_weights_buf = grad_topk_weights_buf + ctx.zero_copy = zero_copy + # Stash the symm-mem slot pointers so bwd can enforce alias of the + # grad inputs. In non-zero-copy mode the slots are fresh per call; + # no enforcement is meaningful, so leave the pointers as None. + ctx.dispatch_symm_ptr = recv_tokens.data_ptr() if zero_copy else None + ctx.dispatch_w_symm_ptr = recv_topk_weights.data_ptr() if zero_copy else None ctx.tokens_shape = tokens.shape ctx.tokens_dtype = tokens.dtype ctx.topk_weights_shape = topk_weights.shape - ctx.topk_weights_dtype = topk_weights.dtype ctx.tokens_T_flat = tokens.numel() // tokens.shape[-1] ctx.topk_T_flat = topk_weights.numel() // topk_weights.shape[-1] + ctx.top_k = topk_weights.shape[-1] ctx.recv_capacity = recv_tokens.shape[0] ctx.hidden_dim = tokens.shape[-1] ctx.mark_non_differentiable(token_counts) @@ -560,8 +561,7 @@ def forward( # type: ignore[override] @staticmethod def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: ignore[override] - """Dispatch backward; grad_tokens uses the buffer's symm-mem slot in - zero-copy mode or a fresh HBM tensor otherwise.""" + """Dispatch bwd; in zero-copy the grad inputs must alias the symm-mem slots.""" device = ctx.handle_mem.device if g_recv_tokens is None: g_recv_tokens = torch.zeros( @@ -573,14 +573,24 @@ def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: g_recv_tokens = g_recv_tokens.contiguous() if not g_recv_topk_weights.is_contiguous(): g_recv_topk_weights = g_recv_topk_weights.contiguous() - if ctx.grad_tokens_buf is not None: - # Zero-copy: narrow the persistent symm-mem slot to this call's leading dim. - grad_tokens = ctx.grad_tokens_buf.narrow(0, 0, ctx.tokens_T_flat) - else: - grad_tokens = torch.empty( - ctx.tokens_T_flat, ctx.hidden_dim, dtype=ctx.tokens_dtype, device=device - ) - grad_topk_weights = ctx.grad_topk_weights_buf.narrow(0, 0, ctx.topk_T_flat) + if ctx.zero_copy: + if g_recv_tokens.data_ptr() != ctx.dispatch_symm_ptr: + raise RuntimeError( + "ep_dispatch bwd: zero-copy mode requires g_recv_tokens to alias " + "buffer.dispatch_symm_buf (write MLP_bwd's grad into that slot; " + "no implicit copy)." + ) + if g_recv_topk_weights.data_ptr() != ctx.dispatch_w_symm_ptr: + raise RuntimeError( + "ep_dispatch bwd: zero-copy mode requires g_recv_topk_weights to alias " + "buffer.dispatch_w_symm_buf (no implicit copy)." + ) + grad_tokens = torch.empty( + ctx.tokens_T_flat, ctx.hidden_dim, dtype=ctx.tokens_dtype, device=device + ) + grad_topk_weights = torch.empty( + ctx.topk_T_flat, ctx.top_k, dtype=torch.float32, device=device + ) torch.ops.transformer_engine_ep.dispatch_bwd( ctx.handle_mem, g_recv_tokens, @@ -588,58 +598,54 @@ def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: grad_tokens, grad_topk_weights, ) - # Reshape back to the original input shape so autograd's grad slot matches. - grad_tokens_out = grad_tokens.view(ctx.tokens_shape) - grad_topk_weights_out = grad_topk_weights.view(ctx.topk_weights_shape) return ( None, # handle_mem None, # top_k None, # alignment + None, # zero_copy None, # recv_tokens None, # recv_topk_weights None, # token_counts - None, # grad_tokens_buf - None, # grad_topk_weights_buf None, # topk_idx - grad_tokens_out, - grad_topk_weights_out, + grad_tokens.view(ctx.tokens_shape), + grad_topk_weights.view(ctx.topk_weights_shape), ) class _EpCombine(torch.autograd.Function): """Autograd-aware combine. Zero-copy mode requires expert_out to alias - buffer.combine_in (no implicit staging), and that storage is reused as - the grad slot in bwd. Non-zero-copy mode reads expert_out directly and - allocates the bwd grad slot fresh. Caller pre-applies topk weighting. + combine_symm_buf (no implicit staging), and that storage is reused as the + bwd grad slot. Non-zero-copy mode reads expert_out directly and allocates + the bwd grad slot fresh. Caller pre-applies topk weighting. """ @staticmethod def forward( # type: ignore[override] ctx, handle_mem: torch.Tensor, - combine_in: Optional[torch.Tensor], + combine_symm_buf: Optional[torch.Tensor], num_local_tokens: int, hidden_dim: int, zero_copy: bool, expert_out: torch.Tensor, ): - """Combine fwd; zero-copy requires expert_out to alias combine_in.""" + """Combine fwd; zero-copy requires expert_out to alias combine_symm_buf.""" if zero_copy: - if combine_in is None: + if combine_symm_buf is None: raise RuntimeError( - "ep_combine: zero-copy mode requires buffer.combine_in to be allocated." + "ep_combine: zero-copy mode requires buffer.combine_symm_buf to be allocated." ) - if combine_in.data_ptr() != expert_out.data_ptr(): + if combine_symm_buf.data_ptr() != expert_out.data_ptr(): raise RuntimeError( "ep_combine: zero-copy mode requires expert_out to alias " - "buffer.combine_in (write expert outputs directly into that slot; " + "buffer.combine_symm_buf (write expert outputs directly into that slot; " "no implicit copy)." ) device = expert_out.device result = torch.empty(num_local_tokens, hidden_dim, dtype=expert_out.dtype, device=device) torch.ops.transformer_engine_ep.combine(handle_mem, expert_out, result) ctx.handle_mem = handle_mem - ctx.combine_in = combine_in # None in non-zero-copy; reused as grad slot otherwise + ctx.combine_symm_buf = combine_symm_buf # reused as grad slot in zero-copy ctx.zero_copy = zero_copy ctx.recv_capacity = expert_out.shape[0] ctx.hidden_dim = expert_out.shape[-1] @@ -648,11 +654,11 @@ def forward( # type: ignore[override] @staticmethod def backward(ctx, g_result): # type: ignore[override] - """Combine bwd; writes into combine_in in zero-copy or a fresh slot otherwise.""" + """Combine bwd; writes into combine_symm_buf in zero-copy or a fresh slot otherwise.""" if not g_result.is_contiguous(): g_result = g_result.contiguous() if ctx.zero_copy: - grad_combine_in = ctx.combine_in + grad_combine_in = ctx.combine_symm_buf else: grad_combine_in = torch.empty( ctx.recv_capacity, @@ -663,7 +669,7 @@ def backward(ctx, g_result): # type: ignore[override] torch.ops.transformer_engine_ep.combine_bwd(ctx.handle_mem, g_result, grad_combine_in) return ( None, # handle_mem - None, # combine_in + None, # combine_symm_buf None, # num_local_tokens None, # hidden_dim None, # zero_copy @@ -696,8 +702,8 @@ def ep_dispatch( """ _require_bf16("tokens", tokens) if buffer.zero_copy: - recv_tokens = buffer.recv_tokens - recv_topk_weights = buffer.recv_topk_weights + recv_tokens = buffer.dispatch_symm_buf + recv_topk_weights = buffer.dispatch_w_symm_buf else: recv_tokens = torch.empty( buffer.recv_capacity_per_rank, @@ -712,11 +718,10 @@ def ep_dispatch( buffer.handle_mem, buffer.top_k, buffer.alignment, + buffer.zero_copy, recv_tokens, recv_topk_weights, buffer.token_counts, - buffer.grad_tokens, - buffer.grad_topk_weights, topk_idx, tokens, topk_weights, @@ -731,7 +736,7 @@ def ep_combine( ): """Combine expert outputs back to the source rank, with autograd. Caller pre-applies topk weighting. Zero-copy mode requires expert_out to alias - buffer.combine_in (write expert outputs into that slot directly). + buffer.combine_symm_buf (write expert outputs into that slot directly). Result shape is (num_local_tokens, buffer.hidden_dim); defaults to buffer.max_tokens_per_rank rows. @@ -741,7 +746,7 @@ def ep_combine( num_local_tokens = buffer.max_tokens_per_rank return _EpCombine.apply( buffer.handle_mem, - buffer.combine_in, + buffer.combine_symm_buf, num_local_tokens, buffer.hidden_dim, buffer.zero_copy, From 91edff0f670a67e324fa78dae33f7e88283a6c7a Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 21:41:33 -0700 Subject: [PATCH 12/39] EP PyTorch: warn that ep_bootstrap(zero_copy=True) is experimental Signed-off-by: Phuong Nguyen --- transformer_engine/pytorch/ep.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index ec99e1250a..6fc69df92b 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -6,6 +6,7 @@ from __future__ import annotations import atexit +import warnings from typing import Optional import torch @@ -129,6 +130,12 @@ def ep_bootstrap( if ep_group.size() < 2: raise ValueError(f"ep_bootstrap requires ep_group.size() >= 2 (got {ep_group.size()}).") _check_nccl_runtime_version() + if zero_copy: + warnings.warn( + "ep_bootstrap(zero_copy=True) is experimental; the symm-mem IO path " + "and its alias contracts on EpBuffer slots are subject to change.", + stacklevel=2, + ) # Materialize the PG's NCCL comm before borrowing its raw handle. dist.barrier(group=ep_group, device_ids=[torch.cuda.current_device()]) From 0415e3a6d68fbb22639e92642dc8742b94e34b01 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Jun 2026 00:23:34 +0000 Subject: [PATCH 13/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/ep.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 6fc69df92b..9789b587d1 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -323,11 +323,7 @@ def _check(t: torch.Tensor, name: str, shape: tuple, dtype: torch.dtype) -> torc inst.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) if inst.zero_copy: - if ( - dispatch_symm_buf is None - or dispatch_w_symm_buf is None - or combine_symm_buf is None - ): + if dispatch_symm_buf is None or dispatch_w_symm_buf is None or combine_symm_buf is None: raise ValueError( "EpBuffer.from_external: zero-copy mode requires dispatch_symm_buf, " "dispatch_w_symm_buf, and combine_symm_buf (all symm-mem-backed)." From bab2bdc6ffabdf1733ff11ac6d9ff11a7bbea006 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 10 Jun 2026 17:30:04 -0700 Subject: [PATCH 14/39] EP PyTorch: wire test_ep.py into L1 distributed QA suite Signed-off-by: Phuong Nguyen --- qa/L1_pytorch_distributed_unittest/test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 7eb34a62e4..50a51353d1 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -50,6 +50,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_use python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_newton_schulz.xml $TE_PATH/tests/pytorch/distributed/test_newton_schulz.py || test_fail "test_newton_schulz.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_ep.xml $TE_PATH/tests/pytorch/distributed/test_ep.py || test_fail "test_ep.py" # debug tests From 669603031eee020b5cea09a67690df2ef4386565 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 10 Jun 2026 17:36:19 -0700 Subject: [PATCH 15/39] EP PyTorch: validate contiguity of dispatch/combine inputs and topk_weights fp32 dtype Signed-off-by: Phuong Nguyen --- transformer_engine/pytorch/csrc/extensions/ep.cpp | 3 +++ transformer_engine/pytorch/ep.py | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/transformer_engine/pytorch/csrc/extensions/ep.cpp b/transformer_engine/pytorch/csrc/extensions/ep.cpp index 8624605edd..e340b41acf 100644 --- a/transformer_engine/pytorch/csrc/extensions/ep.cpp +++ b/transformer_engine/pytorch/csrc/extensions/ep.cpp @@ -183,6 +183,8 @@ void ep_dispatch(at::Tensor handle_mem, at::Tensor topk_idx, at::Tensor tokens, NVTE_CHECK(topk_weights.dim() >= 2, "topk_weights must be at least 2D [..., top_k]"); NVTE_CHECK(recv_tokens.dim() >= 2, "recv_tokens must be at least 2D [..., recv_pr, H]"); check_topk_idx_int64(topk_idx); + NVTE_CHECK(tokens.is_contiguous(), "tokens must be contiguous"); + NVTE_CHECK(topk_weights.is_contiguous(), "topk_weights must be contiguous"); const size_t H = static_cast(tokens.size(-1)); const size_t T_flat = tokens.numel() / H; @@ -225,6 +227,7 @@ void ep_combine(at::Tensor handle_mem, at::Tensor expert_out, at::Tensor result) auto stream = at::cuda::getCurrentCUDAStream().stream(); NVTE_CHECK(expert_out.dim() >= 2, "expert_out must be at least 2D [..., recv_pr, H]"); NVTE_CHECK(result.dim() >= 2, "result must be at least 2D [..., H]"); + NVTE_CHECK(expert_out.is_contiguous(), "expert_out must be contiguous"); const size_t H = static_cast(expert_out.size(-1)); const size_t recv_pr = expert_out.numel() / H; diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 9789b587d1..89951bc67c 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -704,6 +704,11 @@ def ep_dispatch( slots; otherwise they are freshly allocated. token_counts is non-diff. """ _require_bf16("tokens", tokens) + if topk_weights.dtype is not torch.float32: + raise TypeError( + f"topk_weights must be float32; got dtype={topk_weights.dtype}. " + "Cast with topk_weights.float() before calling." + ) if buffer.zero_copy: recv_tokens = buffer.dispatch_symm_buf recv_topk_weights = buffer.dispatch_w_symm_buf From 090d05b6b70c423c1597f3676dc41b60f3110354 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 11 Jun 2026 09:15:45 -0700 Subject: [PATCH 16/39] EP PyTorch: check topk_idx/topk_weights token count matches tokens in ep_dispatch Signed-off-by: Phuong Nguyen --- transformer_engine/pytorch/csrc/extensions/ep.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/transformer_engine/pytorch/csrc/extensions/ep.cpp b/transformer_engine/pytorch/csrc/extensions/ep.cpp index e340b41acf..a7c15a1140 100644 --- a/transformer_engine/pytorch/csrc/extensions/ep.cpp +++ b/transformer_engine/pytorch/csrc/extensions/ep.cpp @@ -193,6 +193,10 @@ void ep_dispatch(at::Tensor handle_mem, at::Tensor topk_idx, at::Tensor tokens, NVTE_CHECK(static_cast(topk_weights.size(-1)) == topk_n, "topk_weights last dim must equal topk_idx last dim"); + NVTE_CHECK(static_cast(topk_idx.numel()) == T_flat * topk_n, + "topk_idx token count must equal tokens token count"); + NVTE_CHECK(static_cast(topk_weights.numel()) == T_flat * topk_n, + "topk_weights token count must equal tokens token count"); NVTE_CHECK(static_cast(recv_topk_weights.numel()) == recv_pr, "recv_topk_weights total size must equal recv_tokens recv_pr"); NVTE_CHECK(recv_tokens.scalar_type() == tokens.scalar_type(), "recv_tokens dtype (", From 926bef2ac178f3bd1b59edc694efb2a9a8012594 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 11 Jun 2026 23:55:09 -0700 Subject: [PATCH 17/39] EP PyTorch: move symm-mem allocation out of EpBuffer; make ep_dispatch/ep_combine accept caller-supplied output buffers with C++ symm-mem checks under zero-copy Signed-off-by: Phuong Nguyen --- tests/pytorch/distributed/run_ep.py | 42 ++- .../pytorch/csrc/extensions/ep.cpp | 30 +++ transformer_engine/pytorch/ep.py | 243 ++++-------------- 3 files changed, 120 insertions(+), 195 deletions(-) diff --git a/tests/pytorch/distributed/run_ep.py b/tests/pytorch/distributed/run_ep.py index 0e29ca7eae..b09071a57c 100644 --- a/tests/pytorch/distributed/run_ep.py +++ b/tests/pytorch/distributed/run_ep.py @@ -173,8 +173,11 @@ def test_dispatch_fwd_bwd(self): buf = self._make_buffer() topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) tokens_p = tokens.detach().clone().requires_grad_(True) - recv_t, _recv_w, _tc = ep_dispatch(buf, tokens_p, topk_idx, w) - loss = 0.5 * (recv_t.float() ** 2).sum() + recv_t, recv_w, _tc = ep_dispatch(buf, tokens_p, topk_idx, w) + # Pull recv_w into the loss with a zero scale so both dispatch outputs + # contribute a (possibly-zero) grad — backward respects user-supplied + # grad inputs and won't fabricate Nones into zeros. + loss = 0.5 * (recv_t.float() ** 2).sum() + 0.0 * recv_w.float().sum() loss.backward() torch.cuda.synchronize() torch.testing.assert_close( @@ -293,6 +296,41 @@ def bwd(k): rtol=5e-2, ) + # Caller-supplied output buffers (autograd) + + def test_dispatch_caller_recv_buffers_autograd(self): + """ep_dispatch with caller-supplied recv buffers; fwd+bwd matches default-alloc grads.""" + buf = self._make_buffer() + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + recv_tokens, recv_w, _ = self._make_raw_recv() + tokens_p = tokens.detach().clone().requires_grad_(True) + rt, rw, _tc = ep_dispatch( + buf, tokens_p, topk_idx, w, recv_tokens=recv_tokens, recv_topk_weights=recv_w + ) + self.assertEqual(rt.data_ptr(), recv_tokens.data_ptr()) + self.assertEqual(rw.data_ptr(), recv_w.data_ptr()) + (0.5 * (rt.float() ** 2).sum() + 0.0 * rw.float().sum()).backward() + torch.cuda.synchronize() + torch.testing.assert_close( + tokens_p.grad.float(), tokens.float() * float(TOP_K), atol=5e-2, rtol=5e-2 + ) + + def test_combine_grad_expert_out_autograd(self): + """ep_combine with caller-supplied grad_expert_out; bwd writes into that slot.""" + buf = self._make_buffer() + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + tokens_p = tokens.detach().clone().requires_grad_(True) + recv_t, recv_w, _ = ep_dispatch(buf, tokens_p, topk_idx, w) + eo = self._weighted(recv_t, recv_w) + grad_eo = torch.empty_like(eo) + gp = grad_eo.data_ptr() + out = ep_combine(buf, eo, grad_expert_out=grad_eo) + (0.5 * (out.float() ** 2).sum()).backward() + torch.cuda.synchronize() + self.assertEqual(grad_eo.data_ptr(), gp) + torch.testing.assert_close(out.float(), tokens.float(), atol=5e-2, rtol=5e-2) + torch.testing.assert_close(tokens_p.grad.float(), tokens.float(), atol=5e-2, rtol=5e-2) + # Input validation def test_topk_int32_raises_clear_error(self): diff --git a/transformer_engine/pytorch/csrc/extensions/ep.cpp b/transformer_engine/pytorch/csrc/extensions/ep.cpp index a7c15a1140..9953bc3993 100644 --- a/transformer_engine/pytorch/csrc/extensions/ep.cpp +++ b/transformer_engine/pytorch/csrc/extensions/ep.cpp @@ -81,6 +81,29 @@ NVTECommWindow maybe_make_window(const at::Tensor& t) { #endif } +// When zero-copy is enabled, the named tensor must be symm-mem-backed on the +// EP group. Throws a clear error otherwise. No-op when zero-copy is off or +// symm-mem support isn't compiled in. Mirrors maybe_make_window's resolution +// path but turns the "not symm-mem" outcome into a hard error. +void check_symm_mem_required(const at::Tensor& t, const char* name) { +#ifdef NCCL_HAS_SYMMEM_SUPPORT + if (!g_zero_copy_enabled.load(std::memory_order_relaxed)) return; + NVTE_CHECK(!g_ep_group_name.empty(), + "Zero-copy is enabled but EP group name is unset; call ep_initialize first."); + c10::intrusive_ptr sm; + try { + sm = c10d::symmetric_memory::rendezvous(t, g_ep_group_name); + } catch (const std::exception&) { + sm = nullptr; + } + NVTE_CHECK(sm != nullptr, "ep zero-copy: ", name, + " must be symm-mem-backed on the EP group (allocate via symm_mem_alloc)."); +#else + (void)t; + (void)name; +#endif +} + // The backend only accepts int64 topk_idx. The PyTorch wrapper enforces this // at the boundary so the per-step ops don't need an upcast workspace. void check_topk_idx_int64(at::Tensor topk_idx) { @@ -202,6 +225,8 @@ void ep_dispatch(at::Tensor handle_mem, at::Tensor topk_idx, at::Tensor tokens, NVTE_CHECK(recv_tokens.scalar_type() == tokens.scalar_type(), "recv_tokens dtype (", c10::toString(recv_tokens.scalar_type()), ") must match tokens dtype (", c10::toString(tokens.scalar_type()), ")"); + check_symm_mem_required(recv_tokens, "recv_tokens"); + check_symm_mem_required(recv_topk_weights, "recv_topk_weights"); auto tok_dtype = GetTransformerEngineDType(tokens.scalar_type()); auto handle_mem_te = makeTransformerEngineTensor( @@ -241,6 +266,7 @@ void ep_combine(at::Tensor handle_mem, at::Tensor expert_out, at::Tensor result) NVTE_CHECK(result.scalar_type() == expert_out.scalar_type(), "result dtype (", c10::toString(result.scalar_type()), ") must match expert_out dtype (", c10::toString(expert_out.scalar_type()), ")"); + check_symm_mem_required(expert_out, "expert_out"); auto eo_dtype = GetTransformerEngineDType(expert_out.scalar_type()); auto handle_mem_te = makeTransformerEngineTensor( @@ -275,6 +301,8 @@ void ep_dispatch_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor g_recv_t NVTE_CHECK(grad_tokens.scalar_type() == grad.scalar_type(), "grad_tokens dtype (", c10::toString(grad_tokens.scalar_type()), ") must match grad dtype (", c10::toString(grad.scalar_type()), ")"); + check_symm_mem_required(grad, "grad (dispatch_bwd input)"); + check_symm_mem_required(g_recv_topk_weights, "g_recv_topk_weights"); auto g_dtype = GetTransformerEngineDType(grad.scalar_type()); auto handle_mem_te = makeTransformerEngineTensor( @@ -306,6 +334,8 @@ void ep_combine_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor grad_expe NVTE_CHECK(grad_expert_out.scalar_type() == grad.scalar_type(), "grad_expert_out dtype (", c10::toString(grad_expert_out.scalar_type()), ") must match grad dtype (", c10::toString(grad.scalar_type()), ")"); + check_symm_mem_required(grad, "grad (combine_bwd input)"); + check_symm_mem_required(grad_expert_out, "grad_expert_out"); auto g_dtype = GetTransformerEngineDType(grad.scalar_type()); auto handle_mem_te = makeTransformerEngineTensor( diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 89951bc67c..537978f593 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -179,40 +179,22 @@ def ep_finalize() -> None: class EpBuffer: - """Per-microbatch EP layer state: routing handle + persistent payload slots. - - Owns the per-call ``handle_mem`` routing scratch and the payload buffers - consumed by :func:`ep_dispatch` / :func:`ep_combine`. Allocate one - EpBuffer per concurrently-in-flight call on the layer (one per PP-1F1B - microbatch); sharing across overlapping calls clobbers tensors the - earlier bwd still reads. Call ``record_stream`` from streams other than - the allocation stream. - - Cross-rank payload slots are symm-mem-backed when ``ep_bootstrap`` was - called with ``zero_copy=True`` (requires ``ep_group``); otherwise plain - HBM. + """Per-microbatch EP layer state holding handle_mem and token_counts. + Cross-rank payload buffers are caller-supplied to ep_dispatch and + ep_combine; allocate via symm_mem_alloc in zero-copy mode. + Use one EpBuffer per concurrently-in-flight call (e.g. per PP-1F1B microbatch). """ __slots__ = ( - # routing "handle_mem", "top_k", "alignment", - # layer config "max_tokens_per_rank", "recv_capacity_per_rank", "hidden_dim", "num_local_experts", "payload_dtype", "device", - # Symm-mem slots (zero-copy only). Each is reused across fwd and bwd: - # dispatch_symm_buf: fwd out (recv_tokens) / bwd in (g_recv_tokens) - # dispatch_w_symm_buf: fwd out (recv_topk_w) / bwd in (g_recv_topk_w) - # combine_symm_buf: fwd in (expert_out) / bwd out (g_expert_out) - "dispatch_symm_buf", - "dispatch_w_symm_buf", - "combine_symm_buf", - # Per-rank scratch (always HBM). "token_counts", "zero_copy", ) @@ -225,7 +207,6 @@ def __init__( hidden_dim: int, num_local_experts: int, alignment: int = 0, - ep_group: Optional[dist.ProcessGroup] = None, payload_dtype: torch.dtype = torch.bfloat16, device: Optional[torch.device] = None, ) -> None: @@ -242,30 +223,10 @@ def __init__( self.num_local_experts = int(num_local_experts) self.payload_dtype = payload_dtype self.device = device + self.zero_copy = bool(tex.ep_get_zero_copy()) size_bytes = tex.ep_handle_mem_size(self.top_k, self.alignment) self.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) - - recv_shape = (self.recv_capacity_per_rank, self.hidden_dim) - zero_copy = bool(tex.ep_get_zero_copy()) - self.zero_copy = zero_copy - if zero_copy: - if ep_group is None: - raise ValueError("EpBuffer requires ep_group when ep_bootstrap(zero_copy=True).") - self.dispatch_symm_buf = symm_mem_alloc( - recv_shape, payload_dtype, ep_group, device=device - ) - self.dispatch_w_symm_buf = symm_mem_alloc( - (self.recv_capacity_per_rank,), torch.float32, ep_group, device=device - ) - self.combine_symm_buf = symm_mem_alloc( - recv_shape, payload_dtype, ep_group, device=device - ) - else: - self.dispatch_symm_buf = None - self.dispatch_w_symm_buf = None - self.combine_symm_buf = None - # token_counts is local-only routing scratch; always plain HBM. self.token_counts = torch.empty(self.num_local_experts, dtype=torch.int32, device=device) @classmethod @@ -277,37 +238,19 @@ def from_external( hidden_dim: int, num_local_experts: int, *, - dispatch_symm_buf: Optional[torch.Tensor] = None, - dispatch_w_symm_buf: Optional[torch.Tensor] = None, - combine_symm_buf: Optional[torch.Tensor] = None, token_counts: Optional[torch.Tensor] = None, alignment: int = 0, payload_dtype: torch.dtype = torch.bfloat16, device: Optional[torch.device] = None, ) -> "EpBuffer": - """Construct from caller-allocated buffers. - - In zero-copy mode dispatch_symm_buf, dispatch_w_symm_buf, and - combine_symm_buf must all be supplied and symm-mem-backed; in - non-zero-copy mode they must all be None (ops allocate per call). - handle_mem is always allocated fresh. - """ + """Construct from a caller-allocated token_counts; handle_mem is always fresh.""" if device is None: device = torch.device("cuda", torch.cuda.current_device()) alignment = int(alignment) if alignment > 1 and (alignment & (alignment - 1)) != 0: raise ValueError(f"alignment must be 0, 1, or a power of two (got {alignment}).") - recv_shape = (recv_capacity_per_rank, hidden_dim) - recv_w_shape = (recv_capacity_per_rank,) counts_shape = (num_local_experts,) - def _check(t: torch.Tensor, name: str, shape: tuple, dtype: torch.dtype) -> torch.Tensor: - if tuple(t.shape) != shape: - raise ValueError(f"{name} shape {tuple(t.shape)} != expected {shape}") - if t.dtype != dtype: - raise ValueError(f"{name} dtype {t.dtype} != expected {dtype}") - return t - inst = cls.__new__(cls) inst.top_k = int(top_k) inst.alignment = alignment @@ -322,53 +265,22 @@ def _check(t: torch.Tensor, name: str, shape: tuple, dtype: torch.dtype) -> torc size_bytes = tex.ep_handle_mem_size(inst.top_k, inst.alignment) inst.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) - if inst.zero_copy: - if dispatch_symm_buf is None or dispatch_w_symm_buf is None or combine_symm_buf is None: + if token_counts is not None: + if tuple(token_counts.shape) != counts_shape: raise ValueError( - "EpBuffer.from_external: zero-copy mode requires dispatch_symm_buf, " - "dispatch_w_symm_buf, and combine_symm_buf (all symm-mem-backed)." + f"token_counts shape {tuple(token_counts.shape)} != expected {counts_shape}" ) - inst.dispatch_symm_buf = _check( - dispatch_symm_buf, "dispatch_symm_buf", recv_shape, payload_dtype - ) - inst.dispatch_w_symm_buf = _check( - dispatch_w_symm_buf, "dispatch_w_symm_buf", recv_w_shape, torch.float32 - ) - inst.combine_symm_buf = _check( - combine_symm_buf, "combine_symm_buf", recv_shape, payload_dtype - ) + if token_counts.dtype != torch.int32: + raise ValueError(f"token_counts dtype {token_counts.dtype} != expected int32") + inst.token_counts = token_counts else: - if ( - dispatch_symm_buf is not None - or dispatch_w_symm_buf is not None - or combine_symm_buf is not None - ): - raise ValueError( - "EpBuffer.from_external: dispatch_symm_buf / dispatch_w_symm_buf / " - "combine_symm_buf are only used in zero-copy mode." - ) - inst.dispatch_symm_buf = None - inst.dispatch_w_symm_buf = None - inst.combine_symm_buf = None - inst.token_counts = ( - _check(token_counts, "token_counts", counts_shape, torch.int32) - if token_counts is not None - else torch.empty(counts_shape, dtype=torch.int32, device=device) - ) + inst.token_counts = torch.empty(counts_shape, dtype=torch.int32, device=device) return inst def record_stream(self, stream: torch.cuda.Stream) -> None: - """Record stream as a user of all owned tensors so the caching allocator - defers reclaim until stream has caught up.""" - for t in ( - self.handle_mem, - self.dispatch_symm_buf, - self.dispatch_w_symm_buf, - self.combine_symm_buf, - self.token_counts, - ): - if t is not None: - t.record_stream(stream) + """Defer caching-allocator reclaim of owned tensors until stream catches up.""" + self.handle_mem.record_stream(stream) + self.token_counts.record_stream(stream) # torch.library custom ops (so they don't graph-break under torch.compile) @@ -510,11 +422,7 @@ def _ep_combine_raw(buffer: "EpBuffer", expert_out: torch.Tensor, result: torch. class _EpDispatch(torch.autograd.Function): - """Autograd-aware prepare + dispatch. Fwd produces recv_tokens (alias of - dispatch_symm_buf in zero-copy, fresh otherwise). Zero-copy bwd requires - the incoming grads to alias dispatch_symm_buf / dispatch_w_symm_buf - (no implicit staging). Fwd/bwd share handle_mem; do not re-run ep_prepare. - """ + """Autograd prepare+dispatch; bwd uses user-supplied grad inputs as-is.""" @staticmethod def forward( # type: ignore[override] @@ -522,7 +430,6 @@ def forward( # type: ignore[override] handle_mem: torch.Tensor, top_k: int, alignment: int, - zero_copy: bool, recv_tokens: torch.Tensor, recv_topk_weights: torch.Tensor, token_counts: torch.Tensor, @@ -530,7 +437,7 @@ def forward( # type: ignore[override] tokens: torch.Tensor, topk_weights: torch.Tensor, ): - """Prepare + dispatch; saves shapes for the bwd pass.""" + """Prepare + dispatch fwd.""" torch.ops.transformer_engine_ep.prepare( handle_mem, top_k, topk_idx, token_counts, alignment ) @@ -543,12 +450,6 @@ def forward( # type: ignore[override] recv_topk_weights, ) ctx.handle_mem = handle_mem - ctx.zero_copy = zero_copy - # Stash the symm-mem slot pointers so bwd can enforce alias of the - # grad inputs. In non-zero-copy mode the slots are fresh per call; - # no enforcement is meaningful, so leave the pointers as None. - ctx.dispatch_symm_ptr = recv_tokens.data_ptr() if zero_copy else None - ctx.dispatch_w_symm_ptr = recv_topk_weights.data_ptr() if zero_copy else None ctx.tokens_shape = tokens.shape ctx.tokens_dtype = tokens.dtype ctx.topk_weights_shape = topk_weights.shape @@ -564,30 +465,8 @@ def forward( # type: ignore[override] @staticmethod def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: ignore[override] - """Dispatch bwd; in zero-copy the grad inputs must alias the symm-mem slots.""" + """Dispatch bwd; uses user-supplied grad inputs as-is.""" device = ctx.handle_mem.device - if g_recv_tokens is None: - g_recv_tokens = torch.zeros( - ctx.recv_capacity, ctx.hidden_dim, dtype=ctx.tokens_dtype, device=device - ) - if g_recv_topk_weights is None: - g_recv_topk_weights = torch.zeros(ctx.recv_capacity, dtype=torch.float32, device=device) - if not g_recv_tokens.is_contiguous(): - g_recv_tokens = g_recv_tokens.contiguous() - if not g_recv_topk_weights.is_contiguous(): - g_recv_topk_weights = g_recv_topk_weights.contiguous() - if ctx.zero_copy: - if g_recv_tokens.data_ptr() != ctx.dispatch_symm_ptr: - raise RuntimeError( - "ep_dispatch bwd: zero-copy mode requires g_recv_tokens to alias " - "buffer.dispatch_symm_buf (write MLP_bwd's grad into that slot; " - "no implicit copy)." - ) - if g_recv_topk_weights.data_ptr() != ctx.dispatch_w_symm_ptr: - raise RuntimeError( - "ep_dispatch bwd: zero-copy mode requires g_recv_topk_weights to alias " - "buffer.dispatch_w_symm_buf (no implicit copy)." - ) grad_tokens = torch.empty( ctx.tokens_T_flat, ctx.hidden_dim, dtype=ctx.tokens_dtype, device=device ) @@ -605,7 +484,6 @@ def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: None, # handle_mem None, # top_k None, # alignment - None, # zero_copy None, # recv_tokens None, # recv_topk_weights None, # token_counts @@ -616,66 +494,36 @@ def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: class _EpCombine(torch.autograd.Function): - """Autograd-aware combine. Zero-copy mode requires expert_out to alias - combine_symm_buf (no implicit staging), and that storage is reused as the - bwd grad slot. Non-zero-copy mode reads expert_out directly and allocates - the bwd grad slot fresh. Caller pre-applies topk weighting. - """ + """Autograd combine; bwd writes into grad_expert_out, or expert_out's storage if None.""" @staticmethod def forward( # type: ignore[override] ctx, handle_mem: torch.Tensor, - combine_symm_buf: Optional[torch.Tensor], num_local_tokens: int, hidden_dim: int, - zero_copy: bool, + grad_expert_out: Optional[torch.Tensor], expert_out: torch.Tensor, ): - """Combine fwd; zero-copy requires expert_out to alias combine_symm_buf.""" - if zero_copy: - if combine_symm_buf is None: - raise RuntimeError( - "ep_combine: zero-copy mode requires buffer.combine_symm_buf to be allocated." - ) - if combine_symm_buf.data_ptr() != expert_out.data_ptr(): - raise RuntimeError( - "ep_combine: zero-copy mode requires expert_out to alias " - "buffer.combine_symm_buf (write expert outputs directly into that slot; " - "no implicit copy)." - ) + """Combine fwd; stashes grad_expert_out (or expert_out) as the bwd output slot.""" device = expert_out.device result = torch.empty(num_local_tokens, hidden_dim, dtype=expert_out.dtype, device=device) torch.ops.transformer_engine_ep.combine(handle_mem, expert_out, result) ctx.handle_mem = handle_mem - ctx.combine_symm_buf = combine_symm_buf # reused as grad slot in zero-copy - ctx.zero_copy = zero_copy - ctx.recv_capacity = expert_out.shape[0] - ctx.hidden_dim = expert_out.shape[-1] - ctx.expert_out_dtype = expert_out.dtype + ctx.grad_expert_out = grad_expert_out if grad_expert_out is not None else expert_out return result @staticmethod def backward(ctx, g_result): # type: ignore[override] - """Combine bwd; writes into combine_symm_buf in zero-copy or a fresh slot otherwise.""" if not g_result.is_contiguous(): g_result = g_result.contiguous() - if ctx.zero_copy: - grad_combine_in = ctx.combine_symm_buf - else: - grad_combine_in = torch.empty( - ctx.recv_capacity, - ctx.hidden_dim, - dtype=ctx.expert_out_dtype, - device=ctx.handle_mem.device, - ) + grad_combine_in = ctx.grad_expert_out torch.ops.transformer_engine_ep.combine_bwd(ctx.handle_mem, g_result, grad_combine_in) return ( None, # handle_mem - None, # combine_symm_buf None, # num_local_tokens None, # hidden_dim - None, # zero_copy + None, # grad_expert_out grad_combine_in, ) @@ -696,12 +544,15 @@ def ep_dispatch( tokens: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, + *, + recv_tokens: Optional[torch.Tensor] = None, + recv_topk_weights: Optional[torch.Tensor] = None, ): - """Run prepare + dispatch with autograd. topk_idx must be int64. + """Prepare + dispatch with autograd. topk_idx must be int64. - Returns (recv_tokens, recv_topk_weights, token_counts). In zero-copy mode - recv_tokens / recv_topk_weights alias the buffer's persistent symm-mem - slots; otherwise they are freshly allocated. token_counts is non-diff. + recv_tokens / recv_topk_weights are used as-is if supplied, else allocated. + Zero-copy mode requires both to be supplied and symm-mem-backed. + Returns (recv_tokens, recv_topk_weights, token_counts); token_counts is non-diff. """ _require_bf16("tokens", tokens) if topk_weights.dtype is not torch.float32: @@ -709,16 +560,19 @@ def ep_dispatch( f"topk_weights must be float32; got dtype={topk_weights.dtype}. " "Cast with topk_weights.float() before calling." ) - if buffer.zero_copy: - recv_tokens = buffer.dispatch_symm_buf - recv_topk_weights = buffer.dispatch_w_symm_buf - else: + if buffer.zero_copy and (recv_tokens is None or recv_topk_weights is None): + raise ValueError( + "ep_dispatch: zero-copy mode requires caller-supplied recv_tokens and " + "recv_topk_weights (allocate via symm_mem_alloc)." + ) + if recv_tokens is None: recv_tokens = torch.empty( buffer.recv_capacity_per_rank, buffer.hidden_dim, dtype=buffer.payload_dtype, device=buffer.device, ) + if recv_topk_weights is None: recv_topk_weights = torch.empty( buffer.recv_capacity_per_rank, dtype=torch.float32, device=buffer.device ) @@ -726,7 +580,6 @@ def ep_dispatch( buffer.handle_mem, buffer.top_k, buffer.alignment, - buffer.zero_copy, recv_tokens, recv_topk_weights, buffer.token_counts, @@ -741,22 +594,26 @@ def ep_combine( expert_out: torch.Tensor, *, num_local_tokens: Optional[int] = None, + grad_expert_out: Optional[torch.Tensor] = None, ): - """Combine expert outputs back to the source rank, with autograd. Caller - pre-applies topk weighting. Zero-copy mode requires expert_out to alias - buffer.combine_symm_buf (write expert outputs into that slot directly). + """Combine with autograd; caller pre-applies topk weighting. - Result shape is (num_local_tokens, buffer.hidden_dim); defaults to - buffer.max_tokens_per_rank rows. + grad_expert_out is the slot the bwd writes into; if None, expert_out's storage is reused. + Zero-copy mode requires both expert_out and grad_expert_out to be symm-mem-backed. + Result shape is (num_local_tokens, buffer.hidden_dim); defaults to buffer.max_tokens_per_rank rows. """ _require_bf16("expert_out", expert_out) + if buffer.zero_copy and grad_expert_out is None: + raise ValueError( + "ep_combine: zero-copy mode requires caller-supplied grad_expert_out " + "(allocate via symm_mem_alloc)." + ) if num_local_tokens is None: num_local_tokens = buffer.max_tokens_per_rank return _EpCombine.apply( buffer.handle_mem, - buffer.combine_symm_buf, num_local_tokens, buffer.hidden_dim, - buffer.zero_copy, + grad_expert_out, expert_out, ) From 2ca14a739b153d5147d2d99bd9a36e7301cc484f Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 12 Jun 2026 00:00:04 -0700 Subject: [PATCH 18/39] EP PyTorch: enforce contiguous caller-supplied EP buffers in C++ and normalize bwd grad layout in Python Signed-off-by: Phuong Nguyen --- transformer_engine/pytorch/csrc/extensions/ep.cpp | 6 ++++++ transformer_engine/pytorch/ep.py | 4 +++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/csrc/extensions/ep.cpp b/transformer_engine/pytorch/csrc/extensions/ep.cpp index 9953bc3993..67e18cd70a 100644 --- a/transformer_engine/pytorch/csrc/extensions/ep.cpp +++ b/transformer_engine/pytorch/csrc/extensions/ep.cpp @@ -208,6 +208,8 @@ void ep_dispatch(at::Tensor handle_mem, at::Tensor topk_idx, at::Tensor tokens, check_topk_idx_int64(topk_idx); NVTE_CHECK(tokens.is_contiguous(), "tokens must be contiguous"); NVTE_CHECK(topk_weights.is_contiguous(), "topk_weights must be contiguous"); + NVTE_CHECK(recv_tokens.is_contiguous(), "recv_tokens must be contiguous"); + NVTE_CHECK(recv_topk_weights.is_contiguous(), "recv_topk_weights must be contiguous"); const size_t H = static_cast(tokens.size(-1)); const size_t T_flat = tokens.numel() / H; @@ -286,6 +288,8 @@ void ep_dispatch_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor g_recv_t NVTE_CHECK(grad.dim() >= 2, "grad must be at least 2D [..., recv_pr, H]"); NVTE_CHECK(grad_tokens.dim() >= 2, "grad_tokens must be at least 2D [..., H]"); NVTE_CHECK(grad_topk_weights.dim() >= 2, "grad_topk_weights must be at least 2D [..., top_k]"); + NVTE_CHECK(grad.is_contiguous(), "grad must be contiguous"); + NVTE_CHECK(g_recv_topk_weights.is_contiguous(), "g_recv_topk_weights must be contiguous"); const size_t H = static_cast(grad.size(-1)); const size_t recv_pr = grad.numel() / H; @@ -325,6 +329,8 @@ void ep_combine_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor grad_expe auto stream = at::cuda::getCurrentCUDAStream().stream(); NVTE_CHECK(grad.dim() >= 2, "grad must be at least 2D [..., H]"); NVTE_CHECK(grad_expert_out.dim() >= 2, "grad_expert_out must be at least 2D [..., recv_pr, H]"); + NVTE_CHECK(grad.is_contiguous(), "grad must be contiguous"); + NVTE_CHECK(grad_expert_out.is_contiguous(), "grad_expert_out must be contiguous"); const size_t H = static_cast(grad.size(-1)); const size_t T_flat = grad.numel() / H; diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 537978f593..0caf3b8c8b 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -465,8 +465,10 @@ def forward( # type: ignore[override] @staticmethod def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: ignore[override] - """Dispatch bwd; uses user-supplied grad inputs as-is.""" + """Dispatch bwd; normalizes grad-input layout, otherwise passes through.""" device = ctx.handle_mem.device + g_recv_tokens = g_recv_tokens.contiguous() + g_recv_topk_weights = g_recv_topk_weights.contiguous() grad_tokens = torch.empty( ctx.tokens_T_flat, ctx.hidden_dim, dtype=ctx.tokens_dtype, device=device ) From 8cb815fc2aec39ff1dc05da55f16d06db2642fea Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 23 Jun 2026 06:56:51 -0700 Subject: [PATCH 19/39] EP PyTorch: store autograd ctx tensors via save_for_backward, mark EP workspace buffers non-offloadable, and let backward stage autograd-allocated upstream grads instead of requiring symm-mem Signed-off-by: Phuong Nguyen --- .../pytorch/csrc/extensions/ep.cpp | 6 ++--- transformer_engine/pytorch/ep.py | 23 +++++++++++++------ 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/ep.cpp b/transformer_engine/pytorch/csrc/extensions/ep.cpp index 67e18cd70a..c349bcb431 100644 --- a/transformer_engine/pytorch/csrc/extensions/ep.cpp +++ b/transformer_engine/pytorch/csrc/extensions/ep.cpp @@ -305,8 +305,7 @@ void ep_dispatch_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor g_recv_t NVTE_CHECK(grad_tokens.scalar_type() == grad.scalar_type(), "grad_tokens dtype (", c10::toString(grad_tokens.scalar_type()), ") must match grad dtype (", c10::toString(grad.scalar_type()), ")"); - check_symm_mem_required(grad, "grad (dispatch_bwd input)"); - check_symm_mem_required(g_recv_topk_weights, "g_recv_topk_weights"); + // Upstream grads are autograd-allocated, so they take the staged-copy path. auto g_dtype = GetTransformerEngineDType(grad.scalar_type()); auto handle_mem_te = makeTransformerEngineTensor( @@ -340,7 +339,8 @@ void ep_combine_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor grad_expe NVTE_CHECK(grad_expert_out.scalar_type() == grad.scalar_type(), "grad_expert_out dtype (", c10::toString(grad_expert_out.scalar_type()), ") must match grad dtype (", c10::toString(grad.scalar_type()), ")"); - check_symm_mem_required(grad, "grad (combine_bwd input)"); + // grad is autograd-allocated (staged-copy path); grad_expert_out is the + // caller-supplied scatter target and must be symm-mem in zero-copy mode. check_symm_mem_required(grad_expert_out, "grad_expert_out"); auto g_dtype = GetTransformerEngineDType(grad.scalar_type()); diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 0caf3b8c8b..ef0e5937d9 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -14,6 +14,8 @@ import transformer_engine_torch as tex +from .cpu_offload import mark_not_offload + __all__ = [ "EpBuffer", @@ -228,6 +230,8 @@ def __init__( size_bytes = tex.ep_handle_mem_size(self.top_k, self.alignment) self.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) self.token_counts = torch.empty(self.num_local_experts, dtype=torch.int32, device=device) + # Persistent workspace; keep resident if activation CPU offloading is on. + mark_not_offload(self.handle_mem) @classmethod def from_external( @@ -264,6 +268,8 @@ def from_external( size_bytes = tex.ep_handle_mem_size(inst.top_k, inst.alignment) inst.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) + # Persistent workspace; keep resident if activation CPU offloading is on. + mark_not_offload(inst.handle_mem) if token_counts is not None: if tuple(token_counts.shape) != counts_shape: @@ -449,7 +455,7 @@ def forward( # type: ignore[override] recv_tokens, recv_topk_weights, ) - ctx.handle_mem = handle_mem + ctx.save_for_backward(handle_mem) ctx.tokens_shape = tokens.shape ctx.tokens_dtype = tokens.dtype ctx.topk_weights_shape = topk_weights.shape @@ -466,7 +472,8 @@ def forward( # type: ignore[override] @staticmethod def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: ignore[override] """Dispatch bwd; normalizes grad-input layout, otherwise passes through.""" - device = ctx.handle_mem.device + (handle_mem,) = ctx.saved_tensors + device = handle_mem.device g_recv_tokens = g_recv_tokens.contiguous() g_recv_topk_weights = g_recv_topk_weights.contiguous() grad_tokens = torch.empty( @@ -476,7 +483,7 @@ def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: ctx.topk_T_flat, ctx.top_k, dtype=torch.float32, device=device ) torch.ops.transformer_engine_ep.dispatch_bwd( - ctx.handle_mem, + handle_mem, g_recv_tokens, g_recv_topk_weights, grad_tokens, @@ -511,16 +518,18 @@ def forward( # type: ignore[override] device = expert_out.device result = torch.empty(num_local_tokens, hidden_dim, dtype=expert_out.dtype, device=device) torch.ops.transformer_engine_ep.combine(handle_mem, expert_out, result) - ctx.handle_mem = handle_mem - ctx.grad_expert_out = grad_expert_out if grad_expert_out is not None else expert_out + grad_combine_in = grad_expert_out if grad_expert_out is not None else expert_out + # bwd write target reused across microbatches; keep resident under CPU offloading. + mark_not_offload(grad_combine_in) + ctx.save_for_backward(handle_mem, grad_combine_in) return result @staticmethod def backward(ctx, g_result): # type: ignore[override] if not g_result.is_contiguous(): g_result = g_result.contiguous() - grad_combine_in = ctx.grad_expert_out - torch.ops.transformer_engine_ep.combine_bwd(ctx.handle_mem, g_result, grad_combine_in) + handle_mem, grad_combine_in = ctx.saved_tensors + torch.ops.transformer_engine_ep.combine_bwd(handle_mem, g_result, grad_combine_in) return ( None, # handle_mem None, # num_local_tokens From 53a3834e2c8a6b7624991ac84111c77038b1746b Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 23 Jun 2026 09:31:30 -0700 Subject: [PATCH 20/39] EP PyTorch: reword EP buffer comments and drop redundant recv_topk_weights loss term in dispatch autograd test Signed-off-by: Phuong Nguyen --- tests/pytorch/distributed/run_ep.py | 7 ++----- transformer_engine/pytorch/ep.py | 4 ++-- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/pytorch/distributed/run_ep.py b/tests/pytorch/distributed/run_ep.py index b09071a57c..7132f42815 100644 --- a/tests/pytorch/distributed/run_ep.py +++ b/tests/pytorch/distributed/run_ep.py @@ -173,11 +173,8 @@ def test_dispatch_fwd_bwd(self): buf = self._make_buffer() topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) tokens_p = tokens.detach().clone().requires_grad_(True) - recv_t, recv_w, _tc = ep_dispatch(buf, tokens_p, topk_idx, w) - # Pull recv_w into the loss with a zero scale so both dispatch outputs - # contribute a (possibly-zero) grad — backward respects user-supplied - # grad inputs and won't fabricate Nones into zeros. - loss = 0.5 * (recv_t.float() ** 2).sum() + 0.0 * recv_w.float().sum() + recv_t, _recv_w, _tc = ep_dispatch(buf, tokens_p, topk_idx, w) + loss = 0.5 * (recv_t.float() ** 2).sum() loss.backward() torch.cuda.synchronize() torch.testing.assert_close( diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index ef0e5937d9..7d9d3b5206 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -230,7 +230,7 @@ def __init__( size_bytes = tex.ep_handle_mem_size(self.top_k, self.alignment) self.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) self.token_counts = torch.empty(self.num_local_experts, dtype=torch.int32, device=device) - # Persistent workspace; keep resident if activation CPU offloading is on. + # Persistent tensor; keep resident if activation CPU offloading is on. mark_not_offload(self.handle_mem) @classmethod @@ -268,7 +268,7 @@ def from_external( size_bytes = tex.ep_handle_mem_size(inst.top_k, inst.alignment) inst.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) - # Persistent workspace; keep resident if activation CPU offloading is on. + # Persistent tensor; keep resident if activation CPU offloading is on. mark_not_offload(inst.handle_mem) if token_counts is not None: From a4cb0cb12ca725700ba0882cee1cba047f1b1c06 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 23 Jun 2026 09:48:21 -0700 Subject: [PATCH 21/39] EP PyTorch: rename build env var NVTE_BUILD_WITH_NCCL_EP to NVTE_WITH_NCCL_EP Signed-off-by: Phuong Nguyen --- build_tools/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index ca54e72434..5ed4eae9d5 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -80,7 +80,7 @@ def setup_pytorch_extension( # Mirror the NCCL EP gate from setup.py / common CMake. When disabled, the # ep.cpp source no-ops at the #ifdef boundary; without the define it would # produce undefined references to nvte_ep_*. - if bool(int(os.getenv("NVTE_BUILD_WITH_NCCL_EP", "1"))): + if bool(int(os.getenv("NVTE_WITH_NCCL_EP", "1"))): cxx_flags.append("-DNVTE_WITH_NCCL_EP") # PyTorch's symm-mem headers gate the NCCL_HAS_SYMMEM_* feature macros on # USE_NCCL. The EP extension shares the symm-mem NCCL comm with torch, so From 8c3a2222fd7c65ac76eef9f83820659b587099e1 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 24 Jun 2026 03:55:36 -0700 Subject: [PATCH 22/39] EP PyTorch: source ep_combine backward grad from EpBuffer symm-mem under zero-copy, allocate in-flight otherwise Signed-off-by: Phuong Nguyen --- .../pytorch/csrc/extensions/ep.cpp | 4 +- transformer_engine/pytorch/ep.py | 91 ++++++++++++++----- 2 files changed, 70 insertions(+), 25 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/ep.cpp b/transformer_engine/pytorch/csrc/extensions/ep.cpp index c349bcb431..e4b987a44c 100644 --- a/transformer_engine/pytorch/csrc/extensions/ep.cpp +++ b/transformer_engine/pytorch/csrc/extensions/ep.cpp @@ -340,7 +340,7 @@ void ep_combine_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor grad_expe c10::toString(grad_expert_out.scalar_type()), ") must match grad dtype (", c10::toString(grad.scalar_type()), ")"); // grad is autograd-allocated (staged-copy path); grad_expert_out is the - // caller-supplied scatter target and must be symm-mem in zero-copy mode. + // EpBuffer-owned scatter target and must be symm-mem in zero-copy mode. check_symm_mem_required(grad_expert_out, "grad_expert_out"); auto g_dtype = GetTransformerEngineDType(grad.scalar_type()); @@ -350,6 +350,8 @@ void ep_combine_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor grad_expe auto grad_eo_te = makeTransformerEngineTensor(grad_expert_out.data_ptr(), Shape{recv_pr, H}, g_dtype); + // grad is autograd-allocated (staged); grad_expert_out resolves to a symm-mem + // window in zero-copy mode, else kNoWindow for the staged path. NVTECommWindow grad_win = maybe_make_window(grad); NVTECommWindow grad_eo_win = maybe_make_window(grad_expert_out); nvte_ep_combine_bwd(handle_mem_te.data(), grad_te.data(), grad_win, grad_eo_te.data(), diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 7d9d3b5206..1b540bad24 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -91,11 +91,14 @@ def _check_nccl_runtime_version() -> None: _BOOTSTRAPPED = False _ATEXIT_REGISTERED = False +# EP group captured at bootstrap; EpBuffer uses it to allocate the symm-mem +# combine grad buffer in zero-copy mode. +_EP_GROUP: Optional[dist.ProcessGroup] = None def _atexit_finalize() -> None: """Best-effort teardown at interpreter shutdown; swallows errors.""" - global _BOOTSTRAPPED + global _BOOTSTRAPPED, _EP_GROUP if _BOOTSTRAPPED: try: tex.ep_finalize() @@ -105,6 +108,7 @@ def _atexit_finalize() -> None: traceback.print_exc() finally: _BOOTSTRAPPED = False + _EP_GROUP = None def ep_bootstrap( @@ -126,7 +130,7 @@ def ep_bootstrap( ``True`` only when payload tensors are allocated via ``symm_mem_alloc``. Defaults to ``False``. """ - global _BOOTSTRAPPED, _ATEXIT_REGISTERED + global _BOOTSTRAPPED, _ATEXIT_REGISTERED, _EP_GROUP if _BOOTSTRAPPED: raise RuntimeError("ep_bootstrap was already called in this process") if ep_group.size() < 2: @@ -155,6 +159,7 @@ def ep_bootstrap( bool(zero_copy), ) _BOOTSTRAPPED = True + _EP_GROUP = ep_group if not _ATEXIT_REGISTERED: atexit.register(_atexit_finalize) _ATEXIT_REGISTERED = True @@ -168,13 +173,14 @@ def ep_finalize() -> None: ``dist.destroy_process_group()``, since the borrowed NCCL comm becomes invalid once the PG is destroyed. """ - global _BOOTSTRAPPED + global _BOOTSTRAPPED, _EP_GROUP if not _BOOTSTRAPPED: return try: tex.ep_finalize() finally: _BOOTSTRAPPED = False + _EP_GROUP = None # Buffer @@ -185,6 +191,10 @@ class EpBuffer: Cross-rank payload buffers are caller-supplied to ep_dispatch and ep_combine; allocate via symm_mem_alloc in zero-copy mode. Use one EpBuffer per concurrently-in-flight call (e.g. per PP-1F1B microbatch). + + In zero-copy mode the combine backward scatters into a symm-mem grad buffer + owned here (one per buffer, so each layer/microbatch is isolated); the normal + mode allocates that grad in-flight in the backward. """ __slots__ = ( @@ -199,8 +209,26 @@ class EpBuffer: "device", "token_counts", "zero_copy", + "grad_combine_symm_buf", ) + def _alloc_grad_combine_symm_buf(self) -> None: + """Allocate the zero-copy combine grad buffer; None in normal mode.""" + if not self.zero_copy: + self.grad_combine_symm_buf = None + return + if _EP_GROUP is None: + raise RuntimeError("ep_bootstrap must be called before constructing a zero-copy EpBuffer") + buf = symm_mem_alloc( + (self.recv_capacity_per_rank, self.hidden_dim), + self.payload_dtype, + _EP_GROUP, + device=self.device, + ) + # Persistent across microbatches; keep resident under CPU offloading. + mark_not_offload(buf) + self.grad_combine_symm_buf = buf + def __init__( self, top_k: int, @@ -232,6 +260,7 @@ def __init__( self.token_counts = torch.empty(self.num_local_experts, dtype=torch.int32, device=device) # Persistent tensor; keep resident if activation CPU offloading is on. mark_not_offload(self.handle_mem) + self._alloc_grad_combine_symm_buf() @classmethod def from_external( @@ -281,12 +310,15 @@ def from_external( inst.token_counts = token_counts else: inst.token_counts = torch.empty(counts_shape, dtype=torch.int32, device=device) + inst._alloc_grad_combine_symm_buf() return inst def record_stream(self, stream: torch.cuda.Stream) -> None: """Defer caching-allocator reclaim of owned tensors until stream catches up.""" self.handle_mem.record_stream(stream) self.token_counts.record_stream(stream) + if self.grad_combine_symm_buf is not None: + self.grad_combine_symm_buf.record_stream(stream) # torch.library custom ops (so they don't graph-break under torch.compile) @@ -503,7 +535,13 @@ def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: class _EpCombine(torch.autograd.Function): - """Autograd combine; bwd writes into grad_expert_out, or expert_out's storage if None.""" + """Autograd combine. + + bwd scatters the expert_out grad into ``grad_symm_buf`` (EpBuffer-owned + symm-mem, one-sided) in zero-copy mode, or into a plain tensor allocated + in-flight here otherwise. The latter keeps allocation torch.compile / + CUDA-graph safe and lets autograd own the grad's lifetime. + """ @staticmethod def forward( # type: ignore[override] @@ -511,31 +549,41 @@ def forward( # type: ignore[override] handle_mem: torch.Tensor, num_local_tokens: int, hidden_dim: int, - grad_expert_out: Optional[torch.Tensor], + grad_symm_buf: Optional[torch.Tensor], expert_out: torch.Tensor, ): - """Combine fwd; stashes grad_expert_out (or expert_out) as the bwd output slot.""" + """Combine fwd; stashes the bwd grad target or expert_out shape to size it.""" device = expert_out.device result = torch.empty(num_local_tokens, hidden_dim, dtype=expert_out.dtype, device=device) torch.ops.transformer_engine_ep.combine(handle_mem, expert_out, result) - grad_combine_in = grad_expert_out if grad_expert_out is not None else expert_out - # bwd write target reused across microbatches; keep resident under CPU offloading. - mark_not_offload(grad_combine_in) - ctx.save_for_backward(handle_mem, grad_combine_in) + if grad_symm_buf is not None: + ctx.save_for_backward(handle_mem, grad_symm_buf) + else: + ctx.save_for_backward(handle_mem) + ctx.expert_out_shape = expert_out.shape + ctx.expert_out_dtype = expert_out.dtype + ctx.device = device return result @staticmethod def backward(ctx, g_result): # type: ignore[override] if not g_result.is_contiguous(): g_result = g_result.contiguous() - handle_mem, grad_combine_in = ctx.saved_tensors - torch.ops.transformer_engine_ep.combine_bwd(handle_mem, g_result, grad_combine_in) + saved = ctx.saved_tensors + handle_mem = saved[0] + if len(saved) == 2: + grad_expert_out = saved[1] + else: + grad_expert_out = torch.empty( + ctx.expert_out_shape, dtype=ctx.expert_out_dtype, device=ctx.device + ) + torch.ops.transformer_engine_ep.combine_bwd(handle_mem, g_result, grad_expert_out) return ( None, # handle_mem None, # num_local_tokens None, # hidden_dim - None, # grad_expert_out - grad_combine_in, + None, # grad_symm_buf + grad_expert_out, ) @@ -605,26 +653,21 @@ def ep_combine( expert_out: torch.Tensor, *, num_local_tokens: Optional[int] = None, - grad_expert_out: Optional[torch.Tensor] = None, ): """Combine with autograd; caller pre-applies topk weighting. - grad_expert_out is the slot the bwd writes into; if None, expert_out's storage is reused. - Zero-copy mode requires both expert_out and grad_expert_out to be symm-mem-backed. - Result shape is (num_local_tokens, buffer.hidden_dim); defaults to buffer.max_tokens_per_rank rows. + The backward scatters the expert_out grad into buffer.grad_combine_symm_buf + (symm-mem) in zero-copy mode, else into a tensor allocated in-flight. Result + shape is (num_local_tokens, buffer.hidden_dim); defaults to + buffer.max_tokens_per_rank rows. """ _require_bf16("expert_out", expert_out) - if buffer.zero_copy and grad_expert_out is None: - raise ValueError( - "ep_combine: zero-copy mode requires caller-supplied grad_expert_out " - "(allocate via symm_mem_alloc)." - ) if num_local_tokens is None: num_local_tokens = buffer.max_tokens_per_rank return _EpCombine.apply( buffer.handle_mem, num_local_tokens, buffer.hidden_dim, - grad_expert_out, + buffer.grad_combine_symm_buf, expert_out, ) From ddbfed0344d7e4a0728ab8b9f35434d5846501aa Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 24 Jun 2026 03:55:36 -0700 Subject: [PATCH 23/39] EP PyTorch: add zero-copy test pass over dispatch/combine/1f1b autograd and rename dispatch autograd tests Signed-off-by: Phuong Nguyen --- tests/pytorch/distributed/run_ep.py | 119 ++++++++++++++++++++--- tests/pytorch/distributed/run_test_ep.sh | 50 ++++++---- 2 files changed, 134 insertions(+), 35 deletions(-) diff --git a/tests/pytorch/distributed/run_ep.py b/tests/pytorch/distributed/run_ep.py index 7132f42815..1914f9fe1b 100644 --- a/tests/pytorch/distributed/run_ep.py +++ b/tests/pytorch/distributed/run_ep.py @@ -24,7 +24,7 @@ ) -ZERO_COPY = False +ZERO_COPY = os.environ.get("NVTE_EP_ZERO_COPY", "0") == "1" # Must come after the transformer_engine import so libtransformer_engine.so is loaded. import transformer_engine_torch as tex # noqa: F401 @@ -36,6 +36,44 @@ TOKENS_PER_RANK = 4 +def _zero_copy_capable(fn): + """Mark a test to also run in the zero-copy pass; others skip there.""" + fn._zero_copy_capable = True + return fn + + +class _StageToSymm(torch.autograd.Function): + """Identity op that stages ``src`` into a symm-mem buffer; grad passes through. + Lets a test feed a symm-mem-backed, autograd-tracked tensor into ep_combine. + """ + + @staticmethod + def forward(ctx, src, symm_buf): # type: ignore[override] + symm_buf.copy_(src) + return symm_buf + + @staticmethod + def backward(ctx, g): # type: ignore[override] + return g, None + + +class _GradToSymm(torch.autograd.Function): + """Identity fwd; bwd stages the upstream grad into a symm-mem buffer and + returns it, so the next backward (dispatch_bwd) receives a symm-window grad + input — which zero-copy ncclEpCombine requires. + """ + + @staticmethod + def forward(ctx, x, symm_buf): # type: ignore[override] + ctx.symm_buf = symm_buf + return x + + @staticmethod + def backward(ctx, g): # type: ignore[override] + ctx.symm_buf.copy_(g) + return ctx.symm_buf, None + + def _device_sm() -> int: major, minor = torch.cuda.get_device_capability() return major * 10 + minor @@ -110,6 +148,13 @@ def setUpClass(cls): zero_copy=ZERO_COPY, ) + def setUp(self): + # Only the zero-copy-capable tests run in the zero-copy pass. + if ZERO_COPY and not getattr( + getattr(self, self._testMethodName), "_zero_copy_capable", False + ): + self.skipTest("not exercised in zero-copy mode") + def _make_buffer(self, alignment=0, top_k=TOP_K): return EpBuffer( top_k=top_k, @@ -120,6 +165,33 @@ def _make_buffer(self, alignment=0, top_k=TOP_K): alignment=alignment, ) + def _recv_kwargs(self): + """ep_dispatch recv-buffer kwargs: symm-mem-backed under zero-copy, else default-alloc.""" + if not ZERO_COPY: + return {} + rc = self.cfg.recv_capacity_per_rank + return { + "recv_tokens": symm_mem_alloc((rc, HIDDEN_DIM), torch.bfloat16, self.ep_group), + "recv_topk_weights": symm_mem_alloc((rc,), torch.float32, self.ep_group), + } + + def _expert_out(self, eo): + """Stage the combine input into symm-mem under zero-copy (combine requires it).""" + if not ZERO_COPY: + return eo + symm_buf = symm_mem_alloc(tuple(eo.shape), eo.dtype, self.ep_group) + return _StageToSymm.apply(eo, symm_buf) + + def _stage_grad_symm(self, x, symm_buf=None): + """Route x's upstream grad through a symm-mem buffer so dispatch_bwd gets + a symm-window grad input under zero-copy; passthrough otherwise. Pass a + pre-allocated symm_buf to avoid allocating during an interleaved schedule.""" + if not ZERO_COPY: + return x + if symm_buf is None: + symm_buf = symm_mem_alloc(tuple(x.shape), x.dtype, self.ep_group) + return _GradToSymm.apply(x, symm_buf) + def _make_raw_recv(self, dtype=torch.bfloat16): """Raw recv tensors + token_counts for the primitive tests.""" rc = self.cfg.recv_capacity_per_rank @@ -168,13 +240,16 @@ def test_primitive_dispatch_combine_identity(self): # Autograd - def test_dispatch_fwd_bwd(self): + @_zero_copy_capable + def test_dispatch_autograd(self): """0.5*||recv_tokens||^2 ; grad_tokens equals TOP_K * tokens.""" buf = self._make_buffer() topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) tokens_p = tokens.detach().clone().requires_grad_(True) - recv_t, _recv_w, _tc = ep_dispatch(buf, tokens_p, topk_idx, w) - loss = 0.5 * (recv_t.float() ** 2).sum() + recv_t, recv_w, _tc = ep_dispatch(buf, tokens_p, topk_idx, w, **self._recv_kwargs()) + recv_t = self._stage_grad_symm(recv_t) + recv_w = self._stage_grad_symm(recv_w) + loss = 0.5 * (recv_t.float() ** 2).sum() + 0.0 * recv_w.float().sum() loss.backward() torch.cuda.synchronize() torch.testing.assert_close( @@ -195,7 +270,7 @@ def test_combine_fwd_bwd(self): # Multi-iter stability - def test_dispatch_fwd_bwd_multiple_iterations(self): + def test_dispatch_autograd_multiple_iterations(self): """5 fwd+bwd iters on the same EpBuffer must be bit-stable.""" buf = self._make_buffer() topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) @@ -255,6 +330,7 @@ def step(): # PP-1F1B handle isolation + @_zero_copy_capable def test_pp_1f1b_two_handles(self): """PP-1F1B interleave (F0 F1 B0 F2 B1 B2) over 3 per-microbatch buffers.""" T, H = TOKENS_PER_RANK, HIDDEN_DIM @@ -270,13 +346,26 @@ def test_pp_1f1b_two_handles(self): tokens_p.append(t.detach().clone().requires_grad_(True)) recv = [None, None, None] + # Per-microbatch recv + grad-staging buffers, all symm-mem under zero-copy + # and pre-allocated so nothing is allocated/freed mid-interleave. + recv_kw = [self._recv_kwargs() for _ in scales] + recv_w = [None, None, None] + rc = self.cfg.recv_capacity_per_rank + if ZERO_COPY: + gbuf_t = [symm_mem_alloc((rc, H), torch.bfloat16, self.ep_group) for _ in scales] + gbuf_w = [symm_mem_alloc((rc,), torch.float32, self.ep_group) for _ in scales] + else: + gbuf_t = gbuf_w = [None, None, None] def fwd(k): - recv[k], _, _ = ep_dispatch(buffers[k], tokens_p[k], idx, w) + rt, rw, _ = ep_dispatch(buffers[k], tokens_p[k], idx, w, **recv_kw[k]) + recv[k] = self._stage_grad_symm(rt, gbuf_t[k]) + recv_w[k] = self._stage_grad_symm(rw, gbuf_w[k]) def bwd(k): - (0.5 * (recv[k].float() ** 2).sum()).backward() + (0.5 * (recv[k].float() ** 2).sum() + 0.0 * recv_w[k].float().sum()).backward() recv[k] = None + recv_w[k] = None fwd(0) fwd(1) @@ -312,19 +401,19 @@ def test_dispatch_caller_recv_buffers_autograd(self): tokens_p.grad.float(), tokens.float() * float(TOP_K), atol=5e-2, rtol=5e-2 ) - def test_combine_grad_expert_out_autograd(self): - """ep_combine with caller-supplied grad_expert_out; bwd writes into that slot.""" + @_zero_copy_capable + def test_combine_autograd(self): + """ep_combine fwd+bwd; bwd grad target is the EpBuffer symm buffer (zc) or in-flight.""" buf = self._make_buffer() topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) tokens_p = tokens.detach().clone().requires_grad_(True) - recv_t, recv_w, _ = ep_dispatch(buf, tokens_p, topk_idx, w) - eo = self._weighted(recv_t, recv_w) - grad_eo = torch.empty_like(eo) - gp = grad_eo.data_ptr() - out = ep_combine(buf, eo, grad_expert_out=grad_eo) + recv_t, recv_w, _ = ep_dispatch(buf, tokens_p, topk_idx, w, **self._recv_kwargs()) + recv_t = self._stage_grad_symm(recv_t) + recv_w = self._stage_grad_symm(recv_w) + eo = self._expert_out(self._weighted(recv_t, recv_w)) + out = ep_combine(buf, eo) (0.5 * (out.float() ** 2).sum()).backward() torch.cuda.synchronize() - self.assertEqual(grad_eo.data_ptr(), gp) torch.testing.assert_close(out.float(), tokens.float(), atol=5e-2, rtol=5e-2) torch.testing.assert_close(tokens_p.grad.float(), tokens.float(), atol=5e-2, rtol=5e-2) diff --git a/tests/pytorch/distributed/run_test_ep.sh b/tests/pytorch/distributed/run_test_ep.sh index 92d63cff7e..2838973152 100755 --- a/tests/pytorch/distributed/run_test_ep.sh +++ b/tests/pytorch/distributed/run_test_ep.sh @@ -29,27 +29,37 @@ export NCCL_EP_JIT_CACHE_DIR mkdir -p "$NCCL_EP_JIT_CACHE_DIR" SCRIPT="${SCRIPT_DIR}/run_ep.py" -echo "=== Running ${SCRIPT} on ${NUM_RANKS} GPUs (timeout=${TEST_TIMEOUT_S}s) ===" - -# setsid + kill-after so SIGKILL takes down the whole process group, not just torchrun. -setsid timeout --foreground --kill-after=10 --signal=TERM "${TEST_TIMEOUT_S}" \ - torchrun --standalone --nnodes=1 --nproc-per-node="${NUM_RANKS}" \ - "${SCRIPT}" 2>&1 | tee stdout_ep.txt -RC=${PIPESTATUS[0]} -pkill -9 -f "tests/pytorch/distributed/run_ep.py" 2>/dev/null || true RET=0 -if [ "${RC}" -ne 0 ]; then - echo "torchrun exited with ${RC}" - RET=1 -fi -# Match unittest failure markers and unhandled Python tracebacks; torchrun -# prefixes per-rank stderr with "[rankN]:" so don't anchor at column 0. -if grep -qE "(^|]:)FAILED|(^|]:)Traceback" stdout_ep.txt; then RET=1; fi -if ! grep -qE "Ran [0-9]+ test|^OK$" stdout_ep.txt; then - echo "ERROR: no test summary — likely hang or early crash" - RET=1 -fi -if [ -z "${KEEP_EP_LOGS:-}" ]; then rm -f stdout_ep.txt; fi +# Run the suite once per IO mode. Modes can't be mixed in one process +# (ep_bootstrap is once-per-process), so zero-copy gets its own run; only the +# zero-copy-capable tests execute there (the rest self-skip). +run_pass() { + local label="$1" + local zc="$2" + local log="stdout_ep_${label}.txt" + echo "=== Running ${SCRIPT} [${label}] on ${NUM_RANKS} GPUs (timeout=${TEST_TIMEOUT_S}s) ===" + # setsid + kill-after so SIGKILL takes down the whole process group, not just torchrun. + NVTE_EP_ZERO_COPY="${zc}" setsid timeout --foreground --kill-after=10 --signal=TERM \ + "${TEST_TIMEOUT_S}" \ + torchrun --standalone --nnodes=1 --nproc-per-node="${NUM_RANKS}" \ + "${SCRIPT}" 2>&1 | tee "${log}" + local rc=${PIPESTATUS[0]} + pkill -9 -f "tests/pytorch/distributed/run_ep.py" 2>/dev/null || true + + if [ "${rc}" -ne 0 ]; then echo "[${label}] torchrun exited with ${rc}"; RET=1; fi + # Match unittest failure markers and unhandled Python tracebacks; torchrun + # prefixes per-rank stderr with "[rankN]:" so don't anchor at column 0. + if grep -qE "(^|]:)FAILED|(^|]:)Traceback" "${log}"; then RET=1; fi + if ! grep -qE "Ran [0-9]+ test|^OK$" "${log}"; then + echo "[${label}] ERROR: no test summary — likely hang or early crash" + RET=1 + fi + if [ -z "${KEEP_EP_LOGS:-}" ]; then rm -f "${log}"; fi +} + +run_pass "default" 0 +run_pass "zero_copy" 1 + exit $RET From bcf9bb06666382c751dd739a5dd9ad429daf5fb5 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 24 Jun 2026 04:07:28 -0700 Subject: [PATCH 24/39] EP PyTorch: rename _zero_copy_capable test marker to _zero_copy_test_include and combine_bwd grad_eo locals to grad_expert_out Signed-off-by: Phuong Nguyen --- tests/pytorch/distributed/run_ep.py | 12 ++++++------ transformer_engine/pytorch/csrc/extensions/ep.cpp | 8 ++++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/pytorch/distributed/run_ep.py b/tests/pytorch/distributed/run_ep.py index 1914f9fe1b..5130f50a93 100644 --- a/tests/pytorch/distributed/run_ep.py +++ b/tests/pytorch/distributed/run_ep.py @@ -36,9 +36,9 @@ TOKENS_PER_RANK = 4 -def _zero_copy_capable(fn): +def _zero_copy_test_include(fn): """Mark a test to also run in the zero-copy pass; others skip there.""" - fn._zero_copy_capable = True + fn._zero_copy_test_include = True return fn @@ -151,7 +151,7 @@ def setUpClass(cls): def setUp(self): # Only the zero-copy-capable tests run in the zero-copy pass. if ZERO_COPY and not getattr( - getattr(self, self._testMethodName), "_zero_copy_capable", False + getattr(self, self._testMethodName), "_zero_copy_test_include", False ): self.skipTest("not exercised in zero-copy mode") @@ -240,7 +240,7 @@ def test_primitive_dispatch_combine_identity(self): # Autograd - @_zero_copy_capable + @_zero_copy_test_include def test_dispatch_autograd(self): """0.5*||recv_tokens||^2 ; grad_tokens equals TOP_K * tokens.""" buf = self._make_buffer() @@ -330,7 +330,7 @@ def step(): # PP-1F1B handle isolation - @_zero_copy_capable + @_zero_copy_test_include def test_pp_1f1b_two_handles(self): """PP-1F1B interleave (F0 F1 B0 F2 B1 B2) over 3 per-microbatch buffers.""" T, H = TOKENS_PER_RANK, HIDDEN_DIM @@ -401,7 +401,7 @@ def test_dispatch_caller_recv_buffers_autograd(self): tokens_p.grad.float(), tokens.float() * float(TOP_K), atol=5e-2, rtol=5e-2 ) - @_zero_copy_capable + @_zero_copy_test_include def test_combine_autograd(self): """ep_combine fwd+bwd; bwd grad target is the EpBuffer symm buffer (zc) or in-flight.""" buf = self._make_buffer() diff --git a/transformer_engine/pytorch/csrc/extensions/ep.cpp b/transformer_engine/pytorch/csrc/extensions/ep.cpp index e4b987a44c..d1ef76af40 100644 --- a/transformer_engine/pytorch/csrc/extensions/ep.cpp +++ b/transformer_engine/pytorch/csrc/extensions/ep.cpp @@ -347,15 +347,15 @@ void ep_combine_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor grad_expe auto handle_mem_te = makeTransformerEngineTensor( handle_mem.data_ptr(), Shape{static_cast(handle_mem.numel())}, DType::kByte); auto grad_te = makeTransformerEngineTensor(grad.data_ptr(), Shape{T_flat, H}, g_dtype); - auto grad_eo_te = + auto grad_expert_out_te = makeTransformerEngineTensor(grad_expert_out.data_ptr(), Shape{recv_pr, H}, g_dtype); // grad is autograd-allocated (staged); grad_expert_out resolves to a symm-mem // window in zero-copy mode, else kNoWindow for the staged path. NVTECommWindow grad_win = maybe_make_window(grad); - NVTECommWindow grad_eo_win = maybe_make_window(grad_expert_out); - nvte_ep_combine_bwd(handle_mem_te.data(), grad_te.data(), grad_win, grad_eo_te.data(), - grad_eo_win, stream); + NVTECommWindow grad_expert_out_win = maybe_make_window(grad_expert_out); + nvte_ep_combine_bwd(handle_mem_te.data(), grad_te.data(), grad_win, grad_expert_out_te.data(), + grad_expert_out_win, stream); } void register_ep_bindings(pybind11::module_& m) { From a6e8768ed73197f45e65dc9912e5029684cd0770 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 24 Jun 2026 04:36:34 -0700 Subject: [PATCH 25/39] EP PyTorch: parametrize test_dispatch_autograd over recv-buffer cases and run 1f1b interleave both eager and CUDA-graph-captured Signed-off-by: Phuong Nguyen --- tests/pytorch/distributed/run_ep.py | 121 ++++++++++++++++------------ 1 file changed, 70 insertions(+), 51 deletions(-) diff --git a/tests/pytorch/distributed/run_ep.py b/tests/pytorch/distributed/run_ep.py index 5130f50a93..2f536b27e2 100644 --- a/tests/pytorch/distributed/run_ep.py +++ b/tests/pytorch/distributed/run_ep.py @@ -242,31 +242,33 @@ def test_primitive_dispatch_combine_identity(self): @_zero_copy_test_include def test_dispatch_autograd(self): - """0.5*||recv_tokens||^2 ; grad_tokens equals TOP_K * tokens.""" - buf = self._make_buffer() - topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) - tokens_p = tokens.detach().clone().requires_grad_(True) - recv_t, recv_w, _tc = ep_dispatch(buf, tokens_p, topk_idx, w, **self._recv_kwargs()) - recv_t = self._stage_grad_symm(recv_t) - recv_w = self._stage_grad_symm(recv_w) - loss = 0.5 * (recv_t.float() ** 2).sum() + 0.0 * recv_w.float().sum() - loss.backward() - torch.cuda.synchronize() - torch.testing.assert_close( - tokens_p.grad.float(), tokens.float() * float(TOP_K), atol=5e-2, rtol=5e-2 - ) - - def test_combine_fwd_bwd(self): - """Full dispatch + combine fwd+bwd; identity inputs round-trip.""" - buf = self._make_buffer() - topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) - tokens_p = tokens.detach().clone().requires_grad_(True) - out = self._moe_step(buf, topk_idx, tokens_p, w) - loss = 0.5 * (out.float() ** 2).sum() - loss.backward() - torch.cuda.synchronize() - torch.testing.assert_close(out.float(), tokens.float(), atol=5e-2, rtol=5e-2) - torch.testing.assert_close(tokens_p.grad.float(), tokens.float(), atol=5e-2, rtol=5e-2) + """0.5*||recv_tokens||^2 ; grad_tokens equals TOP_K * tokens. Covers both + default-allocated and caller-supplied recv buffers (caller-supplied is + required and symm-mem-backed under zero-copy).""" + if ZERO_COPY: + cases = [("caller_recv_symm", self._recv_kwargs())] + else: + rt_buf, rw_buf, _ = self._make_raw_recv() + cases = [ + ("default_alloc", {}), + ("caller_recv", {"recv_tokens": rt_buf, "recv_topk_weights": rw_buf}), + ] + for label, recv_kw in cases: + with self.subTest(case=label): + buf = self._make_buffer() + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + tokens_p = tokens.detach().clone().requires_grad_(True) + rt, rw, _tc = ep_dispatch(buf, tokens_p, topk_idx, w, **recv_kw) + if recv_kw: # caller-supplied buffers must be used in place + self.assertEqual(rt.data_ptr(), recv_kw["recv_tokens"].data_ptr()) + self.assertEqual(rw.data_ptr(), recv_kw["recv_topk_weights"].data_ptr()) + rt = self._stage_grad_symm(rt) + rw = self._stage_grad_symm(rw) + (0.5 * (rt.float() ** 2).sum() + 0.0 * rw.float().sum()).backward() + torch.cuda.synchronize() + torch.testing.assert_close( + tokens_p.grad.float(), tokens.float() * float(TOP_K), atol=5e-2, rtol=5e-2 + ) # Multi-iter stability @@ -332,7 +334,14 @@ def step(): @_zero_copy_test_include def test_pp_1f1b_two_handles(self): - """PP-1F1B interleave (F0 F1 B0 F2 B1 B2) over 3 per-microbatch buffers.""" + """PP-1F1B interleave (F0 F1 B0 F2 B1 B2) over 3 per-microbatch buffers, + run eagerly and replayed from a CUDA graph capturing the full fwd+bwd + schedule (prepare included; routing is fixed so replay reproduces it).""" + for capture in (False, True): + with self.subTest(capture=capture): + self._run_1f1b(capture) + + def _run_1f1b(self, capture): T, H = TOKENS_PER_RANK, HIDDEN_DIM idx, _toks, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) scales = (0.13, 0.41, 0.77) @@ -367,12 +376,41 @@ def bwd(k): recv[k] = None recv_w[k] = None - fwd(0) - fwd(1) - bwd(0) - fwd(2) - bwd(1) - bwd(2) + def interleave(): + fwd(0) + fwd(1) + bwd(0) + fwd(2) + bwd(1) + bwd(2) + + def zero_grads(): + for tp in tokens_p: + if tp.grad is not None: + tp.grad.zero_() + + if not capture: + interleave() + else: + # Warmup on a side stream, then capture the full schedule and replay. + # Grads stay pre-allocated (zeroed, not None) so backward accumulates + # in place during both capture and replay. + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + zero_grads() + interleave() + torch.cuda.current_stream().wait_stream(s) + torch.cuda.synchronize() + + zero_grads() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + interleave() + zero_grads() + graph.replay() + torch.cuda.synchronize() for k in range(3): torch.testing.assert_close( @@ -382,25 +420,6 @@ def bwd(k): rtol=5e-2, ) - # Caller-supplied output buffers (autograd) - - def test_dispatch_caller_recv_buffers_autograd(self): - """ep_dispatch with caller-supplied recv buffers; fwd+bwd matches default-alloc grads.""" - buf = self._make_buffer() - topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) - recv_tokens, recv_w, _ = self._make_raw_recv() - tokens_p = tokens.detach().clone().requires_grad_(True) - rt, rw, _tc = ep_dispatch( - buf, tokens_p, topk_idx, w, recv_tokens=recv_tokens, recv_topk_weights=recv_w - ) - self.assertEqual(rt.data_ptr(), recv_tokens.data_ptr()) - self.assertEqual(rw.data_ptr(), recv_w.data_ptr()) - (0.5 * (rt.float() ** 2).sum() + 0.0 * rw.float().sum()).backward() - torch.cuda.synchronize() - torch.testing.assert_close( - tokens_p.grad.float(), tokens.float() * float(TOP_K), atol=5e-2, rtol=5e-2 - ) - @_zero_copy_test_include def test_combine_autograd(self): """ep_combine fwd+bwd; bwd grad target is the EpBuffer symm buffer (zc) or in-flight.""" From 05068bccc53252945e061fcda6d68e5fdaa23af7 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 24 Jun 2026 04:36:34 -0700 Subject: [PATCH 26/39] EP PyTorch: drop unused EpBuffer.from_external; token_counts is always allocated internally Signed-off-by: Phuong Nguyen --- transformer_engine/pytorch/ep.py | 51 -------------------------------- 1 file changed, 51 deletions(-) diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 1b540bad24..40a3cd3841 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -262,57 +262,6 @@ def __init__( mark_not_offload(self.handle_mem) self._alloc_grad_combine_symm_buf() - @classmethod - def from_external( - cls, - top_k: int, - max_tokens_per_rank: int, - recv_capacity_per_rank: int, - hidden_dim: int, - num_local_experts: int, - *, - token_counts: Optional[torch.Tensor] = None, - alignment: int = 0, - payload_dtype: torch.dtype = torch.bfloat16, - device: Optional[torch.device] = None, - ) -> "EpBuffer": - """Construct from a caller-allocated token_counts; handle_mem is always fresh.""" - if device is None: - device = torch.device("cuda", torch.cuda.current_device()) - alignment = int(alignment) - if alignment > 1 and (alignment & (alignment - 1)) != 0: - raise ValueError(f"alignment must be 0, 1, or a power of two (got {alignment}).") - counts_shape = (num_local_experts,) - - inst = cls.__new__(cls) - inst.top_k = int(top_k) - inst.alignment = alignment - inst.max_tokens_per_rank = int(max_tokens_per_rank) - inst.recv_capacity_per_rank = int(recv_capacity_per_rank) - inst.hidden_dim = int(hidden_dim) - inst.num_local_experts = int(num_local_experts) - inst.payload_dtype = payload_dtype - inst.device = device - inst.zero_copy = bool(tex.ep_get_zero_copy()) - - size_bytes = tex.ep_handle_mem_size(inst.top_k, inst.alignment) - inst.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) - # Persistent tensor; keep resident if activation CPU offloading is on. - mark_not_offload(inst.handle_mem) - - if token_counts is not None: - if tuple(token_counts.shape) != counts_shape: - raise ValueError( - f"token_counts shape {tuple(token_counts.shape)} != expected {counts_shape}" - ) - if token_counts.dtype != torch.int32: - raise ValueError(f"token_counts dtype {token_counts.dtype} != expected int32") - inst.token_counts = token_counts - else: - inst.token_counts = torch.empty(counts_shape, dtype=torch.int32, device=device) - inst._alloc_grad_combine_symm_buf() - return inst - def record_stream(self, stream: torch.cuda.Stream) -> None: """Defer caching-allocator reclaim of owned tensors until stream catches up.""" self.handle_mem.record_stream(stream) From e33e63f894e24bbc4682bbb99e1a591844285e90 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 24 Jun 2026 04:56:22 -0700 Subject: [PATCH 27/39] EP PyTorch: EpBuffer owns dispatch recv-output symm buffers in zero-copy; ep_dispatch falls back to them Signed-off-by: Phuong Nguyen --- tests/pytorch/distributed/run_ep.py | 28 ++++-------- transformer_engine/pytorch/ep.py | 66 ++++++++++++++++------------- 2 files changed, 45 insertions(+), 49 deletions(-) diff --git a/tests/pytorch/distributed/run_ep.py b/tests/pytorch/distributed/run_ep.py index 2f536b27e2..a64fb538f2 100644 --- a/tests/pytorch/distributed/run_ep.py +++ b/tests/pytorch/distributed/run_ep.py @@ -165,16 +165,6 @@ def _make_buffer(self, alignment=0, top_k=TOP_K): alignment=alignment, ) - def _recv_kwargs(self): - """ep_dispatch recv-buffer kwargs: symm-mem-backed under zero-copy, else default-alloc.""" - if not ZERO_COPY: - return {} - rc = self.cfg.recv_capacity_per_rank - return { - "recv_tokens": symm_mem_alloc((rc, HIDDEN_DIM), torch.bfloat16, self.ep_group), - "recv_topk_weights": symm_mem_alloc((rc,), torch.float32, self.ep_group), - } - def _expert_out(self, eo): """Stage the combine input into symm-mem under zero-copy (combine requires it).""" if not ZERO_COPY: @@ -242,11 +232,11 @@ def test_primitive_dispatch_combine_identity(self): @_zero_copy_test_include def test_dispatch_autograd(self): - """0.5*||recv_tokens||^2 ; grad_tokens equals TOP_K * tokens. Covers both - default-allocated and caller-supplied recv buffers (caller-supplied is - required and symm-mem-backed under zero-copy).""" + """0.5*||recv_tokens||^2 ; grad_tokens equals TOP_K * tokens. Covers the + EpBuffer-owned recv outputs (symm-mem under zero-copy) and, in normal + mode, caller-supplied recv buffers.""" if ZERO_COPY: - cases = [("caller_recv_symm", self._recv_kwargs())] + cases = [("buffer_owned", {})] else: rt_buf, rw_buf, _ = self._make_raw_recv() cases = [ @@ -355,9 +345,9 @@ def _run_1f1b(self, capture): tokens_p.append(t.detach().clone().requires_grad_(True)) recv = [None, None, None] - # Per-microbatch recv + grad-staging buffers, all symm-mem under zero-copy - # and pre-allocated so nothing is allocated/freed mid-interleave. - recv_kw = [self._recv_kwargs() for _ in scales] + # Per-microbatch grad-staging buffers, symm-mem under zero-copy and + # pre-allocated so nothing is allocated/freed mid-interleave. The recv + # outputs are owned by each EpBuffer (symm-mem under zero-copy). recv_w = [None, None, None] rc = self.cfg.recv_capacity_per_rank if ZERO_COPY: @@ -367,7 +357,7 @@ def _run_1f1b(self, capture): gbuf_t = gbuf_w = [None, None, None] def fwd(k): - rt, rw, _ = ep_dispatch(buffers[k], tokens_p[k], idx, w, **recv_kw[k]) + rt, rw, _ = ep_dispatch(buffers[k], tokens_p[k], idx, w) recv[k] = self._stage_grad_symm(rt, gbuf_t[k]) recv_w[k] = self._stage_grad_symm(rw, gbuf_w[k]) @@ -426,7 +416,7 @@ def test_combine_autograd(self): buf = self._make_buffer() topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) tokens_p = tokens.detach().clone().requires_grad_(True) - recv_t, recv_w, _ = ep_dispatch(buf, tokens_p, topk_idx, w, **self._recv_kwargs()) + recv_t, recv_w, _ = ep_dispatch(buf, tokens_p, topk_idx, w) recv_t = self._stage_grad_symm(recv_t) recv_w = self._stage_grad_symm(recv_w) eo = self._expert_out(self._weighted(recv_t, recv_w)) diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 40a3cd3841..630dc9fc4e 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -188,13 +188,13 @@ def ep_finalize() -> None: class EpBuffer: """Per-microbatch EP layer state holding handle_mem and token_counts. - Cross-rank payload buffers are caller-supplied to ep_dispatch and - ep_combine; allocate via symm_mem_alloc in zero-copy mode. Use one EpBuffer per concurrently-in-flight call (e.g. per PP-1F1B microbatch). - In zero-copy mode the combine backward scatters into a symm-mem grad buffer - owned here (one per buffer, so each layer/microbatch is isolated); the normal - mode allocates that grad in-flight in the backward. + In zero-copy mode the buffer owns the symm-mem buffers the one-sided path + requires: the dispatch recv outputs (recv_tokens, recv_topk_weights) and the + combine backward grad target. One set per buffer, so each layer/microbatch is + isolated. In normal mode these are None and allocated in-flight instead (recv + outputs in the dispatch forward, the combine grad in the backward). """ __slots__ = ( @@ -209,25 +209,27 @@ class EpBuffer: "device", "token_counts", "zero_copy", + "recv_tokens_symm_buf", + "recv_topk_weights_symm_buf", "grad_combine_symm_buf", ) - def _alloc_grad_combine_symm_buf(self) -> None: - """Allocate the zero-copy combine grad buffer; None in normal mode.""" + def _alloc_symm_buffers(self) -> None: + """Allocate the EpBuffer-owned symm-mem buffers; all None in normal mode.""" if not self.zero_copy: + self.recv_tokens_symm_buf = None + self.recv_topk_weights_symm_buf = None self.grad_combine_symm_buf = None return if _EP_GROUP is None: raise RuntimeError("ep_bootstrap must be called before constructing a zero-copy EpBuffer") - buf = symm_mem_alloc( - (self.recv_capacity_per_rank, self.hidden_dim), - self.payload_dtype, - _EP_GROUP, - device=self.device, - ) + rc, h = self.recv_capacity_per_rank, self.hidden_dim + self.recv_tokens_symm_buf = symm_mem_alloc((rc, h), self.payload_dtype, _EP_GROUP, device=self.device) + self.recv_topk_weights_symm_buf = symm_mem_alloc((rc,), torch.float32, _EP_GROUP, device=self.device) + self.grad_combine_symm_buf = symm_mem_alloc((rc, h), self.payload_dtype, _EP_GROUP, device=self.device) # Persistent across microbatches; keep resident under CPU offloading. - mark_not_offload(buf) - self.grad_combine_symm_buf = buf + for buf in (self.recv_tokens_symm_buf, self.recv_topk_weights_symm_buf, self.grad_combine_symm_buf): + mark_not_offload(buf) def __init__( self, @@ -260,7 +262,7 @@ def __init__( self.token_counts = torch.empty(self.num_local_experts, dtype=torch.int32, device=device) # Persistent tensor; keep resident if activation CPU offloading is on. mark_not_offload(self.handle_mem) - self._alloc_grad_combine_symm_buf() + self._alloc_symm_buffers() def record_stream(self, stream: torch.cuda.Stream) -> None: """Defer caching-allocator reclaim of owned tensors until stream catches up.""" @@ -558,8 +560,9 @@ def ep_dispatch( ): """Prepare + dispatch with autograd. topk_idx must be int64. - recv_tokens / recv_topk_weights are used as-is if supplied, else allocated. - Zero-copy mode requires both to be supplied and symm-mem-backed. + recv_tokens / recv_topk_weights are used as-is if supplied, else taken from + the EpBuffer-owned symm-mem buffers in zero-copy mode or allocated in-flight + otherwise. Returns (recv_tokens, recv_topk_weights, token_counts); token_counts is non-diff. """ _require_bf16("tokens", tokens) @@ -568,21 +571,24 @@ def ep_dispatch( f"topk_weights must be float32; got dtype={topk_weights.dtype}. " "Cast with topk_weights.float() before calling." ) - if buffer.zero_copy and (recv_tokens is None or recv_topk_weights is None): - raise ValueError( - "ep_dispatch: zero-copy mode requires caller-supplied recv_tokens and " - "recv_topk_weights (allocate via symm_mem_alloc)." - ) if recv_tokens is None: - recv_tokens = torch.empty( - buffer.recv_capacity_per_rank, - buffer.hidden_dim, - dtype=buffer.payload_dtype, - device=buffer.device, + recv_tokens = ( + buffer.recv_tokens_symm_buf + if buffer.zero_copy + else torch.empty( + buffer.recv_capacity_per_rank, + buffer.hidden_dim, + dtype=buffer.payload_dtype, + device=buffer.device, + ) ) if recv_topk_weights is None: - recv_topk_weights = torch.empty( - buffer.recv_capacity_per_rank, dtype=torch.float32, device=buffer.device + recv_topk_weights = ( + buffer.recv_topk_weights_symm_buf + if buffer.zero_copy + else torch.empty( + buffer.recv_capacity_per_rank, dtype=torch.float32, device=buffer.device + ) ) return _EpDispatch.apply( buffer.handle_mem, From 31db373fb5a08add447e11c01490eb3d2731e9cb Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 24 Jun 2026 04:56:44 -0700 Subject: [PATCH 28/39] EP PyTorch: stash combine-bwd grad scatter target as plain ctx attribute instead of save_for_backward Signed-off-by: Phuong Nguyen --- transformer_engine/pytorch/ep.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 630dc9fc4e..4a2c3da95b 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -492,6 +492,10 @@ class _EpCombine(torch.autograd.Function): symm-mem, one-sided) in zero-copy mode, or into a plain tensor allocated in-flight here otherwise. The latter keeps allocation torch.compile / CUDA-graph safe and lets autograd own the grad's lifetime. + + ``grad_symm_buf`` is the backward's scatter target (an output it writes, never + reads), so it is stashed as a plain ctx attribute rather than via + save_for_backward, which would version-track a tensor we mutate. """ @staticmethod @@ -507,10 +511,9 @@ def forward( # type: ignore[override] device = expert_out.device result = torch.empty(num_local_tokens, hidden_dim, dtype=expert_out.dtype, device=device) torch.ops.transformer_engine_ep.combine(handle_mem, expert_out, result) - if grad_symm_buf is not None: - ctx.save_for_backward(handle_mem, grad_symm_buf) - else: - ctx.save_for_backward(handle_mem) + ctx.save_for_backward(handle_mem) + ctx.grad_symm_buf = grad_symm_buf + if grad_symm_buf is None: ctx.expert_out_shape = expert_out.shape ctx.expert_out_dtype = expert_out.dtype ctx.device = device @@ -520,11 +523,9 @@ def forward( # type: ignore[override] def backward(ctx, g_result): # type: ignore[override] if not g_result.is_contiguous(): g_result = g_result.contiguous() - saved = ctx.saved_tensors - handle_mem = saved[0] - if len(saved) == 2: - grad_expert_out = saved[1] - else: + (handle_mem,) = ctx.saved_tensors + grad_expert_out = ctx.grad_symm_buf + if grad_expert_out is None: grad_expert_out = torch.empty( ctx.expert_out_shape, dtype=ctx.expert_out_dtype, device=ctx.device ) From 9496515b23bf93726d528687215b431d77d448e5 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 24 Jun 2026 04:56:58 -0700 Subject: [PATCH 29/39] EP PyTorch: drop unused EpBuffer.record_stream Signed-off-by: Phuong Nguyen --- transformer_engine/pytorch/ep.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 4a2c3da95b..aee6372e67 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -264,13 +264,6 @@ def __init__( mark_not_offload(self.handle_mem) self._alloc_symm_buffers() - def record_stream(self, stream: torch.cuda.Stream) -> None: - """Defer caching-allocator reclaim of owned tensors until stream catches up.""" - self.handle_mem.record_stream(stream) - self.token_counts.record_stream(stream) - if self.grad_combine_symm_buf is not None: - self.grad_combine_symm_buf.record_stream(stream) - # torch.library custom ops (so they don't graph-break under torch.compile) From d4f89731c5d8a9022031cb824ee856108f1369fc Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 24 Jun 2026 05:21:04 -0700 Subject: [PATCH 30/39] EP PyTorch: add EpBuffer caller_provides_dispatch_recv_tokens and caller_provides_combine_grad_buffer; recv_topk_weights is always buffer-owned Signed-off-by: Phuong Nguyen --- tests/pytorch/distributed/run_ep.py | 74 ++++++++++++++++++++++++--- transformer_engine/pytorch/ep.py | 78 +++++++++++++++++++++-------- 2 files changed, 123 insertions(+), 29 deletions(-) diff --git a/tests/pytorch/distributed/run_ep.py b/tests/pytorch/distributed/run_ep.py index a64fb538f2..60e81e0c8e 100644 --- a/tests/pytorch/distributed/run_ep.py +++ b/tests/pytorch/distributed/run_ep.py @@ -155,7 +155,13 @@ def setUp(self): ): self.skipTest("not exercised in zero-copy mode") - def _make_buffer(self, alignment=0, top_k=TOP_K): + def _make_buffer( + self, + alignment=0, + top_k=TOP_K, + caller_provides_dispatch_recv_tokens=False, + caller_provides_combine_grad_buffer=False, + ): return EpBuffer( top_k=top_k, max_tokens_per_rank=TOKENS_PER_RANK, @@ -163,6 +169,8 @@ def _make_buffer(self, alignment=0, top_k=TOP_K): hidden_dim=HIDDEN_DIM, num_local_experts=NUM_LOCAL_EXPERTS, alignment=alignment, + caller_provides_dispatch_recv_tokens=caller_provides_dispatch_recv_tokens, + caller_provides_combine_grad_buffer=caller_provides_combine_grad_buffer, ) def _expert_out(self, eo): @@ -233,15 +241,15 @@ def test_primitive_dispatch_combine_identity(self): @_zero_copy_test_include def test_dispatch_autograd(self): """0.5*||recv_tokens||^2 ; grad_tokens equals TOP_K * tokens. Covers the - EpBuffer-owned recv outputs (symm-mem under zero-copy) and, in normal - mode, caller-supplied recv buffers.""" + EpBuffer-owned recv tokens (symm-mem under zero-copy) and, in normal + mode, a caller-supplied recv_tokens buffer.""" if ZERO_COPY: cases = [("buffer_owned", {})] else: - rt_buf, rw_buf, _ = self._make_raw_recv() + rt_buf, _rw_buf, _ = self._make_raw_recv() cases = [ ("default_alloc", {}), - ("caller_recv", {"recv_tokens": rt_buf, "recv_topk_weights": rw_buf}), + ("caller_recv", {"recv_tokens": rt_buf}), ] for label, recv_kw in cases: with self.subTest(case=label): @@ -249,9 +257,8 @@ def test_dispatch_autograd(self): topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) tokens_p = tokens.detach().clone().requires_grad_(True) rt, rw, _tc = ep_dispatch(buf, tokens_p, topk_idx, w, **recv_kw) - if recv_kw: # caller-supplied buffers must be used in place + if recv_kw: # caller-supplied recv_tokens must be used in place self.assertEqual(rt.data_ptr(), recv_kw["recv_tokens"].data_ptr()) - self.assertEqual(rw.data_ptr(), recv_kw["recv_topk_weights"].data_ptr()) rt = self._stage_grad_symm(rt) rw = self._stage_grad_symm(rw) (0.5 * (rt.float() ** 2).sum() + 0.0 * rw.float().sum()).backward() @@ -260,6 +267,59 @@ def test_dispatch_autograd(self): tokens_p.grad.float(), tokens.float() * float(TOP_K), atol=5e-2, rtol=5e-2 ) + @_zero_copy_test_include + def test_caller_provides_dispatch_recv_tokens(self): + """caller_provides_dispatch_recv_tokens: EpBuffer skips recv_tokens allocation + (recv_topk_weights stays owned) and ep_dispatch requires caller-supplied + recv_tokens (symm-mem under zero-copy).""" + buf = self._make_buffer(caller_provides_dispatch_recv_tokens=True) + self.assertIsNone(buf.recv_tokens_symm_buf) + if ZERO_COPY: # recv_topk_weights is always buffer-owned in zero-copy + self.assertIsNotNone(buf.recv_topk_weights_symm_buf) + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + tokens_p = tokens.detach().clone().requires_grad_(True) + with self.assertRaises(ValueError): + ep_dispatch(buf, tokens_p, topk_idx, w) + if ZERO_COPY: + rc = self.cfg.recv_capacity_per_rank + rt_buf = symm_mem_alloc((rc, HIDDEN_DIM), torch.bfloat16, self.ep_group) + else: + rt_buf, _rw_buf, _ = self._make_raw_recv() + rt, rw, _ = ep_dispatch(buf, tokens_p, topk_idx, w, recv_tokens=rt_buf) + self.assertEqual(rt.data_ptr(), rt_buf.data_ptr()) + rt = self._stage_grad_symm(rt) + rw = self._stage_grad_symm(rw) + (0.5 * (rt.float() ** 2).sum() + 0.0 * rw.float().sum()).backward() + torch.cuda.synchronize() + torch.testing.assert_close( + tokens_p.grad.float(), tokens.float() * float(TOP_K), atol=5e-2, rtol=5e-2 + ) + + @_zero_copy_test_include + def test_caller_provides_combine_grad_buffer(self): + """caller_provides_combine_grad_buffer: EpBuffer skips combine-grad allocation + and ep_combine requires a caller-supplied grad buffer (symm-mem under zero-copy).""" + buf = self._make_buffer(caller_provides_combine_grad_buffer=True) + self.assertIsNone(buf.grad_combine_symm_buf) + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + tokens_p = tokens.detach().clone().requires_grad_(True) + recv_t, recv_w, _ = ep_dispatch(buf, tokens_p, topk_idx, w) + recv_t = self._stage_grad_symm(recv_t) + recv_w = self._stage_grad_symm(recv_w) + eo = self._expert_out(self._weighted(recv_t, recv_w)) + with self.assertRaises(ValueError): + ep_combine(buf, eo) + rc = self.cfg.recv_capacity_per_rank + if ZERO_COPY: + gbuf = symm_mem_alloc((rc, HIDDEN_DIM), torch.bfloat16, self.ep_group) + else: + gbuf = torch.empty(rc, HIDDEN_DIM, dtype=torch.bfloat16, device=self.cfg.device) + out = ep_combine(buf, eo, grad_combine_buffer=gbuf) + (0.5 * (out.float() ** 2).sum()).backward() + torch.cuda.synchronize() + torch.testing.assert_close(out.float(), tokens.float(), atol=5e-2, rtol=5e-2) + torch.testing.assert_close(tokens_p.grad.float(), tokens.float(), atol=5e-2, rtol=5e-2) + # Multi-iter stability def test_dispatch_autograd_multiple_iterations(self): diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index aee6372e67..2db9de38c5 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -209,13 +209,18 @@ class EpBuffer: "device", "token_counts", "zero_copy", + "caller_provides_dispatch_recv_tokens", + "caller_provides_combine_grad_buffer", "recv_tokens_symm_buf", "recv_topk_weights_symm_buf", "grad_combine_symm_buf", ) def _alloc_symm_buffers(self) -> None: - """Allocate the EpBuffer-owned symm-mem buffers; all None in normal mode.""" + """Allocate the EpBuffer-owned symm-mem buffers; all None in normal mode. + recv_topk_weights is always owned; recv_tokens is skipped under + caller_provides_dispatch_recv_tokens and the combine grad target under + caller_provides_combine_grad_buffer (the caller supplies those).""" if not self.zero_copy: self.recv_tokens_symm_buf = None self.recv_topk_weights_symm_buf = None @@ -224,12 +229,19 @@ def _alloc_symm_buffers(self) -> None: if _EP_GROUP is None: raise RuntimeError("ep_bootstrap must be called before constructing a zero-copy EpBuffer") rc, h = self.recv_capacity_per_rank, self.hidden_dim - self.recv_tokens_symm_buf = symm_mem_alloc((rc, h), self.payload_dtype, _EP_GROUP, device=self.device) - self.recv_topk_weights_symm_buf = symm_mem_alloc((rc,), torch.float32, _EP_GROUP, device=self.device) - self.grad_combine_symm_buf = symm_mem_alloc((rc, h), self.payload_dtype, _EP_GROUP, device=self.device) # Persistent across microbatches; keep resident under CPU offloading. - for buf in (self.recv_tokens_symm_buf, self.recv_topk_weights_symm_buf, self.grad_combine_symm_buf): - mark_not_offload(buf) + self.recv_topk_weights_symm_buf = symm_mem_alloc((rc,), torch.float32, _EP_GROUP, device=self.device) + mark_not_offload(self.recv_topk_weights_symm_buf) + if self.caller_provides_combine_grad_buffer: + self.grad_combine_symm_buf = None + else: + self.grad_combine_symm_buf = symm_mem_alloc((rc, h), self.payload_dtype, _EP_GROUP, device=self.device) + mark_not_offload(self.grad_combine_symm_buf) + if self.caller_provides_dispatch_recv_tokens: + self.recv_tokens_symm_buf = None + else: + self.recv_tokens_symm_buf = symm_mem_alloc((rc, h), self.payload_dtype, _EP_GROUP, device=self.device) + mark_not_offload(self.recv_tokens_symm_buf) def __init__( self, @@ -241,7 +253,15 @@ def __init__( alignment: int = 0, payload_dtype: torch.dtype = torch.bfloat16, device: Optional[torch.device] = None, + caller_provides_dispatch_recv_tokens: bool = False, + caller_provides_combine_grad_buffer: bool = False, ) -> None: + """``caller_provides_dispatch_recv_tokens`` declares that the caller passes + recv_tokens to ep_dispatch; ``caller_provides_combine_grad_buffer`` that the + caller passes the combine backward grad target to ep_combine (both symm-mem + under zero-copy). This buffer then does not allocate the declared tensor and + the corresponding op requires it to be supplied. recv_topk_weights is always + owned by the buffer.""" if device is None: device = torch.device("cuda", torch.cuda.current_device()) alignment = int(alignment) @@ -256,6 +276,8 @@ def __init__( self.payload_dtype = payload_dtype self.device = device self.zero_copy = bool(tex.ep_get_zero_copy()) + self.caller_provides_dispatch_recv_tokens = bool(caller_provides_dispatch_recv_tokens) + self.caller_provides_combine_grad_buffer = bool(caller_provides_combine_grad_buffer) size_bytes = tex.ep_handle_mem_size(self.top_k, self.alignment) self.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) @@ -550,13 +572,13 @@ def ep_dispatch( topk_weights: torch.Tensor, *, recv_tokens: Optional[torch.Tensor] = None, - recv_topk_weights: Optional[torch.Tensor] = None, ): """Prepare + dispatch with autograd. topk_idx must be int64. - recv_tokens / recv_topk_weights are used as-is if supplied, else taken from - the EpBuffer-owned symm-mem buffers in zero-copy mode or allocated in-flight - otherwise. + recv_tokens is used as-is if supplied, else taken from the EpBuffer-owned + symm-mem buffer (zero-copy) or allocated in-flight (normal mode) -- unless the + buffer was created with caller_provides_dispatch_recv_tokens, in which case it + must be supplied here. recv_topk_weights is always owned by the buffer. Returns (recv_tokens, recv_topk_weights, token_counts); token_counts is non-diff. """ _require_bf16("tokens", tokens) @@ -565,7 +587,7 @@ def ep_dispatch( f"topk_weights must be float32; got dtype={topk_weights.dtype}. " "Cast with topk_weights.float() before calling." ) - if recv_tokens is None: + if recv_tokens is None and not buffer.caller_provides_dispatch_recv_tokens: recv_tokens = ( buffer.recv_tokens_symm_buf if buffer.zero_copy @@ -576,14 +598,16 @@ def ep_dispatch( device=buffer.device, ) ) - if recv_topk_weights is None: - recv_topk_weights = ( - buffer.recv_topk_weights_symm_buf - if buffer.zero_copy - else torch.empty( - buffer.recv_capacity_per_rank, dtype=torch.float32, device=buffer.device - ) + if recv_tokens is None: + raise ValueError( + "ep_dispatch: buffer was created with caller_provides_dispatch_recv_tokens=True, " + "so recv_tokens must be supplied (symm-mem-backed under zero-copy)." ) + recv_topk_weights = ( + buffer.recv_topk_weights_symm_buf + if buffer.zero_copy + else torch.empty(buffer.recv_capacity_per_rank, dtype=torch.float32, device=buffer.device) + ) return _EpDispatch.apply( buffer.handle_mem, buffer.top_k, @@ -602,21 +626,31 @@ def ep_combine( expert_out: torch.Tensor, *, num_local_tokens: Optional[int] = None, + grad_combine_buffer: Optional[torch.Tensor] = None, ): """Combine with autograd; caller pre-applies topk weighting. - The backward scatters the expert_out grad into buffer.grad_combine_symm_buf - (symm-mem) in zero-copy mode, else into a tensor allocated in-flight. Result - shape is (num_local_tokens, buffer.hidden_dim); defaults to + The backward scatters the expert_out grad into grad_combine_buffer if + supplied, else the EpBuffer-owned symm-mem buffer (zero-copy) or a tensor + allocated in-flight (normal mode). When the buffer was created with + caller_provides_combine_grad_buffer, grad_combine_buffer must be supplied. + Result shape is (num_local_tokens, buffer.hidden_dim); defaults to buffer.max_tokens_per_rank rows. """ _require_bf16("expert_out", expert_out) if num_local_tokens is None: num_local_tokens = buffer.max_tokens_per_rank + if grad_combine_buffer is None and not buffer.caller_provides_combine_grad_buffer: + grad_combine_buffer = buffer.grad_combine_symm_buf + if grad_combine_buffer is None and buffer.caller_provides_combine_grad_buffer: + raise ValueError( + "ep_combine: buffer was created with caller_provides_combine_grad_buffer=True, " + "so grad_combine_buffer must be supplied (symm-mem-backed under zero-copy)." + ) return _EpCombine.apply( buffer.handle_mem, num_local_tokens, buffer.hidden_dim, - buffer.grad_combine_symm_buf, + grad_combine_buffer, expert_out, ) From 21dea506de21169ac36fa79db195af7e724699e6 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 24 Jun 2026 07:29:36 -0700 Subject: [PATCH 31/39] EP PyTorch: add opt-in caller_provides_dispatch_recv_tokens/caller_provides_combine_grad_buffer CLI flags to ep_moe example and ep_bench (default False) Signed-off-by: Phuong Nguyen --- examples/pytorch/ep/bench/ep_bench.py | 39 ++++++++++++++++++++++----- examples/pytorch/ep/ep_moe.py | 38 +++++++++++++++++++++----- 2 files changed, 65 insertions(+), 12 deletions(-) diff --git a/examples/pytorch/ep/bench/ep_bench.py b/examples/pytorch/ep/bench/ep_bench.py index 81f5b83883..15707d913b 100644 --- a/examples/pytorch/ep/bench/ep_bench.py +++ b/examples/pytorch/ep/bench/ep_bench.py @@ -74,6 +74,18 @@ def _parse_args(): default=None, help="Optional suffix for NVTX range names (e.g. 'fused' / 'unfused').", ) + p.add_argument( + "--caller-provides-dispatch-recv-tokens", + action="store_true", + default=False, + help="Supply recv_tokens to ep_dispatch instead of letting EpBuffer own it.", + ) + p.add_argument( + "--caller-provides-combine-grad-buffer", + action="store_true", + default=False, + help="Supply the combine backward grad buffer to ep_combine.", + ) return p.parse_args() @@ -182,6 +194,8 @@ def main(): recv_capacity_per_rank=recv_pr, hidden_dim=H, num_local_experts=num_local_experts, + caller_provides_dispatch_recv_tokens=args.caller_provides_dispatch_recv_tokens, + caller_provides_combine_grad_buffer=args.caller_provides_combine_grad_buffer, ) tokens = tokens_hbm @@ -189,6 +203,19 @@ def main(): recv_tokens = torch.empty(recv_pr, H, dtype=torch.bfloat16, device=device) recv_w = torch.empty(recv_pr, dtype=torch.float32, device=device) + # Caller-supplied buffers for the autograd ep_dispatch/ep_combine stages + # (normal mode -> plain tensors), reused across iters. Empty when not opted in. + dispatch_recv_kw = ( + {"recv_tokens": torch.empty(recv_pr, H, dtype=torch.bfloat16, device=device)} + if args.caller_provides_dispatch_recv_tokens + else {} + ) + combine_grad_kw = ( + {"grad_combine_buffer": torch.empty(recv_pr, H, dtype=torch.bfloat16, device=device)} + if args.caller_provides_combine_grad_buffer + else {} + ) + # -- Prepare once outside the timed loops ------------------------------ ep_prepare(buffer, topk_idx) torch.cuda.synchronize() @@ -208,8 +235,8 @@ def main(): eo_p = recv_tokens.detach().clone().requires_grad_(True) # Stand-in callables; the cuda-graph branch below swaps in graphed versions. - fwd_bwd_dispatch_fn = lambda x: ep_dispatch(buffer, x, topk_idx, topk_w)[0] # noqa: E731 - fwd_bwd_combine_fn = lambda eo: ep_combine(buffer, eo) # noqa: E731 + fwd_bwd_dispatch_fn = lambda x: ep_dispatch(buffer, x, topk_idx, topk_w, **dispatch_recv_kw)[0] # noqa: E731 + fwd_bwd_combine_fn = lambda eo: ep_combine(buffer, eo, **combine_grad_kw) # noqa: E731 def _dispatch_raw(): _ep_dispatch_raw(buffer, topk_idx, tokens, topk_w, recv_tokens, recv_w) @@ -219,7 +246,7 @@ def _combine_raw(): _ep_combine_raw(buffer, expert_out, out_buf) def _ep_dispatch_fwd(): - ep_dispatch(buffer, tokens.detach(), topk_idx, topk_w) + ep_dispatch(buffer, tokens.detach(), topk_idx, topk_w, **dispatch_recv_kw) def _ep_dispatch_fwd_bwd(): tokens_p.grad = None @@ -227,7 +254,7 @@ def _ep_dispatch_fwd_bwd(): (0.5 * (r * r).sum(dtype=torch.float32)).backward() def _ep_combine_fwd(): - ep_combine(buffer, recv_tokens) + ep_combine(buffer, recv_tokens, **combine_grad_kw) def _ep_combine_fwd_bwd(): eo_p.grad = None @@ -261,11 +288,11 @@ def _ep_combine_fwd_bwd(): # Graph fwd+bwd of the autograd-wrapped ops via make_graphed_callables. class _DispatchMod(torch.nn.Module): def forward(self, x): - return ep_dispatch(buffer, x, topk_idx, topk_w)[0] + return ep_dispatch(buffer, x, topk_idx, topk_w, **dispatch_recv_kw)[0] class _CombineMod(torch.nn.Module): def forward(self, eo): - return ep_combine(buffer, eo) + return ep_combine(buffer, eo, **combine_grad_kw) disp_mod = _DispatchMod().cuda() comb_mod = _CombineMod().cuda() diff --git a/examples/pytorch/ep/ep_moe.py b/examples/pytorch/ep/ep_moe.py index f72912301b..ee009a13ce 100644 --- a/examples/pytorch/ep/ep_moe.py +++ b/examples/pytorch/ep/ep_moe.py @@ -38,6 +38,18 @@ def _parse_args(): ) p.add_argument("--benchmark-iters", type=int, default=20) p.add_argument("--benchmark-warmup", type=int, default=5) + p.add_argument( + "--caller-provides-dispatch-recv-tokens", + action="store_true", + default=False, + help="Supply recv_tokens to ep_dispatch instead of letting EpBuffer own it.", + ) + p.add_argument( + "--caller-provides-combine-grad-buffer", + action="store_true", + default=False, + help="Supply the combine backward grad buffer to ep_combine.", + ) return p.parse_args() @@ -158,13 +170,27 @@ def _run_layer(args, rank, world_size, ep_size, num_experts, num_local_experts, recv_capacity_per_rank=recv_pr, hidden_dim=args.hidden, num_local_experts=num_local_experts, + caller_provides_dispatch_recv_tokens=args.caller_provides_dispatch_recv_tokens, + caller_provides_combine_grad_buffer=args.caller_provides_combine_grad_buffer, + ) + + # Caller-supplied buffers (normal mode -> plain tensors), reused across iters. + dispatch_kw = ( + {"recv_tokens": torch.empty(recv_pr, args.hidden, dtype=torch.bfloat16, device=device)} + if args.caller_provides_dispatch_recv_tokens + else {} + ) + combine_kw = ( + {"grad_combine_buffer": torch.empty(recv_pr, args.hidden, dtype=torch.bfloat16, device=device)} + if args.caller_provides_combine_grad_buffer + else {} ) - recv_t, recv_w_out, _tc = ep_dispatch(buffer, tokens, topk_idx, topk_w) + recv_t, recv_w_out, _tc = ep_dispatch(buffer, tokens, topk_idx, topk_w, **dispatch_kw) expert_out = _batched_expert_linear(recv_t, kernels_local, num_local_experts) # Apply per-slot topk weighting before combine. expert_out = expert_out * recv_w_out.unsqueeze(-1).to(expert_out.dtype) - out = ep_combine(buffer, expert_out) + out = ep_combine(buffer, expert_out, **combine_kw) loss = 0.5 * (out.float() ** 2).sum() loss.backward() @@ -183,18 +209,18 @@ def _run_layer(args, rank, world_size, ep_size, num_experts, num_local_experts, torch.cuda.synchronize() dist.barrier() for _ in range(args.benchmark_warmup): - rt, rw, _tc = ep_dispatch(buffer, tokens.detach(), topk_idx, topk_w) + rt, rw, _tc = ep_dispatch(buffer, tokens.detach(), topk_idx, topk_w, **dispatch_kw) eo = _batched_expert_linear(rt, kernels_local, num_local_experts) eo = eo * rw.unsqueeze(-1).to(eo.dtype) - ep_combine(buffer, eo) + ep_combine(buffer, eo, **combine_kw) torch.cuda.synchronize() dist.barrier() t0 = time.perf_counter() for _ in range(args.benchmark_iters): - rt, rw, _tc = ep_dispatch(buffer, tokens.detach(), topk_idx, topk_w) + rt, rw, _tc = ep_dispatch(buffer, tokens.detach(), topk_idx, topk_w, **dispatch_kw) eo = _batched_expert_linear(rt, kernels_local, num_local_experts) eo = eo * rw.unsqueeze(-1).to(eo.dtype) - ep_combine(buffer, eo) + ep_combine(buffer, eo, **combine_kw) torch.cuda.synchronize() dt_ms = (time.perf_counter() - t0) * 1000.0 / args.benchmark_iters if rank == 0: From 63d35bd84a1496581d662edc5f5c9bc5bcf2bd8d Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 24 Jun 2026 07:43:43 -0700 Subject: [PATCH 32/39] EP PyTorch: rename combine grad buffer API to grad_expert_out across buffer, dispatch/combine, tests and examples Signed-off-by: Phuong Nguyen --- examples/pytorch/ep/bench/ep_bench.py | 8 +++--- examples/pytorch/ep/ep_moe.py | 8 +++--- tests/pytorch/distributed/run_ep.py | 14 +++++----- transformer_engine/pytorch/ep.py | 40 +++++++++++++-------------- 4 files changed, 35 insertions(+), 35 deletions(-) diff --git a/examples/pytorch/ep/bench/ep_bench.py b/examples/pytorch/ep/bench/ep_bench.py index 15707d913b..f8dc0ca2e3 100644 --- a/examples/pytorch/ep/bench/ep_bench.py +++ b/examples/pytorch/ep/bench/ep_bench.py @@ -81,7 +81,7 @@ def _parse_args(): help="Supply recv_tokens to ep_dispatch instead of letting EpBuffer own it.", ) p.add_argument( - "--caller-provides-combine-grad-buffer", + "--caller-provides-grad-expert-out", action="store_true", default=False, help="Supply the combine backward grad buffer to ep_combine.", @@ -195,7 +195,7 @@ def main(): hidden_dim=H, num_local_experts=num_local_experts, caller_provides_dispatch_recv_tokens=args.caller_provides_dispatch_recv_tokens, - caller_provides_combine_grad_buffer=args.caller_provides_combine_grad_buffer, + caller_provides_grad_expert_out=args.caller_provides_grad_expert_out, ) tokens = tokens_hbm @@ -211,8 +211,8 @@ def main(): else {} ) combine_grad_kw = ( - {"grad_combine_buffer": torch.empty(recv_pr, H, dtype=torch.bfloat16, device=device)} - if args.caller_provides_combine_grad_buffer + {"grad_expert_out": torch.empty(recv_pr, H, dtype=torch.bfloat16, device=device)} + if args.caller_provides_grad_expert_out else {} ) diff --git a/examples/pytorch/ep/ep_moe.py b/examples/pytorch/ep/ep_moe.py index ee009a13ce..f9a8015e7b 100644 --- a/examples/pytorch/ep/ep_moe.py +++ b/examples/pytorch/ep/ep_moe.py @@ -45,7 +45,7 @@ def _parse_args(): help="Supply recv_tokens to ep_dispatch instead of letting EpBuffer own it.", ) p.add_argument( - "--caller-provides-combine-grad-buffer", + "--caller-provides-grad-expert-out", action="store_true", default=False, help="Supply the combine backward grad buffer to ep_combine.", @@ -171,7 +171,7 @@ def _run_layer(args, rank, world_size, ep_size, num_experts, num_local_experts, hidden_dim=args.hidden, num_local_experts=num_local_experts, caller_provides_dispatch_recv_tokens=args.caller_provides_dispatch_recv_tokens, - caller_provides_combine_grad_buffer=args.caller_provides_combine_grad_buffer, + caller_provides_grad_expert_out=args.caller_provides_grad_expert_out, ) # Caller-supplied buffers (normal mode -> plain tensors), reused across iters. @@ -181,8 +181,8 @@ def _run_layer(args, rank, world_size, ep_size, num_experts, num_local_experts, else {} ) combine_kw = ( - {"grad_combine_buffer": torch.empty(recv_pr, args.hidden, dtype=torch.bfloat16, device=device)} - if args.caller_provides_combine_grad_buffer + {"grad_expert_out": torch.empty(recv_pr, args.hidden, dtype=torch.bfloat16, device=device)} + if args.caller_provides_grad_expert_out else {} ) diff --git a/tests/pytorch/distributed/run_ep.py b/tests/pytorch/distributed/run_ep.py index 60e81e0c8e..9afcfe6dac 100644 --- a/tests/pytorch/distributed/run_ep.py +++ b/tests/pytorch/distributed/run_ep.py @@ -160,7 +160,7 @@ def _make_buffer( alignment=0, top_k=TOP_K, caller_provides_dispatch_recv_tokens=False, - caller_provides_combine_grad_buffer=False, + caller_provides_grad_expert_out=False, ): return EpBuffer( top_k=top_k, @@ -170,7 +170,7 @@ def _make_buffer( num_local_experts=NUM_LOCAL_EXPERTS, alignment=alignment, caller_provides_dispatch_recv_tokens=caller_provides_dispatch_recv_tokens, - caller_provides_combine_grad_buffer=caller_provides_combine_grad_buffer, + caller_provides_grad_expert_out=caller_provides_grad_expert_out, ) def _expert_out(self, eo): @@ -296,11 +296,11 @@ def test_caller_provides_dispatch_recv_tokens(self): ) @_zero_copy_test_include - def test_caller_provides_combine_grad_buffer(self): - """caller_provides_combine_grad_buffer: EpBuffer skips combine-grad allocation + def test_caller_provides_grad_expert_out(self): + """caller_provides_grad_expert_out: EpBuffer skips combine-grad allocation and ep_combine requires a caller-supplied grad buffer (symm-mem under zero-copy).""" - buf = self._make_buffer(caller_provides_combine_grad_buffer=True) - self.assertIsNone(buf.grad_combine_symm_buf) + buf = self._make_buffer(caller_provides_grad_expert_out=True) + self.assertIsNone(buf.grad_expert_out_symm_buf) topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) tokens_p = tokens.detach().clone().requires_grad_(True) recv_t, recv_w, _ = ep_dispatch(buf, tokens_p, topk_idx, w) @@ -314,7 +314,7 @@ def test_caller_provides_combine_grad_buffer(self): gbuf = symm_mem_alloc((rc, HIDDEN_DIM), torch.bfloat16, self.ep_group) else: gbuf = torch.empty(rc, HIDDEN_DIM, dtype=torch.bfloat16, device=self.cfg.device) - out = ep_combine(buf, eo, grad_combine_buffer=gbuf) + out = ep_combine(buf, eo, grad_expert_out=gbuf) (0.5 * (out.float() ** 2).sum()).backward() torch.cuda.synchronize() torch.testing.assert_close(out.float(), tokens.float(), atol=5e-2, rtol=5e-2) diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 2db9de38c5..1c7d053a94 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -210,21 +210,21 @@ class EpBuffer: "token_counts", "zero_copy", "caller_provides_dispatch_recv_tokens", - "caller_provides_combine_grad_buffer", + "caller_provides_grad_expert_out", "recv_tokens_symm_buf", "recv_topk_weights_symm_buf", - "grad_combine_symm_buf", + "grad_expert_out_symm_buf", ) def _alloc_symm_buffers(self) -> None: """Allocate the EpBuffer-owned symm-mem buffers; all None in normal mode. recv_topk_weights is always owned; recv_tokens is skipped under caller_provides_dispatch_recv_tokens and the combine grad target under - caller_provides_combine_grad_buffer (the caller supplies those).""" + caller_provides_grad_expert_out (the caller supplies those).""" if not self.zero_copy: self.recv_tokens_symm_buf = None self.recv_topk_weights_symm_buf = None - self.grad_combine_symm_buf = None + self.grad_expert_out_symm_buf = None return if _EP_GROUP is None: raise RuntimeError("ep_bootstrap must be called before constructing a zero-copy EpBuffer") @@ -232,11 +232,11 @@ def _alloc_symm_buffers(self) -> None: # Persistent across microbatches; keep resident under CPU offloading. self.recv_topk_weights_symm_buf = symm_mem_alloc((rc,), torch.float32, _EP_GROUP, device=self.device) mark_not_offload(self.recv_topk_weights_symm_buf) - if self.caller_provides_combine_grad_buffer: - self.grad_combine_symm_buf = None + if self.caller_provides_grad_expert_out: + self.grad_expert_out_symm_buf = None else: - self.grad_combine_symm_buf = symm_mem_alloc((rc, h), self.payload_dtype, _EP_GROUP, device=self.device) - mark_not_offload(self.grad_combine_symm_buf) + self.grad_expert_out_symm_buf = symm_mem_alloc((rc, h), self.payload_dtype, _EP_GROUP, device=self.device) + mark_not_offload(self.grad_expert_out_symm_buf) if self.caller_provides_dispatch_recv_tokens: self.recv_tokens_symm_buf = None else: @@ -254,10 +254,10 @@ def __init__( payload_dtype: torch.dtype = torch.bfloat16, device: Optional[torch.device] = None, caller_provides_dispatch_recv_tokens: bool = False, - caller_provides_combine_grad_buffer: bool = False, + caller_provides_grad_expert_out: bool = False, ) -> None: """``caller_provides_dispatch_recv_tokens`` declares that the caller passes - recv_tokens to ep_dispatch; ``caller_provides_combine_grad_buffer`` that the + recv_tokens to ep_dispatch; ``caller_provides_grad_expert_out`` that the caller passes the combine backward grad target to ep_combine (both symm-mem under zero-copy). This buffer then does not allocate the declared tensor and the corresponding op requires it to be supplied. recv_topk_weights is always @@ -277,7 +277,7 @@ def __init__( self.device = device self.zero_copy = bool(tex.ep_get_zero_copy()) self.caller_provides_dispatch_recv_tokens = bool(caller_provides_dispatch_recv_tokens) - self.caller_provides_combine_grad_buffer = bool(caller_provides_combine_grad_buffer) + self.caller_provides_grad_expert_out = bool(caller_provides_grad_expert_out) size_bytes = tex.ep_handle_mem_size(self.top_k, self.alignment) self.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) @@ -626,31 +626,31 @@ def ep_combine( expert_out: torch.Tensor, *, num_local_tokens: Optional[int] = None, - grad_combine_buffer: Optional[torch.Tensor] = None, + grad_expert_out: Optional[torch.Tensor] = None, ): """Combine with autograd; caller pre-applies topk weighting. - The backward scatters the expert_out grad into grad_combine_buffer if + The backward scatters the expert_out grad into grad_expert_out if supplied, else the EpBuffer-owned symm-mem buffer (zero-copy) or a tensor allocated in-flight (normal mode). When the buffer was created with - caller_provides_combine_grad_buffer, grad_combine_buffer must be supplied. + caller_provides_grad_expert_out, grad_expert_out must be supplied. Result shape is (num_local_tokens, buffer.hidden_dim); defaults to buffer.max_tokens_per_rank rows. """ _require_bf16("expert_out", expert_out) if num_local_tokens is None: num_local_tokens = buffer.max_tokens_per_rank - if grad_combine_buffer is None and not buffer.caller_provides_combine_grad_buffer: - grad_combine_buffer = buffer.grad_combine_symm_buf - if grad_combine_buffer is None and buffer.caller_provides_combine_grad_buffer: + if grad_expert_out is None and not buffer.caller_provides_grad_expert_out: + grad_expert_out = buffer.grad_expert_out_symm_buf + if grad_expert_out is None and buffer.caller_provides_grad_expert_out: raise ValueError( - "ep_combine: buffer was created with caller_provides_combine_grad_buffer=True, " - "so grad_combine_buffer must be supplied (symm-mem-backed under zero-copy)." + "ep_combine: buffer was created with caller_provides_grad_expert_out=True, " + "so grad_expert_out must be supplied (symm-mem-backed under zero-copy)." ) return _EpCombine.apply( buffer.handle_mem, num_local_tokens, buffer.hidden_dim, - grad_combine_buffer, + grad_expert_out, expert_out, ) From 82aabaf73eedef4273598322e2cef82202ed9ea5 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 24 Jun 2026 07:52:39 -0700 Subject: [PATCH 33/39] EP PyTorch: rename eo local to expert_out in tests and examples Signed-off-by: Phuong Nguyen --- examples/pytorch/ep/bench/ep_bench.py | 6 +++--- examples/pytorch/ep/ep_moe.py | 12 ++++++------ tests/pytorch/distributed/run_ep.py | 22 +++++++++++----------- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/examples/pytorch/ep/bench/ep_bench.py b/examples/pytorch/ep/bench/ep_bench.py index f8dc0ca2e3..21f73044d7 100644 --- a/examples/pytorch/ep/bench/ep_bench.py +++ b/examples/pytorch/ep/bench/ep_bench.py @@ -236,7 +236,7 @@ def main(): # Stand-in callables; the cuda-graph branch below swaps in graphed versions. fwd_bwd_dispatch_fn = lambda x: ep_dispatch(buffer, x, topk_idx, topk_w, **dispatch_recv_kw)[0] # noqa: E731 - fwd_bwd_combine_fn = lambda eo: ep_combine(buffer, eo, **combine_grad_kw) # noqa: E731 + fwd_bwd_combine_fn = lambda expert_out: ep_combine(buffer, expert_out, **combine_grad_kw) # noqa: E731 def _dispatch_raw(): _ep_dispatch_raw(buffer, topk_idx, tokens, topk_w, recv_tokens, recv_w) @@ -291,8 +291,8 @@ def forward(self, x): return ep_dispatch(buffer, x, topk_idx, topk_w, **dispatch_recv_kw)[0] class _CombineMod(torch.nn.Module): - def forward(self, eo): - return ep_combine(buffer, eo, **combine_grad_kw) + def forward(self, expert_out): + return ep_combine(buffer, expert_out, **combine_grad_kw) disp_mod = _DispatchMod().cuda() comb_mod = _CombineMod().cuda() diff --git a/examples/pytorch/ep/ep_moe.py b/examples/pytorch/ep/ep_moe.py index f9a8015e7b..6093425a66 100644 --- a/examples/pytorch/ep/ep_moe.py +++ b/examples/pytorch/ep/ep_moe.py @@ -210,17 +210,17 @@ def _run_layer(args, rank, world_size, ep_size, num_experts, num_local_experts, dist.barrier() for _ in range(args.benchmark_warmup): rt, rw, _tc = ep_dispatch(buffer, tokens.detach(), topk_idx, topk_w, **dispatch_kw) - eo = _batched_expert_linear(rt, kernels_local, num_local_experts) - eo = eo * rw.unsqueeze(-1).to(eo.dtype) - ep_combine(buffer, eo, **combine_kw) + expert_out = _batched_expert_linear(rt, kernels_local, num_local_experts) + expert_out = expert_out * rw.unsqueeze(-1).to(expert_out.dtype) + ep_combine(buffer, expert_out, **combine_kw) torch.cuda.synchronize() dist.barrier() t0 = time.perf_counter() for _ in range(args.benchmark_iters): rt, rw, _tc = ep_dispatch(buffer, tokens.detach(), topk_idx, topk_w, **dispatch_kw) - eo = _batched_expert_linear(rt, kernels_local, num_local_experts) - eo = eo * rw.unsqueeze(-1).to(eo.dtype) - ep_combine(buffer, eo, **combine_kw) + expert_out = _batched_expert_linear(rt, kernels_local, num_local_experts) + expert_out = expert_out * rw.unsqueeze(-1).to(expert_out.dtype) + ep_combine(buffer, expert_out, **combine_kw) torch.cuda.synchronize() dt_ms = (time.perf_counter() - t0) * 1000.0 / args.benchmark_iters if rank == 0: diff --git a/tests/pytorch/distributed/run_ep.py b/tests/pytorch/distributed/run_ep.py index 9afcfe6dac..090a331f0e 100644 --- a/tests/pytorch/distributed/run_ep.py +++ b/tests/pytorch/distributed/run_ep.py @@ -173,12 +173,12 @@ def _make_buffer( caller_provides_grad_expert_out=caller_provides_grad_expert_out, ) - def _expert_out(self, eo): + def _expert_out(self, expert_out): """Stage the combine input into symm-mem under zero-copy (combine requires it).""" if not ZERO_COPY: - return eo - symm_buf = symm_mem_alloc(tuple(eo.shape), eo.dtype, self.ep_group) - return _StageToSymm.apply(eo, symm_buf) + return expert_out + symm_buf = symm_mem_alloc(tuple(expert_out.shape), expert_out.dtype, self.ep_group) + return _StageToSymm.apply(expert_out, symm_buf) def _stage_grad_symm(self, x, symm_buf=None): """Route x's upstream grad through a symm-mem buffer so dispatch_bwd gets @@ -207,8 +207,8 @@ def _weighted(recv_tokens, recv_w): def _moe_step(self, buffer, topk_idx, tokens, w): recv_t, recv_w_out, _tc = ep_dispatch(buffer, tokens, topk_idx, w) - eo = self._weighted(recv_t, recv_w_out) - return ep_combine(buffer, eo) + expert_out = self._weighted(recv_t, recv_w_out) + return ep_combine(buffer, expert_out) # Prepare @@ -306,15 +306,15 @@ def test_caller_provides_grad_expert_out(self): recv_t, recv_w, _ = ep_dispatch(buf, tokens_p, topk_idx, w) recv_t = self._stage_grad_symm(recv_t) recv_w = self._stage_grad_symm(recv_w) - eo = self._expert_out(self._weighted(recv_t, recv_w)) + expert_out = self._expert_out(self._weighted(recv_t, recv_w)) with self.assertRaises(ValueError): - ep_combine(buf, eo) + ep_combine(buf, expert_out) rc = self.cfg.recv_capacity_per_rank if ZERO_COPY: gbuf = symm_mem_alloc((rc, HIDDEN_DIM), torch.bfloat16, self.ep_group) else: gbuf = torch.empty(rc, HIDDEN_DIM, dtype=torch.bfloat16, device=self.cfg.device) - out = ep_combine(buf, eo, grad_expert_out=gbuf) + out = ep_combine(buf, expert_out, grad_expert_out=gbuf) (0.5 * (out.float() ** 2).sum()).backward() torch.cuda.synchronize() torch.testing.assert_close(out.float(), tokens.float(), atol=5e-2, rtol=5e-2) @@ -479,8 +479,8 @@ def test_combine_autograd(self): recv_t, recv_w, _ = ep_dispatch(buf, tokens_p, topk_idx, w) recv_t = self._stage_grad_symm(recv_t) recv_w = self._stage_grad_symm(recv_w) - eo = self._expert_out(self._weighted(recv_t, recv_w)) - out = ep_combine(buf, eo) + expert_out = self._expert_out(self._weighted(recv_t, recv_w)) + out = ep_combine(buf, expert_out) (0.5 * (out.float() ** 2).sum()).backward() torch.cuda.synchronize() torch.testing.assert_close(out.float(), tokens.float(), atol=5e-2, rtol=5e-2) From 6ec78a3b2597849f2cc1cad955c9cc95e23343a7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Jun 2026 15:17:02 +0000 Subject: [PATCH 34/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/ep/bench/ep_bench.py | 8 ++++++-- transformer_engine/pytorch/ep.py | 16 ++++++++++++---- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/examples/pytorch/ep/bench/ep_bench.py b/examples/pytorch/ep/bench/ep_bench.py index 21f73044d7..835939b702 100644 --- a/examples/pytorch/ep/bench/ep_bench.py +++ b/examples/pytorch/ep/bench/ep_bench.py @@ -235,8 +235,12 @@ def main(): eo_p = recv_tokens.detach().clone().requires_grad_(True) # Stand-in callables; the cuda-graph branch below swaps in graphed versions. - fwd_bwd_dispatch_fn = lambda x: ep_dispatch(buffer, x, topk_idx, topk_w, **dispatch_recv_kw)[0] # noqa: E731 - fwd_bwd_combine_fn = lambda expert_out: ep_combine(buffer, expert_out, **combine_grad_kw) # noqa: E731 + fwd_bwd_dispatch_fn = lambda x: ep_dispatch(buffer, x, topk_idx, topk_w, **dispatch_recv_kw)[ + 0 + ] # noqa: E731 + fwd_bwd_combine_fn = lambda expert_out: ep_combine( + buffer, expert_out, **combine_grad_kw + ) # noqa: E731 def _dispatch_raw(): _ep_dispatch_raw(buffer, topk_idx, tokens, topk_w, recv_tokens, recv_w) diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 1c7d053a94..63c3c3c2a4 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -227,20 +227,28 @@ def _alloc_symm_buffers(self) -> None: self.grad_expert_out_symm_buf = None return if _EP_GROUP is None: - raise RuntimeError("ep_bootstrap must be called before constructing a zero-copy EpBuffer") + raise RuntimeError( + "ep_bootstrap must be called before constructing a zero-copy EpBuffer" + ) rc, h = self.recv_capacity_per_rank, self.hidden_dim # Persistent across microbatches; keep resident under CPU offloading. - self.recv_topk_weights_symm_buf = symm_mem_alloc((rc,), torch.float32, _EP_GROUP, device=self.device) + self.recv_topk_weights_symm_buf = symm_mem_alloc( + (rc,), torch.float32, _EP_GROUP, device=self.device + ) mark_not_offload(self.recv_topk_weights_symm_buf) if self.caller_provides_grad_expert_out: self.grad_expert_out_symm_buf = None else: - self.grad_expert_out_symm_buf = symm_mem_alloc((rc, h), self.payload_dtype, _EP_GROUP, device=self.device) + self.grad_expert_out_symm_buf = symm_mem_alloc( + (rc, h), self.payload_dtype, _EP_GROUP, device=self.device + ) mark_not_offload(self.grad_expert_out_symm_buf) if self.caller_provides_dispatch_recv_tokens: self.recv_tokens_symm_buf = None else: - self.recv_tokens_symm_buf = symm_mem_alloc((rc, h), self.payload_dtype, _EP_GROUP, device=self.device) + self.recv_tokens_symm_buf = symm_mem_alloc( + (rc, h), self.payload_dtype, _EP_GROUP, device=self.device + ) mark_not_offload(self.recv_tokens_symm_buf) def __init__( From 3a023ef7164eb9135372289435e5a06f22537302 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 25 Jun 2026 05:09:27 -0700 Subject: [PATCH 35/39] EP PyTorch: move symm_mem_alloc to distributed.py and re-export from ep Signed-off-by: Phuong Nguyen --- transformer_engine/pytorch/distributed.py | 21 +++++++++++++++ transformer_engine/pytorch/ep.py | 31 +++-------------------- 2 files changed, 25 insertions(+), 27 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 670eecaa5e..949846f839 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1842,6 +1842,27 @@ def get_symmetric_memory_tensor(tensor_numel, tensor_dtype, tensor_device, tp_gr return msg +def symm_mem_alloc( + shape, + dtype: torch.dtype, + ep_group: dist_group_type, + device: Optional[torch.device] = None, +) -> torch.Tensor: + """Allocate and rendezvous a symm-mem buffer on ep_group. Collective on ep_group.""" + if device is None: + device = torch.device("cuda", torch.cuda.current_device()) + if not HAS_TORCH_SYMMETRIC: + raise RuntimeError( + "torch.distributed._symmetric_memory is unavailable; symm_mem_alloc " + "requires PyTorch built with NCCL symm-mem support." + ) + if symm_mem.get_backend(device) != "NCCL": + symm_mem.set_backend("NCCL") + t = symm_mem.empty(*shape, dtype=dtype, device=device) + symm_mem.rendezvous(t, group=ep_group.group_name) + return t + + def symmetric_all_reduce( inp: torch.Tensor, tp_group: Optional[dist_group_type] = None, diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 63c3c3c2a4..123976e953 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -15,6 +15,7 @@ import transformer_engine_torch as tex from .cpu_offload import mark_not_offload +from .distributed import symm_mem_alloc __all__ = [ @@ -27,33 +28,9 @@ ] -# Symmetric-memory buffer allocator -# -# Used for the symm-mem zero-copy IO path. Set ``ep_bootstrap(zero_copy=True)`` -# to opt in; the C++ backend then operates the EP group in zero-copy mode. - - -def symm_mem_alloc( - shape, - dtype: torch.dtype, - ep_group: dist.ProcessGroup, - device: Optional[torch.device] = None, -) -> torch.Tensor: - """Allocate and rendezvous a symm-mem buffer on ep_group. Collective on ep_group.""" - if device is None: - device = torch.device("cuda", torch.cuda.current_device()) - try: - from torch.distributed import _symmetric_memory as _symm_mem - except ImportError as e: - raise RuntimeError( - "torch.distributed._symmetric_memory is unavailable; symm_mem_alloc " - "requires PyTorch built with NCCL symm-mem support." - ) from e - if _symm_mem.get_backend(device) != "NCCL": - _symm_mem.set_backend("NCCL") - t = _symm_mem.empty(*shape, dtype=dtype, device=device) - _symm_mem.rendezvous(t, group=ep_group.group_name) - return t +# ``symm_mem_alloc`` (imported from .distributed) allocates the symm-mem buffers +# used by the zero-copy IO path. Set ``ep_bootstrap(zero_copy=True)`` to opt in; +# the C++ backend then operates the EP group in zero-copy mode. # Bootstrap From 3ff8a4e4c7628b3931aa44523c652b64f7ef65ae Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 25 Jun 2026 06:27:00 -0700 Subject: [PATCH 36/39] EP PyTorch: replace caller_provides_* bools with dispatch_recv_tokens/combine_grad_expert_out buffers on EpBuffer Signed-off-by: Phuong Nguyen --- examples/pytorch/ep/bench/ep_bench.py | 46 ++++++------- examples/pytorch/ep/ep_moe.py | 32 ++++----- tests/pytorch/distributed/run_ep.py | 67 +++++++++---------- transformer_engine/pytorch/ep.py | 96 +++++++++------------------ 4 files changed, 101 insertions(+), 140 deletions(-) diff --git a/examples/pytorch/ep/bench/ep_bench.py b/examples/pytorch/ep/bench/ep_bench.py index 835939b702..2b7a2c62e5 100644 --- a/examples/pytorch/ep/bench/ep_bench.py +++ b/examples/pytorch/ep/bench/ep_bench.py @@ -188,14 +188,27 @@ def main(): topk_idx, tokens_hbm, topk_w_hbm = _make_inputs(rank, world_size, T, H, K, E, device) + # Caller-supplied buffers for the autograd ep_dispatch/ep_combine stages + # (normal mode -> plain tensors), reused across iters. None when not opted in. + caller_recv_tokens = ( + torch.empty(recv_pr, H, dtype=torch.bfloat16, device=device) + if args.caller_provides_dispatch_recv_tokens + else None + ) + caller_grad_expert_out = ( + torch.empty(recv_pr, H, dtype=torch.bfloat16, device=device) + if args.caller_provides_grad_expert_out + else None + ) + buffer = EpBuffer( top_k=K, max_tokens_per_rank=T, recv_capacity_per_rank=recv_pr, hidden_dim=H, num_local_experts=num_local_experts, - caller_provides_dispatch_recv_tokens=args.caller_provides_dispatch_recv_tokens, - caller_provides_grad_expert_out=args.caller_provides_grad_expert_out, + dispatch_recv_tokens=caller_recv_tokens, + combine_grad_expert_out=caller_grad_expert_out, ) tokens = tokens_hbm @@ -203,19 +216,6 @@ def main(): recv_tokens = torch.empty(recv_pr, H, dtype=torch.bfloat16, device=device) recv_w = torch.empty(recv_pr, dtype=torch.float32, device=device) - # Caller-supplied buffers for the autograd ep_dispatch/ep_combine stages - # (normal mode -> plain tensors), reused across iters. Empty when not opted in. - dispatch_recv_kw = ( - {"recv_tokens": torch.empty(recv_pr, H, dtype=torch.bfloat16, device=device)} - if args.caller_provides_dispatch_recv_tokens - else {} - ) - combine_grad_kw = ( - {"grad_expert_out": torch.empty(recv_pr, H, dtype=torch.bfloat16, device=device)} - if args.caller_provides_grad_expert_out - else {} - ) - # -- Prepare once outside the timed loops ------------------------------ ep_prepare(buffer, topk_idx) torch.cuda.synchronize() @@ -235,12 +235,8 @@ def main(): eo_p = recv_tokens.detach().clone().requires_grad_(True) # Stand-in callables; the cuda-graph branch below swaps in graphed versions. - fwd_bwd_dispatch_fn = lambda x: ep_dispatch(buffer, x, topk_idx, topk_w, **dispatch_recv_kw)[ - 0 - ] # noqa: E731 - fwd_bwd_combine_fn = lambda expert_out: ep_combine( - buffer, expert_out, **combine_grad_kw - ) # noqa: E731 + fwd_bwd_dispatch_fn = lambda x: ep_dispatch(buffer, x, topk_idx, topk_w)[0] # noqa: E731 + fwd_bwd_combine_fn = lambda expert_out: ep_combine(buffer, expert_out) # noqa: E731 def _dispatch_raw(): _ep_dispatch_raw(buffer, topk_idx, tokens, topk_w, recv_tokens, recv_w) @@ -250,7 +246,7 @@ def _combine_raw(): _ep_combine_raw(buffer, expert_out, out_buf) def _ep_dispatch_fwd(): - ep_dispatch(buffer, tokens.detach(), topk_idx, topk_w, **dispatch_recv_kw) + ep_dispatch(buffer, tokens.detach(), topk_idx, topk_w) def _ep_dispatch_fwd_bwd(): tokens_p.grad = None @@ -258,7 +254,7 @@ def _ep_dispatch_fwd_bwd(): (0.5 * (r * r).sum(dtype=torch.float32)).backward() def _ep_combine_fwd(): - ep_combine(buffer, recv_tokens, **combine_grad_kw) + ep_combine(buffer, recv_tokens) def _ep_combine_fwd_bwd(): eo_p.grad = None @@ -292,11 +288,11 @@ def _ep_combine_fwd_bwd(): # Graph fwd+bwd of the autograd-wrapped ops via make_graphed_callables. class _DispatchMod(torch.nn.Module): def forward(self, x): - return ep_dispatch(buffer, x, topk_idx, topk_w, **dispatch_recv_kw)[0] + return ep_dispatch(buffer, x, topk_idx, topk_w)[0] class _CombineMod(torch.nn.Module): def forward(self, expert_out): - return ep_combine(buffer, expert_out, **combine_grad_kw) + return ep_combine(buffer, expert_out) disp_mod = _DispatchMod().cuda() comb_mod = _CombineMod().cuda() diff --git a/examples/pytorch/ep/ep_moe.py b/examples/pytorch/ep/ep_moe.py index 6093425a66..c93f8bb6e4 100644 --- a/examples/pytorch/ep/ep_moe.py +++ b/examples/pytorch/ep/ep_moe.py @@ -164,33 +164,33 @@ def _run_layer(args, rank, world_size, ep_size, num_experts, num_local_experts, kernels_np[rank * num_local_experts : (rank + 1) * num_local_experts] ).to(device=device, dtype=torch.bfloat16) + # Caller-supplied buffers (normal mode -> plain tensors), reused across iters. + recv_tokens = ( + torch.empty(recv_pr, args.hidden, dtype=torch.bfloat16, device=device) + if args.caller_provides_dispatch_recv_tokens + else None + ) + grad_expert_out = ( + torch.empty(recv_pr, args.hidden, dtype=torch.bfloat16, device=device) + if args.caller_provides_grad_expert_out + else None + ) + buffer = EpBuffer( top_k=args.top_k, max_tokens_per_rank=T, recv_capacity_per_rank=recv_pr, hidden_dim=args.hidden, num_local_experts=num_local_experts, - caller_provides_dispatch_recv_tokens=args.caller_provides_dispatch_recv_tokens, - caller_provides_grad_expert_out=args.caller_provides_grad_expert_out, - ) - - # Caller-supplied buffers (normal mode -> plain tensors), reused across iters. - dispatch_kw = ( - {"recv_tokens": torch.empty(recv_pr, args.hidden, dtype=torch.bfloat16, device=device)} - if args.caller_provides_dispatch_recv_tokens - else {} - ) - combine_kw = ( - {"grad_expert_out": torch.empty(recv_pr, args.hidden, dtype=torch.bfloat16, device=device)} - if args.caller_provides_grad_expert_out - else {} + dispatch_recv_tokens=recv_tokens, + combine_grad_expert_out=grad_expert_out, ) - recv_t, recv_w_out, _tc = ep_dispatch(buffer, tokens, topk_idx, topk_w, **dispatch_kw) + recv_t, recv_w_out, _tc = ep_dispatch(buffer, tokens, topk_idx, topk_w) expert_out = _batched_expert_linear(recv_t, kernels_local, num_local_experts) # Apply per-slot topk weighting before combine. expert_out = expert_out * recv_w_out.unsqueeze(-1).to(expert_out.dtype) - out = ep_combine(buffer, expert_out, **combine_kw) + out = ep_combine(buffer, expert_out) loss = 0.5 * (out.float() ** 2).sum() loss.backward() diff --git a/tests/pytorch/distributed/run_ep.py b/tests/pytorch/distributed/run_ep.py index 090a331f0e..0acc00cd57 100644 --- a/tests/pytorch/distributed/run_ep.py +++ b/tests/pytorch/distributed/run_ep.py @@ -159,8 +159,8 @@ def _make_buffer( self, alignment=0, top_k=TOP_K, - caller_provides_dispatch_recv_tokens=False, - caller_provides_grad_expert_out=False, + dispatch_recv_tokens=None, + combine_grad_expert_out=None, ): return EpBuffer( top_k=top_k, @@ -169,8 +169,8 @@ def _make_buffer( hidden_dim=HIDDEN_DIM, num_local_experts=NUM_LOCAL_EXPERTS, alignment=alignment, - caller_provides_dispatch_recv_tokens=caller_provides_dispatch_recv_tokens, - caller_provides_grad_expert_out=caller_provides_grad_expert_out, + dispatch_recv_tokens=dispatch_recv_tokens, + combine_grad_expert_out=combine_grad_expert_out, ) def _expert_out(self, expert_out): @@ -244,21 +244,21 @@ def test_dispatch_autograd(self): EpBuffer-owned recv tokens (symm-mem under zero-copy) and, in normal mode, a caller-supplied recv_tokens buffer.""" if ZERO_COPY: - cases = [("buffer_owned", {})] + cases = [("buffer_owned", None)] else: rt_buf, _rw_buf, _ = self._make_raw_recv() cases = [ - ("default_alloc", {}), - ("caller_recv", {"recv_tokens": rt_buf}), + ("default_alloc", None), + ("caller_recv", rt_buf), ] - for label, recv_kw in cases: + for label, recv_tokens in cases: with self.subTest(case=label): - buf = self._make_buffer() + buf = self._make_buffer(dispatch_recv_tokens=recv_tokens) topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) tokens_p = tokens.detach().clone().requires_grad_(True) - rt, rw, _tc = ep_dispatch(buf, tokens_p, topk_idx, w, **recv_kw) - if recv_kw: # caller-supplied recv_tokens must be used in place - self.assertEqual(rt.data_ptr(), recv_kw["recv_tokens"].data_ptr()) + rt, rw, _tc = ep_dispatch(buf, tokens_p, topk_idx, w) + if recv_tokens is not None: # caller-supplied recv_tokens must be used in place + self.assertEqual(rt.data_ptr(), recv_tokens.data_ptr()) rt = self._stage_grad_symm(rt) rw = self._stage_grad_symm(rw) (0.5 * (rt.float() ** 2).sum() + 0.0 * rw.float().sum()).backward() @@ -269,23 +269,20 @@ def test_dispatch_autograd(self): @_zero_copy_test_include def test_caller_provides_dispatch_recv_tokens(self): - """caller_provides_dispatch_recv_tokens: EpBuffer skips recv_tokens allocation - (recv_topk_weights stays owned) and ep_dispatch requires caller-supplied - recv_tokens (symm-mem under zero-copy).""" - buf = self._make_buffer(caller_provides_dispatch_recv_tokens=True) - self.assertIsNone(buf.recv_tokens_symm_buf) - if ZERO_COPY: # recv_topk_weights is always buffer-owned in zero-copy - self.assertIsNotNone(buf.recv_topk_weights_symm_buf) - topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) - tokens_p = tokens.detach().clone().requires_grad_(True) - with self.assertRaises(ValueError): - ep_dispatch(buf, tokens_p, topk_idx, w) + """Caller-supplied recv_tokens: EpBuffer adopts it (recv_topk_weights stays + owned) and ep_dispatch returns a view of the caller's buffer.""" if ZERO_COPY: rc = self.cfg.recv_capacity_per_rank rt_buf = symm_mem_alloc((rc, HIDDEN_DIM), torch.bfloat16, self.ep_group) else: rt_buf, _rw_buf, _ = self._make_raw_recv() - rt, rw, _ = ep_dispatch(buf, tokens_p, topk_idx, w, recv_tokens=rt_buf) + buf = self._make_buffer(dispatch_recv_tokens=rt_buf) + self.assertEqual(buf.recv_tokens_symm_buf.data_ptr(), rt_buf.data_ptr()) + if ZERO_COPY: # recv_topk_weights is always buffer-owned in zero-copy + self.assertIsNotNone(buf.recv_topk_weights_symm_buf) + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + tokens_p = tokens.detach().clone().requires_grad_(True) + rt, rw, _ = ep_dispatch(buf, tokens_p, topk_idx, w) self.assertEqual(rt.data_ptr(), rt_buf.data_ptr()) rt = self._stage_grad_symm(rt) rw = self._stage_grad_symm(rw) @@ -297,24 +294,22 @@ def test_caller_provides_dispatch_recv_tokens(self): @_zero_copy_test_include def test_caller_provides_grad_expert_out(self): - """caller_provides_grad_expert_out: EpBuffer skips combine-grad allocation - and ep_combine requires a caller-supplied grad buffer (symm-mem under zero-copy).""" - buf = self._make_buffer(caller_provides_grad_expert_out=True) - self.assertIsNone(buf.grad_expert_out_symm_buf) + """Caller-supplied grad_expert_out: EpBuffer adopts it as the combine + backward grad target (symm-mem under zero-copy).""" + rc = self.cfg.recv_capacity_per_rank + if ZERO_COPY: + gbuf = symm_mem_alloc((rc, HIDDEN_DIM), torch.bfloat16, self.ep_group) + else: + gbuf = torch.empty(rc, HIDDEN_DIM, dtype=torch.bfloat16, device=self.cfg.device) + buf = self._make_buffer(combine_grad_expert_out=gbuf) + self.assertEqual(buf.grad_expert_out_symm_buf.data_ptr(), gbuf.data_ptr()) topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) tokens_p = tokens.detach().clone().requires_grad_(True) recv_t, recv_w, _ = ep_dispatch(buf, tokens_p, topk_idx, w) recv_t = self._stage_grad_symm(recv_t) recv_w = self._stage_grad_symm(recv_w) expert_out = self._expert_out(self._weighted(recv_t, recv_w)) - with self.assertRaises(ValueError): - ep_combine(buf, expert_out) - rc = self.cfg.recv_capacity_per_rank - if ZERO_COPY: - gbuf = symm_mem_alloc((rc, HIDDEN_DIM), torch.bfloat16, self.ep_group) - else: - gbuf = torch.empty(rc, HIDDEN_DIM, dtype=torch.bfloat16, device=self.cfg.device) - out = ep_combine(buf, expert_out, grad_expert_out=gbuf) + out = ep_combine(buf, expert_out) (0.5 * (out.float() ** 2).sum()).backward() torch.cuda.synchronize() torch.testing.assert_close(out.float(), tokens.float(), atol=5e-2, rtol=5e-2) diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 123976e953..27a2052019 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -186,22 +186,17 @@ class EpBuffer: "device", "token_counts", "zero_copy", - "caller_provides_dispatch_recv_tokens", - "caller_provides_grad_expert_out", "recv_tokens_symm_buf", "recv_topk_weights_symm_buf", "grad_expert_out_symm_buf", ) def _alloc_symm_buffers(self) -> None: - """Allocate the EpBuffer-owned symm-mem buffers; all None in normal mode. - recv_topk_weights is always owned; recv_tokens is skipped under - caller_provides_dispatch_recv_tokens and the combine grad target under - caller_provides_grad_expert_out (the caller supplies those).""" + """Fill in buffer-owned symm-mem buffers the caller did not supply. + recv_topk_weights is always owned. In normal mode caller-supplied + tensors are kept as-is and the rest stay None (allocated in-flight).""" if not self.zero_copy: - self.recv_tokens_symm_buf = None self.recv_topk_weights_symm_buf = None - self.grad_expert_out_symm_buf = None return if _EP_GROUP is None: raise RuntimeError( @@ -213,20 +208,16 @@ def _alloc_symm_buffers(self) -> None: (rc,), torch.float32, _EP_GROUP, device=self.device ) mark_not_offload(self.recv_topk_weights_symm_buf) - if self.caller_provides_grad_expert_out: - self.grad_expert_out_symm_buf = None - else: - self.grad_expert_out_symm_buf = symm_mem_alloc( - (rc, h), self.payload_dtype, _EP_GROUP, device=self.device - ) - mark_not_offload(self.grad_expert_out_symm_buf) - if self.caller_provides_dispatch_recv_tokens: - self.recv_tokens_symm_buf = None - else: + if self.recv_tokens_symm_buf is None: self.recv_tokens_symm_buf = symm_mem_alloc( (rc, h), self.payload_dtype, _EP_GROUP, device=self.device ) mark_not_offload(self.recv_tokens_symm_buf) + if self.grad_expert_out_symm_buf is None: + self.grad_expert_out_symm_buf = symm_mem_alloc( + (rc, h), self.payload_dtype, _EP_GROUP, device=self.device + ) + mark_not_offload(self.grad_expert_out_symm_buf) def __init__( self, @@ -238,15 +229,14 @@ def __init__( alignment: int = 0, payload_dtype: torch.dtype = torch.bfloat16, device: Optional[torch.device] = None, - caller_provides_dispatch_recv_tokens: bool = False, - caller_provides_grad_expert_out: bool = False, + dispatch_recv_tokens: Optional[torch.Tensor] = None, + combine_grad_expert_out: Optional[torch.Tensor] = None, ) -> None: - """``caller_provides_dispatch_recv_tokens`` declares that the caller passes - recv_tokens to ep_dispatch; ``caller_provides_grad_expert_out`` that the - caller passes the combine backward grad target to ep_combine (both symm-mem - under zero-copy). This buffer then does not allocate the declared tensor and - the corresponding op requires it to be supplied. recv_topk_weights is always - owned by the buffer.""" + """Pass ``dispatch_recv_tokens`` (dispatch recv output) and/or + ``combine_grad_expert_out`` (combine backward grad target) to use caller-owned + buffers; the buffer then skips allocating them. Both must be symm-mem-backed + under zero-copy. Whatever is left None is buffer-owned (zero-copy) or allocated + in-flight (normal mode). recv_topk_weights is always owned by the buffer.""" if device is None: device = torch.device("cuda", torch.cuda.current_device()) alignment = int(alignment) @@ -261,8 +251,8 @@ def __init__( self.payload_dtype = payload_dtype self.device = device self.zero_copy = bool(tex.ep_get_zero_copy()) - self.caller_provides_dispatch_recv_tokens = bool(caller_provides_dispatch_recv_tokens) - self.caller_provides_grad_expert_out = bool(caller_provides_grad_expert_out) + self.recv_tokens_symm_buf = dispatch_recv_tokens + self.grad_expert_out_symm_buf = combine_grad_expert_out size_bytes = tex.ep_handle_mem_size(self.top_k, self.alignment) self.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) @@ -555,16 +545,13 @@ def ep_dispatch( tokens: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, - *, - recv_tokens: Optional[torch.Tensor] = None, ): """Prepare + dispatch with autograd. topk_idx must be int64. - recv_tokens is used as-is if supplied, else taken from the EpBuffer-owned - symm-mem buffer (zero-copy) or allocated in-flight (normal mode) -- unless the - buffer was created with caller_provides_dispatch_recv_tokens, in which case it - must be supplied here. recv_topk_weights is always owned by the buffer. - Returns (recv_tokens, recv_topk_weights, token_counts); token_counts is non-diff. + recv_tokens comes from the EpBuffer (caller-supplied or buffer-owned under + zero-copy) or is allocated in-flight (normal mode). recv_topk_weights is always + owned by the buffer. Returns (recv_tokens, recv_topk_weights, token_counts); + token_counts is non-diff. """ _require_bf16("tokens", tokens) if topk_weights.dtype is not torch.float32: @@ -572,21 +559,13 @@ def ep_dispatch( f"topk_weights must be float32; got dtype={topk_weights.dtype}. " "Cast with topk_weights.float() before calling." ) - if recv_tokens is None and not buffer.caller_provides_dispatch_recv_tokens: - recv_tokens = ( - buffer.recv_tokens_symm_buf - if buffer.zero_copy - else torch.empty( - buffer.recv_capacity_per_rank, - buffer.hidden_dim, - dtype=buffer.payload_dtype, - device=buffer.device, - ) - ) + recv_tokens = buffer.recv_tokens_symm_buf if recv_tokens is None: - raise ValueError( - "ep_dispatch: buffer was created with caller_provides_dispatch_recv_tokens=True, " - "so recv_tokens must be supplied (symm-mem-backed under zero-copy)." + recv_tokens = torch.empty( + buffer.recv_capacity_per_rank, + buffer.hidden_dim, + dtype=buffer.payload_dtype, + device=buffer.device, ) recv_topk_weights = ( buffer.recv_topk_weights_symm_buf @@ -611,27 +590,18 @@ def ep_combine( expert_out: torch.Tensor, *, num_local_tokens: Optional[int] = None, - grad_expert_out: Optional[torch.Tensor] = None, ): """Combine with autograd; caller pre-applies topk weighting. - The backward scatters the expert_out grad into grad_expert_out if - supplied, else the EpBuffer-owned symm-mem buffer (zero-copy) or a tensor - allocated in-flight (normal mode). When the buffer was created with - caller_provides_grad_expert_out, grad_expert_out must be supplied. - Result shape is (num_local_tokens, buffer.hidden_dim); defaults to - buffer.max_tokens_per_rank rows. + The backward scatters the expert_out grad into the EpBuffer grad target + (caller-supplied or buffer-owned under zero-copy), or a tensor allocated + in-flight (normal mode). Result shape is (num_local_tokens, buffer.hidden_dim); + defaults to buffer.max_tokens_per_rank rows. """ _require_bf16("expert_out", expert_out) if num_local_tokens is None: num_local_tokens = buffer.max_tokens_per_rank - if grad_expert_out is None and not buffer.caller_provides_grad_expert_out: - grad_expert_out = buffer.grad_expert_out_symm_buf - if grad_expert_out is None and buffer.caller_provides_grad_expert_out: - raise ValueError( - "ep_combine: buffer was created with caller_provides_grad_expert_out=True, " - "so grad_expert_out must be supplied (symm-mem-backed under zero-copy)." - ) + grad_expert_out = buffer.grad_expert_out_symm_buf return _EpCombine.apply( buffer.handle_mem, num_local_tokens, From 9a885849092982b045711a2174f730c261ae5c2a Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 25 Jun 2026 15:54:00 +0200 Subject: [PATCH 37/39] Cleanup ep_moe.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Phuong Nguyen --- examples/pytorch/ep/ep_moe.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/pytorch/ep/ep_moe.py b/examples/pytorch/ep/ep_moe.py index c93f8bb6e4..149185f251 100644 --- a/examples/pytorch/ep/ep_moe.py +++ b/examples/pytorch/ep/ep_moe.py @@ -209,18 +209,18 @@ def _run_layer(args, rank, world_size, ep_size, num_experts, num_local_experts, torch.cuda.synchronize() dist.barrier() for _ in range(args.benchmark_warmup): - rt, rw, _tc = ep_dispatch(buffer, tokens.detach(), topk_idx, topk_w, **dispatch_kw) + rt, rw, _tc = ep_dispatch(buffer, tokens.detach(), topk_idx, topk_w) expert_out = _batched_expert_linear(rt, kernels_local, num_local_experts) expert_out = expert_out * rw.unsqueeze(-1).to(expert_out.dtype) - ep_combine(buffer, expert_out, **combine_kw) + ep_combine(buffer, expert_out) torch.cuda.synchronize() dist.barrier() t0 = time.perf_counter() for _ in range(args.benchmark_iters): - rt, rw, _tc = ep_dispatch(buffer, tokens.detach(), topk_idx, topk_w, **dispatch_kw) + rt, rw, _tc = ep_dispatch(buffer, tokens.detach(), topk_idx, topk_w) expert_out = _batched_expert_linear(rt, kernels_local, num_local_experts) expert_out = expert_out * rw.unsqueeze(-1).to(expert_out.dtype) - ep_combine(buffer, expert_out, **combine_kw) + ep_combine(buffer, expert_out) torch.cuda.synchronize() dt_ms = (time.perf_counter() - t0) * 1000.0 / args.benchmark_iters if rank == 0: From 008c3d20affe5781c07a0434ae8dd905ec504229 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 25 Jun 2026 06:53:19 -0700 Subject: [PATCH 38/39] EP PyTorch: drop redundant warnings reimport and add _EpCombine.backward docstring Signed-off-by: Phuong Nguyen --- transformer_engine/pytorch/ep.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 27a2052019..b41115976a 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -48,8 +48,6 @@ def _check_nccl_runtime_version() -> None: lib = ctypes.CDLL("libnccl.so.2", mode=ctypes.RTLD_GLOBAL) v = ctypes.c_int(0) if lib.ncclGetVersion(ctypes.byref(v)) != 0: - import warnings - warnings.warn("ncclGetVersion failed; skipping NCCL EP version check.") return except OSError: # libnccl not findable; let the C++ side error @@ -511,6 +509,7 @@ def forward( # type: ignore[override] @staticmethod def backward(ctx, g_result): # type: ignore[override] + """Combine bwd; scatters the result grad into the grad target.""" if not g_result.is_contiguous(): g_result = g_result.contiguous() (handle_mem,) = ctx.saved_tensors From 7d1cbc47df757d02f5148fb319e1f1900517cec4 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 25 Jun 2026 18:03:55 +0200 Subject: [PATCH 39/39] Update transformer_engine/pytorch/distributed.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Phuong Nguyen --- transformer_engine/pytorch/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 949846f839..c050f26869 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1859,7 +1859,7 @@ def symm_mem_alloc( if symm_mem.get_backend(device) != "NCCL": symm_mem.set_backend("NCCL") t = symm_mem.empty(*shape, dtype=dtype, device=device) - symm_mem.rendezvous(t, group=ep_group.group_name) + symm_mem.rendezvous(t, group=ep_group) return t